| """两阶段训练器 + 梯度监控。 |
| |
| Stage 1 (Dense): |
| - MoE 全部专家加权(dense 模式); |
| - 路由温度初始 < 1(锐化),训练中线性升到 1; |
| - DINOv3 冻结; |
| - 中期开启运动学/内外参扰动,监督校准网络; |
| - GradNorm 启用。 |
| |
| Stage 2 (Sparse): |
| - MoE 切 Top-3; |
| - 路由温度退火完成; |
| - DINOv3 解冻并采用 1/100 主干 LR; |
| - GradNorm + PCGrad 同时启用。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import math |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Sequence |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
|
|
| from ..losses import ( |
| HungarianMatcher, |
| action_nll, |
| calibration_regularization, |
| detection_losses, |
| ego_traj_nll, |
| moe_load_balance_and_boundary, |
| object_traj_nll, |
| ) |
| from ..model import E2EAVModel, E2EOutput |
| from .multitask import MultiTaskOptimizer, MultiTaskOptimizerConfig |
| from .schedule import build_scheduler |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class _NullContext: |
| """空 context manager,用于 AMP 关闭时占位 autocast。""" |
|
|
| def __enter__(self): |
| return self |
|
|
| def __exit__(self, exc_type, exc, tb): |
| return False |
|
|
|
|
| @dataclass |
| class TrainerConfig: |
| """Trainer 超参数(与 ``configs/default.yaml`` 对齐)。""" |
|
|
| total_steps: int = 100000 |
| warmup_steps: int = 1000 |
| base_lr: float = 2.0e-4 |
| min_lr: float = 1.0e-6 |
| weight_decay: float = 0.05 |
| grad_clip: float = 1.0 |
| log_interval: int = 20 |
| ckpt_interval: int = 1000 |
| stage1_steps: int = 60000 |
| stage1_perturb_start: int = 20000 |
| grad_monitor_threshold: float = 1e-7 |
| |
| |
| mixed_precision: str = "bf16" |
| grad_accum_steps: int = 1 |
| |
| moe_load_balance_weight: float = 0.01 |
| moe_boundary_weight: float = 0.001 |
| router_temp_init: float = 0.5 |
| router_temp_final: float = 1.0 |
| |
| loss_giou_weight: float = 0.5 |
| loss_calib_weight: float = 0.1 |
| |
| |
| enable_gradnorm: bool = True |
| enable_pcgrad: bool = True |
| |
| dinov3_lr_mult_stage2: float = 0.01 |
| |
| unfreeze_dinov3_at_stage2: bool = True |
| backbone_lr_mult: float = 1.0 |
| calibration_lr_mult: float = 0.1 |
| head_lr_mult: float = 1.0 |
| gate_lr_mult: float = 0.1 |
| |
| output_dir: str | None = None |
| hub_repo_id: str | None = None |
| hub_repo_type: str = "model" |
|
|
|
|
| def _is_gate_param(name: str) -> bool: |
| return ".gate." in name or name.endswith(".gate_proj.weight") or name.endswith(".gate_proj.bias") |
|
|
|
|
| def build_param_groups(model: E2EAVModel, base_lr: float, cfg: TrainerConfig, stage: int) -> list[dict]: |
| """按模块归类参数为不同 LR 组。Stage1 时 DINOv3 lr=0。""" |
| groups: dict[str, list[nn.Parameter]] = { |
| "dinov3": [], |
| "backbone": [], |
| "calibration": [], |
| "head": [], |
| "gate": [], |
| "other": [], |
| } |
| for name, p in model.named_parameters(): |
| if not p.requires_grad and stage == 1: |
| continue |
| if name.startswith("dinov3."): |
| groups["dinov3"].append(p) |
| elif name.startswith("backbone."): |
| if _is_gate_param(name): |
| groups["gate"].append(p) |
| else: |
| groups["backbone"].append(p) |
| elif name.startswith("calib."): |
| if _is_gate_param(name): |
| groups["gate"].append(p) |
| else: |
| groups["calibration"].append(p) |
| elif name.startswith("det_traj_head.") or name.startswith("ctrl_head."): |
| groups["head"].append(p) |
| else: |
| groups["other"].append(p) |
|
|
| dinov3_lr = base_lr * (cfg.dinov3_lr_mult_stage2 if stage == 2 else 0.0) |
| return [ |
| {"params": groups["dinov3"], "lr": dinov3_lr, "name": "dinov3"}, |
| {"params": groups["backbone"], "lr": base_lr * cfg.backbone_lr_mult, "name": "backbone"}, |
| {"params": groups["calibration"], "lr": base_lr * cfg.calibration_lr_mult, "name": "calibration"}, |
| {"params": groups["head"], "lr": base_lr * cfg.head_lr_mult, "name": "head"}, |
| {"params": groups["gate"], "lr": base_lr * cfg.gate_lr_mult, "name": "gate"}, |
| {"params": groups["other"], "lr": base_lr, "name": "other"}, |
| ] |
|
|
|
|
| def grad_norm_per_module(model: nn.Module, threshold: float) -> dict[str, float]: |
| """统计各顶层模块的 grad-norm,返回 dict(用于日志/告警)。 |
| |
| 跳过: |
| - 没有任何 ``requires_grad=True`` 参数的模块(如冻结的 DINOv3、纯 |
| buffer 模块 RoPE); |
| - 空模块(参数计数为 0)。 |
| """ |
| summary: dict[str, float] = {} |
| for name, child in model.named_children(): |
| params = list(child.parameters()) |
| if not params: |
| continue |
| if not any(p.requires_grad for p in params): |
| |
| continue |
| total = 0.0 |
| seen = 0 |
| for p in params: |
| if p.grad is not None: |
| total += float(p.grad.detach().norm().item()) ** 2 |
| seen += 1 |
| if seen == 0: |
| continue |
| n = math.sqrt(total) |
| summary[name] = n |
| if n < threshold: |
| log.warning("[grad_monitor] %s grad_norm=%.3e < %.3e", name, n, threshold) |
| if not math.isfinite(n): |
| log.error("[grad_monitor] %s grad_norm is %s (NaN/Inf)", name, n) |
| return summary |
|
|
|
|
| def compute_all_losses( |
| model_out: E2EOutput, |
| batch: dict, |
| matcher: HungarianMatcher, |
| num_classes: int, |
| cfg: TrainerConfig, |
| perturbation_residual: torch.Tensor | None = None, |
| ) -> dict[str, torch.Tensor]: |
| """计算 8 项损失,返回字典。 |
| |
| ``perturbation_residual``:扰动训练时给定的 ground-truth 残差,用于额外 |
| 监督校准网络;正常训练为 None。 |
| """ |
| targets = batch["targets"] |
|
|
| det_out = model_out.detection |
| ctrl_out = model_out.control |
| calib = model_out.calibration |
|
|
| det_losses = detection_losses( |
| cls_logits=det_out.cls_logits, |
| box_mu=det_out.box3d_mu, |
| box_log_sigma=det_out.box3d_log_sigma, |
| isdyn_logit=det_out.is_dynamic_logit, |
| targets=targets, |
| matcher=matcher, |
| num_classes=num_classes, |
| ) |
|
|
| L_traj_obj = object_traj_nll( |
| det_out.traj_mu, |
| det_out.traj_log_sigma, |
| det_losses.matched_indices, |
| targets, |
| ) |
|
|
| L_traj_ego = ego_traj_nll( |
| ctrl_out.ego_traj_mu, |
| ctrl_out.ego_traj_log_sigma, |
| batch["ego_future"], |
| valid=batch.get("ego_future_valid"), |
| ) |
|
|
| |
| action_target = batch.get("action_target") |
| if action_target is None: |
| action_target = torch.zeros_like(ctrl_out.action_mu) |
| L_ctrl = action_nll( |
| ctrl_out.action_mu, ctrl_out.action_log_sigma, action_target |
| ) + L_traj_ego |
|
|
| |
| L_moe = moe_load_balance_and_boundary( |
| model_out.backbone_out.moe_stats, |
| load_balance_weight=cfg.moe_load_balance_weight, |
| boundary_weight=cfg.moe_boundary_weight, |
| ) |
| L_calib_reg = calibration_regularization( |
| calib.ego_residual, calib.intr_residual, calib.extr_residual, |
| l2_weight=1.0, |
| ) |
| if perturbation_residual is not None: |
| |
| actual = torch.cat( |
| [calib.ego_residual.flatten(1), calib.intr_residual, calib.extr_residual], |
| dim=-1, |
| ) |
| L_calib_reg = L_calib_reg + 1.0 * (actual - perturbation_residual).pow(2).mean() |
|
|
| return { |
| "L_cls": det_losses.cls_loss, |
| "L_box": det_losses.box_nll + cfg.loss_giou_weight * det_losses.giou_loss, |
| "L_isdyn": det_losses.isdyn_loss, |
| "L_traj_obj": L_traj_obj, |
| "L_traj_ego": L_traj_ego, |
| "L_ctrl": L_ctrl, |
| "L_moe": L_moe, |
| "L_calib": L_calib_reg, |
| } |
|
|
|
|
| MAIN_TASK_KEYS = ["L_cls", "L_box", "L_isdyn", "L_traj_obj", "L_traj_ego", "L_ctrl"] |
| AUX_TASK_KEYS = ["L_moe", "L_calib"] |
|
|
|
|
| class Trainer: |
| """端到端训练器。""" |
|
|
| def __init__( |
| self, |
| model: E2EAVModel, |
| cfg: TrainerConfig, |
| num_classes: int = 22, |
| device: str = "cuda", |
| ) -> None: |
| self.model = model.to(device) |
| self.cfg = cfg |
| self.num_classes = num_classes |
| self.device = device |
| self.matcher = HungarianMatcher() |
| self.global_step = 0 |
| self._micro_step = 0 |
| self._stage = 1 |
| self._build_optimizer() |
|
|
| |
| |
| amp_dtype_map = { |
| "fp32": None, |
| "bf16": torch.bfloat16, |
| "fp16": torch.float16, |
| } |
| self.amp_dtype = amp_dtype_map[cfg.mixed_precision] |
| self.amp_enabled = self.amp_dtype is not None and "cuda" in str(device) |
| |
| self.scaler = ( |
| torch.amp.GradScaler("cuda") |
| if (self.amp_enabled and self.amp_dtype == torch.float16) |
| else None |
| ) |
|
|
| |
| self.model.backbone.set_moe_mode("dense") |
| self.model.backbone.set_router_temperature(cfg.router_temp_init) |
|
|
| |
|
|
| def _build_optimizer(self) -> None: |
| cfg = self.cfg |
| groups = build_param_groups(self.model, cfg.base_lr, cfg, stage=self._stage) |
| self.optimizer = torch.optim.AdamW(groups, weight_decay=cfg.weight_decay, betas=(0.9, 0.95)) |
| self.scheduler = build_scheduler( |
| self.optimizer, |
| warmup_steps=cfg.warmup_steps, |
| total_steps=cfg.total_steps, |
| base_lr=cfg.base_lr, |
| min_lr=cfg.min_lr, |
| ) |
|
|
| |
| |
| |
| |
| shared: list[nn.Parameter] = [] |
| last_moe = self.model.backbone.moe_layers[-1] |
| for p in self.model.backbone.final_norm.parameters(): |
| if p.requires_grad: |
| shared.append(p) |
| for p in last_moe.parameters(): |
| if p.requires_grad: |
| shared.append(p) |
| |
| proxy = self.model.backbone.final_norm.weight |
| mt_cfg = MultiTaskOptimizerConfig( |
| enable_gradnorm=cfg.enable_gradnorm, |
| enable_pcgrad=cfg.enable_pcgrad, |
| gradnorm_alpha=1.5, |
| gradnorm_lr=0.025, |
| pcgrad_shuffle=True, |
| ) |
| self.mto = MultiTaskOptimizer( |
| num_main_tasks=len(MAIN_TASK_KEYS), |
| shared_params=shared, |
| gradnorm_proxy_param=proxy, |
| cfg=mt_cfg, |
| ) |
| |
| |
| if self.mto.gradnorm is not None: |
| self.mto.gradnorm.to(self.device) |
|
|
| def _maybe_save_checkpoint(self) -> None: |
| cfg = self.cfg |
| if not cfg.output_dir or cfg.ckpt_interval <= 0: |
| return |
| if self.global_step <= 0 or self.global_step % cfg.ckpt_interval != 0: |
| return |
| od = Path(cfg.output_dir) |
| od.mkdir(parents=True, exist_ok=True) |
| ckpt_path = od / f"checkpoint-step{self.global_step}.pt" |
| torch.save( |
| { |
| "step": self.global_step, |
| "stage": self._stage, |
| "model": self.model.state_dict(), |
| "optimizer": self.optimizer.state_dict(), |
| }, |
| ckpt_path, |
| ) |
| log.info("[Trainer] checkpoint %s", ckpt_path) |
| if not cfg.hub_repo_id: |
| return |
| try: |
| from huggingface_hub import HfApi, create_repo |
|
|
| create_repo(cfg.hub_repo_id, repo_type=cfg.hub_repo_type, exist_ok=True) |
| api = HfApi() |
| rel = f"checkpoints/{ckpt_path.name}" |
| api.upload_file( |
| path_or_fileobj=str(ckpt_path), |
| path_in_repo=rel, |
| repo_id=cfg.hub_repo_id, |
| repo_type=cfg.hub_repo_type, |
| commit_message=f"checkpoint step {self.global_step}", |
| ) |
| log.info("[Trainer] uploaded %s -> %s", rel, cfg.hub_repo_id) |
| except Exception as e: |
| log.warning("[Trainer] Hub upload failed: %s", e) |
|
|
| |
|
|
| def maybe_switch_stage(self) -> None: |
| cfg = self.cfg |
| if self._stage == 1 and self.global_step >= cfg.stage1_steps: |
| log.info("[Trainer] -> Stage 2 (sparse MoE + DINOv3 finetune + PCGrad)") |
| self._stage = 2 |
| |
| self.model.backbone.set_moe_mode("sparse") |
| |
| self.model.backbone.set_router_temperature(cfg.router_temp_final) |
| |
| if cfg.unfreeze_dinov3_at_stage2: |
| self.model.dinov3.unfreeze() |
| |
| self._build_optimizer() |
|
|
| |
|
|
| def train_step(self, batch: dict, rng: np.random.Generator) -> dict: |
| cfg = self.cfg |
| self.maybe_switch_stage() |
|
|
| |
| batch = {k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} |
| |
| if "targets" in batch and isinstance(batch["targets"], list): |
| new_targets = [] |
| for t in batch["targets"]: |
| new_targets.append({k: (v.to(self.device) if isinstance(v, torch.Tensor) else v) for k, v in t.items()}) |
| batch["targets"] = new_targets |
|
|
| |
| perturb_residual = None |
| ego_input = batch["ego_6d"] |
| intr_input = batch["intr_vec"] |
| extr_input = batch["extr_6d"] |
| if ( |
| self._stage == 1 |
| and self.global_step >= cfg.stage1_perturb_start |
| and rng.uniform() < 0.5 |
| ): |
| from ..data.transforms import perturb_kinematics |
| ego_input, intr_input, extr_input, delta = perturb_kinematics( |
| ego_input.cpu().clone(), intr_input.cpu().clone()[0], extr_input.cpu().clone()[0], |
| translation_std_m=0.1, rotation_std_deg=0.5, |
| intrinsic_std=0.005, extrinsic_std=0.005, |
| rng=rng, |
| ) |
| ego_input = ego_input.to(self.device) |
| |
| |
| intr_input = batch["intr_vec"].clone() |
| extr_input = batch["extr_6d"].clone() |
| |
| perturb_residual = -delta.to(self.device).unsqueeze(0).expand(ego_input.shape[0], -1) |
|
|
| |
| |
| ac_ctx = ( |
| torch.autocast(device_type="cuda", dtype=self.amp_dtype) |
| if self.amp_enabled |
| else _NullContext() |
| ) |
| with ac_ctx: |
| out = self.model( |
| images=batch["images"], |
| ego_6d_raw=ego_input, |
| intr_raw=intr_input, |
| extr_6d_raw=extr_input, |
| ) |
| losses = compute_all_losses( |
| out, batch, self.matcher, self.num_classes, cfg, |
| perturbation_residual=perturb_residual, |
| ) |
|
|
| |
| main = torch.stack([losses[k].float() for k in MAIN_TASK_KEYS]) |
| aux = sum(losses[k].float() for k in AUX_TASK_KEYS) |
| |
| if cfg.grad_accum_steps > 1: |
| main = main / cfg.grad_accum_steps |
| aux = aux / cfg.grad_accum_steps |
|
|
| |
| if self._micro_step == 0: |
| self.optimizer.zero_grad(set_to_none=True) |
|
|
| all_params = [p for p in self.model.parameters() if p.requires_grad] |
| if self.scaler is not None: |
| |
| |
| total = main.sum() + aux |
| self.scaler.scale(total).backward() |
| weights = torch.ones_like(main) |
| else: |
| total, weights = self.mto.backward(main, aux, all_params) |
|
|
| self._micro_step += 1 |
| do_step = self._micro_step >= cfg.grad_accum_steps |
| if not do_step: |
| info_partial = { |
| "step": self.global_step, |
| "stage": self._stage, |
| "total_loss": float(total), |
| "weights": [float(w) for w in weights], |
| "grad_norms": {}, |
| } |
| for k, v in losses.items(): |
| info_partial[k] = float(v.detach()) |
| return info_partial |
|
|
| |
| if self.scaler is not None: |
| self.scaler.unscale_(self.optimizer) |
| grad_summary = grad_norm_per_module(self.model, cfg.grad_monitor_threshold) |
| torch.nn.utils.clip_grad_norm_(all_params, max_norm=cfg.grad_clip) |
|
|
| if self.scaler is not None: |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| self.optimizer.step() |
| self.scheduler.step() |
| self._micro_step = 0 |
|
|
| |
| if self._stage == 1: |
| ratio = min(1.0, self.global_step / max(1, cfg.stage1_steps)) |
| t = cfg.router_temp_init + ratio * (cfg.router_temp_final - cfg.router_temp_init) |
| self.model.backbone.set_router_temperature(t) |
|
|
| self.global_step += 1 |
| self._maybe_save_checkpoint() |
|
|
| info = { |
| "step": self.global_step, |
| "stage": self._stage, |
| "total_loss": float(total), |
| "weights": [float(w) for w in weights], |
| "grad_norms": grad_summary, |
| } |
| for k, v in losses.items(): |
| info[k] = float(v.detach()) |
| return info |
|
|
| def fit(self, loader: DataLoader, max_steps: int | None = None) -> None: |
| """简化训练循环。""" |
| rng = np.random.default_rng(0) |
| steps = max_steps or self.cfg.total_steps |
| it = iter(loader) |
| for _ in range(steps): |
| try: |
| batch = next(it) |
| except StopIteration: |
| it = iter(loader) |
| batch = next(it) |
| info = self.train_step(batch, rng) |
| if info["step"] % self.cfg.log_interval == 0: |
| log.info( |
| "step=%d stage=%d total=%.4f cls=%.4f box=%.4f isdyn=%.4f traj_obj=%.4f traj_ego=%.4f ctrl=%.4f moe=%.4f calib=%.4f", |
| info["step"], info["stage"], info["total_loss"], |
| info["L_cls"], info["L_box"], info["L_isdyn"], |
| info["L_traj_obj"], info["L_traj_ego"], info["L_ctrl"], |
| info["L_moe"], info["L_calib"], |
| ) |
|
|