""" qa.py — Retrieval + Generation Layer ------------------------------------- Handles: • Query embedding (SentenceTransformer / E5-compatible) • Chunk retrieval (FAISS with neighborhood merging + re-ranking) • Answer generation (OpenAI GPT-4o-mini → FLAN-T5 fallback) Optimized for Hugging Face Spaces & Streamlit. """ import os import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from vectorstore import search_faiss print("✅ qa.py loaded from:", __file__) # ========================================================== # 1️⃣ Hugging Face Cache Setup # ========================================================== CACHE_DIR = "/tmp/hf_cache" os.makedirs(CACHE_DIR, exist_ok=True) os.environ.update({ "HF_HOME": CACHE_DIR, "TRANSFORMERS_CACHE": CACHE_DIR, "HF_DATASETS_CACHE": CACHE_DIR, "HF_MODULES_CACHE": CACHE_DIR }) # ========================================================== # 2️⃣ Query Embedding Model # ========================================================== try: _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR) print("✅ Loaded query model: intfloat/e5-small-v2") except Exception as e: print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.") _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR) # ========================================================== # 3️⃣ LLM Setup: OpenAI (primary) + FLAN (fallback) # ========================================================== USE_OPENAI = bool(os.getenv("OPENAI_API_KEY")) _answer_model = None # ensures it's always defined if USE_OPENAI: try: from openai import OpenAI client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) print("✅ Using OpenAI GPT-4o-mini for answer generation") except Exception as e: print(f"⚠️ Failed to initialize OpenAI client: {e}") USE_OPENAI = False # Always prepare fallback safely try: from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline MODEL_NAME = "google/flan-t5-base" _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) _answer_model = pipeline("text2text-generation", model=_model, tokenizer=_tokenizer, device=-1) print("💡 Fallback FLAN-T5 ready.") except Exception as e: print(f"⚠️ Could not initialize FLAN fallback: {e}") # ========================================================== # 4️⃣ Prompt Template # ========================================================== PROMPT_TEMPLATE = """ You are an enterprise knowledge assistant. Use ONLY the CONTEXT below to answer the QUESTION clearly, completely, and factually. If the context doesn’t contain the answer, reply exactly: "I don't know based on the provided document." --- Context: {context} --- Question: {query} --- Answer: """ # ========================================================== # 5️⃣ Chunk Retrieval Function # ========================================================== def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5): """Retrieve top-K relevant chunks, merge nearby ones, and re-rank by semantic similarity.""" if not index or not chunks: return [] try: # Step 1: Encode the query query_emb = _query_model.encode( [f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True )[0] # Step 2: Initial FAISS retrieval distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k * 2) # Step 3: Merge neighboring chunks merged_chunks = [] for idx in indices[0]: neighbors = [chunks[i] for i in range(max(0, idx - 1), min(len(chunks), idx + 2))] merged_chunks.append(" ".join(neighbors)) # Step 4: Re-rank using cosine similarity chunk_vecs = np.array([ _query_model.encode([c], convert_to_numpy=True, normalize_embeddings=True)[0] for c in merged_chunks ]) scores = cosine_similarity(np.array([query_emb]), chunk_vecs)[0] sorted_indices = np.argsort(scores)[::-1] # Step 5: Return top-ranked merged chunks return [merged_chunks[i] for i in sorted_indices[:top_k]] except Exception as e: print(f"⚠️ Retrieval error: {e}") return [] # ========================================================== # 6️⃣ Answer Generation Function # ========================================================== def generate_answer(query: str, retrieved_chunks: list): """Generate factual, context-grounded answers using OpenAI or fallback FLAN-T5.""" if not retrieved_chunks: return "Sorry, I couldn’t find relevant information in the document." # Build full context context = "\n\n".join([ f"[Chunk {i+1}]: {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks) ]) prompt = PROMPT_TEMPLATE.format(context=context, query=query) # --- Try OpenAI first --- if USE_OPENAI: try: response = client.chat.completions.create( model="gpt-4o-mini", messages=[ {"role": "system", "content": "You are a precise enterprise document assistant."}, {"role": "user", "content": prompt}, ], temperature=0.4, max_tokens=800, ) return response.choices[0].message.content.strip() except Exception as e: print(f"⚠️ OpenAI generation failed: {e}. Switching to fallback...") # --- Fallback to FLAN-T5 --- try: if _answer_model: result = _answer_model( prompt, max_new_tokens=600, do_sample=False, temperature=0.3 ) return result[0]["generated_text"].strip() else: return "⚠️ Error: Fallback model not available." except Exception as e: print(f"⚠️ Fallback model failed: {e}") return "⚠️ Error: Both OpenAI and fallback generation failed." # ========================================================== # 7️⃣ Local Test # ========================================================== if __name__ == "__main__": dummy_chunks = [ "Step 1: Open the dashboard and navigate to reports.", "Step 2: Click 'Export' to download a CSV summary.", "Step 3: Review the generated report in your downloads folder." ] from vectorstore import build_faiss_index index = build_faiss_index([ _query_model.encode( [f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True )[0] for chunk in dummy_chunks ]) query = "What are the steps to export a report?" retrieved = retrieve_chunks(query, index, dummy_chunks) print("🔍 Retrieved:", retrieved) print("💬 Answer:", generate_answer(query, retrieved))