""" FlowMatchingTTS training entry point. Mirrors cosyvoice/bin/train.py structure. Single-GPU: python -m flow_matching.train --config config/flow_matching.yaml Multi-GPU (torchrun): torchrun --nproc_per_node=8 -m flow_matching.train \\ --config config/flow_matching.yaml """ from __future__ import print_function import argparse import datetime import logging import os import sys logging.getLogger('matplotlib').setLevel(logging.WARNING) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import torch.distributed as dist from torch.distributed.elastic.multiprocessing.errors import record from torch.utils.data import DataLoader, DistributedSampler from torch.utils.tensorboard import SummaryWriter from omegaconf import OmegaConf from flow_matching.utils.scheduler import WarmupLR, NoamHoldAnnealing from flow_matching.dataset import TTSDataset, collate_fn from flow_matching.model import FlowMatchingTTS from flow_matching.speaker_encoder import SpeakerEncoder from flow_matching.executor import Executor def get_args(): parser = argparse.ArgumentParser(description='Train FlowMatchingTTS') parser.add_argument('--config', required=True, help='YAML config file') parser.add_argument('--checkpoint', default=None, help='Resume from checkpoint') parser.add_argument('--model_dir', default=None, help='Override output dir') parser.add_argument('--timeout', default=30, type=int, help='monitored_barrier timeout (seconds)') return parser.parse_args() @record def main(): args = get_args() logging.basicConfig( level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s', ) cfg = OmegaConf.load(args.config) # ── distributed setup ───────────────────────────────────────────────── world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) logging.info( 'rank {} / {} local_rank {}'.format(rank, world_size, local_rank) ) if world_size > 1: dist.init_process_group('nccl') torch.cuda.set_device(local_rank) # ── output directory ────────────────────────────────────────────────── model_dir = args.model_dir or cfg.data.output_dir if rank == 0: os.makedirs(model_dir, exist_ok=True) # ── datasets ────────────────────────────────────────────────────────── train_dataset = TTSDataset( dataset_path=list(cfg.data.dataset_paths), sample_rate=cfg.data.sample_rate, n_mels=cfg.data.n_mels, n_fft=cfg.data.n_fft, hop_length=cfg.data.hop_length, max_duration=cfg.data.max_duration, ) _val_cfg = cfg.data.get('val_dataset_paths', None) val_paths = OmegaConf.to_container(_val_cfg) if _val_cfg is not None else None cv_dataset = None if val_paths: cv_dataset = TTSDataset( dataset_path=val_paths, sample_rate=cfg.data.sample_rate, n_mels=cfg.data.n_mels, n_fft=cfg.data.n_fft, hop_length=cfg.data.hop_length, max_duration=cfg.data.max_duration, ) train_sampler = DistributedSampler(train_dataset) if world_size > 1 else None train_loader = DataLoader( train_dataset, batch_size=cfg.train.batch_size, sampler=train_sampler, shuffle=(train_sampler is None), collate_fn=collate_fn, num_workers=cfg.train.num_workers, pin_memory=True, drop_last=True, ) cv_loader = None if cv_dataset is not None: cv_loader = DataLoader( cv_dataset, batch_size=cfg.train.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=cfg.train.num_workers, pin_memory=True, ) # ── model ───────────────────────────────────────────────────────────── model = FlowMatchingTTS(cfg) if args.checkpoint is not None: model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')) if rank == 0: logging.info('Loaded checkpoint: {}'.format(args.checkpoint)) model.cuda() if world_size > 1: model = torch.nn.parallel.DistributedDataParallel( model, find_unused_parameters=True ) if rank == 0: n_params = sum(p.numel() for p in (model.module if world_size > 1 else model).parameters()) logging.info('Model parameters: {:.1f}M'.format(n_params / 1e6)) # ── frozen speaker encoder ───────────────────────────────────────────── spk_enc = SpeakerEncoder( cfg.data.campplus_ckpt, device='cuda:{}'.format(local_rank), ).cuda().eval() for p in spk_enc.parameters(): p.requires_grad_(False) # ── optimizer ───────────────────────────────────────────────────────── optim_name = cfg.train.get('optim', 'adamw') optim_conf = OmegaConf.to_container(cfg.train.get('optim_conf', {})) if not optim_conf: optim_conf = { 'lr': cfg.train.lr, 'weight_decay': cfg.train.weight_decay, } if optim_name == 'adam': optimizer = torch.optim.Adam(model.parameters(), **optim_conf) else: optimizer = torch.optim.AdamW(model.parameters(), **optim_conf) # ── scheduler ───────────────────────────────────────────────────────── sched_name = cfg.train.get('scheduler', 'warmuplr') sched_conf = OmegaConf.to_container(cfg.train.get('scheduler_conf', {})) if not sched_conf: sched_conf = {'warmup_steps': 4000} if sched_name == 'warmuplr': scheduler = WarmupLR(optimizer, **sched_conf) elif sched_name == 'NoamHoldAnnealing': total_steps = len(train_loader) * cfg.train.num_epochs scheduler = NoamHoldAnnealing( optimizer, max_steps=total_steps, **sched_conf ) else: raise ValueError('Unknown scheduler: {}'.format(sched_name)) # ── tensorboard ─────────────────────────────────────────────────────── writer = None if rank == 0: writer = SummaryWriter(os.path.join(model_dir, 'tb')) # ── shared info dict (mirrors cosyvoice train_conf) ─────────────────── info_dict = { 'model_dir': model_dir, 'grad_clip': cfg.train.grad_clip, 'accum_grad': int(cfg.train.get('accum_grad', 1)), 'save_per_step': int(cfg.train.get('save_per_step', -1)), 'log_interval': int(cfg.train.get('log_interval', 100)), 'max_epoch': cfg.train.num_epochs, 'batch_size': cfg.train.batch_size, 'dtype': 'fp32', 'train_engine': 'torch_ddp', 'group_timeout': datetime.timedelta(seconds=int(cfg.train.get('timeout', args.timeout))), } # ── executor ────────────────────────────────────────────────────────── executor = Executor() for epoch in range(cfg.train.num_epochs): executor.epoch = epoch if train_sampler is not None: train_sampler.set_epoch(epoch) if world_size > 1: dist.barrier() timeout = datetime.timedelta(seconds=int(cfg.train.get('timeout', args.timeout))) group_join = ( dist.new_group(backend='gloo', timeout=timeout) if world_size > 1 else None ) executor.train_one_epoch( model, optimizer, scheduler, train_loader, cv_loader, writer, info_dict, spk_enc, group_join, ) if world_size > 1: dist.destroy_process_group(group_join) if writer is not None: writer.close() if world_size > 1: dist.destroy_process_group() if __name__ == '__main__': main()