Spaces:
Sleeping
Sleeping
Lukas Folle commited on
Commit ·
94e4f13
1
Parent(s): 0c23c71
Add compatibility function for huggingface_hub download
Browse files- DummyModel.py +10 -4
- Model.py +10 -4
DummyModel.py
CHANGED
|
@@ -4,13 +4,19 @@ import torch.nn
|
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
def load_dummy_model(DEBUG):
|
| 8 |
model = DummyModel()
|
| 9 |
if not DEBUG:
|
| 10 |
-
file_path =
|
| 11 |
-
"lfolle/DeepNAPSIModel",
|
| 12 |
-
"dummy_model.pth",
|
| 13 |
-
use_auth_token=os.environ["DeepNAPSIModel"],
|
| 14 |
)
|
| 15 |
model.load_state_dict(torch.load(file_path))
|
| 16 |
return model
|
|
|
|
| 4 |
from huggingface_hub import hf_hub_download
|
| 5 |
|
| 6 |
|
| 7 |
+
def _hf_hub_download_compat(repo_id: str, filename: str, token: str) -> str:
|
| 8 |
+
try:
|
| 9 |
+
return hf_hub_download(repo_id, filename, token=token)
|
| 10 |
+
except TypeError:
|
| 11 |
+
# Backward compatibility for older huggingface_hub releases.
|
| 12 |
+
return hf_hub_download(repo_id, filename, use_auth_token=token)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
def load_dummy_model(DEBUG):
|
| 16 |
model = DummyModel()
|
| 17 |
if not DEBUG:
|
| 18 |
+
file_path = _hf_hub_download_compat(
|
| 19 |
+
"lfolle/DeepNAPSIModel", "dummy_model.pth", os.environ["DeepNAPSIModel"]
|
|
|
|
|
|
|
| 20 |
)
|
| 21 |
model.load_state_dict(torch.load(file_path))
|
| 22 |
return model
|
Model.py
CHANGED
|
@@ -3,6 +3,14 @@ from huggingface_hub import hf_hub_download
|
|
| 3 |
from nail_classification.inference import Inference
|
| 4 |
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
class Model:
|
| 7 |
def __init__(self, DEBUG):
|
| 8 |
if DEBUG:
|
|
@@ -10,10 +18,8 @@ class Model:
|
|
| 10 |
file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
|
| 11 |
else:
|
| 12 |
file_paths = [
|
| 13 |
-
|
| 14 |
-
"lfolle/DeepNAPSIModel",
|
| 15 |
-
f"version_{v}.ckpt",
|
| 16 |
-
use_auth_token=os.environ["DeepNAPSIModel"],
|
| 17 |
)
|
| 18 |
for v in [10, 11, 12, 13, 14]
|
| 19 |
]
|
|
|
|
| 3 |
from nail_classification.inference import Inference
|
| 4 |
|
| 5 |
|
| 6 |
+
def _hf_hub_download_compat(repo_id: str, filename: str, token: str) -> str:
|
| 7 |
+
try:
|
| 8 |
+
return hf_hub_download(repo_id, filename, token=token)
|
| 9 |
+
except TypeError:
|
| 10 |
+
# Backward compatibility for older huggingface_hub releases.
|
| 11 |
+
return hf_hub_download(repo_id, filename, use_auth_token=token)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
class Model:
|
| 15 |
def __init__(self, DEBUG):
|
| 16 |
if DEBUG:
|
|
|
|
| 18 |
file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
|
| 19 |
else:
|
| 20 |
file_paths = [
|
| 21 |
+
_hf_hub_download_compat(
|
| 22 |
+
"lfolle/DeepNAPSIModel", f"version_{v}.ckpt", os.environ["DeepNAPSIModel"]
|
|
|
|
|
|
|
| 23 |
)
|
| 24 |
for v in [10, 11, 12, 13, 14]
|
| 25 |
]
|