Update src/qa.py
Browse files
src/qa.py
CHANGED
|
@@ -215,20 +215,27 @@ REASONING_PROMPT = (
|
|
| 215 |
# ==========================================================
|
| 216 |
from vectorstore import build_faiss_index
|
| 217 |
|
| 218 |
-
def retrieve_chunks(query: str, index, chunks: list, top_k: int =
|
| 219 |
-
min_similarity: float = 0.
|
| 220 |
embeddings: list = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
if not index or not chunks:
|
| 222 |
print("⚠️ No FAISS index or chunks provided — returning empty result.")
|
| 223 |
return []
|
| 224 |
|
| 225 |
try:
|
|
|
|
| 226 |
q_emb = _query_model.encode(
|
| 227 |
[f"query: {query.strip()}"],
|
| 228 |
convert_to_numpy=True,
|
| 229 |
normalize_embeddings=True
|
| 230 |
)[0]
|
| 231 |
|
|
|
|
| 232 |
if hasattr(index, "d") and q_emb.shape[0] != index.d:
|
| 233 |
print(f"⚠️ FAISS dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
|
| 234 |
if embeddings:
|
|
@@ -237,41 +244,56 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
|
|
| 237 |
else:
|
| 238 |
return []
|
| 239 |
|
|
|
|
| 240 |
num_candidates = max(top_k * candidate_multiplier, top_k + 2)
|
| 241 |
distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
|
| 242 |
candidate_indices = [int(i) for i in indices[0] if i >= 0]
|
| 243 |
-
candidate_indices = list(dict.fromkeys(candidate_indices))
|
| 244 |
|
|
|
|
| 245 |
doc_embs = _query_model.encode(
|
| 246 |
[f"passage: {chunks[i]}" for i in candidate_indices],
|
| 247 |
convert_to_numpy=True,
|
| 248 |
normalize_embeddings=True,
|
| 249 |
)
|
| 250 |
sims = cosine_similarity([q_emb], doc_embs)[0]
|
|
|
|
| 251 |
boosted_sims = []
|
| 252 |
for idx, sim in zip(candidate_indices, sims):
|
| 253 |
text = chunks[idx].strip()
|
| 254 |
if re.match(r"^[-•\d]+[\.\s]", text):
|
| 255 |
-
sim += 0.05
|
| 256 |
boosted_sims.append((idx, sim))
|
| 257 |
|
| 258 |
ranked = sorted(boosted_sims, key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
| 259 |
filtered = [idx for idx, sim in ranked if sim >= min_similarity][:top_k]
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
neighbors = set()
|
| 262 |
for idx in filtered:
|
| 263 |
for n in [idx - 1, idx + 1]:
|
| 264 |
if 0 <= n < len(chunks):
|
| 265 |
neighbors.add(n)
|
| 266 |
filtered = sorted(set(filtered) | neighbors)
|
|
|
|
|
|
|
| 267 |
final_chunks = [chunks[i] for i in filtered]
|
| 268 |
-
|
|
|
|
| 269 |
return final_chunks
|
| 270 |
|
| 271 |
except Exception as e:
|
| 272 |
print(f"⚠️ Retrieval error: {repr(e)}")
|
| 273 |
return []
|
| 274 |
|
|
|
|
| 275 |
# ==========================================================
|
| 276 |
# 8️⃣ Answer Generation (Lazy GPT-4o Initialization)
|
| 277 |
# ==========================================================
|
|
|
|
| 215 |
# ==========================================================
|
| 216 |
from vectorstore import build_faiss_index
|
| 217 |
|
| 218 |
+
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 7,
|
| 219 |
+
min_similarity: float = 0.4, candidate_multiplier: int = 3,
|
| 220 |
embeddings: list = None):
|
| 221 |
+
"""
|
| 222 |
+
Retrieves the most relevant chunks using FAISS similarity + reranking.
|
| 223 |
+
Includes bullet-aware similarity boost and a fallback mechanism if
|
| 224 |
+
similarity threshold isn't met — ensuring predictable, complete retrieval.
|
| 225 |
+
"""
|
| 226 |
if not index or not chunks:
|
| 227 |
print("⚠️ No FAISS index or chunks provided — returning empty result.")
|
| 228 |
return []
|
| 229 |
|
| 230 |
try:
|
| 231 |
+
# --- Encode query
|
| 232 |
q_emb = _query_model.encode(
|
| 233 |
[f"query: {query.strip()}"],
|
| 234 |
convert_to_numpy=True,
|
| 235 |
normalize_embeddings=True
|
| 236 |
)[0]
|
| 237 |
|
| 238 |
+
# --- Rebuild index if mismatch occurs
|
| 239 |
if hasattr(index, "d") and q_emb.shape[0] != index.d:
|
| 240 |
print(f"⚠️ FAISS dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
|
| 241 |
if embeddings:
|
|
|
|
| 244 |
else:
|
| 245 |
return []
|
| 246 |
|
| 247 |
+
# --- Retrieve top candidate chunks
|
| 248 |
num_candidates = max(top_k * candidate_multiplier, top_k + 2)
|
| 249 |
distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
|
| 250 |
candidate_indices = [int(i) for i in indices[0] if i >= 0]
|
| 251 |
+
candidate_indices = list(dict.fromkeys(candidate_indices)) # remove duplicates
|
| 252 |
|
| 253 |
+
# --- Re-rank using cosine similarity
|
| 254 |
doc_embs = _query_model.encode(
|
| 255 |
[f"passage: {chunks[i]}" for i in candidate_indices],
|
| 256 |
convert_to_numpy=True,
|
| 257 |
normalize_embeddings=True,
|
| 258 |
)
|
| 259 |
sims = cosine_similarity([q_emb], doc_embs)[0]
|
| 260 |
+
|
| 261 |
boosted_sims = []
|
| 262 |
for idx, sim in zip(candidate_indices, sims):
|
| 263 |
text = chunks[idx].strip()
|
| 264 |
if re.match(r"^[-•\d]+[\.\s]", text):
|
| 265 |
+
sim += 0.05 # slight boost for procedural bullets
|
| 266 |
boosted_sims.append((idx, sim))
|
| 267 |
|
| 268 |
ranked = sorted(boosted_sims, key=lambda x: x[1], reverse=True)
|
| 269 |
+
|
| 270 |
+
# --- Filter based on similarity threshold
|
| 271 |
filtered = [idx for idx, sim in ranked if sim >= min_similarity][:top_k]
|
| 272 |
|
| 273 |
+
# --- Fallback: if no matches above threshold, pick top_k anyway
|
| 274 |
+
if not filtered:
|
| 275 |
+
print(f"⚠️ No chunks ≥ {min_similarity:.2f} — using top {top_k} ranked chunks instead.")
|
| 276 |
+
filtered = [idx for idx, sim in ranked[:top_k]]
|
| 277 |
+
|
| 278 |
+
# --- Neighbor continuity: include nearby chunks
|
| 279 |
neighbors = set()
|
| 280 |
for idx in filtered:
|
| 281 |
for n in [idx - 1, idx + 1]:
|
| 282 |
if 0 <= n < len(chunks):
|
| 283 |
neighbors.add(n)
|
| 284 |
filtered = sorted(set(filtered) | neighbors)
|
| 285 |
+
|
| 286 |
+
# --- Return final chunk set
|
| 287 |
final_chunks = [chunks[i] for i in filtered]
|
| 288 |
+
avg_sim = np.mean([s for _, s in ranked[:top_k]])
|
| 289 |
+
print(f"✅ Retrieved {len(final_chunks)} chunks | avg_sim={avg_sim:.3f} | threshold={min_similarity:.2f}")
|
| 290 |
return final_chunks
|
| 291 |
|
| 292 |
except Exception as e:
|
| 293 |
print(f"⚠️ Retrieval error: {repr(e)}")
|
| 294 |
return []
|
| 295 |
|
| 296 |
+
|
| 297 |
# ==========================================================
|
| 298 |
# 8️⃣ Answer Generation (Lazy GPT-4o Initialization)
|
| 299 |
# ==========================================================
|