Shubham170793 commited on
Commit
874e5e3
·
verified ·
1 Parent(s): d1ca01c

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +36 -25
src/qa.py CHANGED
@@ -1,19 +1,20 @@
1
  """
2
- qa.py — Retrieval + Generation (Phi-2 Fast Reasoning)
3
- -----------------------------------------------------
4
  Uses:
5
- - intfloat/e5-small-v2 for embeddings
6
- - microsoft/phi-2 as main LLM (fast, strong reasoning)
7
- - Optional fallback: google/flan-t5-base
8
- Optimized for CPU inference (Hugging Face Spaces / Streamlit)
9
  """
10
 
11
  import os
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
  from sklearn.metrics.pairwise import cosine_similarity
 
 
15
 
16
- print("✅ qa.py (Phi-2 optimized) loaded from:", __file__)
17
 
18
  # ==========================================================
19
  # 1️⃣ Cache Setup
@@ -38,29 +39,33 @@ except Exception as e:
38
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
39
 
40
  # ==========================================================
41
- # 3️⃣ Phi-2 LLM Setup
42
  # ==========================================================
43
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
44
-
45
  try:
46
  MODEL_NAME = "microsoft/phi-2"
47
- print(f"✅ Loading LLM: {MODEL_NAME}")
 
48
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
 
 
49
  _model = AutoModelForCausalLM.from_pretrained(
50
  MODEL_NAME,
51
  cache_dir=CACHE_DIR,
52
- torch_dtype="auto",
53
  low_cpu_mem_usage=True,
54
- )
 
 
55
  _answer_model = pipeline(
56
  "text-generation",
57
  model=_model,
58
  tokenizer=_tokenizer,
59
  device=-1,
60
- max_new_tokens=250,
61
- do_sample=False,
62
  )
63
- print("✅ Phi-2 generation pipeline ready.")
 
 
64
  except Exception as e:
65
  print(f"⚠️ Phi-2 load failed: {e}")
66
  _answer_model = None
@@ -71,16 +76,16 @@ except Exception as e:
71
  PROMPT_TEMPLATE = (
72
  "You are an expert assistant for enterprise document understanding.\n"
73
  "Use ONLY the context below to answer the question clearly and factually.\n"
74
- "If the context doesn’t contain the answer, reply: "
75
  "'I don't know based on the provided document.'\n\n"
76
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
77
  )
78
 
79
  # ==========================================================
80
- # 5️⃣ Retrieval Function
81
  # ==========================================================
82
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
83
- """Fast FAISS retrieval with E5 embeddings."""
84
  if not index or not chunks:
85
  return []
86
 
@@ -88,7 +93,6 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
88
  q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
89
  distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
90
 
91
- # Merge nearby chunks for continuity
92
  selected = set()
93
  for idx in indices[0]:
94
  for i in range(max(0, idx - 1), min(len(chunks), idx + 2)):
@@ -101,10 +105,10 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
101
  return []
102
 
103
  # ==========================================================
104
- # 6️⃣ Answer Generation Function
105
  # ==========================================================
106
  def generate_answer(query: str, retrieved_chunks: list):
107
- """Generate grounded answers using Phi-2."""
108
  if not retrieved_chunks:
109
  return "Sorry, I couldn’t find relevant information in the document."
110
 
@@ -112,21 +116,28 @@ def generate_answer(query: str, retrieved_chunks: list):
112
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
113
 
114
  try:
 
115
  result = _answer_model(
116
  prompt,
117
- max_new_tokens=250,
118
  do_sample=False,
119
  early_stopping=True,
120
  pad_token_id=_tokenizer.eos_token_id,
121
  )
122
  answer = result[0]["generated_text"].strip()
 
 
 
 
 
123
  return answer
 
124
  except Exception as e:
125
  print(f"⚠️ Generation failed: {e}")
126
  return "⚠️ Error: Could not generate an answer at the moment."
127
 
128
  # ==========================================================
129
- # 7️⃣ Local Test (optional)
130
  # ==========================================================
131
  if __name__ == "__main__":
132
  from vectorstore import build_faiss_index
@@ -135,12 +146,12 @@ if __name__ == "__main__":
135
  "Step 2: Click 'Export' to download a CSV summary.",
136
  "Step 3: Review the generated report in your downloads folder."
137
  ]
138
-
139
  embeddings = [
140
  _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
141
  for chunk in dummy_chunks
142
  ]
143
  index = build_faiss_index(embeddings)
 
144
  query = "What are the steps to export a report?"
145
  retrieved = retrieve_chunks(query, index, dummy_chunks)
146
  print("🔍 Retrieved:", retrieved)
 
