File size: 3,823 Bytes
b5a0bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
import asyncio
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import hydra
import submitit
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict, read_write
from stopes.core import Requirements, StopesModule
from lcm.train.common import get_trainer
from lcm.utils.common import setup_conf
setup_conf()
class TrainModule(StopesModule):
def requirements(self) -> Requirements:
return self.config.requirements
def run(self, iteration_value: Optional[Any] = None, iteration_index: int = 0):
# Add module.name to the config's log_folder
with read_write(self.config):
self.config.log_folder = Path(self.config.log_folder) / self.name()
trainer = get_trainer(self.config)
# trainer should have a run() method
trainer.run()
def should_retry(
self,
ex: Exception,
attempt: int,
iteration_value: Optional[Any] = None,
iteration_index: int = 0,
) -> bool:
# Before retrying the failed train run, clean the environment to make sure
# fs2 ProcessGroupGang can set up properly without raising error if the
# gang is not set up reliably
with submitit.helpers.clean_env():
return "ValueError" not in str(ex)
def name(self):
"""
implement this if you want to give a fancy name to your job
"""
name = self.config.get(
"experiment_name", f"{self.__class__.__name__}_{self.sha_key()[:10]}"
)
return name
@dataclass
class TrainingConfig:
trainer: DictConfig
launcher: DictConfig
dry_run: bool = False
async def run(config: TrainingConfig):
# dump the all config to the outputs config log
dump_dir = Path(config.launcher.config_dump_dir)
dump_dir.mkdir(parents=True, exist_ok=True)
OmegaConf.resolve(config) # type: ignore
# XXX: do we want to promote datasets configs from thier names to the final params
OmegaConf.save(
config=config,
f=str(dump_dir / "all_config.yaml"),
)
train_config = config.trainer
# If launcher.cluster = debug set debug in the trainer to True
with open_dict(train_config):
if config.launcher.cluster == "debug":
train_config.debug = True
train_config.log_folder = config.launcher.log_folder
if getattr(config, "dry_run", False):
trainer = get_trainer(train_config)
print(f"Trainer: {trainer}")
print(f"Train config: {getattr(trainer, 'config')}")
return
launcher = hydra.utils.instantiate(config.launcher)
train_module = TrainModule(train_config)
wait_on = launcher.schedule(train_module)
await wait_on
@hydra.main(
version_base="1.2",
config_path="../../recipes/train",
config_name="defaults.yaml",
)
def main(config: TrainingConfig) -> None:
"""
Launch a train module from CLI.
Example:
```sh
python -m lcm.train +pretrain=mse
```
in this example, `pretrain` is a folder under the `recipes` directory and `mse`
is a yaml file with the trainer configuration.
This yaml file must be in the `trainer` package (i.e. start with the `# @package trainer`
hydra directive).
It must contain a `__trainer__` entry defining the constructor for the trainer.
You can use `-c job` to see the configuration without running anything. You can use
`dry_run=true` to initialize the trainer from the configuration and make sure it's correct
without running the actual training. To debug the jobs, you can use `launcher.cluster=debug`
"""
asyncio.run(run(config))
if __name__ == "__main__":
main()
|