Shubham170793 commited on
Commit
c7133f4
·
verified ·
1 Parent(s): 5491531

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +72 -46
src/qa.py CHANGED
@@ -3,8 +3,8 @@ qa.py — Retrieval + Generation Layer
3
  -------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
- • Chunk retrieval (FAISS + cosine re-ranking)
7
- • Answer generation (OpenAI GPT-4o-mini or FLAN-T5 fallback)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
@@ -14,23 +14,10 @@ from sentence_transformers import SentenceTransformer
14
  from sklearn.metrics.pairwise import cosine_similarity
15
  from vectorstore import search_faiss
16
 
17
- # ==========================================================
18
- # 1️⃣ Load OpenAI if key available
19
- # ==========================================================
20
- USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
21
-
22
- if USE_OPENAI:
23
- from openai import OpenAI
24
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
25
- print("✅ Using OpenAI GPT-4o-mini for answer generation")
26
- else:
27
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
28
- print("⚙️ No OpenAI key found — using fallback FLAN-T5 model")
29
-
30
- print("✅ qa.py loaded successfully")
31
 
32
  # ==========================================================
33
- # 2️⃣ Hugging Face Cache Setup (Safe for Spaces)
34
  # ==========================================================
35
  CACHE_DIR = "/tmp/hf_cache"
36
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -42,32 +29,48 @@ os.environ.update({
42
  })
43
 
44
  # ==========================================================
45
- # 3️⃣ Embedding Model (E5 for better retrieval)
46
  # ==========================================================
47
  try:
48
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
49
- print("✅ Loaded embedding model: intfloat/e5-small-v2")
50
  except Exception as e:
51
- print(f"⚠️ Failed to load e5-small-v2 ({e}), switching to MiniLM.")
52
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
53
- print("✅ Loaded fallback: all-MiniLM-L6-v2")
54
 
55
  # ==========================================================
56
- # 4️⃣ Fallback Model (FLAN-T5)
57
  # ==========================================================
58
- if not USE_OPENAI:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  MODEL_NAME = "google/flan-t5-base"
60
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
61
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
62
  _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
 
 
 
63
 
64
  # ==========================================================
65
- # 5️⃣ Prompt Template
66
  # ==========================================================
67
  PROMPT_TEMPLATE = """
68
  You are an enterprise knowledge assistant.
69
- Use ONLY the context below to answer the question clearly, precisely, and factually.
70
- If the context doesn’t contain the answer, say exactly:
71
  "I don't know based on the provided document."
72
 
73
  ---
@@ -81,28 +84,31 @@ Answer:
81
  """
82
 
83
  # ==========================================================
84
- # 6️⃣ Chunk Retrieval
85
  # ==========================================================
86
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
87
- """Retrieve top-K relevant chunks and re-rank by semantic similarity."""
88
  if not index or not chunks:
89
  return []
90
 
91
  try:
 
92
  query_emb = _query_model.encode(
93
  [f"query: {query.strip()}"],
94
  convert_to_numpy=True,
95
  normalize_embeddings=True
96
  )[0]
97
 
98
- # Retrieve more and then re-rank
99
  distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
 
 
100
  merged_chunks = []
101
  for idx in indices[0]:
102
  neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
103
  merged_chunks.append(" ".join(neighbors))
104
 
