Spaces:
Build error
Build error
movie kdnv update
Browse files- app.py +3 -3
- models/bert.pt +3 -0
- models/kdnv_models.py +71 -0
- models/logistic_regression_pipeline.pkl +3 -0
- models/lstm.pt +3 -0
- models/word2int.json +0 -0
- pages/kudinov_films.py +94 -0
app.py
CHANGED
|
@@ -11,11 +11,11 @@ col1, col2, col3 = st.columns(3)
|
|
| 11 |
# st.page_link('pages/chernyshov_learning.py', label='Обучение', icon='💀')
|
| 12 |
|
| 13 |
with col2:
|
| 14 |
-
st.page_link('pages/natasha_model.py', label='
|
| 15 |
# st.page_link('pages/bond_learning.py', label='Обучение', icon='ℹ️')
|
| 16 |
|
| 17 |
with col3:
|
| 18 |
-
st.page_link('pages/kdnv_model.py', label='
|
| 19 |
-
|
| 20 |
|
| 21 |
st.divider()
|
|
|
|
| 11 |
# st.page_link('pages/chernyshov_learning.py', label='Обучение', icon='💀')
|
| 12 |
|
| 13 |
with col2:
|
| 14 |
+
st.page_link('pages/natasha_model.py', label='Токсикметр Наташи', icon='🤬')
|
| 15 |
# st.page_link('pages/bond_learning.py', label='Обучение', icon='ℹ️')
|
| 16 |
|
| 17 |
with col3:
|
| 18 |
+
st.page_link('pages/kdnv_model.py', label='Ночной собутыльник Серёжи', icon='🍻')
|
| 19 |
+
st.page_link('pages/kudinov_films.py', label='Оценщик фильмов Серёжи', icon='🎥')
|
| 20 |
|
| 21 |
st.divider()
|
models/bert.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:177ac13ff0731ee3aa8ba9e4ec0b09f177f558cd9b0c0c596c73470d85dab44b
|
| 3 |
+
size 117119720
|
models/kdnv_models.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
HIDDEN_SIZE = 128
|
| 6 |
+
EMBEDDING_DIM = 128
|
| 7 |
+
VOCAB_SIZE = 1980
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bandanau(nn.Module):
|
| 11 |
+
def __init__(self, HIDDEN_SIZE) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.hidden_size = HIDDEN_SIZE
|
| 14 |
+
self.linearwk = nn.Linear(self.hidden_size, self.hidden_size)
|
| 15 |
+
self.linearwa = nn.Linear(self.hidden_size, self.hidden_size)
|
| 16 |
+
self.linearwv = nn.Linear(self.hidden_size, 1)
|
| 17 |
+
|
| 18 |
+
def forward(
|
| 19 |
+
self,
|
| 20 |
+
lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
|
| 21 |
+
final_hidden: torch.Tensor # BATCH_SIZE x HIDDEN_SIZE
|
| 22 |
+
):
|
| 23 |
+
final_hidden = final_hidden.unsqueeze(1)
|
| 24 |
+
|
| 25 |
+
wk_out = self.linearwk(lstm_outputs)
|
| 26 |
+
wa_out = self.linearwa(final_hidden)
|
| 27 |
+
|
| 28 |
+
plus = F.tanh(wk_out + wa_out)
|
| 29 |
+
|
| 30 |
+
wv_out = self.linearwv(plus)
|
| 31 |
+
|
| 32 |
+
attention_weights = F.softmax(wv_out, dim=1)
|
| 33 |
+
|
| 34 |
+
attention_weights = attention_weights.transpose(1, 2)
|
| 35 |
+
|
| 36 |
+
context = torch.bmm(attention_weights, wk_out)
|
| 37 |
+
|
| 38 |
+
context = context.squeeze(1)
|
| 39 |
+
attention_weights = attention_weights.squeeze(1)
|
| 40 |
+
|
| 41 |
+
return context, attention_weights
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# %%
|
| 45 |
+
class LSTMConcatAttention(nn.Module):
|
| 46 |
+
def __init__(self) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
|
| 50 |
+
# self.embedding = embedding_layer
|
| 51 |
+
self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
|
| 52 |
+
self.attn = Bandanau(HIDDEN_SIZE)
|
| 53 |
+
self.clf = nn.Sequential(
|
| 54 |
+
nn.Linear(HIDDEN_SIZE, 512),
|
| 55 |
+
nn.Dropout(0.3),
|
| 56 |
+
nn.Tanh(),
|
| 57 |
+
nn.Linear(512, 256),
|
| 58 |
+
nn.Dropout(0.3),
|
| 59 |
+
nn.Tanh(),
|
| 60 |
+
nn.Linear(256, 128),
|
| 61 |
+
nn.Dropout(0.3),
|
| 62 |
+
nn.Tanh(),
|
| 63 |
+
nn.Linear(128, 3)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
embeddings = self.embedding(x)
|
| 68 |
+
outputs, (h_n, _) = self.lstm(embeddings)
|
| 69 |
+
att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
|
| 70 |
+
out = self.clf(att_hidden)
|
| 71 |
+
return out, att_weights
|
models/logistic_regression_pipeline.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3613a44d3c9b524869e84b21eb720ac071e1188473def7a5ead3b14fafadde16
|
| 3 |
+
size 5808993
|
models/lstm.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:08546effc4bb9a268e1ec53b9ce41bd721673fb9053cff07a24dd62a87061bba
|
| 3 |
+
size 2602410
|
models/word2int.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pages/kudinov_films.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from models.kdnv_models import LSTMConcatAttention
|
| 4 |
+
from models.kdnv_preprocess import preprocess_single_string, data_preprocessing
|
| 5 |
+
import json
|
| 6 |
+
import transformers
|
| 7 |
+
from torch import nn
|
| 8 |
+
import joblib
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
autotoken = transformers.AutoTokenizer.from_pretrained(
|
| 12 |
+
"cointegrated/rubert-tiny2"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@st.cache_resource()
|
| 17 |
+
def load_model_lstm():
|
| 18 |
+
model = LSTMConcatAttention()
|
| 19 |
+
model.load_state_dict(torch.load('models/lstm.pt', map_location=torch.device('cpu'), weights_only=True))
|
| 20 |
+
return model
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@st.cache_resource()
|
| 24 |
+
def load_model_bert():
|
| 25 |
+
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
| 26 |
+
"cointegrated/rubert-tiny2"
|
| 27 |
+
)
|
| 28 |
+
model.classifier = nn.Sequential(
|
| 29 |
+
nn.Linear(in_features=312, out_features=256),
|
| 30 |
+
nn.Sigmoid(),
|
| 31 |
+
nn.Dropout(0.5),
|
| 32 |
+
nn.Linear(in_features=256, out_features=3)
|
| 33 |
+
)
|
| 34 |
+
model.load_state_dict(torch.load('models/bert.pt', map_location=torch.device('cpu'), weights_only=True))
|
| 35 |
+
return model
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
model_bert = load_model_bert()
|
| 39 |
+
model_lstm = load_model_lstm()
|
| 40 |
+
model_lr = joblib.load('models/logistic_regression_pipeline.pkl')
|
| 41 |
+
|
| 42 |
+
labels_dict = {
|
| 43 |
+
0: 'Негативный',
|
| 44 |
+
1: 'Нейтральный',
|
| 45 |
+
2: 'Позитивный'
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
with open('models/word2int.json', 'r') as f:
|
| 50 |
+
vocab2int = json.load(f)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def predict_lstm(text):
|
| 54 |
+
sample = preprocess_single_string(text, 32, vocab2int).long()
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
return labels_dict[model_lstm(sample.unsqueeze(0))[0].argmax(dim=1).item()]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def predict_bert(text):
|
| 60 |
+
sample = data_preprocessing(text)
|
| 61 |
+
sample = autotoken(sample, padding=True, truncation=True, return_tensors='pt')
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
output = model_bert(**sample)
|
| 64 |
+
return labels_dict[output.logits.argmax(dim=1).item()]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def predict_lr(text):
|
| 68 |
+
sample = [data_preprocessing(text)]
|
| 69 |
+
return labels_dict[model_lr.predict(sample)[0]]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
st.title('Аналитик отзывов на фильмы')
|
| 73 |
+
st.caption('От Серёжи')
|
| 74 |
+
st.divider()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
with st.form(key='Отзыв'):
|
| 78 |
+
prompt = st.text_area("Ваш отзыв")
|
| 79 |
+
submit = st.form_submit_button('Оценивай!')
|
| 80 |
+
|
| 81 |
+
if submit:
|
| 82 |
+
ans_lstm = predict_lstm(prompt)
|
| 83 |
+
ans_bert = predict_bert(prompt)
|
| 84 |
+
ans_lr = predict_lr(prompt)
|
| 85 |
+
|
| 86 |
+
col1, col2, col3 = st.columns(3)
|
| 87 |
+
|
| 88 |
+
with col1:
|
| 89 |
+
st.metric(label="LSTM Prediction", value=ans_lstm)
|
| 90 |
+
with col2:
|
| 91 |
+
st.metric(label="BERT Prediction", value=ans_bert)
|
| 92 |
+
with col3:
|
| 93 |
+
st.metric(label="Logistic Regression Prediction", value=ans_lr)
|
| 94 |
+
|