Spaces:
Sleeping
Sleeping
| from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
| import torch | |
| from datasets import load_dataset | |
| import pandas as pd | |
| import re | |
| class ChatBot: | |
| def __init__(self,dir,tokenizer,model,device): | |
| self.directory = dir | |
| self.tokenizer = tokenizer | |
| self.model = model | |
| self.device = device | |
| self.model.to(self.device) | |
| def generate_response(self, history): | |
| combined_prompt = "" | |
| # self.tokenizer.eos_token_id = '<|endoftext|>' | |
| if len(history.user) > 7: | |
| history.user = history.user[-7:] | |
| history.ai = history.ai[-6:] | |
| # Iterate over user and AI messages | |
| for user_message, ai_message in zip(history.user, history.ai): | |
| combined_prompt += f"<user> {user_message}{self.tokenizer.eos_token_id}<AI> {ai_message}{self.tokenizer.eos_token_id}" | |
| # Include the last user message in the prompt for response generation | |
| if history.user: | |
| combined_prompt += f"<user> {history.user[-1]}{self.tokenizer.eos_token_id}<AI>" | |
| # Tokenize and generate response | |
| inputs = self.tokenizer.encode(combined_prompt, return_tensors="pt").to(self.device) | |
| attention_mask = torch.ones(inputs.shape, device=self.device) | |
| outputs = self.model.generate( | |
| inputs, | |
| max_new_tokens=20, # Adjust length as needed | |
| num_beams=5, | |
| early_stopping=True, | |
| no_repeat_ngram_size=2, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| attention_mask=attention_mask, | |
| repetition_penalty=1.2 | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # response = response.replace(combined_prompt, "").split(".")[0]#.replace("(user 1's name)",'AI').replace("(user 2's name)",'AI').replace("[user 1's name]",'AI').replace('<user>','') | |
| # print('here:\n', combined_prompt,'\n\n response:\n', response,'\n\n edit-resposne: \n', response.replace(combined_prompt, "").replace('(name)','AI').split(".")[0],'\n\n') | |
| return response.replace(combined_prompt, "").split(".")[0] |