| """ |
| FlowMatchingTTS training entry point. |
| Mirrors cosyvoice/bin/train.py structure. |
| |
| Single-GPU: |
| python -m flow_matching.train --config config/flow_matching.yaml |
| |
| Multi-GPU (torchrun): |
| torchrun --nproc_per_node=8 -m flow_matching.train \\ |
| --config config/flow_matching.yaml |
| """ |
| from __future__ import print_function |
|
|
| import argparse |
| import datetime |
| import logging |
| import os |
| import sys |
|
|
| logging.getLogger('matplotlib').setLevel(logging.WARNING) |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.elastic.multiprocessing.errors import record |
| from torch.utils.data import DataLoader, DistributedSampler |
| from torch.utils.tensorboard import SummaryWriter |
| from omegaconf import OmegaConf |
|
|
| from flow_matching.utils.scheduler import WarmupLR, NoamHoldAnnealing |
| from flow_matching.dataset import TTSDataset, collate_fn |
| from flow_matching.model import FlowMatchingTTS |
| from flow_matching.speaker_encoder import SpeakerEncoder |
| from flow_matching.executor import Executor |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description='Train FlowMatchingTTS') |
| parser.add_argument('--config', required=True, help='YAML config file') |
| parser.add_argument('--checkpoint', default=None, help='Resume from checkpoint') |
| parser.add_argument('--model_dir', default=None, help='Override output dir') |
| parser.add_argument('--timeout', default=30, type=int, |
| help='monitored_barrier timeout (seconds)') |
| return parser.parse_args() |
|
|
|
|
| @record |
| def main(): |
| args = get_args() |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s %(levelname)s %(message)s', |
| ) |
|
|
| cfg = OmegaConf.load(args.config) |
|
|
| |
| world_size = int(os.environ.get('WORLD_SIZE', 1)) |
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| rank = int(os.environ.get('RANK', 0)) |
|
|
| logging.info( |
| 'rank {} / {} local_rank {}'.format(rank, world_size, local_rank) |
| ) |
|
|
| if world_size > 1: |
| dist.init_process_group('nccl') |
| torch.cuda.set_device(local_rank) |
|
|
| |
| model_dir = args.model_dir or cfg.data.output_dir |
| if rank == 0: |
| os.makedirs(model_dir, exist_ok=True) |
|
|
| |
| train_dataset = TTSDataset( |
| dataset_path=list(cfg.data.dataset_paths), |
| sample_rate=cfg.data.sample_rate, |
| n_mels=cfg.data.n_mels, |
| n_fft=cfg.data.n_fft, |
| hop_length=cfg.data.hop_length, |
| max_duration=cfg.data.max_duration, |
| ) |
|
|
| _val_cfg = cfg.data.get('val_dataset_paths', None) |
| val_paths = OmegaConf.to_container(_val_cfg) if _val_cfg is not None else None |
| cv_dataset = None |
| if val_paths: |
| cv_dataset = TTSDataset( |
| dataset_path=val_paths, |
| sample_rate=cfg.data.sample_rate, |
| n_mels=cfg.data.n_mels, |
| n_fft=cfg.data.n_fft, |
| hop_length=cfg.data.hop_length, |
| max_duration=cfg.data.max_duration, |
| ) |
|
|
| train_sampler = DistributedSampler(train_dataset) if world_size > 1 else None |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=cfg.train.batch_size, |
| sampler=train_sampler, |
| shuffle=(train_sampler is None), |
| collate_fn=collate_fn, |
| num_workers=cfg.train.num_workers, |
| pin_memory=True, |
| drop_last=True, |
| ) |
| cv_loader = None |
| if cv_dataset is not None: |
| cv_loader = DataLoader( |
| cv_dataset, |
| batch_size=cfg.train.batch_size, |
| shuffle=False, |
| collate_fn=collate_fn, |
| num_workers=cfg.train.num_workers, |
| pin_memory=True, |
| ) |
|
|
| |
| model = FlowMatchingTTS(cfg) |
|
|
| if args.checkpoint is not None: |
| model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')) |
| if rank == 0: |
| logging.info('Loaded checkpoint: {}'.format(args.checkpoint)) |
|
|
| model.cuda() |
|
|
| if world_size > 1: |
| model = torch.nn.parallel.DistributedDataParallel( |
| model, find_unused_parameters=True |
| ) |
|
|
| if rank == 0: |
| n_params = sum(p.numel() for p in |
| (model.module if world_size > 1 else model).parameters()) |
| logging.info('Model parameters: {:.1f}M'.format(n_params / 1e6)) |
|
|
| |
| spk_enc = SpeakerEncoder( |
| cfg.data.campplus_ckpt, |
| device='cuda:{}'.format(local_rank), |
| ).cuda().eval() |
| for p in spk_enc.parameters(): |
| p.requires_grad_(False) |
|
|
| |
| optim_name = cfg.train.get('optim', 'adamw') |
| optim_conf = OmegaConf.to_container(cfg.train.get('optim_conf', {})) |
| if not optim_conf: |
| optim_conf = { |
| 'lr': cfg.train.lr, |
| 'weight_decay': cfg.train.weight_decay, |
| } |
|
|
| if optim_name == 'adam': |
| optimizer = torch.optim.Adam(model.parameters(), **optim_conf) |
| else: |
| optimizer = torch.optim.AdamW(model.parameters(), **optim_conf) |
|
|
| |
| sched_name = cfg.train.get('scheduler', 'warmuplr') |
| sched_conf = OmegaConf.to_container(cfg.train.get('scheduler_conf', {})) |
| if not sched_conf: |
| sched_conf = {'warmup_steps': 4000} |
|
|
| if sched_name == 'warmuplr': |
| scheduler = WarmupLR(optimizer, **sched_conf) |
| elif sched_name == 'NoamHoldAnnealing': |
| total_steps = len(train_loader) * cfg.train.num_epochs |
| scheduler = NoamHoldAnnealing( |
| optimizer, max_steps=total_steps, **sched_conf |
| ) |
| else: |
| raise ValueError('Unknown scheduler: {}'.format(sched_name)) |
|
|
| |
| writer = None |
| if rank == 0: |
| writer = SummaryWriter(os.path.join(model_dir, 'tb')) |
|
|
| |
| info_dict = { |
| 'model_dir': model_dir, |
| 'grad_clip': cfg.train.grad_clip, |
| 'accum_grad': int(cfg.train.get('accum_grad', 1)), |
| 'save_per_step': int(cfg.train.get('save_per_step', -1)), |
| 'log_interval': int(cfg.train.get('log_interval', 100)), |
| 'max_epoch': cfg.train.num_epochs, |
| 'batch_size': cfg.train.batch_size, |
| 'dtype': 'fp32', |
| 'train_engine': 'torch_ddp', |
| 'group_timeout': datetime.timedelta(seconds=int(cfg.train.get('timeout', args.timeout))), |
| } |
|
|
| |
| executor = Executor() |
|
|
| for epoch in range(cfg.train.num_epochs): |
| executor.epoch = epoch |
|
|
| if train_sampler is not None: |
| train_sampler.set_epoch(epoch) |
|
|
| if world_size > 1: |
| dist.barrier() |
|
|
| timeout = datetime.timedelta(seconds=int(cfg.train.get('timeout', args.timeout))) |
| group_join = ( |
| dist.new_group(backend='gloo', timeout=timeout) |
| if world_size > 1 else None |
| ) |
|
|
| executor.train_one_epoch( |
| model, optimizer, scheduler, |
| train_loader, cv_loader, |
| writer, info_dict, |
| spk_enc, group_join, |
| ) |
|
|
| if world_size > 1: |
| dist.destroy_process_group(group_join) |
|
|
| if writer is not None: |
| writer.close() |
| if world_size > 1: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|