Update src/qa.py
Browse files
src/qa.py
CHANGED
|
@@ -85,71 +85,66 @@ REASONING_PROMPT = (
|
|
| 85 |
# ==========================================================
|
| 86 |
# 5️⃣ Retrieve Chunks (FAISS + Rerank + Neighbor Expansion)
|
| 87 |
# ==========================================================
|
| 88 |
-
def retrieve_chunks(
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
chunks
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
expansion_window: int = 1,
|
| 96 |
-
max_context_chunks: int = 6,
|
| 97 |
-
):
|
| 98 |
-
"""Retrieve semantically relevant chunks with reranking and neighbor expansion."""
|
| 99 |
if not index or not chunks:
|
| 100 |
return []
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
return [chunks[i] for i in final_order]
|
| 153 |
|
| 154 |
# ==========================================================
|
| 155 |
# 6️⃣ Answer Generation
|
|
|
|
| 85 |
# ==========================================================
|
| 86 |
# 5️⃣ Retrieve Chunks (FAISS + Rerank + Neighbor Expansion)
|
| 87 |
# ==========================================================
|
| 88 |
+
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, min_similarity: float = 0.6):
|
| 89 |
+
"""
|
| 90 |
+
Hybrid retrieval:
|
| 91 |
+
1️⃣ Get semantic top-K chunks via FAISS.
|
| 92 |
+
2️⃣ Re-rank by cosine similarity and apply a minimum similarity filter.
|
| 93 |
+
3️⃣ If fewer than top_k remain, fill remaining seats with adjacent chunks (±1) for continuity.
|
| 94 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
if not index or not chunks:
|
| 96 |
return []
|
| 97 |
|
| 98 |
+
try:
|
| 99 |
+
# Encode query
|
| 100 |
+
q_emb = _query_model.encode(
|
| 101 |
+
[f"query: {query.strip()}"],
|
| 102 |
+
convert_to_numpy=True,
|
| 103 |
+
normalize_embeddings=True
|
| 104 |
+
)[0]
|
| 105 |
+
|
| 106 |
+
# Step 1️⃣ — FAISS initial retrieval
|
| 107 |
+
distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
|
| 108 |
+
retrieved_indices = list(indices[0])
|
| 109 |
+
|
| 110 |
+
# Step 2️⃣ — Compute cosine similarity for re-ranking
|
| 111 |
+
retrieved_texts = [chunks[i] for i in retrieved_indices]
|
| 112 |
+
doc_embs = _query_model.encode(
|
| 113 |
+
[f"passage: {c}" for c in retrieved_texts],
|
| 114 |
+
convert_to_numpy=True,
|
| 115 |
+
normalize_embeddings=True
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
sims = cosine_similarity([q_emb], doc_embs)[0]
|
| 119 |
+
ranked = sorted(zip(retrieved_indices, sims), key=lambda x: x[1], reverse=True)
|
| 120 |
+
|
| 121 |
+
# Step 3️⃣ — Apply minimum similarity filter
|
| 122 |
+
filtered_indices = [idx for idx, score in ranked if score >= min_similarity]
|
| 123 |
+
|
| 124 |
+
# Step 4️⃣ — If not enough, add ±1 neighbors for continuity
|
| 125 |
+
if len(filtered_indices) < top_k:
|
| 126 |
+
extras_needed = top_k - len(filtered_indices)
|
| 127 |
+
expanded_indices = set(filtered_indices)
|
| 128 |
+
for idx in filtered_indices:
|
| 129 |
+
for neighbor in [idx - 1, idx + 1]:
|
| 130 |
+
if 0 <= neighbor < len(chunks):
|
| 131 |
+
expanded_indices.add(neighbor)
|
| 132 |
+
if len(expanded_indices) >= top_k:
|
| 133 |
+
break
|
| 134 |
+
if len(expanded_indices) >= top_k:
|
| 135 |
+
break
|
| 136 |
+
filtered_indices = list(sorted(expanded_indices))[:top_k]
|
| 137 |
+
|
| 138 |
+
# Step 5️⃣ — Build final ordered list of chunks
|
| 139 |
+
final_chunks = [chunks[i] for i in filtered_indices]
|
| 140 |
+
|
| 141 |
+
print(f"✅ Retrieved {len(final_chunks)} chunks (semantic + neighbor fill).")
|
| 142 |
+
return final_chunks
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"⚠️ Retrieval error: {e}")
|
| 146 |
+
return []
|
| 147 |
+
|
|
|
|
| 148 |
|
| 149 |
# ==========================================================
|
| 150 |
# 6️⃣ Answer Generation
|