guoyb0's picture
Upload code snapshot (2task with caption)
95f6448 verified
#!/usr/bin/env python3
import argparse
import os
import sys
import math
import time
import json
import logging
from datetime import timedelta
from pathlib import Path
from typing import Dict, Optional, List
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
sys.path.insert(0, str(Path(__file__).resolve().parent))
from src.model.modeling_atlas import AtlasForCausalLM
from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens
from src.model.streampetr_adapter import extract_streampetr_topk_tokens
from src.dataset.atlas_dataset import (
AtlasDataset, make_atlas_collate_fn, load_tokenizer,
)
from src.dataset.scene_sampler import (
SceneSequentialSampler,
SceneUnitTaskBalancedSampler,
)
from src.prompting import PLANNING_TABLE3_MODES
logger = logging.getLogger("train_atlas")
TASK_LOSS_WEIGHT_KEYS = ("detection", "planning", "caption", "lane")
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--llm_model", default="lmsys/vicuna-7b-v1.5")
p.add_argument("--visual_hidden_size", type=int, default=256)
p.add_argument("--num_det_queries", type=int, default=256)
p.add_argument("--num_map_queries", type=int, default=256)
p.add_argument("--streampetr_config", default=None)
p.add_argument("--streampetr_ckpt", default=None)
p.add_argument("--topomlp_config", default=None)
p.add_argument("--topomlp_ckpt", default=None)
p.add_argument("--data_json", required=True)
p.add_argument("--data_root", default="/mnt/data/nuscenes")
p.add_argument("--openlane_root", default=None,
help="Root for OpenLane relative image paths (lane task). "
"Falls back to --data_root if not set.")
p.add_argument("--max_length", type=int, default=4096)
p.add_argument("--output_dir", default="work_dirs/atlas")
p.add_argument("--lr", type=float, default=2e-5)
p.add_argument("--weight_decay", type=float, default=1e-4)
p.add_argument("--batch_size", type=int, default=1)
p.add_argument("--epochs", type=int, default=8)
p.add_argument("--warmup_ratio", type=float, default=0.03)
p.add_argument("--gradient_accumulation_steps", type=int, default=2)
p.add_argument("--max_grad_norm", type=float, default=1.0)
p.add_argument("--use_lora", action="store_true")
p.add_argument("--lora_r", type=int, default=64)
p.add_argument("--lora_alpha", type=int, default=64)
p.add_argument("--lora_dropout", type=float, default=0.1)
p.add_argument("--load_in_4bit", action="store_true")
p.add_argument("--save_steps", type=int, default=0)
p.add_argument("--save_epochs", type=int, default=1)
p.add_argument("--log_steps", type=int, default=10)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--num_workers", type=int, default=4)
p.add_argument("--resume", default=None)
p.add_argument("--local_rank", "--local-rank", type=int, default=int(os.environ.get("LOCAL_RANK", -1)))
p.add_argument("--fp16", action="store_true")
p.add_argument("--bf16", action="store_true")
p.add_argument("--image_path_remap", default=None,
help="old=new path remap, e.g. /mnt/data=/local/data")
p.add_argument("--precomputed_det_tokens", default=None,
help="[offline only] Dir with precomputed det tokens (.pt files)")
p.add_argument("--precomputed_map_tokens", default=None,
help="[offline only] Dir with precomputed TopoMLP map tokens (.pt files)")
p.add_argument("--visual_token_mode", choices=("online", "offline"), default="online",
help="Visual token source: online=live frozen encoders (default), offline=read *_offline dirs")
p.add_argument("--deepspeed", default=None,
help="Path to DeepSpeed config JSON (enables ZeRO)")
p.add_argument("--keep_last_n_ckpts", type=int, default=0,
help="Keep only the N most recent epoch checkpoints (0=keep all)")
p.add_argument(
"--task_balance_mode",
choices=("none", "scene_unit_111"),
default="none",
help=(
"Online sampler mode. "
"none=legacy scene-sequential sampling; "
"scene_unit_111=scene-sequential unit-level 1:1:1 balance over "
"detection/planning/caption timestamps (caption raw expands to 6 views)."
),
)
p.add_argument(
"--task_loss_weights",
default="",
help=(
"Static task weights applied to the scalar LM loss. "
"Format: detection=0.35,planning=1.0,caption=0.05. "
"Unspecified tasks default to 1.0."
),
)
p.add_argument(
"--planning_table3_mode",
choices=PLANNING_TABLE3_MODES,
default="atlas_base",
help=(
"Planning prompt variant matching Atlas Table 3: "
"atlas_base=no command/no explicit ego state; "
"atlas_high_level=requires top-level route_command "
"(this repo uses a UniAD-style future-GT-derived command); "
"atlas_high_level_ego=requires top-level route_command plus "
"velocity/acceleration bins."
),
)
args = p.parse_args()
args.task_loss_weights = _parse_task_loss_weights(args.task_loss_weights)
return args
def _parse_task_loss_weights(raw: str) -> Dict[str, float]:
weights = {task: 1.0 for task in TASK_LOSS_WEIGHT_KEYS}
if raw is None:
return weights
text = str(raw).strip()
if not text:
return weights
for entry in text.split(","):
item = entry.strip()
if not item:
continue
if "=" not in item:
raise ValueError(
"task_loss_weights must use key=value pairs separated by commas. "
f"Invalid entry: {item!r}"
)
key, value = item.split("=", 1)
task = key.strip().lower()
if task not in weights:
raise ValueError(
f"Unsupported task in task_loss_weights: {task!r}. "
f"Expected one of {TASK_LOSS_WEIGHT_KEYS}."
)
try:
numeric = float(value)
except Exception as exc:
raise ValueError(
f"Invalid float in task_loss_weights for task {task!r}: {value!r}"
) from exc
if numeric <= 0.0:
raise ValueError(
f"task_loss_weights values must be > 0. Got {numeric} for task {task!r}."
)
weights[task] = numeric
return weights
def _validate_visual_token_mode(args):
"""Enforce mode-specific constraints. Fail hard, never silently degrade."""
if args.task_balance_mode != "none" and args.visual_token_mode != "online":
raise RuntimeError(
"task_balance_mode requires --visual_token_mode online. "
f"Got: visual_token_mode={args.visual_token_mode!r}"
)
if args.visual_token_mode == "online":
if args.precomputed_det_tokens or args.precomputed_map_tokens:
raise RuntimeError(
"visual_token_mode=online forbids --precomputed_det_tokens / "
"--precomputed_map_tokens. Use --visual_token_mode offline to "
"read offline token directories."
)
missing = []
if not args.streampetr_config or not args.streampetr_ckpt:
missing.append("--streampetr_config/--streampetr_ckpt")
if not args.topomlp_config or not args.topomlp_ckpt:
missing.append("--topomlp_config/--topomlp_ckpt")
if missing:
raise RuntimeError(
"visual_token_mode=online requires live encoder configs and "
"checkpoints. Missing: " + ", ".join(missing)
)
for p in (args.streampetr_config, args.streampetr_ckpt, args.topomlp_config, args.topomlp_ckpt):
if not os.path.exists(p):
raise RuntimeError(f"Required online asset does not exist: {p}")
if args.batch_size != 1:
raise RuntimeError(
"visual_token_mode=online with temporal memory requires "
"--batch_size 1 (paper-aligned). Got: %d" % args.batch_size
)
else:
if not args.precomputed_det_tokens and not args.precomputed_map_tokens:
raise RuntimeError(
"visual_token_mode=offline requires at least one "
"--precomputed_*_tokens directory."
)
for p in (args.precomputed_det_tokens, args.precomputed_map_tokens):
if p and not os.path.isdir(p):
raise RuntimeError(f"Offline token directory does not exist: {p}")
def _format_task_counts(counts: Dict[str, int]) -> str:
ordered = ["detection", "planning", "caption"]
parts = [f"{task}={int(counts.get(task, 0))}" for task in ordered]
nonzero = [int(counts.get(task, 0)) for task in ordered if int(counts.get(task, 0)) > 0]
if nonzero:
base = float(min(nonzero))
ratio = ":".join(f"{int(counts.get(task, 0)) / base:.2f}" for task in ordered)
else:
ratio = "0:0:0"
return f"{', '.join(parts)} | ratio={ratio}"
def _format_task_loss_weights(weights: Dict[str, float]) -> str:
return ", ".join(f"{task}={float(weights.get(task, 1.0)):.4f}" for task in TASK_LOSS_WEIGHT_KEYS)
def _update_task_raw_loss_ema(
ema_state: Dict[str, float],
task: str,
raw_loss_value: float,
beta: float = 0.98,
) -> None:
prev = ema_state.get(task)
if prev is None:
ema_state[task] = float(raw_loss_value)
else:
ema_state[task] = float(beta) * float(prev) + (1.0 - float(beta)) * float(raw_loss_value)
def _format_task_raw_loss_ema(ema_state: Dict[str, float]) -> str:
mapping = [
("detection", "det_raw_loss_ema"),
("planning", "planning_raw_loss_ema"),
("caption", "caption_raw_loss_ema"),
]
parts = []
for task, label in mapping:
value = ema_state.get(task)
if value is None:
parts.append(f"{label}=NA")
else:
parts.append(f"{label}={float(value):.4f}")
return " ".join(parts)
def _resolve_batch_task_weight(batch_task_types: List[str], task_weights: Dict[str, float]) -> float:
if not batch_task_types:
return 1.0
tasks = [str(task).strip().lower() for task in batch_task_types if str(task).strip()]
if not tasks:
return 1.0
unique_tasks = sorted(set(tasks))
if len(unique_tasks) == 1:
return float(task_weights.get(unique_tasks[0], 1.0))
unique_weights = {
task: float(task_weights.get(task, 1.0))
for task in unique_tasks
}
rounded_values = {round(val, 12) for val in unique_weights.values()}
if len(rounded_values) != 1:
raise RuntimeError(
"Mixed-task batch encountered with non-uniform task_loss_weights. "
f"batch_tasks={unique_tasks} weights={unique_weights}. "
"Use batch_size=1 / online mode, or set equal weights for mixed-task batches."
)
return float(next(iter(unique_weights.values())))
def set_seed(seed):
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def setup_distributed(local_rank):
if local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return device, False, 0, 1
dist.init_process_group(backend="nccl", timeout=timedelta(seconds=1800))
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
rank = dist.get_rank()
world_size = dist.get_world_size()
return device, True, rank, world_size
def is_main_process(distributed, rank):
return (not distributed) or (rank == 0)
def load_frozen_encoder(config_path, ckpt_path, model_type, device):
if config_path is None or ckpt_path is None:
return None
try:
from mmcv import Config
from mmdet3d.models import build_model
from mmcv.runner import load_checkpoint
except ImportError:
raise RuntimeError(
f"mmcv/mmdet3d not installed but --{model_type}_config and "
f"--{model_type}_ckpt were explicitly provided. "
f"Install mmcv/mmdet3d or remove these arguments to train without {model_type}."
)
if model_type == "streampetr":
sp_root = str(Path(__file__).resolve().parent / "external" / "StreamPETR")
if sp_root not in sys.path:
sys.path.insert(0, sp_root)
try:
import projects.mmdet3d_plugin # noqa: F401
except ImportError:
raise RuntimeError(
f"StreamPETR plugin not found under {sp_root}/projects/mmdet3d_plugin. "
f"Ensure the submodule is checked out, or remove --streampetr_config/--streampetr_ckpt."
)
elif model_type == "topomlp":
tp_root = str(Path(__file__).resolve().parent / "external" / "TopoMLP_Repo")
if tp_root not in sys.path:
sys.path.insert(0, tp_root)
try:
os.environ["ATLAS_TOPOMLP_MODELS_ONLY"] = "1"
from mmcv.utils import registry as _reg
_orig = _reg.Registry._register_module
def _tolerant_register(self, module, module_name=None, force=False):
return _orig(self, module, module_name=module_name, force=True)
_reg.Registry._register_module = _tolerant_register
import projects.topomlp # noqa: F401
_reg.Registry._register_module = _orig
except ImportError:
raise RuntimeError(
f"TopoMLP plugin not found under {tp_root}/projects/topomlp. "
f"Ensure the submodule is checked out, or remove --topomlp_config/--topomlp_ckpt."
)
cfg = Config.fromfile(config_path)
model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
load_checkpoint(model, ckpt_path, map_location="cpu")
model.eval()
model.to(device)
for param in model.parameters():
param.requires_grad_(False)
logger.info("Loaded frozen %s from %s", model_type, ckpt_path)
return model
def build_img_metas_streampetr(batch, device, idx):
N = batch["pixel_values_det"].shape[1]
fH, fW = 800, 1600
scene_ids = batch.get("scene_id", ["__atlas__"] * (idx + 1))
meta = {
"pad_shape": [(fH, fW, 3)] * N,
"img_shape": [(fH, fW, 3)] * N,
"scene_token": scene_ids[idx] if idx < len(scene_ids) else "__atlas__",
}
if "lidar2img_det" in batch:
meta["lidar2img"] = batch["lidar2img_det"][idx].cpu().numpy()
return meta
def build_img_metas_topomlp(batch, device, idx):
meta = {}
if "lidar2img_map" in batch:
meta["lidar2img"] = batch["lidar2img_map"][idx].cpu().numpy()
tH, tW = 800, 1600
N = batch["pixel_values_map"].shape[1]
meta["img_shape"] = tuple([(tH, tW, 3)] * N)
meta["pad_shape"] = tuple([(tH, tW, 3)] * N)
meta["scale_factor"] = 1.0
meta["te_yolov8"] = None
return meta
@torch.no_grad()
def run_streampetr_forward(model, imgs, img_metas, batch, device, prev_exists=None):
B, N = imgs.shape[:2]
img_feats = model.extract_img_feat(imgs, 1)
data = {
"img": imgs,
"img_feats": img_feats,
"prev_exists": prev_exists if prev_exists is not None else imgs.new_zeros(B),
}
if "intrinsics_det" in batch:
K3 = batch["intrinsics_det"].to(device)
K4 = torch.zeros(B, N, 4, 4, device=device, dtype=K3.dtype)
K4[:, :, :3, :3] = K3
K4[:, :, 3, 3] = 1.0
data["intrinsics"] = K4
else:
data["intrinsics"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
if "lidar2img_det" in batch:
data["lidar2img"] = batch["lidar2img_det"].to(device)
else:
data["lidar2img"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
if "ego_pose" in batch and batch["ego_pose"] is not None:
data["ego_pose"] = batch["ego_pose"].to(device)
else:
data["ego_pose"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
if "ego_pose_inv" in batch and batch["ego_pose_inv"] is not None:
data["ego_pose_inv"] = batch["ego_pose_inv"].to(device)
else:
data["ego_pose_inv"] = torch.inverse(data["ego_pose"])
if "timestamp" in batch and batch["timestamp"] is not None:
data["timestamp"] = batch["timestamp"].to(device)
else:
data["timestamp"] = torch.zeros(B, device=device)
location = model.prepare_location(img_metas, **data)
outs_roi = model.forward_roi_head(location, **data)
topk_indexes = outs_roi["topk_indexes"]
outs = model.pts_bbox_head(location, img_metas, topk_indexes, **data)
return outs
@torch.no_grad()
def run_topomlp_forward(model, imgs, img_metas):
return model.simple_forward(imgs, img_metas)
def _reconstruct_topomlp_outs(saved: dict, device, dtype):
"""Convert precomputed .pt dict back to the format adapter.forward() expects."""
def _restore(t):
return t.to(device=device, dtype=dtype).unsqueeze(0)
return {
"lc_outs_dec_list": [_restore(saved["lc_outs_dec"])],
"all_lc_cls_scores_list": [_restore(saved["lc_cls_scores"])],
"all_lc_preds_list": [_restore(saved["lc_preds"])],
"lc_outs_dec_one2many_list": [_restore(saved["lc_outs_dec_o2m"])],
"all_lc_cls_scores_one2many_list": [_restore(saved["lc_cls_scores_o2m"])],
"all_lc_preds_one2many_list": [_restore(saved["lc_preds_o2m"])],
}
def extract_visual_tokens(
streampetr_model,
topomlp_model,
topomlp_adapter,
batch,
device,
num_det_queries=256,
visual_hidden_size=256,
query_token_id=None,
visual_token_mode="online",
streaming_state=None,
):
"""Extract det + map visual tokens.
In online mode with streaming_state, StreamPETR temporal memory is managed
per-scene and duplicate physical frames are protected: if the current
sample_id equals the previous one, we reuse cached det tokens and skip the
StreamPETR forward to avoid pushing the same frame into memory twice.
"""
B = batch["pixel_values_det"].shape[0]
vis: Dict[str, torch.Tensor] = {}
needs_map = False
if query_token_id is not None and "input_ids" in batch:
n_queries = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item())
needs_map = n_queries > num_det_queries
# ---- Detection tokens ----
if visual_token_mode == "offline" and "precomputed_det" in batch and "precomputed_det_ref" in batch:
vis["detection"] = batch["precomputed_det"].to(device)
vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device)
elif visual_token_mode == "offline":
raise RuntimeError(
"visual_token_mode=offline but detection precomputed tokens are missing "
"for the current batch. Refusing to zero-fill."
)
elif streampetr_model is not None:
if B != 1 and streaming_state is not None:
raise RuntimeError("online temporal det requires batch_size=1")
current_sample_id = batch.get("sample_id", [None])[0]
current_scene = batch.get("scene_id", ["__atlas__"])[0]
reuse_cache = False
if streaming_state is not None:
prev_scene = streaming_state.get("prev_scene_token")
prev_sample_id = streaming_state.get("prev_sample_id")
ts_tensor = batch.get("timestamp")
current_ts = float(ts_tensor[0].item()) if ts_tensor is not None else None
prev_ts = streaming_state.get("prev_timestamp")
is_new_segment = (
prev_scene is None
or current_scene != prev_scene
or (current_ts is not None and prev_ts is not None and current_ts <= prev_ts)
)
if current_sample_id is not None and current_sample_id == prev_sample_id:
cached = streaming_state.get("cached_det")
if cached is not None:
reuse_cache = True
vis["detection"] = cached["detection"]
vis["detection_ref_points"] = cached["detection_ref_points"]
if not reuse_cache:
if is_new_segment:
streampetr_model.pts_bbox_head.reset_memory()
prev_exists_val = 0.0 if is_new_segment else 1.0
imgs_det = batch["pixel_values_det"].to(device)
prev_exists = imgs_det.new_full((B,), prev_exists_val)
img_metas = [build_img_metas_streampetr(batch, device, b) for b in range(B)]
run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device, prev_exists=prev_exists)
ego_pose_for_ref = batch.get("ego_pose")
if ego_pose_for_ref is not None:
ego_pose_for_ref = ego_pose_for_ref.to(device)
det_out = extract_streampetr_topk_tokens(
streampetr_model.pts_bbox_head,
topk=num_det_queries,
ego_pose=ego_pose_for_ref,
)
vis["detection"] = det_out["detection"]
vis["detection_ref_points"] = det_out["detection_ref_points"]
streaming_state["cached_det"] = {
"detection": vis["detection"],
"detection_ref_points": vis["detection_ref_points"],
}
streaming_state["prev_scene_token"] = current_scene
streaming_state["prev_sample_id"] = current_sample_id
if batch.get("timestamp") is not None:
streaming_state["prev_timestamp"] = float(batch["timestamp"][0].item())
else:
imgs_det = batch["pixel_values_det"].to(device)
img_metas = [build_img_metas_streampetr(batch, device, b) for b in range(B)]
run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device)
ego_pose_for_ref = batch.get("ego_pose")
if ego_pose_for_ref is not None:
ego_pose_for_ref = ego_pose_for_ref.to(device)
det_out = extract_streampetr_topk_tokens(
streampetr_model.pts_bbox_head,
topk=num_det_queries,
ego_pose=ego_pose_for_ref,
)
vis["detection"] = det_out["detection"]
vis["detection_ref_points"] = det_out["detection_ref_points"]
elif visual_token_mode == "online":
raise RuntimeError(
"visual_token_mode=online but StreamPETR model is None. "
"Provide --streampetr_config and --streampetr_ckpt."
)
else:
vis["detection"] = torch.zeros(B, num_det_queries, visual_hidden_size, device=device)
vis["detection_ref_points"] = torch.zeros(B, num_det_queries, 3, device=device)
# ---- Map tokens ----
num_map_queries = num_det_queries
if topomlp_adapter is not None:
num_map_queries = topomlp_adapter.num_map_tokens
if topomlp_adapter is not None:
_params = list(topomlp_adapter.parameters())
_bufs = list(topomlp_adapter.buffers())
adapter_dtype = _params[0].dtype if _params else (_bufs[0].dtype if _bufs else torch.float32)
map_filled = False
if visual_token_mode == "offline" and needs_map and "precomputed_map" in batch:
if B == 1:
outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype)
else:
per_sample = [_reconstruct_topomlp_outs(batch["precomputed_map"][b], device, adapter_dtype) for b in range(B)]
outs = {}
for k in per_sample[0]:
outs[k] = [torch.cat([s[k][i] for s in per_sample], dim=0) for i in range(len(per_sample[0][k]))]
map_out = topomlp_adapter(outs)
vis["map"] = map_out["map"]
vis["map_ref_points"] = map_out["map_ref_points"]
map_filled = True
elif visual_token_mode == "offline" and needs_map:
raise RuntimeError(
"visual_token_mode=offline but map precomputed tokens are missing "
"for a batch that requires map queries. Refusing to zero-fill."
)
elif needs_map and topomlp_model is not None:
imgs_map = batch["pixel_values_map"].to(device)
img_metas = [build_img_metas_topomlp(batch, device, b) for b in range(B)]
outs = run_topomlp_forward(topomlp_model, imgs_map, img_metas)
for k, v in outs.items():
if isinstance(v, torch.Tensor):
outs[k] = v.to(adapter_dtype)
elif isinstance(v, list):
outs[k] = [x.to(adapter_dtype) if isinstance(x, torch.Tensor) else x for x in v]
map_out = topomlp_adapter(outs)
vis["map"] = map_out["map"]
vis["map_ref_points"] = map_out["map_ref_points"]
map_filled = True
elif needs_map and visual_token_mode == "online":
raise RuntimeError(
"visual_token_mode=online but TopoMLP model is None. "
"Provide --topomlp_config and --topomlp_ckpt."
)
if not map_filled:
vis["map"] = torch.zeros(B, num_map_queries, visual_hidden_size, device=device)
vis["map_ref_points"] = torch.zeros(B, num_map_queries, 3, device=device)
return vis
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.0):
def lr_lambda(step):
if step < num_warmup_steps:
return float(step) / float(max(1, num_warmup_steps))
progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def _optimizer_steps_per_epoch(num_batches: int, grad_accum_steps: int) -> int:
if num_batches <= 0:
return 0
return int(math.ceil(float(num_batches) / float(max(1, grad_accum_steps))))
def _accum_window_size_for_batch(
batch_idx: int,
num_batches: int,
grad_accum_steps: int,
) -> int:
"""Return the effective accumulation window size for this batch.
Full windows use `grad_accum_steps`. The final partial window uses the
remainder so that tail batches are not under-scaled when we flush them.
"""
grad_accum_steps = max(1, int(grad_accum_steps))
remainder = int(num_batches % grad_accum_steps)
tail_start = int(num_batches - remainder)
if remainder > 0 and batch_idx >= tail_start:
return remainder
return grad_accum_steps
def _is_optimizer_step_batch(
batch_idx: int,
num_batches: int,
grad_accum_steps: int,
) -> bool:
grad_accum_steps = max(1, int(grad_accum_steps))
natural_boundary = ((batch_idx + 1) % grad_accum_steps) == 0
is_last_batch = (batch_idx + 1) == num_batches
return natural_boundary or is_last_batch
def save_checkpoint(path, atlas, adapter, optimizer, scheduler, global_step, epoch, args):
save_dict = {
"global_step": global_step,
"epoch": epoch,
"args": vars(args),
"atlas_state_dict": {k: v.cpu() for k, v in atlas.state_dict().items()},
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict() if scheduler is not None else None,
}
if adapter is not None:
save_dict["adapter_state_dict"] = {k: v.cpu() for k, v in adapter.state_dict().items()}
Path(path).parent.mkdir(parents=True, exist_ok=True)
torch.save(save_dict, path)
def cleanup_old_checkpoints(output_dir: Path, keep_n: int):
"""Delete old epoch-* checkpoint dirs, keeping only the most recent *keep_n*."""
if keep_n <= 0:
return
import shutil
epoch_dirs = sorted(
[d for d in output_dir.iterdir() if d.is_dir() and d.name.startswith("epoch-")],
key=lambda d: int(d.name.split("-")[1]),
)
while len(epoch_dirs) > keep_n:
old = epoch_dirs.pop(0)
shutil.rmtree(old, ignore_errors=True)
logger.info("Deleted old checkpoint: %s", old)
def main():
args = parse_args()
class _FlushHandler(logging.StreamHandler):
def emit(self, record):
super().emit(record)
self.flush()
logging.root.handlers.clear()
_h = _FlushHandler(sys.stderr)
_fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
_h.setFormatter(_fmt)
logging.root.addHandler(_h)
logging.root.setLevel(logging.INFO)
_validate_visual_token_mode(args)
device, distributed, rank, world_size = setup_distributed(args.local_rank)
set_seed(args.seed + rank)
_main = is_main_process(distributed, rank)
is_online = args.visual_token_mode == "online"
output_dir = Path(args.output_dir)
if _main:
output_dir.mkdir(parents=True, exist_ok=True)
_fh = logging.FileHandler(str(output_dir / "train.log"), mode="a")
_fh.setFormatter(_fmt)
logging.root.addHandler(_fh)
with open(output_dir / "args.json", "w") as f:
json.dump(vars(args), f, indent=2)
if _main:
logger.info("Loading tokenizer: %s", args.llm_model)
tokenizer = load_tokenizer(args.llm_model)
if "<query>" not in tokenizer.get_vocab():
tokenizer.add_tokens(["<query>"])
_precomp_det = args.precomputed_det_tokens if not is_online else None
_precomp_map = args.precomputed_map_tokens if not is_online else None
dataset = AtlasDataset(
json_file=args.data_json,
image_root=args.data_root,
tokenizer=tokenizer,
max_length=args.max_length,
is_training=True,
planning_table3_mode=args.planning_table3_mode,
image_path_remap=args.image_path_remap,
precomputed_det_tokens=_precomp_det,
precomputed_map_tokens=_precomp_map,
openlane_root=args.openlane_root,
)
if is_online:
if args.task_balance_mode == "scene_unit_111":
scene_unit_groups = dataset.get_scene_unit_groups()
scene_unit_summary = dataset.get_scene_unit_summary()
sampler = SceneUnitTaskBalancedSampler(
scene_unit_groups,
num_replicas=world_size,
rank=rank,
seed=args.seed,
pad_to_multiple=args.gradient_accumulation_steps,
)
if _main:
sampler_summary = sampler.get_summary()
logger.info(
"Online mode: SceneUnitTaskBalancedSampler "
"(balanced_scenes=%d, dataset_samples=%d, world=%d, skipped_scenes=%d)",
int(sampler_summary.get("num_scenes", 0)),
len(dataset),
world_size,
int(sampler_summary.get("skipped_scenes", 0)),
)
logger.info(
"Scene-unit summary: total_scenes=%d units_total=%d all_three=%d",
int(scene_unit_summary.get("scenes_total", 0)),
int(scene_unit_summary.get("units_total", 0)),
int(scene_unit_summary.get("units_with_all_three", 0)),
)
logger.info(
"Balanced sampler unit counts: %s",
_format_task_counts(sampler_summary.get("unit_counts", {})),
)
logger.info(
"Balanced sampler raw counts: %s",
_format_task_counts(sampler_summary.get("raw_counts", {})),
)
else:
scene_groups = dataset.get_scene_groups()
sampler = SceneSequentialSampler(
scene_groups,
num_replicas=world_size,
rank=rank,
seed=args.seed,
pad_to_multiple=args.gradient_accumulation_steps,
)
if _main:
logger.info("Online mode: SceneSequentialSampler (%d scenes, %d samples, world=%d)",
len(scene_groups), len(dataset), world_size)
else:
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(dataset, shuffle=True) if distributed else None
collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=(not is_online and sampler is None),
sampler=sampler,
num_workers=args.num_workers,
collate_fn=collate_fn,
pin_memory=True,
drop_last=not is_online,
)
streampetr_model = load_frozen_encoder(
args.streampetr_config, args.streampetr_ckpt, "streampetr", device
)
topomlp_model = load_frozen_encoder(
args.topomlp_config, args.topomlp_ckpt, "topomlp", device
)
topomlp_adapter = None
if topomlp_model is not None or _precomp_map:
_tp_bev_range = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0)
if args.topomlp_config:
try:
from mmcv import Config as _Cfg
_tp_cfg = _Cfg.fromfile(args.topomlp_config)
if hasattr(_tp_cfg, "point_cloud_range"):
_tp_bev_range = tuple(float(v) for v in _tp_cfg.point_cloud_range)
logger.info("TopoMLP bev_range from config: %s", _tp_bev_range)
except Exception as e:
logger.warning("Failed to read point_cloud_range from TopoMLP config: %s. Using default: %s", e, _tp_bev_range)
topomlp_adapter = TopoMLPToAtlasMapTokens(
num_map_tokens=args.num_map_queries,
hidden_size=args.visual_hidden_size,
bev_range=_tp_bev_range,
).to(device)
dtype = torch.float32
if args.bf16:
dtype = torch.bfloat16
elif args.fp16:
dtype = torch.float16
if args.load_in_4bit:
dm = {"": device} if distributed else "auto"
else:
dm = None
_ds_bf16 = False
_ds_fp16 = False
if args.deepspeed:
with open(args.deepspeed) as _f:
_ds_cfg_peek = json.load(_f)
_ds_bf16 = _ds_cfg_peek.get("bf16", {}).get("enabled", False)
_ds_fp16 = _ds_cfg_peek.get("fp16", {}).get("enabled", False)
_use_half = args.bf16 or args.fp16 or _ds_bf16 or _ds_fp16
if _use_half and dtype == torch.float32:
dtype = torch.bfloat16 if (args.bf16 or _ds_bf16) else torch.float16
atlas = AtlasForCausalLM(
llm_model_name=args.llm_model,
visual_hidden_size=args.visual_hidden_size,
num_queries=args.num_det_queries,
num_map_queries=args.num_map_queries,
load_in_4bit=args.load_in_4bit,
use_flash_attention=_use_half,
device_map=dm,
torch_dtype=dtype,
use_lora=args.use_lora,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
)
atlas.resize_token_embeddings(len(tokenizer))
query_token_id = tokenizer.convert_tokens_to_ids("<query>")
atlas.set_query_token_id(query_token_id)
if topomlp_adapter is not None:
atlas.topomlp_adapter = topomlp_adapter
if dm is None and args.deepspeed is None:
atlas = atlas.to(device)
atlas.gradient_checkpointing_enable()
num_batches_per_epoch = len(dataloader)
steps_per_epoch = _optimizer_steps_per_epoch(
num_batches_per_epoch, args.gradient_accumulation_steps
)
total_steps = steps_per_epoch * args.epochs
warmup_steps = int(total_steps * args.warmup_ratio)
global_step = 0
start_epoch = 0
_resume_ckpt = None
if args.resume:
_resume_ckpt = torch.load(args.resume, map_location="cpu")
if "atlas_state_dict" not in _resume_ckpt:
raise RuntimeError(f"Checkpoint missing 'atlas_state_dict'. Keys: {list(_resume_ckpt.keys())}")
missing, _ = atlas.load_state_dict(_resume_ckpt["atlas_state_dict"], strict=False)
if _main and missing:
logger.warning("Resume: %d missing keys (first 10): %s", len(missing), missing[:10])
if topomlp_adapter is not None and "adapter_state_dict" in _resume_ckpt:
_m, _u = topomlp_adapter.load_state_dict(_resume_ckpt["adapter_state_dict"], strict=False)
if _main and _u:
logger.info("Adapter resume: ignored %d legacy keys: %s", len(_u), _u[:5])
global_step = _resume_ckpt.get("global_step", 0)
start_epoch = _resume_ckpt.get("epoch", 0)
if _main:
logger.info("Resumed from %s (step=%d, epoch=%d)", args.resume, global_step, start_epoch)
use_deepspeed = args.deepspeed is not None
if use_deepspeed:
import deepspeed
ds_config = json.load(open(args.deepspeed))
ds_config["optimizer"] = {
"type": "Adam",
"params": {
"lr": args.lr, "weight_decay": args.weight_decay,
"betas": [0.9, 0.999], "torch_adam": True, "adam_w_mode": True,
},
}
ds_config["scheduler"] = {
"type": "WarmupCosineLR",
"params": {
"total_num_steps": total_steps,
"warmup_num_steps": warmup_steps,
"warmup_type": "linear",
},
}
ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
ds_config["train_micro_batch_size_per_gpu"] = args.batch_size
ds_config["train_batch_size"] = args.batch_size * args.gradient_accumulation_steps * world_size
ds_bf16 = ds_config.get("bf16", {}).get("enabled", False)
ds_fp16 = ds_config.get("fp16", {}).get("enabled", False)
if ds_bf16:
atlas.to(device=device, dtype=torch.bfloat16)
elif ds_fp16:
atlas.to(device=device, dtype=torch.float16)
else:
atlas.to(device)
all_params = atlas.get_trainable_param_groups(args.lr, weight_decay=args.weight_decay)
if topomlp_adapter is not None:
_adapter_trainable = [p for p in topomlp_adapter.parameters() if p.requires_grad]
if _adapter_trainable:
all_params.append({"params": _adapter_trainable, "lr": args.lr, "weight_decay": 0.0})
atlas_ddp, optimizer, _, scheduler = deepspeed.initialize(
model=atlas, model_parameters=all_params,
config=ds_config, dist_init_required=False,
)
if _resume_ckpt is not None and "optimizer" in _resume_ckpt:
try:
optimizer.load_state_dict(_resume_ckpt["optimizer"])
if _main:
logger.info("Restored DeepSpeed optimizer state from checkpoint")
except Exception as e:
if _main:
logger.warning("Failed to restore DeepSpeed optimizer state: %s", e)
if global_step > 0 and scheduler is not None:
for _ in range(global_step):
scheduler.step()
_ff_lr = scheduler.get_lr()
if _main:
logger.info(
"Fast-forwarded DeepSpeed LR scheduler to step %d (lr=%s)",
global_step,
[f"{x:.6e}" for x in _ff_lr] if isinstance(_ff_lr, (list, tuple)) else f"{_ff_lr:.6e}",
)
else:
param_groups = atlas.get_trainable_param_groups(args.lr, weight_decay=args.weight_decay)
if topomlp_adapter is not None:
_adapter_trainable = [p for p in topomlp_adapter.parameters() if p.requires_grad]
if _adapter_trainable:
param_groups.append({"params": _adapter_trainable, "lr": args.lr, "weight_decay": 0.0})
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
if distributed:
atlas_ddp = torch.nn.parallel.DistributedDataParallel(
atlas, device_ids=[args.local_rank], find_unused_parameters=True,
)
else:
atlas_ddp = atlas
if _resume_ckpt is not None and not use_deepspeed:
if "optimizer" in _resume_ckpt:
try:
optimizer.load_state_dict(_resume_ckpt["optimizer"])
except Exception as e:
if _main:
logger.warning("Failed to restore optimizer state: %s", e)
if "scheduler" in _resume_ckpt and _resume_ckpt["scheduler"] is not None:
try:
scheduler.load_state_dict(_resume_ckpt["scheduler"])
except Exception as e:
if _main:
logger.warning("Failed to restore scheduler state: %s", e)
_resume_ckpt = None
atlas_ddp.train()
if topomlp_adapter is not None:
topomlp_adapter.train()
if _main:
logger.info("=== Training Config ===")
logger.info(" epochs: %d, lr: %s, batch: %d, accum: %d",
args.epochs, args.lr, args.batch_size, args.gradient_accumulation_steps)
logger.info(" total_steps: %d, warmup_steps: %d", total_steps, warmup_steps)
_effective_dtype = next(atlas.parameters()).dtype
logger.info(" use_lora: %s, load_in_4bit: %s, fp16: %s, bf16: %s, deepspeed: %s, effective_dtype: %s",
args.use_lora, args.load_in_4bit, args.fp16, args.bf16, use_deepspeed, _effective_dtype)
n_trainable = sum(p.numel() for p in atlas.parameters() if p.requires_grad)
if topomlp_adapter is not None:
n_trainable += sum(p.numel() for p in topomlp_adapter.parameters() if p.requires_grad)
logger.info(" trainable params: %s", f"{n_trainable:,}")
logger.info(" visual_token_mode: %s", args.visual_token_mode)
logger.info(" task_balance_mode: %s", args.task_balance_mode)
logger.info(" task_loss_weights: %s", _format_task_loss_weights(args.task_loss_weights))
logger.info(" streampetr: %s", "online-temporal" if (is_online and streampetr_model) else ("loaded" if streampetr_model else ("precomputed" if _precomp_det else "NONE (should not happen)")))
logger.info(" topomlp: %s", "online" if (is_online and topomlp_model) else ("loaded" if topomlp_model else ("precomputed" if _precomp_map else "NONE (should not happen)")))
logger.info("=======================")
streaming_state = {} if is_online else None
for epoch in range(start_epoch, args.epochs):
if sampler is not None:
sampler.set_epoch(epoch)
if is_online and args.task_balance_mode == "scene_unit_111" and isinstance(sampler, SceneUnitTaskBalancedSampler):
_ = len(sampler)
epoch_stats = sampler.get_last_epoch_stats()
if _main:
logger.info(
"Epoch %d scene-unit balance: unit_counts=%s",
epoch,
_format_task_counts(epoch_stats.get("unit_counts", {})),
)
logger.info(
"Epoch %d scene-unit raw counts: %s",
epoch,
_format_task_counts(epoch_stats.get("raw_counts", {})),
)
logger.info(
"Epoch %d sampler padding: scenes_rank=%d/%d skipped=%d prepad=%d target=%d replay_extra=%d",
epoch,
int(epoch_stats.get("num_scenes_rank", 0)),
int(epoch_stats.get("num_scenes_total", 0)),
int(epoch_stats.get("num_skipped_scenes", 0)),
int(epoch_stats.get("prepad_len", 0)),
int(epoch_stats.get("target_len", 0)),
int(epoch_stats.get("replay_extra", 0)),
)
if streaming_state is not None:
streaming_state.clear()
if streampetr_model is not None:
streampetr_model.pts_bbox_head.reset_memory()
epoch_loss = 0.0
epoch_raw_loss = 0.0
num_batches = 0
t0 = time.time()
task_raw_loss_ema: Dict[str, float] = {}
if not use_deepspeed:
optimizer.zero_grad()
for batch_idx, batch in enumerate(dataloader):
do_step = _is_optimizer_step_batch(
batch_idx, num_batches_per_epoch, args.gradient_accumulation_steps
)
accum_window_size = _accum_window_size_for_batch(
batch_idx, num_batches_per_epoch, args.gradient_accumulation_steps
)
scaled_loss = None
if use_deepspeed:
if not hasattr(atlas_ddp, "set_gradient_accumulation_boundary"):
raise RuntimeError(
"DeepSpeed engine is missing set_gradient_accumulation_boundary(); "
"cannot enforce epoch-tail flush semantics."
)
atlas_ddp.set_gradient_accumulation_boundary(do_step)
if _main and batch_idx < 5:
_nq = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item()) if query_token_id else 0
_needs_map = _nq > args.num_det_queries
logger.info("[DBG] batch_idx=%d nq=%d needs_map=%s sid=%s mode=%s",
batch_idx, _nq, _needs_map,
batch.get("sample_id", ["?"])[0][:20],
args.visual_token_mode)
for _handler in logging.root.handlers:
_handler.flush()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
visual_features = extract_visual_tokens(
streampetr_model, topomlp_model, topomlp_adapter,
batch, device, args.num_det_queries, args.visual_hidden_size,
query_token_id=query_token_id,
visual_token_mode=args.visual_token_mode,
streaming_state=streaming_state,
)
if _main and batch_idx < 5:
logger.info("[DBG] vis_keys=%s", list(visual_features.keys()))
for _handler in logging.root.handlers:
_handler.flush()
if _main and batch_idx < 5:
logger.info("[DBG] pre-forward batch_idx=%d seqlen=%d", batch_idx, input_ids.shape[1])
for _handler in logging.root.handlers:
_handler.flush()
outputs = atlas_ddp(
input_ids=input_ids,
attention_mask=attention_mask,
visual_features=visual_features,
labels=labels,
)
raw_loss = outputs.loss
batch_task_types = [str(task) for task in batch.get("task_type", [])]
task_loss_weight = _resolve_batch_task_weight(batch_task_types, args.task_loss_weights)
loss = raw_loss * task_loss_weight
batch_task = batch_task_types[0].strip().lower() if batch_task_types else ""
if batch_task in ("detection", "planning", "caption"):
_update_task_raw_loss_ema(task_raw_loss_ema, batch_task, raw_loss.item())
if _main and batch_idx < 5:
logger.info(
"[DBG] post-forward batch_idx=%d raw_loss=%.4f weighted_loss=%.4f task=%s weight=%.4f %s",
batch_idx,
raw_loss.item(),
loss.item(),
batch_task_types[0] if batch_task_types else "?",
task_loss_weight,
_format_task_raw_loss_ema(task_raw_loss_ema),
)
for _handler in logging.root.handlers:
_handler.flush()
if _main and batch_idx < 5:
logger.info("[DBG] pre-backward batch_idx=%d", batch_idx)
for _handler in logging.root.handlers:
_handler.flush()
if use_deepspeed:
scaled_loss = loss / accum_window_size
atlas_ddp.backward(scaled_loss, scale_wrt_gas=False)
if _main and batch_idx < 5:
logger.info("[DBG] pre-step batch_idx=%d", batch_idx)
for _handler in logging.root.handlers:
_handler.flush()
atlas_ddp.step()
else:
scaled_loss = loss / accum_window_size
scaled_loss.backward()
if not use_deepspeed and distributed and topomlp_adapter is not None:
for p in topomlp_adapter.parameters():
if p.requires_grad and p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
p.grad.div_(world_size)
epoch_loss += loss.item()
epoch_raw_loss += raw_loss.item()
num_batches += 1
if _main and num_batches <= 3:
logger.info(
"batch=%d raw_loss=%.4f weighted_loss=%.4f task=%s weight=%.4f %s",
num_batches,
raw_loss.item(),
loss.item(),
batch_task_types[0] if batch_task_types else "?",
task_loss_weight,
_format_task_raw_loss_ema(task_raw_loss_ema),
)
for _handler in logging.root.handlers:
_handler.flush()
if do_step:
if not use_deepspeed:
all_params = list(atlas.parameters()) + (
list(topomlp_adapter.parameters()) if topomlp_adapter is not None else []
)
trainable = [p for p in all_params if p.requires_grad]
torch.nn.utils.clip_grad_norm_(trainable, args.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
if _main and global_step % args.log_steps == 0:
# DeepSpeed LR scheduler may not expose get_last_lr() before first scheduler.step().
if use_deepspeed and hasattr(atlas_ddp, "get_lr"):
try:
_lrs = atlas_ddp.get_lr()
if isinstance(_lrs, (list, tuple)) and len(_lrs) > 0:
lr_now = float(_lrs[0])
else:
lr_now = float(_lrs)
except Exception:
lr_now = optimizer.param_groups[0]["lr"] if getattr(optimizer, "param_groups", None) else args.lr
elif hasattr(scheduler, "get_last_lr"):
try:
lr_now = scheduler.get_last_lr()[0]
except Exception:
lr_now = optimizer.param_groups[0]["lr"] if getattr(optimizer, "param_groups", None) else args.lr
else:
lr_now = args.lr
elapsed = time.time() - t0
samples_sec = num_batches * args.batch_size / max(elapsed, 1e-6)
avg_loss = epoch_loss / max(num_batches, 1)
avg_raw_loss = epoch_raw_loss / max(num_batches, 1)
logger.info(
"epoch=%d step=%d raw_loss=%.4f weighted_loss=%.4f lr=%.2e samples/s=%.1f %s",
epoch, global_step, avg_raw_loss, avg_loss, lr_now, samples_sec,
_format_task_raw_loss_ema(task_raw_loss_ema),
)
for _handler in logging.root.handlers:
_handler.flush()
if _main and args.save_steps > 0 and global_step % args.save_steps == 0:
ckpt_path = output_dir / f"checkpoint-{global_step}" / "checkpoint.pt"
save_checkpoint(ckpt_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, epoch, args)
logger.info("Saved step checkpoint: %s", ckpt_path)
avg_loss = epoch_loss / max(num_batches, 1)
avg_raw_loss = epoch_raw_loss / max(num_batches, 1)
if _main:
logger.info(
"Epoch %d done — avg_raw_loss=%.4f avg_weighted_loss=%.4f (%.1f min) %s",
epoch,
avg_raw_loss,
avg_loss,
(time.time() - t0) / 60,
_format_task_raw_loss_ema(task_raw_loss_ema),
)
if _main and (epoch + 1) % args.save_epochs == 0:
ckpt_path = output_dir / f"epoch-{epoch}" / "checkpoint.pt"
save_checkpoint(ckpt_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, epoch + 1, args)
logger.info("Saved epoch checkpoint: %s", ckpt_path)
if args.keep_last_n_ckpts > 0:
cleanup_old_checkpoints(output_dir, args.keep_last_n_ckpts)
if _main:
final_path = output_dir / "final" / "checkpoint.pt"
save_checkpoint(final_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, args.epochs, args)
logger.info("Training complete. Final checkpoint: %s", final_path)
if distributed:
dist.destroy_process_group()
if __name__ == "__main__":
main()