Shubham170793 commited on
Commit
41ac7b0
·
verified ·
1 Parent(s): 66bfc48

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +124 -98
src/qa.py CHANGED
@@ -1,22 +1,23 @@
1
  """
2
- qa.py — Phi-2 Fast + Smart Reasoning Mode (CPU-only Stable)
3
- -----------------------------------------------------------
4
- Uses intfloat/e5-small-v2 for embeddings
5
- Uses microsoft/phi-2 (CPU-only, no GPU / quantization)
6
- Reasoning Mode toggle integrated cleanly
7
- Retrieval and chunking unchanged
8
  """
9
 
10
  import os
11
  import numpy as np
12
- import torch
13
  from sentence_transformers import SentenceTransformer
 
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
15
 
16
- print("✅ qa.py (Phi-2 CPU) loaded from:", __file__)
17
 
18
  # ==========================================================
19
- # 1️⃣ Cache Setup
20
  # ==========================================================
21
  CACHE_DIR = "/tmp/hf_cache"
22
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -26,6 +27,7 @@ os.environ.update({
26
  "HF_DATASETS_CACHE": CACHE_DIR,
27
  "HF_MODULES_CACHE": CACHE_DIR
28
  })
 
29
 
30
  # ==========================================================
31
  # 2️⃣ Embedding Model
@@ -34,119 +36,145 @@ try:
34
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
35
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
36
  except Exception as e:
37
- print(f"⚠️ Fallback to MiniLM due to {e}")
38
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
39
 
40
  # ==========================================================
41
- # 3️⃣ Phi-2 Model (CPU-only, no quantization)
42
  # ==========================================================
43
- try:
44
- MODEL_NAME = "microsoft/phi-2"
45
- print(f"✅ Loading LLM: {MODEL_NAME} (CPU mode)")
46
-
47
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
48
-
49
- _model = AutoModelForCausalLM.from_pretrained(
50
- MODEL_NAME,
51
- cache_dir=CACHE_DIR,
52
- torch_dtype=torch.float32, # safest for CPU
53
- low_cpu_mem_usage=True,
54
- ).to("cpu")
55
-
56
- _answer_model = pipeline(
57
- "text-generation",
58
- model=_model,
59
- tokenizer=_tokenizer,
60
- device=-1, # Force CPU
61
- model_kwargs={"low_cpu_mem_usage": True},
62
- )
63
-
64
- print("✅ Phi-2 text-generation pipeline ready (CPU).")
65
-
66
- except Exception as e:
67
- print(f"⚠️ Phi-2 load failed: {e}")
68
- _answer_model = None
69
 
70
  # ==========================================================
71
- # 4️⃣ Prompt Templates
72
  # ==========================================================
73
  STRICT_PROMPT = (
74
- "Answer based ONLY on the context below.\n"
75
- "If the answer isn’t in the context, say: 'I don't know based on the provided document.'\n\n"
 
 
76
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
77
  )
78
 
79
  REASONING_PROMPT = (
80
- "You are an expert assistant. Use the context and your reasoning ability to form a clear, step-by-step answer.\n"
81
- "Be concise yet complete. If the context doesn’t contain the answer, say: 'I don't know based on the provided document.'\n\n"
 
 
 
82
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
83
  )
84
 
85
  # ==========================================================
86
- # 5️⃣ Retrieval (unchanged)
87
  # ==========================================================
88
- def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
 
 
 
 
 
 
 
 
 
 
89
  if not index or not chunks:
90
  return []
91
- try:
92
- q_emb = _query_model.encode(
93
- [f"query: {query.strip()}"],
94
- convert_to_numpy=True,
95
- normalize_embeddings=True
96
- )[0]
97
- distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
98
- selected = set()
99
- for idx in indices[0]:
100
- for i in range(max(0, idx - 1), min(len(chunks), idx + 2)):
101
- selected.add(i)
102
- return [chunks[i] for i in sorted(selected)]
103
- except Exception as e:
104
- print(f"⚠️ Retrieval error: {e}")
105
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  # ==========================================================
108
- # 6️⃣ Answer Generation (CPU Stable)
109
  # ==========================================================
110
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
 
111
  if not retrieved_chunks:
112
  return "Sorry, I couldn’t find relevant information in the document."
113
 
114
  context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
115
- context = context[:2500] # keep context short to avoid overflow & massive slowdowns
116
- print(f"🧩 Context length (chars): {len(context)}, chunks used: {len(retrieved_chunks)}")
117
-
118
-
119
- reasoning_prompt = (
120
- "You are an expert assistant for enterprise document understanding.\n"
121
- "Use the context below and your reasoning ability to form a complete, explanatory answer.\n"
122
- "If the context doesn’t contain the answer, you can logically infer based on general knowledge, "
123
- "but mention that explicitly.\n\n"
124
- "Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
125
  )
126
 
127
- strict_prompt = (
128
- "You are an assistant that must answer only using the information in the provided context.\n"
129
- "If the context does not contain relevant information, respond exactly:\n"
130
- "'I don't know based on the provided document.'\n\n"
131
- "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
132
- )
133
-
134
- prompt = (reasoning_prompt if reasoning_mode else strict_prompt).format(context=context, query=query)
135
-
136
  try:
137
  result = _answer_model(
138
  prompt,
139
- max_new_tokens=120 if not reasoning_mode else 180,
140
- temperature=0.2 if not reasoning_mode else 0.4,
141
- do_sample=False,
 
142
  pad_token_id=_tokenizer.eos_token_id,
143
  )
144
-
145
- raw = result[0]["generated_text"]
146
- if "Answer:" in raw:
147
- raw = raw.split("Answer:")[-1].strip()
148
- return raw.strip()
149
-
150
  except Exception as e:
151
  print(f"⚠️ Generation failed: {e}")
152
  return "⚠️ Error: Could not generate an answer."
@@ -156,22 +184,20 @@ def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = F
156
  # ==========================================================
157
  if __name__ == "__main__":
158
  from vectorstore import build_faiss_index
 
159
  dummy_chunks = [
160
  "Step 1: Open the dashboard and navigate to reports.",
161
  "Step 2: Click 'Export' to download a CSV summary.",
162
- "Step 3: Review the generated report in your downloads folder."
 
163
  ]
164
  embeddings = [
165
- _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
166
- for chunk in dummy_chunks
167
  ]
168
  index = build_faiss_index(embeddings)
169
 
170
- query = "What are the steps to export a report?"
171
  retrieved = retrieve_chunks(query, index, dummy_chunks)
172
-
173
- print("\n--- Strict Mode ---")
174
- print(generate_answer(query, retrieved, reasoning_mode=False))
175
-
176
- print("\n--- Reasoning Mode ---")
177
- print(generate_answer(query, retrieved, reasoning_mode=True))
 
1
  """
2
+ qa.py — Phi-2 FAST + RERANKED RETRIEVAL
3
+ --------------------------------------
4
+ Uses:
5
+ intfloat/e5-small-v2 — embeddings
6
+ microsoft/phi-2 — generation
7
+ Optimized for: speed, factual accuracy, and semantic retrieval on Hugging Face Spaces
8
  """
9
 
10
  import os
11
  import numpy as np
 
12
  from sentence_transformers import SentenceTransformer
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
15
+ import torch
16
 
17
+ print("✅ qa.py (Phi-2 FAST + ReRank) loaded from:", __file__)
18
 
19
  # ==========================================================
20
+ # 1️⃣ Cache Setup (Hugging Face /tmp cache)
21
  # ==========================================================
22
  CACHE_DIR = "/tmp/hf_cache"
23
  os.makedirs(CACHE_DIR, exist_ok=True)
 
27
  "HF_DATASETS_CACHE": CACHE_DIR,
28
  "HF_MODULES_CACHE": CACHE_DIR
29
  })
30
+ print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
31
 
32
  # ==========================================================
33
  # 2️⃣ Embedding Model
 
36
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
37
  print("✅ Loaded embedding model: intfloat/e5-small-v2")
38
  except Exception as e:
39
+ print(f"⚠️ Embedding load failed ({e}), falling back to MiniLM")
40
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
41
 
42
  # ==========================================================
43
+ # 3️⃣ Phi-2 LLM Setup
44
  # ==========================================================
45
+ MODEL_NAME = "microsoft/phi-2"
46
+ print(f"✅ Loading LLM: {MODEL_NAME}")
47
+
48
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
49
+ _model = AutoModelForCausalLM.from_pretrained(
50
+ MODEL_NAME,
51
+ cache_dir=CACHE_DIR,
52
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16,
53
+ low_cpu_mem_usage=True,
54
+ ).to("cpu")
55
+
56
+ _answer_model = pipeline(
57
+ "text-generation",
58
+ model=_model,
59
+ tokenizer=_tokenizer,
60
+ device=-1,
61
+ model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True},
62
+ )
63
+ print("✅ Phi-2 text-generation pipeline ready (optimized).")
 
 
 
 
 
 
 
64
 
65
  # ==========================================================
66
+ # 4️⃣ Prompt Template
67
  # ==========================================================
68
  STRICT_PROMPT = (
69
+ "You are an enterprise documentation assistant.\n"
70
+ "Answer factually using ONLY the context below.\n"
71
+ "If the answer isn’t present, reply exactly:\n"
72
+ "'I don't know based on the provided document.'\n\n"
73
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
74
  )
75
 
76
  REASONING_PROMPT = (
77
+ "You are an expert enterprise assistant with reasoning ability.\n"
78
+ "Think carefully about the context and question.\n"
79
+ "Use world knowledge and inference if necessary, but prefer factual accuracy.\n"
80
+ "If the document lacks the answer, say:\n"
81
+ "'I don't know based on the provided document.'\n\n"
82
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
83
  )
84
 
85
  # ==========================================================
86
+ # 5️⃣ Retrieve Chunks (FAISS + Rerank + Neighbor Expansion)
87
  # ==========================================================
