Shubham170793's picture
Update src/qa.py
09c2f03 verified
raw
history blame
5.87 kB
"""
qa.py — Retrieval + Generation Layer
-------------------------------------
Handles:
• Query embedding (SentenceTransformer / E5-compatible)
• Chunk retrieval (FAISS)
• Answer generation (Flan-T5)
Optimized for Hugging Face Spaces & Streamlit.
"""
import os
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from vectorstore import search_faiss
print("✅ qa.py loaded from:", __file__)
# ==========================================================
# 1️⃣ Hugging Face Cache Setup (Safe for Spaces)
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR,
"HF_DATASETS_CACHE": CACHE_DIR,
"HF_MODULES_CACHE": CACHE_DIR
})
# ==========================================================
# 2️⃣ Query Embedding Model
# ==========================================================
# Use E5-small-v2 for retrieval consistency with embeddings.py
try:
_query_model = SentenceTransformer(
"intfloat/e5-small-v2",
cache_folder=CACHE_DIR
)
print("✅ Loaded query model: intfloat/e5-small-v2")
except Exception as e:
print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
_query_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
cache_folder=CACHE_DIR
)
print("✅ Loaded fallback model: all-MiniLM-L6-v2")
# ==========================================================
# 3️⃣ LLM for Answer Generation (FLAN-T5)
# ==========================================================
MODEL_NAME = "google/flan-t5-base" # switch to 'large' if RAM allows
print(f"✅ Loading LLM: {MODEL_NAME}")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_answer_model = pipeline(
"text2text-generation",
model=_model,
tokenizer=_tokenizer,
device=-1 # CPU-safe for Spaces
)
# ==========================================================
# 4️⃣ Prompt Template (concise and factual)
# ==========================================================
PROMPT_TEMPLATE = """
You are an expert enterprise assistant.
Using ONLY the CONTEXT below, answer the QUESTION clearly and factually.
If the context doesn’t contain the answer, reply exactly:
"I don't know based on the provided document."
---
Context:
{context}
---
Question:
{query}
---
Answer:
"""
# ==========================================================
# 5️⃣ Chunk Retrieval Function
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
"""
Encodes the user query and retrieves top-k relevant chunks via FAISS.
Uses 'query:' prefix (E5 training style) for semantic alignment.
"""
if not index or not chunks:
return []
try:
# E5 expects 'query:' prefix for better retrieval accuracy
query_emb = _query_model.encode(
[f"query: {query.strip()}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
results = search_faiss(query_emb, index, chunks, top_k)
return results
except Exception as e:
print(f"⚠️ Retrieval error: {e}")
return []
# ==========================================================
# 6️⃣ Answer Generation Function
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list):
"""
Generates an answer using FLAN-T5 and retrieved chunks as context.
"""
if not retrieved_chunks:
return "Sorry, I couldn’t find relevant information in the document."
# Merge retrieved chunks for context
context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])
# Build structured prompt
prompt = PROMPT_TEMPLATE.format(context=context, query=query)
try:
result = _answer_model(
prompt,
max_new_tokens=350, # allow longer, more complete answers
do_sample=True, # enable sampling for natural flow
temperature=0.7, # slightly higher = more expressive responses
top_p=0.95, # nucleus sampling for coherence
repetition_penalty=1.2 # discourages repetitive phrasing
)
answer = result[0]["generated_text"].strip()
# 🧩 If the model outputs something too short, expand gracefully
if len(answer.split()) < 8:
answer = (
"The document mentions this briefly. Based on the context, here's what it suggests: "
+ answer
)
return answer
except Exception as e:
print(f"⚠️ Generation failed: {e}")
return "⚠️ Error: Could not generate an answer at the moment."
# ==========================================================
# 7️⃣ Optional Local Test
# ==========================================================
if __name__ == "__main__":
dummy_chunks = [
"SAP Ariba is a cloud-based procurement solution.",
"It helps companies manage suppliers and sourcing processes efficiently.",
"Integration with SAP ERP allows for seamless data synchronization."
]
from vectorstore import build_faiss_index
import numpy as np
index = build_faiss_index([
_query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
for chunk in dummy_chunks
])
query = "What is SAP Ariba used for?"
retrieved = retrieve_chunks(query, index, dummy_chunks)
print("🔍 Retrieved:", retrieved)
print("💬 Answer:", generate_answer(query, retrieved))