| from pathlib import Path |
| import wandb |
|
|
|
|
| def is_run_id(run_id: str) -> bool: |
| """Check if a string is a run ID.""" |
| return len(run_id) == 8 and run_id.isalnum() |
|
|
|
|
| def version_to_int(artifact) -> int: |
| """Convert versions of the form vX to X. For example, v12 to 12.""" |
| return int(artifact.version[1:]) |
|
|
|
|
| def download_latest_checkpoint(run_path: str, download_dir: Path) -> Path: |
| api = wandb.Api() |
| run = api.run(run_path) |
|
|
| |
| latest = None |
| for artifact in run.logged_artifacts(): |
| if artifact.type != "model" or artifact.state != "COMMITTED": |
| continue |
|
|
| if latest is None or version_to_int(artifact) > version_to_int(latest): |
| latest = artifact |
|
|
| |
| download_dir.mkdir(exist_ok=True, parents=True) |
| root = download_dir / run_path |
| latest.download(root=root) |
| return root / "model.ckpt" |
|
|