Shubham170793 commited on
Commit
c28ff15
·
verified ·
1 Parent(s): 16a0f13

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +68 -118
src/qa.py CHANGED
@@ -47,147 +47,97 @@ except Exception as e:
47
  print("✅ Loaded fallback model: all-MiniLM-L6-v2")
48
 
49
  # ==========================================================
50
- # 3️⃣ LLM for Answer Generation (FLAN-T5)
51
  # ==========================================================
52
- MODEL_NAME = "google/flan-t5-base" # Switch to 'large' if you have more memory
53
- print(f"✅ Loading LLM: {MODEL_NAME}")
54
-
55
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
56
- _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
57
-
58
- _answer_model = pipeline(
59
- "text2text-generation",
60
- model=_model,
61
- tokenizer=_tokenizer,
62
- device=-1 # CPU-safe
63
- )
64
-
65
- # ==========================================================
66
- # 4️⃣ Prompt Template
67
- # ==========================================================
68
- PROMPT_TEMPLATE = """
69
- You are an expert enterprise knowledge assistant.
70
- Use ONLY the CONTEXT below to answer the QUESTION clearly, completely, and factually.
71
- If the context doesn’t contain the answer, reply exactly:
72
- "I don't know based on the provided document."
73
-
74
- ---
75
- Context:
76
- {context}
77
- ---
78
- Question:
79
- {query}
80
- ---
81
- Answer:
82
- """
83
-
84
- # ==========================================================
85
- # 5️⃣ Chunk Retrieval Function (Improved for Large Docs)
86
- # ==========================================================
87
- def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
88
- """
89
- Retrieve top-K relevant chunks and merge nearby ones for context continuity.
90
- Re-ranks using cosine similarity to improve semantic precision.
91
- """
92
- if not index or not chunks:
93
- return []
94
-
95
- try:
96
- # Step 1: Encode query
97
- query_emb = _query_model.encode(
98
- [f"query: {query.strip()}"],
99
- convert_to_numpy=True,
100
- normalize_embeddings=True
101
- )[0]
102
-
103
- # Step 2: Initial FAISS retrieval
104
- distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
105
-
106
- # Step 3: Merge neighbors for more complete context
107
- merged_chunks = []
108
- for idx in indices[0]:
109
- neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
110
- merged_chunks.append(" ".join(neighbors))
111
-
112
- # Step 4: Re-rank results with cosine similarity
113
- chunk_vecs = np.array([
114
- _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
115
- for c in merged_chunks
116
- ])
117
- scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0]
118
- sorted_indices = np.argsort(scores)[::-1]
119
-
120
- # Step 5: Return top ranked chunks
121
- return [merged_chunks[i] for i in sorted_indices[:top_k]]
122
-
123
- except Exception as e:
124
- print(f"⚠️ Retrieval error: {e}")
125
- return []
126
 
127
 
128
  # ==========================================================
129
- # 6️⃣ Answer Generation Function (Long + Structured)
130
  # ==========================================================
131
  def generate_answer(query: str, retrieved_chunks: list):
132
  """
133
- Generates a well-structured answer using FLAN-T5.
134
- - Supports multi-step reasoning (if context mentions steps or procedures)
135
- - Ensures completeness for large-document answers
136
  """
137
  if not retrieved_chunks:
138
  return "Sorry, I couldn’t find relevant information in the document."
139
 
140
- # Merge retrieved chunks
141
  context = "\n\n".join([
142
- f"[Chunk {i+1}]: {chunk.strip()}"
143
- for i, chunk in enumerate(retrieved_chunks)
144
  ])
145
 
