Lexa
Converted .pt files to safetensors, then (dirtily) patched fairseq to enable loading of safetensor files
b5a0bec
| # Copyright (c) Meta Platforms, Inc. and affiliates | |
| # All rights reserved. | |
| # | |
| # | |
| import asyncio | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| import hydra | |
| import submitit | |
| from omegaconf import DictConfig, OmegaConf | |
| from omegaconf.omegaconf import open_dict, read_write | |
| from stopes.core import Requirements, StopesModule | |
| from lcm.train.common import get_trainer | |
| from lcm.utils.common import setup_conf | |
| setup_conf() | |
| class TrainModule(StopesModule): | |
| def requirements(self) -> Requirements: | |
| return self.config.requirements | |
| def run(self, iteration_value: Optional[Any] = None, iteration_index: int = 0): | |
| # Add module.name to the config's log_folder | |
| with read_write(self.config): | |
| self.config.log_folder = Path(self.config.log_folder) / self.name() | |
| trainer = get_trainer(self.config) | |
| # trainer should have a run() method | |
| trainer.run() | |
| def should_retry( | |
| self, | |
| ex: Exception, | |
| attempt: int, | |
| iteration_value: Optional[Any] = None, | |
| iteration_index: int = 0, | |
| ) -> bool: | |
| # Before retrying the failed train run, clean the environment to make sure | |
| # fs2 ProcessGroupGang can set up properly without raising error if the | |
| # gang is not set up reliably | |
| with submitit.helpers.clean_env(): | |
| return "ValueError" not in str(ex) | |
| def name(self): | |
| """ | |
| implement this if you want to give a fancy name to your job | |
| """ | |
| name = self.config.get( | |
| "experiment_name", f"{self.__class__.__name__}_{self.sha_key()[:10]}" | |
| ) | |
| return name | |
| class TrainingConfig: | |
| trainer: DictConfig | |
| launcher: DictConfig | |
| dry_run: bool = False | |
| async def run(config: TrainingConfig): | |
| # dump the all config to the outputs config log | |
| dump_dir = Path(config.launcher.config_dump_dir) | |
| dump_dir.mkdir(parents=True, exist_ok=True) | |
| OmegaConf.resolve(config) # type: ignore | |
| # XXX: do we want to promote datasets configs from thier names to the final params | |
| OmegaConf.save( | |
| config=config, | |
| f=str(dump_dir / "all_config.yaml"), | |
| ) | |
| train_config = config.trainer | |
| # If launcher.cluster = debug set debug in the trainer to True | |
| with open_dict(train_config): | |
| if config.launcher.cluster == "debug": | |
| train_config.debug = True | |
| train_config.log_folder = config.launcher.log_folder | |
| if getattr(config, "dry_run", False): | |
| trainer = get_trainer(train_config) | |
| print(f"Trainer: {trainer}") | |
| print(f"Train config: {getattr(trainer, 'config')}") | |
| return | |
| launcher = hydra.utils.instantiate(config.launcher) | |
| train_module = TrainModule(train_config) | |
| wait_on = launcher.schedule(train_module) | |
| await wait_on | |
| def main(config: TrainingConfig) -> None: | |
| """ | |
| Launch a train module from CLI. | |
| Example: | |
| ```sh | |
| python -m lcm.train +pretrain=mse | |
| ``` | |
| in this example, `pretrain` is a folder under the `recipes` directory and `mse` | |
| is a yaml file with the trainer configuration. | |
| This yaml file must be in the `trainer` package (i.e. start with the `# @package trainer` | |
| hydra directive). | |
| It must contain a `__trainer__` entry defining the constructor for the trainer. | |
| You can use `-c job` to see the configuration without running anything. You can use | |
| `dry_run=true` to initialize the trainer from the configuration and make sure it's correct | |
| without running the actual training. To debug the jobs, you can use `launcher.cluster=debug` | |
| """ | |
| asyncio.run(run(config)) | |
| if __name__ == "__main__": | |
| main() | |