|
|
""" |
|
|
qa.py — GPT-4o (SAP Gen AI Hub) + ReRank Retrieval |
|
|
-------------------------------------------------- |
|
|
✅ Semantic retrieval (FAISS + cosine re-rank + neighbor fill) |
|
|
✅ Bullet-aware similarity boost for procedural chunks |
|
|
✅ Embedding caching (per PDF) |
|
|
✅ Smart factual mode (fast) |
|
|
✅ Deep reasoning mode (ChatGPT-like) |
|
|
✅ genai_generate() helper for suggestions |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import pickle |
|
|
import hashlib |
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client |
|
|
from gen_ai_hub.proxy.langchain.openai import ChatOpenAI |
|
|
|
|
|
print("✅ qa.py (GPT-4o via Gen AI Hub + Bullet-Aware Retrieval + Cache) loaded from:", __file__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_DIR = "/tmp/hf_cache" |
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
os.environ.update({ |
|
|
"HF_HOME": CACHE_DIR, |
|
|
"TRANSFORMERS_CACHE": CACHE_DIR, |
|
|
"HF_DATASETS_CACHE": CACHE_DIR, |
|
|
"HF_MODULES_CACHE": CACHE_DIR |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
_query_model = SentenceTransformer( |
|
|
"intfloat/e5-small-v2", |
|
|
cache_folder=CACHE_DIR |
|
|
) |
|
|
print("✅ Loaded embedding model: intfloat/e5-small-v2 (fast mode)") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Embedding load failed ({e}), using MiniLM fallback") |
|
|
_query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("✅ Loading GPT-4o via SAP Gen AI Hub...") |
|
|
CRED_PATH = os.path.join(os.path.dirname(__file__), "GEN AI HUB PROXY.json") |
|
|
|
|
|
try: |
|
|
with open(CRED_PATH, "r") as key_file: |
|
|
svcKey = json.load(key_file) |
|
|
|
|
|
os.environ.update({ |
|
|
"AICORE_AUTH_URL": svcKey["url"], |
|
|
"AICORE_CLIENT_ID": svcKey["clientid"], |
|
|
"AICORE_CLIENT_SECRET": svcKey["clientsecret"], |
|
|
"AICORE_RESOURCE_GROUP": "default", |
|
|
"AICORE_BASE_URL": svcKey["serviceurls"]["AI_API_URL"] |
|
|
}) |
|
|
|
|
|
proxy_client = get_proxy_client("gen-ai-hub") |
|
|
chat_llm = ChatOpenAI( |
|
|
proxy_model_name="gpt-4o", |
|
|
proxy_client=proxy_client, |
|
|
temperature=0.3, |
|
|
max_tokens=1500 |
|
|
) |
|
|
print("✅ GPT-4o (via Gen AI Hub) ready for generation.") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Gen AI Hub setup failed: {e}") |
|
|
chat_llm = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_EMB_DIR = "/tmp/embed_cache" |
|
|
os.makedirs(CACHE_EMB_DIR, exist_ok=True) |
|
|
|
|
|
def _hash_name(file_name: str): |
|
|
"""Generate unique hash for PDF file name.""" |
|
|
return hashlib.md5(file_name.encode()).hexdigest() |
|
|
|
|
|
def cache_embeddings(file_name: str, chunks, embed_func): |
|
|
""" |
|
|
Checks if cached embeddings exist for a PDF; if not, compute and save. |
|
|
""" |
|
|
cache_path = os.path.join(CACHE_EMB_DIR, f"{_hash_name(file_name)}.pkl") |
|
|
|
|
|
if os.path.exists(cache_path): |
|
|
print(f"🧠 Loaded cached embeddings for {file_name}") |
|
|
with open(cache_path, "rb") as f: |
|
|
return pickle.load(f) |
|
|
|
|
|
print(f"💡 No cache found for {file_name}. Generating embeddings...") |
|
|
embeddings = embed_func(chunks) |
|
|
with open(cache_path, "wb") as f: |
|
|
pickle.dump(embeddings, f) |
|
|
print(f"💾 Cached embeddings saved for {file_name}") |
|
|
return embeddings |
|
|
|
|
|
def embed_chunks(chunks, batch_size=32): |
|
|
""" |
|
|
Batch-encode text chunks for speed. |
|
|
""" |
|
|
all_embeddings = [] |
|
|
for i in range(0, len(chunks), batch_size): |
|
|
batch = [f"passage: {c}" for c in chunks[i:i+batch_size]] |
|
|
batch_embs = _query_model.encode( |
|
|
batch, |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True, |
|
|
show_progress_bar=False |
|
|
) |
|
|
all_embeddings.extend(batch_embs) |
|
|
print(f"⚡ Embedded {len(all_embeddings)} chunks in batches of {batch_size}") |
|
|
return np.array(all_embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STRICT_PROMPT = ( |
|
|
"You are an enterprise documentation assistant.\n" |
|
|
"Use all relevant information from the CONTEXT below.\n" |
|
|
"If multiple related points appear across chunks, combine them logically into one clear answer.\n" |
|
|
"Keep the answer concise but complete. Do not invent facts outside the provided content.\n" |
|
|
"If the answer cannot be found even after considering all chunks, say exactly:\n" |
|
|
"'I don't know based on the provided document.'\n\n" |
|
|
"Context:\n{context}\n\nQuestion: {query}\nAnswer:" |
|
|
) |
|
|
|
|
|
REASONING_PROMPT = ( |
|
|
"You are an expert enterprise assistant capable of reasoning.\n" |
|
|
"Think step by step and synthesize information even if scattered across chunks.\n" |
|
|
"Base your answer primarily on the CONTEXT, but if multiple partial clues exist, combine them logically.\n" |
|
|
"You may fill reasonable gaps with general knowledge to form a complete answer.\n" |
|
|
"If absolutely nothing in the document relates, say exactly:\n" |
|
|
"'I don't know based on the provided document.'\n\n" |
|
|
"Context:\n{context}\n\nQuestion: {query}\nLet's reason step-by-step:\nAnswer:" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from vectorstore import build_faiss_index |
|
|
|
|
|
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, |
|
|
min_similarity: float = 0.6, candidate_multiplier: int = 3, |
|
|
embeddings: list = None): |
|
|
""" |
|
|
Retrieves top relevant chunks and preserves context continuity. |
|
|
Adds small similarity boost for procedural (bullet or numbered) chunks. |
|
|
""" |
|
|
|
|
|
if not index or not chunks: |
|
|
print("⚠️ No FAISS index or chunks provided — returning empty result.") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
q_emb = _query_model.encode( |
|
|
[f"query: {query.strip()}"], |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True |
|
|
)[0] |
|
|
|
|
|
|
|
|
if hasattr(index, "d") and q_emb.shape[0] != index.d: |
|
|
print(f"⚠️ FAISS dimension mismatch: index={index.d}, query={q_emb.shape[0]}") |
|
|
if embeddings: |
|
|
print("🔄 Rebuilding FAISS index...") |
|
|
index = build_faiss_index(embeddings) |
|
|
else: |
|
|
return [] |
|
|
|
|
|
|
|
|
num_candidates = max(top_k * candidate_multiplier, top_k + 2) |
|
|
distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates) |
|
|
candidate_indices = [int(i) for i in indices[0] if i >= 0] |
|
|
candidate_indices = list(dict.fromkeys(candidate_indices)) |
|
|
|
|
|
|
|
|
doc_embs = _query_model.encode( |
|
|
[f"passage: {chunks[i]}" for i in candidate_indices], |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True, |
|
|
) |
|
|
sims = cosine_similarity([q_emb], doc_embs)[0] |
|
|
boosted_sims = [] |
|
|
for idx, sim in zip(candidate_indices, sims): |
|
|
text = chunks[idx].strip() |
|
|
if re.match(r"^[-•\d]+[\.\s]", text): |
|
|
sim += 0.05 |
|
|
boosted_sims.append((idx, sim)) |
|
|
|
|
|
ranked = sorted(boosted_sims, key=lambda x: x[1], reverse=True) |
|
|
filtered = [idx for idx, sim in ranked if sim >= min_similarity][:top_k] |
|
|
|
|
|
|
|
|
neighbors = set() |
|
|
for idx in filtered: |
|
|
for n in [idx - 1, idx + 1]: |
|
|
if 0 <= n < len(chunks): |
|
|
neighbors.add(n) |
|
|
filtered = sorted(set(filtered) | neighbors) |
|
|
|
|
|
final_chunks = [chunks[i] for i in filtered] |
|
|
print(f"✅ Retrieved {len(final_chunks)} chunks (bullet-aware + continuity).") |
|
|
return final_chunks |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Retrieval error: {repr(e)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False): |
|
|
if not retrieved_chunks: |
|
|
return "Sorry, I couldn’t find relevant information in the document." |
|
|
if chat_llm is None: |
|
|
return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space." |
|
|
|
|
|
context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks)) |
|
|
prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query) |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": |
|
|
"You are an expert enterprise documentation assistant. " |
|
|
"When reasoning_mode is off, stay strictly factual and concise. " |
|
|
"When reasoning_mode is on, combine insights across chunks logically " |
|
|
"and explain briefly. " |
|
|
"If the answer is not in the document, reply exactly: " |
|
|
"'I don't know based on the provided document.'"}, |
|
|
{"role": "user", "content": prompt}, |
|
|
] |
|
|
|
|
|
try: |
|
|
response = chat_llm.invoke(messages) |
|
|
return response.content.strip() |
|
|
except Exception as e: |
|
|
print(f"⚠️ GPT-4o generation failed: {e}") |
|
|
return "⚠️ Error: Could not generate an answer." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def genai_generate(prompt: str) -> str: |
|
|
""" |
|
|
Utility for single-turn GPT-4o generation (e.g., query suggestions, summaries). |
|
|
Uses the same SAP Gen AI Hub connection as main assistant. |
|
|
""" |
|
|
global chat_llm |
|
|
if chat_llm is None: |
|
|
raise RuntimeError("⚠️ GPT-4o not initialized. Check credentials or rebuild the Space.") |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a concise, intelligent text generator."}, |
|
|
{"role": "user", "content": prompt.strip()}, |
|
|
] |
|
|
|
|
|
try: |
|
|
response = chat_llm.invoke(messages) |
|
|
return response.content.strip() |
|
|
except Exception as e: |
|
|
print(f"⚠️ genai_generate() failed: {e}") |
|
|
return "⚠️ Unable to generate response." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from vectorstore import build_faiss_index |
|
|
|
|
|
dummy_chunks = [ |
|
|
"- Step 1: Enable order confirmation capability.", |
|
|
"- Step 2: Configure supplier email.", |
|
|
"Setup instructions and configuration details.", |
|
|
"Prerequisites for automation are described here." |
|
|
] |
|
|
|
|
|
embeddings = embed_chunks(dummy_chunks) |
|
|
index = build_faiss_index(embeddings) |
|
|
|
|
|
query = "What are the prerequisites for commerce automation?" |
|
|
retrieved = retrieve_chunks(query, index, dummy_chunks) |
|
|
print("🔍 Retrieved:", retrieved) |
|
|
print("💬 Answer:", generate_answer(query, retrieved, reasoning_mode=False)) |
|
|
|