| |
| import argparse |
| import os |
| import sys |
| import math |
| import time |
| import json |
| import logging |
| from datetime import timedelta |
| from pathlib import Path |
| from typing import Dict, Optional, List |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import DataLoader |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent)) |
|
|
| from src.model.modeling_atlas import AtlasForCausalLM |
| from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens |
| from src.model.streampetr_adapter import extract_streampetr_topk_tokens |
| from src.dataset.atlas_dataset import ( |
| AtlasDataset, make_atlas_collate_fn, load_tokenizer, |
| ) |
| from src.dataset.scene_sampler import ( |
| SceneSequentialSampler, |
| SceneUnitTaskBalancedSampler, |
| ) |
| from src.prompting import PLANNING_TABLE3_MODES |
|
|
| logger = logging.getLogger("train_atlas") |
| TASK_LOSS_WEIGHT_KEYS = ("detection", "planning", "caption", "lane") |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--llm_model", default="lmsys/vicuna-7b-v1.5") |
| p.add_argument("--visual_hidden_size", type=int, default=256) |
| p.add_argument("--num_det_queries", type=int, default=256) |
| p.add_argument("--num_map_queries", type=int, default=256) |
| p.add_argument("--streampetr_config", default=None) |
| p.add_argument("--streampetr_ckpt", default=None) |
| p.add_argument("--topomlp_config", default=None) |
| p.add_argument("--topomlp_ckpt", default=None) |
| p.add_argument("--data_json", required=True) |
| p.add_argument("--data_root", default="/mnt/data/nuscenes") |
| p.add_argument("--openlane_root", default=None, |
| help="Root for OpenLane relative image paths (lane task). " |
| "Falls back to --data_root if not set.") |
| p.add_argument("--max_length", type=int, default=4096) |
| p.add_argument("--output_dir", default="work_dirs/atlas") |
| p.add_argument("--lr", type=float, default=2e-5) |
| p.add_argument("--weight_decay", type=float, default=1e-4) |
| p.add_argument("--batch_size", type=int, default=1) |
| p.add_argument("--epochs", type=int, default=8) |
| p.add_argument("--warmup_ratio", type=float, default=0.03) |
| p.add_argument("--gradient_accumulation_steps", type=int, default=2) |
| p.add_argument("--max_grad_norm", type=float, default=1.0) |
| p.add_argument("--use_lora", action="store_true") |
| p.add_argument("--lora_r", type=int, default=64) |
| p.add_argument("--lora_alpha", type=int, default=64) |
| p.add_argument("--lora_dropout", type=float, default=0.1) |
| p.add_argument("--load_in_4bit", action="store_true") |
| p.add_argument("--save_steps", type=int, default=0) |
| p.add_argument("--save_epochs", type=int, default=1) |
| p.add_argument("--log_steps", type=int, default=10) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--num_workers", type=int, default=4) |
| p.add_argument("--resume", default=None) |
| p.add_argument("--local_rank", "--local-rank", type=int, default=int(os.environ.get("LOCAL_RANK", -1))) |
| p.add_argument("--fp16", action="store_true") |
| p.add_argument("--bf16", action="store_true") |
| p.add_argument("--image_path_remap", default=None, |
| help="old=new path remap, e.g. /mnt/data=/local/data") |
| p.add_argument("--precomputed_det_tokens", default=None, |
| help="[offline only] Dir with precomputed det tokens (.pt files)") |
| p.add_argument("--precomputed_map_tokens", default=None, |
| help="[offline only] Dir with precomputed TopoMLP map tokens (.pt files)") |
| p.add_argument("--visual_token_mode", choices=("online", "offline"), default="online", |
| help="Visual token source: online=live frozen encoders (default), offline=read *_offline dirs") |
| p.add_argument("--deepspeed", default=None, |
| help="Path to DeepSpeed config JSON (enables ZeRO)") |
| p.add_argument("--keep_last_n_ckpts", type=int, default=0, |
| help="Keep only the N most recent epoch checkpoints (0=keep all)") |
| p.add_argument( |
| "--task_balance_mode", |
| choices=("none", "scene_unit_111"), |
| default="none", |
| help=( |
| "Online sampler mode. " |
| "none=legacy scene-sequential sampling; " |
| "scene_unit_111=scene-sequential unit-level 1:1:1 balance over " |
| "detection/planning/caption timestamps (caption raw expands to 6 views)." |
| ), |
| ) |
| p.add_argument( |
| "--task_loss_weights", |
| default="", |
| help=( |
| "Static task weights applied to the scalar LM loss. " |
| "Format: detection=0.35,planning=1.0,caption=0.05. " |
| "Unspecified tasks default to 1.0." |
| ), |
| ) |
| p.add_argument( |
| "--planning_table3_mode", |
| choices=PLANNING_TABLE3_MODES, |
| default="atlas_base", |
| help=( |
| "Planning prompt variant matching Atlas Table 3: " |
| "atlas_base=no command/no explicit ego state; " |
| "atlas_high_level=requires top-level route_command " |
| "(this repo uses a UniAD-style future-GT-derived command); " |
| "atlas_high_level_ego=requires top-level route_command plus " |
| "velocity/acceleration bins." |
| ), |
| ) |
| args = p.parse_args() |
| args.task_loss_weights = _parse_task_loss_weights(args.task_loss_weights) |
| return args |
|
|
|
|
| def _parse_task_loss_weights(raw: str) -> Dict[str, float]: |
| weights = {task: 1.0 for task in TASK_LOSS_WEIGHT_KEYS} |
| if raw is None: |
| return weights |
| text = str(raw).strip() |
| if not text: |
| return weights |
| for entry in text.split(","): |
| item = entry.strip() |
| if not item: |
| continue |
| if "=" not in item: |
| raise ValueError( |
| "task_loss_weights must use key=value pairs separated by commas. " |
| f"Invalid entry: {item!r}" |
| ) |
| key, value = item.split("=", 1) |
| task = key.strip().lower() |
| if task not in weights: |
| raise ValueError( |
| f"Unsupported task in task_loss_weights: {task!r}. " |
| f"Expected one of {TASK_LOSS_WEIGHT_KEYS}." |
| ) |
| try: |
| numeric = float(value) |
| except Exception as exc: |
| raise ValueError( |
| f"Invalid float in task_loss_weights for task {task!r}: {value!r}" |
| ) from exc |
| if numeric <= 0.0: |
| raise ValueError( |
| f"task_loss_weights values must be > 0. Got {numeric} for task {task!r}." |
| ) |
| weights[task] = numeric |
| return weights |
|
|
|
|
| def _validate_visual_token_mode(args): |
| """Enforce mode-specific constraints. Fail hard, never silently degrade.""" |
| if args.task_balance_mode != "none" and args.visual_token_mode != "online": |
| raise RuntimeError( |
| "task_balance_mode requires --visual_token_mode online. " |
| f"Got: visual_token_mode={args.visual_token_mode!r}" |
| ) |
| if args.visual_token_mode == "online": |
| if args.precomputed_det_tokens or args.precomputed_map_tokens: |
| raise RuntimeError( |
| "visual_token_mode=online forbids --precomputed_det_tokens / " |
| "--precomputed_map_tokens. Use --visual_token_mode offline to " |
| "read offline token directories." |
| ) |
| missing = [] |
| if not args.streampetr_config or not args.streampetr_ckpt: |
| missing.append("--streampetr_config/--streampetr_ckpt") |
| if not args.topomlp_config or not args.topomlp_ckpt: |
| missing.append("--topomlp_config/--topomlp_ckpt") |
| if missing: |
| raise RuntimeError( |
| "visual_token_mode=online requires live encoder configs and " |
| "checkpoints. Missing: " + ", ".join(missing) |
| ) |
| for p in (args.streampetr_config, args.streampetr_ckpt, args.topomlp_config, args.topomlp_ckpt): |
| if not os.path.exists(p): |
| raise RuntimeError(f"Required online asset does not exist: {p}") |
| if args.batch_size != 1: |
| raise RuntimeError( |
| "visual_token_mode=online with temporal memory requires " |
| "--batch_size 1 (paper-aligned). Got: %d" % args.batch_size |
| ) |
| else: |
| if not args.precomputed_det_tokens and not args.precomputed_map_tokens: |
| raise RuntimeError( |
| "visual_token_mode=offline requires at least one " |
| "--precomputed_*_tokens directory." |
| ) |
| for p in (args.precomputed_det_tokens, args.precomputed_map_tokens): |
| if p and not os.path.isdir(p): |
| raise RuntimeError(f"Offline token directory does not exist: {p}") |
|
|
|
|
| def _format_task_counts(counts: Dict[str, int]) -> str: |
| ordered = ["detection", "planning", "caption"] |
| parts = [f"{task}={int(counts.get(task, 0))}" for task in ordered] |
| nonzero = [int(counts.get(task, 0)) for task in ordered if int(counts.get(task, 0)) > 0] |
| if nonzero: |
| base = float(min(nonzero)) |
| ratio = ":".join(f"{int(counts.get(task, 0)) / base:.2f}" for task in ordered) |
| else: |
| ratio = "0:0:0" |
| return f"{', '.join(parts)} | ratio={ratio}" |
|
|
|
|
| def _format_task_loss_weights(weights: Dict[str, float]) -> str: |
| return ", ".join(f"{task}={float(weights.get(task, 1.0)):.4f}" for task in TASK_LOSS_WEIGHT_KEYS) |
|
|
|
|
| def _update_task_raw_loss_ema( |
| ema_state: Dict[str, float], |
| task: str, |
| raw_loss_value: float, |
| beta: float = 0.98, |
| ) -> None: |
| prev = ema_state.get(task) |
| if prev is None: |
| ema_state[task] = float(raw_loss_value) |
| else: |
| ema_state[task] = float(beta) * float(prev) + (1.0 - float(beta)) * float(raw_loss_value) |
|
|
|
|
| def _format_task_raw_loss_ema(ema_state: Dict[str, float]) -> str: |
| mapping = [ |
| ("detection", "det_raw_loss_ema"), |
| ("planning", "planning_raw_loss_ema"), |
| ("caption", "caption_raw_loss_ema"), |
| ] |
| parts = [] |
| for task, label in mapping: |
| value = ema_state.get(task) |
| if value is None: |
| parts.append(f"{label}=NA") |
| else: |
| parts.append(f"{label}={float(value):.4f}") |
| return " ".join(parts) |
|
|
|
|
| def _resolve_batch_task_weight(batch_task_types: List[str], task_weights: Dict[str, float]) -> float: |
| if not batch_task_types: |
| return 1.0 |
| tasks = [str(task).strip().lower() for task in batch_task_types if str(task).strip()] |
| if not tasks: |
| return 1.0 |
| unique_tasks = sorted(set(tasks)) |
| if len(unique_tasks) == 1: |
| return float(task_weights.get(unique_tasks[0], 1.0)) |
|
|
| unique_weights = { |
| task: float(task_weights.get(task, 1.0)) |
| for task in unique_tasks |
| } |
| rounded_values = {round(val, 12) for val in unique_weights.values()} |
| if len(rounded_values) != 1: |
| raise RuntimeError( |
| "Mixed-task batch encountered with non-uniform task_loss_weights. " |
| f"batch_tasks={unique_tasks} weights={unique_weights}. " |
| "Use batch_size=1 / online mode, or set equal weights for mixed-task batches." |
| ) |
| return float(next(iter(unique_weights.values()))) |
|
|
|
|
| def set_seed(seed): |
| import random |
| import numpy as np |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def setup_distributed(local_rank): |
| if local_rank == -1: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| return device, False, 0, 1 |
| dist.init_process_group(backend="nccl", timeout=timedelta(seconds=1800)) |
| torch.cuda.set_device(local_rank) |
| device = torch.device("cuda", local_rank) |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| return device, True, rank, world_size |
|
|
|
|
| def is_main_process(distributed, rank): |
| return (not distributed) or (rank == 0) |
|
|
|
|
| def load_frozen_encoder(config_path, ckpt_path, model_type, device): |
| if config_path is None or ckpt_path is None: |
| return None |
| try: |
| from mmcv import Config |
| from mmdet3d.models import build_model |
| from mmcv.runner import load_checkpoint |
| except ImportError: |
| raise RuntimeError( |
| f"mmcv/mmdet3d not installed but --{model_type}_config and " |
| f"--{model_type}_ckpt were explicitly provided. " |
| f"Install mmcv/mmdet3d or remove these arguments to train without {model_type}." |
| ) |
|
|
| if model_type == "streampetr": |
| sp_root = str(Path(__file__).resolve().parent / "external" / "StreamPETR") |
| if sp_root not in sys.path: |
| sys.path.insert(0, sp_root) |
| try: |
| import projects.mmdet3d_plugin |
| except ImportError: |
| raise RuntimeError( |
| f"StreamPETR plugin not found under {sp_root}/projects/mmdet3d_plugin. " |
| f"Ensure the submodule is checked out, or remove --streampetr_config/--streampetr_ckpt." |
| ) |
| elif model_type == "topomlp": |
| tp_root = str(Path(__file__).resolve().parent / "external" / "TopoMLP_Repo") |
| if tp_root not in sys.path: |
| sys.path.insert(0, tp_root) |
| try: |
| os.environ["ATLAS_TOPOMLP_MODELS_ONLY"] = "1" |
| from mmcv.utils import registry as _reg |
| _orig = _reg.Registry._register_module |
| def _tolerant_register(self, module, module_name=None, force=False): |
| return _orig(self, module, module_name=module_name, force=True) |
| _reg.Registry._register_module = _tolerant_register |
| import projects.topomlp |
| _reg.Registry._register_module = _orig |
| except ImportError: |
| raise RuntimeError( |
| f"TopoMLP plugin not found under {tp_root}/projects/topomlp. " |
| f"Ensure the submodule is checked out, or remove --topomlp_config/--topomlp_ckpt." |
| ) |
|
|
| cfg = Config.fromfile(config_path) |
| model = build_model(cfg.model, test_cfg=cfg.get("test_cfg")) |
| load_checkpoint(model, ckpt_path, map_location="cpu") |
| model.eval() |
| model.to(device) |
| for param in model.parameters(): |
| param.requires_grad_(False) |
| logger.info("Loaded frozen %s from %s", model_type, ckpt_path) |
| return model |
|
|
|
|
| def build_img_metas_streampetr(batch, device, idx): |
| N = batch["pixel_values_det"].shape[1] |
| fH, fW = 800, 1600 |
| scene_ids = batch.get("scene_id", ["__atlas__"] * (idx + 1)) |
| meta = { |
| "pad_shape": [(fH, fW, 3)] * N, |
| "img_shape": [(fH, fW, 3)] * N, |
| "scene_token": scene_ids[idx] if idx < len(scene_ids) else "__atlas__", |
| } |
| if "lidar2img_det" in batch: |
| meta["lidar2img"] = batch["lidar2img_det"][idx].cpu().numpy() |
| return meta |
|
|
|
|
| def build_img_metas_topomlp(batch, device, idx): |
| meta = {} |
| if "lidar2img_map" in batch: |
| meta["lidar2img"] = batch["lidar2img_map"][idx].cpu().numpy() |
| tH, tW = 800, 1600 |
| N = batch["pixel_values_map"].shape[1] |
| meta["img_shape"] = tuple([(tH, tW, 3)] * N) |
| meta["pad_shape"] = tuple([(tH, tW, 3)] * N) |
| meta["scale_factor"] = 1.0 |
| meta["te_yolov8"] = None |
| return meta |
|
|
|
|
| @torch.no_grad() |
| def run_streampetr_forward(model, imgs, img_metas, batch, device, prev_exists=None): |
| B, N = imgs.shape[:2] |
|
|
| img_feats = model.extract_img_feat(imgs, 1) |
|
|
| data = { |
| "img": imgs, |
| "img_feats": img_feats, |
| "prev_exists": prev_exists if prev_exists is not None else imgs.new_zeros(B), |
| } |
|
|
| if "intrinsics_det" in batch: |
| K3 = batch["intrinsics_det"].to(device) |
| K4 = torch.zeros(B, N, 4, 4, device=device, dtype=K3.dtype) |
| K4[:, :, :3, :3] = K3 |
| K4[:, :, 3, 3] = 1.0 |
| data["intrinsics"] = K4 |
| else: |
| data["intrinsics"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous() |
|
|
| if "lidar2img_det" in batch: |
| data["lidar2img"] = batch["lidar2img_det"].to(device) |
| else: |
| data["lidar2img"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous() |
|
|
| if "ego_pose" in batch and batch["ego_pose"] is not None: |
| data["ego_pose"] = batch["ego_pose"].to(device) |
| else: |
| data["ego_pose"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous() |
|
|
| if "ego_pose_inv" in batch and batch["ego_pose_inv"] is not None: |
| data["ego_pose_inv"] = batch["ego_pose_inv"].to(device) |
| else: |
| data["ego_pose_inv"] = torch.inverse(data["ego_pose"]) |
|
|
| if "timestamp" in batch and batch["timestamp"] is not None: |
| data["timestamp"] = batch["timestamp"].to(device) |
| else: |
| data["timestamp"] = torch.zeros(B, device=device) |
|
|
| location = model.prepare_location(img_metas, **data) |
| outs_roi = model.forward_roi_head(location, **data) |
| topk_indexes = outs_roi["topk_indexes"] |
|
|
| outs = model.pts_bbox_head(location, img_metas, topk_indexes, **data) |
| return outs |
|
|
|
|
| @torch.no_grad() |
| def run_topomlp_forward(model, imgs, img_metas): |
| return model.simple_forward(imgs, img_metas) |
|
|
|
|
| def _reconstruct_topomlp_outs(saved: dict, device, dtype): |
| """Convert precomputed .pt dict back to the format adapter.forward() expects.""" |
| def _restore(t): |
| return t.to(device=device, dtype=dtype).unsqueeze(0) |
| return { |
| "lc_outs_dec_list": [_restore(saved["lc_outs_dec"])], |
| "all_lc_cls_scores_list": [_restore(saved["lc_cls_scores"])], |
| "all_lc_preds_list": [_restore(saved["lc_preds"])], |
| "lc_outs_dec_one2many_list": [_restore(saved["lc_outs_dec_o2m"])], |
| "all_lc_cls_scores_one2many_list": [_restore(saved["lc_cls_scores_o2m"])], |
| "all_lc_preds_one2many_list": [_restore(saved["lc_preds_o2m"])], |
| } |
|
|
|
|
| def extract_visual_tokens( |
| streampetr_model, |
| topomlp_model, |
| topomlp_adapter, |
| batch, |
| device, |
| num_det_queries=256, |
| visual_hidden_size=256, |
| query_token_id=None, |
| visual_token_mode="online", |
| streaming_state=None, |
| ): |
| """Extract det + map visual tokens. |
| |
| In online mode with streaming_state, StreamPETR temporal memory is managed |
| per-scene and duplicate physical frames are protected: if the current |
| sample_id equals the previous one, we reuse cached det tokens and skip the |
| StreamPETR forward to avoid pushing the same frame into memory twice. |
| """ |
| B = batch["pixel_values_det"].shape[0] |
| vis: Dict[str, torch.Tensor] = {} |
|
|
| needs_map = False |
| if query_token_id is not None and "input_ids" in batch: |
| n_queries = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item()) |
| needs_map = n_queries > num_det_queries |
|
|
| |
| if visual_token_mode == "offline" and "precomputed_det" in batch and "precomputed_det_ref" in batch: |
| vis["detection"] = batch["precomputed_det"].to(device) |
| vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device) |
| elif visual_token_mode == "offline": |
| raise RuntimeError( |
| "visual_token_mode=offline but detection precomputed tokens are missing " |
| "for the current batch. Refusing to zero-fill." |
| ) |
| elif streampetr_model is not None: |
| if B != 1 and streaming_state is not None: |
| raise RuntimeError("online temporal det requires batch_size=1") |
|
|
| current_sample_id = batch.get("sample_id", [None])[0] |
| current_scene = batch.get("scene_id", ["__atlas__"])[0] |
| reuse_cache = False |
|
|
| if streaming_state is not None: |
| prev_scene = streaming_state.get("prev_scene_token") |
| prev_sample_id = streaming_state.get("prev_sample_id") |
| ts_tensor = batch.get("timestamp") |
| current_ts = float(ts_tensor[0].item()) if ts_tensor is not None else None |
| prev_ts = streaming_state.get("prev_timestamp") |
|
|
| is_new_segment = ( |
| prev_scene is None |
| or current_scene != prev_scene |
| or (current_ts is not None and prev_ts is not None and current_ts <= prev_ts) |
| ) |
|
|
| if current_sample_id is not None and current_sample_id == prev_sample_id: |
| cached = streaming_state.get("cached_det") |
| if cached is not None: |
| reuse_cache = True |
| vis["detection"] = cached["detection"] |
| vis["detection_ref_points"] = cached["detection_ref_points"] |
|
|
| if not reuse_cache: |
| if is_new_segment: |
| streampetr_model.pts_bbox_head.reset_memory() |
| prev_exists_val = 0.0 if is_new_segment else 1.0 |
| imgs_det = batch["pixel_values_det"].to(device) |
| prev_exists = imgs_det.new_full((B,), prev_exists_val) |
|
|
| img_metas = [build_img_metas_streampetr(batch, device, b) for b in range(B)] |
| run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device, prev_exists=prev_exists) |
| ego_pose_for_ref = batch.get("ego_pose") |
| if ego_pose_for_ref is not None: |
| ego_pose_for_ref = ego_pose_for_ref.to(device) |
| det_out = extract_streampetr_topk_tokens( |
| streampetr_model.pts_bbox_head, |
| topk=num_det_queries, |
| ego_pose=ego_pose_for_ref, |
| ) |
| vis["detection"] = det_out["detection"] |
| vis["detection_ref_points"] = det_out["detection_ref_points"] |
|
|
| streaming_state["cached_det"] = { |
| "detection": vis["detection"], |
| "detection_ref_points": vis["detection_ref_points"], |
| } |
|
|
| streaming_state["prev_scene_token"] = current_scene |
| streaming_state["prev_sample_id"] = current_sample_id |
| if batch.get("timestamp") is not None: |
| streaming_state["prev_timestamp"] = float(batch["timestamp"][0].item()) |
| else: |
| imgs_det = batch["pixel_values_det"].to(device) |
| img_metas = [build_img_metas_streampetr(batch, device, b) for b in range(B)] |
| run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device) |
| ego_pose_for_ref = batch.get("ego_pose") |
| if ego_pose_for_ref is not None: |
| ego_pose_for_ref = ego_pose_for_ref.to(device) |
| det_out = extract_streampetr_topk_tokens( |
| streampetr_model.pts_bbox_head, |
| topk=num_det_queries, |
| ego_pose=ego_pose_for_ref, |
| ) |
| vis["detection"] = det_out["detection"] |
| vis["detection_ref_points"] = det_out["detection_ref_points"] |
| elif visual_token_mode == "online": |
| raise RuntimeError( |
| "visual_token_mode=online but StreamPETR model is None. " |
| "Provide --streampetr_config and --streampetr_ckpt." |
| ) |
| else: |
| vis["detection"] = torch.zeros(B, num_det_queries, visual_hidden_size, device=device) |
| vis["detection_ref_points"] = torch.zeros(B, num_det_queries, 3, device=device) |
|
|
| |
| num_map_queries = num_det_queries |
| if topomlp_adapter is not None: |
| num_map_queries = topomlp_adapter.num_map_tokens |
|
|
| if topomlp_adapter is not None: |
| _params = list(topomlp_adapter.parameters()) |
| _bufs = list(topomlp_adapter.buffers()) |
| adapter_dtype = _params[0].dtype if _params else (_bufs[0].dtype if _bufs else torch.float32) |
|
|
| map_filled = False |
| if visual_token_mode == "offline" and needs_map and "precomputed_map" in batch: |
| if B == 1: |
| outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype) |
| else: |
| per_sample = [_reconstruct_topomlp_outs(batch["precomputed_map"][b], device, adapter_dtype) for b in range(B)] |
| outs = {} |
| for k in per_sample[0]: |
| outs[k] = [torch.cat([s[k][i] for s in per_sample], dim=0) for i in range(len(per_sample[0][k]))] |
| map_out = topomlp_adapter(outs) |
| vis["map"] = map_out["map"] |
| vis["map_ref_points"] = map_out["map_ref_points"] |
| map_filled = True |
| elif visual_token_mode == "offline" and needs_map: |
| raise RuntimeError( |
| "visual_token_mode=offline but map precomputed tokens are missing " |
| "for a batch that requires map queries. Refusing to zero-fill." |
| ) |
| elif needs_map and topomlp_model is not None: |
| imgs_map = batch["pixel_values_map"].to(device) |
| img_metas = [build_img_metas_topomlp(batch, device, b) for b in range(B)] |
| outs = run_topomlp_forward(topomlp_model, imgs_map, img_metas) |
| for k, v in outs.items(): |
| if isinstance(v, torch.Tensor): |
| outs[k] = v.to(adapter_dtype) |
| elif isinstance(v, list): |
| outs[k] = [x.to(adapter_dtype) if isinstance(x, torch.Tensor) else x for x in v] |
| map_out = topomlp_adapter(outs) |
| vis["map"] = map_out["map"] |
| vis["map_ref_points"] = map_out["map_ref_points"] |
| map_filled = True |
| elif needs_map and visual_token_mode == "online": |
| raise RuntimeError( |
| "visual_token_mode=online but TopoMLP model is None. " |
| "Provide --topomlp_config and --topomlp_ckpt." |
| ) |
|
|
| if not map_filled: |
| vis["map"] = torch.zeros(B, num_map_queries, visual_hidden_size, device=device) |
| vis["map_ref_points"] = torch.zeros(B, num_map_queries, 3, device=device) |
|
|
| return vis |
|
|
|
|
| def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.0): |
| def lr_lambda(step): |
| if step < num_warmup_steps: |
| return float(step) / float(max(1, num_warmup_steps)) |
| progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress))) |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| def _optimizer_steps_per_epoch(num_batches: int, grad_accum_steps: int) -> int: |
| if num_batches <= 0: |
| return 0 |
| return int(math.ceil(float(num_batches) / float(max(1, grad_accum_steps)))) |
|
|
|
|
| def _accum_window_size_for_batch( |
| batch_idx: int, |
| num_batches: int, |
| grad_accum_steps: int, |
| ) -> int: |
| """Return the effective accumulation window size for this batch. |
| |
| Full windows use `grad_accum_steps`. The final partial window uses the |
| remainder so that tail batches are not under-scaled when we flush them. |
| """ |
| grad_accum_steps = max(1, int(grad_accum_steps)) |
| remainder = int(num_batches % grad_accum_steps) |
| tail_start = int(num_batches - remainder) |
| if remainder > 0 and batch_idx >= tail_start: |
| return remainder |
| return grad_accum_steps |
|
|
|
|
| def _is_optimizer_step_batch( |
| batch_idx: int, |
| num_batches: int, |
| grad_accum_steps: int, |
| ) -> bool: |
| grad_accum_steps = max(1, int(grad_accum_steps)) |
| natural_boundary = ((batch_idx + 1) % grad_accum_steps) == 0 |
| is_last_batch = (batch_idx + 1) == num_batches |
| return natural_boundary or is_last_batch |
|
|
|
|
| def save_checkpoint(path, atlas, adapter, optimizer, scheduler, global_step, epoch, args): |
| save_dict = { |
| "global_step": global_step, |
| "epoch": epoch, |
| "args": vars(args), |
| "atlas_state_dict": {k: v.cpu() for k, v in atlas.state_dict().items()}, |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict() if scheduler is not None else None, |
| } |
| if adapter is not None: |
| save_dict["adapter_state_dict"] = {k: v.cpu() for k, v in adapter.state_dict().items()} |
| Path(path).parent.mkdir(parents=True, exist_ok=True) |
| torch.save(save_dict, path) |
|
|
|
|
| def cleanup_old_checkpoints(output_dir: Path, keep_n: int): |
| """Delete old epoch-* checkpoint dirs, keeping only the most recent *keep_n*.""" |
| if keep_n <= 0: |
| return |
| import shutil |
| epoch_dirs = sorted( |
| [d for d in output_dir.iterdir() if d.is_dir() and d.name.startswith("epoch-")], |
| key=lambda d: int(d.name.split("-")[1]), |
| ) |
| while len(epoch_dirs) > keep_n: |
| old = epoch_dirs.pop(0) |
| shutil.rmtree(old, ignore_errors=True) |
| logger.info("Deleted old checkpoint: %s", old) |
|
|
|
|
| def main(): |
| args = parse_args() |
| class _FlushHandler(logging.StreamHandler): |
| def emit(self, record): |
| super().emit(record) |
| self.flush() |
| logging.root.handlers.clear() |
| _h = _FlushHandler(sys.stderr) |
| _fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") |
| _h.setFormatter(_fmt) |
| logging.root.addHandler(_h) |
| logging.root.setLevel(logging.INFO) |
|
|
| _validate_visual_token_mode(args) |
|
|
| device, distributed, rank, world_size = setup_distributed(args.local_rank) |
| set_seed(args.seed + rank) |
| _main = is_main_process(distributed, rank) |
|
|
| is_online = args.visual_token_mode == "online" |
|
|
| output_dir = Path(args.output_dir) |
| if _main: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| _fh = logging.FileHandler(str(output_dir / "train.log"), mode="a") |
| _fh.setFormatter(_fmt) |
| logging.root.addHandler(_fh) |
| with open(output_dir / "args.json", "w") as f: |
| json.dump(vars(args), f, indent=2) |
|
|
| if _main: |
| logger.info("Loading tokenizer: %s", args.llm_model) |
| tokenizer = load_tokenizer(args.llm_model) |
| if "<query>" not in tokenizer.get_vocab(): |
| tokenizer.add_tokens(["<query>"]) |
|
|
| _precomp_det = args.precomputed_det_tokens if not is_online else None |
| _precomp_map = args.precomputed_map_tokens if not is_online else None |
| dataset = AtlasDataset( |
| json_file=args.data_json, |
| image_root=args.data_root, |
| tokenizer=tokenizer, |
| max_length=args.max_length, |
| is_training=True, |
| planning_table3_mode=args.planning_table3_mode, |
| image_path_remap=args.image_path_remap, |
| precomputed_det_tokens=_precomp_det, |
| precomputed_map_tokens=_precomp_map, |
| openlane_root=args.openlane_root, |
| ) |
|
|
| if is_online: |
| if args.task_balance_mode == "scene_unit_111": |
| scene_unit_groups = dataset.get_scene_unit_groups() |
| scene_unit_summary = dataset.get_scene_unit_summary() |
| sampler = SceneUnitTaskBalancedSampler( |
| scene_unit_groups, |
| num_replicas=world_size, |
| rank=rank, |
| seed=args.seed, |
| pad_to_multiple=args.gradient_accumulation_steps, |
| ) |
| if _main: |
| sampler_summary = sampler.get_summary() |
| logger.info( |
| "Online mode: SceneUnitTaskBalancedSampler " |
| "(balanced_scenes=%d, dataset_samples=%d, world=%d, skipped_scenes=%d)", |
| int(sampler_summary.get("num_scenes", 0)), |
| len(dataset), |
| world_size, |
| int(sampler_summary.get("skipped_scenes", 0)), |
| ) |
| logger.info( |
| "Scene-unit summary: total_scenes=%d units_total=%d all_three=%d", |
| int(scene_unit_summary.get("scenes_total", 0)), |
| int(scene_unit_summary.get("units_total", 0)), |
| int(scene_unit_summary.get("units_with_all_three", 0)), |
| ) |
| logger.info( |
| "Balanced sampler unit counts: %s", |
| _format_task_counts(sampler_summary.get("unit_counts", {})), |
| ) |
| logger.info( |
| "Balanced sampler raw counts: %s", |
| _format_task_counts(sampler_summary.get("raw_counts", {})), |
| ) |
| else: |
| scene_groups = dataset.get_scene_groups() |
| sampler = SceneSequentialSampler( |
| scene_groups, |
| num_replicas=world_size, |
| rank=rank, |
| seed=args.seed, |
| pad_to_multiple=args.gradient_accumulation_steps, |
| ) |
| if _main: |
| logger.info("Online mode: SceneSequentialSampler (%d scenes, %d samples, world=%d)", |
| len(scene_groups), len(dataset), world_size) |
| else: |
| from torch.utils.data import DistributedSampler |
| sampler = DistributedSampler(dataset, shuffle=True) if distributed else None |
|
|
| collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=(not is_online and sampler is None), |
| sampler=sampler, |
| num_workers=args.num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| drop_last=not is_online, |
| ) |
|
|
| streampetr_model = load_frozen_encoder( |
| args.streampetr_config, args.streampetr_ckpt, "streampetr", device |
| ) |
| topomlp_model = load_frozen_encoder( |
| args.topomlp_config, args.topomlp_ckpt, "topomlp", device |
| ) |
|
|
| topomlp_adapter = None |
| if topomlp_model is not None or _precomp_map: |
| _tp_bev_range = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0) |
| if args.topomlp_config: |
| try: |
| from mmcv import Config as _Cfg |
| _tp_cfg = _Cfg.fromfile(args.topomlp_config) |
| if hasattr(_tp_cfg, "point_cloud_range"): |
| _tp_bev_range = tuple(float(v) for v in _tp_cfg.point_cloud_range) |
| logger.info("TopoMLP bev_range from config: %s", _tp_bev_range) |
| except Exception as e: |
| logger.warning("Failed to read point_cloud_range from TopoMLP config: %s. Using default: %s", e, _tp_bev_range) |
| topomlp_adapter = TopoMLPToAtlasMapTokens( |
| num_map_tokens=args.num_map_queries, |
| hidden_size=args.visual_hidden_size, |
| bev_range=_tp_bev_range, |
| ).to(device) |
|
|
| dtype = torch.float32 |
| if args.bf16: |
| dtype = torch.bfloat16 |
| elif args.fp16: |
| dtype = torch.float16 |
|
|
| if args.load_in_4bit: |
| dm = {"": device} if distributed else "auto" |
| else: |
| dm = None |
|
|
| _ds_bf16 = False |
| _ds_fp16 = False |
| if args.deepspeed: |
| with open(args.deepspeed) as _f: |
| _ds_cfg_peek = json.load(_f) |
| _ds_bf16 = _ds_cfg_peek.get("bf16", {}).get("enabled", False) |
| _ds_fp16 = _ds_cfg_peek.get("fp16", {}).get("enabled", False) |
| _use_half = args.bf16 or args.fp16 or _ds_bf16 or _ds_fp16 |
| if _use_half and dtype == torch.float32: |
| dtype = torch.bfloat16 if (args.bf16 or _ds_bf16) else torch.float16 |
|
|
| atlas = AtlasForCausalLM( |
| llm_model_name=args.llm_model, |
| visual_hidden_size=args.visual_hidden_size, |
| num_queries=args.num_det_queries, |
| num_map_queries=args.num_map_queries, |
| load_in_4bit=args.load_in_4bit, |
| use_flash_attention=_use_half, |
| device_map=dm, |
| torch_dtype=dtype, |
| use_lora=args.use_lora, |
| lora_r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| ) |
| atlas.resize_token_embeddings(len(tokenizer)) |
| query_token_id = tokenizer.convert_tokens_to_ids("<query>") |
| atlas.set_query_token_id(query_token_id) |
| if topomlp_adapter is not None: |
| atlas.topomlp_adapter = topomlp_adapter |
| if dm is None and args.deepspeed is None: |
| atlas = atlas.to(device) |
| atlas.gradient_checkpointing_enable() |
|
|
| num_batches_per_epoch = len(dataloader) |
| steps_per_epoch = _optimizer_steps_per_epoch( |
| num_batches_per_epoch, args.gradient_accumulation_steps |
| ) |
| total_steps = steps_per_epoch * args.epochs |
| warmup_steps = int(total_steps * args.warmup_ratio) |
|
|
| global_step = 0 |
| start_epoch = 0 |
|
|
| _resume_ckpt = None |
| if args.resume: |
| _resume_ckpt = torch.load(args.resume, map_location="cpu") |
| if "atlas_state_dict" not in _resume_ckpt: |
| raise RuntimeError(f"Checkpoint missing 'atlas_state_dict'. Keys: {list(_resume_ckpt.keys())}") |
| missing, _ = atlas.load_state_dict(_resume_ckpt["atlas_state_dict"], strict=False) |
| if _main and missing: |
| logger.warning("Resume: %d missing keys (first 10): %s", len(missing), missing[:10]) |
| if topomlp_adapter is not None and "adapter_state_dict" in _resume_ckpt: |
| _m, _u = topomlp_adapter.load_state_dict(_resume_ckpt["adapter_state_dict"], strict=False) |
| if _main and _u: |
| logger.info("Adapter resume: ignored %d legacy keys: %s", len(_u), _u[:5]) |
| global_step = _resume_ckpt.get("global_step", 0) |
| start_epoch = _resume_ckpt.get("epoch", 0) |
| if _main: |
| logger.info("Resumed from %s (step=%d, epoch=%d)", args.resume, global_step, start_epoch) |
|
|
| use_deepspeed = args.deepspeed is not None |
| if use_deepspeed: |
| import deepspeed |
| ds_config = json.load(open(args.deepspeed)) |
| ds_config["optimizer"] = { |
| "type": "Adam", |
| "params": { |
| "lr": args.lr, "weight_decay": args.weight_decay, |
| "betas": [0.9, 0.999], "torch_adam": True, "adam_w_mode": True, |
| }, |
| } |
| ds_config["scheduler"] = { |
| "type": "WarmupCosineLR", |
| "params": { |
| "total_num_steps": total_steps, |
| "warmup_num_steps": warmup_steps, |
| "warmup_type": "linear", |
| }, |
| } |
| ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps |
| ds_config["train_micro_batch_size_per_gpu"] = args.batch_size |
| ds_config["train_batch_size"] = args.batch_size * args.gradient_accumulation_steps * world_size |
|
|
| ds_bf16 = ds_config.get("bf16", {}).get("enabled", False) |
| ds_fp16 = ds_config.get("fp16", {}).get("enabled", False) |
| if ds_bf16: |
| atlas.to(device=device, dtype=torch.bfloat16) |
| elif ds_fp16: |
| atlas.to(device=device, dtype=torch.float16) |
| else: |
| atlas.to(device) |
|
|
| all_params = atlas.get_trainable_param_groups(args.lr, weight_decay=args.weight_decay) |
| if topomlp_adapter is not None: |
| _adapter_trainable = [p for p in topomlp_adapter.parameters() if p.requires_grad] |
| if _adapter_trainable: |
| all_params.append({"params": _adapter_trainable, "lr": args.lr, "weight_decay": 0.0}) |
|
|
| atlas_ddp, optimizer, _, scheduler = deepspeed.initialize( |
| model=atlas, model_parameters=all_params, |
| config=ds_config, dist_init_required=False, |
| ) |
|
|
| if _resume_ckpt is not None and "optimizer" in _resume_ckpt: |
| try: |
| optimizer.load_state_dict(_resume_ckpt["optimizer"]) |
| if _main: |
| logger.info("Restored DeepSpeed optimizer state from checkpoint") |
| except Exception as e: |
| if _main: |
| logger.warning("Failed to restore DeepSpeed optimizer state: %s", e) |
|
|
| if global_step > 0 and scheduler is not None: |
| for _ in range(global_step): |
| scheduler.step() |
| _ff_lr = scheduler.get_lr() |
| if _main: |
| logger.info( |
| "Fast-forwarded DeepSpeed LR scheduler to step %d (lr=%s)", |
| global_step, |
| [f"{x:.6e}" for x in _ff_lr] if isinstance(_ff_lr, (list, tuple)) else f"{_ff_lr:.6e}", |
| ) |
| else: |
| param_groups = atlas.get_trainable_param_groups(args.lr, weight_decay=args.weight_decay) |
| if topomlp_adapter is not None: |
| _adapter_trainable = [p for p in topomlp_adapter.parameters() if p.requires_grad] |
| if _adapter_trainable: |
| param_groups.append({"params": _adapter_trainable, "lr": args.lr, "weight_decay": 0.0}) |
| optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay) |
| scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps) |
| if distributed: |
| atlas_ddp = torch.nn.parallel.DistributedDataParallel( |
| atlas, device_ids=[args.local_rank], find_unused_parameters=True, |
| ) |
| else: |
| atlas_ddp = atlas |
|
|
| if _resume_ckpt is not None and not use_deepspeed: |
| if "optimizer" in _resume_ckpt: |
| try: |
| optimizer.load_state_dict(_resume_ckpt["optimizer"]) |
| except Exception as e: |
| if _main: |
| logger.warning("Failed to restore optimizer state: %s", e) |
| if "scheduler" in _resume_ckpt and _resume_ckpt["scheduler"] is not None: |
| try: |
| scheduler.load_state_dict(_resume_ckpt["scheduler"]) |
| except Exception as e: |
| if _main: |
| logger.warning("Failed to restore scheduler state: %s", e) |
| _resume_ckpt = None |
|
|
| atlas_ddp.train() |
| if topomlp_adapter is not None: |
| topomlp_adapter.train() |
|
|
| if _main: |
| logger.info("=== Training Config ===") |
| logger.info(" epochs: %d, lr: %s, batch: %d, accum: %d", |
| args.epochs, args.lr, args.batch_size, args.gradient_accumulation_steps) |
| logger.info(" total_steps: %d, warmup_steps: %d", total_steps, warmup_steps) |
| _effective_dtype = next(atlas.parameters()).dtype |
| logger.info(" use_lora: %s, load_in_4bit: %s, fp16: %s, bf16: %s, deepspeed: %s, effective_dtype: %s", |
| args.use_lora, args.load_in_4bit, args.fp16, args.bf16, use_deepspeed, _effective_dtype) |
| n_trainable = sum(p.numel() for p in atlas.parameters() if p.requires_grad) |
| if topomlp_adapter is not None: |
| n_trainable += sum(p.numel() for p in topomlp_adapter.parameters() if p.requires_grad) |
| logger.info(" trainable params: %s", f"{n_trainable:,}") |
| logger.info(" visual_token_mode: %s", args.visual_token_mode) |
| logger.info(" task_balance_mode: %s", args.task_balance_mode) |
| logger.info(" task_loss_weights: %s", _format_task_loss_weights(args.task_loss_weights)) |
| logger.info(" streampetr: %s", "online-temporal" if (is_online and streampetr_model) else ("loaded" if streampetr_model else ("precomputed" if _precomp_det else "NONE (should not happen)"))) |
| logger.info(" topomlp: %s", "online" if (is_online and topomlp_model) else ("loaded" if topomlp_model else ("precomputed" if _precomp_map else "NONE (should not happen)"))) |
| logger.info("=======================") |
|
|
| streaming_state = {} if is_online else None |
|
|
| for epoch in range(start_epoch, args.epochs): |
| if sampler is not None: |
| sampler.set_epoch(epoch) |
| if is_online and args.task_balance_mode == "scene_unit_111" and isinstance(sampler, SceneUnitTaskBalancedSampler): |
| _ = len(sampler) |
| epoch_stats = sampler.get_last_epoch_stats() |
| if _main: |
| logger.info( |
| "Epoch %d scene-unit balance: unit_counts=%s", |
| epoch, |
| _format_task_counts(epoch_stats.get("unit_counts", {})), |
| ) |
| logger.info( |
| "Epoch %d scene-unit raw counts: %s", |
| epoch, |
| _format_task_counts(epoch_stats.get("raw_counts", {})), |
| ) |
| logger.info( |
| "Epoch %d sampler padding: scenes_rank=%d/%d skipped=%d prepad=%d target=%d replay_extra=%d", |
| epoch, |
| int(epoch_stats.get("num_scenes_rank", 0)), |
| int(epoch_stats.get("num_scenes_total", 0)), |
| int(epoch_stats.get("num_skipped_scenes", 0)), |
| int(epoch_stats.get("prepad_len", 0)), |
| int(epoch_stats.get("target_len", 0)), |
| int(epoch_stats.get("replay_extra", 0)), |
| ) |
|
|
| if streaming_state is not None: |
| streaming_state.clear() |
| if streampetr_model is not None: |
| streampetr_model.pts_bbox_head.reset_memory() |
|
|
| epoch_loss = 0.0 |
| epoch_raw_loss = 0.0 |
| num_batches = 0 |
| t0 = time.time() |
| task_raw_loss_ema: Dict[str, float] = {} |
|
|
| if not use_deepspeed: |
| optimizer.zero_grad() |
|
|
| for batch_idx, batch in enumerate(dataloader): |
| do_step = _is_optimizer_step_batch( |
| batch_idx, num_batches_per_epoch, args.gradient_accumulation_steps |
| ) |
| accum_window_size = _accum_window_size_for_batch( |
| batch_idx, num_batches_per_epoch, args.gradient_accumulation_steps |
| ) |
| scaled_loss = None |
| if use_deepspeed: |
| if not hasattr(atlas_ddp, "set_gradient_accumulation_boundary"): |
| raise RuntimeError( |
| "DeepSpeed engine is missing set_gradient_accumulation_boundary(); " |
| "cannot enforce epoch-tail flush semantics." |
| ) |
| atlas_ddp.set_gradient_accumulation_boundary(do_step) |
|
|
| if _main and batch_idx < 5: |
| _nq = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item()) if query_token_id else 0 |
| _needs_map = _nq > args.num_det_queries |
| logger.info("[DBG] batch_idx=%d nq=%d needs_map=%s sid=%s mode=%s", |
| batch_idx, _nq, _needs_map, |
| batch.get("sample_id", ["?"])[0][:20], |
| args.visual_token_mode) |
| for _handler in logging.root.handlers: |
| _handler.flush() |
|
|
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| visual_features = extract_visual_tokens( |
| streampetr_model, topomlp_model, topomlp_adapter, |
| batch, device, args.num_det_queries, args.visual_hidden_size, |
| query_token_id=query_token_id, |
| visual_token_mode=args.visual_token_mode, |
| streaming_state=streaming_state, |
| ) |
|
|
| if _main and batch_idx < 5: |
| logger.info("[DBG] vis_keys=%s", list(visual_features.keys())) |
| for _handler in logging.root.handlers: |
| _handler.flush() |
|
|
| if _main and batch_idx < 5: |
| logger.info("[DBG] pre-forward batch_idx=%d seqlen=%d", batch_idx, input_ids.shape[1]) |
| for _handler in logging.root.handlers: |
| _handler.flush() |
|
|
| outputs = atlas_ddp( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| visual_features=visual_features, |
| labels=labels, |
| ) |
| raw_loss = outputs.loss |
| batch_task_types = [str(task) for task in batch.get("task_type", [])] |
| task_loss_weight = _resolve_batch_task_weight(batch_task_types, args.task_loss_weights) |
| loss = raw_loss * task_loss_weight |
| batch_task = batch_task_types[0].strip().lower() if batch_task_types else "" |
| if batch_task in ("detection", "planning", "caption"): |
| _update_task_raw_loss_ema(task_raw_loss_ema, batch_task, raw_loss.item()) |
|
|
| if _main and batch_idx < 5: |
| logger.info( |
| "[DBG] post-forward batch_idx=%d raw_loss=%.4f weighted_loss=%.4f task=%s weight=%.4f %s", |
| batch_idx, |
| raw_loss.item(), |
| loss.item(), |
| batch_task_types[0] if batch_task_types else "?", |
| task_loss_weight, |
| _format_task_raw_loss_ema(task_raw_loss_ema), |
| ) |
| for _handler in logging.root.handlers: |
| _handler.flush() |
|
|
| |
|
|
| if _main and batch_idx < 5: |
| logger.info("[DBG] pre-backward batch_idx=%d", batch_idx) |
| for _handler in logging.root.handlers: |
| _handler.flush() |
|
|
| if use_deepspeed: |
| scaled_loss = loss / accum_window_size |
| atlas_ddp.backward(scaled_loss, scale_wrt_gas=False) |
|
|
| if _main and batch_idx < 5: |
| logger.info("[DBG] pre-step batch_idx=%d", batch_idx) |
| for _handler in logging.root.handlers: |
| _handler.flush() |
|
|
| atlas_ddp.step() |
| else: |
| scaled_loss = loss / accum_window_size |
| scaled_loss.backward() |
|
|
| if not use_deepspeed and distributed and topomlp_adapter is not None: |
| for p in topomlp_adapter.parameters(): |
| if p.requires_grad and p.grad is not None: |
| dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) |
| p.grad.div_(world_size) |
|
|
| epoch_loss += loss.item() |
| epoch_raw_loss += raw_loss.item() |
| num_batches += 1 |
| if _main and num_batches <= 3: |
| logger.info( |
| "batch=%d raw_loss=%.4f weighted_loss=%.4f task=%s weight=%.4f %s", |
| num_batches, |
| raw_loss.item(), |
| loss.item(), |
| batch_task_types[0] if batch_task_types else "?", |
| task_loss_weight, |
| _format_task_raw_loss_ema(task_raw_loss_ema), |
| ) |
| for _handler in logging.root.handlers: |
| _handler.flush() |
|
|
| if do_step: |
| if not use_deepspeed: |
| all_params = list(atlas.parameters()) + ( |
| list(topomlp_adapter.parameters()) if topomlp_adapter is not None else [] |
| ) |
| trainable = [p for p in all_params if p.requires_grad] |
| torch.nn.utils.clip_grad_norm_(trainable, args.max_grad_norm) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
|
|
| if _main and global_step % args.log_steps == 0: |
| |
| if use_deepspeed and hasattr(atlas_ddp, "get_lr"): |
| try: |
| _lrs = atlas_ddp.get_lr() |
| if isinstance(_lrs, (list, tuple)) and len(_lrs) > 0: |
| lr_now = float(_lrs[0]) |
| else: |
| lr_now = float(_lrs) |
| except Exception: |
| lr_now = optimizer.param_groups[0]["lr"] if getattr(optimizer, "param_groups", None) else args.lr |
| elif hasattr(scheduler, "get_last_lr"): |
| try: |
| lr_now = scheduler.get_last_lr()[0] |
| except Exception: |
| lr_now = optimizer.param_groups[0]["lr"] if getattr(optimizer, "param_groups", None) else args.lr |
| else: |
| lr_now = args.lr |
| elapsed = time.time() - t0 |
| samples_sec = num_batches * args.batch_size / max(elapsed, 1e-6) |
| avg_loss = epoch_loss / max(num_batches, 1) |
| avg_raw_loss = epoch_raw_loss / max(num_batches, 1) |
| logger.info( |
| "epoch=%d step=%d raw_loss=%.4f weighted_loss=%.4f lr=%.2e samples/s=%.1f %s", |
| epoch, global_step, avg_raw_loss, avg_loss, lr_now, samples_sec, |
| _format_task_raw_loss_ema(task_raw_loss_ema), |
| ) |
| for _handler in logging.root.handlers: |
| _handler.flush() |
|
|
| if _main and args.save_steps > 0 and global_step % args.save_steps == 0: |
| ckpt_path = output_dir / f"checkpoint-{global_step}" / "checkpoint.pt" |
| save_checkpoint(ckpt_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, epoch, args) |
| logger.info("Saved step checkpoint: %s", ckpt_path) |
|
|
| avg_loss = epoch_loss / max(num_batches, 1) |
| avg_raw_loss = epoch_raw_loss / max(num_batches, 1) |
| if _main: |
| logger.info( |
| "Epoch %d done — avg_raw_loss=%.4f avg_weighted_loss=%.4f (%.1f min) %s", |
| epoch, |
| avg_raw_loss, |
| avg_loss, |
| (time.time() - t0) / 60, |
| _format_task_raw_loss_ema(task_raw_loss_ema), |
| ) |
|
|
| if _main and (epoch + 1) % args.save_epochs == 0: |
| ckpt_path = output_dir / f"epoch-{epoch}" / "checkpoint.pt" |
| save_checkpoint(ckpt_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, epoch + 1, args) |
| logger.info("Saved epoch checkpoint: %s", ckpt_path) |
| if args.keep_last_n_ckpts > 0: |
| cleanup_old_checkpoints(output_dir, args.keep_last_n_ckpts) |
|
|
| if _main: |
| final_path = output_dir / "final" / "checkpoint.pt" |
| save_checkpoint(final_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, args.epochs, args) |
| logger.info("Training complete. Final checkpoint: %s", final_path) |
|
|
| if distributed: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|