Shubham170793 commited on
Commit
f384f96
·
verified ·
1 Parent(s): 8f19216

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +41 -25
src/qa.py CHANGED
@@ -1,10 +1,11 @@
1
  """
2
- qa.py — FAST Phi-2 with Re-ranking + Similarity Threshold
3
- ---------------------------------------------------------
4
  ✅ Optimized for Hugging Face Spaces & Streamlit
5
- Uses intfloat/e5-small-v2 embeddings
6
- Uses microsoft/phi-2 for generation (fast CPU-optimized)
7
- Includes re-ranking and semantic similarity filtering
 
8
  """
9
 
10
  import os
@@ -14,7 +15,7 @@ from sklearn.metrics.pairwise import cosine_similarity
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
15
  import torch
16
 
17
- print("✅ qa.py (Phi-2 Fast + Rerank + Similarity Filter) loaded from:", __file__)
18
 
19
  # ==========================================================
20
  # 1️⃣ Hugging Face Cache Setup
@@ -66,9 +67,9 @@ except Exception as e:
66
  _answer_model = None
67
 
68
  # ==========================================================
69
- # 4️⃣ Prompt Template (Concise & Factual)
70
  # ==========================================================
71
- PROMPT_TEMPLATE = (
72
  "You are an assistant for enterprise documentation.\n"
73
  "Answer the question based ONLY on the context below.\n"
74
  "If the answer is not in the context, reply exactly:\n"
@@ -76,22 +77,30 @@ PROMPT_TEMPLATE = (
76
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
77
  )
78
 
 
 
 
 
 
 
 
 
79
  # ==========================================================
80
- # 5️⃣ Retrieve Chunks — FAISS + Rerank + Similarity Filter
81
  # ==========================================================
82
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, min_similarity: float = 0.6):
83
  """
84
  Retrieves top-K relevant chunks with re-ranking and similarity threshold filtering.
85
  Steps:
86
  1️⃣ Use FAISS to get approximate top candidates.
87
- 2️⃣ Re-rank those by cosine similarity.
88
- 3️⃣ Filter out low-similarity chunks (below min_similarity).
89
  """
90
  if not index or not chunks:
91
  return []
92
 
93
  try:
94
- # --- Encode the query ---
95
  q_emb = _query_model.encode(
96
  [f"query: {query.strip()}"],
97
  convert_to_numpy=True,
@@ -127,25 +136,26 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, min_similar
127
  return []
128
 
129
  # ==========================================================
130
- # 6️⃣ Answer Generation (Fast & Deterministic)
131
  # ==========================================================
132
- def generate_answer(query: str, retrieved_chunks: list):
133
  """
134
- Generates a concise factual answer using Phi-2.
135
- Retrieval should already be clean & re-ranked.
 
136
  """
137
  if not retrieved_chunks:
138
  return "Sorry, I couldn’t find relevant information in the document."
139
 
140
  context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
141
- prompt = PROMPT_TEMPLATE.format(context=context, query=query)
142
 
143
  try:
144
  result = _answer_model(
145
  prompt,
146
- max_new_tokens=120, # Keep short for speed
147
- temperature=0.2,
148
- do_sample=False, # Deterministic for strict mode
149
  pad_token_id=_tokenizer.eos_token_id,
150
  )
151
  answer = result[0]["generated_text"].strip()
@@ -164,20 +174,26 @@ def generate_answer(query: str, retrieved_chunks: list):
164
  # ==========================================================
165
  if __name__ == "__main__":
166
  from vectorstore import build_faiss_index
 
 
167
  dummy_chunks = [
168
  "Step 1: Open the dashboard and navigate to reports.",
169
  "Step 2: Click 'Export' to download a CSV summary.",
170
  "Step 3: Review the generated report in your downloads folder."
171
  ]
 
172
  embeddings = [
173
  _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
174
  for chunk in dummy_chunks
175
  ]
176
- import faiss
177
- index = faiss.IndexFlatL2(embeddings[0].shape[0])
 
178
  index.add(np.array(embeddings).astype("float32"))
179
 
180
- query = "How do I export a report?"
181
  retrieved = retrieve_chunks(query, index, dummy_chunks, top_k=3, min_similarity=0.6)
182
- print("🔍 Retrieved:", retrieved)
183
- print("💬 Answer:", generate_answer(query, retrieved))
 
 
 
1
  """
2
+ qa.py — Phi-2 Hybrid (Fast + Reasoning) with Rerank & Similarity Filtering
3
+ --------------------------------------------------------------------------
4
  ✅ Optimized for Hugging Face Spaces & Streamlit
5
+ ✅ intfloat/e5-small-v2 for embeddings
6
+ ✅ microsoft/phi-2 for generation (fast CPU-optimized)
7
+ Re-ranking + minimum similarity threshold for clean retrieval
8
+ ✅ reasoning_mode toggle for deeper answers
9
  """
10
 
11
  import os
 
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  import torch
17
 
18
+ print("✅ qa.py (Phi-2 Hybrid + Rerank + Similarity Filter) loaded from:", __file__)
19
 
20
  # ==========================================================
21
  # 1️⃣ Hugging Face Cache Setup
 
67
  _answer_model = None
68
 
69
  # ==========================================================
70
+ # 4️⃣ Prompt Templates
71
  # ==========================================================
72
+ STRICT_PROMPT = (
73
  "You are an assistant for enterprise documentation.\n"
74
  "Answer the question based ONLY on the context below.\n"
75
  "If the answer is not in the context, reply exactly:\n"
 
77
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
78
  )
79
 
80
+ REASONING_PROMPT = (
81
+ "You are an expert enterprise assistant.\n"
82
+ "Carefully reason about the following context and provide a detailed, step-by-step answer.\n"
83
+ "If the context does not provide enough information, you may make cautious inferences based on logical reasoning.\n"
84
+ "However, always note when you are inferring beyond the text.\n\n"
85
+ "Context:\n{context}\n\nQuestion: {query}\n\nReasoning and Answer:"
86
+ )
87
+
88
  # ==========================================================
89
+ # 5️⃣ Retrieve Chunks — FAISS + Re-rank + Similarity Filter
90
  # ==========================================================
91
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, min_similarity: float = 0.6):
92
  """
93
  Retrieves top-K relevant chunks with re-ranking and similarity threshold filtering.
94
  Steps:
95
  1️⃣ Use FAISS to get approximate top candidates.
96
+ 2️⃣ Re-rank them by cosine similarity with the query.
97
+ 3️⃣ Filter out low-similarity chunks below min_similarity.
98
  """
