Shubham170793 commited on
Commit
f86b15f
·
verified ·
1 Parent(s): 5f3c646

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +76 -56
src/qa.py CHANGED
@@ -3,6 +3,7 @@ qa.py — GPT-4o (SAP Gen AI Hub) + ReRank Retrieval
3
  --------------------------------------------------
4
  ✅ Semantic retrieval (FAISS + cosine re-rank + neighbor fill)
5
  ✅ Bullet-aware similarity boost for procedural chunks
 
6
  ✅ Smart factual mode (fast)
7
  ✅ Deep reasoning mode (ChatGPT-like)
8
  """
@@ -10,16 +11,18 @@ qa.py — GPT-4o (SAP Gen AI Hub) + ReRank Retrieval
10
  import os
11
  import re
12
  import json
 
 
13
  import numpy as np
14
  from sentence_transformers import SentenceTransformer
15
  from sklearn.metrics.pairwise import cosine_similarity
16
  from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
17
  from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
18
 
19
- print("✅ qa.py (GPT-4o via Gen AI Hub + Bullet-Aware Retrieval) loaded from:", __file__)
20
 
21
  # ==========================================================
22
- # 1️⃣ Hugging Face Cache
23
  # ==========================================================
24
  CACHE_DIR = "/tmp/hf_cache"
25
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -35,7 +38,7 @@ os.environ.update({
35
  # ==========================================================
36
  try:
37
  _query_model = SentenceTransformer(
38
- "intfloat/e5-small-v2", # ⚡ Faster, 384-dim embeddings
39
  cache_folder=CACHE_DIR
40
  )
41
  print("✅ Loaded embedding model: intfloat/e5-small-v2 (fast mode)")
@@ -74,7 +77,52 @@ except Exception as e:
74
  chat_llm = None
75
 
76
  # ==========================================================
77
- # 4️⃣ Prompt Templates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # ==========================================================
79
  STRICT_PROMPT = (
80
  "You are an enterprise documentation assistant.\n"
@@ -97,7 +145,7 @@ REASONING_PROMPT = (
97
  )
98
 
99
  # ==========================================================
100
- # 5️⃣ Retrieval — FAISS + Bullet-Aware Re-rank + Neighbor Fill
101
  # ==========================================================
102
  from vectorstore import build_faiss_index
103
 
@@ -105,9 +153,8 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
105
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
106
  embeddings: list = None):
107
  """
