Shubham170793 commited on
Commit
d7aaa8f
·
verified ·
1 Parent(s): 00be68d

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +61 -83
src/qa.py CHANGED
@@ -1,24 +1,22 @@
1
  """
2
- qa.py — Retrieval + Generation Layer (Mistral Optimized v2)
3
- -----------------------------------------------------------
4
- Handles:
5
- Query embedding (SentenceTransformer / E5)
6
- Fast FAISS retrieval with context merging
7
- Answer generation via Mistral-7B-Instruct (optimized for CPU)
8
- -----------------------------------------------------------
9
- Built for Hugging Face Spaces / Streamlit apps.
10
  """
11
 
12
  import os
13
  import numpy as np
14
  from sentence_transformers import SentenceTransformer
15
  from sklearn.metrics.pairwise import cosine_similarity
16
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
 
18
- print("✅ qa.py (Mistral Optimized v2) loaded from:", __file__)
19
 
20
  # ==========================================================
21
- # 1️⃣ Hugging Face Cache Setup
22
  # ==========================================================
23
  CACHE_DIR = "/tmp/hf_cache"
24
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -28,141 +26,121 @@ os.environ.update({
28
  "HF_DATASETS_CACHE": CACHE_DIR,
29
  "HF_MODULES_CACHE": CACHE_DIR
30
  })
31
- print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
32
 
33
  # ==========================================================
34
- # 2️⃣ Query Embedding Model (E5-small, lightweight)
35
  # ==========================================================
36
  try:
37
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
38
- print("✅ Loaded query model: intfloat/e5-small-v2")
39
  except Exception as e:
40
- print(f"⚠️ Embedding model load failed ({e}), using MiniLM fallback.")
41
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
42
 
43
  # ==========================================================
44
- # 3️⃣ LLM Setup: Mistral-7B-Instruct (quantized + optimized)
45
  # ==========================================================
46
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" # slightly faster and stable
47
- print(f"✅ Loading LLM: {MODEL_NAME}")
48
-
49
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
50
- _model = AutoModelForCausalLM.from_pretrained(
51
- MODEL_NAME,
52
- cache_dir=CACHE_DIR,
53
- torch_dtype="auto",
54
- device_map="auto",
55
- low_cpu_mem_usage=True,
56
- )
57
- _answer_model = pipeline(
58
- "text-generation",
59
- model=_model,
60
- tokenizer=_tokenizer,
61
- max_new_tokens=600,
62
- do_sample=False,
63
- )
64
- print("✅ Mistral text-generation pipeline ready.")
 
 
 
 
 
65
 
66
  # ==========================================================
67
- # 4️⃣ Prompt Template (compact + efficient)
68
  # ==========================================================
69
  PROMPT_TEMPLATE = (
70
- "Answer the question using only the document context below. "
71
- "If the answer isn’t clearly in the document, say: "
 
72
  "'I don't know based on the provided document.'\n\n"
73
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
74
  )
75
 
76
  # ==========================================================
77
- # 5️⃣ Fast Chunk Retrieval with Context Merging
78
  # ==========================================================
79
- def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, merge_window: int = 1):
80
- """
81
- Fast semantic retrieval with lightweight neighborhood expansion.
82
- Retrieves top-K relevant chunks, then merges nearby ones for context continuity.
83
- """
84
  if not index or not chunks:
85
  return []
86
 
87
  try:
88
- # Step 1: Encode query once
89
- query_emb = _query_model.encode(
90
- [f"query: {query.strip()}"],
91
- convert_to_numpy=True,
92
- normalize_embeddings=True
93
- )[0]
94
-
95
- # Step 2: Retrieve top-K*2 candidates
96
- distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
97
 
98
- # Step 3: Expand retrieval to nearby chunks
99
  selected = set()
100
  for idx in indices[0]:
