Shubham170793 commited on
Commit
235a5b5
·
verified ·
1 Parent(s): 197e569

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +57 -62
src/qa.py CHANGED
@@ -85,71 +85,66 @@ REASONING_PROMPT = (
85
  # ==========================================================
86
  # 5️⃣ Retrieve Chunks (FAISS + Rerank + Neighbor Expansion)
87
  # ==========================================================
88
- def retrieve_chunks(
89
- query: str,
90
- index,
91
- chunks: list,
92
- top_k: int = 3,
93
- topn_candidates: int = 20,
94
- neighbor_threshold: float = 0.68,
95
- expansion_window: int = 1,
96
- max_context_chunks: int = 6,
97
- ):
98
- """Retrieve semantically relevant chunks with reranking and neighbor expansion."""
99
  if not index or not chunks:
100
  return []
101
 
102
- # 1️⃣ Encode query (normalized)
103
- query_emb = _query_model.encode(
104
- [f"query: {query.strip()}"],
105
- convert_to_numpy=True,
106
- normalize_embeddings=True
107
- )[0].astype("float32")
108
-
109
- # 2️⃣ FAISS search (initial candidates)
110
- topn_candidates = min(topn_candidates, getattr(index, "ntotal", topn_candidates))
111
- _, candidate_ids = index.search(np.array([query_emb]).astype("float32"), topn_candidates)
112
- candidate_ids = [int(i) for i in candidate_ids[0] if i != -1]
113
-
114
- # 3️⃣ Re-encode candidate chunks and compute cosine similarities
115
- candidate_texts = [chunks[i] for i in candidate_ids]
116
- candidate_vecs = np.array([
117
- _query_model.encode([t], convert_to_numpy=True, normalize_embeddings=True)[0]
118
- for t in candidate_texts
119
- ])
120
- sims = cosine_similarity([query_emb], candidate_vecs)[0]
121
- sorted_idx = np.argsort(sims)[::-1]
122
- reranked_ids = [candidate_ids[i] for i in sorted_idx]
123
-
124
- # 4️⃣ Select top-k base chunks
125
- selected, selected_set = [], set()
126
- for rid in reranked_ids:
127
- if len(selected) >= top_k:
128
- break
129
- selected.append(rid)
130
- selected_set.add(rid)
131
-
132
- # 5️⃣ Conditional neighbor expansion
133
- final_order = list(selected)
134
- for base_id in selected:
135
- if len(final_order) >= max_context_chunks:
136
- break
137
- for offset in range(1, expansion_window + 1):
138
- for neighbor in (base_id - offset, base_id + offset):
139
- if neighbor < 0 or neighbor >= len(chunks) or neighbor in selected_set:
140
- continue
141
- # Check semantic closeness
142
- neighbor_vec = _query_model.encode([chunks[neighbor]], convert_to_numpy=True, normalize_embeddings=True)[0]
143
- sim = float(cosine_similarity([query_emb], [neighbor_vec])[0][0])
144
- if sim >= neighbor_threshold:
145
- final_order.append(neighbor)
146
- selected_set.add(neighbor)
147
- if len(final_order) >= max_context_chunks:
148
- break
149
- if len(final_order) >= max_context_chunks:
150
- break
151
-
152
- return [chunks[i] for i in final_order]
153
 
154
  # ==========================================================
155
  # 6️⃣ Answer Generation
 
85
  # ==========================================================
86
  # 5️⃣ Retrieve Chunks (FAISS + Rerank + Neighbor Expansion)
87
  # ==========================================================
88
+ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, min_similarity: float = 0.6):
89
+ """
90
+ Hybrid retrieval:
91
+ 1️⃣ Get semantic top-K chunks via FAISS.
92
+ 2️⃣ Re-rank by cosine similarity and apply a minimum similarity filter.
93
+ 3️⃣ If fewer than top_k remain, fill remaining seats with adjacent chunks (±1) for continuity.
94
+ """
 
 
 
 
95
  if not index or not chunks:
96
  return []
97
 
98
+ try:
99
+ # Encode query
100
+ q_emb = _query_model.encode(
101
+ [f"query: {query.strip()}"],
102
+ convert_to_numpy=True,
103
+ normalize_embeddings=True
104
+ )[0]
105
+
106
+ # Step 1️⃣ FAISS initial retrieval
107
+ distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
108
+ retrieved_indices = list(indices[0])
109
+
110
+ # Step 2️⃣ Compute cosine similarity for re-ranking
111
+ retrieved_texts = [chunks[i] for i in retrieved_indices]
112
+ doc_embs = _query_model.encode(
113
+ [f"passage: {c}" for c in retrieved_texts],
114
+ convert_to_numpy=True,
115
+ normalize_embeddings=True
116
+ )
117
+
118
+ sims = cosine_similarity([q_emb], doc_embs)[0]
119
+ ranked = sorted(zip(retrieved_indices, sims), key=lambda x: x[1], reverse=True)
120
+
121
+ # Step 3️⃣ Apply minimum similarity filter
122
+ filtered_indices = [idx for idx, score in ranked if score >= min_similarity]
123
+
124
+ # Step 4️⃣ — If not enough, add ±1 neighbors for continuity
125
+ if len(filtered_indices) < top_k:
126
+ extras_needed = top_k - len(filtered_indices)
127
+ expanded_indices = set(filtered_indices)
128
+ for idx in filtered_indices:
129
+ for neighbor in [idx - 1, idx + 1]:
130
+ if 0 <= neighbor < len(chunks):
131
+ expanded_indices.add(neighbor)
132
+ if len(expanded_indices) >= top_k:
133
+ break
134
+ if len(expanded_indices) >= top_k:
135
+ break
136
+ filtered_indices = list(sorted(expanded_indices))[:top_k]
137
+
138
+ # Step 5️⃣ Build final ordered list of chunks
139
+ final_chunks = [chunks[i] for i in filtered_indices]
140
+
141
+ print(f"✅ Retrieved {len(final_chunks)} chunks (semantic + neighbor fill).")
142
+ return final_chunks
143
+
144
+ except Exception as e:
145
+ print(f"⚠️ Retrieval error: {e}")
146
+ return []
147
+
 
148
 
149
  # ==========================================================
150
  # 6️⃣ Answer Generation