Upload mixture/train_embedding_router.py with huggingface_hub
Browse files
mixture/train_embedding_router.py
CHANGED
|
@@ -2607,9 +2607,21 @@ def save_complete_checkpoint(
|
|
| 2607 |
def _resolve_repo_path(path_str: str) -> Path:
|
| 2608 |
"""Resolve paths saved inside checkpoints relative to the repository root."""
|
| 2609 |
path = Path(path_str).expanduser()
|
| 2610 |
-
if
|
| 2611 |
-
|
| 2612 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2613 |
|
| 2614 |
|
| 2615 |
def _checkpoint_to_expert_specs(checkpoint: Mapping[str, Any]) -> List[ExpertSpec]:
|
|
|
|
| 2607 |
def _resolve_repo_path(path_str: str) -> Path:
|
| 2608 |
"""Resolve paths saved inside checkpoints relative to the repository root."""
|
| 2609 |
path = Path(path_str).expanduser()
|
| 2610 |
+
if path.is_absolute():
|
| 2611 |
+
if path.exists():
|
| 2612 |
+
return path
|
| 2613 |
+
# Fallback for checkpoints saved with absolute training paths
|
| 2614 |
+
repo_name = REPO_ROOT.name
|
| 2615 |
+
if repo_name in path.parts:
|
| 2616 |
+
try:
|
| 2617 |
+
repo_idx = path.parts.index(repo_name)
|
| 2618 |
+
candidate = REPO_ROOT.joinpath(*path.parts[repo_idx + 1 :])
|
| 2619 |
+
if candidate.exists():
|
| 2620 |
+
return candidate
|
| 2621 |
+
except ValueError:
|
| 2622 |
+
pass
|
| 2623 |
+
return path
|
| 2624 |
+
return (REPO_ROOT / path).resolve()
|
| 2625 |
|
| 2626 |
|
| 2627 |
def _checkpoint_to_expert_specs(checkpoint: Mapping[str, Any]) -> List[ExpertSpec]:
|