codingcoolfun9ed commited on
Commit
b363dbf
·
verified ·
1 Parent(s): 5e99838

Update api/predict.py

Browse files
Files changed (1) hide show
  1. api/predict.py +10 -5
api/predict.py CHANGED
@@ -1,11 +1,8 @@
1
  import torch
2
  import numpy as np
3
  import re
4
- import os
5
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
6
-
7
- scriptDir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
- modelsDir = os.path.join(scriptDir, "models")
9
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  tokenizer = None
@@ -26,12 +23,20 @@ def load_resources():
26
 
27
  models = []
28
  for i in range(1, 6):
 
 
 
 
 
 
 
 
29
  model = DistilBertForSequenceClassification.from_pretrained(
30
  'distilbert-base-uncased',
31
  num_labels=num_classes,
32
  dropout=dropout
33
  )
34
- model.load_state_dict(torch.load(os.path.join(modelsDir, f"ensemble_model_{i}.pth"), map_location=device))
35
  model = model.to(device)
36
  model.eval()
37
  models.append(model)
 
1
  import torch
2
  import numpy as np
3
  import re
 
4
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
5
+ from huggingface_hub import hf_hub_download
 
 
6
 
7
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
  tokenizer = None
 
23
 
24
  models = []
25
  for i in range(1, 6):
26
+ model_filename = f"ensemble_model_{i}.pth"
27
+
28
+ print(f"downloading {model_filename}...")
29
+ model_path = hf_hub_download(
30
+ repo_id="codingcoolfun9ed/sentinelcheck-models",
31
+ filename=model_filename
32
+ )
33
+
34
  model = DistilBertForSequenceClassification.from_pretrained(
35
  'distilbert-base-uncased',
36
  num_labels=num_classes,
37
  dropout=dropout
38
  )
39
+ model.load_state_dict(torch.load(model_path, map_location=device))
40
  model = model.to(device)
41
  model.eval()
42
  models.append(model)