99
  if not index or not chunks:
100
  return []
101
 
102
  try:
103
+ # --- Encode query ---
104
  q_emb = _query_model.encode(
105
  [f"query: {query.strip()}"],
106
  convert_to_numpy=True,
 
136
  return []
137
 
138
  # ==========================================================
139
+ # 6️⃣ Answer Generation (Fast / Reasoning Hybrid)
140
  # ==========================================================
141
+ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
142
  """
143
+ Generates concise or reasoning-rich answers using Phi-2.
144
+ reasoning_mode=True longer, more explanatory (slower)
145
+ reasoning_mode=False → short factual (fast)
146
  """
147
  if not retrieved_chunks:
148
  return "Sorry, I couldn’t find relevant information in the document."
149
 
150
  context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
151
+ prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
152
 
153
  try:
154
  result = _answer_model(
155
  prompt,
156
+ max_new_tokens=200 if reasoning_mode else 120,
157
+ temperature=0.6 if reasoning_mode else 0.2,
158
+ do_sample=reasoning_mode,
159
  pad_token_id=_tokenizer.eos_token_id,
160
  )
161
  answer = result[0]["generated_text"].strip()
 
174
  # ==========================================================
175
  if __name__ == "__main__":
176
  from vectorstore import build_faiss_index
177
+ import faiss
178
+
179
  dummy_chunks = [
180
  "Step 1: Open the dashboard and navigate to reports.",
181
  "Step 2: Click 'Export' to download a CSV summary.",
182
  "Step 3: Review the generated report in your downloads folder."
183
  ]
184
+
185
  embeddings = [
186
  _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
187
  for chunk in dummy_chunks
188
  ]
189
+
190
+ dim = embeddings[0].shape[0]
191
+ index = faiss.IndexFlatL2(dim)
192
  index.add(np.array(embeddings).astype("float32"))
193
 
194
+ query = "How to export a report?"
195
  retrieved = retrieve_chunks(query, index, dummy_chunks, top_k=3, min_similarity=0.6)
196
+
197
+ print("\n🔍 Retrieved chunks:", retrieved)
198
+ print("\n💬 FAST Answer:", generate_answer(query, retrieved, reasoning_mode=False))
199
+ print("\n🧠 REASONING Answer:", generate_answer(query, retrieved, reasoning_mode=True))