| """本地训练入口。 |
| |
| 支持两种模式: |
| - ``--tiny`` : 用 1 个 clip / 极少步数验证训练循环可跑通; |
| - 否则默认按 configs/default.yaml 训练(需要数据集解压完成)。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import logging |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import yaml |
| from torch.utils.data import DataLoader |
|
|
| from ..data.cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index, collate_samples |
| from ..model import E2EAVModel |
| from .trainer import Trainer, TrainerConfig |
|
|
|
|
| def _load_config(path: str) -> dict: |
| with open(path, "r", encoding="utf-8") as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def _deep_merge(base: dict, override: dict) -> dict: |
| out = dict(base) |
| for k, v in override.items(): |
| if k in out and isinstance(out[k], dict) and isinstance(v, dict): |
| out[k] = _deep_merge(out[k], v) |
| else: |
| out[k] = v |
| return out |
|
|
|
|
| def _make_model_from_cfg(cfg: dict, dinov3_path: str) -> E2EAVModel: |
| """根据配置创建模型。""" |
| return E2EAVModel( |
| dinov3_path=dinov3_path, |
| backbone_dim=cfg["backbone"]["hidden_size"], |
| num_heads=cfg["backbone"]["num_heads"], |
| num_dense_layers=cfg["backbone"]["num_dense_layers"], |
| num_moe_layers=cfg["backbone"]["num_moe_layers"], |
| num_routed_experts=cfg["moe"]["num_routed_experts"], |
| num_shared_experts=cfg["moe"]["num_shared_experts"], |
| topk_experts=cfg["moe"]["topk"], |
| ffn_mult=cfg["backbone"]["ffn_mult"], |
| num_history_frames=cfg["input"]["num_history_frames"], |
| num_detection_tokens=cfg["tokens"]["num_detection"], |
| num_control_tokens=cfg["tokens"]["num_control"], |
| num_ego_tokens=cfg["tokens"]["num_ego"], |
| num_extra_tokens=cfg["tokens"]["num_extra"], |
| image_h=cfg["input"]["image_height"], |
| image_w=cfg["input"]["image_width"], |
| patch_size=cfg["dinov3"]["patch_size"], |
| num_classes=cfg["det_traj_head"]["num_classes"], |
| traj_horizon=cfg["det_traj_head"]["traj_horizon"], |
| det_head_hidden=cfg["det_traj_head"]["hidden_size"], |
| ctrl_head_hidden=cfg["control_head"]["hidden_size"], |
| calib_dim=cfg["calibration"]["hidden_size"], |
| calib_num_query=cfg["calibration"]["num_query_tokens"], |
| calib_num_blocks=cfg["calibration"]["num_blocks"], |
| calib_num_self_per_block=cfg["calibration"]["num_self_attn_per_block"], |
| calib_num_heads=cfg["calibration"]["num_heads"], |
| calib_residual_range=cfg["calibration"]["residual_range"], |
| calib_intr_dim=cfg["calibration"]["intr_vec_dim"], |
| freeze_dinov3=cfg["dinov3"]["freeze_in_stage1"], |
| attn_implementation=cfg["dinov3"]["attn_implementation"], |
| ) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", default=str(Path(__file__).resolve().parents[3] / "configs" / "default.yaml")) |
| parser.add_argument( |
| "--config_overrides", |
| default=None, |
| help="可选第二份 YAML,与 --config 深度合并(如 configs/jobs_overrides.yaml)", |
| ) |
| parser.add_argument("--data_root", default=None, help="覆盖 config 中的 data.root") |
| parser.add_argument("--dinov3_path", default=str(Path(__file__).resolve().parents[3] / "dinov3-vitb16-pretrain-lvd1689m")) |
| parser.add_argument("--tiny", action="store_true", help="用极少样本验证训练循环") |
| parser.add_argument( |
| "--max_steps", |
| type=int, |
| default=None, |
| help="默认:tiny=50;全量训练=None 使用配置文件 total_steps", |
| ) |
| parser.add_argument("--device", default="cpu") |
| parser.add_argument("--output_dir", default=None, help="检查点目录;也可用环境变量 WJAD_OUTPUT_DIR") |
| parser.add_argument("--hub_repo", default=None, help="Hub model repo(推送 checkpoint);或 WJAD_HUB_REPO") |
| args = parser.parse_args() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| cfg = _load_config(args.config) |
| if args.config_overrides: |
| cfg = _deep_merge(cfg, _load_config(args.config_overrides)) |
| data_root = args.data_root or cfg["data"]["root"] |
|
|
| samples = build_clip_index( |
| data_root, |
| weathers=cfg["data"]["weather"], |
| camera_name=cfg["input"]["camera_name"], |
| ) |
| if args.tiny: |
| samples = samples[:8] or samples |
| print(f"[runner_local] 找到 {len(samples)} 个样本于 {data_root}") |
| if not samples: |
| print("[runner_local] 没有数据。请先运行 scripts/download_data.py 准备数据集。") |
| sys.exit(0) |
|
|
| dataset = CosmosDriveDreamsDataset( |
| data_root=data_root, |
| samples=samples, |
| camera_name=cfg["input"]["camera_name"], |
| image_h=cfg["input"]["image_height"], |
| image_w=cfg["input"]["image_width"], |
| num_history=cfg["input"]["num_history_frames"], |
| future_horizon=cfg["input"]["num_future_frames"], |
| max_distance_m=cfg["detection"]["max_distance_m"], |
| occlusion_tol=cfg["detection"]["occlusion_depth_tolerance"], |
| ) |
| bs = cfg["train"]["batch_size"] |
| if os.environ.get("WJAD_BATCH_SIZE"): |
| bs = int(os.environ["WJAD_BATCH_SIZE"]) |
|
|
| deploy = cfg.get("deploy") or {} |
| output_dir = args.output_dir or os.environ.get("WJAD_OUTPUT_DIR") |
| hub_repo = args.hub_repo or os.environ.get("WJAD_HUB_REPO") or deploy.get("hf_weights_repo") |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=bs, |
| shuffle=True, |
| num_workers=cfg["data"]["num_workers"] if not args.tiny else 0, |
| collate_fn=collate_samples, |
| pin_memory=cfg["data"]["pin_memory"], |
| ) |
|
|
| model = _make_model_from_cfg(cfg, args.dinov3_path) |
| if cfg.get("gradient_checkpointing", False): |
| model.backbone.set_gradient_checkpointing(True) |
|
|
| tcfg = TrainerConfig( |
| total_steps=cfg["train"]["total_steps"], |
| warmup_steps=cfg["train"]["warmup_steps"], |
| base_lr=cfg["train"]["base_lr"], |
| min_lr=cfg["train"]["min_lr"], |
| weight_decay=cfg["train"]["weight_decay"], |
| grad_clip=cfg["train"]["grad_clip"], |
| log_interval=cfg["train"]["log_interval"], |
| ckpt_interval=cfg["train"]["ckpt_interval"], |
| stage1_steps=cfg["train"]["stage1_steps"], |
| stage1_perturb_start=cfg["train"]["stage1_perturb_start"], |
| grad_monitor_threshold=cfg["train"]["grad_monitor_threshold"], |
| moe_load_balance_weight=cfg["moe"]["load_balance_weight"], |
| moe_boundary_weight=cfg["moe"]["boundary_weight"], |
| router_temp_init=cfg["moe"]["router_temperature_init"], |
| router_temp_final=cfg["moe"]["router_temperature_final"], |
| loss_giou_weight=cfg["loss"]["giou_weight"], |
| loss_calib_weight=cfg["loss"]["calib_weight"], |
| enable_gradnorm=cfg["multitask"]["enable_gradnorm"], |
| enable_pcgrad=cfg["multitask"]["enable_pcgrad"], |
| mixed_precision=cfg["mixed_precision"], |
| grad_accum_steps=cfg["train"]["grad_accum_steps"], |
| dinov3_lr_mult_stage2=cfg["dinov3"]["finetune_lr_ratio"], |
| backbone_lr_mult=cfg["train"]["param_groups"]["backbone_lr_mult"], |
| calibration_lr_mult=cfg["train"]["param_groups"]["calibration_lr_mult"], |
| head_lr_mult=cfg["train"]["param_groups"]["head_lr_mult"], |
| gate_lr_mult=cfg["train"]["param_groups"]["gate_lr_mult"], |
| output_dir=output_dir, |
| hub_repo_id=hub_repo, |
| ) |
| trainer = Trainer(model, tcfg, num_classes=cfg["det_traj_head"]["num_classes"], device=args.device) |
| if args.max_steps is not None: |
| max_steps = args.max_steps |
| else: |
| max_steps = 50 if args.tiny else None |
| trainer.fit(loader, max_steps=max_steps) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|