sunf / flow_matching /train.py
anhtunguyen98's picture
Upload folder using huggingface_hub
4698bfc verified
"""
FlowMatchingTTS training entry point.
Mirrors cosyvoice/bin/train.py structure.
Single-GPU:
python -m flow_matching.train --config config/flow_matching.yaml
Multi-GPU (torchrun):
torchrun --nproc_per_node=8 -m flow_matching.train \\
--config config/flow_matching.yaml
"""
from __future__ import print_function
import argparse
import datetime
import logging
import os
import sys
logging.getLogger('matplotlib').setLevel(logging.WARNING)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from omegaconf import OmegaConf
from flow_matching.utils.scheduler import WarmupLR, NoamHoldAnnealing
from flow_matching.dataset import TTSDataset, collate_fn
from flow_matching.model import FlowMatchingTTS
from flow_matching.speaker_encoder import SpeakerEncoder
from flow_matching.executor import Executor
def get_args():
parser = argparse.ArgumentParser(description='Train FlowMatchingTTS')
parser.add_argument('--config', required=True, help='YAML config file')
parser.add_argument('--checkpoint', default=None, help='Resume from checkpoint')
parser.add_argument('--model_dir', default=None, help='Override output dir')
parser.add_argument('--timeout', default=30, type=int,
help='monitored_barrier timeout (seconds)')
return parser.parse_args()
@record
def main():
args = get_args()
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s',
)
cfg = OmegaConf.load(args.config)
# ── distributed setup ─────────────────────────────────────────────────
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
logging.info(
'rank {} / {} local_rank {}'.format(rank, world_size, local_rank)
)
if world_size > 1:
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
# ── output directory ──────────────────────────────────────────────────
model_dir = args.model_dir or cfg.data.output_dir
if rank == 0:
os.makedirs(model_dir, exist_ok=True)
# ── datasets ──────────────────────────────────────────────────────────
train_dataset = TTSDataset(
dataset_path=list(cfg.data.dataset_paths),
sample_rate=cfg.data.sample_rate,
n_mels=cfg.data.n_mels,
n_fft=cfg.data.n_fft,
hop_length=cfg.data.hop_length,
max_duration=cfg.data.max_duration,
)
_val_cfg = cfg.data.get('val_dataset_paths', None)
val_paths = OmegaConf.to_container(_val_cfg) if _val_cfg is not None else None
cv_dataset = None
if val_paths:
cv_dataset = TTSDataset(
dataset_path=val_paths,
sample_rate=cfg.data.sample_rate,
n_mels=cfg.data.n_mels,
n_fft=cfg.data.n_fft,
hop_length=cfg.data.hop_length,
max_duration=cfg.data.max_duration,
)
train_sampler = DistributedSampler(train_dataset) if world_size > 1 else None
train_loader = DataLoader(
train_dataset,
batch_size=cfg.train.batch_size,
sampler=train_sampler,
shuffle=(train_sampler is None),
collate_fn=collate_fn,
num_workers=cfg.train.num_workers,
pin_memory=True,
drop_last=True,
)
cv_loader = None
if cv_dataset is not None:
cv_loader = DataLoader(
cv_dataset,
batch_size=cfg.train.batch_size,
shuffle=False,
collate_fn=collate_fn,
num_workers=cfg.train.num_workers,
pin_memory=True,
)
# ── model ─────────────────────────────────────────────────────────────
model = FlowMatchingTTS(cfg)
if args.checkpoint is not None:
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
if rank == 0:
logging.info('Loaded checkpoint: {}'.format(args.checkpoint))
model.cuda()
if world_size > 1:
model = torch.nn.parallel.DistributedDataParallel(
model, find_unused_parameters=True
)
if rank == 0:
n_params = sum(p.numel() for p in
(model.module if world_size > 1 else model).parameters())
logging.info('Model parameters: {:.1f}M'.format(n_params / 1e6))
# ── frozen speaker encoder ─────────────────────────────────────────────
spk_enc = SpeakerEncoder(
cfg.data.campplus_ckpt,
device='cuda:{}'.format(local_rank),
).cuda().eval()
for p in spk_enc.parameters():
p.requires_grad_(False)
# ── optimizer ─────────────────────────────────────────────────────────
optim_name = cfg.train.get('optim', 'adamw')
optim_conf = OmegaConf.to_container(cfg.train.get('optim_conf', {}))
if not optim_conf:
optim_conf = {
'lr': cfg.train.lr,
'weight_decay': cfg.train.weight_decay,
}
if optim_name == 'adam':
optimizer = torch.optim.Adam(model.parameters(), **optim_conf)
else:
optimizer = torch.optim.AdamW(model.parameters(), **optim_conf)
# ── scheduler ─────────────────────────────────────────────────────────
sched_name = cfg.train.get('scheduler', 'warmuplr')
sched_conf = OmegaConf.to_container(cfg.train.get('scheduler_conf', {}))
if not sched_conf:
sched_conf = {'warmup_steps': 4000}
if sched_name == 'warmuplr':
scheduler = WarmupLR(optimizer, **sched_conf)
elif sched_name == 'NoamHoldAnnealing':
total_steps = len(train_loader) * cfg.train.num_epochs
scheduler = NoamHoldAnnealing(
optimizer, max_steps=total_steps, **sched_conf
)
else:
raise ValueError('Unknown scheduler: {}'.format(sched_name))
# ── tensorboard ───────────────────────────────────────────────────────
writer = None
if rank == 0:
writer = SummaryWriter(os.path.join(model_dir, 'tb'))
# ── shared info dict (mirrors cosyvoice train_conf) ───────────────────
info_dict = {
'model_dir': model_dir,
'grad_clip': cfg.train.grad_clip,
'accum_grad': int(cfg.train.get('accum_grad', 1)),
'save_per_step': int(cfg.train.get('save_per_step', -1)),
'log_interval': int(cfg.train.get('log_interval', 100)),
'max_epoch': cfg.train.num_epochs,
'batch_size': cfg.train.batch_size,
'dtype': 'fp32',
'train_engine': 'torch_ddp',
'group_timeout': datetime.timedelta(seconds=int(cfg.train.get('timeout', args.timeout))),
}
# ── executor ──────────────────────────────────────────────────────────
executor = Executor()
for epoch in range(cfg.train.num_epochs):
executor.epoch = epoch
if train_sampler is not None:
train_sampler.set_epoch(epoch)
if world_size > 1:
dist.barrier()
timeout = datetime.timedelta(seconds=int(cfg.train.get('timeout', args.timeout)))
group_join = (
dist.new_group(backend='gloo', timeout=timeout)
if world_size > 1 else None
)
executor.train_one_epoch(
model, optimizer, scheduler,
train_loader, cv_loader,
writer, info_dict,
spk_enc, group_join,
)
if world_size > 1:
dist.destroy_process_group(group_join)
if writer is not None:
writer.close()
if world_size > 1:
dist.destroy_process_group()
if __name__ == '__main__':
main()