| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os |
|
|
| import torch |
| from lightning.pytorch import Trainer |
| from omegaconf import OmegaConf |
|
|
| from nemo.collections.speechlm2 import SALM, DataModule, SALMDataset |
| from nemo.core.config import hydra_runner |
| from nemo.utils.exp_manager import exp_manager |
| from nemo.utils.trainer_utils import resolve_trainer_cfg |
|
|
| torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) |
|
|
|
|
| @hydra_runner(config_path="conf", config_name="salm") |
| def train(cfg): |
| OmegaConf.resolve(cfg) |
| torch.distributed.init_process_group(backend="nccl") |
| torch.set_float32_matmul_precision("medium") |
| trainer = Trainer(**resolve_trainer_cfg(cfg.trainer)) |
| log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) |
| OmegaConf.save(cfg, log_dir / "exp_config.yaml") |
|
|
| with trainer.init_module(): |
| model = SALM(OmegaConf.to_container(cfg.model, resolve=True)) |
|
|
| dataset = SALMDataset(tokenizer=model.tokenizer) |
| datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) |
|
|
| trainer.fit(model, datamodule) |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|