Shubham170793 commited on
Commit
a5ea9d2
·
verified ·
1 Parent(s): 6f0d970

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +44 -48
src/qa.py CHANGED
@@ -1,10 +1,10 @@
1
  """
2
- qa.py — Retrieval + Generation Layer
3
- -------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
- • Chunk retrieval (FAISS with neighborhood merging + re-ranking)
7
- • Answer generation (Mistral-7B-Instruct-v0.3)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
@@ -12,8 +12,8 @@ import os
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
  from sklearn.metrics.pairwise import cosine_similarity
15
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  from vectorstore import search_faiss
 
17
 
18
  print("✅ qa.py (Mistral version) loaded from:", __file__)
19
 
@@ -28,19 +28,20 @@ os.environ.update({
28
  "HF_DATASETS_CACHE": CACHE_DIR,
29
  "HF_MODULES_CACHE": CACHE_DIR
30
  })
 
31
 
32
  # ==========================================================
33
- # 2️⃣ Query Embedding Model
34
  # ==========================================================
35
  try:
36
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
37
- print("✅ Loaded query model: intfloat/e5-small-v2")
38
  except Exception as e:
39
- print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
40
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
41
 
42
  # ==========================================================
43
- # 3️⃣ LLM Setup: Mistral-7B-Instruct-v0.3
44
  # ==========================================================
45
  MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
46
  print(f"✅ Loading LLM: {MODEL_NAME}")
@@ -50,101 +51,96 @@ _model = AutoModelForCausalLM.from_pretrained(
50
  MODEL_NAME,
51
  cache_dir=CACHE_DIR,
52
  torch_dtype="auto",
53
- device_map="auto" # Uses GPU if available, CPU otherwise
 
54
  )
55
-
56
  _answer_model = pipeline(
57
  "text-generation",
58
  model=_model,
59
  tokenizer=_tokenizer,
60
- max_new_tokens=800,
61
- temperature=0.4,
62
- do_sample=False
63
  )
64
  print("✅ Mistral text-generation pipeline ready.")
65
 
66
  # ==========================================================
67
  # 4️⃣ Prompt Template
68
  # ==========================================================
69
- PROMPT_TEMPLATE = """You are a precise enterprise knowledge assistant.
70
- Use only the context provided below to answer the question clearly and factually.
71
- If the answer cannot be found, reply exactly:
 
72
  "I don't know based on the provided document."
73
 
 
74
  Context:
75
  {context}
76
-
77
  Question:
78
  {query}
79
-
80
- Answer:"""
 
81
 
82
  # ==========================================================
83
- # 5️⃣ Chunk Retrieval Function
84
  # ==========================================================
85
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
86
- """Retrieve top-K relevant chunks and re-rank them by semantic accuracy."""
87
  if not index or not chunks:
88
  return []
89
 
90
  try:
91
- # Encode the query
92
  query_emb = _query_model.encode(
93
  [f"query: {query.strip()}"],
94
  convert_to_numpy=True,
95
  normalize_embeddings=True
96
  )[0]
97
 
98
- # Initial FAISS retrieval
99
- distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
100
-
101
- # Merge neighboring chunks
102
- merged_chunks = []
103
- for idx in indices[0]:
104
- neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
105
- merged_chunks.append(" ".join(neighbors))
106
-
107
- # Re-rank by cosine similarity
108
- chunk_vecs = np.array([
109
- _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
110
- for c in merged_chunks
111
- ])
112
- scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0]
113
- sorted_indices = np.argsort(scores)[::-1]
114
 
115
- return [merged_chunks[i] for i in sorted_indices[:top_k]]
 
116
 
117
  except Exception as e:
118
  print(f"⚠️ Retrieval error: {e}")
119
  return []
120
 
121
  # ==========================================================
122
- # 6️⃣ Answer Generation Function
123
  # ==========================================================
124
  def generate_answer(query: str, retrieved_chunks: list):
125
- """Generate factual, context-grounded answers using Mistral-7B."""
126
  if not retrieved_chunks:
127
  return "Sorry, I couldn’t find relevant information in the document."
128
 
129
- # Build the full context
130
  context = "\n\n".join([
131
  f"[Chunk {i+1}]: {chunk.strip()}"
132
  for i, chunk in enumerate(retrieved_chunks)
133
  ])
 
134
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
135
 
136
  try:
137
- result = _answer_model(prompt)
138
- output = result[0]["generated_text"]
139
- # Remove the repeated prompt text (if any)
140
- answer = output[len(prompt):].strip()
 
 
 
 
141
  return answer
 
142
  except Exception as e:
143
  print(f"⚠️ Generation failed: {e}")
144
  return "⚠️ Error: Could not generate an answer at the moment."
145
 
146
  # ==========================================================
147
- # 7️⃣ Local Test
148
  # ==========================================================
149
  if __name__ == "__main__":
150
  dummy_chunks = [
 
1
  """
