Shubham170793 commited on
Commit
d73a9dd
Β·
verified Β·
1 Parent(s): d511cfa

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +42 -21
src/qa.py CHANGED
@@ -2,20 +2,21 @@
2
  qa.py β€” GPT-4o (SAP Gen AI Hub) + ReRank Retrieval
3
  --------------------------------------------------
4
  βœ… Semantic retrieval (FAISS + cosine re-rank + neighbor fill)
 
5
  βœ… Smart factual mode (fast)
6
  βœ… Deep reasoning mode (ChatGPT-like)
7
  """
8
 
9
  import os
 
10
  import json
11
  import numpy as np
12
  from sentence_transformers import SentenceTransformer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
  from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
15
  from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
16
- from vectorstore import build_faiss_index
17
 
18
- print("βœ… qa.py (GPT-4o via Gen AI Hub + ReRank) loaded from:", __file__)
19
 
20
  # ==========================================================
21
  # 1️⃣ Hugging Face Cache
@@ -34,7 +35,7 @@ os.environ.update({
34
  # ==========================================================
35
  try:
36
  _query_model = SentenceTransformer(
37
- "intfloat/e5-small-v2",
38
  cache_folder=CACHE_DIR
39
  )
40
  print("βœ… Loaded embedding model: intfloat/e5-small-v2 (fast mode)")
@@ -83,7 +84,6 @@ STRICT_PROMPT = (
83
  "If the answer cannot be found even after considering all chunks, say exactly:\n"
84
  "'I don't know based on the provided document.'\n\n"
85
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
86
-
87
  )
88
 
89
  REASONING_PROMPT = (
@@ -97,13 +97,16 @@ REASONING_PROMPT = (
97
  )
98
 
99
  # ==========================================================
100
- # 5️⃣ Retrieval β€” FAISS + Re-rank + Neighbor Fill
101
  # ==========================================================
 
 
102
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
103
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
104
  embeddings: list = None):
105
  """
106
  Re-rank and optionally fill with neighbors for context continuity.
 
107
  Auto-detects and rebuilds FAISS index if dimension mismatch occurs.
