Update api/predict.py
Browse files- api/predict.py +10 -5
api/predict.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
import re
|
| 4 |
-
import os
|
| 5 |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
|
| 6 |
-
|
| 7 |
-
scriptDir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 8 |
-
modelsDir = os.path.join(scriptDir, "models")
|
| 9 |
|
| 10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 11 |
tokenizer = None
|
|
@@ -26,12 +23,20 @@ def load_resources():
|
|
| 26 |
|
| 27 |
models = []
|
| 28 |
for i in range(1, 6):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
model = DistilBertForSequenceClassification.from_pretrained(
|
| 30 |
'distilbert-base-uncased',
|
| 31 |
num_labels=num_classes,
|
| 32 |
dropout=dropout
|
| 33 |
)
|
| 34 |
-
model.load_state_dict(torch.load(
|
| 35 |
model = model.to(device)
|
| 36 |
model.eval()
|
| 37 |
models.append(model)
|
|
|
|
| 1 |
import torch
|
| 2 |
import numpy as np
|
| 3 |
import re
|
|
|
|
| 4 |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
| 6 |
|
| 7 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 8 |
tokenizer = None
|
|
|
|
| 23 |
|
| 24 |
models = []
|
| 25 |
for i in range(1, 6):
|
| 26 |
+
model_filename = f"ensemble_model_{i}.pth"
|
| 27 |
+
|
| 28 |
+
print(f"downloading {model_filename}...")
|
| 29 |
+
model_path = hf_hub_download(
|
| 30 |
+
repo_id="codingcoolfun9ed/sentinelcheck-models",
|
| 31 |
+
filename=model_filename
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
model = DistilBertForSequenceClassification.from_pretrained(
|
| 35 |
'distilbert-base-uncased',
|
| 36 |
num_labels=num_classes,
|
| 37 |
dropout=dropout
|
| 38 |
)
|
| 39 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 40 |
model = model.to(device)
|
| 41 |
model.eval()
|
| 42 |
models.append(model)
|