Rupeia_Customer / src /generation.py
ayush2917's picture
Update src/generation.py
3802f97 verified
from transformers import AutoModelForCausalLM, AutoTokenizer
class ResponseGenerator:
def __init__(self, model_name="distilgpt2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)
# Set a distinct pad token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.prompt_template = """
You are a customer support chatbot for Rupeia, a financial platform. Provide accurate, concise answers about Investments, Goals, Benefits, Gadgets, and News & Blogs. Use the context and history to respond naturally. If unsure, say: "I’m not sure about that. Could you clarify or ask about Rupeia features?"
Context: {context}
History: {history}
User: {user_input}
Answer: """
def generate_response(self, user_input, context, history):
history_str = "\n".join([f"User: {h[0]}\nBot: {h[1]}" for h in history[-3:]])
prompt = self.prompt_template.format(context=context, history=history_str, user_input=user_input)
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
# Ensure attention mask is passed
outputs = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=200,
pad_token_id=self.tokenizer.pad_token_id
)
response = self.tokenizer.decode(outputs[:, inputs["input_ids"].shape[-1]:][0], skip_special_tokens=True).strip()
return response if response else context or "I’m not sure about that. Could you clarify or ask about Rupeia features?"