Shubham170793 commited on
Commit
43cd83d
·
verified ·
1 Parent(s): 885d81f

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +82 -84
src/qa.py CHANGED
@@ -1,25 +1,22 @@
1
  """
2
- qa.py — Phi-2 Hybrid Mode (Reasoning + Strict)
3
- -------------------------------------
4
- Handles:
5
- Query embedding (SentenceTransformer / E5-small-v2)
6
- Chunk retrieval (FAISS)
7
- Answer generation (Phi-2, with toggleable reasoning)
8
- Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
11
  import os
12
  import numpy as np
13
- import torch
14
  from sentence_transformers import SentenceTransformer
15
- from sklearn.metrics.pairwise import cosine_similarity
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
- from vectorstore import search_faiss
18
 
19
- print("✅ qa.py (Phi-2 Hybrid Mode) loaded from:", __file__)
20
 
21
  # ==========================================================
22
- # 1️⃣ Hugging Face Cache Setup
23
  # ==========================================================
24
  CACHE_DIR = "/tmp/hf_cache"
25
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -32,22 +29,17 @@ os.environ.update({
32
  print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
33
 
34
  # ==========================================================
35
- # 2️⃣ Speed Tweaks for CPU
36
- # ==========================================================
37
- torch.set_num_threads(2) # Limit CPU threads for faster execution
38
-
39
- # ==========================================================
40
- # 3️⃣ Query Embedding Model
41
  # ==========================================================
42
  try:
43
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
44
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
45
  except Exception as e:
46
- print(f"⚠️ Query model load failed ({e}), using fallback MiniLM.")
47
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
48
 
49
  # ==========================================================
50
- # 4️⃣ LLM Setup — Phi-2 (Optimized)
51
  # ==========================================================
52
  MODEL_NAME = "microsoft/phi-2"
53
  print(f"✅ Loading LLM: {MODEL_NAME}")
@@ -57,103 +49,109 @@ _model = AutoModelForCausalLM.from_pretrained(
57
  MODEL_NAME,
58
  cache_dir=CACHE_DIR,
59
  torch_dtype="auto",
60
- device_map="auto"
 
 
 
 
 
 
61
  )
62
- _answer_model = pipeline("text-generation", model=_model, tokenizer=_tokenizer, device_map="auto")
63
  print("✅ Phi-2 generation pipeline ready.")
64
 
65
  # ==========================================================
66
- # 5️⃣ Prompt Templates (Two Modes)
67
  # ==========================================================
68
- STRICT_PROMPT = (
69
- "You are a factual assistant. Use ONLY the CONTEXT below to answer. "
70
- "If the answer is not explicitly in the context, say exactly: "
71
- "'I don't know based on the provided document.'\n\n"
72
- "CONTEXT:\n{context}\n\nQUESTION: {query}\nANSWER:"
73
- )
74
 
75
- REASONING_PROMPT = (
76
- "You are an intelligent assistant. Use the CONTEXT below and your general knowledge "
77
- "to provide the most complete and helpful answer. If unsure, say 'I don't know.'\n\n"
78
- "CONTEXT:\n{context}\n\nQUESTION: {query}\nANSWER:"
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # ==========================================================
82
- # 6️⃣ Chunk Retrieval Function
83
  # ==========================================================
84
- def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
85
- """Retrieve top-K relevant chunks quickly using FAISS."""
86
  if not index or not chunks:
87
  return []
88
-
89
- try:
90
- query_emb = _query_model.encode(
91
- [f"query: {query.strip()}"],
92
- convert_to_numpy=True,
93
- normalize_embeddings=True
94
- )[0]
95
- distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k)
96
- return [chunks[i] for i in indices[0]]
97
- except Exception as e:
98
- print(f"⚠️ Retrieval error: {e}")
99
- return []
100
 
101
  # ==========================================================
102
- # 7️⃣ Answer Generation Function (with Mode Toggle)
103
  # ==========================================================
104
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = True):
105
- """
106
- Generates answers using Phi-2.
107
- reasoning_mode=True → reasoning + external knowledge
108
- reasoning_mode=False → strict chunk-only factual mode
109
- """
110
  if not retrieved_chunks:
111
  return "Sorry, I couldn’t find relevant information in the document."
112
 
113
- # Merge retrieved context
114
  context = "\n".join([chunk.strip() for chunk in retrieved_chunks])
115
-
116
- # Select prompt based on mode
117
- prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
118
- context=context, query=query
119
- )
120
 
121
  try:
122
- # ⚡ Speed-optimized generation
123
  result = _answer_model(
124
  prompt,
125
- max_new_tokens=140 if reasoning_mode else 100, # shorter output = faster
126
- temperature=0.3 if reasoning_mode else 0.1, # balanced creativity
127
- do_sample=False, # ✅ greedy decoding = fastest
128
- repetition_penalty=1.1, # avoids repetitive phrasing
 
129
  )
130
-
131
- # Cleanly extract the answer
132
- answer = result[0]["generated_text"].split("ANSWER:")[-1].strip()
133
-
134
- # Safety: truncate overly long rambles
135
- if len(answer.split()) > 150:
136
- answer = " ".join(answer.split()[:150]) + "..."
137
-
138
- return answer
139
-
140
  except Exception as e:
141
  print(f"⚠️ Generation failed: {e}")
142
  return "⚠️ Error: Could not generate an answer."
143
 
144
  # ==========================================================
145
- # 8️⃣ Local Test (Optional)
146
  # ==========================================================
147
  if __name__ == "__main__":
 
 
148
  dummy_chunks = [
149
  "Step 1: Open the dashboard and navigate to reports.",
150
- "Step 2: Click 'Export' to download a CSV summary."
 
151
  ]
152
- from vectorstore import build_faiss_index
153
  index = build_faiss_index([
154
- _query_model.encode([f"passage: {c}"], convert_to_numpy=True, normalize_embeddings=True)[0]
155
- for c in dummy_chunks
156
  ])
157
- query = "How to export a report?"
158
- print("💬 Strict:", generate_answer(query, dummy_chunks, reasoning_mode=False))
159
- print("💬 Reasoning:", generate_answer(query, dummy_chunks, reasoning_mode=True))
 
 
 
1
  """
