Shubham170793 commited on
Commit
fbd4778
·
verified ·
1 Parent(s): 28eda6f

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +75 -151
src/qa.py CHANGED
@@ -1,11 +1,9 @@
1
  """
2
- qa.py — Phi-2 FAST + ReRank (stable) — Prefer semantic ranking, neighbor-fill last-resort
3
- ---------------------------------------------------------------------------------------
4
- - Uses intfloat/e5-small-v2 for embeddings
5
- - Uses microsoft/phi-2 for generation
6
- - Re-ranks candidate pool from FAISS then picks top_k by true cosine similarity
7
- - Neighbor expansion only if not enough high-sim items
8
- - Logs chunk indices + similarity scores for debugging
9
  """
10
 
11
  import os
@@ -15,11 +13,11 @@ from sklearn.metrics.pairwise import cosine_similarity
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  import torch
17
 
18
- print("✅ qa.py (Phi-2 FAST + ReRank stable) loaded from:", __file__)
19
 
20
- # ---------------------------
21
- # Cache
22
- # ---------------------------
23
  CACHE_DIR = "/tmp/hf_cache"
24
  os.makedirs(CACHE_DIR, exist_ok=True)
25
  os.environ.update({
@@ -28,21 +26,20 @@ os.environ.update({
28
  "HF_DATASETS_CACHE": CACHE_DIR,
29
  "HF_MODULES_CACHE": CACHE_DIR
30
  })
31
- print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
32
 
33
- # ---------------------------
34
- # Embeddings
35
- # ---------------------------
36
  try:
37
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
38
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
39
  except Exception as e:
40
- print(f"⚠️ Embedding load failed ({e}), falling back to MiniLM")
41
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
42
 
43
- # ---------------------------
44
- # Phi-2 model
45
- # ---------------------------
46
  MODEL_NAME = "microsoft/phi-2"
47
  print(f"✅ Loading LLM: {MODEL_NAME}")
48
 
@@ -63,194 +60,121 @@ _answer_model = pipeline(
63
  )
64
  print("✅ Phi-2 text-generation pipeline ready (optimized).")
65
 
66
- # ---------------------------
67
- # Prompts
68
- # ---------------------------
69
  STRICT_PROMPT = (
70
  "You are an enterprise documentation assistant.\n"
71
- "Use ONLY the CONTEXT chunks below to answer the QUESTION.\n"
72
- "Cite the chunk number(s) you used, e.g. [Chunk 3].\n"
73
- "If the document does not contain the answer, reply exactly:\n"
74
- "\"I don't know based on the provided document.\"\n\n"
75
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
76
  )
77
 
78
  REASONING_PROMPT = (
79
- "You are an expert enterprise assistant with reasoning capacity.\n"
80
- "Prefer the provided CONTEXT but you may cautiously infer when reasonable.\n"
81
- "If you infer, say so and prefer facts from the document.\n"
82
- "If the document lacks the answer, say:\n"
83
- "\"I don't know based on the provided document.\"\n\n"
84
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
85
  )
86
 
87
- # ---------------------------
88
- # Retrieval: FAISS -> rerank -> neighbor fill (last resort)
89
- # ---------------------------
90
- def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3, min_similarity: float = 0.55, candidate_multiplier: int = 4):
91
- """
92
- Steps:
93
- 1. Encode query (E5 style).
94
- 2. Run FAISS search for k*candidate_multiplier candidates.
95
- 3. Re-embed those candidate texts and compute cosine similarity with query embedding.
96
- 4. Sort by similarity and pick top_k where similarity >= min_similarity.
97
- 5. If fewer than top_k passed threshold, fill remaining slots by:
98
- - selecting neighboring chunks around the *highest-scoring* chunk(s),
99
- but only if absolutely necessary (keeps noise low).
100
- Returns: ordered list of chunks (strings)
101
- Also prints indices + similarity scores for debugging.
102
- """
103
-
104
  if not index or not chunks:
105
  return []
106
 
107
  try:
108
- # 1. encode query
109
  q_emb = _query_model.encode(
110
  [f"query: {query.strip()}"],
111
  convert_to_numpy=True,
112
  normalize_embeddings=True
113
  )[0]
114
 
115
- # 2. FAISS initial retrieval (get a larger candidate pool)
116
- num_candidates = max(top_k * candidate_multiplier, top_k + 2)
117
  distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
118
- candidate_indices = [int(i) for i in indices[0] if i >= 0]
119
-
120
- # protective dedupe and clamp
121
- candidate_indices = list(dict.fromkeys(candidate_indices)) # preserve order, unique
122
 
123
- # 3. Re-embed candidate texts and compute true cosine similarity
124
  candidate_texts = [chunks[i] for i in candidate_indices]
125
- # Encode passages (passage prefix helps alignment)
126
  doc_embs = _query_model.encode(
127
  [f"passage: {c}" for c in candidate_texts],
128
  convert_to_numpy=True,
129
  normalize_embeddings=True
130
  )
131
  sims = cosine_similarity([q_emb], doc_embs)[0]
132
-
133
- # Pair up indices and sims and sort descending
134
- paired = [(candidate_indices[i], float(sims[i])) for i in range(len(candidate_indices))]
135
- paired_sorted = sorted(paired, key=lambda x: x[1], reverse=True)
136
-
137
- # Debug print: top candidates and their similarity
138
- print("🔎 Candidate ranking (index : sim):")
139
- for idx, sim in paired_sorted[: min(len(paired_sorted), top_k * 3)]:
140
- print(f" - Chunk {idx} : {sim:.4f}")
141
-
142
- # 4. Pick those meeting threshold
143
- selected = [idx for idx, sim in paired_sorted if sim >= min_similarity]
144
-
145
- # Preserve order by similarity
146
- selected = selected[:top_k]
147
-
148
- # 5. If not enough, fill by neighbors around highest-scoring items
149
- if len(selected) < top_k:
150
- needed = top_k - len(selected)
151
- # pick highest scoring indices as anchor(s)
152
- anchors = [idx for idx, _ in paired_sorted[:3]] # top 3 anchors
153
- expanded = []
154
- for a in anchors:
155
- # neighbors ordered by proximity: a, a-1, a+1, a-2, a+2 ...
156
- if a not in expanded:
157
- expanded.append(a)
158
- offset = 1
159
- while len(expanded) < top_k and offset < 5:
160
- for cand in (a - offset, a + offset):
161
- if 0 <= cand < len(chunks) and cand not in expanded:
162
- expanded.append(cand)
163
- if len(expanded) >= top_k:
164
- break
165
- offset += 1
166
  if len(expanded) >= top_k:
167
  break
168
- # final selected: first maintain previously selected, then add neighbors from expanded preserving order
169
- final_order = []
170
- for idx, _sim in paired_sorted:
171
- if idx in selected and idx not in final_order:
172
- final_order.append(idx)
173
- for idx in expanded:
174
- if idx not in final_order:
175
- final_order.append(idx)
176
- selected = final_order[:top_k]
177
 
178
- # final chunk strings (ordered by selected list)
179
- final_chunks = [chunks[i] for i in selected]
180
-
181
- print(f"✅ retrieve_chunks: returning {len(final_chunks)} chunks (top_k={top_k}, min_sim={min_similarity})")
182
- print(f" chunk indices: {selected}")
183
-
184
- # Also return the indices? (if you want to display chunk numbers in UI, you can)
185
- return final_chunks
186
 
187
  except Exception as e:
188
  print(f"⚠️ Retrieval error: {e}")
189
  return []
190
 
191
-
192
- # ---------------------------
193
- # Answer generation
194
- # ---------------------------
195
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
196
- """
197
- - reasoning_mode=False => strict factual, deterministic
198
- - reasoning_mode=True => allow cautious inference (slower / longer)
199
- """
200
  if not retrieved_chunks:
201
  return "Sorry, I couldn’t find relevant information in the document."
202
 
203
- # Add chunk headings so model can cite them if needed
204
- context_lines = []
205
- for i, chunk in enumerate(retrieved_chunks, start=1):
206
- # Use [Chunk i] markers — LLM will echo them when asked to cite sources
207
- context_lines.append(f"[Chunk {i}]: {chunk.strip()}")
208
- context = "\n".join(context_lines)
209
 
210
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
211
  context=context, query=query
212
  )
