#!/usr/bin/env python3 import itertools from pathlib import Path from typing import Dict, Optional import argbind import torch from tensorboardX import SummaryWriter from torch.optim import AdamW from transformers import get_cosine_schedule_with_warmup from voxcpm.model import VoxCPMModel from voxcpm.model.voxcpm import LoRAConfig from voxcpm.training import ( Accelerator, BatchProcessor, TrainingTracker, build_dataloader, load_audio_text_datasets, ) @argbind.bind(without_prefix=True) def train( pretrained_path: str, train_manifest: str, val_manifest: str = "", sample_rate: int = 16_000, batch_size: int = 1, grad_accum_steps: int = 1, num_workers: int = 2, num_iters: int = 100_000, log_interval: int = 100, valid_interval: int = 1_000, save_interval: int = 10_000, learning_rate: float = 1e-4, weight_decay: float = 1e-2, warmup_steps: int = 1_000, max_steps: int = 100_000, max_batch_tokens: int = 0, save_path: str = "checkpoints", tensorboard: str = "", lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0}, lora: dict = None, config_path: str = "", ): _ = config_path accelerator = Accelerator(amp=True) save_dir = Path(save_path) save_dir.mkdir(parents=True, exist_ok=True) tb_dir = Path(tensorboard) if tensorboard else save_dir / "logs" tb_dir.mkdir(parents=True, exist_ok=True) writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank) base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None) tokenizer = base_model.text_tokenizer train_ds, val_ds = load_audio_text_datasets( train_manifest=train_manifest, val_manifest=val_manifest, sample_rate=sample_rate, ) def tokenize(batch): text_list = batch["text"] text_ids = [tokenizer(text) for text in text_list] return {"text_ids": text_ids} train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"]) if val_ds is not None: val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"]) dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1 num_train_samples = len(train_ds) # ------------------------------------------------------------------ # # 可选:按预估 token 数过滤超长样本,避免单个样本撑爆显存 # max_batch_tokens > 0 时启用: # 每个样本的最大长度 = max_batch_tokens // batch_size # 超过该长度的样本将被丢弃(train_ds 过滤) # ------------------------------------------------------------------ # if max_batch_tokens and max_batch_tokens > 0: from voxcpm.training.data import compute_sample_lengths est_lengths = compute_sample_lengths( train_ds, audio_vae_fps=25, patch_size=base_model.config.patch_size, ) max_sample_len = max_batch_tokens // batch_size if batch_size > 0 else max(est_lengths) keep_indices = [i for i, L in enumerate(est_lengths) if L <= max_sample_len] if len(keep_indices) < len(train_ds) and accelerator.rank == 0: tracker.print( f"Filtering {len(train_ds) - len(keep_indices)} / {len(train_ds)} " f"training samples longer than {max_sample_len} tokens " f"(max_batch_tokens={max_batch_tokens})." ) train_ds = train_ds.select(keep_indices) train_loader = build_dataloader( train_ds, accelerator=accelerator, batch_size=batch_size, num_workers=num_workers, drop_last=True, ) val_loader = ( build_dataloader( val_ds, accelerator=accelerator, batch_size=batch_size, num_workers=num_workers, drop_last=False, ) if val_ds is not None else None ) model = accelerator.prepare_model(base_model) unwrapped_model = accelerator.unwrap(model) unwrapped_model.train() batch_processor = BatchProcessor( config=unwrapped_model.config, audio_vae=unwrapped_model.audio_vae, dataset_cnt=dataset_cnt, device=accelerator.device, ) for name, param in model.named_parameters(): print(name, param.requires_grad) optimizer = AdamW( (p for p in model.parameters() if p.requires_grad), lr=learning_rate, weight_decay=weight_decay, ) # 使用 transformers 的 cosine + warmup 调度器: # - num_warmup_steps: 预热步数 # - num_training_steps: 计划的总训练步数(按 outer step 计数) total_training_steps = max_steps if max_steps > 0 else num_iters scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_training_steps, ) train_iter = iter(itertools.cycle(train_loader)) grad_accum_steps = max(int(grad_accum_steps), 1) with tracker.live(): for step in range(num_iters): tracker.step = step optimizer.zero_grad(set_to_none=True) # 梯度累积:在多个 micro-batch 上累积梯度,再进行一次优化步 loss_dict = {} for micro_step in range(grad_accum_steps): batch = next(train_iter) processed = batch_processor(batch) with accelerator.autocast(dtype=torch.bfloat16): outputs = model( processed["text_tokens"], processed["text_mask"], processed["audio_feats"], processed["audio_mask"], processed["loss_mask"], processed["position_ids"], processed["labels"], progress=step / max(1, num_iters), ) total_loss = 0.0 for key, value in outputs.items(): if key.startswith("loss/"): weight = lambdas.get(key, 1.0) loss_value = value * weight / grad_accum_steps total_loss = total_loss + loss_value # 记录最后一个 micro-batch 的原始 loss,便于日志查看 loss_dict[key] = value.detach() # 对当前 micro-batch 累积梯度(已按 grad_accum_steps 归一化) accelerator.backward(total_loss) # 在所有 micro-batch 反向完成后,再做一次 unscale / grad_norm / step scaler = getattr(accelerator, "scaler", None) if scaler is not None: scaler.unscale_(optimizer) # 使用极大 max_norm 复用实现,仅做 grad_norm 统计而不实际裁剪 grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9) accelerator.step(optimizer) accelerator.update() scheduler.step() if step % log_interval == 0: loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()} loss_values["lr"] = float(optimizer.param_groups[0]["lr"]) # 近似当前 epoch:已见样本数 / 训练集样本数(考虑梯度累积和 batch_size) epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples) loss_values["epoch"] = float(epoch) loss_values["grad_norm"] = float(grad_norm) tracker.log_metrics(loss_values, split="train") if val_loader is not None and step % valid_interval == 0 and step != 0: validate(model, val_loader, batch_processor, accelerator, tracker, lambdas) if step % save_interval == 0 and accelerator.rank == 0: save_checkpoint(model, optimizer, scheduler, save_dir, step) if accelerator.rank == 0: save_checkpoint(model, optimizer, scheduler, save_dir, num_iters) if writer: writer.close() def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas): model.eval() losses = [] with torch.no_grad(): for batch in itertools.islice(val_loader, 0, 10): processed = batch_processor(batch) with accelerator.autocast(dtype=torch.bfloat16): outputs = model( processed["text_tokens"], processed["text_mask"], processed["audio_feats"], processed["audio_mask"], processed["loss_mask"], processed["position_ids"], processed["labels"], progress=0.0, sample_generate=False, ) total = 0.0 for key, value in outputs.items(): if key.startswith("loss/"): total += lambdas.get(key, 1.0) * value losses.append(total.detach()) if losses: mean_loss = torch.stack(losses).mean() tracker.log_metrics({"loss": mean_loss.item()}, split="val") model.train() def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int): save_dir.mkdir(parents=True, exist_ok=True) tag = "latest" if step == 0 else f"step_{step:07d}" folder = save_dir / tag folder.mkdir(parents=True, exist_ok=True) # 根据是否启用 LoRA 决定保存哪些权重: # - 启用 LoRA,则仅保存 LoRA 参数(lora_A / lora_B) # - 否则保存完整模型权重 unwrapped = model.module if hasattr(model, "module") else model full_state = unwrapped.state_dict() lora_cfg = unwrapped.lora_config if lora_cfg is not None: state_dict = {k: v for k, v in full_state.items() if ("lora_A" in k or "lora_B" in k)} else: state_dict = full_state torch.save({"state_dict": state_dict}, folder / "generator.pth") torch.save(optimizer.state_dict(), folder / "optimizer.pth") torch.save(scheduler.state_dict(), folder / "scheduler.pth") if __name__ == "__main__": from voxcpm.training.config import load_yaml_config args = argbind.parse_args() config_file = args.get("config_path") # 如果提供了 YAML 配置文件,则直接用 YAML 的参数调用 train if config_file: yaml_args = load_yaml_config(config_file) train(**yaml_args) else: # 否则使用命令行参数(argbind 解析)调用 train with argbind.scope(args): train()