Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import wandb | |
| def version_to_int(artifact) -> int: | |
| """Convert versions of the form vX to X. For example, v12 to 12.""" | |
| return int(artifact.version[1:]) | |
| def download_checkpoint( | |
| run_id: str, | |
| download_dir: Path, | |
| version: str | None, | |
| ) -> Path: | |
| api = wandb.Api() | |
| run = api.run(run_id) | |
| # Find the latest saved model checkpoint. | |
| chosen = None | |
| for artifact in run.logged_artifacts(): | |
| if artifact.type != "model" or artifact.state != "COMMITTED": | |
| continue | |
| # If no version is specified, use the latest. | |
| if version is None: | |
| if chosen is None or version_to_int(artifact) > version_to_int(chosen): | |
| chosen = artifact | |
| # If a specific verison is specified, look for it. | |
| elif version == artifact.version: | |
| chosen = artifact | |
| break | |
| # Download the checkpoint. | |
| download_dir.mkdir(exist_ok=True, parents=True) | |
| root = download_dir / run_id | |
| chosen.download(root=root) | |
| return root / "model.ckpt" | |
| def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: | |
| if path is None: | |
| return None | |
| if not str(path).startswith("wandb://"): | |
| return Path(path) | |
| run_id, *version = path[len("wandb://") :].split(":") | |
| if len(version) == 0: | |
| version = None | |
| elif len(version) == 1: | |
| version = version[0] | |
| else: | |
| raise ValueError("Invalid version specifier!") | |
| project = wandb_cfg["project"] | |
| return download_checkpoint( | |
| f"{project}/{run_id}", | |
| Path("checkpoints"), | |
| version, | |
| ) | |