Shubham170793's picture
Update src/qa.py
cd6e69b verified
raw
history blame
7.33 kB
"""
qa.py — Phi-2 Fast + Smart Reasoning Mode (Hybrid)
-------------------------------------------------
✅ Uses intfloat/e5-small-v2 for embeddings
✅ Uses microsoft/phi-2 (fast CPU quantized)
✅ Reasoning Mode toggle integrated cleanly
✅ Retrieval unaffected by reasoning mode
"""
import os
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
print("✅ qa.py (Phi-2 Hybrid) loaded from:", __file__)
# ==========================================================
# 1️⃣ Cache Setups
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR,
"HF_DATASETS_CACHE": CACHE_DIR,
"HF_MODULES_CACHE": CACHE_DIR
})
# ==========================================================
# 2️⃣ Embedding Model
# ==========================================================
try:
_query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
print("✅ Loaded embedding model: intfloat/e5-small-v2")
except Exception as e:
print(f"⚠️ Fallback to MiniLM due to {e}")
_query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
# ==========================================================
# 3️⃣ Phi-2 LLM Setup (Optimized for CPU)
# ==========================================================
try:
MODEL_NAME = "microsoft/phi-2"
print(f"✅ Loading LLM: {MODEL_NAME} (optimized for reasoning)")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16,
low_cpu_mem_usage=True,
).to("cpu")
_answer_model = pipeline(
"text-generation",
model=_model,
tokenizer=_tokenizer,
device=-1,
model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True},
)
print("✅ Phi-2 text-generation pipeline ready.")
except Exception as e:
print(f"⚠️ Phi-2 load failed: {e}")
_answer_model = None
# ==========================================================
# 4️⃣ Prompt Templates
# ==========================================================
STRICT_PROMPT = (
"Answer based ONLY on the context below.\n"
"If the answer isn’t in the context, say: 'I don't know based on the provided document.'\n\n"
"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
)
REASONING_PROMPT = (
"You are an expert assistant. Use the context and your reasoning ability to form a clear, step-by-step answer.\n"
"Be concise yet complete. If the context doesn’t contain the answer, say: 'I don't know based on the provided document.'\n\n"
"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
)
# ==========================================================
# 5️⃣ Chunk Retrieval (Unchanged — Fast)
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5):
"""Fast FAISS retrieval using cosine similarity."""
if not index or not chunks:
return []
try:
q_emb = _query_model.encode(
[f"query: {query.strip()}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2)
selected = set()
for idx in indices[0]:
for i in range(max(0, idx - 1), min(len(chunks), idx + 2)):
selected.add(i)
return [chunks[i] for i in sorted(selected)]
except Exception as e:
print(f"⚠️ Retrieval error: {e}")
return []
# ==========================================================
# 6️⃣ Answer Generation (Enhanced — Balanced Reasoning + Speed)
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
"""
Generate answers with Phi-2.
- reasoning_mode=False → strict factual, fast
- reasoning_mode=True → analytical, richer reasoning (slower)
"""
if not retrieved_chunks:
return "Sorry, I couldn’t find relevant information in the document."
context = "\n".join(chunk.strip() for chunk in retrieved_chunks)
# 🧠 Reasoning prompt: encourages explanation, not just lookup
REASONING_PROMPT = (
"You are an expert assistant with strong reasoning skills.\n"
"Think step by step and form a detailed, logical answer.\n"
"You can combine hints from the context with your general understanding.\n"
"If the context doesn't mention the answer, acknowledge that.\n\n"
"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
)
# ⚡ Strict factual prompt
STRICT_PROMPT = (
"Answer based ONLY on the context below.\n"
"If the answer isn’t in the context, say: 'I don't know based on the provided document.'\n\n"
"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
)
prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
try:
if reasoning_mode:
# 🧩 The “brainy” config that produced the great long answer
result = _answer_model(
prompt,
max_new_tokens=180,
temperature=0.45,
do_sample=False, # reasoning but deterministic
pad_token_id=_tokenizer.eos_token_id,
)
else:
# ⚡ Fast factual config
result = _answer_model(
prompt,
max_new_tokens=120,
temperature=0.2,
do_sample=False,
pad_token_id=_tokenizer.eos_token_id,
)
raw = result[0]["generated_text"].strip()
if "Answer:" in raw:
raw = raw.split("Answer:")[-1].strip()
return raw
except Exception as e:
print(f"⚠️ Generation failed: {e}")
return "⚠️ Error: Could not generate an answer."
# ==========================================================
# 7️⃣ Local Test
# ==========================================================
if __name__ == "__main__":
from vectorstore import build_faiss_index
dummy_chunks = [
"Step 1: Open the dashboard and navigate to reports.",
"Step 2: Click 'Export' to download a CSV summary.",
"Step 3: Review the generated report in your downloads folder."
]
embeddings = [
_query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
for chunk in dummy_chunks
]
index = build_faiss_index(embeddings)
query = "What are the steps to export a report?"
retrieved = retrieve_chunks(query, index, dummy_chunks)
print("\n--- Strict Mode ---")
print(generate_answer(query, retrieved, reasoning_mode=False))
print("\n--- Reasoning Mode ---")
print(generate_answer(query, retrieved, reasoning_mode=True))