Shubham170793 commited on
Commit
6718956
·
verified ·
1 Parent(s): 74cc3b2

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +76 -91
src/qa.py CHANGED
@@ -1,23 +1,25 @@
1
  """
2
- qa.py — Optimized Phi-2 Retrieval + Generation
3
- ----------------------------------------------
4
- Uses:
5
- intfloat/e5-small-v2 for embeddings
6
- microsoft/phi-2 for reasoning-rich generation (fast on CPU)
7
- Optimized for: speed + stability in Streamlit / Hugging Face Spaces
 
8
  """
9
 
10
  import os
11
  import numpy as np
 
12
  from sentence_transformers import SentenceTransformer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
15
- import torch
16
 
17
- print("✅ qa.py (Phi-2 optimized fast) loaded from:", __file__)
18
 
19
  # ==========================================================
20
- # 1️⃣ Cache Setup
21
  # ==========================================================
22
  CACHE_DIR = "/tmp/hf_cache"
23
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -27,132 +29,115 @@ os.environ.update({
27
  "HF_DATASETS_CACHE": CACHE_DIR,
28
  "HF_MODULES_CACHE": CACHE_DIR
29
  })
 
 
 
 
 
 
30
 
31
  # ==========================================================
32
- # 2️⃣ Embedding Model
33
  # ==========================================================
34
  try:
35
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
36
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
37
  except Exception as e:
38
- print(f"⚠️ Fallback to MiniLM due to {e}")
39
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
40
 
41
  # ==========================================================
42
- # 3️⃣ Phi-2 LLM Setup (Quantized for CPU)
43
  # ==========================================================
44
- try:
45
- MODEL_NAME = "microsoft/phi-2"
46
- print(f"✅ Loading LLM: {MODEL_NAME} (quantized, CPU-optimized)")
47
-
48
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
49
-
50
- # ✅ Load model in mixed precision for 4–6× faster inference
51
- _model = AutoModelForCausalLM.from_pretrained(
52
- MODEL_NAME,
53
- cache_dir=CACHE_DIR,
54
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16,
55
- low_cpu_mem_usage=True,
56
- ).to("cpu")
57
-
58
- # ✅ Create generation pipeline (keep in memory)
59
- _answer_model = pipeline(
60
- "text-generation",
61
- model=_model,
62
- tokenizer=_tokenizer,
63
- device=-1,
64
- model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True},
65
- )
66
-
67
- print("✅ Phi-2 text-generation pipeline ready (optimized).")
68
-
69
- except Exception as e:
70
- print(f"⚠️ Phi-2 load failed: {e}")
71
- _answer_model = None
72
 
73
  # ==========================================================
74
- # 4️⃣ Prompt Template (Balanced Mode — quality + speed)
75
  # ==========================================================
76
- PROMPT_TEMPLATE = (
77
- "You are a helpful enterprise document assistant. "
78
- "Use ONLY the following context to answer the question clearly and factually. "
79
- "If the information is missing, say exactly: 'I don't know based on the provided document.'\n\n"
80
- "Keep your answer concise (2–5 sentences) but ensure it covers all relevant details.\n\n"
81
- "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
 
 
 
 
 
82
  )
83
 
84
  # ==========================================================
85
- # 5️⃣ Retrieve Top-K Chunks
86
  # ==========================================================
87
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
88
- """Efficient FAISS retrieval using cosine similarity."""
89
  if not index or not chunks:
90
  return []
91
 
92
  try:
93
- q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
94
- distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
95
-
96
- selected = set()
97
- for idx in indices[0]:
98
- for i in range(max(0, idx - 1), min(len(chunks), idx + 2)):
99
- selected.add(i)
100
-
101
- ordered_chunks = [chunks[i] for i in sorted(selected)]
102
- return ordered_chunks
103
  except Exception as e:
104
  print(f"⚠️ Retrieval error: {e}")
105
  return []
106
 
