""" qa.py — Retrieval + Generation Layer ------------------------------------- Handles: • Query embedding (SentenceTransformer / E5-compatible) • Chunk retrieval (FAISS) • Answer generation (Flan-T5) Optimized for Hugging Face Spaces & Streamlit. """ import os from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from vectorstore import search_faiss print("✅ qa.py loaded from:", __file__) # ========================================================== # 1️⃣ Hugging Face Cache Setup (Safe for Spaces) # ========================================================== CACHE_DIR = "/tmp/hf_cache" os.makedirs(CACHE_DIR, exist_ok=True) os.environ.update({ "HF_HOME": CACHE_DIR, "TRANSFORMERS_CACHE": CACHE_DIR, "HF_DATASETS_CACHE": CACHE_DIR, "HF_MODULES_CACHE": CACHE_DIR }) # ========================================================== # 2️⃣ Query Embedding Model # ========================================================== # Use E5-small-v2 for retrieval consistency with embeddings.py try: _query_model = SentenceTransformer( "intfloat/e5-small-v2", cache_folder=CACHE_DIR ) print("✅ Loaded query model: intfloat/e5-small-v2") except Exception as e: print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.") _query_model = SentenceTransformer( "sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR ) print("✅ Loaded fallback model: all-MiniLM-L6-v2") # ========================================================== # 3️⃣ LLM for Answer Generation (FLAN-T5) # ========================================================== MODEL_NAME = "google/flan-t5-base" # switch to 'large' if RAM allows print(f"✅ Loading LLM: {MODEL_NAME}") _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) _answer_model = pipeline( "text2text-generation", model=_model, tokenizer=_tokenizer, device=-1 # CPU-safe for Spaces ) # ========================================================== # 4️⃣ Prompt Template (concise and factual) # ========================================================== PROMPT_TEMPLATE = """ You are an expert enterprise assistant. Using ONLY the CONTEXT below, answer the QUESTION clearly and factually. If the context doesn’t contain the answer, reply exactly: "I don't know based on the provided document." --- Context: {context} --- Question: {query} --- Answer: """ # ========================================================== # 5️⃣ Chunk Retrieval Function # ========================================================== def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3): """ Encodes the user query and retrieves top-k relevant chunks via FAISS. Uses 'query:' prefix (E5 training style) for semantic alignment. """ if not index or not chunks: return [] try: # E5 expects 'query:' prefix for better retrieval accuracy query_emb = _query_model.encode( [f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True )[0] results = search_faiss(query_emb, index, chunks, top_k) return results except Exception as e: print(f"⚠️ Retrieval error: {e}") return [] # ========================================================== # 6️⃣ Answer Generation Function # ========================================================== def generate_answer(query: str, retrieved_chunks: list): """ Generates an answer using FLAN-T5 and retrieved chunks as context. """ if not retrieved_chunks: return "Sorry, I couldn’t find relevant information in the document." # Merge retrieved chunks for context context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)]) # Build structured prompt prompt = PROMPT_TEMPLATE.format(context=context, query=query) try: result = _answer_model( prompt, max_new_tokens=350, # allow longer, more complete answers do_sample=True, # enable sampling for natural flow temperature=0.7, # slightly higher = more expressive responses top_p=0.95, # nucleus sampling for coherence repetition_penalty=1.2 # discourages repetitive phrasing ) answer = result[0]["generated_text"].strip() # 🧩 If the model outputs something too short, expand gracefully if len(answer.split()) < 8: answer = ( "The document mentions this briefly. Based on the context, here's what it suggests: " + answer ) return answer except Exception as e: print(f"⚠️ Generation failed: {e}") return "⚠️ Error: Could not generate an answer at the moment." # ========================================================== # 7️⃣ Optional Local Test # ========================================================== if __name__ == "__main__": dummy_chunks = [ "SAP Ariba is a cloud-based procurement solution.", "It helps companies manage suppliers and sourcing processes efficiently.", "Integration with SAP ERP allows for seamless data synchronization." ] from vectorstore import build_faiss_index import numpy as np index = build_faiss_index([ _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0] for chunk in dummy_chunks ]) query = "What is SAP Ariba used for?" retrieved = retrieve_chunks(query, index, dummy_chunks) print("🔍 Retrieved:", retrieved) print("💬 Answer:", generate_answer(query, retrieved))