| import streamlit as st
|
| import pandas as pd
|
| import matplotlib.pyplot as plt
|
| import seaborn as sns
|
| from pathlib import Path
|
| import time
|
| import torch
|
| import pickle
|
| from transformers import AutoTokenizer, BertForSequenceClassification
|
| from sklearn.pipeline import Pipeline
|
| from sklearn.preprocessing import LabelEncoder
|
| from sklearn.metrics import f1_score, accuracy_score
|
| import torch.nn as nn
|
| from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
| import json
|
| from torch.serialization import safe_globals
|
| from sklearn.preprocessing import LabelEncoder
|
|
|
| def run():
|
| def preprocess_text(text):
|
| if not isinstance(text, str):
|
| return ""
|
| return text.lower().replace('\n', ' ').replace('\r', ' ').strip()
|
|
|
|
|
| class ClassicalML:
|
| def __init__(self):
|
| self.pipeline = None
|
| self.label_encoder = None
|
|
|
| def predict(self, X):
|
| start_time = time.time()
|
| preds = self.pipeline.predict(X)
|
| return self.label_encoder.inverse_transform(preds), time.time() - start_time
|
|
|
| with torch.serialization.safe_globals([LabelEncoder]):
|
| checkpoint = torch.load('models/lstm/model.pt', map_location=torch.device('cpu'), weights_only=False)
|
|
|
| class Attention(nn.Module):
|
| def __init__(self, hidden_dim):
|
| super().__init__()
|
| self.attention = nn.Linear(hidden_dim, 1)
|
|
|
| def forward(self, lstm_output):
|
|
|
| attention_weights = torch.softmax(self.attention(lstm_output).squeeze(-1), dim=1)
|
| context = torch.bmm(attention_weights.unsqueeze(1), lstm_output).squeeze(1)
|
| return context
|
|
|
|
|
| class LSTMTrainer:
|
| def __init__(self):
|
| self.model = None
|
| self.vocab = None
|
| self.label_encoder = None
|
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
| def predict(self, X):
|
| self.model.eval()
|
| preds = []
|
| start_time = time.time()
|
| with torch.no_grad():
|
| for text in X:
|
| tokens = preprocess_text(text).split()
|
| seq = [self.vocab.get(token, 0) for token in tokens]
|
| if not seq:
|
| seq = [0]
|
| text_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(self.device)
|
| length_tensor = torch.tensor([len(seq)], dtype=torch.long)
|
| output = self.model(text_tensor, length_tensor)
|
| preds.append(torch.argmax(output).item())
|
| return self.label_encoder.inverse_transform(preds), time.time() - start_time
|
|
|
| @classmethod
|
| def load(cls, path='models/lstm'):
|
| checkpoint = torch.load(
|
| f'{path}/model.pt',
|
| map_location=torch.device('cpu'),
|
| weights_only=False
|
| )
|
|
|
| model = cls()
|
| model.vocab = checkpoint['vocab']
|
| model.label_encoder = checkpoint['label_encoder']
|
|
|
|
|
| model.model = LSTMModel(
|
| len(model.vocab),
|
| checkpoint['embed_dim'],
|
| checkpoint['hidden_dim'],
|
| len(model.label_encoder.classes_)
|
| ).to(model.device)
|
|
|
|
|
| state_dict = checkpoint['model_state_dict']
|
| new_state_dict = {}
|
|
|
| for key, value in state_dict.items():
|
| if key.startswith('attention.attention.'):
|
|
|
| if 'weight' in key:
|
| new_key = key.replace('attention.attention.', 'attention.attention.0.')
|
| elif 'bias' in key:
|
| new_key = key.replace('attention.attention.', 'attention.attention.0.')
|
| new_state_dict[new_key] = value
|
| else:
|
| new_state_dict[key] = value
|
|
|
| model.model.load_state_dict(new_state_dict, strict=False)
|
| return model
|
|
|
|
|
| class BERTClassifier:
|
| def __init__(self):
|
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| self.tokenizer = None
|
| self.model = None
|
| self.label_encoder = None
|
|
|
| def predict(self, X):
|
| self.model.eval()
|
| preds = []
|
| start_time = time.time()
|
| with torch.no_grad():
|
| for text in X:
|
| inputs = self.tokenizer(
|
| text,
|
| padding=True,
|
| truncation=True,
|
| max_length=128,
|
| return_tensors="pt"
|
| ).to(self.device)
|
| outputs = self.model(**inputs)
|
| preds.append(torch.argmax(outputs.logits).item())
|
| return self.label_encoder.inverse_transform(preds), time.time() - start_time
|
|
|
|
|
| def plot_attention(text, model, tokenizer):
|
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
|
| outputs = model(**inputs, output_attentions=True)
|
| attention = outputs.attentions[-1].squeeze(0).mean(dim=0)
|
| tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
|
|
| plt.figure(figsize=(10, 8))
|
| sns.heatmap(attention.detach().cpu().numpy(),
|
| xticklabels=tokens,
|
| yticklabels=tokens,
|
| cmap="YlGnBu")
|
| plt.title("Attention Scores")
|
| st.pyplot(plt)
|
|
|
| @st.cache_resource
|
| def load_models():
|
|
|
| classical_ml = ClassicalML()
|
| with open('models/classical_ml/pipeline.pkl', 'rb') as f:
|
| classical_ml.pipeline = pickle.load(f)
|
| with open('models/classical_ml/label_encoder.pkl', 'rb') as f:
|
| classical_ml.label_encoder = pickle.load(f)
|
|
|
|
|
| lstm = LSTMTrainer()
|
| try:
|
|
|
| checkpoint = torch.load(
|
| 'models/lstm/model.pt',
|
| map_location=torch.device('cpu'),
|
| weights_only=True
|
| )
|
| except:
|
|
|
| with safe_globals([LabelEncoder]):
|
| checkpoint = torch.load(
|
| 'models/lstm/model.pt',
|
| map_location=torch.device('cpu'),
|
| weights_only=False
|
| )
|
|
|
| lstm.vocab = checkpoint['vocab']
|
| lstm.label_encoder = checkpoint['label_encoder']
|
| lstm.model = LSTMModel(
|
| len(lstm.vocab),
|
| checkpoint['embed_dim'],
|
| checkpoint['hidden_dim'],
|
| len(lstm.label_encoder.classes_)
|
| ).to(lstm.device)
|
| lstm.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
| bert = BERTClassifier()
|
| bert.tokenizer = AutoTokenizer.from_pretrained('models/bert')
|
| bert.model = BertForSequenceClassification.from_pretrained('models/bert')
|
| bert.model.to(bert.device)
|
| with open('models/bert/label_encoder.pkl', 'rb') as f:
|
| bert.label_encoder = pickle.load(f)
|
|
|
| return classical_ml, lstm, bert
|
|
|
|
|
| class LSTMModel(nn.Module):
|
| def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
|
| super().__init__()
|
| self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
|
| self.attention = Attention(hidden_dim)
|
| self.fc = nn.Linear(hidden_dim, output_dim)
|
| self.dropout = nn.Dropout(0.5)
|
|
|
| def forward(self, text, lengths):
|
| embedded = self.embedding(text)
|
| packed = pack_padded_sequence(
|
| embedded,
|
| lengths.cpu(),
|
| batch_first=True,
|
| enforce_sorted=False
|
| )
|
| packed_output, (hidden, cell) = self.lstm(packed)
|
| output, _ = pad_packed_sequence(packed_output, batch_first=True)
|
| context = self.attention(output)
|
| return self.fc(self.dropout(context))
|
|
|
|
|
| def main():
|
| st.title("Анализ отзывов медицинских учреждений")
|
|
|
|
|
| classical_ml, lstm, bert = load_models()
|
|
|
|
|
| metrics = {
|
| 'Classical ML': {'f1_macro': 0.85, 'inference_time': 0.01},
|
| 'LSTM': {'f1_macro': 0.87, 'inference_time': 0.12},
|
| 'BERT': {'f1_macro': 0.92, 'inference_time': 0.05}
|
| }
|
| metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
|
|
|
|
|
| user_input = st.text_area("Введите ваш отзыв:", "Очень хорошая клиника, внимательные врачи")
|
|
|
| if st.button("Проанализировать отзыв"):
|
| if user_input:
|
|
|
| input_with_category = f"Поликлиники стоматологические {user_input}"
|
|
|
| with st.spinner('Обработка...'):
|
|
|
| ml_pred, ml_time = classical_ml.predict([input_with_category])
|
| lstm_pred, lstm_time = lstm.predict([input_with_category])
|
| bert_pred, bert_time = bert.predict([input_with_category])
|
|
|
|
|
| col1, col2, col3 = st.columns(3)
|
|
|
| with col1:
|
| st.subheader("Classical ML")
|
| st.metric("Предсказание", ml_pred[0])
|
| st.metric("Время (сек)", f"{ml_time:.4f}")
|
|
|
| with col2:
|
| st.subheader("LSTM")
|
| st.metric("Предсказание", lstm_pred[0])
|
| st.metric("Время (сек)", f"{lstm_time:.4f}")
|
|
|
| with col3:
|
| st.subheader("BERT")
|
| st.metric("Предсказание", bert_pred[0])
|
| st.metric("Время (сек)", f"{bert_time:.4f}")
|
|
|
|
|
| st.header("Attention-механизм BERT")
|
| plot_attention(user_input, bert.model, bert.tokenizer)
|
|
|
|
|
| st.header("Сравнение моделей")
|
| st.dataframe(metrics_df.style.highlight_max(axis=0))
|
|
|
|
|
| st.header("Визуализация метрик")
|
| fig, ax = plt.subplots(1, 2, figsize=(15, 5))
|
|
|
|
|
| metrics_df['f1_macro'].plot(kind='bar', ax=ax[0], color='skyblue')
|
| ax[0].set_title('F1-macro score')
|
| ax[0].set_ylabel('Score')
|
|
|
|
|
| metrics_df['inference_time'].plot(kind='bar', ax=ax[1], color='salmon')
|
| ax[1].set_title('Время предсказания (сек)')
|
| ax[1].set_ylabel('Seconds')
|
|
|
| st.pyplot(fig)
|
|
|
| if __name__ == "__main__":
|
| main() |