import streamlit as st import joblib import pandas as pd from models.model1.Custom_class import TextPreprocessor from pathlib import Path import sys import torch import numpy as np import matplotlib.pyplot as plt import time project_root = Path(__file__).resolve().parents[1] models_path = project_root / 'models' sys.path.append(str(models_path)) from models.model1.lstm_preprocessor import TextPreprocessorWord2Vec from models.model1.lstm_model import LSTMConcatAttention # Load the trained pipeline pipeline = joblib.load('models/model1/logistic_regression_pipeline.pkl') # Streamlit application st.title('Классификация отзывов на русском языке') input_text = st.text_area('Введите текст отзыва') device = 'cpu' # Загрузка модели LSTM и словаря @st.cache_resource def load_lstm_model(): model = LSTMConcatAttention() weights_path = models_path / 'model1' / 'lstm_weights' state_dict = torch.load(weights_path, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() return model lstm_model = load_lstm_model() @st.cache_resource def load_int_to_vocab(): vocab_path = models_path / 'model1' / 'lstm_vocab_to_int.pkl' vocab_to_int = joblib.load(vocab_path) int_to_vocab = {j:i for i, j in vocab_to_int.items()} return int_to_vocab int_to_vocab = load_int_to_vocab() def plot_and_predict_lstm(input_text): preprocessor_lstm = TextPreprocessorWord2Vec() preprocessed = preprocessor_lstm.transform(input_text) lstm_model.eval() with torch.inference_mode(): pred, att_scores = lstm_model(preprocessed.long().unsqueeze(0)) lstm_pred = pred.sigmoid().item() # Получить индексы слов, которые не равны и не имеют индекс 0 valid_indices = [i for i, x in enumerate(preprocessed) if x.item() != 0 and int_to_vocab[x.item()] != ""] # Получить соответствующие оценки внимания и метки слов valid_att_scores = att_scores.detach().cpu().numpy()[0][valid_indices] valid_labels = [int_to_vocab[preprocessed[i].item()] for i in valid_indices] # Упорядочить метки и оценки внимания по убыванию веса смысла sorted_indices = np.argsort(valid_att_scores) sorted_labels = [valid_labels[i] for i in sorted_indices] sorted_att_scores = valid_att_scores[sorted_indices] # Построить график с учетом только валидных меток plt.figure(figsize=(4, 8)) plt.barh(np.arange(len(sorted_indices)), sorted_att_scores) plt.yticks(ticks=np.arange(len(sorted_indices)), labels=sorted_labels) return lstm_pred, plt if st.button('Предсказать'): start_time_lr = time.time() prediction = pipeline.predict(pd.Series([input_text])) pred_probe = pipeline.predict_proba(pd.Series([input_text])) pred_proba_rounded = np.round(pred_probe, 2).flatten() if prediction[0] == 0: predicted_class = "POSITIVE" else: predicted_class = "NEGATIVE" st.subheader('Предсказанный класс с помощью логистической регрессии и tf-idf') end_time_lr = time.time() time_lr = end_time_lr - start_time_lr st.write(f'**{predicted_class}** с вероятностью {pred_proba_rounded[0]}') st.write(f'Время выполнения расчетов {time_lr:.4f} секунд') start_time_lstm = time.time() lstm_pred, lstm_plot = plot_and_predict_lstm(input_text) if lstm_pred > 0.5: predicted_lstm_class = "POSITIVE" else: predicted_lstm_class = "NEGATIVE" st.subheader('Предсказанный класс с помощью LSTM + Word2Vec + BahdanauAttention:') end_time_lstm = time.time() time_lstm = end_time_lstm - start_time_lstm st.write(f'**{predicted_lstm_class}** с вероятностью {round(lstm_pred, 3)}') st.write(f'Время выполнения расчетов {time_lstm:.4f} секунд') st.pyplot(lstm_plot) st.write("# Информация об обучении модели логистической регрессии и tf-idf:") st.image(str(project_root / 'images/pipeline_logreg.png')) st.write("Модель обучалась на предсказание 1 класса") st.write("Размер датасета - 70597 текстов отзывов") st.write("Проведена предобработка текста") st.write("Метрики:") st.image(str(project_root / 'images/log_reg_metrics.png')) st.write("# Информация об обучении модели LSTM + Word2Vec + BahdanauAttention:") st.write("Время обучения модели - 10 эпох") st.write("Метрики на 10 эпохе:") st.write("Train f1: 0.95, Val f1: 0.93") st.write("Train accuracy: 0.94, Val accuracy: 0.92")