Spaces:
Sleeping
Sleeping
| import torch | |
| import joblib | |
| import telebot | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| import re | |
| weights_path = 'bot/utils/ruBERTcls_weights.pth' | |
| tokenizer_path = 'bot/utils/tokenizer.joblib' | |
| label_encoder_path = 'bot/utils/label_encoder.joblib' | |
| pretrained_weights = 'cointegrated/rubert-tiny2' | |
| MAX_LEN = 924 | |
| device = 'cuda' | |
| class ruBERTClassifier(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained(pretrained_weights) | |
| for param in self.bert.parameters(): | |
| param.requires_grad = False | |
| self.linear = torch.nn.Linear(312, 5) | |
| def forward(self, x, attention_mask=None): | |
| bert_out = self.bert(input_ids=x, attention_mask=attention_mask)[0][:, 0, :] | |
| out = self.linear(bert_out) | |
| return out | |
| ruBERTcls = ruBERTClassifier() | |
| ruBERTcls.load_state_dict(torch.load(weights_path, map_location=device)) | |
| ruBERTcls.to(device) | |
| tokenizer = joblib.load(tokenizer_path) | |
| label_encoder = joblib.load(label_encoder_path) | |
| def preprocess_text(text): | |
| text = re.sub(r'[\U0001F600-\U0001F64F]', '', text) | |
| return text | |
| def predict_topic(text, model, tokenizer, label_encoder, max_len=MAX_LEN, device='cuda'): | |
| text = preprocess_text(text) | |
| tokens = tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=max_len) | |
| tokens_padded = np.array(tokens + [0] * (max_len - len(tokens))) | |
| attention_mask = np.where(tokens_padded != 0, 1, 0) | |
| tokens_tensor = torch.tensor([tokens_padded]).to(device) | |
| attention_mask_tensor = torch.tensor([attention_mask]).to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(tokens_tensor, attention_mask=attention_mask_tensor) | |
| prediction = torch.argmax(outputs, dim=1).cpu().numpy()[0] | |
| topic = label_encoder.inverse_transform([prediction])[0] | |
| return topic | |
| # bot = telebot.TeleBot('7388106883:AAGznNWkQqs3dxBb90BXT5OaOS3ln_dD2ZU') | |
| # @bot.message_handler(commands=['start']) | |
| # def send_welcome(message): | |
| # bot.reply_to(message, "Hello! Send me a news text and I'll predict its topic.") | |
| # @bot.message_handler(func=lambda message: True) | |
| # def predict_message_topic(message): | |
| # text = message.text | |
| # predicted_topic = predict_topic(text, ruBERTcls, tokenizer, label_encoder) | |
| # bot.reply_to(message, f"The predicted topic is: {predicted_topic}") | |
| # bot.polling() | |