j-js commited on
Commit
55b1e0c
·
verified ·
1 Parent(s): ba68b33

Update retrieval_engine.py

Browse files
Files changed (1) hide show
  1. retrieval_engine.py +56 -58
retrieval_engine.py CHANGED
@@ -120,70 +120,68 @@ class RetrievalEngine:
120
  intent: str = "answer",
121
  k: int = 3,
122
  ) -> List[RetrievedChunk]:
123
-
124
  if not self.rows:
125
  return []
126
 
127
- combined_query = clean_math_text(query)
128
- normalized_topic = (topic or "").strip().lower()
129
-
130
- # First narrow the pool when we have a specific topic.
131
- candidate_rows = self.rows
132
- if normalized_topic:
133
- exact_topic_rows = [
134
- row for row in self.rows
135
- if (row.get("topic") or "").strip().lower() == normalized_topic
136
- ]
137
- if exact_topic_rows:
138
- candidate_rows = exact_topic_rows
139
-
140
- scores = []
141
-
142
- if self.encoder is not None and self.embeddings is not None and np is not None:
143
- try:
144
- q = self.encoder.encode(
145
- [combined_query],
146
- convert_to_numpy=True,
147
- normalize_embeddings=True,
148
- )[0]
149
-
150
- # If we filtered rows, we must also filter embeddings to the same indices.
151
- if candidate_rows is self.rows:
152
- candidate_embeddings = self.embeddings
153
- else:
154
- candidate_indices = [
155
- i for i, row in enumerate(self.rows)
156
- if (row.get("topic") or "").strip().lower() == normalized_topic
157
- ]
158
- candidate_embeddings = self.embeddings[candidate_indices]
159
 
160
- semantic_scores = candidate_embeddings @ q
161
 
162
- for row, sem in zip(candidate_rows, semantic_scores.tolist()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  lexical = score_token_overlap(combined_query, row["text"])
164
  bonus = self._topic_bonus(topic, row["topic"], intent)
165
- total = 0.7 * sem + 0.3 * lexical + bonus
166
- scores.append((total, row))
167
- except Exception:
168
- scores = []
169
-
170
- if not scores:
171
- for row in candidate_rows:
172
- lexical = score_token_overlap(combined_query, row["text"])
173
- bonus = self._topic_bonus(topic, row["topic"], intent)
174
- scores.append((lexical + bonus, row))
175
-
176
- scores.sort(key=lambda x: x[0], reverse=True)
177
-
178
- results: List[RetrievedChunk] = []
179
- for score, row in scores[:k]:
180
- results.append(
181
- RetrievedChunk(
182
- text=row["text"],
183
- topic=row["topic"],
184
- source=row["source"],
185
- score=float(score),
186
  )
187
- )
188
 
189
- return results
 
120
  intent: str = "answer",
121
  k: int = 3,
122
  ) -> List[RetrievedChunk]:
123
+
124
  if not self.rows:
125
  return []
126
 
127
+ combined_query = clean_math_text(query)
128
+ normalized_topic = (topic or "").strip().lower()
129
+
130
+ # Narrow search pool by topic if possible
131
+ candidate_rows = self.rows
132
+ candidate_indices = None
133
+
134
+ if normalized_topic:
135
+ exact_topic_rows = [
136
+ (i, row) for i, row in enumerate(self.rows)
137
+ if (row.get("topic") or "").strip().lower() == normalized_topic
138
+ ]
139
+ if exact_topic_rows:
140
+ candidate_indices = [i for i, _ in exact_topic_rows]
141
+ candidate_rows = [row for _, row in exact_topic_rows]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ scores = []
144
 
145
+ if self.encoder is not None and self.embeddings is not None and np is not None:
146
+ try:
147
+ q = self.encoder.encode(
148
+ [combined_query],
149
+ convert_to_numpy=True,
150
+ normalize_embeddings=True,
151
+ )[0]
152
+
153
+ if candidate_indices is None:
154
+ candidate_embeddings = self.embeddings
155
+ else:
156
+ candidate_embeddings = self.embeddings[candidate_indices]
157
+
158
+ semantic_scores = candidate_embeddings @ q
159
+
160
+ for row, sem in zip(candidate_rows, semantic_scores.tolist()):
161
+ lexical = score_token_overlap(combined_query, row["text"])
162
+ bonus = self._topic_bonus(topic, row["topic"], intent)
163
+ total = 0.7 * sem + 0.3 * lexical + bonus
164
+ scores.append((total, row))
165
+ except Exception:
166
+ scores = []
167
+
168
+ if not scores:
169
+ for row in candidate_rows:
170
  lexical = score_token_overlap(combined_query, row["text"])
171
  bonus = self._topic_bonus(topic, row["topic"], intent)
172
+ scores.append((lexical + bonus, row))
173
+
174
+ scores.sort(key=lambda x: x[0], reverse=True)
175
+
176
+ results: List[RetrievedChunk] = []
177
+ for score, row in scores[:k]:
178
+ results.append(
179
+ RetrievedChunk(
180
+ text=row["text"],
181
+ topic=row["topic"],
182
+ source=row["source"],
183
+ score=float(score),
184
+ )
 
 
 
 
 
 
 
 
185
  )
 
186
 
187
+ return results