Shubham170793's picture
Update src/qa.py
43cd83d verified
raw
history blame
5.43 kB
"""
qa.py — Fast, Reasoning-Enabled Phi-2 Version
----------------------------------------------
• Uses SentenceTransformer (E5-small) for embeddings
• Uses microsoft/phi-2 for generation
• Retains reasoning vs factual modes
• Optimized for speed and low VRAM on CPU
"""
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sklearn.metrics.pairwise import cosine_similarity
print("✅ qa.py (Phi-2 optimized) loaded from:", __file__)
# ==========================================================
# Hugging Face Cache Setup
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR,
"HF_DATASETS_CACHE": CACHE_DIR,
"HF_MODULES_CACHE": CACHE_DIR
})
print(f"✅ Using Hugging Face cache at {CACHE_DIR}")
# ==========================================================
# Query Embedding Model
# ==========================================================
try:
_query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
print("✅ Loaded embedding model: intfloat/e5-small-v2")
except Exception as e:
print(f"⚠️ Fallback to MiniLM due to {e}")
_query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
# ==========================================================
# Phi-2 Model (Causal LM)
# ==========================================================
MODEL_NAME = "microsoft/phi-2"
print(f"✅ Loading LLM: {MODEL_NAME}")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR,
torch_dtype="auto",
low_cpu_mem_usage=True
)
_answer_model = pipeline(
"text-generation",
model=_model,
tokenizer=_tokenizer,
device=-1 # CPU-compatible
)
print("✅ Phi-2 generation pipeline ready.")
# ==========================================================
# Prompt Templates
# ==========================================================
REASONING_PROMPT = """
You are an intelligent enterprise assistant.
Use the CONTEXT below and your general understanding to answer the QUESTION logically and clearly.
Explain your reasoning briefly if helpful.
---
CONTEXT:
{context}
---
QUESTION:
{query}
---
ANSWER:
"""
STRICT_PROMPT = """
You are an enterprise document assistant.
Use ONLY the CONTEXT below to answer the QUESTION clearly and factually.
If the answer is not found in the context, reply exactly:
"I don't know based on the provided document."
---
CONTEXT:
{context}
---
QUESTION:
{query}
---
ANSWER:
"""
# ==========================================================
# Retrieve Chunks
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
"""Retrieve top-K most relevant chunks quickly (no re-ranking for speed)."""
if not index or not chunks:
return []
query_emb = _query_model.encode(
[f"query: {query.strip()}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
distances, indices = index.search(np.array([query_emb]).astype("float32"), top_k)
return [chunks[i] for i in indices[0]]
# ==========================================================
# Generate Answer (Phi-2)
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = True):
"""Generate answers using Phi-2. Supports reasoning or strict factual modes."""
if not retrieved_chunks:
return "Sorry, I couldn’t find relevant information in the document."
context = "\n".join([chunk.strip() for chunk in retrieved_chunks])
prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query)
try:
result = _answer_model(
prompt,
max_new_tokens=180, # keeps output short & fast
temperature=0.4 if reasoning_mode else 0.2,
do_sample=False, # deterministic
num_beams=1, # no beam search for speed
early_stopping=True,
)
text = result[0]["generated_text"].split("ANSWER:")[-1].strip()
return text
except Exception as e:
print(f"⚠️ Generation failed: {e}")
return "⚠️ Error: Could not generate an answer."
# ==========================================================
# Local Test (optional)
# ==========================================================
if __name__ == "__main__":
from vectorstore import build_faiss_index
dummy_chunks = [
"Step 1: Open the dashboard and navigate to reports.",
"Step 2: Click 'Export' to download a CSV summary.",
"Step 3: Review the generated report in your downloads folder."
]
index = build_faiss_index([
_query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
for chunk in dummy_chunks
])
query = "What are the steps to export a report?"
retrieved = retrieve_chunks(query, index, dummy_chunks)
print("🔍 Retrieved:", retrieved)
print("💬 Answer:", generate_answer(query, retrieved))