File size: 3,823 Bytes
b5a0bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#


import asyncio
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional

import hydra
import submitit
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict, read_write
from stopes.core import Requirements, StopesModule

from lcm.train.common import get_trainer
from lcm.utils.common import setup_conf

setup_conf()


class TrainModule(StopesModule):
    def requirements(self) -> Requirements:
        return self.config.requirements

    def run(self, iteration_value: Optional[Any] = None, iteration_index: int = 0):
        # Add module.name to the config's log_folder
        with read_write(self.config):
            self.config.log_folder = Path(self.config.log_folder) / self.name()

        trainer = get_trainer(self.config)

        # trainer should have a run() method
        trainer.run()

    def should_retry(
        self,
        ex: Exception,
        attempt: int,
        iteration_value: Optional[Any] = None,
        iteration_index: int = 0,
    ) -> bool:
        # Before retrying the failed train run, clean the environment to make sure
        # fs2 ProcessGroupGang can set up properly without raising error if the
        # gang is not set up reliably
        with submitit.helpers.clean_env():
            return "ValueError" not in str(ex)

    def name(self):
        """
        implement this if you want to give a fancy name to your job
        """
        name = self.config.get(
            "experiment_name", f"{self.__class__.__name__}_{self.sha_key()[:10]}"
        )
        return name


@dataclass
class TrainingConfig:
    trainer: DictConfig
    launcher: DictConfig
    dry_run: bool = False


async def run(config: TrainingConfig):
    # dump the all config to the outputs config log
    dump_dir = Path(config.launcher.config_dump_dir)
    dump_dir.mkdir(parents=True, exist_ok=True)
    OmegaConf.resolve(config)  # type: ignore
    # XXX: do we want to promote datasets configs from thier names to the final params
    OmegaConf.save(
        config=config,
        f=str(dump_dir / "all_config.yaml"),
    )

    train_config = config.trainer

    # If launcher.cluster = debug set debug in the trainer to True
    with open_dict(train_config):
        if config.launcher.cluster == "debug":
            train_config.debug = True
        train_config.log_folder = config.launcher.log_folder

    if getattr(config, "dry_run", False):
        trainer = get_trainer(train_config)
        print(f"Trainer: {trainer}")
        print(f"Train config: {getattr(trainer, 'config')}")

        return

    launcher = hydra.utils.instantiate(config.launcher)

    train_module = TrainModule(train_config)
    wait_on = launcher.schedule(train_module)

    await wait_on


@hydra.main(
    version_base="1.2",
    config_path="../../recipes/train",
    config_name="defaults.yaml",
)
def main(config: TrainingConfig) -> None:
    """
    Launch a train module from CLI.

    Example:

    ```sh
    python -m lcm.train +pretrain=mse
    ```

    in this example, `pretrain` is a folder under the `recipes` directory and `mse`
    is a yaml file with the trainer configuration.
    This yaml file must be in the `trainer` package (i.e. start with the `# @package trainer`
    hydra directive).
    It must contain a `__trainer__` entry defining the constructor for the trainer.

    You can use `-c job` to see the configuration without running anything. You can use
    `dry_run=true` to initialize the trainer from the configuration and make sure it's correct
    without running the actual training. To debug the jobs, you can use `launcher.cluster=debug`
    """
    asyncio.run(run(config))


if __name__ == "__main__":
    main()