Shubham170793's picture
Update src/qa.py
1242abb verified
raw
history blame
7.93 kB
"""
qa.py — Phi-2 FAST + ReRank (with FULL Reasoning Mode)
-------------------------------------------------------
✅ Semantic retrieval (FAISS + cosine re-rank + neighbor-fill)
✅ Smart factual mode
✅ Deep reasoning mode (ChatGPT-like)
"""
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
print("✅ qa.py (Phi-2 FAST + ReRank + Full Reasoning) loaded from:", __file__)
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("❌ OPENAI_API_KEY not found in environment!")
else:
print("✅ OPENAI_API_KEY loaded successfully (length:", len(api_key), ")")
# ==========================================================
# 1️⃣ Cache Setup
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR,
"HF_DATASETS_CACHE": CACHE_DIR,
"HF_MODULES_CACHE": CACHE_DIR
})
# ==========================================================
# 2️⃣ Embedding Model
# ==========================================================
try:
_query_model = SentenceTransformer("intfloat/e5-small-v2", cache_folder=CACHE_DIR)
print("✅ Loaded embedding model: intfloat/e5-small-v2")
except Exception as e:
print(f"⚠️ Embedding load failed ({e}), using MiniLM fallback")
_query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR)
# ==========================================================
# 3️⃣ GPT-4o Model Setup (OpenAI API)
# ==========================================================
from openai import OpenAI
MODEL_NAME = "gpt-4o"
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
print(f"✅ Connected to OpenAI GPT model: {MODEL_NAME}")
# ==========================================================
# 4️⃣ Prompts
# ==========================================================
STRICT_PROMPT = (
"You are an enterprise documentation assistant.\n"
"Use ONLY the CONTEXT below to answer the QUESTION clearly and factually.\n"
"If the answer isn’t in the document, reply exactly:\n"
"'I don't know based on the provided document.'\n\n"
"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
)
REASONING_PROMPT = (
"You are an expert enterprise assistant capable of deep reasoning.\n"
"Think step by step before answering. Use the CONTEXT below first, but also apply your world knowledge logically.\n"
"Explain your reasoning concisely if it helps clarity.\n"
"Avoid hallucination — if the document does not include the answer, say:\n"
"'I don't know based on the provided document.'\n\n"
"Context:\n{context}\n\nQuestion: {query}\nLet's reason this out carefully:\nAnswer:"
)
# ==========================================================
# 5️⃣ Retrieval — FAISS + Re-rank + Neighbor Fill
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5,
min_similarity: float = 0.6, candidate_multiplier: int = 3):
"""Re-rank and optionally fill with neighbors for context continuity."""
if not index or not chunks:
return []
try:
q_emb = _query_model.encode(
[f"query: {query.strip()}"], convert_to_numpy=True, normalize_embeddings=True
)[0]
# Initial FAISS search
distances, indices = index.search(np.array([q_emb]).astype("float32"), top_k * candidate_multiplier)
candidate_indices = list(dict.fromkeys(indices[0])) # dedup
# Re-rank by cosine similarity
doc_embs = _query_model.encode(
[f"passage: {chunks[i]}" for i in candidate_indices],
convert_to_numpy=True,
normalize_embeddings=True,
)
sims = cosine_similarity([q_emb], doc_embs)[0]
ranked = sorted(zip(candidate_indices, sims), key=lambda x: x[1], reverse=True)
# Filter by min_similarity
filtered = [idx for idx, sim in ranked if sim >= min_similarity]
if len(filtered) > top_k:
filtered = filtered[:top_k]
# Neighbor fill if needed
if len(filtered) < top_k:
expanded = set(filtered)
for idx in filtered:
for neighbor in [idx - 1, idx + 1]:
if 0 <= neighbor < len(chunks):
expanded.add(neighbor)
if len(expanded) >= top_k:
break
if len(expanded) >= top_k:
break
filtered = sorted(expanded)[:top_k]
return [chunks[i] for i in filtered]
except Exception as e:
print(f"⚠️ Retrieval error: {e}")
return []
# ==========================================================
# 6️⃣ Answer Generation (GPT-4o with Full Reasoning)
# ==========================================================
from openai import OpenAI
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
MODEL_NAME = "gpt-4o"
def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False):
"""
Generates answers using GPT-4o.
- reasoning_mode=False → strict factual mode (fast)
- reasoning_mode=True → reasoning-rich mode (longer, more explanatory)
"""
if not retrieved_chunks:
return "Sorry, I couldn’t find relevant information in the document."
# Format context with chunk tags
context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks))
prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(
context=context, query=query
)
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": (
"You are an expert enterprise documentation assistant. "
"Answer questions precisely using the provided context. "
"If reasoning_mode is enabled, provide deeper explanations and step-by-step logic. "
"If the document lacks information, respond exactly: "
"'I don't know based on the provided document.'"
),
},
{"role": "user", "content": prompt},
],
temperature=0.6 if reasoning_mode else 0.2,
max_tokens=600 if reasoning_mode else 350,
top_p=0.95,
)
text = response.choices[0].message.content.strip()
return text
except Exception as e:
print(f"⚠️ GPT-4o generation failed: {e}")
return "⚠️ Error: Could not generate an answer."
# ==========================================================
# 7️⃣ Local Test
# ==========================================================
if __name__ == "__main__":
from vectorstore import build_faiss_index
dummy_chunks = [
"Step 1: Open the dashboard and navigate to reports.",
"Step 2: Click 'Export' to download a CSV summary.",
"Step 3: Review the generated report in your downloads folder.",
"Appendix: Communication user creation steps are explained later in this guide."
]
embeddings = [
_query_model.encode([f"passage: {c}"], convert_to_numpy=True, normalize_embeddings=True)[0]
for c in dummy_chunks
]
index = build_faiss_index(embeddings)
query = "How do I create a communication user?"
retrieved = retrieve_chunks(query, index, dummy_chunks)
print("🔍 Retrieved:", retrieved)
print("💬 Answer:", generate_answer(query, retrieved, reasoning_mode=True))