Shubham170793 commited on
Commit
e9faa78
Β·
verified Β·
1 Parent(s): edaeee6

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +84 -41
src/qa.py CHANGED
@@ -94,16 +94,30 @@ REASONING_PROMPT = (
94
 
95
 
96
  # ==========================================================
97
- # 5️⃣ Retrieval β€” FAISS + Re-rank + Neighbor Fill (Auto-Healing)
98
  # ==========================================================
99
  from vectorstore import build_faiss_index
100
 
 
 
 
 
 
 
 
 
 
 
 
101
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
102
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
103
- embeddings: list = None):
104
  """
105
- Re-rank and optionally fill with neighbors for context continuity.
106
- Auto-detects and rebuilds FAISS index if dimension mismatch occurs.
 
 
 
107
  """
108
 
109
  if not index or not chunks:
@@ -111,52 +125,67 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
111
  return []
112
 
113
  try:
114
- # Encode query embedding
115
- q_emb = _query_model.encode(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  [f"query: {query.strip()}"],
117
  convert_to_numpy=True,
118
  normalize_embeddings=True
119
  )[0]
120
-
121
- # βœ… Sanity check: dimension match between query and FAISS index
122
- if hasattr(index, "d") and q_emb.shape[0] != index.d:
123
- print(f"⚠️ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
124
- if embeddings:
125
- print("πŸ”„ Rebuilding FAISS index to match embedding dimensions...")
126
- index = build_faiss_index(embeddings)
127
- print("βœ… FAISS index successfully rebuilt.")
128
-
129
- # βœ… Regenerate query embedding now that we have a matching index
130
- q_emb = _query_model.encode(
131
- [f"query: {query.strip()}"],
132
- convert_to_numpy=True,
133
- normalize_embeddings=True
134
- )[0]
135
- else:
136
- print("❌ No embeddings available to rebuild FAISS index.")
137
- return []
138
-
139
- # Step 1️⃣ β€” Initial FAISS retrieval
140
- num_candidates = max(top_k * candidate_multiplier, top_k + 2)
141
- distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
142
- candidate_indices = [int(i) for i in indices[0] if i >= 0]
143
- candidate_indices = list(dict.fromkeys(candidate_indices)) # de-dupe
144
-
145
- # Step 2️⃣ β€” Re-rank by cosine similarity
146
  doc_embs = _query_model.encode(
147
  [f"passage: {chunks[i]}" for i in candidate_indices],
148
  convert_to_numpy=True,
149
  normalize_embeddings=True,
150
  )
151
- sims = cosine_similarity([q_emb], doc_embs)[0]
152
  ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
153
 
154
- # Step 3️⃣ β€” Filter by similarity threshold
155
- filtered = [idx for idx, sim in ranked if sim >= min_similarity]
156
- if len(filtered) > top_k:
157
- filtered = filtered[:top_k]
158
 
159
- # Step 4️⃣ β€” Neighbor fill (if not enough)
160
  if len(filtered) < top_k:
161
  expanded = set(filtered)
162
  for idx in filtered:
@@ -167,11 +196,25 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
167
  break
168
  if len(expanded) >= top_k:
169
  break
170
- filtered = sorted(expanded)[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- # Step 5️⃣ β€” Build final chunk list
173
  final_chunks = [chunks[i] for i in filtered]
174
- print(f"βœ… Retrieved {len(final_chunks)} chunks (semantic + neighbor fill).")
175
  return final_chunks
176
 
177
  except Exception as e:
 
94
 
95
 
96
  # ==========================================================
97
+ # πŸ” Improved Retrieval β€” Multi-Span Query + Adaptive Similarity + Context Expansion
98
  # ==========================================================
99
  from vectorstore import build_faiss_index
100
 
101
+ def _split_query(query: str):
102
+ """
103
+ Breaks long or compound questions into smaller sub-queries for richer retrieval coverage.
104
+ """
105
+ separators = [".", "?", "and", "then", "also", ",", ";"]
106
+ for sep in separators:
107
+ query = query.replace(sep, "|")
108
+ parts = [q.strip() for q in query.split("|") if len(q.strip()) > 3]
109
+ return parts[:3] if parts else [query.strip()]
110
+
111
+
112
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
113
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
114
+ embeddings: list = None, token_budget: int = 3500):
115
  """
116
+ Enhanced retrieval:
117
+ βœ… Handles large / multi-part questions
118
+ βœ… Dynamically adjusts similarity threshold
119
+ βœ… Expands context until token budget is reached
120
+ βœ… Keeps neighbor fill for continuity
121
  """
122
 
123
  if not index or not chunks:
 
125
  return []
126
 
127
  try:
128
+ # πŸ”Ή Step 0 β€” Split into sub-queries
129
+ sub_queries = _split_query(query)
130
+ dynamic_min_sim = max(0.45, min(0.6, 0.6 - 0.02 * len(sub_queries)))
131
+ print(f"🧩 Sub-queries: {sub_queries} | Dynamic min_similarity={dynamic_min_sim:.2f}")
132
+
133
+ # πŸ”Ή Step 1 β€” Embed all sub-queries and gather candidate indices
134
+ all_candidates = set()
135
+ for sub_q in sub_queries:
136
+ q_emb = _query_model.encode(
137
+ [f"query: {sub_q.strip()}"],
138
+ convert_to_numpy=True,
139
+ normalize_embeddings=True
140
+ )[0]
141
+
142
+ # βœ… Auto-heal FAISS index dimension mismatch
143
+ if hasattr(index, "d") and q_emb.shape[0] != index.d:
144
+ print(f"⚠️ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
145
+ if embeddings:
146
+ print("πŸ”„ Rebuilding FAISS index to match embedding dimensions...")
147
+ index = build_faiss_index(embeddings)
148
+ print("βœ… FAISS index successfully rebuilt.")
149
+ q_emb = _query_model.encode(
150
+ [f"query: {sub_q.strip()}"],
151
+ convert_to_numpy=True,
152
+ normalize_embeddings=True
153
+ )[0]
154
+ else:
155
+ print("❌ No embeddings available to rebuild FAISS index.")
156
+ continue
157
+
158
+ # Initial retrieval for each sub-query
159
+ num_candidates = max(top_k * candidate_multiplier, top_k + 2)
160
+ distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
161
+ all_candidates.update([int(i) for i in indices[0] if i >= 0])
162
+
163
+ if not all_candidates:
164
+ print("⚠️ No retrieval candidates found.")
165
+ return []
166
+
167
+ candidate_indices = list(all_candidates)
168
+
169
+ # πŸ”Ή Step 2 β€” Re-rank by cosine similarity
170
+ q_emb_global = _query_model.encode(
171
  [f"query: {query.strip()}"],
172
  convert_to_numpy=True,
173
  normalize_embeddings=True
174
  )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  doc_embs = _query_model.encode(
176
  [f"passage: {chunks[i]}" for i in candidate_indices],
177
  convert_to_numpy=True,
178
  normalize_embeddings=True,
179
  )
180
+ sims = cosine_similarity([q_emb_global], doc_embs)[0]
181
  ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
182
 
183
+ # πŸ”Ή Step 3 β€” Dynamic filtering
184
+ filtered = [idx for idx, sim in ranked if sim >= dynamic_min_sim]
185
+ if not filtered:
186
+ filtered = [idx for idx, _ in ranked[:top_k]]
187
 
188
+ # πŸ”Ή Step 4 β€” Neighbor fill for continuity
189
  if len(filtered) < top_k:
190
  expanded = set(filtered)
191
  for idx in filtered:
 
196
  break
197
  if len(expanded) >= top_k:
198
  break
199
+ filtered = sorted(expanded)
200
+
201
+ # πŸ”Ή Step 5 β€” Context expansion (token-budget-aware)
202
+ context_limit = token_budget # approx. by word count
203
+ context_accum, current_len = [], 0
204
+ for idx, sim in ranked:
205
+ if idx not in filtered:
206
+ filtered.append(idx)
207
+ chunk_len = len(chunks[idx].split())
208
+ if current_len + chunk_len > context_limit:
209
+ break
210
+ context_accum.append(idx)
211
+ current_len += chunk_len
212
+
213
+ filtered = sorted(set(context_accum or filtered))[: max(top_k, len(filtered))]
214
 
215
+ # πŸ”Ή Step 6 β€” Final context prep
216
  final_chunks = [chunks[i] for i in filtered]
217
+ print(f"βœ… Retrieved {len(final_chunks)} chunks (multi-span + adaptive threshold).")
218
  return final_chunks
219
 
220
  except Exception as e: