Shubham170793 commited on
Commit
28eda6f
·
verified ·
1 Parent(s): 960fe58

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +148 -102
src/qa.py CHANGED
@@ -1,14 +1,11 @@
1
  """
2
- qa.py — Phi-2 FAST + RERANKED RETRIEVAL + INTENT WEIGHTING (with Debug)
3
- -----------------------------------------------------------------------
4
- Uses:
5
- intfloat/e5-small-v2 — embeddings
6
- microsoft/phi-2 — generation
7
- Optimized for: speed, factual accuracy, and semantic retrieval on Hugging Face Spaces
8
- Now includes:
9
- • Intent-weighted query embedding
10
- • Intent-aware prompting (LLM focuses on “how”, “what”, “why”)
11
- • Debug printout showing detected query intent for verification
12
  """
13
 
14
  import os
@@ -18,11 +15,11 @@ from sklearn.metrics.pairwise import cosine_similarity
18
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
19
  import torch
20
 
21
- print("✅ qa.py (Phi-2 FAST + ReRank + Intent + Debug) loaded from:", __file__)
22
 
23
- # ==========================================================
24
- # 1️⃣ Cache Setup (Hugging Face /tmp cache)
25
- # ==========================================================
26
  CACHE_DIR = "/tmp/hf_cache"
27
  os.makedirs(CACHE_DIR, exist_ok=True)
28
  os.environ.update({
@@ -33,9 +30,9 @@ os.environ.update({
33
  })
34
  print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
35
 
36
- # ==========================================================
37
- # 2️⃣ Embedding Model
38
- # ==========================================================
39
  try:
40
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
41
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
@@ -43,9 +40,9 @@ except Exception as e:
43
  print(f"⚠️ Embedding load failed ({e}), falling back to MiniLM")
44
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
45
 
46
- # ==========================================================
47
- # 3️⃣ Phi-2 LLM Setup
48
- # ==========================================================
49
  MODEL_NAME = "microsoft/phi-2"
50
  print(f"✅ Loading LLM: {MODEL_NAME}")
51
 
@@ -66,145 +63,194 @@ _answer_model = pipeline(
66
  )
67
  print("✅ Phi-2 text-generation pipeline ready (optimized).")
68
 
69
- # ==========================================================
70
- # 4️⃣ Prompt Templates (intent-aware)
71
- # ==========================================================
72
  STRICT_PROMPT = (
73
  "You are an enterprise documentation assistant.\n"
74
- "Understand the intent of the question before answering:\n"
75
- " If it asks 'how', focus only on step-by-step or procedural instructions.\n"
76
- "If it asks 'what', provide definitions or factual explanations.\n"
77
- " If it asks 'why', explain reasons or purposes.\n"
78
- "Use ONLY the provided context below to answer factually.\n"
79
- "If the answer isn’t present, reply exactly:\n"
80
- "'I don't know based on the provided document.'\n\n"
81
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
82
  )
83
 
84
  REASONING_PROMPT = (
85
- "You are an expert enterprise assistant with reasoning ability.\n"
86
- "Think carefully about the context and question intent.\n"
87
- "If it's procedural, outline steps clearly.\n"
88
- "If it's conceptual, explain in detail.\n"
89
- "Prefer factual accuracy but you may infer if clearly implied.\n"
90
  "If the document lacks the answer, say:\n"
91
- "'I don't know based on the provided document.'\n\n"
92
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
93
  )
94
 
95
- # ==========================================================
96
- # 5️⃣ Retrieve Chunks (FAISS + Rerank + Intent-weighting + Debug)
97
- # ==========================================================
98
- def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, min_similarity: float = 0.6):
99
  """
100
- Hybrid retrieval:
101
- 1️⃣ Detect query intent and embed accordingly.
102
- 2️⃣ Get semantic top-K chunks via FAISS.
103
- 3️⃣ Re-rank by cosine similarity and apply a minimum similarity filter.
104
- 4️⃣ If fewer than top_k remain, fill remaining seats with adjacent chunks (±1) for continuity.
 
 
 
 
 
105
  """
 
106
  if not index or not chunks:
107
  return []
108
 
109
  try:
110
- # 🔍 Detect and encode query intent
111
- intent_hint = ""
112
- query_type = "factual"
113
- if any(kw in query.lower() for kw in ["how", "create", "steps", "procedure", "setup", "configure"]):
114
- query_type = "procedural"
115
- intent_hint = " This is an instructional query; focus on procedure and step-by-step instructions."
116
- elif any(kw in query.lower() for kw in ["why", "reason", "purpose", "benefit"]):
117
- query_type = "conceptual"
118
- intent_hint = " This is a conceptual query; focus on rationale and explanation."
119
-
120
- print(f"🧩 Detected query type: {query_type}")
121
-
122
  q_emb = _query_model.encode(
123
- [f"query: {query.strip()}{intent_hint}"],
124
  convert_to_numpy=True,
125
  normalize_embeddings=True
126
  )[0]
127
 
128
- # Step 1️⃣ — FAISS initial retrieval
129
- distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
130
- retrieved_indices = list(indices[0])
 
 
 
 
131
 
132
- # Step 2️⃣ Compute cosine similarity for re-ranking
133
- retrieved_texts = [chunks[i] for i in retrieved_indices]
 
134
  doc_embs = _query_model.encode(
135
- [f"passage: {c}" for c in retrieved_texts],
136
  convert_to_numpy=True,
137
  normalize_embeddings=True
138
  )
139
-
140
  sims = cosine_similarity([q_emb], doc_embs)[0]
141
- ranked = sorted(zip(retrieved_indices, sims), key=lambda x: x[1], reverse=True)
142
-
143
- # Step 3️⃣ Apply minimum similarity filter
144
- filtered_indices = [idx for idx, score in ranked if score >= min_similarity]
145
-
146
- # Step 4️⃣ If not enough, add ±1 neighbors for continuity
147
- if len(filtered_indices) < top_k:
148
- expanded_indices = set(filtered_indices)
149
- for idx in filtered_indices:
150
- for neighbor in [idx - 1, idx + 1]:
151
- if 0 <= neighbor < len(chunks):
152
- expanded_indices.add(neighbor)
153
- if len(expanded_indices) >= top_k:
154
- break
155
- if len(expanded_indices) >= top_k:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  break
157
- filtered_indices = list(sorted(expanded_indices))[:top_k]
 
 
 
 
 
 
 
 
158
 
159
- # Step 5️⃣ Build final ordered list of chunks
160
- final_chunks = [chunks[i] for i in filtered_indices]
161
 
162
- print(f"✅ Retrieved {len(final_chunks)} chunks (intent-weighted, semantic + neighbor fill).")
 
 
 
163
  return final_chunks
164
 
165
  except Exception as e:
166
  print(f"⚠️ Retrieval error: {e}")
167
  return []
168
 
169
- # ==========================================================
170
- # 6️⃣ Answer Generation
171
- # ==========================================================
 
172
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
173
- """Generate concise, factual or reasoning-based answers using Phi-2."""
 
 
 
174
  if not retrieved_chunks:
175
  return "Sorry, I couldn’t find relevant information in the document."
176
 
177
- context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
 
 
 
 
 
 
178
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
179
  context=context, query=query
180
  )
181
 
182
  try:
183
- # 🧠 Adaptive length for factual mode (based on question complexity)
184
  if reasoning_mode:
185
- max_tokens = 180 # keep reasoning slightly longer
 
 
186
  else:
187
- max_tokens = 120 if len(query.split()) < 6 else 180 # short factual queries stay fast
 
 
188
 
189
  result = _answer_model(
190
  prompt,
191
- max_new_tokens=max_tokens,
192
- temperature=0.6 if reasoning_mode else 0.3,
193
- do_sample=reasoning_mode,
194
  early_stopping=True,
195
  pad_token_id=_tokenizer.eos_token_id,
196
  )
197
 
198
- text = result[0]["generated_text"].strip()
199
- return text.split("Answer:")[-1].strip() if "Answer:" in text else text
 
 
 
 
 
 
 
 
 
 
200
 
201
  except Exception as e:
202
  print(f"⚠️ Generation failed: {e}")
203
  return "⚠️ Error: Could not generate an answer."
204
 
205
- # ==========================================================
206
- # 7️⃣ Local Test
207
- # ==========================================================
208
  if __name__ == "__main__":
209
  from vectorstore import build_faiss_index
210
 
@@ -221,6 +267,6 @@ if __name__ == "__main__":
221
  index = build_faiss_index(embeddings)
222
 
223
  query = "How do I create a communication user?"
224
- retrieved = retrieve_chunks(query, index, dummy_chunks)
225
  print("🔍 Retrieved:", retrieved)
226
- print("💬 Answer:", generate_answer(query, retrieved))
 
1
  """
2
+ qa.py — Phi-2 FAST + ReRank (stable) Prefer semantic ranking, neighbor-fill last-resort
3
+ ---------------------------------------------------------------------------------------
4
+ - Uses intfloat/e5-small-v2 for embeddings
5
+ - Uses microsoft/phi-2 for generation
6
+ - Re-ranks candidate pool from FAISS then picks top_k by true cosine similarity
7
+ - Neighbor expansion only if not enough high-sim items
8
+ - Logs chunk indices + similarity scores for debugging
 
 
 
9
  """
10
 
11
  import os
 
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  import torch
17
 
18
+ print("✅ qa.py (Phi-2 FAST + ReRank stable) loaded from:", __file__)
19
 
20
+ # ---------------------------
21
+ # Cache
22
+ # ---------------------------
23
  CACHE_DIR = "/tmp/hf_cache"
24
  os.makedirs(CACHE_DIR, exist_ok=True)
25
  os.environ.update({
 
30
  })
31
  print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
32
 
33
+ # ---------------------------
34
+ # Embeddings
35
+ # ---------------------------
36
  try:
37
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
38
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
 
40
  print(f"⚠️ Embedding load failed ({e}), falling back to MiniLM")
41
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
42
 
43
+ # ---------------------------
44
+ # Phi-2 model
45
+ # ---------------------------
46
  MODEL_NAME = "microsoft/phi-2"
47
  print(f"✅ Loading LLM: {MODEL_NAME}")
48
 
 
63
  )
64
  print("✅ Phi-2 text-generation pipeline ready (optimized).")
65
 
66
+ # ---------------------------
67
+ # Prompts
68
+ # ---------------------------
69
  STRICT_PROMPT = (
70
  "You are an enterprise documentation assistant.\n"
71
+ "Use ONLY the CONTEXT chunks below to answer the QUESTION.\n"
72
+ "Cite the chunk number(s) you used, e.g. [Chunk 3].\n"
73
+ "If the document does not contain the answer, reply exactly:\n"
74
+ "\"I don't know based on the provided document.\"\n\n"
 
 
 
75
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
76
  )
77
 
78
  REASONING_PROMPT = (
79
+ "You are an expert enterprise assistant with reasoning capacity.\n"
80
+ "Prefer the provided CONTEXT but you may cautiously infer when reasonable.\n"
81
+ "If you infer, say so and prefer facts from the document.\n"
 
 
82
  "If the document lacks the answer, say:\n"
83
+ "\"I don't know based on the provided document.\"\n\n"
84
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
85
  )
86
 
87
+ # ---------------------------
88
+ # Retrieval: FAISS -> rerank -> neighbor fill (last resort)
89
+ # ---------------------------
90
+ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3, min_similarity: float = 0.55, candidate_multiplier: int = 4):
91
  """
92
+ Steps:
93
+ 1. Encode query (E5 style).
94
+ 2. Run FAISS search for k*candidate_multiplier candidates.
95
+ 3. Re-embed those candidate texts and compute cosine similarity with query embedding.
96
+ 4. Sort by similarity and pick top_k where similarity >= min_similarity.
97
+ 5. If fewer than top_k passed threshold, fill remaining slots by:
98
+ - selecting neighboring chunks around the *highest-scoring* chunk(s),
99
+ but only if absolutely necessary (keeps noise low).
100
+ Returns: ordered list of chunks (strings)
101
+ Also prints indices + similarity scores for debugging.
102
  """
103
+
104
  if not index or not chunks:
105
  return []
106
 
107
  try:
108
+ # 1. encode query
 
 
 
 
 
 
 
 
 
 
 
109
  q_emb = _query_model.encode(
110
+ [f"query: {query.strip()}"],
111
  convert_to_numpy=True,
112
  normalize_embeddings=True
113
  )[0]
114
 
115
+ # 2. FAISS initial retrieval (get a larger candidate pool)
116
+ num_candidates = max(top_k * candidate_multiplier, top_k + 2)
117
+ distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates)
118
+ candidate_indices = [int(i) for i in indices[0] if i >= 0]
119
+
120
+ # protective dedupe and clamp
121
+ candidate_indices = list(dict.fromkeys(candidate_indices)) # preserve order, unique
122
 
123
+ # 3. Re-embed candidate texts and compute true cosine similarity
124
+ candidate_texts = [chunks[i] for i in candidate_indices]
125
+ # Encode passages (passage prefix helps alignment)
126
  doc_embs = _query_model.encode(
127
+ [f"passage: {c}" for c in candidate_texts],
128
  convert_to_numpy=True,
129
  normalize_embeddings=True
130
  )
 
131
  sims = cosine_similarity([q_emb], doc_embs)[0]