213
 
214
  try:
215
- # deterministic in strict mode
216
- if reasoning_mode:
217
- max_new_tokens = 220
218
- temp = 0.6
219
- do_sample = True
220
- else:
221
- max_new_tokens = 140
222
- temp = 0.0
223
- do_sample = False
224
-
225
  result = _answer_model(
226
  prompt,
227
- max_new_tokens=max_new_tokens,
228
- temperature=temp,
229
- do_sample=do_sample,
230
- early_stopping=True,
231
  pad_token_id=_tokenizer.eos_token_id,
 
232
  )
233
 
234
- text = result[0].get("generated_text", "").strip()
235
- # remove the prompt echo if present
236
  if "Answer:" in text:
237
- out = text.split("Answer:")[-1].strip()
238
- else:
239
- out = text
240
-
241
- # Enforce exact fallback phrase if model tries to paraphrase missing-answer
242
- if not reasoning_mode and ("i don't know" in out.lower() or "not present" in out.lower()):
243
- return "I don't know based on the provided document."
244
 
245
- return out
246
 
247
  except Exception as e:
248
  print(f"⚠️ Generation failed: {e}")
249
  return "⚠️ Error: Could not generate an answer."
250
 
251
- # ---------------------------
252
- # Local debug main
253
- # ---------------------------
254
  if __name__ == "__main__":
