Shubham170793 commited on
Commit
1b878f3
·
verified ·
1 Parent(s): a5c876d

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +44 -17
src/qa.py CHANGED
@@ -89,24 +89,48 @@ REASONING_PROMPT = (
89
  )
90
 
91
  # ==========================================================
92
- # 5️⃣ Retrieval — FAISS + Cosine Re-Rank + Neighbor Fill
93
  # ==========================================================
 
 
94
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
95
- min_similarity: float = 0.6, candidate_multiplier: int = 3):
96
- """Select top chunks via FAISS, rerank by cosine similarity, fill gaps with neighbors."""
 
 
 
 
 
97
  if not index or not chunks:
 
98
  return []
99
 
100
  try:
 
101
  q_emb = _query_model.encode(
102
- [f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True
 
 
103
  )[0]
104
 
105
- # 1️⃣ Initial FAISS search
106
- distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * candidate_multiplier)
107
- candidate_indices = list(dict.fromkeys(indices[0])) # dedup, preserve order
108
-
109
- # 2️⃣ Compute true cosine similarity for rerank
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  doc_embs = _query_model.encode(
111
  [f"passage: {chunks[i]}" for i in candidate_indices],
112
  convert_to_numpy=True,
@@ -115,28 +139,31 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
115
  sims = cosine_similarity([q_emb], doc_embs)[0]
116
  ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
117
 
118
- # 3️⃣ Keep only chunks meeting threshold
119
- filtered = [idx for idx, sim in ranked if sim >= min_similarity][:top_k]
 
 
120
 
121
- # 4️⃣ Neighbor fill if not enough
122
  if len(filtered) < top_k:
123
  expanded = set(filtered)
124
  for idx in filtered:
125
- for nb in [idx - 1, idx + 1]:
126
- if 0 <= nb < len(chunks):
127
- expanded.add(nb)
128
  if len(expanded) >= top_k:
129
  break
130
  if len(expanded) >= top_k:
131
  break
132
  filtered = sorted(expanded)[:top_k]
133
 
 
134
  final_chunks = [chunks[i] for i in filtered]
135
- print(f"✅ Retrieved {len(final_chunks)} chunks (semantic + neighbor fill)")
136
  return final_chunks
137
 
138
  except Exception as e:
139
- print(f"⚠️ Retrieval error: {e}")
140
  return []
141
 
142
  # ==========================================================
 
89
  )
90
 
91
  # ==========================================================
92
+ # 5️⃣ Retrieval — FAISS + Re-rank + Neighbor Fill (Auto-Healing)
93
  # ==========================================================
94
+ from vectorstore import build_faiss_index
95
+
96
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
97
+ min_similarity: float = 0.6, candidate_multiplier: int = 3,
98
+ embeddings: list = None):
99
+ """
100
+ Re-rank and optionally fill with neighbors for context continuity.
101
+ Auto-detects and rebuilds FAISS index if dimension mismatch occurs.
102
+ """
103
+
104
  if not index or not chunks:
105
+ print("⚠️ No FAISS index or chunks provided — returning empty result.")
106
  return []
107
 
108
  try:
109
+ # Encode query embedding
110
  q_emb = _query_model.encode(
111
+ [f"query: {query.strip()}"],
112
+ convert_to_numpy=True,
113
+ normalize_embeddings=True
114
  )[0]
115
 
116
+ # Sanity check: dimension match between query and FAISS index
117
+ if hasattr(index, "d") and q_emb.shape[0] != index.d:
118
+ print(f"⚠️ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
119
+ if embeddings:
120
+ print("🔄 Rebuilding FAISS index to match embedding dimensions...")
121
+ index = build_faiss_index(embeddings)
122
+ print("✅ FAISS index successfully rebuilt.")
123
+ else:
124
+ print("❌ No embeddings available to rebuild FAISS index.")
125
+ return []
126
+
127
+ # Step 1️⃣ — Initial FAISS retrieval
128
+ num_candidates = max(top_k * candidate_multiplier, top_k + 2)
129
+ distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
130
+ candidate_indices = [int(i) for i in indices[0] if i >= 0]
131
+ candidate_indices = list(dict.fromkeys(candidate_indices)) # de-dupe
132
+
133
+ # Step 2️⃣ — Re-rank by cosine similarity
134
  doc_embs = _query_model.encode(
135
  [f"passage: {chunks[i]}" for i in candidate_indices],
136
  convert_to_numpy=True,
 
139
  sims = cosine_similarity([q_emb], doc_embs)[0]
140
  ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
141
 
142
+ # Step 3️⃣ Filter by similarity threshold
143
+ filtered = [idx for idx, sim in ranked if sim >= min_similarity]
144
+ if len(filtered) > top_k:
145
+ filtered = filtered[:top_k]
146
 
147
+ # Step 4️⃣ Neighbor fill (if not enough)
148
  if len(filtered) < top_k:
149
  expanded = set(filtered)
150
  for idx in filtered:
151
+ for neighbor in [idx - 1, idx + 1]:
152
+ if 0 <= neighbor < len(chunks):
153
+ expanded.add(neighbor)
154
  if len(expanded) >= top_k:
155
  break
156
  if len(expanded) >= top_k:
157
  break
158
  filtered = sorted(expanded)[:top_k]
159
 
160
+ # Step 5️⃣ — Build final chunk list
161
  final_chunks = [chunks[i] for i in filtered]
162
+ print(f"✅ Retrieved {len(final_chunks)} chunks (semantic + neighbor fill).")
163
  return final_chunks
164
 
165
  except Exception as e:
166
+ print(f"⚠️ Retrieval error: {repr(e)}")
167
  return []
168
 
169
  # ==========================================================