Shubham170793 commited on
Commit
7b7e367
·
verified ·
1 Parent(s): 7e98078

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +38 -65
src/qa.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- qa.py — GPT-4o (SAP Gen AI Hub) + ReRank Retrieval
3
  --------------------------------------------------
4
  ✅ Semantic retrieval (FAISS + cosine re-rank + neighbor fill)
5
  ✅ Bullet-aware similarity boost for procedural chunks
@@ -7,6 +7,7 @@ qa.py — GPT-4o (SAP Gen AI Hub) + ReRank Retrieval
7
  ✅ Smart factual mode (fast)
8
  ✅ Deep reasoning mode (ChatGPT-like)
9
  ✅ genai_generate() helper for suggestions
 
10
  """
11
 
12
  import os
@@ -15,12 +16,13 @@ import json
15
  import pickle
16
  import hashlib
17
  import numpy as np
 
18
  from sentence_transformers import SentenceTransformer
19
  from sklearn.metrics.pairwise import cosine_similarity
20
  from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
21
  from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
22
 
23
- print("✅ qa.py (GPT-4o via Gen AI Hub + Bullet-Aware Retrieval + Cache) loaded from:", __file__)
24
 
25
  # ==========================================================
26
  # 🧱 Permanent Embeddings Cache Directory
@@ -28,7 +30,6 @@ print("✅ qa.py (GPT-4o via Gen AI Hub + Bullet-Aware Retrieval + Cache) loaded
28
  CACHE_EMB_DIR = os.path.join(os.path.dirname(__file__), "embed_cache")
29
  os.makedirs(CACHE_EMB_DIR, exist_ok=True)
30
 
31
- # Verify write permission
32
  try:
33
  test_file = os.path.join(CACHE_EMB_DIR, "test_write.tmp")
34
  with open(test_file, "w") as f:
@@ -57,10 +58,7 @@ os.environ.update({
57
  # 2️⃣ Embedding Model (E5-small-v2)
58
  # ==========================================================
59
  try:
60
- _query_model = SentenceTransformer(
61
- "intfloat/e5-small-v2", # ⚡ Faster, 384-dim embeddings
62
- cache_folder=CACHE_DIR
63
- )
64
  print("✅ Loaded embedding model: intfloat/e5-small-v2 (fast mode)")
65
  except Exception as e:
66
  print(f"⚠️ Embedding load failed ({e}), using MiniLM fallback")
@@ -69,21 +67,15 @@ except Exception as e:
69
  # ==========================================================
70
  # 3️⃣ GPT-4o via SAP Gen AI Hub — Lazy / On-demand initialization
71
  # ==========================================================
72
-
73
  CRED_PATH = os.path.join(os.path.dirname(__file__), "GEN AI HUB PROXY.json")
74
- _chat_llm = None # cached instance
75
 
76
  def get_chat_llm(model_name: str = "gpt-4o", temperature: float = 0.3, max_tokens: int = 1500):
77
- """
78
- Lazily initializes ChatOpenAI via Gen AI Hub proxy.
79
- Only runs when first needed; cached afterward.
80
- """
81
  global _chat_llm
82
  if _chat_llm is not None:
83
  return _chat_llm
84
 
85
  try:
86
- # Optional: set environment variables from service key if present
87
  if os.path.exists(CRED_PATH):
88
  with open(CRED_PATH, "r") as key_file:
89
  svcKey = json.load(key_file)
@@ -109,15 +101,10 @@ def get_chat_llm(model_name: str = "gpt-4o", temperature: float = 0.3, max_token
109
  _chat_llm = None
110
  raise
111
 
112
-
113
  # ==========================================================
114
  # 4️⃣ Embedding Generator (batch-optimized)
115
  # ==========================================================
116
  def embed_chunks(chunks, batch_size: int = 32):
117
- """
118
- Batch-encode text chunks using the global embedding model.
119
- Normalized 384-dim embeddings for FAISS retrieval.
120
- """
121
  if not chunks:
122
  return np.array([])
123
 
@@ -135,18 +122,13 @@ def embed_chunks(chunks, batch_size: int = 32):
135
  return np.array(all_embeddings)
136
 
137
  # ==========================================================
138
- # 5️⃣ Embedding Cache Manager (Chunk-Aware + Auto-Cleanup)
139
  # ==========================================================
140
- CACHE_EMB_DIR = "/tmp/embed_cache"
141
- os.makedirs(CACHE_EMB_DIR, exist_ok=True)
142
-
143
  def _hash_name(file_name: str, chunk_size: int, overlap: int, num_chunks: int):
144
- """Generate unique short hash for a file + chunking configuration."""
145
  combo = f"{file_name}_{chunk_size}_{overlap}_{num_chunks}"
146
  return hashlib.md5(combo.encode()).hexdigest()[:8]
147
 
148
  def _clean_old_caches(base_name: str, keep_latest: int = 5):
149
- """Keep only latest few embedding caches for each document."""
150
  files = [
151
  (os.path.getmtime(os.path.join(CACHE_EMB_DIR, f)), f)
152
  for f in os.listdir(CACHE_EMB_DIR)
@@ -162,7 +144,6 @@ def _clean_old_caches(base_name: str, keep_latest: int = 5):
162
  pass
163
 
164
  def cache_embeddings(file_name: str, chunks, embed_func, chunk_size: int = None, overlap: int = None):
165
- """Load or create embeddings cache (chunk size + overlap aware)."""
166
  cache_key = _hash_name(file_name, chunk_size or 1000, overlap or 100, len(chunks))
167
  cache_file = f"{os.path.basename(file_name)}_cs{chunk_size}_ov{overlap}_{cache_key}.pkl"
168
  cache_path = os.path.join(CACHE_EMB_DIR, cache_file)
@@ -182,9 +163,8 @@ def cache_embeddings(file_name: str, chunks, embed_func, chunk_size: int = None,
182
  return embeddings
183
 
184
  # ==========================================================
185
- # 6️⃣ Prompt Templates (Enhanced for Structured Formatting + Clean Output)
186
  # ==========================================================
187
-
188
  STRICT_PROMPT = (
189
  "You are an enterprise documentation assistant.\n"
190
  "Use all relevant information from the CONTEXT below.\n"
@@ -198,8 +178,6 @@ STRICT_PROMPT = (
198
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
199
  )
200
 
201
-
202
-
203
  REASONING_PROMPT = (
204
  "You are an expert enterprise assistant capable of reasoning.\n"
205
  "Think step by step and synthesize information even if scattered across chunks.\n"
@@ -211,6 +189,30 @@ REASONING_PROMPT = (
211
  "Context:\n{context}\n\nQuestion: {query}\nLet's reason step-by-step:\nAnswer:"
212
  )
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  # ==========================================================
216
  # 7️⃣ Retrieval — FAISS + Bullet-Aware Re-rank + Neighbor Fill
@@ -220,24 +222,14 @@ from vectorstore import build_faiss_index
220
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 7,
221
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
222
  embeddings: list = None):
223
- """
224
- Retrieves the most relevant chunks using FAISS similarity + reranking.
225
- Includes bullet-aware similarity boost and a fallback mechanism if
226
- similarity threshold isn't met — ensuring predictable, complete retrieval.
227
- """
228
  if not index or not chunks:
229
  print("⚠️ No FAISS index or chunks provided — returning empty result.")
230
  return []
231
 
232
  try:
233
- # --- Encode query
234
- q_emb = _query_model.encode(
235
- [f"query: {query.strip()}"],
236
- convert_to_numpy=True,
237
- normalize_embeddings=True
238
- )[0]
239
 
240
- # --- Rebuild index if mismatch occurs
241
  if hasattr(index, "d") and q_emb.shape[0] != index.d:
242
  print(f"⚠️ FAISS dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
243
  if embeddings:
@@ -246,46 +238,35 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 7,
246
  else:
247
  return []
248
 
249
- # --- Retrieve top candidate chunks
250
  num_candidates = max(top_k * candidate_multiplier, top_k + 2)
251
  distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
252
- candidate_indices = [int(i) for i in indices[0] if i >= 0]
253
- candidate_indices = list(dict.fromkeys(candidate_indices)) # remove duplicates
254
 
255
- # --- Re-rank using cosine similarity
256
  doc_embs = _query_model.encode(
257
  [f"passage: {chunks[i]}" for i in candidate_indices],
258
  convert_to_numpy=True,
259
  normalize_embeddings=True,
260
  )
261
  sims = cosine_similarity([q_emb], doc_embs)[0]
262
-
263
  boosted_sims = []
264
  for idx, sim in zip(candidate_indices, sims):
265
  text = chunks[idx].strip()
266
  if re.match(r"^[-•\d]+[\.\s]", text):
267
- sim += 0.05 # slight boost for procedural bullets
268
  boosted_sims.append((idx, sim))
269
 
270
  ranked = sorted(boosted_sims, key=lambda x: x[1], reverse=True)
271
-
272
- # --- Filter based on similarity threshold
273
  filtered = [idx for idx, sim in ranked if sim >= min_similarity][:top_k]
274
-
275
- # --- Fallback: if no matches above threshold, pick top_k anyway
276
  if not filtered:
277
  print(f"⚠️ No chunks ≥ {min_similarity:.2f} — using top {top_k} ranked chunks instead.")
278
  filtered = [idx for idx, sim in ranked[:top_k]]
279
 
280
- # --- Neighbor continuity: include nearby chunks
281
  neighbors = set()
282
  for idx in filtered:
283
  for n in [idx - 1, idx + 1]:
284
  if 0 <= n < len(chunks):
285
  neighbors.add(n)
286
  filtered = sorted(set(filtered) | neighbors)
287
-
288
- # --- Return final chunk set
289
  final_chunks = [chunks[i] for i in filtered]
290
  avg_sim = np.mean([s for _, s in ranked[:top_k]])
291
  print(f"✅ Retrieved {len(final_chunks)} chunks | avg_sim={avg_sim:.3f} | threshold={min_similarity:.2f}")
@@ -295,21 +276,18 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 7,
295
  print(f"⚠️ Retrieval error: {repr(e)}")
296
  return []
297
 
298
-
299
  # ==========================================================
300
- # 8️⃣ Answer Generation (Lazy GPT-4o Initialization)
301
  # ==========================================================
302
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
303
  if not retrieved_chunks:
304
  return "Sorry, I couldn’t find relevant information in the document."
305
 
306
- # Try lazy initialization
307
  try:
308
  chat_llm_local = get_chat_llm()
309
  except Exception:
310
  return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space."
311
 
312
- # Build context and prompt
313
  context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
314
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
315
 
@@ -323,8 +301,6 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
323
  "'I don't know based on the provided document.'"},
324
  {"role": "user", "content": prompt},
325
  ]
326
-
327
- # Invoke GPT-4o
328
  try:
329
  response = chat_llm_local.invoke(messages)
330
  return response.content.strip()
@@ -332,12 +308,10 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
332
  print(f"⚠️ GPT-4o generation failed: {e}")
333
  return "⚠️ Error: Could not generate an answer."
334
 
335
-
336
  # ==========================================================
337
- # 9️⃣ Generic Text Generation Helper (for AI suggestions)
338
  # ==========================================================
339
  def genai_generate(prompt: str) -> str:
340
- # Try lazy initialization
341
  try:
342
  chat_llm_local = get_chat_llm()
343
  except Exception:
@@ -370,7 +344,6 @@ if __name__ == "__main__":
370
 
371
  embeddings = embed_chunks(dummy_chunks)
372
  index = build_faiss_index(embeddings)
373
-
374
  query = "What are the prerequisites for commerce automation?"
375
  retrieved = retrieve_chunks(query, index, dummy_chunks)
376
  print("🔍 Retrieved:", retrieved)
 
1
  """
