Update src/qa.py
Browse files
src/qa.py
CHANGED
|
@@ -112,27 +112,45 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
|
|
| 112 |
return []
|
| 113 |
|
| 114 |
# ==========================================================
|
| 115 |
-
# 6️⃣ Answer Generation (
|
| 116 |
# ==========================================================
|
| 117 |
def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
|
| 118 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
if not retrieved_chunks:
|
| 120 |
return "Sorry, I couldn’t find relevant information in the document."
|
| 121 |
|
|
|
|
| 122 |
context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
|
| 123 |
-
prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
try:
|
|
|
|
| 126 |
result = _answer_model(
|
| 127 |
prompt,
|
| 128 |
-
max_new_tokens=
|
| 129 |
-
temperature=0.
|
| 130 |
-
|
|
|
|
|
|
|
| 131 |
pad_token_id=_tokenizer.eos_token_id,
|
| 132 |
)
|
|
|
|
|
|
|
| 133 |
answer = result[0]["generated_text"].strip()
|
| 134 |
if "Answer:" in answer:
|
| 135 |
answer = answer.split("Answer:")[-1].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
return answer
|
| 137 |
|
| 138 |
except Exception as e:
|
|
|
|
| 112 |
return []
|
| 113 |
|
| 114 |
# ==========================================================
|
| 115 |
+
# 6️⃣ Answer Generation (Enhanced — Balanced Reasoning + Speed)
|
| 116 |
# ==========================================================
|
| 117 |
def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
|
| 118 |
+
"""
|
| 119 |
+
Generates answers with or without reasoning.
|
| 120 |
+
- reasoning_mode=True → uses world knowledge + logic (slower, more explanatory)
|
| 121 |
+
- reasoning_mode=False → sticks to chunks for fast factual accuracy
|
| 122 |
+
"""
|
| 123 |
if not retrieved_chunks:
|
| 124 |
return "Sorry, I couldn’t find relevant information in the document."
|
| 125 |
|
| 126 |
+
# Build the prompt with selected mode
|
| 127 |
context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
|
| 128 |
+
prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
|
| 129 |
+
context=context,
|
| 130 |
+
query=query
|
| 131 |
+
)
|
| 132 |
|
| 133 |
try:
|
| 134 |
+
# ⚙️ Tuned generation settings for balance
|
| 135 |
result = _answer_model(
|
| 136 |
prompt,
|
| 137 |
+
max_new_tokens=180 if reasoning_mode else 120, # let reasoning finish sentences
|
| 138 |
+
temperature=0.6 if reasoning_mode else 0.2, # more creative but controlled
|
| 139 |
+
top_p=0.9 if reasoning_mode else 0.8, # smooth probability cutoff
|
| 140 |
+
do_sample=reasoning_mode, # only sample when reasoning
|
| 141 |
+
early_stopping=True,
|
| 142 |
pad_token_id=_tokenizer.eos_token_id,
|
| 143 |
)
|
| 144 |
+
|
| 145 |
+
# Clean answer text
|
| 146 |
answer = result[0]["generated_text"].strip()
|
| 147 |
if "Answer:" in answer:
|
| 148 |
answer = answer.split("Answer:")[-1].strip()
|
| 149 |
+
|
| 150 |
+
# ✅ Prevents mid-sentence cutoffs
|
| 151 |
+
if answer.endswith(("and", "or", ",")):
|
| 152 |
+
answer += " ..."
|
| 153 |
+
|
| 154 |
return answer
|
| 155 |
|
| 156 |
except Exception as e:
|