Shubham170793's picture
Update src/qa.py
7b609f8 verified
raw
history blame
5.9 kB
"""
qa.py — Retrieval + Generation Layer
-------------------------------------
Handles:
• Query embedding (SentenceTransformer / E5-compatible)
• Chunk retrieval (FAISS)
• Answer generation (Flan-T5)
Optimized for Hugging Face Spaces & Streamlit.
"""
import os
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from vectorstore import search_faiss
print("✅ qa.py loaded from:", __file__)
# ==========================================================
# 1️⃣ Hugging Face Cache Setup (Safe for Spaces) :)
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR,
"HF_DATASETS_CACHE": CACHE_DIR,
"HF_MODULES_CACHE": CACHE_DIR
})
# ==========================================================
# 2️⃣ Query Embedding Model
# ==========================================================
try:
_query_model = SentenceTransformer(
"intfloat/e5-small-v2",
cache_folder=CACHE_DIR
)
print("✅ Loaded query model: intfloat/e5-small-v2")
except Exception as e:
print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
_query_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
cache_folder=CACHE_DIR
)
print("✅ Loaded fallback model: all-MiniLM-L6-v2")
# ==========================================================
# 3️⃣ LLM for Answer Generation (FLAN-T5)
# ==========================================================
MODEL_NAME = "google/flan-t5-base" # Switch to 'large' if enough memory
print(f"✅ Loading LLM: {MODEL_NAME}")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_answer_model = pipeline(
"text2text-generation",
model=_model,
tokenizer=_tokenizer,
device=-1 # CPU-safe (Hugging Face Spaces)
)
# ==========================================================
# 4️⃣ Prompt Template
# ==========================================================
PROMPT_TEMPLATE = """
You are an expert enterprise knowledge assistant.
Use ONLY the CONTEXT below to answer the QUESTION clearly, factually, and completely.
If the context doesn’t contain the answer, reply exactly:
"I don't know based on the provided document."
---
Context:
{context}
---
Question:
{query}
---
Answer:
"""
# ==========================================================
# 5️⃣ Chunk Retrieval Function
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
"""
Encodes the user query and retrieves top-k relevant chunks via FAISS.
Uses 'query:' prefix (E5 training style) for semantic alignment.
"""
if not index or not chunks:
return []
try:
query_emb = _query_model.encode(
[f"query: {query.strip()}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
results = search_faiss(query_emb, index, chunks, top_k)
return results
except Exception as e:
print(f"⚠️ Retrieval error: {e}")
return []
# ==========================================================
# 6️⃣ Answer Generation Function
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list):
"""
Generates an answer using FLAN-T5 and retrieved chunks as context.
Includes dynamic length, sampling for expressiveness, and fallback logic.
"""
if not retrieved_chunks:
return "Sorry, I couldn’t find relevant information in the document."
# Merge retrieved chunks into one coherent context
context = "\n\n".join([
f"[Chunk {i+1}]: {chunk.strip()}"
for i, chunk in enumerate(retrieved_chunks)
])
prompt = PROMPT_TEMPLATE.format(context=context, query=query)
try:
result = _answer_model(
prompt,
max_new_tokens=400, # allow more elaborate responses
do_sample=True, # enable natural variability
temperature=0.7, # creativity balance
top_p=0.9, # nucleus sampling for relevance
repetition_penalty=1.15 # discourage repetition
)
answer = result[0]["generated_text"].strip()
# 🧩 Handle overly short answers
#if len(answer.split()) < 5:
# answer = (
# "The document briefly mentions this. Based on the context, here's what it implies: "
#+ answer
#)
return answer
except Exception as e:
print(f"⚠️ Generation failed: {e}")
return "⚠️ Error: Could not generate an answer at the moment."
# ==========================================================
# 7️⃣ Optional Local Test (runs only in dev mode)
# ==========================================================
if __name__ == "__main__":
dummy_chunks = [
"SAP Ariba is a cloud-based procurement solution.",
"It helps companies manage suppliers and sourcing processes efficiently.",
"Integration with SAP ERP allows for seamless data synchronization."
]
from vectorstore import build_faiss_index
index = build_faiss_index([
_query_model.encode(
[f"passage: {chunk}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
for chunk in dummy_chunks
])
query = "What is SAP Ariba used for?"
retrieved = retrieve_chunks(query, index, dummy_chunks)
print("🔍 Retrieved:", retrieved)
print("💬 Answer:", generate_answer(query, retrieved))