108
- Re-rank and optionally fill with neighbors for context continuity.
109
- Adds small similarity boost for bullet-style or step-based chunks.
110
- Auto-detects and rebuilds FAISS index if dimension mismatch occurs.
111
  """
112
 
113
  if not index or not chunks:
@@ -115,60 +162,45 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
115
  return []
116
 
117
  try:
118
- # Encode query embedding
119
  q_emb = _query_model.encode(
120
  [f"query: {query.strip()}"],
121
  convert_to_numpy=True,
122
  normalize_embeddings=True
123
  )[0]
124
 
125
- # ✅ Sanity check: dimension match between query and FAISS index
126
  if hasattr(index, "d") and q_emb.shape[0] != index.d:
127
- print(f"⚠️ FAISS index dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
128
  if embeddings:
129
- print("🔄 Rebuilding FAISS index to match embedding dimensions...")
130
  index = build_faiss_index(embeddings)
131
- print("✅ FAISS index successfully rebuilt.")
132
-
133
- q_emb = _query_model.encode(
134
- [f"query: {query.strip()}"],
135
- convert_to_numpy=True,
136
- normalize_embeddings=True
137
- )[0]
138
  else:
139
- print("❌ No embeddings available to rebuild FAISS index.")
140
  return []
141
 
142
  # Step 1️⃣ — Initial FAISS retrieval
143
  num_candidates = max(top_k * candidate_multiplier, top_k + 2)
144
  distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
145
  candidate_indices = [int(i) for i in indices[0] if i >= 0]
146
- candidate_indices = list(dict.fromkeys(candidate_indices)) # de-dupe
147
 
148
- # Step 2️⃣ — Compute similarities
149
  doc_embs = _query_model.encode(
150
  [f"passage: {chunks[i]}" for i in candidate_indices],
151
  convert_to_numpy=True,
152
  normalize_embeddings=True,
153
  )
154
  sims = cosine_similarity([q_emb], doc_embs)[0]
155
-
156
- # 🔹 NEW: Boost similarity for bullet-style or step-based chunks
157
  boosted_sims = []
158
  for idx, sim in zip(candidate_indices, sims):
159
- chunk_text = chunks[idx].strip()
160
- if re.match(r"^[-•\d]+[\.\s]", chunk_text): # bullet or numbered
161
- sim += 0.05 # small procedural context boost
162
  boosted_sims.append((idx, sim))
163
 
164
  ranked = sorted(boosted_sims, key=lambda x: x[1], reverse=True)
 
165
 
166
- # Step 3️⃣ — Filter by similarity threshold
167
- filtered = [idx for idx, sim in ranked if sim >= min_similarity]
168
- if len(filtered) > top_k:
169
- filtered = filtered[:top_k]
170
-
171
- # Step 4️⃣ — Neighbor fill (context continuity)
172
  neighbors = set()
173
  for idx in filtered:
174
  for n in [idx - 1, idx + 1]:
@@ -176,7 +208,6 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
176
  neighbors.add(n)
177
  filtered = sorted(set(filtered) | neighbors)
178
 
179
- # Step 5️⃣ — Build final chunk list
180
  final_chunks = [chunks[i] for i in filtered]
181
  print(f"✅ Retrieved {len(final_chunks)} chunks (bullet-aware + continuity).")
182
  return final_chunks
@@ -186,34 +217,25 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
186
  return []
187
 
188
  # ==========================================================
189
- # 6️⃣ Answer Generation — GPT-4o via Gen AI Hub
190
  # ==========================================================
191
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
192
- """
193
- reasoning_mode=False → strict factual mode (fast)
194
- reasoning_mode=True → deep reasoning mode (ChatGPT-like)
195
- """
196
  if not retrieved_chunks:
197
  return "Sorry, I couldn’t find relevant information in the document."
198
  if chat_llm is None:
199
  return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space."
200
 
201
- # Combine chunks with markers
202
  context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
203
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
204
 
205
  messages = [
206
- {
207
- "role": "system",
208
- "content": (
209
- "You are an expert enterprise documentation assistant. "
210
- "When reasoning_mode is off, stay strictly factual and concise. "
211
- "When reasoning_mode is on, combine insights across chunks logically "
212
- "and explain the reasoning briefly. "
213
- "If the answer is not in the document, reply exactly: "
214
- "'I don't know based on the provided document.'"
215
- ),
216
- },
217
  {"role": "user", "content": prompt},
218
  ]
219
 
@@ -225,7 +247,7 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
225
  return "⚠️ Error: Could not generate an answer."
226
 
227
  # ==========================================================
228
- # 7️⃣ Local Test
229
  # ==========================================================
230
  if __name__ == "__main__":
231
  from vectorstore import build_faiss_index
@@ -236,10 +258,8 @@ if __name__ == "__main__":
236
  "Setup instructions and configuration details.",
237
  "Prerequisites for automation are described here."
238
  ]
239
- embeddings = [
240
- _query_model.encode([f"passage: {c}"], convert_to_numpy=True, normalize_embeddings=True)[0]
241
- for c in dummy_chunks
242
- ]
243
  index = build_faiss_index(embeddings)
244
 
245
  query = "What are the prerequisites for commerce automation?"
 
3
  --------------------------------------------------
4
  ✅ Semantic retrieval (FAISS + cosine re-rank + neighbor fill)
5
  ✅ Bullet-aware similarity boost for procedural chunks
6
+ ✅ Embedding caching (per PDF)
7
  ✅ Smart factual mode (fast)
8
  ✅ Deep reasoning mode (ChatGPT-like)
9
  """
 
