File size: 12,274 Bytes
b47a1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
"""
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research 
template [repo](https://github.com/buoyancy99/research-template). 
By its MIT license, you must keep the above sentence in `README.md` 
and the `LICENSE` file to credit the author.

Main file for the project. This will create and run new experiments and load checkpoints from wandb. 
Borrowed part of the code from David Charatan and wandb.
"""

import os
import sys
import subprocess
import time
import re
from pathlib import Path

import hydra
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict

from utils.print_utils import cyan
from utils.ckpt_utils import download_latest_checkpoint, is_run_id
from utils.cluster_utils import submit_slurm_job
from utils.distributed_utils import is_rank_zero

WANDB_RUN_ID_FILE = ".wandb_run_id"


def get_latest_checkpoint(checkpoint_folder: Path, pattern: str = "*.ckpt"):
    if not checkpoint_folder.exists():
        return None

    checkpoint_files = [path for path in checkpoint_folder.glob(pattern) if path.is_file() and path.stat().st_size > 0]
    if not checkpoint_files:
        return None

    last_checkpoint = checkpoint_folder / "last.ckpt"
    if last_checkpoint in checkpoint_files:
        return last_checkpoint

    def checkpoint_key(path: Path):
        step_match = re.search(r"step[=_-]?(\d+)", path.stem)
        step = int(step_match.group(1)) if step_match else -1
        return step, path.stat().st_mtime

    return max(checkpoint_files, key=checkpoint_key)


def validate_resume_checkpoint(checkpoint_path: Path) -> Path:
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Resume checkpoint does not exist: {checkpoint_path}")
    if not checkpoint_path.is_file():
        raise ValueError(f"Resume checkpoint is not a file: {checkpoint_path}")
    if checkpoint_path.suffix != ".ckpt":
        raise ValueError(f"Resume checkpoint must be a .ckpt file: {checkpoint_path}")
    if checkpoint_path.stat().st_size == 0:
        raise ValueError(f"Resume checkpoint is empty: {checkpoint_path}")
    return checkpoint_path


def discover_wandb_run_id(output_dir: Path):
    run_id_file = output_dir / WANDB_RUN_ID_FILE
    if run_id_file.exists():
        run_id = run_id_file.read_text().strip()
        if not run_id:
            raise ValueError(f"W&B run id file is empty: {run_id_file}")
        return run_id

    wandb_dir = output_dir / "wandb"
    if wandb_dir.exists():
        run_dirs = [path for path in wandb_dir.iterdir() if path.is_dir()]
        run_dirs = [path for path in run_dirs if re.match(r"(offline-run|run)-.+-[A-Za-z0-9]+$", path.name)]
        if run_dirs:
            run_dir = max(run_dirs, key=lambda path: path.stat().st_mtime)
            return run_dir.name.rsplit("-", 1)[-1]
    return None


def get_process_rank() -> int:
    for env_name in ("RANK", "SLURM_PROCID", "LOCAL_RANK"):
        value = os.environ.get(env_name)
        if value is not None:
            return int(value)
    return 0


def wait_for_wandb_run_id(output_dir: Path, timeout_s: float = 300.0):
    run_id_file = output_dir / WANDB_RUN_ID_FILE
    deadline = time.time() + timeout_s
    while time.time() < deadline:
        if run_id_file.exists():
            run_id = run_id_file.read_text().strip()
            if run_id:
                return run_id
        time.sleep(0.5)
    raise TimeoutError(f"Timed out waiting for rank 0 to create W&B run id file: {run_id_file}")


def create_wandb_run_id_on_rank_zero(output_dir: Path, requested_run_id=None):
    if requested_run_id:
        run_id = requested_run_id
    else:
        run_id = discover_wandb_run_id(output_dir)
        if run_id is None:
            import wandb
            run_id = wandb.util.generate_id()

    output_dir.mkdir(parents=True, exist_ok=True)
    run_id_file = output_dir / WANDB_RUN_ID_FILE
    if run_id_file.exists():
        existing_run_id = run_id_file.read_text().strip()
        if existing_run_id and existing_run_id != run_id:
            raise ValueError(
                f"Output directory already belongs to W&B run id {existing_run_id}, "
                f"but {run_id} was requested. Use a different output_dir or resume id."
            )
    run_id_file.write_text(f"{run_id}\n")
    return run_id


def get_or_create_wandb_run_id(output_dir: Path, requested_run_id=None):
    if get_process_rank() == 0:
        return create_wandb_run_id_on_rank_zero(output_dir, requested_run_id=requested_run_id)
    return wait_for_wandb_run_id(output_dir)


