Spaces:
Sleeping
Sleeping
File size: 4,448 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import os
from pathlib import Path
import wandb
from pytorch_lightning.loggers.wandb import WandbLogger
from omegaconf import OmegaConf
from optgs.misc.LocalLogger import LocalLogger
from optgs.paths import DEBUG
def version_to_int(artifact) -> int:
"""Convert versions of the form vX to X. For example, v12 to 12."""
return int(artifact.version[1:])
def download_checkpoint(
run_id: str,
download_dir: Path,
version: str | None,
) -> Path:
api = wandb.Api()
run = api.run(run_id)
# Find the latest saved model checkpoint.
chosen = None
for artifact in run.logged_artifacts():
if artifact.type != "model" or artifact.state != "COMMITTED":
continue
# If no version is specified, use the latest.
if version is None:
if chosen is None or version_to_int(artifact) > version_to_int(chosen):
chosen = artifact
# If a specific verison is specified, look for it.
elif version == artifact.version:
chosen = artifact
break
# Download the checkpoint.
download_dir.mkdir(exist_ok=True, parents=True)
root = download_dir / run_id
chosen.download(root=root)
return root / "model.ckpt"
def setup_wandb_logger(cfg, cfg_dict) -> WandbLogger | LocalLogger:
if cfg_dict.wandb.mode == "disabled" or cfg.mode != "train":
return LocalLogger()
wandb_extra_kwargs = {}
# Detect the wandb id job run if resuming
if cfg_dict.checkpointing.resume:
if cfg_dict.wandb.id is None:
print(f"Resuming wandb run without id, using latest run in output directory.")
# Find the latest wandb run id in the output directory
wandb_dir = cfg_dict.output_dir / "wandb" / "latest-run"
# look for a file name in the format "run-######.wandb" file and extract the id
wandb_files = list(wandb_dir.glob("run-*.wandb"))
assert len(wandb_files) <= 1, "Multiple wandb files found in the latest run directory."
if len(wandb_files) == 1:
wandb_file = wandb_files[0]
wandb_id = wandb_file.stem.split('-')[1]
wandb_extra_kwargs.update({'id': wandb_id, 'resume': "must"})
if cfg_dict.wandb.id is not None:
print(f"Setting wandb run with id from cfg {cfg_dict.wandb.id}.")
wandb_extra_kwargs.update({'id': cfg_dict.wandb.id, 'resume': "must"})
run_name = os.path.basename(cfg_dict.output_dir)
if cfg_dict.log_slurm_id:
hostname = os.uname().nodename
job_id = os.environ.get('SLURM_JOB_ID', "local run: " + hostname)
run_name += f" ({job_id})"
# if debugging, add a tag to the run name
if DEBUG:
run_name += " DEBUG"
cfg_dict.wandb.update({'tags': ['debug']})
if os.environ.get('WANDB_ENTITY') is not None:
cfg_dict.wandb.update({'entity': os.environ.get('WANDB_ENTITY')})
logger = WandbLogger(
entity=cfg_dict.wandb.entity,
project=cfg_dict.wandb.project,
mode=cfg_dict.wandb.mode,
name=run_name,
tags=cfg_dict.wandb.get("tags", None),
log_model=False,
save_dir=cfg_dict.output_dir,
config=OmegaConf.to_container(cfg_dict),
**wandb_extra_kwargs,
)
if logger.experiment is not None:
# Log code
logger.experiment.log_code("optgs")
# Log notes
if cfg_dict.wandb.notes is not None:
logger.experiment.notes = cfg_dict.wandb.notes
# Write wandb run ID to file for SLURM requeue resume
wandb_id_file = os.environ.get("WANDB_ID_FILE")
if wandb_id_file:
with open(wandb_id_file, "w") as f:
f.write(logger.experiment.id)
print(f"Wrote wandb run ID {logger.experiment.id} to {wandb_id_file}")
return logger
def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None:
if path is None:
return None
if not str(path).startswith("wandb://"):
return Path(path)
run_id, *version = path[len("wandb://") :].split(":")
if len(version) == 0:
version = None
elif len(version) == 1:
version = version[0]
else:
raise ValueError("Invalid version specifier!")
project = wandb_cfg["project"]
return download_checkpoint(
f"{project}/{run_id}",
Path("checkpoints"),
version,
)
|