Update retrieval_engine.py
Browse files- retrieval_engine.py +12 -4
retrieval_engine.py
CHANGED
|
@@ -69,6 +69,7 @@ class RetrievalEngine:
|
|
| 69 |
),
|
| 70 |
}
|
| 71 |
)
|
|
|
|
| 72 |
|
| 73 |
def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
|
| 74 |
desired_topic = (desired_topic or "").lower()
|
|
@@ -136,7 +137,6 @@ class RetrievalEngine:
|
|
| 136 |
combined_query = clean_math_text(query)
|
| 137 |
normalized_topic = (topic or "").strip().lower()
|
| 138 |
|
| 139 |
-
# Narrow search pool by topic if possible
|
| 140 |
candidate_rows = self.rows
|
| 141 |
candidate_indices = None
|
| 142 |
|
|
@@ -145,9 +145,17 @@ class RetrievalEngine:
|
|
| 145 |
(i, row) for i, row in enumerate(self.rows)
|
| 146 |
if (row.get("topic") or "").strip().lower() == normalized_topic
|
| 147 |
]
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
scores = []
|
| 153 |
|
|
|
|
| 69 |
),
|
| 70 |
}
|
| 71 |
)
|
| 72 |
+
return rows
|
| 73 |
|
| 74 |
def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
|
| 75 |
desired_topic = (desired_topic or "").lower()
|
|
|
|
| 137 |
combined_query = clean_math_text(query)
|
| 138 |
normalized_topic = (topic or "").strip().lower()
|
| 139 |
|
|
|
|
| 140 |
candidate_rows = self.rows
|
| 141 |
candidate_indices = None
|
| 142 |
|
|
|
|
| 145 |
(i, row) for i, row in enumerate(self.rows)
|
| 146 |
if (row.get("topic") or "").strip().lower() == normalized_topic
|
| 147 |
]
|
| 148 |
+
|
| 149 |
+
partial_topic_rows = [
|
| 150 |
+
(i, row) for i, row in enumerate(self.rows)
|
| 151 |
+
if normalized_topic in (row.get("topic") or "").strip().lower()
|
| 152 |
+
or (row.get("topic") or "").strip().lower() in normalized_topic
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
chosen_rows = exact_topic_rows or partial_topic_rows
|
| 156 |
+
if chosen_rows:
|
| 157 |
+
candidate_indices = [i for i, _ in chosen_rows]
|
| 158 |
+
candidate_rows = [row for _, row in chosen_rows]
|
| 159 |
|
| 160 |
scores = []
|
| 161 |
|