chitlchow commited on
Commit
643a0fa
·
1 Parent(s): eeb35fc

Changed code for loading model

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -1,13 +1,17 @@
1
  import gradio as gr
2
  import torch
3
  from src.model import BertClassifier, RobertaClassifier
 
 
 
 
4
  try:
5
- device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
6
- device = torch.device(device_name)
7
- model = RobertaClassifier(model_name='roberta-base', dropout_rate=0.5)
8
- model.to(device)
9
- model.load_state_dict(torch.load('models/roberta-wiki.pth', map_location=device))
10
- print('Model loaded')
11
  except:
12
  print('Model cannot be loaded')
13
 
@@ -18,4 +22,4 @@ def ai_text_classifier(text: str):
18
  }
19
 
20
  demo = gr.Interface(fn=ai_text_classifier, inputs="text", outputs="label")
21
- demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
  from src.model import BertClassifier, RobertaClassifier
4
+ from transformers import BertTokenizer
5
+ from datetime import datetime
6
+
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
  try:
9
+ model_name = 'bert-base-uncased'
10
+ model = BertClassifier(model_name, 0.5)
11
+ model.load_state_dict(torch.load('bert-all-data.pth', map_location=device))
12
+ tokenizer = BertTokenizer.from_pretrained(model_name)
13
+ print(f"[{datetime.now()}] Model loaded ")
14
+
15
  except:
16
  print('Model cannot be loaded')
17
 
 
22
  }
23
 
24
  demo = gr.Interface(fn=ai_text_classifier, inputs="text", outputs="label")
25
+ demo.launch()