"""本地训练入口。 支持两种模式: - ``--tiny`` : 用 1 个 clip / 极少步数验证训练循环可跑通; - 否则默认按 configs/default.yaml 训练(需要数据集解压完成)。 """ from __future__ import annotations import argparse import logging import os import sys from pathlib import Path import torch import yaml from torch.utils.data import DataLoader from ..data.cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index, collate_samples from ..model import E2EAVModel from .trainer import Trainer, TrainerConfig def _load_config(path: str) -> dict: with open(path, "r", encoding="utf-8") as f: return yaml.safe_load(f) def _deep_merge(base: dict, override: dict) -> dict: out = dict(base) for k, v in override.items(): if k in out and isinstance(out[k], dict) and isinstance(v, dict): out[k] = _deep_merge(out[k], v) else: out[k] = v return out def _make_model_from_cfg(cfg: dict, dinov3_path: str) -> E2EAVModel: """根据配置创建模型。""" return E2EAVModel( dinov3_path=dinov3_path, backbone_dim=cfg["backbone"]["hidden_size"], num_heads=cfg["backbone"]["num_heads"], num_dense_layers=cfg["backbone"]["num_dense_layers"], num_moe_layers=cfg["backbone"]["num_moe_layers"], num_routed_experts=cfg["moe"]["num_routed_experts"], num_shared_experts=cfg["moe"]["num_shared_experts"], topk_experts=cfg["moe"]["topk"], ffn_mult=cfg["backbone"]["ffn_mult"], num_history_frames=cfg["input"]["num_history_frames"], num_detection_tokens=cfg["tokens"]["num_detection"], num_control_tokens=cfg["tokens"]["num_control"], num_ego_tokens=cfg["tokens"]["num_ego"], num_extra_tokens=cfg["tokens"]["num_extra"], image_h=cfg["input"]["image_height"], image_w=cfg["input"]["image_width"], patch_size=cfg["dinov3"]["patch_size"], num_classes=cfg["det_traj_head"]["num_classes"], traj_horizon=cfg["det_traj_head"]["traj_horizon"], det_head_hidden=cfg["det_traj_head"]["hidden_size"], ctrl_head_hidden=cfg["control_head"]["hidden_size"], calib_dim=cfg["calibration"]["hidden_size"], calib_num_query=cfg["calibration"]["num_query_tokens"], calib_num_blocks=cfg["calibration"]["num_blocks"], calib_num_self_per_block=cfg["calibration"]["num_self_attn_per_block"], calib_num_heads=cfg["calibration"]["num_heads"], calib_residual_range=cfg["calibration"]["residual_range"], calib_intr_dim=cfg["calibration"]["intr_vec_dim"], freeze_dinov3=cfg["dinov3"]["freeze_in_stage1"], attn_implementation=cfg["dinov3"]["attn_implementation"], ) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--config", default=str(Path(__file__).resolve().parents[3] / "configs" / "default.yaml")) parser.add_argument( "--config_overrides", default=None, help="可选第二份 YAML,与 --config 深度合并(如 configs/jobs_overrides.yaml)", ) parser.add_argument("--data_root", default=None, help="覆盖 config 中的 data.root") parser.add_argument("--dinov3_path", default=str(Path(__file__).resolve().parents[3] / "dinov3-vitb16-pretrain-lvd1689m")) parser.add_argument("--tiny", action="store_true", help="用极少样本验证训练循环") parser.add_argument( "--max_steps", type=int, default=None, help="默认:tiny=50;全量训练=None 使用配置文件 total_steps", ) parser.add_argument("--device", default="cpu") parser.add_argument("--output_dir", default=None, help="检查点目录;也可用环境变量 WJAD_OUTPUT_DIR") parser.add_argument("--hub_repo", default=None, help="Hub model repo(推送 checkpoint);或 WJAD_HUB_REPO") args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") cfg = _load_config(args.config) if args.config_overrides: cfg = _deep_merge(cfg, _load_config(args.config_overrides)) data_root = args.data_root or cfg["data"]["root"] samples = build_clip_index( data_root, weathers=cfg["data"]["weather"], camera_name=cfg["input"]["camera_name"], ) if args.tiny: samples = samples[:8] or samples print(f"[runner_local] 找到 {len(samples)} 个样本于 {data_root}") if not samples: print("[runner_local] 没有数据。请先运行 scripts/download_data.py 准备数据集。") sys.exit(0) dataset = CosmosDriveDreamsDataset( data_root=data_root, samples=samples, camera_name=cfg["input"]["camera_name"], image_h=cfg["input"]["image_height"], image_w=cfg["input"]["image_width"], num_history=cfg["input"]["num_history_frames"], future_horizon=cfg["input"]["num_future_frames"], max_distance_m=cfg["detection"]["max_distance_m"], occlusion_tol=cfg["detection"]["occlusion_depth_tolerance"], ) bs = cfg["train"]["batch_size"] if os.environ.get("WJAD_BATCH_SIZE"): bs = int(os.environ["WJAD_BATCH_SIZE"]) deploy = cfg.get("deploy") or {} output_dir = args.output_dir or os.environ.get("WJAD_OUTPUT_DIR") hub_repo = args.hub_repo or os.environ.get("WJAD_HUB_REPO") or deploy.get("hf_weights_repo") loader = DataLoader( dataset, batch_size=bs, shuffle=True, num_workers=cfg["data"]["num_workers"] if not args.tiny else 0, collate_fn=collate_samples, pin_memory=cfg["data"]["pin_memory"], ) model = _make_model_from_cfg(cfg, args.dinov3_path) if cfg.get("gradient_checkpointing", False): model.backbone.set_gradient_checkpointing(True) tcfg = TrainerConfig( total_steps=cfg["train"]["total_steps"], warmup_steps=cfg["train"]["warmup_steps"], base_lr=cfg["train"]["base_lr"], min_lr=cfg["train"]["min_lr"], weight_decay=cfg["train"]["weight_decay"], grad_clip=cfg["train"]["grad_clip"], log_interval=cfg["train"]["log_interval"], ckpt_interval=cfg["train"]["ckpt_interval"], stage1_steps=cfg["train"]["stage1_steps"], stage1_perturb_start=cfg["train"]["stage1_perturb_start"], grad_monitor_threshold=cfg["train"]["grad_monitor_threshold"], moe_load_balance_weight=cfg["moe"]["load_balance_weight"], moe_boundary_weight=cfg["moe"]["boundary_weight"], router_temp_init=cfg["moe"]["router_temperature_init"], router_temp_final=cfg["moe"]["router_temperature_final"], loss_giou_weight=cfg["loss"]["giou_weight"], loss_calib_weight=cfg["loss"]["calib_weight"], enable_gradnorm=cfg["multitask"]["enable_gradnorm"], enable_pcgrad=cfg["multitask"]["enable_pcgrad"], mixed_precision=cfg["mixed_precision"], grad_accum_steps=cfg["train"]["grad_accum_steps"], dinov3_lr_mult_stage2=cfg["dinov3"]["finetune_lr_ratio"], backbone_lr_mult=cfg["train"]["param_groups"]["backbone_lr_mult"], calibration_lr_mult=cfg["train"]["param_groups"]["calibration_lr_mult"], head_lr_mult=cfg["train"]["param_groups"]["head_lr_mult"], gate_lr_mult=cfg["train"]["param_groups"]["gate_lr_mult"], output_dir=output_dir, hub_repo_id=hub_repo, ) trainer = Trainer(model, tcfg, num_classes=cfg["det_traj_head"]["num_classes"], device=args.device) if args.max_steps is not None: max_steps = args.max_steps else: max_steps = 50 if args.tiny else None trainer.fit(loader, max_steps=max_steps) if __name__ == "__main__": main()