11
  import os
12
  import re
13
  import json
14
+ import pickle
15
+ import hashlib
16
  import numpy as np
17
  from sentence_transformers import SentenceTransformer
18
  from sklearn.metrics.pairwise import cosine_similarity
19
  from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client
20
  from gen_ai_hub.proxy.langchain.openai import ChatOpenAI
21
 
22
+ print("✅ qa.py (GPT-4o via Gen AI Hub + Bullet-Aware Retrieval + Cache) loaded from:", __file__)
23
 
24
  # ==========================================================
25
+ # 1️⃣ Hugging Face Cache Setup
26
  # ==========================================================
27
  CACHE_DIR = "/tmp/hf_cache"
28
  os.makedirs(CACHE_DIR, exist_ok=True)
 
38
  # ==========================================================
39
  try:
40
  _query_model = SentenceTransformer(
41
+ "intfloat/e5-small-v2", # ⚡ Faster, 384-dim embeddings
42
  cache_folder=CACHE_DIR
43
  )
44
  print("✅ Loaded embedding model: intfloat/e5-small-v2 (fast mode)")
 
77
  chat_llm = None
78
 
79
  # ==========================================================
80
+ # 4️⃣ Embedding Cache Manager
81
+ # ==========================================================
82
+ CACHE_EMB_DIR = "/tmp/embed_cache"
83
+ os.makedirs(CACHE_EMB_DIR, exist_ok=True)
84
+
85
+ def _hash_name(file_name: str):
86
+ """Generate unique hash for PDF file name."""
87
+ return hashlib.md5(file_name.encode()).hexdigest()
88
+
89
+ def cache_embeddings(file_name: str, chunks, embed_func):
90
+ """
91
+ Checks if cached embeddings exist for a PDF; if not, compute and save.
92
+ """
93
+ cache_path = os.path.join(CACHE_EMB_DIR, f"{_hash_name(file_name)}.pkl")
94
+
95
+ if os.path.exists(cache_path):
96
+ print(f"🧠 Loaded cached embeddings for {file_name}")
97
+ with open(cache_path, "rb") as f:
98
+ return pickle.load(f)
99
+
100
+ print(f"💡 No cache found for {file_name}. Generating embeddings...")
101
+ embeddings = embed_func(chunks)
102
+ with open(cache_path, "wb") as f:
103
+ pickle.dump(embeddings, f)
104
+ print(f"💾 Cached embeddings saved for {file_name}")
105
+ return embeddings
106
+
107
+ def embed_chunks(chunks, batch_size=32):
108
+ """
109
+ Batch-encode text chunks for speed.
110
+ """
111
+ all_embeddings = []
112
+ for i in range(0, len(chunks), batch_size):
113
+ batch = [f"passage: {c}" for c in chunks[i:i+batch_size]]
114
+ batch_embs = _query_model.encode(
115
+ batch,
116
+ convert_to_numpy=True,
117
+ normalize_embeddings=True,
118
+ show_progress_bar=False
119
+ )
120
+ all_embeddings.extend(batch_embs)
121
+ print(f"⚡ Embedded {len(all_embeddings)} chunks in batches of {batch_size}")
122
+ return np.array(all_embeddings)
123
+
124
+ # ==========================================================
125
+ # 5️⃣ Prompt Templates
126
  # ==========================================================
127
  STRICT_PROMPT = (
128
  "You are an enterprise documentation assistant.\n"
 
145
  )
146
 
147
  # ==========================================================
148
+ # 6️⃣ Retrieval — FAISS + Bullet-Aware Re-rank + Neighbor Fill
149
  # ==========================================================
150
  from vectorstore import build_faiss_index
151
 
 
153
  min_similarity: float = 0.6, candidate_multiplier: int = 3,
154
  embeddings: list = None):
155
  """
156
+ Retrieves top relevant chunks and preserves context continuity.
157
+ Adds small similarity boost for procedural (bullet or numbered) chunks.
 
158
  """
