""" Lightweight model distillation from Kaggle datasets. """ import json import os import csv import logging from typing import List, Dict, Tuple, Optional from datetime import datetime from pathlib import Path logger = logging.getLogger(__name__) class KnowledgeDistiller: """Distills datasets into lightweight domain models.""" def __init__(self, data_dir: Optional[str] = None): if data_dir is None: from app.config import DATA_DIR self.data_dir = Path(DATA_DIR) else: self.data_dir = Path(data_dir) self.models_dir = self.data_dir / "distilled_models" self.models_dir.mkdir(parents=True, exist_ok=True) def distill_dataset_to_model( self, dataset_path: str, domain: str, model_name: str, max_size_kb: int = 500 ) -> Dict: """Extract and compress dataset into lightweight domain model.""" logger.info(f"Distilling {dataset_path} for {domain}...") # 1. Extract QA pairs qa_pairs = self._extract_qa_pairs(dataset_path, domain) if not qa_pairs: logger.warning(f"No QA pairs extracted from {dataset_path}") return {} # 2. Rank by relevance ranked_qa = self._rank_qa_pairs(qa_pairs, domain) # 3. Select within size constraint compressed_qa = self._compress_to_size_limit(ranked_qa, max_size_kb) # 4. Create model model = { "name": model_name, "domain": domain, "created_at": datetime.now().isoformat(), "qa_pairs": compressed_qa, "metadata": { "total_extracted": len(qa_pairs), "selected_pairs": len(compressed_qa), "avg_relevance": sum(p.get("relevance", 0) for p in compressed_qa) / len(compressed_qa) if compressed_qa else 0, "size_kb": self._estimate_size_kb(compressed_qa), } } # 5. Save model model_path = self.models_dir / f"{domain}_primary.json" with open(model_path, 'w') as f: json.dump(model, f, separators=(',', ':')) logger.info(f"✓ Model saved to {model_path} ({model['metadata']['size_kb']} KB)") return model["metadata"] def load_model(self, domain: str) -> Optional[Dict]: """Load distilled model from disk.""" model_path = self.models_dir / f"{domain}_primary.json" if not model_path.exists(): return None try: with open(model_path) as f: return json.load(f) except Exception as e: logger.error(f"Failed to load model {domain}: {e}") return None def query_model(self, model: Dict, query: str, top_k: int = 3) -> List[str]: """Query a distilled model for relevant insights.""" qa_pairs = model.get("qa_pairs", []) if not qa_pairs: return [] query_words = set(query.lower().split()) stop_words = {"what", "is", "the", "how", "does", "of", "in", "for", "a", "an", "to", "and", "or", "on", "with", "are", "do", "you", "tell", "me", "about"} query_words = query_words - stop_words if not query_words: return [] scored = [] for pair in qa_pairs: q_text = pair.get("question", "").lower() q_words = set(q_text.split()) # Simple keyword overlap (excluding stop words from QA as well) overlap = len(query_words & (q_words - stop_words)) if overlap > 0: # Weight by overlap and relevance score = overlap * pair.get("relevance", 0.5) scored.append((pair.get("answer"), score)) # Sort and return top unique answers scored.sort(key=lambda x: x[1], reverse=True) seen = set() results = [] for ans, _ in scored: if ans not in seen: results.append(ans) seen.add(ans) if len(results) >= top_k: break return results def _extract_qa_pairs(self, dataset_path: str, domain: str) -> List[Dict]: """Walk through files and extract QA pairs.""" qa_pairs = [] for root, _, files in os.walk(dataset_path): for file in files: file_path = os.path.join(root, file) try: if file.endswith('.csv'): qa_pairs.extend(self._extract_from_csv(file_path)) elif file.endswith('.json'): qa_pairs.extend(self._extract_from_json(file_path)) except Exception as e: logger.debug(f"Skipping {file}: {e}") return qa_pairs def _extract_from_csv(self, path: str) -> List[Dict]: pairs = [] with open(path, encoding='utf-8', errors='ignore') as f: reader = csv.DictReader(f) # Find columns that look like Q&A or key metrics cols = reader.fieldnames or [] q_col = next((c for c in cols if any(k in c.lower() for k in ['question', 'title', 'name', 'indicator'])), None) a_col = next((c for c in cols if any(k in c.lower() for k in ['answer', 'desc', 'value', 'price'])), None) if q_col and a_col: for row in reader: q, a = row.get(q_col), row.get(a_col) if q and a and len(str(q)) > 5: pairs.append({"question": str(q), "answer": str(a)}) return pairs def _extract_from_json(self, path: str) -> List[Dict]: pairs = [] with open(path, encoding='utf-8', errors='ignore') as f: data = json.load(f) if isinstance(data, list): for item in data: if isinstance(item, dict): q = item.get('question') or item.get('q') or item.get('title') a = item.get('answer') or item.get('a') or item.get('content') if q and a: pairs.append({"question": str(q), "answer": str(a)}) return pairs def _rank_qa_pairs(self, pairs: List[Dict], domain: str) -> List[Dict]: keywords = { "finance": ["stock", "price", "market", "revenue", "earnings", "valuation", "ratio", "dividend"], "tech": ["software", "algorithm", "platform", "cloud", "ai", "latency", "architecture"], "healthcare": ["drug", "efficacy", "trial", "patient", "disease", "treatment", "medical"], }.get(domain, []) for p in pairs: text = (p['question'] + " " + p['answer']).lower() matches = sum(1 for k in keywords if k in text) p["relevance"] = min(1.0, 0.2 + (matches * 0.2)) return sorted(pairs, key=lambda x: x["relevance"], reverse=True) def _compress_to_size_limit(self, pairs: List[Dict], max_kb: int) -> List[Dict]: selected = [] current_size = 0 for p in pairs: # Estimate size: roughly length of JSON string size = len(json.dumps(p)) / 1024 if current_size + size <= max_kb: selected.append(p) current_size += size else: break return selected def _estimate_size_kb(self, pairs: List[Dict]) -> float: return len(json.dumps(pairs).encode('utf-8')) / 1024