107
  # ==========================================================
108
- # 6️⃣ Answer Generation (fast)
109
  # ==========================================================
110
- def generate_answer(query: str, retrieved_chunks: list):
111
- """Generate concise, grounded answers using Phi-2."""
 
 
 
 
112
  if not retrieved_chunks:
113
  return "Sorry, I couldn’t find relevant information in the document."
114
 
115
- context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
116
- prompt = PROMPT_TEMPLATE.format(context=context, query=query)
117
 
118
  try:
119
- # ✅ Limit tokens to speed up inference
120
  result = _answer_model(
121
  prompt,
122
- max_new_tokens=120, # reduced for faster completion
 
123
  do_sample=False,
124
- early_stopping=True,
125
- pad_token_id=_tokenizer.eos_token_id,
126
  )
127
- answer = result[0]["generated_text"].strip()
128
-
129
- # Clean excessive prompt echo
130
- if "Answer:" in answer:
131
- answer = answer.split("Answer:")[-1].strip()
132
-
133
- return answer
134
-
135
  except Exception as e:
136
  print(f"⚠️ Generation failed: {e}")
137
- return "⚠️ Error: Could not generate an answer at the moment."
138
 
139
  # ==========================================================
140
- # 7️⃣ Local Test
141
  # ==========================================================
142
  if __name__ == "__main__":
143
- from vectorstore import build_faiss_index
144
  dummy_chunks = [
145
  "Step 1: Open the dashboard and navigate to reports.",
146
- "Step 2: Click 'Export' to download a CSV summary.",
147
- "Step 3: Review the generated report in your downloads folder."
148
  ]
149
- embeddings = [
150
- _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
151
- for chunk in dummy_chunks
152
- ]
153
- index = build_faiss_index(embeddings)
154
-
155
- query = "What are the steps to export a report?"
156
- retrieved = retrieve_chunks(query, index, dummy_chunks)
157
- print("🔍 Retrieved:", retrieved)
158
- print("💬 Answer:", generate_answer(query, retrieved))
 
