|
|
""" |
|
|
qa.py — Retrieval + Generation Layer |
|
|
------------------------------------- |
|
|
Handles: |
|
|
• Query embedding (SentenceTransformer / E5-compatible) |
|
|
• Chunk retrieval (FAISS with neighborhood merging + re-ranking) |
|
|
• Answer generation (OpenAI GPT-4o-mini → FLAN-T5 fallback) |
|
|
Optimized for Hugging Face Spaces & Streamlit. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from vectorstore import search_faiss |
|
|
|
|
|
print("✅ qa.py loaded from:", __file__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_DIR = "/tmp/hf_cache" |
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
os.environ.update({ |
|
|
"HF_HOME": CACHE_DIR, |
|
|
"TRANSFORMERS_CACHE": CACHE_DIR, |
|
|
"HF_DATASETS_CACHE": CACHE_DIR, |
|
|
"HF_MODULES_CACHE": CACHE_DIR |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
_query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR) |
|
|
print("✅ Loaded query model: intfloat/e5-small-v2") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.") |
|
|
_query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
USE_OPENAI = bool(os.getenv("OPENAI_API_KEY")) |
|
|
_answer_model = None |
|
|
|
|
|
if USE_OPENAI: |
|
|
try: |
|
|
from openai import OpenAI |
|
|
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
|
print("✅ Using OpenAI GPT-4o-mini for answer generation") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to initialize OpenAI client: {e}") |
|
|
USE_OPENAI = False |
|
|
|
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
MODEL_NAME = "google/flan-t5-base" |
|
|
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) |
|
|
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) |
|
|
_answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1) |
|
|
print("💡 Fallback FLAN-T5 ready.") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Could not initialize FLAN fallback: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROMPT_TEMPLATE = """ |
|
|
You are an enterprise knowledge assistant. |
|
|
Use ONLY the CONTEXT below to answer the QUESTION clearly, completely, and factually. |
|
|
If the context doesn’t contain the answer, reply exactly: |
|
|
"I don't know based on the provided document." |
|
|
|
|
|
--- |
|
|
Context: |
|
|
{context} |
|
|
--- |
|
|
Question: |
|
|
{query} |
|
|
--- |
|
|
Answer: |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5): |
|
|
"""Retrieve top-K relevant chunks, merge nearby ones, and re-rank by semantic similarity.""" |
|
|
if not index or not chunks: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
query_emb = _query_model.encode( |
|
|
[f"query: {query.strip()}"], |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True |
|
|
)[0] |
|
|
|
|
|
|
|
|
distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2) |
|
|
|
|
|
|
|
|
merged_chunks = [] |
|
|
for idx in indices[0]: |
|
|
neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))] |
|
|
merged_chunks.append(" ".join(neighbors)) |
|
|
|
|
|
|
|
|
chunk_vecs = np.array([ |
|
|
_query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0] |
|
|
for c in merged_chunks |
|
|
]) |
|
|
scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0] |
|
|
sorted_indices = np.argsort(scores)[::-1] |
|
|
|
|
|
|
|
|
return [merged_chunks[i] for i in sorted_indices[:top_k]] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Retrieval error: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_answer(query: str, retrieved_chunks: list): |
|
|
"""Generate factual, context-grounded answers using OpenAI or fallback FLAN-T5.""" |
|
|
if not retrieved_chunks: |
|
|
return "Sorry, I couldn’t find relevant information in the document." |
|
|
|
|
|
|
|
|
context = "\n\n".join([ |
|
|
f"[Chunk {i+1}]: {chunk.strip()}" |
|
|
for i, chunk in enumerate(retrieved_chunks) |
|
|
]) |
|
|
prompt = PROMPT_TEMPLATE.format(context=context, query=query) |
|
|
|
|
|
|
|
|
if USE_OPENAI: |
|
|
try: |
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o-mini", |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are a precise enterprise document assistant."}, |
|
|
{"role": "user", "content": prompt}, |
|
|
], |
|
|
temperature=0.4, |
|
|
max_tokens=800, |
|
|
) |
|
|
return response.choices[0].message.content.strip() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ OpenAI generation failed: {e}. Switching to fallback...") |
|
|
|
|
|
|
|
|
try: |
|
|
if _answer_model: |
|
|
result = _answer_model( |
|
|
prompt, |
|
|
max_new_tokens=600, |
|
|
do_sample=False, |
|
|
temperature=0.3 |
|
|
) |
|
|
return result[0]["generated_text"].strip() |
|
|
else: |
|
|
return "⚠️ Error: Fallback model not available." |
|
|
except Exception as e: |
|
|
print(f"⚠️ Fallback model failed: {e}") |
|
|
return "⚠️ Error: Both OpenAI and fallback generation failed." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
dummy_chunks = [ |
|
|
"Step 1: Open the dashboard and navigate to reports.", |
|
|
"Step 2: Click 'Export' to download a CSV summary.", |
|
|
"Step 3: Review the generated report in your downloads folder." |
|
|
] |
|
|
from vectorstore import build_faiss_index |
|
|
|
|
|
index = build_faiss_index([ |
|
|
_query_model.encode( |
|
|
[f"passage: {chunk}"], |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True |
|
|
)[0] |
|
|
for chunk in dummy_chunks |
|
|
]) |
|
|
|
|
|
query = "What are the steps to export a report?" |
|
|
retrieved = retrieve_chunks(query, index, dummy_chunks) |
|
|
print("🔍 Retrieved:", retrieved) |
|
|
print("💬 Answer:", generate_answer(query, retrieved)) |
|
|
|