""" Training executor for FlowMatchingTTS – mirrors cosyvoice/utils/executor.py. Key additions over cosyvoice's Executor: • _extract_speaker_emb: on-the-fly CAM++ extraction per batch • optional cv_loader (skip CV when no validation set) • single-GPU safe (_barrier / model_context guards) """ import logging import os from contextlib import nullcontext import torch import torch.distributed as dist import tqdm from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ # ── helpers ─────────────────────────────────────────────────────────────────── def _world_size() -> int: return int(os.environ.get('WORLD_SIZE', 1)) def _rank() -> int: return int(os.environ.get('RANK', 0)) def _barrier(): if _world_size() > 1: dist.barrier() def _extract_speaker_emb(batch: dict, spk_enc, device) -> dict: """Run CAM++ on wav_16k and store result as batch['embedding'].""" wav_16k = batch['wav_16k'].to(device) with torch.no_grad(): feats = spk_enc.fbank(wav_16k) # (B, T_frames, 80) batch['embedding'] = spk_enc(feats) # (B, 192) L2-normalised return batch def batch_forward(model, batch: dict, info_dict: dict) -> dict: device = int(os.environ.get('LOCAL_RANK', 0)) info_dict['loss_dict'] = model(batch, device) return info_dict def batch_backward(model, info_dict: dict) -> dict: accum_grad = info_dict.get('accum_grad', 1) loss = info_dict['loss_dict']['loss'] / accum_grad loss.backward() info_dict['loss_dict']['loss'] = loss return info_dict def update_parameter_and_lr(model, optimizer, scheduler, info_dict: dict) -> dict: grad_norm = 0.0 accum_grad = info_dict.get('accum_grad', 1) if (info_dict['batch_idx'] + 1) % accum_grad == 0: grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) if torch.isfinite(grad_norm): optimizer.step() optimizer.zero_grad() scheduler.step() info_dict['lr'] = optimizer.param_groups[0]['lr'] info_dict['grad_norm'] = float(grad_norm) return info_dict def log_per_step(writer, info_dict: dict): tag = info_dict['tag'] step = info_dict['step'] batch_idx = info_dict['batch_idx'] loss_dict = info_dict['loss_dict'] accum_grad = info_dict.get('accum_grad', 1) rank = _rank() if writer is not None and (batch_idx + 1) % accum_grad == 0: for k in ['epoch', 'lr', 'grad_norm']: writer.add_scalar(f'{tag}/{k}', info_dict.get(k, 0), step + 1) for k, v in loss_dict.items(): writer.add_scalar(f'{tag}/{k}', v, step + 1) if (batch_idx + 1) % info_dict.get('log_interval', 100) == 0: log_str = f'{tag} Epoch {info_dict["epoch"]} Batch {batch_idx + 1} ' for name, val in loss_dict.items(): log_str += f'{name} {float(val):.6f} ' if tag == 'TRAIN': log_str += (f'lr {info_dict["lr"]:.2e} ' f'gnorm {info_dict["grad_norm"]:.4f}') log_str += f' rank {rank}' logging.info(log_str) def log_per_save(writer, info_dict: dict): tag = info_dict['tag'] step = info_dict['step'] rank = _rank() loss_dict = info_dict['loss_dict'] logging.info( 'Epoch {} Step {} {} rank {} {}'.format( info_dict['epoch'], step + 1, tag, rank, ' '.join(f'{k}={v:.6f}' for k, v in loss_dict.items()), ) ) if writer is not None: for k in ['epoch', 'lr']: writer.add_scalar(f'{tag}/{k}', info_dict.get(k, 0), step + 1) for k, v in loss_dict.items(): writer.add_scalar(f'{tag}/{k}', v, step + 1) def save_model(model, model_name: str, info_dict: dict): rank = _rank() model_dir = info_dict['model_dir'] path = os.path.join(model_dir, f'{model_name}.pt') if rank == 0: m = model.module if isinstance(model, DDP) else model torch.save(m.state_dict(), path) logging.info(f'[Rank 0] Saved {path}') def cosyvoice_join(group_join, info_dict: dict) -> bool: """Return True when this rank should break out of the training loop due to uneven batch counts across DDP workers.""" if group_join is None or info_dict['batch_idx'] == 0: return False try: dist.monitored_barrier( group=group_join, timeout=info_dict.get('group_timeout'), ) return False except RuntimeError as e: logging.info( 'Uneven workload detected: {}\n' 'rank {}/{} local_rank {} breaking early.'.format( e, _rank(), _world_size(), int(os.environ.get('LOCAL_RANK', 0)), ) ) return True # ── Executor ────────────────────────────────────────────────────────────────── class Executor: def __init__(self): self.step = 0 self.epoch = 0 self.rank = _rank() self.device = torch.device( 'cuda:{}'.format(int(os.environ.get('LOCAL_RANK', 0))) ) def train_one_epoch( self, model, optimizer, scheduler, train_loader, cv_loader, writer, info_dict: dict, spk_enc, group_join, ): lr = optimizer.param_groups[0]['lr'] logging.info( 'Epoch {} TRAIN lr {:.2e} rank {}'.format(self.epoch, lr, self.rank) ) logging.info( 'Gradient accumulation: effective batch = {} × {}'.format( info_dict['batch_size'], info_dict.get('accum_grad', 1) ) ) model.train() accum_grad = info_dict.get('accum_grad', 1) save_per_step = info_dict.get('save_per_step', -1) model_context = model.join if isinstance(model, DDP) else nullcontext with model_context(): for batch_idx, batch in enumerate(tqdm.tqdm(train_loader)): info_dict['tag'] = 'TRAIN' info_dict['step'] = self.step info_dict['epoch'] = self.epoch info_dict['batch_idx'] = batch_idx if cosyvoice_join(group_join, info_dict): break # Frozen speaker encoder: extract embeddings on GPU batch = _extract_speaker_emb(batch, spk_enc, self.device) # Delay DDP gradient sync until the last accumulation step if isinstance(model, DDP) and (batch_idx + 1) % accum_grad != 0: sync_ctx = model.no_sync else: sync_ctx = nullcontext with sync_ctx(): info_dict = batch_forward(model, batch, info_dict) info_dict = batch_backward(model, info_dict) info_dict = update_parameter_and_lr( model, optimizer, scheduler, info_dict ) log_per_step(writer, info_dict) # Mid-epoch checkpoint + CV if (save_per_step > 0 and (self.step + 1) % save_per_step == 0 and (batch_idx + 1) % accum_grad == 0): _barrier() self.cv( model, cv_loader, writer, info_dict, spk_enc, model_name=f'epoch_{self.epoch}_step_{self.step + 1}', on_batch_end=False, ) model.train() if (batch_idx + 1) % accum_grad == 0: self.step += 1 _barrier() self.cv( model, cv_loader, writer, info_dict, spk_enc, model_name=f'epoch_{self.epoch}_whole', on_batch_end=True, ) @torch.inference_mode() def cv( self, model, cv_loader, writer, info_dict: dict, spk_enc, model_name: str = 'model', on_batch_end: bool = True, ): logging.info( 'Epoch {} Step {} on_batch_end={} CV rank {}'.format( self.epoch, self.step + 1, on_batch_end, self.rank ) ) model.eval() total_utts = 0 total_loss: dict = {} if cv_loader is not None: for batch_idx, batch in enumerate(cv_loader): info_dict['tag'] = 'CV' info_dict['step'] = self.step info_dict['epoch'] = self.epoch info_dict['batch_idx'] = batch_idx batch = _extract_speaker_emb(batch, spk_enc, self.device) num_utts = batch['mel'].shape[0] total_utts += num_utts info_dict = batch_forward(model, batch, info_dict) for k, v in info_dict['loss_dict'].items(): total_loss.setdefault(k, []).append(float(v) * num_utts) for k in total_loss: total_loss[k] = sum(total_loss[k]) / max(total_utts, 1) info_dict['loss_dict'] = total_loss log_per_save(writer, info_dict) save_model(model, model_name, info_dict)