| from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
| import torch |
|
|
| class MultilingualChatbot: |
| def __init__(self): |
| self.models = { |
| 'en': GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium"), |
| 'es': GPT2LMHeadModel.from_pretrained("DeepESP/gpt2-spanish"), |
| 'fr': GPT2LMHeadModel.from_pretrained("asi/gpt-fr-cased-small") |
| } |
| self.tokenizers = { |
| 'en': GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium"), |
| 'es': GPT2Tokenizer.from_pretrained("DeepESP/gpt2-spanish"), |
| 'fr': GPT2Tokenizer.from_pretrained("asi/gpt-fr-cased-small") |
| } |
| for tokenizer in self.tokenizers.values(): |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| def generate_response(self, prompt, src_lang): |
| |
| model = self.models.get(src_lang, self.models['en']) |
| tokenizer = self.tokenizers.get(src_lang, self.tokenizers['en']) |
| |
| input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt') |
| |
| |
| input_ids = input_ids.to(model.device) |
| |
| chat_history_ids = model.generate( |
| input_ids, |
| max_length=1000, |
| pad_token_id=tokenizer.eos_token_id, |
| no_repeat_ngram_size=3, |
| do_sample=True, |
| top_k=50, |
| top_p=0.95, |
| temperature=0.7, |
| num_return_sequences=1, |
| length_penalty=1.0, |
| repetition_penalty=1.2 |
| ) |
| return tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
|
| def initialize_chatbot(): |
| return MultilingualChatbot() |
|
|
| def get_chatbot_response(chatbot, prompt, src_lang): |
| return chatbot.generate_response(prompt, src_lang) |
|
|
| def initialize_chatbot(): |
| return MultilingualChatbot() |
|
|
| def get_chatbot_response(chatbot, prompt, src_lang): |
| return chatbot.generate_response(prompt, src_lang) |