| """ |
| This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research |
| template [repo](https://github.com/buoyancy99/research-template). |
| By its MIT license, you must keep the above sentence in `README.md` |
| and the `LICENSE` file to credit the author. |
| |
| Main file for the project. This will create and run new experiments and load checkpoints from wandb. |
| Borrowed part of the code from David Charatan and wandb. |
| """ |
|
|
| import os |
| import sys |
| import subprocess |
| import time |
| import re |
| from pathlib import Path |
|
|
| import hydra |
| from omegaconf import DictConfig, OmegaConf |
| from omegaconf.omegaconf import open_dict |
|
|
| from utils.print_utils import cyan |
| from utils.ckpt_utils import download_latest_checkpoint, is_run_id |
| from utils.cluster_utils import submit_slurm_job |
| from utils.distributed_utils import is_rank_zero |
|
|
| WANDB_RUN_ID_FILE = ".wandb_run_id" |
|
|
|
|
| def get_latest_checkpoint(checkpoint_folder: Path, pattern: str = "*.ckpt"): |
| if not checkpoint_folder.exists(): |
| return None |
|
|
| checkpoint_files = [path for path in checkpoint_folder.glob(pattern) if path.is_file() and path.stat().st_size > 0] |
| if not checkpoint_files: |
| return None |
|
|
| last_checkpoint = checkpoint_folder / "last.ckpt" |
| if last_checkpoint in checkpoint_files: |
| return last_checkpoint |
|
|
| def checkpoint_key(path: Path): |
| step_match = re.search(r"step[=_-]?(\d+)", path.stem) |
| step = int(step_match.group(1)) if step_match else -1 |
| return step, path.stat().st_mtime |
|
|
| return max(checkpoint_files, key=checkpoint_key) |
|
|
|
|
| def validate_resume_checkpoint(checkpoint_path: Path) -> Path: |
| if not checkpoint_path.exists(): |
| raise FileNotFoundError(f"Resume checkpoint does not exist: {checkpoint_path}") |
| if not checkpoint_path.is_file(): |
| raise ValueError(f"Resume checkpoint is not a file: {checkpoint_path}") |
| if checkpoint_path.suffix != ".ckpt": |
| raise ValueError(f"Resume checkpoint must be a .ckpt file: {checkpoint_path}") |
| if checkpoint_path.stat().st_size == 0: |
| raise ValueError(f"Resume checkpoint is empty: {checkpoint_path}") |
| return checkpoint_path |
|
|
|
|
| def discover_wandb_run_id(output_dir: Path): |
| run_id_file = output_dir / WANDB_RUN_ID_FILE |
| if run_id_file.exists(): |
| run_id = run_id_file.read_text().strip() |
| if not run_id: |
| raise ValueError(f"W&B run id file is empty: {run_id_file}") |
| return run_id |
|
|
| wandb_dir = output_dir / "wandb" |
| if wandb_dir.exists(): |
| run_dirs = [path for path in wandb_dir.iterdir() if path.is_dir()] |
| run_dirs = [path for path in run_dirs if re.match(r"(offline-run|run)-.+-[A-Za-z0-9]+$", path.name)] |
| if run_dirs: |
| run_dir = max(run_dirs, key=lambda path: path.stat().st_mtime) |
| return run_dir.name.rsplit("-", 1)[-1] |
| return None |
|
|
|
|
| def get_process_rank() -> int: |
| for env_name in ("RANK", "SLURM_PROCID", "LOCAL_RANK"): |
| value = os.environ.get(env_name) |
| if value is not None: |
| return int(value) |
| return 0 |
|
|
|
|
| def wait_for_wandb_run_id(output_dir: Path, timeout_s: float = 300.0): |
| run_id_file = output_dir / WANDB_RUN_ID_FILE |
| deadline = time.time() + timeout_s |
| while time.time() < deadline: |
| if run_id_file.exists(): |
| run_id = run_id_file.read_text().strip() |
| if run_id: |
| return run_id |
| time.sleep(0.5) |
| raise TimeoutError(f"Timed out waiting for rank 0 to create W&B run id file: {run_id_file}") |
|
|
|
|
| def create_wandb_run_id_on_rank_zero(output_dir: Path, requested_run_id=None): |
| if requested_run_id: |
| run_id = requested_run_id |
| else: |
| run_id = discover_wandb_run_id(output_dir) |
| if run_id is None: |
| import wandb |
| run_id = wandb.util.generate_id() |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| run_id_file = output_dir / WANDB_RUN_ID_FILE |
| if run_id_file.exists(): |
| existing_run_id = run_id_file.read_text().strip() |
| if existing_run_id and existing_run_id != run_id: |
| raise ValueError( |
| f"Output directory already belongs to W&B run id {existing_run_id}, " |
| f"but {run_id} was requested. Use a different output_dir or resume id." |
| ) |
| run_id_file.write_text(f"{run_id}\n") |
| return run_id |
|
|
|
|
| def get_or_create_wandb_run_id(output_dir: Path, requested_run_id=None): |
| if get_process_rank() == 0: |
| return create_wandb_run_id_on_rank_zero(output_dir, requested_run_id=requested_run_id) |
| return wait_for_wandb_run_id(output_dir) |
|
|
|
|
| def run_local(cfg: DictConfig): |
| |
| from experiments import build_experiment |
| from utils.wandb_utils import OfflineWandbLogger, SpaceEfficientWandbLogger |
| import lightning.pytorch as pl |
|
|
| |
| if cfg.get("seed", None) is not None: |
| pl.seed_everything(cfg.seed, workers=True) |
|
|
| |
| hydra_cfg = hydra.core.hydra_config.HydraConfig.get() |
| cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices) |
|
|
| with open_dict(cfg): |
| if cfg_choice["experiment"] is not None: |
| cfg.experiment._name = cfg_choice["experiment"] |
| if cfg_choice["dataset"] is not None: |
| cfg.dataset._name = cfg_choice["dataset"] |
| if cfg_choice["algorithm"] is not None: |
| cfg.algorithm._name = cfg_choice["algorithm"] |
|
|
| |
| output_dir = getattr(cfg, "output_dir", None) |
| if output_dir is not None: |
| OmegaConf.set_readonly(hydra_cfg, False) |
| hydra_cfg.runtime.output_dir = output_dir |
| OmegaConf.set_readonly(hydra_cfg, True) |
| |
| output_dir = Path(hydra_cfg.runtime.output_dir) |
| |
| if is_rank_zero: |
| print(cyan(f"Outputs will be saved to:"), output_dir) |
| (output_dir.parents[1] / "latest-run").unlink(missing_ok=True) |
| (output_dir.parents[1] / "latest-run").symlink_to(output_dir, target_is_directory=True) |
|
|
| training_requested = "training" in cfg.experiment.tasks |
| checkpoint_dir = output_dir / "checkpoints" |
| auto_resume = bool(getattr(cfg, "auto_resume", True)) |
| explicit_resume_ckpt = getattr(cfg, "resume_ckpt_path", None) |
| auto_resume_checkpoint_path = None |
| if training_requested: |
| if explicit_resume_ckpt: |
| auto_resume_checkpoint_path = validate_resume_checkpoint(Path(explicit_resume_ckpt)) |
| elif auto_resume: |
| auto_resume_checkpoint_path = get_latest_checkpoint(checkpoint_dir) |
| if auto_resume_checkpoint_path is not None: |
| auto_resume_checkpoint_path = validate_resume_checkpoint(auto_resume_checkpoint_path) |
|
|
| if auto_resume_checkpoint_path and is_rank_zero: |
| print(cyan("Auto-resuming training from:"), auto_resume_checkpoint_path) |
|
|
| with open_dict(cfg): |
| cfg._auto_resuming = auto_resume_checkpoint_path is not None |
| cfg._resume_checkpoint_path = str(auto_resume_checkpoint_path) if auto_resume_checkpoint_path else None |
|
|
| |
| if cfg.wandb.mode != "disabled": |
| |
| resume = cfg.get("resume", None) |
| wandb_run_id = get_or_create_wandb_run_id(output_dir, requested_run_id=resume) |
| name = None if auto_resume_checkpoint_path else f"{cfg.name} ({output_dir.parent.name}/{output_dir.name})" |
|
|
| if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline: |
| logger_cls = OfflineWandbLogger |
| else: |
| logger_cls = SpaceEfficientWandbLogger |
|
|
| offline = cfg.wandb.mode != "online" |
| logger = logger_cls( |
| name=name, |
| save_dir=str(output_dir), |
| offline=offline, |
| entity=cfg.wandb.entity, |
| project=cfg.wandb.project, |
| log_model=False, |
| config=OmegaConf.to_container(cfg), |
| id=wandb_run_id, |
| resume="auto" |
| ) |
|
|
| else: |
| logger = None |
|
|
| |
| resume = cfg.get("resume", None) |
| load = cfg.get("load", None) |
| checkpoint_path = auto_resume_checkpoint_path |
| load_id = None |
| if checkpoint_path is None and load and not is_run_id(load): |
| checkpoint_path = load |
| if checkpoint_path is None and resume: |
| load_id = resume |
| elif checkpoint_path is None and load and is_run_id(load): |
| load_id = load |
| else: |
| load_id = None |
|
|
| if load_id: |
| checkpoint_path = get_latest_checkpoint(output_dir / "checkpoints") |
| if checkpoint_path is None: |
| raise FileNotFoundError(f"No checkpoint found under {output_dir / 'checkpoints'} for run id {load_id}") |
| checkpoint_path = validate_resume_checkpoint(checkpoint_path) |
| |
| if checkpoint_path and is_rank_zero: |
| print(f"Will load checkpoint from {checkpoint_path}") |
|
|
| |
| experiment = build_experiment(cfg, logger, checkpoint_path) |
| for task in cfg.experiment.tasks: |
| experiment.exec_task(task) |
|
|
|
|
| def run_slurm(cfg: DictConfig): |
| python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True" |
| project_root = Path.cwd() |
| while not (project_root / ".git").exists(): |
| project_root = project_root.parent |
| if project_root == Path("/"): |
| raise Exception("Could not find repo directory!") |
|
|
| slurm_log_dir = submit_slurm_job( |
| cfg, |
| python_args, |
| project_root, |
| ) |
|
|
| if "cluster" in cfg and cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online": |
| print("Job submitted to a compute node without internet. This requires manual syncing on login node.") |
| osh_command_dir = project_root / ".wandb_osh_command_dir" |
|
|
| osh_proc = None |
| |
| osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir]) |
| print(f"Running wandb-osh in background... PID: {osh_proc.pid}") |
| print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.") |
| print( |
| f"You can manually start a sync loop later by running the following:", |
| cyan(f"wandb-osh --command-dir {osh_command_dir}"), |
| ) |
|
|
| print( |
| "Once the job gets allocated and starts running, we will print a command below " |
| "for you to trace the errors and outputs: (Ctrl + C to exit without waiting)" |
| ) |
| msg = f"tail -f {slurm_log_dir}/* \n" |
| try: |
| while not list(slurm_log_dir.glob("*.out")) and not list(slurm_log_dir.glob("*.err")): |
| time.sleep(1) |
| print(cyan("To trace the outputs and errors, run the following command:"), msg) |
| except KeyboardInterrupt: |
| print("Keyboard interrupt detected. Exiting...") |
| print( |
| cyan("To trace the outputs and errors, manually wait for the job to start and run the following command:"), |
| msg, |
| ) |
|
|
|
|
| @hydra.main( |
| version_base=None, |
| config_path="configurations", |
| config_name="training", |
| ) |
| def run(cfg: DictConfig): |
| if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline: |
| with open_dict(cfg): |
| if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online": |
| cfg.wandb.mode = "offline" |
|
|
| if "name" not in cfg: |
| raise ValueError("must specify a name for the run with command line argument '+name=[name]'") |
|
|
| if not cfg.wandb.get("entity", None): |
| raise ValueError( |
| "must specify wandb entity in 'configurations/config.yaml' or with command line" |
| " argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group" |
| " name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/" |
| ) |
|
|
| if cfg.wandb.project is None: |
| cfg.wandb.project = str(Path(__file__).parent.name) |
|
|
| if "cluster" in cfg and not "_on_compute_node" in cfg: |
| print(cyan("Slurm detected, submitting to compute node instead of running locally...")) |
| run_slurm(cfg) |
| else: |
| run_local(cfg) |
|
|
|
|
| if __name__ == "__main__": |
| run() |
|
|