|
|
""" |
|
|
qa.py — Phi-2 Fast + Smart Reasoning Mode (Hybrid) |
|
|
------------------------------------------------- |
|
|
✅ Uses intfloat/e5-small-v2 for embeddings |
|
|
✅ Uses microsoft/phi-2 (fast CPU quantized) |
|
|
✅ Reasoning Mode toggle integrated cleanly |
|
|
✅ Retrieval unaffected by reasoning mode |
|
|
""" |
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
print("✅ qa.py (Phi-2 Hybrid) loaded from:", __file__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_DIR = "/tmp/hf_cache" |
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
os.environ.update({ |
|
|
"HF_HOME": CACHE_DIR, |
|
|
"TRANSFORMERS_CACHE": CACHE_DIR, |
|
|
"HF_DATASETS_CACHE": CACHE_DIR, |
|
|
"HF_MODULES_CACHE": CACHE_DIR |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
_query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR) |
|
|
print("✅ Loaded embedding model: intfloat/e5-small-v2") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Fallback to MiniLM due to {e}") |
|
|
_query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
MODEL_NAME = "microsoft/phi-2" |
|
|
print(f"✅ Loading LLM: {MODEL_NAME} (optimized for reasoning)") |
|
|
|
|
|
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) |
|
|
_model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
cache_dir=CACHE_DIR, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
).to("cpu") |
|
|
|
|
|
_answer_model = pipeline( |
|
|
"text-generation", |
|
|
model=_model, |
|
|
tokenizer=_tokenizer, |
|
|
device=-1, |
|
|
model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True}, |
|
|
) |
|
|
|
|
|
print("✅ Phi-2 text-generation pipeline ready.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Phi-2 load failed: {e}") |
|
|
_answer_model = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STRICT_PROMPT = ( |
|
|
"Answer based ONLY on the context below.\n" |
|
|
"If the answer isn’t in the context, say: 'I don't know based on the provided document.'\n\n" |
|
|
"Context:\n{context}\n\nQuestion: {query}\nAnswer:" |
|
|
) |
|
|
|
|
|
REASONING_PROMPT = ( |
|
|
"You are an expert assistant. Use the context and your reasoning ability to form a clear, step-by-step answer.\n" |
|
|
"Be concise yet complete. If the context doesn’t contain the answer, say: 'I don't know based on the provided document.'\n\n" |
|
|
"Context:\n{context}\n\nQuestion: {query}\nAnswer:" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5): |
|
|
"""Fast FAISS retrieval using cosine similarity.""" |
|
|
if not index or not chunks: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
q_emb = _query_model.encode( |
|
|
[f"query: {query.strip()}"], |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True |
|
|
)[0] |
|
|
distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * 2) |
|
|
|
|
|
selected = set() |
|
|
for idx in indices[0]: |
|
|
for i in range(max(0, idx - 1), min(len(chunks), idx + 2)): |
|
|
selected.add(i) |
|
|
|
|
|
return [chunks[i] for i in sorted(selected)] |
|
|
except Exception as e: |
|
|
print(f"⚠️ Retrieval error: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False): |
|
|
""" |
|
|
Generate answers with Phi-2. |
|
|
- reasoning_mode=False → strict factual, fast |
|
|
- reasoning_mode=True → analytical, richer reasoning (slower) |
|
|
""" |
|
|
if not retrieved_chunks: |
|
|
return "Sorry, I couldn’t find relevant information in the document." |
|
|
|
|
|
context = "\n".join(chunk.strip() for chunk in retrieved_chunks) |
|
|
|
|
|
|
|
|
REASONING_PROMPT = ( |
|
|
"You are an expert assistant with strong reasoning skills.\n" |
|
|
"Think step by step and form a detailed, logical answer.\n" |
|
|
"You can combine hints from the context with your general understanding.\n" |
|
|
"If the context doesn't mention the answer, acknowledge that.\n\n" |
|
|
"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:" |
|
|
) |
|
|
|
|
|
|
|
|
STRICT_PROMPT = ( |
|
|
"Answer based ONLY on the context below.\n" |
|
|
"If the answer isn’t in the context, say: 'I don't know based on the provided document.'\n\n" |
|
|
"Context:\n{context}\n\nQuestion: {query}\nAnswer:" |
|
|
) |
|
|
|
|
|
prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query) |
|
|
|
|
|
try: |
|
|
if reasoning_mode: |
|
|
|
|
|
result = _answer_model( |
|
|
prompt, |
|
|
max_new_tokens=180, |
|
|
temperature=0.45, |
|
|
do_sample=False, |
|
|
pad_token_id=_tokenizer.eos_token_id, |
|
|
) |
|
|
else: |
|
|
|
|
|
result = _answer_model( |
|
|
prompt, |
|
|
max_new_tokens=120, |
|
|
temperature=0.2, |
|
|
do_sample=False, |
|
|
pad_token_id=_tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
raw = result[0]["generated_text"].strip() |
|
|
if "Answer:" in raw: |
|
|
raw = raw.split("Answer:")[-1].strip() |
|
|
return raw |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Generation failed: {e}") |
|
|
return "⚠️ Error: Could not generate an answer." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from vectorstore import build_faiss_index |
|
|
dummy_chunks = [ |
|
|
"Step 1: Open the dashboard and navigate to reports.", |
|
|
"Step 2: Click 'Export' to download a CSV summary.", |
|
|
"Step 3: Review the generated report in your downloads folder." |
|
|
] |
|
|
embeddings = [ |
|
|
_query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0] |
|
|
for chunk in dummy_chunks |
|
|
] |
|
|
index = build_faiss_index(embeddings) |
|
|
|
|
|
query = "What are the steps to export a report?" |
|
|
retrieved = retrieve_chunks(query, index, dummy_chunks) |
|
|
|
|
|
print("\n--- Strict Mode ---") |
|
|
print(generate_answer(query, retrieved, reasoning_mode=False)) |
|
|
|
|
|
print("\n--- Reasoning Mode ---") |
|
|
print(generate_answer(query, retrieved, reasoning_mode=True)) |
|
|
|