WJAD / src /wjad /train /trainer.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""两阶段训练器 + 梯度监控。
Stage 1 (Dense):
- MoE 全部专家加权(dense 模式);
- 路由温度初始 < 1(锐化),训练中线性升到 1;
- DINOv3 冻结;
- 中期开启运动学/内外参扰动,监督校准网络;
- GradNorm 启用。
Stage 2 (Sparse):
- MoE 切 Top-3;
- 路由温度退火完成;
- DINOv3 解冻并采用 1/100 主干 LR;
- GradNorm + PCGrad 同时启用。
"""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ..losses import (
HungarianMatcher,
action_nll,
calibration_regularization,
detection_losses,
ego_traj_nll,
moe_load_balance_and_boundary,
object_traj_nll,
)
from ..model import E2EAVModel, E2EOutput
from .multitask import MultiTaskOptimizer, MultiTaskOptimizerConfig
from .schedule import build_scheduler
log = logging.getLogger(__name__)
class _NullContext:
"""空 context manager,用于 AMP 关闭时占位 autocast。"""
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
@dataclass
class TrainerConfig:
"""Trainer 超参数(与 ``configs/default.yaml`` 对齐)。"""
total_steps: int = 100000
warmup_steps: int = 1000
base_lr: float = 2.0e-4
min_lr: float = 1.0e-6
weight_decay: float = 0.05
grad_clip: float = 1.0
log_interval: int = 20
ckpt_interval: int = 1000
stage1_steps: int = 60000
stage1_perturb_start: int = 20000
grad_monitor_threshold: float = 1e-7
# === AMP / 混合精度 ===
# "fp32" / "bf16" / "fp16"。默认 bf16(H100/A100 推荐,无需 GradScaler)。
mixed_precision: str = "bf16"
grad_accum_steps: int = 1
# MoE
moe_load_balance_weight: float = 0.01
moe_boundary_weight: float = 0.001
router_temp_init: float = 0.5
router_temp_final: float = 1.0
# 损失初始权重(GradNorm 自适应主任务 1-6)
loss_giou_weight: float = 0.5
loss_calib_weight: float = 0.1
# MultiTask(GradNorm + PCGrad 在 Stage1/Stage2 全程启用——
# 两阶段的 6 项主任务都存在尺度不均与梯度冲突,PCGrad 不应延迟到 Stage2)
enable_gradnorm: bool = True
enable_pcgrad: bool = True
# 参数组
dinov3_lr_mult_stage2: float = 0.01
# 显存吃紧的设备(如 a10g-small)上可关闭 Stage2 DINOv3 解冻,保持冻结
unfreeze_dinov3_at_stage2: bool = True
backbone_lr_mult: float = 1.0
calibration_lr_mult: float = 0.1
head_lr_mult: float = 1.0
gate_lr_mult: float = 0.1
# 检查点:目录 + 可选同步到 Hub model repo
output_dir: str | None = None
hub_repo_id: str | None = None
hub_repo_type: str = "model"
def _is_gate_param(name: str) -> bool:
return ".gate." in name or name.endswith(".gate_proj.weight") or name.endswith(".gate_proj.bias")
def build_param_groups(model: E2EAVModel, base_lr: float, cfg: TrainerConfig, stage: int) -> list[dict]:
"""按模块归类参数为不同 LR 组。Stage1 时 DINOv3 lr=0。"""
groups: dict[str, list[nn.Parameter]] = {
"dinov3": [],
"backbone": [],
"calibration": [],
"head": [],
"gate": [],
"other": [],
}
for name, p in model.named_parameters():
if not p.requires_grad and stage == 1:
continue
if name.startswith("dinov3."):
groups["dinov3"].append(p)
elif name.startswith("backbone."):
if _is_gate_param(name):
groups["gate"].append(p)
else:
groups["backbone"].append(p)
elif name.startswith("calib."):
if _is_gate_param(name):
groups["gate"].append(p)
else:
groups["calibration"].append(p)
elif name.startswith("det_traj_head.") or name.startswith("ctrl_head."):
groups["head"].append(p)
else:
groups["other"].append(p)
dinov3_lr = base_lr * (cfg.dinov3_lr_mult_stage2 if stage == 2 else 0.0)
return [
{"params": groups["dinov3"], "lr": dinov3_lr, "name": "dinov3"},
{"params": groups["backbone"], "lr": base_lr * cfg.backbone_lr_mult, "name": "backbone"},
{"params": groups["calibration"], "lr": base_lr * cfg.calibration_lr_mult, "name": "calibration"},
{"params": groups["head"], "lr": base_lr * cfg.head_lr_mult, "name": "head"},
{"params": groups["gate"], "lr": base_lr * cfg.gate_lr_mult, "name": "gate"},
{"params": groups["other"], "lr": base_lr, "name": "other"},
]
def grad_norm_per_module(model: nn.Module, threshold: float) -> dict[str, float]:
"""统计各顶层模块的 grad-norm,返回 dict(用于日志/告警)。
跳过:
- 没有任何 ``requires_grad=True`` 参数的模块(如冻结的 DINOv3、纯
buffer 模块 RoPE);
- 空模块(参数计数为 0)。
"""
summary: dict[str, float] = {}
for name, child in model.named_children():
params = list(child.parameters())
if not params:
continue
if not any(p.requires_grad for p in params):
# 整个模块被冻结 -> 不监控
continue
total = 0.0
seen = 0
for p in params:
if p.grad is not None:
total += float(p.grad.detach().norm().item()) ** 2
seen += 1
if seen == 0:
continue
n = math.sqrt(total)
summary[name] = n
if n < threshold:
log.warning("[grad_monitor] %s grad_norm=%.3e < %.3e", name, n, threshold)
if not math.isfinite(n):
log.error("[grad_monitor] %s grad_norm is %s (NaN/Inf)", name, n)
return summary
def compute_all_losses(
model_out: E2EOutput,
batch: dict,
matcher: HungarianMatcher,
num_classes: int,
cfg: TrainerConfig,
perturbation_residual: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""计算 8 项损失,返回字典。
``perturbation_residual``:扰动训练时给定的 ground-truth 残差,用于额外
监督校准网络;正常训练为 None。
"""
targets = batch["targets"]
det_out = model_out.detection
ctrl_out = model_out.control
calib = model_out.calibration
det_losses = detection_losses(
cls_logits=det_out.cls_logits,
box_mu=det_out.box3d_mu,
box_log_sigma=det_out.box3d_log_sigma,
isdyn_logit=det_out.is_dynamic_logit,
targets=targets,
matcher=matcher,
num_classes=num_classes,
)
L_traj_obj = object_traj_nll(
det_out.traj_mu,
det_out.traj_log_sigma,
det_losses.matched_indices,
targets,
)
L_traj_ego = ego_traj_nll(
ctrl_out.ego_traj_mu,
ctrl_out.ego_traj_log_sigma,
batch["ego_future"],
valid=batch.get("ego_future_valid"),
)
# 全局动作 GT 通常没有;此处用 0 做占位(实际数据集需补齐)
action_target = batch.get("action_target")
if action_target is None:
action_target = torch.zeros_like(ctrl_out.action_mu)
L_ctrl = action_nll(
ctrl_out.action_mu, ctrl_out.action_log_sigma, action_target
) + L_traj_ego # 控制损失 = action + ego_traj 复用同一项的便利封装;trainer 视情况拆分
# MoE / 校准 正则
L_moe = moe_load_balance_and_boundary(
model_out.backbone_out.moe_stats,
load_balance_weight=cfg.moe_load_balance_weight,
boundary_weight=cfg.moe_boundary_weight,
)
L_calib_reg = calibration_regularization(
calib.ego_residual, calib.intr_residual, calib.extr_residual,
l2_weight=1.0,
)
if perturbation_residual is not None:
# 扰动训练:计算校准网络应该预测的 GT 残差与实际残差的 MSE
actual = torch.cat(
[calib.ego_residual.flatten(1), calib.intr_residual, calib.extr_residual],
dim=-1,
)
L_calib_reg = L_calib_reg + 1.0 * (actual - perturbation_residual).pow(2).mean()
return {
"L_cls": det_losses.cls_loss,
"L_box": det_losses.box_nll + cfg.loss_giou_weight * det_losses.giou_loss,
"L_isdyn": det_losses.isdyn_loss,
"L_traj_obj": L_traj_obj,
"L_traj_ego": L_traj_ego,
"L_ctrl": L_ctrl,
"L_moe": L_moe,
"L_calib": L_calib_reg,
}
MAIN_TASK_KEYS = ["L_cls", "L_box", "L_isdyn", "L_traj_obj", "L_traj_ego", "L_ctrl"]
AUX_TASK_KEYS = ["L_moe", "L_calib"]
class Trainer:
"""端到端训练器。"""
def __init__(
self,
model: E2EAVModel,
cfg: TrainerConfig,
num_classes: int = 22,
device: str = "cuda",
) -> None:
self.model = model.to(device)
self.cfg = cfg
self.num_classes = num_classes
self.device = device
self.matcher = HungarianMatcher()
self.global_step = 0
self._micro_step = 0 # 用于 grad_accum
self._stage = 1
self._build_optimizer()
# === AMP 配置 ===
# 仅在 device 为 cuda 时启用 autocast(CPU 上 bf16 也能跑但收益极小)。
amp_dtype_map = {
"fp32": None,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}
self.amp_dtype = amp_dtype_map[cfg.mixed_precision]
self.amp_enabled = self.amp_dtype is not None and "cuda" in str(device)
# GradScaler 仅 fp16 需要;bf16 数值范围大无需 scaler
self.scaler = (
torch.amp.GradScaler("cuda")
if (self.amp_enabled and self.amp_dtype == torch.float16)
else None
)
# MoE 初始模式 = dense;Stage2 切 sparse
self.model.backbone.set_moe_mode("dense")
self.model.backbone.set_router_temperature(cfg.router_temp_init)
# ---------- 优化器构建 ----------
def _build_optimizer(self) -> None:
cfg = self.cfg
groups = build_param_groups(self.model, cfg.base_lr, cfg, stage=self._stage)
self.optimizer = torch.optim.AdamW(groups, weight_decay=cfg.weight_decay, betas=(0.9, 0.95))
self.scheduler = build_scheduler(
self.optimizer,
warmup_steps=cfg.warmup_steps,
total_steps=cfg.total_steps,
base_lr=cfg.base_lr,
min_lr=cfg.min_lr,
)
# PCGrad 共享参数 = 主干最后的“共享瓶颈”:final_norm + 最后 1 层 MoE block。
# 不把全部 DINOv3/Calib/Backbone 都纳入,否则 N 个任务 × full-grad 扁平副本会
# 在 a10g-small 上瞬间 OOM(~600M 参数 × 6 任务 × 2 副本 ≈ 28 GB)。
# 较前的层仍享受 GradNorm 自适应加权 + 共同求和的标准多任务训练。
shared: list[nn.Parameter] = []
last_moe = self.model.backbone.moe_layers[-1]
for p in self.model.backbone.final_norm.parameters():
if p.requires_grad:
shared.append(p)
for p in last_moe.parameters():
if p.requires_grad:
shared.append(p)
# GradNorm 代理参数:取主干最后 LayerNorm 的 weight
proxy = self.model.backbone.final_norm.weight
mt_cfg = MultiTaskOptimizerConfig(
enable_gradnorm=cfg.enable_gradnorm,
enable_pcgrad=cfg.enable_pcgrad,
gradnorm_alpha=1.5,
gradnorm_lr=0.025,
pcgrad_shuffle=True,
)
self.mto = MultiTaskOptimizer(
num_main_tasks=len(MAIN_TASK_KEYS),
shared_params=shared,
gradnorm_proxy_param=proxy,
cfg=mt_cfg,
)
# GradNormBalancer 是 nn.Module,需要把 raw_weights / initial_losses 缓冲
# 移到 model 所在 device,否则与 losses (cuda) 设备不匹配。
if self.mto.gradnorm is not None:
self.mto.gradnorm.to(self.device)
def _maybe_save_checkpoint(self) -> None:
cfg = self.cfg
if not cfg.output_dir or cfg.ckpt_interval <= 0:
return
if self.global_step <= 0 or self.global_step % cfg.ckpt_interval != 0:
return
od = Path(cfg.output_dir)
od.mkdir(parents=True, exist_ok=True)
ckpt_path = od / f"checkpoint-step{self.global_step}.pt"
torch.save(
{
"step": self.global_step,
"stage": self._stage,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
},
ckpt_path,
)
log.info("[Trainer] checkpoint %s", ckpt_path)
if not cfg.hub_repo_id:
return
try:
from huggingface_hub import HfApi, create_repo
create_repo(cfg.hub_repo_id, repo_type=cfg.hub_repo_type, exist_ok=True)
api = HfApi()
rel = f"checkpoints/{ckpt_path.name}"
api.upload_file(
path_or_fileobj=str(ckpt_path),
path_in_repo=rel,
repo_id=cfg.hub_repo_id,
repo_type=cfg.hub_repo_type,
commit_message=f"checkpoint step {self.global_step}",
)
log.info("[Trainer] uploaded %s -> %s", rel, cfg.hub_repo_id)
except Exception as e:
log.warning("[Trainer] Hub upload failed: %s", e)
# ---------- 阶段切换 ----------
def maybe_switch_stage(self) -> None:
cfg = self.cfg
if self._stage == 1 and self.global_step >= cfg.stage1_steps:
log.info("[Trainer] -> Stage 2 (sparse MoE + DINOv3 finetune + PCGrad)")
self._stage = 2
# 1) MoE 切 sparse
self.model.backbone.set_moe_mode("sparse")
# 2) 路由温度退火完成
self.model.backbone.set_router_temperature(cfg.router_temp_final)
# 3) DINOv3 解冻(小显存设备可禁用)
if cfg.unfreeze_dinov3_at_stage2:
self.model.dinov3.unfreeze()
# 4) 重建优化器(包含 DINOv3 参数)+ 启用 PCGrad
self._build_optimizer()
# ---------- 单步 ----------
def train_step(self, batch: dict, rng: np.random.Generator) -> dict:
cfg = self.cfg
self.maybe_switch_stage()
# 移到 device
batch = {k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
# targets 是 list of dict,里面的 tensor 也移到 device
if "targets" in batch and isinstance(batch["targets"], list):
new_targets = []
for t in batch["targets"]:
new_targets.append({k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in t.items()})
batch["targets"] = new_targets
# 扰动注入(Stage1 中期开启)
perturb_residual = None
ego_input = batch["ego_6d"]
intr_input = batch["intr_vec"]
extr_input = batch["extr_6d"]
if (
self._stage == 1
and self.global_step >= cfg.stage1_perturb_start
and rng.uniform() < 0.5
):
from ..data.transforms import perturb_kinematics
ego_input, intr_input, extr_input, delta = perturb_kinematics(
ego_input.cpu().clone(), intr_input.cpu().clone()[0], extr_input.cpu().clone()[0],
translation_std_m=0.1, rotation_std_deg=0.5,
intrinsic_std=0.005, extrinsic_std=0.005,
rng=rng,
)
ego_input = ego_input.to(self.device)
# intr/extr 是 [B,...] 而 perturb_kinematics 是单样本;这里为简洁仅扰动第 0 个样本
# 实际生产中应 batched 实现
intr_input = batch["intr_vec"].clone()
extr_input = batch["extr_6d"].clone()
# GT 残差:在 symlog 空间 = -delta(symlog 是非线性,这里用线性近似)
perturb_residual = -delta.to(self.device).unsqueeze(0).expand(ego_input.shape[0], -1)
# 前向(AMP autocast 仅包住 forward 与匹配/损失,反传由 PyTorch
# 在 fp32 主梯度下完成;GradNorm/PCGrad 内的 autograd.grad 也在 fp32)
ac_ctx = (
torch.autocast(device_type="cuda", dtype=self.amp_dtype)
if self.amp_enabled
else _NullContext()
)
with ac_ctx:
out = self.model(
images=batch["images"],
ego_6d_raw=ego_input,
intr_raw=intr_input,
extr_6d_raw=extr_input,
)
losses = compute_all_losses(
out, batch, self.matcher, self.num_classes, cfg,
perturbation_residual=perturb_residual,
)
# === 把损失提升到 fp32 以保证后续 GradNorm/PCGrad 数值稳定 ===
main = torch.stack([losses[k].float() for k in MAIN_TASK_KEYS])
aux = sum(losses[k].float() for k in AUX_TASK_KEYS)
# 梯度累积:对累积步数取平均
if cfg.grad_accum_steps > 1:
main = main / cfg.grad_accum_steps
aux = aux / cfg.grad_accum_steps
# === 反传 ===
if self._micro_step == 0:
self.optimizer.zero_grad(set_to_none=True)
all_params = [p for p in self.model.parameters() if p.requires_grad]
if self.scaler is not None:
# fp16 路径:GradScaler 不直接支持 PCGrad(需手动调度);这里
# 退化为标准 sum-backward。bf16 推荐路径无此限制。
total = main.sum() + aux
self.scaler.scale(total).backward()
weights = torch.ones_like(main)
else:
total, weights = self.mto.backward(main, aux, all_params)
self._micro_step += 1
do_step = self._micro_step >= cfg.grad_accum_steps
if not do_step:
info_partial = {
"step": self.global_step,
"stage": self._stage,
"total_loss": float(total),
"weights": [float(w) for w in weights],
"grad_norms": {},
}
for k, v in losses.items():
info_partial[k] = float(v.detach())
return info_partial
# === 梯度裁剪 + 监控 + step ===
if self.scaler is not None:
self.scaler.unscale_(self.optimizer)
grad_summary = grad_norm_per_module(self.model, cfg.grad_monitor_threshold)
torch.nn.utils.clip_grad_norm_(all_params, max_norm=cfg.grad_clip)
if self.scaler is not None:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.scheduler.step()
self._micro_step = 0
# 路由温度线性退火
if self._stage == 1:
ratio = min(1.0, self.global_step / max(1, cfg.stage1_steps))
t = cfg.router_temp_init + ratio * (cfg.router_temp_final - cfg.router_temp_init)
self.model.backbone.set_router_temperature(t)
self.global_step += 1
self._maybe_save_checkpoint()
info = {
"step": self.global_step,
"stage": self._stage,
"total_loss": float(total),
"weights": [float(w) for w in weights],
"grad_norms": grad_summary,
}
for k, v in losses.items():
info[k] = float(v.detach())
return info
def fit(self, loader: DataLoader, max_steps: int | None = None) -> None:
"""简化训练循环。"""
rng = np.random.default_rng(0)
steps = max_steps or self.cfg.total_steps
it = iter(loader)
for _ in range(steps):
try:
batch = next(it)
except StopIteration:
it = iter(loader)
batch = next(it)
info = self.train_step(batch, rng)
if info["step"] % self.cfg.log_interval == 0:
log.info(
"step=%d stage=%d total=%.4f cls=%.4f box=%.4f isdyn=%.4f traj_obj=%.4f traj_ego=%.4f ctrl=%.4f moe=%.4f calib=%.4f",
info["step"], info["stage"], info["total_loss"],
info["L_cls"], info["L_box"], info["L_isdyn"],
info["L_traj_obj"], info["L_traj_ego"], info["L_ctrl"],
info["L_moe"], info["L_calib"],
)