88
+ def retrieve_chunks(
89
+ query: str,
90
+ index,
91
+ chunks: list,
92
+ top_k: int = 3,
93
+ topn_candidates: int = 20,
94
+ neighbor_threshold: float = 0.68,
95
+ expansion_window: int = 1,
96
+ max_context_chunks: int = 6,
97
+ ):
98
+ """Retrieve semantically relevant chunks with reranking and neighbor expansion."""
99
  if not index or not chunks:
100
  return []
101
+
102
+ # 1️⃣ Encode query (normalized)
103
+ query_emb = _query_model.encode(
104
+ [f"query: {query.strip()}"],
105
+ convert_to_numpy=True,
106
+ normalize_embeddings=True
107
+ )[0].astype("float32")
108
+
109
+ # 2️⃣ FAISS search (initial candidates)
110
+ topn_candidates = min(topn_candidates, getattr(index, "ntotal", topn_candidates))
111
+ _, candidate_ids = index.search(np.array([query_emb]).astype("float32"), topn_candidates)
112
+ candidate_ids = [int(i) for i in candidate_ids[0] if i != -1]
113
+
114
+ # 3️⃣ Re-encode candidate chunks and compute cosine similarities
115
+ candidate_texts = [chunks[i] for i in candidate_ids]
116
+ candidate_vecs = np.array([
117
+ _query_model.encode([t], convert_to_numpy=True, normalize_embeddings=True)[0]
118
+ for t in candidate_texts
119
+ ])
120
+ sims = cosine_similarity([query_emb], candidate_vecs)[0]
121
+ sorted_idx = np.argsort(sims)[::-1]
122
+ reranked_ids = [candidate_ids[i] for i in sorted_idx]
123
+
124
+ # 4️⃣ Select top-k base chunks
125
+ selected, selected_set = [], set()
126
+ for rid in reranked_ids:
127
+ if len(selected) >= top_k:
128
+ break
129
+ selected.append(rid)
130
+ selected_set.add(rid)
131
+
132
+ # 5️⃣ Conditional neighbor expansion
133
+ final_order = list(selected)
134
+ for base_id in selected:
135
+ if len(final_order) >= max_context_chunks:
136
+ break
137
+ for offset in range(1, expansion_window + 1):
138
+ for neighbor in (base_id - offset, base_id + offset):
139
+ if neighbor < 0 or neighbor >= len(chunks) or neighbor in selected_set:
140
+ continue
141
+ # Check semantic closeness
142
+ neighbor_vec = _query_model.encode([chunks[neighbor]], convert_to_numpy=True, normalize_embeddings=True)[0]
143
+ sim = float(cosine_similarity([query_emb], [neighbor_vec])[0][0])
144
+ if sim >= neighbor_threshold:
145
+ final_order.append(neighbor)
146
+ selected_set.add(neighbor)
147
+ if len(final_order) >= max_context_chunks:
148
+ break
149
+ if len(final_order) >= max_context_chunks:
150
+ break
151
+
152
+ return [chunks[i] for i in final_order]
153
 
154
  # ==========================================================
155
+ # 6️⃣ Answer Generation
156
  # ==========================================================
157
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
158
+ """Generate concise, factual or reasoning-based answers using Phi-2."""
159
  if not retrieved_chunks:
160
  return "Sorry, I couldn’t find relevant information in the document."
161
 
162
  context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
163
+ prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
164
+ context=context, query=query
 
 
 
 
 
 
 
 
165
  )
166
 
 
 
 
 
 
 
 
 
 
167
  try:
168
  result = _answer_model(
169
  prompt,
170
+ max_new_tokens=180 if reasoning_mode else 120,
171
+ temperature=0.6 if reasoning_mode else 0.3,
172
+ do_sample=reasoning_mode,
173
+ early_stopping=True,
174
  pad_token_id=_tokenizer.eos_token_id,
175
  )
176
+ text = result[0]["generated_text"].strip()
177
+ return text.split("Answer:")[-1].strip() if "Answer:" in text else text
 
 
 
 
178
  except Exception as e:
179
  print(f"⚠️ Generation failed: {e}")
180
  return "⚠️ Error: Could not generate an answer."
 
184
  # ==========================================================
185
  if __name__ == "__main__":
186
  from vectorstore import build_faiss_index
187
+
188
  dummy_chunks = [
189
  "Step 1: Open the dashboard and navigate to reports.",
190
  "Step 2: Click 'Export' to download a CSV summary.",
191
+ "Step 3: Review the generated report in your downloads folder.",
192
+ "Appendix: Communication user creation steps are explained later in this guide."
193
  ]
194
  embeddings = [
195
+ _query_model.encode([f"passage: {c}"], convert_to_numpy=True, normalize_embeddings=True)[0]
196
+ for c in dummy_chunks
197
  ]
198
  index = build_faiss_index(embeddings)
199
 
200
+ query = "How do I create a communication user?"
201
  retrieved = retrieve_chunks(query, index, dummy_chunks)
202
+ print("🔍 Retrieved:", retrieved)
203
+ print("💬 Answer:", generate_answer(query, retrieved))