101
- for n in range(max(0, idx - merge_window), min(len(chunks), idx + merge_window + 1)):
102
- selected.add(n)
103
-
104
- # Step 4: Preserve order (important for sequential text like steps)
105
- ordered = [chunks[i] for i in sorted(selected)]
106
- return ordered
107
 
 
 
108
  except Exception as e:
109
  print(f"⚠️ Retrieval error: {e}")
110
  return []
111
 
112
  # ==========================================================
113
- # 6️⃣ Answer Generation Function (Faster + Cleaner Output)
114
  # ==========================================================
115
  def generate_answer(query: str, retrieved_chunks: list):
116
- """Generate factual, context-grounded answers using Mistral."""
117
  if not retrieved_chunks:
118
  return "Sorry, I couldn’t find relevant information in the document."
119
 
120
- # Merge retrieved chunks
121
  context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
122
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
123
 
124
  try:
125
  result = _answer_model(
126
  prompt,
127
- max_new_tokens=700,
128
- temperature=None,
129
  do_sample=False,
 
130
  pad_token_id=_tokenizer.eos_token_id,
131
  )
132
  answer = result[0]["generated_text"].strip()
133
-
134
- # Cleanup redundant prompt echo
135
- if "Question:" in answer:
136
- answer = answer.split("Question:")[-1].strip()
137
- if answer.startswith(query):
138
- answer = answer[len(query):].strip()
139
-
140
  return answer
141
-
142
  except Exception as e:
143
  print(f"⚠️ Generation failed: {e}")
144
  return "⚠️ Error: Could not generate an answer at the moment."
145
 
146
  # ==========================================================
147
- # 7️⃣ Local Dev Test (optional)
148
  # ==========================================================
149
  if __name__ == "__main__":
 
150
  dummy_chunks = [
151
  "Step 1: Open the dashboard and navigate to reports.",
152
  "Step 2: Click 'Export' to download a CSV summary.",
153
  "Step 3: Review the generated report in your downloads folder."
154
  ]
155
- from vectorstore import build_faiss_index
156
 
157
- index = build_faiss_index([
158
- _query_model.encode(
159
- [f"passage: {chunk}"],
160
- convert_to_numpy=True,
161
- normalize_embeddings=True
162
- )[0]
163
  for chunk in dummy_chunks
164
- ])
165
-
166
  query = "What are the steps to export a report?"
167
  retrieved = retrieve_chunks(query, index, dummy_chunks)
168
  print("🔍 Retrieved:", retrieved)
 
1
  """
2
+ qa.py — Retrieval + Generation (Phi-2 Fast Reasoning)
3
+ -----------------------------------------------------
4
+ Uses:
5
+ - intfloat/e5-small-v2 for embeddings
6
+ - microsoft/phi-2 as main LLM (fast, strong reasoning)
7
+ - Optional fallback: google/flan-t5-base
8
+ Optimized for CPU inference (Hugging Face Spaces / Streamlit)
 
9
  """
10
 
11
  import os
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
  from sklearn.metrics.pairwise import cosine_similarity
 
15
 
16
+ print("✅ qa.py (Phi-2 optimized) loaded from:", __file__)
17
 
18
  # ==========================================================
19
+ # 1️⃣ Cache Setup
20
  # ==========================================================
21
  CACHE_DIR = "/tmp/hf_cache"
22
  os.makedirs(CACHE_DIR, exist_ok=True)
 
26
  "HF_DATASETS_CACHE": CACHE_DIR,
27
  "HF_MODULES_CACHE": CACHE_DIR
28
  })
 
29
 
30
  # ==========================================================
31
+ # 2️⃣ Embedding Model
32
  # ==========================================================
33
  try:
34
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
35
+ print("✅ Loaded embedding model: intfloat/e5-small-v2")
36
  except Exception as e:
37
+ print(f"⚠️ Fallback to MiniLM due to {e}")
38
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
39
 
40
  # ==========================================================
41
+ # 3️⃣ Phi-2 LLM Setup
42
  # ==========================================================
