Kirill commited on
Commit
fe5559c
·
1 Parent(s): abc2c55

fix model load

Browse files
Files changed (2) hide show
  1. app.py +6 -2
  2. model_train.ipynb +0 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
 
3
- from transformers import AutoTokenizer, DistilBertForSequenceClassification
4
  import torch
5
  from torch.nn.functional import softmax
6
 
@@ -28,7 +28,11 @@ id_to_description = load_tags_info()
28
 
29
  @st.cache_resource
30
  def load_model():
31
- return DistilBertForSequenceClassification.from_pretrained('./')
 
 
 
 
32
 
33
  def load_tokenizer():
34
  return AutoTokenizer.from_pretrained('distilbert-base-uncased')
 
1
  import streamlit as st
2
 
3
+ from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertConfig
4
  import torch
5
  from torch.nn.functional import softmax
6
 
 
28
 
29
  @st.cache_resource
30
  def load_model():
31
+ config = DistilBertConfig.from_json_file('./config.json')
32
+ model = DistilBertForSequenceClassification(config)
33
+ state_dict = torch.load('./pytorch_model.bin')
34
+ model.load_state_dict(state_dict)
35
+ return model
36
 
37
  def load_tokenizer():
38
  return AutoTokenizer.from_pretrained('distilbert-base-uncased')
model_train.ipynb CHANGED
The diff for this file is too large to render. See raw diff