from __future__ import annotations import json import os from typing import List from models import RetrievedChunk from utils import clean_math_text, score_token_overlap try: import numpy as np except Exception: np = None try: from sentence_transformers import SentenceTransformer except Exception: SentenceTransformer = None class RetrievalEngine: def __init__(self, data_path: str = "data/gmat_hf_chunks.jsonl"): self.data_path = data_path self.rows = self._load_rows(data_path) self.encoder = None self.embeddings = None if SentenceTransformer is not None and self.rows: try: self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") self.embeddings = self.encoder.encode( [r["text"] for r in self.rows], convert_to_numpy=True, normalize_embeddings=True, ) except Exception: self.encoder = None self.embeddings = None def _load_rows(self, data_path: str) -> List[dict]: rows: List[dict] = [] if not os.path.exists(data_path): return rows with open(data_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: item = json.loads(line) except Exception: continue rows.append( { "text": item.get("text", ""), "topic": ( item.get("topic") or item.get("topic_guess") or item.get("section") or "general" ), "source": ( item.get("source") or item.get("source_name") or item.get("source_file") or "local_corpus" ), } ) return rows def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float: desired_topic = (desired_topic or "").lower() row_topic = (row_topic or "").lower() intent = (intent or "").lower() bonus = 0.0 if desired_topic and desired_topic in row_topic: bonus += 1.25 if desired_topic == "algebra" and row_topic in {"algebra", "linear equations", "equations"}: bonus += 1.0 if desired_topic == "percent" and "percent" in row_topic: bonus += 1.0 if desired_topic in {"number_theory", "number_properties"} and any( k in row_topic for k in ["number", "divisible", "remainder", "prime", "factor"] ): bonus += 1.0 if desired_topic == "geometry" and any( k in row_topic for k in ["geometry", "circle", "triangle", "area", "perimeter"] ): bonus += 1.0 if desired_topic == "probability" and "probability" in row_topic: bonus += 1.0 if desired_topic == "statistics" and any( k in row_topic for k in ["statistics", "mean", "median", "average", "distribution"] ): bonus += 1.0 if intent in {"method", "step_by_step", "full_working", "hint", "walkthrough", "instruction"}: if any( k in row_topic for k in [ "algebra", "percent", "fractions", "word_problems", "general", "ratio", "probability", "statistics", ] ): bonus += 0.25 return bonus def search( self, query: str, topic: str = "", intent: str = "answer", k: int = 3, ) -> List[RetrievedChunk]: if not self.rows: return [] combined_query = clean_math_text(query) normalized_topic = (topic or "").strip().lower() candidate_rows = self.rows candidate_indices = None if normalized_topic: exact_topic_rows = [ (i, row) for i, row in enumerate(self.rows) if (row.get("topic") or "").strip().lower() == normalized_topic ] partial_topic_rows = [ (i, row) for i, row in enumerate(self.rows) if normalized_topic in (row.get("topic") or "").strip().lower() or (row.get("topic") or "").strip().lower() in normalized_topic ] chosen_rows = exact_topic_rows or partial_topic_rows if chosen_rows: candidate_indices = [i for i, _ in chosen_rows] candidate_rows = [row for _, row in chosen_rows] scores = [] if self.encoder is not None and self.embeddings is not None and np is not None: try: q = self.encoder.encode( [combined_query], convert_to_numpy=True, normalize_embeddings=True, )[0] if candidate_indices is None: candidate_embeddings = self.embeddings else: candidate_embeddings = self.embeddings[candidate_indices] semantic_scores = candidate_embeddings @ q for row, sem in zip(candidate_rows, semantic_scores.tolist()): lexical = score_token_overlap(combined_query, row["text"]) bonus = self._topic_bonus(topic, row["topic"], intent) total = 0.7 * sem + 0.3 * lexical + bonus scores.append((total, row)) except Exception: scores = [] if not scores: for row in candidate_rows: lexical = score_token_overlap(combined_query, row["text"]) bonus = self._topic_bonus(topic, row["topic"], intent) scores.append((lexical + bonus, row)) scores.sort(key=lambda x: x[0], reverse=True) results: List[RetrievedChunk] = [] for score, row in scores[:k]: results.append( RetrievedChunk( text=row["text"], topic=row["topic"], source=row["source"], score=float(score), ) ) return results