anzheCheng commited on
Commit
588e364
·
verified ·
1 Parent(s): 19ccf99

Upload eigen_moe.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eigen_moe.py +12 -12
eigen_moe.py CHANGED
@@ -390,13 +390,13 @@ class HFEigenMoE(nn.Module, PyTorchModelHubMixin):
390
  vit_model_name: str,
391
  ) -> str:
392
  base = Path(model_dir)
393
- candidates = []
394
  if checkpoint_filename:
395
- candidates.append(checkpoint_filename)
396
- default_name = default_hub_checkpoint_filename(vit_model_name)
397
- if default_name:
398
- candidates.append(default_name)
399
- candidates.extend(["model.safetensors", "pytorch_model.bin"])
 
400
 
401
  for filename in candidates:
402
  path = base / filename
@@ -429,13 +429,13 @@ class HFEigenMoE(nn.Module, PyTorchModelHubMixin):
429
  if hf_hub_download is None:
430
  raise ImportError("huggingface_hub is required to download checkpoints from the Hub.")
431
 
432
- candidates = []
433
  if checkpoint_filename:
434
- candidates.append(checkpoint_filename)
435
- default_name = default_hub_checkpoint_filename(vit_model_name)
436
- if default_name:
437
- candidates.append(default_name)
438
- candidates.extend(["model.safetensors", "pytorch_model.bin"])
 
439
 
440
  seen = set()
441
  unique_candidates = []
 
390
  vit_model_name: str,
391
  ) -> str:
392
  base = Path(model_dir)
 
393
  if checkpoint_filename:
394
+ candidates = [checkpoint_filename]
395
+ else:
396
+ candidates = ["model.safetensors", "pytorch_model.bin"]
397
+ default_name = default_hub_checkpoint_filename(vit_model_name)
398
+ if default_name:
399
+ candidates.append(default_name)
400
 
401
  for filename in candidates:
402
  path = base / filename
 
429
  if hf_hub_download is None:
430
  raise ImportError("huggingface_hub is required to download checkpoints from the Hub.")
431
 
 
432
  if checkpoint_filename:
433
+ candidates = [checkpoint_filename]
434
+ else:
435
+ candidates = ["model.safetensors", "pytorch_model.bin"]
436
+ default_name = default_hub_checkpoint_filename(vit_model_name)
437
+ if default_name:
438
+ candidates.append(default_name)
439
 
440
  seen = set()
441
  unique_candidates = []