Spaces:
No application file
No application file
| from transformers import BertTokenizer, BertForSequenceClassification | |
| import torch | |
| from sklearn.preprocessing import LabelEncoder | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| import torch | |
| from sklearn.preprocessing import LabelEncoder | |
| labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта'] | |
| label_encoder = LabelEncoder() | |
| label_encoder.fit(labels) | |
| # Загрузка сохраненной модели и токенизатора в Streamlit | |
| loaded_model_path = "rubert-base-cased" | |
| loaded_tokenizer_path = BertForSequenceClassification.from_pretrained(loaded_model_path) | |
| # Инициализация модели и токенизатора | |
| loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path) | |
| loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path) | |
| # Создание модели с архитектурой BertForSequenceClassification | |
| # Передайте в аргумент `num_labels` количество классов, для которых модель будет выполнять классификацию | |
| model = BertForSequenceClassification(num_labels=len(labels)) | |
| # Загрузка весов из сохраненного файла | |
| weights_path = "model_weights_epoch_8.pt" | |
| state_dict = torch.load(weights_path, map_location='cpu') # Укажите 'cuda' вместо 'cpu', если используете GPU | |
| model.load_state_dict(state_dict) | |
| # Пример использования загруженной модели | |
| user_input = "Ваш текст для классификации" | |
| predicted_class = predict_class(user_input, model=model, tokenizer=loaded_tokenizer, label_encoder=label_encoder) | |
| print(predicted_class) | |
| # #Загрузка сохраненной модели и токенизатора в Streamlit | |
| # loaded_model_path = "nlp_project/model" | |
| # loaded_tokenizer_path = "nlp_project/tokenizer" | |
| # loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path) | |
| # loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path) | |
| def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128): | |
| if not user_input: | |
| return "Введите текст" | |
| def tokenize_text(text): | |
| encoded_text = tokenizer.encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| max_length=max_length, | |
| pad_to_max_length=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| return encoded_text | |
| encoded_text = tokenize_text(user_input) | |
| with torch.no_grad(): | |
| model.eval() | |
| input_ids = encoded_text['input_ids'] | |
| attention_mask = encoded_text['attention_mask'] | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| predicted_class_index = torch.argmax(logits, dim=1).item() | |
| # Получение названия класса | |
| predicted_class = label_encoder.classes_[predicted_class_index] | |
| return predicted_class | |