255
  from vectorstore import build_faiss_index
256
 
@@ -267,6 +191,6 @@ if __name__ == "__main__":
267
  index = build_faiss_index(embeddings)
268
 
269
  query = "How do I create a communication user?"
270
- retrieved = retrieve_chunks(query, index, dummy_chunks, top_k=3, min_similarity=0.55)
271
  print("🔍 Retrieved:", retrieved)
272
- print("💬 Answer:", generate_answer(query, retrieved, reasoning_mode=False))
 
1
  """
2
+ qa.py — Phi-2 FAST + SMART RETRIEVAL (Stable)
3
+ ---------------------------------------------
4
+ intfloat/e5-small-v2 embeddings
5
+ microsoft/phi-2 generation
6
+ Optimized for: speed, factual accuracy, low hallucination
 
 
7
  """
8
 
9
  import os
 
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  import torch
15
 
16
+ print("✅ qa.py (Phi-2 FAST + Smart Retrieval) loaded from:", __file__)
17
 
18
+ # ==========================================================
19
+ # 1️⃣ Cache Setup (Hugging Face /tmp cache)
20
+ # ==========================================================
21
  CACHE_DIR = "/tmp/hf_cache"
22
  os.makedirs(CACHE_DIR, exist_ok=True)
23
  os.environ.update({
 
26
  "HF_DATASETS_CACHE": CACHE_DIR,
27
  "HF_MODULES_CACHE": CACHE_DIR
28
  })
 
29
 
30
+ # ==========================================================
31
+ # 2️⃣ Embedding Model
32
+ # ==========================================================
33
  try:
34
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
35
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
36
  except Exception as e:
37
+ print(f"⚠️ Embedding load failed ({e}), using MiniLM fallback")
38
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
39
 
40
+ # ==========================================================
41
+ # 3️⃣ Phi-2 LLM Setup
42
+ # ==========================================================
43
  MODEL_NAME = "microsoft/phi-2"
44
  print(f"✅ Loading LLM: {MODEL_NAME}")
45
 
 
60
  )
61
  print("✅ Phi-2 text-generation pipeline ready (optimized).")
62
 
63
+ # ==========================================================
64
+ # 4️⃣ Prompt Templates
65
+ # ==========================================================
66
  STRICT_PROMPT = (
67
  "You are an enterprise documentation assistant.\n"
68
+ "Use ONLY the CONTEXT below to answer the QUESTION.\n"
69
+ "If the answer isn’t present, reply exactly:\n"
70
+ "'I don't know based on the provided document.'\n\n"
 
71
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
72
  )
73
 
74
  REASONING_PROMPT = (
75
+ "You are an enterprise assistant with reasoning ability.\n"
76
+ "Think carefully, but use the document context first.\n"
77
+ "If you must infer, say so explicitly.\n"
78
+ "If answer not in the document, reply exactly:\n"
79
+ "'I don't know based on the provided document.'\n\n"
80
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
81
  )
82
 