2
+ qa.py — Retrieval + Generation Layer (Optimized Mistral Version)
3
+ ---------------------------------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
+ • Chunk retrieval (FAISS, no redundant encoding)
7
+ • Answer generation (Mistral-7B-Instruct, quantized for CPU)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
 
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
  from sklearn.metrics.pairwise import cosine_similarity
 
15
  from vectorstore import search_faiss
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
 
18
  print("✅ qa.py (Mistral version) loaded from:", __file__)
19
 
 
28
  "HF_DATASETS_CACHE": CACHE_DIR,
29
  "HF_MODULES_CACHE": CACHE_DIR
30
  })
31
+ print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
32
 
33
  # ==========================================================
34
+ # 2️⃣ Query Embedding Model (fast, efficient)
35
  # ==========================================================
36
  try:
37
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
38
+ print("✅ Loaded model: intfloat/e5-small-v2")
39
  except Exception as e:
40
+ print(f"⚠️ Embedding model load failed ({e}), falling back to MiniLM.")
41
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
42
 
43
  # ==========================================================
44
+ # 3️⃣ LLM Setup (Mistral 7B-Instruct, quantized)
45
  # ==========================================================
46
  MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
47
  print(f"✅ Loading LLM: {MODEL_NAME}")
 
51
  MODEL_NAME,
52
  cache_dir=CACHE_DIR,
53
  torch_dtype="auto",
54
+ device_map="auto", # smart layer placement
55
+ low_cpu_mem_usage=True, # enables disk offloading on CPU
56
  )
 
57
  _answer_model = pipeline(
58
  "text-generation",
59
  model=_model,
60
  tokenizer=_tokenizer,
61
+ max_new_tokens=600,
62
+ do_sample=False,
 
63
  )
64
  print("✅ Mistral text-generation pipeline ready.")
65
 
66
  # ==========================================================
67
  # 4️⃣ Prompt Template
68
  # ==========================================================
69
+ PROMPT_TEMPLATE = """
70
+ You are an enterprise knowledge assistant.
71
+ Use ONLY the CONTEXT below to answer the QUESTION clearly, completely, and factually.
72
+ If the context doesn’t contain the answer, reply exactly:
73
  "I don't know based on the provided document."
74
 
75
+ ---
76
  Context:
77
  {context}
78
+ ---
79
  Question:
80
  {query}
81
+ ---
82
+ Answer:
83
+ """
84
 
85
  # ==========================================================
86
+ # 5️⃣ Chunk Retrieval Function (FAST)
87
  # ==========================================================
88
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
89
+ """Fast semantic retrieval with FAISS no redundant re-encoding."""
90
  if not index or not chunks:
91
  return []
92
 
93
  try:
94
+ # Step 1: Encode query once
95
  query_emb = _query_model.encode(
96
  [f"query: {query.strip()}"],
97
  convert_to_numpy=True,
98
  normalize_embeddings=True
99
  )[0]
100
 
101
+ # Step 2: FAISS search only (already has precomputed embeddings)
102
+ distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ # Step 3: Return top chunks directly (fast)
105
+ return [chunks[i] for i in indices[0]]
106
 
107
  except Exception as e:
108
  print(f"⚠️ Retrieval error: {e}")
109
  return []
110
 
111
  # ==========================================================
112
+ # 6️⃣ Answer Generation Function (Optimized for Speed)
113
  # ==========================================================
114
  def generate_answer(query: str, retrieved_chunks: list):
115
+ """Generate factual, context-grounded answers using Mistral."""
116
  if not retrieved_chunks:
117
  return "Sorry, I couldn’t find relevant information in the document."
118
 
119
+ # Merge retrieved chunks
120
  context = "\n\n".join([
121
  f"[Chunk {i+1}]: {chunk.strip()}"
122
  for i, chunk in enumerate(retrieved_chunks)
123
  ])
124
+
125
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
126
 
127
  try:
128
+ result = _answer_model(
129
+ prompt,
130
+ max_new_tokens=700,
131
+ temperature=None,
132
+ do_sample=False,
133
+ pad_token_id=_tokenizer.eos_token_id,
134
+ )
135
+ answer = result[0]["generated_text"].strip()
136
  return answer
137
+
138
  except Exception as e:
139
  print(f"⚠️ Generation failed: {e}")
140
  return "⚠️ Error: Could not generate an answer at the moment."
141
 
142
  # ==========================================================
143
+ # 7️⃣ Local Test (run only in dev mode)
144
  # ==========================================================
145
  if __name__ == "__main__":
146
  dummy_chunks = [