ecom-qa / src /models /generative_model.py
rnyx's picture
Initial deploy
f48b219
Raw
History Blame Contribute Delete
3.95 kB
"""
src/models/generative_model.py
Fix #1 — Generative fallback using google/flan-t5-base (open-source, ~300 MB).
Fires only when BERT confidence < 40%.
"""
import logging
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
logger = logging.getLogger(__name__)
MODEL_ID = "google/flan-t5-base"
# Max tokens for the context fed to T5 (T5 is not limited to 512 like BERT,
# but longer inputs slow inference). We truncate to keep latency acceptable.
MAX_INPUT_TOKENS = 768
MAX_OUTPUT_TOKENS = 220
class GenerativeModel:
"""
Uses Flan-T5-base to synthesise answers when BERT's extractive confidence
is too low, or to ENRICH BERT's exact-span answers with explanation.
Prompt is engineered for descriptive, salesman-like responses.
"""
def __init__(self):
self._tokenizer = None
self._model = None
self._device = "cuda" if torch.cuda.is_available() else "cpu"
def _load(self):
if self._model is None:
logger.info("Loading Flan-T5 generative fallback (%s) on %s…", MODEL_ID, self._device)
self._tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
self._model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
self._model.to(self._device)
self._model.eval()
# ── Prompt engineering ────────────────────────────────────────────────────
@staticmethod
def _build_prompt(question: str, context: str, mode: str = "answer") -> str:
"""
Two modes:
mode="answer" → standalone answer when BERT failed
mode="enrich" → expand on a BERT-extracted span
"""
ctx_preview = context[:2200] # ~550 tokens
if mode == "enrich":
return (
f"You are a helpful product expert.\n"
f"Read the product information below and give a clear, detailed answer "
f"to the customer's question in 2-3 sentences. Be specific and quote "
f"facts from the product info. Do not invent details.\n\n"
f"PRODUCT INFORMATION:\n{ctx_preview}\n\n"
f"CUSTOMER QUESTION: {question}\n\n"
f"DETAILED ANSWER:"
)
return (
f"You are a knowledgeable product assistant. Answer the customer's question "
f"using ONLY the product information provided. Give a complete, descriptive "
f"answer (2-4 sentences). If the answer isn't in the information, say so "
f"clearly instead of guessing.\n\n"
f"PRODUCT INFORMATION:\n{ctx_preview}\n\n"
f"QUESTION: {question}\n\n"
f"ANSWER:"
)
# ── Public interface ──────────────────────────────────────────────────────
def answer(self, question: str, context: str, mode: str = "answer") -> str:
"""Returns a generated answer string. mode: 'answer' or 'enrich'."""
self._load()
prompt = self._build_prompt(question, context, mode)
inputs = self._tokenizer(
prompt,
return_tensors="pt",
max_length=MAX_INPUT_TOKENS,
truncation=True,
).to(self._device)
with torch.no_grad():
output_ids = self._model.generate(
**inputs,
max_new_tokens=MAX_OUTPUT_TOKENS,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3,
temperature=0.7,
do_sample=False,
)
answer = self._tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
return answer if answer else "Could not generate an answer for this question."