j-js commited on
Commit
2ed1ad1
·
verified ·
1 Parent(s): 90ad83b

Update retrieval_engine.py

Browse files
Files changed (1) hide show
  1. retrieval_engine.py +12 -4
retrieval_engine.py CHANGED
@@ -69,6 +69,7 @@ class RetrievalEngine:
69
  ),
70
  }
71
  )
 
72
 
73
  def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
74
  desired_topic = (desired_topic or "").lower()
@@ -136,7 +137,6 @@ class RetrievalEngine:
136
  combined_query = clean_math_text(query)
137
  normalized_topic = (topic or "").strip().lower()
138
 
139
- # Narrow search pool by topic if possible
140
  candidate_rows = self.rows
141
  candidate_indices = None
142
 
@@ -145,9 +145,17 @@ class RetrievalEngine:
145
  (i, row) for i, row in enumerate(self.rows)
146
  if (row.get("topic") or "").strip().lower() == normalized_topic
147
  ]
148
- if exact_topic_rows:
149
- candidate_indices = [i for i, _ in exact_topic_rows]
150
- candidate_rows = [row for _, row in exact_topic_rows]
 
 
 
 
 
 
 
 
151
 
152
  scores = []
153
 
 
69
  ),
70
  }
71
  )
72
+ return rows
73
 
74
  def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
75
  desired_topic = (desired_topic or "").lower()
 
137
  combined_query = clean_math_text(query)
138
  normalized_topic = (topic or "").strip().lower()
139
 
 
140
  candidate_rows = self.rows
141
  candidate_indices = None
142
 
 
145
  (i, row) for i, row in enumerate(self.rows)
146
  if (row.get("topic") or "").strip().lower() == normalized_topic
147
  ]
148
+
149
+ partial_topic_rows = [
150
+ (i, row) for i, row in enumerate(self.rows)
151
+ if normalized_topic in (row.get("topic") or "").strip().lower()
152
+ or (row.get("topic") or "").strip().lower() in normalized_topic
153
+ ]
154
+
155
+ chosen_rows = exact_topic_rows or partial_topic_rows
156
+ if chosen_rows:
157
+ candidate_indices = [i for i, _ in chosen_rows]
158
+ candidate_rows = [row for _, row in chosen_rows]
159
 
160
  scores = []
161