2
+ qa.py — Fast, Reasoning-Enabled Phi-2 Version
3
+ ----------------------------------------------
4
+ • Uses SentenceTransformer (E5-small) for embeddings
5
+ Uses microsoft/phi-2 for generation
6
+ Retains reasoning vs factual modes
7
+ Optimized for speed and low VRAM on CPU
 
8
  """
9
 
10
  import os
11
  import numpy as np
 
12
  from sentence_transformers import SentenceTransformer
 
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
+ from sklearn.metrics.pairwise import cosine_similarity
15
 
16
+ print("✅ qa.py (Phi-2 optimized) loaded from:", __file__)
17
 
18
  # ==========================================================
19
+ # Hugging Face Cache Setup
20
  # ==========================================================
21
  CACHE_DIR = "/tmp/hf_cache"
22
  os.makedirs(CACHE_DIR, exist_ok=True)
 
29
  print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
30
 
31
  # ==========================================================
32
+ # Query Embedding Model
 
 
 
 
 
33
  # ==========================================================
34
  try:
35
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
36
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
37
  except Exception as e:
38
+ print(f"⚠️ Fallback to MiniLM due to {e}")
39
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
40
 
41
  # ==========================================================
42
+ # Phi-2 Model (Causal LM)
43
  # ==========================================================
44
  MODEL_NAME = "microsoft/phi-2"
45
  print(f"✅ Loading LLM: {MODEL_NAME}")
 
49
  MODEL_NAME,
50
  cache_dir=CACHE_DIR,
51
  torch_dtype="auto",
52
+ low_cpu_mem_usage=True
53
+ )
54
+ _answer_model = pipeline(
55
+ "text-generation",
56
+ model=_model,
57
+ tokenizer=_tokenizer,
58
+ device=-1 # CPU-compatible
59
  )
 
60
  print("✅ Phi-2 generation pipeline ready.")
61
 
62
  # ==========================================================
63
+ # Prompt Templates
64
  # ==========================================================
65
+ REASONING_PROMPT = """
66
+ You are an intelligent enterprise assistant.
67
+ Use the CONTEXT below and your general understanding to answer the QUESTION logically and clearly.
68
+ Explain your reasoning briefly if helpful.
 
 
69
 