108
  """
109
 
@@ -119,37 +122,53 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
119
  normalize_embeddings=True
120
  )[0]
121
 
122
- # βœ… Check dimension match
123
  if hasattr(index, "d") and q_emb.shape[0] != index.d:
124
  print(f"⚠️ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
125
  if embeddings:
126
  print("πŸ”„ Rebuilding FAISS index to match embedding dimensions...")
127
  index = build_faiss_index(embeddings)
128
- q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
 
 
 
 
 
 
129
  else:
 
130
  return []
131
 
132
  # Step 1️⃣ β€” Initial FAISS retrieval
133
  num_candidates = max(top_k * candidate_multiplier, top_k + 2)
134
  distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
135
  candidate_indices = [int(i) for i in indices[0] if i >= 0]
136
- candidate_indices = list(dict.fromkeys(candidate_indices))
137
 
138
- # Step 2️⃣ β€” Re-rank by cosine similarity
139
  doc_embs = _query_model.encode(
140
  [f"passage: {chunks[i]}" for i in candidate_indices],
141
  convert_to_numpy=True,
142
  normalize_embeddings=True,
143
  )
144
  sims = cosine_similarity([q_emb], doc_embs)[0]
145
- ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
146
 
147
- # Step 3️⃣ β€” Filter by similarity
 
 
 
 
 
 
 
 
 
 
148
  filtered = [idx for idx, sim in ranked if sim >= min_similarity]
149
  if len(filtered) > top_k:
150
  filtered = filtered[:top_k]
151
 
152
- # Step 4️⃣ β€” Include Β±1 neighbors for continuity
153
  neighbors = set()
154
  for idx in filtered:
155
  for n in [idx - 1, idx + 1]:
@@ -159,7 +178,7 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
159
 
160
  # Step 5️⃣ β€” Build final chunk list
161
  final_chunks = [chunks[i] for i in filtered]
162
- print(f"βœ… Retrieved {len(final_chunks)} chunks (semantic + neighbor fill).")
163
  return final_chunks
164
 
165
  except Exception as e:
@@ -179,6 +198,7 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
179
  if chat_llm is None:
180
  return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space."
181
 
 
182
  context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
183
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
184
 
@@ -189,8 +209,8 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
189
  "You are an expert enterprise documentation assistant. "
190
  "When reasoning_mode is off, stay strictly factual and concise. "
191
  "When reasoning_mode is on, combine insights across chunks logically "
192
- "and explain the reasoning briefly."
193
- "If answer not in document, say exactly: "
194
  "'I don't know based on the provided document.'"
195
  ),
196
  },
@@ -204,16 +224,17 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
204
  print(f"⚠️ GPT-4o generation failed: {e}")
205
  return "⚠️ Error: Could not generate an answer."
206
 
207
-
208
  # ==========================================================
209
  # 7️⃣ Local Test
210
  # ==========================================================
211
  if __name__ == "__main__":
 
 
212
  dummy_chunks = [
213
- "Step 1: Open the dashboard and navigate to reports.",
214
- "Step 2: Click 'Export' to download a CSV summary.",
215
- "Step 3: Review the generated report in your downloads folder.",
216
- "Appendix: Communication user creation steps are explained later in this guide."
217
  ]
218
  embeddings = [
219
  _query_model.encode([f"passage: {c}"], convert_to_numpy=True, normalize_embeddings=True)[0]
@@ -221,7 +242,7 @@ if __name__ == "__main__":
221
  ]
222
  index = build_faiss_index(embeddings)
223
 
224
- query = "How do I create a communication user?"
225
  retrieved = retrieve_chunks(query, index, dummy_chunks)
226
  print("πŸ” Retrieved:", retrieved)
227
  print("πŸ’¬ Answer:", generate_answer(query, retrieved, reasoning_mode=False))
 
2
  qa.py β€” GPT-4o (SAP Gen AI Hub) + ReRank Retrieval
3
  --------------------------------------------------
4
  βœ… Semantic retrieval (FAISS + cosine re-rank + neighbor fill)
5
+ βœ… Bullet-aware similarity boost for procedural chunks
6
  βœ… Smart factual mode (fast)
7
  βœ… Deep reasoning mode (ChatGPT-like)
8
  """
9
 
10
  import os
11
+ import re
12
  import json
13
  import numpy as np
14
  from sentence_transformers import SentenceTransformer
15
  from sklearn.metrics.pairwise import cosine_similarity
16
  from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
17
  from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
 
18
 
19
+ print("βœ… qa.py (GPT-4o via Gen AI Hub + Bullet-Aware Retrieval) loaded from:", __file__)
20
 
21
  # ==========================================================
22
  # 1️⃣ Hugging Face Cache
 
35
  # ==========================================================
36
  try:
37
  _query_model = SentenceTransformer(
38
+ "intfloat/e5-small-v2", # ⚑ Faster, 384-dim embeddings
39
  cache_folder=CACHE_DIR
40
  )
41
  print("βœ… Loaded embedding model: intfloat/e5-small-v2 (fast mode)")
 
84
  "If the answer cannot be found even after considering all chunks, say exactly:\n"
85
  "'I don't know based on the provided document.'\n\n"
86
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
 
87
  )
88
 
89
  REASONING_PROMPT = (
 
97
  )
98
 
99
  # ==========================================================
100
+ # 5️⃣ Retrieval β€” FAISS + Bullet-Aware Re-rank + Neighbor Fill
101
  # ==========================================================
102
+ from vectorstore import build_faiss_index
103
+
104
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
105
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
106
  embeddings: list = None):
107
  """
108
  Re-rank and optionally fill with neighbors for context continuity.
109
+ Adds small similarity boost for bullet-style or step-based chunks.
110
  Auto-detects and rebuilds FAISS index if dimension mismatch occurs.
111
  """
112
 
 
122
  normalize_embeddings=True
123
  )[0]
124
 
125
+ # βœ… Sanity check: dimension match between query and FAISS index
126
  if hasattr(index, "d") and q_emb.shape[0] != index.d:
