| """ |
| This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research |
| template [repo](https://github.com/buoyancy99/research-template). |
| By its MIT license, you must keep the above sentence in `README.md` |
| and the `LICENSE` file to credit the author. |
| |
| Main file for the project. This will create and run new experiments and load checkpoints from wandb. |
| Borrowed part of the code from David Charatan and wandb. |
| """ |
|
|
| import sys |
| import subprocess |
| import time |
| from pathlib import Path |
|
|
| import hydra |
| from omegaconf import DictConfig, OmegaConf |
| from omegaconf.omegaconf import open_dict |
|
|
| from utils.print_utils import cyan |
| from utils.ckpt_utils import download_latest_checkpoint, is_run_id |
| from utils.cluster_utils import submit_slurm_job |
| from utils.distributed_utils import is_rank_zero |
|
|
| def get_latest_checkpoint(checkpoint_folder: Path, pattern: str = '*.ckpt'): |
| checkpoint_files = list(checkpoint_folder.glob(pattern)) |
| if not checkpoint_files: |
| return None |
| latest_checkpoint = max(checkpoint_files, key=lambda f: f.stat().st_mtime) |
| return latest_checkpoint |
|
|
| def run_local(cfg: DictConfig): |
| |
| from experiments import build_experiment |
| from utils.wandb_utils import OfflineWandbLogger, SpaceEfficientWandbLogger |
| import lightning.pytorch as pl |
|
|
| |
| if cfg.get("seed", None) is not None: |
| pl.seed_everything(cfg.seed, workers=True) |
|
|
| |
| hydra_cfg = hydra.core.hydra_config.HydraConfig.get() |
| cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices) |
|
|
| with open_dict(cfg): |
| if cfg_choice["experiment"] is not None: |
| cfg.experiment._name = cfg_choice["experiment"] |
| if cfg_choice["dataset"] is not None: |
| cfg.dataset._name = cfg_choice["dataset"] |
| if cfg_choice["algorithm"] is not None: |
| cfg.algorithm._name = cfg_choice["algorithm"] |
|
|
| |
| output_dir = getattr(cfg, "output_dir", None) |
| if output_dir is not None: |
| OmegaConf.set_readonly(hydra_cfg, False) |
| hydra_cfg.runtime.output_dir = output_dir |
| OmegaConf.set_readonly(hydra_cfg, True) |
| |
| output_dir = Path(hydra_cfg.runtime.output_dir) |
| if not output_dir.exists(): |
| output_dir.mkdir(parents=True, exist_ok=True) |
| if is_rank_zero: |
| print(cyan(f"Created output directory: {output_dir}")) |
| |
| if is_rank_zero: |
| print(cyan(f"Outputs will be saved to:"), output_dir) |
| (output_dir.parents[1] / "latest-run").unlink(missing_ok=True) |
| (output_dir.parents[1] / "latest-run").symlink_to(output_dir, target_is_directory=True) |
|
|
| |
| if cfg.wandb.mode != "disabled": |
| |
| resume = cfg.get("resume", None) |
| name = f"{cfg.name} ({output_dir.parent.name}/{output_dir.name})" if resume is None else None |
|
|
| if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline: |
| logger_cls = OfflineWandbLogger |
| else: |
| logger_cls = SpaceEfficientWandbLogger |
|
|
| offline = cfg.wandb.mode != "online" |
| logger = logger_cls( |
| name=name, |
| save_dir=str(output_dir), |
| offline=offline, |
| entity=cfg.wandb.entity, |
| project=cfg.wandb.project, |
| log_model=False, |
| config=OmegaConf.to_container(cfg), |
| id=resume, |
| resume="auto" |
| ) |
|
|
| else: |
| logger = None |
|
|
| |
| resume = cfg.get("resume", None) |
| load = cfg.get("load", None) |
| checkpoint_path = None |
| load_id = None |
| if load and not is_run_id(load): |
| checkpoint_path = load |
| if resume: |
| load_id = resume |
| elif load and is_run_id(load): |
| load_id = load |
| else: |
| load_id = None |
|
|
| if load_id: |
| checkpoint_path = get_latest_checkpoint(output_dir / "checkpoints") |
| |
| if checkpoint_path and is_rank_zero: |
| print(f"Will load checkpoint from {checkpoint_path}") |
|
|
| |
| experiment = build_experiment(cfg, logger, checkpoint_path) |
| for task in cfg.experiment.tasks: |
| experiment.exec_task(task) |
|
|
|
|
| def run_slurm(cfg: DictConfig): |
| python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True" |
| project_root = Path.cwd() |
| while not (project_root / ".git").exists(): |
| project_root = project_root.parent |
| if project_root == Path("/"): |
| raise Exception("Could not find repo directory!") |
|
|
| slurm_log_dir = submit_slurm_job( |
| cfg, |
| python_args, |
| project_root, |
| ) |
|
|
| if "cluster" in cfg and cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online": |
| print("Job submitted to a compute node without internet. This requires manual syncing on login node.") |
| osh_command_dir = project_root / ".wandb_osh_command_dir" |
|
|
| osh_proc = None |
| |
| osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir]) |
| print(f"Running wandb-osh in background... PID: {osh_proc.pid}") |
| print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.") |
| print( |
| f"You can manually start a sync loop later by running the following:", |
| cyan(f"wandb-osh --command-dir {osh_command_dir}"), |
| ) |
|
|
| print( |
| "Once the job gets allocated and starts running, we will print a command below " |
| "for you to trace the errors and outputs: (Ctrl + C to exit without waiting)" |
| ) |
| msg = f"tail -f {slurm_log_dir}/* \n" |
| try: |
| while not list(slurm_log_dir.glob("*.out")) and not list(slurm_log_dir.glob("*.err")): |
| time.sleep(1) |
| print(cyan("To trace the outputs and errors, run the following command:"), msg) |
| except KeyboardInterrupt: |
| print("Keyboard interrupt detected. Exiting...") |
| print( |
| cyan("To trace the outputs and errors, manually wait for the job to start and run the following command:"), |
| msg, |
| ) |
|
|
|
|
| @hydra.main( |
| version_base=None, |
| config_path="configurations", |
| config_name="training", |
| ) |
| def run(cfg: DictConfig): |
| if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline: |
| with open_dict(cfg): |
| if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online": |
| cfg.wandb.mode = "offline" |
|
|
| if "name" not in cfg: |
| raise ValueError("must specify a name for the run with command line argument '+name=[name]'") |
|
|
| if not cfg.wandb.get("entity", None): |
| raise ValueError( |
| "must specify wandb entity in 'configurations/config.yaml' or with command line" |
| " argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group" |
| " name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/" |
| ) |
|
|
| if cfg.wandb.project is None: |
| cfg.wandb.project = str(Path(__file__).parent.name) |
|
|
| if "cluster" in cfg and not "_on_compute_node" in cfg: |
| print(cyan("Slurm detected, submitting to compute node instead of running locally...")) |
| run_slurm(cfg) |
| else: |
| run_local(cfg) |
|
|
|
|
| if __name__ == "__main__": |
| run() |
|
|