132
+
133
+ # Pair up indices and sims and sort descending
134
+ paired = [(candidate_indices[i], float(sims[i])) for i in range(len(candidate_indices))]
135
+ paired_sorted = sorted(paired, key=lambda x: x[1], reverse=True)
136
+
137
+ # Debug print: top candidates and their similarity
138
+ print("🔎 Candidate ranking (index : sim):")
139
+ for idx, sim in paired_sorted[: min(len(paired_sorted), top_k * 3)]:
140
+ print(f" - Chunk {idx} : {sim:.4f}")
141
+
142
+ # 4. Pick those meeting threshold
143
+ selected = [idx for idx, sim in paired_sorted if sim >= min_similarity]
144
+
145
+ # Preserve order by similarity
146
+ selected = selected[:top_k]
147
+
148
+ # 5. If not enough, fill by neighbors around highest-scoring items
149
+ if len(selected) < top_k:
150
+ needed = top_k - len(selected)
151
+ # pick highest scoring indices as anchor(s)
152
+ anchors = [idx for idx, _ in paired_sorted[:3]] # top 3 anchors
153
+ expanded = []
154
+ for a in anchors:
155
+ # neighbors ordered by proximity: a, a-1, a+1, a-2, a+2 ...
156
+ if a not in expanded:
157
+ expanded.append(a)
158
+ offset = 1
159
+ while len(expanded) < top_k and offset < 5:
160
+ for cand in (a - offset, a + offset):
161
+ if 0 <= cand < len(chunks) and cand not in expanded:
162
+ expanded.append(cand)
163
+ if len(expanded) >= top_k:
164
+ break
165
+ offset += 1
166
+ if len(expanded) >= top_k:
167
  break
168
+ # final selected: first maintain previously selected, then add neighbors from expanded preserving order
169
+ final_order = []
170
+ for idx, _sim in paired_sorted:
171
+ if idx in selected and idx not in final_order:
172
+ final_order.append(idx)
173
+ for idx in expanded:
174
+ if idx not in final_order:
175
+ final_order.append(idx)
176
+ selected = final_order[:top_k]
177
 
178
+ # final chunk strings (ordered by selected list)
179
+ final_chunks = [chunks[i] for i in selected]
180
 
181
+ print(f"✅ retrieve_chunks: returning {len(final_chunks)} chunks (top_k={top_k}, min_sim={min_similarity})")
182
+ print(f" chunk indices: {selected}")
183
+
184
+ # Also return the indices? (if you want to display chunk numbers in UI, you can)
185
  return final_chunks
186
 
187
  except Exception as e:
188
  print(f"⚠️ Retrieval error: {e}")
189
  return []
190
 
191
+
192
+ # ---------------------------
193
+ # Answer generation
194
+ # ---------------------------
195
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
196
+ """
197
+ - reasoning_mode=False => strict factual, deterministic
198
+ - reasoning_mode=True => allow cautious inference (slower / longer)
199
+ """
200
  if not retrieved_chunks:
201
  return "Sorry, I couldn’t find relevant information in the document."
202
 
203
+ # Add chunk headings so model can cite them if needed
204
+ context_lines = []
205
+ for i, chunk in enumerate(retrieved_chunks, start=1):
206
+ # Use [Chunk i] markers — LLM will echo them when asked to cite sources
207
+ context_lines.append(f"[Chunk {i}]: {chunk.strip()}")
208
+ context = "\n".join(context_lines)
209
+
210
  prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
211
  context=context, query=query
212
  )
213
 
214
  try:
215
+ # deterministic in strict mode
216
  if reasoning_mode:
217
+ max_new_tokens = 220
218
+ temp = 0.6
219
+ do_sample = True
220
  else:
221
+ max_new_tokens = 140
222
+ temp = 0.0
223
+ do_sample = False
224
 
225
  result = _answer_model(
226
  prompt,
227
+ max_new_tokens=max_new_tokens,
228
+ temperature=temp,
229
+ do_sample=do_sample,
230
  early_stopping=True,
231
  pad_token_id=_tokenizer.eos_token_id,
232
  )
233
 
234
+ text = result[0].get("generated_text", "").strip()
235
+ # remove the prompt echo if present
236
+ if "Answer:" in text:
237
+ out = text.split("Answer:")[-1].strip()
238
+ else:
239
+ out = text
240
+
241
+ # Enforce exact fallback phrase if model tries to paraphrase missing-answer
242
+ if not reasoning_mode and ("i don't know" in out.lower() or "not present" in out.lower()):
243
+ return "I don't know based on the provided document."
244
+
245
+ return out
246
 
247
  except Exception as e:
248
  print(f"⚠️ Generation failed: {e}")
249
  return "⚠️ Error: Could not generate an answer."
250
 
251
+ # ---------------------------
252
+ # Local debug main
253
+ # ---------------------------
254
  if __name__ == "__main__":
255
  from vectorstore import build_faiss_index
256
 
 
267
  index = build_faiss_index(embeddings)
268
 
269
  query = "How do I create a communication user?"
270
+ retrieved = retrieve_chunks(query, index, dummy_chunks, top_k=3, min_similarity=0.55)
271
  print("🔍 Retrieved:", retrieved)
272
+ print("💬 Answer:", generate_answer(query, retrieved, reasoning_mode=False))