dpss-exp3-TTS / VoxCPM /scripts /train_voxcpm_finetune.py
lglg666's picture
Upload folder using huggingface_hub
6766eda verified
#!/usr/bin/env python3
import itertools
from pathlib import Path
from typing import Dict, Optional
import argbind
import torch
from tensorboardX import SummaryWriter
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
from voxcpm.model import VoxCPMModel
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training import (
Accelerator,
BatchProcessor,
TrainingTracker,
build_dataloader,
load_audio_text_datasets,
)
@argbind.bind(without_prefix=True)
def train(
pretrained_path: str,
train_manifest: str,
val_manifest: str = "",
sample_rate: int = 16_000,
batch_size: int = 1,
grad_accum_steps: int = 1,
num_workers: int = 2,
num_iters: int = 100_000,
log_interval: int = 100,
valid_interval: int = 1_000,
save_interval: int = 10_000,
learning_rate: float = 1e-4,
weight_decay: float = 1e-2,
warmup_steps: int = 1_000,
max_steps: int = 100_000,
max_batch_tokens: int = 0,
save_path: str = "checkpoints",
tensorboard: str = "",
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
lora: dict = None,
config_path: str = "",
):
_ = config_path
accelerator = Accelerator(amp=True)
save_dir = Path(save_path)
save_dir.mkdir(parents=True, exist_ok=True)
tb_dir = Path(tensorboard) if tensorboard else save_dir / "logs"
tb_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
tokenizer = base_model.text_tokenizer
train_ds, val_ds = load_audio_text_datasets(
train_manifest=train_manifest,
val_manifest=val_manifest,
sample_rate=sample_rate,
)
def tokenize(batch):
text_list = batch["text"]
text_ids = [tokenizer(text) for text in text_list]
return {"text_ids": text_ids}
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"])
if val_ds is not None:
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"])
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1
num_train_samples = len(train_ds)
# ------------------------------------------------------------------ #
# 可选:按预估 token 数过滤超长样本,避免单个样本撑爆显存
# max_batch_tokens > 0 时启用:
# 每个样本的最大长度 = max_batch_tokens // batch_size
# 超过该长度的样本将被丢弃(train_ds 过滤)
# ------------------------------------------------------------------ #
if max_batch_tokens and max_batch_tokens > 0:
from voxcpm.training.data import compute_sample_lengths
est_lengths = compute_sample_lengths(
train_ds,
audio_vae_fps=25,
patch_size=base_model.config.patch_size,
)
max_sample_len = max_batch_tokens // batch_size if batch_size > 0 else max(est_lengths)
keep_indices = [i for i, L in enumerate(est_lengths) if L <= max_sample_len]
if len(keep_indices) < len(train_ds) and accelerator.rank == 0:
tracker.print(
f"Filtering {len(train_ds) - len(keep_indices)} / {len(train_ds)} "
f"training samples longer than {max_sample_len} tokens "
f"(max_batch_tokens={max_batch_tokens})."
)
train_ds = train_ds.select(keep_indices)
train_loader = build_dataloader(
train_ds,
accelerator=accelerator,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
)
val_loader = (
build_dataloader(
val_ds,
accelerator=accelerator,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
if val_ds is not None
else None
)
model = accelerator.prepare_model(base_model)
unwrapped_model = accelerator.unwrap(model)
unwrapped_model.train()
batch_processor = BatchProcessor(
config=unwrapped_model.config,
audio_vae=unwrapped_model.audio_vae,
dataset_cnt=dataset_cnt,
device=accelerator.device,
)
for name, param in model.named_parameters():
print(name, param.requires_grad)
optimizer = AdamW(
(p for p in model.parameters() if p.requires_grad),
lr=learning_rate,
weight_decay=weight_decay,
)
# 使用 transformers 的 cosine + warmup 调度器:
# - num_warmup_steps: 预热步数
# - num_training_steps: 计划的总训练步数(按 outer step 计数)
total_training_steps = max_steps if max_steps > 0 else num_iters
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_training_steps,
)
train_iter = iter(itertools.cycle(train_loader))
grad_accum_steps = max(int(grad_accum_steps), 1)
with tracker.live():
for step in range(num_iters):
tracker.step = step
optimizer.zero_grad(set_to_none=True)
# 梯度累积:在多个 micro-batch 上累积梯度,再进行一次优化步
loss_dict = {}
for micro_step in range(grad_accum_steps):
batch = next(train_iter)
processed = batch_processor(batch)
with accelerator.autocast(dtype=torch.bfloat16):
outputs = model(
processed["text_tokens"],
processed["text_mask"],
processed["audio_feats"],
processed["audio_mask"],
processed["loss_mask"],
processed["position_ids"],
processed["labels"],
progress=step / max(1, num_iters),
)
total_loss = 0.0
for key, value in outputs.items():
if key.startswith("loss/"):
weight = lambdas.get(key, 1.0)
loss_value = value * weight / grad_accum_steps
total_loss = total_loss + loss_value
# 记录最后一个 micro-batch 的原始 loss,便于日志查看
loss_dict[key] = value.detach()
# 对当前 micro-batch 累积梯度(已按 grad_accum_steps 归一化)
accelerator.backward(total_loss)
# 在所有 micro-batch 反向完成后,再做一次 unscale / grad_norm / step
scaler = getattr(accelerator, "scaler", None)
if scaler is not None:
scaler.unscale_(optimizer)
# 使用极大 max_norm 复用实现,仅做 grad_norm 统计而不实际裁剪
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
accelerator.step(optimizer)
accelerator.update()
scheduler.step()
if step % log_interval == 0:
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
# 近似当前 epoch:已见样本数 / 训练集样本数(考虑梯度累积和 batch_size)
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples)
loss_values["epoch"] = float(epoch)
loss_values["grad_norm"] = float(grad_norm)
tracker.log_metrics(loss_values, split="train")
if val_loader is not None and step % valid_interval == 0 and step != 0:
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
if step % save_interval == 0 and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step)
if accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters)
if writer:
writer.close()
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
model.eval()
losses = []
with torch.no_grad():
for batch in itertools.islice(val_loader, 0, 10):
processed = batch_processor(batch)
with accelerator.autocast(dtype=torch.bfloat16):
outputs = model(
processed["text_tokens"],
processed["text_mask"],
processed["audio_feats"],
processed["audio_mask"],
processed["loss_mask"],
processed["position_ids"],
processed["labels"],
progress=0.0,
sample_generate=False,
)
total = 0.0
for key, value in outputs.items():
if key.startswith("loss/"):
total += lambdas.get(key, 1.0) * value
losses.append(total.detach())
if losses:
mean_loss = torch.stack(losses).mean()
tracker.log_metrics({"loss": mean_loss.item()}, split="val")
model.train()
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int):
save_dir.mkdir(parents=True, exist_ok=True)
tag = "latest" if step == 0 else f"step_{step:07d}"
folder = save_dir / tag
folder.mkdir(parents=True, exist_ok=True)
# 根据是否启用 LoRA 决定保存哪些权重:
# - 启用 LoRA,则仅保存 LoRA 参数(lora_A / lora_B)
# - 否则保存完整模型权重
unwrapped = model.module if hasattr(model, "module") else model
full_state = unwrapped.state_dict()
lora_cfg = unwrapped.lora_config
if lora_cfg is not None:
state_dict = {k: v for k, v in full_state.items() if ("lora_A" in k or "lora_B" in k)}
else:
state_dict = full_state
torch.save({"state_dict": state_dict}, folder / "generator.pth")
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
if __name__ == "__main__":
from voxcpm.training.config import load_yaml_config
args = argbind.parse_args()
config_file = args.get("config_path")
# 如果提供了 YAML 配置文件,则直接用 YAML 的参数调用 train
if config_file:
yaml_args = load_yaml_config(config_file)
train(**yaml_args)
else:
# 否则使用命令行参数(argbind 解析)调用 train
with argbind.scope(args):
train()