sunf / flow_matching /executor.py
anhtunguyen98's picture
Upload folder using huggingface_hub
4698bfc verified
"""
Training executor for FlowMatchingTTS – mirrors cosyvoice/utils/executor.py.
Key additions over cosyvoice's Executor:
β€’ _extract_speaker_emb: on-the-fly CAM++ extraction per batch
β€’ optional cv_loader (skip CV when no validation set)
β€’ single-GPU safe (_barrier / model_context guards)
"""
import logging
import os
from contextlib import nullcontext
import torch
import torch.distributed as dist
import tqdm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
# ── helpers ───────────────────────────────────────────────────────────────────
def _world_size() -> int:
return int(os.environ.get('WORLD_SIZE', 1))
def _rank() -> int:
return int(os.environ.get('RANK', 0))
def _barrier():
if _world_size() > 1:
dist.barrier()
def _extract_speaker_emb(batch: dict, spk_enc, device) -> dict:
"""Run CAM++ on wav_16k and store result as batch['embedding']."""
wav_16k = batch['wav_16k'].to(device)
with torch.no_grad():
feats = spk_enc.fbank(wav_16k) # (B, T_frames, 80)
batch['embedding'] = spk_enc(feats) # (B, 192) L2-normalised
return batch
def batch_forward(model, batch: dict, info_dict: dict) -> dict:
device = int(os.environ.get('LOCAL_RANK', 0))
info_dict['loss_dict'] = model(batch, device)
return info_dict
def batch_backward(model, info_dict: dict) -> dict:
accum_grad = info_dict.get('accum_grad', 1)
loss = info_dict['loss_dict']['loss'] / accum_grad
loss.backward()
info_dict['loss_dict']['loss'] = loss
return info_dict
def update_parameter_and_lr(model, optimizer, scheduler, info_dict: dict) -> dict:
grad_norm = 0.0
accum_grad = info_dict.get('accum_grad', 1)
if (info_dict['batch_idx'] + 1) % accum_grad == 0:
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
if torch.isfinite(grad_norm):
optimizer.step()
optimizer.zero_grad()
scheduler.step()
info_dict['lr'] = optimizer.param_groups[0]['lr']
info_dict['grad_norm'] = float(grad_norm)
return info_dict
def log_per_step(writer, info_dict: dict):
tag = info_dict['tag']
step = info_dict['step']
batch_idx = info_dict['batch_idx']
loss_dict = info_dict['loss_dict']
accum_grad = info_dict.get('accum_grad', 1)
rank = _rank()
if writer is not None and (batch_idx + 1) % accum_grad == 0:
for k in ['epoch', 'lr', 'grad_norm']:
writer.add_scalar(f'{tag}/{k}', info_dict.get(k, 0), step + 1)
for k, v in loss_dict.items():
writer.add_scalar(f'{tag}/{k}', v, step + 1)
if (batch_idx + 1) % info_dict.get('log_interval', 100) == 0:
log_str = f'{tag} Epoch {info_dict["epoch"]} Batch {batch_idx + 1} '
for name, val in loss_dict.items():
log_str += f'{name} {float(val):.6f} '
if tag == 'TRAIN':
log_str += (f'lr {info_dict["lr"]:.2e} '
f'gnorm {info_dict["grad_norm"]:.4f}')
log_str += f' rank {rank}'
logging.info(log_str)
def log_per_save(writer, info_dict: dict):
tag = info_dict['tag']
step = info_dict['step']
rank = _rank()
loss_dict = info_dict['loss_dict']
logging.info(
'Epoch {} Step {} {} rank {} {}'.format(
info_dict['epoch'], step + 1, tag, rank,
' '.join(f'{k}={v:.6f}' for k, v in loss_dict.items()),
)
)
if writer is not None:
for k in ['epoch', 'lr']:
writer.add_scalar(f'{tag}/{k}', info_dict.get(k, 0), step + 1)
for k, v in loss_dict.items():
writer.add_scalar(f'{tag}/{k}', v, step + 1)
def save_model(model, model_name: str, info_dict: dict):
rank = _rank()
model_dir = info_dict['model_dir']
path = os.path.join(model_dir, f'{model_name}.pt')
if rank == 0:
m = model.module if isinstance(model, DDP) else model
torch.save(m.state_dict(), path)
logging.info(f'[Rank 0] Saved {path}')
def cosyvoice_join(group_join, info_dict: dict) -> bool:
"""Return True when this rank should break out of the training loop
due to uneven batch counts across DDP workers."""
if group_join is None or info_dict['batch_idx'] == 0:
return False
try:
dist.monitored_barrier(
group=group_join,
timeout=info_dict.get('group_timeout'),
)
return False
except RuntimeError as e:
logging.info(
'Uneven workload detected: {}\n'
'rank {}/{} local_rank {} breaking early.'.format(
e,
_rank(), _world_size(),
int(os.environ.get('LOCAL_RANK', 0)),
)
)
return True
# ── Executor ──────────────────────────────────────────────────────────────────
class Executor:
def __init__(self):
self.step = 0
self.epoch = 0
self.rank = _rank()
self.device = torch.device(
'cuda:{}'.format(int(os.environ.get('LOCAL_RANK', 0)))
)
def train_one_epoch(
self,
model, optimizer, scheduler,
train_loader, cv_loader,
writer, info_dict: dict,
spk_enc, group_join,
):
lr = optimizer.param_groups[0]['lr']
logging.info(
'Epoch {} TRAIN lr {:.2e} rank {}'.format(self.epoch, lr, self.rank)
)
logging.info(
'Gradient accumulation: effective batch = {} Γ— {}'.format(
info_dict['batch_size'], info_dict.get('accum_grad', 1)
)
)
model.train()
accum_grad = info_dict.get('accum_grad', 1)
save_per_step = info_dict.get('save_per_step', -1)
model_context = model.join if isinstance(model, DDP) else nullcontext
with model_context():
for batch_idx, batch in enumerate(tqdm.tqdm(train_loader)):
info_dict['tag'] = 'TRAIN'
info_dict['step'] = self.step
info_dict['epoch'] = self.epoch
info_dict['batch_idx'] = batch_idx
if cosyvoice_join(group_join, info_dict):
break
# Frozen speaker encoder: extract embeddings on GPU
batch = _extract_speaker_emb(batch, spk_enc, self.device)
# Delay DDP gradient sync until the last accumulation step
if isinstance(model, DDP) and (batch_idx + 1) % accum_grad != 0:
sync_ctx = model.no_sync
else:
sync_ctx = nullcontext
with sync_ctx():
info_dict = batch_forward(model, batch, info_dict)
info_dict = batch_backward(model, info_dict)
info_dict = update_parameter_and_lr(
model, optimizer, scheduler, info_dict
)
log_per_step(writer, info_dict)
# Mid-epoch checkpoint + CV
if (save_per_step > 0
and (self.step + 1) % save_per_step == 0
and (batch_idx + 1) % accum_grad == 0):
_barrier()
self.cv(
model, cv_loader, writer, info_dict, spk_enc,
model_name=f'epoch_{self.epoch}_step_{self.step + 1}',
on_batch_end=False,
)
model.train()
if (batch_idx + 1) % accum_grad == 0:
self.step += 1
_barrier()
self.cv(
model, cv_loader, writer, info_dict, spk_enc,
model_name=f'epoch_{self.epoch}_whole',
on_batch_end=True,
)
@torch.inference_mode()
def cv(
self,
model, cv_loader,
writer, info_dict: dict,
spk_enc,
model_name: str = 'model',
on_batch_end: bool = True,
):
logging.info(
'Epoch {} Step {} on_batch_end={} CV rank {}'.format(
self.epoch, self.step + 1, on_batch_end, self.rank
)
)
model.eval()
total_utts = 0
total_loss: dict = {}
if cv_loader is not None:
for batch_idx, batch in enumerate(cv_loader):
info_dict['tag'] = 'CV'
info_dict['step'] = self.step
info_dict['epoch'] = self.epoch
info_dict['batch_idx'] = batch_idx
batch = _extract_speaker_emb(batch, spk_enc, self.device)
num_utts = batch['mel'].shape[0]
total_utts += num_utts
info_dict = batch_forward(model, batch, info_dict)
for k, v in info_dict['loss_dict'].items():
total_loss.setdefault(k, []).append(float(v) * num_utts)
for k in total_loss:
total_loss[k] = sum(total_loss[k]) / max(total_utts, 1)
info_dict['loss_dict'] = total_loss
log_per_save(writer, info_dict)
save_model(model, model_name, info_dict)