""" qa.py — Retrieval + Generation Layer ------------------------------------- Handles: • Query embedding (SentenceTransformer / E5-compatible) • Chunk retrieval (FAISS) • Answer generation (Flan-T5) Optimized for Hugging Face Spaces & Streamlit. """ import os from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from vectorstore import search_faiss print("✅ qa.py loaded from:", __file__) # ========================================================== # 1️⃣ Hugging Face Cache Setup (Safe for Spaces) :) # ========================================================== CACHE_DIR = "/tmp/hf_cache" os.makedirs(CACHE_DIR, exist_ok=True) os.environ.update({ "HF_HOME": CACHE_DIR, "TRANSFORMERS_CACHE": CACHE_DIR, "HF_DATASETS_CACHE": CACHE_DIR, "HF_MODULES_CACHE": CACHE_DIR }) # ========================================================== # 2️⃣ Query Embedding Model # ========================================================== try: _query_model = SentenceTransformer( "intfloat/e5-small-v2", cache_folder=CACHE_DIR ) print("✅ Loaded query model: intfloat/e5-small-v2") except Exception as e: print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.") _query_model = SentenceTransformer( "sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR ) print("✅ Loaded fallback model: all-MiniLM-L6-v2") # ========================================================== # 3️⃣ LLM for Answer Generation (FLAN-T5) # ========================================================== MODEL_NAME = "google/flan-t5-base" # Switch to 'large' if enough memory print(f"✅ Loading LLM: {MODEL_NAME}") _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) _answer_model = pipeline( "text2text-generation", model=_model, tokenizer=_tokenizer, device=-1 # CPU-safe (Hugging Face Spaces) ) # ========================================================== # 4️⃣ Prompt Template # ========================================================== PROMPT_TEMPLATE = """ You are an expert enterprise knowledge assistant. Use ONLY the CONTEXT below to answer the QUESTION clearly, factually, and completely. If the context doesn’t contain the answer, reply exactly: "I don't know based on the provided document." --- Context: {context} --- Question: {query} --- Answer: """ # ========================================================== # 5️⃣ Chunk Retrieval Function # ========================================================== def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3): """ Encodes the user query and retrieves top-k relevant chunks via FAISS. Uses 'query:' prefix (E5 training style) for semantic alignment. """ if not index or not chunks: return [] try: query_emb = _query_model.encode( [f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True )[0] results = search_faiss(query_emb, index, chunks, top_k) return results except Exception as e: print(f"⚠️ Retrieval error: {e}") return [] # ========================================================== # 6️⃣ Answer Generation Function # ========================================================== def generate_answer(query: str, retrieved_chunks: list): """ Generates an answer using FLAN-T5 and retrieved chunks as context. Includes dynamic length, sampling for expressiveness, and fallback logic. """ if not retrieved_chunks: return "Sorry, I couldn’t find relevant information in the document." # Merge retrieved chunks into one coherent context context = "\n\n".join([ f"[Chunk {i+1}]: {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks) ]) prompt = PROMPT_TEMPLATE.format(context=context, query=query) try: result = _answer_model( prompt, max_new_tokens=400, # allow more elaborate responses do_sample=True, # enable natural variability temperature=0.7, # creativity balance top_p=0.9, # nucleus sampling for relevance repetition_penalty=1.15 # discourage repetition ) answer = result[0]["generated_text"].strip() # 🧩 Handle overly short answers #if len(answer.split()) < 5: # answer = ( # "The document briefly mentions this. Based on the context, here's what it implies: " #+ answer #) return answer except Exception as e: print(f"⚠️ Generation failed: {e}") return "⚠️ Error: Could not generate an answer at the moment." # ========================================================== # 7️⃣ Optional Local Test (runs only in dev mode) # ========================================================== if __name__ == "__main__": dummy_chunks = [ "SAP Ariba is a cloud-based procurement solution.", "It helps companies manage suppliers and sourcing processes efficiently.", "Integration with SAP ERP allows for seamless data synchronization." ] from vectorstore import build_faiss_index index = build_faiss_index([ _query_model.encode( [f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True )[0] for chunk in dummy_chunks ]) query = "What is SAP Ariba used for?" retrieved = retrieve_chunks(query, index, dummy_chunks) print("🔍 Retrieved:", retrieved) print("💬 Answer:", generate_answer(query, retrieved))