Shubham170793 commited on
Commit
49c4268
·
verified ·
1 Parent(s): a5ea9d2

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +49 -45
src/qa.py CHANGED
@@ -1,21 +1,21 @@
1
  """
2
- qa.py — Retrieval + Generation Layer (Optimized Mistral Version)
3
- ---------------------------------------------------------------
4
  Handles:
5
- • Query embedding (SentenceTransformer / E5-compatible)
6
- Chunk retrieval (FAISS, no redundant encoding)
7
- • Answer generation (Mistral-7B-Instruct, quantized for CPU)
8
- Optimized for Hugging Face Spaces & Streamlit.
 
9
  """
10
 
11
  import os
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
  from sklearn.metrics.pairwise import cosine_similarity
15
- from vectorstore import search_faiss
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
 
18
- print("✅ qa.py (Mistral version) loaded from:", __file__)
19
 
20
  # ==========================================================
21
  # 1️⃣ Hugging Face Cache Setup
@@ -31,19 +31,19 @@ os.environ.update({
31
  print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
32
 
33
  # ==========================================================
34
- # 2️⃣ Query Embedding Model (fast, efficient)
35
  # ==========================================================
36
  try:
37
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
38
- print("✅ Loaded model: intfloat/e5-small-v2")
39
  except Exception as e:
40
- print(f"⚠️ Embedding model load failed ({e}), falling back to MiniLM.")
41
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
42
 
43
  # ==========================================================
44
- # 3️⃣ LLM Setup (Mistral 7B-Instruct, quantized)
45
  # ==========================================================
46
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
47
  print(f"✅ Loading LLM: {MODEL_NAME}")
48
 
49
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
@@ -51,8 +51,8 @@ _model = AutoModelForCausalLM.from_pretrained(
51
  MODEL_NAME,
52
  cache_dir=CACHE_DIR,
53
  torch_dtype="auto",
54
- device_map="auto", # smart layer placement
55
- low_cpu_mem_usage=True, # enables disk offloading on CPU
56
  )
57
  _answer_model = pipeline(
58
  "text-generation",
@@ -64,29 +64,23 @@ _answer_model = pipeline(
64
  print("✅ Mistral text-generation pipeline ready.")
65
 
66
  # ==========================================================
67
- # 4️⃣ Prompt Template
68
  # ==========================================================
69
- PROMPT_TEMPLATE = """
70
- You are an enterprise knowledge assistant.
71
- Use ONLY the CONTEXT below to answer the QUESTION clearly, completely, and factually.
72
- If the context doesn’t contain the answer, reply exactly:
73
- "I don't know based on the provided document."
74
-
75
- ---
76
- Context:
77
- {context}
78
- ---
79
- Question:
80
- {query}
81
- ---
82
- Answer:
83
- """
84
 
85
  # ==========================================================
86
- # 5️⃣ Chunk Retrieval Function (FAST)
87
  # ==========================================================
88
- def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
89
- """Fast semantic retrieval with FAISS — no redundant re-encoding."""
 
 
 
90
  if not index or not chunks:
91
  return []
92
 
@@ -98,18 +92,25 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
98
  normalize_embeddings=True
99
  )[0]
100
 
101
- # Step 2: FAISS search only (already has precomputed embeddings)
102
- distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k)
103
 
104
- # Step 3: Return top chunks directly (fast)
105
- return [chunks[i] for i in indices[0]]
 
 
 
 
 
 
 
106
 
107
  except Exception as e:
108
  print(f"⚠️ Retrieval error: {e}")
109
  return []
110
 
111
  # ==========================================================
112
- # 6️⃣ Answer Generation Function (Optimized for Speed)
113
  # ==========================================================
114
  def generate_answer(query: str, retrieved_chunks: list):
115
  """Generate factual, context-grounded answers using Mistral."""
@@ -117,11 +118,7 @@ def generate_answer(query: str, retrieved_chunks: list):
117
  return "Sorry, I couldn’t find relevant information in the document."
118
 
119
  # Merge retrieved chunks
120
- context = "\n\n".join([
121
- f"[Chunk {i+1}]: {chunk.strip()}"
122
- for i, chunk in enumerate(retrieved_chunks)
123
- ])
124
-
125
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
126
 
127
  try:
@@ -133,6 +130,13 @@ def generate_answer(query: str, retrieved_chunks: list):
133
  pad_token_id=_tokenizer.eos_token_id,
134
  )
135
  answer = result[0]["generated_text"].strip()
 
 
 
 
 
 
 
136
  return answer
137
 
138
  except Exception as e:
@@ -140,7 +144,7 @@ def generate_answer(query: str, retrieved_chunks: list):
140
  return "⚠️ Error: Could not generate an answer at the moment."
141
 
142
  # ==========================================================
143
- # 7️⃣ Local Test (run only in dev mode)
144
  # ==========================================================
145
  if __name__ == "__main__":
146
  dummy_chunks = [
 
1
  """
2
+ qa.py — Retrieval + Generation Layer (Mistral Optimized v2)
3
+ -----------------------------------------------------------
4
  Handles:
5
+ • Query embedding (SentenceTransformer / E5)
6
+ Fast FAISS retrieval with context merging
7
+ • Answer generation via Mistral-7B-Instruct (optimized for CPU)
8
+ -----------------------------------------------------------
9
+ Built for Hugging Face Spaces / Streamlit apps.
10
  """
11
 
12
  import os
13
  import numpy as np
14
  from sentence_transformers import SentenceTransformer
15
  from sklearn.metrics.pairwise import cosine_similarity
 
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
 
18
+ print("✅ qa.py (Mistral Optimized v2) loaded from:", __file__)
19
 
20
  # ==========================================================
21
  # 1️⃣ Hugging Face Cache Setup
 
31
  print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
32
 
33
  # ==========================================================
34
+ # 2️⃣ Query Embedding Model (E5-small, lightweight)
35
  # ==========================================================
36
  try:
37
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
38
+ print("✅ Loaded query model: intfloat/e5-small-v2")
39
  except Exception as e:
40
+ print(f"⚠️ Embedding model load failed ({e}), using MiniLM fallback.")
41
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
42
 
43
  # ==========================================================
44
+ # 3️⃣ LLM Setup: Mistral-7B-Instruct (quantized + optimized)
45
  # ==========================================================
46
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" # slightly faster and stable
47
  print(f"✅ Loading LLM: {MODEL_NAME}")
48
 
49
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
 
51
  MODEL_NAME,
52
  cache_dir=CACHE_DIR,
53
  torch_dtype="auto",
54
+ device_map="auto",
55
+ low_cpu_mem_usage=True,
56
  )
57
  _answer_model = pipeline(
58
  "text-generation",
 
64
  print("✅ Mistral text-generation pipeline ready.")
65
 
66
  # ==========================================================
67
+ # 4️⃣ Prompt Template (compact + efficient)
68
  # ==========================================================
69
+ PROMPT_TEMPLATE = (
70
+ "Answer the question using only the document context below. "
71
+ "If the answer isn’t clearly in the document, say: "
72
+ "'I don't know based on the provided document.'\n\n"
73
+ "Context:\n{context}\n\nQuestion: {query}\nAnswer:"
74
+ )
 
 
 
 
 
 
 
 
 
75
 
76
  # ==========================================================
77
+ # 5️⃣ Fast Chunk Retrieval with Context Merging
78
  # ==========================================================
79
+ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, merge_window: int = 1):
80
+ """
81
+ Fast semantic retrieval with lightweight neighborhood expansion.
82
+ Retrieves top-K relevant chunks, then merges nearby ones for context continuity.
83
+ """
84
  if not index or not chunks:
85
  return []
86
 
 
92
  normalize_embeddings=True
93
  )[0]