1
  """
2
+ qa.py — Phi-2 Hybrid Mode (Reasoning + Strict)
3
+ -------------------------------------
4
+ Handles:
5
+ Query embedding (SentenceTransformer / E5-small-v2)
6
+ Chunk retrieval (FAISS)
7
+ Answer generation (Phi-2, with toggleable reasoning)
8
+ Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
11
  import os
12
  import numpy as np
13
+ import torch
14
  from sentence_transformers import SentenceTransformer
15
  from sklearn.metrics.pairwise import cosine_similarity
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
+ from vectorstore import search_faiss
18
 
19
+ print("✅ qa.py (Phi-2 Hybrid Mode) loaded from:", __file__)
20
 
21
  # ==========================================================
22
+ # 1️⃣ Hugging Face Cache Setup
23
  # ==========================================================
24
  CACHE_DIR = "/tmp/hf_cache"
25
  os.makedirs(CACHE_DIR, exist_ok=True)
 
29
  "HF_DATASETS_CACHE": CACHE_DIR,
30
  "HF_MODULES_CACHE": CACHE_DIR
31
  })
32
+ print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
33
+
34
+ # ==========================================================
35
+ # 2️⃣ Speed Tweaks for CPU
36
+ # ==========================================================
37
+ torch.set_num_threads(2) # Limit CPU threads for faster execution
38
 
39
  # ==========================================================
40
+ # 3️⃣ Query Embedding Model
41
  # ==========================================================
42
  try:
43
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
44
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
45
  except Exception as e:
46
+ print(f"⚠️ Query model load failed ({e}), using fallback MiniLM.")
47
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
48
 
49
  # ==========================================================
50
+ # 4️⃣ LLM Setup Phi-2 (Optimized)
51
  # ==========================================================
52
+ MODEL_NAME = "microsoft/phi-2"
53
+ print(f"✅ Loading LLM: {MODEL_NAME}")
54
+
55
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
56
+ _model = AutoModelForCausalLM.from_pretrained(
57
+ MODEL_NAME,
58
+ cache_dir=CACHE_DIR,
59
+ torch_dtype="auto",
60
+ device_map="auto"
61
+ )
62
+ _answer_model = pipeline("text-generation", model=_model, tokenizer=_tokenizer, device_map="auto")
63
+ print("✅ Phi-2 generation pipeline ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # ==========================================================
66
+ # 5️⃣ Prompt Templates (Two Modes)
67
  # ==========================================================
68
+ STRICT_PROMPT = (
69
+ "You are a factual assistant. Use ONLY the CONTEXT below to answer. "
70
+ "If the answer is not explicitly in the context, say exactly: "
71
+ "'I don't know based on the provided document.'\n\n"
72
+ "CONTEXT:\n{context}\n\nQUESTION: {query}\nANSWER:"
73
+ )
74
+
75
+ REASONING_PROMPT = (
76
+ "You are an intelligent assistant. Use the CONTEXT below and your general knowledge "
77
+ "to provide the most complete and helpful answer. If unsure, say 'I don't know.'\n\n"
78
+ "CONTEXT:\n{context}\n\nQUESTION: {query}\nANSWER:"
79
  )
80
 
81
  # ==========================================================
82
+ # 6️⃣ Chunk Retrieval Function
83
  # ==========================================================
84
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
85
+ """Retrieve top-K relevant chunks quickly using FAISS."""
86
  if not index or not chunks:
87
  return []
88
 
89
  try:
90
+ query_emb = _query_model.encode(
91
+ [f"query: {query.strip()}"],
92
+ convert_to_numpy=True,
93
+ normalize_embeddings=True
94
+ )[0]
95
+ distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k)
96
+ return [chunks[i] for i in indices[0]]
 
 
 
97
  except Exception as e:
98
  print(f"⚠️ Retrieval error: {e}")
99
  return []
100
 
101
  # ==========================================================
102
+ # 7️⃣ Answer Generation Function (with Mode Toggle)
103
  # ==========================================================
104
+ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = True):
105
+ """
106
+ Generates answers using Phi-2.
107
+ reasoning_mode=True → reasoning + external knowledge
108
+ reasoning_mode=False → strict chunk-only factual mode
109
+ """
110
  if not retrieved_chunks:
111
  return "Sorry, I couldn’t find relevant information in the document."
112
 
113
+ context = "\n".join([chunk.strip() for chunk in retrieved_chunks])
114
+ prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
115
 
116
  try:
 
117
  result = _answer_model(
118
  prompt,
119
+ max_new_tokens=180,
120
+ temperature=0.4 if reasoning_mode else 0.2,
121
  do_sample=False,
 
 
122
  )
123
+ return result[0]["generated_text"].split("ANSWER:")[-1].strip()
 
 
 
 
 
 
 
124
  except Exception as e:
125
  print(f"⚠️ Generation failed: {e}")
126
+ return "⚠️ Error: Could not generate an answer."
127
 
128
  # ==========================================================
129
+ # 8️⃣ Local Test (Optional)
130
  # ==========================================================
131
  if __name__ == "__main__":
 
132
  dummy_chunks = [
133
  "Step 1: Open the dashboard and navigate to reports.",
134
+ "Step 2: Click 'Export' to download a CSV summary."
 
135
  ]
136
+ from vectorstore import build_faiss_index
137
+ index = build_faiss_index([
138
+ _query_model.encode([f"passage: {c}"], convert_to_numpy=True, normalize_embeddings=True)[0]
139
+ for c in dummy_chunks
140
+ ])
141
+ query = "How to export a report?"
142
+ print("💬 Strict:", generate_answer(query, dummy_chunks, reasoning_mode=False))
143
+ print("💬 Reasoning:", generate_answer(query, dummy_chunks, reasoning_mode=True))