mirralz commited on
Commit
9bd8832
·
verified ·
1 Parent(s): e30c664

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -4,9 +4,11 @@ from transformers import BertTokenizer
4
  import gradio as gr
5
  from model import BertClassifier
6
 
7
- # Загрузка модели
8
- model = BertClassifier(dropout_rate=0.3, num_classes=4)
9
- model.load_state_dict(torch.load("bert_mc_dropout.pt", map_location="cpu"))
 
 
10
  model.eval()
11
 
12
  # Токенизатор
 
4
  import gradio as gr
5
  from model import BertClassifier
6
 
7
+ model = BertClassifier()
8
+ state_dict = torch.load("bert_mc_dropout.pt", map_location="cpu")
9
+
10
+ # подгружаем корректно
11
+ model.load_state_dict(state_dict, strict=False)
12
  model.eval()
13
 
14
  # Токенизатор