j-js commited on
Commit
496d977
·
verified ·
1 Parent(s): ba0bf0a

Update retrieval_engine.py

Browse files
Files changed (1) hide show
  1. retrieval_engine.py +67 -44
retrieval_engine.py CHANGED
@@ -114,52 +114,75 @@ class RetrievalEngine:
114
  return bonus
115
 
116
  def search(
117
- self,
118
- query: str,
119
- topic: str = "",
120
- intent: str = "answer",
121
- k: int = 3,
122
- ) -> List[RetrievedChunk]:
123
- if not self.rows:
124
- return []
125
-
126
- combined_query = clean_math_text(query)
127
- scores = []
128
-
129
- if self.encoder is not None and self.embeddings is not None and np is not None:
130
- try:
131
- q = self.encoder.encode(
132
- [combined_query],
133
- convert_to_numpy=True,
134
- normalize_embeddings=True,
135
- )[0]
136
- semantic_scores = self.embeddings @ q
137
-
138
- for row, sem in zip(self.rows, semantic_scores.tolist()):
139
- lexical = score_token_overlap(combined_query, row["text"])
140
- bonus = self._topic_bonus(topic, row["topic"], intent)
141
- total = 0.7 * sem + 0.3 * lexical + bonus
142
- scores.append((total, row))
143
- except Exception:
144
- scores = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- if not scores:
147
- for row in self.rows:
148
  lexical = score_token_overlap(combined_query, row["text"])
149
  bonus = self._topic_bonus(topic, row["topic"], intent)
150
- scores.append((lexical + bonus, row))
151
-
152
- scores.sort(key=lambda x: x[0], reverse=True)
153
-
154
- results: List[RetrievedChunk] = []
155
- for score, row in scores[:k]:
156
- results.append(
157
- RetrievedChunk(
158
- text=row["text"],
159
- topic=row["topic"],
160
- source=row["source"],
161
- score=float(score),
162
- )
 
 
 
 
 
 
 
 
163
  )
 
164
 
165
- return results
 
114
  return bonus
115
 
116
  def search(
117
+ self,
118
+ query: str,
119
+ topic: str = "",
120
+ intent: str = "answer",
121
+ k: int = 3,
122
+ ) -> List[RetrievedChunk]:
123
+ if not self.rows:
124
+ return []
125
+
126
+ combined_query = clean_math_text(query)
127
+ normalized_topic = (topic or "").strip().lower()
128
+
129
+ # First narrow the pool when we have a specific topic.
130
+ candidate_rows = self.rows
131
+ if normalized_topic:
132
+ exact_topic_rows = [
133
+ row for row in self.rows
134
+ if (row.get("topic") or "").strip().lower() == normalized_topic
135
+ ]
136
+ if exact_topic_rows:
137
+ candidate_rows = exact_topic_rows
138
+
139
+ scores = []
140
+
141
+ if self.encoder is not None and self.embeddings is not None and np is not None:
142
+ try:
143
+ q = self.encoder.encode(
144
+ [combined_query],
145
+ convert_to_numpy=True,
146
+ normalize_embeddings=True,
147
+ )[0]
148
+
149
+ # If we filtered rows, we must also filter embeddings to the same indices.
150
+ if candidate_rows is self.rows:
151
+ candidate_embeddings = self.embeddings
152
+ else:
153
+ candidate_indices = [
154
+ i for i, row in enumerate(self.rows)
155
+ if (row.get("topic") or "").strip().lower() == normalized_topic
156
+ ]
157
+ candidate_embeddings = self.embeddings[candidate_indices]
158
+
159
+ semantic_scores = candidate_embeddings @ q
160
 
161
+ for row, sem in zip(candidate_rows, semantic_scores.tolist()):
 
162
  lexical = score_token_overlap(combined_query, row["text"])
163
  bonus = self._topic_bonus(topic, row["topic"], intent)
164
+ total = 0.7 * sem + 0.3 * lexical + bonus
165
+ scores.append((total, row))
166
+ except Exception:
167
+ scores = []
168
+
169
+ if not scores:
170
+ for row in candidate_rows:
171
+ lexical = score_token_overlap(combined_query, row["text"])
172
+ bonus = self._topic_bonus(topic, row["topic"], intent)
173
+ scores.append((lexical + bonus, row))
174
+
175
+ scores.sort(key=lambda x: x[0], reverse=True)
176
+
177
+ results: List[RetrievedChunk] = []
178
+ for score, row in scores[:k]:
179
+ results.append(
180
+ RetrievedChunk(
181
+ text=row["text"],
182
+ topic=row["topic"],
183
+ source=row["source"],
184
+ score=float(score),
185
  )
186
+ )
187
 
188
+ return results