159
 
160
  if not index or not chunks:
 
162
  return []
163
 
164
  try:
 
165
  q_emb = _query_model.encode(
166
  [f"query: {query.strip()}"],
167
  convert_to_numpy=True,
168
  normalize_embeddings=True
169
  )[0]
170
 
171
+ # ✅ Dimension sanity check
172
  if hasattr(index, "d") and q_emb.shape[0] != index.d:
173
+ print(f"⚠️ FAISS dimension mismatch: index={index.d}, query={q_emb.shape[0]}")
174
  if embeddings:
175
+ print("🔄 Rebuilding FAISS index...")
176
  index = build_faiss_index(embeddings)
 
 
 
 
 
 
 
177
  else:
 
178
  return []
179
 
180
  # Step 1️⃣ — Initial FAISS retrieval
181
  num_candidates = max(top_k * candidate_multiplier, top_k + 2)
182
  distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
183
  candidate_indices = [int(i) for i in indices[0] if i >= 0]
184
+ candidate_indices = list(dict.fromkeys(candidate_indices))
185
 
186
+ # Step 2️⃣ — Re-rank with bullet-aware boost
187
  doc_embs = _query_model.encode(
188
  [f"passage: {chunks[i]}" for i in candidate_indices],
189
  convert_to_numpy=True,
190
  normalize_embeddings=True,
191
  )
192
  sims = cosine_similarity([q_emb], doc_embs)[0]
 
 
193
  boosted_sims = []
194
  for idx, sim in zip(candidate_indices, sims):
195
+ text = chunks[idx].strip()
196
+ if re.match(r"^[-•\d]+[\.\s]", text): # bullet or step pattern
197
+ sim += 0.05
198
  boosted_sims.append((idx, sim))
199
 
200
  ranked = sorted(boosted_sims, key=lambda x: x[1], reverse=True)
201
+ filtered = [idx for idx, sim in ranked if sim >= min_similarity][:top_k]
202
 
203
+ # Step 3️⃣ — Add neighboring chunks for continuity
 
 
 
 
 
204
  neighbors = set()
205
  for idx in filtered:
206
  for n in [idx - 1, idx + 1]:
 
208
  neighbors.add(n)
209
  filtered = sorted(set(filtered) | neighbors)
210
 
 
211
  final_chunks = [chunks[i] for i in filtered]
212
  print(f"✅ Retrieved {len(final_chunks)} chunks (bullet-aware + continuity).")
213
  return final_chunks
 
217
  return []
218
 
219
  # ==========================================================
220
+ # 7️⃣ Answer Generation
221
  # ==========================================================
222
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
 
 
 
 
223
  if not retrieved_chunks:
224
  return "Sorry, I couldn’t find relevant information in the document."
225
  if chat_llm is None:
226
  return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space."
227
 
 
228
  context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
229
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
230
 
231
  messages = [
232
+ {"role": "system", "content":
233
+ "You are an expert enterprise documentation assistant. "
234
+ "When reasoning_mode is off, stay strictly factual and concise. "
235
+ "When reasoning_mode is on, combine insights across chunks logically "
236
+ "and explain briefly. "
237
+ "If the answer is not in the document, reply exactly: "
238
+ "'I don't know based on the provided document.'"},
 
 
 
 
239
  {"role": "user", "content": prompt},
240
  ]
241
 
 
247
  return "⚠️ Error: Could not generate an answer."
248
 
249
  # ==========================================================
250
+ # 8️⃣ Local Test
251
  # ==========================================================
252
  if __name__ == "__main__":
253
  from vectorstore import build_faiss_index
 
258
  "Setup instructions and configuration details.",
259
  "Prerequisites for automation are described here."
260
  ]
261
+
262
+ embeddings = embed_chunks(dummy_chunks)
 
 
263
  index = build_faiss_index(embeddings)
264
 
265
  query = "What are the prerequisites for commerce automation?"