127
  print(f"⚠️ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
128
  if embeddings:
129
  print("πŸ”„ Rebuilding FAISS index to match embedding dimensions...")
130
  index = build_faiss_index(embeddings)
131
+ print("βœ… FAISS index successfully rebuilt.")
132
+
133
+ q_emb = _query_model.encode(
134
+ [f"query: {query.strip()}"],
135
+ convert_to_numpy=True,
136
+ normalize_embeddings=True
137
+ )[0]
138
  else:
139
+ print("❌ No embeddings available to rebuild FAISS index.")
140
  return []
141
 
142
  # Step 1️⃣ β€” Initial FAISS retrieval
143
  num_candidates = max(top_k * candidate_multiplier, top_k + 2)
144
  distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
145
  candidate_indices = [int(i) for i in indices[0] if i >= 0]
146
+ candidate_indices = list(dict.fromkeys(candidate_indices)) # de-dupe
147
 
148
+ # Step 2️⃣ β€” Compute similarities
149
  doc_embs = _query_model.encode(
150
  [f"passage: {chunks[i]}" for i in candidate_indices],
151
  convert_to_numpy=True,
152
  normalize_embeddings=True,
153
  )
154
  sims = cosine_similarity([q_emb], doc_embs)[0]
 
155
 
156
+ # πŸ”Ή NEW: Boost similarity for bullet-style or step-based chunks
157
+ boosted_sims = []
158
+ for idx, sim in zip(candidate_indices, sims):
159
+ chunk_text = chunks[idx].strip()
160
+ if re.match(r"^[-β€’\d]+[\.\s]", chunk_text): # bullet or numbered
161
+ sim += 0.05 # small procedural context boost
162
+ boosted_sims.append((idx, sim))
163
+
164
+ ranked = sorted(boosted_sims, key=lambda x: x[1], reverse=True)
165
+
166
+ # Step 3️⃣ β€” Filter by similarity threshold
167
  filtered = [idx for idx, sim in ranked if sim >= min_similarity]
168
  if len(filtered) > top_k:
169
  filtered = filtered[:top_k]
170
 
171
+ # Step 4️⃣ β€” Neighbor fill (context continuity)
172
  neighbors = set()
173
  for idx in filtered:
174
  for n in [idx - 1, idx + 1]:
 
178
 
179
  # Step 5️⃣ β€” Build final chunk list
180
  final_chunks = [chunks[i] for i in filtered]
181
+ print(f"βœ… Retrieved {len(final_chunks)} chunks (bullet-aware + continuity).")
182
  return final_chunks
183
 
184
  except Exception as e:
 
198
  if chat_llm is None:
199
  return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space."
200
 
201
+ # Combine chunks with markers
202
  context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
203
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
204
 
 
209
  "You are an expert enterprise documentation assistant. "
210
  "When reasoning_mode is off, stay strictly factual and concise. "
211
  "When reasoning_mode is on, combine insights across chunks logically "
212
+ "and explain the reasoning briefly. "
213
+ "If the answer is not in the document, reply exactly: "
214
  "'I don't know based on the provided document.'"
215
  ),
216
  },
 
224
  print(f"⚠️ GPT-4o generation failed: {e}")
225
  return "⚠️ Error: Could not generate an answer."
226
 
 
227
  # ==========================================================
228
  # 7️⃣ Local Test
229
  # ==========================================================
230
  if __name__ == "__main__":
231
+ from vectorstore import build_faiss_index
232
+
233
  dummy_chunks = [
234
+ "- Step 1: Enable order confirmation capability.",
235
+ "- Step 2: Configure supplier email.",
236
+ "Setup instructions and configuration details.",
237
+ "Prerequisites for automation are described here."
238
  ]
239
  embeddings = [
240
  _query_model.encode([f"passage: {c}"], convert_to_numpy=True, normalize_embeddings=True)[0]
 
242
  ]
243
  index = build_faiss_index(embeddings)
244
 
245
+ query = "What are the prerequisites for commerce automation?"
246
  retrieved = retrieve_chunks(query, index, dummy_chunks)
247
  print("πŸ” Retrieved:", retrieved)
248
  print("πŸ’¬ Answer:", generate_answer(query, retrieved, reasoning_mode=False))