Shubham170793 commited on
Commit
0671dc0
Β·
verified Β·
1 Parent(s): b61a150

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +48 -105
src/qa.py CHANGED
@@ -13,6 +13,7 @@ from sentence_transformers import SentenceTransformer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
  from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
15
  from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
 
16
 
17
  print("βœ… qa.py (GPT-4o via Gen AI Hub + ReRank) loaded from:", __file__)
18
 
@@ -33,7 +34,7 @@ os.environ.update({
33
  # ==========================================================
34
  try:
35
  _query_model = SentenceTransformer(
36
- "intfloat/e5-small-v2", # ⚑ Faster, 384-dim embeddings
37
  cache_folder=CACHE_DIR
38
  )
39
  print("βœ… Loaded embedding model: intfloat/e5-small-v2 (fast mode)")
@@ -76,8 +77,10 @@ except Exception as e:
76
  # ==========================================================
77
  STRICT_PROMPT = (
78
  "You are an enterprise documentation assistant.\n"
79
- "Answer clearly and factually using ONLY the CONTEXT below.\n"
80
- "If the answer is not in the document, reply exactly:\n"
 
 
81
  "'I don't know based on the provided document.'\n\n"
82
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
83
  )
@@ -92,32 +95,15 @@ REASONING_PROMPT = (
92
  "Context:\n{context}\n\nQuestion: {query}\nLet's reason step-by-step:\nAnswer:"
93
  )
94
 
95
-
96
  # ==========================================================
97
- # πŸ” Improved Retrieval β€” Multi-Span Query + Adaptive Similarity + Context Expansion
98
  # ==========================================================
99
- from vectorstore import build_faiss_index
100
-
101
- def _split_query(query: str):
102
- """
103
- Breaks long or compound questions into smaller sub-queries for richer retrieval coverage.
104
- """
105
- separators = [".", "?", "and", "then", "also", ",", ";"]
106
- for sep in separators:
107
- query = query.replace(sep, "|")
108
- parts = [q.strip() for q in query.split("|") if len(q.strip()) > 3]
109
- return parts[:3] if parts else [query.strip()]
110
-
111
-
112
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
113
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
114
- embeddings: list = None, token_budget: int = 3500):
115
  """
116
- Enhanced retrieval:
117
- βœ… Handles large / multi-part questions
118
- βœ… Dynamically adjusts similarity threshold
119
- βœ… Expands context until token budget is reached
120
- βœ… Keeps neighbor fill for continuity
121
  """
122
 
123
  if not index or not chunks:
@@ -125,96 +111,54 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
125
  return []
126
 
127
  try:
128
- # πŸ”Ή Step 0 β€” Split into sub-queries
129
- sub_queries = _split_query(query)
130
- dynamic_min_sim = max(0.45, min(0.6, 0.6 - 0.02 * len(sub_queries)))
131
- print(f"🧩 Sub-queries: {sub_queries} | Dynamic min_similarity={dynamic_min_sim:.2f}")
132
-
133
- # πŸ”Ή Step 1 β€” Embed all sub-queries and gather candidate indices
134
- all_candidates = set()
135
- for sub_q in sub_queries:
136
- q_emb = _query_model.encode(
137
- [f"query: {sub_q.strip()}"],
138
- convert_to_numpy=True,
139
- normalize_embeddings=True
140
- )[0]
141
-
142
- # βœ… Auto-heal FAISS index dimension mismatch
143
- if hasattr(index, "d") and q_emb.shape[0] != index.d:
144
- print(f"⚠️ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
145
- if embeddings:
146
- print("πŸ”„ Rebuilding FAISS index to match embedding dimensions...")
147
- index = build_faiss_index(embeddings)
148
- print("βœ… FAISS index successfully rebuilt.")
149
- q_emb = _query_model.encode(
150
- [f"query: {sub_q.strip()}"],
151
- convert_to_numpy=True,
152
- normalize_embeddings=True
153
- )[0]
154
- else:
155
- print("❌ No embeddings available to rebuild FAISS index.")
156
- continue
157
-
158
- # Initial retrieval for each sub-query
159
- num_candidates = max(top_k * candidate_multiplier, top_k + 2)
160
- distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
161
- all_candidates.update([int(i) for i in indices[0] if i >= 0])
162
-
163
- if not all_candidates:
164
- print("⚠️ No retrieval candidates found.")
165
- return []
166
-
167
- candidate_indices = list(all_candidates)
168
-
169
- # πŸ”Ή Step 2 β€” Re-rank by cosine similarity
170
- q_emb_global = _query_model.encode(
171
  [f"query: {query.strip()}"],
172
  convert_to_numpy=True,
173
  normalize_embeddings=True
174
  )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  doc_embs = _query_model.encode(
176
  [f"passage: {chunks[i]}" for i in candidate_indices],
177
  convert_to_numpy=True,
178
  normalize_embeddings=True,
179
  )
180
- sims = cosine_similarity([q_emb_global], doc_embs)[0]
181
  ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
182
 
183
- # πŸ”Ή Step 3 β€” Dynamic filtering
184
- filtered = [idx for idx, sim in ranked if sim >= dynamic_min_sim]
185
- if not filtered:
186
- filtered = [idx for idx, _ in ranked[:top_k]]
187
 
188
- # πŸ”Ή Step 4 β€” Neighbor fill for continuity
189
- if len(filtered) < top_k:
190
- expanded = set(filtered)
191
- for idx in filtered:
192
- for neighbor in [idx - 1, idx + 1]:
193
- if 0 <= neighbor < len(chunks):
194
- expanded.add(neighbor)
195
- if len(expanded) >= top_k:
196
- break
197
- if len(expanded) >= top_k:
198
- break
199
- filtered = sorted(expanded)
200
 
201
- # πŸ”Ή Step 5 β€” Context expansion (token-budget-aware)
202
- context_limit = token_budget # approx. by word count
203
- context_accum, current_len = [], 0
204
- for idx, sim in ranked:
205
- if idx not in filtered:
206
- filtered.append(idx)
207
- chunk_len = len(chunks[idx].split())
208
- if current_len + chunk_len > context_limit:
209
- break
210
- context_accum.append(idx)
211
- current_len += chunk_len
212
-
213
- filtered = sorted(set(context_accum or filtered))[: max(top_k, len(filtered))]
214
-
215
- # πŸ”Ή Step 6 β€” Final context prep
216
  final_chunks = [chunks[i] for i in filtered]
217
- print(f"βœ… Retrieved {len(final_chunks)} chunks (multi-span + adaptive threshold).")
218
  return final_chunks
219
 
220
  except Exception as e:
@@ -234,7 +178,6 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
234
  if chat_llm is None:
235
  return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space."
236
 
237
- # Combine chunks with markers
238
  context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
239
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
240
 
@@ -243,7 +186,8 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
243
  "role": "system",
244
  "content": (
245
  "You are an expert enterprise documentation assistant. "
246
- "Answer only using provided context; if reasoning_mode is on, explain briefly. "
 
247
  "If answer not in document, say exactly: "
248
  "'I don't know based on the provided document.'"
249
  ),
@@ -258,12 +202,11 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
258
  print(f"⚠️ GPT-4o generation failed: {e}")
259
  return "⚠️ Error: Could not generate an answer."
260
 
 
261
  # ==========================================================
262
  # 7️⃣ Local Test
263
  # ==========================================================
264
  if __name__ == "__main__":
265
- from vectorstore import build_faiss_index
266
-
267
  dummy_chunks = [
268
  "Step 1: Open the dashboard and navigate to reports.",
269
  "Step 2: Click 'Export' to download a CSV summary.",
@@ -279,4 +222,4 @@ if __name__ == "__main__":
279
  query = "How do I create a communication user?"
280
  retrieved = retrieve_chunks(query, index, dummy_chunks)
281
  print("πŸ” Retrieved:", retrieved)
282
- print("πŸ’¬ Answer:", generate_answer(query, retrieved, reasoning_mode=True))
 
13
  from sklearn.metrics.pairwise import cosine_similarity
14
  from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
15
  from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
16
+ from vectorstore import build_faiss_index
17
 
18
  print("βœ… qa.py (GPT-4o via Gen AI Hub + ReRank) loaded from:", __file__)
19
 
 
34
  # ==========================================================
35
  try:
36
  _query_model = SentenceTransformer(
37
+ "intfloat/e5-small-v2",
38
  cache_folder=CACHE_DIR
39
  )
40
  print("βœ… Loaded embedding model: intfloat/e5-small-v2 (fast mode)")
 
77
  # ==========================================================
78
  STRICT_PROMPT = (
79
  "You are an enterprise documentation assistant.\n"
80
+ "Use all relevant information from the CONTEXT below.\n"
81
+ "If multiple related points appear across chunks, combine them logically into one clear answer.\n"
82
+ "Do not invent facts outside the provided content.\n"
83
+ "If the answer cannot be found even after considering all chunks, say exactly:\n"
84
  "'I don't know based on the provided document.'\n\n"
85
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
86
  )
 
95
  "Context:\n{context}\n\nQuestion: {query}\nLet's reason step-by-step:\nAnswer:"
96
  )
97
 
 
98
  # ==========================================================
99
+ # 5️⃣ Retrieval β€” FAISS + Re-rank + Neighbor Fill
100
  # ==========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
102
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
103
+ embeddings: list = None):
104
  """
105
+ Re-rank and optionally fill with neighbors for context continuity.
106
+ Auto-detects and rebuilds FAISS index if dimension mismatch occurs.
 
 
 
107
  """
108
 
109
  if not index or not chunks:
 
111
  return []
112
 
113
  try:
114
+ # Encode query embedding
115
+ q_emb = _query_model.encode(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  [f"query: {query.strip()}"],
117
  convert_to_numpy=True,
118
  normalize_embeddings=True
119
  )[0]
120
+
121
+ # βœ… Check dimension match
122
+ if hasattr(index, "d") and q_emb.shape[0] != index.d:
123
+ print(f"⚠️ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
124
+ if embeddings:
125
+ print("πŸ”„ Rebuilding FAISS index to match embedding dimensions...")
126
+ index = build_faiss_index(embeddings)
127
+ q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
128
+ else:
129
+ return []
130
+
131
+ # Step 1️⃣ β€” Initial FAISS retrieval
132
+ num_candidates = max(top_k * candidate_multiplier, top_k + 2)
133
+ distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
134
+ candidate_indices = [int(i) for i in indices[0] if i >= 0]
135
+ candidate_indices = list(dict.fromkeys(candidate_indices))
136
+
137
+ # Step 2️⃣ β€” Re-rank by cosine similarity
138
  doc_embs = _query_model.encode(
139
  [f"passage: {chunks[i]}" for i in candidate_indices],
140
  convert_to_numpy=True,
141
  normalize_embeddings=True,
142
  )
143
+ sims = cosine_similarity([q_emb], doc_embs)[0]
144
  ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
145
 
146
+ # Step 3️⃣ β€” Filter by similarity
147
+ filtered = [idx for idx, sim in ranked if sim >= min_similarity]
148
+ if len(filtered) > top_k:
149
+ filtered = filtered[:top_k]
150
 
151
+ # Step 4️⃣ β€” Include Β±1 neighbors for continuity
152
+ neighbors = set()
153
+ for idx in filtered:
154
+ for n in [idx - 1, idx + 1]:
155
+ if 0 <= n < len(chunks):
156
+ neighbors.add(n)
157
+ filtered = sorted(set(filtered) | neighbors)
 
 
 
 
 
158
 
159
+ # Step 5️⃣ β€” Build final chunk list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  final_chunks = [chunks[i] for i in filtered]
161
+ print(f"βœ… Retrieved {len(final_chunks)} chunks (semantic + neighbor fill).")
162
  return final_chunks
163
 
164
  except Exception as e:
 
178
  if chat_llm is None:
179
  return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space."
180
 
 
181
  context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
182
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
183
 
 
186
  "role": "system",
187
  "content": (
188
  "You are an expert enterprise documentation assistant. "
189
+ "When reasoning_mode is off, stay strictly factual and concise. "
190
+ "When on, combine insights across chunks logically. "
191
  "If answer not in document, say exactly: "
192
  "'I don't know based on the provided document.'"
193
  ),
 
202
  print(f"⚠️ GPT-4o generation failed: {e}")
203
  return "⚠️ Error: Could not generate an answer."
204
 
205
+
206
  # ==========================================================
207
  # 7️⃣ Local Test
208
  # ==========================================================
209
  if __name__ == "__main__":
 
 
210
  dummy_chunks = [
211
  "Step 1: Open the dashboard and navigate to reports.",
212
  "Step 2: Click 'Export' to download a CSV summary.",
 
222
  query = "How do I create a communication user?"
223
  retrieved = retrieve_chunks(query, index, dummy_chunks)
224
  print("πŸ” Retrieved:", retrieved)
225
+ print("πŸ’¬ Answer:", generate_answer(query, retrieved, reasoning_mode=False))