File size: 5,873 Bytes
cd266a5 a610ce4 cd266a5 ebbd49e 9f0da7b 4724824 9f0da7b ebbd49e 9f0da7b cd266a5 43b802c cd266a5 50ab09a 0c81fa1 cd266a5 641185f cd266a5 a610ce4 cd266a5 a610ce4 43b802c a610ce4 43b802c a610ce4 cd266a5 43b802c cd266a5 a610ce4 cd266a5 641185f 4724824 93a72c6 4724824 cd266a5 a610ce4 93a72c6 641185f cd266a5 43b802c cd266a5 43b802c cd266a5 6d7ba5b cd266a5 6d7ba5b cd266a5 6d7ba5b cd266a5 43b802c 6d7ba5b cd266a5 43b802c cd266a5 6d7ba5b 43b802c a610ce4 6d7ba5b cd266a5 a610ce4 43b802c a610ce4 43b802c cd266a5 43b802c cd266a5 641185f 6d7ba5b cd266a5 6d7ba5b cd266a5 6d7ba5b 641185f 6d7ba5b a610ce4 6d7ba5b a610ce4 cd266a5 6d7ba5b 09c2f03 6d7ba5b 09c2f03 cd266a5 a610ce4 cd266a5 43b802c cd266a5 43b802c cd266a5 43b802c cd266a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
"""
qa.py — Retrieval + Generation Layer
-------------------------------------
Handles:
• Query embedding (SentenceTransformer / E5-compatible)
• Chunk retrieval (FAISS)
• Answer generation (Flan-T5)
Optimized for Hugging Face Spaces & Streamlit.
"""
import os
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from vectorstore import search_faiss
print("✅ qa.py loaded from:", __file__)
# ==========================================================
# 1️⃣ Hugging Face Cache Setup (Safe for Spaces)
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR,
"HF_DATASETS_CACHE": CACHE_DIR,
"HF_MODULES_CACHE": CACHE_DIR
})
# ==========================================================
# 2️⃣ Query Embedding Model
# ==========================================================
# Use E5-small-v2 for retrieval consistency with embeddings.py
try:
_query_model = SentenceTransformer(
"intfloat/e5-small-v2",
cache_folder=CACHE_DIR
)
print("✅ Loaded query model: intfloat/e5-small-v2")
except Exception as e:
print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
_query_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
cache_folder=CACHE_DIR
)
print("✅ Loaded fallback model: all-MiniLM-L6-v2")
# ==========================================================
# 3️⃣ LLM for Answer Generation (FLAN-T5)
# ==========================================================
MODEL_NAME = "google/flan-t5-base" # switch to 'large' if RAM allows
print(f"✅ Loading LLM: {MODEL_NAME}")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_answer_model = pipeline(
"text2text-generation",
model=_model,
tokenizer=_tokenizer,
device=-1 # CPU-safe for Spaces
)
# ==========================================================
# 4️⃣ Prompt Template (concise and factual)
# ==========================================================
PROMPT_TEMPLATE = """
You are an expert enterprise assistant.
Using ONLY the CONTEXT below, answer the QUESTION clearly and factually.
If the context doesn’t contain the answer, reply exactly:
"I don't know based on the provided document."
---
Context:
{context}
---
Question:
{query}
---
Answer:
"""
# ==========================================================
# 5️⃣ Chunk Retrieval Function
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
"""
Encodes the user query and retrieves top-k relevant chunks via FAISS.
Uses 'query:' prefix (E5 training style) for semantic alignment.
"""
if not index or not chunks:
return []
try:
# E5 expects 'query:' prefix for better retrieval accuracy
query_emb = _query_model.encode(
[f"query: {query.strip()}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
results = search_faiss(query_emb, index, chunks, top_k)
return results
except Exception as e:
print(f"⚠️ Retrieval error: {e}")
return []
# ==========================================================
# 6️⃣ Answer Generation Function
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list):
"""
Generates an answer using FLAN-T5 and retrieved chunks as context.
"""
if not retrieved_chunks:
return "Sorry, I couldn’t find relevant information in the document."
# Merge retrieved chunks for context
context = "\n\n".join([f"[Chunk {i+1}]: {chunk}" for i, chunk in enumerate(retrieved_chunks)])
# Build structured prompt
prompt = PROMPT_TEMPLATE.format(context=context, query=query)
try:
result = _answer_model(
prompt,
max_new_tokens=350, # allow longer, more complete answers
do_sample=True, # enable sampling for natural flow
temperature=0.7, # slightly higher = more expressive responses
top_p=0.95, # nucleus sampling for coherence
repetition_penalty=1.2 # discourages repetitive phrasing
)
answer = result[0]["generated_text"].strip()
# 🧩 If the model outputs something too short, expand gracefully
if len(answer.split()) < 8:
answer = (
"The document mentions this briefly. Based on the context, here's what it suggests: "
+ answer
)
return answer
except Exception as e:
print(f"⚠️ Generation failed: {e}")
return "⚠️ Error: Could not generate an answer at the moment."
# ==========================================================
# 7️⃣ Optional Local Test
# ==========================================================
if __name__ == "__main__":
dummy_chunks = [
"SAP Ariba is a cloud-based procurement solution.",
"It helps companies manage suppliers and sourcing processes efficiently.",
"Integration with SAP ERP allows for seamless data synchronization."
]
from vectorstore import build_faiss_index
import numpy as np
index = build_faiss_index([
_query_model.encode([f"passage: {chunk}"], convert_to_numpy=True, normalize_embeddings=True)[0]
for chunk in dummy_chunks
])
query = "What is SAP Ariba used for?"
retrieved = retrieve_chunks(query, index, dummy_chunks)
print("🔍 Retrieved:", retrieved)
print("💬 Answer:", generate_answer(query, retrieved))
|