Spaces:
Sleeping
Sleeping
| import logging | |
| import torch | |
| from aiogram import Bot, Dispatcher | |
| from aiogram.types import Message, ReplyKeyboardMarkup, KeyboardButton, ReplyKeyboardRemove | |
| from aiogram.filters.command import Command | |
| from functools import lru_cache | |
| from preprocess_text import TextPreprocessorBERT | |
| from model import BERTClassifier | |
| from transformers import AutoTokenizer | |
| device = 'cpu' | |
| # Инициализация объектов | |
| TOKEN = '6864353709:AAHM-J59cETYpxWzJFdHpm9QyV7rE2FL_KU' | |
| bot = Bot(token=TOKEN) | |
| dp = Dispatcher() | |
| logging.basicConfig(filename="mylog.log", level=logging.INFO) | |
| start_keyboard = ReplyKeyboardMarkup( | |
| keyboard=[ | |
| [KeyboardButton(text="/start")] | |
| ], | |
| resize_keyboard=True | |
| ) | |
| def load_model(): | |
| model = BERTClassifier() | |
| weights_path = 'bot/model_weights_new.pth' | |
| state_dict = torch.load(weights_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def load_tokenizer(): | |
| return AutoTokenizer.from_pretrained('cointegrated/rubert-tiny-toxicity') | |
| model = load_model() | |
| tokenizer = load_tokenizer() | |
| # Обработка команды start | |
| async def proccess_command_start(message: Message): | |
| user_name = message.from_user.full_name | |
| user_id = message.from_user.id | |
| text = f'Привет, {user_name}! Я помогу тебе оценить токсичность сообщений 😃' | |
| logging.info(f'{user_name} {user_id} запустил бота') | |
| await bot.send_message(chat_id=user_id, text=text, reply_markup=ReplyKeyboardRemove()) | |
| # Добавление кнопки "Start" при старте | |
| async def send_welcome(message: Message): | |
| user_id = message.from_user.id | |
| await bot.send_message(chat_id=user_id, text="Нажмите кнопку /start для начала работы", reply_markup=start_keyboard) | |
| async def predict_sentence(message: Message): | |
| user_name = message.from_user.full_name | |
| user_id = message.from_user.id | |
| text = message.text | |
| # Предобработка сообщения | |
| preprocessor = TextPreprocessorBERT() | |
| preprocessed_text = preprocessor.transform(text) | |
| tokens = tokenizer.encode_plus( | |
| preprocessed_text, | |
| add_special_tokens=True, | |
| truncation=True, | |
| max_length=100, | |
| padding='max_length', | |
| return_tensors='pt' | |
| ) | |
| # Получаем input_ids и attention_mask из токенов | |
| input_ids = tokens['input_ids'].to(device) | |
| attention_mask = tokens['attention_mask'].to(device) | |
| # Предсказание | |
| with torch.no_grad(): | |
| output = model(input_ids, attention_mask=attention_mask) | |
| # Интерпретация результата | |
| prediction = torch.sigmoid(output).item() | |
| if prediction > 0.5: | |
| predicted_class = "ТОКСИК!!!" | |
| response_text = f'{predicted_class} c вероятностью {round(prediction, 3)}' | |
| sticker_id = 'CAACAgIAAxkBAAMrZll5jPH6HJ3j7kSLDEQU8NKDjR0AAhQAA5KfHhEGBsTRjH5zHDUE' | |
| else: | |
| predicted_class = 'Не токсик)' | |
| # response_text = f'{predicted_class} c вероятностью {round(1 - prediction, 3)}' | |
| response_text = f'{predicted_class} c вероятностью {round(prediction, 3)}' | |
| sticker_id = 'CAACAgIAAxkBAAMtZll5udV6ScWrGUMhkJIFmvazQicAAlgAA5KfHhFUuZt-mMSZyTUE' | |
| # Отправка ответа пользователю | |
| logging.info(f'{user_name} {user_id}: {text}') | |
| await bot.send_message(chat_id=user_id, text=response_text) | |
| await bot.send_sticker(chat_id=user_id, sticker=sticker_id) | |
| if __name__ == '__main__': | |
| dp.run_polling(bot) | |