Update src/qa.py
Browse files
src/qa.py
CHANGED
|
@@ -94,16 +94,30 @@ REASONING_PROMPT = (
|
|
| 94 |
|
| 95 |
|
| 96 |
# ==========================================================
|
| 97 |
-
#
|
| 98 |
# ==========================================================
|
| 99 |
from vectorstore import build_faiss_index
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
|
| 102 |
min_similarity: float = 0.6, candidate_multiplier: int = 3,
|
| 103 |
-
embeddings: list = None):
|
| 104 |
"""
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
| 107 |
"""
|
| 108 |
|
| 109 |
if not index or not chunks:
|
|
@@ -111,52 +125,67 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
|
|
| 111 |
return []
|
| 112 |
|
| 113 |
try:
|
| 114 |
-
#
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
[f"query: {query.strip()}"],
|
| 117 |
convert_to_numpy=True,
|
| 118 |
normalize_embeddings=True
|
| 119 |
)[0]
|
| 120 |
-
|
| 121 |
-
# β
Sanity check: dimension match between query and FAISS index
|
| 122 |
-
if hasattr(index, "d") and q_emb.shape[0] != index.d:
|
| 123 |
-
print(f"β οΈ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
|
| 124 |
-
if embeddings:
|
| 125 |
-
print("π Rebuilding FAISS index to match embedding dimensions...")
|
| 126 |
-
index = build_faiss_index(embeddings)
|
| 127 |
-
print("β
FAISS index successfully rebuilt.")
|
| 128 |
-
|
| 129 |
-
# β
Regenerate query embedding now that we have a matching index
|
| 130 |
-
q_emb = _query_model.encode(
|
| 131 |
-
[f"query: {query.strip()}"],
|
| 132 |
-
convert_to_numpy=True,
|
| 133 |
-
normalize_embeddings=True
|
| 134 |
-
)[0]
|
| 135 |
-
else:
|
| 136 |
-
print("β No embeddings available to rebuild FAISS index.")
|
| 137 |
-
return []
|
| 138 |
-
|
| 139 |
-
# Step 1οΈβ£ β Initial FAISS retrieval
|
| 140 |
-
num_candidates = max(top_k * candidate_multiplier, top_k + 2)
|
| 141 |
-
distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
|
| 142 |
-
candidate_indices = [int(i) for i in indices[0] if i >= 0]
|
| 143 |
-
candidate_indices = list(dict.fromkeys(candidate_indices)) # de-dupe
|
| 144 |
-
|
| 145 |
-
# Step 2οΈβ£ β Re-rank by cosine similarity
|
| 146 |
doc_embs = _query_model.encode(
|
| 147 |
[f"passage: {chunks[i]}" for i in candidate_indices],
|
| 148 |
convert_to_numpy=True,
|
| 149 |
normalize_embeddings=True,
|
| 150 |
)
|
| 151 |
-
sims = cosine_similarity([
|
| 152 |
ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
|
| 153 |
|
| 154 |
-
# Step 3
|
| 155 |
-
filtered = [idx for idx, sim in ranked if sim >=
|
| 156 |
-
if
|
| 157 |
-
filtered =
|
| 158 |
|
| 159 |
-
# Step 4
|
| 160 |
if len(filtered) < top_k:
|
| 161 |
expanded = set(filtered)
|
| 162 |
for idx in filtered:
|
|
@@ -167,11 +196,25 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
|
|
| 167 |
break
|
| 168 |
if len(expanded) >= top_k:
|
| 169 |
break
|
| 170 |
-
filtered = sorted(expanded)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
# Step
|
| 173 |
final_chunks = [chunks[i] for i in filtered]
|
| 174 |
-
print(f"β
Retrieved {len(final_chunks)} chunks (
|
| 175 |
return final_chunks
|
| 176 |
|
| 177 |
except Exception as e:
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
# ==========================================================
|
| 97 |
+
# π Improved Retrieval β Multi-Span Query + Adaptive Similarity + Context Expansion
|
| 98 |
# ==========================================================
|
| 99 |
from vectorstore import build_faiss_index
|
| 100 |
|
| 101 |
+
def _split_query(query: str):
|
| 102 |
+
"""
|
| 103 |
+
Breaks long or compound questions into smaller sub-queries for richer retrieval coverage.
|
| 104 |
+
"""
|
| 105 |
+
separators = [".", "?", "and", "then", "also", ",", ";"]
|
| 106 |
+
for sep in separators:
|
| 107 |
+
query = query.replace(sep, "|")
|
| 108 |
+
parts = [q.strip() for q in query.split("|") if len(q.strip()) > 3]
|
| 109 |
+
return parts[:3] if parts else [query.strip()]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
|
| 113 |
min_similarity: float = 0.6, candidate_multiplier: int = 3,
|
| 114 |
+
embeddings: list = None, token_budget: int = 3500):
|
| 115 |
"""
|
| 116 |
+
Enhanced retrieval:
|
| 117 |
+
β
Handles large / multi-part questions
|
| 118 |
+
β
Dynamically adjusts similarity threshold
|
| 119 |
+
β
Expands context until token budget is reached
|
| 120 |
+
β
Keeps neighbor fill for continuity
|
| 121 |
"""
|
| 122 |
|
| 123 |
if not index or not chunks:
|
|
|
|
| 125 |
return []
|
| 126 |
|
| 127 |
try:
|
| 128 |
+
# πΉ Step 0 β Split into sub-queries
|
| 129 |
+
sub_queries = _split_query(query)
|
| 130 |
+
dynamic_min_sim = max(0.45, min(0.6, 0.6 - 0.02 * len(sub_queries)))
|
| 131 |
+
print(f"π§© Sub-queries: {sub_queries} | Dynamic min_similarity={dynamic_min_sim:.2f}")
|
| 132 |
+
|
| 133 |
+
# πΉ Step 1 β Embed all sub-queries and gather candidate indices
|
| 134 |
+
all_candidates = set()
|
| 135 |
+
for sub_q in sub_queries:
|
| 136 |
+
q_emb = _query_model.encode(
|
| 137 |
+
[f"query: {sub_q.strip()}"],
|
| 138 |
+
convert_to_numpy=True,
|
| 139 |
+
normalize_embeddings=True
|
| 140 |
+
)[0]
|
| 141 |
+
|
| 142 |
+
# β
Auto-heal FAISS index dimension mismatch
|
| 143 |
+
if hasattr(index, "d") and q_emb.shape[0] != index.d:
|
| 144 |
+
print(f"β οΈ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
|
| 145 |
+
if embeddings:
|
| 146 |
+
print("π Rebuilding FAISS index to match embedding dimensions...")
|
| 147 |
+
index = build_faiss_index(embeddings)
|
| 148 |
+
print("β
FAISS index successfully rebuilt.")
|
| 149 |
+
q_emb = _query_model.encode(
|
| 150 |
+
[f"query: {sub_q.strip()}"],
|
| 151 |
+
convert_to_numpy=True,
|
| 152 |
+
normalize_embeddings=True
|
| 153 |
+
)[0]
|
| 154 |
+
else:
|
| 155 |
+
print("β No embeddings available to rebuild FAISS index.")
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
# Initial retrieval for each sub-query
|
| 159 |
+
num_candidates = max(top_k * candidate_multiplier, top_k + 2)
|
| 160 |
+
distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
|
| 161 |
+
all_candidates.update([int(i) for i in indices[0] if i >= 0])
|
| 162 |
+
|
| 163 |
+
if not all_candidates:
|
| 164 |
+
print("β οΈ No retrieval candidates found.")
|
| 165 |
+
return []
|
| 166 |
+
|
| 167 |
+
candidate_indices = list(all_candidates)
|
| 168 |
+
|
| 169 |
+
# πΉ Step 2 β Re-rank by cosine similarity
|
| 170 |
+
q_emb_global = _query_model.encode(
|
| 171 |
[f"query: {query.strip()}"],
|
| 172 |
convert_to_numpy=True,
|
| 173 |
normalize_embeddings=True
|
| 174 |
)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
doc_embs = _query_model.encode(
|
| 176 |
[f"passage: {chunks[i]}" for i in candidate_indices],
|
| 177 |
convert_to_numpy=True,
|
| 178 |
normalize_embeddings=True,
|
| 179 |
)
|
| 180 |
+
sims = cosine_similarity([q_emb_global], doc_embs)[0]
|
| 181 |
ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
|
| 182 |
|
| 183 |
+
# πΉ Step 3 β Dynamic filtering
|
| 184 |
+
filtered = [idx for idx, sim in ranked if sim >= dynamic_min_sim]
|
| 185 |
+
if not filtered:
|
| 186 |
+
filtered = [idx for idx, _ in ranked[:top_k]]
|
| 187 |
|
| 188 |
+
# πΉ Step 4 β Neighbor fill for continuity
|
| 189 |
if len(filtered) < top_k:
|
| 190 |
expanded = set(filtered)
|
| 191 |
for idx in filtered:
|
|
|
|
| 196 |
break
|
| 197 |
if len(expanded) >= top_k:
|
| 198 |
break
|
| 199 |
+
filtered = sorted(expanded)
|
| 200 |
+
|
| 201 |
+
# πΉ Step 5 β Context expansion (token-budget-aware)
|
| 202 |
+
context_limit = token_budget # approx. by word count
|
| 203 |
+
context_accum, current_len = [], 0
|
| 204 |
+
for idx, sim in ranked:
|
| 205 |
+
if idx not in filtered:
|
| 206 |
+
filtered.append(idx)
|
| 207 |
+
chunk_len = len(chunks[idx].split())
|
| 208 |
+
if current_len + chunk_len > context_limit:
|
| 209 |
+
break
|
| 210 |
+
context_accum.append(idx)
|
| 211 |
+
current_len += chunk_len
|
| 212 |
+
|
| 213 |
+
filtered = sorted(set(context_accum or filtered))[: max(top_k, len(filtered))]
|
| 214 |
|
| 215 |
+
# πΉ Step 6 β Final context prep
|
| 216 |
final_chunks = [chunks[i] for i in filtered]
|
| 217 |
+
print(f"β
Retrieved {len(final_chunks)} chunks (multi-span + adaptive threshold).")
|
| 218 |
return final_chunks
|
| 219 |
|
| 220 |
except Exception as e:
|