146
- prompt = PROMPT_TEMPLATE.format(context=context, query=query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  try:
149
  result = _answer_model(
150
- prompt,
151
- max_new_tokens=600, # allow multi-step responses
152
- do_sample=False, # deterministic for factual QA
153
  temperature=0.3,
154
  repetition_penalty=1.1
155
  )
156
-
157
  answer = result[0]["generated_text"].strip()
158
-
159
- # Safety filter: ensure the model doesn’t hallucinate
160
  if "I don't know" in answer:
161
  return "I don't know based on the provided document."
162
-
163
  return answer
164
-
165
  except Exception as e:
166
- print(f"⚠️ Generation failed: {e}")
167
  return "⚠️ Error: Could not generate an answer at the moment."
168
-
169
-
170
- # ==========================================================
171
- # 7️⃣ Optional Local Test
172
- # ==========================================================
173
- if __name__ == "__main__":
174
- dummy_chunks = [
175
- "Step 1: Open the dashboard and navigate to reports.",
176
- "Step 2: Click 'Export' to download a CSV summary.",
177
- "Step 3: Review the generated report in your downloads folder."
178
- ]
179
- from vectorstore import build_faiss_index
180
-
181
- index = build_faiss_index([
182
- _query_model.encode(
183
- [f"passage: {chunk}"],
184
- convert_to_numpy=True,
185
- normalize_embeddings=True
186
- )[0]
187
- for chunk in dummy_chunks
188
- ])
189
-
190
- query = "What are the steps to export a report?"
191
- retrieved = retrieve_chunks(query, index, dummy_chunks)
192
- print("🔍 Retrieved:", retrieved)
193
- print("💬 Answer:", generate_answer(query, retrieved))
 
47
  print("✅ Loaded fallback model: all-MiniLM-L6-v2")
48
 
49
  # ==========================================================
50
+ # 3️⃣ LLM for Answer Generation (OpenAI GPT with Flan fallback)
51
  # ==========================================================
52
+ from openai import OpenAI
53
+ client = None
54
+
55
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
56
+ if OPENAI_API_KEY:
57
+ client = OpenAI(api_key=OPENAI_API_KEY)
58
+ LLM_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
59
+ print(f"✅ Using OpenAI model: {LLM_MODEL}")
60
+ else:
61
+ # Fallback to Flan if no API key is provided
62
+ MODEL_NAME = "google/flan-t5-base"
63
+ print(f"⚠️ No OpenAI key found. Using fallback model: {MODEL_NAME}")
64
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
65
+ _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
66
+ _answer_model = pipeline(
67
+ "text2text-generation",
68
+ model=_model,
69
+ tokenizer=_tokenizer,
70
+ device=-1
71
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  # ==========================================================
75
+ # 6️⃣ Answer Generation Function (GPT or Flan fallback)
76
  # ==========================================================
77
  def generate_answer(query: str, retrieved_chunks: list):
78
  """
79
+ Generates grounded, context-only answers.
80
+ Uses GPT (preferred) or Flan-T5 (fallback) for response synthesis.
 
81
  """
82
  if not retrieved_chunks:
83
  return "Sorry, I couldn’t find relevant information in the document."
84
 
85
+ # Combine retrieved chunks
86
  context = "\n\n".join([
87
+ f"[Chunk {i+1}]: {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks)
 
88
  ])
89
 
90
+ # --- PROMPT TEMPLATE ---
91
+ system_prompt = """You are an enterprise knowledge assistant.
92
+ Use ONLY the provided context to answer the user's question accurately.
93
+ If the answer is not explicitly in the context, reply exactly:
94
+ "I don't know based on the provided document."
95
+ Be factual, concise, and structured when relevant.
96
+ """
97
+
98
+ user_prompt = f"""
99
+ Context:
100
+ {context}
101
+
102
+ Question:
103
+ {query}
104
+
105
+ Answer:
106
+ """
107
 
108
+ # --- Use OpenAI GPT if key available ---
109
+ if client:
110
+ try:
111
+ response = client.chat.completions.create(
112
+ model=LLM_MODEL,
113
+ messages=[
114
+ {"role": "system", "content": system_prompt},
115
+ {"role": "user", "content": user_prompt},
116
+ ],
117
+ temperature=0.2, # factual, low creativity
118
+ max_tokens=500,
119
+ presence_penalty=0,
120
+ frequency_penalty=0
121
+ )
122
+ answer = response.choices[0].message.content.strip()
123
+ return answer
124
+ except Exception as e:
125
+ print(f"⚠️ OpenAI generation failed: {e}")
126
+ return "⚠️ Error: Could not generate an answer at the moment."
127
+
128
+ # --- Otherwise, use Flan-T5 fallback ---
129
  try:
130
  result = _answer_model(
131
+ PROMPT_TEMPLATE.format(context=context, query=query),
132
+ max_new_tokens=600,
133
+ do_sample=False,
134
  temperature=0.3,
135
  repetition_penalty=1.1
136
  )
 
137
  answer = result[0]["generated_text"].strip()
 
 
138
  if "I don't know" in answer:
139
  return "I don't know based on the provided document."
 
140
  return answer
 
141
  except Exception as e:
142
+ print(f"⚠️ Flan generation failed: {e}")
143
  return "⚠️ Error: Could not generate an answer at the moment."