File size: 6,716 Bytes
e9462cd
 
 
fa70564
a214825
e9462cd
fa70564
 
e9462cd
 
fa70564
e9462cd
fa70564
e9462cd
 
fa70564
e9462cd
 
 
 
 
fa70564
 
 
 
e9462cd
a214825
fa70564
 
 
a214825
 
 
 
 
fa70564
 
 
 
a214825
 
fa70564
 
a214825
fa70564
 
 
 
 
 
 
 
 
a214825
 
 
 
138a82b
 
 
 
 
 
 
 
 
 
 
 
a214825
 
2ed1ad1
fa70564
 
 
 
 
a214825
fa70564
a214825
fa70564
 
a214825
fa70564
 
a214825
fa70564
 
a214825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa70564
a214825
fa70564
 
0b3cc65
 
 
 
 
 
 
55b1e0c
0b3cc65
4ff13e6
496d977
55b1e0c
 
 
 
 
 
 
 
 
 
 
2ed1ad1
 
 
 
 
 
 
 
 
 
 
496d977
55b1e0c
fa70564
55b1e0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa70564
 
55b1e0c
 
 
 
 
 
 
 
 
 
 
 
 
a214825
 
55b1e0c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from __future__ import annotations

import json
import os
from typing import List

from models import RetrievedChunk
from utils import clean_math_text, score_token_overlap

try:
    import numpy as np
except Exception:
    np = None

try:
    from sentence_transformers import SentenceTransformer
except Exception:
    SentenceTransformer = None


class RetrievalEngine:
    def __init__(self, data_path: str = "data/gmat_hf_chunks.jsonl"):
        self.data_path = data_path
        self.rows = self._load_rows(data_path)
        self.encoder = None
        self.embeddings = None

        if SentenceTransformer is not None and self.rows:
            try:
                self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
                self.embeddings = self.encoder.encode(
                    [r["text"] for r in self.rows],
                    convert_to_numpy=True,
                    normalize_embeddings=True,
                )
            except Exception:
                self.encoder = None
                self.embeddings = None

    def _load_rows(self, data_path: str) -> List[dict]:
        rows: List[dict] = []
        if not os.path.exists(data_path):
            return rows

        with open(data_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    item = json.loads(line)
                except Exception:
                    continue

                rows.append(
                    {
                        "text": item.get("text", ""),
                        "topic": (
                            item.get("topic")
                            or item.get("topic_guess")
                            or item.get("section")
                            or "general"
                        ),
                        "source": (
                            item.get("source")
                            or item.get("source_name")
                            or item.get("source_file")
                            or "local_corpus"
                        ),
                    }
                )
        return rows

    def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
        desired_topic = (desired_topic or "").lower()
        row_topic = (row_topic or "").lower()
        intent = (intent or "").lower()

        bonus = 0.0

        if desired_topic and desired_topic in row_topic:
            bonus += 1.25

        if desired_topic == "algebra" and row_topic in {"algebra", "linear equations", "equations"}:
            bonus += 1.0

        if desired_topic == "percent" and "percent" in row_topic:
            bonus += 1.0

        if desired_topic in {"number_theory", "number_properties"} and any(
            k in row_topic for k in ["number", "divisible", "remainder", "prime", "factor"]
        ):
            bonus += 1.0

        if desired_topic == "geometry" and any(
            k in row_topic for k in ["geometry", "circle", "triangle", "area", "perimeter"]
        ):
            bonus += 1.0

        if desired_topic == "probability" and "probability" in row_topic:
            bonus += 1.0

        if desired_topic == "statistics" and any(
            k in row_topic for k in ["statistics", "mean", "median", "average", "distribution"]
        ):
            bonus += 1.0

        if intent in {"method", "step_by_step", "full_working", "hint", "walkthrough", "instruction"}:
            if any(
                k in row_topic
                for k in [
                    "algebra",
                    "percent",
                    "fractions",
                    "word_problems",
                    "general",
                    "ratio",
                    "probability",
                    "statistics",
                ]
            ):
                bonus += 0.25

        return bonus

    def search(
        self,
        query: str,
        topic: str = "",
        intent: str = "answer",
        k: int = 3,
    ) -> List[RetrievedChunk]:

        if not self.rows:
            return []

        combined_query = clean_math_text(query)
        normalized_topic = (topic or "").strip().lower()

        candidate_rows = self.rows
        candidate_indices = None

        if normalized_topic:
            exact_topic_rows = [
                (i, row) for i, row in enumerate(self.rows)
                if (row.get("topic") or "").strip().lower() == normalized_topic
            ]

            partial_topic_rows = [
                (i, row) for i, row in enumerate(self.rows)
                if normalized_topic in (row.get("topic") or "").strip().lower()
                or (row.get("topic") or "").strip().lower() in normalized_topic
            ]

            chosen_rows = exact_topic_rows or partial_topic_rows
            if chosen_rows:
                candidate_indices = [i for i, _ in chosen_rows]
                candidate_rows = [row for _, row in chosen_rows]

        scores = []

        if self.encoder is not None and self.embeddings is not None and np is not None:
            try:
                q = self.encoder.encode(
                    [combined_query],
                    convert_to_numpy=True,
                    normalize_embeddings=True,
                )[0]

                if candidate_indices is None:
                    candidate_embeddings = self.embeddings
                else:
                    candidate_embeddings = self.embeddings[candidate_indices]

                semantic_scores = candidate_embeddings @ q

                for row, sem in zip(candidate_rows, semantic_scores.tolist()):
                    lexical = score_token_overlap(combined_query, row["text"])
                    bonus = self._topic_bonus(topic, row["topic"], intent)
                    total = 0.7 * sem + 0.3 * lexical + bonus
                    scores.append((total, row))
            except Exception:
                scores = []

        if not scores:
            for row in candidate_rows:
                lexical = score_token_overlap(combined_query, row["text"])
                bonus = self._topic_bonus(topic, row["topic"], intent)
                scores.append((lexical + bonus, row))

        scores.sort(key=lambda x: x[0], reverse=True)

        results: List[RetrievedChunk] = []
        for score, row in scores[:k]:
            results.append(
                RetrievedChunk(
                    text=row["text"],
                    topic=row["topic"],
                    source=row["source"],
                    score=float(score),
                )
            )

        return results