Shubham170793 commited on
Commit
d14744d
·
verified ·
1 Parent(s): cd86419

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +32 -48
src/qa.py CHANGED
@@ -1,11 +1,9 @@
1
  """
2
- qa.py — Fast Phi-2 Retrieval + Generation (Final Optimized Version)
3
  -------------------------------------------------------------------
4
- Uses:
5
- intfloat/e5-small-v2 for embeddings
6
- microsoft/phi-2 (quantized for CPU)
7
- • Reasoning toggle support (ON/OFF)
8
- Optimized for: speed + stability on Streamlit / Hugging Face Spaces
9
  """
10
 
11
  import os
@@ -13,9 +11,8 @@ import numpy as np
13
  import torch
14
  from sentence_transformers import SentenceTransformer
15
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
- from sklearn.metrics.pairwise import cosine_similarity
17
 
18
- print("✅ qa.py (Final Fast Phi-2) loaded from:", __file__)
19
 
20
  # ==========================================================
21
  # 1️⃣ Cache Setup
@@ -28,10 +25,9 @@ os.environ.update({
28
  "HF_DATASETS_CACHE": CACHE_DIR,
29
  "HF_MODULES_CACHE": CACHE_DIR
30
  })
31
- print(f"✅ Using cache dir: {CACHE_DIR}")
32
 
33
  # ==========================================================
34
- # 2️⃣ Embedding Model (fast + reliable)
35
  # ==========================================================
36
  try:
37
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
@@ -41,17 +37,18 @@ except Exception as e:
41
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
42
 
43
  # ==========================================================
44
- # 3️⃣ Phi-2 LLM Setup (Quantized + CPU Optimized)
45
  # ==========================================================
46
  try:
47
  MODEL_NAME = "microsoft/phi-2"
48
  print(f"✅ Loading LLM: {MODEL_NAME} (quantized, CPU-optimized)")
49
 
50
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
 
51
  _model = AutoModelForCausalLM.from_pretrained(
52
  MODEL_NAME,
53
  cache_dir=CACHE_DIR,
54
- torch_dtype=torch.bfloat16 if not torch.cuda.is_available() else torch.float16,
55
  low_cpu_mem_usage=True,
56
  ).to("cpu")
57
 
@@ -63,34 +60,28 @@ try:
63
  model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True},
64
  )
65
 
66
- print("✅ Phi-2 pipeline ready (optimized).")
 
67
  except Exception as e:
68
  print(f"⚠️ Phi-2 load failed: {e}")
69
  _answer_model = None
70
 
71
  # ==========================================================
72
- # 4️⃣ Prompt Templates
73
  # ==========================================================
74
- STRICT_PROMPT = (
75
- "You are an expert enterprise assistant.\n"
76
  "Use ONLY the context below to answer the question clearly and factually.\n"
77
- "If the answer isn’t found in the context, reply exactly:\n"
78
  "'I don't know based on the provided document.'\n\n"
79
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
80
  )
81
 
82
- REASONING_PROMPT = (
83
- "You are a reasoning-enabled enterprise assistant.\n"
84
- "Use the CONTEXT below and your own reasoning ability to explain the answer clearly and logically.\n"
85
- "If the answer isn’t explicit, infer based on context and domain understanding.\n\n"
86
- "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
87
- )
88
-
89
  # ==========================================================
90
- # 5️⃣ Retrieve Top-K Chunks (Balanced speed)
91
  # ==========================================================
92
- def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
93
- """Retrieve top-K relevant chunks and re-rank by cosine similarity for better precision."""
94
  if not index or not chunks:
95
  return []
96
 
@@ -98,46 +89,40 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
98
  q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
99
  distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
100
 
101
- # Compute similarity scores for re-ranking
102
- candidates = [chunks[i] for i in indices[0]]
103
- cand_vecs = _query_model.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
104
- sims = cosine_similarity([q_emb], cand_vecs)[0]
105
 
106
- # Return top-K most semantically aligned
107
- top_indices = np.argsort(sims)[::-1][:top_k]
108
- return [candidates[i] for i in top_indices]
109
 
110
  except Exception as e:
111
  print(f"⚠️ Retrieval error: {e}")
112
  return []
113
 
114
-
115
  # ==========================================================
116
- # 6️⃣ Generate Answer (Reasoning or Strict Mode)
117
  # ==========================================================
118
- def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = True):
119
- """Generate concise, context-grounded answers using Phi-2."""
120
  if not retrieved_chunks:
121
  return "Sorry, I couldn’t find relevant information in the document."
122
 
123
- # Keep short context for faster inference
124
- context = "\n".join(chunk.strip() for chunk in retrieved_chunks[:5])
125
- prompt_template = REASONING_PROMPT if reasoning_mode else STRICT_PROMPT
126
- prompt = prompt_template.format(context=context, query=query)
127
 
128
  try:
129
  result = _answer_model(
130
  prompt,
131
- max_new_tokens=140, # fast but coherent answers
132
  do_sample=False,
133
  early_stopping=True,
134
  pad_token_id=_tokenizer.eos_token_id,
135
  )
136
-
137
  answer = result[0]["generated_text"].strip()
138
  if "Answer:" in answer:
139
  answer = answer.split("Answer:")[-1].strip()
140
-
141
  return answer
142
 
143
  except Exception as e:
@@ -145,11 +130,10 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = T
145
  return "⚠️ Error: Could not generate an answer at the moment."