94
 
95
+ # Step 2: Retrieve top-K*2 candidates
96
+ distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
97
 
98
+ # Step 3: Expand retrieval to nearby chunks
99
+ selected = set()
100
+ for idx in indices[0]:
101
+ for n in range(max(0, idx - merge_window), min(len(chunks), idx + merge_window + 1)):
102
+ selected.add(n)
103
+
104
+ # Step 4: Preserve order (important for sequential text like steps)
105
+ ordered = [chunks[i] for i in sorted(selected)]
106
+ return ordered
107
 
108
  except Exception as e:
109
  print(f"⚠️ Retrieval error: {e}")
110
  return []
111
 
112
  # ==========================================================
113
+ # 6️⃣ Answer Generation Function (Faster + Cleaner Output)
114
  # ==========================================================
115
  def generate_answer(query: str, retrieved_chunks: list):
116
  """Generate factual, context-grounded answers using Mistral."""
 
118
  return "Sorry, I couldn’t find relevant information in the document."
119
 
120
  # Merge retrieved chunks
121
+ context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
 
 
 
 
122
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
123
 
124
  try:
 
130
  pad_token_id=_tokenizer.eos_token_id,
131
  )
132
  answer = result[0]["generated_text"].strip()
133
+
134
+ # Cleanup redundant prompt echo
135
+ if "Question:" in answer:
136
+ answer = answer.split("Question:")[-1].strip()
137
+ if answer.startswith(query):
138
+ answer = answer[len(query):].strip()
139
+
140
  return answer
141
 
142
  except Exception as e:
 
144
  return "⚠️ Error: Could not generate an answer at the moment."
145
 
146
  # ==========================================================
147
+ # 7️⃣ Local Dev Test (optional)
148
  # ==========================================================
149
  if __name__ == "__main__":
150
  dummy_chunks = [