ayush2917 commited on
Commit
3802f97
·
verified ·
1 Parent(s): 532e893

Update src/generation.py

Browse files
Files changed (1) hide show
  1. src/generation.py +14 -4
src/generation.py CHANGED
@@ -2,7 +2,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
2
 
3
  class ResponseGenerator:
4
  def __init__(self, model_name="distilgpt2"):
5
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
6
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
7
  self.prompt_template = """
8
  You are a customer support chatbot for Rupeia, a financial platform. Provide accurate, concise answers about Investments, Goals, Benefits, Gadgets, and News & Blogs. Use the context and history to respond naturally. If unsure, say: "I’m not sure about that. Could you clarify or ask about Rupeia features?"
@@ -15,7 +19,13 @@ Answer: """
15
  def generate_response(self, user_input, context, history):
16
  history_str = "\n".join([f"User: {h[0]}\nBot: {h[1]}" for h in history[-3:]])
17
  prompt = self.prompt_template.format(context=context, history=history_str, user_input=user_input)
18
- inputs = self.tokenizer.encode(prompt, return_tensors="pt")
19
- outputs = self.model.generate(inputs, max_length=200, pad_token_id=self.tokenizer.eos_token_id)
20
- response = self.tokenizer.decode(outputs[:, inputs.shape[-1]:][0], skip_special_tokens=True).strip()
 
 
 
 
 
 
21
  return response if response else context or "I’m not sure about that. Could you clarify or ask about Rupeia features?"
 
2
 
3
  class ResponseGenerator:
4
  def __init__(self, model_name="distilgpt2"):
5
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)
6
+ # Set a distinct pad token
7
+ if self.tokenizer.pad_token is None:
8
+ self.tokenizer.pad_token = self.tokenizer.eos_token
9
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
10
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
11
  self.prompt_template = """
12
  You are a customer support chatbot for Rupeia, a financial platform. Provide accurate, concise answers about Investments, Goals, Benefits, Gadgets, and News & Blogs. Use the context and history to respond naturally. If unsure, say: "I’m not sure about that. Could you clarify or ask about Rupeia features?"
 
19
  def generate_response(self, user_input, context, history):
20
  history_str = "\n".join([f"User: {h[0]}\nBot: {h[1]}" for h in history[-3:]])
21
  prompt = self.prompt_template.format(context=context, history=history_str, user_input=user_input)
22
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
23
+ # Ensure attention mask is passed
24
+ outputs = self.model.generate(
25
+ inputs["input_ids"],
26
+ attention_mask=inputs["attention_mask"],
27
+ max_length=200,
28
+ pad_token_id=self.tokenizer.pad_token_id
29
+ )
30
+ response = self.tokenizer.decode(outputs[:, inputs["input_ids"].shape[-1]:][0], skip_special_tokens=True).strip()
31
  return response if response else context or "I’m not sure about that. Could you clarify or ask about Rupeia features?"