def run_local(cfg: DictConfig):
    # delay some imports in case they are not needed in non-local envs for submission
    from experiments import build_experiment
    from utils.wandb_utils import OfflineWandbLogger, SpaceEfficientWandbLogger
    import lightning.pytorch as pl

    # Set global seed for reproducibility
    if cfg.get("seed", None) is not None:
        pl.seed_everything(cfg.seed, workers=True)

    # Get yaml names
    hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
    cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)

    with open_dict(cfg):
        if cfg_choice["experiment"] is not None:
            cfg.experiment._name = cfg_choice["experiment"]
        if cfg_choice["dataset"] is not None:
            cfg.dataset._name = cfg_choice["dataset"]
        if cfg_choice["algorithm"] is not None:
            cfg.algorithm._name = cfg_choice["algorithm"]

    # Set up the output directory.
    output_dir = getattr(cfg, "output_dir", None)
    if output_dir is not None:
        OmegaConf.set_readonly(hydra_cfg, False)
        hydra_cfg.runtime.output_dir = output_dir
        OmegaConf.set_readonly(hydra_cfg, True)
        
    output_dir = Path(hydra_cfg.runtime.output_dir)
    
    if is_rank_zero:
        print(cyan(f"Outputs will be saved to:"), output_dir)
        (output_dir.parents[1] / "latest-run").unlink(missing_ok=True)
        (output_dir.parents[1] / "latest-run").symlink_to(output_dir, target_is_directory=True)

    training_requested = "training" in cfg.experiment.tasks
    checkpoint_dir = output_dir / "checkpoints"
    auto_resume = bool(getattr(cfg, "auto_resume", True))
    explicit_resume_ckpt = getattr(cfg, "resume_ckpt_path", None)
    auto_resume_checkpoint_path = None
    if training_requested:
        if explicit_resume_ckpt:
            auto_resume_checkpoint_path = validate_resume_checkpoint(Path(explicit_resume_ckpt))
        elif auto_resume:
            auto_resume_checkpoint_path = get_latest_checkpoint(checkpoint_dir)
            if auto_resume_checkpoint_path is not None:
                auto_resume_checkpoint_path = validate_resume_checkpoint(auto_resume_checkpoint_path)

    if auto_resume_checkpoint_path and is_rank_zero:
        print(cyan("Auto-resuming training from:"), auto_resume_checkpoint_path)

    with open_dict(cfg):
        cfg._auto_resuming = auto_resume_checkpoint_path is not None
        cfg._resume_checkpoint_path = str(auto_resume_checkpoint_path) if auto_resume_checkpoint_path else None

    # Set up logging with wandb.
    if cfg.wandb.mode != "disabled":
        # If resuming, merge into the existing run on wandb.
        resume = cfg.get("resume", None)
        wandb_run_id = get_or_create_wandb_run_id(output_dir, requested_run_id=resume)
        name = None if auto_resume_checkpoint_path else f"{cfg.name} ({output_dir.parent.name}/{output_dir.name})"

        if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
            logger_cls = OfflineWandbLogger
        else:
            logger_cls = SpaceEfficientWandbLogger

        offline = cfg.wandb.mode != "online"
        logger = logger_cls(
            name=name,
            save_dir=str(output_dir),
            offline=offline,
            entity=cfg.wandb.entity,
            project=cfg.wandb.project,
            log_model=False,
            config=OmegaConf.to_container(cfg),
            id=wandb_run_id,
            resume="auto"
        )

    else:
        logger = None

    # Load ckpt
    resume = cfg.get("resume", None)
    load = cfg.get("load", None)
    checkpoint_path = auto_resume_checkpoint_path
    load_id = None
    if checkpoint_path is None and load and not is_run_id(load):
        checkpoint_path = load
    if checkpoint_path is None and resume:
        load_id = resume
    elif checkpoint_path is None and load and is_run_id(load):
        load_id = load
    else:
        load_id = None

    if load_id:
        checkpoint_path = get_latest_checkpoint(output_dir / "checkpoints")
        if checkpoint_path is None:
            raise FileNotFoundError(f"No checkpoint found under {output_dir / 'checkpoints'} for run id {load_id}")
        checkpoint_path = validate_resume_checkpoint(checkpoint_path)
    
    if checkpoint_path and is_rank_zero:
        print(f"Will load checkpoint from {checkpoint_path}")

    # launch experiment
    experiment = build_experiment(cfg, logger, checkpoint_path)
    for task in cfg.experiment.tasks:
        experiment.exec_task(task)


def run_slurm(cfg: DictConfig):
    python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True"
    project_root = Path.cwd()
    while not (project_root / ".git").exists():
        project_root = project_root.parent
        if project_root == Path("/"):
            raise Exception("Could not find repo directory!")

    slurm_log_dir = submit_slurm_job(
        cfg,
        python_args,
        project_root,
    )

    if "cluster" in cfg and cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
        print("Job submitted to a compute node without internet. This requires manual syncing on login node.")
        osh_command_dir = project_root / ".wandb_osh_command_dir"

        osh_proc = None
        # if click.confirm("Do you want us to run the sync loop for you?", default=True):
        osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir])
        print(f"Running wandb-osh in background... PID: {osh_proc.pid}")
        print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.")
        print(
            f"You can manually start a sync loop later by running the following:",
            cyan(f"wandb-osh --command-dir {osh_command_dir}"),
        )

    print(
        "Once the job gets allocated and starts running, we will print a command below "
        "for you to trace the errors and outputs: (Ctrl + C to exit without waiting)"
    )
    msg = f"tail -f {slurm_log_dir}/* \n"
    try:
        while not list(slurm_log_dir.glob("*.out")) and not list(slurm_log_dir.glob("*.err")):
            time.sleep(1)
        print(cyan("To trace the outputs and errors, run the following command:"), msg)
    except KeyboardInterrupt:
        print("Keyboard interrupt detected. Exiting...")
        print(
            cyan("To trace the outputs and errors, manually wait for the job to start and run the following command:"),
            msg,
        )


@hydra.main(
    version_base=None,
    config_path="configurations",
    config_name="training",
)
def run(cfg: DictConfig):
    if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline:
        with open_dict(cfg):
            if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online":
                cfg.wandb.mode = "offline"

    if "name" not in cfg:
        raise ValueError("must specify a name for the run with command line argument '+name=[name]'")

    if not cfg.wandb.get("entity", None):
        raise ValueError(
            "must specify wandb entity in 'configurations/config.yaml' or with command line"
            " argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group"
            " name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/"
        )

    if cfg.wandb.project is None:
        cfg.wandb.project = str(Path(__file__).parent.name)

    if "cluster" in cfg and not "_on_compute_node" in cfg:
        print(cyan("Slurm detected, submitting to compute node instead of running locally..."))
        run_slurm(cfg)
    else:
        run_local(cfg)


if __name__ == "__main__":
    run()  # pylint: disable=no-value-for-parameter