Update src/qa.py
Browse files
src/qa.py
CHANGED
|
@@ -31,7 +31,6 @@ os.environ.update({
|
|
| 31 |
# ==========================================================
|
| 32 |
# 2️⃣ Query Embedding Model
|
| 33 |
# ==========================================================
|
| 34 |
-
# Use E5-small-v2 for retrieval consistency with embeddings.py
|
| 35 |
try:
|
| 36 |
_query_model = SentenceTransformer(
|
| 37 |
"intfloat/e5-small-v2",
|
|
@@ -49,7 +48,7 @@ except Exception as e:
|
|
| 49 |
# ==========================================================
|
| 50 |
# 3️⃣ LLM for Answer Generation (FLAN-T5)
|
| 51 |
# ==========================================================
|
| 52 |
-
MODEL_NAME = "google/flan-t5-base" #
|
| 53 |
print(f"✅ Loading LLM: {MODEL_NAME}")
|
| 54 |
|
| 55 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
|
|
@@ -59,15 +58,15 @@ _answer_model = pipeline(
|
|
| 59 |
"text2text-generation",
|
| 60 |
model=_model,
|
| 61 |
tokenizer=_tokenizer,
|
| 62 |
-
device=-1 # CPU-safe
|
| 63 |
)
|
| 64 |
|
| 65 |
# ==========================================================
|
| 66 |
-
# 4️⃣ Prompt Template
|
| 67 |
# ==========================================================
|
| 68 |
PROMPT_TEMPLATE = """
|
| 69 |
-
You are an expert enterprise assistant.
|
| 70 |
-
|
| 71 |
If the context doesn’t contain the answer, reply exactly:
|
| 72 |
"I don't know based on the provided document."
|
| 73 |
|
|
@@ -93,7 +92,6 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
|
|
| 93 |
return []
|
| 94 |
|
| 95 |
try:
|
| 96 |
-
# E5 expects 'query:' prefix for better retrieval accuracy
|
| 97 |
query_emb = _query_model.encode(
|
| 98 |
[f"query: {query.strip()}"],
|
| 99 |
convert_to_numpy=True,
|
|
@@ -114,45 +112,47 @@ def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
|
|
| 114 |
def generate_answer(query: str, retrieved_chunks: list):
|
| 115 |
"""
|
| 116 |
Generates an answer using FLAN-T5 and retrieved chunks as context.
|
|
|
|
| 117 |
"""
|
| 118 |
if not retrieved_chunks:
|
| 119 |
return "Sorry, I couldn’t find relevant information in the document."
|
| 120 |
|
| 121 |
-
# Merge retrieved chunks
|
| 122 |
-
context = "\n\n".join([
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
# Build structured prompt
|
| 125 |
prompt = PROMPT_TEMPLATE.format(context=context, query=query)
|
| 126 |
|
| 127 |
try:
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
answer = result[0]["generated_text"].strip()
|
| 138 |
-
|
| 139 |
-
# 🧩 If the model outputs something too short, expand gracefully
|
| 140 |
-
if len(answer.split()) < 8:
|
| 141 |
-
answer = (
|
| 142 |
-
"The document mentions this briefly. Based on the context, here's what it suggests: "
|
| 143 |
-
+ answer
|
| 144 |
)
|
| 145 |
|
| 146 |
-
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
|
| 154 |
# ==========================================================
|
| 155 |
-
# 7️⃣ Optional Local Test
|
| 156 |
# ==========================================================
|
| 157 |
if __name__ == "__main__":
|
| 158 |
dummy_chunks = [
|
|
@@ -161,10 +161,13 @@ if __name__ == "__main__":
|
|
| 161 |
"Integration with SAP ERP allows for seamless data synchronization."
|
| 162 |
]
|
| 163 |
from vectorstore import build_faiss_index
|
| 164 |
-
import numpy as np
|
| 165 |
|
| 166 |
index = build_faiss_index([
|
| 167 |
-
_query_model.encode(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
for chunk in dummy_chunks
|
| 169 |
])
|
| 170 |
|
|
|
|
| 31 |
# ==========================================================
|
| 32 |
# 2️⃣ Query Embedding Model
|
| 33 |
# ==========================================================
|
|
|
|
| 34 |
try:
|
| 35 |
_query_model = SentenceTransformer(
|
| 36 |
"intfloat/e5-small-v2",
|
|
|
|
| 48 |
# ==========================================================
|
| 49 |
# 3️⃣ LLM for Answer Generation (FLAN-T5)
|
| 50 |
# ==========================================================
|
| 51 |
+
MODEL_NAME = "google/flan-t5-base" # Switch to 'large' if enough memory
|
| 52 |
print(f"✅ Loading LLM: {MODEL_NAME}")
|
| 53 |
|
| 54 |
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
|
|
|
|
| 58 |
"text2text-generation",
|
| 59 |
model=_model,
|
| 60 |
tokenizer=_tokenizer,
|
| 61 |
+
device=-1 # CPU-safe (Hugging Face Spaces)
|
| 62 |
)
|
| 63 |
|
| 64 |
# ==========================================================
|
| 65 |
+
# 4️⃣ Prompt Template
|
| 66 |
# ==========================================================
|
| 67 |
PROMPT_TEMPLATE = """
|
| 68 |
+
You are an expert enterprise knowledge assistant.
|
| 69 |
+
Use ONLY the CONTEXT below to answer the QUESTION clearly, factually, and completely.
|
| 70 |
If the context doesn’t contain the answer, reply exactly:
|
| 71 |
"I don't know based on the provided document."
|
| 72 |
|
|
|
|
| 92 |
return []
|
| 93 |
|
| 94 |
try:
|
|
|
|
| 95 |
query_emb = _query_model.encode(
|
| 96 |
[f"query: {query.strip()}"],
|
| 97 |
convert_to_numpy=True,
|
|
|
|
| 112 |
def generate_answer(query: str, retrieved_chunks: list):
|
| 113 |
"""
|
| 114 |
Generates an answer using FLAN-T5 and retrieved chunks as context.
|
| 115 |
+
Includes dynamic length, sampling for expressiveness, and fallback logic.
|
| 116 |
"""
|
| 117 |
if not retrieved_chunks:
|
| 118 |
return "Sorry, I couldn’t find relevant information in the document."
|
| 119 |
|
| 120 |
+
# Merge retrieved chunks into one coherent context
|
| 121 |
+
context = "\n\n".join([
|
| 122 |
+
f"[Chunk {i+1}]: {chunk.strip()}"
|
| 123 |
+
for i, chunk in enumerate(retrieved_chunks)
|
| 124 |
+
])
|
| 125 |
|
|
|
|
| 126 |
prompt = PROMPT_TEMPLATE.format(context=context, query=query)
|
| 127 |
|
| 128 |
try:
|
| 129 |
+
result = _answer_model(
|
| 130 |
+
prompt,
|
| 131 |
+
max_new_tokens=400, # allow more elaborate responses
|
| 132 |
+
do_sample=True, # enable natural variability
|
| 133 |
+
temperature=0.7, # creativity balance
|
| 134 |
+
top_p=0.9, # nucleus sampling for relevance
|
| 135 |
+
repetition_penalty=1.15 # discourage repetition
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
)
|
| 137 |
|
| 138 |
+
answer = result[0]["generated_text"].strip()
|
| 139 |
|
| 140 |
+
# 🧩 Handle overly short answers
|
| 141 |
+
if len(answer.split()) < 8:
|
| 142 |
+
answer = (
|
| 143 |
+
"The document briefly mentions this. Based on the context, here's what it implies: "
|
| 144 |
+
+ answer
|
| 145 |
+
)
|
| 146 |
|
| 147 |
+
return answer
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"⚠️ Generation failed: {e}")
|
| 151 |
+
return "⚠️ Error: Could not generate an answer at the moment."
|
| 152 |
|
| 153 |
|
| 154 |
# ==========================================================
|
| 155 |
+
# 7️⃣ Optional Local Test (runs only in dev mode)
|
| 156 |
# ==========================================================
|
| 157 |
if __name__ == "__main__":
|
| 158 |
dummy_chunks = [
|
|
|
|
| 161 |
"Integration with SAP ERP allows for seamless data synchronization."
|
| 162 |
]
|
| 163 |
from vectorstore import build_faiss_index
|
|
|
|
| 164 |
|
| 165 |
index = build_faiss_index([
|
| 166 |
+
_query_model.encode(
|
| 167 |
+
[f"passage: {chunk}"],
|
| 168 |
+
convert_to_numpy=True,
|
| 169 |
+
normalize_embeddings=True
|
| 170 |
+
)[0]
|
| 171 |
for chunk in dummy_chunks
|
| 172 |
])
|
| 173 |
|