WJAD / src /wjad /train /runner_local.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""本地训练入口。
支持两种模式:
- ``--tiny`` : 用 1 个 clip / 极少步数验证训练循环可跑通;
- 否则默认按 configs/default.yaml 训练(需要数据集解压完成)。
"""
from __future__ import annotations
import argparse
import logging
import os
import sys
from pathlib import Path
import torch
import yaml
from torch.utils.data import DataLoader
from ..data.cosmos_dataset import CosmosDriveDreamsDataset, build_clip_index, collate_samples
from ..model import E2EAVModel
from .trainer import Trainer, TrainerConfig
def _load_config(path: str) -> dict:
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def _deep_merge(base: dict, override: dict) -> dict:
out = dict(base)
for k, v in override.items():
if k in out and isinstance(out[k], dict) and isinstance(v, dict):
out[k] = _deep_merge(out[k], v)
else:
out[k] = v
return out
def _make_model_from_cfg(cfg: dict, dinov3_path: str) -> E2EAVModel:
"""根据配置创建模型。"""
return E2EAVModel(
dinov3_path=dinov3_path,
backbone_dim=cfg["backbone"]["hidden_size"],
num_heads=cfg["backbone"]["num_heads"],
num_dense_layers=cfg["backbone"]["num_dense_layers"],
num_moe_layers=cfg["backbone"]["num_moe_layers"],
num_routed_experts=cfg["moe"]["num_routed_experts"],
num_shared_experts=cfg["moe"]["num_shared_experts"],
topk_experts=cfg["moe"]["topk"],
ffn_mult=cfg["backbone"]["ffn_mult"],
num_history_frames=cfg["input"]["num_history_frames"],
num_detection_tokens=cfg["tokens"]["num_detection"],
num_control_tokens=cfg["tokens"]["num_control"],
num_ego_tokens=cfg["tokens"]["num_ego"],
num_extra_tokens=cfg["tokens"]["num_extra"],
image_h=cfg["input"]["image_height"],
image_w=cfg["input"]["image_width"],
patch_size=cfg["dinov3"]["patch_size"],
num_classes=cfg["det_traj_head"]["num_classes"],
traj_horizon=cfg["det_traj_head"]["traj_horizon"],
det_head_hidden=cfg["det_traj_head"]["hidden_size"],
ctrl_head_hidden=cfg["control_head"]["hidden_size"],
calib_dim=cfg["calibration"]["hidden_size"],
calib_num_query=cfg["calibration"]["num_query_tokens"],
calib_num_blocks=cfg["calibration"]["num_blocks"],
calib_num_self_per_block=cfg["calibration"]["num_self_attn_per_block"],
calib_num_heads=cfg["calibration"]["num_heads"],
calib_residual_range=cfg["calibration"]["residual_range"],
calib_intr_dim=cfg["calibration"]["intr_vec_dim"],
freeze_dinov3=cfg["dinov3"]["freeze_in_stage1"],
attn_implementation=cfg["dinov3"]["attn_implementation"],
)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=str(Path(__file__).resolve().parents[3] / "configs" / "default.yaml"))
parser.add_argument(
"--config_overrides",
default=None,
help="可选第二份 YAML,与 --config 深度合并(如 configs/jobs_overrides.yaml)",
)
parser.add_argument("--data_root", default=None, help="覆盖 config 中的 data.root")
parser.add_argument("--dinov3_path", default=str(Path(__file__).resolve().parents[3] / "dinov3-vitb16-pretrain-lvd1689m"))
parser.add_argument("--tiny", action="store_true", help="用极少样本验证训练循环")
parser.add_argument(
"--max_steps",
type=int,
default=None,
help="默认:tiny=50;全量训练=None 使用配置文件 total_steps",
)
parser.add_argument("--device", default="cpu")
parser.add_argument("--output_dir", default=None, help="检查点目录;也可用环境变量 WJAD_OUTPUT_DIR")
parser.add_argument("--hub_repo", default=None, help="Hub model repo(推送 checkpoint);或 WJAD_HUB_REPO")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
cfg = _load_config(args.config)
if args.config_overrides:
cfg = _deep_merge(cfg, _load_config(args.config_overrides))
data_root = args.data_root or cfg["data"]["root"]
samples = build_clip_index(
data_root,
weathers=cfg["data"]["weather"],
camera_name=cfg["input"]["camera_name"],
)
if args.tiny:
samples = samples[:8] or samples
print(f"[runner_local] 找到 {len(samples)} 个样本于 {data_root}")
if not samples:
print("[runner_local] 没有数据。请先运行 scripts/download_data.py 准备数据集。")
sys.exit(0)
dataset = CosmosDriveDreamsDataset(
data_root=data_root,
samples=samples,
camera_name=cfg["input"]["camera_name"],
image_h=cfg["input"]["image_height"],
image_w=cfg["input"]["image_width"],
num_history=cfg["input"]["num_history_frames"],
future_horizon=cfg["input"]["num_future_frames"],
max_distance_m=cfg["detection"]["max_distance_m"],
occlusion_tol=cfg["detection"]["occlusion_depth_tolerance"],
)
bs = cfg["train"]["batch_size"]
if os.environ.get("WJAD_BATCH_SIZE"):
bs = int(os.environ["WJAD_BATCH_SIZE"])
deploy = cfg.get("deploy") or {}
output_dir = args.output_dir or os.environ.get("WJAD_OUTPUT_DIR")
hub_repo = args.hub_repo or os.environ.get("WJAD_HUB_REPO") or deploy.get("hf_weights_repo")
loader = DataLoader(
dataset,
batch_size=bs,
shuffle=True,
num_workers=cfg["data"]["num_workers"] if not args.tiny else 0,
collate_fn=collate_samples,
pin_memory=cfg["data"]["pin_memory"],
)
model = _make_model_from_cfg(cfg, args.dinov3_path)
if cfg.get("gradient_checkpointing", False):
model.backbone.set_gradient_checkpointing(True)
tcfg = TrainerConfig(
total_steps=cfg["train"]["total_steps"],
warmup_steps=cfg["train"]["warmup_steps"],
base_lr=cfg["train"]["base_lr"],
min_lr=cfg["train"]["min_lr"],
weight_decay=cfg["train"]["weight_decay"],
grad_clip=cfg["train"]["grad_clip"],
log_interval=cfg["train"]["log_interval"],
ckpt_interval=cfg["train"]["ckpt_interval"],
stage1_steps=cfg["train"]["stage1_steps"],
stage1_perturb_start=cfg["train"]["stage1_perturb_start"],
grad_monitor_threshold=cfg["train"]["grad_monitor_threshold"],
moe_load_balance_weight=cfg["moe"]["load_balance_weight"],
moe_boundary_weight=cfg["moe"]["boundary_weight"],
router_temp_init=cfg["moe"]["router_temperature_init"],
router_temp_final=cfg["moe"]["router_temperature_final"],
loss_giou_weight=cfg["loss"]["giou_weight"],
loss_calib_weight=cfg["loss"]["calib_weight"],
enable_gradnorm=cfg["multitask"]["enable_gradnorm"],
enable_pcgrad=cfg["multitask"]["enable_pcgrad"],
mixed_precision=cfg["mixed_precision"],
grad_accum_steps=cfg["train"]["grad_accum_steps"],
dinov3_lr_mult_stage2=cfg["dinov3"]["finetune_lr_ratio"],
backbone_lr_mult=cfg["train"]["param_groups"]["backbone_lr_mult"],
calibration_lr_mult=cfg["train"]["param_groups"]["calibration_lr_mult"],
head_lr_mult=cfg["train"]["param_groups"]["head_lr_mult"],
gate_lr_mult=cfg["train"]["param_groups"]["gate_lr_mult"],
output_dir=output_dir,
hub_repo_id=hub_repo,
)
trainer = Trainer(model, tcfg, num_classes=cfg["det_traj_head"]["num_classes"], device=args.device)
if args.max_steps is not None:
max_steps = args.max_steps
else:
max_steps = 50 if args.tiny else None
trainer.fit(loader, max_steps=max_steps)
if __name__ == "__main__":
main()