70
+ ---
71
+ CONTEXT:
72
+ {context}
73
+ ---
74
+ QUESTION:
75
+ {query}
76
+ ---
77
+ ANSWER:
78
+ """
79
+
80
+ STRICT_PROMPT = """
81
+ You are an enterprise document assistant.
82
+ Use ONLY the CONTEXT below to answer the QUESTION clearly and factually.
83
+ If the answer is not found in the context, reply exactly:
84
+ "I don't know based on the provided document."
85
+
86
+ ---
87
+ CONTEXT:
88
+ {context}
89
+ ---
90
+ QUESTION:
91
+ {query}
92
+ ---
93
+ ANSWER:
94
+ """
95
 
96
  # ==========================================================
97
+ # Retrieve Chunks
98
  # ==========================================================
99
+ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
100
+ """Retrieve top-K most relevant chunks quickly (no re-ranking for speed)."""
101
  if not index or not chunks:
102
  return []
103
+ query_emb = _query_model.encode(
104
+ [f"query: {query.strip()}"],
105
+ convert_to_numpy=True,
106
+ normalize_embeddings=True
107
+ )[0]
108
+ distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k)
109
+ return [chunks[i] for i in indices[0]]
 
 
 
 
 
110
 
111
  # ==========================================================
112
+ # Generate Answer (Phi-2)
113
  # ==========================================================
114
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = True):
115
+ """Generate answers using Phi-2. Supports reasoning or strict factual modes."""
 
 
 
 
116
  if not retrieved_chunks:
117
  return "Sorry, I couldn’t find relevant information in the document."
118
 
 
119
  context = "\n".join([chunk.strip() for chunk in retrieved_chunks])
120
+ prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
 
 
 
 
121
 
122
  try:
 
123
  result = _answer_model(
124
  prompt,
125
+ max_new_tokens=180, # keeps output short & fast
126
+ temperature=0.4 if reasoning_mode else 0.2,
127
+ do_sample=False, # deterministic
128
+ num_beams=1, # no beam search for speed
129
+ early_stopping=True,
130
  )
131
+ text = result[0]["generated_text"].split("ANSWER:")[-1].strip()
132
+ return text
 
 
 
 
 
 
 
 
133
  except Exception as e:
134
  print(f"⚠️ Generation failed: {e}")
135
  return "⚠️ Error: Could not generate an answer."
136
 
137
  # ==========================================================
138
+ # Local Test (optional)
139
  # ==========================================================
140
  if __name__ == "__main__":
141
+ from vectorstore import build_faiss_index
142
+
143
  dummy_chunks = [
144
  "Step 1: Open the dashboard and navigate to reports.",
145
+ "Step 2: Click 'Export' to download a CSV summary.",
146
+ "Step 3: Review the generated report in your downloads folder."
147
  ]
148
+
149
  index = build_faiss_index([
150
+ _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
151
+ for chunk in dummy_chunks
152
  ])
153
+
154
+ query = "What are the steps to export a report?"
155
+ retrieved = retrieve_chunks(query, index, dummy_chunks)
156
+ print("🔍 Retrieved:", retrieved)
157
+ print("💬 Answer:", generate_answer(query, retrieved))