105
- # Re-rank by cosine similarity
106
  chunk_vecs = np.array([
107
  _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
108
  for c in merged_chunks
@@ -110,32 +116,36 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
110
  scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0]
111
  sorted_indices = np.argsort(scores)[::-1]
112
 
 
113
  return [merged_chunks[i] for i in sorted_indices[:top_k]]
114
 
115
  except Exception as e:
116
  print(f"⚠️ Retrieval error: {e}")
117
  return []
118
 
 
119
  # ==========================================================
120
- # 7️⃣ Answer Generation
121
  # ==========================================================
122
  def generate_answer(query: str, retrieved_chunks: list):
123
- """Generate factual answer using OpenAI GPT-4o-mini (preferred) or FLAN fallback."""
124
  if not retrieved_chunks:
125
  return "Sorry, I couldn’t find relevant information in the document."
126
 
127
- # Merge retrieved chunks
128
- context = "\n\n".join(
129
- [f"[Chunk {i+1}]: {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks)]
130
- )
 
131
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
132
 
133
- try:
134
- if USE_OPENAI:
 
135
  response = client.chat.completions.create(
136
  model="gpt-4o-mini",
137
  messages=[
138
- {"role": "system", "content": "You are a precise enterprise assistant that answers only from the provided context."},
139
  {"role": "user", "content": prompt},
140
  ],
141
  temperature=0.4,
@@ -143,16 +153,28 @@ def generate_answer(query: str, retrieved_chunks: list):
143
  )
144
  return response.choices[0].message.content.strip()
145
 
146
- else:
147
- result = _answer_model(prompt, max_new_tokens=600, do_sample=False, temperature=0.3)
148
- return result[0]["generated_text"].strip()
149
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  except Exception as e:
151
- print(f"⚠️ Generation failed: {e}")
152
- return "⚠️ Error: Could not generate an answer at the moment."
 
153
 
154
  # ==========================================================
155
- # 8️⃣ Local Test
156
  # ==========================================================
157
  if __name__ == "__main__":
158
  dummy_chunks = [
@@ -163,7 +185,11 @@ if __name__ == "__main__":
163
  from vectorstore import build_faiss_index
164
 
165
  index = build_faiss_index([
166
- _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
 
 
 
 
167
  for chunk in dummy_chunks
168
  ])
169
 
 
3
  -------------------------------------
4
  Handles:
5
  • Query embedding (SentenceTransformer / E5-compatible)
6
+ • Chunk retrieval (FAISS with neighborhood merging + re-ranking)
7
+ • Answer generation (OpenAI GPT-4o-mini FLAN-T5 fallback)
8
  Optimized for Hugging Face Spaces & Streamlit.
9
  """
10
 
 
14
  from sklearn.metrics.pairwise import cosine_similarity
15
  from vectorstore import search_faiss
16
 
17
+ print("✅ qa.py loaded from:", __file__)
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # ==========================================================
20
+ # 1️⃣ Hugging Face Cache Setup
21
  # ==========================================================
22
  CACHE_DIR = "/tmp/hf_cache"
23
  os.makedirs(CACHE_DIR, exist_ok=True)
 
29
  })
30
 
31
  # ==========================================================
32
+ # 2️⃣ Query Embedding Model
33
  # ==========================================================
34
  try:
35
  _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
36
+ print("✅ Loaded query model: intfloat/e5-small-v2")
37
  except Exception as e:
38
+ print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
39
  _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
 
40
 
41
  # ==========================================================
42
+ # 3️⃣ LLM Setup: OpenAI (primary) + FLAN (fallback)
43
  # ==========================================================
44
+ USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
45
+ _answer_model = None # ensures it's always defined
46
+
47
+ if USE_OPENAI:
48
+ try:
49
+ from openai import OpenAI
50
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
51
+ print("✅ Using OpenAI GPT-4o-mini for answer generation")
52
+ except Exception as e:
53
+ print(f"⚠️ Failed to initialize OpenAI client: {e}")
54
+ USE_OPENAI = False
55
+
56
+ # Always prepare fallback safely
57
+ try:
58
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
59
  MODEL_NAME = "google/flan-t5-base"
60
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
61
  _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
62
  _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
63
+ print("💡 Fallback FLAN-T5 ready.")
64
+ except Exception as e:
65
+ print(f"⚠️ Could not initialize FLAN fallback: {e}")
66
 
67
  # ==========================================================
68
+ # 4️⃣ Prompt Template
69
  # ==========================================================
70
  PROMPT_TEMPLATE = """
71
  You are an enterprise knowledge assistant.
72
+ Use ONLY the CONTEXT below to answer the QUESTION clearly, completely, and factually.
73
+ If the context doesn’t contain the answer, reply exactly:
74
  "I don't know based on the provided document."
75
 
76
  ---
 
84
  """
85
 
86
  # ==========================================================
87
+ # 5️⃣ Chunk Retrieval Function
88
  # ==========================================================
89
  def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
90
+ """Retrieve top-K relevant chunks, merge nearby ones, and re-rank by semantic similarity."""
91
  if not index or not chunks:
92
  return []
93
 
94
  try:
95
+ # Step 1: Encode the query
96
  query_emb = _query_model.encode(
97
  [f"query: {query.strip()}"],
98
  convert_to_numpy=True,
99
  normalize_embeddings=True
100
  )[0]
101
 
102
+ # Step 2: Initial FAISS retrieval
103
  distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
104
+
105
+ # Step 3: Merge neighboring chunks
106
  merged_chunks = []
107
  for idx in indices[0]:
108
  neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
109
  merged_chunks.append(" ".join(neighbors))
110
 
111
+ # Step 4: Re-rank using cosine similarity
112
  chunk_vecs = np.array([
113
  _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
114
  for c in merged_chunks
 
116
  scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0]
117
  sorted_indices = np.argsort(scores)[::-1]
118
 
119
+ # Step 5: Return top-ranked merged chunks
120
  return [merged_chunks[i] for i in sorted_indices[:top_k]]
121
 
122
  except Exception as e:
123
  print(f"⚠️ Retrieval error: {e}")
124
  return []
125
 
126
+
127
  # ==========================================================
128
+ # 6️⃣ Answer Generation Function
129
  # ==========================================================
130
  def generate_answer(query: str, retrieved_chunks: list):
131
+ """Generate factual, context-grounded answers using OpenAI or fallback FLAN-T5."""
132
  if not retrieved_chunks:
133
  return "Sorry, I couldn’t find relevant information in the document."
134
 
135
+ # Build full context
136
+ context = "\n\n".join([
137
+ f"[Chunk {i+1}]: {chunk.strip()}"
138
+ for i, chunk in enumerate(retrieved_chunks)
139
+ ])
140
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
141
 
142
+ # --- Try OpenAI first ---
143
+ if USE_OPENAI:
144
+ try:
145
  response = client.chat.completions.create(
146
  model="gpt-4o-mini",
147
  messages=[
148
+ {"role": "system", "content": "You are a precise enterprise document assistant."},
149
  {"role": "user", "content": prompt},
150
  ],
151
  temperature=0.4,
 
153
  )
154
  return response.choices[0].message.content.strip()
155
 
156
+ except Exception as e:
157
+ print(f"⚠️ OpenAI generation failed: {e}. Switching to fallback...")
 
158
 
159
+ # --- Fallback to FLAN-T5 ---
160
+ try:
161
+ if _answer_model:
162
+ result = _answer_model(
163
+ prompt,
164
+ max_new_tokens=600,
165
+ do_sample=False,
166
+ temperature=0.3
167
+ )
168
+ return result[0]["generated_text"].strip()
169
+ else:
170
+ return "⚠️ Error: Fallback model not available."
171
  except Exception as e:
172
+ print(f"⚠️ Fallback model failed: {e}")
173
+ return "⚠️ Error: Both OpenAI and fallback generation failed."
174
+
175
 
176
  # ==========================================================
177
+ # 7️⃣ Local Test
178
  # ==========================================================
179
  if __name__ == "__main__":
180
  dummy_chunks = [
 
185
  from vectorstore import build_faiss_index
186
 
187
  index = build_faiss_index([
188
+ _query_model.encode(
189
+ [f"passage: {chunk}"],
190
+ convert_to_numpy=True,
191
+ normalize_embeddings=True
192
+ )[0]
193
  for chunk in dummy_chunks
194
  ])
195