AshenR commited on
Commit
73337c1
·
verified ·
1 Parent(s): 38be73f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -11,10 +11,14 @@ bert_out_address = 'model/'
11
 
12
  # Load the configuration file
13
  config = BertConfig.from_json_file(os.path.join(bert_out_address, "config.json"))
 
14
 
15
  # Load the pre-trained model's weights for sequence classification
16
  model = BertForSequenceClassification(config)
17
- model.load_state_dict(torch.load(os.path.join(bert_out_address, "pytorch_model.bin")))
 
 
 
18
 
19
  # Load the tokenizer
20
  tokenizer = BertTokenizer.from_pretrained(bert_out_address)
 
11
 
12
  # Load the configuration file
13
  config = BertConfig.from_json_file(os.path.join(bert_out_address, "config.json"))
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
 
16
  # Load the pre-trained model's weights for sequence classification
17
  model = BertForSequenceClassification(config)
18
+ # model.load_state_dict(torch.load(os.path.join(bert_out_address, "pytorch_model.bin")))
19
+
20
+ model.load_state_dict(torch.load(os.path.join(bert_out_address, "pytorch_model.bin"), map_location=torch.device(device)))
21
+
22
 
23
  # Load the tokenizer
24
  tokenizer = BertTokenizer.from_pretrained(bert_out_address)