|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import torch |
|
|
from lightning.pytorch import Trainer |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
from nemo.collections.speechlm2 import DataModule, DuplexS2SDataset, DuplexS2SSpeechDecoderModel |
|
|
from nemo.core.config import hydra_runner |
|
|
from nemo.utils.exp_manager import exp_manager |
|
|
from nemo.utils.trainer_utils import resolve_trainer_cfg |
|
|
|
|
|
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) |
|
|
|
|
|
|
|
|
@hydra_runner(config_path="conf", config_name="s2s_duplex_speech_decoder") |
|
|
def train(cfg): |
|
|
OmegaConf.resolve(cfg) |
|
|
torch.distributed.init_process_group(backend="nccl") |
|
|
torch.set_float32_matmul_precision("medium") |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
trainer = Trainer(**resolve_trainer_cfg(cfg.trainer)) |
|
|
log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) |
|
|
OmegaConf.save(cfg, log_dir / "exp_config.yaml") |
|
|
|
|
|
with trainer.init_module(): |
|
|
model = DuplexS2SSpeechDecoderModel(OmegaConf.to_container(cfg.model, resolve=True)) |
|
|
|
|
|
dataset = DuplexS2SDataset( |
|
|
tokenizer=model.tokenizer, |
|
|
frame_length=cfg.data.frame_length, |
|
|
source_sample_rate=cfg.data.source_sample_rate, |
|
|
target_sample_rate=cfg.data.target_sample_rate, |
|
|
input_roles=cfg.data.input_roles, |
|
|
output_roles=cfg.data.output_roles, |
|
|
) |
|
|
datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset) |
|
|
|
|
|
trainer.fit(model, datamodule) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|