File size: 12,274 Bytes
b47a1ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 | """
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
template [repo](https://github.com/buoyancy99/research-template).
By its MIT license, you must keep the above sentence in `README.md`
and the `LICENSE` file to credit the author.
Main file for the project. This will create and run new experiments and load checkpoints from wandb.
Borrowed part of the code from David Charatan and wandb.
"""
import os
import sys
import subprocess
import time
import re
from pathlib import Path
import hydra
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict
from utils.print_utils import cyan
from utils.ckpt_utils import download_latest_checkpoint, is_run_id
from utils.cluster_utils import submit_slurm_job
from utils.distributed_utils import is_rank_zero
WANDB_RUN_ID_FILE = ".wandb_run_id"
def get_latest_checkpoint(checkpoint_folder: Path, pattern: str = "*.ckpt"):
if not checkpoint_folder.exists():
return None
checkpoint_files = [path for path in checkpoint_folder.glob(pattern) if path.is_file() and path.stat().st_size > 0]
if not checkpoint_files:
return None
last_checkpoint = checkpoint_folder / "last.ckpt"
if last_checkpoint in checkpoint_files:
return last_checkpoint
def checkpoint_key(path: Path):
step_match = re.search(r"step[=_-]?(\d+)", path.stem)
step = int(step_match.group(1)) if step_match else -1
return step, path.stat().st_mtime
return max(checkpoint_files, key=checkpoint_key)
def validate_resume_checkpoint(checkpoint_path: Path) -> Path:
if not checkpoint_path.exists():
raise FileNotFoundError(f"Resume checkpoint does not exist: {checkpoint_path}")
if not checkpoint_path.is_file():
raise ValueError(f"Resume checkpoint is not a file: {checkpoint_path}")
if checkpoint_path.suffix != ".ckpt":
raise ValueError(f"Resume checkpoint must be a .ckpt file: {checkpoint_path}")
if checkpoint_path.stat().st_size == 0:
raise ValueError(f"Resume checkpoint is empty: {checkpoint_path}")
return checkpoint_path
def discover_wandb_run_id(output_dir: Path):
run_id_file = output_dir / WANDB_RUN_ID_FILE
if run_id_file.exists():
run_id = run_id_file.read_text().strip()
if not run_id:
raise ValueError(f"W&B run id file is empty: {run_id_file}")
return run_id
wandb_dir = output_dir / "wandb"
if wandb_dir.exists():
run_dirs = [path for path in wandb_dir.iterdir() if path.is_dir()]
run_dirs = [path for path in run_dirs if re.match(r"(offline-run|run)-.+-[A-Za-z0-9]+$", path.name)]
if run_dirs:
run_dir = max(run_dirs, key=lambda path: path.stat().st_mtime)
return run_dir.name.rsplit("-", 1)[-1]
return None
def get_process_rank() -> int:
for env_name in ("RANK", "SLURM_PROCID", "LOCAL_RANK"):
value = os.environ.get(env_name)
if value is not None:
return int(value)
return 0
def wait_for_wandb_run_id(output_dir: Path, timeout_s: float = 300.0):
run_id_file = output_dir / WANDB_RUN_ID_FILE
deadline = time.time() + timeout_s
while time.time() < deadline:
if run_id_file.exists():
run_id = run_id_file.read_text().strip()
if run_id:
return run_id
time.sleep(0.5)
raise TimeoutError(f"Timed out waiting for rank 0 to create W&B run id file: {run_id_file}")
def create_wandb_run_id_on_rank_zero(output_dir: Path, requested_run_id=None):
if requested_run_id:
run_id = requested_run_id
else:
run_id = discover_wandb_run_id(output_dir)
if run_id is None:
import wandb
run_id = wandb.util.generate_id()
output_dir.mkdir(parents=True, exist_ok=True)
run_id_file = output_dir / WANDB_RUN_ID_FILE
if run_id_file.exists():
existing_run_id = run_id_file.read_text().strip()
if existing_run_id and existing_run_id != run_id:
raise ValueError(
f"Output directory already belongs to W&B run id {existing_run_id}, "
f"but {run_id} was requested. Use a different output_dir or resume id."
)
run_id_file.write_text(f"{run_id}\n")
return run_id
def get_or_create_wandb_run_id(output_dir: Path, requested_run_id=None):
if get_process_rank() == 0:
return create_wandb_run_id_on_rank_zero(output_dir, requested_run_id=requested_run_id)
return wait_for_wandb_run_id(output_dir)
def run_local(cfg: DictConfig):
# delay some imports in case they are not needed in non-local envs for submission
from experiments import build_experiment
from utils.wandb_utils import OfflineWandbLogger, SpaceEfficientWandbLogger
import lightning.pytorch as pl
# Set global seed for reproducibility
if cfg.get("seed", None) is not None:
pl.seed_everything(cfg.seed, workers=True)
# Get yaml names
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
with open_dict(cfg):
if cfg_choice["experiment"] is not None:
cfg.experiment._name = cfg_choice["experiment"]
if cfg_choice["dataset"] is not None:
cfg.dataset._name = cfg_choice["dataset"]
if cfg_choice["algorithm"] is not None:
cfg.algorithm._name = cfg_choice["algorithm"]
# Set up the output directory.
output_dir = getattr(cfg, "output_dir", None)
if output_dir is not None:
OmegaConf.set_readonly(hydra_cfg, False)
hydra_cfg.runtime.output_dir = output_dir
OmegaConf.set_readonly(hydra_cfg, True)
output_dir = Path(hydra_cfg.runtime.output_dir)
if is_rank_zero:
print(cyan(f"Outputs will be saved to:"), output_dir)
(output_dir.parents[1] / "latest-run").unlink(missing_ok=True)
(output_dir.parents[1] / "latest-run").symlink_to(output_dir, target_is_directory=True)
training_requested = "training" in cfg.experiment.tasks
checkpoint_dir = output_dir / "checkpoints"
auto_resume = bool(getattr(cfg, "auto_resume", True))
explicit_resume_ckpt = getattr(cfg, "resume_ckpt_path", None)
auto_resume_checkpoint_path = None
if training_requested:
if explicit_resume_ckpt:
auto_resume_checkpoint_path = validate_resume_checkpoint(Path(explicit_resume_ckpt))
elif auto_resume:
auto_resume_checkpoint_path = get_latest_checkpoint(checkpoint_dir)
if auto_resume_checkpoint_path is not None:
auto_resume_checkpoint_path = validate_resume_checkpoint(auto_resume_checkpoint_path)
if auto_resume_checkpoint_path and is_rank_zero:
print(cyan("Auto-resuming training from:"), auto_resume_checkpoint_path)
with open_dict(cfg):
cfg._auto_resuming = auto_resume_checkpoint_path is not None
cfg._resume_checkpoint_path = str(auto_resume_checkpoint_path) if auto_resume_checkpoint_path else None
# Set up logging with wandb.
if cfg.wandb.mode != "disabled":
# If resuming, merge into the existing run on wandb.
resume = cfg.get("resume", None)
wandb_run_id = get_or_create_wandb_run_id(output_dir, requested_run_id=resume)
name = None if auto_resume_checkpoint_path else f"{cfg.name} ({output_dir.parent.name}/{output_dir.name})"
if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
logger_cls = OfflineWandbLogger
else:
logger_cls = SpaceEfficientWandbLogger
offline = cfg.wandb.mode != "online"
logger = logger_cls(
name=name,
save_dir=str(output_dir),
offline=offline,
entity=cfg.wandb.entity,
project=cfg.wandb.project,
log_model=False,
config=OmegaConf.to_container(cfg),
id=wandb_run_id,
resume="auto"
)
else:
logger = None
# Load ckpt
resume = cfg.get("resume", None)
load = cfg.get("load", None)
checkpoint_path = auto_resume_checkpoint_path
load_id = None
if checkpoint_path is None and load and not is_run_id(load):
checkpoint_path = load
if checkpoint_path is None and resume:
load_id = resume
elif checkpoint_path is None and load and is_run_id(load):
load_id = load
else:
load_id = None
if load_id:
checkpoint_path = get_latest_checkpoint(output_dir / "checkpoints")
if checkpoint_path is None:
raise FileNotFoundError(f"No checkpoint found under {output_dir / 'checkpoints'} for run id {load_id}")
checkpoint_path = validate_resume_checkpoint(checkpoint_path)
if checkpoint_path and is_rank_zero:
print(f"Will load checkpoint from {checkpoint_path}")
# launch experiment
experiment = build_experiment(cfg, logger, checkpoint_path)
for task in cfg.experiment.tasks:
experiment.exec_task(task)
def run_slurm(cfg: DictConfig):
python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True"
project_root = Path.cwd()
while not (project_root / ".git").exists():
project_root = project_root.parent
if project_root == Path("/"):
raise Exception("Could not find repo directory!")
slurm_log_dir = submit_slurm_job(
cfg,
python_args,
project_root,
)
if "cluster" in cfg and cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
print("Job submitted to a compute node without internet. This requires manual syncing on login node.")
osh_command_dir = project_root / ".wandb_osh_command_dir"
osh_proc = None
# if click.confirm("Do you want us to run the sync loop for you?", default=True):
osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir])
print(f"Running wandb-osh in background... PID: {osh_proc.pid}")
print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.")
print(
f"You can manually start a sync loop later by running the following:",
cyan(f"wandb-osh --command-dir {osh_command_dir}"),
)
print(
"Once the job gets allocated and starts running, we will print a command below "
"for you to trace the errors and outputs: (Ctrl + C to exit without waiting)"
)
msg = f"tail -f {slurm_log_dir}/* \n"
try:
while not list(slurm_log_dir.glob("*.out")) and not list(slurm_log_dir.glob("*.err")):
time.sleep(1)
print(cyan("To trace the outputs and errors, run the following command:"), msg)
except KeyboardInterrupt:
print("Keyboard interrupt detected. Exiting...")
print(
cyan("To trace the outputs and errors, manually wait for the job to start and run the following command:"),
msg,
)
@hydra.main(
version_base=None,
config_path="configurations",
config_name="training",
)
def run(cfg: DictConfig):
if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
with open_dict(cfg):
if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
cfg.wandb.mode = "offline"
if "name" not in cfg:
raise ValueError("must specify a name for the run with command line argument '+name=[name]'")
if not cfg.wandb.get("entity", None):
raise ValueError(
"must specify wandb entity in 'configurations/config.yaml' or with command line"
" argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group"
" name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/"
)
if cfg.wandb.project is None:
cfg.wandb.project = str(Path(__file__).parent.name)
if "cluster" in cfg and not "_on_compute_node" in cfg:
print(cyan("Slurm detected, submitting to compute node instead of running locally..."))
run_slurm(cfg)
else:
run_local(cfg)
if __name__ == "__main__":
run() # pylint: disable=no-value-for-parameter
|