Lukas Folle commited on
Commit
94e4f13
·
1 Parent(s): 0c23c71

Add compatibility function for huggingface_hub download

Browse files
Files changed (2) hide show
  1. DummyModel.py +10 -4
  2. Model.py +10 -4
DummyModel.py CHANGED
@@ -4,13 +4,19 @@ import torch.nn
4
  from huggingface_hub import hf_hub_download
5
 
6
 
 
 
 
 
 
 
 
 
7
  def load_dummy_model(DEBUG):
8
  model = DummyModel()
9
  if not DEBUG:
10
- file_path = hf_hub_download(
11
- "lfolle/DeepNAPSIModel",
12
- "dummy_model.pth",
13
- use_auth_token=os.environ["DeepNAPSIModel"],
14
  )
15
  model.load_state_dict(torch.load(file_path))
16
  return model
 
4
  from huggingface_hub import hf_hub_download
5
 
6
 
7
+ def _hf_hub_download_compat(repo_id: str, filename: str, token: str) -> str:
8
+ try:
9
+ return hf_hub_download(repo_id, filename, token=token)
10
+ except TypeError:
11
+ # Backward compatibility for older huggingface_hub releases.
12
+ return hf_hub_download(repo_id, filename, use_auth_token=token)
13
+
14
+
15
  def load_dummy_model(DEBUG):
16
  model = DummyModel()
17
  if not DEBUG:
18
+ file_path = _hf_hub_download_compat(
19
+ "lfolle/DeepNAPSIModel", "dummy_model.pth", os.environ["DeepNAPSIModel"]
 
 
20
  )
21
  model.load_state_dict(torch.load(file_path))
22
  return model
Model.py CHANGED
@@ -3,6 +3,14 @@ from huggingface_hub import hf_hub_download
3
  from nail_classification.inference import Inference
4
 
5
 
 
 
 
 
 
 
 
 
6
  class Model:
7
  def __init__(self, DEBUG):
8
  if DEBUG:
@@ -10,10 +18,8 @@ class Model:
10
  file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
11
  else:
12
  file_paths = [
13
- hf_hub_download(
14
- "lfolle/DeepNAPSIModel",
15
- f"version_{v}.ckpt",
16
- use_auth_token=os.environ["DeepNAPSIModel"],
17
  )
18
  for v in [10, 11, 12, 13, 14]
19
  ]
 
3
  from nail_classification.inference import Inference
4
 
5
 
6
+ def _hf_hub_download_compat(repo_id: str, filename: str, token: str) -> str:
7
+ try:
8
+ return hf_hub_download(repo_id, filename, token=token)
9
+ except TypeError:
10
+ # Backward compatibility for older huggingface_hub releases.
11
+ return hf_hub_download(repo_id, filename, use_auth_token=token)
12
+
13
+
14
  class Model:
15
  def __init__(self, DEBUG):
16
  if DEBUG:
 
18
  file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
19
  else:
20
  file_paths = [
21
+ _hf_hub_download_compat(
22
+ "lfolle/DeepNAPSIModel", f"version_{v}.ckpt", os.environ["DeepNAPSIModel"]
 
 
23
  )
24
  for v in [10, 11, 12, 13, 14]
25
  ]