Shubham170793's picture
Update src/qa.py
c7133f4 verified
raw
history blame
7.23 kB
"""
qa.py — Retrieval + Generation Layer
-------------------------------------
Handles:
• Query embedding (SentenceTransformer / E5-compatible)
• Chunk retrieval (FAISS with neighborhood merging + re-ranking)
• Answer generation (OpenAI GPT-4o-mini → FLAN-T5 fallback)
Optimized for Hugging Face Spaces & Streamlit.
"""
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from vectorstore import search_faiss
print("✅ qa.py loaded from:", __file__)
# ==========================================================
# 1️⃣ Hugging Face Cache Setup
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR,
"HF_DATASETS_CACHE": CACHE_DIR,
"HF_MODULES_CACHE": CACHE_DIR
})
# ==========================================================
# 2️⃣ Query Embedding Model
# ==========================================================
try:
_query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
print("✅ Loaded query model: intfloat/e5-small-v2")
except Exception as e:
print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
_query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
# ==========================================================
# 3️⃣ LLM Setup: OpenAI (primary) + FLAN (fallback)
# ==========================================================
USE_OPENAI = bool(os.getenv("OPENAI_API_KEY"))
_answer_model = None # ensures it's always defined
if USE_OPENAI:
try:
from openai import OpenAI
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
print("✅ Using OpenAI GPT-4o-mini for answer generation")
except Exception as e:
print(f"⚠️ Failed to initialize OpenAI client: {e}")
USE_OPENAI = False
# Always prepare fallback safely
try:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
MODEL_NAME = "google/flan-t5-base"
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1)
print("💡 Fallback FLAN-T5 ready.")
except Exception as e:
print(f"⚠️ Could not initialize FLAN fallback: {e}")
# ==========================================================
# 4️⃣ Prompt Template
# ==========================================================
PROMPT_TEMPLATE = """
You are an enterprise knowledge assistant.
Use ONLY the CONTEXT below to answer the QUESTION clearly, completely, and factually.
If the context doesn’t contain the answer, reply exactly:
"I don't know based on the provided document."
---
Context:
{context}
---
Question:
{query}
---
Answer:
"""
# ==========================================================
# 5️⃣ Chunk Retrieval Function
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
"""Retrieve top-K relevant chunks, merge nearby ones, and re-rank by semantic similarity."""
if not index or not chunks:
return []
try:
# Step 1: Encode the query
query_emb = _query_model.encode(
[f"query: {query.strip()}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
# Step 2: Initial FAISS retrieval
distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2)
# Step 3: Merge neighboring chunks
merged_chunks = []
for idx in indices[0]:
neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))]
merged_chunks.append(" ".join(neighbors))
# Step 4: Re-rank using cosine similarity
chunk_vecs = np.array([
_query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0]
for c in merged_chunks
])
scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0]
sorted_indices = np.argsort(scores)[::-1]
# Step 5: Return top-ranked merged chunks
return [merged_chunks[i] for i in sorted_indices[:top_k]]
except Exception as e:
print(f"⚠️ Retrieval error: {e}")
return []
# ==========================================================
# 6️⃣ Answer Generation Function
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list):
"""Generate factual, context-grounded answers using OpenAI or fallback FLAN-T5."""
if not retrieved_chunks:
return "Sorry, I couldn’t find relevant information in the document."
# Build full context
context = "\n\n".join([
f"[Chunk {i+1}]: {chunk.strip()}"
for i, chunk in enumerate(retrieved_chunks)
])
prompt = PROMPT_TEMPLATE.format(context=context, query=query)
# --- Try OpenAI first ---
if USE_OPENAI:
try:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a precise enterprise document assistant."},
{"role": "user", "content": prompt},
],
temperature=0.4,
max_tokens=800,
)
return response.choices[0].message.content.strip()
except Exception as e:
print(f"⚠️ OpenAI generation failed: {e}. Switching to fallback...")
# --- Fallback to FLAN-T5 ---
try:
if _answer_model:
result = _answer_model(
prompt,
max_new_tokens=600,
do_sample=False,
temperature=0.3
)
return result[0]["generated_text"].strip()
else:
return "⚠️ Error: Fallback model not available."
except Exception as e:
print(f"⚠️ Fallback model failed: {e}")
return "⚠️ Error: Both OpenAI and fallback generation failed."
# ==========================================================
# 7️⃣ Local Test
# ==========================================================
if __name__ == "__main__":
dummy_chunks = [
"Step 1: Open the dashboard and navigate to reports.",
"Step 2: Click 'Export' to download a CSV summary.",
"Step 3: Review the generated report in your downloads folder."
]
from vectorstore import build_faiss_index
index = build_faiss_index([
_query_model.encode(
[f"passage: {chunk}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
for chunk in dummy_chunks
])
query = "What are the steps to export a report?"
retrieved = retrieve_chunks(query, index, dummy_chunks)
print("🔍 Retrieved:", retrieved)
print("💬 Answer:", generate_answer(query, retrieved))