GameAI / retrieval_engine.py
j-js's picture
Update retrieval_engine.py
2ed1ad1 verified
from __future__ import annotations
import json
import os
from typing import List
from models import RetrievedChunk
from utils import clean_math_text, score_token_overlap
try:
import numpy as np
except Exception:
np = None
try:
from sentence_transformers import SentenceTransformer
except Exception:
SentenceTransformer = None
class RetrievalEngine:
def __init__(self, data_path: str = "data/gmat_hf_chunks.jsonl"):
self.data_path = data_path
self.rows = self._load_rows(data_path)
self.encoder = None
self.embeddings = None
if SentenceTransformer is not None and self.rows:
try:
self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
self.embeddings = self.encoder.encode(
[r["text"] for r in self.rows],
convert_to_numpy=True,
normalize_embeddings=True,
)
except Exception:
self.encoder = None
self.embeddings = None
def _load_rows(self, data_path: str) -> List[dict]:
rows: List[dict] = []
if not os.path.exists(data_path):
return rows
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
item = json.loads(line)
except Exception:
continue
rows.append(
{
"text": item.get("text", ""),
"topic": (
item.get("topic")
or item.get("topic_guess")
or item.get("section")
or "general"
),
"source": (
item.get("source")
or item.get("source_name")
or item.get("source_file")
or "local_corpus"
),
}
)
return rows
def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
desired_topic = (desired_topic or "").lower()
row_topic = (row_topic or "").lower()
intent = (intent or "").lower()
bonus = 0.0
if desired_topic and desired_topic in row_topic:
bonus += 1.25
if desired_topic == "algebra" and row_topic in {"algebra", "linear equations", "equations"}:
bonus += 1.0
if desired_topic == "percent" and "percent" in row_topic:
bonus += 1.0
if desired_topic in {"number_theory", "number_properties"} and any(
k in row_topic for k in ["number", "divisible", "remainder", "prime", "factor"]
):
bonus += 1.0
if desired_topic == "geometry" and any(
k in row_topic for k in ["geometry", "circle", "triangle", "area", "perimeter"]
):
bonus += 1.0
if desired_topic == "probability" and "probability" in row_topic:
bonus += 1.0
if desired_topic == "statistics" and any(
k in row_topic for k in ["statistics", "mean", "median", "average", "distribution"]
):
bonus += 1.0
if intent in {"method", "step_by_step", "full_working", "hint", "walkthrough", "instruction"}:
if any(
k in row_topic
for k in [
"algebra",
"percent",
"fractions",
"word_problems",
"general",
"ratio",
"probability",
"statistics",
]
):
bonus += 0.25
return bonus
def search(
self,
query: str,
topic: str = "",
intent: str = "answer",
k: int = 3,
) -> List[RetrievedChunk]:
if not self.rows:
return []
combined_query = clean_math_text(query)
normalized_topic = (topic or "").strip().lower()
candidate_rows = self.rows
candidate_indices = None
if normalized_topic:
exact_topic_rows = [
(i, row) for i, row in enumerate(self.rows)
if (row.get("topic") or "").strip().lower() == normalized_topic
]
partial_topic_rows = [
(i, row) for i, row in enumerate(self.rows)
if normalized_topic in (row.get("topic") or "").strip().lower()
or (row.get("topic") or "").strip().lower() in normalized_topic
]
chosen_rows = exact_topic_rows or partial_topic_rows
if chosen_rows:
candidate_indices = [i for i, _ in chosen_rows]
candidate_rows = [row for _, row in chosen_rows]
scores = []
if self.encoder is not None and self.embeddings is not None and np is not None:
try:
q = self.encoder.encode(
[combined_query],
convert_to_numpy=True,
normalize_embeddings=True,
)[0]
if candidate_indices is None:
candidate_embeddings = self.embeddings
else:
candidate_embeddings = self.embeddings[candidate_indices]
semantic_scores = candidate_embeddings @ q
for row, sem in zip(candidate_rows, semantic_scores.tolist()):
lexical = score_token_overlap(combined_query, row["text"])
bonus = self._topic_bonus(topic, row["topic"], intent)
total = 0.7 * sem + 0.3 * lexical + bonus
scores.append((total, row))
except Exception:
scores = []
if not scores:
for row in candidate_rows:
lexical = score_token_overlap(combined_query, row["text"])
bonus = self._topic_bonus(topic, row["topic"], intent)
scores.append((lexical + bonus, row))
scores.sort(key=lambda x: x[0], reverse=True)
results: List[RetrievedChunk] = []
for score, row in scores[:k]:
results.append(
RetrievedChunk(
text=row["text"],
topic=row["topic"],
source=row["source"],
score=float(score),
)
)
return results