Update src/qa.py
Browse files
src/qa.py
CHANGED
|
@@ -89,24 +89,48 @@ REASONING_PROMPT = (
|
|
| 89 |
)
|
| 90 |
|
| 91 |
# ==========================================================
|
| 92 |
-
# 5️⃣ Retrieval — FAISS +
|
| 93 |
# ==========================================================
|
|
|
|
|
|
|
| 94 |
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
|
| 95 |
-
min_similarity: float = 0.6, candidate_multiplier: int = 3
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
if not index or not chunks:
|
|
|
|
| 98 |
return []
|
| 99 |
|
| 100 |
try:
|
|
|
|
| 101 |
q_emb = _query_model.encode(
|
| 102 |
-
[f"query: {query.strip()}"],
|
|
|
|
|
|
|
| 103 |
)[0]
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
doc_embs = _query_model.encode(
|
| 111 |
[f"passage: {chunks[i]}" for i in candidate_indices],
|
| 112 |
convert_to_numpy=True,
|
|
@@ -115,28 +139,31 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
|
|
| 115 |
sims = cosine_similarity([q_emb], doc_embs)[0]
|
| 116 |
ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
|
| 117 |
|
| 118 |
-
# 3️⃣
|
| 119 |
-
filtered = [idx for idx, sim in ranked if sim >= min_similarity]
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
# 4️⃣ Neighbor fill if not enough
|
| 122 |
if len(filtered) < top_k:
|
| 123 |
expanded = set(filtered)
|
| 124 |
for idx in filtered:
|
| 125 |
-
for
|
| 126 |
-
if 0 <=
|
| 127 |
-
expanded.add(
|
| 128 |
if len(expanded) >= top_k:
|
| 129 |
break
|
| 130 |
if len(expanded) >= top_k:
|
| 131 |
break
|
| 132 |
filtered = sorted(expanded)[:top_k]
|
| 133 |
|
|
|
|
| 134 |
final_chunks = [chunks[i] for i in filtered]
|
| 135 |
-
print(f"✅ Retrieved {len(final_chunks)} chunks (semantic + neighbor fill)")
|
| 136 |
return final_chunks
|
| 137 |
|
| 138 |
except Exception as e:
|
| 139 |
-
print(f"⚠️ Retrieval error: {e}")
|
| 140 |
return []
|
| 141 |
|
| 142 |
# ==========================================================
|
|
|
|
| 89 |
)
|
| 90 |
|
| 91 |
# ==========================================================
|
| 92 |
+
# 5️⃣ Retrieval — FAISS + Re-rank + Neighbor Fill (Auto-Healing)
|
| 93 |
# ==========================================================
|
| 94 |
+
from vectorstore import build_faiss_index
|
| 95 |
+
|
| 96 |
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
|
| 97 |
+
min_similarity: float = 0.6, candidate_multiplier: int = 3,
|
| 98 |
+
embeddings: list = None):
|
| 99 |
+
"""
|
| 100 |
+
Re-rank and optionally fill with neighbors for context continuity.
|
| 101 |
+
Auto-detects and rebuilds FAISS index if dimension mismatch occurs.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
if not index or not chunks:
|
| 105 |
+
print("⚠️ No FAISS index or chunks provided — returning empty result.")
|
| 106 |
return []
|
| 107 |
|
| 108 |
try:
|
| 109 |
+
# Encode query embedding
|
| 110 |
q_emb = _query_model.encode(
|
| 111 |
+
[f"query: {query.strip()}"],
|
| 112 |
+
convert_to_numpy=True,
|
| 113 |
+
normalize_embeddings=True
|
| 114 |
)[0]
|
| 115 |
|
| 116 |
+
# ✅ Sanity check: dimension match between query and FAISS index
|
| 117 |
+
if hasattr(index, "d") and q_emb.shape[0] != index.d:
|
| 118 |
+
print(f"⚠️ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
|
| 119 |
+
if embeddings:
|
| 120 |
+
print("🔄 Rebuilding FAISS index to match embedding dimensions...")
|
| 121 |
+
index = build_faiss_index(embeddings)
|
| 122 |
+
print("✅ FAISS index successfully rebuilt.")
|
| 123 |
+
else:
|
| 124 |
+
print("❌ No embeddings available to rebuild FAISS index.")
|
| 125 |
+
return []
|
| 126 |
+
|
| 127 |
+
# Step 1️⃣ — Initial FAISS retrieval
|
| 128 |
+
num_candidates = max(top_k * candidate_multiplier, top_k + 2)
|
| 129 |
+
distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
|
| 130 |
+
candidate_indices = [int(i) for i in indices[0] if i >= 0]
|
| 131 |
+
candidate_indices = list(dict.fromkeys(candidate_indices)) # de-dupe
|
| 132 |
+
|
| 133 |
+
# Step 2️⃣ — Re-rank by cosine similarity
|
| 134 |
doc_embs = _query_model.encode(
|
| 135 |
[f"passage: {chunks[i]}" for i in candidate_indices],
|
| 136 |
convert_to_numpy=True,
|
|
|
|
| 139 |
sims = cosine_similarity([q_emb], doc_embs)[0]
|
| 140 |
ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
|
| 141 |
|
| 142 |
+
# Step 3️⃣ — Filter by similarity threshold
|
| 143 |
+
filtered = [idx for idx, sim in ranked if sim >= min_similarity]
|
| 144 |
+
if len(filtered) > top_k:
|
| 145 |
+
filtered = filtered[:top_k]
|
| 146 |
|
| 147 |
+
# Step 4️⃣ — Neighbor fill (if not enough)
|
| 148 |
if len(filtered) < top_k:
|
| 149 |
expanded = set(filtered)
|
| 150 |
for idx in filtered:
|
| 151 |
+
for neighbor in [idx - 1, idx + 1]:
|
| 152 |
+
if 0 <= neighbor < len(chunks):
|
| 153 |
+
expanded.add(neighbor)
|
| 154 |
if len(expanded) >= top_k:
|
| 155 |
break
|
| 156 |
if len(expanded) >= top_k:
|
| 157 |
break
|
| 158 |
filtered = sorted(expanded)[:top_k]
|
| 159 |
|
| 160 |
+
# Step 5️⃣ — Build final chunk list
|
| 161 |
final_chunks = [chunks[i] for i in filtered]
|
| 162 |
+
print(f"✅ Retrieved {len(final_chunks)} chunks (semantic + neighbor fill).")
|
| 163 |
return final_chunks
|
| 164 |
|
| 165 |
except Exception as e:
|
| 166 |
+
print(f"⚠️ Retrieval error: {repr(e)}")
|
| 167 |
return []
|
| 168 |
|
| 169 |
# ==========================================================
|