Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import joblib | |
| import pandas as pd | |
| from models.model1.Custom_class import TextPreprocessor | |
| from pathlib import Path | |
| import sys | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import time | |
| project_root = Path(__file__).resolve().parents[1] | |
| models_path = project_root / 'models' | |
| sys.path.append(str(models_path)) | |
| from models.model1.lstm_preprocessor import TextPreprocessorWord2Vec | |
| from models.model1.lstm_model import LSTMConcatAttention | |
| # Load the trained pipeline | |
| pipeline = joblib.load('models/model1/logistic_regression_pipeline.pkl') | |
| # Streamlit application | |
| st.title('Классификация отзывов на русском языке') | |
| input_text = st.text_area('Введите текст отзыва') | |
| device = 'cpu' | |
| # Загрузка модели LSTM и словаря | |
| def load_lstm_model(): | |
| model = LSTMConcatAttention() | |
| weights_path = models_path / 'model1' / 'lstm_weights' | |
| state_dict = torch.load(weights_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| lstm_model = load_lstm_model() | |
| def load_int_to_vocab(): | |
| vocab_path = models_path / 'model1' / 'lstm_vocab_to_int.pkl' | |
| vocab_to_int = joblib.load(vocab_path) | |
| int_to_vocab = {j:i for i, j in vocab_to_int.items()} | |
| return int_to_vocab | |
| int_to_vocab = load_int_to_vocab() | |
| def plot_and_predict_lstm(input_text): | |
| preprocessor_lstm = TextPreprocessorWord2Vec() | |
| preprocessed = preprocessor_lstm.transform(input_text) | |
| lstm_model.eval() | |
| with torch.inference_mode(): | |
| pred, att_scores = lstm_model(preprocessed.long().unsqueeze(0)) | |
| lstm_pred = pred.sigmoid().item() | |
| # Получить индексы слов, которые не равны <pad> и не имеют индекс 0 | |
| valid_indices = [i for i, x in enumerate(preprocessed) if x.item() != 0 and int_to_vocab[x.item()] != "<pad>"] | |
| # Получить соответствующие оценки внимания и метки слов | |
| valid_att_scores = att_scores.detach().cpu().numpy()[0][valid_indices] | |
| valid_labels = [int_to_vocab[preprocessed[i].item()] for i in valid_indices] | |
| # Упорядочить метки и оценки внимания по убыванию веса смысла | |
| sorted_indices = np.argsort(valid_att_scores) | |
| sorted_labels = [valid_labels[i] for i in sorted_indices] | |
| sorted_att_scores = valid_att_scores[sorted_indices] | |
| # Построить график с учетом только валидных меток | |
| plt.figure(figsize=(4, 8)) | |
| plt.barh(np.arange(len(sorted_indices)), sorted_att_scores) | |
| plt.yticks(ticks=np.arange(len(sorted_indices)), labels=sorted_labels) | |
| return lstm_pred, plt | |
| if st.button('Предсказать'): | |
| start_time_lr = time.time() | |
| prediction = pipeline.predict(pd.Series([input_text])) | |
| pred_probe = pipeline.predict_proba(pd.Series([input_text])) | |
| pred_proba_rounded = np.round(pred_probe, 2).flatten() | |
| if prediction[0] == 0: | |
| predicted_class = "POSITIVE" | |
| else: | |
| predicted_class = "NEGATIVE" | |
| st.subheader('Предсказанный класс с помощью логистической регрессии и tf-idf') | |
| end_time_lr = time.time() | |
| time_lr = end_time_lr - start_time_lr | |
| st.write(f'**{predicted_class}** с вероятностью {pred_proba_rounded[0]}') | |
| st.write(f'Время выполнения расчетов {time_lr:.4f} секунд') | |
| start_time_lstm = time.time() | |
| lstm_pred, lstm_plot = plot_and_predict_lstm(input_text) | |
| if lstm_pred > 0.5: | |
| predicted_lstm_class = "POSITIVE" | |
| else: | |
| predicted_lstm_class = "NEGATIVE" | |
| st.subheader('Предсказанный класс с помощью LSTM + Word2Vec + BahdanauAttention:') | |
| end_time_lstm = time.time() | |
| time_lstm = end_time_lstm - start_time_lstm | |
| st.write(f'**{predicted_lstm_class}** с вероятностью {round(lstm_pred, 3)}') | |
| st.write(f'Время выполнения расчетов {time_lstm:.4f} секунд') | |
| st.pyplot(lstm_plot) | |
| st.write("# Информация об обучении модели логистической регрессии и tf-idf:") | |
| st.image(str(project_root / 'images/pipeline_logreg.png')) | |
| st.write("Модель обучалась на предсказание 1 класса") | |
| st.write("Размер датасета - 70597 текстов отзывов") | |
| st.write("Проведена предобработка текста") | |
| st.write("Метрики:") | |
| st.image(str(project_root / 'images/log_reg_metrics.png')) | |
| st.write("# Информация об обучении модели LSTM + Word2Vec + BahdanauAttention:") | |
| st.write("Время обучения модели - 10 эпох") | |
| st.write("Метрики на 10 эпохе:") | |
| st.write("Train f1: 0.95, Val f1: 0.93") | |
| st.write("Train accuracy: 0.94, Val accuracy: 0.92") | |