Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| import wandb | |
| from pytorch_lightning.loggers.wandb import WandbLogger | |
| from omegaconf import OmegaConf | |
| from optgs.misc.LocalLogger import LocalLogger | |
| from optgs.paths import DEBUG | |
| def version_to_int(artifact) -> int: | |
| """Convert versions of the form vX to X. For example, v12 to 12.""" | |
| return int(artifact.version[1:]) | |
| def download_checkpoint( | |
| run_id: str, | |
| download_dir: Path, | |
| version: str | None, | |
| ) -> Path: | |
| api = wandb.Api() | |
| run = api.run(run_id) | |
| # Find the latest saved model checkpoint. | |
| chosen = None | |
| for artifact in run.logged_artifacts(): | |
| if artifact.type != "model" or artifact.state != "COMMITTED": | |
| continue | |
| # If no version is specified, use the latest. | |
| if version is None: | |
| if chosen is None or version_to_int(artifact) > version_to_int(chosen): | |
| chosen = artifact | |
| # If a specific verison is specified, look for it. | |
| elif version == artifact.version: | |
| chosen = artifact | |
| break | |
| # Download the checkpoint. | |
| download_dir.mkdir(exist_ok=True, parents=True) | |
| root = download_dir / run_id | |
| chosen.download(root=root) | |
| return root / "model.ckpt" | |
| def setup_wandb_logger(cfg, cfg_dict) -> WandbLogger | LocalLogger: | |
| if cfg_dict.wandb.mode == "disabled" or cfg.mode != "train": | |
| return LocalLogger() | |
| wandb_extra_kwargs = {} | |
| # Detect the wandb id job run if resuming | |
| if cfg_dict.checkpointing.resume: | |
| if cfg_dict.wandb.id is None: | |
| print(f"Resuming wandb run without id, using latest run in output directory.") | |
| # Find the latest wandb run id in the output directory | |
| wandb_dir = cfg_dict.output_dir / "wandb" / "latest-run" | |
| # look for a file name in the format "run-######.wandb" file and extract the id | |
| wandb_files = list(wandb_dir.glob("run-*.wandb")) | |
| assert len(wandb_files) <= 1, "Multiple wandb files found in the latest run directory." | |
| if len(wandb_files) == 1: | |
| wandb_file = wandb_files[0] | |
| wandb_id = wandb_file.stem.split('-')[1] | |
| wandb_extra_kwargs.update({'id': wandb_id, 'resume': "must"}) | |
| if cfg_dict.wandb.id is not None: | |
| print(f"Setting wandb run with id from cfg {cfg_dict.wandb.id}.") | |
| wandb_extra_kwargs.update({'id': cfg_dict.wandb.id, 'resume': "must"}) | |
| run_name = os.path.basename(cfg_dict.output_dir) | |
| if cfg_dict.log_slurm_id: | |
| hostname = os.uname().nodename | |
| job_id = os.environ.get('SLURM_JOB_ID', "local run: " + hostname) | |
| run_name += f" ({job_id})" | |
| # if debugging, add a tag to the run name | |
| if DEBUG: | |
| run_name += " DEBUG" | |
| cfg_dict.wandb.update({'tags': ['debug']}) | |
| if os.environ.get('WANDB_ENTITY') is not None: | |
| cfg_dict.wandb.update({'entity': os.environ.get('WANDB_ENTITY')}) | |
| logger = WandbLogger( | |
| entity=cfg_dict.wandb.entity, | |
| project=cfg_dict.wandb.project, | |
| mode=cfg_dict.wandb.mode, | |
| name=run_name, | |
| tags=cfg_dict.wandb.get("tags", None), | |
| log_model=False, | |
| save_dir=cfg_dict.output_dir, | |
| config=OmegaConf.to_container(cfg_dict), | |
| **wandb_extra_kwargs, | |
| ) | |
| if logger.experiment is not None: | |
| # Log code | |
| logger.experiment.log_code("optgs") | |
| # Log notes | |
| if cfg_dict.wandb.notes is not None: | |
| logger.experiment.notes = cfg_dict.wandb.notes | |
| # Write wandb run ID to file for SLURM requeue resume | |
| wandb_id_file = os.environ.get("WANDB_ID_FILE") | |
| if wandb_id_file: | |
| with open(wandb_id_file, "w") as f: | |
| f.write(logger.experiment.id) | |
| print(f"Wrote wandb run ID {logger.experiment.id} to {wandb_id_file}") | |
| return logger | |
| def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: | |
| if path is None: | |
| return None | |
| if not str(path).startswith("wandb://"): | |
| return Path(path) | |
| run_id, *version = path[len("wandb://") :].split(":") | |
| if len(version) == 0: | |
| version = None | |
| elif len(version) == 1: | |
| version = version[0] | |
| else: | |
| raise ValueError("Invalid version specifier!") | |
| project = wandb_cfg["project"] | |
| return download_checkpoint( | |
| f"{project}/{run_id}", | |
| Path("checkpoints"), | |
| version, | |
| ) | |