Upload eigen_moe.py with huggingface_hub
Browse files- eigen_moe.py +12 -12
eigen_moe.py
CHANGED
|
@@ -390,13 +390,13 @@ class HFEigenMoE(nn.Module, PyTorchModelHubMixin):
|
|
| 390 |
vit_model_name: str,
|
| 391 |
) -> str:
|
| 392 |
base = Path(model_dir)
|
| 393 |
-
candidates = []
|
| 394 |
if checkpoint_filename:
|
| 395 |
-
candidates
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
| 400 |
|
| 401 |
for filename in candidates:
|
| 402 |
path = base / filename
|
|
@@ -429,13 +429,13 @@ class HFEigenMoE(nn.Module, PyTorchModelHubMixin):
|
|
| 429 |
if hf_hub_download is None:
|
| 430 |
raise ImportError("huggingface_hub is required to download checkpoints from the Hub.")
|
| 431 |
|
| 432 |
-
candidates = []
|
| 433 |
if checkpoint_filename:
|
| 434 |
-
candidates
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
|
|
|
| 439 |
|
| 440 |
seen = set()
|
| 441 |
unique_candidates = []
|
|
|
|
| 390 |
vit_model_name: str,
|
| 391 |
) -> str:
|
| 392 |
base = Path(model_dir)
|
|
|
|
| 393 |
if checkpoint_filename:
|
| 394 |
+
candidates = [checkpoint_filename]
|
| 395 |
+
else:
|
| 396 |
+
candidates = ["model.safetensors", "pytorch_model.bin"]
|
| 397 |
+
default_name = default_hub_checkpoint_filename(vit_model_name)
|
| 398 |
+
if default_name:
|
| 399 |
+
candidates.append(default_name)
|
| 400 |
|
| 401 |
for filename in candidates:
|
| 402 |
path = base / filename
|
|
|
|
| 429 |
if hf_hub_download is None:
|
| 430 |
raise ImportError("huggingface_hub is required to download checkpoints from the Hub.")
|
| 431 |
|
|
|
|
| 432 |
if checkpoint_filename:
|
| 433 |
+
candidates = [checkpoint_filename]
|
| 434 |
+
else:
|
| 435 |
+
candidates = ["model.safetensors", "pytorch_model.bin"]
|
| 436 |
+
default_name = default_hub_checkpoint_filename(vit_model_name)
|
| 437 |
+
if default_name:
|
| 438 |
+
candidates.append(default_name)
|
| 439 |
|
| 440 |
seen = set()
|
| 441 |
unique_candidates = []
|