import torch import joblib import telebot import numpy as np from transformers import AutoTokenizer, AutoModel import re weights_path = 'bot/utils/ruBERTcls_weights.pth' tokenizer_path = 'bot/utils/tokenizer.joblib' label_encoder_path = 'bot/utils/label_encoder.joblib' pretrained_weights = 'cointegrated/rubert-tiny2' MAX_LEN = 924 device = 'cuda' class ruBERTClassifier(torch.nn.Module): def __init__(self): super().__init__() self.bert = AutoModel.from_pretrained(pretrained_weights) for param in self.bert.parameters(): param.requires_grad = False self.linear = torch.nn.Linear(312, 5) def forward(self, x, attention_mask=None): bert_out = self.bert(input_ids=x, attention_mask=attention_mask)[0][:, 0, :] out = self.linear(bert_out) return out ruBERTcls = ruBERTClassifier() ruBERTcls.load_state_dict(torch.load(weights_path, map_location=device)) ruBERTcls.to(device) tokenizer = joblib.load(tokenizer_path) label_encoder = joblib.load(label_encoder_path) def preprocess_text(text): text = re.sub(r'[\U0001F600-\U0001F64F]', '', text) return text def predict_topic(text, model, tokenizer, label_encoder, max_len=MAX_LEN, device='cuda'): text = preprocess_text(text) tokens = tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_len) tokens_padded = np.array(tokens + [0] * (max_len - len(tokens))) attention_mask = np.where(tokens_padded != 0, 1, 0) tokens_tensor = torch.tensor([tokens_padded]).to(device) attention_mask_tensor = torch.tensor([attention_mask]).to(device) model.eval() with torch.no_grad(): outputs = model(tokens_tensor, attention_mask=attention_mask_tensor) prediction = torch.argmax(outputs, dim=1).cpu().numpy()[0] topic = label_encoder.inverse_transform([prediction])[0] return topic # bot = telebot.TeleBot('7388106883:AAGznNWkQqs3dxBb90BXT5OaOS3ln_dD2ZU') # @bot.message_handler(commands=['start']) # def send_welcome(message): # bot.reply_to(message, "Hello! Send me a news text and I'll predict its topic.") # @bot.message_handler(func=lambda message: True) # def predict_message_topic(message): # text = message.text # predicted_topic = predict_topic(text, ruBERTcls, tokenizer, label_encoder) # bot.reply_to(message, f"The predicted topic is: {predicted_topic}") # bot.polling()