Spaces:
Sleeping
Sleeping
File size: 5,060 Bytes
961ee03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import streamlit as st
import joblib
import pandas as pd
from models.model1.Custom_class import TextPreprocessor
from pathlib import Path
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
project_root = Path(__file__).resolve().parents[1]
models_path = project_root / 'models'
sys.path.append(str(models_path))
from models.model1.lstm_preprocessor import TextPreprocessorWord2Vec
from models.model1.lstm_model import LSTMConcatAttention
# Load the trained pipeline
pipeline = joblib.load('models/model1/logistic_regression_pipeline.pkl')
# Streamlit application
st.title('Классификация отзывов на русском языке')
input_text = st.text_area('Введите текст отзыва')
device = 'cpu'
# Загрузка модели LSTM и словаря
@st.cache_resource
def load_lstm_model():
model = LSTMConcatAttention()
weights_path = models_path / 'model1' / 'lstm_weights'
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
lstm_model = load_lstm_model()
@st.cache_resource
def load_int_to_vocab():
vocab_path = models_path / 'model1' / 'lstm_vocab_to_int.pkl'
vocab_to_int = joblib.load(vocab_path)
int_to_vocab = {j:i for i, j in vocab_to_int.items()}
return int_to_vocab
int_to_vocab = load_int_to_vocab()
def plot_and_predict_lstm(input_text):
preprocessor_lstm = TextPreprocessorWord2Vec()
preprocessed = preprocessor_lstm.transform(input_text)
lstm_model.eval()
with torch.inference_mode():
pred, att_scores = lstm_model(preprocessed.long().unsqueeze(0))
lstm_pred = pred.sigmoid().item()
# Получить индексы слов, которые не равны <pad> и не имеют индекс 0
valid_indices = [i for i, x in enumerate(preprocessed) if x.item() != 0 and int_to_vocab[x.item()] != "<pad>"]
# Получить соответствующие оценки внимания и метки слов
valid_att_scores = att_scores.detach().cpu().numpy()[0][valid_indices]
valid_labels = [int_to_vocab[preprocessed[i].item()] for i in valid_indices]
# Упорядочить метки и оценки внимания по убыванию веса смысла
sorted_indices = np.argsort(valid_att_scores)
sorted_labels = [valid_labels[i] for i in sorted_indices]
sorted_att_scores = valid_att_scores[sorted_indices]
# Построить график с учетом только валидных меток
plt.figure(figsize=(4, 8))
plt.barh(np.arange(len(sorted_indices)), sorted_att_scores)
plt.yticks(ticks=np.arange(len(sorted_indices)), labels=sorted_labels)
return lstm_pred, plt
if st.button('Предсказать'):
start_time_lr = time.time()
prediction = pipeline.predict(pd.Series([input_text]))
pred_probe = pipeline.predict_proba(pd.Series([input_text]))
pred_proba_rounded = np.round(pred_probe, 2).flatten()
if prediction[0] == 0:
predicted_class = "POSITIVE"
else:
predicted_class = "NEGATIVE"
st.subheader('Предсказанный класс с помощью логистической регрессии и tf-idf')
end_time_lr = time.time()
time_lr = end_time_lr - start_time_lr
st.write(f'**{predicted_class}** с вероятностью {pred_proba_rounded[0]}')
st.write(f'Время выполнения расчетов {time_lr:.4f} секунд')
start_time_lstm = time.time()
lstm_pred, lstm_plot = plot_and_predict_lstm(input_text)
if lstm_pred > 0.5:
predicted_lstm_class = "POSITIVE"
else:
predicted_lstm_class = "NEGATIVE"
st.subheader('Предсказанный класс с помощью LSTM + Word2Vec + BahdanauAttention:')
end_time_lstm = time.time()
time_lstm = end_time_lstm - start_time_lstm
st.write(f'**{predicted_lstm_class}** с вероятностью {round(lstm_pred, 3)}')
st.write(f'Время выполнения расчетов {time_lstm:.4f} секунд')
st.pyplot(lstm_plot)
st.write("# Информация об обучении модели логистической регрессии и tf-idf:")
st.image(str(project_root / 'images/pipeline_logreg.png'))
st.write("Модель обучалась на предсказание 1 класса")
st.write("Размер датасета - 70597 текстов отзывов")
st.write("Проведена предобработка текста")
st.write("Метрики:")
st.image(str(project_root / 'images/log_reg_metrics.png'))
st.write("# Информация об обучении модели LSTM + Word2Vec + BahdanauAttention:")
st.write("Время обучения модели - 10 эпох")
st.write("Метрики на 10 эпохе:")
st.write("Train f1: 0.95, Val f1: 0.93")
st.write("Train accuracy: 0.94, Val accuracy: 0.92")
|