Shubham170793 commited on
Commit
d4d8027
·
verified ·
1 Parent(s): 54ca62a

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +44 -74
src/qa.py CHANGED
@@ -4,7 +4,7 @@ qa.py — Retrieval + Generation Layer
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
  • Chunk retrieval (FAISS with neighborhood merging + re-ranking)
7
- • Answer generation (OpenAI GPT-4o-mini → FLAN-T5 fallback)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
@@ -12,9 +12,10 @@ import os
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
  from sklearn.metrics.pairwise import cosine_similarity
 
15
  from vectorstore import search_faiss
16
 
17
- print("✅ qa.py loaded from:", __file__)
18
 
19
  # ==========================================================
20
  # 1️⃣ Hugging Face Cache Setup
@@ -39,76 +40,71 @@ except Exception as e:
39
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
40
 
41
  # ==========================================================
42
- # 3️⃣ LLM Setup: OpenAI (primary) + FLAN (fallback)
43
  # ==========================================================
44
- USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
45
- _answer_model = None # ensures it's always defined
46
-
47
- if USE_OPENAI:
48
- try:
49
- from openai import OpenAI
50
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
51
- print("✅ Using OpenAI GPT-4o-mini for answer generation")
52
- except Exception as e:
53
- print(f"⚠️ Failed to initialize OpenAI client: {e}")
54
- USE_OPENAI = False
55
-
56
- # Always prepare fallback safely
57
- try:
58
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
59
- MODEL_NAME = "google/flan-t5-base"
60
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
61
- _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
62
- _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
63
- print("💡 Fallback FLAN-T5 ready.")
64
- except Exception as e:
65
- print(f"⚠️ Could not initialize FLAN fallback: {e}")
66
 
67
  # ==========================================================
68
  # 4️⃣ Prompt Template
69
  # ==========================================================
70
- PROMPT_TEMPLATE = """
71
- You are an enterprise knowledge assistant.
72
- Use ONLY the CONTEXT below to answer the QUESTION clearly, completely, and factually.
73
- If the context doesn’t contain the answer, reply exactly:
74
  "I don't know based on the provided document."
75
 
76
- ---
77
  Context:
78
  {context}
79
- ---
80
  Question:
81
  {query}
82
- ---
83
- Answer:
84
- """
85
 
86
  # ==========================================================
87
  # 5️⃣ Chunk Retrieval Function
88
  # ==========================================================
89
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
90
- """Retrieve top-K relevant chunks, merge nearby ones, and re-rank by semantic similarity."""
91
  if not index or not chunks:
92
  return []
93
 
94
  try:
95
- # Step 1: Encode the query
96
  query_emb = _query_model.encode(
97
  [f"query: {query.strip()}"],
98
  convert_to_numpy=True,
99
  normalize_embeddings=True
100
  )[0]
101
 
102
- # Step 2: Initial FAISS retrieval
103
  distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
104
 
105
- # Step 3: Merge neighboring chunks
106
  merged_chunks = []
107
  for idx in indices[0]:
108
  neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
109
  merged_chunks.append(" ".join(neighbors))
110
 
