File size: 4,448 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
from pathlib import Path

import wandb
from pytorch_lightning.loggers.wandb import WandbLogger
from omegaconf import OmegaConf

from optgs.misc.LocalLogger import LocalLogger
from optgs.paths import DEBUG


def version_to_int(artifact) -> int:
    """Convert versions of the form vX to X. For example, v12 to 12."""
    return int(artifact.version[1:])


def download_checkpoint(
    run_id: str,
    download_dir: Path,
    version: str | None,
) -> Path:
    api = wandb.Api()
    run = api.run(run_id)

    # Find the latest saved model checkpoint.
    chosen = None
    for artifact in run.logged_artifacts():
        if artifact.type != "model" or artifact.state != "COMMITTED":
            continue

        # If no version is specified, use the latest.
        if version is None:
            if chosen is None or version_to_int(artifact) > version_to_int(chosen):
                chosen = artifact

        # If a specific verison is specified, look for it.
        elif version == artifact.version:
            chosen = artifact
            break

    # Download the checkpoint.
    download_dir.mkdir(exist_ok=True, parents=True)
    root = download_dir / run_id
    chosen.download(root=root)
    return root / "model.ckpt"


def setup_wandb_logger(cfg, cfg_dict) -> WandbLogger | LocalLogger:
    if cfg_dict.wandb.mode == "disabled" or cfg.mode != "train":
        return LocalLogger()

    wandb_extra_kwargs = {}

    # Detect the wandb id job run if resuming
    if cfg_dict.checkpointing.resume:
        if cfg_dict.wandb.id is None:
            print(f"Resuming wandb run without id, using latest run in output directory.")
            # Find the latest wandb run id in the output directory
            wandb_dir = cfg_dict.output_dir / "wandb" / "latest-run"
            # look for a file name in the format "run-######.wandb" file and extract the id
            wandb_files = list(wandb_dir.glob("run-*.wandb"))
            assert len(wandb_files) <= 1, "Multiple wandb files found in the latest run directory."
            if len(wandb_files) == 1:
                wandb_file = wandb_files[0]
                wandb_id = wandb_file.stem.split('-')[1]
                wandb_extra_kwargs.update({'id': wandb_id, 'resume': "must"})

    if cfg_dict.wandb.id is not None:
        print(f"Setting wandb run with id from cfg {cfg_dict.wandb.id}.")
        wandb_extra_kwargs.update({'id': cfg_dict.wandb.id, 'resume': "must"})

    run_name = os.path.basename(cfg_dict.output_dir)

    if cfg_dict.log_slurm_id:
        hostname = os.uname().nodename
        job_id = os.environ.get('SLURM_JOB_ID', "local run: " + hostname)
        run_name += f" ({job_id})"

    # if debugging, add a tag to the run name
    if DEBUG:
        run_name += "  DEBUG"
        cfg_dict.wandb.update({'tags': ['debug']})
    if os.environ.get('WANDB_ENTITY') is not None:
        cfg_dict.wandb.update({'entity': os.environ.get('WANDB_ENTITY')})

    logger = WandbLogger(
        entity=cfg_dict.wandb.entity,
        project=cfg_dict.wandb.project,
        mode=cfg_dict.wandb.mode,
        name=run_name,
        tags=cfg_dict.wandb.get("tags", None),
        log_model=False,
        save_dir=cfg_dict.output_dir,
        config=OmegaConf.to_container(cfg_dict),
        **wandb_extra_kwargs,
    )

    if logger.experiment is not None:
        # Log code
        logger.experiment.log_code("optgs")
        # Log notes
        if cfg_dict.wandb.notes is not None:
            logger.experiment.notes = cfg_dict.wandb.notes
        # Write wandb run ID to file for SLURM requeue resume
        wandb_id_file = os.environ.get("WANDB_ID_FILE")
        if wandb_id_file:
            with open(wandb_id_file, "w") as f:
                f.write(logger.experiment.id)
            print(f"Wrote wandb run ID {logger.experiment.id} to {wandb_id_file}")

    return logger


def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None:
    if path is None:
        return None

    if not str(path).startswith("wandb://"):
        return Path(path)

    run_id, *version = path[len("wandb://") :].split(":")
    if len(version) == 0:
        version = None
    elif len(version) == 1:
        version = version[0]
    else:
        raise ValueError("Invalid version specifier!")

    project = wandb_cfg["project"]
    return download_checkpoint(
        f"{project}/{run_id}",
        Path("checkpoints"),
        version,
    )