1
  """
2
+ qa.py — Optimized Phi-2 Retrieval + Generation
3
+ ----------------------------------------------
4
  Uses:
5
+ intfloat/e5-small-v2 for embeddings
6
+ microsoft/phi-2 for reasoning-rich generation (fast on CPU)
7
+ Optimized for: speed + stability in Streamlit / Hugging Face Spaces
 
8
  """
9
 
10
  import os
11
  import numpy as np
12
  from sentence_transformers import SentenceTransformer
13
  from sklearn.metrics.pairwise import cosine_similarity
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
15
+ import torch
16
 
17
+ print("✅ qa.py (Phi-2 optimized fast) loaded from:", __file__)
18
 
19
  # ==========================================================
20
  # 1️⃣ Cache Setup
 
39
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
40
 
41
  # ==========================================================
42
+ # 3️⃣ Phi-2 LLM Setup (Quantized for CPU)
43
  # ==========================================================
 
 
44
  try:
45
  MODEL_NAME = "microsoft/phi-2"
46
+ print(f"✅ Loading LLM: {MODEL_NAME} (quantized, CPU-optimized)")
47
+
48
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
49
+
50
+ # ✅ Load model in mixed precision for 4–6× faster inference
51
  _model = AutoModelForCausalLM.from_pretrained(
52
  MODEL_NAME,
53
  cache_dir=CACHE_DIR,
54
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16,
55
  low_cpu_mem_usage=True,
56
+ ).to("cpu")
57
+
58
+ # ✅ Create generation pipeline (keep in memory)
59
  _answer_model = pipeline(
60
  "text-generation",
61
  model=_model,
62
  tokenizer=_tokenizer,
63
  device=-1,
64
+ model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True},
 
65
  )
66
+
67
+ print("✅ Phi-2 text-generation pipeline ready (optimized).")
68
+
69
  except Exception as e:
70
  print(f"⚠️ Phi-2 load failed: {e}")
71
  _answer_model = None
 
76
  PROMPT_TEMPLATE = (
77
  "You are an expert assistant for enterprise document understanding.\n"
78
  "Use ONLY the context below to answer the question clearly and factually.\n"
79
+ "If the context doesn’t contain the answer, reply exactly:\n"
80
  "'I don't know based on the provided document.'\n\n"
81
  "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
82
  )
83
 
84
  # ==========================================================
85
+ # 5️⃣ Retrieve Top-K Chunks
86
  # ==========================================================
87
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
88
+ """Efficient FAISS retrieval using cosine similarity."""
89
  if not index or not chunks:
90
  return []
91
 
 
93
  q_emb = _query_model.encode([f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True)[0]
94
  distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
95
 
 
96
  selected = set()
97
  for idx in indices[0]:
98
  for i in range(max(0, idx - 1), min(len(chunks), idx + 2)):
 
105
  return []
106
 
107
  # ==========================================================
108
+ # 6️⃣ Answer Generation (fast)
109
  # ==========================================================
110
  def generate_answer(query: str, retrieved_chunks: list):
111
+ """Generate concise, grounded answers using Phi-2."""
112
  if not retrieved_chunks:
113
  return "Sorry, I couldn’t find relevant information in the document."
114
 
 
116
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
117
 
118
  try:
119
+ # ✅ Limit tokens to speed up inference
120
  result = _answer_model(
121
  prompt,
122
+ max_new_tokens=120, # reduced for faster completion
123
  do_sample=False,
124
  early_stopping=True,
125
  pad_token_id=_tokenizer.eos_token_id,
126
  )
127
  answer = result[0]["generated_text"].strip()
128
+
129
+ # Clean excessive prompt echo
130
+ if "Answer:" in answer:
131
+ answer = answer.split("Answer:")[-1].strip()
132
+
133
  return answer
134
+
135
  except Exception as e:
136
  print(f"⚠️ Generation failed: {e}")
137
  return "⚠️ Error: Could not generate an answer at the moment."
138
 
139
  # ==========================================================
140
+ # 7️⃣ Local Test
141
  # ==========================================================
142
  if __name__ == "__main__":
143
  from vectorstore import build_faiss_index
 
146
  "Step 2: Click 'Export' to download a CSV summary.",
147
  "Step 3: Review the generated report in your downloads folder."
148
  ]
 
149
  embeddings = [
150
  _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
151
  for chunk in dummy_chunks
152
  ]
153
  index = build_faiss_index(embeddings)
154
+
155
  query = "What are the steps to export a report?"
156
  retrieved = retrieve_chunks(query, index, dummy_chunks)
157
  print("🔍 Retrieved:", retrieved)