43
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
44
+
45
+ try:
46
+ MODEL_NAME = "microsoft/phi-2"
47
+ print(f"✅ Loading LLM: {MODEL_NAME}")
48
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
49
+ _model = AutoModelForCausalLM.from_pretrained(
50
+ MODEL_NAME,
51
+ cache_dir=CACHE_DIR,
52
+ torch_dtype="auto",
53
+ low_cpu_mem_usage=True,
54
+ )
55
+ _answer_model = pipeline(
56
+ "text-generation",
57
+ model=_model,
58
+ tokenizer=_tokenizer,
59
+ device=-1,
60
+ max_new_tokens=250,
61
+ do_sample=False,
62
+ )
63
+ print("✅ Phi-2 generation pipeline ready.")
64
+ except Exception as e:
65
+ print(f"⚠️ Phi-2 load failed: {e}")
66
+ _answer_model = None
67
 
68
  # ==========================================================
69
+ # 4️⃣ Prompt Template
70
  # ==========================================================
71
  PROMPT_TEMPLATE = (
72
+ "You are an expert assistant for enterprise document understanding.\n"
73
+ "Use ONLY the context below to answer the question clearly and factually.\n"
74
+ "If the context doesn’t contain the answer, reply: "
75
  "'I don't know based on the provided document.'\n\n"
76
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
77
  )
78
 
79
  # ==========================================================
80
+ # 5️⃣ Retrieval Function
81
  # ==========================================================
82
+ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
83
+ """Fast FAISS retrieval with E5 embeddings."""
 
 
 
84
  if not index or not chunks:
85
  return []
86
 
87
  try:
88
+ q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
89
+ distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
 
 
 
 
 
 
 
90
 
91
+ # Merge nearby chunks for continuity
92
  selected = set()
93
  for idx in indices[0]:
94
+ for i in range(max(0, idx - 1), min(len(chunks), idx + 2)):
95
+ selected.add(i)
 
 
 
 
96
 
97
+ ordered_chunks = [chunks[i] for i in sorted(selected)]
98
+ return ordered_chunks
99
  except Exception as e:
100
  print(f"⚠️ Retrieval error: {e}")
101
  return []
102
 
103
  # ==========================================================
104
+ # 6️⃣ Answer Generation Function
105
  # ==========================================================
106
  def generate_answer(query: str, retrieved_chunks: list):
107
+ """Generate grounded answers using Phi-2."""
108
  if not retrieved_chunks:
109
  return "Sorry, I couldn’t find relevant information in the document."
110
 
 
111
  context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
112
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
113
 
114
  try:
115
  result = _answer_model(
116
  prompt,
117
+ max_new_tokens=250,
 
118
  do_sample=False,
119
+ early_stopping=True,
120
  pad_token_id=_tokenizer.eos_token_id,
121
  )
122
  answer = result[0]["generated_text"].strip()
 
 
 
 
 
 
 
123
  return answer
 
124
  except Exception as e:
125
  print(f"⚠️ Generation failed: {e}")
126
  return "⚠️ Error: Could not generate an answer at the moment."
127
 
128
  # ==========================================================
129
+ # 7️⃣ Local Test (optional)
130
  # ==========================================================
131
  if __name__ == "__main__":
132
+ from vectorstore import build_faiss_index
133
  dummy_chunks = [
134
  "Step 1: Open the dashboard and navigate to reports.",
135
  "Step 2: Click 'Export' to download a CSV summary.",
136
  "Step 3: Review the generated report in your downloads folder."
137
  ]
 
138
 
139
+ embeddings = [
140
+ _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
 
 
 
 
141
  for chunk in dummy_chunks
142
+ ]
143
+ index = build_faiss_index(embeddings)
144
  query = "What are the steps to export a report?"
145
  retrieved = retrieve_chunks(query, index, dummy_chunks)
146
  print("🔍 Retrieved:", retrieved)