| """ |
| Training executor for FlowMatchingTTS β mirrors cosyvoice/utils/executor.py. |
| |
| Key additions over cosyvoice's Executor: |
| β’ _extract_speaker_emb: on-the-fly CAM++ extraction per batch |
| β’ optional cv_loader (skip CV when no validation set) |
| β’ single-GPU safe (_barrier / model_context guards) |
| """ |
| import logging |
| import os |
| from contextlib import nullcontext |
|
|
| import torch |
| import torch.distributed as dist |
| import tqdm |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.nn.utils import clip_grad_norm_ |
|
|
|
|
| |
|
|
| def _world_size() -> int: |
| return int(os.environ.get('WORLD_SIZE', 1)) |
|
|
|
|
| def _rank() -> int: |
| return int(os.environ.get('RANK', 0)) |
|
|
|
|
| def _barrier(): |
| if _world_size() > 1: |
| dist.barrier() |
|
|
|
|
| def _extract_speaker_emb(batch: dict, spk_enc, device) -> dict: |
| """Run CAM++ on wav_16k and store result as batch['embedding'].""" |
| wav_16k = batch['wav_16k'].to(device) |
| with torch.no_grad(): |
| feats = spk_enc.fbank(wav_16k) |
| batch['embedding'] = spk_enc(feats) |
| return batch |
|
|
|
|
| def batch_forward(model, batch: dict, info_dict: dict) -> dict: |
| device = int(os.environ.get('LOCAL_RANK', 0)) |
| info_dict['loss_dict'] = model(batch, device) |
| return info_dict |
|
|
|
|
| def batch_backward(model, info_dict: dict) -> dict: |
| accum_grad = info_dict.get('accum_grad', 1) |
| loss = info_dict['loss_dict']['loss'] / accum_grad |
| loss.backward() |
| info_dict['loss_dict']['loss'] = loss |
| return info_dict |
|
|
|
|
| def update_parameter_and_lr(model, optimizer, scheduler, info_dict: dict) -> dict: |
| grad_norm = 0.0 |
| accum_grad = info_dict.get('accum_grad', 1) |
| if (info_dict['batch_idx'] + 1) % accum_grad == 0: |
| grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip']) |
| if torch.isfinite(grad_norm): |
| optimizer.step() |
| optimizer.zero_grad() |
| scheduler.step() |
| info_dict['lr'] = optimizer.param_groups[0]['lr'] |
| info_dict['grad_norm'] = float(grad_norm) |
| return info_dict |
|
|
|
|
| def log_per_step(writer, info_dict: dict): |
| tag = info_dict['tag'] |
| step = info_dict['step'] |
| batch_idx = info_dict['batch_idx'] |
| loss_dict = info_dict['loss_dict'] |
| accum_grad = info_dict.get('accum_grad', 1) |
| rank = _rank() |
|
|
| if writer is not None and (batch_idx + 1) % accum_grad == 0: |
| for k in ['epoch', 'lr', 'grad_norm']: |
| writer.add_scalar(f'{tag}/{k}', info_dict.get(k, 0), step + 1) |
| for k, v in loss_dict.items(): |
| writer.add_scalar(f'{tag}/{k}', v, step + 1) |
|
|
| if (batch_idx + 1) % info_dict.get('log_interval', 100) == 0: |
| log_str = f'{tag} Epoch {info_dict["epoch"]} Batch {batch_idx + 1} ' |
| for name, val in loss_dict.items(): |
| log_str += f'{name} {float(val):.6f} ' |
| if tag == 'TRAIN': |
| log_str += (f'lr {info_dict["lr"]:.2e} ' |
| f'gnorm {info_dict["grad_norm"]:.4f}') |
| log_str += f' rank {rank}' |
| logging.info(log_str) |
|
|
|
|
| def log_per_save(writer, info_dict: dict): |
| tag = info_dict['tag'] |
| step = info_dict['step'] |
| rank = _rank() |
| loss_dict = info_dict['loss_dict'] |
| logging.info( |
| 'Epoch {} Step {} {} rank {} {}'.format( |
| info_dict['epoch'], step + 1, tag, rank, |
| ' '.join(f'{k}={v:.6f}' for k, v in loss_dict.items()), |
| ) |
| ) |
| if writer is not None: |
| for k in ['epoch', 'lr']: |
| writer.add_scalar(f'{tag}/{k}', info_dict.get(k, 0), step + 1) |
| for k, v in loss_dict.items(): |
| writer.add_scalar(f'{tag}/{k}', v, step + 1) |
|
|
|
|
| def save_model(model, model_name: str, info_dict: dict): |
| rank = _rank() |
| model_dir = info_dict['model_dir'] |
| path = os.path.join(model_dir, f'{model_name}.pt') |
| if rank == 0: |
| m = model.module if isinstance(model, DDP) else model |
| torch.save(m.state_dict(), path) |
| logging.info(f'[Rank 0] Saved {path}') |
|
|
|
|
| def cosyvoice_join(group_join, info_dict: dict) -> bool: |
| """Return True when this rank should break out of the training loop |
| due to uneven batch counts across DDP workers.""" |
| if group_join is None or info_dict['batch_idx'] == 0: |
| return False |
| try: |
| dist.monitored_barrier( |
| group=group_join, |
| timeout=info_dict.get('group_timeout'), |
| ) |
| return False |
| except RuntimeError as e: |
| logging.info( |
| 'Uneven workload detected: {}\n' |
| 'rank {}/{} local_rank {} breaking early.'.format( |
| e, |
| _rank(), _world_size(), |
| int(os.environ.get('LOCAL_RANK', 0)), |
| ) |
| ) |
| return True |
|
|
|
|
| |
|
|
| class Executor: |
|
|
| def __init__(self): |
| self.step = 0 |
| self.epoch = 0 |
| self.rank = _rank() |
| self.device = torch.device( |
| 'cuda:{}'.format(int(os.environ.get('LOCAL_RANK', 0))) |
| ) |
|
|
| def train_one_epoch( |
| self, |
| model, optimizer, scheduler, |
| train_loader, cv_loader, |
| writer, info_dict: dict, |
| spk_enc, group_join, |
| ): |
| lr = optimizer.param_groups[0]['lr'] |
| logging.info( |
| 'Epoch {} TRAIN lr {:.2e} rank {}'.format(self.epoch, lr, self.rank) |
| ) |
| logging.info( |
| 'Gradient accumulation: effective batch = {} Γ {}'.format( |
| info_dict['batch_size'], info_dict.get('accum_grad', 1) |
| ) |
| ) |
| model.train() |
| accum_grad = info_dict.get('accum_grad', 1) |
| save_per_step = info_dict.get('save_per_step', -1) |
|
|
| model_context = model.join if isinstance(model, DDP) else nullcontext |
| with model_context(): |
| for batch_idx, batch in enumerate(tqdm.tqdm(train_loader)): |
| info_dict['tag'] = 'TRAIN' |
| info_dict['step'] = self.step |
| info_dict['epoch'] = self.epoch |
| info_dict['batch_idx'] = batch_idx |
|
|
| if cosyvoice_join(group_join, info_dict): |
| break |
|
|
| |
| batch = _extract_speaker_emb(batch, spk_enc, self.device) |
|
|
| |
| if isinstance(model, DDP) and (batch_idx + 1) % accum_grad != 0: |
| sync_ctx = model.no_sync |
| else: |
| sync_ctx = nullcontext |
|
|
| with sync_ctx(): |
| info_dict = batch_forward(model, batch, info_dict) |
| info_dict = batch_backward(model, info_dict) |
|
|
| info_dict = update_parameter_and_lr( |
| model, optimizer, scheduler, info_dict |
| ) |
| log_per_step(writer, info_dict) |
|
|
| |
| if (save_per_step > 0 |
| and (self.step + 1) % save_per_step == 0 |
| and (batch_idx + 1) % accum_grad == 0): |
| _barrier() |
| self.cv( |
| model, cv_loader, writer, info_dict, spk_enc, |
| model_name=f'epoch_{self.epoch}_step_{self.step + 1}', |
| on_batch_end=False, |
| ) |
| model.train() |
|
|
| if (batch_idx + 1) % accum_grad == 0: |
| self.step += 1 |
|
|
| _barrier() |
| self.cv( |
| model, cv_loader, writer, info_dict, spk_enc, |
| model_name=f'epoch_{self.epoch}_whole', |
| on_batch_end=True, |
| ) |
|
|
| @torch.inference_mode() |
| def cv( |
| self, |
| model, cv_loader, |
| writer, info_dict: dict, |
| spk_enc, |
| model_name: str = 'model', |
| on_batch_end: bool = True, |
| ): |
| logging.info( |
| 'Epoch {} Step {} on_batch_end={} CV rank {}'.format( |
| self.epoch, self.step + 1, on_batch_end, self.rank |
| ) |
| ) |
| model.eval() |
| total_utts = 0 |
| total_loss: dict = {} |
|
|
| if cv_loader is not None: |
| for batch_idx, batch in enumerate(cv_loader): |
| info_dict['tag'] = 'CV' |
| info_dict['step'] = self.step |
| info_dict['epoch'] = self.epoch |
| info_dict['batch_idx'] = batch_idx |
|
|
| batch = _extract_speaker_emb(batch, spk_enc, self.device) |
| num_utts = batch['mel'].shape[0] |
| total_utts += num_utts |
|
|
| info_dict = batch_forward(model, batch, info_dict) |
| for k, v in info_dict['loss_dict'].items(): |
| total_loss.setdefault(k, []).append(float(v) * num_utts) |
|
|
| for k in total_loss: |
| total_loss[k] = sum(total_loss[k]) / max(total_utts, 1) |
| info_dict['loss_dict'] = total_loss |
| log_per_save(writer, info_dict) |
|
|
| save_model(model, model_name, info_dict) |
|
|