darinaseva's picture
a
414a454
import torch
import joblib
import telebot
import numpy as np
from transformers import AutoTokenizer, AutoModel
import re
weights_path = 'bot/utils/ruBERTcls_weights.pth'
tokenizer_path = 'bot/utils/tokenizer.joblib'
label_encoder_path = 'bot/utils/label_encoder.joblib'
pretrained_weights = 'cointegrated/rubert-tiny2'
MAX_LEN = 924
device = 'cuda'
class ruBERTClassifier(torch.nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained(pretrained_weights)
for param in self.bert.parameters():
param.requires_grad = False
self.linear = torch.nn.Linear(312, 5)
def forward(self, x, attention_mask=None):
bert_out = self.bert(input_ids=x, attention_mask=attention_mask)[0][:, 0, :]
out = self.linear(bert_out)
return out
ruBERTcls = ruBERTClassifier()
ruBERTcls.load_state_dict(torch.load(weights_path, map_location=device))
ruBERTcls.to(device)
tokenizer = joblib.load(tokenizer_path)
label_encoder = joblib.load(label_encoder_path)
def preprocess_text(text):
text = re.sub(r'[\U0001F600-\U0001F64F]', '', text)
return text
def predict_topic(text, model, tokenizer, label_encoder, max_len=MAX_LEN, device='cuda'):
text = preprocess_text(text)
tokens = tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_len)
tokens_padded = np.array(tokens + [0] * (max_len - len(tokens)))
attention_mask = np.where(tokens_padded != 0, 1, 0)
tokens_tensor = torch.tensor([tokens_padded]).to(device)
attention_mask_tensor = torch.tensor([attention_mask]).to(device)
model.eval()
with torch.no_grad():
outputs = model(tokens_tensor, attention_mask=attention_mask_tensor)
prediction = torch.argmax(outputs, dim=1).cpu().numpy()[0]
topic = label_encoder.inverse_transform([prediction])[0]
return topic
# bot = telebot.TeleBot('7388106883:AAGznNWkQqs3dxBb90BXT5OaOS3ln_dD2ZU')
# @bot.message_handler(commands=['start'])
# def send_welcome(message):
# bot.reply_to(message, "Hello! Send me a news text and I'll predict its topic.")
# @bot.message_handler(func=lambda message: True)
# def predict_message_topic(message):
# text = message.text
# predicted_topic = predict_topic(text, ruBERTcls, tokenizer, label_encoder)
# bot.reply_to(message, f"The predicted topic is: {predicted_topic}")
# bot.polling()