""" qa.py — Phi-2 Fast + Smart Reasoning Mode (Hybrid) ------------------------------------------------- ✅ Uses intfloat/e5-small-v2 for embeddings ✅ Uses microsoft/phi-2 (fast CPU quantized) ✅ Reasoning Mode toggle integrated cleanly ✅ Retrieval unaffected by reasoning mode """ import os import numpy as np import torch from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline print("✅ qa.py (Phi-2 Hybrid) loaded from:", __file__) # ========================================================== # 1️⃣ Cache Setups # ========================================================== CACHE_DIR = "/tmp/hf_cache" os.makedirs(CACHE_DIR, exist_ok=True) os.environ.update({ "HF_HOME": CACHE_DIR, "TRANSFORMERS_CACHE": CACHE_DIR, "HF_DATASETS_CACHE": CACHE_DIR, "HF_MODULES_CACHE": CACHE_DIR }) # ========================================================== # 2️⃣ Embedding Model # ========================================================== try: _query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR) print("✅ Loaded embedding model: intfloat/e5-small-v2") except Exception as e: print(f"⚠️ Fallback to MiniLM due to {e}") _query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR) # ========================================================== # 3️⃣ Phi-2 LLM Setup (Optimized for CPU) # ========================================================== try: MODEL_NAME = "microsoft/phi-2" print(f"✅ Loading LLM: {MODEL_NAME} (optimized for reasoning)") _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) _model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, cache_dir=CACHE_DIR, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16, low_cpu_mem_usage=True, ).to("cpu") _answer_model = pipeline( "text-generation", model=_model, tokenizer=_tokenizer, device=-1, model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True}, ) print("✅ Phi-2 text-generation pipeline ready.") except Exception as e: print(f"⚠️ Phi-2 load failed: {e}") _answer_model = None # ========================================================== # 4️⃣ Prompt Templates # ========================================================== STRICT_PROMPT = ( "Answer based ONLY on the context below.\n" "If the answer isn’t in the context, say: 'I don't know based on the provided document.'\n\n" "Context:\n{context}\n\nQuestion: {query}\nAnswer:" ) REASONING_PROMPT = ( "You are an expert assistant. Use the context and your reasoning ability to form a clear, step-by-step answer.\n" "Be concise yet complete. If the context doesn’t contain the answer, say: 'I don't know based on the provided document.'\n\n" "Context:\n{context}\n\nQuestion: {query}\nAnswer:" ) # ========================================================== # 5️⃣ Chunk Retrieval (Unchanged — Fast) # ========================================================== def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5): """Fast FAISS retrieval using cosine similarity.""" if not index or not chunks: return [] try: q_emb = _query_model.encode( [f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True )[0] distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2) selected = set() for idx in indices[0]: for i in range(max(0, idx - 1), min(len(chunks), idx + 2)): selected.add(i) return [chunks[i] for i in sorted(selected)] except Exception as e: print(f"⚠️ Retrieval error: {e}") return [] # ========================================================== # 6️⃣ Answer Generation (Enhanced — Balanced Reasoning + Speed) # ========================================================== def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False): """ Generate answers with Phi-2. - reasoning_mode=False → strict factual, fast - reasoning_mode=True → analytical, richer reasoning (slower) """ if not retrieved_chunks: return "Sorry, I couldn’t find relevant information in the document." context = "\n".join(chunk.strip() for chunk in retrieved_chunks) # 🧠 Reasoning prompt: encourages explanation, not just lookup REASONING_PROMPT = ( "You are an expert assistant with strong reasoning skills.\n" "Think step by step and form a detailed, logical answer.\n" "You can combine hints from the context with your general understanding.\n" "If the context doesn't mention the answer, acknowledge that.\n\n" "Context:\n{context}\n\nQuestion: {query}\n\nAnswer:" ) # ⚡ Strict factual prompt STRICT_PROMPT = ( "Answer based ONLY on the context below.\n" "If the answer isn’t in the context, say: 'I don't know based on the provided document.'\n\n" "Context:\n{context}\n\nQuestion: {query}\nAnswer:" ) prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query) try: if reasoning_mode: # 🧩 The “brainy” config that produced the great long answer result = _answer_model( prompt, max_new_tokens=180, temperature=0.45, do_sample=False, # reasoning but deterministic pad_token_id=_tokenizer.eos_token_id, ) else: # ⚡ Fast factual config result = _answer_model( prompt, max_new_tokens=120, temperature=0.2, do_sample=False, pad_token_id=_tokenizer.eos_token_id, ) raw = result[0]["generated_text"].strip() if "Answer:" in raw: raw = raw.split("Answer:")[-1].strip() return raw except Exception as e: print(f"⚠️ Generation failed: {e}") return "⚠️ Error: Could not generate an answer." # ========================================================== # 7️⃣ Local Test # ========================================================== if __name__ == "__main__": from vectorstore import build_faiss_index dummy_chunks = [ "Step 1: Open the dashboard and navigate to reports.", "Step 2: Click 'Export' to download a CSV summary.", "Step 3: Review the generated report in your downloads folder." ] embeddings = [ _query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0] for chunk in dummy_chunks ] index = build_faiss_index(embeddings) query = "What are the steps to export a report?" retrieved = retrieve_chunks(query, index, dummy_chunks) print("\n--- Strict Mode ---") print(generate_answer(query, retrieved, reasoning_mode=False)) print("\n--- Reasoning Mode ---") print(generate_answer(query, retrieved, reasoning_mode=True))