File size: 5,901 Bytes
cd266a5 a610ce4 cd266a5 ebbd49e 9f0da7b 4724824 9f0da7b ebbd49e 9f0da7b cd266a5 f41d618 cd266a5 50ab09a 0c81fa1 cd266a5 641185f cd266a5 a610ce4 cd266a5 43b802c a610ce4 43b802c a610ce4 cd266a5 43b802c cd266a5 743f89e cd266a5 641185f 4724824 93a72c6 4724824 cd266a5 743f89e 93a72c6 641185f cd266a5 743f89e cd266a5 43b802c 743f89e 43b802c cd266a5 6d7ba5b cd266a5 6d7ba5b cd266a5 6d7ba5b cd266a5 43b802c 6d7ba5b cd266a5 43b802c cd266a5 6d7ba5b 43b802c a610ce4 6d7ba5b cd266a5 43b802c a610ce4 43b802c cd266a5 43b802c cd266a5 641185f 6d7ba5b cd266a5 6d7ba5b cd266a5 743f89e 6d7ba5b 641185f 6d7ba5b 743f89e a610ce4 cd266a5 6d7ba5b 743f89e 6d7ba5b 09c2f03 743f89e 09c2f03 743f89e 7b609f8 09c2f03 743f89e cd266a5 743f89e cd266a5 43b802c cd266a5 743f89e cd266a5 43b802c cd266a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
"""
qa.py — Retrieval + Generation Layer
-------------------------------------
Handles:
• Query embedding (SentenceTransformer / E5-compatible)
• Chunk retrieval (FAISS)
• Answer generation (Flan-T5)
Optimized for Hugging Face Spaces & Streamlit.
"""
import os
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from vectorstore import search_faiss
print("✅ qa.py loaded from:", __file__)
# ==========================================================
# 1️⃣ Hugging Face Cache Setup (Safe for Spaces) :)
# ==========================================================
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ.update({
"HF_HOME": CACHE_DIR,
"TRANSFORMERS_CACHE": CACHE_DIR,
"HF_DATASETS_CACHE": CACHE_DIR,
"HF_MODULES_CACHE": CACHE_DIR
})
# ==========================================================
# 2️⃣ Query Embedding Model
# ==========================================================
try:
_query_model = SentenceTransformer(
"intfloat/e5-small-v2",
cache_folder=CACHE_DIR
)
print("✅ Loaded query model: intfloat/e5-small-v2")
except Exception as e:
print(f"⚠️ Query model load failed ({e}), falling back to MiniLM.")
_query_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
cache_folder=CACHE_DIR
)
print("✅ Loaded fallback model: all-MiniLM-L6-v2")
# ==========================================================
# 3️⃣ LLM for Answer Generation (FLAN-T5)
# ==========================================================
MODEL_NAME = "google/flan-t5-base" # Switch to 'large' if enough memory
print(f"✅ Loading LLM: {MODEL_NAME}")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
_answer_model = pipeline(
"text2text-generation",
model=_model,
tokenizer=_tokenizer,
device=-1 # CPU-safe (Hugging Face Spaces)
)
# ==========================================================
# 4️⃣ Prompt Template
# ==========================================================
PROMPT_TEMPLATE = """
You are an expert enterprise knowledge assistant.
Use ONLY the CONTEXT below to answer the QUESTION clearly, factually, and completely.
If the context doesn’t contain the answer, reply exactly:
"I don't know based on the provided document."
---
Context:
{context}
---
Question:
{query}
---
Answer:
"""
# ==========================================================
# 5️⃣ Chunk Retrieval Function
# ==========================================================
def retrieve_chunks(query: str, index, chunks: list, top_k: int = 3):
"""
Encodes the user query and retrieves top-k relevant chunks via FAISS.
Uses 'query:' prefix (E5 training style) for semantic alignment.
"""
if not index or not chunks:
return []
try:
query_emb = _query_model.encode(
[f"query: {query.strip()}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
results = search_faiss(query_emb, index, chunks, top_k)
return results
except Exception as e:
print(f"⚠️ Retrieval error: {e}")
return []
# ==========================================================
# 6️⃣ Answer Generation Function
# ==========================================================
def generate_answer(query: str, retrieved_chunks: list):
"""
Generates an answer using FLAN-T5 and retrieved chunks as context.
Includes dynamic length, sampling for expressiveness, and fallback logic.
"""
if not retrieved_chunks:
return "Sorry, I couldn’t find relevant information in the document."
# Merge retrieved chunks into one coherent context
context = "\n\n".join([
f"[Chunk {i+1}]: {chunk.strip()}"
for i, chunk in enumerate(retrieved_chunks)
])
prompt = PROMPT_TEMPLATE.format(context=context, query=query)
try:
result = _answer_model(
prompt,
max_new_tokens=400, # allow more elaborate responses
do_sample=True, # enable natural variability
temperature=0.7, # creativity balance
top_p=0.9, # nucleus sampling for relevance
repetition_penalty=1.15 # discourage repetition
)
answer = result[0]["generated_text"].strip()
# 🧩 Handle overly short answers
#if len(answer.split()) < 5:
# answer = (
# "The document briefly mentions this. Based on the context, here's what it implies: "
#+ answer
#)
return answer
except Exception as e:
print(f"⚠️ Generation failed: {e}")
return "⚠️ Error: Could not generate an answer at the moment."
# ==========================================================
# 7️⃣ Optional Local Test (runs only in dev mode)
# ==========================================================
if __name__ == "__main__":
dummy_chunks = [
"SAP Ariba is a cloud-based procurement solution.",
"It helps companies manage suppliers and sourcing processes efficiently.",
"Integration with SAP ERP allows for seamless data synchronization."
]
from vectorstore import build_faiss_index
index = build_faiss_index([
_query_model.encode(
[f"passage: {chunk}"],
convert_to_numpy=True,
normalize_embeddings=True
)[0]
for chunk in dummy_chunks
])
query = "What is SAP Ariba used for?"
retrieved = retrieve_chunks(query, index, dummy_chunks)
print("🔍 Retrieved:", retrieved)
print("💬 Answer:", generate_answer(query, retrieved))
|