Shubham170793 commited on
Commit
2f0b456
·
verified ·
1 Parent(s): a0dee9a

Update src/qa.py

Browse files
Files changed (1) hide show
  1. src/qa.py +22 -6
src/qa.py CHANGED
@@ -104,23 +104,39 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
104
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = True):
105
  """
106
  Generates answers using Phi-2.
107
- reasoning_mode=True → reasoning + external knowledge
108
  reasoning_mode=False → strict chunk-only factual mode
109
  """
110
  if not retrieved_chunks:
111
  return "Sorry, I couldn’t find relevant information in the document."
112
 
 
113
  context = "\n".join([chunk.strip() for chunk in retrieved_chunks])
114
- prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
 
 
 
 
115
 
116
  try:
 
117
  result = _answer_model(
118
  prompt,
119
- max_new_tokens=180,
120
- temperature=0.4 if reasoning_mode else 0.2,
121
- do_sample=False,
 
122
  )
123
- return result[0]["generated_text"].split("ANSWER:")[-1].strip()
 
 
 
 
 
 
 
 
 
124
  except Exception as e:
125
  print(f"⚠️ Generation failed: {e}")
126
  return "⚠️ Error: Could not generate an answer."
 
104
  def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = True):
105
  """
106
  Generates answers using Phi-2.
107
+ reasoning_mode=True → reasoning + external knowledge
108
  reasoning_mode=False → strict chunk-only factual mode
109
  """
110
  if not retrieved_chunks:
111
  return "Sorry, I couldn’t find relevant information in the document."
112
 
113
+ # Merge retrieved context
114
  context = "\n".join([chunk.strip() for chunk in retrieved_chunks])
115
+
116
+ # Select prompt based on mode
117
+ prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
118
+ context=context, query=query
119
+ )
120
 
121
  try:
122
+ # ⚡ Speed-optimized generation
123
  result = _answer_model(
124
  prompt,
125
+ max_new_tokens=140 if reasoning_mode else 100, # ⏱ shorter output = faster
126
+ temperature=0.3 if reasoning_mode else 0.1, # balanced creativity
127
+ do_sample=False, # ✅ greedy decoding = fastest
128
+ repetition_penalty=1.1, # avoids repetitive phrasing
129
  )
130
+
131
+ # Cleanly extract the answer
132
+ answer = result[0]["generated_text"].split("ANSWER:")[-1].strip()
133
+
134
+ # Safety: truncate overly long rambles
135
+ if len(answer.split()) > 150:
136
+ answer = " ".join(answer.split()[:150]) + "..."
137
+
138
+ return answer
139
+
140
  except Exception as e:
141
  print(f"⚠️ Generation failed: {e}")
142
  return "⚠️ Error: Could not generate an answer."