|
|
""" |
|
|
qa.py — GPT-4o (SAP Gen AI Hub) + ReRank Retrieval |
|
|
-------------------------------------------------- |
|
|
✅ Semantic retrieval (FAISS + cosine re-rank + neighbor fill) |
|
|
✅ Bullet-aware similarity boost for procedural chunks |
|
|
✅ Embedding caching (per PDF + chunk config aware) |
|
|
✅ Smart factual mode (fast) |
|
|
✅ Deep reasoning mode (ChatGPT-like) |
|
|
✅ genai_generate() helper for suggestions |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import pickle |
|
|
import hashlib |
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from gen_ai_hub.proxy.core.proxy_clients import get_proxy_client |
|
|
from gen_ai_hub.proxy.langchain.openai import ChatOpenAI |
|
|
|
|
|
print("✅ qa.py (GPT-4o via Gen AI Hub + Bullet-Aware Retrieval + Cache) loaded from:", __file__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_DIR = "/tmp/hf_cache" |
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
os.environ.update({ |
|
|
"HF_HOME": CACHE_DIR, |
|
|
"TRANSFORMERS_CACHE": CACHE_DIR, |
|
|
"HF_DATASETS_CACHE": CACHE_DIR, |
|
|
"HF_MODULES_CACHE": CACHE_DIR |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
_query_model = SentenceTransformer( |
|
|
"intfloat/e5-small-v2", |
|
|
cache_folder=CACHE_DIR |
|
|
) |
|
|
print("✅ Loaded embedding model: intfloat/e5-small-v2 (fast mode)") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Embedding load failed ({e}), using MiniLM fallback") |
|
|
_query_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", cache_folder=CACHE_DIR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("✅ Loading GPT-4o via SAP Gen AI Hub...") |
|
|
CRED_PATH = os.path.join(os.path.dirname(__file__), "GEN AI HUB PROXY.json") |
|
|
|
|
|
try: |
|
|
with open(CRED_PATH, "r") as key_file: |
|
|
svcKey = json.load(key_file) |
|
|
|
|
|
os.environ.update({ |
|
|
"AICORE_AUTH_URL": svcKey["url"], |
|
|
"AICORE_CLIENT_ID": svcKey["clientid"], |
|
|
"AICORE_CLIENT_SECRET": svcKey["clientsecret"], |
|
|
"AICORE_RESOURCE_GROUP": "default", |
|
|
"AICORE_BASE_URL": svcKey["serviceurls"]["AI_API_URL"] |
|
|
}) |
|
|
|
|
|
proxy_client = get_proxy_client("gen-ai-hub") |
|
|
chat_llm = ChatOpenAI( |
|
|
proxy_model_name="gpt-4o", |
|
|
proxy_client=proxy_client, |
|
|
temperature=0.3, |
|
|
max_tokens=1500 |
|
|
) |
|
|
print("✅ GPT-4o (via Gen AI Hub) ready for generation.") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Gen AI Hub setup failed: {e}") |
|
|
chat_llm = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def embed_chunks(chunks, batch_size: int = 32): |
|
|
""" |
|
|
Batch-encode text chunks using the global embedding model. |
|
|
Normalized 384-dim embeddings for FAISS retrieval. |
|
|
""" |
|
|
if not chunks: |
|
|
return np.array([]) |
|
|
|
|
|
all_embeddings = [] |
|
|
for i in range(0, len(chunks), batch_size): |
|
|
batch = [f"passage: {c}" for c in chunks[i:i + batch_size]] |
|
|
batch_embs = _query_model.encode( |
|
|
batch, |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True, |
|
|
show_progress_bar=False |
|
|
) |
|
|
all_embeddings.extend(batch_embs) |
|
|
print(f"⚡ Embedded {len(all_embeddings)} chunks in batches of {batch_size}") |
|
|
return np.array(all_embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_EMB_DIR = "/tmp/embed_cache" |
|
|
os.makedirs(CACHE_EMB_DIR, exist_ok=True) |
|
|
|
|
|
def _hash_name(file_name: str, chunk_size: int, overlap: int, num_chunks: int): |
|
|
"""Generate unique short hash for a file + chunking configuration.""" |
|
|
combo = f"{file_name}_{chunk_size}_{overlap}_{num_chunks}" |
|
|
return hashlib.md5(combo.encode()).hexdigest()[:8] |
|
|
|
|
|
def _clean_old_caches(base_name: str, keep_latest: int = 5): |
|
|
"""Keep only latest few embedding caches for each document.""" |
|
|
files = [ |
|
|
(os.path.getmtime(os.path.join(CACHE_EMB_DIR, f)), f) |
|
|
for f in os.listdir(CACHE_EMB_DIR) |
|
|
if f.startswith(base_name) |
|
|
] |
|
|
if len(files) > keep_latest: |
|
|
files.sort(reverse=True) |
|
|
for _, old_file in files[keep_latest:]: |
|
|
try: |
|
|
os.remove(os.path.join(CACHE_EMB_DIR, old_file)) |
|
|
print(f"🧹 Removed old cache: {old_file}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def cache_embeddings(file_name: str, chunks, embed_func, chunk_size: int = None, overlap: int = None): |
|
|
"""Load or create embeddings cache (chunk size + overlap aware).""" |
|
|
cache_key = _hash_name(file_name, chunk_size or 1000, overlap or 100, len(chunks)) |
|
|
cache_file = f"{os.path.basename(file_name)}_cs{chunk_size}_ov{overlap}_{cache_key}.pkl" |
|
|
cache_path = os.path.join(CACHE_EMB_DIR, cache_file) |
|
|
base_name = os.path.basename(file_name) |
|
|
|
|
|
if os.path.exists(cache_path): |
|
|
print(f"🧠 Loaded cached embeddings for {base_name} ({chunk_size}/{overlap})") |
|
|
with open(cache_path, "rb") as f: |
|
|
return pickle.load(f) |
|
|
|
|
|
print(f"💡 No cache found for {base_name} ({chunk_size}/{overlap}). Generating new embeddings...") |
|
|
embeddings = embed_func(chunks) |
|
|
with open(cache_path, "wb") as f: |
|
|
pickle.dump(embeddings, f) |
|
|
print(f"💾 Cached embeddings saved as {cache_file}") |
|
|
_clean_old_caches(base_name, keep_latest=5) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STRICT_PROMPT = ( |
|
|
"You are an enterprise documentation assistant.\n" |
|
|
"Use all relevant information from the CONTEXT below.\n" |
|
|
"If multiple related points appear across chunks, combine them logically into one clear answer.\n" |
|
|
"Keep the answer concise but complete. Do not invent facts outside the provided content.\n" |
|
|
"If the answer cannot be found even after considering all chunks, say exactly:\n" |
|
|
"'I don't know based on the provided document.'\n\n" |
|
|
"Context:\n{context}\n\nQuestion: {query}\nAnswer:" |
|
|
) |
|
|
|
|
|
REASONING_PROMPT = ( |
|
|
"You are an expert enterprise assistant capable of reasoning.\n" |
|
|
"Think step by step and synthesize information even if scattered across chunks.\n" |
|
|
"Base your answer primarily on the CONTEXT, but if multiple partial clues exist, combine them logically.\n" |
|
|
"You may fill reasonable gaps with general knowledge to form a complete answer.\n" |
|
|
"If absolutely nothing in the document relates, say exactly:\n" |
|
|
"'I don't know based on the provided document.'\n\n" |
|
|
"Context:\n{context}\n\nQuestion: {query}\nLet's reason step-by-step:\nAnswer:" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from vectorstore import build_faiss_index |
|
|
|
|
|
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 5, |
|
|
min_similarity: float = 0.6, candidate_multiplier: int = 3, |
|
|
embeddings: list = None): |
|
|
if not index or not chunks: |
|
|
print("⚠️ No FAISS index or chunks provided — returning empty result.") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
q_emb = _query_model.encode( |
|
|
[f"query: {query.strip()}"], |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True |
|
|
)[0] |
|
|
|
|
|
if hasattr(index, "d") and q_emb.shape[0] != index.d: |
|
|
print(f"⚠️ FAISS dimension mismatch: index={index.d}, query={q_emb.shape[0]}") |
|
|
if embeddings: |
|
|
print("🔄 Rebuilding FAISS index...") |
|
|
index = build_faiss_index(embeddings) |
|
|
else: |
|
|
return [] |
|
|
|
|
|
num_candidates = max(top_k * candidate_multiplier, top_k + 2) |
|
|
distances, indices = index.search(np.array([q_emb]).astype("float32"), num_candidates) |
|
|
candidate_indices = [int(i) for i in indices[0] if i >= 0] |
|
|
candidate_indices = list(dict.fromkeys(candidate_indices)) |
|
|
|
|
|
doc_embs = _query_model.encode( |
|
|
[f"passage: {chunks[i]}" for i in candidate_indices], |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True, |
|
|
) |
|
|
sims = cosine_similarity([q_emb], doc_embs)[0] |
|
|
boosted_sims = [] |
|
|
for idx, sim in zip(candidate_indices, sims): |
|
|
text = chunks[idx].strip() |
|
|
if re.match(r"^[-•\d]+[\.\s]", text): |
|
|
sim += 0.05 |
|
|
boosted_sims.append((idx, sim)) |
|
|
|
|
|
ranked = sorted(boosted_sims, key=lambda x: x[1], reverse=True) |
|
|
filtered = [idx for idx, sim in ranked if sim >= min_similarity][:top_k] |
|
|
|
|
|
neighbors = set() |
|
|
for idx in filtered: |
|
|
for n in [idx - 1, idx + 1]: |
|
|
if 0 <= n < len(chunks): |
|
|
neighbors.add(n) |
|
|
filtered = sorted(set(filtered) | neighbors) |
|
|
final_chunks = [chunks[i] for i in filtered] |
|
|
print(f"✅ Retrieved {len(final_chunks)} chunks (bullet-aware + continuity).") |
|
|
return final_chunks |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Retrieval error: {repr(e)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_answer(query: str, retrieved_chunks: list, reasoning_mode: bool = False): |
|
|
if not retrieved_chunks: |
|
|
return "Sorry, I couldn’t find relevant information in the document." |
|
|
if chat_llm is None: |
|
|
return "⚠️ GPT-4o not initialized. Check credentials or rebuild the Space." |
|
|
|
|
|
context = "\n".join(f"[Chunk {i+1}] {chunk.strip()}" for i, chunk in enumerate(retrieved_chunks)) |
|
|
prompt = (REASONING_PROMPT if reasoning_mode else STRICT_PROMPT).format(context=context, query=query) |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": |
|
|
"You are an expert enterprise documentation assistant. " |
|
|
"When reasoning_mode is off, stay strictly factual and concise. " |
|
|
"When reasoning_mode is on, combine insights across chunks logically " |
|
|
"and explain briefly. " |
|
|
"If the answer is not in the document, reply exactly: " |
|
|
"'I don't know based on the provided document.'"}, |
|
|
{"role": "user", "content": prompt}, |
|
|
] |
|
|
|
|
|
try: |
|
|
response = chat_llm.invoke(messages) |
|
|
return response.content.strip() |
|
|
except Exception as e: |
|
|
print(f"⚠️ GPT-4o generation failed: {e}") |
|
|
return "⚠️ Error: Could not generate an answer." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def genai_generate(prompt: str) -> str: |
|
|
if chat_llm is None: |
|
|
raise RuntimeError("⚠️ GPT-4o not initialized. Check credentials or rebuild the Space.") |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a concise, intelligent text generator."}, |
|
|
{"role": "user", "content": prompt.strip()}, |
|
|
] |
|
|
|
|
|
try: |
|
|
response = chat_llm.invoke(messages) |
|
|
return response.content.strip() |
|
|
except Exception as e: |
|
|
print(f"⚠️ genai_generate() failed: {e}") |
|
|
return "⚠️ Unable to generate response." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from vectorstore import build_faiss_index |
|
|
|
|
|
dummy_chunks = [ |
|
|
"- Step 1: Enable order confirmation capability.", |
|
|
"- Step 2: Configure supplier email.", |
|
|
"Setup instructions and configuration details.", |
|
|
"Prerequisites for automation are described here." |
|
|
] |
|
|
|
|
|
embeddings = embed_chunks(dummy_chunks) |
|
|
index = build_faiss_index(embeddings) |
|
|
|
|
|
query = "What are the prerequisites for commerce automation?" |
|
|
retrieved = retrieve_chunks(query, index, dummy_chunks) |
|
|
print("🔍 Retrieved:", retrieved) |
|
|
print("💬 Answer:", generate_answer(query, retrieved, reasoning_mode=False)) |
|
|
|