"""两阶段训练器 + 梯度监控。 Stage 1 (Dense): - MoE 全部专家加权(dense 模式); - 路由温度初始 < 1(锐化),训练中线性升到 1; - DINOv3 冻结; - 中期开启运动学/内外参扰动,监督校准网络; - GradNorm 启用。 Stage 2 (Sparse): - MoE 切 Top-3; - 路由温度退火完成; - DINOv3 解冻并采用 1/100 主干 LR; - GradNorm + PCGrad 同时启用。 """ from __future__ import annotations import logging import math from dataclasses import dataclass from pathlib import Path from typing import Sequence import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from ..losses import ( HungarianMatcher, action_nll, calibration_regularization, detection_losses, ego_traj_nll, moe_load_balance_and_boundary, object_traj_nll, ) from ..model import E2EAVModel, E2EOutput from .multitask import MultiTaskOptimizer, MultiTaskOptimizerConfig from .schedule import build_scheduler log = logging.getLogger(__name__) class _NullContext: """空 context manager,用于 AMP 关闭时占位 autocast。""" def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False @dataclass class TrainerConfig: """Trainer 超参数(与 ``configs/default.yaml`` 对齐)。""" total_steps: int = 100000 warmup_steps: int = 1000 base_lr: float = 2.0e-4 min_lr: float = 1.0e-6 weight_decay: float = 0.05 grad_clip: float = 1.0 log_interval: int = 20 ckpt_interval: int = 1000 stage1_steps: int = 60000 stage1_perturb_start: int = 20000 grad_monitor_threshold: float = 1e-7 # === AMP / 混合精度 === # "fp32" / "bf16" / "fp16"。默认 bf16(H100/A100 推荐,无需 GradScaler)。 mixed_precision: str = "bf16" grad_accum_steps: int = 1 # MoE moe_load_balance_weight: float = 0.01 moe_boundary_weight: float = 0.001 router_temp_init: float = 0.5 router_temp_final: float = 1.0 # 损失初始权重(GradNorm 自适应主任务 1-6) loss_giou_weight: float = 0.5 loss_calib_weight: float = 0.1 # MultiTask(GradNorm + PCGrad 在 Stage1/Stage2 全程启用—— # 两阶段的 6 项主任务都存在尺度不均与梯度冲突,PCGrad 不应延迟到 Stage2) enable_gradnorm: bool = True enable_pcgrad: bool = True # 参数组 dinov3_lr_mult_stage2: float = 0.01 # 显存吃紧的设备(如 a10g-small)上可关闭 Stage2 DINOv3 解冻,保持冻结 unfreeze_dinov3_at_stage2: bool = True backbone_lr_mult: float = 1.0 calibration_lr_mult: float = 0.1 head_lr_mult: float = 1.0 gate_lr_mult: float = 0.1 # 检查点:目录 + 可选同步到 Hub model repo output_dir: str | None = None hub_repo_id: str | None = None hub_repo_type: str = "model" def _is_gate_param(name: str) -> bool: return ".gate." in name or name.endswith(".gate_proj.weight") or name.endswith(".gate_proj.bias") def build_param_groups(model: E2EAVModel, base_lr: float, cfg: TrainerConfig, stage: int) -> list[dict]: """按模块归类参数为不同 LR 组。Stage1 时 DINOv3 lr=0。""" groups: dict[str, list[nn.Parameter]] = { "dinov3": [], "backbone": [], "calibration": [], "head": [], "gate": [], "other": [], } for name, p in model.named_parameters(): if not p.requires_grad and stage == 1: continue if name.startswith("dinov3."): groups["dinov3"].append(p) elif name.startswith("backbone."): if _is_gate_param(name): groups["gate"].append(p) else: groups["backbone"].append(p) elif name.startswith("calib."): if _is_gate_param(name): groups["gate"].append(p) else: groups["calibration"].append(p) elif name.startswith("det_traj_head.") or name.startswith("ctrl_head."): groups["head"].append(p) else: groups["other"].append(p) dinov3_lr = base_lr * (cfg.dinov3_lr_mult_stage2 if stage == 2 else 0.0) return [ {"params": groups["dinov3"], "lr": dinov3_lr, "name": "dinov3"}, {"params": groups["backbone"], "lr": base_lr * cfg.backbone_lr_mult, "name": "backbone"}, {"params": groups["calibration"], "lr": base_lr * cfg.calibration_lr_mult, "name": "calibration"}, {"params": groups["head"], "lr": base_lr * cfg.head_lr_mult, "name": "head"}, {"params": groups["gate"], "lr": base_lr * cfg.gate_lr_mult, "name": "gate"}, {"params": groups["other"], "lr": base_lr, "name": "other"}, ] def grad_norm_per_module(model: nn.Module, threshold: float) -> dict[str, float]: """统计各顶层模块的 grad-norm,返回 dict(用于日志/告警)。 跳过: - 没有任何 ``requires_grad=True`` 参数的模块(如冻结的 DINOv3、纯 buffer 模块 RoPE); - 空模块(参数计数为 0)。 """ summary: dict[str, float] = {} for name, child in model.named_children(): params = list(child.parameters()) if not params: continue if not any(p.requires_grad for p in params): # 整个模块被冻结 -> 不监控 continue total = 0.0 seen = 0 for p in params: if p.grad is not None: total += float(p.grad.detach().norm().item()) ** 2 seen += 1 if seen == 0: continue n = math.sqrt(total) summary[name] = n if n < threshold: log.warning("[grad_monitor] %s grad_norm=%.3e < %.3e", name, n, threshold) if not math.isfinite(n): log.error("[grad_monitor] %s grad_norm is %s (NaN/Inf)", name, n) return summary def compute_all_losses( model_out: E2EOutput, batch: dict, matcher: HungarianMatcher, num_classes: int, cfg: TrainerConfig, perturbation_residual: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """计算 8 项损失,返回字典。 ``perturbation_residual``:扰动训练时给定的 ground-truth 残差,用于额外 监督校准网络;正常训练为 None。 """ targets = batch["targets"] det_out = model_out.detection ctrl_out = model_out.control calib = model_out.calibration det_losses = detection_losses( cls_logits=det_out.cls_logits, box_mu=det_out.box3d_mu, box_log_sigma=det_out.box3d_log_sigma, isdyn_logit=det_out.is_dynamic_logit, targets=targets, matcher=matcher, num_classes=num_classes, ) L_traj_obj = object_traj_nll( det_out.traj_mu, det_out.traj_log_sigma, det_losses.matched_indices, targets, ) L_traj_ego = ego_traj_nll( ctrl_out.ego_traj_mu, ctrl_out.ego_traj_log_sigma, batch["ego_future"], valid=batch.get("ego_future_valid"), ) # 全局动作 GT 通常没有;此处用 0 做占位(实际数据集需补齐) action_target = batch.get("action_target") if action_target is None: action_target = torch.zeros_like(ctrl_out.action_mu) L_ctrl = action_nll( ctrl_out.action_mu, ctrl_out.action_log_sigma, action_target ) + L_traj_ego # 控制损失 = action + ego_traj 复用同一项的便利封装;trainer 视情况拆分 # MoE / 校准 正则 L_moe = moe_load_balance_and_boundary( model_out.backbone_out.moe_stats, load_balance_weight=cfg.moe_load_balance_weight, boundary_weight=cfg.moe_boundary_weight, ) L_calib_reg = calibration_regularization( calib.ego_residual, calib.intr_residual, calib.extr_residual, l2_weight=1.0, ) if perturbation_residual is not None: # 扰动训练:计算校准网络应该预测的 GT 残差与实际残差的 MSE actual = torch.cat( [calib.ego_residual.flatten(1), calib.intr_residual, calib.extr_residual], dim=-1, ) L_calib_reg = L_calib_reg + 1.0 * (actual - perturbation_residual).pow(2).mean() return { "L_cls": det_losses.cls_loss, "L_box": det_losses.box_nll + cfg.loss_giou_weight * det_losses.giou_loss, "L_isdyn": det_losses.isdyn_loss, "L_traj_obj": L_traj_obj, "L_traj_ego": L_traj_ego, "L_ctrl": L_ctrl, "L_moe": L_moe, "L_calib": L_calib_reg, } MAIN_TASK_KEYS = ["L_cls", "L_box", "L_isdyn", "L_traj_obj", "L_traj_ego", "L_ctrl"] AUX_TASK_KEYS = ["L_moe", "L_calib"] class Trainer: """端到端训练器。""" def __init__( self, model: E2EAVModel, cfg: TrainerConfig, num_classes: int = 22, device: str = "cuda", ) -> None: self.model = model.to(device) self.cfg = cfg self.num_classes = num_classes self.device = device self.matcher = HungarianMatcher() self.global_step = 0 self._micro_step = 0 # 用于 grad_accum self._stage = 1 self._build_optimizer() # === AMP 配置 === # 仅在 device 为 cuda 时启用 autocast(CPU 上 bf16 也能跑但收益极小)。 amp_dtype_map = { "fp32": None, "bf16": torch.bfloat16, "fp16": torch.float16, } self.amp_dtype = amp_dtype_map[cfg.mixed_precision] self.amp_enabled = self.amp_dtype is not None and "cuda" in str(device) # GradScaler 仅 fp16 需要;bf16 数值范围大无需 scaler self.scaler = ( torch.amp.GradScaler("cuda") if (self.amp_enabled and self.amp_dtype == torch.float16) else None ) # MoE 初始模式 = dense;Stage2 切 sparse self.model.backbone.set_moe_mode("dense") self.model.backbone.set_router_temperature(cfg.router_temp_init) # ---------- 优化器构建 ---------- def _build_optimizer(self) -> None: cfg = self.cfg groups = build_param_groups(self.model, cfg.base_lr, cfg, stage=self._stage) self.optimizer = torch.optim.AdamW(groups, weight_decay=cfg.weight_decay, betas=(0.9, 0.95)) self.scheduler = build_scheduler( self.optimizer, warmup_steps=cfg.warmup_steps, total_steps=cfg.total_steps, base_lr=cfg.base_lr, min_lr=cfg.min_lr, ) # PCGrad 共享参数 = 主干最后的“共享瓶颈”:final_norm + 最后 1 层 MoE block。 # 不把全部 DINOv3/Calib/Backbone 都纳入,否则 N 个任务 × full-grad 扁平副本会 # 在 a10g-small 上瞬间 OOM(~600M 参数 × 6 任务 × 2 副本 ≈ 28 GB)。 # 较前的层仍享受 GradNorm 自适应加权 + 共同求和的标准多任务训练。 shared: list[nn.Parameter] = [] last_moe = self.model.backbone.moe_layers[-1] for p in self.model.backbone.final_norm.parameters(): if p.requires_grad: shared.append(p) for p in last_moe.parameters(): if p.requires_grad: shared.append(p) # GradNorm 代理参数:取主干最后 LayerNorm 的 weight proxy = self.model.backbone.final_norm.weight mt_cfg = MultiTaskOptimizerConfig( enable_gradnorm=cfg.enable_gradnorm, enable_pcgrad=cfg.enable_pcgrad, gradnorm_alpha=1.5, gradnorm_lr=0.025, pcgrad_shuffle=True, ) self.mto = MultiTaskOptimizer( num_main_tasks=len(MAIN_TASK_KEYS), shared_params=shared, gradnorm_proxy_param=proxy, cfg=mt_cfg, ) # GradNormBalancer 是 nn.Module,需要把 raw_weights / initial_losses 缓冲 # 移到 model 所在 device,否则与 losses (cuda) 设备不匹配。 if self.mto.gradnorm is not None: self.mto.gradnorm.to(self.device) def _maybe_save_checkpoint(self) -> None: cfg = self.cfg if not cfg.output_dir or cfg.ckpt_interval <= 0: return if self.global_step <= 0 or self.global_step % cfg.ckpt_interval != 0: return od = Path(cfg.output_dir) od.mkdir(parents=True, exist_ok=True) ckpt_path = od / f"checkpoint-step{self.global_step}.pt" torch.save( { "step": self.global_step, "stage": self._stage, "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), }, ckpt_path, ) log.info("[Trainer] checkpoint %s", ckpt_path) if not cfg.hub_repo_id: return try: from huggingface_hub import HfApi, create_repo create_repo(cfg.hub_repo_id, repo_type=cfg.hub_repo_type, exist_ok=True) api = HfApi() rel = f"checkpoints/{ckpt_path.name}" api.upload_file( path_or_fileobj=str(ckpt_path), path_in_repo=rel, repo_id=cfg.hub_repo_id, repo_type=cfg.hub_repo_type, commit_message=f"checkpoint step {self.global_step}", ) log.info("[Trainer] uploaded %s -> %s", rel, cfg.hub_repo_id) except Exception as e: log.warning("[Trainer] Hub upload failed: %s", e) # ---------- 阶段切换 ---------- def maybe_switch_stage(self) -> None: cfg = self.cfg if self._stage == 1 and self.global_step >= cfg.stage1_steps: log.info("[Trainer] -> Stage 2 (sparse MoE + DINOv3 finetune + PCGrad)") self._stage = 2 # 1) MoE 切 sparse self.model.backbone.set_moe_mode("sparse") # 2) 路由温度退火完成 self.model.backbone.set_router_temperature(cfg.router_temp_final) # 3) DINOv3 解冻(小显存设备可禁用) if cfg.unfreeze_dinov3_at_stage2: self.model.dinov3.unfreeze() # 4) 重建优化器(包含 DINOv3 参数)+ 启用 PCGrad self._build_optimizer() # ---------- 单步 ---------- def train_step(self, batch: dict, rng: np.random.Generator) -> dict: cfg = self.cfg self.maybe_switch_stage() # 移到 device batch = {k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} # targets 是 list of dict,里面的 tensor 也移到 device if "targets" in batch and isinstance(batch["targets"], list): new_targets = [] for t in batch["targets"]: new_targets.append({k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in t.items()}) batch["targets"] = new_targets # 扰动注入(Stage1 中期开启) perturb_residual = None ego_input = batch["ego_6d"] intr_input = batch["intr_vec"] extr_input = batch["extr_6d"] if ( self._stage == 1 and self.global_step >= cfg.stage1_perturb_start and rng.uniform() < 0.5 ): from ..data.transforms import perturb_kinematics ego_input, intr_input, extr_input, delta = perturb_kinematics( ego_input.cpu().clone(), intr_input.cpu().clone()[0], extr_input.cpu().clone()[0], translation_std_m=0.1, rotation_std_deg=0.5, intrinsic_std=0.005, extrinsic_std=0.005, rng=rng, ) ego_input = ego_input.to(self.device) # intr/extr 是 [B,...] 而 perturb_kinematics 是单样本;这里为简洁仅扰动第 0 个样本 # 实际生产中应 batched 实现 intr_input = batch["intr_vec"].clone() extr_input = batch["extr_6d"].clone() # GT 残差:在 symlog 空间 = -delta(symlog 是非线性,这里用线性近似) perturb_residual = -delta.to(self.device).unsqueeze(0).expand(ego_input.shape[0], -1) # 前向(AMP autocast 仅包住 forward 与匹配/损失,反传由 PyTorch # 在 fp32 主梯度下完成;GradNorm/PCGrad 内的 autograd.grad 也在 fp32) ac_ctx = ( torch.autocast(device_type="cuda", dtype=self.amp_dtype) if self.amp_enabled else _NullContext() ) with ac_ctx: out = self.model( images=batch["images"], ego_6d_raw=ego_input, intr_raw=intr_input, extr_6d_raw=extr_input, ) losses = compute_all_losses( out, batch, self.matcher, self.num_classes, cfg, perturbation_residual=perturb_residual, ) # === 把损失提升到 fp32 以保证后续 GradNorm/PCGrad 数值稳定 === main = torch.stack([losses[k].float() for k in MAIN_TASK_KEYS]) aux = sum(losses[k].float() for k in AUX_TASK_KEYS) # 梯度累积:对累积步数取平均 if cfg.grad_accum_steps > 1: main = main / cfg.grad_accum_steps aux = aux / cfg.grad_accum_steps # === 反传 === if self._micro_step == 0: self.optimizer.zero_grad(set_to_none=True) all_params = [p for p in self.model.parameters() if p.requires_grad] if self.scaler is not None: # fp16 路径:GradScaler 不直接支持 PCGrad(需手动调度);这里 # 退化为标准 sum-backward。bf16 推荐路径无此限制。 total = main.sum() + aux self.scaler.scale(total).backward() weights = torch.ones_like(main) else: total, weights = self.mto.backward(main, aux, all_params) self._micro_step += 1 do_step = self._micro_step >= cfg.grad_accum_steps if not do_step: info_partial = { "step": self.global_step, "stage": self._stage, "total_loss": float(total), "weights": [float(w) for w in weights], "grad_norms": {}, } for k, v in losses.items(): info_partial[k] = float(v.detach()) return info_partial # === 梯度裁剪 + 监控 + step === if self.scaler is not None: self.scaler.unscale_(self.optimizer) grad_summary = grad_norm_per_module(self.model, cfg.grad_monitor_threshold) torch.nn.utils.clip_grad_norm_(all_params, max_norm=cfg.grad_clip) if self.scaler is not None: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.scheduler.step() self._micro_step = 0 # 路由温度线性退火 if self._stage == 1: ratio = min(1.0, self.global_step / max(1, cfg.stage1_steps)) t = cfg.router_temp_init + ratio * (cfg.router_temp_final - cfg.router_temp_init) self.model.backbone.set_router_temperature(t) self.global_step += 1 self._maybe_save_checkpoint() info = { "step": self.global_step, "stage": self._stage, "total_loss": float(total), "weights": [float(w) for w in weights], "grad_norms": grad_summary, } for k, v in losses.items(): info[k] = float(v.detach()) return info def fit(self, loader: DataLoader, max_steps: int | None = None) -> None: """简化训练循环。""" rng = np.random.default_rng(0) steps = max_steps or self.cfg.total_steps it = iter(loader) for _ in range(steps): try: batch = next(it) except StopIteration: it = iter(loader) batch = next(it) info = self.train_step(batch, rng) if info["step"] % self.cfg.log_interval == 0: log.info( "step=%d stage=%d total=%.4f cls=%.4f box=%.4f isdyn=%.4f traj_obj=%.4f traj_ego=%.4f ctrl=%.4f moe=%.4f calib=%.4f", info["step"], info["stage"], info["total_loss"], info["L_cls"], info["L_box"], info["L_isdyn"], info["L_traj_obj"], info["L_traj_ego"], info["L_ctrl"], info["L_moe"], info["L_calib"], )