| """ |
| src/models/generative_model.py |
| Fix #1 — Generative fallback using google/flan-t5-base (open-source, ~300 MB). |
| Fires only when BERT confidence < 40%. |
| """ |
|
|
| import logging |
| from transformers import T5ForConditionalGeneration, T5Tokenizer |
| import torch |
|
|
| logger = logging.getLogger(__name__) |
|
|
| MODEL_ID = "google/flan-t5-base" |
|
|
| |
| |
| MAX_INPUT_TOKENS = 768 |
| MAX_OUTPUT_TOKENS = 220 |
|
|
|
|
| class GenerativeModel: |
| """ |
| Uses Flan-T5-base to synthesise answers when BERT's extractive confidence |
| is too low, or to ENRICH BERT's exact-span answers with explanation. |
| Prompt is engineered for descriptive, salesman-like responses. |
| """ |
|
|
| def __init__(self): |
| self._tokenizer = None |
| self._model = None |
| self._device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def _load(self): |
| if self._model is None: |
| logger.info("Loading Flan-T5 generative fallback (%s) on %s…", MODEL_ID, self._device) |
| self._tokenizer = T5Tokenizer.from_pretrained(MODEL_ID) |
| self._model = T5ForConditionalGeneration.from_pretrained(MODEL_ID) |
| self._model.to(self._device) |
| self._model.eval() |
|
|
| |
| @staticmethod |
| def _build_prompt(question: str, context: str, mode: str = "answer") -> str: |
| """ |
| Two modes: |
| mode="answer" → standalone answer when BERT failed |
| mode="enrich" → expand on a BERT-extracted span |
| """ |
| ctx_preview = context[:2200] |
| if mode == "enrich": |
| return ( |
| f"You are a helpful product expert.\n" |
| f"Read the product information below and give a clear, detailed answer " |
| f"to the customer's question in 2-3 sentences. Be specific and quote " |
| f"facts from the product info. Do not invent details.\n\n" |
| f"PRODUCT INFORMATION:\n{ctx_preview}\n\n" |
| f"CUSTOMER QUESTION: {question}\n\n" |
| f"DETAILED ANSWER:" |
| ) |
| return ( |
| f"You are a knowledgeable product assistant. Answer the customer's question " |
| f"using ONLY the product information provided. Give a complete, descriptive " |
| f"answer (2-4 sentences). If the answer isn't in the information, say so " |
| f"clearly instead of guessing.\n\n" |
| f"PRODUCT INFORMATION:\n{ctx_preview}\n\n" |
| f"QUESTION: {question}\n\n" |
| f"ANSWER:" |
| ) |
|
|
| |
| def answer(self, question: str, context: str, mode: str = "answer") -> str: |
| """Returns a generated answer string. mode: 'answer' or 'enrich'.""" |
| self._load() |
|
|
| prompt = self._build_prompt(question, context, mode) |
|
|
| inputs = self._tokenizer( |
| prompt, |
| return_tensors="pt", |
| max_length=MAX_INPUT_TOKENS, |
| truncation=True, |
| ).to(self._device) |
|
|
| with torch.no_grad(): |
| output_ids = self._model.generate( |
| **inputs, |
| max_new_tokens=MAX_OUTPUT_TOKENS, |
| num_beams=4, |
| early_stopping=True, |
| no_repeat_ngram_size=3, |
| temperature=0.7, |
| do_sample=False, |
| ) |
|
|
| answer = self._tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() |
| return answer if answer else "Could not generate an answer for this question." |
|
|