| | from pathlib import Path |
| |
|
| | import wandb |
| |
|
| |
|
| | def version_to_int(artifact) -> int: |
| | """Convert versions of the form vX to X. For example, v12 to 12.""" |
| | return int(artifact.version[1:]) |
| |
|
| |
|
| | def download_checkpoint( |
| | run_id: str, |
| | download_dir: Path, |
| | version: str | None, |
| | ) -> Path: |
| | api = wandb.Api() |
| | run = api.run(run_id) |
| |
|
| | |
| | chosen = None |
| | for artifact in run.logged_artifacts(): |
| | if artifact.type != "model" or artifact.state != "COMMITTED": |
| | continue |
| |
|
| | |
| | if version is None: |
| | if chosen is None or version_to_int(artifact) > version_to_int(chosen): |
| | chosen = artifact |
| |
|
| | |
| | elif version == artifact.version: |
| | chosen = artifact |
| | break |
| |
|
| | |
| | download_dir.mkdir(exist_ok=True, parents=True) |
| | root = download_dir / run_id |
| | chosen.download(root=root) |
| | return root / "model.ckpt" |
| |
|
| |
|
| | def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: |
| | if path is None: |
| | return None |
| |
|
| | if not str(path).startswith("wandb://"): |
| | return Path(path) |
| |
|
| | run_id, *version = path[len("wandb://") :].split(":") |
| | if len(version) == 0: |
| | version = None |
| | elif len(version) == 1: |
| | version = version[0] |
| | else: |
| | raise ValueError("Invalid version specifier!") |
| |
|
| | project = wandb_cfg["project"] |
| | return download_checkpoint( |
| | f"{project}/{run_id}", |
| | Path("checkpoints"), |
| | version, |
| | ) |
| |
|