Shubham170793 commited on
Commit
743f89e
·
verified ·
1 Parent(s): 09c2f03

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +36 -33
src/qa.py CHANGED
@@ -31,7 +31,6 @@ os.environ.update({
31
  # ==========================================================
32
  # 2️⃣ Query Embedding Model
33
  # ==========================================================
34
- # Use E5-small-v2 for retrieval consistency with embeddings.py
35
  try:
36
  _query_model = SentenceTransformer(
37
  "intfloat/e5-small-v2",
@@ -49,7 +48,7 @@ except Exception as e:
49
  # ==========================================================
50
  # 3️⃣ LLM for Answer Generation (FLAN-T5)
51
  # ==========================================================
52
- MODEL_NAME = "google/flan-t5-base" # switch to 'large' if RAM allows
53
  print(f"✅ Loading LLM: {MODEL_NAME}")
54
 
55
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
@@ -59,15 +58,15 @@ _answer_model = pipeline(
59
  "text2text-generation",
60
  model=_model,
61
  tokenizer=_tokenizer,
62
- device=-1 # CPU-safe for Spaces
63
  )
64
 
65
  # ==========================================================
66
- # 4️⃣ Prompt Template (concise and factual)
67
  # ==========================================================
68
  PROMPT_TEMPLATE = """
69
- You are an expert enterprise assistant.
70
- Using ONLY the CONTEXT below, answer the QUESTION clearly and factually.
71
  If the context doesn’t contain the answer, reply exactly:
72
  "I don't know based on the provided document."
73
 
@@ -93,7 +92,6 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
93
  return []
94
 
95
  try:
96
- # E5 expects 'query:' prefix for better retrieval accuracy
97
  query_emb = _query_model.encode(
98
  [f"query: {query.strip()}"],
99
  convert_to_numpy=True,
@@ -114,45 +112,47 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
114
  def generate_answer(query: str, retrieved_chunks: list):
115
  """
116
  Generates an answer using FLAN-T5 and retrieved chunks as context.
 
117
  """
118
  if not retrieved_chunks:
119
  return "Sorry, I couldn’t find relevant information in the document."
120
 
121
- # Merge retrieved chunks for context
122
- context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])
 
 
 
123
 
124
- # Build structured prompt
125
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
126
 
127
  try:
128
- result = _answer_model(
129
- prompt,
130
- max_new_tokens=350, # allow longer, more complete answers
131
- do_sample=True, # enable sampling for natural flow
132
- temperature=0.7, # slightly higher = more expressive responses
133
- top_p=0.95, # nucleus sampling for coherence
134
- repetition_penalty=1.2 # discourages repetitive phrasing
135
- )
136
-
137
- answer = result[0]["generated_text"].strip()
138
-
139
- # 🧩 If the model outputs something too short, expand gracefully
140
- if len(answer.split()) < 8:
141
- answer = (
142
- "The document mentions this briefly. Based on the context, here's what it suggests: "
143
- + answer
144
  )
145
 
146
- return answer
147
 
148
- except Exception as e:
149
- print(f"⚠️ Generation failed: {e}")
150
- return "⚠️ Error: Could not generate an answer at the moment."
 
 
 
151
 
 
 
 
 
 
152
 
153
 
154
  # ==========================================================
155
- # 7️⃣ Optional Local Test
156
  # ==========================================================
157
  if __name__ == "__main__":
158
  dummy_chunks = [
@@ -161,10 +161,13 @@ if __name__ == "__main__":
161
  "Integration with SAP ERP allows for seamless data synchronization."
162
  ]
163
  from vectorstore import build_faiss_index
164
- import numpy as np
165
 
166
  index = build_faiss_index([
167
- _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
 
 
 
 
168
  for chunk in dummy_chunks
169
  ])
170
 
 
31
  # ==========================================================
32
  # 2️⃣ Query Embedding Model
33
  # ==========================================================
 
34
  try:
35
  _query_model = SentenceTransformer(
36
  "intfloat/e5-small-v2",
 
48
  # ==========================================================
49
  # 3️⃣ LLM for Answer Generation (FLAN-T5)
50
  # ==========================================================
51
+ MODEL_NAME = "google/flan-t5-base" # Switch to 'large' if enough memory
52
  print(f"✅ Loading LLM: {MODEL_NAME}")
53
 
54
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
 
58
  "text2text-generation",
59
  model=_model,
60
  tokenizer=_tokenizer,
61
+ device=-1 # CPU-safe (Hugging Face Spaces)
62
  )
63
 
64
  # ==========================================================
65
+ # 4️⃣ Prompt Template
66
  # ==========================================================
67
  PROMPT_TEMPLATE = """
68
+ You are an expert enterprise knowledge assistant.
69
+ Use ONLY the CONTEXT below to answer the QUESTION clearly, factually, and completely.
70
  If the context doesn’t contain the answer, reply exactly:
71
  "I don't know based on the provided document."
72
 
 
92
  return []
93
 
94
  try:
 
95
  query_emb = _query_model.encode(
96
  [f"query: {query.strip()}"],
97
  convert_to_numpy=True,
 
112
  def generate_answer(query: str, retrieved_chunks: list):
113
  """
114
  Generates an answer using FLAN-T5 and retrieved chunks as context.
115
+ Includes dynamic length, sampling for expressiveness, and fallback logic.
116
  """
117
  if not retrieved_chunks:
118
  return "Sorry, I couldn’t find relevant information in the document."
119
 
120
+ # Merge retrieved chunks into one coherent context
121
+ context = "\n\n".join([
122
+ f"[Chunk {i+1}]: {chunk.strip()}"
123
+ for i, chunk in enumerate(retrieved_chunks)
124
+ ])
125
 
 
126
  prompt = PROMPT_TEMPLATE.format(context=context, query=query)
127
 
128
  try:
129
+ result = _answer_model(
130
+ prompt,
131
+ max_new_tokens=400, # allow more elaborate responses
132
+ do_sample=True, # enable natural variability
133
+ temperature=0.7, # creativity balance
134
+ top_p=0.9, # nucleus sampling for relevance
135
+ repetition_penalty=1.15 # discourage repetition
 
 
 
 
 
 
 
 
 
136
  )
137
 
138
+ answer = result[0]["generated_text"].strip()
139
 
140
+ # 🧩 Handle overly short answers
141
+ if len(answer.split()) < 8:
142
+ answer = (
143
+ "The document briefly mentions this. Based on the context, here's what it implies: "
144
+ + answer
145
+ )
146
 
147
+ return answer
148
+
149
+ except Exception as e:
150
+ print(f"⚠️ Generation failed: {e}")
151
+ return "⚠️ Error: Could not generate an answer at the moment."
152
 
153
 
154
  # ==========================================================
155
+ # 7️⃣ Optional Local Test (runs only in dev mode)
156
  # ==========================================================
157
  if __name__ == "__main__":
158
  dummy_chunks = [
 
161
  "Integration with SAP ERP allows for seamless data synchronization."
162
  ]
163
  from vectorstore import build_faiss_index
 
164
 
165
  index = build_faiss_index([
166
+ _query_model.encode(
167
+ [f"passage: {chunk}"],
168
+ convert_to_numpy=True,
169
+ normalize_embeddings=True
170
+ )[0]
171
  for chunk in dummy_chunks
172
  ])
173