2
+ qa.py — GPT-4o (SAP Gen AI Hub) + ReRank Retrieval + PRF Query Expansion
3
  --------------------------------------------------
4
  ✅ Semantic retrieval (FAISS + cosine re-rank + neighbor fill)
5
  ✅ Bullet-aware similarity boost for procedural chunks
 
7
  ✅ Smart factual mode (fast)
8
  ✅ Deep reasoning mode (ChatGPT-like)
9
  ✅ genai_generate() helper for suggestions
10
+ ✅ NEW: Lightweight PRF query expansion to fix synonym-based retrieval misses
11
  """
12
 
13
  import os
 
16
  import pickle
17
  import hashlib
18
  import numpy as np
19
+ from collections import Counter
20
  from sentence_transformers import SentenceTransformer
21
  from sklearn.metrics.pairwise import cosine_similarity
22
  from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
23
  from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
24
 
25
+ print("✅ qa.py (GPT-4o via Gen AI Hub + Bullet-Aware Retrieval + PRF) loaded from:", __file__)
26
 
27
  # ==========================================================
28
  # 🧱 Permanent Embeddings Cache Directory
 
30
  CACHE_EMB_DIR = os.path.join(os.path.dirname(__file__), "embed_cache")
31
  os.makedirs(CACHE_EMB_DIR, exist_ok=True)
32
 
 
33
  try:
34
  test_file = os.path.join(CACHE_EMB_DIR, "test_write.tmp")
35
  with open(test_file, "w") as f:
 
58
  # 2️⃣ Embedding Model (E5-small-v2)
59
  # ==========================================================
60
  try:
61
+ _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
 
 
 
62
  print("✅ Loaded embedding model: intfloat/e5-small-v2 (fast mode)")
63
  except Exception as e:
64
  print(f"⚠️ Embedding load failed ({e}), using MiniLM fallback")
 
67
  # ==========================================================
68
  # 3️⃣ GPT-4o via SAP Gen AI Hub — Lazy / On-demand initialization
69
  # ==========================================================
 
70
  CRED_PATH = os.path.join(os.path.dirname(__file__), "GEN AI HUB PROXY.json")
71
+ _chat_llm = None
72
 
73
  def get_chat_llm(model_name: str = "gpt-4o", temperature: float = 0.3, max_tokens: int = 1500):
 
 
 
 
74
  global _chat_llm
75
  if _chat_llm is not None:
76
  return _chat_llm
77
 
78
  try:
 
79
  if os.path.exists(CRED_PATH):
80
  with open(CRED_PATH, "r") as key_file:
81
  svcKey = json.load(key_file)
 
101
  _chat_llm = None
102
  raise
103
 
 
104
  # ==========================================================
105
  # 4️⃣ Embedding Generator (batch-optimized)
106
  # ==========================================================
107
  def embed_chunks(chunks, batch_size: int = 32):
 
 
 
 
108
  if not chunks:
109
  return np.array([])
110
 
 
122
  return np.array(all_embeddings)
123
 
124
  # ==========================================================
125
+ # 5️⃣ Embedding Cache Manager
126
  # ==========================================================
 
 
 
127
  def _hash_name(file_name: str, chunk_size: int, overlap: int, num_chunks: int):
 
128
  combo = f"{file_name}_{chunk_size}_{overlap}_{num_chunks}"
129
  return hashlib.md5(combo.encode()).hexdigest()[:8]
130
 
131
  def _clean_old_caches(base_name: str, keep_latest: int = 5):
 
132
  files = [
133
  (os.path.getmtime(os.path.join(CACHE_EMB_DIR, f)), f)
134
  for f in os.listdir(CACHE_EMB_DIR)
 
144
  pass
145
 
146
  def cache_embeddings(file_name: str, chunks, embed_func, chunk_size: int = None, overlap: int = None):
 
147
  cache_key = _hash_name(file_name, chunk_size or 1000, overlap or 100, len(chunks))
148
  cache_file = f"{os.path.basename(file_name)}_cs{chunk_size}_ov{overlap}_{cache_key}.pkl"
149
  cache_path = os.path.join(CACHE_EMB_DIR, cache_file)
 
163
  return embeddings
164
 
165
  # ==========================================================
166
+ # 6️⃣ Prompt Templates
167
  # ==========================================================
 
168
  STRICT_PROMPT = (
169
  "You are an enterprise documentation assistant.\n"
170
  "Use all relevant information from the CONTEXT below.\n"
 
178
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
179
  )
180
 
 
 
181
  REASONING_PROMPT = (
182
  "You are an expert enterprise assistant capable of reasoning.\n"
183
  "Think step by step and synthesize information even if scattered across chunks.\n"
 
189
  "Context:\n{context}\n\nQuestion: {query}\nLet's reason step-by-step:\nAnswer:"
190
  )
191
 
192
+ # ==========================================================
193
+ # 🔹 NEW: Lightweight PRF Query Expansion
194
+ # ==========================================================
195
+ def expand_query_embedding(query, model, index, chunks, topN=40, alpha=0.75):
196
+ """
197
+ Expands the query embedding slightly using top candidate chunks (PRF-style).
198
+ Helps when query wording differs from document phrasing.
199
+ """
200
+ try:
201
+ q_emb = model.encode([f"query: {query}"], convert_to_numpy=True, normalize_embeddings=True)[0]
202
+ D, I = index.search(np.array([q_emb]).astype("float32"), topN)
203
+ texts = " ".join(chunks[i] for i in I[0] if i >= 0)
204
+ words = re.findall(r"[A-Za-z]{4,}", texts)
205
+ common = [w for w, _ in Counter(words).most_common(6) if w.lower() not in query.lower()]
206
+ if not common:
207
+ return q_emb
208
+ e_emb = model.encode([f"passage: {' '.join(common)}"], convert_to_numpy=True, normalize_embeddings=True)[0]
209
+ combined = alpha * q_emb + (1 - alpha) * e_emb
210
+ combined /= np.linalg.norm(combined)
211
+ print(f"🔍 Query expanded with: {common}")
212
+ return combined
213
+ except Exception as e:
214
+ print(f"⚠️ Query expansion skipped due to error: {e}")
215
+ return q_emb
216
 
217
  # ==========================================================
218
  # 7️⃣ Retrieval — FAISS + Bullet-Aware Re-rank + Neighbor Fill
 
222
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 7,
223
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
224
  embeddings: list = None):
 
 
 
 
 
225
  if not index or not chunks:
226
  print("⚠️ No FAISS index or chunks provided — returning empty result.")
227
  return []
228
 
229
  try:
230
+ # --- PRF-enhanced query embedding
231
+ q_emb = expand_query_embedding(query, _query_model, index, chunks)
 
 
 
 
232
 
 
233
  if hasattr(index, "d") and q_emb.shape[0] != index.d:
234
  print(f"⚠️ FAISS dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
235
  if embeddings:
 
238
  else:
239
  return []
240
 
 
241
  num_candidates = max(top_k * candidate_multiplier, top_k + 2)
242
  distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
243
+ candidate_indices = list(dict.fromkeys([int(i) for i in indices[0] if i >= 0]))
 
244
 
 
245
  doc_embs = _query_model.encode(
246
  [f"passage: {chunks[i]}" for i in candidate_indices],
247
  convert_to_numpy=True,
248
  normalize_embeddings=True,
249
  )
250
  sims = cosine_similarity([q_emb], doc_embs)[0]
 
251
  boosted_sims = []
252
  for idx, sim in zip(candidate_indices, sims):
253
  text = chunks[idx].strip()
254
  if re.match(r"^[-•\d]+[\.\s]", text):
255
+ sim += 0.05
256
  boosted_sims.append((idx, sim))
257
 
258
  ranked = sorted(boosted_sims, key=lambda x: x[1], reverse=True)
 
 
259
  filtered = [idx for idx, sim in ranked if sim >= min_similarity][:top_k]
 
 
260
  if not filtered:
261
  print(f"⚠️ No chunks ≥ {min_similarity:.2f} — using top {top_k} ranked chunks instead.")
262
  filtered = [idx for idx, sim in ranked[:top_k]]
263
 
 
264
  neighbors = set()
265
  for idx in filtered:
266
  for n in [idx - 1, idx + 1]:
267
  if 0 <= n < len(chunks):
268
  neighbors.add(n)
269
  filtered = sorted(set(filtered) | neighbors)
 
 
270
  final_chunks = [chunks[i] for i in filtered]
271
  avg_sim = np.mean([s for _, s in ranked[:top_k]])
272
  print(f"✅ Retrieved {len(final_chunks)} chunks | avg_sim={avg_sim:.3f} | threshold={min_similarity:.2f}")
 
276
  print(f"⚠️ Retrieval error: {repr(e)}")
277
  return []
278
 
 
279
  # ==========================================================
280
+ # 8️⃣ Answer Generation
281
  # ==========================================================
282
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
283
  if not retrieved_chunks:
284
  return "Sorry, I couldn’t find relevant information in the document."
285
 
 
286
  try:
287
  chat_llm_local = get_chat_llm()
288
  except Exception:
289
  return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space."
290
 
 
291
  context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
292
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
293
 
 
301
  "'I don't know based on the provided document.'"},
302
  {"role": "user", "content": prompt},
303
  ]
 
 
304
  try:
305
  response = chat_llm_local.invoke(messages)
306
  return response.content.strip()
 
308
  print(f"⚠️ GPT-4o generation failed: {e}")
309
  return "⚠️ Error: Could not generate an answer."
310
 
 
311
  # ==========================================================
312
+ # 9️⃣ Generic Text Generation Helper
313
  # ==========================================================
314
  def genai_generate(prompt: str) -> str:
 
315
  try:
316
  chat_llm_local = get_chat_llm()
317
  except Exception:
 
344
 
345
  embeddings = embed_chunks(dummy_chunks)
346
  index = build_faiss_index(embeddings)
 
347
  query = "What are the prerequisites for commerce automation?"
348
  retrieved = retrieve_chunks(query, index, dummy_chunks)
349
  print("🔍 Retrieved:", retrieved)