146
 
147
  # ==========================================================
148
- # 7️⃣ Local Test (Optional)
149
  # ==========================================================
150
  if __name__ == "__main__":
151
  from vectorstore import build_faiss_index
152
-
153
  dummy_chunks = [
154
  "Step 1: Open the dashboard and navigate to reports.",
155
  "Step 2: Click 'Export' to download a CSV summary.",
@@ -164,4 +148,4 @@ if __name__ == "__main__":
164
  query = "What are the steps to export a report?"
165
  retrieved = retrieve_chunks(query, index, dummy_chunks)
166
  print("🔍 Retrieved:", retrieved)
167
- print("💬 Answer:", generate_answer(query, retrieved, reasoning_mode=True))
 
1
  """
2
+ qa.py — Optimized Phi-2 Retrieval + Generation (Stable Fast Baseline)
3
  -------------------------------------------------------------------
4
+ ✅ Best balance of speed + accuracy
5
+ Works perfectly on CPU (quantized)
6
+ Non-hallucinating (document-strict)
 
 
7
  """
8
 
9
  import os
 
11
  import torch
12
  from sentence_transformers import SentenceTransformer
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
14
 
15
+ print("✅ qa.py (FAST BASELINE) loaded from:", __file__)
16
 
17
  # ==========================================================
18
  # 1️⃣ Cache Setup
 
25
  "HF_DATASETS_CACHE": CACHE_DIR,
26
  "HF_MODULES_CACHE": CACHE_DIR
27
  })
 
28
 
29
  # ==========================================================
30
+ # 2️⃣ Embedding Model
31
  # ==========================================================
32
  try:
33
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
 
37
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
38
 
39
  # ==========================================================
40
+ # 3️⃣ Phi-2 LLM Setup (Quantized for CPU)
41
  # ==========================================================
42
  try:
43
  MODEL_NAME = "microsoft/phi-2"
44
  print(f"✅ Loading LLM: {MODEL_NAME} (quantized, CPU-optimized)")
45
 
46
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
47
+
48
  _model = AutoModelForCausalLM.from_pretrained(
49
  MODEL_NAME,
50
  cache_dir=CACHE_DIR,
51
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16,
52
  low_cpu_mem_usage=True,
53
  ).to("cpu")
54
 
 
60
  model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True},
61
  )
62
 
63
+ print("✅ Phi-2 text-generation pipeline ready (optimized).")
64
+
65
  except Exception as e:
66
  print(f"⚠️ Phi-2 load failed: {e}")
67
  _answer_model = None
68
 
69
  # ==========================================================
70
+ # 4️⃣ Prompt Template
71
  # ==========================================================
72
+ PROMPT_TEMPLATE = (
73
+ "You are an expert assistant for enterprise document understanding.\n"
74
  "Use ONLY the context below to answer the question clearly and factually.\n"
75
+ "If the context doesn’t contain the answer, reply exactly:\n"
76
  "'I don't know based on the provided document.'\n\n"
77
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
78
  )
79
 
 
 
 
 
 
 
 
80
  # ==========================================================
81
+ # 5️⃣ Retrieve Top-K Chunks (Simple + Fast)
82
  # ==========================================================
83
+ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
84
+ """Efficient FAISS retrieval using cosine similarity."""
85
  if not index or not chunks:
86
  return []
87
 
 
89
  q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
90
  distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
91
 
92
+ selected = set()
93
+ for idx in indices[0]:
94
+ for i in range(max(0, idx - 1), min(len(chunks), idx + 2)):
95
+ selected.add(i)
96
 
97
+ ordered_chunks = [chunks[i] for i in sorted(selected)]
98
+ return ordered_chunks
 
99
 
100
  except Exception as e:
101
  print(f"⚠️ Retrieval error: {e}")
102
  return []
103
 
 
104
  # ==========================================================
105
+ # 6️⃣ Generate Answer (Fast)
106
  # ==========================================================
107
+ def generate_answer(query: str, retrieved_chunks: list):
108
+ """Generate concise, grounded answers using Phi-2."""
109
  if not retrieved_chunks:
110
  return "Sorry, I couldn’t find relevant information in the document."
111
 
112
+ context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
113
+ prompt = PROMPT_TEMPLATE.format(context=context, query=query)
 
 
114
 
115
  try:
116
  result = _answer_model(
117
  prompt,
118
+ max_new_tokens=120, # lower for faster completion
119
  do_sample=False,
120
  early_stopping=True,
121
  pad_token_id=_tokenizer.eos_token_id,
122
  )
 
123
  answer = result[0]["generated_text"].strip()
124
  if "Answer:" in answer:
125
  answer = answer.split("Answer:")[-1].strip()
 
126
  return answer
127
 
128
  except Exception as e:
 
130
  return "⚠️ Error: Could not generate an answer at the moment."
131
 
132
  # ==========================================================
133
+ # 7️⃣ Local Test
134
  # ==========================================================
135
  if __name__ == "__main__":
136
  from vectorstore import build_faiss_index
 
137
  dummy_chunks = [
138
  "Step 1: Open the dashboard and navigate to reports.",
139
  "Step 2: Click 'Export' to download a CSV summary.",
 
148
  query = "What are the steps to export a report?"
149
  retrieved = retrieve_chunks(query, index, dummy_chunks)
150
  print("🔍 Retrieved:", retrieved)
151
+ print("💬 Answer:", generate_answer(query, retrieved))