Shubham170793 commited on
Commit
3cf73df
·
verified ·
1 Parent(s): 8ef9d9a

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +13 -3
src/qa.py CHANGED
@@ -90,18 +90,28 @@ REASONING_PROMPT = (
90
  # 5️⃣ Retrieve Top-K Chunks (Balanced speed)
91
  # ==========================================================
92
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
93
- """Retrieve top-K relevant chunks efficiently."""
94
  if not index or not chunks:
95
  return []
96
 
97
  try:
98
  q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
99
- distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k)
100
- return [chunks[i] for i in indices[0]]
 
 
 
 
 
 
 
 
 
101
  except Exception as e:
102
  print(f"⚠️ Retrieval error: {e}")
103
  return []
104
 
 
105
  # ==========================================================
106
  # 6️⃣ Generate Answer (Reasoning or Strict Mode)
107
  # ==========================================================
 
90
  # 5️⃣ Retrieve Top-K Chunks (Balanced speed)
91
  # ==========================================================
92
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
93
+ """Retrieve top-K relevant chunks and re-rank by cosine similarity for better precision."""
94
  if not index or not chunks:
95
  return []
96
 
97
  try:
98
  q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
99
+ distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
100
+
101
+ # Compute similarity scores for re-ranking
102
+ candidates = [chunks[i] for i in indices[0]]
103
+ cand_vecs = _query_model.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
104
+ sims = cosine_similarity([q_emb], cand_vecs)[0]
105
+
106
+ # Return top-K most semantically aligned
107
+ top_indices = np.argsort(sims)[::-1][:top_k]
108
+ return [candidates[i] for i in top_indices]
109
+
110
  except Exception as e:
111
  print(f"⚠️ Retrieval error: {e}")
112
  return []
113
 
114
+
115
  # ==========================================================
116
  # 6️⃣ Generate Answer (Reasoning or Strict Mode)
117
  # ==========================================================