83
+ # ==========================================================
84
+ # 5️⃣ Smart Retrieval (Re-rank + Neighbor Fill)
85
+ # ==========================================================
86
+ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
87
+ min_similarity: float = 0.6, candidate_multiplier: int = 3):
88
+ """FAISS → Re-rank by cosine sim → Filter → Neighbor fill (only if needed)."""
 
 
 
 
 
 
 
 
 
 
 
89
  if not index or not chunks:
90
  return []
91
 
92
  try:
93
+ # 1️⃣ Encode query
94
  q_emb = _query_model.encode(
95
  [f"query: {query.strip()}"],
96
  convert_to_numpy=True,
97
  normalize_embeddings=True
98
  )[0]
99
 
100
+ # 2️⃣ Initial FAISS retrieval (larger candidate pool)
101
+ num_candidates = top_k * candidate_multiplier
102
  distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
103
+ candidate_indices = list(dict.fromkeys(indices[0])) # dedup, preserve order
 
 
 
104
 
105
+ # 3️⃣ Re-rank by cosine similarity
106
  candidate_texts = [chunks[i] for i in candidate_indices]
 
107
  doc_embs = _query_model.encode(
108
  [f"passage: {c}" for c in candidate_texts],
109
  convert_to_numpy=True,
110
  normalize_embeddings=True
111
  )
112
  sims = cosine_similarity([q_emb], doc_embs)[0]
113
+ ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
114
+
115
+ # 4️⃣ Filter low-similarity
116
+ filtered = [idx for idx, sim in ranked if sim >= min_similarity]
117
+ if len(filtered) > top_k:
118
+ filtered = filtered[:top_k]
119
+
120
+ # 5️⃣ Neighbor fill (only if fewer than top_k)
121
+ if len(filtered) < top_k:
122
+ expanded = set(filtered)
123
+ for idx in filtered:
124
+ for neighbor in [idx - 1, idx + 1]:
125
+ if 0 <= neighbor < len(chunks):
126
+ expanded.add(neighbor)
127
+ if len(expanded) >= top_k:
128
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  if len(expanded) >= top_k:
130
  break
131
+ filtered = sorted(expanded)[:top_k]
 
 
 
 
 
 
 
 
132
 
133
+ print(f"✅ Retrieved {len(filtered)} chunks (top_k={top_k}, min_sim={min_similarity})")
134
+ return [chunks[i] for i in filtered]
 
 
 
 
 
 
135
 
136
  except Exception as e:
137
  print(f"⚠️ Retrieval error: {e}")
138
  return []
139
 
140
+ # ==========================================================
141
+ # 6️⃣ Answer Generation
142
+ # ==========================================================
 
143
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
144
+ """Generate concise, factual or reasoning-based answers using Phi-2."""
 
 
 
145
  if not retrieved_chunks:
146
  return "Sorry, I couldn’t find relevant information in the document."
147
 
148
+ # Include [Chunk N] markers
149
+ context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
 
 
 
 
150
 
151
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
152
  context=context, query=query
153
  )
154
 
155
  try:
 
 
 
 
 
 
 
 
 
 
156
  result = _answer_model(
157
  prompt,
158
+ max_new_tokens=180 if reasoning_mode else 140,
159
+ temperature=0.5 if reasoning_mode else 0.2,
160
+ do_sample=reasoning_mode,
 
161
  pad_token_id=_tokenizer.eos_token_id,
162
+ early_stopping=True,
163
  )
164
 
165
+ text = result[0]["generated_text"].strip()
 
166
  if "Answer:" in text:
167
+ text = text.split("Answer:")[-1].strip()
 
 
 
 
 
 
168
 
169
+ return text or "⚠️ No answer generated."
170
 
171
  except Exception as e:
172
  print(f"⚠️ Generation failed: {e}")
173
  return "⚠️ Error: Could not generate an answer."
174
 
175
+ # ==========================================================
176
+ # 7️⃣ Local Test
177
+ # ==========================================================
178
  if __name__ == "__main__":
179
  from vectorstore import build_faiss_index
180
 
 
191
  index = build_faiss_index(embeddings)
192
 
193
  query = "How do I create a communication user?"
194
+ retrieved = retrieve_chunks(query, index, dummy_chunks)
195
  print("🔍 Retrieved:", retrieved)
196
+ print("💬 Answer:", generate_answer(query, retrieved))