Spaces:
Runtime error
Runtime error
| import logging | |
| from telegram import Update | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from telegram.ext import ( | |
| CallbackContext, | |
| ) | |
| NAME = "Conversation" | |
| DESCRIPTION = """ | |
| Useful for building up conversation. | |
| Input: A normal chat text | |
| Output: A text | |
| """ | |
| GET_CON = range(1) | |
| class Conversation(): | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
| model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") | |
| async def talk(self, message: str): | |
| logging.info(f"{message}") | |
| new_user_input_ids = self.tokenizer.encode(input(f'{message}') + self.tokenizer.eos_token, return_tensors='pt') | |
| bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) | |
| chat_history_ids =self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id) | |
| return "{}".format(self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)) | |
| async def process_conversation(self, update: Update, context: CallbackContext) -> int: | |
| message = update.message.text | |
| text = await self.talk(message) | |
| await update.message.reply_text(f'{text}') | |