111
- # Step 4: Re-rank using cosine similarity
112
  chunk_vecs = np.array([
113
  _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
114
  for c in merged_chunks
@@ -116,62 +112,36 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
116
  scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0]
117
  sorted_indices = np.argsort(scores)[::-1]
118
 
119
- # Step 5: Return top-ranked merged chunks
120
  return [merged_chunks[i] for i in sorted_indices[:top_k]]
121
 
122
  except Exception as e:
123
  print(f"⚠️ Retrieval error: {e}")
124
  return []
125
 
126
-
127
  # ==========================================================
128
  # 6️⃣ Answer Generation Function
129
  # ==========================================================
130
  def generate_answer(query: str, retrieved_chunks: list):
131
- """Generate factual, context-grounded answers using OpenAI or fallback FLAN-T5."""
132
  if not retrieved_chunks:
133
  return "Sorry, I couldn’t find relevant information in the document."
134
 
135
- # Build full context
136
  context = "\n\n".join([
137
  f"[Chunk {i+1}]: {chunk.strip()}"
138
  for i, chunk in enumerate(retrieved_chunks)
139
  ])
140
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
141
 
142
- # --- Try OpenAI first ---
143
- if USE_OPENAI:
144
- try:
145
- response = client.chat.completions.create(
146
- model="gpt-4o-mini",
147
- messages=[
148
- {"role": "system", "content": "You are a precise enterprise document assistant."},
149
- {"role": "user", "content": prompt},
150
- ],
151
- temperature=0.4,
152
- max_tokens=800,
153
- )
154
- return response.choices[0].message.content.strip()
155
-
156
- except Exception as e:
157
- print(f"⚠️ OpenAI generation failed: {e}. Switching to fallback...")
158
-
159
- # --- Fallback to FLAN-T5 ---
160
  try:
161
- if _answer_model:
162
- result = _answer_model(
163
- prompt,
164
- max_new_tokens=600,
165
- do_sample=False,
166
- temperature=0.3
167
- )
168
- return result[0]["generated_text"].strip()
169
- else:
170
- return "⚠️ Error: Fallback model not available."
171
  except Exception as e:
172
- print(f"⚠️ Fallback model failed: {e}")
173
- return "⚠️ Error: Both OpenAI and fallback generation failed."
174
-
175
 
176
  # ==========================================================
177
  # 7️⃣ Local Test
 
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
  • Chunk retrieval (FAISS with neighborhood merging + re-ranking)
7
+ • Answer generation (Mistral-7B-Instruct-v0.3)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
 
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
  from sklearn.metrics.pairwise import cosine_similarity
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
16
  from vectorstore import search_faiss
17
 
18
+ print("✅ qa.py (Mistral version) loaded from:", __file__)
19
 
20
  # ==========================================================
21
  # 1️⃣ Hugging Face Cache Setup
 
40
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
41
 
42
  # ==========================================================
43
+ # 3️⃣ LLM Setup: Mistral-7B-Instruct-v0.3
44
  # ==========================================================
45
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
46
+ print(f"✅ Loading LLM: {MODEL_NAME}")
47
+
48
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
49
+ _model = AutoModelForCausalLM.from_pretrained(
50
+ MODEL_NAME,
51
+ cache_dir=CACHE_DIR,
52
+ torch_dtype="auto",
53
+ device_map="auto" # Uses GPU if available, CPU otherwise
54
+ )
55
+
56
+ _answer_model = pipeline(
57
+ "text-generation",
58
+ model=_model,
59
+ tokenizer=_tokenizer,
60
+ max_new_tokens=800,
61
+ temperature=0.4,
62
+ do_sample=False
63
+ )
64
+ print(" Mistral text-generation pipeline ready.")
 
 
65
 
66
  # ==========================================================
67
  # 4️⃣ Prompt Template
68
  # ==========================================================
69
+ PROMPT_TEMPLATE = """You are a precise enterprise knowledge assistant.
70
+ Use only the context provided below to answer the question clearly and factually.
71
+ If the answer cannot be found, reply exactly:
 
72
  "I don't know based on the provided document."
73
 
 
74
  Context:
75
  {context}
76
+
77
  Question:
78
  {query}
79
+
80
+ Answer:"""
 
81
 
82
  # ==========================================================
83
  # 5️⃣ Chunk Retrieval Function
84
  # ==========================================================
85
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
86
+ """Retrieve top-K relevant chunks and re-rank them by semantic accuracy."""
87
  if not index or not chunks:
88
  return []
89
 
90
  try:
91
+ # Encode the query
92
  query_emb = _query_model.encode(
93
  [f"query: {query.strip()}"],
94
  convert_to_numpy=True,
95
  normalize_embeddings=True
96
  )[0]
97
 
98
+ # Initial FAISS retrieval
99
  distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
100
 
101
+ # Merge neighboring chunks
102
  merged_chunks = []
103
  for idx in indices[0]:
104
  neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
105
  merged_chunks.append(" ".join(neighbors))
106
 
107
+ # Re-rank by cosine similarity
108
  chunk_vecs = np.array([
109
  _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
110
  for c in merged_chunks
 
112
  scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0]
113
  sorted_indices = np.argsort(scores)[::-1]
114
 
 
115
  return [merged_chunks[i] for i in sorted_indices[:top_k]]
116
 
117
  except Exception as e:
118
  print(f"⚠️ Retrieval error: {e}")
119
  return []
120
 
 
121
  # ==========================================================
122
  # 6️⃣ Answer Generation Function
123
  # ==========================================================
124
  def generate_answer(query: str, retrieved_chunks: list):
125
+ """Generate factual, context-grounded answers using Mistral-7B."""
126
  if not retrieved_chunks:
127
  return "Sorry, I couldn’t find relevant information in the document."
128
 
129
+ # Build the full context
130
  context = "\n\n".join([
131
  f"[Chunk {i+1}]: {chunk.strip()}"
132
  for i, chunk in enumerate(retrieved_chunks)
133
  ])
134
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  try:
137
+ result = _answer_model(prompt)
138
+ output = result[0]["generated_text"]
139
+ # Remove the repeated prompt text (if any)
140
+ answer = output[len(prompt):].strip()
141
+ return answer
 
 
 
 
 
142
  except Exception as e:
143
+ print(f"⚠️ Generation failed: {e}")
144
+ return "⚠️ Error: Could not generate an answer at the moment."
 
145
 
146
  # ==========================================================
147
  # 7️⃣ Local Test