guoyb0 commited on
Commit
9fe982a
·
verified ·
1 Parent(s): 17200d8

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +333 -0
  2. configs/REPRODUCTION.md +200 -0
  3. configs/ds_zero2.json +17 -0
  4. eval_atlas.py +1175 -0
  5. extract_streampetr_tokens.py +568 -0
  6. extract_topomlp_tokens.py +381 -0
  7. scripts/eval_checkpoint_offline.sh +44 -0
  8. scripts/gen_atlas_caption_dashscope.py +272 -0
  9. scripts/gen_atlas_caption_qa.py +274 -0
  10. scripts/gen_atlas_openlane_subsetB_lane_qa.py +251 -0
  11. scripts/gen_atlas_planning_qa.py +491 -0
  12. scripts/run_val_extraction.sh +56 -0
  13. scripts/train_no_caption_baseline.sh +50 -0
  14. scripts/train_no_caption_baseline_offline.sh +48 -0
  15. scripts/train_with_caption_balanced.sh +48 -0
  16. scripts/vis_atlas_lane_gt_pred.py +500 -0
  17. scripts/vis_atlas_planning_qualitative.py +800 -0
  18. scripts/vis_traffic_violation.py +516 -0
  19. src/__pycache__/__init__.cpython-310.pyc +0 -0
  20. src/__pycache__/prompting.cpython-310.pyc +0 -0
  21. src/__pycache__/prompting.cpython-38.pyc +0 -0
  22. src/audit/__pycache__/__init__.cpython-310.pyc +0 -0
  23. src/audit/__pycache__/audit_utils.cpython-310.pyc +0 -0
  24. src/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
  25. src/dataset/__pycache__/atlas_dataset.cpython-310.pyc +0 -0
  26. src/dataset/__pycache__/atlas_dataset.cpython-38.pyc +0 -0
  27. src/dataset/__pycache__/scene_sampler.cpython-310.pyc +0 -0
  28. src/dataset/__pycache__/scene_sampler.cpython-38.pyc +0 -0
  29. src/dataset/atlas_dataset.py +1416 -0
  30. src/dataset/scene_sampler.py +111 -0
  31. src/eval/__pycache__/__init__.cpython-310.pyc +0 -0
  32. src/eval/__pycache__/__init__.cpython-38.pyc +0 -0
  33. src/eval/__pycache__/metrics.cpython-310.pyc +0 -0
  34. src/eval/__pycache__/metrics.cpython-38.pyc +0 -0
  35. src/eval/metrics.py +852 -0
  36. src/model/__init__.py +28 -0
  37. src/model/__pycache__/__init__.cpython-310.pyc +0 -0
  38. src/model/__pycache__/__init__.cpython-38.pyc +0 -0
  39. src/model/__pycache__/configuration_atlas.cpython-310.pyc +0 -0
  40. src/model/__pycache__/modeling_atlas.cpython-310.pyc +0 -0
  41. src/model/__pycache__/modeling_atlas.cpython-38.pyc +0 -0
  42. src/model/__pycache__/streampetr_adapter.cpython-310.pyc +0 -0
  43. src/model/__pycache__/streampetr_adapter.cpython-38.pyc +0 -0
  44. src/model/__pycache__/topomlp_adapter.cpython-310.pyc +0 -0
  45. src/model/__pycache__/topomlp_adapter.cpython-38.pyc +0 -0
  46. src/model/modeling_atlas.py +549 -0
  47. src/model/streampetr_adapter.py +110 -0
  48. src/model/topomlp_adapter.py +88 -0
  49. src/prompting.py +277 -0
  50. train_atlas.py +1018 -0
README.md ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - autonomous-driving
5
+ - 3d-detection
6
+ - lane-detection
7
+ - planning
8
+ - multimodal
9
+ - vicuna
10
+ ---
11
+
12
+ # Atlas — 3D-Tokenized LLM for Autonomous Driving
13
+
14
+ 基于 [Atlas 论文](https://arxiv.org/abs/2405.18361) 的多模态自动驾驶大语言模型实现。将 **StreamPETR**(3D 目标检测)和 **TopoMLP**(车道线检测)提取的 3D visual tokens 注入 **Vicuna-7B** LLM,实现检测、车道线、规划等多任务统一生成。
15
+
16
+ ## 项目结构
17
+
18
+ ```
19
+ 3dtokenizer-atlas/
20
+ ├── train_atlas.py # Atlas LLM 训练入口
21
+ ├── eval_atlas.py # Atlas 评估入口
22
+ ├── extract_streampetr_tokens.py # 预提取 StreamPETR detection tokens
23
+ ├── extract_topomlp_tokens.py # 预提取 TopoMLP lane tokens
24
+ ├── train_streampetr.sh # StreamPETR 预训练启动脚本
25
+ ├── train_topomlp.sh # TopoMLP 预训练启动脚本
26
+
27
+ ├── configs/
28
+ │ ├── streampetr_atlas_aligned.py # StreamPETR 配置 (EVA-02 ViT-L, 800x1600)
29
+ │ ├── topomlp_atlas_aligned.py # TopoMLP 配置 (EVA-02 ViT-L, 800x1600)
30
+ │ ├── ds_zero2.json # DeepSpeed ZeRO-2 配置
31
+ │ └── REPRODUCTION.md # 复现文档
32
+
33
+ ├── src/
34
+ │ ├── model/
35
+ │ │ ├── modeling_atlas.py # AtlasForCausalLM 主模型
36
+ │ │ ├── streampetr_adapter.py # StreamPETR → 检测 token 适配器
37
+ │ │ └── topomlp_adapter.py # TopoMLP → 地图 token 适配器 (Top-K selection)
38
+ │ ├── dataset/
39
+ │ │ ├── atlas_dataset.py # AtlasDataset + Collate
40
+ │ │ └── scene_sampler.py # SceneSequentialSampler (时序采样)
41
+ │ ├── eval/
42
+ │ │ └── metrics.py # 评估指标 (F1/Chamfer/L2/Collision)
43
+ │ └── prompting.py # 多任务 Prompt 模板
44
+
45
+ ├── scripts/
46
+ │ ├── gen_atlas_full_data.py # nuScenes → 检测 QA JSON
47
+ │ ├── gen_atlas_openlane_subsetB_lane_qa.py # OpenLane-V2 → 车道线 QA JSON
48
+ │ ├── gen_atlas_planning_qa.py # nuScenes → 规划 QA JSON
49
+ │ ├── train_no_caption_baseline.sh # 无 caption 训练脚本
50
+ │ └── train_with_caption_balanced.sh # 含 caption 训练脚本
51
+
52
+ ├── data/ # 训练/验证数据 (JSON)
53
+ │ ├── atlas_nuscenes_train.json # 检测 (28,130 样本)
54
+ │ ├── atlas_nuscenes_val.json # 检测验证 (6,019 样本)
55
+ │ ├── openlane_subsetB_lane_train_4pt.json # 车道线 (27,968 样本, 4 点/lane)
56
+ │ ├── openlane_subsetB_lane_val_4pt.json # 车道线验证 (6,019 样本)
57
+ │ ├── atlas_planning_train_uniad_command.json # 规划 (23,541 样本, UniAD-style command)
58
+ │ ├── atlas_planning_val_uniad_command.json # 规划验证 (5,037 样本, UniAD-style command)
59
+ │ ├── atlas_caption_train.json # 环境描述 caption
60
+ │ └── atlas_caption_val.json # 环境描述 caption 验证
61
+
62
+ ├── pretrained/ # 预训练权重
63
+ │ ├── vicuna-7b-v1.5/ # Vicuna-7B-v1.5 LLM
64
+ │ ├── eva02_L_coco_det_sys_o365_remapped_fixed.pth
65
+ │ └── streampetr/
66
+ │ └── streampetr_eva02_ep24.pth
67
+
68
+ ├── work_dirs/
69
+ │ ├── _quick_eval_cpu.py # 快速检测评估 (CPU, micro-avg F1)
70
+ │ ├── _quick_eval_gpu.py # 快速检测评估 (GPU)
71
+ │ ├── _quick_eval_lane_gpu.py # 快速车道线评估 (GPU)
72
+ │ ├── _quick_eval_plan_gpu.py # 快速规划评估 (GPU, scene-sequential)
73
+ │ ├── precomputed_det_tokens_offline/ # 预提取的 StreamPETR tokens (offline 备选)
74
+ │ │ ├── train/ # 56,099 个 .pt 文件 (det + lane,planning 与 det 共享 ID)
75
+ │ │ └── val/ # 12,039 个 .pt 文件
76
+ │ ├── precomputed_map_tokens_offline/ # 预提取的 TopoMLP tokens (offline 备选)
77
+ │ │ ├── train/ # 51,510 个 .pt 文件 (lane + planning)
78
+ │ │ └── val/ # 11,057 个 .pt 文件
79
+ │ └── topomlp_atlas_aligned/ # TopoMLP 预训练权重
80
+ │ └── epoch_24.pth
81
+
82
+ └── external/ # 外部依赖
83
+ ├── StreamPETR/
84
+ ├── TopoMLP_Repo/
85
+ └── nuscenes-devkit/
86
+ ```
87
+
88
+ ## 模型架构
89
+
90
+ ```
91
+ ┌─────────────────────────────────────┐
92
+ 6x 环视相机图片 → │ StreamPETR (frozen, EVA-02 ViT-L) │→ det tokens [B, 256, 256]
93
+ │ TopoMLP (frozen, EVA-02 ViT-L) │→ one-to-one lane queries (300) → Top-K → map tokens [B, 256, 256]
94
+ └─────────────────────────────────────┘
95
+
96
+ AtlasUnifiedProjector
97
+ ┌────────────────────────────────┐
98
+ │ projector_det: Linear(256→4096) │ ← 单层线性投影
99
+ │ projector_map: Linear(256→4096) │
100
+ │ projector_rp: Linear(3→256) │ ← Reference Point, zero-init
101
+ │ features += projector_rp(ref) │
102
+ └────────────────────────────────┘
103
+
104
+ 注入到 <query> token 位置 (256 det + 256 map)
105
+
106
+ ┌───────────────────────────────────────┐
107
+ │ Vicuna-7B (当前运行: 全参数微调) │
108
+ │ Causal Language Modeling Loss │
109
+ └───────────────────────────────────────┘
110
+
111
+ 多任务文本输出
112
+ (3D 检测 / 车道线 / 规划轨迹)
113
+ ```
114
+
115
+ ## 训练配置(来源与优先级)
116
+
117
+ 为避免“论文描述 / 脚本默认 / 实际运行”混淆,本仓库统一按以下优先级解释配置:
118
+
119
+ 1. **论文描述**:以 Atlas 原论文(arXiv:2405.18361)正文 Section 3.2 与 Appendix B.2 为准。
120
+ 2. **代码默认**:以 `train_atlas.py` 的 argparse 默认值为准。
121
+ 3. **实际运行**:以 `work_dirs/<exp>/args.json` + `work_dirs/*train*.log` 为准(最高优先级)。
122
+
123
+ 若三者冲突,请按第 3 条解释实验结果,不以 README 中示例命令覆盖真实运行参数。
124
+
125
+ ### 论文原文(Atlas LLM,Section 3.2 + Appendix B.2)
126
+
127
+ | 项目 | 论文描述 |
128
+ |------|----------|
129
+ | Batch Size | 1 per GPU |
130
+ | Learning Rate | 2e-5 |
131
+ | Optimizer | AdamW (weight_decay=1e-4) |
132
+ | LR Schedule | 3% linear warmup + cosine decay |
133
+ | Max Length | 4096 |
134
+ | 硬件/时长 | 8x A100,约 100 小时 |
135
+ | LoRA | 论文附录未显式给出 LoRA 开关 |
136
+
137
+ ### 当前实际运行(示例:`work_dirs/atlas_no_caption_v3_linear_warmup`)
138
+
139
+ 来源:`work_dirs/atlas_no_caption_v3_linear_warmup/args.json` 与 `work_dirs/train_no_caption_v3_linear_warmup.log`
140
+
141
+ | 参数 | 实际值 |
142
+ |------|--------|
143
+ | LLM | Vicuna-7B-v1.5 |
144
+ | 微调方式 | **全参数微调** (`use_lora=false`) |
145
+ | 可训练参数 | 6,740,530,176 |
146
+ | Learning Rate | 2e-5 |
147
+ | Optimizer | AdamW (`weight_decay=1e-4`, `torch_adam`, `adam_w_mode`) |
148
+ | LR Schedule | WarmupCosineLR(warmup ratio=3%) |
149
+ | Epochs | 10 |
150
+ | Batch Size | 1 per GPU |
151
+ | Gradient Accumulation | 2 |
152
+ | Effective Batch Size | 8 (4 GPU x 1 x 2 accum) |
153
+ | Total Steps | 99,550 |
154
+ | Warmup Steps | 2,986 |
155
+ | Max Sequence Length | 4096 |
156
+ | 分布式 | DeepSpeed ZeRO-2 |
157
+ | GPU | 4x NVIDIA H100 80GB |
158
+ | 精度 | BF16(由 `configs/ds_zero2.json` 启用) |
159
+ | Visual Tokens | **在线** (live frozen StreamPETR + TopoMLP, temporal memory);离线预提取仅作为 fallback |
160
+
161
+ ### 训练数据
162
+
163
+ | 任务 | 数据文件 | 样本数 |
164
+ |------|---------|--------|
165
+ | 3D 目标检测 | `atlas_nuscenes_train.json` | 28,130 |
166
+ | 3D 车道线检测 | `openlane_subsetB_lane_train_4pt.json` | 27,968 |
167
+ | 轨迹规划 | `atlas_planning_train_uniad_command.json` | 23,541 |
168
+ | 环境描述 (可选) | `atlas_caption_train.json` | — |
169
+ | **总计 (无 caption)** | | **79,639** |
170
+
171
+ 车道线数据使用 4 个采样点/lane(论文 Appendix A.2 要求 four lane points,本仓库实现为均匀采样),不设 lane 数量上限(论文未指定上限),按 BEV 距离近→远排序。实际平均约 25 条 lane/样本,最多约 80 条。所有坐标使用 1000-bin 离散化。规划数据包含 `gt_boxes_3d_per_timestep` 字段用于 ST-P3 对齐的 per-timestep 碰撞评测。
172
+
173
+ 三类主任务的 question pool 统一采用“前 3 条按论文 Table 6 / 7 / 9 原话整理,后 2 条为仓库补充的同风格扩展模板”的策略;其中车道线 Table 7 的第 2 条按论文现有文本原样保留。
174
+
175
+ 为避免运行时再依赖 prompt 文本猜任务,四类样本的磁盘 JSON 统一显式写入 `task` 字段:`detection` / `lane` / `planning` / `caption`。这是仓库层面的工程化 schema,对论文中的 question-answer 文本格式不做额外解释。
176
+
177
+ caption 数据按论文 Appendix A.3 的单视角设定生成:Table 8 作为 GPT-4V 标注 prompt,human prompt 采用 Figure 5 风格的单模板,并注入具体 `camera_name`。
178
+
179
+ 当前仓库不再向 prompt 追加 bins-format hint;detection / lane / caption 默认以磁盘 JSON 中的 `human` 文本作为 prompt 主体语义来源。planning ��务运行时仍会按 `planning_table3_mode` 对磁盘 `human` prompt 做轻量重写,只负责插入/剥离 command 句和 ego-state 句,再统一做 `<query>` 展开、空白归一化和 `USER: ... / ASSISTANT:` 包装。
180
+
181
+ 当前 detection 的 canonical answer 格式为:`category: [x_bin, y_bin, z_bin], [x_bin, y_bin, z_bin]; category: [x_bin, y_bin, z_bin].`。当前 lane/map 的 canonical answer 格式为:`Lane: [x_bin, y_bin, z_bin], [x_bin, y_bin, z_bin], ...; [x_bin, y_bin, z_bin], ... .`。旧的 detection flat 文本和 `lane_centerline(id=...)` legacy 文本不再作为受支持协议。
182
+
183
+ planning 的 answer/output protocol 采用 Figure 5 风格表述,但保持论文 Table 9 的二维语义:`Ego car speed value:[vx_bin, vy_bin]. Ego car acceleration value:[ax_bin, ay_bin]. Based on the ego car speed and acceleration you predicted, requeset the ego car planning waypoint in 3-seconds: [x_bin, y_bin], ...`。当前实现不为 planning 引入第三维,也不使用固定 `z=500` 占位。
184
+
185
+ 当前 planning JSON 的顶层 `route_command` 采用 **UniAD-style future-GT-derived command**:根据 future planning GT / future waypoints 的最后一个有效 timestep 的横向位移离散为 `turn left` / `turn right` / `go straight`。它不是 raw nuScenes 原生字段,也不是独立导航命令;因此 `atlas_high_level*` 在本仓库中的含义更接近 UniAD 风格条件输入,而不是 Atlas 论文 Table 3 严格意义上的独立 route command。
186
+
187
+ ### 3D Tokenizer 预训练 (已完成)
188
+
189
+ | 参数 | StreamPETR | TopoMLP |
190
+ |------|-----------|---------|
191
+ | Backbone | EVA-02 ViT-L (embed_dim=1024) | EVA-02 ViT-L (embed_dim=1024) |
192
+ | Resolution | 800x1600 | 800x1600 |
193
+ | Queries | 256 (detection) | 256 (map, Top-K from 300 one-to-one queries) |
194
+ | Control Points | - | 4 per lane |
195
+ | Epochs | 24 | 24 |
196
+ | 数据集 | nuScenes trainval | OpenLane-V2 subset-B |
197
+
198
+ ## 快速开始
199
+
200
+ ### 1. 环境
201
+
202
+ ```bash
203
+ conda activate streampetr
204
+ # 主要依赖: PyTorch 2.0+, transformers, peft, flash-attn, mmcv 1.7, mmdet3d 1.0
205
+ # DeepSpeed (ZeRO-2): pip install deepspeed
206
+ ```
207
+
208
+ ### 2. 数据准备
209
+
210
+ ```bash
211
+ # nuScenes 数据根目录 (含 v1.0-trainval/ 和 samples/)
212
+ export DATA_ROOT=/path/to/nuscenes
213
+
214
+ # OpenLane-V2 subset-B
215
+ export OPENLANE_ROOT=/path/to/OpenLane-V2/subset_B
216
+
217
+ # 生成检测 QA 数据 (按类别分组, 论文 Figure 5 格式)
218
+ python scripts/gen_atlas_full_data.py \
219
+ --data-root $DATA_ROOT --split train \
220
+ --output data/atlas_nuscenes_train.json
221
+ python scripts/gen_atlas_full_data.py \
222
+ --data-root $DATA_ROOT --split val \
223
+ --output data/atlas_nuscenes_val.json
224
+
225
+ # 生成车道线 QA 数据 (4 点/lane, 无 lane 数量上限, BEV 距离排序)
226
+ python scripts/gen_atlas_openlane_subsetB_lane_qa.py \
227
+ --openlane_root $OPENLANE_ROOT \
228
+ --split train --out_json data/openlane_subsetB_lane_train_4pt.json
229
+
230
+ python scripts/gen_atlas_openlane_subsetB_lane_qa.py \
231
+ --openlane_root $OPENLANE_ROOT \
232
+ --split val --out_json data/openlane_subsetB_lane_val_4pt.json
233
+
234
+ # 生成规划 QA 数据
235
+ # 默认输出:
236
+ # data/atlas_planning_train_uniad_command.json
237
+ # data/atlas_planning_val_uniad_command.json
238
+ # 默认写顶层 route_command(UniAD-style future-GT-derived command)
239
+ # 默认 materialize atlas_high_level human prompt;运行时仍可通过
240
+ # --planning_table3_mode 改写为 atlas_base / atlas_high_level_ego
241
+ python scripts/gen_atlas_planning_qa.py \
242
+ --data-root $DATA_ROOT --split train
243
+ python scripts/gen_atlas_planning_qa.py \
244
+ --data-root $DATA_ROOT --split val
245
+ ```
246
+
247
+ ### 3. 训练
248
+
249
+ 默认使用 **在线模式**(`--visual_token_mode online`),训练时 live 前向 frozen StreamPETR(含 temporal memory)和 TopoMLP,无需预提取 token。
250
+
251
+ ```bash
252
+ # ===== 推荐:在线模式训练(默认)=====
253
+ # 无 caption: det + planning + lane
254
+ bash scripts/train_no_caption_baseline.sh
255
+
256
+ # 含 caption: det + planning + lane + caption
257
+ bash scripts/train_with_caption_balanced.sh
258
+ ```
259
+
260
+ 等效手动命令(以无 caption 为例):
261
+
262
+ ```bash
263
+ deepspeed --num_gpus 4 train_atlas.py \
264
+ --llm_model pretrained/vicuna-7b-v1.5 \
265
+ --data_json data/atlas_nuscenes_train.json,data/atlas_planning_train_uniad_command.json,data/openlane_subsetB_lane_train_4pt.json \
266
+ --data_root $DATA_ROOT \
267
+ --visual_token_mode online \
268
+ --streampetr_config configs/streampetr_atlas_aligned.py \
269
+ --streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
270
+ --topomlp_config configs/topomlp_atlas_aligned.py \
271
+ --topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
272
+ --deepspeed configs/ds_zero2.json \
273
+ --output_dir work_dirs/atlas_no_caption_online \
274
+ --lr 2e-5 --weight_decay 1e-4 \
275
+ --batch_size 1 --epochs 10 --gradient_accumulation_steps 2 \
276
+ --warmup_ratio 0.03 --max_grad_norm 1.0 \
277
+ --save_epochs 1 --log_steps 100 \
278
+ --seed 42 --num_workers 4
279
+ ```
280
+
281
+ > **离线 fallback**:如需使用预提取 token 训练(速度更快,但 det 无 temporal memory),
282
+ > 先运行预提取脚本,再使用 `bash scripts/train_no_caption_baseline_offline.sh`。
283
+ > 预提取 token 存放于 `work_dirs/precomputed_*_tokens_offline/`。
284
+
285
+ ### 4. 评估
286
+
287
+ `eval_atlas.py` 支持两种模式:**在线模式**(默认,live frozen encoder + temporal memory)和离线 fallback。
288
+
289
+ ```bash
290
+ # ===== 推荐:在线模式评估(默认)=====
291
+ # 检测
292
+ bash scripts/eval_checkpoint.sh <checkpoint> data/atlas_nuscenes_val.json
293
+
294
+ # 车道线
295
+ bash scripts/eval_checkpoint.sh <checkpoint> data/openlane_subsetB_lane_val_4pt.json
296
+
297
+ # 规划
298
+ bash scripts/eval_checkpoint.sh <checkpoint> data/atlas_planning_val_uniad_command.json
299
+ ```
300
+
301
+ 等效手动命令:
302
+
303
+ ```bash
304
+ python eval_atlas.py \
305
+ --checkpoint work_dirs/atlas_no_caption_online/final/checkpoint.pt \
306
+ --llm_model pretrained/vicuna-7b-v1.5 \
307
+ --visual_token_mode online \
308
+ --streampetr_config configs/streampetr_atlas_aligned.py \
309
+ --streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
310
+ --topomlp_config configs/topomlp_atlas_aligned.py \
311
+ --topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
312
+ --data_json data/atlas_nuscenes_val.json \
313
+ --data_root $DATA_ROOT \
314
+ --batch_size 1 --max_new_tokens 2700 --bf16
315
+ ```
316
+
317
+ > **离线 fallback**:使用预提取 token 评估(速度更快,但 det 无 temporal memory):
318
+ > `bash scripts/eval_checkpoint_offline.sh <checkpoint> <data_json>`
319
+ >
320
+ > **快速验证脚本**(`work_dirs/_quick_eval_*.py`)使用离线 token,仅用于开发调试,**不等价于主在线评测,不应用于正式结果**。其口径与主评测存在差异:planning 解析更宽松、不走 live encoders/temporal memory、不自动检测 LoRA。
321
+
322
+ > **评测协议说明**:
323
+ > - **检测**:micro-averaged F1 @ 0.5/1.0/2.0/4.0m,BEV 中心距离匹配。
324
+ > - **车道线**:使用 OpenLane-V2 官方 F-Score 评测器(`openlanev2` 为必需依赖,缺失时直接报错,不再退化为 Chamfer)。
325
+ > - **规划**:L2 误差 + 碰撞率。规划数据含 `gt_boxes_3d_per_timestep` 字段时使用 ST-P3 对齐的 per-timestep 碰撞检测;旧数据自动退化为静态 box 检测。
326
+ > - 在线主评测(`eval_atlas.py`)需要 `mmcv`、`mmdet3d`、`openlanev2` 三个关键依赖,缺失时会在启动前报错。
327
+
328
+ ## 参考
329
+
330
+ - [Atlas: Is a 3D-Tokenized LLM the Key to Reliable Autonomous Driving?](https://arxiv.org/abs/2405.18361)
331
+ - [StreamPETR](https://github.com/exiawsh/StreamPETR)
332
+ - [TopoMLP](https://github.com/wudongming97/TopoMLP)
333
+ - [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/)
configs/REPRODUCTION.md ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # StreamPETR Atlas-Aligned 复现指南
2
+
3
+ 本文档记录了复现 Atlas 论文中 3D Tokenizer (StreamPETR) 所需的完整环境和数据准备信息。
4
+
5
+ ## 1. 版本矩阵(官方 StreamPETR 要求)
6
+
7
+ | 依赖 | 版本 | 备注 |
8
+ |------|------|------|
9
+ | Python | >= 3.8 | 推荐 3.8 |
10
+ | CUDA | 11.2 | 或兼容版本 |
11
+ | PyTorch | 1.9.0 | `pip install torch==1.9.0+cu111` |
12
+ | torchvision | 0.10.0 | |
13
+ | mmcv-full | 1.6.0 | `pip install mmcv-full==1.6.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html` |
14
+ | mmdet | 2.28.2 | |
15
+ | mmsegmentation | 0.30.0 | |
16
+ | mmdetection3d | 1.0.0rc6 | `git checkout v1.0.0rc6` |
17
+ | flash-attn | 0.2.2 | 可选,但高分辨率训练需要 |
18
+
19
+ ### 注意事项
20
+ - 如果使用 PyTorch >= 1.13,需要对应 flash-attn 0.2.8
21
+ - Tesla V100 可能不兼容 flash-attn,需要注释相关代码
22
+
23
+ ## 2. StreamPETR 代码版本
24
+
25
+ ```bash
26
+ # 克隆 StreamPETR
27
+ git clone https://github.com/exiawsh/StreamPETR
28
+
29
+ # 克隆 mmdetection3d 并切换到指定版本
30
+ cd StreamPETR
31
+ git clone https://github.com/open-mmlab/mmdetection3d.git
32
+ cd mmdetection3d
33
+ git checkout v1.0.0rc6
34
+ pip install -e .
35
+ ```
36
+
37
+ **当前仓库状态**:`external/StreamPETR` 是直接复制的代码,非 git 仓库,无法校验具体 commit。
38
+
39
+ ## 3. EVA-02 预训练权重
40
+
41
+ | 文件 | 来源 | MD5 |
42
+ |------|------|-----|
43
+ | `eva02_L_coco_det_sys_o365_remapped.pth` | [GitHub Release](https://github.com/exiawsh/storage/releases/download/v1.0/eva02_L_coco_det_sys_o365_remapped.pth) | `15c389fe4e987275c3d08405ca9eeb92` |
44
+
45
+ ```bash
46
+ # 下载权重
47
+ mkdir -p /root/autodl-tmp
48
+ wget -O /root/autodl-tmp/eva02_L_coco_det_sys_o365_remapped.pth \
49
+ https://github.com/exiawsh/storage/releases/download/v1.0/eva02_L_coco_det_sys_o365_remapped.pth
50
+
51
+ # 验证 MD5
52
+ md5sum /root/autodl-tmp/eva02_L_coco_det_sys_o365_remapped.pth
53
+ # 应该输出: 15c389fe4e987275c3d08405ca9eeb92
54
+ ```
55
+
56
+ ## 4. nuScenes 数据准备
57
+
58
+ ### 4.1 下载数据集
59
+
60
+ 从 [nuScenes 官网](https://www.nuscenes.org/download) 下载:
61
+ - Full dataset (v1.0-trainval): ~330GB
62
+ - Mini dataset (v1.0-mini): ~4GB(调试用)
63
+
64
+ ### 4.2 生成 temporal infos pkl
65
+
66
+ ```bash
67
+ # 在本仓库中请使用 external/StreamPETR 目录
68
+ cd external/StreamPETR
69
+
70
+ # 生成 nuscenes2d_temporal_infos_{train,val}.pkl
71
+ python tools/create_data_nusc.py \
72
+ --root-path ./data/nuscenes \
73
+ --out-dir ./data/nuscenes \
74
+ --extra-tag nuscenes2d \
75
+ --version v1.0
76
+ ```
77
+
78
+ ### 4.3 或下载预处理好的 pkl
79
+
80
+ | 文件 | 下载链接 |
81
+ |------|----------|
82
+ | train.pkl | [nuscenes2d_temporal_infos_train.pkl](https://github.com/exiawsh/storage/releases/download/v1.0/nuscenes2d_temporal_infos_train.pkl) |
83
+ | val.pkl | [nuscenes2d_temporal_infos_val.pkl](https://github.com/exiawsh/storage/releases/download/v1.0/nuscenes2d_temporal_infos_val.pkl) |
84
+ | test.pkl | [nuscenes2d_temporal_infos_test.pkl](https://github.com/exiawsh/storage/releases/download/v1.0/nuscenes2d_temporal_infos_test.pkl) |
85
+
86
+ ### 4.4 目录结构(官方 StreamPETR 期望)
87
+
88
+ ```
89
+ external/StreamPETR/data/nuscenes/
90
+ ├── maps/
91
+ ├── samples/
92
+ ├── sweeps/
93
+ ├── v1.0-trainval/
94
+ ├── nuscenes2d_temporal_infos_train.pkl
95
+ ├── nuscenes2d_temporal_infos_val.pkl
96
+ └── nuscenes2d_temporal_infos_test.pkl
97
+ ```
98
+
99
+ ## 5. 训练配置
100
+
101
+ ### 5.1 3D Tokenizer 预训练设置(Appendix B.1)
102
+
103
+ | 参数 | 设置 | 备注 |
104
+ |------|-----------|------|
105
+ | Backbone | EVA-02 ViT-L | 1024 embed_dim |
106
+ | Resolution | 800×1600 | 高分辨率 |
107
+ | Detection queries | 256 | |
108
+ | Epochs | 24 | |
109
+ | Batch size | 1 per GPU | 高分辨率需要小 batch |
110
+ | GPUs | 8× A100 | 论文使用 |
111
+ | Base LR | 4e-4 | bs=16 (8 GPU × 2) |
112
+ | Backbone LR | 0.1× base | |
113
+
114
+ ### 5.2 Atlas LLM 配置(论文原文 Appendix B.1)
115
+
116
+ | 参数 | 论文描述 |
117
+ |------|----------|
118
+ | Batch size | 1 per GPU |
119
+ | Learning rate | 2e-5 |
120
+ | Optimizer | AdamW (weight_decay=1e-4) |
121
+ | LR schedule | 3% linear warmup + cosine |
122
+ | Max length | 4096 |
123
+ | 训练硬件/时长 | 8x A100,约 100 小时 |
124
+
125
+ > 说明:论文附录未显式给出 LoRA 开关。
126
+ > 如需解释某次实验结果,请以该实验的 `work_dirs/<exp>/args.json` 与训练日志为准。
127
+
128
+ ### 5.3 当前配置(A100 环境)
129
+
130
+ 当前配置使用线性缩放学习率(官方注释: "bs 8: 2e-4 || bs 16: 4e-4"):
131
+ - **8× A100 GPU**
132
+ - **batch_size = 1 per GPU**
133
+ - **Base LR = 2e-4** (effective bs=8,按官方注释使用 2e-4)
134
+ - **配置文件**:`configs/streampetr_atlas_aligned.py`
135
+
136
+ ### 5.4 有效 batch size 计算
137
+
138
+ ```
139
+ 有效 batch size = num_gpus × batch_size_per_gpu × gradient_accumulation
140
+ = 8 × 1 × 1 = 8
141
+ ```
142
+
143
+ 官方学习率参考: "bs 8: 2e-4 || bs 16: 4e-4"
144
+ 当前 effective bs=8,所以使用 lr=2e-4。
145
+
146
+ 如需匹配官方 bs=16,可设置 `gradient_accumulation_steps=2`,并将 lr 改为 4e-4。
147
+
148
+ ### 5.5 其他硬件适配
149
+
150
+ 如果 GPU 数量不同,需要线性缩放学习率:
151
+ ```
152
+ 调整后 LR = 2e-4 × (实际 effective batch size / 8)
153
+ ```
154
+
155
+ 例如 6×GPU、bs=1:`LR = 2e-4 × 6/8 = 1.5e-4`
156
+
157
+ ## 6. 训练命令(官方 StreamPETR)
158
+
159
+ ```bash
160
+ # 推荐:从本仓库根目录启动(已强制使用官方 StreamPETR 配置)
161
+ bash scripts/run_train_streampetr.sh
162
+
163
+ # 或直接使用官方 StreamPETR 工具
164
+ cd external/StreamPETR
165
+ bash tools/dist_train.sh projects/configs/Atlas/atlas_streampetr_eva02_800x1600_256q_24e.py 8
166
+ ```
167
+
168
+ > 注意:不要使用本仓库内的 `scripts/train_streampetr.py`,它是非官方实现,仅用于对比与调试。
169
+
170
+ ## 7. 验证环境
171
+
172
+ ```bash
173
+ # 检查 CUDA
174
+ python -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.version.cuda}')"
175
+
176
+ # 检查 mmdet3d
177
+ python -c "import mmdet3d; print(f'mmdet3d: {mmdet3d.__version__}')"
178
+
179
+ # 检查 flash-attn
180
+ python -c "from flash_attn import flash_attn_func; print('flash-attn OK')"
181
+ ```
182
+
183
+ ## 8. 参考资料
184
+
185
+ - [StreamPETR 官方仓库](https://github.com/exiawsh/StreamPETR)
186
+ - [Atlas 论文](https://arxiv.org/abs/2405.18361)
187
+ - [nuScenes 数据集](https://www.nuscenes.org/)
188
+ - [EVA-02 论文](https://arxiv.org/abs/2303.11926)
189
+
190
+ ---
191
+
192
+ ## 缺失项清单(需补齐,否则无法严格复现)
193
+
194
+ 1. `external/StreamPETR` 的**准确 commit hash**(当前不是 git 仓库,无法校验版本)。
195
+ 2. Atlas 论文中 StreamPETR 预训练的**实际梯度累积与有效 batch size**(官方配置注释为 bs=16,但未显式开启 accum)。
196
+ 3. 论文实验所用的**完整依赖版本锁定**(除了版本矩阵以外,未提供 lockfile/requirements 固定依赖)。
197
+
198
+ ---
199
+
200
+ **最后更新**: 2026-01-29
configs/ds_zero2.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {"enabled": true},
3
+ "zero_optimization": {
4
+ "stage": 2,
5
+ "allgather_partitions": true,
6
+ "allgather_bucket_size": 50000000,
7
+ "overlap_comm": false,
8
+ "reduce_scatter": true,
9
+ "reduce_bucket_size": 50000000,
10
+ "contiguous_gradients": true
11
+ },
12
+ "gradient_accumulation_steps": 2,
13
+ "gradient_clipping": 1.0,
14
+ "train_batch_size": 8,
15
+ "train_micro_batch_size_per_gpu": 1,
16
+ "wall_clock_breakdown": false
17
+ }
eval_atlas.py ADDED
@@ -0,0 +1,1175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import json
6
+ import logging
7
+ import re
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional
10
+ from collections import Counter, defaultdict
11
+
12
+ import torch
13
+ import numpy as np
14
+
15
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
16
+
17
+ from src.model.modeling_atlas import AtlasForCausalLM
18
+ from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens
19
+ from src.model.streampetr_adapter import extract_streampetr_topk_tokens
20
+ from src.dataset.atlas_dataset import AtlasDataset, infer_task_type, make_atlas_collate_fn, load_tokenizer
21
+ from src.dataset.scene_sampler import SceneSequentialSampler
22
+ from src.prompting import PLANNING_TABLE3_MODES
23
+ from src.eval.metrics import (
24
+ parse_atlas_output,
25
+ parse_planning_output,
26
+ normalize_ground_truths,
27
+ calculate_detection_f1,
28
+ calculate_multi_threshold_detection_f1,
29
+ calculate_lane_detection_metrics,
30
+ calculate_planning_metrics,
31
+ )
32
+
33
+ logger = logging.getLogger("eval_atlas")
34
+
35
+ _DET_POINT_BLOCK_RE = re.compile(
36
+ r"\[\s*[-+]?\d+(?:\.\d+)?\s*,\s*[-+]?\d+(?:\.\d+)?\s*,\s*[-+]?\d+(?:\.\d+)?\s*\]"
37
+ )
38
+
39
+
40
+ def summarize_detection_parse(gen_text: str, det_preds: List[Dict]) -> Dict[str, object]:
41
+ stripped = gen_text.strip().rstrip(". \t\n")
42
+ parsed_count = int(len(det_preds))
43
+ raw_point_count = int(len(_DET_POINT_BLOCK_RE.findall(gen_text)))
44
+ is_empty_negative = stripped.lower() == "no objects detected within range"
45
+ partial_parse_suspected = bool(parsed_count > 0 and raw_point_count > parsed_count)
46
+
47
+ parse_failed = False
48
+ failure_reason = ""
49
+ if parsed_count == 0 and not is_empty_negative:
50
+ parse_failed = True
51
+ failure_reason = (
52
+ "coordinates_present_but_unparsed"
53
+ if raw_point_count > 0 else
54
+ "no_detection_pattern"
55
+ )
56
+ elif partial_parse_suspected:
57
+ failure_reason = "partial_parse_suspected"
58
+
59
+ return {
60
+ "detection_parse_failed": parse_failed,
61
+ "detection_parse_failure_reason": failure_reason,
62
+ "detection_partial_parse_suspected": partial_parse_suspected,
63
+ "detection_raw_point_count": raw_point_count,
64
+ "detection_parsed_count": parsed_count,
65
+ "detection_is_empty_negative": is_empty_negative,
66
+ }
67
+
68
+
69
+ def summarize_lane_parse(gen_text: str, lane_preds: List[Dict]) -> Dict[str, object]:
70
+ stripped = gen_text.strip().rstrip(". \t\n")
71
+ raw_point_count = int(len(_DET_POINT_BLOCK_RE.findall(gen_text)))
72
+ parsed_lane_count = int(len(lane_preds))
73
+ parsed_point_count = int(sum(len(l.get("points", [])) for l in lane_preds))
74
+ is_empty_negative = stripped.lower() == "no lane centerlines detected within range"
75
+ partial_parse_suspected = bool(
76
+ parsed_lane_count > 0 and raw_point_count > parsed_point_count
77
+ )
78
+
79
+ parse_failed = False
80
+ failure_reason = ""
81
+ lowered = stripped.lower()
82
+ if parsed_lane_count == 0 and not is_empty_negative:
83
+ parse_failed = True
84
+ if lowered.startswith("lane_centerline("):
85
+ failure_reason = "legacy_lane_format"
86
+ elif raw_point_count > 0:
87
+ failure_reason = "coordinates_present_but_unparsed"
88
+ elif "lane:" in lowered:
89
+ failure_reason = "lane_prefix_without_valid_points"
90
+ else:
91
+ failure_reason = "no_lane_pattern"
92
+ elif partial_parse_suspected:
93
+ failure_reason = "partial_parse_suspected"
94
+
95
+ return {
96
+ "lane_parse_failed": parse_failed,
97
+ "lane_parse_failure_reason": failure_reason,
98
+ "lane_partial_parse_suspected": partial_parse_suspected,
99
+ "lane_raw_point_count": raw_point_count,
100
+ "lane_parsed_lane_count": parsed_lane_count,
101
+ "lane_parsed_point_count": parsed_point_count,
102
+ "lane_is_empty_negative": is_empty_negative,
103
+ }
104
+
105
+
106
+ def _audit_lane_gt_point_counts(dataset, max_samples: int = 100):
107
+ """Spot-check lane GT point counts; warn if median != 4 (non-canonical format)."""
108
+ lane_indices = [i for i, t in enumerate(dataset._task_types) if t == "lane"]
109
+ if not lane_indices:
110
+ return
111
+ check = lane_indices[:max_samples]
112
+ pt_counts = []
113
+ for idx in check:
114
+ item = dataset.data[idx]
115
+ conv = item.get("conversations", [])
116
+ answer = ""
117
+ for turn in conv:
118
+ if turn.get("from") in ("gpt", "assistant"):
119
+ answer = turn.get("value", "")
120
+ break
121
+ lanes = parse_atlas_output(answer)
122
+ for lane in lanes:
123
+ pts = lane.get("points", [])
124
+ if pts:
125
+ pt_counts.append(len(pts))
126
+ if not pt_counts:
127
+ return
128
+ median_pts = int(sorted(pt_counts)[len(pt_counts) // 2])
129
+ if median_pts != 4:
130
+ logger.warning(
131
+ "Lane GT point-count median is %d (expected 4 for canonical _4pt format). "
132
+ "Sampled %d lanes from %d samples. "
133
+ "Check that --data_json points to the correct *_4pt.json files.",
134
+ median_pts, len(pt_counts), len(check),
135
+ )
136
+ else:
137
+ logger.info(
138
+ "Lane GT point-count check OK: median=%d, sampled %d lanes from %d samples",
139
+ median_pts, len(pt_counts), len(check),
140
+ )
141
+
142
+
143
+ def summarize_lane_gt_parse(gt_answer: str, gt_lanes: List[Dict]) -> Dict[str, object]:
144
+ """Mirror of summarize_lane_parse() but for the GT side."""
145
+ stripped = gt_answer.strip().rstrip(". \t\n")
146
+ raw_point_count = int(len(_DET_POINT_BLOCK_RE.findall(gt_answer)))
147
+ parsed_lane_count = int(len(gt_lanes))
148
+ parsed_point_count = int(sum(len(l.get("points", [])) for l in gt_lanes))
149
+ is_empty_negative = stripped.lower() == "no lane centerlines detected within range"
150
+ partial_parse_suspected = bool(
151
+ parsed_lane_count > 0 and raw_point_count > parsed_point_count
152
+ )
153
+
154
+ parse_failed = False
155
+ failure_reason = ""
156
+ lowered = stripped.lower()
157
+ if parsed_lane_count == 0 and not is_empty_negative:
158
+ parse_failed = True
159
+ if lowered.startswith("lane_centerline("):
160
+ failure_reason = "legacy_lane_format"
161
+ elif raw_point_count > 0:
162
+ failure_reason = "coordinates_present_but_unparsed"
163
+ elif "lane:" in lowered:
164
+ failure_reason = "lane_prefix_without_valid_points"
165
+ else:
166
+ failure_reason = "no_lane_pattern"
167
+ elif partial_parse_suspected:
168
+ failure_reason = "partial_parse_suspected"
169
+
170
+ return {
171
+ "gt_lane_parse_failed": parse_failed,
172
+ "gt_lane_parse_failure_reason": failure_reason,
173
+ "gt_lane_partial_parse_suspected": partial_parse_suspected,
174
+ "gt_lane_raw_point_count": raw_point_count,
175
+ "gt_lane_parsed_lane_count": parsed_lane_count,
176
+ "gt_lane_parsed_point_count": parsed_point_count,
177
+ "gt_lane_is_empty_negative": is_empty_negative,
178
+ }
179
+
180
+
181
+ def parse_args():
182
+ p = argparse.ArgumentParser()
183
+ p.add_argument("--checkpoint", required=True)
184
+ p.add_argument("--llm_model", default="lmsys/vicuna-7b-v1.5")
185
+ p.add_argument("--visual_hidden_size", type=int, default=256)
186
+ p.add_argument("--num_det_queries", type=int, default=256)
187
+ p.add_argument("--num_map_queries", type=int, default=256)
188
+ p.add_argument("--streampetr_config", default=None)
189
+ p.add_argument("--streampetr_ckpt", default=None)
190
+ p.add_argument("--topomlp_config", default=None)
191
+ p.add_argument("--topomlp_ckpt", default=None)
192
+ p.add_argument("--data_json", required=True)
193
+ p.add_argument("--data_root", default="/mnt/data/nuscenes")
194
+ p.add_argument("--max_length", type=int, default=4096)
195
+ p.add_argument("--max_new_tokens", type=int, default=2700)
196
+ p.add_argument("--batch_size", type=int, default=1)
197
+ p.add_argument("--num_workers", type=int, default=4)
198
+ p.add_argument("--use_lora", action="store_true")
199
+ p.add_argument("--lora_r", type=int, default=64)
200
+ p.add_argument("--lora_alpha", type=int, default=64)
201
+ p.add_argument("--load_in_4bit", action="store_true")
202
+ p.add_argument("--output_json", default=None)
203
+ p.add_argument("--max_samples", type=int, default=0)
204
+ p.add_argument("--fp16", action="store_true")
205
+ p.add_argument("--bf16", action="store_true")
206
+ p.add_argument("--precomputed_det_tokens", default=None,
207
+ help="[offline only] Dir with precomputed det tokens (.pt files)")
208
+ p.add_argument("--precomputed_map_tokens", default=None,
209
+ help="[offline only] Dir with precomputed TopoMLP map tokens (.pt files)")
210
+ p.add_argument("--visual_token_mode", choices=("online", "offline"), default="online",
211
+ help="Visual token source: online=live frozen encoders (default), offline=read *_offline dirs")
212
+ p.add_argument(
213
+ "--planning_table3_mode",
214
+ choices=PLANNING_TABLE3_MODES,
215
+ default="atlas_base",
216
+ help=(
217
+ "Planning prompt variant matching Atlas Table 3: "
218
+ "atlas_base=no command/no explicit ego state; "
219
+ "atlas_high_level=requires top-level route_command "
220
+ "(this repo uses a UniAD-style future-GT-derived command); "
221
+ "atlas_high_level_ego=requires top-level route_command plus "
222
+ "velocity/acceleration bins."
223
+ ),
224
+ )
225
+ return p.parse_args()
226
+
227
+
228
+ def infer_task(item: Dict) -> str:
229
+ return infer_task_type(item)
230
+
231
+
232
+ def load_frozen_encoder(config_path, ckpt_path, model_type, device):
233
+ if config_path is None or ckpt_path is None:
234
+ return None
235
+ try:
236
+ from mmcv import Config
237
+ from mmdet3d.models import build_model
238
+ from mmcv.runner import load_checkpoint
239
+ except ImportError:
240
+ raise RuntimeError(
241
+ f"mmcv/mmdet3d not installed but --{model_type}_config and "
242
+ f"--{model_type}_ckpt were explicitly provided. "
243
+ f"Install mmcv/mmdet3d or remove these arguments."
244
+ )
245
+
246
+ project_root = Path(__file__).resolve().parent
247
+ if model_type == "streampetr":
248
+ sp_root = str(project_root / "external" / "StreamPETR")
249
+ if sp_root not in sys.path:
250
+ sys.path.insert(0, sp_root)
251
+ try:
252
+ import projects.mmdet3d_plugin # noqa: F401
253
+ except ImportError:
254
+ raise RuntimeError(
255
+ f"StreamPETR plugin not found under {sp_root}/projects/mmdet3d_plugin. "
256
+ f"Ensure the submodule is checked out, or remove --streampetr_config/--streampetr_ckpt."
257
+ )
258
+ elif model_type == "topomlp":
259
+ tp_root = str(project_root / "external" / "TopoMLP_Repo")
260
+ if tp_root not in sys.path:
261
+ sys.path.insert(0, tp_root)
262
+ try:
263
+ os.environ["ATLAS_TOPOMLP_MODELS_ONLY"] = "1"
264
+ from mmcv.utils import registry as _reg
265
+ _orig = _reg.Registry._register_module
266
+ def _tolerant_register(self, module, module_name=None, force=False):
267
+ return _orig(self, module, module_name=module_name, force=True)
268
+ _reg.Registry._register_module = _tolerant_register
269
+ import projects.topomlp # noqa: F401
270
+ _reg.Registry._register_module = _orig
271
+ except ImportError:
272
+ raise RuntimeError(
273
+ f"TopoMLP plugin not found under {tp_root}/projects/topomlp. "
274
+ f"Ensure the submodule is checked out, or remove --topomlp_config/--topomlp_ckpt."
275
+ )
276
+
277
+ cfg = Config.fromfile(config_path)
278
+ model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
279
+ load_checkpoint(model, ckpt_path, map_location="cpu")
280
+ model.eval()
281
+ model.to(device)
282
+ for param in model.parameters():
283
+ param.requires_grad_(False)
284
+ logger.info("Loaded frozen %s from %s", model_type, ckpt_path)
285
+ return model
286
+
287
+
288
+ def _run_streampetr_forward(model, imgs, img_metas, batch, device, prev_exists=None):
289
+ """Run frozen StreamPETR forward. prev_exists controls temporal memory."""
290
+ B, N = imgs.shape[:2]
291
+
292
+ img_feats = model.extract_img_feat(imgs, 1)
293
+
294
+ data = {
295
+ "img": imgs,
296
+ "img_feats": img_feats,
297
+ "prev_exists": prev_exists if prev_exists is not None else imgs.new_zeros(B),
298
+ }
299
+
300
+ if "intrinsics_det" in batch:
301
+ K3 = batch["intrinsics_det"].to(device)
302
+ K4 = torch.zeros(B, N, 4, 4, device=device, dtype=K3.dtype)
303
+ K4[:, :, :3, :3] = K3
304
+ K4[:, :, 3, 3] = 1.0
305
+ data["intrinsics"] = K4
306
+ else:
307
+ data["intrinsics"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
308
+
309
+ if "lidar2img_det" in batch:
310
+ data["lidar2img"] = batch["lidar2img_det"].to(device)
311
+ else:
312
+ data["lidar2img"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
313
+
314
+ if "ego_pose" in batch and batch["ego_pose"] is not None:
315
+ data["ego_pose"] = batch["ego_pose"].to(device)
316
+ else:
317
+ data["ego_pose"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
318
+
319
+ if "ego_pose_inv" in batch and batch["ego_pose_inv"] is not None:
320
+ data["ego_pose_inv"] = batch["ego_pose_inv"].to(device)
321
+ else:
322
+ data["ego_pose_inv"] = torch.inverse(data["ego_pose"])
323
+
324
+ if "timestamp" in batch and batch["timestamp"] is not None:
325
+ data["timestamp"] = batch["timestamp"].to(device)
326
+ else:
327
+ data["timestamp"] = torch.zeros(B, device=device)
328
+
329
+ location = model.prepare_location(img_metas, **data)
330
+ outs_roi = model.forward_roi_head(location, **data)
331
+ topk_indexes = outs_roi["topk_indexes"]
332
+
333
+ outs = model.pts_bbox_head(location, img_metas, topk_indexes, **data)
334
+ return outs
335
+
336
+
337
+ def _reconstruct_topomlp_outs(saved: dict, device, dtype):
338
+ """Convert precomputed .pt dict back to the format adapter.forward() expects."""
339
+ def _restore(t):
340
+ return t.to(device=device, dtype=dtype).unsqueeze(0)
341
+ return {
342
+ "lc_outs_dec_list": [_restore(saved["lc_outs_dec"])],
343
+ "all_lc_cls_scores_list": [_restore(saved["lc_cls_scores"])],
344
+ "all_lc_preds_list": [_restore(saved["lc_preds"])],
345
+ "lc_outs_dec_one2many_list": [_restore(saved["lc_outs_dec_o2m"])],
346
+ "all_lc_cls_scores_one2many_list": [_restore(saved["lc_cls_scores_o2m"])],
347
+ "all_lc_preds_one2many_list": [_restore(saved["lc_preds_o2m"])],
348
+ }
349
+
350
+
351
+ def extract_visual_tokens(
352
+ streampetr_model, topomlp_model, topomlp_adapter,
353
+ batch, device, num_det_queries, visual_hidden_size,
354
+ visual_token_mode="online",
355
+ streaming_state=None,
356
+ query_token_id=None,
357
+ ):
358
+ """Extract det + map visual tokens (eval version).
359
+
360
+ Mirrors train_atlas.extract_visual_tokens: in online mode with
361
+ streaming_state, StreamPETR temporal memory is scene-aware and
362
+ duplicate physical frames are protected. The needs_map gating
363
+ skips TopoMLP when the current sample only requires det queries.
364
+ """
365
+ B = batch["pixel_values_det"].shape[0]
366
+ N = batch["pixel_values_det"].shape[1]
367
+ vis: Dict[str, torch.Tensor] = {}
368
+
369
+ needs_map = False
370
+ if query_token_id is not None and "input_ids" in batch:
371
+ n_queries = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item())
372
+ needs_map = n_queries > num_det_queries
373
+
374
+ # ---- Detection tokens ----
375
+ if visual_token_mode == "offline" and "precomputed_det" in batch and "precomputed_det_ref" in batch:
376
+ vis["detection"] = batch["precomputed_det"].to(device)
377
+ vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device)
378
+ elif visual_token_mode == "offline":
379
+ raise RuntimeError(
380
+ "visual_token_mode=offline but detection precomputed tokens are missing "
381
+ "for the current batch. Refusing to zero-fill."
382
+ )
383
+ elif streampetr_model is not None:
384
+ current_sample_id = batch.get("sample_id", [None])[0]
385
+ current_scene = batch.get("scene_id", ["__atlas_eval__"])[0]
386
+ reuse_cache = False
387
+
388
+ if streaming_state is not None:
389
+ prev_scene = streaming_state.get("prev_scene_token")
390
+ prev_sample_id = streaming_state.get("prev_sample_id")
391
+ ts_tensor = batch.get("timestamp")
392
+ current_ts = float(ts_tensor[0].item()) if ts_tensor is not None else None
393
+ prev_ts = streaming_state.get("prev_timestamp")
394
+
395
+ is_new_segment = (
396
+ prev_scene is None
397
+ or current_scene != prev_scene
398
+ or (current_ts is not None and prev_ts is not None and current_ts <= prev_ts)
399
+ )
400
+
401
+ if current_sample_id is not None and current_sample_id == prev_sample_id:
402
+ cached = streaming_state.get("cached_det")
403
+ if cached is not None:
404
+ reuse_cache = True
405
+ vis["detection"] = cached["detection"]
406
+ vis["detection_ref_points"] = cached["detection_ref_points"]
407
+
408
+ if not reuse_cache:
409
+ if is_new_segment:
410
+ streampetr_model.pts_bbox_head.reset_memory()
411
+ prev_exists_val = 0.0 if is_new_segment else 1.0
412
+ prev_exists = batch["pixel_values_det"].new_full((B,), prev_exists_val)
413
+
414
+ imgs_det = batch["pixel_values_det"].to(device)
415
+ fH, fW = 800, 1600
416
+ img_metas = [{
417
+ "pad_shape": [(fH, fW, 3)] * N,
418
+ "img_shape": [(fH, fW, 3)] * N,
419
+ "scene_token": current_scene,
420
+ } for _ in range(B)]
421
+ if "lidar2img_det" in batch:
422
+ for b in range(B):
423
+ img_metas[b]["lidar2img"] = batch["lidar2img_det"][b].cpu().numpy()
424
+ with torch.no_grad():
425
+ _run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device, prev_exists=prev_exists)
426
+ ego_pose_for_ref = batch.get("ego_pose")
427
+ if ego_pose_for_ref is not None:
428
+ ego_pose_for_ref = ego_pose_for_ref.to(device)
429
+ det_out = extract_streampetr_topk_tokens(
430
+ streampetr_model.pts_bbox_head, topk=num_det_queries,
431
+ ego_pose=ego_pose_for_ref,
432
+ )
433
+ vis["detection"] = det_out["detection"]
434
+ vis["detection_ref_points"] = det_out["detection_ref_points"]
435
+
436
+ streaming_state["cached_det"] = {
437
+ "detection": vis["detection"],
438
+ "detection_ref_points": vis["detection_ref_points"],
439
+ }
440
+
441
+ streaming_state["prev_scene_token"] = current_scene
442
+ streaming_state["prev_sample_id"] = current_sample_id
443
+ if batch.get("timestamp") is not None:
444
+ streaming_state["prev_timestamp"] = float(batch["timestamp"][0].item())
445
+ else:
446
+ imgs_det = batch["pixel_values_det"].to(device)
447
+ fH, fW = 800, 1600
448
+ img_metas = [{
449
+ "pad_shape": [(fH, fW, 3)] * N,
450
+ "img_shape": [(fH, fW, 3)] * N,
451
+ "scene_token": current_scene,
452
+ } for _ in range(B)]
453
+ if "lidar2img_det" in batch:
454
+ for b in range(B):
455
+ img_metas[b]["lidar2img"] = batch["lidar2img_det"][b].cpu().numpy()
456
+ with torch.no_grad():
457
+ streampetr_model.pts_bbox_head.reset_memory()
458
+ _run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device)
459
+ ego_pose_for_ref = batch.get("ego_pose")
460
+ if ego_pose_for_ref is not None:
461
+ ego_pose_for_ref = ego_pose_for_ref.to(device)
462
+ det_out = extract_streampetr_topk_tokens(
463
+ streampetr_model.pts_bbox_head, topk=num_det_queries,
464
+ ego_pose=ego_pose_for_ref,
465
+ )
466
+ vis["detection"] = det_out["detection"]
467
+ vis["detection_ref_points"] = det_out["detection_ref_points"]
468
+ elif visual_token_mode == "online":
469
+ raise RuntimeError(
470
+ "visual_token_mode=online but StreamPETR model is None. "
471
+ "Provide --streampetr_config and --streampetr_ckpt."
472
+ )
473
+ else:
474
+ vis["detection"] = torch.zeros(B, num_det_queries, visual_hidden_size, device=device)
475
+ vis["detection_ref_points"] = torch.zeros(B, num_det_queries, 3, device=device)
476
+
477
+ # ---- Map tokens ----
478
+ num_map_queries = num_det_queries
479
+ if topomlp_adapter is not None:
480
+ num_map_queries = topomlp_adapter.num_map_tokens
481
+
482
+ map_filled = False
483
+ if visual_token_mode == "offline" and topomlp_adapter is not None and "precomputed_map" in batch:
484
+ _params = list(topomlp_adapter.parameters())
485
+ _bufs = list(topomlp_adapter.buffers())
486
+ adapter_dtype = _params[0].dtype if _params else (_bufs[0].dtype if _bufs else torch.float32)
487
+ if B == 1:
488
+ outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype)
489
+ else:
490
+ per_sample = [_reconstruct_topomlp_outs(batch["precomputed_map"][b], device, adapter_dtype) for b in range(B)]
491
+ outs = {}
492
+ for k in per_sample[0]:
493
+ outs[k] = [torch.cat([s[k][i] for s in per_sample], dim=0) for i in range(len(per_sample[0][k]))]
494
+ with torch.no_grad():
495
+ map_out = topomlp_adapter(outs)
496
+ vis["map"] = map_out["map"]
497
+ vis["map_ref_points"] = map_out["map_ref_points"]
498
+ map_filled = True
499
+ elif visual_token_mode == "offline" and topomlp_adapter is not None:
500
+ raise RuntimeError(
501
+ "visual_token_mode=offline but map precomputed tokens are missing "
502
+ "for the current batch. Refusing to zero-fill."
503
+ )
504
+ elif needs_map and topomlp_model is not None and topomlp_adapter is not None:
505
+ imgs_map = batch["pixel_values_map"].to(device)
506
+ img_metas = []
507
+ for b in range(B):
508
+ meta = {"img_shape": [(800, 1600, 3)] * N, "pad_shape": [(800, 1600, 3)] * N}
509
+ meta["scale_factor"] = 1.0
510
+ meta["te_yolov8"] = None
511
+ if "lidar2img_map" in batch:
512
+ meta["lidar2img"] = batch["lidar2img_map"][b].cpu().numpy()
513
+ img_metas.append(meta)
514
+ with torch.no_grad():
515
+ outs = topomlp_model.simple_forward(imgs_map, img_metas)
516
+ map_out = topomlp_adapter(outs)
517
+ vis["map"] = map_out["map"]
518
+ vis["map_ref_points"] = map_out["map_ref_points"]
519
+ map_filled = True
520
+
521
+ if topomlp_adapter is not None and not map_filled:
522
+ vis["map"] = torch.zeros(B, num_map_queries, visual_hidden_size, device=device)
523
+ vis["map_ref_points"] = torch.zeros(B, num_map_queries, 3, device=device)
524
+
525
+ return vis
526
+
527
+
528
+ def parse_gt_from_item(item: Dict, task: str) -> Dict:
529
+ gt = {}
530
+ if task == "detection":
531
+ annotations = item.get("gt_boxes_3d", item.get("annotations", []))
532
+ gt_dets = []
533
+ for ann in annotations:
534
+ if isinstance(ann, dict):
535
+ cat = ann.get("category_name", ann.get("category", "unknown"))
536
+ if "box" in ann:
537
+ coords = ann["box"][:3]
538
+ elif "translation" in ann:
539
+ coords = ann["translation"][:3]
540
+ else:
541
+ continue
542
+ gt_dets.append({
543
+ "category": cat,
544
+ "world_coords": list(coords),
545
+ })
546
+ gt["detections"] = normalize_ground_truths(gt_dets)
547
+ elif task == "lane":
548
+ conv = item.get("conversations", [])
549
+ answer = ""
550
+ for turn in conv:
551
+ if turn.get("from") in ("gpt", "assistant"):
552
+ answer = turn.get("value", "")
553
+ break
554
+ gt["lanes"] = parse_atlas_output(answer)
555
+ elif task == "planning":
556
+ ego = item.get("ego_motion", {})
557
+ gt["waypoints"] = ego.get("waypoints", [])
558
+ gt["gt_boxes"] = item.get("gt_boxes_3d", [])
559
+ if "gt_boxes_3d_per_timestep" in item:
560
+ gt["gt_boxes_per_timestep"] = item["gt_boxes_3d_per_timestep"]
561
+ return gt
562
+
563
+
564
+ def _check_task_dependencies(args, *, has_lane: bool, is_online: bool):
565
+ """Check dependencies based on actual task distribution in the dataset."""
566
+ missing = []
567
+ if is_online:
568
+ for mod in ("mmcv", "mmdet3d"):
569
+ try:
570
+ __import__(mod)
571
+ except ImportError:
572
+ missing.append(mod)
573
+ if has_lane:
574
+ try:
575
+ __import__("openlanev2")
576
+ except ImportError:
577
+ missing.append("openlanev2 (needed for lane F-Score)")
578
+ if missing:
579
+ raise RuntimeError(
580
+ f"Missing dependencies for this eval run: {', '.join(missing)}. "
581
+ f"Install them before running eval_atlas.py."
582
+ )
583
+
584
+
585
+ def main():
586
+ args = parse_args()
587
+ logging.basicConfig(
588
+ level=logging.INFO,
589
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
590
+ )
591
+
592
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
593
+
594
+ tokenizer = load_tokenizer(args.llm_model)
595
+ if "<query>" not in tokenizer.get_vocab():
596
+ tokenizer.add_tokens(["<query>"])
597
+
598
+ dtype = torch.float32
599
+ if args.bf16:
600
+ dtype = torch.bfloat16
601
+ elif args.fp16:
602
+ dtype = torch.float16
603
+
604
+ dm = "auto" if args.load_in_4bit else None
605
+ atlas = AtlasForCausalLM(
606
+ llm_model_name=args.llm_model,
607
+ visual_hidden_size=args.visual_hidden_size,
608
+ num_queries=args.num_det_queries,
609
+ num_map_queries=args.num_map_queries,
610
+ load_in_4bit=args.load_in_4bit,
611
+ use_flash_attention=True,
612
+ device_map=dm,
613
+ torch_dtype=dtype,
614
+ use_lora=args.use_lora,
615
+ lora_r=args.lora_r,
616
+ lora_alpha=args.lora_alpha,
617
+ )
618
+ atlas.resize_token_embeddings(len(tokenizer))
619
+ atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("<query>"))
620
+ if dm is None:
621
+ atlas = atlas.to(device)
622
+
623
+ topomlp_adapter = None
624
+
625
+ ckpt = torch.load(args.checkpoint, map_location="cpu")
626
+
627
+ if "atlas_state_dict" not in ckpt:
628
+ raise RuntimeError(
629
+ f"Checkpoint missing 'atlas_state_dict'. "
630
+ f"Top-level keys: {sorted(ckpt.keys()) if isinstance(ckpt, dict) else type(ckpt).__name__}. "
631
+ f"Make sure --checkpoint points to an Atlas training checkpoint (checkpoint.pt), "
632
+ f"not a frozen encoder weight (.pth)."
633
+ )
634
+ atlas_sd = ckpt["atlas_state_dict"]
635
+ if not isinstance(atlas_sd, dict) or len(atlas_sd) == 0:
636
+ raise RuntimeError(
637
+ f"'atlas_state_dict' is empty or not a dict (type={type(atlas_sd).__name__}, "
638
+ f"len={len(atlas_sd) if isinstance(atlas_sd, dict) else 'N/A'}). "
639
+ f"Checkpoint is likely corrupted."
640
+ )
641
+
642
+ # Auto-detect LoRA: if checkpoint has LoRA keys but model doesn't have LoRA,
643
+ # rebuild the model with LoRA enabled to prevent silent degradation.
644
+ has_lora_keys = any("lora_" in k for k in atlas_sd)
645
+ if has_lora_keys and not args.use_lora:
646
+ logger.warning(
647
+ "Checkpoint contains LoRA weights but --use_lora was not set. "
648
+ "Auto-enabling LoRA to prevent silent degradation."
649
+ )
650
+ args.use_lora = True
651
+ atlas = AtlasForCausalLM(
652
+ llm_model_name=args.llm_model,
653
+ visual_hidden_size=args.visual_hidden_size,
654
+ num_queries=args.num_det_queries,
655
+ num_map_queries=args.num_map_queries,
656
+ load_in_4bit=args.load_in_4bit,
657
+ use_flash_attention=True,
658
+ device_map=dm,
659
+ torch_dtype=dtype,
660
+ use_lora=True,
661
+ lora_r=args.lora_r,
662
+ lora_alpha=args.lora_alpha,
663
+ )
664
+ atlas.resize_token_embeddings(len(tokenizer))
665
+ atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("<query>"))
666
+ if dm is None:
667
+ atlas = atlas.to(device)
668
+
669
+ missing, unexpected = atlas.load_state_dict(atlas_sd, strict=False)
670
+ if missing:
671
+ raise RuntimeError(
672
+ f"Atlas checkpoint is incomplete: {len(missing)} missing keys "
673
+ f"(first 10: {missing[:10]}). This means the checkpoint does not "
674
+ f"match the current model architecture. Refusing to evaluate with "
675
+ f"partially-initialized weights."
676
+ )
677
+ if unexpected:
678
+ logger.warning("Unexpected keys in checkpoint (possibly ignored): %d keys, first 5: %s",
679
+ len(unexpected), unexpected[:5])
680
+ logger.info("Loaded Atlas weights from %s (%d keys, 0 missing)", args.checkpoint, len(atlas_sd))
681
+ _tp_bev_range = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0)
682
+ if args.topomlp_config:
683
+ try:
684
+ from mmcv import Config as _Cfg
685
+ _tp_cfg = _Cfg.fromfile(args.topomlp_config)
686
+ if hasattr(_tp_cfg, "point_cloud_range"):
687
+ _tp_bev_range = tuple(float(v) for v in _tp_cfg.point_cloud_range)
688
+ logger.info("TopoMLP bev_range from config: %s", _tp_bev_range)
689
+ except Exception as e:
690
+ logger.warning("Failed to read point_cloud_range from TopoMLP config: %s. Using default: %s", e, _tp_bev_range)
691
+
692
+ if args.topomlp_config or args.topomlp_ckpt or args.precomputed_map_tokens:
693
+ topomlp_adapter = TopoMLPToAtlasMapTokens(
694
+ num_map_tokens=args.num_map_queries,
695
+ hidden_size=args.visual_hidden_size,
696
+ bev_range=_tp_bev_range,
697
+ ).to(device)
698
+ if "adapter_state_dict" in ckpt:
699
+ topomlp_adapter.load_state_dict(ckpt["adapter_state_dict"], strict=False)
700
+ topomlp_adapter.eval()
701
+
702
+ atlas.eval()
703
+
704
+ is_online = args.visual_token_mode == "online"
705
+ _precomp_det = args.precomputed_det_tokens if not is_online else None
706
+ _precomp_map = args.precomputed_map_tokens if not is_online else None
707
+
708
+ dataset = AtlasDataset(
709
+ json_file=args.data_json,
710
+ image_root=args.data_root,
711
+ tokenizer=tokenizer,
712
+ max_length=args.max_length,
713
+ is_training=False,
714
+ planning_table3_mode=args.planning_table3_mode,
715
+ precomputed_det_tokens=_precomp_det,
716
+ precomputed_map_tokens=_precomp_map,
717
+ )
718
+
719
+ _task_counts = Counter(dataset._task_types)
720
+ has_lane = "lane" in _task_counts
721
+ has_planning = "planning" in _task_counts
722
+ needs_map_tokens = has_lane or has_planning
723
+ logger.info("Task needs: has_lane=%s, has_planning=%s, needs_map=%s",
724
+ has_lane, has_planning, needs_map_tokens)
725
+
726
+ if has_lane:
727
+ _audit_lane_gt_point_counts(dataset, max_samples=100)
728
+
729
+ _check_task_dependencies(args, has_lane=has_lane, is_online=is_online)
730
+
731
+ if is_online:
732
+ if args.precomputed_det_tokens or args.precomputed_map_tokens:
733
+ raise RuntimeError(
734
+ "visual_token_mode=online forbids --precomputed_* arguments."
735
+ )
736
+ if args.batch_size != 1:
737
+ raise RuntimeError(
738
+ "visual_token_mode=online with temporal memory requires "
739
+ "--batch_size 1. Got: %d" % args.batch_size
740
+ )
741
+ if not args.streampetr_config or not args.streampetr_ckpt:
742
+ raise RuntimeError(
743
+ "online mode requires --streampetr_config and --streampetr_ckpt"
744
+ )
745
+ for p in (args.streampetr_config, args.streampetr_ckpt):
746
+ if not os.path.exists(p):
747
+ raise RuntimeError(f"Required online asset does not exist: {p}")
748
+ if needs_map_tokens:
749
+ if not args.topomlp_config or not args.topomlp_ckpt:
750
+ raise RuntimeError(
751
+ "online mode with lane/planning tasks requires "
752
+ "--topomlp_config and --topomlp_ckpt"
753
+ )
754
+ for p in (args.topomlp_config, args.topomlp_ckpt):
755
+ if not os.path.exists(p):
756
+ raise RuntimeError(f"Required online asset does not exist: {p}")
757
+ else:
758
+ if not _precomp_det:
759
+ raise RuntimeError(
760
+ "offline mode requires --precomputed_det_tokens"
761
+ )
762
+ if needs_map_tokens and not _precomp_map:
763
+ raise RuntimeError(
764
+ "offline mode with lane/planning tasks requires "
765
+ "--precomputed_map_tokens"
766
+ )
767
+ for p in (_precomp_det, _precomp_map):
768
+ if p and not os.path.isdir(p):
769
+ raise RuntimeError(f"Offline token directory does not exist: {p}")
770
+
771
+ streampetr_model = load_frozen_encoder(
772
+ args.streampetr_config, args.streampetr_ckpt, "streampetr", device,
773
+ )
774
+ topomlp_model = load_frozen_encoder(
775
+ args.topomlp_config, args.topomlp_ckpt, "topomlp", device,
776
+ )
777
+
778
+ if is_online:
779
+ scene_groups = dataset.get_scene_groups()
780
+ sampler = SceneSequentialSampler(scene_groups, shuffle_scenes=False)
781
+ logger.info("Online eval: scene-sequential sampler (%d scenes)", len(scene_groups))
782
+ else:
783
+ sampler = None
784
+
785
+ collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id)
786
+ dataloader = torch.utils.data.DataLoader(
787
+ dataset,
788
+ batch_size=1 if is_online else args.batch_size,
789
+ shuffle=False,
790
+ sampler=sampler,
791
+ num_workers=args.num_workers,
792
+ collate_fn=collate_fn,
793
+ pin_memory=True,
794
+ )
795
+
796
+ streaming_state = {} if is_online else None
797
+
798
+ task_preds: Dict[str, List] = defaultdict(list)
799
+ task_gts: Dict[str, List] = defaultdict(list)
800
+ all_outputs: List[Dict] = []
801
+ sample_count = 0
802
+
803
+ logger.info("Starting evaluation on %d samples (mode=%s)...", len(dataset), args.visual_token_mode)
804
+
805
+ if is_online and streampetr_model is not None:
806
+ streampetr_model.pts_bbox_head.reset_memory()
807
+
808
+ for batch_idx, batch in enumerate(dataloader):
809
+ if args.max_samples > 0 and sample_count >= args.max_samples:
810
+ break
811
+
812
+ B = batch["input_ids"].shape[0]
813
+ input_ids = batch["input_ids"].to(device)
814
+ attention_mask = batch["attention_mask"].to(device)
815
+
816
+ visual_features = extract_visual_tokens(
817
+ streampetr_model, topomlp_model, topomlp_adapter,
818
+ batch, device, args.num_det_queries, args.visual_hidden_size,
819
+ visual_token_mode=args.visual_token_mode,
820
+ streaming_state=streaming_state,
821
+ query_token_id=atlas.query_token_id,
822
+ )
823
+
824
+ with torch.no_grad():
825
+ generated_ids = atlas.generate(
826
+ input_ids=input_ids,
827
+ attention_mask=attention_mask,
828
+ visual_features=visual_features,
829
+ max_new_tokens=args.max_new_tokens,
830
+ do_sample=False,
831
+ )
832
+
833
+ for b in range(B):
834
+ if args.max_samples > 0 and sample_count >= args.max_samples:
835
+ break
836
+
837
+ gen_text_full = tokenizer.decode(generated_ids[b], skip_special_tokens=True)
838
+ # Extract assistant response: take content after the last "ASSISTANT:" tag
839
+ _split_tag = "ASSISTANT:"
840
+ if _split_tag in gen_text_full:
841
+ gen_text = gen_text_full.split(_split_tag)[-1].strip()
842
+ else:
843
+ gen_text = gen_text_full.strip()
844
+ sample_id = batch["sample_id"][b] if "sample_id" in batch else str(sample_count)
845
+ item_idx = int(batch["dataset_idx"][b].item()) if "dataset_idx" in batch else (batch_idx * B + b)
846
+ item = dataset.data[item_idx]
847
+ task = infer_task(item)
848
+ gt = parse_gt_from_item(item, task)
849
+
850
+ record = {
851
+ "sample_id": sample_id,
852
+ "task": task,
853
+ "generated_text": gen_text,
854
+ }
855
+
856
+ if task == "detection":
857
+ preds = parse_atlas_output(gen_text)
858
+ det_preds = [p for p in preds if p.get("type") == "detection"]
859
+ gt_dets = gt.get("detections", [])
860
+ task_preds["detection"].append(det_preds)
861
+ task_gts["detection"].append(gt_dets)
862
+ record.update(summarize_detection_parse(gen_text, det_preds))
863
+ record["num_preds"] = len(det_preds)
864
+ record["num_gt"] = len(gt_dets)
865
+
866
+ elif task == "lane":
867
+ preds = parse_atlas_output(gen_text)
868
+ lane_preds = [p for p in preds if p.get("type") == "lane"]
869
+ gt_lanes = gt.get("lanes", [])
870
+ task_preds["lane"].append(lane_preds)
871
+ task_gts["lane"].append(gt_lanes)
872
+ record.update(summarize_lane_parse(gen_text, lane_preds))
873
+ record["num_preds"] = len(lane_preds)
874
+ record["num_gt"] = len(gt_lanes)
875
+
876
+ gt_answer_text = ""
877
+ for _turn in item.get("conversations", []):
878
+ if _turn.get("from") in ("gpt", "assistant"):
879
+ gt_answer_text = _turn.get("value", "")
880
+ break
881
+ record.update(summarize_lane_gt_parse(gt_answer_text, gt_lanes))
882
+
883
+ elif task == "planning":
884
+ loose_plan_pred = parse_planning_output(gen_text, require_full_vap=False)
885
+ strict_plan_pred = parse_planning_output(gen_text, require_full_vap=True)
886
+ gt_wps = gt.get("waypoints", [])
887
+ gt_boxes = gt.get("gt_boxes", [])
888
+ has_waypoints = bool(
889
+ loose_plan_pred is not None and "waypoints" in loose_plan_pred
890
+ )
891
+ has_velocity = bool(
892
+ loose_plan_pred is not None and "velocity_bins" in loose_plan_pred
893
+ )
894
+ has_acceleration = bool(
895
+ loose_plan_pred is not None and "acceleration_bins" in loose_plan_pred
896
+ )
897
+ vap_complete = bool(strict_plan_pred is not None)
898
+
899
+ require_strict = args.planning_table3_mode == "atlas_high_level_ego"
900
+ accepted_pred = strict_plan_pred if require_strict else loose_plan_pred
901
+ parse_ok = accepted_pred is not None and "waypoints" in accepted_pred
902
+
903
+ if parse_ok:
904
+ plan_pred = accepted_pred
905
+ record["planning_parse_failed"] = False
906
+ record["planning_vap_complete"] = vap_complete
907
+ record["planning_has_velocity"] = has_velocity
908
+ record["planning_has_acceleration"] = has_acceleration
909
+ record["planning_has_waypoints"] = True
910
+ if require_strict:
911
+ record["planning_parse_failure_reason"] = ""
912
+ else:
913
+ record["planning_parse_failure_reason"] = (
914
+ "" if vap_complete else "missing_velocity_or_acceleration"
915
+ )
916
+ else:
917
+ plan_pred = {"waypoints": [[0.0, 0.0]] * max(len(gt_wps), 6)}
918
+ record["planning_parse_failed"] = True
919
+ record["planning_vap_complete"] = False
920
+ record["planning_has_velocity"] = has_velocity
921
+ record["planning_has_acceleration"] = has_acceleration
922
+ record["planning_has_waypoints"] = has_waypoints
923
+ if require_strict and has_waypoints and not vap_complete:
924
+ record["planning_parse_failure_reason"] = (
925
+ "strict_mode_missing_velocity_or_acceleration"
926
+ )
927
+ elif has_waypoints:
928
+ record["planning_parse_failure_reason"] = "missing_velocity_or_acceleration"
929
+ else:
930
+ record["planning_parse_failure_reason"] = "unparseable_waypoints"
931
+ task_preds["planning"].append(plan_pred)
932
+ plan_gt_entry = {
933
+ "waypoints": gt_wps,
934
+ "gt_boxes": gt_boxes,
935
+ }
936
+ if "gt_boxes_per_timestep" in gt:
937
+ plan_gt_entry["gt_boxes_per_timestep"] = gt["gt_boxes_per_timestep"]
938
+ task_gts["planning"].append(plan_gt_entry)
939
+ record["has_plan"] = has_waypoints
940
+
941
+ elif task == "caption":
942
+ record["skipped"] = True
943
+
944
+ else:
945
+ logger.warning("Unknown task %r for sample %s — skipping metrics", task, sample_id)
946
+ record["skipped"] = True
947
+
948
+ all_outputs.append(record)
949
+ sample_count += 1
950
+
951
+ if (batch_idx + 1) % 50 == 0:
952
+ logger.info("Processed %d / %d samples", sample_count, len(dataset))
953
+
954
+ logger.info("Evaluation complete. Total samples: %d", sample_count)
955
+
956
+ _skipped = sum(1 for r in all_outputs if r.get("skipped"))
957
+ if _skipped:
958
+ logger.warning(
959
+ "%d samples were not scored (caption or unknown task). "
960
+ "These consumed GPU time but produced no metrics.",
961
+ _skipped,
962
+ )
963
+
964
+ results = {}
965
+
966
+ if task_preds["detection"] and task_gts["detection"]:
967
+ thresholds = (0.5, 1.0, 2.0, 4.0)
968
+ global_counts = {t: {"tp": 0, "fp": 0, "fn": 0} for t in thresholds}
969
+ for s_preds, s_gts in zip(task_preds["detection"], task_gts["detection"]):
970
+ for t in thresholds:
971
+ m = calculate_detection_f1(s_preds, s_gts, threshold=t)
972
+ global_counts[t]["tp"] += m["tp"]
973
+ global_counts[t]["fp"] += m["fp"]
974
+ global_counts[t]["fn"] += m["fn"]
975
+ det_results = {}
976
+ for t in thresholds:
977
+ tp, fp, fn = global_counts[t]["tp"], global_counts[t]["fp"], global_counts[t]["fn"]
978
+ p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
979
+ r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
980
+ f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
981
+ det_results[f"F1@{t}m"] = round(f1, 4)
982
+ det_results[f"P@{t}m"] = round(p, 4)
983
+ det_results[f"R@{t}m"] = round(r, 4)
984
+ n_det_total = len(task_preds["detection"])
985
+ n_det_failed = sum(
986
+ 1
987
+ for r in all_outputs
988
+ if r.get("task") == "detection" and r.get("detection_parse_failed", False)
989
+ )
990
+ n_det_partial = sum(
991
+ 1
992
+ for r in all_outputs
993
+ if r.get("task") == "detection" and r.get("detection_partial_parse_suspected", False)
994
+ )
995
+ n_det_empty_negative = sum(
996
+ 1
997
+ for r in all_outputs
998
+ if r.get("task") == "detection" and r.get("detection_is_empty_negative", False)
999
+ )
1000
+ det_results["num_samples"] = n_det_total
1001
+ det_results["parse_fail_count"] = n_det_failed
1002
+ det_results["parse_fail_rate"] = n_det_failed / max(n_det_total, 1)
1003
+ det_results["partial_parse_count"] = n_det_partial
1004
+ det_results["partial_parse_rate"] = n_det_partial / max(n_det_total, 1)
1005
+ det_results["empty_negative_count"] = n_det_empty_negative
1006
+ det_results["empty_negative_rate"] = n_det_empty_negative / max(n_det_total, 1)
1007
+ results["detection"] = det_results
1008
+ logger.info("Detection results (micro-averaged):")
1009
+ for k, v in sorted(results["detection"].items()):
1010
+ if isinstance(v, float):
1011
+ logger.info(" %s: %.4f", k, v)
1012
+
1013
+ if task_preds["lane"] and task_gts["lane"]:
1014
+ try:
1015
+ from openlanev2.evaluation.f_score import LaneEval
1016
+ except ImportError:
1017
+ raise RuntimeError(
1018
+ "openlanev2 is required for lane evaluation but could not be imported. "
1019
+ "Install it with: pip install openlanev2"
1020
+ )
1021
+ _lane_evaluator = LaneEval()
1022
+
1023
+ def _lanes_to_ndarray_list(lanes):
1024
+ out = []
1025
+ for lane in lanes:
1026
+ pts = lane.get("points", [])
1027
+ if not pts:
1028
+ continue
1029
+ rows = []
1030
+ for pt in pts:
1031
+ if isinstance(pt, dict):
1032
+ rows.append(pt.get("world_coords", [0, 0, 0])[:3])
1033
+ else:
1034
+ rows.append(list(pt)[:3])
1035
+ arr = np.array(rows, dtype=np.float64)
1036
+ if arr.shape[0] >= 2:
1037
+ out.append(arr)
1038
+ return out
1039
+
1040
+ stats = []
1041
+ for pl, gl in zip(task_preds["lane"], task_gts["lane"]):
1042
+ pa = _lanes_to_ndarray_list(pl)
1043
+ ga = _lanes_to_ndarray_list(gl)
1044
+ pc = [np.int8(1)] * len(pa)
1045
+ gc = [np.int8(1)] * len(ga)
1046
+ r, p, c, ng, np_, mn = _lane_evaluator.bench(pa, pc, ga, gc)
1047
+ stats.append(np.array([r, p, c, ng, np_, mn]))
1048
+ if stats:
1049
+ s = np.array(stats)
1050
+ tg = np.sum(s[:, 3])
1051
+ tp_sum = np.sum(s[:, 4])
1052
+ lane_r = float(np.sum(s[:, 0]) / max(tg, 1e-6))
1053
+ lane_p = float(np.sum(s[:, 1]) / max(tp_sum, 1e-6))
1054
+ lane_f1 = 2 * lane_p * lane_r / (lane_p + lane_r) if (lane_p + lane_r) > 0 else 0.0
1055
+ else:
1056
+ lane_p = lane_r = lane_f1 = 0.0
1057
+ results["lane"] = {
1058
+ "lane_precision": round(lane_p, 4),
1059
+ "lane_recall": round(lane_r, 4),
1060
+ "lane_f1": round(lane_f1, 4),
1061
+ "method": "openlanev2_f_score",
1062
+ }
1063
+ n_lane_total = len(task_preds["lane"])
1064
+ n_lane_failed = sum(
1065
+ 1
1066
+ for r in all_outputs
1067
+ if r.get("task") == "lane" and r.get("lane_parse_failed", False)
1068
+ )
1069
+ n_lane_partial = sum(
1070
+ 1
1071
+ for r in all_outputs
1072
+ if r.get("task") == "lane" and r.get("lane_partial_parse_suspected", False)
1073
+ )
1074
+ n_lane_empty_negative = sum(
1075
+ 1
1076
+ for r in all_outputs
1077
+ if r.get("task") == "lane" and r.get("lane_is_empty_negative", False)
1078
+ )
1079
+ results["lane"]["num_samples"] = n_lane_total
1080
+ results["lane"]["parse_fail_count"] = n_lane_failed
1081
+ results["lane"]["parse_fail_rate"] = n_lane_failed / max(n_lane_total, 1)
1082
+ results["lane"]["partial_parse_count"] = n_lane_partial
1083
+ results["lane"]["partial_parse_rate"] = n_lane_partial / max(n_lane_total, 1)
1084
+ results["lane"]["empty_negative_count"] = n_lane_empty_negative
1085
+ results["lane"]["empty_negative_rate"] = n_lane_empty_negative / max(n_lane_total, 1)
1086
+
1087
+ n_gt_lane_failed = sum(
1088
+ 1
1089
+ for r in all_outputs
1090
+ if r.get("task") == "lane" and r.get("gt_lane_parse_failed", False)
1091
+ )
1092
+ n_gt_lane_partial = sum(
1093
+ 1
1094
+ for r in all_outputs
1095
+ if r.get("task") == "lane" and r.get("gt_lane_partial_parse_suspected", False)
1096
+ )
1097
+ n_gt_lane_empty_negative = sum(
1098
+ 1
1099
+ for r in all_outputs
1100
+ if r.get("task") == "lane" and r.get("gt_lane_is_empty_negative", False)
1101
+ )
1102
+ results["lane"]["gt_parse_fail_count"] = n_gt_lane_failed
1103
+ results["lane"]["gt_parse_fail_rate"] = n_gt_lane_failed / max(n_lane_total, 1)
1104
+ results["lane"]["gt_partial_parse_count"] = n_gt_lane_partial
1105
+ results["lane"]["gt_partial_parse_rate"] = n_gt_lane_partial / max(n_lane_total, 1)
1106
+ results["lane"]["gt_empty_negative_count"] = n_gt_lane_empty_negative
1107
+ results["lane"]["gt_empty_negative_rate"] = n_gt_lane_empty_negative / max(n_lane_total, 1)
1108
+
1109
+ logger.info("Lane results (OpenLane-V2 official F-Score):")
1110
+ for k, v in sorted(results["lane"].items()):
1111
+ if isinstance(v, float):
1112
+ logger.info(" %s: %.4f", k, v)
1113
+ else:
1114
+ logger.info(" %s: %s", k, v)
1115
+
1116
+ if task_preds["planning"] and task_gts["planning"]:
1117
+ results["planning"] = calculate_planning_metrics(
1118
+ task_preds["planning"], task_gts["planning"],
1119
+ )
1120
+ n_plan_total = len(task_preds["planning"])
1121
+ n_plan_failed = sum(
1122
+ 1
1123
+ for r in all_outputs
1124
+ if r.get("task") == "planning" and r.get("planning_parse_failed", False)
1125
+ )
1126
+ n_plan_vap_complete = sum(
1127
+ 1
1128
+ for r in all_outputs
1129
+ if r.get("task") == "planning" and r.get("planning_vap_complete", False)
1130
+ )
1131
+ n_plan_missing_va = sum(
1132
+ 1
1133
+ for r in all_outputs
1134
+ if r.get("task") == "planning"
1135
+ and (not r.get("planning_parse_failed", False))
1136
+ and (not r.get("planning_vap_complete", True))
1137
+ )
1138
+ results["planning"]["num_samples"] = n_plan_total
1139
+ results["planning"]["parse_fail_count"] = n_plan_failed
1140
+ results["planning"]["parse_fail_rate"] = (
1141
+ n_plan_failed / max(n_plan_total, 1)
1142
+ )
1143
+ results["planning"]["vap_complete_count"] = n_plan_vap_complete
1144
+ results["planning"]["vap_complete_rate"] = (
1145
+ n_plan_vap_complete / max(n_plan_total, 1)
1146
+ )
1147
+ results["planning"]["missing_velocity_or_acceleration_count"] = n_plan_missing_va
1148
+ results["planning"]["missing_velocity_or_acceleration_rate"] = (
1149
+ n_plan_missing_va / max(n_plan_total, 1)
1150
+ )
1151
+ logger.info("Planning results:")
1152
+ for k, v in sorted(results["planning"].items()):
1153
+ if isinstance(v, float):
1154
+ logger.info(" %s: %.4f", k, v)
1155
+ else:
1156
+ logger.info(" %s: %s", k, v)
1157
+
1158
+ output_path = args.output_json
1159
+ if output_path is None:
1160
+ ckpt_dir = Path(args.checkpoint).parent
1161
+ output_path = str(ckpt_dir / "eval_results.json")
1162
+
1163
+ with open(output_path, "w") as f:
1164
+ json.dump({
1165
+ "metrics": results,
1166
+ "num_samples": sample_count,
1167
+ "args": vars(args),
1168
+ "predictions": all_outputs[:100],
1169
+ }, f, indent=2, ensure_ascii=False)
1170
+ logger.info("Results saved to %s", output_path)
1171
+
1172
+
1173
+ if __name__ == "__main__":
1174
+ main()
1175
+
extract_streampetr_tokens.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Pre-extract frozen StreamPETR detection tokens for offline Atlas training.
3
+
4
+ Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
5
+ For online training (default), use: bash scripts/train_no_caption_baseline.sh
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"):
12
+ print(
13
+ "ERROR: This is an OFFLINE token extraction script.\n"
14
+ "It is isolated by default to prevent accidental use.\n"
15
+ "If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n"
16
+ "For online training use: bash scripts/train_no_caption_baseline.sh",
17
+ file=sys.stderr,
18
+ )
19
+ sys.exit(1)
20
+
21
+ import argparse
22
+ import json
23
+ import re
24
+ import time
25
+ from collections import defaultdict
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ import numpy as np
30
+
31
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
32
+
33
+ from src.model.streampetr_adapter import extract_streampetr_topk_tokens
34
+
35
+
36
+ def parse_args():
37
+ p = argparse.ArgumentParser()
38
+ p.add_argument("--streampetr_config", required=True)
39
+ p.add_argument("--streampetr_ckpt", required=True)
40
+ p.add_argument("--data_json", required=True)
41
+ p.add_argument("--data_root", default="/home/guoyuanbo/autodl-tmp/data/nuscenes")
42
+ p.add_argument("--output_dir", required=True)
43
+ p.add_argument("--topk", type=int, default=256)
44
+ p.add_argument("--image_path_remap", default=None)
45
+ p.add_argument("--shard_id", type=int, default=0)
46
+ p.add_argument("--num_shards", type=int, default=1)
47
+ return p.parse_args()
48
+
49
+
50
+ def load_streampetr(config_path, ckpt_path, device):
51
+ sp_root = str(Path(__file__).resolve().parent / "external" / "StreamPETR")
52
+ if sp_root not in sys.path:
53
+ sys.path.insert(0, sp_root)
54
+ import projects.mmdet3d_plugin # noqa: F401
55
+ from mmcv import Config
56
+ from mmdet3d.models import build_model
57
+ from mmcv.runner import load_checkpoint
58
+
59
+ cfg = Config.fromfile(config_path)
60
+ model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
61
+ load_checkpoint(model, ckpt_path, map_location="cpu")
62
+ model.eval()
63
+ model.to(device)
64
+ for param in model.parameters():
65
+ param.requires_grad_(False)
66
+ return model
67
+
68
+
69
+ def _parse_openlane_scene_timestamp(item):
70
+ sample_id = str(item.get("id", ""))
71
+ m = re.match(r"openlane_subsetB_(?:train|val)_(.+?)_(\d+)$", sample_id)
72
+ if m is not None:
73
+ return f"openlane_{m.group(1)}", int(m.group(2))
74
+
75
+ image_paths = item.get("image_paths", [])
76
+ if image_paths:
77
+ p0 = str(image_paths[0]).replace("\\", "/")
78
+ m2 = re.search(r"/(?:train|val)/([^/]+)/image/[^/]+/(\d+)\.(?:jpg|jpeg|png)$", p0, flags=re.IGNORECASE)
79
+ if m2 is not None:
80
+ return f"openlane_{m2.group(1)}", int(m2.group(2))
81
+
82
+ return None, None
83
+
84
+
85
+ def build_scene_order(data_items, data_root):
86
+ nuscenes_root = Path(data_root)
87
+ sample_file = None
88
+ for v in ["v1.0-trainval", "v1.0-mini", "v1.0-test"]:
89
+ sf = nuscenes_root / v / "sample.json"
90
+ if sf.exists():
91
+ sample_file = sf
92
+ break
93
+
94
+ samples_meta = {}
95
+ if sample_file is None:
96
+ print("WARNING: sample.json not found — using OpenLane/id-based scene ordering where possible")
97
+ else:
98
+ with open(sample_file) as f:
99
+ samples_meta = {s["token"]: s for s in json.load(f)}
100
+
101
+ n_nuscenes = 0
102
+ n_openlane = 0
103
+ n_unknown = 0
104
+
105
+ scene_map = defaultdict(list)
106
+ for idx, item in enumerate(data_items):
107
+ sample_token = str(item.get("id", ""))
108
+ meta = samples_meta.get(sample_token, None)
109
+ if meta is not None:
110
+ scene_token = meta.get("scene_token", f"_nus_unknown_{idx}")
111
+ timestamp = int(meta.get("timestamp", 0))
112
+ scene_map[scene_token].append((timestamp, idx))
113
+ n_nuscenes += 1
114
+ continue
115
+
116
+ scene_token, timestamp = _parse_openlane_scene_timestamp(item)
117
+ if scene_token is not None:
118
+ scene_map[scene_token].append((int(timestamp), idx))
119
+ n_openlane += 1
120
+ else:
121
+ scene_map[f"_unknown_{idx}"].append((0, idx))
122
+ n_unknown += 1
123
+
124
+ scenes = []
125
+ for scene_token in sorted(scene_map.keys()):
126
+ frames = sorted(scene_map[scene_token], key=lambda x: x[0])
127
+ scenes.append([idx for _, idx in frames])
128
+
129
+ total = sum(len(s) for s in scenes)
130
+ print(
131
+ "Scene grouping: %d scenes, %d samples (nuScenes=%d, OpenLane=%d, unknown=%d)"
132
+ % (len(scenes), total, n_nuscenes, n_openlane, n_unknown)
133
+ )
134
+ return scenes
135
+
136
+
137
+ def load_and_preprocess_images(item, data_root, image_path_remap, streampetr_conf, image_transform):
138
+ from PIL import Image
139
+ import torchvision.transforms as transforms
140
+
141
+ fH, fW = streampetr_conf["final_dim"]
142
+ images = []
143
+ intrinsics_list = []
144
+ extrinsics_list = []
145
+ lidar2img_list = []
146
+
147
+ calibration = streampetr_conf.get("_calibration")
148
+ cam_names = [
149
+ 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
150
+ 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT',
151
+ ]
152
+
153
+ ego_pose_out = None
154
+ ego_pose_inv_out = None
155
+ timestamp_out = None
156
+
157
+ for i, img_path in enumerate(item["image_paths"]):
158
+ camera_name = cam_names[i] if i < len(cam_names) else f"CAM_{i}"
159
+ for cam in sorted(cam_names, key=len, reverse=True):
160
+ if cam in str(img_path):
161
+ camera_name = cam
162
+ break
163
+
164
+ remapped = img_path
165
+ for old_prefix, new_prefix in image_path_remap.items():
166
+ if remapped.startswith(old_prefix):
167
+ remapped = new_prefix + remapped[len(old_prefix):]
168
+ break
169
+
170
+ if os.path.isabs(remapped):
171
+ full_path = remapped
172
+ else:
173
+ full_path = os.path.normpath(os.path.join(data_root, remapped))
174
+
175
+ img = Image.open(full_path).convert("RGB")
176
+ W_orig, H_orig = img.size
177
+
178
+ resize = max(fH / H_orig, fW / W_orig)
179
+ rW, rH = int(W_orig * resize), int(H_orig * resize)
180
+ crop_h = rH - fH
181
+ crop_w = max(0, rW - fW) // 2
182
+ if resize != 1.0:
183
+ img = img.resize((rW, rH), Image.BILINEAR)
184
+ img = img.crop((crop_w, crop_h, crop_w + fW, crop_h + fH))
185
+
186
+ K = None
187
+ E = None
188
+ ep_rec = None
189
+ sd_rec = None
190
+
191
+ if calibration is not None:
192
+ norm_path = str(img_path).replace("\\", "/").lstrip("./")
193
+ sd_rec = None
194
+ for cand in [img_path, norm_path]:
195
+ sd_rec = calibration["sample_data_by_filename"].get(cand)
196
+ if sd_rec:
197
+ break
198
+ if sd_rec is None:
199
+ for key in ("samples/", "sweeps/"):
200
+ if key in norm_path:
201
+ sd_rec = calibration["sample_data_by_filename"].get(norm_path[norm_path.index(key):])
202
+ if sd_rec:
203
+ break
204
+
205
+ if sd_rec is not None:
206
+ cs = calibration["calibrated_sensor_by_token"].get(sd_rec.get("calibrated_sensor_token"))
207
+ ep_rec = calibration["ego_pose_by_token"].get(sd_rec.get("ego_pose_token"))
208
+ if cs is not None:
209
+ K = np.array(cs["camera_intrinsic"], dtype=np.float32)
210
+ q = cs["rotation"]
211
+ t = cs["translation"]
212
+ w, x, y, z = q
213
+ R = np.array([
214
+ [1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
215
+ [2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w)],
216
+ [2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y)],
217
+ ], dtype=np.float32)
218
+ E = np.eye(4, dtype=np.float32)
219
+ E[:3, :3] = R
220
+ E[:3, 3] = np.array(t, dtype=np.float32)
221
+
222
+ if ego_pose_out is None and ep_rec is not None:
223
+ q_ep = ep_rec.get("rotation")
224
+ t_ep = ep_rec.get("translation")
225
+ if q_ep is not None and t_ep is not None:
226
+ w, x, y, z = q_ep
227
+ R_ep = np.array([
228
+ [1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
229
+ [2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w)],
230
+ [2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y)],
231
+ ], dtype=np.float32)
232
+ ego_m = np.eye(4, dtype=np.float32)
233
+ ego_m[:3, :3] = R_ep
234
+ ego_m[:3, 3] = np.array(t_ep, dtype=np.float32)
235
+ ego_pose_out = torch.tensor(ego_m, dtype=torch.float32)
236
+ try:
237
+ ego_pose_inv_out = torch.tensor(np.linalg.inv(ego_m), dtype=torch.float32)
238
+ except Exception:
239
+ ego_pose_inv_out = None
240
+
241
+ if timestamp_out is None and sd_rec is not None:
242
+ ts = sd_rec.get("timestamp")
243
+ if ts is not None:
244
+ timestamp_out = torch.tensor(float(ts) * 1e-6, dtype=torch.float32)
245
+
246
+ if K is None or E is None:
247
+ sensor = (item or {}).get("sensor", None) if isinstance(item, dict) else None
248
+ if isinstance(sensor, dict) and camera_name in sensor:
249
+ cam_s = sensor[camera_name]
250
+ K = np.array(cam_s["intrinsic"]["K"], dtype=np.float32)
251
+ R = np.array(cam_s["extrinsic"]["rotation"], dtype=np.float32)
252
+ t = np.array(cam_s["extrinsic"]["translation"], dtype=np.float32)
253
+ E = np.eye(4, dtype=np.float32)
254
+ E[:3, :3] = R
255
+ E[:3, 3] = t
256
+
257
+ if K is None or E is None:
258
+ raise RuntimeError(f"no camera params for {img_path} (camera={camera_name})")
259
+
260
+ K_adj = K.copy()
261
+ K_adj[0, 0] *= resize
262
+ K_adj[1, 1] *= resize
263
+ K_adj[0, 2] = K_adj[0, 2] * resize - crop_w
264
+ K_adj[1, 2] = K_adj[1, 2] * resize - crop_h
265
+
266
+ images.append(image_transform(img))
267
+ intrinsics_list.append(torch.tensor(K_adj, dtype=torch.float32))
268
+ extrinsics_list.append(torch.tensor(E, dtype=torch.float32))
269
+
270
+ cam2ego = E.astype(np.float32)
271
+ ego2cam = np.linalg.inv(cam2ego)
272
+ K4 = np.eye(4, dtype=np.float32)
273
+ K4[:3, :3] = K_adj.astype(np.float32)
274
+
275
+ def _qR(qwxyz):
276
+ ww, xx, yy, zz = qwxyz
277
+ return np.array([
278
+ [1-2*(yy*yy+zz*zz), 2*(xx*yy-zz*ww), 2*(xx*zz+yy*ww)],
279
+ [2*(xx*yy+zz*ww), 1-2*(xx*xx+zz*zz), 2*(yy*zz-xx*ww)],
280
+ [2*(xx*zz-yy*ww), 2*(yy*zz+xx*ww), 1-2*(xx*xx+yy*yy)],
281
+ ], dtype=np.float32)
282
+
283
+ def _T(Rm, tv):
284
+ T = np.eye(4, dtype=np.float32)
285
+ T[:3, :3] = Rm
286
+ T[:3, 3] = np.array(tv, dtype=np.float32)
287
+ return T
288
+
289
+ lidar2img_mat = K4 @ ego2cam
290
+ if calibration is not None and sd_rec is not None and ep_rec is not None:
291
+ sample_tk = sd_rec.get("sample_token")
292
+ if sample_tk:
293
+ ego2global_c = _T(_qR(ep_rec["rotation"]), ep_rec["translation"])
294
+ global2ego_c = np.linalg.inv(ego2global_c)
295
+ lidar_sd = calibration.get("lidar_sd_by_sample_token", {}).get(str(sample_tk))
296
+ if lidar_sd is not None:
297
+ lidar_cs = calibration["calibrated_sensor_by_token"].get(lidar_sd.get("calibrated_sensor_token"))
298
+ lidar_ep = calibration["ego_pose_by_token"].get(lidar_sd.get("ego_pose_token"))
299
+ if lidar_cs is not None and lidar_ep is not None:
300
+ lidar2ego = _T(_qR(lidar_cs["rotation"]), lidar_cs["translation"])
301
+ ego2global_l = _T(_qR(lidar_ep["rotation"]), lidar_ep["translation"])
302
+ lidar2cam = ego2cam @ global2ego_c @ ego2global_l @ lidar2ego
303
+ lidar2img_mat = K4 @ lidar2cam
304
+
305
+ lidar2img_list.append(torch.tensor(lidar2img_mat, dtype=torch.float32))
306
+
307
+ # Fallback: if nuScenes calibration lookup failed (e.g. OpenLane samples),
308
+ # recover ego_pose from item["pose"] and timestamp from item["timestamp"].
309
+ if ego_pose_out is None and isinstance(item, dict):
310
+ pose_data = item.get("pose", None)
311
+ if isinstance(pose_data, dict):
312
+ try:
313
+ rot_raw = pose_data.get("rotation", None)
314
+ t_p = pose_data.get("translation", None)
315
+ if rot_raw is not None and t_p is not None:
316
+ arr = np.array(rot_raw, dtype=np.float32)
317
+ if arr.shape == (3, 3):
318
+ R_p = arr
319
+ elif arr.shape == (4,):
320
+ w, x, y, z = arr
321
+ R_p = np.array([
322
+ [1-2*(y*y+z*z), 2*(x*y-z*w), 2*(x*z+y*w)],
323
+ [2*(x*y+z*w), 1-2*(x*x+z*z), 2*(y*z-x*w)],
324
+ [2*(x*z-y*w), 2*(y*z+x*w), 1-2*(x*x+y*y)],
325
+ ], dtype=np.float32)
326
+ else:
327
+ raise ValueError(f"Unsupported rotation shape: {arr.shape}")
328
+ T_p = np.eye(4, dtype=np.float32)
329
+ T_p[:3, :3] = R_p
330
+ T_p[:3, 3] = np.array(t_p, dtype=np.float32)
331
+ ego_pose_out = torch.tensor(T_p, dtype=torch.float32)
332
+ try:
333
+ ego_pose_inv_out = torch.tensor(np.linalg.inv(T_p), dtype=torch.float32)
334
+ except Exception:
335
+ ego_pose_inv_out = None
336
+ except Exception as e:
337
+ print(f"WARNING: Failed to parse item['pose']: {e}")
338
+
339
+ if timestamp_out is None and isinstance(item, dict):
340
+ ts_raw = item.get("timestamp", None)
341
+ if ts_raw is not None:
342
+ try:
343
+ timestamp_out = torch.tensor(float(ts_raw) * 1e-6, dtype=torch.float32)
344
+ except Exception:
345
+ pass
346
+
347
+ return {
348
+ "pixel_values_det": torch.stack(images).unsqueeze(0),
349
+ "intrinsics_det": torch.stack(intrinsics_list).unsqueeze(0),
350
+ "lidar2img_det": torch.stack(lidar2img_list).unsqueeze(0),
351
+ "ego_pose": ego_pose_out.unsqueeze(0) if ego_pose_out is not None else None,
352
+ "ego_pose_inv": ego_pose_inv_out.unsqueeze(0) if ego_pose_inv_out is not None else None,
353
+ "timestamp": timestamp_out.unsqueeze(0) if timestamp_out is not None else None,
354
+ }
355
+
356
+
357
+ def load_nuscenes_calibration(data_root):
358
+ nuscenes_root = Path(data_root)
359
+ version_dir = None
360
+ for v in ["v1.0-trainval", "v1.0-mini", "v1.0-test"]:
361
+ if (nuscenes_root / v).exists():
362
+ version_dir = nuscenes_root / v
363
+ break
364
+ if version_dir is None:
365
+ return None
366
+
367
+ needed = ["sample_data.json", "calibrated_sensor.json", "ego_pose.json", "sample.json"]
368
+ for n in needed:
369
+ if not (version_dir / n).exists():
370
+ return None
371
+
372
+ with open(version_dir / "sample_data.json") as f:
373
+ sample_data = json.load(f)
374
+ with open(version_dir / "calibrated_sensor.json") as f:
375
+ calibrated_sensor = json.load(f)
376
+ with open(version_dir / "ego_pose.json") as f:
377
+ ego_pose = json.load(f)
378
+
379
+ sd_by_fn = {r["filename"]: r for r in sample_data if "filename" in r}
380
+ cs_by_tok = {r["token"]: r for r in calibrated_sensor}
381
+ ep_by_tok = {r["token"]: r for r in ego_pose}
382
+
383
+ lidar_sd_by_sample = {}
384
+ for r in sample_data:
385
+ fn = str(r.get("filename", "")).replace("\\", "/")
386
+ if "/LIDAR_TOP/" in fn and fn.startswith("samples/") and r.get("is_key_frame"):
387
+ st = r.get("sample_token")
388
+ if st:
389
+ lidar_sd_by_sample.setdefault(str(st), r)
390
+
391
+ print(f"Calibration loaded: {len(sd_by_fn)} sample_data, {len(cs_by_tok)} cal_sensor, {len(ep_by_tok)} ego_pose, {len(lidar_sd_by_sample)} lidar_kf")
392
+ return {
393
+ "sample_data_by_filename": sd_by_fn,
394
+ "calibrated_sensor_by_token": cs_by_tok,
395
+ "ego_pose_by_token": ep_by_tok,
396
+ "lidar_sd_by_sample_token": lidar_sd_by_sample,
397
+ }
398
+
399
+
400
+ @torch.no_grad()
401
+ def run_streampetr_forward_temporal(model, batch, device, prev_exists_val, scene_token="__extract__"):
402
+ imgs = batch["pixel_values_det"].to(device)
403
+ B, N = imgs.shape[:2]
404
+ fH, fW = 800, 1600
405
+
406
+ img_metas = [{
407
+ "pad_shape": [(fH, fW, 3)] * N,
408
+ "img_shape": [(fH, fW, 3)] * N,
409
+ "scene_token": str(scene_token),
410
+ } for _ in range(B)]
411
+ if batch.get("lidar2img_det") is not None:
412
+ for b in range(B):
413
+ img_metas[b]["lidar2img"] = batch["lidar2img_det"][b].cpu().numpy()
414
+
415
+ img_feats = model.extract_img_feat(imgs, 1)
416
+
417
+ data = {
418
+ "img": imgs,
419
+ "img_feats": img_feats,
420
+ "prev_exists": imgs.new_tensor([prev_exists_val]),
421
+ }
422
+
423
+ if batch.get("intrinsics_det") is not None:
424
+ K3 = batch["intrinsics_det"].to(device)
425
+ K4 = torch.zeros(B, N, 4, 4, device=device, dtype=K3.dtype)
426
+ K4[:, :, :3, :3] = K3
427
+ K4[:, :, 3, 3] = 1.0
428
+ data["intrinsics"] = K4
429
+ else:
430
+ data["intrinsics"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
431
+
432
+ if batch.get("lidar2img_det") is not None:
433
+ data["lidar2img"] = batch["lidar2img_det"].to(device)
434
+ else:
435
+ data["lidar2img"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
436
+
437
+ if batch.get("ego_pose") is not None:
438
+ data["ego_pose"] = batch["ego_pose"].to(device)
439
+ else:
440
+ data["ego_pose"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
441
+
442
+ if batch.get("ego_pose_inv") is not None:
443
+ data["ego_pose_inv"] = batch["ego_pose_inv"].to(device)
444
+ else:
445
+ data["ego_pose_inv"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
446
+
447
+ if batch.get("timestamp") is not None:
448
+ data["timestamp"] = batch["timestamp"].to(device)
449
+ else:
450
+ data["timestamp"] = torch.zeros(B, device=device)
451
+
452
+ location = model.prepare_location(img_metas, **data)
453
+ outs_roi = model.forward_roi_head(location, **data)
454
+ topk_indexes = outs_roi["topk_indexes"]
455
+
456
+ model.pts_bbox_head(location, img_metas, topk_indexes, **data)
457
+
458
+
459
+ def main():
460
+ args = parse_args()
461
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
462
+
463
+ image_path_remap = {}
464
+ if args.image_path_remap:
465
+ for pair in args.image_path_remap.split(","):
466
+ if "=" in pair:
467
+ old, new = pair.split("=", 1)
468
+ image_path_remap[old] = new
469
+
470
+ paths = [p.strip() for p in args.data_json.split(",") if p.strip()]
471
+ data_items_raw = []
472
+ for p in paths:
473
+ with open(p) as f:
474
+ chunk = json.load(f)
475
+ data_items_raw.extend(chunk)
476
+
477
+ data_items = []
478
+ seen_ids = set()
479
+ for i, item in enumerate(data_items_raw):
480
+ sid = str(item.get("id", f"__idx_{i}"))
481
+ if sid in seen_ids:
482
+ continue
483
+ seen_ids.add(sid)
484
+ data_items.append(item)
485
+ print(f"Loaded {len(data_items)} unique samples from {len(paths)} file(s) (raw={len(data_items_raw)})")
486
+
487
+ calibration = load_nuscenes_calibration(args.data_root)
488
+ all_scenes = build_scene_order(data_items, args.data_root)
489
+ if args.num_shards > 1:
490
+ scenes = [s for i, s in enumerate(all_scenes) if i % args.num_shards == args.shard_id]
491
+ print(f"Shard {args.shard_id}/{args.num_shards}: {len(scenes)}/{len(all_scenes)} scenes")
492
+ else:
493
+ scenes = all_scenes
494
+ model = load_streampetr(args.streampetr_config, args.streampetr_ckpt, device)
495
+
496
+ import torchvision.transforms as transforms
497
+ image_transform = transforms.Compose([
498
+ transforms.ToTensor(),
499
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
500
+ ])
501
+ sp_conf = {"final_dim": (800, 1600), "_calibration": calibration}
502
+
503
+ output_dir = Path(args.output_dir)
504
+ output_dir.mkdir(parents=True, exist_ok=True)
505
+
506
+ total_samples = sum(len(s) for s in scenes)
507
+ num_saved = 0
508
+ num_existed = 0
509
+ t0 = time.time()
510
+
511
+ for scene_idx, scene_indices in enumerate(scenes):
512
+ model.pts_bbox_head.reset_memory()
513
+ scene_token = f"scene_{scene_idx}"
514
+
515
+ for frame_idx, data_idx in enumerate(scene_indices):
516
+ item = data_items[data_idx]
517
+ sample_id = str(item.get("id", data_idx))
518
+ out_path = output_dir / f"{sample_id}.pt"
519
+
520
+ already_done = out_path.exists()
521
+
522
+ batch = load_and_preprocess_images(
523
+ item, args.data_root, image_path_remap, sp_conf, image_transform
524
+ )
525
+
526
+ prev_exists_val = 0.0 if frame_idx == 0 else 1.0
527
+ run_streampetr_forward_temporal(
528
+ model, batch, device, prev_exists_val, scene_token=scene_token
529
+ )
530
+
531
+ if not already_done:
532
+ ego_pose_for_ref = batch.get("ego_pose")
533
+ if ego_pose_for_ref is not None:
534
+ ego_pose_for_ref = ego_pose_for_ref.to(device)
535
+ det_out = extract_streampetr_topk_tokens(
536
+ model.pts_bbox_head, topk=args.topk, ego_pose=ego_pose_for_ref,
537
+ )
538
+ torch.save({
539
+ "detection": det_out["detection"][0].cpu().half(),
540
+ "detection_ref_points": det_out["detection_ref_points"][0].cpu().half(),
541
+ }, out_path)
542
+ num_saved += 1
543
+ else:
544
+ num_existed += 1
545
+
546
+ done = num_saved + num_existed
547
+ if done % 200 == 0:
548
+ elapsed = time.time() - t0
549
+ rate = done / max(elapsed, 1)
550
+ eta = (total_samples - done) / max(rate, 0.01)
551
+ print(f" [{done}/{total_samples}] saved={num_saved} existed={num_existed} "
552
+ f"{elapsed:.0f}s elapsed, ETA {eta:.0f}s")
553
+
554
+ elapsed = time.time() - t0
555
+ print(f"Done. saved={num_saved}, existed={num_existed}, total={total_samples}, time={elapsed:.0f}s")
556
+ print(f"Output: {output_dir}")
557
+
558
+ index = {}
559
+ for pt in output_dir.glob("*.pt"):
560
+ index[pt.stem] = pt.name
561
+ with open(output_dir / "index.json", "w") as f:
562
+ json.dump(index, f)
563
+ print(f"Index written: {len(index)} entries")
564
+
565
+
566
+ if __name__ == "__main__":
567
+ main()
568
+
extract_topomlp_tokens.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Pre-extract frozen TopoMLP raw outputs for offline Atlas training.
3
+
4
+ Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
5
+ For online training (default), use: bash scripts/train_no_caption_baseline.sh
6
+
7
+ Saves the 6 tensors that TopoMLPToAtlasMapTokens.forward() needs from the
8
+ last decoder layer. The adapter performs Top-K selection online during
9
+ training; only the frozen TopoMLP forward pass is pre-computed here.
10
+
11
+ Usage (4-GPU parallel, requires ATLAS_ALLOW_OFFLINE=1):
12
+ for i in 0 1 2 3; do
13
+ ATLAS_ALLOW_OFFLINE=1 CUDA_VISIBLE_DEVICES=$i python extract_topomlp_tokens.py \
14
+ --topomlp_config configs/topomlp_atlas_aligned.py \
15
+ --topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
16
+ --data_json data/openlane_subsetB_lane_train_4pt.json \
17
+ --output_dir work_dirs/precomputed_map_tokens_offline/train \
18
+ --shard_id $i --num_shards 4 &
19
+ done; wait
20
+ """
21
+
22
+ import os
23
+ import sys
24
+
25
+ if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"):
26
+ print(
27
+ "ERROR: This is an OFFLINE token extraction script.\n"
28
+ "It is isolated by default to prevent accidental use.\n"
29
+ "If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n"
30
+ "For online training use: bash scripts/train_no_caption_baseline.sh",
31
+ file=sys.stderr,
32
+ )
33
+ sys.exit(1)
34
+
35
+ import argparse
36
+ import json
37
+ import time
38
+ from pathlib import Path
39
+
40
+ import numpy as np
41
+ import torch
42
+ from PIL import Image
43
+ import torchvision.transforms as transforms
44
+
45
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
46
+
47
+
48
+ def _load_nuscenes_calibration(data_root):
49
+ nuscenes_root = Path(data_root)
50
+ version_dir = None
51
+ for v in ["v1.0-trainval", "v1.0-mini", "v1.0-test"]:
52
+ if (nuscenes_root / v).exists():
53
+ version_dir = nuscenes_root / v
54
+ break
55
+ if version_dir is None:
56
+ return None
57
+ needed = ["sample_data.json", "calibrated_sensor.json", "ego_pose.json"]
58
+ for n in needed:
59
+ if not (version_dir / n).exists():
60
+ return None
61
+ with open(version_dir / "sample_data.json") as f:
62
+ sample_data = json.load(f)
63
+ with open(version_dir / "calibrated_sensor.json") as f:
64
+ calibrated_sensor = json.load(f)
65
+ with open(version_dir / "ego_pose.json") as f:
66
+ ego_pose_data = json.load(f)
67
+ sd_by_fn = {r["filename"]: r for r in sample_data if "filename" in r}
68
+ cs_by_tok = {r["token"]: r for r in calibrated_sensor}
69
+ ep_by_tok = {r["token"]: r for r in ego_pose_data}
70
+ lidar_sd_by_sample = {}
71
+ for r in sample_data:
72
+ fn = str(r.get("filename", "")).replace("\\", "/")
73
+ if "/LIDAR_TOP/" in fn and fn.startswith("samples/") and r.get("is_key_frame"):
74
+ st = r.get("sample_token")
75
+ if st:
76
+ lidar_sd_by_sample.setdefault(str(st), r)
77
+ print(f"nuScenes calibration: {len(sd_by_fn)} sample_data, {len(cs_by_tok)} cal_sensor, "
78
+ f"{len(ep_by_tok)} ego_pose, {len(lidar_sd_by_sample)} lidar_kf")
79
+ return {
80
+ "sample_data_by_filename": sd_by_fn,
81
+ "calibrated_sensor_by_token": cs_by_tok,
82
+ "ego_pose_by_token": ep_by_tok,
83
+ "lidar_sd_by_sample_token": lidar_sd_by_sample,
84
+ }
85
+
86
+
87
+ def parse_args():
88
+ p = argparse.ArgumentParser()
89
+ p.add_argument("--topomlp_config", required=True)
90
+ p.add_argument("--topomlp_ckpt", required=True)
91
+ p.add_argument("--data_json", required=True)
92
+ p.add_argument("--data_root", default="")
93
+ p.add_argument("--output_dir", required=True)
94
+ p.add_argument("--image_path_remap", default=None)
95
+ p.add_argument("--shard_id", type=int, default=0)
96
+ p.add_argument("--num_shards", type=int, default=1)
97
+ return p.parse_args()
98
+
99
+
100
+ def load_topomlp(config_path, ckpt_path, device):
101
+ tp_root = str(Path(__file__).resolve().parent / "external" / "TopoMLP_Repo")
102
+ if tp_root not in sys.path:
103
+ sys.path.insert(0, tp_root)
104
+
105
+ sp_root = str(Path(__file__).resolve().parent / "external" / "StreamPETR")
106
+ if sp_root not in sys.path:
107
+ sys.path.insert(0, sp_root)
108
+ try:
109
+ import projects.mmdet3d_plugin # noqa: F401
110
+ except ImportError:
111
+ pass
112
+
113
+ os.environ["ATLAS_TOPOMLP_MODELS_ONLY"] = "1"
114
+ from mmcv.utils import registry as _reg
115
+ _orig = _reg.Registry._register_module
116
+ def _tolerant_register(self, module, module_name=None, force=False):
117
+ return _orig(self, module, module_name=module_name, force=True)
118
+ _reg.Registry._register_module = _tolerant_register
119
+ import projects.topomlp # noqa: F401
120
+ _reg.Registry._register_module = _orig
121
+
122
+ from mmcv import Config
123
+ from mmdet3d.models import build_model
124
+ from mmcv.runner import load_checkpoint
125
+
126
+ cfg = Config.fromfile(config_path)
127
+ model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
128
+ load_checkpoint(model, ckpt_path, map_location="cpu")
129
+ model.eval()
130
+ model.to(device)
131
+ for param in model.parameters():
132
+ param.requires_grad_(False)
133
+ print(f"Loaded frozen TopoMLP from {ckpt_path}")
134
+ return model
135
+
136
+
137
+ def _quat_to_rot(q):
138
+ w, x, y, z = q
139
+ return np.array([
140
+ [1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
141
+ [2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w)],
142
+ [2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y)],
143
+ ], dtype=np.float32)
144
+
145
+
146
+ def load_and_preprocess_images(item, data_root, image_path_remap, image_transform,
147
+ calibration=None):
148
+ tW, tH = 1600, 800
149
+ cam_names = [
150
+ "CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT",
151
+ "CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT",
152
+ ]
153
+
154
+ images = []
155
+ lidar2img_list = []
156
+ sensor = item.get("sensor", {})
157
+
158
+ for i, img_path in enumerate(item["image_paths"]):
159
+ camera_name = cam_names[i] if i < len(cam_names) else f"CAM_{i}"
160
+ for cam in sorted(cam_names, key=len, reverse=True):
161
+ if cam in str(img_path):
162
+ camera_name = cam
163
+ break
164
+
165
+ remapped = str(img_path)
166
+ if image_path_remap:
167
+ for old_prefix, new_prefix in image_path_remap.items():
168
+ if remapped.startswith(old_prefix):
169
+ remapped = new_prefix + remapped[len(old_prefix):]
170
+ break
171
+
172
+ if os.path.isabs(remapped):
173
+ full_path = remapped
174
+ elif data_root:
175
+ full_path = os.path.normpath(os.path.join(data_root, remapped))
176
+ else:
177
+ full_path = remapped
178
+
179
+ img = Image.open(full_path).convert("RGB")
180
+ W_orig, H_orig = img.size
181
+ w_scale = tW / W_orig
182
+ h_scale = tH / H_orig
183
+ img = img.resize((tW, tH), Image.BILINEAR)
184
+ images.append(image_transform(img))
185
+
186
+ K, E = None, None
187
+ if isinstance(sensor, dict) and camera_name in sensor:
188
+ cam_s = sensor[camera_name]
189
+ K = np.array(cam_s["intrinsic"]["K"], dtype=np.float32)
190
+ R = np.array(cam_s["extrinsic"]["rotation"], dtype=np.float32)
191
+ t = np.array(cam_s["extrinsic"]["translation"], dtype=np.float32)
192
+ E = np.eye(4, dtype=np.float32)
193
+ E[:3, :3] = R
194
+ E[:3, 3] = t
195
+
196
+ sd_rec = None
197
+ ep_rec = None
198
+ if K is None and calibration is not None:
199
+ norm_path = str(img_path).replace("\\", "/").lstrip("./")
200
+ for cand in [img_path, norm_path]:
201
+ sd_rec = calibration["sample_data_by_filename"].get(cand)
202
+ if sd_rec:
203
+ break
204
+ if sd_rec is None:
205
+ for prefix in ("samples/", "sweeps/"):
206
+ if prefix in norm_path:
207
+ sd_rec = calibration["sample_data_by_filename"].get(
208
+ norm_path[norm_path.index(prefix):]
209
+ )
210
+ if sd_rec:
211
+ break
212
+ if sd_rec is not None:
213
+ cs = calibration["calibrated_sensor_by_token"].get(
214
+ sd_rec.get("calibrated_sensor_token")
215
+ )
216
+ ep_rec = calibration["ego_pose_by_token"].get(
217
+ sd_rec.get("ego_pose_token")
218
+ )
219
+ if cs is not None:
220
+ K = np.array(cs["camera_intrinsic"], dtype=np.float32)
221
+ E = np.eye(4, dtype=np.float32)
222
+ E[:3, :3] = _quat_to_rot(cs["rotation"])
223
+ E[:3, 3] = np.array(cs["translation"], dtype=np.float32)
224
+
225
+ if K is not None and E is not None:
226
+ K_adj = K.copy()
227
+ K_adj[0, 0] *= w_scale
228
+ K_adj[0, 2] *= w_scale
229
+ K_adj[1, 1] *= h_scale
230
+ K_adj[1, 2] *= h_scale
231
+ ego2cam = np.linalg.inv(E.astype(np.float32))
232
+ K4 = np.eye(4, dtype=np.float32)
233
+ K4[:3, :3] = K_adj
234
+
235
+ lidar2img_mat = K4 @ ego2cam
236
+ if calibration is not None and sd_rec is not None and ep_rec is not None:
237
+ sample_tk = sd_rec.get("sample_token")
238
+ if sample_tk:
239
+ ego2global_c = np.eye(4, dtype=np.float32)
240
+ ego2global_c[:3, :3] = _quat_to_rot(ep_rec["rotation"])
241
+ ego2global_c[:3, 3] = np.array(ep_rec["translation"], dtype=np.float32)
242
+ global2ego_c = np.linalg.inv(ego2global_c)
243
+
244
+ lidar_sd = calibration.get("lidar_sd_by_sample_token", {}).get(str(sample_tk))
245
+ if lidar_sd is not None:
246
+ lidar_cs = calibration["calibrated_sensor_by_token"].get(
247
+ lidar_sd.get("calibrated_sensor_token"))
248
+ lidar_ep = calibration["ego_pose_by_token"].get(
249
+ lidar_sd.get("ego_pose_token"))
250
+ if lidar_cs is not None and lidar_ep is not None:
251
+ lidar2ego = np.eye(4, dtype=np.float32)
252
+ lidar2ego[:3, :3] = _quat_to_rot(lidar_cs["rotation"])
253
+ lidar2ego[:3, 3] = np.array(lidar_cs["translation"], dtype=np.float32)
254
+ ego2global_l = np.eye(4, dtype=np.float32)
255
+ ego2global_l[:3, :3] = _quat_to_rot(lidar_ep["rotation"])
256
+ ego2global_l[:3, 3] = np.array(lidar_ep["translation"], dtype=np.float32)
257
+ lidar2cam = ego2cam @ global2ego_c @ ego2global_l @ lidar2ego
258
+ lidar2img_mat = K4 @ lidar2cam
259
+
260
+ lidar2img_list.append(torch.tensor(lidar2img_mat, dtype=torch.float32))
261
+ else:
262
+ lidar2img_list.append(torch.eye(4, dtype=torch.float32))
263
+
264
+ return torch.stack(images).unsqueeze(0), torch.stack(lidar2img_list).unsqueeze(0)
265
+
266
+
267
+ @torch.no_grad()
268
+ def run_topomlp_forward(model, pixel_values, lidar2img, device):
269
+ imgs = pixel_values.to(device)
270
+ B, N = imgs.shape[:2]
271
+ tH, tW = 800, 1600
272
+
273
+ img_metas = [{
274
+ "img_shape": tuple([(tH, tW, 3)] * N),
275
+ "pad_shape": tuple([(tH, tW, 3)] * N),
276
+ "scale_factor": 1.0,
277
+ "te_yolov8": None,
278
+ } for _ in range(B)]
279
+ if lidar2img is not None:
280
+ for b in range(B):
281
+ img_metas[b]["lidar2img"] = lidar2img[b].cpu().numpy()
282
+
283
+ return model.simple_forward(imgs, img_metas)
284
+
285
+
286
+ def extract_adapter_inputs(outs):
287
+ return {
288
+ "lc_outs_dec": outs["lc_outs_dec_list"][-1][0].cpu().half(),
289
+ "lc_cls_scores": outs["all_lc_cls_scores_list"][-1][0].cpu().half(),
290
+ "lc_preds": outs["all_lc_preds_list"][-1][0].cpu().half(),
291
+ "lc_outs_dec_o2m": outs["lc_outs_dec_one2many_list"][-1][0].cpu().half(),
292
+ "lc_cls_scores_o2m": outs["all_lc_cls_scores_one2many_list"][-1][0].cpu().half(),
293
+ "lc_preds_o2m": outs["all_lc_preds_one2many_list"][-1][0].cpu().half(),
294
+ }
295
+
296
+
297
+ def main():
298
+ args = parse_args()
299
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
300
+
301
+ image_path_remap = {}
302
+ if args.image_path_remap:
303
+ for pair in args.image_path_remap.split(","):
304
+ if "=" in pair:
305
+ old, new = pair.split("=", 1)
306
+ image_path_remap[old] = new
307
+
308
+ paths = [p.strip() for p in args.data_json.split(",") if p.strip()]
309
+ data_items_raw = []
310
+ for p in paths:
311
+ with open(p) as f:
312
+ data_items_raw.extend(json.load(f))
313
+
314
+ data_items = []
315
+ seen_ids = set()
316
+ for i, item in enumerate(data_items_raw):
317
+ sid = str(item.get("id", f"__idx_{i}"))
318
+ if sid in seen_ids:
319
+ continue
320
+ seen_ids.add(sid)
321
+ data_items.append(item)
322
+ print(f"Loaded {len(data_items)} unique samples from {len(paths)} file(s)")
323
+
324
+ if args.num_shards > 1:
325
+ data_items = [item for i, item in enumerate(data_items) if i % args.num_shards == args.shard_id]
326
+ print(f"Shard {args.shard_id}/{args.num_shards}: {len(data_items)} samples")
327
+
328
+ model = load_topomlp(args.topomlp_config, args.topomlp_ckpt, device)
329
+
330
+ calibration = _load_nuscenes_calibration(args.data_root) if args.data_root else None
331
+
332
+ image_transform = transforms.Compose([
333
+ transforms.ToTensor(),
334
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
335
+ ])
336
+
337
+ output_dir = Path(args.output_dir)
338
+ output_dir.mkdir(parents=True, exist_ok=True)
339
+
340
+ total = len(data_items)
341
+ num_saved = 0
342
+ num_existed = 0
343
+ t0 = time.time()
344
+
345
+ for idx, item in enumerate(data_items):
346
+ sample_id = str(item.get("id", idx))
347
+ out_path = output_dir / f"{sample_id}.pt"
348
+
349
+ if out_path.exists():
350
+ num_existed += 1
351
+ else:
352
+ pixel_values, lidar2img = load_and_preprocess_images(
353
+ item, args.data_root, image_path_remap, image_transform,
354
+ calibration=calibration,
355
+ )
356
+ outs = run_topomlp_forward(model, pixel_values, lidar2img, device)
357
+ torch.save(extract_adapter_inputs(outs), out_path)
358
+ num_saved += 1
359
+
360
+ done = num_saved + num_existed
361
+ if done % 200 == 0:
362
+ elapsed = time.time() - t0
363
+ rate = done / max(elapsed, 1)
364
+ eta = (total - done) / max(rate, 0.01)
365
+ print(f" [{done}/{total}] saved={num_saved} existed={num_existed} "
366
+ f"{elapsed:.0f}s elapsed, ETA {eta:.0f}s")
367
+
368
+ elapsed = time.time() - t0
369
+ print(f"Done. saved={num_saved}, existed={num_existed}, total={total}, time={elapsed:.0f}s")
370
+ print(f"Output: {output_dir}")
371
+
372
+ index = {}
373
+ for pt in output_dir.glob("*.pt"):
374
+ index[pt.stem] = pt.name
375
+ with open(output_dir / "index.json", "w") as f:
376
+ json.dump(index, f)
377
+ print(f"Index written: {len(index)} entries")
378
+
379
+
380
+ if __name__ == "__main__":
381
+ main()
scripts/eval_checkpoint_offline.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # [OFFLINE fallback] Evaluate using precomputed *_offline visual tokens.
3
+ # This is NOT the default. Use scripts/eval_checkpoint.sh for online mode.
4
+ # Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
5
+ set -e
6
+
7
+ if [ "${ATLAS_ALLOW_OFFLINE}" != "1" ]; then
8
+ echo "ERROR: This is an OFFLINE fallback script, not the primary online evaluation." >&2
9
+ echo "It is isolated by default to prevent accidental use in experiments." >&2
10
+ echo "If you really need it, set: ATLAS_ALLOW_OFFLINE=1" >&2
11
+ echo "For production evaluation use: bash scripts/eval_checkpoint.sh" >&2
12
+ exit 1
13
+ fi
14
+
15
+ PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
16
+ cd "$PROJECT_ROOT"
17
+
18
+ CHECKPOINT="${1:?Usage: $0 <checkpoint> <data_json> [output_json] [max_samples]}"
19
+ DATA_JSON="${2:?Usage: $0 <checkpoint> <data_json> [output_json] [max_samples]}"
20
+ OUTPUT_JSON="${3:-}"
21
+ MAX_SAMPLES="${4:-0}"
22
+ PLANNING_TABLE3_MODE="${PLANNING_TABLE3_MODE:-atlas_base}"
23
+
24
+ EXTRA_ARGS=""
25
+ if [ -n "$OUTPUT_JSON" ]; then
26
+ EXTRA_ARGS="$EXTRA_ARGS --output_json $OUTPUT_JSON"
27
+ fi
28
+ if [ "$MAX_SAMPLES" -gt 0 ] 2>/dev/null; then
29
+ EXTRA_ARGS="$EXTRA_ARGS --max_samples $MAX_SAMPLES"
30
+ fi
31
+
32
+ python eval_atlas.py \
33
+ --checkpoint "$CHECKPOINT" \
34
+ --llm_model pretrained/vicuna-7b-v1.5 \
35
+ --data_json "$DATA_JSON" \
36
+ --data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
37
+ --visual_token_mode offline \
38
+ --planning_table3_mode "$PLANNING_TABLE3_MODE" \
39
+ --precomputed_det_tokens work_dirs/precomputed_det_tokens_offline/val \
40
+ --precomputed_map_tokens work_dirs/precomputed_map_tokens_offline/val \
41
+ --bf16 \
42
+ --batch_size 1 \
43
+ --num_workers 2 \
44
+ $EXTRA_ARGS
scripts/gen_atlas_caption_dashscope.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Atlas Caption 数据生成脚本 - Dashscope 版
3
+
4
+ 与 gen_atlas_caption_qa.py 完全相同的输出格式,
5
+ 支持 --start/--end 指定 keyframe 范围,写入独立文件,最终合并。
6
+
7
+ 模型: qwen-vl-max-latest (Dashscope)
8
+ """
9
+ import asyncio
10
+ import json
11
+ import base64
12
+ import os
13
+ import sys
14
+ import time
15
+ import signal
16
+ from io import BytesIO
17
+ from pathlib import Path
18
+
19
+ try:
20
+ import httpx
21
+ from PIL import Image
22
+ except ImportError:
23
+ print("pip install httpx Pillow")
24
+ sys.exit(1)
25
+
26
+ NUSCENES_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes"
27
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
28
+
29
+ CAMERAS = [
30
+ "CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT",
31
+ "CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT",
32
+ ]
33
+
34
+ GPT4V_PROMPT = (
35
+ "Describe the current traffic conditions. "
36
+ "If there are traffic lights in the image, describe the status of all the traffic lights, "
37
+ "including any countdowns; if there are none, please do not respond. "
38
+ "If there are traffic signs in the picture, identify and explain each one; "
39
+ "if there are none, no explanation is necessary. "
40
+ "If there are other vehicles in the picture, describe them in more detail. "
41
+ "Please ensure the answer does not exceed 600 words. Answers must be in English."
42
+ )
43
+
44
+ TRAIN_PROMPTS = [
45
+ (
46
+ "There are six images captured by the surround view cameras in driving vehicle. "
47
+ "They are uniformly represented as queries embeddings<query>. "
48
+ "Communicate a narrative of the setting within {camera_name} view image."
49
+ ),
50
+ ]
51
+
52
+ API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
53
+ MODEL = "qwen-vl-max-latest"
54
+
55
+ MAX_CONCURRENCY = 50
56
+ MAX_RETRIES = 3
57
+ RETRY_DELAY = 3
58
+ TIMEOUT = 60
59
+ CHECKPOINT_INTERVAL = 100
60
+
61
+
62
+ def image_to_base64(path):
63
+ img = Image.open(path)
64
+ buf = BytesIO()
65
+ img.save(buf, format="JPEG", quality=80)
66
+ return base64.b64encode(buf.getvalue()).decode()
67
+
68
+
69
+ async def call_api(client, api_key, image_b64, camera_name):
70
+ content = [
71
+ {"type": "text", "text": f"[{camera_name}] {GPT4V_PROMPT}"},
72
+ {"type": "image_url", "image_url": {
73
+ "url": f"data:image/jpeg;base64,{image_b64}",
74
+ }},
75
+ ]
76
+ payload = {
77
+ "model": MODEL,
78
+ "messages": [{"role": "user", "content": content}],
79
+ "max_tokens": 800,
80
+ "temperature": 0.3,
81
+ }
82
+ headers = {
83
+ "Authorization": f"Bearer {api_key}",
84
+ "Content-Type": "application/json",
85
+ }
86
+ resp = await client.post(API_URL, json=payload, headers=headers, timeout=TIMEOUT)
87
+ resp.raise_for_status()
88
+ data = resp.json()
89
+ msg = data["choices"][0]["message"]["content"].strip()
90
+ usage = data.get("usage", {})
91
+ return msg, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0)
92
+
93
+
94
+ async def process_one_view(client, api_key, sample, cam_idx, sem, stats):
95
+ cam = CAMERAS[cam_idx]
96
+ img_path = os.path.join(NUSCENES_ROOT, sample["image_paths"][cam_idx])
97
+ if not os.path.exists(img_path):
98
+ stats["skipped"] += 1
99
+ return None
100
+
101
+ img_b64 = image_to_base64(img_path)
102
+ train_prompt = TRAIN_PROMPTS[0].format(camera_name=cam)
103
+
104
+ for attempt in range(MAX_RETRIES):
105
+ async with sem:
106
+ try:
107
+ caption, in_tok, out_tok = await call_api(client, api_key, img_b64, cam)
108
+ stats["success"] += 1
109
+ stats["total_in"] += in_tok
110
+ stats["total_out"] += out_tok
111
+ return {
112
+ "id": sample["id"],
113
+ "image_paths": sample["image_paths"],
114
+ "num_map_queries": 0,
115
+ "task": "caption",
116
+ "camera": cam,
117
+ "conversations": [
118
+ {"from": "human", "value": train_prompt},
119
+ {"from": "gpt", "value": caption},
120
+ ],
121
+ }
122
+ except httpx.TimeoutException:
123
+ stats["retries"] += 1
124
+ if attempt < MAX_RETRIES - 1:
125
+ await asyncio.sleep(RETRY_DELAY * (attempt + 1))
126
+ except httpx.HTTPStatusError as e:
127
+ stats["retries"] += 1
128
+ if e.response.status_code == 429:
129
+ await asyncio.sleep(RETRY_DELAY * (attempt + 2))
130
+ elif attempt < MAX_RETRIES - 1:
131
+ await asyncio.sleep(RETRY_DELAY)
132
+ else:
133
+ stats["failed"] += 1
134
+ return None
135
+ except Exception:
136
+ stats["retries"] += 1
137
+ if attempt < MAX_RETRIES - 1:
138
+ await asyncio.sleep(RETRY_DELAY)
139
+ else:
140
+ stats["failed"] += 1
141
+ return None
142
+
143
+ stats["failed"] += 1
144
+ return None
145
+
146
+
147
+ def make_ckpt_key(sample_id, cam_idx):
148
+ return f"{sample_id}_{cam_idx}"
149
+
150
+
151
+ def load_checkpoint(path):
152
+ if os.path.exists(path):
153
+ with open(path) as f:
154
+ return set(json.load(f))
155
+ return set()
156
+
157
+
158
+ def save_checkpoint(path, done_keys):
159
+ with open(path, "w") as f:
160
+ json.dump(sorted(done_keys), f)
161
+
162
+
163
+ async def run(split, start, end, dry_run=False, tag="dashscope"):
164
+ api_key = os.environ.get("DASHSCOPE_KEY", "")
165
+ if not api_key:
166
+ print("ERROR: set DASHSCOPE_KEY env var", flush=True)
167
+ sys.exit(1)
168
+
169
+ data_file = PROJECT_ROOT / f"data/atlas_nuscenes_{split}.json"
170
+ out_file = PROJECT_ROOT / f"data/atlas_caption_{split}_{tag}.json"
171
+ ckpt_file = PROJECT_ROOT / f"data/.caption_{split}_{tag}_checkpoint.json"
172
+
173
+ with open(data_file) as f:
174
+ all_samples = json.load(f)
175
+
176
+ all_samples = all_samples[start:end]
177
+ print(f"Range: [{start}:{end}] = {len(all_samples)} keyframes", flush=True)
178
+
179
+ done_keys = load_checkpoint(ckpt_file)
180
+ existing_results = []
181
+ if os.path.exists(out_file) and done_keys:
182
+ with open(out_file) as f:
183
+ existing_results = json.load(f)
184
+
185
+ todo = []
186
+ for s in all_samples:
187
+ for cam_idx in range(6):
188
+ key = make_ckpt_key(s["id"], cam_idx)
189
+ if key not in done_keys:
190
+ todo.append((s, cam_idx))
191
+
192
+ total = len(todo)
193
+ print(f"Split: {split}, Tag: {tag}", flush=True)
194
+ print(f"Total keyframes: {len(all_samples)}", flush=True)
195
+ print(f"Total views to caption: {len(all_samples) * 6}", flush=True)
196
+ print(f"Already done: {len(done_keys)}", flush=True)
197
+ print(f"To process: {total}", flush=True)
198
+ print(f"Model: {MODEL}", flush=True)
199
+ print(f"Concurrency: {MAX_CONCURRENCY}", flush=True)
200
+ if dry_run:
201
+ print("DRY RUN", flush=True)
202
+ return
203
+
204
+ stats = {"success": 0, "failed": 0, "skipped": 0, "retries": 0,
205
+ "total_in": 0, "total_out": 0}
206
+ results = list(existing_results)
207
+ sem = asyncio.Semaphore(MAX_CONCURRENCY)
208
+ client = httpx.AsyncClient()
209
+
210
+ shutdown = False
211
+ def handle_signal(sig, frame):
212
+ nonlocal shutdown
213
+ shutdown = True
214
+ print("\nGraceful shutdown...", flush=True)
215
+ signal.signal(signal.SIGINT, handle_signal)
216
+
217
+ t0 = time.time()
218
+ batch_size = CHECKPOINT_INTERVAL
219
+ for batch_start in range(0, total, batch_size):
220
+ if shutdown:
221
+ break
222
+ batch = todo[batch_start:batch_start + batch_size]
223
+ tasks = [process_one_view(client, api_key, s, ci, sem, stats) for s, ci in batch]
224
+ batch_results = await asyncio.gather(*tasks)
225
+
226
+ for (s, ci), r in zip(batch, batch_results):
227
+ if r is not None:
228
+ results.append(r)
229
+ done_keys.add(make_ckpt_key(s["id"], ci))
230
+
231
+ with open(out_file, "w") as f:
232
+ json.dump(results, f, ensure_ascii=False)
233
+ save_checkpoint(ckpt_file, done_keys)
234
+
235
+ elapsed = time.time() - t0
236
+ done_n = batch_start + len(batch)
237
+ rps = stats["success"] / elapsed if elapsed > 0 else 0
238
+ eta = (total - done_n) / rps / 3600 if rps > 0 else 0
239
+ pct = done_n / total * 100
240
+
241
+ print(
242
+ f" [{pct:5.1f}%] {done_n}/{total} | "
243
+ f"ok={stats['success']} fail={stats['failed']} retry={stats['retries']} | "
244
+ f"{rps:.2f} rps | ETA {eta:.1f}h | "
245
+ f"tok: {stats['total_in']/1e6:.1f}M in + {stats['total_out']/1e6:.1f}M out",
246
+ flush=True,
247
+ )
248
+
249
+ await client.aclose()
250
+
251
+ elapsed = time.time() - t0
252
+ print(f"\nDone in {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)
253
+ print(f"Results: {len(results)} captions saved to {out_file}", flush=True)
254
+ print(f"Stats: {json.dumps(stats)}", flush=True)
255
+ total_tok = stats["total_in"] + stats["total_out"]
256
+ cost_in = stats["total_in"] / 1000 * 0.003
257
+ cost_out = stats["total_out"] / 1000 * 0.009
258
+ print(f"Total tokens: {total_tok:,} | Cost: ¥{cost_in + cost_out:.1f}", flush=True)
259
+
260
+
261
+ if __name__ == "__main__":
262
+ import argparse
263
+ parser = argparse.ArgumentParser()
264
+ parser.add_argument("--split", default="train", choices=["train", "val"])
265
+ parser.add_argument("--start", type=int, required=True, help="Start keyframe index (inclusive)")
266
+ parser.add_argument("--end", type=int, required=True, help="End keyframe index (exclusive)")
267
+ parser.add_argument("--tag", default="dashscope", help="Output file tag")
268
+ parser.add_argument("--dry-run", action="store_true")
269
+ parser.add_argument("--concurrency", type=int, default=50)
270
+ args = parser.parse_args()
271
+ MAX_CONCURRENCY = args.concurrency
272
+ asyncio.run(run(args.split, args.start, args.end, args.dry_run, args.tag))
scripts/gen_atlas_caption_qa.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Atlas Caption 数据生成脚本 (论文对齐版)
3
+
4
+ 与论文 Appendix A.3 完全对齐:
5
+ - 每个 keyframe 的 6 个摄像头各自独立生成 caption
6
+ - 使用论文 Table 8 中的 GPT-4V prompt
7
+ - human prompt 使用论文 Figure 5 风格的单视角模板
8
+ - 输出样本显式写入 `task="caption"` 与 `camera`
9
+ - 每个 keyframe 产出 6 条 QA,总计 ~204K 条 (34K x 6)
10
+ - 训练 prompt 与 src/prompting.py 中的 CAPTION_PROMPTS 保持一致
11
+
12
+ 支持: 异步并发、断点续传、自动重试
13
+ """
14
+ import asyncio
15
+ import json
16
+ import base64
17
+ import os
18
+ import sys
19
+ import time
20
+ import signal
21
+ from io import BytesIO
22
+ from pathlib import Path
23
+
24
+ try:
25
+ import httpx
26
+ from PIL import Image
27
+ except ImportError:
28
+ print("pip install httpx Pillow")
29
+ sys.exit(1)
30
+
31
+ NUSCENES_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes"
32
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent
33
+
34
+ CAMERAS = [
35
+ "CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT",
36
+ "CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT",
37
+ ]
38
+
39
+ GPT4V_PROMPT = (
40
+ "Describe the current traffic conditions. "
41
+ "If there are traffic lights in the image, describe the status of all the traffic lights, "
42
+ "including any countdowns; if there are none, please do not respond. "
43
+ "If there are traffic signs in the picture, identify and explain each one; "
44
+ "if there are none, no explanation is necessary. "
45
+ "If there are other vehicles in the picture, describe them in more detail. "
46
+ "Please ensure the answer does not exceed 600 words. Answers must be in English."
47
+ )
48
+
49
+ TRAIN_PROMPTS = [
50
+ (
51
+ "There are six images captured by the surround view cameras in driving vehicle. "
52
+ "They are uniformly represented as queries embeddings<query>. "
53
+ "Communicate a narrative of the setting within {camera_name} view image."
54
+ ),
55
+ ]
56
+
57
+ API_URL = "https://openrouter.fans/v1/chat/completions"
58
+ MODEL = "Qwen/Qwen3-VL-235B-A22B-Instruct"
59
+
60
+ MAX_CONCURRENCY = 30
61
+ MAX_RETRIES = 3
62
+ RETRY_DELAY = 5
63
+ TIMEOUT = 90
64
+ CHECKPOINT_INTERVAL = 100
65
+
66
+
67
+ def image_to_base64(path):
68
+ img = Image.open(path)
69
+ buf = BytesIO()
70
+ img.save(buf, format="JPEG", quality=80)
71
+ return base64.b64encode(buf.getvalue()).decode()
72
+
73
+
74
+ async def call_api(client, api_key, image_b64, camera_name):
75
+ content = [
76
+ {"type": "text", "text": f"[{camera_name}] {GPT4V_PROMPT}"},
77
+ {"type": "image_url", "image_url": {
78
+ "url": f"data:image/jpeg;base64,{image_b64}",
79
+ }},
80
+ ]
81
+ payload = {
82
+ "model": MODEL,
83
+ "messages": [{"role": "user", "content": content}],
84
+ "max_tokens": 800,
85
+ "temperature": 0.3,
86
+ }
87
+ headers = {
88
+ "Authorization": f"Bearer {api_key}",
89
+ "Content-Type": "application/json",
90
+ }
91
+ resp = await client.post(API_URL, json=payload, headers=headers, timeout=TIMEOUT)
92
+ resp.raise_for_status()
93
+ data = resp.json()
94
+ msg = data["choices"][0]["message"]["content"].strip()
95
+ usage = data.get("usage", {})
96
+ return msg, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0)
97
+
98
+
99
+ async def process_one_view(client, api_key, sample, cam_idx, sem, stats):
100
+ cam = CAMERAS[cam_idx]
101
+ img_path = os.path.join(NUSCENES_ROOT, sample["image_paths"][cam_idx])
102
+ if not os.path.exists(img_path):
103
+ stats["skipped"] += 1
104
+ return None
105
+
106
+ img_b64 = image_to_base64(img_path)
107
+ train_prompt = TRAIN_PROMPTS[0].format(camera_name=cam)
108
+
109
+ for attempt in range(MAX_RETRIES):
110
+ async with sem:
111
+ try:
112
+ caption, in_tok, out_tok = await call_api(client, api_key, img_b64, cam)
113
+ stats["success"] += 1
114
+ stats["total_in"] += in_tok
115
+ stats["total_out"] += out_tok
116
+ return {
117
+ "id": sample["id"],
118
+ "image_paths": sample["image_paths"],
119
+ "num_map_queries": 0,
120
+ "task": "caption",
121
+ "camera": cam,
122
+ "segment_id": sample.get("segment_id", ""),
123
+ "timestamp": sample.get("timestamp", None),
124
+ "conversations": [
125
+ {"from": "human", "value": train_prompt},
126
+ {"from": "gpt", "value": caption},
127
+ ],
128
+ }
129
+ except httpx.TimeoutException:
130
+ stats["retries"] += 1
131
+ if attempt < MAX_RETRIES - 1:
132
+ await asyncio.sleep(RETRY_DELAY * (attempt + 1))
133
+ except httpx.HTTPStatusError as e:
134
+ stats["retries"] += 1
135
+ if e.response.status_code == 429:
136
+ await asyncio.sleep(RETRY_DELAY * (attempt + 2))
137
+ elif attempt < MAX_RETRIES - 1:
138
+ await asyncio.sleep(RETRY_DELAY)
139
+ else:
140
+ stats["failed"] += 1
141
+ return None
142
+ except Exception:
143
+ stats["retries"] += 1
144
+ if attempt < MAX_RETRIES - 1:
145
+ await asyncio.sleep(RETRY_DELAY)
146
+ else:
147
+ stats["failed"] += 1
148
+ return None
149
+
150
+ stats["failed"] += 1
151
+ return None
152
+
153
+
154
+ def make_ckpt_key(sample_id, cam_idx):
155
+ return f"{sample_id}_{cam_idx}"
156
+
157
+
158
+ def load_checkpoint(path):
159
+ if os.path.exists(path):
160
+ with open(path) as f:
161
+ return set(json.load(f))
162
+ return set()
163
+
164
+
165
+ def save_checkpoint(path, done_keys):
166
+ with open(path, "w") as f:
167
+ json.dump(sorted(done_keys), f)
168
+
169
+
170
+ async def run(split, dry_run=False, limit=None):
171
+ api_key = os.environ.get("OPENROUTER_KEY", "")
172
+ if not api_key:
173
+ print("ERROR: set OPENROUTER_KEY env var", flush=True)
174
+ sys.exit(1)
175
+
176
+ data_file = PROJECT_ROOT / f"data/atlas_nuscenes_{split}.json"
177
+ out_file = PROJECT_ROOT / f"data/atlas_caption_{split}.json"
178
+ ckpt_file = PROJECT_ROOT / f"data/.caption_{split}_checkpoint.json"
179
+
180
+ with open(data_file) as f:
181
+ all_samples = json.load(f)
182
+
183
+ if limit:
184
+ all_samples = all_samples[:limit]
185
+
186
+ done_keys = load_checkpoint(ckpt_file)
187
+ existing_results = []
188
+ if os.path.exists(out_file) and done_keys:
189
+ with open(out_file) as f:
190
+ existing_results = json.load(f)
191
+
192
+ todo = []
193
+ for s in all_samples:
194
+ for cam_idx in range(6):
195
+ key = make_ckpt_key(s["id"], cam_idx)
196
+ if key not in done_keys:
197
+ todo.append((s, cam_idx))
198
+
199
+ total = len(todo)
200
+ print(f"Split: {split}", flush=True)
201
+ print(f"Total keyframes: {len(all_samples)}", flush=True)
202
+ print(f"Total views to caption: {len(all_samples) * 6}", flush=True)
203
+ print(f"Already done: {len(done_keys)}", flush=True)
204
+ print(f"To process: {total}", flush=True)
205
+ if dry_run:
206
+ print("DRY RUN", flush=True)
207
+ return
208
+
209
+ stats = {"success": 0, "failed": 0, "skipped": 0, "retries": 0,
210
+ "total_in": 0, "total_out": 0}
211
+ results = list(existing_results)
212
+ sem = asyncio.Semaphore(MAX_CONCURRENCY)
213
+ client = httpx.AsyncClient()
214
+
215
+ shutdown = False
216
+ def handle_signal(sig, frame):
217
+ nonlocal shutdown
218
+ shutdown = True
219
+ print("\nGraceful shutdown...", flush=True)
220
+ signal.signal(signal.SIGINT, handle_signal)
221
+
222
+ t0 = time.time()
223
+ batch_size = CHECKPOINT_INTERVAL
224
+ for batch_start in range(0, total, batch_size):
225
+ if shutdown:
226
+ break
227
+ batch = todo[batch_start:batch_start + batch_size]
228
+ tasks = [process_one_view(client, api_key, s, ci, sem, stats) for s, ci in batch]
229
+ batch_results = await asyncio.gather(*tasks)
230
+
231
+ for (s, ci), r in zip(batch, batch_results):
232
+ if r is not None:
233
+ results.append(r)
234
+ done_keys.add(make_ckpt_key(s["id"], ci))
235
+
236
+ with open(out_file, "w") as f:
237
+ json.dump(results, f, ensure_ascii=False)
238
+ save_checkpoint(ckpt_file, done_keys)
239
+
240
+ elapsed = time.time() - t0
241
+ done_n = batch_start + len(batch)
242
+ rps = stats["success"] / elapsed if elapsed > 0 else 0
243
+ eta = (total - done_n) / rps / 3600 if rps > 0 else 0
244
+ pct = done_n / total * 100
245
+
246
+ print(
247
+ f" [{pct:5.1f}%] {done_n}/{total} | "
248
+ f"ok={stats['success']} fail={stats['failed']} retry={stats['retries']} | "
249
+ f"{rps:.2f} rps | ETA {eta:.1f}h | "
250
+ f"tok: {stats['total_in']/1e6:.1f}M in + {stats['total_out']/1e6:.1f}M out",
251
+ flush=True,
252
+ )
253
+
254
+ await client.aclose()
255
+
256
+ elapsed = time.time() - t0
257
+ print(f"\nDone in {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)
258
+ print(f"Results: {len(results)} captions saved to {out_file}", flush=True)
259
+ print(f"Stats: {json.dumps(stats)}", flush=True)
260
+ total_tok = stats["total_in"] + stats["total_out"]
261
+ cost_rmb = total_tok / 50e6 * 40
262
+ print(f"Total tokens: {total_tok:,} | Est cost: ¥{cost_rmb:.1f}", flush=True)
263
+
264
+
265
+ if __name__ == "__main__":
266
+ import argparse
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument("--split", default="train", choices=["train", "val"])
269
+ parser.add_argument("--dry-run", action="store_true")
270
+ parser.add_argument("--limit", type=int, default=None)
271
+ parser.add_argument("--concurrency", type=int, default=30)
272
+ args = parser.parse_args()
273
+ MAX_CONCURRENCY = args.concurrency
274
+ asyncio.run(run(args.split, args.dry_run, args.limit))
scripts/gen_atlas_openlane_subsetB_lane_qa.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convert OpenLane-V2 subset-B data into Atlas-style lane QA JSON."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Dict, List, Tuple, Optional, Iterable
11
+
12
+
13
+ CAM_ORDER = [
14
+ "CAM_FRONT",
15
+ "CAM_FRONT_RIGHT",
16
+ "CAM_FRONT_LEFT",
17
+ "CAM_BACK",
18
+ "CAM_BACK_LEFT",
19
+ "CAM_BACK_RIGHT",
20
+ ]
21
+
22
+
23
+ def _val_to_bin(value: float, min_val: float, max_val: float, num_bins: int = 1000) -> int:
24
+ v = float(value)
25
+ if v < min_val:
26
+ v = min_val
27
+ if v > max_val:
28
+ v = max_val
29
+ t = (v - min_val) / (max_val - min_val)
30
+ idx = int(round(t * (num_bins - 1)))
31
+ if idx < 0:
32
+ idx = 0
33
+ if idx > (num_bins - 1):
34
+ idx = num_bins - 1
35
+ return idx
36
+
37
+
38
+ def _openlane_to_paper_xy(x_fwd: float, y_left: float) -> Tuple[float, float]:
39
+ return (-float(y_left), float(x_fwd))
40
+
41
+
42
+ def _uniform_sample_points(pts: List[List[float]], k: int) -> List[List[float]]:
43
+ if k <= 0 or len(pts) <= k:
44
+ return pts
45
+ if k == 1:
46
+ # 避免 k-1 = 0 导致的除零错误,返回中点
47
+ mid_idx = len(pts) // 2
48
+ return [pts[mid_idx]]
49
+ n = len(pts)
50
+ out = []
51
+ for i in range(k):
52
+ j = int(round(i * (n - 1) / float(k - 1)))
53
+ out.append(pts[j])
54
+ return out
55
+
56
+
57
+ def _lane_bev_distance(lane: Dict) -> float:
58
+ """Compute BEV distance from ego (origin) to lane centroid for sorting."""
59
+ pts = lane.get("points", [])
60
+ if not pts:
61
+ return float('inf')
62
+ xs, ys = [], []
63
+ for p in pts:
64
+ if isinstance(p, (list, tuple)) and len(p) >= 2:
65
+ xs.append(float(p[0]))
66
+ ys.append(float(p[1]))
67
+ if not xs:
68
+ return float('inf')
69
+ cx, cy = sum(xs) / len(xs), sum(ys) / len(ys)
70
+ return cx * cx + cy * cy
71
+
72
+
73
+ def _lane_answer_from_centerlines(
74
+ lane_centerline: List[Dict],
75
+ max_lanes: int,
76
+ points_per_lane: int,
77
+ xy_range_m: float = 51.2,
78
+ z_min: float = -5.0,
79
+ z_max: float = 3.0,
80
+ ) -> str:
81
+ lanes = list(lane_centerline or [])
82
+ lanes.sort(key=_lane_bev_distance)
83
+ if max_lanes > 0:
84
+ lanes = lanes[:max_lanes]
85
+
86
+ parts: List[str] = []
87
+ for ln in lanes:
88
+ pts = ln.get("points", [])
89
+ if not isinstance(pts, list) or not pts:
90
+ continue
91
+ pts = [p for p in pts if isinstance(p, (list, tuple)) and len(p) >= 3]
92
+ if not pts:
93
+ continue
94
+ pts = _uniform_sample_points([list(map(float, p[:3])) for p in pts], points_per_lane)
95
+
96
+ bins: List[str] = []
97
+ for x, y, z in pts:
98
+ x_p, y_p = _openlane_to_paper_xy(x, y)
99
+ if abs(x_p) > xy_range_m or abs(y_p) > xy_range_m:
100
+ continue
101
+ if z < z_min or z > z_max:
102
+ continue
103
+ xb = _val_to_bin(x_p, -xy_range_m, xy_range_m, 1000)
104
+ yb = _val_to_bin(y_p, -xy_range_m, xy_range_m, 1000)
105
+ zb = _val_to_bin(z, z_min, z_max, 1000)
106
+ bins.append(f"[{xb}, {yb}, {zb}]")
107
+
108
+ if bins:
109
+ parts.append(", ".join(bins))
110
+
111
+ if not parts:
112
+ return "No lane centerlines detected within range."
113
+ return "Lane: " + "; ".join(parts) + "."
114
+
115
+
116
+ def _prompt_lane_qa() -> str:
117
+ import sys, os
118
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
119
+ try:
120
+ from src.prompting import sample_prompt
121
+ return sample_prompt("lane")
122
+ except ImportError:
123
+ return (
124
+ "There are six images captured by the surround view cameras in driving vehicle. "
125
+ "They are uniformly represented as queries embeddings<query>. "
126
+ "Please complete the centerline detection task under the Bird's Eye View (BEV) perspective. "
127
+ "Ensure that the detection range does not exceed 50 meters."
128
+ )
129
+
130
+
131
+ def iter_info_jsons(root: Path, split: str) -> Iterable[Path]:
132
+ base = root / split
133
+ if not base.exists():
134
+ return []
135
+ return base.glob("*/info/*.json")
136
+
137
+
138
+ def write_json_list_stream(out_path: Path, items: Iterable[Dict]) -> None:
139
+ out_path.parent.mkdir(parents=True, exist_ok=True)
140
+ with open(out_path, "w", encoding="utf-8") as f:
141
+ f.write("[\n")
142
+ first = True
143
+ for it in items:
144
+ if not first:
145
+ f.write(",\n")
146
+ first = False
147
+ json.dump(it, f, ensure_ascii=False)
148
+ f.write("\n]\n")
149
+
150
+
151
+ def main() -> None:
152
+ ap = argparse.ArgumentParser()
153
+ ap.add_argument("--openlane_root", type=str, default="/home/guoyuanbo/autodl-tmp/OpenLane-V2")
154
+ ap.add_argument(
155
+ "--no_absolute_image_paths",
156
+ action="store_true",
157
+ default=False,
158
+ )
159
+ ap.add_argument("--split", type=str, default="train", choices=["train", "val", "test"])
160
+ ap.add_argument("--out_json", type=str, required=True)
161
+ ap.add_argument("--max_samples", type=int, default=0, help="0 means all")
162
+ ap.add_argument("--max_lanes", type=int, default=0,
163
+ help="Max lanes per sample (0=no limit, paper does not specify a cap)")
164
+ ap.add_argument("--points_per_lane", type=int, default=4)
165
+ ap.add_argument("--include_raw_lane", action="store_true", default=False)
166
+ ap.add_argument("--num_map_queries", type=int, default=256)
167
+ args = ap.parse_args()
168
+
169
+ root = Path(args.openlane_root)
170
+ out_path = Path(args.out_json)
171
+
172
+ def _gen():
173
+ n = 0
174
+ for p in iter_info_jsons(root, args.split):
175
+ if args.max_samples and n >= int(args.max_samples):
176
+ break
177
+ try:
178
+ d = json.loads(p.read_text(encoding="utf-8", errors="replace"))
179
+ except Exception:
180
+ continue
181
+
182
+ sensor = d.get("sensor", None)
183
+ ann = d.get("annotation", {})
184
+ lane_centerline = ann.get("lane_centerline", [])
185
+
186
+ image_paths: List[str] = []
187
+ use_absolute = not args.no_absolute_image_paths
188
+ if isinstance(sensor, dict) and all(cam in sensor for cam in CAM_ORDER):
189
+ for cam in CAM_ORDER:
190
+ rp = str(sensor[cam].get("image_path"))
191
+ if use_absolute:
192
+ image_paths.append(str((root / rp).resolve()))
193
+ else:
194
+ image_paths.append(rp)
195
+ else:
196
+ seq_dir = p.parent.parent
197
+ ts = p.stem
198
+ for cam in CAM_ORDER:
199
+ rp = str((Path(args.split) / seq_dir.name / "image" / cam / f"{ts}.jpg").as_posix())
200
+ if use_absolute:
201
+ image_paths.append(str((root / rp).resolve()))
202
+ else:
203
+ image_paths.append(rp)
204
+
205
+ if use_absolute:
206
+ missing = [ip for ip in image_paths if not Path(ip).exists()]
207
+ else:
208
+ missing = [ip for ip in image_paths if not (root / ip).exists()]
209
+ if missing:
210
+ continue
211
+
212
+ answer = _lane_answer_from_centerlines(
213
+ lane_centerline=lane_centerline,
214
+ max_lanes=int(args.max_lanes),
215
+ points_per_lane=int(args.points_per_lane),
216
+ )
217
+
218
+ prompt = _prompt_lane_qa()
219
+
220
+ sample_id = f"openlane_subsetB_{args.split}_{d.get('segment_id','seg')}_{d.get('timestamp','ts')}"
221
+ it: Dict = {
222
+ "id": sample_id,
223
+ "image_paths": image_paths,
224
+ "num_map_queries": int(args.num_map_queries),
225
+ "task": "lane",
226
+ "sensor": sensor,
227
+ "pose": d.get("pose", None),
228
+ "timestamp": d.get("timestamp", None),
229
+ "segment_id": d.get("segment_id", None),
230
+ "meta_data": d.get("meta_data", None),
231
+ "conversations": [
232
+ {"from": "human", "value": prompt},
233
+ {"from": "gpt", "value": answer},
234
+ ],
235
+ }
236
+ if args.include_raw_lane:
237
+ it["openlane_lane_centerline"] = lane_centerline
238
+ n += 1
239
+ if n % 1000 == 0:
240
+ print(f"[progress] wrote_samples={n}")
241
+ yield it
242
+
243
+ print(f"[start] openlane_root={root} split={args.split} out_json={out_path}")
244
+ write_json_list_stream(out_path, _gen())
245
+ print("[done]")
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()
250
+
251
+
scripts/gen_atlas_planning_qa.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import math
3
+ import argparse
4
+ import json
5
+ import os
6
+ import sys
7
+ from collections import Counter
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
15
+ from src.prompting import PLANNING_TABLE3_MODES, rewrite_planning_prompt_for_table3
16
+
17
+ Z_MIN, Z_MAX = -5.0, 3.0
18
+ VEL_ACC_RANGE = (-50.0, 50.0)
19
+ XY_RANGE = (-51.2, 51.2)
20
+ NUM_BINS = 1000
21
+ WAYPOINT_DT = 0.5
22
+ NUM_WAYPOINTS = 6
23
+ # Official UniAD get_sdc_planning_label() uses the terminal lateral offset
24
+ # (RIGHT if x >= 2, LEFT if x <= -2, else FORWARD). Our waypoints are already
25
+ # in Atlas paper frame, where x is lateral-right and y is forward.
26
+ UNIAD_COMMAND_X_THRESHOLD = 2.0
27
+
28
+ CAMERA_NAMES = [
29
+ 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
30
+ 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
31
+ ]
32
+
33
+
34
+ def _val_to_bin(value: float, min_val: float, max_val: float, num_bins: int = NUM_BINS) -> int:
35
+ v = float(np.clip(value, min_val, max_val))
36
+ t = (v - min_val) / (max_val - min_val)
37
+ idx = int(round(t * (num_bins - 1)))
38
+ return int(np.clip(idx, 0, num_bins - 1))
39
+
40
+
41
+ def _nuscenes_to_paper_xy(x_fwd: float, y_left: float) -> Tuple[float, float]:
42
+ return (-float(y_left), float(x_fwd))
43
+
44
+
45
+ def _derive_uniad_style_command(
46
+ waypoints: List[List[float]],
47
+ lateral_threshold: float = UNIAD_COMMAND_X_THRESHOLD,
48
+ ) -> str:
49
+ """Derive a 3-way planning command from future GT waypoints.
50
+
51
+ This intentionally matches the semantics of UniAD's
52
+ `get_sdc_planning_label()`: the final valid future position determines a
53
+ coarse RIGHT / LEFT / FORWARD command based on lateral displacement.
54
+ """
55
+ valid_waypoints: List[Tuple[float, float]] = []
56
+ for wp in waypoints:
57
+ if not isinstance(wp, (list, tuple)) or len(wp) < 2:
58
+ continue
59
+ x = float(wp[0])
60
+ y = float(wp[1])
61
+ if np.isfinite(x) and np.isfinite(y):
62
+ valid_waypoints.append((x, y))
63
+
64
+ if not valid_waypoints:
65
+ return "go straight"
66
+
67
+ target_x = float(valid_waypoints[-1][0])
68
+ if target_x >= lateral_threshold:
69
+ return "turn right"
70
+ if target_x <= -lateral_threshold:
71
+ return "turn left"
72
+ return "go straight"
73
+
74
+
75
+ def _compute_velocity(nusc, sample) -> Tuple[float, float]:
76
+ try:
77
+ lidar_token = sample['data']['LIDAR_TOP']
78
+ lidar_data = nusc.get('sample_data', lidar_token)
79
+ from pyquaternion import Quaternion
80
+
81
+ ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
82
+ ego_t = np.array(ego_pose['translation'])
83
+ ego_q = Quaternion(ego_pose['rotation'])
84
+
85
+ prev_token = lidar_data.get('prev', '')
86
+ if prev_token:
87
+ prev_data = nusc.get('sample_data', prev_token)
88
+ prev_ego = nusc.get('ego_pose', prev_data['ego_pose_token'])
89
+ prev_t = np.array(prev_ego['translation'])
90
+ dt = (lidar_data['timestamp'] - prev_data['timestamp']) * 1e-6
91
+ if dt > 0:
92
+ vel_global = (ego_t - prev_t) / dt
93
+ vel_ego = ego_q.inverse.rotate(vel_global)
94
+ return float(vel_ego[0]), float(vel_ego[1])
95
+ except Exception:
96
+ pass
97
+ return 0.0, 0.0
98
+
99
+
100
+ def _compute_acceleration(nusc, sample) -> Tuple[float, float]:
101
+ try:
102
+ lidar_token = sample['data']['LIDAR_TOP']
103
+ lidar_data = nusc.get('sample_data', lidar_token)
104
+ from pyquaternion import Quaternion
105
+
106
+ ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
107
+ ego_q = Quaternion(ego_pose['rotation'])
108
+
109
+ prev_token = lidar_data.get('prev', '')
110
+ if not prev_token:
111
+ return 0.0, 0.0
112
+ prev_data = nusc.get('sample_data', prev_token)
113
+ dt1 = (lidar_data['timestamp'] - prev_data['timestamp']) * 1e-6
114
+ if dt1 <= 0:
115
+ return 0.0, 0.0
116
+
117
+ prev2_token = prev_data.get('prev', '')
118
+ if not prev2_token:
119
+ return 0.0, 0.0
120
+ prev2_data = nusc.get('sample_data', prev2_token)
121
+ dt2 = (prev_data['timestamp'] - prev2_data['timestamp']) * 1e-6
122
+ if dt2 <= 0:
123
+ return 0.0, 0.0
124
+
125
+ def _ego_vel(sd1, sd2, dt_val):
126
+ e1 = nusc.get('ego_pose', sd1['ego_pose_token'])
127
+ e2 = nusc.get('ego_pose', sd2['ego_pose_token'])
128
+ t1 = np.array(e1['translation'])
129
+ t2 = np.array(e2['translation'])
130
+ return (t1 - t2) / dt_val
131
+
132
+ v1_global = _ego_vel(lidar_data, prev_data, dt1)
133
+ v0_global = _ego_vel(prev_data, prev2_data, dt2)
134
+ acc_global = (v1_global - v0_global) / dt1
135
+ acc_ego = ego_q.inverse.rotate(acc_global)
136
+ return float(acc_ego[0]), float(acc_ego[1])
137
+ except Exception:
138
+ return 0.0, 0.0
139
+
140
+
141
+ def _get_future_waypoints(nusc, sample) -> Optional[List[List[float]]]:
142
+ try:
143
+ from pyquaternion import Quaternion
144
+
145
+ lidar_token = sample['data']['LIDAR_TOP']
146
+ lidar_data = nusc.get('sample_data', lidar_token)
147
+ ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
148
+ ego_t = np.array(ego_pose['translation'])
149
+ ego_q = Quaternion(ego_pose['rotation'])
150
+
151
+ current_ts = lidar_data['timestamp']
152
+ target_times = [current_ts + int(WAYPOINT_DT * (i + 1) * 1e6) for i in range(NUM_WAYPOINTS)]
153
+
154
+ all_sd = []
155
+ sd_token = lidar_token
156
+ while sd_token:
157
+ sd = nusc.get('sample_data', sd_token)
158
+ all_sd.append(sd)
159
+ sd_token = sd.get('next', '')
160
+ if not sd_token:
161
+ break
162
+ if sd['timestamp'] > target_times[-1] + 1e6:
163
+ break
164
+
165
+ if len(all_sd) < 2:
166
+ return None
167
+
168
+ timestamps = np.array([s['timestamp'] for s in all_sd])
169
+ poses = []
170
+ for s in all_sd:
171
+ ep = nusc.get('ego_pose', s['ego_pose_token'])
172
+ poses.append(np.array(ep['translation']))
173
+ poses = np.array(poses)
174
+
175
+ waypoints = []
176
+ for tt in target_times:
177
+ if tt > timestamps[-1] or tt < timestamps[0]:
178
+ return None
179
+ idx = np.searchsorted(timestamps, tt, side='right') - 1
180
+ idx = max(0, min(idx, len(timestamps) - 2))
181
+ dt_seg = timestamps[idx + 1] - timestamps[idx]
182
+ if dt_seg <= 0:
183
+ return None
184
+ alpha = (tt - timestamps[idx]) / dt_seg
185
+ pos_global = poses[idx] * (1 - alpha) + poses[idx + 1] * alpha
186
+ pos_ego = ego_q.inverse.rotate(pos_global - ego_t)
187
+ x_p, y_p = _nuscenes_to_paper_xy(pos_ego[0], pos_ego[1])
188
+ waypoints.append([float(x_p), float(y_p)])
189
+
190
+ return waypoints
191
+ except Exception:
192
+ return None
193
+
194
+
195
+ def _format_planning_answer(
196
+ vx: float, vy: float, ax: float, ay: float,
197
+ waypoints: List[List[float]],
198
+ ) -> str:
199
+ vx_bin = _val_to_bin(vx, *VEL_ACC_RANGE)
200
+ vy_bin = _val_to_bin(vy, *VEL_ACC_RANGE)
201
+ ax_bin = _val_to_bin(ax, *VEL_ACC_RANGE)
202
+ ay_bin = _val_to_bin(ay, *VEL_ACC_RANGE)
203
+
204
+ wp_strs = []
205
+ for wp in waypoints:
206
+ xb = _val_to_bin(wp[0], *XY_RANGE)
207
+ yb = _val_to_bin(wp[1], *XY_RANGE)
208
+ wp_strs.append(f"[{xb}, {yb}]")
209
+ return (
210
+ f"Ego car speed value:[{vx_bin}, {vy_bin}]. "
211
+ f"Ego car acceleration value:[{ax_bin}, {ay_bin}]. "
212
+ "Based on the ego car speed and acceleration you predicted, "
213
+ f"request the ego car planning waypoint in 3-seconds: {', '.join(wp_strs)}"
214
+ )
215
+
216
+
217
+ def _collect_gt_boxes_ego(nusc, sample) -> List[Dict]:
218
+ from pyquaternion import Quaternion
219
+
220
+ lidar_token = sample['data']['LIDAR_TOP']
221
+ lidar_data = nusc.get('sample_data', lidar_token)
222
+ ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
223
+ ego_t = np.array(ego_pose['translation'])
224
+ ego_q = Quaternion(ego_pose['rotation'])
225
+
226
+ cs_record = nusc.get('calibrated_sensor', lidar_data['calibrated_sensor_token'])
227
+ cs_t = np.array(cs_record['translation'])
228
+ cs_q = Quaternion(cs_record['rotation'])
229
+
230
+ boxes = []
231
+ for ann_token in sample['anns']:
232
+ ann = nusc.get('sample_annotation', ann_token)
233
+ center_global = np.array(ann['translation'])
234
+ center_ego = ego_q.inverse.rotate(center_global - ego_t)
235
+ x_p, y_p = _nuscenes_to_paper_xy(center_ego[0], center_ego[1])
236
+ yaw_global = Quaternion(ann['rotation'])
237
+ yaw_ego = ego_q.inverse * yaw_global
238
+ # _nuscenes_to_paper_xy applies a 90° CCW rotation:
239
+ # x_paper = -y_ego, y_paper = x_ego
240
+ # Yaw must be rotated by the same +π/2 to stay consistent.
241
+ yaw_angle = float(yaw_ego.yaw_pitch_roll[0]) + math.pi / 2.0
242
+ w = float(ann['size'][0])
243
+ l = float(ann['size'][1])
244
+ h = float(ann['size'][2])
245
+ boxes.append({
246
+ "world_coords": [float(x_p), float(y_p), float(center_ego[2])],
247
+ "w": w,
248
+ "l": l,
249
+ "h": h,
250
+ "yaw": yaw_angle,
251
+ "category": ann['category_name'],
252
+ })
253
+ return boxes
254
+
255
+
256
+ def _collect_gt_boxes_per_timestep(nusc, sample, num_timesteps=NUM_WAYPOINTS) -> List[List[Dict]]:
257
+ """Collect GT boxes for each future keyframe, transformed to current ego frame.
258
+
259
+ ST-P3 protocol: at each future timestep t, collision is checked against
260
+ the actual positions of other agents at time t, not their positions at t=0.
261
+ nuScenes keyframes are ~0.5s apart, matching the waypoint interval.
262
+ """
263
+ from pyquaternion import Quaternion
264
+
265
+ lidar_token = sample['data']['LIDAR_TOP']
266
+ lidar_data = nusc.get('sample_data', lidar_token)
267
+ ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
268
+ ref_ego_t = np.array(ego_pose['translation'])
269
+ ref_ego_q = Quaternion(ego_pose['rotation'])
270
+
271
+ per_timestep_boxes: List[List[Dict]] = []
272
+ cur_sample = sample
273
+ for _ in range(num_timesteps):
274
+ next_token = cur_sample.get('next', '')
275
+ if not next_token:
276
+ per_timestep_boxes.append(per_timestep_boxes[-1] if per_timestep_boxes else [])
277
+ continue
278
+
279
+ cur_sample = nusc.get('sample', next_token)
280
+ boxes = []
281
+ for ann_token in cur_sample['anns']:
282
+ ann = nusc.get('sample_annotation', ann_token)
283
+ center_global = np.array(ann['translation'])
284
+ center_ego = ref_ego_q.inverse.rotate(center_global - ref_ego_t)
285
+ x_p, y_p = _nuscenes_to_paper_xy(center_ego[0], center_ego[1])
286
+
287
+ yaw_global = Quaternion(ann['rotation'])
288
+ yaw_ego = ref_ego_q.inverse * yaw_global
289
+ yaw_angle = float(yaw_ego.yaw_pitch_roll[0]) + math.pi / 2.0
290
+
291
+ w = float(ann['size'][0])
292
+ l = float(ann['size'][1])
293
+ h = float(ann['size'][2])
294
+ boxes.append({
295
+ "world_coords": [float(x_p), float(y_p), float(center_ego[2])],
296
+ "w": w, "l": l, "h": h,
297
+ "yaw": yaw_angle,
298
+ "category": ann['category_name'],
299
+ })
300
+ per_timestep_boxes.append(boxes)
301
+
302
+ return per_timestep_boxes
303
+
304
+
305
+ def process_sample(
306
+ nusc,
307
+ sample_token: str,
308
+ data_root: Path,
309
+ planning_table3_mode: str,
310
+ ) -> Optional[Dict]:
311
+ try:
312
+ from pyquaternion import Quaternion
313
+ from src.prompting import sample_prompt
314
+
315
+ sample = nusc.get('sample', sample_token)
316
+
317
+ image_paths = []
318
+ for cam_name in CAMERA_NAMES:
319
+ if cam_name in sample['data']:
320
+ cam_token = sample['data'][cam_name]
321
+ cam_data = nusc.get('sample_data', cam_token)
322
+ image_paths.append(cam_data['filename'].replace('\\', '/'))
323
+
324
+ if len(image_paths) != 6:
325
+ return None
326
+
327
+ vx_n, vy_n = _compute_velocity(nusc, sample)
328
+ ax_n, ay_n = _compute_acceleration(nusc, sample)
329
+
330
+ vx_p, vy_p = _nuscenes_to_paper_xy(vx_n, vy_n)
331
+ ax_p, ay_p = _nuscenes_to_paper_xy(ax_n, ay_n)
332
+
333
+ waypoints = _get_future_waypoints(nusc, sample)
334
+ if waypoints is None:
335
+ return None
336
+
337
+ vx_bin = _val_to_bin(vx_p, *VEL_ACC_RANGE)
338
+ vy_bin = _val_to_bin(vy_p, *VEL_ACC_RANGE)
339
+ ax_bin = _val_to_bin(ax_p, *VEL_ACC_RANGE)
340
+ ay_bin = _val_to_bin(ay_p, *VEL_ACC_RANGE)
341
+ route_command = _derive_uniad_style_command(waypoints)
342
+
343
+ prompt = sample_prompt(
344
+ "planning",
345
+ vx_bin=vx_bin, vy_bin=vy_bin,
346
+ ax_bin=ax_bin, ay_bin=ay_bin,
347
+ command=route_command,
348
+ )
349
+ prompt = rewrite_planning_prompt_for_table3(
350
+ prompt,
351
+ mode=planning_table3_mode,
352
+ command=route_command,
353
+ velocity_bins=(vx_bin, vy_bin),
354
+ acceleration_bins=(ax_bin, ay_bin),
355
+ )
356
+ answer = _format_planning_answer(vx_p, vy_p, ax_p, ay_p, waypoints)
357
+
358
+ gt_boxes = _collect_gt_boxes_ego(nusc, sample)
359
+ gt_boxes_per_ts = _collect_gt_boxes_per_timestep(nusc, sample)
360
+
361
+ item = {
362
+ "id": sample_token,
363
+ "image_paths": image_paths,
364
+ "num_map_queries": 256,
365
+ "task": "planning",
366
+ "segment_id": sample.get("scene_token", ""),
367
+ "timestamp": sample.get("timestamp", None),
368
+ "ego_motion": {
369
+ "velocity": [vx_p, vy_p],
370
+ "acceleration": [ax_p, ay_p],
371
+ "waypoints": waypoints,
372
+ },
373
+ "gt_boxes_3d": gt_boxes,
374
+ "gt_boxes_3d_per_timestep": gt_boxes_per_ts,
375
+ "conversations": [
376
+ {"from": "human", "value": prompt},
377
+ {"from": "gpt", "value": answer},
378
+ ],
379
+ "route_command": route_command,
380
+ }
381
+ return item
382
+ except Exception:
383
+ return None
384
+
385
+
386
+ def _audit_results(results: List[Dict], planning_table3_mode: str) -> None:
387
+ total = int(len(results))
388
+ if total == 0:
389
+ print("[AUDIT] No planning samples were generated.")
390
+ return
391
+
392
+ route_commands = [item.get("route_command") for item in results]
393
+ route_command_coverage = sum(isinstance(cmd, str) and bool(cmd) for cmd in route_commands)
394
+ route_command_dist = Counter(route_commands)
395
+ legacy_ego_motion_command = sum(
396
+ 1
397
+ for item in results
398
+ if isinstance(item.get("ego_motion"), dict) and "command" in item["ego_motion"]
399
+ )
400
+
401
+ prompt_with_command = 0
402
+ prompt_with_state = 0
403
+ for item in results:
404
+ conv = item.get("conversations", [])
405
+ if not conv:
406
+ continue
407
+ prompt_text = str(conv[0].get("value", ""))
408
+ if "The ego car will " in prompt_text:
409
+ prompt_with_command += 1
410
+ if "The current speed value of the ego car is [" in prompt_text:
411
+ prompt_with_state += 1
412
+
413
+ print(
414
+ "[AUDIT] planning route_command "
415
+ f"mode={planning_table3_mode} "
416
+ f"coverage={route_command_coverage}/{total} "
417
+ f"legacy_ego_motion_command={legacy_ego_motion_command}/{total} "
418
+ f"prompt_with_command={prompt_with_command}/{total} "
419
+ f"prompt_with_state={prompt_with_state}/{total}"
420
+ )
421
+ print(f"[AUDIT] planning route_command distribution={dict(route_command_dist)}")
422
+ print(
423
+ "[AUDIT] route_command semantics: UniAD-style future-GT-derived "
424
+ f"(terminal lateral x threshold={UNIAD_COMMAND_X_THRESHOLD:.1f}m)."
425
+ )
426
+
427
+
428
+ def main():
429
+ parser = argparse.ArgumentParser()
430
+ parser.add_argument('--version', type=str, default='v1.0-trainval')
431
+ parser.add_argument('--split', type=str, default='train', choices=['train', 'val'])
432
+ parser.add_argument('--data-root', type=str, default='/mnt/data/nuscenes')
433
+ parser.add_argument('--output', type=str, default=None)
434
+ parser.add_argument(
435
+ '--planning-table3-mode',
436
+ type=str,
437
+ choices=PLANNING_TABLE3_MODES,
438
+ default='atlas_high_level',
439
+ help=(
440
+ 'Human prompt variant to materialize in the generated JSON. '
441
+ 'route_command is always written as a top-level UniAD-style '
442
+ 'future-GT-derived command.'
443
+ ),
444
+ )
445
+ args = parser.parse_args()
446
+
447
+ data_root = Path(args.data_root)
448
+ script_dir = Path(__file__).parent.absolute()
449
+ project_root = script_dir.parent
450
+
451
+ if args.output:
452
+ output_file = Path(args.output)
453
+ else:
454
+ output_file = project_root / "data" / f"atlas_planning_{args.split}_uniad_command.json"
455
+
456
+ from nuscenes.nuscenes import NuScenes
457
+ from nuscenes.utils.splits import create_splits_scenes
458
+
459
+ nusc = NuScenes(version=args.version, dataroot=str(data_root), verbose=True)
460
+
461
+ splits = create_splits_scenes()
462
+ split_scenes = set(splits[args.split])
463
+ scene_tokens = set()
464
+ for scene in nusc.scene:
465
+ if scene['name'] in split_scenes:
466
+ scene_tokens.add(scene['token'])
467
+ samples_to_process = [s for s in nusc.sample if s['scene_token'] in scene_tokens]
468
+
469
+ print(f"Processing {len(samples_to_process)} samples for planning...")
470
+ results = []
471
+ for sample in tqdm(samples_to_process):
472
+ item = process_sample(
473
+ nusc,
474
+ sample['token'],
475
+ data_root,
476
+ planning_table3_mode=args.planning_table3_mode,
477
+ )
478
+ if item is not None:
479
+ results.append(item)
480
+
481
+ output_file.parent.mkdir(parents=True, exist_ok=True)
482
+ with open(output_file, 'w', encoding='utf-8') as f:
483
+ json.dump(results, f, indent=2, ensure_ascii=False)
484
+
485
+ _audit_results(results, args.planning_table3_mode)
486
+ print(f"Saved {len(results)} planning samples to {output_file}")
487
+
488
+
489
+ if __name__ == "__main__":
490
+ main()
491
+
scripts/run_val_extraction.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Offline token extraction orchestrator.
3
+ # Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
4
+ set -e
5
+
6
+ if [ "${ATLAS_ALLOW_OFFLINE}" != "1" ]; then
7
+ echo "ERROR: This is an OFFLINE extraction orchestrator." >&2
8
+ echo "It is isolated by default to prevent accidental use." >&2
9
+ echo "If you really need it, set: ATLAS_ALLOW_OFFLINE=1" >&2
10
+ echo "For online training use: bash scripts/train_no_caption_baseline.sh" >&2
11
+ exit 1
12
+ fi
13
+
14
+ cd /home/guoyuanbo/3dtokenizer-atlas
15
+ export LD_LIBRARY_PATH=/home/guoyuanbo/3dtokenizer/envs/streampetr/lib:${LD_LIBRARY_PATH:-}
16
+ PY=/home/guoyuanbo/3dtokenizer/envs/streampetr/bin/python
17
+
18
+ LOG_DIR=work_dirs
19
+ DET_OUT=work_dirs/precomputed_det_tokens_offline/val
20
+ MAP_OUT=work_dirs/precomputed_map_tokens_offline/val
21
+
22
+ echo "[$(date)] === Waiting for Phase 1 (StreamPETR det) to finish ==="
23
+
24
+ # Wait for all extract_streampetr_tokens processes to finish
25
+ while pgrep -f "extract_streampetr_tokens.py" > /dev/null 2>&1; do
26
+ DET_COUNT=$(find "$DET_OUT" -name "*.pt" 2>/dev/null | wc -l)
27
+ echo "[$(date)] Phase 1 running... det files: $DET_COUNT / ~12038"
28
+ sleep 300
29
+ done
30
+
31
+ DET_FINAL=$(find "$DET_OUT" -name "*.pt" 2>/dev/null | wc -l)
32
+ echo "[$(date)] Phase 1 DONE. Total det files: $DET_FINAL"
33
+
34
+ echo "[$(date)] === Starting Phase 2 (TopoMLP map) ==="
35
+ mkdir -p "$MAP_OUT"
36
+
37
+ for i in 0 1 2 3; do
38
+ CUDA_VISIBLE_DEVICES=$i $PY extract_topomlp_tokens.py \
39
+ --topomlp_config configs/topomlp_atlas_aligned.py \
40
+ --topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
41
+ --data_json "data/atlas_planning_val_uniad_command.json,data/openlane_subsetB_lane_val_4pt.json" \
42
+ --data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
43
+ --output_dir "$MAP_OUT" \
44
+ --shard_id $i --num_shards 4 \
45
+ > "$LOG_DIR/extract_map_val_shard_${i}.log" 2>&1 &
46
+ echo "[$(date)] Phase 2 shard $i launched (PID=$!)"
47
+ done
48
+
49
+ echo "[$(date)] Waiting for Phase 2 to complete..."
50
+ wait
51
+
52
+ MAP_FINAL=$(find "$MAP_OUT" -name "*.pt" 2>/dev/null | wc -l)
53
+ echo "[$(date)] Phase 2 DONE. Total map files: $MAP_FINAL"
54
+ echo "[$(date)] === All extraction complete ==="
55
+ echo " Det tokens: $DET_FINAL files in $DET_OUT"
56
+ echo " Map tokens: $MAP_FINAL files in $MAP_OUT"
scripts/train_no_caption_baseline.sh ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Atlas online training: detection + planning + lane (no caption).
3
+ # Default: visual_token_mode=online, live frozen StreamPETR + TopoMLP.
4
+ # Data: 28k det + 24k plan + 28k lane = ~80k samples/epoch
5
+ #
6
+ # Usage:
7
+ # bash scripts/train_no_caption_baseline.sh
8
+ # RESUME_CKPT=work_dirs/atlas_no_caption/epoch-4/checkpoint.pt bash scripts/train_no_caption_baseline.sh
9
+ # NUM_GPUS=8 bash scripts/train_no_caption_baseline.sh
10
+ #
11
+ # For offline mode (read precomputed *_offline token dirs):
12
+ # bash scripts/train_no_caption_baseline_offline.sh
13
+ set -e
14
+
15
+ PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
16
+ cd "$PROJECT_ROOT"
17
+
18
+ NUM_GPUS=${NUM_GPUS:-4}
19
+ PLANNING_TABLE3_MODE=${PLANNING_TABLE3_MODE:-atlas_high_level_ego}
20
+
21
+ EXTRA_ARGS=""
22
+ if [ -n "$RESUME_CKPT" ]; then
23
+ EXTRA_ARGS="--resume $RESUME_CKPT"
24
+ fi
25
+
26
+ deepspeed --num_gpus "$NUM_GPUS" train_atlas.py \
27
+ --llm_model pretrained/vicuna-7b-v1.5 \
28
+ --data_json data/atlas_nuscenes_train.json,data/atlas_planning_train_uniad_command.json,data/openlane_subsetB_lane_train_4pt.json \
29
+ --data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
30
+ --visual_token_mode online \
31
+ --planning_table3_mode "$PLANNING_TABLE3_MODE" \
32
+ --streampetr_config configs/streampetr_atlas_aligned.py \
33
+ --streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
34
+ --topomlp_config configs/topomlp_atlas_aligned.py \
35
+ --topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
36
+ --deepspeed configs/ds_zero2.json \
37
+ --output_dir work_dirs/atlas_no_caption_online \
38
+ --epochs 10 \
39
+ --lr 2e-5 \
40
+ --weight_decay 1e-4 \
41
+ --batch_size 1 \
42
+ --gradient_accumulation_steps 2 \
43
+ --warmup_ratio 0.03 \
44
+ --max_grad_norm 1.0 \
45
+ --log_steps 100 \
46
+ --save_epochs 1 \
47
+ --keep_last_n_ckpts 3 \
48
+ --seed 42 \
49
+ --num_workers 4 \
50
+ $EXTRA_ARGS
scripts/train_no_caption_baseline_offline.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # [OFFLINE fallback] Atlas training using precomputed *_offline visual tokens.
3
+ # This is NOT the default. Use scripts/train_no_caption_baseline.sh for online mode.
4
+ # Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
5
+ set -e
6
+
7
+ if [ "${ATLAS_ALLOW_OFFLINE}" != "1" ]; then
8
+ echo "ERROR: This is an OFFLINE fallback script, not the primary online training." >&2
9
+ echo "It is isolated by default to prevent accidental use in experiments." >&2
10
+ echo "If you really need it, set: ATLAS_ALLOW_OFFLINE=1" >&2
11
+ echo "For production training use: bash scripts/train_no_caption_baseline.sh" >&2
12
+ exit 1
13
+ fi
14
+
15
+ PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
16
+ cd "$PROJECT_ROOT"
17
+
18
+ NUM_GPUS=${NUM_GPUS:-4}
19
+ PLANNING_TABLE3_MODE=${PLANNING_TABLE3_MODE:-atlas_base}
20
+
21
+ EXTRA_ARGS=""
22
+ if [ -n "$RESUME_CKPT" ]; then
23
+ EXTRA_ARGS="--resume $RESUME_CKPT"
24
+ fi
25
+
26
+ deepspeed --num_gpus "$NUM_GPUS" train_atlas.py \
27
+ --llm_model pretrained/vicuna-7b-v1.5 \
28
+ --data_json data/atlas_nuscenes_train.json,data/atlas_planning_train_uniad_command.json,data/openlane_subsetB_lane_train_4pt.json \
29
+ --data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
30
+ --visual_token_mode offline \
31
+ --planning_table3_mode "$PLANNING_TABLE3_MODE" \
32
+ --precomputed_det_tokens work_dirs/precomputed_det_tokens_offline/train \
33
+ --precomputed_map_tokens work_dirs/precomputed_map_tokens_offline/train \
34
+ --deepspeed configs/ds_zero2.json \
35
+ --output_dir work_dirs/atlas_no_caption_offline \
36
+ --epochs 10 \
37
+ --lr 2e-5 \
38
+ --weight_decay 1e-4 \
39
+ --batch_size 1 \
40
+ --gradient_accumulation_steps 2 \
41
+ --warmup_ratio 0.03 \
42
+ --max_grad_norm 1.0 \
43
+ --log_steps 100 \
44
+ --save_epochs 1 \
45
+ --keep_last_n_ckpts 3 \
46
+ --seed 42 \
47
+ --num_workers 4 \
48
+ $EXTRA_ARGS
scripts/train_with_caption_balanced.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Atlas online training: detection + planning + lane + caption.
3
+ # Default: visual_token_mode=online, live frozen StreamPETR + TopoMLP.
4
+ # Data: 28k det + 24k plan + 28k lane + 169k caption = ~249k samples/epoch
5
+ # WARNING: caption data is 6x larger than other tasks — consider downsampling.
6
+ #
7
+ # Usage:
8
+ # bash scripts/train_with_caption_balanced.sh
9
+ # RESUME_CKPT=work_dirs/atlas_with_caption/epoch-4/checkpoint.pt bash scripts/train_with_caption_balanced.sh
10
+ # NUM_GPUS=8 bash scripts/train_with_caption_balanced.sh
11
+ set -e
12
+
13
+ PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
14
+ cd "$PROJECT_ROOT"
15
+
16
+ NUM_GPUS=${NUM_GPUS:-4}
17
+ PLANNING_TABLE3_MODE=${PLANNING_TABLE3_MODE:-atlas_high_level_ego}
18
+
19
+ EXTRA_ARGS=""
20
+ if [ -n "$RESUME_CKPT" ]; then
21
+ EXTRA_ARGS="--resume $RESUME_CKPT"
22
+ fi
23
+
24
+ deepspeed --num_gpus "$NUM_GPUS" train_atlas.py \
25
+ --llm_model pretrained/vicuna-7b-v1.5 \
26
+ --data_json data/atlas_nuscenes_train.json,data/atlas_planning_train_uniad_command.json,data/openlane_subsetB_lane_train_4pt.json,data/atlas_caption_train.json \
27
+ --data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
28
+ --visual_token_mode online \
29
+ --planning_table3_mode "$PLANNING_TABLE3_MODE" \
30
+ --streampetr_config configs/streampetr_atlas_aligned.py \
31
+ --streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
32
+ --topomlp_config configs/topomlp_atlas_aligned.py \
33
+ --topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
34
+ --deepspeed configs/ds_zero2.json \
35
+ --output_dir work_dirs/atlas_with_caption_online \
36
+ --epochs 8 \
37
+ --lr 2e-5 \
38
+ --weight_decay 1e-4 \
39
+ --batch_size 1 \
40
+ --gradient_accumulation_steps 2 \
41
+ --warmup_ratio 0.03 \
42
+ --max_grad_norm 1.0 \
43
+ --log_steps 100 \
44
+ --save_epochs 1 \
45
+ --keep_last_n_ckpts 3 \
46
+ --seed 42 \
47
+ --num_workers 4 \
48
+ $EXTRA_ARGS
scripts/vis_atlas_lane_gt_pred.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Visualize one Atlas lane sample with denser diagnostics.
4
+
5
+ Compared with the earlier plot, this version:
6
+ - keeps the original 4-point supervision visible,
7
+ - densifies those 4 points into easier-to-read curves,
8
+ - optionally overlays raw OpenLane GT centerlines when available,
9
+ - keeps all BEV views on the same axes for fair visual comparison.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import sys
17
+ from pathlib import Path
18
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
19
+
20
+ import numpy as np
21
+
22
+ try:
23
+ from scipy.interpolate import interp1d
24
+
25
+ SCIPY_AVAILABLE = True
26
+ except Exception:
27
+ interp1d = None
28
+ SCIPY_AVAILABLE = False
29
+
30
+ _REPO_ROOT = Path(__file__).resolve().parent.parent
31
+ if str(_REPO_ROOT) not in sys.path:
32
+ sys.path.insert(0, str(_REPO_ROOT))
33
+
34
+
35
+ def _load_json(path: Path):
36
+ with path.open("r", encoding="utf-8") as f:
37
+ return json.load(f)
38
+
39
+
40
+ def _find_eval_record(eval_obj: Dict, sample_id: str) -> Dict:
41
+ for rec in eval_obj.get("predictions", []):
42
+ if rec.get("sample_id") == sample_id:
43
+ return rec
44
+ raise KeyError(f"sample_id not found in eval_json: {sample_id}")
45
+
46
+
47
+ def _find_data_item(data_list: List[Dict], sample_id: str) -> Dict:
48
+ for it in data_list:
49
+ if it.get("id") == sample_id:
50
+ return it
51
+ raise KeyError(f"sample_id not found in data_json: {sample_id}")
52
+
53
+
54
+ def _extract_gt_answer(item: Dict) -> str:
55
+ for turn in item.get("conversations", []) or []:
56
+ if isinstance(turn, dict) and turn.get("from") in ("gpt", "assistant"):
57
+ return str(turn.get("value", ""))
58
+ return ""
59
+
60
+
61
+ def _openlane_to_paper_xy(x_fwd: float, y_left: float) -> Tuple[float, float]:
62
+ return (-float(y_left), float(x_fwd))
63
+
64
+
65
+ def _extract_xy_rows(points: Sequence) -> np.ndarray:
66
+ rows: List[List[float]] = []
67
+ for pt in points or []:
68
+ if isinstance(pt, dict):
69
+ wc = pt.get("world_coords", None)
70
+ if wc is None:
71
+ continue
72
+ rows.append([float(wc[0]), float(wc[1])])
73
+ elif isinstance(pt, (list, tuple)) and len(pt) >= 2:
74
+ rows.append([float(pt[0]), float(pt[1])])
75
+ if not rows:
76
+ return np.zeros((0, 2), dtype=np.float64)
77
+ return np.asarray(rows, dtype=np.float64)
78
+
79
+
80
+ def _lanes_as_arrays(parsed: List[Dict]) -> List[np.ndarray]:
81
+ lanes = []
82
+ for obj in parsed:
83
+ if obj.get("type") != "lane":
84
+ continue
85
+ arr = _extract_xy_rows(obj.get("points", []) or [])
86
+ if len(arr) >= 1:
87
+ lanes.append(arr)
88
+ return lanes
89
+
90
+
91
+ def _bounds_xy(lanes: Iterable[np.ndarray]) -> Optional[Tuple[np.ndarray, np.ndarray]]:
92
+ lanes = [ln for ln in lanes if len(ln) > 0]
93
+ if not lanes:
94
+ return None
95
+ allp = np.concatenate(lanes, axis=0)
96
+ return allp.min(axis=0), allp.max(axis=0)
97
+
98
+
99
+ def _lane_bev_distance(lane: np.ndarray) -> float:
100
+ if len(lane) == 0:
101
+ return float("inf")
102
+ centroid = lane.mean(axis=0)
103
+ return float(centroid[0] ** 2 + centroid[1] ** 2)
104
+
105
+
106
+ def _select_closest_lanes(lanes: List[np.ndarray], max_lanes: int) -> List[np.ndarray]:
107
+ if max_lanes <= 0 or len(lanes) <= max_lanes:
108
+ return list(lanes)
109
+ order = sorted(range(len(lanes)), key=lambda i: _lane_bev_distance(lanes[i]))
110
+ return [lanes[i] for i in order[:max_lanes]]
111
+
112
+
113
+ def _candidate_openlane_paths(openlane_root: Path, item: Dict) -> List[Path]:
114
+ segment_id = str(item.get("segment_id", "")).strip()
115
+ timestamp = str(item.get("timestamp", "")).strip()
116
+ if not segment_id or not timestamp:
117
+ return []
118
+ rel = Path("val") / segment_id / "info" / f"{timestamp}.json"
119
+ return [
120
+ openlane_root / "subset_B" / rel,
121
+ openlane_root / rel,
122
+ ]
123
+
124
+
125
+ def _load_raw_openlane_gt_lanes(openlane_root: Path, item: Dict) -> Tuple[List[np.ndarray], Optional[Path]]:
126
+ for info_path in _candidate_openlane_paths(openlane_root, item):
127
+ if not info_path.exists():
128
+ continue
129
+ obj = _load_json(info_path)
130
+ lanes = []
131
+ ann = obj.get("annotation", {}) or {}
132
+ for lane in ann.get("lane_centerline", []) or []:
133
+ rows = []
134
+ for pt in lane.get("points", []) or []:
135
+ if not isinstance(pt, (list, tuple)) or len(pt) < 2:
136
+ continue
137
+ rows.append(list(_openlane_to_paper_xy(float(pt[0]), float(pt[1]))))
138
+ if len(rows) >= 2:
139
+ lanes.append(np.asarray(rows, dtype=np.float64))
140
+ return lanes, info_path
141
+ return [], None
142
+
143
+
144
+ def _resample_lane(lane: np.ndarray, num: int = 41) -> np.ndarray:
145
+ if len(lane) <= 1 or num <= len(lane):
146
+ return lane.copy()
147
+
148
+ diffs = np.diff(lane, axis=0)
149
+ seg_len = np.linalg.norm(diffs, axis=1)
150
+ keep = np.concatenate(([True], seg_len > 1e-8))
151
+ lane = lane[keep]
152
+ if len(lane) <= 1:
153
+ return lane.copy()
154
+
155
+ diffs = np.diff(lane, axis=0)
156
+ seg_len = np.linalg.norm(diffs, axis=1)
157
+ arc = np.concatenate(([0.0], np.cumsum(seg_len)))
158
+ total = float(arc[-1])
159
+ if total <= 1e-8:
160
+ return lane.copy()
161
+
162
+ target = np.linspace(0.0, total, num=num)
163
+ if SCIPY_AVAILABLE and len(lane) >= 4:
164
+ try:
165
+ fx = interp1d(arc, lane[:, 0], kind="cubic")
166
+ fy = interp1d(arc, lane[:, 1], kind="cubic")
167
+ dense = np.stack([fx(target), fy(target)], axis=1)
168
+ return dense.astype(np.float64)
169
+ except Exception:
170
+ pass
171
+
172
+ x = np.interp(target, arc, lane[:, 0])
173
+ y = np.interp(target, arc, lane[:, 1])
174
+ return np.stack([x, y], axis=1).astype(np.float64)
175
+
176
+
177
+ def _draw_ego(ax, *, color: str = "k"):
178
+ import matplotlib.patches as patches
179
+
180
+ w, l = 1.85, 4.084
181
+ rect = patches.Rectangle(
182
+ (-w / 2.0, -l / 2.0),
183
+ w,
184
+ l,
185
+ linewidth=1.2,
186
+ edgecolor=color,
187
+ facecolor="none",
188
+ zorder=8,
189
+ )
190
+ ax.add_patch(rect)
191
+ ax.arrow(0.0, 0.0, 0.0, l * 0.7, head_width=0.6, head_length=0.8, fc=color, ec=color, zorder=9)
192
+ ax.scatter([0.0], [0.0], s=18, c=color, zorder=10)
193
+ ax.text(0.2, -0.5, "EGO", color=color, fontsize=8, zorder=10)
194
+
195
+
196
+ def _set_bev_axes(ax, *, title: str, xlim, ylim):
197
+ ax.set_title(title, fontsize=10)
198
+ ax.set_xlabel("X (m) — right", fontsize=8)
199
+ ax.set_ylabel("Y (m) — forward", fontsize=8)
200
+ ax.grid(True, linewidth=0.4, alpha=0.4)
201
+ ax.set_aspect("equal", adjustable="box")
202
+ ax.set_xlim(xlim)
203
+ ax.set_ylim(ylim)
204
+
205
+
206
+ def _plot_colorful_panel(
207
+ ax,
208
+ lanes: List[np.ndarray],
209
+ *,
210
+ title: str,
211
+ xlim,
212
+ ylim,
213
+ raw_reference_lanes: Optional[List[np.ndarray]] = None,
214
+ show_control_points: bool = True,
215
+ dense_points: int = 41,
216
+ ):
217
+ import matplotlib.pyplot as plt
218
+ from matplotlib.lines import Line2D
219
+
220
+ _set_bev_axes(ax, title=title, xlim=xlim, ylim=ylim)
221
+ cmap = plt.get_cmap("tab20")
222
+
223
+ if raw_reference_lanes:
224
+ for ln in raw_reference_lanes:
225
+ ax.plot(ln[:, 0], ln[:, 1], "--", linewidth=1.0, color="0.60", alpha=0.50, zorder=1)
226
+
227
+ for i, lane in enumerate(lanes):
228
+ color = cmap(i % 20)
229
+ dense = _resample_lane(lane, num=dense_points)
230
+ ax.plot(dense[:, 0], dense[:, 1], "-", linewidth=2.0, color=color, alpha=0.95, zorder=3)
231
+ if show_control_points:
232
+ ax.scatter(
233
+ lane[:, 0],
234
+ lane[:, 1],
235
+ s=20,
236
+ facecolors="white",
237
+ edgecolors=[color],
238
+ linewidths=0.9,
239
+ zorder=4,
240
+ )
241
+
242
+ handles = []
243
+ if raw_reference_lanes:
244
+ handles.append(Line2D([0], [0], color="0.60", linestyle="--", linewidth=1.2, label="Raw GT"))
245
+ handles.append(Line2D([0], [0], color="black", linewidth=2.0, label="Densified 4pt"))
246
+ if show_control_points:
247
+ handles.append(
248
+ Line2D(
249
+ [0],
250
+ [0],
251
+ marker="o",
252
+ color="black",
253
+ markerfacecolor="white",
254
+ markersize=5,
255
+ linewidth=0.0,
256
+ label="Control points",
257
+ )
258
+ )
259
+ ax.legend(handles=handles, loc="upper left", fontsize=7, frameon=True)
260
+ _draw_ego(ax, color="k")
261
+
262
+
263
+ def _plot_overlay_panel(
264
+ ax,
265
+ *,
266
+ gt_lanes: List[np.ndarray],
267
+ pred_lanes: List[np.ndarray],
268
+ raw_gt_lanes: Optional[List[np.ndarray]],
269
+ xlim,
270
+ ylim,
271
+ dense_points: int = 41,
272
+ ):
273
+ from matplotlib.lines import Line2D
274
+
275
+ _set_bev_axes(
276
+ ax,
277
+ title="Overlay: raw GT(gray), 4pt GT(green), Pred(red)",
278
+ xlim=xlim,
279
+ ylim=ylim,
280
+ )
281
+
282
+ if raw_gt_lanes:
283
+ for lane in raw_gt_lanes:
284
+ ax.plot(lane[:, 0], lane[:, 1], "--", linewidth=1.1, color="0.55", alpha=0.45, zorder=1)
285
+
286
+ for lane in gt_lanes:
287
+ dense = _resample_lane(lane, num=dense_points)
288
+ ax.plot(dense[:, 0], dense[:, 1], "-", linewidth=2.0, color="green", alpha=0.78, zorder=3)
289
+ ax.scatter(lane[:, 0], lane[:, 1], s=18, facecolors="white", edgecolors="green", linewidths=0.8, zorder=4)
290
+
291
+ for lane in pred_lanes:
292
+ dense = _resample_lane(lane, num=dense_points)
293
+ ax.plot(dense[:, 0], dense[:, 1], "-", linewidth=2.0, color="red", alpha=0.78, zorder=5)
294
+ ax.scatter(lane[:, 0], lane[:, 1], s=18, marker="x", color="red", linewidths=0.9, zorder=6)
295
+
296
+ handles = [
297
+ Line2D([0], [0], color="0.55", linestyle="--", linewidth=1.1, label="Raw GT"),
298
+ Line2D([0], [0], color="green", linewidth=2.0, label="Densified GT (4pt)"),
299
+ Line2D([0], [0], marker="o", color="green", markerfacecolor="white", markersize=5, linewidth=0.0, label="GT control"),
300
+ Line2D([0], [0], color="red", linewidth=2.0, label="Densified Pred (4pt)"),
301
+ Line2D([0], [0], marker="x", color="red", markersize=5, linewidth=0.0, label="Pred control"),
302
+ ]
303
+ ax.legend(handles=handles, loc="upper left", fontsize=7, frameon=True)
304
+ _draw_ego(ax, color="k")
305
+
306
+
307
+ def _load_images(data_root: Path, image_paths: List[str]):
308
+ from PIL import Image
309
+
310
+ imgs = []
311
+ for rp in image_paths:
312
+ p = Path(rp)
313
+ if not p.is_absolute():
314
+ p = data_root / rp
315
+ imgs.append(Image.open(p).convert("RGB"))
316
+ return imgs
317
+
318
+
319
+ def parse_args():
320
+ ap = argparse.ArgumentParser()
321
+ ap.add_argument("--eval_json", type=str, default="work_dirs/eval_nocap_ep2_lane50.json")
322
+ ap.add_argument("--data_json", type=str, default="data/openlane_subsetB_lane_val_4pt.json")
323
+ ap.add_argument("--data_root", type=str, default="/home/guoyuanbo/autodl-tmp/data/nuscenes")
324
+ ap.add_argument("--openlane_root", type=str, default="/home/guoyuanbo/autodl-tmp/OpenLane-V2")
325
+ ap.add_argument("--sample_id", type=str, required=True)
326
+ ap.add_argument("--out_png", type=str, default="work_dirs/atlas_lane_real_prediction_dense.png")
327
+ ap.add_argument("--xlim", type=float, nargs=2, default=[-30.0, 30.0])
328
+ ap.add_argument("--ylim", type=float, nargs=2, default=[-55.0, 55.0])
329
+ ap.add_argument("--dpi", type=int, default=170)
330
+ ap.add_argument("--dense_points", type=int, default=41)
331
+ return ap.parse_args()
332
+
333
+
334
+ def main():
335
+ args = parse_args()
336
+ repo = Path(__file__).resolve().parent.parent
337
+ eval_path = (repo / args.eval_json).resolve()
338
+ data_path = (repo / args.data_json).resolve()
339
+ data_root = Path(args.data_root).resolve()
340
+ openlane_root = Path(args.openlane_root).resolve()
341
+ out_png = (repo / args.out_png).resolve()
342
+
343
+ eval_obj = _load_json(eval_path)
344
+ data_list = _load_json(data_path)
345
+
346
+ rec = _find_eval_record(eval_obj, args.sample_id)
347
+ item = _find_data_item(data_list, args.sample_id)
348
+
349
+ pred_text = str(rec.get("generated_text", ""))
350
+ gt_text = _extract_gt_answer(item)
351
+
352
+ from src.eval.metrics import parse_atlas_output
353
+
354
+ pred_parsed = parse_atlas_output(pred_text)
355
+ gt_parsed = parse_atlas_output(gt_text)
356
+
357
+ pred_lanes = _lanes_as_arrays(pred_parsed)
358
+ gt_lanes_all = _lanes_as_arrays(gt_parsed)
359
+ raw_gt_lanes_all, raw_gt_path = _load_raw_openlane_gt_lanes(openlane_root, item)
360
+
361
+ expected_gt_count = rec.get("num_gt", 0)
362
+ try:
363
+ expected_gt_count = int(expected_gt_count)
364
+ except Exception:
365
+ expected_gt_count = 0
366
+ gt_lanes = _select_closest_lanes(gt_lanes_all, expected_gt_count)
367
+ raw_gt_lanes = _select_closest_lanes(raw_gt_lanes_all, expected_gt_count)
368
+
369
+ xlim = (float(args.xlim[0]), float(args.xlim[1]))
370
+ ylim = (float(args.ylim[0]), float(args.ylim[1]))
371
+
372
+ delta = None
373
+ if pred_lanes and gt_lanes:
374
+ p_all = np.concatenate(pred_lanes, axis=0)
375
+ g_all = np.concatenate(gt_lanes, axis=0)
376
+ delta = (p_all.mean(axis=0) - g_all.mean(axis=0)).tolist()
377
+
378
+ imgs = _load_images(data_root, item.get("image_paths", []) or [])
379
+ if len(imgs) != 6:
380
+ raise RuntimeError(f"expected 6 images, got {len(imgs)}")
381
+
382
+ order = [2, 0, 1, 4, 3, 5]
383
+ cam_titles = ["FRONT_LEFT", "FRONT", "FRONT_RIGHT", "BACK_LEFT", "BACK", "BACK_RIGHT"]
384
+ imgs = [imgs[i] for i in order]
385
+
386
+ import matplotlib.pyplot as plt
387
+ from matplotlib.gridspec import GridSpec
388
+
389
+ fig = plt.figure(figsize=(15.5, 9.4), dpi=args.dpi)
390
+ fig.suptitle("Atlas Lane Diagnostics — Control Points, Densified Curves, and Raw GT", fontsize=14, y=0.985)
391
+
392
+ gs = GridSpec(
393
+ 3,
394
+ 4,
395
+ figure=fig,
396
+ width_ratios=[1.0, 1.0, 1.0, 1.10],
397
+ height_ratios=[1.0, 1.0, 1.08],
398
+ wspace=0.24,
399
+ hspace=0.33,
400
+ )
401
+
402
+ for i in range(6):
403
+ r = 0 if i < 3 else 1
404
+ c = i % 3
405
+ ax = fig.add_subplot(gs[r, c])
406
+ ax.imshow(imgs[i])
407
+ ax.set_title(cam_titles[i], fontsize=9)
408
+ ax.axis("off")
409
+
410
+ ax_gt = fig.add_subplot(gs[0, 3])
411
+ _plot_colorful_panel(
412
+ ax_gt,
413
+ gt_lanes,
414
+ title=f"GT: 4pt densified ({len(gt_lanes)} lanes)",
415
+ xlim=xlim,
416
+ ylim=ylim,
417
+ raw_reference_lanes=raw_gt_lanes or None,
418
+ dense_points=int(args.dense_points),
419
+ )
420
+
421
+ ax_pr = fig.add_subplot(gs[1, 3])
422
+ _plot_colorful_panel(
423
+ ax_pr,
424
+ pred_lanes,
425
+ title=f"Pred: 4pt densified ({len(pred_lanes)} lanes)",
426
+ xlim=xlim,
427
+ ylim=ylim,
428
+ raw_reference_lanes=None,
429
+ dense_points=int(args.dense_points),
430
+ )
431
+
432
+ ax_ov = fig.add_subplot(gs[2, 0:3])
433
+ _plot_overlay_panel(
434
+ ax_ov,
435
+ gt_lanes=gt_lanes,
436
+ pred_lanes=pred_lanes,
437
+ raw_gt_lanes=raw_gt_lanes or None,
438
+ xlim=xlim,
439
+ ylim=ylim,
440
+ dense_points=int(args.dense_points),
441
+ )
442
+
443
+ ax_txt = fig.add_subplot(gs[2, 3])
444
+ ax_txt.axis("off")
445
+
446
+ lane_metrics = eval_obj.get("metrics", {}).get("lane", {})
447
+ model_name = Path(str(eval_obj.get("args", {}).get("checkpoint", ""))).parent.name or "atlas"
448
+ lines = [
449
+ "Atlas Lane Diagnostics",
450
+ f"Model: {model_name}",
451
+ f"Sample: {args.sample_id}",
452
+ "",
453
+ f"GT 4pt lanes: {len(gt_lanes)} vis / {len(gt_lanes_all)} full, {sum(len(x) for x in gt_lanes)} ctrl pts",
454
+ f"Pred 4pt lanes: {len(pred_lanes)} lanes, {sum(len(x) for x in pred_lanes)} ctrl pts",
455
+ f"Raw OpenLane GT: {len(raw_gt_lanes)} vis / {len(raw_gt_lanes_all)} full, {sum(len(x) for x in raw_gt_lanes)} pts",
456
+ f"Curve densifier: {'cubic spline' if SCIPY_AVAILABLE else 'linear'} ({args.dense_points} pts/lane)",
457
+ "",
458
+ "Interpretation:",
459
+ " - white circles: GT 4pt control points",
460
+ " - red x: pred 4pt control points",
461
+ " - gray dashed: raw OpenLane GT centerlines",
462
+ " - green/red solid: densified 4pt curves",
463
+ f" - vis GT subset matches eval num_gt={expected_gt_count}" if expected_gt_count > 0 else " - vis GT subset uses full GT",
464
+ "",
465
+ "Eval metrics (from eval_json):",
466
+ ]
467
+ for k in ("lane_f1", "method", "num_samples"):
468
+ if k in lane_metrics:
469
+ lines.append(f" {k}: {lane_metrics[k]}")
470
+
471
+ lines.append("")
472
+ if delta is not None:
473
+ lines.append("Mean shift (Pred - GT 4pt):")
474
+ lines.append(f" dx={delta[0]:+.3f} m, dy={delta[1]:+.3f} m")
475
+ lines.append("")
476
+
477
+ gb = _bounds_xy(gt_lanes)
478
+ pb = _bounds_xy(pred_lanes)
479
+ rb = _bounds_xy(raw_gt_lanes)
480
+ if gb is not None:
481
+ lines.append(f"GT4 bounds: x[{gb[0][0]:+.1f},{gb[1][0]:+.1f}] y[{gb[0][1]:+.1f},{gb[1][1]:+.1f}]")
482
+ if pb is not None:
483
+ lines.append(f"Pred bounds: x[{pb[0][0]:+.1f},{pb[1][0]:+.1f}] y[{pb[0][1]:+.1f},{pb[1][1]:+.1f}]")
484
+ if rb is not None:
485
+ lines.append(f"Raw bounds: x[{rb[0][0]:+.1f},{rb[1][0]:+.1f}] y[{rb[0][1]:+.1f},{rb[1][1]:+.1f}]")
486
+ if raw_gt_path is not None:
487
+ lines.append("")
488
+ lines.append(f"Raw GT source: {raw_gt_path}")
489
+
490
+ ax_txt.text(0.0, 1.0, "\n".join(lines), va="top", ha="left", fontsize=8, family="monospace")
491
+
492
+ out_png.parent.mkdir(parents=True, exist_ok=True)
493
+ fig.savefig(out_png, bbox_inches="tight")
494
+ plt.close(fig)
495
+
496
+ print(f"[saved] {out_png}")
497
+
498
+
499
+ if __name__ == "__main__":
500
+ main()
scripts/vis_atlas_planning_qualitative.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Paper-style qualitative planning visualization (Figure-4-like):
4
+
5
+ - Left: 6-camera mosaic (2x3)
6
+ - Right: BEV panel with planned trajectory
7
+
8
+ Optionally overlays the planned trajectory onto CAM_FRONT using *fixed* nuScenes
9
+ camera calibration (loaded from v1.0-trainval sensor/calibrated_sensor tables).
10
+ This avoids scanning the 1.3GB sample_data.json.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import json
17
+ import math
18
+ import mmap
19
+ import pickle
20
+ import sys
21
+ from pathlib import Path
22
+ from typing import Dict, Iterable, List, Optional, Tuple
23
+
24
+ import numpy as np
25
+
26
+ _REPO_ROOT = Path(__file__).resolve().parent.parent
27
+ if str(_REPO_ROOT) not in sys.path:
28
+ sys.path.insert(0, str(_REPO_ROOT))
29
+
30
+
31
+ CAM_ORDER_DATAJSON = [
32
+ "CAM_FRONT",
33
+ "CAM_FRONT_RIGHT",
34
+ "CAM_FRONT_LEFT",
35
+ "CAM_BACK",
36
+ "CAM_BACK_LEFT",
37
+ "CAM_BACK_RIGHT",
38
+ ]
39
+
40
+ CAM_ORDER_PAPER = [
41
+ "CAM_FRONT_LEFT",
42
+ "CAM_FRONT",
43
+ "CAM_FRONT_RIGHT",
44
+ "CAM_BACK_LEFT",
45
+ "CAM_BACK",
46
+ "CAM_BACK_RIGHT",
47
+ ]
48
+
49
+ _IDX_REORDER = [2, 0, 1, 4, 3, 5] # data_json -> paper layout
50
+
51
+
52
+ def _load_json(path: Path):
53
+ with path.open("r", encoding="utf-8") as f:
54
+ return json.load(f)
55
+
56
+
57
+ def _load_pickle(path: Path):
58
+ with path.open("rb") as f:
59
+ return pickle.load(f)
60
+
61
+
62
+ def _extract_results_list_for_token(mm: "mmap.mmap", token: str) -> List[Dict]:
63
+ """
64
+ Extract `results[token]` list from a nuScenes detection results JSON mmap.
65
+
66
+ The file is a minified JSON like:
67
+ {"meta": {...}, "results": {"<token>": [ {...}, ... ], ...}}
68
+ """
69
+ pat = (f"\"{token}\":").encode("utf-8")
70
+ idx = mm.find(pat)
71
+ if idx < 0:
72
+ return []
73
+ j = idx + len(pat)
74
+ # Skip whitespace
75
+ while j < len(mm) and mm[j] in b" \t\r\n":
76
+ j += 1
77
+ if j >= len(mm) or mm[j : j + 1] != b"[":
78
+ return []
79
+
80
+ start = j
81
+ depth = 0
82
+ in_str = False
83
+ esc = False
84
+ k = start
85
+ end = None
86
+ while k < len(mm):
87
+ c = mm[k]
88
+ if in_str:
89
+ if esc:
90
+ esc = False
91
+ elif c == 0x5C: # backslash
92
+ esc = True
93
+ elif c == 0x22: # quote
94
+ in_str = False
95
+ else:
96
+ if c == 0x22:
97
+ in_str = True
98
+ elif c == 0x5B: # [
99
+ depth += 1
100
+ elif c == 0x5D: # ]
101
+ depth -= 1
102
+ if depth == 0:
103
+ end = k + 1
104
+ break
105
+ k += 1
106
+
107
+ if end is None:
108
+ return []
109
+ try:
110
+ return json.loads(mm[start:end].decode("utf-8"))
111
+ except Exception:
112
+ return []
113
+
114
+
115
+ def _load_ego_pose_map(nuscenes_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
116
+ """
117
+ Map: sample_token -> (R_ego2global(3x3), t_ego2global(3,))
118
+ """
119
+ pkl = nuscenes_root / "nuscenes_infos_val.pkl"
120
+ if not pkl.exists():
121
+ raise FileNotFoundError(f"Missing {pkl} (required to transform detector global boxes to ego frame).")
122
+ obj = _load_pickle(pkl)
123
+ infos = obj["infos"] if isinstance(obj, dict) and "infos" in obj else obj
124
+ out: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
125
+ for it in infos:
126
+ tok = str(it.get("token", ""))
127
+ if not tok:
128
+ continue
129
+ t = np.asarray(it.get("ego2global_translation", [0.0, 0.0, 0.0]), dtype=np.float64).reshape(3)
130
+ q = it.get("ego2global_rotation", [1.0, 0.0, 0.0, 0.0])
131
+ if not (isinstance(q, (list, tuple)) and len(q) == 4):
132
+ continue
133
+ R = _quat_to_rotmat(float(q[0]), float(q[1]), float(q[2]), float(q[3]))
134
+ out[tok] = (R, t)
135
+ return out
136
+
137
+
138
+ def _quat_to_rotmat(qw: float, qx: float, qy: float, qz: float) -> np.ndarray:
139
+ # nuScenes quaternions are in (w, x, y, z)
140
+ n = math.sqrt(qw * qw + qx * qx + qy * qy + qz * qz)
141
+ if n < 1e-12:
142
+ return np.eye(3, dtype=np.float64)
143
+ qw, qx, qy, qz = qw / n, qx / n, qy / n, qz / n
144
+ # Standard quaternion -> rotation matrix
145
+ xx, yy, zz = qx * qx, qy * qy, qz * qz
146
+ xy, xz, yz = qx * qy, qx * qz, qy * qz
147
+ wx, wy, wz = qw * qx, qw * qy, qw * qz
148
+ return np.array(
149
+ [
150
+ [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)],
151
+ [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)],
152
+ [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)],
153
+ ],
154
+ dtype=np.float64,
155
+ )
156
+
157
+
158
+ def _load_fixed_cam_front_calib(nuscenes_root: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
159
+ """
160
+ Returns (R_c2e, t_c2e, K) for CAM_FRONT.
161
+ Uses the first calibrated_sensor record matching CAM_FRONT.
162
+ """
163
+ meta_root = nuscenes_root / "v1.0-trainval"
164
+ sensor = _load_json(meta_root / "sensor.json")
165
+ calib = _load_json(meta_root / "calibrated_sensor.json")
166
+
167
+ # sensor.json is a list of dicts
168
+ sensor_token = None
169
+ for rec in sensor:
170
+ if rec.get("channel") == "CAM_FRONT":
171
+ sensor_token = rec.get("token")
172
+ break
173
+ if sensor_token is None:
174
+ raise RuntimeError("CAM_FRONT not found in sensor.json")
175
+
176
+ calib_rec = None
177
+ for rec in calib:
178
+ if rec.get("sensor_token") == sensor_token:
179
+ calib_rec = rec
180
+ break
181
+ if calib_rec is None:
182
+ raise RuntimeError("No calibrated_sensor record found for CAM_FRONT")
183
+
184
+ t = np.asarray(calib_rec.get("translation", [0.0, 0.0, 0.0]), dtype=np.float64).reshape(3)
185
+ q = calib_rec.get("rotation", [1.0, 0.0, 0.0, 0.0])
186
+ if not (isinstance(q, (list, tuple)) and len(q) == 4):
187
+ raise RuntimeError(f"Unexpected CAM_FRONT rotation quaternion: {q}")
188
+ R = _quat_to_rotmat(float(q[0]), float(q[1]), float(q[2]), float(q[3]))
189
+
190
+ K = np.asarray(calib_rec.get("camera_intrinsic", np.eye(3).tolist()), dtype=np.float64)
191
+ if K.shape != (3, 3):
192
+ raise RuntimeError(f"Unexpected CAM_FRONT camera_intrinsic shape: {K.shape}")
193
+ return R, t, K
194
+
195
+
196
+ def _load_fixed_cam_calibs(nuscenes_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
197
+ """
198
+ Load fixed (R_c2e, t_c2e, K) for all nuScenes cameras by channel name.
199
+ """
200
+ meta_root = nuscenes_root / "v1.0-trainval"
201
+ sensor = _load_json(meta_root / "sensor.json")
202
+ calib = _load_json(meta_root / "calibrated_sensor.json")
203
+
204
+ sensor_token_by_channel: Dict[str, str] = {}
205
+ for rec in sensor:
206
+ ch = rec.get("channel")
207
+ tok = rec.get("token")
208
+ if isinstance(ch, str) and isinstance(tok, str):
209
+ sensor_token_by_channel[ch] = tok
210
+
211
+ calib_by_sensor_token: Dict[str, Dict] = {}
212
+ for rec in calib:
213
+ st = rec.get("sensor_token")
214
+ if isinstance(st, str):
215
+ calib_by_sensor_token[st] = rec
216
+
217
+ out: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
218
+ for ch, st in sensor_token_by_channel.items():
219
+ rec = calib_by_sensor_token.get(st)
220
+ if not rec or "camera_intrinsic" not in rec:
221
+ continue
222
+ t = np.asarray(rec.get("translation", [0.0, 0.0, 0.0]), dtype=np.float64).reshape(3)
223
+ q = rec.get("rotation", [1.0, 0.0, 0.0, 0.0])
224
+ if not (isinstance(q, (list, tuple)) and len(q) == 4):
225
+ continue
226
+ R = _quat_to_rotmat(float(q[0]), float(q[1]), float(q[2]), float(q[3]))
227
+ K = np.asarray(rec.get("camera_intrinsic", np.eye(3).tolist()), dtype=np.float64)
228
+ if K.shape != (3, 3):
229
+ continue
230
+ out[ch] = (R, t, K)
231
+ return out
232
+
233
+
234
+ def _paper_xy_to_nuscenes_ego_xyz(x_right: float, y_fwd: float, z_up: float = 0.0) -> np.ndarray:
235
+ # nuScenes ego: x forward, y left, z up
236
+ x_fwd = float(y_fwd)
237
+ y_left = float(-x_right)
238
+ return np.array([x_fwd, y_left, float(z_up)], dtype=np.float64)
239
+
240
+
241
+ def _project_ego_points_to_cam(
242
+ pts_ego_xyz: np.ndarray, # (N,3) in nuScenes ego
243
+ R_c2e: np.ndarray,
244
+ t_c2e: np.ndarray,
245
+ K: np.ndarray,
246
+ ) -> np.ndarray:
247
+ """
248
+ Project nuScenes ego-frame 3D points to pixel coords in CAM_FRONT.
249
+ Returns (M,2) pixels for points with z_cam > 0.
250
+ """
251
+ # ego -> cam: p_cam = R_c2e^T (p_ego - t_c2e)
252
+ R_e2c = R_c2e.T
253
+ pts_cam = (R_e2c @ (pts_ego_xyz - t_c2e[None, :]).T).T # (N,3)
254
+ z = pts_cam[:, 2]
255
+ keep = z > 1e-3
256
+ pts_cam = pts_cam[keep]
257
+ if pts_cam.shape[0] == 0:
258
+ return np.zeros((0, 2), dtype=np.float64)
259
+ x = pts_cam[:, 0] / pts_cam[:, 2]
260
+ y = pts_cam[:, 1] / pts_cam[:, 2]
261
+ fx, fy = float(K[0, 0]), float(K[1, 1])
262
+ cx, cy = float(K[0, 2]), float(K[1, 2])
263
+ u = fx * x + cx
264
+ v = fy * y + cy
265
+ return np.stack([u, v], axis=1)
266
+
267
+
268
+ def _project_one_ego_point_to_cam(
269
+ pt_ego_xyz: np.ndarray, # (3,)
270
+ R_c2e: np.ndarray,
271
+ t_c2e: np.ndarray,
272
+ K: np.ndarray,
273
+ ) -> Optional[Tuple[float, float, float]]:
274
+ """Return (u, v, z_cam) or None if behind camera."""
275
+ pt_ego_xyz = np.asarray(pt_ego_xyz, dtype=np.float64).reshape(3)
276
+ R_e2c = R_c2e.T
277
+ pt_cam = R_e2c @ (pt_ego_xyz - t_c2e)
278
+ zc = float(pt_cam[2])
279
+ if zc <= 1e-3:
280
+ return None
281
+ fx, fy = float(K[0, 0]), float(K[1, 1])
282
+ cx, cy = float(K[0, 2]), float(K[1, 2])
283
+ u = fx * float(pt_cam[0] / zc) + cx
284
+ v = fy * float(pt_cam[1] / zc) + cy
285
+ return float(u), float(v), zc
286
+
287
+
288
+ def _draw_ego_bev(ax, *, color: str = "k"):
289
+ import matplotlib.patches as patches
290
+
291
+ w, l = 1.85, 4.084
292
+ rect = patches.Rectangle(
293
+ (-w / 2.0, -l / 2.0),
294
+ w,
295
+ l,
296
+ linewidth=1.2,
297
+ edgecolor=color,
298
+ facecolor="none",
299
+ zorder=10,
300
+ )
301
+ ax.add_patch(rect)
302
+ ax.arrow(0.0, 0.0, 0.0, l * 0.8, head_width=0.6, head_length=0.8, fc=color, ec=color, zorder=11)
303
+ ax.text(0.2, -0.6, "EGO", color=color, fontsize=8, zorder=12)
304
+
305
+
306
+ def _box_corners_xy(cx: float, cy: float, w: float, l: float, yaw: float) -> np.ndarray:
307
+ """
308
+ Oriented rectangle corners in XY (paper coords).
309
+ yaw is treated as radians in the same XY frame (best-effort).
310
+
311
+ IMPORTANT: In our planning eval JSON, `yaw` behaves like a standard heading
312
+ angle measured from +X (right) axis (yaw-from-x). So:
313
+ - yaw = 0 => vehicle length points to +X (right)
314
+ - yaw = +pi/2 => vehicle length points to +Y (forward)
315
+ """
316
+ c, s = math.cos(yaw), math.sin(yaw)
317
+ center = np.array([cx, cy], dtype=np.float64)
318
+
319
+ # Length axis (heading) and width axis (perpendicular), yaw-from-x
320
+ d_len = np.array([c, s], dtype=np.float64) * (l / 2.0)
321
+ d_wid = np.array([-s, c], dtype=np.float64) * (w / 2.0)
322
+
323
+ corners = np.stack(
324
+ [
325
+ center + d_len + d_wid,
326
+ center + d_len - d_wid,
327
+ center - d_len - d_wid,
328
+ center - d_len + d_wid,
329
+ ],
330
+ axis=0,
331
+ )
332
+ return corners
333
+
334
+
335
+ def _title_case_command(cmd: str) -> str:
336
+ c = (cmd or "").strip().lower()
337
+ if c == "turn left":
338
+ return "Turn Left"
339
+ if c == "turn right":
340
+ return "Turn Right"
341
+ if c == "go straight":
342
+ return "Go Straight"
343
+ return cmd
344
+
345
+
346
+ def _short_cat(cat: str) -> str:
347
+ c = (cat or "").strip()
348
+ if not c:
349
+ return "obj"
350
+ tail = c.split(".")[-1]
351
+ mapping = {
352
+ "car": "car",
353
+ "truck": "truck",
354
+ "trailer": "trailer",
355
+ "bus": "bus",
356
+ "construction": "cveh",
357
+ "construction_vehicle": "cveh",
358
+ "pedestrian": "ped",
359
+ "trafficcone": "cone",
360
+ "traffic_cone": "cone",
361
+ "barrier": "barrier",
362
+ "motorcycle": "moto",
363
+ "bicycle": "bike",
364
+ }
365
+ return mapping.get(tail, tail[:10])
366
+
367
+
368
+ def _select_default_samples(data_items: List[Dict]) -> List[str]:
369
+ # prefer (turn right, go straight) like the paper figure
370
+ by_cmd: Dict[str, str] = {}
371
+ for it in data_items:
372
+ cmd = (it.get("ego_motion", {}) or {}).get("command")
373
+ if cmd and cmd not in by_cmd and it.get("id"):
374
+ by_cmd[str(cmd)] = str(it["id"])
375
+ out = []
376
+ if "turn right" in by_cmd:
377
+ out.append(by_cmd["turn right"])
378
+ if "go straight" in by_cmd:
379
+ out.append(by_cmd["go straight"])
380
+ if not out:
381
+ out = [str(it.get("id")) for it in data_items[:2] if it.get("id")]
382
+ return out[:2]
383
+
384
+
385
+ def parse_args():
386
+ ap = argparse.ArgumentParser()
387
+ ap.add_argument("--eval_json", type=str, default="work_dirs/eval_nocap_ep2_plan50.json")
388
+ ap.add_argument("--data_json", type=str, default="data/_eval_ep2_plan50.json")
389
+ ap.add_argument("--data_root", type=str, default="/home/guoyuanbo/autodl-tmp/data/nuscenes")
390
+ ap.add_argument("--out_png", type=str, default="work_dirs/atlas_planning_qualitative.png")
391
+ ap.add_argument("--sample_ids", type=str, nargs="*", default=None, help="Provide 1-2 sample ids; default picks turn right & go straight.")
392
+ ap.add_argument("--bev_xlim", type=float, nargs=2, default=[-30.0, 30.0])
393
+ ap.add_argument("--bev_ylim", type=float, nargs=2, default=[-10.0, 55.0])
394
+ ap.add_argument("--dpi", type=int, default=200)
395
+ ap.add_argument("--draw_gt_boxes", action="store_true", default=True)
396
+ ap.add_argument("--no_draw_gt_boxes", action="store_false", dest="draw_gt_boxes")
397
+ ap.add_argument("--overlay_on_front_cam", action="store_true", default=True)
398
+ ap.add_argument("--no_overlay_on_front_cam", action="store_false", dest="overlay_on_front_cam")
399
+ ap.add_argument("--highlight_front_visible_boxes", action="store_true", default=True)
400
+ ap.add_argument("--no_highlight_front_visible_boxes", action="store_false", dest="highlight_front_visible_boxes")
401
+ ap.add_argument("--max_front_labels", type=int, default=8, help="Max GT box labels shown on CAM_FRONT and BEV.")
402
+ ap.add_argument("--max_cam_labels", type=int, default=4, help="Max GT labels per camera view (all 6 views).")
403
+ ap.add_argument("--bev_only_visible", action="store_true", default=True, help="If set, BEV draws only FRONT* visible GT boxes.")
404
+ ap.add_argument("--bev_show_all", action="store_false", dest="bev_only_visible", help="Draw all GT boxes faintly in BEV (plus highlighted).")
405
+ ap.add_argument(
406
+ "--det_results_json",
407
+ type=str,
408
+ default="external/StreamPETR/val/work_dirs/streampetr_atlas_aligned/Tue_Feb_10_23_13_58_2026/pts_bbox/results_nusc.json",
409
+ help="nuScenes detection results JSON (StreamPETR).",
410
+ )
411
+ ap.add_argument("--draw_det_pred_boxes", action="store_true", default=True)
412
+ ap.add_argument("--no_draw_det_pred_boxes", action="store_false", dest="draw_det_pred_boxes")
413
+ ap.add_argument("--det_score_thresh", type=float, default=0.3)
414
+ ap.add_argument("--det_max_boxes", type=int, default=30)
415
+ ap.add_argument("--det_label_topk", type=int, default=6, help="Label top-k detector boxes in BEV.")
416
+ return ap.parse_args()
417
+
418
+
419
+ def main():
420
+ args = parse_args()
421
+ repo = _REPO_ROOT
422
+ eval_path = (repo / args.eval_json).resolve()
423
+ data_path = (repo / args.data_json).resolve()
424
+ nuscenes_root = Path(args.data_root).resolve()
425
+ out_png = (repo / args.out_png).resolve()
426
+
427
+ eval_obj = _load_json(eval_path)
428
+ data_items = _load_json(data_path)
429
+
430
+ # Build lookup maps
431
+ pred_text_by_id: Dict[str, str] = {}
432
+ for rec in eval_obj.get("predictions", []):
433
+ sid = str(rec.get("sample_id", ""))
434
+ if not sid:
435
+ continue
436
+ pred_text_by_id[sid] = str(rec.get("generated_text", ""))
437
+
438
+ item_by_id: Dict[str, Dict] = {str(it.get("id")): it for it in data_items if it.get("id")}
439
+
440
+ sample_ids = list(args.sample_ids) if args.sample_ids else _select_default_samples(data_items)
441
+ sample_ids = [sid for sid in sample_ids if sid in item_by_id]
442
+ if not sample_ids:
443
+ raise RuntimeError("No valid sample_ids found. Provide --sample_ids explicitly.")
444
+ if len(sample_ids) > 2:
445
+ sample_ids = sample_ids[:2]
446
+
447
+ # Calibration for projecting onto CAM_FRONT
448
+ cam_calibs: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
449
+ if args.overlay_on_front_cam:
450
+ cam_calibs = _load_fixed_cam_calibs(nuscenes_root)
451
+
452
+ from src.eval.metrics import parse_planning_output
453
+
454
+ # Detector predictions (StreamPETR nuScenes results) are in GLOBAL coords.
455
+ # We need ego pose to transform them into ego/paper coords for BEV plotting.
456
+ det_mm = None
457
+ det_f = None
458
+ ego_pose_map = None
459
+ det_results_path = (repo / args.det_results_json).resolve()
460
+ if args.draw_det_pred_boxes and det_results_path.exists():
461
+ try:
462
+ ego_pose_map = _load_ego_pose_map(nuscenes_root)
463
+ det_f = det_results_path.open("rb")
464
+ det_mm = mmap.mmap(det_f.fileno(), 0, access=mmap.ACCESS_READ)
465
+ except Exception:
466
+ det_mm = None
467
+ if det_f is not None:
468
+ det_f.close()
469
+ det_f = None
470
+
471
+ import matplotlib.pyplot as plt
472
+ from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
473
+
474
+ nrows = len(sample_ids)
475
+ fig = plt.figure(figsize=(12.5, 6.5 * nrows), dpi=args.dpi)
476
+ fig.suptitle("Atlas Planning — 6-Camera + BEV (Paper-style)", fontsize=16, y=0.99)
477
+
478
+ outer = GridSpec(nrows, 2, figure=fig, width_ratios=[3.2, 1.4], wspace=0.10, hspace=0.18)
479
+
480
+ from PIL import Image
481
+
482
+ for r, sid in enumerate(sample_ids):
483
+ item = item_by_id[sid]
484
+ pred_text = pred_text_by_id.get(sid, "")
485
+ plan = parse_planning_output(pred_text) if pred_text else None
486
+ pred_wps = (plan or {}).get("waypoints", []) if plan else []
487
+ gt_wps = (item.get("ego_motion", {}) or {}).get("waypoints", []) or []
488
+ cmd = (item.get("ego_motion", {}) or {}).get("command", "")
489
+
490
+ # Load images
491
+ rel_paths: List[str] = list(item.get("image_paths", []) or [])
492
+ if len(rel_paths) != 6:
493
+ raise RuntimeError(f"sample {sid}: expected 6 image_paths, got {len(rel_paths)}")
494
+ imgs = []
495
+ for rp in rel_paths:
496
+ p = Path(rp)
497
+ if not p.is_absolute():
498
+ p = nuscenes_root / rp
499
+ imgs.append(Image.open(p).convert("RGB"))
500
+ imgs = [imgs[i] for i in _IDX_REORDER]
501
+ front_w, front_h = imgs[1].size # CAM_FRONT in paper order
502
+
503
+ # Left: 2x3 image mosaic
504
+ left = GridSpecFromSubplotSpec(2, 3, subplot_spec=outer[r, 0], wspace=0.02, hspace=0.06)
505
+ ax_imgs = []
506
+ for i in range(6):
507
+ ax = fig.add_subplot(left[i // 3, i % 3])
508
+ ax.imshow(imgs[i])
509
+ # Lock image axis limits so later overlays don't autoscale-shrink the image.
510
+ w_i, h_i = imgs[i].size
511
+ ax.set_xlim(0, w_i)
512
+ ax.set_ylim(h_i, 0)
513
+ ax.set_title(CAM_ORDER_PAPER[i], fontsize=9)
514
+ ax.axis("off")
515
+ ax_imgs.append(ax)
516
+
517
+ # Collect GT boxes that are visible in each FRONT* camera (by projected centers).
518
+ boxes = item.get("gt_boxes_3d", []) or []
519
+
520
+ def _visible_for_channel(channel: str, img_w: int, img_h: int):
521
+ if not (args.highlight_front_visible_boxes and args.overlay_on_front_cam):
522
+ return []
523
+ if channel not in cam_calibs:
524
+ return []
525
+ R_c2e, t_c2e, K = cam_calibs[channel]
526
+ out = []
527
+ for bi, b in enumerate(boxes):
528
+ if not isinstance(b, dict) or "world_coords" not in b:
529
+ continue
530
+ wc = b.get("world_coords", [0.0, 0.0, 0.0])
531
+ if not (isinstance(wc, (list, tuple)) and len(wc) >= 3):
532
+ continue
533
+ x_r, y_f, z_u = float(wc[0]), float(wc[1]), float(wc[2])
534
+ pt_ego = _paper_xy_to_nuscenes_ego_xyz(x_r, y_f, z_u)
535
+ uv = _project_one_ego_point_to_cam(pt_ego, R_c2e, t_c2e, K)
536
+ if uv is None:
537
+ continue
538
+ u, v, _zc = uv
539
+ if 0.0 <= u <= float(img_w) and 0.0 <= v <= float(img_h):
540
+ d = float(math.hypot(x_r, y_f))
541
+ out.append((d, bi, str(b.get("category", "")), x_r, y_f, z_u, u, v))
542
+ out.sort(key=lambda t: t[0])
543
+ return out
544
+
545
+ vis_by_ch = {
546
+ "CAM_FRONT_LEFT": _visible_for_channel("CAM_FRONT_LEFT", imgs[0].size[0], imgs[0].size[1]),
547
+ "CAM_FRONT": _visible_for_channel("CAM_FRONT", imgs[1].size[0], imgs[1].size[1]),
548
+ "CAM_FRONT_RIGHT": _visible_for_channel("CAM_FRONT_RIGHT", imgs[2].size[0], imgs[2].size[1]),
549
+ "CAM_BACK_LEFT": _visible_for_channel("CAM_BACK_LEFT", imgs[3].size[0], imgs[3].size[1]),
550
+ "CAM_BACK": _visible_for_channel("CAM_BACK", imgs[4].size[0], imgs[4].size[1]),
551
+ "CAM_BACK_RIGHT": _visible_for_channel("CAM_BACK_RIGHT", imgs[5].size[0], imgs[5].size[1]),
552
+ }
553
+ visible_union_all = []
554
+ if any(vis_by_ch.get(ch) for ch in vis_by_ch.keys()):
555
+ seen = set()
556
+ for ch in ("CAM_FRONT_LEFT", "CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_BACK_LEFT", "CAM_BACK", "CAM_BACK_RIGHT"):
557
+ for tup in vis_by_ch.get(ch, []):
558
+ bi = tup[1]
559
+ if bi in seen:
560
+ continue
561
+ seen.add(bi)
562
+ visible_union_all.append(tup)
563
+ visible_union_all.sort(key=lambda t: t[0])
564
+ visible_union = visible_union_all[: max(int(args.max_front_labels), 0)]
565
+
566
+ # Overlay trajectory on CAM_FRONT (middle of top row in paper order)
567
+ if args.overlay_on_front_cam and pred_wps and "CAM_FRONT" in cam_calibs:
568
+ R_c2e, t_c2e, K = cam_calibs["CAM_FRONT"]
569
+ ax_front = ax_imgs[1]
570
+ # Pred
571
+ pts_pred = np.array([_paper_xy_to_nuscenes_ego_xyz(x, y, 0.0) for x, y in pred_wps], dtype=np.float64)
572
+ uv_pred = _project_ego_points_to_cam(pts_pred, R_c2e, t_c2e, K)
573
+ if uv_pred.shape[0] >= 2:
574
+ ax_front.plot(uv_pred[:, 0], uv_pred[:, 1], "-o", color="#00C2FF", linewidth=2.0, markersize=4.0, alpha=0.95)
575
+ # GT (dashed)
576
+ if gt_wps:
577
+ pts_gt = np.array([_paper_xy_to_nuscenes_ego_xyz(x, y, 0.0) for x, y in gt_wps], dtype=np.float64)
578
+ uv_gt = _project_ego_points_to_cam(pts_gt, R_c2e, t_c2e, K)
579
+ if uv_gt.shape[0] >= 2:
580
+ ax_front.plot(uv_gt[:, 0], uv_gt[:, 1], "--o", color="white", linewidth=2.0, markersize=3.5, alpha=0.85)
581
+
582
+ # Highlight GT centers on the three FRONT* cameras (helps validate BEV↔camera consistency).
583
+ if args.highlight_front_visible_boxes and args.overlay_on_front_cam:
584
+ cam_axes = {
585
+ "CAM_FRONT_LEFT": ax_imgs[0],
586
+ "CAM_FRONT": ax_imgs[1],
587
+ "CAM_FRONT_RIGHT": ax_imgs[2],
588
+ "CAM_BACK_LEFT": ax_imgs[3],
589
+ "CAM_BACK": ax_imgs[4],
590
+ "CAM_BACK_RIGHT": ax_imgs[5],
591
+ }
592
+ for ch, ax in cam_axes.items():
593
+ vis = (vis_by_ch.get(ch, []) or [])[: max(int(args.max_cam_labels), 0)]
594
+ if not vis:
595
+ continue
596
+ for _d, _bi, cat, _x, _y, _z, u, v in vis:
597
+ ax.scatter([u], [v], s=26, c="#FF4D4D", edgecolors="black", linewidths=0.6, zorder=20)
598
+ ax.text(
599
+ u + 6.0,
600
+ v - 6.0,
601
+ _short_cat(cat),
602
+ color="white",
603
+ fontsize=8,
604
+ fontweight="bold",
605
+ ha="left",
606
+ va="bottom",
607
+ zorder=21,
608
+ bbox=dict(boxstyle="round,pad=0.15", facecolor="black", edgecolor="none", alpha=0.6),
609
+ )
610
+
611
+ # Right: BEV
612
+ ax_bev = fig.add_subplot(outer[r, 1])
613
+ ax_bev.set_title("BEV", fontsize=12)
614
+ ax_bev.set_xlabel("X (m) — right", fontsize=9)
615
+ ax_bev.set_ylabel("Y (m) — forward", fontsize=9)
616
+ ax_bev.set_aspect("equal", adjustable="box")
617
+ ax_bev.grid(True, linewidth=0.4, alpha=0.35)
618
+ ax_bev.set_xlim(float(args.bev_xlim[0]), float(args.bev_xlim[1]))
619
+ ax_bev.set_ylim(float(args.bev_ylim[0]), float(args.bev_ylim[1]))
620
+
621
+ _draw_ego_bev(ax_bev, color="black")
622
+
623
+ # Draw GT boxes (paper-style context)
624
+ if args.draw_gt_boxes:
625
+ boxes = item.get("gt_boxes_3d", []) or []
626
+
627
+ # When debugging camera↔BEV consistency, draw only the boxes that are visible
628
+ # in the shown camera views to avoid confusing off-screen objects.
629
+ if args.bev_only_visible and visible_union_all:
630
+ boxes_to_draw = [boxes[bi] for _d, bi, _cat, *_rest in visible_union_all if 0 <= bi < len(boxes)]
631
+ else:
632
+ boxes_to_draw = boxes
633
+
634
+ for b in boxes_to_draw:
635
+ if not isinstance(b, dict) or "world_coords" not in b:
636
+ continue
637
+ wc = b.get("world_coords", [0.0, 0.0, 0.0])
638
+ cx, cy = float(wc[0]), float(wc[1])
639
+ w = float(b.get("w", 1.8))
640
+ l = float(b.get("l", 4.0))
641
+ yaw = float(b.get("yaw", 0.0))
642
+ corners = _box_corners_xy(cx, cy, w, l, yaw)
643
+ poly = np.vstack([corners, corners[0:1]])
644
+ ax_bev.plot(poly[:, 0], poly[:, 1], color="#FF5A5A", linewidth=0.8, alpha=0.50)
645
+
646
+ # Re-draw FRONT* visible boxes thicker + label them in BEV.
647
+ if args.highlight_front_visible_boxes and visible_union:
648
+ for _d, bi, cat, x_r, y_f, _z, _u, _v in visible_union:
649
+ if bi >= len(boxes):
650
+ continue
651
+ b = boxes[bi]
652
+ wc = b.get("world_coords", [0.0, 0.0, 0.0])
653
+ cx, cy = float(wc[0]), float(wc[1])
654
+ w = float(b.get("w", 1.8))
655
+ l = float(b.get("l", 4.0))
656
+ yaw = float(b.get("yaw", 0.0))
657
+ corners = _box_corners_xy(cx, cy, w, l, yaw)
658
+ poly = np.vstack([corners, corners[0:1]])
659
+ ax_bev.plot(poly[:, 0], poly[:, 1], color="#FF2E2E", linewidth=1.8, alpha=0.95)
660
+ ax_bev.text(
661
+ cx,
662
+ cy,
663
+ _short_cat(cat),
664
+ fontsize=7.5,
665
+ color="#FF2E2E",
666
+ ha="center",
667
+ va="center",
668
+ zorder=30,
669
+ bbox=dict(boxstyle="round,pad=0.12", facecolor="white", edgecolor="none", alpha=0.55),
670
+ )
671
+
672
+ # Draw detector predicted boxes (StreamPETR) in BEV.
673
+ det_handle = None
674
+ if args.draw_det_pred_boxes and det_mm is not None and ego_pose_map is not None:
675
+ R_e2g_t = None
676
+ t_e2g_t = None
677
+ if sid in ego_pose_map:
678
+ R_e2g_t, t_e2g_t = ego_pose_map[sid]
679
+ if R_e2g_t is not None and t_e2g_t is not None:
680
+ dets = _extract_results_list_for_token(det_mm, sid)
681
+ # Filter & transform
682
+ det_plot = []
683
+ xlo, xhi = float(args.bev_xlim[0]), float(args.bev_xlim[1])
684
+ ylo, yhi = float(args.bev_ylim[0]), float(args.bev_ylim[1])
685
+ for d in dets:
686
+ try:
687
+ score = float(d.get("detection_score", 0.0))
688
+ except Exception:
689
+ score = 0.0
690
+ if score < float(args.det_score_thresh):
691
+ continue
692
+ tr = d.get("translation", None)
693
+ sz = d.get("size", None)
694
+ rot = d.get("rotation", None)
695
+ if not (isinstance(tr, (list, tuple)) and len(tr) >= 3):
696
+ continue
697
+ if not (isinstance(sz, (list, tuple)) and len(sz) >= 3):
698
+ continue
699
+ if not (isinstance(rot, (list, tuple)) and len(rot) == 4):
700
+ continue
701
+
702
+ p_g = np.asarray([float(tr[0]), float(tr[1]), float(tr[2])], dtype=np.float64)
703
+ # global -> ego
704
+ p_e = (R_e2g_t.T @ (p_g - t_e2g_t)).reshape(3)
705
+ # ego -> paper
706
+ x_p = float(-p_e[1])
707
+ y_p = float(p_e[0])
708
+ z_p = float(p_e[2])
709
+ if not (xlo <= x_p <= xhi and ylo <= y_p <= yhi):
710
+ continue
711
+
712
+ # orientation: global -> ego -> paper yaw
713
+ R_box_g = _quat_to_rotmat(float(rot[0]), float(rot[1]), float(rot[2]), float(rot[3]))
714
+ R_box_e = R_e2g_t.T @ R_box_g
715
+ yaw_e = float(math.atan2(R_box_e[1, 0], R_box_e[0, 0])) # yaw-from-x in ego
716
+ yaw_p = yaw_e + math.pi / 2.0 # convert to paper yaw-from-x
717
+
718
+ w = float(sz[0])
719
+ l = float(sz[1])
720
+ cat = str(d.get("detection_name", "obj"))
721
+ det_plot.append((score, cat, x_p, y_p, z_p, w, l, yaw_p))
722
+
723
+ det_plot.sort(key=lambda t: -t[0])
724
+ det_plot = det_plot[: max(int(args.det_max_boxes), 0)]
725
+
726
+ det_color = "#00A65A" # green
727
+ for score, cat, x_p, y_p, _z, w, l, yaw_p in det_plot:
728
+ corners = _box_corners_xy(x_p, y_p, w, l, yaw_p)
729
+ poly = np.vstack([corners, corners[0:1]])
730
+ ax_bev.plot(poly[:, 0], poly[:, 1], color=det_color, linewidth=1.1, alpha=0.65, linestyle="--")
731
+
732
+ # Label top-k predictions to show "front vehicles"
733
+ for score, cat, x_p, y_p, _z, w, l, yaw_p in det_plot[: max(int(args.det_label_topk), 0)]:
734
+ ax_bev.text(
735
+ x_p,
736
+ y_p,
737
+ _short_cat(cat),
738
+ fontsize=7.2,
739
+ color=det_color,
740
+ ha="center",
741
+ va="center",
742
+ zorder=31,
743
+ bbox=dict(boxstyle="round,pad=0.12", facecolor="white", edgecolor="none", alpha=0.55),
744
+ )
745
+
746
+ # Legend handle for detector boxes
747
+ from matplotlib.lines import Line2D
748
+ det_handle = Line2D([0], [0], color=det_color, lw=1.5, linestyle="--", label="Det Pred boxes")
749
+
750
+ # Plot trajectories
751
+ if gt_wps:
752
+ gt_arr = np.asarray(gt_wps, dtype=np.float64)
753
+ ax_bev.plot(gt_arr[:, 0], gt_arr[:, 1], "--o", color="black", linewidth=1.8, markersize=4.0, alpha=0.85, label="GT traj")
754
+ if pred_wps:
755
+ pr_arr = np.asarray(pred_wps, dtype=np.float64)
756
+ ax_bev.plot(pr_arr[:, 0], pr_arr[:, 1], "-o", color="#00C2FF", linewidth=2.4, markersize=4.5, alpha=0.95, label="Pred traj")
757
+
758
+ # Legend: add GT boxes handle to avoid confusion (boxes are NOT predicted here).
759
+ handles, _labels = ax_bev.get_legend_handles_labels()
760
+ if args.draw_gt_boxes:
761
+ import matplotlib.patches as mpatches
762
+ handles.append(mpatches.Patch(facecolor="none", edgecolor="#FF5A5A", linewidth=1.2, label="GT boxes"))
763
+ if det_handle is not None:
764
+ handles.append(det_handle)
765
+ ax_bev.legend(handles=handles, loc="upper left", fontsize=9, frameon=True)
766
+
767
+ # Command label (like paper: bottom-left)
768
+ ax_bev.text(
769
+ 0.02,
770
+ 0.02,
771
+ _title_case_command(str(cmd)),
772
+ transform=ax_bev.transAxes,
773
+ fontsize=12,
774
+ fontweight="bold",
775
+ ha="left",
776
+ va="bottom",
777
+ color="black",
778
+ bbox=dict(boxstyle="round,pad=0.25", facecolor="white", edgecolor="none", alpha=0.65),
779
+ )
780
+
781
+ out_png.parent.mkdir(parents=True, exist_ok=True)
782
+ fig.savefig(out_png, bbox_inches="tight")
783
+ plt.close(fig)
784
+ print(f"[saved] {out_png}")
785
+
786
+ if det_mm is not None:
787
+ try:
788
+ det_mm.close()
789
+ except Exception:
790
+ pass
791
+ if det_f is not None:
792
+ try:
793
+ det_f.close()
794
+ except Exception:
795
+ pass
796
+
797
+
798
+ if __name__ == "__main__":
799
+ main()
800
+
scripts/vis_traffic_violation.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Figure 11 — Violation of traffic regulations.
4
+
5
+ This script focuses on the construction-blocking case shown in the paper:
6
+ the road ahead is fenced by barriers / traffic cones, but Atlas still outputs
7
+ a "go straight" trajectory that cuts through the blocked area.
8
+
9
+ Style:
10
+ - Left: 6 camera views, only CAM_FRONT overlays the trajectory
11
+ - Right: clean BEV with black road boundaries, red construction boxes,
12
+ green/blue trajectory, and "Go Straight" label
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import math
20
+ import pickle
21
+ import sys
22
+ from pathlib import Path
23
+ from typing import Dict, List, Optional, Tuple
24
+
25
+ import numpy as np
26
+
27
+ _REPO = Path(__file__).resolve().parent.parent
28
+ if str(_REPO) not in sys.path:
29
+ sys.path.insert(0, str(_REPO))
30
+
31
+ CAM_ORDER_PAPER = [
32
+ "CAM_FRONT_LEFT",
33
+ "CAM_FRONT",
34
+ "CAM_FRONT_RIGHT",
35
+ "CAM_BACK_LEFT",
36
+ "CAM_BACK",
37
+ "CAM_BACK_RIGHT",
38
+ ]
39
+ _IDX_REORDER = [2, 0, 1, 4, 3, 5]
40
+
41
+ LOCATION_TO_MAP = {
42
+ "singapore-onenorth": "53992ee3023e5494b90c316c183be829.png",
43
+ "boston-seaport": "36092f0b03a857c6a3403e25b4b7aab3.png",
44
+ "singapore-queenstown": "93406b464a165eaba6d9de76ca09f5da.png",
45
+ "singapore-hollandvillage": "37819e65e09e5547b8a3ceaefba56bb2.png",
46
+ }
47
+ MAP_RES = 0.1 # meters per pixel
48
+ DEFAULT_FIG11_SAMPLE = "856ccc626a4a4c0aaac1e62335050ac0"
49
+
50
+
51
+ def _load_json(path: Path):
52
+ with path.open("r", encoding="utf-8") as f:
53
+ return json.load(f)
54
+
55
+
56
+ def _load_pickle(path: Path):
57
+ with path.open("rb") as f:
58
+ return pickle.load(f)
59
+
60
+
61
+ def _quat_to_rotmat(qw, qx, qy, qz):
62
+ n = math.sqrt(qw * qw + qx * qx + qy * qy + qz * qz)
63
+ if n < 1e-12:
64
+ return np.eye(3, dtype=np.float64)
65
+ qw, qx, qy, qz = qw / n, qx / n, qy / n, qz / n
66
+ xx, yy, zz = qx * qx, qy * qy, qz * qz
67
+ xy, xz, yz = qx * qy, qx * qz, qy * qz
68
+ wx, wy, wz = qw * qx, qw * qy, qw * qz
69
+ return np.array(
70
+ [
71
+ [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)],
72
+ [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)],
73
+ [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)],
74
+ ],
75
+ dtype=np.float64,
76
+ )
77
+
78
+
79
+ def _quat_to_yaw(q):
80
+ w, x, y, z = [float(v) for v in q]
81
+ return math.atan2(2 * (w * z + x * y), 1 - 2 * (y * y + z * z))
82
+
83
+
84
+ def _paper_xy_to_ego(x_right: float, y_fwd: float, z_up: float = 0.0) -> np.ndarray:
85
+ return np.array([float(y_fwd), float(-x_right), float(z_up)], dtype=np.float64)
86
+
87
+
88
+ def _project_batch(pts_ego: np.ndarray, R_c2e: np.ndarray, t_c2e: np.ndarray, K: np.ndarray) -> np.ndarray:
89
+ R_e2c = R_c2e.T
90
+ pts_cam = (R_e2c @ (pts_ego - t_c2e[None, :]).T).T
91
+ z = pts_cam[:, 2]
92
+ keep = z > 1e-3
93
+ pts_cam = pts_cam[keep]
94
+ if pts_cam.shape[0] == 0:
95
+ return np.zeros((0, 2), dtype=np.float64)
96
+ x = pts_cam[:, 0] / pts_cam[:, 2]
97
+ y = pts_cam[:, 1] / pts_cam[:, 2]
98
+ fx, fy = float(K[0, 0]), float(K[1, 1])
99
+ cx, cy = float(K[0, 2]), float(K[1, 2])
100
+ return np.stack([fx * x + cx, fy * y + cy], axis=1)
101
+
102
+
103
+ def _load_cam_calibs(nusc_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
104
+ meta = nusc_root / "v1.0-trainval"
105
+ sensor = _load_json(meta / "sensor.json")
106
+ calib = _load_json(meta / "calibrated_sensor.json")
107
+
108
+ sensor_token_by_channel: Dict[str, str] = {}
109
+ for rec in sensor:
110
+ ch = rec.get("channel")
111
+ tok = rec.get("token")
112
+ if isinstance(ch, str) and isinstance(tok, str):
113
+ sensor_token_by_channel[ch] = tok
114
+
115
+ calib_by_sensor_token: Dict[str, Dict] = {}
116
+ for rec in calib:
117
+ st = rec.get("sensor_token")
118
+ if isinstance(st, str):
119
+ calib_by_sensor_token[st] = rec
120
+
121
+ out: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
122
+ for ch, st in sensor_token_by_channel.items():
123
+ rec = calib_by_sensor_token.get(st)
124
+ if not rec or "camera_intrinsic" not in rec:
125
+ continue
126
+ t = np.asarray(rec.get("translation", [0, 0, 0]), dtype=np.float64).reshape(3)
127
+ q = rec.get("rotation", [1, 0, 0, 0])
128
+ K = np.asarray(rec.get("camera_intrinsic", np.eye(3).tolist()), dtype=np.float64)
129
+ if not (isinstance(q, (list, tuple)) and len(q) == 4):
130
+ continue
131
+ if K.shape != (3, 3):
132
+ continue
133
+ R = _quat_to_rotmat(*[float(x) for x in q])
134
+ out[ch] = (R, t, K)
135
+ return out
136
+
137
+
138
+ def _load_ego_poses(nusc_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, float]]:
139
+ pkl = nusc_root / "nuscenes_infos_val.pkl"
140
+ obj = _load_pickle(pkl)
141
+ infos = obj["infos"] if isinstance(obj, dict) and "infos" in obj else obj
142
+ out = {}
143
+ for it in infos:
144
+ tok = str(it.get("token", ""))
145
+ if not tok:
146
+ continue
147
+ t = np.asarray(it.get("ego2global_translation", [0, 0, 0]), dtype=np.float64).reshape(3)
148
+ q = it.get("ego2global_rotation", [1, 0, 0, 0])
149
+ if isinstance(q, (list, tuple)) and len(q) == 4:
150
+ out[tok] = (_quat_to_rotmat(*[float(x) for x in q]), t, _quat_to_yaw(q))
151
+ return out
152
+
153
+
154
+ def _get_sample_location(nusc_root: Path, sample_token: str) -> str:
155
+ samples = _load_json(nusc_root / "v1.0-trainval" / "sample.json")
156
+ scenes = _load_json(nusc_root / "v1.0-trainval" / "scene.json")
157
+ logs = _load_json(nusc_root / "v1.0-trainval" / "log.json")
158
+ scene_map = {s["token"]: s for s in scenes}
159
+ log_map = {l["token"]: l for l in logs}
160
+ for s in samples:
161
+ if s["token"] == sample_token:
162
+ sc = scene_map.get(s.get("scene_token", ""), {})
163
+ lg = log_map.get(sc.get("log_token", ""), {})
164
+ return str(lg.get("location", ""))
165
+ return ""
166
+
167
+
168
+ def _smooth_map(bev_map: np.ndarray) -> np.ndarray:
169
+ acc = bev_map.copy()
170
+ acc += np.roll(bev_map, 1, axis=0)
171
+ acc += np.roll(bev_map, -1, axis=0)
172
+ acc += np.roll(bev_map, 1, axis=1)
173
+ acc += np.roll(bev_map, -1, axis=1)
174
+ acc += np.roll(np.roll(bev_map, 1, axis=0), 1, axis=1)
175
+ acc += np.roll(np.roll(bev_map, 1, axis=0), -1, axis=1)
176
+ acc += np.roll(np.roll(bev_map, -1, axis=0), 1, axis=1)
177
+ acc += np.roll(np.roll(bev_map, -1, axis=0), -1, axis=1)
178
+ return acc / 9.0
179
+
180
+
181
+ def _build_bev_map(
182
+ nusc_root: Path,
183
+ location: str,
184
+ ego_xy: np.ndarray,
185
+ ego_yaw: float,
186
+ bev_xlim: Tuple[float, float],
187
+ bev_ylim: Tuple[float, float],
188
+ bev_res: float = 0.1,
189
+ ) -> Optional[np.ndarray]:
190
+ import PIL.Image
191
+ PIL.Image.MAX_IMAGE_PIXELS = None
192
+ from PIL import Image
193
+
194
+ map_fn = LOCATION_TO_MAP.get(location)
195
+ if not map_fn:
196
+ return None
197
+ map_path = nusc_root / "maps" / map_fn
198
+ if not map_path.exists():
199
+ return None
200
+
201
+ map_img = Image.open(map_path)
202
+ mw, mh = map_img.size
203
+ map_max_y = mh * MAP_RES
204
+ map_arr = np.asarray(map_img, dtype=np.float32) / 255.0
205
+
206
+ ex, ey = float(ego_xy[0]), float(ego_xy[1])
207
+ c_yaw, s_yaw = math.cos(ego_yaw), math.sin(ego_yaw)
208
+
209
+ x0, x1 = bev_xlim
210
+ y0, y1 = bev_ylim
211
+ nx = int((x1 - x0) / bev_res)
212
+ ny = int((y1 - y0) / bev_res)
213
+ bev = np.zeros((ny, nx), dtype=np.float32)
214
+
215
+ px_arr = np.linspace(x0, x1, nx)
216
+ py_arr = np.linspace(y1, y0, ny)
217
+ PX, PY = np.meshgrid(px_arr, py_arr)
218
+
219
+ GX = ex + PY * c_yaw + PX * s_yaw
220
+ GY = ey + PY * s_yaw - PX * c_yaw
221
+
222
+ MX = (GX / MAP_RES).astype(np.int32)
223
+ MY = ((map_max_y - GY) / MAP_RES).astype(np.int32)
224
+ valid = (MX >= 0) & (MX < mw) & (MY >= 0) & (MY < mh)
225
+ bev[valid] = map_arr[MY[valid], MX[valid]]
226
+ return _smooth_map(bev)
227
+
228
+
229
+ def _box_corners(cx: float, cy: float, w: float, l: float, yaw: float) -> np.ndarray:
230
+ c, s = math.cos(yaw), math.sin(yaw)
231
+ center = np.array([cx, cy], dtype=np.float64)
232
+ d_len = np.array([c, s], dtype=np.float64) * (l / 2.0)
233
+ d_wid = np.array([-s, c], dtype=np.float64) * (w / 2.0)
234
+ return np.stack(
235
+ [
236
+ center + d_len + d_wid,
237
+ center + d_len - d_wid,
238
+ center - d_len - d_wid,
239
+ center - d_len + d_wid,
240
+ ],
241
+ axis=0,
242
+ )
243
+
244
+
245
+ def _is_barrier_like(cat: str) -> bool:
246
+ return ("barrier" in cat) or ("traffic_cone" in cat) or ("trafficcone" in cat)
247
+
248
+
249
+ def _is_context_like(cat: str) -> bool:
250
+ return ("construction" in cat) or ("pedestrian" in cat) or ("vehicle.car" in cat)
251
+
252
+
253
+ def _title_cmd(cmd: str) -> str:
254
+ c = (cmd or "").strip().lower()
255
+ return {
256
+ "turn left": "Turn Left",
257
+ "turn right": "Turn Right",
258
+ "go straight": "Go Straight",
259
+ }.get(c, cmd)
260
+
261
+
262
+ def _blocked_score(item: Dict) -> float:
263
+ route_command = str(item.get("route_command", "")).strip().lower()
264
+ if route_command != "go straight":
265
+ return -1e9
266
+ boxes = item.get("gt_boxes_3d", []) or []
267
+ wps = (item.get("ego_motion", {}) or {}).get("waypoints", []) or []
268
+ n_block_front = 0
269
+ n_block_center = 0
270
+ n_barrier = 0
271
+ n_cone = 0
272
+ for b in boxes:
273
+ cat = str(b.get("category", ""))
274
+ wc = b.get("world_coords", [0, 0, 0])
275
+ if not (isinstance(wc, (list, tuple)) and len(wc) >= 2):
276
+ continue
277
+ x, y = float(wc[0]), float(wc[1])
278
+ if _is_barrier_like(cat):
279
+ if "barrier" in cat:
280
+ n_barrier += 1
281
+ else:
282
+ n_cone += 1
283
+ if 0 < y < 25 and abs(x) < 10:
284
+ n_block_front += 1
285
+ if 2 < y < 20 and abs(x) < 4:
286
+ n_block_center += 1
287
+ through_center = 0
288
+ for x, y in wps:
289
+ if 2 < float(y) < 20 and abs(float(x)) < 4:
290
+ through_center += 1
291
+ return n_block_center * 8 + n_block_front * 3 + through_center * 4 + n_barrier * 1.5 + n_cone
292
+
293
+
294
+ def parse_args():
295
+ ap = argparse.ArgumentParser()
296
+ ap.add_argument("--eval_json", default="work_dirs/eval_final_plan100.json")
297
+ ap.add_argument("--data_json", default="data/atlas_planning_val_uniad_command.json")
298
+ ap.add_argument("--data_root", default="/home/guoyuanbo/autodl-tmp/data/nuscenes")
299
+ ap.add_argument("--sample_id", default=None)
300
+ ap.add_argument("--out_png", default="work_dirs/atlas_traffic_violation.png")
301
+ ap.add_argument("--dpi", type=int, default=200)
302
+ ap.add_argument("--bev_xlim", type=float, nargs=2, default=[-7.5, 10.5])
303
+ ap.add_argument("--bev_ylim", type=float, nargs=2, default=[-1.5, 30.5])
304
+ return ap.parse_args()
305
+
306
+
307
+ def main():
308
+ args = parse_args()
309
+ repo = _REPO
310
+ nusc_root = Path(args.data_root).resolve()
311
+ out_png = (repo / args.out_png).resolve()
312
+
313
+ eval_obj = _load_json((repo / args.eval_json).resolve())
314
+ data_items = _load_json((repo / args.data_json).resolve())
315
+ pred_by_id = {
316
+ str(r.get("sample_id", "")): str(r.get("generated_text", ""))
317
+ for r in eval_obj.get("predictions", [])
318
+ if r.get("sample_id")
319
+ }
320
+ item_by_id = {str(it["id"]): it for it in data_items if it.get("id")}
321
+
322
+ from src.eval.metrics import parse_planning_output
323
+
324
+ sid = args.sample_id
325
+ if not sid:
326
+ if DEFAULT_FIG11_SAMPLE in item_by_id and parse_planning_output(pred_by_id.get(DEFAULT_FIG11_SAMPLE, "")):
327
+ sid = DEFAULT_FIG11_SAMPLE
328
+ else:
329
+ candidates = []
330
+ for item in data_items:
331
+ item_sid = str(item.get("id", ""))
332
+ pred_text = pred_by_id.get(item_sid, "")
333
+ if not pred_text:
334
+ continue
335
+ plan = parse_planning_output(pred_text)
336
+ if not plan or not plan.get("waypoints"):
337
+ continue
338
+ candidates.append((_blocked_score(item), item_sid))
339
+ candidates.sort(reverse=True)
340
+ if not candidates:
341
+ raise RuntimeError("No valid construction-blocked sample found.")
342
+ sid = candidates[0][1]
343
+ if sid not in item_by_id:
344
+ raise RuntimeError(f"sample_id {sid} not found in data_json")
345
+
346
+ item = item_by_id[sid]
347
+ pred_text = pred_by_id.get(sid, "")
348
+ plan = parse_planning_output(pred_text) if pred_text else None
349
+ if not plan or not plan.get("waypoints"):
350
+ raise RuntimeError(f"sample_id {sid} has no parseable planning output in {args.eval_json}")
351
+
352
+ pred_wps = np.asarray(plan["waypoints"], dtype=np.float64)
353
+ gt_wps = np.asarray((item.get("ego_motion", {}) or {}).get("waypoints", []) or [], dtype=np.float64)
354
+ cmd = str(item.get("route_command", ""))
355
+ boxes = item.get("gt_boxes_3d", []) or []
356
+ location = _get_sample_location(nusc_root, sid)
357
+ ego_poses = _load_ego_poses(nusc_root)
358
+ ego_info = ego_poses.get(sid)
359
+ if ego_info is None:
360
+ raise RuntimeError(f"Missing ego pose for sample {sid}")
361
+ ego_xy = ego_info[1][:2]
362
+ ego_yaw = ego_info[2]
363
+
364
+ print(f"[sample] {sid}")
365
+ print(f" location: {location}")
366
+ print(f" pred: {pred_text}")
367
+
368
+ bev_map = _build_bev_map(
369
+ nusc_root,
370
+ location,
371
+ ego_xy,
372
+ ego_yaw,
373
+ tuple(args.bev_xlim),
374
+ tuple(args.bev_ylim),
375
+ bev_res=0.1,
376
+ )
377
+
378
+ from PIL import Image
379
+
380
+ rel_paths = list(item.get("image_paths", []) or [])
381
+ if len(rel_paths) != 6:
382
+ raise RuntimeError(f"Expected 6 images, got {len(rel_paths)}")
383
+ imgs = []
384
+ for rp in rel_paths:
385
+ p = Path(rp)
386
+ if not p.is_absolute():
387
+ p = nusc_root / rp
388
+ imgs.append(Image.open(p).convert("RGB"))
389
+ imgs = [imgs[i] for i in _IDX_REORDER]
390
+
391
+ cam_calibs = _load_cam_calibs(nusc_root)
392
+
393
+ import matplotlib
394
+ matplotlib.use("Agg")
395
+ import matplotlib.pyplot as plt
396
+ import matplotlib.patches as patches
397
+ from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
398
+
399
+ fig = plt.figure(figsize=(14.8, 4.1), dpi=args.dpi)
400
+ gs = GridSpec(1, 2, figure=fig, width_ratios=[3.0, 1.1], wspace=0.03)
401
+ gs_cam = GridSpecFromSubplotSpec(2, 3, subplot_spec=gs[0, 0], wspace=0.01, hspace=0.01)
402
+
403
+ ax_imgs = []
404
+ for i in range(6):
405
+ ax = fig.add_subplot(gs_cam[i // 3, i % 3])
406
+ ax.imshow(imgs[i])
407
+ w_i, h_i = imgs[i].size
408
+ ax.set_xlim(0, w_i)
409
+ ax.set_ylim(h_i, 0)
410
+ ax.axis("off")
411
+ ax.text(
412
+ 6,
413
+ 14,
414
+ CAM_ORDER_PAPER[i],
415
+ color="white",
416
+ fontsize=7,
417
+ ha="left",
418
+ va="top",
419
+ bbox=dict(boxstyle="square,pad=0.12", facecolor="black", edgecolor="none", alpha=0.55),
420
+ )
421
+ ax_imgs.append(ax)
422
+
423
+ if "CAM_FRONT" in cam_calibs:
424
+ R_c2e, t_c2e, K = cam_calibs["CAM_FRONT"]
425
+ ax_front = ax_imgs[1]
426
+ if gt_wps.shape[0] >= 2:
427
+ uv_gt = _project_batch(
428
+ np.array([_paper_xy_to_ego(x, y) for x, y in gt_wps], dtype=np.float64),
429
+ R_c2e,
430
+ t_c2e,
431
+ K,
432
+ )
433
+ if uv_gt.shape[0] >= 2:
434
+ ax_front.plot(uv_gt[:, 0], uv_gt[:, 1], color="#34c759", linewidth=4.0, alpha=0.95, zorder=18)
435
+ uv_pred = _project_batch(
436
+ np.array([_paper_xy_to_ego(x, y) for x, y in pred_wps], dtype=np.float64),
437
+ R_c2e,
438
+ t_c2e,
439
+ K,
440
+ )
441
+ if uv_pred.shape[0] >= 2:
442
+ ax_front.plot(uv_pred[:, 0], uv_pred[:, 1], color="#1f5cff", linewidth=2.2, alpha=0.98, zorder=20)
443
+ ax_front.scatter(uv_pred[:, 0], uv_pred[:, 1], color="#1f5cff", s=8, zorder=21)
444
+
445
+ ax_bev = fig.add_subplot(gs[0, 1])
446
+ ax_bev.set_facecolor("white")
447
+ ax_bev.set_xlim(*args.bev_xlim)
448
+ ax_bev.set_ylim(*args.bev_ylim)
449
+ ax_bev.set_aspect("equal", adjustable="box")
450
+ ax_bev.set_xticks([])
451
+ ax_bev.set_yticks([])
452
+ for spine in ax_bev.spines.values():
453
+ spine.set_linewidth(1.3)
454
+ spine.set_color("black")
455
+
456
+ if bev_map is not None:
457
+ ny, nx = bev_map.shape
458
+ xs = np.linspace(args.bev_xlim[0], args.bev_xlim[1], nx)
459
+ ys = np.linspace(args.bev_ylim[0], args.bev_ylim[1], ny)
460
+ ax_bev.contour(xs, ys, bev_map, levels=[0.5], colors="black", linewidths=0.8, zorder=1)
461
+
462
+ construction_boxes = []
463
+ for b in boxes:
464
+ cat = str(b.get("category", ""))
465
+ wc = b.get("world_coords", [0, 0, 0])
466
+ if not (isinstance(wc, (list, tuple)) and len(wc) >= 2):
467
+ continue
468
+ cx, cy = float(wc[0]), float(wc[1])
469
+ if not (args.bev_xlim[0] - 2 <= cx <= args.bev_xlim[1] + 2 and args.bev_ylim[0] - 2 <= cy <= args.bev_ylim[1] + 2):
470
+ continue
471
+ box_item = (cx, cy, float(b.get("w", 1.8)), float(b.get("l", 4.0)), float(b.get("yaw", 0.0)))
472
+ if _is_barrier_like(cat):
473
+ construction_boxes.append(box_item)
474
+
475
+ for cx, cy, w, l, yaw in construction_boxes:
476
+ poly = _box_corners(cx, cy, w, l, yaw)
477
+ poly = np.vstack([poly, poly[0:1]])
478
+ ax_bev.plot(poly[:, 0], poly[:, 1], color="#cf7a6b", linewidth=0.9, alpha=0.95, zorder=3)
479
+
480
+ if gt_wps.shape[0] >= 2:
481
+ ax_bev.plot(gt_wps[:, 0], gt_wps[:, 1], color="#34c759", linewidth=3.4, alpha=0.95, zorder=5)
482
+ ax_bev.plot(pred_wps[:, 0], pred_wps[:, 1], color="#1f5cff", linewidth=1.8, alpha=0.98, zorder=6)
483
+
484
+ ego_rect = patches.Rectangle(
485
+ (-0.45, -0.45),
486
+ 0.9,
487
+ 0.9,
488
+ linewidth=1.4,
489
+ edgecolor="#34c759",
490
+ facecolor="none",
491
+ zorder=7,
492
+ )
493
+ ax_bev.add_patch(ego_rect)
494
+
495
+ ax_bev.text(0.03, 0.98, "BEV", transform=ax_bev.transAxes, ha="left", va="top", fontsize=9, fontweight="bold")
496
+ ax_bev.text(
497
+ 0.03,
498
+ 0.03,
499
+ _title_cmd(cmd),
500
+ transform=ax_bev.transAxes,
501
+ ha="left",
502
+ va="bottom",
503
+ fontsize=9,
504
+ fontweight="bold",
505
+ )
506
+
507
+ out_png.parent.mkdir(parents=True, exist_ok=True)
508
+ fig.savefig(out_png, bbox_inches="tight", facecolor="white", pad_inches=0.02)
509
+ plt.close(fig)
510
+
511
+ print(f"[saved] {out_png}")
512
+ print(f" construction boxes: {len(construction_boxes)}")
513
+
514
+
515
+ if __name__ == "__main__":
516
+ main()
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (177 Bytes). View file
 
src/__pycache__/prompting.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
src/__pycache__/prompting.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
src/audit/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (240 Bytes). View file
 
src/audit/__pycache__/audit_utils.cpython-310.pyc ADDED
Binary file (998 Bytes). View file
 
src/dataset/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (316 Bytes). View file
 
src/dataset/__pycache__/atlas_dataset.cpython-310.pyc ADDED
Binary file (40 kB). View file
 
src/dataset/__pycache__/atlas_dataset.cpython-38.pyc ADDED
Binary file (40.6 kB). View file
 
src/dataset/__pycache__/scene_sampler.cpython-310.pyc ADDED
Binary file (3.88 kB). View file
 
src/dataset/__pycache__/scene_sampler.cpython-38.pyc ADDED
Binary file (3.89 kB). View file
 
src/dataset/atlas_dataset.py ADDED
@@ -0,0 +1,1416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ from torch.utils.data import Dataset
7
+ from PIL import Image
8
+ from typing import Dict, List, Optional, Tuple
9
+ from pathlib import Path
10
+ import torchvision.transforms as transforms
11
+ try:
12
+ import torchvision.transforms.v2 as _v2
13
+ _HAS_V2 = hasattr(_v2, "ToImage") and hasattr(_v2, "ToDtype")
14
+ except (ImportError, AttributeError):
15
+ _HAS_V2 = False
16
+ if _HAS_V2:
17
+ transforms = _v2
18
+
19
+ from src.prompting import (
20
+ PLANNING_TABLE3_MODES,
21
+ build_prompt,
22
+ rewrite_planning_prompt_for_table3,
23
+ )
24
+ from src.audit.audit_utils import audit_enabled, audit_check
25
+
26
+ NUM_DETECTION_QUERIES = 256
27
+ NUM_MAP_QUERIES = 256
28
+ PLANNING_STATE_RANGE = (-50.0, 50.0)
29
+ PLANNING_NUM_BINS = 1000
30
+
31
+ AVAILABLE_COMMANDS = ["turn left", "turn right", "go straight"]
32
+
33
+ # Z range aligned with StreamPETR point_cloud_range [-5, 3]
34
+ Z_MIN, Z_MAX = -5.0, 3.0
35
+
36
+ # nuScenes 10-class detection 类别映射
37
+ NUSCENES_CLASSES = [
38
+ 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
39
+ 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
40
+ ]
41
+
42
+ # 完整的 nuScenes 类别名映射到基础类别
43
+ NUSCENES_CATEGORY_MAP = {
44
+ # 基础类别名
45
+ 'car': 0, 'truck': 1, 'construction_vehicle': 2, 'bus': 3, 'trailer': 4,
46
+ 'barrier': 5, 'motorcycle': 6, 'bicycle': 7, 'pedestrian': 8, 'traffic_cone': 9,
47
+ # 完整 nuScenes 类别名 - 车辆
48
+ 'vehicle.car': 0, 'vehicle.truck': 1, 'vehicle.construction': 2,
49
+ 'vehicle.bus.bendy': 3, 'vehicle.bus.rigid': 3, 'vehicle.trailer': 4,
50
+ 'vehicle.motorcycle': 6, 'vehicle.bicycle': 7,
51
+ # 完整 nuScenes 类别名 - 行人
52
+ 'human.pedestrian.adult': 8, 'human.pedestrian.child': 8,
53
+ 'human.pedestrian.construction_worker': 8, 'human.pedestrian.police_officer': 8,
54
+ 'human.pedestrian.wheelchair': 8, 'human.pedestrian.stroller': 8,
55
+ 'human.pedestrian.personal_mobility': 8,
56
+ # 完整 nuScenes 类别名 - 可移动物体
57
+ 'movable_object.barrier': 5, 'movable_object.trafficcone': 9,
58
+ 'movable_object.traffic_cone': 9,
59
+ }
60
+
61
+
62
+ def nuscenes_to_paper_coords(x_nuscenes: float, y_nuscenes: float) -> Tuple[float, float]:
63
+ return -y_nuscenes, x_nuscenes
64
+
65
+
66
+ def planning_state_to_bin(
67
+ value: float,
68
+ min_val: float = PLANNING_STATE_RANGE[0],
69
+ max_val: float = PLANNING_STATE_RANGE[1],
70
+ num_bins: int = PLANNING_NUM_BINS,
71
+ ) -> int:
72
+ v = float(np.clip(value, min_val, max_val))
73
+ t = (v - min_val) / (max_val - min_val)
74
+ idx = int(round(t * (num_bins - 1)))
75
+ return int(np.clip(idx, 0, num_bins - 1))
76
+
77
+
78
+ def normalize_route_command(command: object) -> Optional[str]:
79
+ if not isinstance(command, str):
80
+ return None
81
+ cmd = command.strip().lower()
82
+ mapping = {
83
+ "turn left": "turn left",
84
+ "left": "turn left",
85
+ "turn right": "turn right",
86
+ "right": "turn right",
87
+ "go straight": "go straight",
88
+ "straight": "go straight",
89
+ "keep straight": "go straight",
90
+ "forward": "go straight",
91
+ }
92
+ if cmd in mapping:
93
+ return mapping[cmd]
94
+ raise ValueError(f"Unsupported route command: {command!r}")
95
+
96
+
97
+ CAMERA_NAMES = [
98
+ 'CAM_FRONT',
99
+ 'CAM_FRONT_RIGHT',
100
+ 'CAM_FRONT_LEFT',
101
+ 'CAM_BACK',
102
+ 'CAM_BACK_LEFT',
103
+ 'CAM_BACK_RIGHT',
104
+ ]
105
+
106
+
107
+ VALID_TASK_TYPES = {"detection", "lane", "planning", "caption"}
108
+
109
+
110
+ def _normalize_task_type(task: object) -> Optional[str]:
111
+ if not isinstance(task, str):
112
+ return None
113
+ task_name = task.strip().lower()
114
+ if task_name in VALID_TASK_TYPES:
115
+ return task_name
116
+ raise ValueError(f"Unsupported task type: {task!r}")
117
+
118
+
119
+ def _infer_task_type_from_structure(item: dict) -> Optional[str]:
120
+ if not isinstance(item, dict):
121
+ return None
122
+
123
+ try:
124
+ num_map_queries = int(item.get("num_map_queries", -1))
125
+ except Exception:
126
+ num_map_queries = -1
127
+
128
+ if "ego_motion" in item or "gt_boxes_3d_per_timestep" in item:
129
+ return "planning"
130
+ if isinstance(item.get("camera"), str) and item.get("camera") and num_map_queries == 0:
131
+ return "caption"
132
+ if item.get("sensor") is not None or "openlane_lane_centerline" in item:
133
+ return "lane"
134
+ if "gt_boxes_3d" in item or num_map_queries == 0:
135
+ return "detection"
136
+ return None
137
+
138
+
139
+ def infer_task_type(item: dict) -> str:
140
+ task = _normalize_task_type(item.get("task"))
141
+ if task:
142
+ return task
143
+ inferred = _infer_task_type_from_structure(item)
144
+ if inferred:
145
+ return inferred
146
+ sample_id = item.get("id", "<unknown>")
147
+ raise ValueError(
148
+ f"Unable to infer task type for sample {sample_id!r}. "
149
+ "Please add an explicit 'task' field to the dataset JSON."
150
+ )
151
+
152
+
153
+
154
+ def load_tokenizer(model_name: str = "lmsys/vicuna-7b-v1.5"):
155
+ from transformers import AutoTokenizer
156
+
157
+ print(f"Loading tokenizer: {model_name}")
158
+
159
+ try:
160
+ tokenizer = AutoTokenizer.from_pretrained(
161
+ model_name,
162
+ use_fast=False,
163
+ trust_remote_code=True,
164
+ )
165
+ except Exception as e:
166
+ print(f"Failed to load from {model_name}: {e}")
167
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
168
+ tokenizer = AutoTokenizer.from_pretrained(
169
+ model_name,
170
+ use_fast=False,
171
+ trust_remote_code=True,
172
+ )
173
+
174
+ if tokenizer.pad_token is None:
175
+ tokenizer.pad_token = tokenizer.eos_token
176
+
177
+ return tokenizer
178
+
179
+
180
+ class AtlasDataset(Dataset):
181
+ def __init__(
182
+ self,
183
+ json_file: str,
184
+ image_root: str,
185
+ tokenizer,
186
+ max_length: int = 4096,
187
+ image_size: Tuple[int, int] = (800, 1600),
188
+ use_nuscenes_calibration: bool = True,
189
+ is_training: bool = True,
190
+ num_detection_queries: int = NUM_DETECTION_QUERIES,
191
+ num_map_queries: int = NUM_MAP_QUERIES,
192
+ planning_table3_mode: str = "atlas_base",
193
+ image_path_remap: Optional[str] = None,
194
+ precomputed_det_tokens: Optional[str] = None,
195
+ precomputed_map_tokens: Optional[str] = None,
196
+ ):
197
+ self.json_file = json_file
198
+ self.image_root = image_root
199
+ self.image_path_remap = None
200
+ if image_path_remap:
201
+ old, new = image_path_remap.split("=", 1)
202
+ self.image_path_remap = (old, new)
203
+ self.precomputed_det_dir = precomputed_det_tokens
204
+ self.precomputed_map_dir = precomputed_map_tokens
205
+ self.tokenizer = tokenizer
206
+ self.max_length = max_length
207
+ self.is_training = is_training
208
+ self.num_detection_queries = int(num_detection_queries)
209
+ self.num_map_queries = int(num_map_queries)
210
+ if planning_table3_mode not in PLANNING_TABLE3_MODES:
211
+ raise ValueError(
212
+ f"Unsupported planning_table3_mode: {planning_table3_mode}. "
213
+ f"Expected one of {PLANNING_TABLE3_MODES}."
214
+ )
215
+ self.planning_table3_mode = planning_table3_mode
216
+ self.use_nuscenes_calibration = bool(use_nuscenes_calibration)
217
+
218
+ if isinstance(image_size, int):
219
+ self.image_size = (image_size, image_size)
220
+ else:
221
+ self.image_size = image_size
222
+
223
+ paths = [p.strip() for p in str(json_file).split(",") if p.strip()]
224
+ if len(paths) == 0:
225
+ raise RuntimeError("json_file is empty")
226
+ self.data = []
227
+ for p in paths:
228
+ with open(p, "r", encoding="utf-8") as f:
229
+ chunk = json.load(f)
230
+ if not isinstance(chunk, list):
231
+ raise RuntimeError(f"JSON must be a list: {p}")
232
+ self.data.extend(chunk)
233
+
234
+ if len(paths) == 1:
235
+ print(f"Loaded {len(self.data)} samples from {paths[0]}")
236
+ else:
237
+ print(f"Loaded {len(self.data)} samples from {len(paths)} json files")
238
+ print(f"Image size: {self.image_size[0]}x{self.image_size[1]} (HxW)")
239
+
240
+ self._task_types = [infer_task_type(item) for item in self.data]
241
+ from collections import Counter
242
+ print(f"Task distribution: {dict(Counter(self._task_types))}")
243
+ self._audit_planning_route_command_schema()
244
+
245
+ if _HAS_V2:
246
+ self.image_transform = transforms.Compose([
247
+ transforms.ToImage(),
248
+ transforms.ToDtype(torch.float32, scale=True),
249
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
250
+ ])
251
+ else:
252
+ self.image_transform = transforms.Compose([
253
+ transforms.ToTensor(),
254
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
255
+ ])
256
+
257
+ self.streampetr_conf = {
258
+ "H": 900, "W": 1600,
259
+ "final_dim": (800, 1600),
260
+ "resize_lim": (1.0, 1.2),
261
+ "bot_pct_lim": (0.0, 0.0),
262
+ }
263
+
264
+ self.topomlp_conf = {
265
+ "target_size": (1600, 800),
266
+ }
267
+
268
+ self.calibration = None
269
+ if self.use_nuscenes_calibration:
270
+ self.calibration = self._load_nuscenes_calibration()
271
+ if self.calibration is None:
272
+ print("[WARN] nuScenes calibration metadata not found, will use per-item sensor field.")
273
+
274
+ self.query_token = "<query>"
275
+ vocab = tokenizer.get_vocab()
276
+ if self.query_token not in vocab:
277
+ raise RuntimeError(f"Tokenizer missing required special token: {self.query_token}")
278
+ if tokenizer.pad_token_id is None:
279
+ raise RuntimeError("tokenizer.pad_token_id is None")
280
+
281
+ self.query_token_id = tokenizer.convert_tokens_to_ids(self.query_token)
282
+ print(f"<query> token ID: {self.query_token_id}")
283
+
284
+ def _load_nuscenes_calibration(self) -> Optional[Dict]:
285
+ try:
286
+ nuscenes_root = Path(self.image_root)
287
+ version_dir = None
288
+ for v in ["v1.0-trainval", "v1.0-mini", "v1.0-test"]:
289
+ if (nuscenes_root / v).exists():
290
+ version_dir = nuscenes_root / v
291
+ break
292
+ if version_dir is None:
293
+ print("nuScenes metadata folder not found under image_root")
294
+ return None
295
+
296
+ sample_data_file = version_dir / "sample_data.json"
297
+ calibrated_sensor_file = version_dir / "calibrated_sensor.json"
298
+ ego_pose_file = version_dir / "ego_pose.json"
299
+ sample_file = version_dir / "sample.json"
300
+ if (not sample_data_file.exists()) or (not calibrated_sensor_file.exists()) or (not ego_pose_file.exists()) or (not sample_file.exists()):
301
+ print(f"nuScenes metadata missing under {version_dir}")
302
+ return None
303
+
304
+ with open(sample_data_file, "r") as f:
305
+ sample_data = json.load(f)
306
+ with open(calibrated_sensor_file, "r") as f:
307
+ calibrated_sensor = json.load(f)
308
+ with open(ego_pose_file, "r") as f:
309
+ ego_pose = json.load(f)
310
+ with open(sample_file, "r") as f:
311
+ sample = json.load(f)
312
+
313
+ sample_data_by_filename = {rec["filename"]: rec for rec in sample_data if "filename" in rec}
314
+ calibrated_sensor_by_token = {rec["token"]: rec for rec in calibrated_sensor if "token" in rec}
315
+ ego_pose_by_token = {rec["token"]: rec for rec in ego_pose if "token" in rec}
316
+ sample_by_token = {rec["token"]: rec for rec in sample if "token" in rec}
317
+
318
+ lidar_sd_by_sample_token: Dict[str, Dict] = {}
319
+ for rec in sample_data:
320
+ if not isinstance(rec, dict):
321
+ continue
322
+ fn = str(rec.get("filename", "")).replace("\\", "/")
323
+ if "/LIDAR_TOP/" not in fn:
324
+ continue
325
+ if not fn.startswith("samples/"):
326
+ continue
327
+ if not bool(rec.get("is_key_frame", False)):
328
+ continue
329
+ st = rec.get("sample_token", None)
330
+ if st is None:
331
+ continue
332
+ lidar_sd_by_sample_token.setdefault(str(st), rec)
333
+
334
+ print(f"Loaded nuScenes metadata from {version_dir.name}:")
335
+ print(f" sample_data: {len(sample_data_by_filename)}")
336
+ print(f" calibrated_sensor: {len(calibrated_sensor_by_token)}")
337
+ print(f" ego_pose: {len(ego_pose_by_token)}")
338
+ print(f" sample: {len(sample_by_token)}")
339
+ print(f" lidar_keyframes: {len(lidar_sd_by_sample_token)}")
340
+
341
+ return {
342
+ "sample_data_by_filename": sample_data_by_filename,
343
+ "calibrated_sensor_by_token": calibrated_sensor_by_token,
344
+ "ego_pose_by_token": ego_pose_by_token,
345
+ "sample_by_token": sample_by_token,
346
+ "lidar_sd_by_sample_token": lidar_sd_by_sample_token,
347
+ }
348
+ except Exception as e:
349
+ print(f"Failed to load nuScenes calibration: {e}")
350
+ return None
351
+
352
+ def __len__(self) -> int:
353
+ return len(self.data)
354
+
355
+ def _audit_planning_route_command_schema(self) -> None:
356
+ planning_indices = [i for i, task in enumerate(self._task_types) if task == "planning"]
357
+ if not planning_indices:
358
+ return
359
+
360
+ total = len(planning_indices)
361
+ top_level_count = 0
362
+ legacy_ego_motion_command = 0
363
+ route_command_dist: Dict[str, int] = {}
364
+
365
+ for idx in planning_indices:
366
+ item = self.data[idx]
367
+ try:
368
+ command = self._resolve_route_command(item)
369
+ except Exception as exc:
370
+ sample_id = item.get("id", f"planning_idx_{idx}")
371
+ print(f"[WARN] invalid planning route_command for sample {sample_id}: {exc}")
372
+ command = None
373
+ if command is not None:
374
+ top_level_count += 1
375
+ route_command_dist[command] = route_command_dist.get(command, 0) + 1
376
+ ego = item.get("ego_motion")
377
+ if isinstance(ego, dict) and "command" in ego:
378
+ legacy_ego_motion_command += 1
379
+
380
+ print(
381
+ "[Planning route_command audit] "
382
+ f"mode={self.planning_table3_mode} "
383
+ f"top_level_coverage={top_level_count}/{total} "
384
+ f"legacy_ego_motion_command={legacy_ego_motion_command}/{total} "
385
+ f"distribution={route_command_dist}"
386
+ )
387
+ if self.planning_table3_mode != "atlas_base" and top_level_count < total:
388
+ print(
389
+ "[WARN] planning high-level mode requested but top-level "
390
+ "route_command coverage is incomplete."
391
+ )
392
+ if legacy_ego_motion_command > 0:
393
+ print(
394
+ "[WARN] legacy planning schema detected: ego_motion.command "
395
+ "is still present in the loaded JSON."
396
+ )
397
+
398
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
399
+ item = self.data[idx]
400
+ task_type = self._task_types[idx]
401
+
402
+ img_paths = item["image_paths"]
403
+ if self.image_path_remap:
404
+ old, new = self.image_path_remap
405
+ img_paths = [p.replace(old, new) for p in img_paths]
406
+
407
+ cam_out = self._load_images_with_cameras(
408
+ img_paths,
409
+ item=item,
410
+ )
411
+ pixel_values = cam_out["pixel_values"]
412
+ intrinsics = cam_out["intrinsics"]
413
+ extrinsics = cam_out["extrinsics"]
414
+ lidar2img = cam_out["lidar2img"]
415
+ ego_pose = cam_out.get("ego_pose")
416
+ ego_pose_inv = cam_out.get("ego_pose_inv")
417
+ timestamp = cam_out.get("timestamp")
418
+
419
+ prompt_raw, answer_raw = self._extract_conversation(item)
420
+ if task_type == "planning":
421
+ prompt_raw = self._rewrite_planning_prompt(prompt_raw, item)
422
+ expected_num_queries = self._infer_expected_query_count(prompt_raw, item=item)
423
+ prompt_text = self._expand_query_placeholders(
424
+ prompt_raw,
425
+ expected_num_queries,
426
+ item=item,
427
+ )
428
+
429
+ prompt_str = build_prompt(prompt_text, mode="train" if self.is_training else "infer")
430
+ answer_str = f" {answer_raw}"
431
+
432
+ prompt_ids = self.tokenizer(prompt_str, add_special_tokens=False)["input_ids"]
433
+ answer_ids = self.tokenizer(answer_str, add_special_tokens=False)["input_ids"]
434
+
435
+ bos = [self.tokenizer.bos_token_id] if self.tokenizer.bos_token_id is not None else []
436
+ eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
437
+
438
+ if self.is_training:
439
+ input_ids_full = bos + prompt_ids + answer_ids + eos
440
+ else:
441
+ input_ids_full = bos + prompt_ids
442
+
443
+ input_ids = input_ids_full
444
+ if len(input_ids) > self.max_length:
445
+ input_ids = input_ids[: self.max_length]
446
+ attention_mask = [1] * len(input_ids)
447
+
448
+ num_query_tokens = sum(1 for t in input_ids if t == self.query_token_id)
449
+ if num_query_tokens != expected_num_queries:
450
+ raise ValueError(f"<query> mismatch: expected {expected_num_queries}, got {num_query_tokens}")
451
+
452
+ if self.is_training:
453
+ labels = input_ids.copy()
454
+ prompt_len = len(bos) + len(prompt_ids)
455
+ labels[:prompt_len] = [-100] * prompt_len
456
+
457
+ first_nonmasked_index = -1
458
+ for i, t in enumerate(labels):
459
+ if t != -100:
460
+ first_nonmasked_index = i
461
+ break
462
+ labels_nonmasked_count = sum(1 for t in labels if t != -100)
463
+
464
+ assert len(labels) == len(input_ids), "labels/input_ids length mismatch"
465
+ if labels_nonmasked_count > 0:
466
+ assert first_nonmasked_index == prompt_len, (
467
+ f"first_nonmasked_index={first_nonmasked_index} != prompt_len={prompt_len}"
468
+ )
469
+ assert labels_nonmasked_count > 0, (
470
+ f"all labels are -100: len_full={len(input_ids_full)} len_trunc={len(input_ids)} max_length={self.max_length}"
471
+ )
472
+ if audit_enabled():
473
+ if not hasattr(self, "_audit_trunc_total"):
474
+ self._audit_trunc_total = 0
475
+ self._audit_trunc_hits = 0
476
+ self._audit_trunc_total += 1
477
+ truncated = int(len(input_ids_full) > self.max_length)
478
+ self._audit_trunc_hits += truncated
479
+ trunc_rate = float(self._audit_trunc_hits) / float(self._audit_trunc_total)
480
+ min_ans = int(os.getenv("ATLAS_MIN_ANSWER_TOKENS", "16"))
481
+ ok = (labels_nonmasked_count >= min_ans)
482
+ audit_check(
483
+ "A6",
484
+ ok,
485
+ once=False,
486
+ truncated=truncated,
487
+ trunc_rate=trunc_rate,
488
+ labels_nonmasked_count=labels_nonmasked_count,
489
+ min_answer_tokens=min_ans,
490
+ )
491
+ else:
492
+ labels = [-100] * len(input_ids)
493
+ prompt_len = len(input_ids)
494
+ first_nonmasked_index = -1
495
+ labels_nonmasked_count = 0
496
+ scene_id = self._get_scene_id(item)
497
+ sample_id = str(item.get("id", idx))
498
+
499
+ gt_boxes, gt_labels = self._load_gt_boxes(item)
500
+
501
+ result = {
502
+ "pixel_values": pixel_values,
503
+ "pixel_values_det": cam_out["pixel_values_det"],
504
+ "pixel_values_map": cam_out["pixel_values_map"],
505
+ "intrinsics": intrinsics,
506
+ "intrinsics_det": cam_out["intrinsics_det"],
507
+ "intrinsics_map": cam_out["intrinsics_map"],
508
+ "extrinsics": extrinsics,
509
+ "lidar2img": lidar2img,
510
+ "lidar2img_det": cam_out["lidar2img_det"],
511
+ "lidar2img_map": cam_out["lidar2img_map"],
512
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
513
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
514
+ "labels": torch.tensor(labels, dtype=torch.long),
515
+ "scene_id": scene_id,
516
+ "sample_id": sample_id,
517
+ "dataset_idx": torch.tensor(idx, dtype=torch.long),
518
+ "task_type": task_type,
519
+ "audit_prompt_len": torch.tensor(prompt_len, dtype=torch.long),
520
+ "audit_answer_len": torch.tensor(len(answer_ids), dtype=torch.long),
521
+ "audit_labels_nonmasked_count": torch.tensor(labels_nonmasked_count, dtype=torch.long),
522
+ "audit_first_nonmasked_index": torch.tensor(first_nonmasked_index, dtype=torch.long),
523
+ "audit_num_query_tokens_in_input_ids": torch.tensor(num_query_tokens, dtype=torch.long),
524
+ "audit_expected_num_queries": torch.tensor(expected_num_queries, dtype=torch.long),
525
+ "audit_truncated": torch.tensor(int(len(input_ids_full) > self.max_length), dtype=torch.long),
526
+ }
527
+
528
+ if ego_pose is not None:
529
+ result["ego_pose"] = ego_pose
530
+ if ego_pose_inv is not None:
531
+ result["ego_pose_inv"] = ego_pose_inv
532
+ if timestamp is not None:
533
+ result["timestamp"] = timestamp
534
+ try:
535
+ ego_motion = self._get_ego_motion_data(item)
536
+ if isinstance(ego_motion, dict) and "velocity" in ego_motion:
537
+ result["velocity"] = torch.tensor(ego_motion["velocity"], dtype=torch.float32)
538
+ except Exception:
539
+ pass
540
+
541
+ if gt_boxes is not None:
542
+ result["gt_boxes"] = gt_boxes
543
+ result["gt_labels"] = gt_labels
544
+
545
+ if self.precomputed_det_dir:
546
+ pt = self._load_precomputed(self.precomputed_det_dir, item)
547
+ if pt is not None:
548
+ result["precomputed_det"] = pt["detection"]
549
+ result["precomputed_det_ref"] = pt["detection_ref_points"]
550
+
551
+ if self.precomputed_map_dir:
552
+ mpt = self._load_precomputed(self.precomputed_map_dir, item)
553
+ if mpt is not None:
554
+ result["precomputed_map"] = mpt
555
+
556
+ if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
557
+ max_samples = int(os.getenv("ATLAS_AUDIT_MAX_SAMPLES", "1"))
558
+ if idx < max_samples:
559
+ print(
560
+ "[ATLAS_AUDIT][A1/A3] "
561
+ f"idx={idx} "
562
+ f"prompt_len={prompt_len} answer_len={len(answer_ids)} "
563
+ f"first_nonmasked_index={first_nonmasked_index} labels_nonmasked_count={labels_nonmasked_count} "
564
+ f"num_query_tokens_in_input_ids={num_query_tokens} expected_num_queries={expected_num_queries} "
565
+ f"seq_len={len(input_ids)} truncated={int(len(input_ids_full) > self.max_length)}"
566
+ )
567
+ return result
568
+
569
+ def _resolve_category_id(self, label) -> int:
570
+ """将类别标签转换为类别 ID,支持整数、浮点数和字符串类别名"""
571
+ if isinstance(label, (int, float)):
572
+ return int(label)
573
+ if isinstance(label, str):
574
+ label_lower = label.lower().strip()
575
+ if label_lower in NUSCENES_CATEGORY_MAP:
576
+ return NUSCENES_CATEGORY_MAP[label_lower]
577
+ # 对 human.pedestrian.* 子类使用前缀匹配
578
+ if label_lower.startswith('human.pedestrian.'):
579
+ return 8 # pedestrian
580
+ # 对 vehicle.bus.* 子类使用前缀匹配
581
+ if label_lower.startswith('vehicle.bus.'):
582
+ return 3 # bus
583
+ return 0
584
+
585
+ def _load_gt_boxes(self, item: Dict) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
586
+ annotations = item.get("annotations", item.get("gt_boxes_3d", None))
587
+
588
+ if annotations is None:
589
+ return None, None
590
+
591
+ if not annotations:
592
+ return torch.zeros(0, 7), torch.zeros(0, dtype=torch.long)
593
+
594
+ boxes_list = []
595
+ labels_list = []
596
+ bev_range = 51.2
597
+
598
+ for ann in annotations:
599
+ if isinstance(ann, dict):
600
+ if "translation" in ann:
601
+ pos_x, pos_y, pos_z = ann["translation"][:3]
602
+ size_w, size_l, size_h = ann.get("size", [1, 1, 1])[:3]
603
+ yaw = ann.get("rotation", 0)
604
+ if isinstance(yaw, (list, tuple)):
605
+ if len(yaw) == 1:
606
+ yaw = float(yaw[0])
607
+ elif len(yaw) == 4:
608
+ # Quaternion (w, x, y, z) -> yaw
609
+ qw, qx, qy, qz = [float(v) for v in yaw]
610
+ t0 = 2.0 * (qw * qz + qx * qy)
611
+ t1 = 1.0 - 2.0 * (qy * qy + qz * qz)
612
+ yaw = math.atan2(t0, t1)
613
+ else:
614
+ yaw = 0.0
615
+ # 支持多种类别字段: category_id, category, category_name
616
+ label = ann.get("category_id", ann.get("category", ann.get("category_name", 0)))
617
+ x, y, z = pos_x, pos_y, pos_z
618
+ w, l, h = size_w, size_l, size_h
619
+
620
+ x_norm = (x + bev_range) / (2 * bev_range)
621
+ y_norm = (y + bev_range) / (2 * bev_range)
622
+ z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
623
+
624
+ boxes_list.append([x_norm, y_norm, z_norm, w, l, h, yaw])
625
+ labels_list.append(self._resolve_category_id(label))
626
+
627
+ elif "box" in ann:
628
+ box = ann["box"]
629
+ x, y, z = box[:3]
630
+ w, l, h = box[3:6] if len(box) >= 6 else (1, 1, 1)
631
+ yaw = box[6] if len(box) >= 7 else 0
632
+ label = ann.get("category_id", ann.get("label", ann.get("category_name", 0)))
633
+
634
+ x_norm = (x + bev_range) / (2 * bev_range)
635
+ y_norm = (y + bev_range) / (2 * bev_range)
636
+ z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
637
+
638
+ boxes_list.append([x_norm, y_norm, z_norm, w, l, h, yaw])
639
+ labels_list.append(self._resolve_category_id(label))
640
+
641
+ elif "world_coords" in ann:
642
+ wc = ann["world_coords"]
643
+ x, y, z = wc[0], wc[1], wc[2] if len(wc) > 2 else 0.0
644
+ w = ann.get("w", 1.0)
645
+ l = ann.get("l", 1.0)
646
+ h = ann.get("h", 1.0)
647
+ yaw = ann.get("yaw", 0.0)
648
+ label = ann.get("category_id", ann.get("category", ann.get("category_name", 0)))
649
+
650
+ x_norm = (x + bev_range) / (2 * bev_range)
651
+ y_norm = (y + bev_range) / (2 * bev_range)
652
+ z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
653
+
654
+ boxes_list.append([x_norm, y_norm, z_norm, w, l, h, yaw])
655
+ labels_list.append(self._resolve_category_id(label))
656
+
657
+ elif isinstance(ann, (list, tuple)) and len(ann) >= 3:
658
+ x, y, z = ann[:3]
659
+ w, l, h = ann[3:6] if len(ann) >= 6 else (1, 1, 1)
660
+ yaw = ann[6] if len(ann) >= 7 else 0
661
+
662
+ x_norm = (x + bev_range) / (2 * bev_range)
663
+ y_norm = (y + bev_range) / (2 * bev_range)
664
+ z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
665
+
666
+ boxes_list.append([x_norm, y_norm, z_norm, w, l, h, yaw])
667
+ labels_list.append(0)
668
+
669
+ if not boxes_list:
670
+ return torch.zeros(0, 7), torch.zeros(0, dtype=torch.long)
671
+
672
+ if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
673
+ if not hasattr(self, "_audit_gt_calls"):
674
+ self._audit_gt_calls = 0
675
+ max_calls = int(os.getenv("ATLAS_AUDIT_MAX_GT", "1"))
676
+ if self._audit_gt_calls < max_calls:
677
+ yaws = [float(b[6]) for b in boxes_list if len(b) >= 7]
678
+ if yaws:
679
+ y_min = float(min(yaws))
680
+ y_max = float(max(yaws))
681
+ y_abs = float(max(abs(y) for y in yaws))
682
+ print(f"[ATLAS_AUDIT][E2/E3] gt_yaw_min={y_min:.3e} gt_yaw_max={y_max:.3e} gt_yaw_absmax={y_abs:.3e}")
683
+ if y_abs > 10.0:
684
+ print("[ATLAS_AUDIT][E2] yaw_absmax>10 (possible degrees instead of radians)")
685
+ self._audit_gt_calls += 1
686
+
687
+ gt_boxes = torch.tensor(boxes_list, dtype=torch.float32)
688
+ gt_labels = torch.tensor(labels_list, dtype=torch.long)
689
+
690
+ return gt_boxes, gt_labels
691
+
692
+ def _rewrite_planning_prompt(self, prompt_text: str, item: Dict) -> str:
693
+ ego_motion = item.get("ego_motion", {})
694
+ if not isinstance(ego_motion, dict):
695
+ ego_motion = {}
696
+
697
+ route_command = self._resolve_route_command(item)
698
+ velocity = ego_motion.get("velocity")
699
+ acceleration = ego_motion.get("acceleration")
700
+
701
+ velocity_bins = None
702
+ acceleration_bins = None
703
+ if velocity is not None:
704
+ if not isinstance(velocity, (list, tuple)) or len(velocity) < 2:
705
+ raise RuntimeError("planning ego_motion.velocity must be a 2D vector")
706
+ velocity_bins = (
707
+ planning_state_to_bin(float(velocity[0])),
708
+ planning_state_to_bin(float(velocity[1])),
709
+ )
710
+ if acceleration is not None:
711
+ if not isinstance(acceleration, (list, tuple)) or len(acceleration) < 2:
712
+ raise RuntimeError("planning ego_motion.acceleration must be a 2D vector")
713
+ acceleration_bins = (
714
+ planning_state_to_bin(float(acceleration[0])),
715
+ planning_state_to_bin(float(acceleration[1])),
716
+ )
717
+
718
+ return rewrite_planning_prompt_for_table3(
719
+ prompt_text,
720
+ mode=self.planning_table3_mode,
721
+ command=route_command,
722
+ velocity_bins=velocity_bins,
723
+ acceleration_bins=acceleration_bins,
724
+ )
725
+
726
+ def _resolve_route_command(self, item: Dict) -> Optional[str]:
727
+ candidates = [
728
+ item.get("route_command"),
729
+ item.get("nav_command"),
730
+ item.get("high_level_command"),
731
+ item.get("navigation_command"),
732
+ ]
733
+ meta = item.get("meta_data")
734
+ if isinstance(meta, dict):
735
+ candidates.extend([
736
+ meta.get("route_command"),
737
+ meta.get("nav_command"),
738
+ meta.get("high_level_command"),
739
+ meta.get("navigation_command"),
740
+ ])
741
+ for candidate in candidates:
742
+ normalized = normalize_route_command(candidate)
743
+ if normalized is not None:
744
+ return normalized
745
+ return None
746
+
747
+ def _extract_conversation(self, item: Dict) -> Tuple[str, str]:
748
+ conv = item.get("conversations", None)
749
+ if not conv or not isinstance(conv, list):
750
+ raise RuntimeError("missing conversations field")
751
+ prompt = None
752
+ answer = None
753
+ for turn in conv:
754
+ if not isinstance(turn, dict):
755
+ continue
756
+ if turn.get("from") in ("human", "user") and prompt is None:
757
+ prompt = turn.get("value")
758
+ if turn.get("from") in ("gpt", "assistant") and answer is None:
759
+ answer = turn.get("value")
760
+ if prompt is None or answer is None:
761
+ raise RuntimeError("conversations missing human/gpt pair")
762
+
763
+ return str(prompt), str(answer)
764
+
765
+ def _infer_expected_query_count(self, prompt_text: str, item: Optional[Dict] = None) -> int:
766
+ if isinstance(item, dict):
767
+ if "num_map_queries" in item:
768
+ try:
769
+ num_map = int(item.get("num_map_queries", 0))
770
+ if num_map < 0:
771
+ num_map = 0
772
+ return self.num_detection_queries + num_map
773
+ except Exception:
774
+ pass
775
+ use_map = item.get("use_map_queries", None)
776
+ if use_map is not None:
777
+ return self.num_detection_queries + (self.num_map_queries if bool(use_map) else 0)
778
+ p = prompt_text.lower()
779
+ if "map query" in p:
780
+ return self.num_detection_queries + self.num_map_queries
781
+ return self.num_detection_queries
782
+
783
+ def _expand_query_placeholders(
784
+ self,
785
+ prompt_text: str,
786
+ expected_num_queries: int,
787
+ item: Optional[Dict] = None,
788
+ ) -> str:
789
+ cnt = prompt_text.count(self.query_token)
790
+ if cnt == 1:
791
+ placeholder = " ".join([self.query_token] * expected_num_queries)
792
+ out = prompt_text.replace(self.query_token, f" {placeholder} ")
793
+ return " ".join(out.split())
794
+ if cnt == 2 and expected_num_queries > self.num_detection_queries:
795
+ num_map = expected_num_queries - self.num_detection_queries
796
+ if isinstance(item, dict) and "num_map_queries" in item:
797
+ try:
798
+ num_map = int(item.get("num_map_queries", num_map))
799
+ except Exception:
800
+ pass
801
+ num_map = max(0, min(num_map, expected_num_queries))
802
+ num_det = expected_num_queries - num_map
803
+ parts = prompt_text.split(self.query_token)
804
+ if len(parts) == 3 and num_det > 0 and num_map > 0:
805
+ # Latest planning prompts use two placeholders: the first stands
806
+ # for detection query slots and the second for map query slots.
807
+ det_placeholder = " ".join([self.query_token] * num_det)
808
+ map_placeholder = " ".join([self.query_token] * num_map)
809
+ out = (
810
+ f"{parts[0]} {det_placeholder} "
811
+ f"{parts[1]} {map_placeholder} "
812
+ f"{parts[2]}"
813
+ )
814
+ return " ".join(out.split())
815
+ if cnt == expected_num_queries:
816
+ return " ".join(prompt_text.split())
817
+ if expected_num_queries > self.num_detection_queries:
818
+ raise ValueError(
819
+ f"<query> count mismatch: got {cnt}, expected 1, 2, or {expected_num_queries}"
820
+ )
821
+ raise ValueError(f"<query> count mismatch: got {cnt}, expected 1 or {expected_num_queries}")
822
+
823
+ def _load_precomputed(self, directory: str, item: Dict) -> Optional[Dict]:
824
+ item_id = str(item.get("id", ""))
825
+ pt_path = os.path.join(directory, f"{item_id}.pt")
826
+ if os.path.isfile(pt_path):
827
+ try:
828
+ return torch.load(pt_path, map_location="cpu")
829
+ except Exception:
830
+ return None
831
+ meta = item.get("meta_data", {})
832
+ if isinstance(meta, dict):
833
+ source_id = meta.get("source_id")
834
+ if source_id:
835
+ pt_path2 = os.path.join(directory, f"{source_id}.pt")
836
+ if os.path.isfile(pt_path2):
837
+ try:
838
+ return torch.load(pt_path2, map_location="cpu")
839
+ except Exception:
840
+ pass
841
+ return None
842
+
843
+ def _get_scene_id(self, item: Dict) -> str:
844
+ if "segment_id" in item and item["segment_id"]:
845
+ return str(item["segment_id"])
846
+ try:
847
+ p0 = item.get("image_paths", [None])[0]
848
+ if not p0:
849
+ return "unknown"
850
+ fname = os.path.basename(p0)
851
+ parts = fname.split("__")
852
+ if len(parts) > 1:
853
+ return parts[0]
854
+ return "unknown"
855
+ except Exception:
856
+ return "unknown"
857
+
858
+ def _get_timestamp(self, item: Dict) -> int:
859
+ if "timestamp" in item and item["timestamp"]:
860
+ return int(item["timestamp"])
861
+ try:
862
+ p0 = item.get("image_paths", [None])[0]
863
+ fname = os.path.basename(p0)
864
+ parts = fname.replace(".jpg", "").split("__")
865
+ return int(parts[-1]) if parts else 0
866
+ except Exception:
867
+ return 0
868
+
869
+ def get_scene_groups(self) -> Dict[str, List[int]]:
870
+ groups: Dict[str, List[int]] = {}
871
+ for idx, item in enumerate(self.data):
872
+ sid = self._get_scene_id(item)
873
+ groups.setdefault(sid, []).append(idx)
874
+ for sid in groups:
875
+ groups[sid].sort(key=lambda i: self._get_timestamp(self.data[i]))
876
+ return groups
877
+
878
+ def _get_ego_motion_data(self, item: Dict) -> Dict:
879
+ route_cmd = self._resolve_route_command(item) or "go straight"
880
+
881
+ if 'ego_motion' in item:
882
+ ego = item['ego_motion']
883
+
884
+ # JSON already stores paper-frame values (gen_atlas_planning_qa.py
885
+ # applies _nuscenes_to_paper_xy before writing). Do NOT transform again.
886
+ vel = ego.get('velocity', [0.0, 0.0])
887
+ acc = ego.get('acceleration', [0.0, 0.0])
888
+ vx, vy = float(vel[0]), float(vel[1])
889
+ ax, ay = float(acc[0]), float(acc[1])
890
+
891
+ if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
892
+ if not hasattr(self, "_audit_ego_calls"):
893
+ self._audit_ego_calls = 0
894
+ max_calls = int(os.getenv("ATLAS_AUDIT_MAX_EGO", "1"))
895
+ if self._audit_ego_calls < max_calls:
896
+ print(f"[ATLAS_AUDIT][ego] vel=({vx:.3e},{vy:.3e}) acc=({ax:.3e},{ay:.3e}) [paper-frame, no transform]")
897
+ self._audit_ego_calls += 1
898
+
899
+ if 'waypoints' in ego:
900
+ waypoints_raw = ego['waypoints']
901
+ waypoints = [[float(wp[0]), float(wp[1])] for wp in waypoints_raw]
902
+ else:
903
+ waypoints = self._generate_waypoints(route_cmd, [vx, vy], [ax, ay])
904
+
905
+ return {
906
+ 'velocity': [vx, vy],
907
+ 'acceleration': [ax, ay],
908
+ 'waypoints': waypoints,
909
+ }
910
+
911
+ if not hasattr(self, "_warned_missing_ego_motion"):
912
+ self._warned_missing_ego_motion = True
913
+ print("[WARN] ego_motion missing, using stationary default (velocity=[0,0])")
914
+
915
+ velocity = [0.0, 0.0]
916
+ acceleration = [0.0, 0.0]
917
+ waypoints = self._generate_waypoints(route_cmd, velocity, acceleration)
918
+
919
+ return {
920
+ 'velocity': velocity,
921
+ 'acceleration': acceleration,
922
+ 'waypoints': waypoints,
923
+ }
924
+
925
+ def _generate_waypoints(
926
+ self,
927
+ command: str,
928
+ velocity: List[float] = None,
929
+ acceleration: List[float] = None,
930
+ ) -> List[List[float]]:
931
+ if velocity is None:
932
+ velocity = [0.0, 5.0]
933
+ if acceleration is None:
934
+ acceleration = [0.0, 0.0]
935
+
936
+ vx, vy = velocity
937
+ ax, ay = acceleration
938
+ waypoints = []
939
+
940
+ for i in range(1, 7):
941
+ t = i * 0.5
942
+
943
+ x = vx * t + 0.5 * ax * t * t
944
+ y = vy * t + 0.5 * ay * t * t
945
+
946
+ if command == "turn left":
947
+ curvature = -0.3 * t * t
948
+ x += curvature
949
+ elif command == "turn right":
950
+ curvature = 0.3 * t * t
951
+ x += curvature
952
+
953
+ waypoints.append([round(x, 2), round(y, 2)])
954
+
955
+ return waypoints
956
+
957
+ def _preprocess_streampetr(self, img_pil: Image.Image, K: np.ndarray):
958
+ W, H = img_pil.size
959
+ conf = self.streampetr_conf
960
+ fH, fW = conf["final_dim"]
961
+ resize = max(fH / H, fW / W)
962
+ rW, rH = int(W * resize), int(H * resize)
963
+ crop_h = int(rH) - fH
964
+ crop_w = max(0, rW - fW) // 2
965
+
966
+ if resize != 1.0:
967
+ img_pil = img_pil.resize((rW, rH), Image.BILINEAR)
968
+ img_pil = img_pil.crop((crop_w, crop_h, crop_w + fW, crop_h + fH))
969
+
970
+ K_new = K.copy()
971
+ K_new[0, 0] *= resize
972
+ K_new[1, 1] *= resize
973
+ K_new[0, 2] = K_new[0, 2] * resize - crop_w
974
+ K_new[1, 2] = K_new[1, 2] * resize - crop_h
975
+ return img_pil, K_new
976
+
977
+ def _preprocess_topomlp(self, img_pil: Image.Image, K: np.ndarray):
978
+ W, H = img_pil.size
979
+ tW, tH = self.topomlp_conf["target_size"]
980
+ w_scale = tW / W
981
+ h_scale = tH / H
982
+
983
+ img_pil = img_pil.resize((tW, tH), Image.BILINEAR)
984
+
985
+ K_new = K.copy()
986
+ K_new[0, 0] *= w_scale
987
+ K_new[0, 2] *= w_scale
988
+ K_new[1, 1] *= h_scale
989
+ K_new[1, 2] *= h_scale
990
+ return img_pil, K_new
991
+
992
+ def _load_images_with_cameras(
993
+ self,
994
+ image_paths: List[str],
995
+ item: Optional[Dict] = None,
996
+ ) -> Dict:
997
+ images_det = []
998
+ images_map = []
999
+ intrinsics_det_list = []
1000
+ intrinsics_map_list = []
1001
+ extrinsics_list = []
1002
+ lidar2img_det_list = []
1003
+ lidar2img_map_list = []
1004
+ ego_pose_out = None
1005
+ ego_pose_inv_out = None
1006
+ timestamp_out = None
1007
+
1008
+ for i, img_path in enumerate(image_paths):
1009
+ camera_name = CAMERA_NAMES[i] if i < len(CAMERA_NAMES) else f"CAM_{i}"
1010
+ for cam in sorted(CAMERA_NAMES, key=len, reverse=True):
1011
+ if cam in img_path:
1012
+ camera_name = cam
1013
+ break
1014
+
1015
+ def _normalize_path(p: str) -> str:
1016
+ return str(p).replace("\\", "/").lstrip("./")
1017
+
1018
+ def _lookup_sample_data(path: str) -> Optional[Dict]:
1019
+ if self.calibration is None:
1020
+ return None
1021
+ by_name = self.calibration["sample_data_by_filename"]
1022
+ candidates = []
1023
+ raw = str(path)
1024
+ norm = _normalize_path(raw)
1025
+ candidates.append(raw)
1026
+ candidates.append(norm)
1027
+ for key in ("samples/", "sweeps/"):
1028
+ if key in norm:
1029
+ candidates.append(norm[norm.index(key):])
1030
+ for cand in candidates:
1031
+ rec = by_name.get(cand, None)
1032
+ if rec is not None:
1033
+ return rec
1034
+ return None
1035
+
1036
+ full_path = os.path.normpath(os.path.join(self.image_root, img_path))
1037
+ if not os.path.isabs(img_path):
1038
+ full_path = os.path.normpath(os.path.join(self.image_root, img_path))
1039
+ else:
1040
+ full_path = img_path
1041
+ try:
1042
+ img = Image.open(full_path).convert("RGB")
1043
+ except Exception as e:
1044
+ sample_id = "unknown"
1045
+ task_type = "unknown"
1046
+ if isinstance(item, dict):
1047
+ sample_id = str(item.get("id", item.get("sample_id", "unknown")))
1048
+ try:
1049
+ task_type = infer_task_type(item)
1050
+ except Exception:
1051
+ task_type = str(item.get("task_type", item.get("task", "unknown")))
1052
+ raise RuntimeError(
1053
+ "failed to load image for AtlasDataset: "
1054
+ f"sample_id={sample_id} task_type={task_type} "
1055
+ f"camera_name={camera_name} image_path={img_path} "
1056
+ f"full_path={full_path} image_root={self.image_root}"
1057
+ ) from e
1058
+
1059
+ K = None
1060
+ E = None
1061
+ ep = None
1062
+ sd = None
1063
+
1064
+ if self.calibration is not None:
1065
+ sd = _lookup_sample_data(img_path)
1066
+ if sd is not None:
1067
+ cs_token = sd.get("calibrated_sensor_token", None)
1068
+ if cs_token is None:
1069
+ raise RuntimeError(f"sample_data missing calibrated_sensor_token: {img_path}")
1070
+ cs = self.calibration["calibrated_sensor_by_token"].get(cs_token, None)
1071
+ if cs is None:
1072
+ raise RuntimeError(f"calibrated_sensor not found: {cs_token}")
1073
+
1074
+ ep_token = sd.get("ego_pose_token", None)
1075
+ if ep_token is None:
1076
+ raise RuntimeError(f"sample_data missing ego_pose_token: {img_path}")
1077
+ ep = self.calibration["ego_pose_by_token"].get(ep_token, None)
1078
+ if ep is None:
1079
+ raise RuntimeError(f"ego_pose not found: {ep_token}")
1080
+
1081
+ K = np.array(cs["camera_intrinsic"], dtype=np.float32)
1082
+ q = cs["rotation"]
1083
+ t = cs["translation"]
1084
+ w, x, y, z = q
1085
+ R = np.array(
1086
+ [
1087
+ [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
1088
+ [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
1089
+ [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
1090
+ ],
1091
+ dtype=np.float32,
1092
+ )
1093
+ E = np.eye(4, dtype=np.float32)
1094
+ E[:3, :3] = R
1095
+ E[:3, 3] = np.array(t, dtype=np.float32)
1096
+
1097
+ if ego_pose_out is None and self.calibration is not None:
1098
+ _item_id = str(item.get("id", "")) if isinstance(item, dict) else ""
1099
+ _lidar_sd = self.calibration.get("lidar_sd_by_sample_token", {}).get(_item_id)
1100
+ if _lidar_sd is not None:
1101
+ _lidar_cs = self.calibration["calibrated_sensor_by_token"].get(
1102
+ _lidar_sd.get("calibrated_sensor_token"), None)
1103
+ _lidar_ep = self.calibration["ego_pose_by_token"].get(
1104
+ _lidar_sd.get("ego_pose_token"), None)
1105
+ if _lidar_cs is not None and _lidar_ep is not None:
1106
+ def _q2R(q):
1107
+ ww, xx, yy, zz = q
1108
+ return np.array([
1109
+ [1-2*(yy*yy+zz*zz), 2*(xx*yy-zz*ww), 2*(xx*zz+yy*ww)],
1110
+ [2*(xx*yy+zz*ww), 1-2*(xx*xx+zz*zz), 2*(yy*zz-xx*ww)],
1111
+ [2*(xx*zz-yy*ww), 2*(yy*zz+xx*ww), 1-2*(xx*xx+yy*yy)],
1112
+ ], dtype=np.float32)
1113
+ _l2e = np.eye(4, dtype=np.float32)
1114
+ _l2e[:3, :3] = _q2R(_lidar_cs["rotation"])
1115
+ _l2e[:3, 3] = np.array(_lidar_cs["translation"], dtype=np.float32)
1116
+ _e2g = np.eye(4, dtype=np.float32)
1117
+ _e2g[:3, :3] = _q2R(_lidar_ep["rotation"])
1118
+ _e2g[:3, 3] = np.array(_lidar_ep["translation"], dtype=np.float32)
1119
+ _lidar2global = (_e2g @ _l2e).astype(np.float32)
1120
+ ego_pose_out = torch.tensor(_lidar2global, dtype=torch.float32)
1121
+ try:
1122
+ ego_pose_inv_out = torch.tensor(np.linalg.inv(_lidar2global), dtype=torch.float32)
1123
+ except Exception:
1124
+ ego_pose_inv_out = None
1125
+ _lidar_ts = _lidar_sd.get("timestamp", None)
1126
+ if _lidar_ts is not None:
1127
+ timestamp_out = torch.tensor(float(_lidar_ts) * 1e-6, dtype=torch.float32)
1128
+ if ego_pose_out is None and ep is not None:
1129
+ q_ep = ep.get("rotation", None)
1130
+ t_ep = ep.get("translation", None)
1131
+ if q_ep is not None and t_ep is not None:
1132
+ w, x, y, z = q_ep
1133
+ R_ep = np.array(
1134
+ [
1135
+ [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
1136
+ [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
1137
+ [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
1138
+ ],
1139
+ dtype=np.float32,
1140
+ )
1141
+ ego_pose_m = np.eye(4, dtype=np.float32)
1142
+ ego_pose_m[:3, :3] = R_ep
1143
+ ego_pose_m[:3, 3] = np.array(t_ep, dtype=np.float32)
1144
+ ego_pose_out = torch.tensor(ego_pose_m, dtype=torch.float32)
1145
+ try:
1146
+ ego_pose_inv_out = torch.tensor(np.linalg.inv(ego_pose_m), dtype=torch.float32)
1147
+ except Exception:
1148
+ ego_pose_inv_out = None
1149
+ if timestamp_out is None and sd is not None:
1150
+ ts = sd.get("timestamp", None)
1151
+ if ts is not None:
1152
+ timestamp_out = torch.tensor(float(ts) * 1e-6, dtype=torch.float32)
1153
+
1154
+ if K is None or E is None:
1155
+ sensor = (item or {}).get("sensor", None) if isinstance(item, dict) else None
1156
+ if not isinstance(sensor, dict):
1157
+ raise RuntimeError(f"no camera params for {img_path}")
1158
+ if camera_name not in sensor:
1159
+ raise RuntimeError(f"sensor missing camera {camera_name}")
1160
+ cam_s = sensor[camera_name]
1161
+ try:
1162
+ K = np.array(cam_s["intrinsic"]["K"], dtype=np.float32)
1163
+ R = np.array(cam_s["extrinsic"]["rotation"], dtype=np.float32)
1164
+ t = np.array(cam_s["extrinsic"]["translation"], dtype=np.float32)
1165
+ E = np.eye(4, dtype=np.float32)
1166
+ E[:3, :3] = R
1167
+ E[:3, 3] = t
1168
+ except Exception as e:
1169
+ raise RuntimeError(f"failed to parse sensor for {camera_name}: {e}")
1170
+
1171
+ img_det, K_det = self._preprocess_streampetr(img.copy(), K.copy())
1172
+ img_map, K_map = self._preprocess_topomlp(img.copy(), K.copy())
1173
+
1174
+ images_det.append(self.image_transform(img_det))
1175
+ images_map.append(self.image_transform(img_map))
1176
+
1177
+ intrinsics_det_list.append(torch.tensor(K_det, dtype=torch.float32))
1178
+ intrinsics_map_list.append(torch.tensor(K_map, dtype=torch.float32))
1179
+ extrinsics_list.append(torch.tensor(E, dtype=torch.float32))
1180
+
1181
+ def _quat_wxyz_to_R(qwxyz):
1182
+ ww, xx, yy, zz = qwxyz
1183
+ return np.array(
1184
+ [
1185
+ [1 - 2 * (yy * yy + zz * zz), 2 * (xx * yy - zz * ww), 2 * (xx * zz + yy * ww)],
1186
+ [2 * (xx * yy + zz * ww), 1 - 2 * (xx * xx + zz * zz), 2 * (yy * zz - xx * ww)],
1187
+ [2 * (xx * zz - yy * ww), 2 * (yy * zz + xx * ww), 1 - 2 * (xx * xx + yy * yy)],
1188
+ ],
1189
+ dtype=np.float32,
1190
+ )
1191
+
1192
+ def _T_from_Rt(Rm, tv):
1193
+ T = np.eye(4, dtype=np.float32)
1194
+ T[:3, :3] = Rm
1195
+ T[:3, 3] = np.array(tv, dtype=np.float32)
1196
+ return T
1197
+
1198
+ def _compute_lidar2img(K_adj, E_mat, sd_rec, ep_rec):
1199
+ cam2ego = E_mat.astype(np.float32)
1200
+ ego2cam = np.linalg.inv(cam2ego)
1201
+ K4 = np.eye(4, dtype=np.float32)
1202
+ K4[:3, :3] = K_adj.astype(np.float32)
1203
+
1204
+ if sd_rec is None or ep_rec is None:
1205
+ return K4 @ ego2cam
1206
+
1207
+ sample_tk = sd_rec.get("sample_token", None)
1208
+ if sample_tk is None:
1209
+ return K4 @ ego2cam
1210
+
1211
+ ego2global_c = _T_from_Rt(_quat_wxyz_to_R(ep_rec["rotation"]), ep_rec["translation"])
1212
+ global2ego_c = np.linalg.inv(ego2global_c)
1213
+
1214
+ lidar_sd_rec = self.calibration.get("lidar_sd_by_sample_token", {}).get(str(sample_tk), None)
1215
+ if lidar_sd_rec is None:
1216
+ return K4 @ ego2cam
1217
+
1218
+ lidar_cs_rec = self.calibration["calibrated_sensor_by_token"].get(
1219
+ lidar_sd_rec.get("calibrated_sensor_token"), None)
1220
+ lidar_ep_rec = self.calibration["ego_pose_by_token"].get(
1221
+ lidar_sd_rec.get("ego_pose_token"), None)
1222
+ if lidar_cs_rec is None or lidar_ep_rec is None:
1223
+ return K4 @ ego2cam
1224
+
1225
+ lidar2ego = _T_from_Rt(_quat_wxyz_to_R(lidar_cs_rec["rotation"]), lidar_cs_rec["translation"])
1226
+ ego2global_lidar = _T_from_Rt(_quat_wxyz_to_R(lidar_ep_rec["rotation"]), lidar_ep_rec["translation"])
1227
+ lidar2cam = ego2cam @ global2ego_c @ ego2global_lidar @ lidar2ego
1228
+ return K4 @ lidar2cam
1229
+
1230
+ lidar2img_det_list.append(
1231
+ torch.tensor(_compute_lidar2img(K_det, E, sd, ep), dtype=torch.float32))
1232
+ lidar2img_map_list.append(
1233
+ torch.tensor(_compute_lidar2img(K_map, E, sd, ep), dtype=torch.float32))
1234
+
1235
+ # Fallback: if nuScenes calibration lookup failed (e.g. OpenLane samples),
1236
+ # try to recover ego_pose from item["pose"] and timestamp from item["timestamp"].
1237
+ if ego_pose_out is None and isinstance(item, dict):
1238
+ pose_data = item.get("pose", None)
1239
+ if isinstance(pose_data, dict):
1240
+ try:
1241
+ rot_raw = pose_data.get("rotation", None)
1242
+ t_p = pose_data.get("translation", None)
1243
+ if rot_raw is not None and t_p is not None:
1244
+ arr = np.array(rot_raw, dtype=np.float32)
1245
+ if arr.shape == (3, 3):
1246
+ R_p = arr
1247
+ elif arr.shape == (4,):
1248
+ w, x, y, z = arr
1249
+ R_p = np.array([
1250
+ [1-2*(y*y+z*z), 2*(x*y-z*w), 2*(x*z+y*w)],
1251
+ [2*(x*y+z*w), 1-2*(x*x+z*z), 2*(y*z-x*w)],
1252
+ [2*(x*z-y*w), 2*(y*z+x*w), 1-2*(x*x+y*y)],
1253
+ ], dtype=np.float32)
1254
+ else:
1255
+ raise ValueError(f"Unsupported rotation shape: {arr.shape}")
1256
+ T_p = np.eye(4, dtype=np.float32)
1257
+ T_p[:3, :3] = R_p
1258
+ T_p[:3, 3] = np.array(t_p, dtype=np.float32)
1259
+ ego_pose_out = torch.tensor(T_p, dtype=torch.float32)
1260
+ try:
1261
+ ego_pose_inv_out = torch.tensor(np.linalg.inv(T_p), dtype=torch.float32)
1262
+ except Exception:
1263
+ ego_pose_inv_out = None
1264
+ except Exception as e:
1265
+ print(f"WARNING: Failed to parse item['pose']: {e}")
1266
+
1267
+ if timestamp_out is None and isinstance(item, dict):
1268
+ ts_raw = item.get("timestamp", None)
1269
+ if ts_raw is not None:
1270
+ try:
1271
+ timestamp_out = torch.tensor(float(ts_raw) * 1e-6, dtype=torch.float32)
1272
+ except Exception:
1273
+ pass
1274
+
1275
+ result = {
1276
+ "pixel_values_det": torch.stack(images_det, dim=0),
1277
+ "pixel_values_map": torch.stack(images_map, dim=0),
1278
+ "intrinsics_det": torch.stack(intrinsics_det_list, dim=0),
1279
+ "intrinsics_map": torch.stack(intrinsics_map_list, dim=0),
1280
+ "extrinsics": torch.stack(extrinsics_list, dim=0),
1281
+ "lidar2img_det": torch.stack(lidar2img_det_list, dim=0),
1282
+ "lidar2img_map": torch.stack(lidar2img_map_list, dim=0),
1283
+ "ego_pose": ego_pose_out,
1284
+ "ego_pose_inv": ego_pose_inv_out,
1285
+ "timestamp": timestamp_out,
1286
+ }
1287
+ result["pixel_values"] = result["pixel_values_det"]
1288
+ result["intrinsics"] = result["intrinsics_det"]
1289
+ result["lidar2img"] = result["lidar2img_det"]
1290
+ return result
1291
+
1292
+
1293
+ def atlas_collate_fn(
1294
+ batch: List[Dict[str, torch.Tensor]],
1295
+ pad_token_id: Optional[int] = None,
1296
+ ) -> Dict[str, torch.Tensor]:
1297
+ pixel_values = torch.stack([item['pixel_values'] for item in batch])
1298
+ pixel_values_det = torch.stack([item['pixel_values_det'] for item in batch])
1299
+ pixel_values_map = torch.stack([item['pixel_values_map'] for item in batch])
1300
+ intrinsics = torch.stack([item['intrinsics'] for item in batch])
1301
+ intrinsics_det = torch.stack([item['intrinsics_det'] for item in batch])
1302
+ intrinsics_map = torch.stack([item['intrinsics_map'] for item in batch])
1303
+ extrinsics = torch.stack([item['extrinsics'] for item in batch])
1304
+
1305
+ def _try_stack(key):
1306
+ if all(key in item for item in batch):
1307
+ try:
1308
+ return torch.stack([item[key] for item in batch])
1309
+ except Exception:
1310
+ pass
1311
+ return None
1312
+
1313
+ lidar2img = _try_stack("lidar2img")
1314
+ lidar2img_det = _try_stack("lidar2img_det")
1315
+ lidar2img_map = _try_stack("lidar2img_map")
1316
+ ego_pose = _try_stack("ego_pose")
1317
+ ego_pose_inv = _try_stack("ego_pose_inv")
1318
+ timestamp = _try_stack("timestamp")
1319
+
1320
+ max_length = max(len(item['input_ids']) for item in batch)
1321
+ batch_size = len(batch)
1322
+
1323
+ if pad_token_id is None:
1324
+ raise RuntimeError("atlas_collate_fn requires explicit pad_token_id")
1325
+
1326
+ input_ids = torch.full((batch_size, max_length), fill_value=pad_token_id, dtype=torch.long)
1327
+ attention_mask = torch.full((batch_size, max_length), fill_value=0, dtype=torch.long)
1328
+ labels = torch.full((batch_size, max_length), fill_value=-100, dtype=torch.long)
1329
+
1330
+ for i, item in enumerate(batch):
1331
+ seq_len = len(item['input_ids'])
1332
+ input_ids[i, :seq_len] = item['input_ids']
1333
+ attention_mask[i, :seq_len] = item['attention_mask']
1334
+ labels[i, :seq_len] = item['labels']
1335
+
1336
+ pad_mask = attention_mask == 0
1337
+ if pad_mask.any():
1338
+ if not torch.all(input_ids[pad_mask] == pad_token_id):
1339
+ raise RuntimeError("padding inconsistent: input_ids")
1340
+ if not torch.all(labels[pad_mask] == -100):
1341
+ raise RuntimeError("padding inconsistent: labels")
1342
+
1343
+ result = {
1344
+ 'pixel_values': pixel_values,
1345
+ 'pixel_values_det': pixel_values_det,
1346
+ 'pixel_values_map': pixel_values_map,
1347
+ 'intrinsics': intrinsics,
1348
+ 'intrinsics_det': intrinsics_det,
1349
+ 'intrinsics_map': intrinsics_map,
1350
+ 'extrinsics': extrinsics,
1351
+ 'input_ids': input_ids,
1352
+ 'attention_mask': attention_mask,
1353
+ 'labels': labels,
1354
+ 'scene_id': [item.get('scene_id', 'unknown') for item in batch],
1355
+ 'sample_id': [item.get('sample_id', 'unknown') for item in batch],
1356
+ 'dataset_idx': torch.stack([item['dataset_idx'] for item in batch]),
1357
+ 'task_type': [item.get('task_type', 'detection') for item in batch],
1358
+ }
1359
+ for k, v in [("lidar2img", lidar2img), ("lidar2img_det", lidar2img_det),
1360
+ ("lidar2img_map", lidar2img_map), ("ego_pose", ego_pose),
1361
+ ("ego_pose_inv", ego_pose_inv), ("timestamp", timestamp)]:
1362
+ if v is not None:
1363
+ result[k] = v
1364
+
1365
+ velocity = None
1366
+ if all("velocity" in item for item in batch):
1367
+ try:
1368
+ velocity = torch.stack([item["velocity"] for item in batch])
1369
+ except Exception:
1370
+ velocity = None
1371
+ if velocity is not None:
1372
+ result["velocity"] = velocity
1373
+
1374
+ if any("precomputed_det" in item for item in batch):
1375
+ det_shape = next(item["precomputed_det"].shape for item in batch if "precomputed_det" in item)
1376
+ ref_shape = next(item["precomputed_det_ref"].shape for item in batch if "precomputed_det" in item)
1377
+ dets, refs = [], []
1378
+ for item in batch:
1379
+ if "precomputed_det" in item:
1380
+ dets.append(item["precomputed_det"])
1381
+ refs.append(item["precomputed_det_ref"])
1382
+ else:
1383
+ dets.append(torch.zeros(det_shape))
1384
+ refs.append(torch.zeros(ref_shape))
1385
+ result["precomputed_det"] = torch.stack(dets)
1386
+ result["precomputed_det_ref"] = torch.stack(refs)
1387
+
1388
+ if any("precomputed_map" in item for item in batch):
1389
+ result["precomputed_map"] = [item.get("precomputed_map") for item in batch]
1390
+
1391
+ audit_keys = [
1392
+ "audit_prompt_len",
1393
+ "audit_answer_len",
1394
+ "audit_labels_nonmasked_count",
1395
+ "audit_first_nonmasked_index",
1396
+ "audit_num_query_tokens_in_input_ids",
1397
+ "audit_expected_num_queries",
1398
+ "audit_truncated",
1399
+ ]
1400
+ for k in audit_keys:
1401
+ if all(k in item for item in batch):
1402
+ try:
1403
+ result[k] = torch.stack([item[k] for item in batch])
1404
+ except Exception:
1405
+ pass
1406
+ has_gt = all('gt_boxes' in item for item in batch)
1407
+ if has_gt:
1408
+ result['gt_boxes'] = [item['gt_boxes'] for item in batch]
1409
+ result['gt_labels'] = [item['gt_labels'] for item in batch]
1410
+
1411
+ return result
1412
+
1413
+
1414
+ def make_atlas_collate_fn(pad_token_id: int):
1415
+ from functools import partial
1416
+ return partial(atlas_collate_fn, pad_token_id=pad_token_id)
src/dataset/scene_sampler.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DDP-safe scene-sequential sampler for online temporal training.
2
+
3
+ Guarantees:
4
+ 1. Within each scene, frames are yielded in strict timestamp order.
5
+ 2. Scene order is shuffled per-epoch for training diversity.
6
+ 3. All ranks yield exactly the same number of micro-steps per epoch
7
+ (balanced by greedy scene assignment + deterministic replay padding).
8
+ 4. Epoch boundaries and replay-scene starts are detectable by the caller
9
+ via timestamp regression, so StreamPETR memory can be reset correctly.
10
+ """
11
+
12
+ import random
13
+ from typing import Dict, Iterator, List, Sequence
14
+
15
+ from torch.utils.data import Sampler
16
+
17
+
18
+ class SceneSequentialSampler(Sampler[int]):
19
+ """Distributed temporal sampler with equal-step guarantee."""
20
+
21
+ def __init__(
22
+ self,
23
+ scene_groups: Dict[str, List[int]],
24
+ num_replicas: int = 1,
25
+ rank: int = 0,
26
+ seed: int = 0,
27
+ shuffle_scenes: bool = True,
28
+ pad_to_multiple: int = 1,
29
+ ):
30
+ if num_replicas < 1:
31
+ raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
32
+ if rank < 0 or rank >= num_replicas:
33
+ raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
34
+ if not scene_groups:
35
+ raise ValueError("scene_groups must not be empty")
36
+
37
+ self.scene_groups = scene_groups
38
+ self.scene_ids = sorted(scene_groups.keys())
39
+ self.num_replicas = int(num_replicas)
40
+ self.rank = int(rank)
41
+ self.seed = int(seed)
42
+ self.shuffle_scenes = bool(shuffle_scenes)
43
+ self.pad_to_multiple = max(1, int(pad_to_multiple))
44
+ self.epoch = 0
45
+ self._cached: List[int] = []
46
+ self._target_len = 0
47
+
48
+ def set_epoch(self, epoch: int) -> None:
49
+ self.epoch = int(epoch)
50
+ self._cached = []
51
+ self._target_len = 0
52
+
53
+ def _scene_len(self, sid: str) -> int:
54
+ return len(self.scene_groups[sid])
55
+
56
+ def _build_indices(self) -> List[int]:
57
+ rng = random.Random(self.seed + self.epoch)
58
+ scene_order = list(self.scene_ids)
59
+ if self.shuffle_scenes:
60
+ rng.shuffle(scene_order)
61
+
62
+ per_rank_scenes: List[List[str]] = [[] for _ in range(self.num_replicas)]
63
+ per_rank_counts = [0] * self.num_replicas
64
+
65
+ for sid in sorted(scene_order, key=self._scene_len, reverse=True):
66
+ target = min(range(self.num_replicas), key=lambda r: per_rank_counts[r])
67
+ per_rank_scenes[target].append(sid)
68
+ per_rank_counts[target] += self._scene_len(sid)
69
+
70
+ for rid in range(self.num_replicas):
71
+ if not per_rank_scenes[rid]:
72
+ fallback = scene_order[rid % len(scene_order)]
73
+ per_rank_scenes[rid].append(fallback)
74
+ per_rank_counts[rid] += self._scene_len(fallback)
75
+
76
+ target_count = max(per_rank_counts)
77
+ if target_count % self.pad_to_multiple != 0:
78
+ target_count = (
79
+ (target_count + self.pad_to_multiple - 1)
80
+ // self.pad_to_multiple
81
+ ) * self.pad_to_multiple
82
+ self._target_len = target_count
83
+
84
+ my_scenes = per_rank_scenes[self.rank]
85
+ if self.shuffle_scenes:
86
+ rng2 = random.Random(self.seed + self.epoch + self.rank)
87
+ rng2.shuffle(my_scenes)
88
+
89
+ indices: List[int] = []
90
+ for sid in my_scenes:
91
+ indices.extend(self.scene_groups[sid])
92
+
93
+ if len(indices) < target_count:
94
+ replay_pool = list(my_scenes)
95
+ cursor = 0
96
+ while len(indices) < target_count:
97
+ sid = replay_pool[cursor % len(replay_pool)]
98
+ indices.extend(self.scene_groups[sid])
99
+ cursor += 1
100
+ indices = indices[:target_count]
101
+
102
+ return indices
103
+
104
+ def __iter__(self) -> Iterator[int]:
105
+ self._cached = self._build_indices()
106
+ return iter(self._cached)
107
+
108
+ def __len__(self) -> int:
109
+ if self._target_len == 0:
110
+ self._cached = self._build_indices()
111
+ return self._target_len
src/eval/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (403 Bytes). View file
 
src/eval/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (451 Bytes). View file
 
src/eval/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (23.6 kB). View file
 
src/eval/__pycache__/metrics.cpython-38.pyc ADDED
Binary file (23.9 kB). View file
 
src/eval/metrics.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Atlas evaluation metrics."""
2
+
3
+ import re
4
+ import numpy as np
5
+ from typing import List, Dict, Tuple, Optional
6
+
7
+ # scipy only affects match_lanes() / calculate_lane_detection_metrics(),
8
+ # which are NOT used in the main eval path (eval_atlas.py).
9
+ # Main eval uses: greedy matching for detection, OpenLane-V2 LaneEval.bench() for lanes.
10
+ try:
11
+ from scipy.optimize import linear_sum_assignment
12
+ SCIPY_AVAILABLE = True
13
+ except ImportError:
14
+ SCIPY_AVAILABLE = False
15
+
16
+
17
+ NUSCENES_CLASS_MAP = {
18
+ # Base class names
19
+ 'car': 'car',
20
+ 'truck': 'truck',
21
+ 'construction_vehicle': 'construction_vehicle',
22
+ 'bus': 'bus',
23
+ 'trailer': 'trailer',
24
+ 'barrier': 'barrier',
25
+ 'motorcycle': 'motorcycle',
26
+ 'bicycle': 'bicycle',
27
+ 'pedestrian': 'pedestrian',
28
+ 'traffic_cone': 'traffic_cone',
29
+ # Full nuScenes category names - vehicles
30
+ 'vehicle.car': 'car',
31
+ 'vehicle.truck': 'truck',
32
+ 'vehicle.construction': 'construction_vehicle',
33
+ 'vehicle.bus.bendy': 'bus',
34
+ 'vehicle.bus.rigid': 'bus',
35
+ 'vehicle.trailer': 'trailer',
36
+ 'vehicle.motorcycle': 'motorcycle',
37
+ 'vehicle.bicycle': 'bicycle',
38
+ # Full nuScenes category names - pedestrians (all subtypes)
39
+ 'human.pedestrian.adult': 'pedestrian',
40
+ 'human.pedestrian.child': 'pedestrian',
41
+ 'human.pedestrian.construction_worker': 'pedestrian',
42
+ 'human.pedestrian.police_officer': 'pedestrian',
43
+ 'human.pedestrian.wheelchair': 'pedestrian',
44
+ 'human.pedestrian.stroller': 'pedestrian',
45
+ 'human.pedestrian.personal_mobility': 'pedestrian',
46
+ # Full nuScenes category names - movable objects
47
+ 'movable_object.barrier': 'barrier',
48
+ 'movable_object.trafficcone': 'traffic_cone',
49
+ 'movable_object.traffic_cone': 'traffic_cone',
50
+ }
51
+
52
+
53
+ def normalize_category(category: str) -> str:
54
+ """Normalize nuScenes category names to base class names."""
55
+ cat_lower = category.lower().strip()
56
+ if cat_lower in NUSCENES_CLASS_MAP:
57
+ return NUSCENES_CLASS_MAP[cat_lower]
58
+ for key, val in NUSCENES_CLASS_MAP.items():
59
+ if key in cat_lower or cat_lower in key:
60
+ return val
61
+ return cat_lower
62
+
63
+
64
+ def normalize_ground_truths(ground_truths: List[Dict]) -> List[Dict]:
65
+ """Normalize category names and ensure world_coords in ground truth list.
66
+
67
+ Handles multiple GT formats:
68
+ - {"translation": [x, y, z], "category_name": ...} (from regenerate_atlas_with_gt.py)
69
+ - {"box": [x, y, z, w, l, h, yaw], "category_name": ...} (from gen_atlas_full_data.py)
70
+ - {"world_coords": [x, y, z], "category": ...} (already normalized)
71
+ """
72
+ normalized = []
73
+ for gt in ground_truths:
74
+ gt_copy = dict(gt)
75
+ # Normalize category
76
+ if 'category' in gt_copy:
77
+ gt_copy['category_raw'] = gt_copy['category']
78
+ gt_copy['category'] = normalize_category(gt_copy['category'])
79
+ elif 'category_name' in gt_copy:
80
+ gt_copy['category_raw'] = gt_copy['category_name']
81
+ gt_copy['category'] = normalize_category(gt_copy['category_name'])
82
+
83
+ # Ensure world_coords exists
84
+ if 'world_coords' not in gt_copy:
85
+ if 'translation' in gt_copy:
86
+ gt_copy['world_coords'] = list(gt_copy['translation'][:3])
87
+ elif 'box' in gt_copy:
88
+ gt_copy['world_coords'] = list(gt_copy['box'][:3])
89
+
90
+ normalized.append(gt_copy)
91
+ return normalized
92
+
93
+
94
+ def bin_to_meters(bin_val: int, bin_range: Tuple[float, float] = (-51.2, 51.2), num_bins: int = 1000) -> float:
95
+ min_val, max_val = bin_range
96
+ normalized = bin_val / (num_bins - 1)
97
+ meters = min_val + normalized * (max_val - min_val)
98
+ return meters
99
+
100
+
101
+ def meters_to_bin(meters: float, bin_range: Tuple[float, float] = (-51.2, 51.2), num_bins: int = 1000) -> int:
102
+ min_val, max_val = bin_range
103
+ meters = np.clip(meters, min_val, max_val)
104
+ normalized = (meters - min_val) / (max_val - min_val)
105
+ bin_val = round(normalized * (num_bins - 1))
106
+ bin_val = int(np.clip(bin_val, 0, num_bins - 1))
107
+ return bin_val
108
+
109
+
110
+ def _parse_lane_points(points_str: str) -> List[Dict]:
111
+ """Parse a sequence of [x, y, z] bins into lane point dicts."""
112
+ point_pattern = r'\[(\d+),\s*(\d+),\s*(\d+)\]'
113
+ points = re.findall(point_pattern, points_str)
114
+ lane_points = []
115
+ for x_bin, y_bin, z_bin in points:
116
+ x_bin, y_bin, z_bin = int(x_bin), int(y_bin), int(z_bin)
117
+ x_meters = bin_to_meters(x_bin, bin_range=(-51.2, 51.2))
118
+ y_meters = bin_to_meters(y_bin, bin_range=(-51.2, 51.2))
119
+ z_meters = bin_to_meters(z_bin, bin_range=(-5.0, 3.0))
120
+ lane_points.append({
121
+ 'bin_coords': [x_bin, y_bin, z_bin],
122
+ 'world_coords': [x_meters, y_meters, z_meters]
123
+ })
124
+ return lane_points
125
+
126
+
127
+ def parse_atlas_output(text: str) -> List[Dict]:
128
+ """
129
+ Parse Atlas model output. Supports two canonical formats (checked in order):
130
+ 1. Paper lane: Lane: [x, y, z], [x, y, z]; [x, y, z], [x, y, z]; ...
131
+ 2. Detection: category: [x, y, z], [x, y, z]; category: [x, y, z].
132
+ """
133
+ results = []
134
+
135
+ # --- 1. Paper lane format: "Lane: [pts], [pts]; [pts], [pts]; ..." ---
136
+ paper_lane_match = re.search(r'Lane:\s*(.*)', text, re.DOTALL)
137
+ if paper_lane_match:
138
+ content = paper_lane_match.group(1).rstrip('. \t\n')
139
+ lane_strs = content.split(';')
140
+ for lane_idx, lane_str in enumerate(lane_strs):
141
+ lane_str = lane_str.strip()
142
+ if not lane_str:
143
+ continue
144
+ lane_points = _parse_lane_points(lane_str)
145
+ if lane_points:
146
+ results.append({
147
+ 'type': 'lane',
148
+ 'lane_id': str(lane_idx),
149
+ 'points': lane_points,
150
+ })
151
+ if results:
152
+ return results
153
+
154
+ # --- 2. Detection grouped format ---
155
+ # Canonical: "car: [pt1], [pt2]; truck: [pt3]."
156
+
157
+ def _make_det(category: str, x_b: int, y_b: int, z_b: int) -> Dict:
158
+ return {
159
+ 'type': 'detection',
160
+ 'category': normalize_category(category),
161
+ 'category_raw': category,
162
+ 'bin_coords': [x_b, y_b, z_b],
163
+ 'world_coords': [
164
+ bin_to_meters(x_b, bin_range=(-51.2, 51.2)),
165
+ bin_to_meters(y_b, bin_range=(-51.2, 51.2)),
166
+ bin_to_meters(z_b, bin_range=(-5.0, 3.0)),
167
+ ],
168
+ }
169
+
170
+ point_re = re.compile(r'\[(\d+),\s*(\d+),\s*(\d+)\]')
171
+ group_re = re.compile(r'(\S+)\s*:\s*((?:\[\d+,\s*\d+,\s*\d+\][\s,]*)+)')
172
+
173
+ stripped = text.strip().rstrip('.')
174
+
175
+ if stripped.startswith('lane_centerline('):
176
+ return []
177
+
178
+ if ';' in stripped:
179
+ for seg in stripped.split(';'):
180
+ seg = seg.strip()
181
+ if not seg:
182
+ continue
183
+ gm = group_re.match(seg)
184
+ if gm:
185
+ for x_b, y_b, z_b in point_re.findall(gm.group(2)):
186
+ results.append(_make_det(gm.group(1), int(x_b), int(y_b), int(z_b)))
187
+
188
+ if not results:
189
+ gm = group_re.match(stripped)
190
+ if gm:
191
+ pts_in_group = point_re.findall(gm.group(2))
192
+ pts_in_text = point_re.findall(stripped)
193
+ if len(pts_in_group) == len(pts_in_text):
194
+ for x_b, y_b, z_b in pts_in_group:
195
+ results.append(_make_det(gm.group(1), int(x_b), int(y_b), int(z_b)))
196
+
197
+ return results
198
+
199
+
200
+ def calculate_distance(
201
+ pred_coord: List[float],
202
+ gt_coord: List[float],
203
+ use_2d: bool = False,
204
+ ) -> float:
205
+ """
206
+ 计算预测坐标和真实坐标之间的距离
207
+
208
+ Args:
209
+ pred_coord: 预测坐标 [x, y, z]
210
+ gt_coord: 真实坐标 [x, y, z]
211
+ use_2d: 如果为 True,只使用 XY 平面距离(BEV 距离),忽略 Z 轴
212
+ 这是 BEV 3D 检测中更常用的匹配方式
213
+ """
214
+ pred = np.array(pred_coord)
215
+ gt = np.array(gt_coord)
216
+
217
+ if use_2d:
218
+ # 只使用 XY 平面距离(BEV 距离)
219
+ distance = np.linalg.norm(pred[:2] - gt[:2])
220
+ else:
221
+ # 3D 欧式距离
222
+ distance = np.linalg.norm(pred - gt)
223
+
224
+ return float(distance)
225
+
226
+
227
+ def match_detections(
228
+ predictions: List[Dict],
229
+ ground_truths: List[Dict],
230
+ threshold: float = 2.0,
231
+ use_2d_distance: bool = True,
232
+ use_hungarian: bool = False,
233
+ ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
234
+ """
235
+ 匹配预测和真实检测框
236
+
237
+ Args:
238
+ predictions: 预测检测结果列表
239
+ ground_truths: 真实检测结果列表
240
+ threshold: 匹配距离阈值(米)
241
+ use_2d_distance: 如果为 True,使用 2D BEV 距离(XY 平面),这是 BEV 检测的标准做法
242
+ use_hungarian: 如果为 True,使用匈牙利算法进行最优匹配(需要 scipy);
243
+ 默认 False,使用贪婪匹配(nuScenes 标准)
244
+ """
245
+ if len(predictions) == 0:
246
+ return [], [], list(range(len(ground_truths)))
247
+
248
+ if len(ground_truths) == 0:
249
+ return [], list(range(len(predictions))), []
250
+
251
+ # 按类别分组进行匹配
252
+ all_categories = set(p['category'] for p in predictions) | set(g['category'] for g in ground_truths)
253
+
254
+ matched_preds = set()
255
+ matched_gts = set()
256
+ matches = []
257
+
258
+ for category in all_categories:
259
+ cat_preds = [(i, p) for i, p in enumerate(predictions) if p['category'] == category]
260
+ cat_gts = [(i, g) for i, g in enumerate(ground_truths) if g['category'] == category]
261
+
262
+ if not cat_preds or not cat_gts:
263
+ continue
264
+
265
+ # 构建距离矩阵
266
+ n_preds = len(cat_preds)
267
+ n_gts = len(cat_gts)
268
+ cost_matrix = np.full((n_preds, n_gts), float('inf'))
269
+
270
+ for pi, (pred_idx, pred) in enumerate(cat_preds):
271
+ for gi, (gt_idx, gt) in enumerate(cat_gts):
272
+ dist = calculate_distance(pred['world_coords'], gt['world_coords'], use_2d=use_2d_distance)
273
+ if dist < threshold:
274
+ cost_matrix[pi, gi] = dist
275
+
276
+ # 使用匈牙利算法或贪婪匹配
277
+ if use_hungarian and SCIPY_AVAILABLE and n_preds > 0 and n_gts > 0:
278
+ # 匈牙利算法最优匹配
279
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
280
+ for pi, gi in zip(row_ind, col_ind):
281
+ if cost_matrix[pi, gi] < threshold:
282
+ pred_idx = cat_preds[pi][0]
283
+ gt_idx = cat_gts[gi][0]
284
+ matches.append((pred_idx, gt_idx))
285
+ matched_preds.add(pred_idx)
286
+ matched_gts.add(gt_idx)
287
+ else:
288
+ # 贪婪匹配(按距离排序)
289
+ distances = []
290
+ for pi, (pred_idx, pred) in enumerate(cat_preds):
291
+ for gi, (gt_idx, gt) in enumerate(cat_gts):
292
+ dist = cost_matrix[pi, gi]
293
+ if dist < threshold:
294
+ distances.append((dist, pred_idx, gt_idx))
295
+
296
+ distances.sort(key=lambda x: x[0])
297
+ for dist, pred_idx, gt_idx in distances:
298
+ if pred_idx not in matched_preds and gt_idx not in matched_gts:
299
+ matches.append((pred_idx, gt_idx))
300
+ matched_preds.add(pred_idx)
301
+ matched_gts.add(gt_idx)
302
+
303
+ false_positives = [i for i in range(len(predictions)) if i not in matched_preds]
304
+ false_negatives = [i for i in range(len(ground_truths)) if i not in matched_gts]
305
+
306
+ return matches, false_positives, false_negatives
307
+
308
+
309
+ def calculate_detection_f1(
310
+ predictions: List[Dict],
311
+ ground_truths: List[Dict],
312
+ threshold: float = 2.0,
313
+ ) -> Dict[str, float]:
314
+ matches, false_positives, false_negatives = match_detections(
315
+ predictions, ground_truths, threshold
316
+ )
317
+
318
+ tp = len(matches)
319
+ fp = len(false_positives)
320
+ fn = len(false_negatives)
321
+
322
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
323
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
324
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
325
+
326
+ metrics = {
327
+ 'precision': precision,
328
+ 'recall': recall,
329
+ 'f1': f1,
330
+ 'tp': tp,
331
+ 'fp': fp,
332
+ 'fn': fn,
333
+ 'num_predictions': len(predictions),
334
+ 'num_ground_truths': len(ground_truths),
335
+ }
336
+
337
+ return metrics
338
+
339
+
340
+ def denormalize_ref_points_01(
341
+ ref_points_01: np.ndarray,
342
+ pc_range: Tuple[float, float, float, float, float, float] = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
343
+ ) -> np.ndarray:
344
+ """Convert normalized ref points in [0,1] back to meters.
345
+
346
+ Args:
347
+ ref_points_01: array-like [..., 3] in [0, 1]
348
+ pc_range: (x_min, y_min, z_min, x_max, y_max, z_max)
349
+ Returns:
350
+ np.ndarray [..., 3] in meters
351
+ """
352
+ ref = np.asarray(ref_points_01, dtype=np.float64)
353
+ pc_min = np.array(pc_range[:3], dtype=np.float64)
354
+ pc_max = np.array(pc_range[3:], dtype=np.float64)
355
+ denom = np.clip(pc_max - pc_min, 1e-6, None)
356
+ ref01 = np.clip(ref, 0.0, 1.0)
357
+ return pc_min + ref01 * denom
358
+
359
+
360
+ def snap_detections_to_ref_points(
361
+ predictions: List[Dict],
362
+ ref_points_01: np.ndarray,
363
+ pc_range: Tuple[float, float, float, float, float, float] = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
364
+ keep_z: bool = True,
365
+ ) -> List[Dict]:
366
+ """Snap predicted detection centers to nearest reference points (BEV XY).
367
+
368
+ This is a post-processing step that constrains predictions to lie on the
369
+ StreamPETR proposal set (ref points). It can significantly reduce small
370
+ metric thresholds (0.5m/1m) sensitivity to free-form numeric drift.
371
+
372
+ Args:
373
+ predictions: list of detection dicts with 'world_coords' in meters
374
+ ref_points_01: [Q,3] or [B,Q,3] normalized ref points in [0,1]
375
+ pc_range: point cloud range for denormalization
376
+ keep_z: if True, keep each prediction's original z; else use ref z
377
+ Returns:
378
+ New list of predictions (deep-copied dicts) with snapped 'world_coords'
379
+ """
380
+ if not predictions:
381
+ return []
382
+
383
+ ref = np.asarray(ref_points_01, dtype=np.float64)
384
+ if ref.ndim == 3:
385
+ ref = ref[0]
386
+ if ref.ndim != 2 or ref.shape[1] != 3 or ref.shape[0] == 0:
387
+ return list(predictions)
388
+
389
+ ref_m = denormalize_ref_points_01(ref, pc_range=pc_range)
390
+ ref_xy = ref_m[:, :2]
391
+
392
+ pred_xy = np.array([p.get("world_coords", [0.0, 0.0, 0.0])[:2] for p in predictions], dtype=np.float64)
393
+ if pred_xy.ndim != 2 or pred_xy.shape[0] == 0:
394
+ return list(predictions)
395
+
396
+ d = ((pred_xy[:, None, :] - ref_xy[None, :, :]) ** 2).sum(-1)
397
+ nn = d.argmin(axis=1)
398
+
399
+ snapped = []
400
+ for i, p in enumerate(predictions):
401
+ p2 = dict(p)
402
+ wc = list(p2.get("world_coords", [0.0, 0.0, 0.0]))
403
+ j = int(nn[i])
404
+ new_xyz = ref_m[j].tolist()
405
+ if keep_z and len(wc) >= 3:
406
+ new_xyz[2] = float(wc[2])
407
+ p2["world_coords"] = [float(new_xyz[0]), float(new_xyz[1]), float(new_xyz[2])]
408
+ snapped.append(p2)
409
+ return snapped
410
+
411
+
412
+ def calculate_per_class_metrics(
413
+ predictions: List[Dict],
414
+ ground_truths: List[Dict],
415
+ threshold: float = 2.0,
416
+ ) -> Dict[str, Dict[str, float]]:
417
+ pred_categories = set(pred['category'] for pred in predictions)
418
+ gt_categories = set(gt['category'] for gt in ground_truths)
419
+ all_categories = pred_categories | gt_categories
420
+
421
+ per_class_metrics = {}
422
+
423
+ for category in all_categories:
424
+ cat_preds = [pred for pred in predictions if pred['category'] == category]
425
+ cat_gts = [gt for gt in ground_truths if gt['category'] == category]
426
+ metrics = calculate_detection_f1(cat_preds, cat_gts, threshold)
427
+ per_class_metrics[category] = metrics
428
+
429
+ return per_class_metrics
430
+
431
+
432
+ def parse_planning_output(text: str, require_full_vap: bool = False) -> Optional[Dict]:
433
+ result = {}
434
+ vel_pattern = r'ego car speed value:\s*\[(\d+),\s*(\d+)\]\.?'
435
+ acc_pattern = r'ego car acceleration value:\s*\[(\d+),\s*(\d+)\]\.?'
436
+ wp_pattern = (
437
+ r'(?:based on the ego car speed and acceleration you predicted,\s*)?'
438
+ r'(?:requeset|request)\s+the ego car planning waypoint(?:s)? in 3-seconds:\s*'
439
+ r'((?:\[\d+,\s*\d+\](?:,\s*)?)+)\.?'
440
+ )
441
+
442
+ vel_m = re.search(vel_pattern, text, flags=re.IGNORECASE)
443
+ if vel_m:
444
+ result['velocity_bins'] = [int(vel_m.group(1)), int(vel_m.group(2))]
445
+
446
+ acc_m = re.search(acc_pattern, text, flags=re.IGNORECASE)
447
+ if acc_m:
448
+ result['acceleration_bins'] = [int(acc_m.group(1)), int(acc_m.group(2))]
449
+
450
+ wp_m = re.search(wp_pattern, text, flags=re.IGNORECASE)
451
+ if wp_m:
452
+ point_pattern = r'\[(\d+),\s*(\d+)\]'
453
+ points = re.findall(point_pattern, wp_m.group(1))
454
+ wps = []
455
+ for xb, yb in points:
456
+ x = bin_to_meters(int(xb), bin_range=(-51.2, 51.2))
457
+ y = bin_to_meters(int(yb), bin_range=(-51.2, 51.2))
458
+ wps.append([x, y])
459
+ result['waypoints'] = wps
460
+
461
+ if 'waypoints' not in result or len(result['waypoints']) == 0:
462
+ return None
463
+
464
+ # Planning answers use a Figure 5-style chained speed + acceleration +
465
+ # waypoint protocol. The main evaluation path can require all three fields.
466
+ if require_full_vap and (
467
+ 'velocity_bins' not in result or 'acceleration_bins' not in result
468
+ ):
469
+ return None
470
+ return result
471
+
472
+
473
+ def _pad_waypoints(waypoints: List[List[float]], target_n: int = 6) -> List[List[float]]:
474
+ """Pad waypoint list to target_n by repeating last waypoint.
475
+
476
+ This prevents short model outputs from gaming the L2 / collision metrics.
477
+ """
478
+ if len(waypoints) >= target_n:
479
+ return waypoints[:target_n]
480
+ if len(waypoints) == 0:
481
+ return [[0.0, 0.0]] * target_n
482
+ last = list(waypoints[-1])
483
+ return list(waypoints) + [list(last)] * (target_n - len(waypoints))
484
+
485
+
486
+ def calculate_planning_l2(
487
+ pred_waypoints: List[List[float]],
488
+ gt_waypoints: List[List[float]],
489
+ timestamps: List[float] = None,
490
+ ) -> Dict[str, float]:
491
+ n_gt = len(gt_waypoints)
492
+ if timestamps is None:
493
+ timestamps = [0.5 * (i + 1) for i in range(n_gt)]
494
+
495
+ # Pad predictions to match GT length to prevent short-output bias
496
+ pred_padded = _pad_waypoints(pred_waypoints, target_n=n_gt)
497
+
498
+ errors = {}
499
+ all_l2 = []
500
+ for i in range(n_gt):
501
+ pred = np.array(pred_padded[i][:2])
502
+ gt = np.array(gt_waypoints[i][:2])
503
+ l2 = float(np.linalg.norm(pred - gt))
504
+ all_l2.append(l2)
505
+ t = timestamps[i] if i < len(timestamps) else 0.5 * (i + 1)
506
+ if abs(t - 1.0) < 0.01:
507
+ errors['L2_1s'] = l2
508
+ if abs(t - 2.0) < 0.01:
509
+ errors['L2_2s'] = l2
510
+ if abs(t - 3.0) < 0.01:
511
+ errors['L2_3s'] = l2
512
+
513
+ key_steps = [v for k, v in errors.items() if k in ('L2_1s', 'L2_2s', 'L2_3s')]
514
+ errors['L2_avg'] = float(np.mean(key_steps)) if key_steps else (float(np.mean(all_l2)) if all_l2 else 0.0)
515
+
516
+ return errors
517
+
518
+
519
+ def _box_corners_2d(cx: float, cy: float, w: float, l: float, yaw: float) -> np.ndarray:
520
+ """Build oriented box corners for yaw-from-x headings.
521
+
522
+ In planning eval JSON, yaw is measured from +X (right) axis:
523
+ - yaw = 0 -> vehicle length points to +X
524
+ - yaw = +pi/2 -> vehicle length points to +Y
525
+ This matches the qualitative visualization helper.
526
+ """
527
+ c = np.cos(yaw)
528
+ s = np.sin(yaw)
529
+ center = np.array([cx, cy], dtype=np.float64)
530
+
531
+ # Heading axis follows the vehicle length, with width perpendicular to it.
532
+ d_len = np.array([c, s], dtype=np.float64) * (l / 2.0)
533
+ d_wid = np.array([-s, c], dtype=np.float64) * (w / 2.0)
534
+
535
+ corners = np.stack([
536
+ center + d_len + d_wid,
537
+ center + d_len - d_wid,
538
+ center - d_len - d_wid,
539
+ center - d_len + d_wid,
540
+ ], axis=0)
541
+ return corners
542
+
543
+
544
+ def _boxes_overlap(box1_corners: np.ndarray, box2_corners: np.ndarray) -> bool:
545
+ for box in [box1_corners, box2_corners]:
546
+ for i in range(4):
547
+ j = (i + 1) % 4
548
+ edge = box[j] - box[i]
549
+ normal = np.array([-edge[1], edge[0]])
550
+ proj1 = box1_corners @ normal
551
+ proj2 = box2_corners @ normal
552
+ if proj1.max() < proj2.min() or proj2.max() < proj1.min():
553
+ return False
554
+ return True
555
+
556
+
557
+ def _check_collision_at_waypoints(
558
+ waypoints: List[List[float]],
559
+ gt_boxes: List[Dict],
560
+ ego_w: float,
561
+ ego_l: float,
562
+ gt_boxes_per_timestep: Optional[List[List[Dict]]] = None,
563
+ ) -> List[bool]:
564
+ """Check collision between ego at each waypoint and GT boxes.
565
+
566
+ When *gt_boxes_per_timestep* is provided (ST-P3 aligned), each waypoint
567
+ is checked against the boxes at the corresponding future timestep.
568
+ Otherwise falls back to using the same static *gt_boxes* for all waypoints.
569
+ """
570
+ collisions = []
571
+ for i, wp in enumerate(waypoints):
572
+ if i + 1 < len(waypoints):
573
+ dx = waypoints[i + 1][0] - wp[0]
574
+ dy = waypoints[i + 1][1] - wp[1]
575
+ ego_yaw = float(np.arctan2(dy, dx)) if (abs(dx) + abs(dy)) > 1e-4 else 0.0
576
+ elif i > 0:
577
+ dx = wp[0] - waypoints[i - 1][0]
578
+ dy = wp[1] - waypoints[i - 1][1]
579
+ ego_yaw = float(np.arctan2(dy, dx)) if (abs(dx) + abs(dy)) > 1e-4 else 0.0
580
+ else:
581
+ ego_yaw = 0.0
582
+ ego_corners = _box_corners_2d(wp[0], wp[1], ego_w, ego_l, ego_yaw)
583
+
584
+ boxes_at_t = gt_boxes
585
+ if gt_boxes_per_timestep is not None and i < len(gt_boxes_per_timestep):
586
+ boxes_at_t = gt_boxes_per_timestep[i]
587
+
588
+ collided = False
589
+ for box in boxes_at_t:
590
+ if 'world_coords' not in box:
591
+ continue
592
+ bx, by = box['world_coords'][0], box['world_coords'][1]
593
+ bw = box.get('w', 2.0)
594
+ bl = box.get('l', 4.0)
595
+ byaw = box.get('yaw', 0.0)
596
+ obj_corners = _box_corners_2d(bx, by, bw, bl, byaw)
597
+ if _boxes_overlap(ego_corners, obj_corners):
598
+ collided = True
599
+ break
600
+ collisions.append(collided)
601
+ return collisions
602
+
603
+
604
+ def calculate_collision_rate(
605
+ pred_waypoints: List[List[float]],
606
+ gt_boxes: List[Dict],
607
+ ego_w: float = 1.85,
608
+ ego_l: float = 4.084,
609
+ timestamps: List[float] = None,
610
+ num_waypoints: int = 6,
611
+ gt_waypoints: Optional[List[List[float]]] = None,
612
+ gt_boxes_per_timestep: Optional[List[List[Dict]]] = None,
613
+ ) -> Dict[str, float]:
614
+ pred_padded = _pad_waypoints(pred_waypoints, target_n=num_waypoints)
615
+ if timestamps is None:
616
+ timestamps = [0.5 * (i + 1) for i in range(num_waypoints)]
617
+
618
+ # ST-P3 aligned: exclude timesteps where the GT trajectory itself collides
619
+ gt_collides = [False] * num_waypoints
620
+ if gt_waypoints is not None:
621
+ gt_padded = _pad_waypoints(gt_waypoints, target_n=num_waypoints)
622
+ gt_collides = _check_collision_at_waypoints(
623
+ gt_padded, gt_boxes, ego_w, ego_l,
624
+ gt_boxes_per_timestep=gt_boxes_per_timestep,
625
+ )
626
+
627
+ pred_collides = _check_collision_at_waypoints(
628
+ pred_padded, gt_boxes, ego_w, ego_l,
629
+ gt_boxes_per_timestep=gt_boxes_per_timestep,
630
+ )
631
+
632
+ collisions_at_t = {}
633
+ for i in range(num_waypoints):
634
+ t = timestamps[i] if i < len(timestamps) else 0.5 * (i + 1)
635
+ if gt_collides[i]:
636
+ collisions_at_t[t] = False
637
+ else:
638
+ collisions_at_t[t] = pred_collides[i]
639
+
640
+ results = {}
641
+ for target_t, key in [(1.0, 'collision_1s'), (2.0, 'collision_2s'), (3.0, 'collision_3s')]:
642
+ matched = [v for t, v in collisions_at_t.items() if abs(t - target_t) < 0.01]
643
+ if matched:
644
+ results[key] = float(matched[0])
645
+
646
+ key_cols = [v for k, v in results.items() if k in ('collision_1s', 'collision_2s', 'collision_3s')]
647
+ results['collision_avg'] = float(np.mean(key_cols)) if key_cols else 0.0
648
+
649
+ return results
650
+
651
+
652
+ def calculate_planning_metrics(
653
+ predictions: List[Dict],
654
+ ground_truths: List[Dict],
655
+ ) -> Dict[str, float]:
656
+ all_l2 = {'L2_1s': [], 'L2_2s': [], 'L2_3s': [], 'L2_avg': []}
657
+ all_col = {'collision_1s': [], 'collision_2s': [], 'collision_3s': [], 'collision_avg': []}
658
+
659
+ for pred, gt in zip(predictions, ground_truths):
660
+ pred_wps = pred.get('waypoints', [])
661
+ gt_wps = gt.get('waypoints', [])
662
+ if pred_wps and gt_wps:
663
+ l2 = calculate_planning_l2(pred_wps, gt_wps)
664
+ for k, v in l2.items():
665
+ if k in all_l2:
666
+ all_l2[k].append(v)
667
+
668
+ gt_boxes = gt.get('gt_boxes', [])
669
+ gt_boxes_per_ts = gt.get('gt_boxes_per_timestep', None)
670
+ if pred_wps and (gt_boxes or gt_boxes_per_ts):
671
+ col = calculate_collision_rate(
672
+ pred_wps, gt_boxes, gt_waypoints=gt_wps,
673
+ gt_boxes_per_timestep=gt_boxes_per_ts,
674
+ )
675
+ for k, v in col.items():
676
+ if k in all_col:
677
+ all_col[k].append(v)
678
+
679
+ results = {}
680
+ for k, vals in all_l2.items():
681
+ results[k] = float(np.mean(vals)) if vals else 0.0
682
+ for k, vals in all_col.items():
683
+ results[k] = float(np.mean(vals)) if vals else 0.0
684
+
685
+ return results
686
+
687
+
688
+ VEL_ACC_RANGE = (-50.0, 50.0)
689
+
690
+
691
+ def vel_acc_bin_to_meters(bin_val: int, num_bins: int = 1000) -> float:
692
+ return bin_to_meters(bin_val, bin_range=VEL_ACC_RANGE, num_bins=num_bins)
693
+
694
+
695
+ def chamfer_distance_polyline(
696
+ pred_pts: np.ndarray,
697
+ gt_pts: np.ndarray,
698
+ ) -> float:
699
+ if len(pred_pts) == 0 or len(gt_pts) == 0:
700
+ return float('inf')
701
+ pred_pts = np.asarray(pred_pts, dtype=np.float64)
702
+ gt_pts = np.asarray(gt_pts, dtype=np.float64)
703
+ d_p2g = 0.0
704
+ for p in pred_pts:
705
+ d_p2g += np.linalg.norm(gt_pts - p[None, :], axis=1).min()
706
+ d_p2g /= len(pred_pts)
707
+ d_g2p = 0.0
708
+ for g in gt_pts:
709
+ d_g2p += np.linalg.norm(pred_pts - g[None, :], axis=1).min()
710
+ d_g2p /= len(gt_pts)
711
+ return 0.5 * (d_p2g + d_g2p)
712
+
713
+
714
+ def _lane_points_array(lane) -> np.ndarray:
715
+ pts = lane.get('points', [])
716
+ if not pts:
717
+ return np.zeros((0, 3))
718
+ rows = []
719
+ for pt in pts:
720
+ if isinstance(pt, dict):
721
+ rows.append(pt.get('world_coords', [0, 0, 0])[:3])
722
+ else:
723
+ rows.append(list(pt)[:3])
724
+ return np.array(rows, dtype=np.float64)
725
+
726
+
727
+ def match_lanes(
728
+ pred_lanes: List[Dict],
729
+ gt_lanes: List[Dict],
730
+ threshold: float = 1.5,
731
+ ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
732
+ if not pred_lanes:
733
+ return [], [], list(range(len(gt_lanes)))
734
+ if not gt_lanes:
735
+ return [], list(range(len(pred_lanes))), []
736
+
737
+ n_p = len(pred_lanes)
738
+ n_g = len(gt_lanes)
739
+ cost = np.full((n_p, n_g), float('inf'))
740
+
741
+ for i, pl in enumerate(pred_lanes):
742
+ p_pts = _lane_points_array(pl)
743
+ if len(p_pts) == 0:
744
+ continue
745
+ for j, gl in enumerate(gt_lanes):
746
+ g_pts = _lane_points_array(gl)
747
+ if len(g_pts) == 0:
748
+ continue
749
+ cd = chamfer_distance_polyline(p_pts, g_pts)
750
+ if cd < threshold:
751
+ cost[i, j] = cd
752
+
753
+ matches = []
754
+ matched_p = set()
755
+ matched_g = set()
756
+
757
+ if SCIPY_AVAILABLE and n_p > 0 and n_g > 0 and np.isfinite(cost).any():
758
+ try:
759
+ row_ind, col_ind = linear_sum_assignment(cost)
760
+ except ValueError:
761
+ row_ind, col_ind = [], []
762
+ for pi, gi in zip(row_ind, col_ind):
763
+ if cost[pi, gi] < threshold:
764
+ matches.append((pi, gi))
765
+ matched_p.add(pi)
766
+ matched_g.add(gi)
767
+ else:
768
+ pairs = []
769
+ for i in range(n_p):
770
+ for j in range(n_g):
771
+ if cost[i, j] < threshold:
772
+ pairs.append((cost[i, j], i, j))
773
+ pairs.sort()
774
+ for _, i, j in pairs:
775
+ if i not in matched_p and j not in matched_g:
776
+ matches.append((i, j))
777
+ matched_p.add(i)
778
+ matched_g.add(j)
779
+
780
+ fp = [i for i in range(n_p) if i not in matched_p]
781
+ fn = [j for j in range(n_g) if j not in matched_g]
782
+ return matches, fp, fn
783
+
784
+
785
+ def calculate_lane_detection_metrics(
786
+ pred_lanes: List[Dict],
787
+ gt_lanes: List[Dict],
788
+ threshold: float = 1.5,
789
+ ) -> Dict[str, float]:
790
+ matches, fp_list, fn_list = match_lanes(pred_lanes, gt_lanes, threshold)
791
+ tp = len(matches)
792
+ fp = len(fp_list)
793
+ fn = len(fn_list)
794
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
795
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
796
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
797
+ return {
798
+ 'lane_precision': precision,
799
+ 'lane_recall': recall,
800
+ 'lane_f1': f1,
801
+ 'lane_tp': tp,
802
+ 'lane_fp': fp,
803
+ 'lane_fn': fn,
804
+ }
805
+
806
+
807
+ def calculate_multi_threshold_detection_f1(
808
+ predictions: List[Dict],
809
+ ground_truths: List[Dict],
810
+ thresholds: Tuple[float, ...] = (0.5, 1.0, 2.0, 4.0),
811
+ ) -> Dict[str, float]:
812
+ results = {}
813
+ f1_vals = []
814
+ for t in thresholds:
815
+ m = calculate_detection_f1(predictions, ground_truths, threshold=t)
816
+ results[f'P@{t}m'] = m['precision']
817
+ results[f'R@{t}m'] = m['recall']
818
+ results[f'F1@{t}m'] = m['f1']
819
+ f1_vals.append(m['f1'])
820
+ results['F1_avg'] = float(np.mean(f1_vals)) if f1_vals else 0.0
821
+ return results
822
+
823
+
824
+ def evaluate_all(
825
+ task_predictions: Dict[str, List],
826
+ task_ground_truths: Dict[str, List],
827
+ ) -> Dict[str, Dict[str, float]]:
828
+ results = {}
829
+
830
+ if 'detection' in task_predictions and 'detection' in task_ground_truths:
831
+ results['detection'] = calculate_multi_threshold_detection_f1(
832
+ task_predictions['detection'],
833
+ task_ground_truths['detection'],
834
+ )
835
+
836
+ if 'lane' in task_predictions and 'lane' in task_ground_truths:
837
+ agg = {'lane_precision': [], 'lane_recall': [], 'lane_f1': []}
838
+ for pred_set, gt_set in zip(task_predictions['lane'], task_ground_truths['lane']):
839
+ p_list = pred_set if isinstance(pred_set, list) else [pred_set]
840
+ g_list = gt_set if isinstance(gt_set, list) else [gt_set]
841
+ m = calculate_lane_detection_metrics(p_list, g_list)
842
+ for k in agg:
843
+ agg[k].append(m[k])
844
+ results['lane'] = {k: float(np.mean(v)) for k, v in agg.items() if v}
845
+
846
+ if 'planning' in task_predictions and 'planning' in task_ground_truths:
847
+ results['planning'] = calculate_planning_metrics(
848
+ task_predictions['planning'],
849
+ task_ground_truths['planning'],
850
+ )
851
+
852
+ return results
src/model/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Atlas model module."""
2
+
3
+ from .topomlp_adapter import TopoMLPToAtlasMapTokens
4
+ from .streampetr_adapter import extract_streampetr_topk_tokens
5
+
6
+ # Atlas (LLM) side depends on `transformers`. StreamPETR/TopoMLP pretraining does not.
7
+ try:
8
+ from .configuration_atlas import AtlasConfig
9
+ from .modeling_atlas import AtlasProjector, AtlasForCausalLM
10
+ _ATLAS_AVAILABLE = True
11
+ except Exception:
12
+ AtlasConfig = None # type: ignore[assignment]
13
+ AtlasProjector = None # type: ignore[assignment]
14
+ AtlasForCausalLM = None # type: ignore[assignment]
15
+ _ATLAS_AVAILABLE = False
16
+
17
+ __all__ = [
18
+ "TopoMLPToAtlasMapTokens",
19
+ "extract_streampetr_topk_tokens",
20
+ ]
21
+
22
+ if _ATLAS_AVAILABLE:
23
+ __all__ += [
24
+ "AtlasConfig",
25
+ "AtlasProjector",
26
+ "AtlasForCausalLM",
27
+ ]
28
+
src/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (622 Bytes). View file
 
src/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (622 Bytes). View file
 
src/model/__pycache__/configuration_atlas.cpython-310.pyc ADDED
Binary file (852 Bytes). View file
 
src/model/__pycache__/modeling_atlas.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
src/model/__pycache__/modeling_atlas.cpython-38.pyc ADDED
Binary file (15.2 kB). View file
 
src/model/__pycache__/streampetr_adapter.cpython-310.pyc ADDED
Binary file (4.25 kB). View file
 
src/model/__pycache__/streampetr_adapter.cpython-38.pyc ADDED
Binary file (4.22 kB). View file
 
src/model/__pycache__/topomlp_adapter.cpython-310.pyc ADDED
Binary file (3.91 kB). View file
 
src/model/__pycache__/topomlp_adapter.cpython-38.pyc ADDED
Binary file (3.81 kB). View file
 
src/model/modeling_atlas.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Atlas model (LLM + visual token injection)."""
2
+
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional, Tuple, List, Union, Dict
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoConfig,
10
+ BitsAndBytesConfig,
11
+ )
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+
14
+ try:
15
+ from peft import (
16
+ LoraConfig,
17
+ get_peft_model,
18
+ prepare_model_for_kbit_training,
19
+ TaskType,
20
+ )
21
+ PEFT_AVAILABLE = True
22
+ except ImportError:
23
+ PEFT_AVAILABLE = False
24
+ from src.audit.audit_utils import audit_enabled, audit_check
25
+
26
+
27
+ def get_quantization_config(
28
+ load_in_4bit: bool = True,
29
+ bnb_4bit_compute_dtype: torch.dtype = torch.bfloat16,
30
+ bnb_4bit_quant_type: str = "nf4",
31
+ bnb_4bit_use_double_quant: bool = True,
32
+ ) -> Optional[BitsAndBytesConfig]:
33
+ """
34
+ Create BitsAndBytes quantization config for 4-bit loading.
35
+ """
36
+ if not load_in_4bit:
37
+ return None
38
+
39
+ return BitsAndBytesConfig(
40
+ load_in_4bit=True,
41
+ bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
42
+ bnb_4bit_quant_type=bnb_4bit_quant_type,
43
+ bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
44
+ )
45
+
46
+
47
+ class ReferencePointProjector(nn.Module):
48
+ """Reference points [B,Q,3] -> embeddings [B,Q,D]. Zero-initialized per paper."""
49
+
50
+ def __init__(self, visual_hidden_size: int = 256):
51
+ super().__init__()
52
+ self.visual_hidden_size = visual_hidden_size
53
+ self.projector_rp = nn.Linear(3, visual_hidden_size)
54
+ nn.init.zeros_(self.projector_rp.weight)
55
+ if self.projector_rp.bias is not None:
56
+ nn.init.zeros_(self.projector_rp.bias)
57
+
58
+ def forward(self, ref_points: torch.Tensor) -> torch.Tensor:
59
+ return self.projector_rp(ref_points.to(self.projector_rp.weight.dtype))
60
+
61
+
62
+ class AtlasProjector(nn.Module):
63
+ """Project visual features [B,Q,Dv] -> LLM hidden [B,Q,H]."""
64
+
65
+ def __init__(
66
+ self,
67
+ visual_hidden_size: int = 256,
68
+ llm_hidden_size: int = 4096,
69
+ ):
70
+ super().__init__()
71
+ self.visual_hidden_size = visual_hidden_size
72
+ self.llm_hidden_size = llm_hidden_size
73
+ self.projector = nn.Linear(visual_hidden_size, llm_hidden_size)
74
+ nn.init.xavier_uniform_(self.projector.weight)
75
+ if self.projector.bias is not None:
76
+ nn.init.zeros_(self.projector.bias)
77
+
78
+ def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
79
+ return self.projector(visual_features.to(self.projector.weight.dtype))
80
+
81
+
82
+ class AtlasUnifiedProjector(nn.Module):
83
+ def __init__(self, visual_hidden_size: int, llm_hidden_size: int):
84
+ super().__init__()
85
+ self.projector_det = AtlasProjector(visual_hidden_size, llm_hidden_size)
86
+ self.projector_map = AtlasProjector(visual_hidden_size, llm_hidden_size)
87
+ self.projector_rp = ReferencePointProjector(visual_hidden_size)
88
+
89
+ def forward(
90
+ self,
91
+ detection_features: torch.Tensor,
92
+ map_features: Optional[torch.Tensor] = None,
93
+ detection_ref_points: Optional[torch.Tensor] = None,
94
+ map_ref_points: Optional[torch.Tensor] = None,
95
+ ) -> Dict[str, torch.Tensor]:
96
+ if detection_ref_points is not None:
97
+ detection_features = detection_features + self.projector_rp(detection_ref_points)
98
+ det_embeds = self.projector_det(detection_features)
99
+
100
+ out: Dict[str, torch.Tensor] = {"detection": det_embeds}
101
+
102
+ if map_features is not None:
103
+ if map_ref_points is not None:
104
+ map_features = map_features + self.projector_rp(map_ref_points)
105
+ out["map"] = self.projector_map(map_features)
106
+
107
+ return out
108
+
109
+
110
+ class AtlasForCausalLM(nn.Module):
111
+ """Atlas model."""
112
+
113
+ def __init__(
114
+ self,
115
+ llm_model_name: str = "lmsys/vicuna-7b-v1.5",
116
+ visual_hidden_size: int = 256,
117
+ num_queries: int = 256,
118
+ num_map_queries: int = 256,
119
+ load_in_4bit: bool = False,
120
+ use_flash_attention: bool = True,
121
+ device_map: Optional[str] = None,
122
+ torch_dtype: torch.dtype = torch.bfloat16,
123
+ use_lora: bool = False,
124
+ lora_r: int = 64,
125
+ lora_alpha: int = 64,
126
+ lora_dropout: float = 0.1,
127
+ lora_target_modules: List[str] = None,
128
+ ):
129
+ super().__init__()
130
+
131
+ self.llm_model_name = llm_model_name
132
+ self.visual_hidden_size = visual_hidden_size
133
+ self.num_queries = num_queries
134
+ self.num_map_queries = num_map_queries
135
+ self.query_token_id = None
136
+
137
+ if load_in_4bit and not use_lora:
138
+ raise ValueError(
139
+ "load_in_4bit=True requires use_lora=True (4-bit weights are not trainable)."
140
+ )
141
+
142
+ if use_lora and not PEFT_AVAILABLE:
143
+ raise ImportError("LoRA requires peft library")
144
+ self.llm_config = AutoConfig.from_pretrained(llm_model_name)
145
+ self.llm_hidden_size = self.llm_config.hidden_size
146
+
147
+ quantization_config = None
148
+ if load_in_4bit:
149
+ quantization_config = get_quantization_config(
150
+ load_in_4bit=True,
151
+ bnb_4bit_compute_dtype=torch_dtype,
152
+ )
153
+
154
+ attn_implementation = None
155
+ if use_flash_attention:
156
+ try:
157
+ import flash_attn # noqa: F401
158
+ attn_implementation = "flash_attention_2"
159
+ except ImportError:
160
+ raise ImportError(
161
+ "use_flash_attention=True but flash_attn is not installed. "
162
+ "Install with: pip install flash-attn --no-build-isolation"
163
+ )
164
+
165
+ self.llm = AutoModelForCausalLM.from_pretrained(
166
+ llm_model_name,
167
+ quantization_config=quantization_config,
168
+ device_map=device_map,
169
+ torch_dtype=torch_dtype,
170
+ attn_implementation=attn_implementation,
171
+ trust_remote_code=True,
172
+ )
173
+
174
+ if use_lora:
175
+ if load_in_4bit:
176
+ self.llm = prepare_model_for_kbit_training(self.llm)
177
+ if lora_target_modules is None:
178
+ lora_target_modules = [
179
+ "q_proj", "k_proj", "v_proj", "o_proj", # Attention
180
+ "gate_proj", "up_proj", "down_proj", # MLP
181
+ ]
182
+
183
+ lora_config = LoraConfig(
184
+ r=lora_r,
185
+ lora_alpha=lora_alpha,
186
+ lora_dropout=lora_dropout,
187
+ target_modules=lora_target_modules,
188
+ bias="none",
189
+ task_type=TaskType.CAUSAL_LM,
190
+ )
191
+
192
+ self.llm = get_peft_model(self.llm, lora_config)
193
+
194
+ self.projector = AtlasUnifiedProjector(
195
+ visual_hidden_size=visual_hidden_size,
196
+ llm_hidden_size=self.llm_hidden_size,
197
+ )
198
+
199
+ try:
200
+ llm_device = next(self.llm.parameters()).device
201
+ self.projector = self.projector.to(llm_device)
202
+ except StopIteration:
203
+ pass
204
+
205
+ def set_query_token_id(self, query_token_id: int):
206
+ self.query_token_id = query_token_id
207
+
208
+ def get_input_embeddings(self):
209
+ return self.llm.get_input_embeddings()
210
+
211
+ def resize_token_embeddings(self, new_num_tokens: int):
212
+ self.llm.resize_token_embeddings(new_num_tokens)
213
+
214
+ def gradient_checkpointing_enable(self):
215
+ if hasattr(self.llm, 'gradient_checkpointing_enable'):
216
+ self.llm.gradient_checkpointing_enable()
217
+
218
+ _NO_DECAY_KEYWORDS = {"bias", "LayerNorm.weight", "layernorm.weight",
219
+ "layer_norm.weight", "norm.weight"}
220
+
221
+ def get_trainable_param_groups(self, lr: float, weight_decay: float = 0.0) -> List[Dict]:
222
+ decay, no_decay = [], []
223
+ for module in [self.projector, self.llm]:
224
+ for name, param in module.named_parameters():
225
+ if not param.requires_grad:
226
+ continue
227
+ if any(nd in name for nd in self._NO_DECAY_KEYWORDS):
228
+ no_decay.append(param)
229
+ else:
230
+ decay.append(param)
231
+
232
+ if not any(p.requires_grad for p in self.projector.parameters()):
233
+ raise RuntimeError("projector has no trainable parameters (requires_grad=False).")
234
+
235
+ groups: List[Dict] = []
236
+ if decay:
237
+ groups.append({"params": decay, "lr": lr, "weight_decay": weight_decay})
238
+ if no_decay:
239
+ groups.append({"params": no_decay, "lr": lr, "weight_decay": 0.0})
240
+ return groups
241
+
242
+ def get_expected_trainable_param_ids(self, lr: float) -> set:
243
+ """Convenience helper for 'optimizer coverage' hard checks."""
244
+ param_ids: set = set()
245
+ for g in self.get_trainable_param_groups(lr):
246
+ for p in g["params"]:
247
+ param_ids.add(id(p))
248
+ return param_ids
249
+
250
+ def parameters(self, recurse: bool = True):
251
+ for param in self.projector.parameters(recurse):
252
+ yield param
253
+ for param in self.llm.parameters(recurse):
254
+ yield param
255
+
256
+ def forward(
257
+ self,
258
+ input_ids: torch.LongTensor,
259
+ attention_mask: Optional[torch.Tensor] = None,
260
+ visual_features: Optional[Union[torch.Tensor, Dict]] = None,
261
+ labels: Optional[torch.LongTensor] = None,
262
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
263
+ use_cache: Optional[bool] = None,
264
+ output_attentions: Optional[bool] = None,
265
+ output_hidden_states: Optional[bool] = None,
266
+ return_dict: Optional[bool] = True,
267
+ query_token_id: Optional[int] = None,
268
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
269
+ inputs_embeds, attention_mask, position_ids = self._prepare_llm_inputs(
270
+ input_ids=input_ids,
271
+ attention_mask=attention_mask,
272
+ visual_features=visual_features,
273
+ labels=labels,
274
+ query_token_id=query_token_id,
275
+ )
276
+ if self.training:
277
+ use_cache = False
278
+
279
+ # Forward through LLM
280
+ outputs = self.llm(
281
+ input_ids=None,
282
+ attention_mask=attention_mask,
283
+ inputs_embeds=inputs_embeds,
284
+ position_ids=position_ids,
285
+ labels=labels,
286
+ past_key_values=past_key_values,
287
+ use_cache=use_cache,
288
+ output_attentions=output_attentions,
289
+ output_hidden_states=output_hidden_states,
290
+ return_dict=return_dict,
291
+ )
292
+
293
+ return outputs
294
+
295
+ def _prepare_llm_inputs(
296
+ self,
297
+ input_ids: torch.LongTensor,
298
+ attention_mask: Optional[torch.Tensor] = None,
299
+ visual_features: Optional[Union[torch.Tensor, Dict]] = None,
300
+ labels: Optional[torch.LongTensor] = None,
301
+ query_token_id: Optional[int] = None,
302
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
303
+ """
304
+ Shared path for forward/generate:
305
+ - build inputs_embeds
306
+ - inject visual tokens into <query> positions
307
+ - ensure attention_mask / position_ids are consistent
308
+ """
309
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids).clone()
310
+
311
+ _attention_mask_before_inject = None
312
+ if attention_mask is not None:
313
+ attention_mask = attention_mask.clone()
314
+ _attention_mask_before_inject = attention_mask.clone()
315
+
316
+ assert inputs_embeds.shape[:2] == input_ids.shape[:2], (
317
+ f"shape mismatch: inputs_embeds={inputs_embeds.shape[:2]} != input_ids={input_ids.shape[:2]}"
318
+ )
319
+ if attention_mask is not None:
320
+ assert attention_mask.shape[:2] == input_ids.shape[:2], (
321
+ f"shape mismatch: attention_mask={attention_mask.shape[:2]} != input_ids={input_ids.shape[:2]}"
322
+ )
323
+
324
+ qid = query_token_id if query_token_id is not None else self.query_token_id
325
+
326
+ if visual_features is not None and qid is not None:
327
+ batch_size = input_ids.shape[0]
328
+ device = inputs_embeds.device
329
+ dtype = inputs_embeds.dtype
330
+
331
+ if isinstance(visual_features, dict):
332
+ detection_features = visual_features["detection"]
333
+ map_features = visual_features.get("map", None)
334
+ det_ref_points = visual_features.get("detection_ref_points", None)
335
+ map_ref_points = visual_features.get("map_ref_points", None)
336
+
337
+ projected = self.projector(
338
+ detection_features=detection_features,
339
+ map_features=map_features,
340
+ detection_ref_points=det_ref_points,
341
+ map_ref_points=map_ref_points,
342
+ )
343
+ detection_embeds = projected["detection"].to(dtype=dtype, device=device)
344
+ map_embeds = projected.get("map", None)
345
+ if map_embeds is not None:
346
+ map_embeds = map_embeds.to(dtype=dtype, device=device)
347
+ num_det = detection_embeds.shape[1]
348
+ num_map = map_embeds.shape[1] if map_embeds is not None else 0
349
+ else:
350
+ detection_embeds = self.projector.projector_det(visual_features).to(dtype=dtype, device=device)
351
+ map_embeds = None
352
+ num_det = detection_embeds.shape[1]
353
+ num_map = 0
354
+
355
+ for b in range(batch_size):
356
+ query_positions = torch.where(input_ids[b] == qid)[0]
357
+ num_query_slots = int(query_positions.numel())
358
+ if num_query_slots == 0:
359
+ continue
360
+
361
+ if labels is not None:
362
+ ans_pos = (labels[b] != -100).nonzero(as_tuple=True)[0]
363
+ if ans_pos.numel() > 0 and audit_enabled():
364
+ first_answer_token_pos = int(ans_pos[0].item())
365
+ q_min = int(query_positions.min().item())
366
+ q_max = int(query_positions.max().item())
367
+ audit_check(
368
+ "B4",
369
+ q_max < first_answer_token_pos,
370
+ once=True,
371
+ min_query_pos=q_min,
372
+ max_query_pos=q_max,
373
+ first_answer_token_pos=first_answer_token_pos,
374
+ )
375
+
376
+ if map_embeds is not None and num_query_slots == (num_det + num_map):
377
+ for i in range(num_det):
378
+ inputs_embeds[b, int(query_positions[i].item())] = detection_embeds[b, i]
379
+ for i in range(num_map):
380
+ inputs_embeds[b, int(query_positions[int(num_det) + i].item())] = map_embeds[b, i]
381
+ num_injected = int(num_det + num_map)
382
+ inj_det = int(num_det)
383
+ inj_map = int(num_map)
384
+ elif num_query_slots == num_det:
385
+ for i in range(num_det):
386
+ inputs_embeds[b, int(query_positions[i].item())] = detection_embeds[b, i]
387
+ num_injected = int(num_det)
388
+ inj_det = int(num_det)
389
+ inj_map = 0
390
+ else:
391
+ raise ValueError(
392
+ f"<query> slot mismatch: slots={num_query_slots}, "
393
+ f"det={num_det}, map={num_map if map_embeds is not None else 0}. "
394
+ f"Ensure visual_features provides the correct number of tokens "
395
+ f"matching the prompt's <query> placeholders."
396
+ )
397
+
398
+ if attention_mask is not None and num_injected > 0:
399
+ attention_mask[b, query_positions[:num_injected]] = 1
400
+
401
+ if audit_enabled():
402
+ if not hasattr(self, "_audit_forward_calls"):
403
+ self._audit_forward_calls = 0
404
+ max_calls = int(os.getenv("ATLAS_AUDIT_MAX_FWD", "1"))
405
+ if self._audit_forward_calls < max_calls and b == 0:
406
+ total_injected = int(num_injected)
407
+ print(
408
+ "[ATLAS_AUDIT][A3] "
409
+ f"num_query_tokens_in_input_ids={num_query_slots} "
410
+ f"num_detection_tokens_injected={inj_det} "
411
+ f"num_map_tokens_injected={inj_map} "
412
+ f"total_injected={total_injected}"
413
+ )
414
+ seq_len_input_ids = int(input_ids.shape[1])
415
+ seq_len_embeds = int(inputs_embeds.shape[1])
416
+ seq_len_mask = int(attention_mask.shape[1]) if attention_mask is not None else -1
417
+ print(f"[ATLAS_AUDIT][B1] seq_len_input_ids={seq_len_input_ids} seq_len_embeds={seq_len_embeds} seq_len_mask={seq_len_mask}")
418
+ n_show = min(20, num_query_slots)
419
+ pos_first = query_positions[:n_show].tolist()
420
+ src_first = []
421
+ for i in range(n_show):
422
+ src_first.append("DET" if i < int(num_det) else "MAP")
423
+ packed = ",".join([f"{s}@{p}" for s, p in zip(src_first, pos_first)])
424
+ print(f"[ATLAS_AUDIT][A4] first20={packed}")
425
+ if attention_mask is not None and n_show > 0 and _attention_mask_before_inject is not None:
426
+ q_mask = _attention_mask_before_inject[b, query_positions[:num_injected]]
427
+ bad = int((q_mask == 0).sum().item())
428
+ q_min2 = int(q_mask.min().item()) if q_mask.numel() else -1
429
+ q_max2 = int(q_mask.max().item()) if q_mask.numel() else -1
430
+ print(f"[ATLAS_AUDIT][A5/B2] query_mask_bad_count={bad} query_mask_min={q_min2} query_mask_max={q_max2}")
431
+ assert bad == 0, "<query> positions have attention_mask==0"
432
+ n = min(8, num_query_slots, num_injected)
433
+ if n > 0:
434
+ diffs = []
435
+ for i in range(n):
436
+ pos = int(query_positions[i].item())
437
+ if i < int(num_det):
438
+ ref = detection_embeds[b, i]
439
+ else:
440
+ j = i - int(num_det)
441
+ ref = map_embeds[b, j] if map_embeds is not None else detection_embeds[b, min(i, int(num_det) - 1)]
442
+ diffs.append(float((inputs_embeds[b, pos] - ref).abs().max().item()))
443
+ max_diff = max(diffs) if diffs else 0.0
444
+ print(f"[ATLAS_AUDIT][C1] sampled_max_diff={max_diff:.3e} (n={n})")
445
+ assert max_diff < 1e-5, f"injected embed diff too large: {max_diff}"
446
+ # C3 audit (post-inject)
447
+ if attention_mask is not None:
448
+ text_pos2 = (attention_mask[b] == 1) & (input_ids[b] != qid)
449
+ else:
450
+ text_pos2 = (input_ids[b] != qid)
451
+ text_vec2 = inputs_embeds[b, text_pos2].float()
452
+ vis_vec2 = inputs_embeds[b, query_positions[:num_injected]].float()
453
+ if text_vec2.numel() > 0 and vis_vec2.numel() > 0:
454
+ text_norm = text_vec2.norm(dim=-1)
455
+ vis_norm = vis_vec2.norm(dim=-1)
456
+ t_mean = float(text_norm.mean().item())
457
+ t_std = float(text_norm.std(unbiased=False).item())
458
+ v_mean = float(vis_norm.mean().item())
459
+ v_std = float(vis_norm.std(unbiased=False).item())
460
+ ratio = v_mean / (t_mean + 1e-8)
461
+ rmin = float(os.getenv("ATLAS_VIS_TEXT_NORM_RATIO_MIN", "0.1"))
462
+ rmax = float(os.getenv("ATLAS_VIS_TEXT_NORM_RATIO_MAX", "10.0"))
463
+ audit_check(
464
+ "C3",
465
+ (ratio >= rmin and ratio <= rmax),
466
+ once=True,
467
+ ratio=ratio,
468
+ ratio_min=rmin,
469
+ ratio_max=rmax,
470
+ text_norm_mean=t_mean,
471
+ text_norm_std=t_std,
472
+ vis_norm_mean=v_mean,
473
+ vis_norm_std=v_std,
474
+ dtype=str(inputs_embeds.dtype),
475
+ )
476
+ self._audit_forward_calls += 1
477
+
478
+ batch_size, seq_length = inputs_embeds.shape[:2]
479
+ device = inputs_embeds.device
480
+ if attention_mask is not None:
481
+ has_token = (attention_mask.sum(dim=1) > 0)
482
+ left_padded = bool(((attention_mask[:, 0] == 0) & has_token).any().item() and ((attention_mask[:, -1] == 1) & has_token).any().item())
483
+ if left_padded:
484
+ position_ids = attention_mask.long().cumsum(dim=1) - 1
485
+ position_ids.masked_fill_(attention_mask == 0, 0)
486
+ else:
487
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
488
+ else:
489
+ position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
490
+
491
+ return inputs_embeds, attention_mask, position_ids
492
+
493
+ @torch.no_grad()
494
+ def generate(
495
+ self,
496
+ input_ids: torch.LongTensor,
497
+ attention_mask: Optional[torch.Tensor] = None,
498
+ visual_features: Optional[Union[torch.Tensor, Dict]] = None,
499
+ max_new_tokens: int = 256,
500
+ query_token_id: Optional[int] = None,
501
+ **kwargs,
502
+ ) -> torch.LongTensor:
503
+ """
504
+ Greedy/standard generation using the SAME projector+injection+pos/mask path as forward.
505
+ """
506
+ inputs_embeds, attention_mask, position_ids = self._prepare_llm_inputs(
507
+ input_ids=input_ids,
508
+ attention_mask=attention_mask,
509
+ visual_features=visual_features,
510
+ labels=None,
511
+ query_token_id=query_token_id,
512
+ )
513
+ gen_kwargs = dict(
514
+ inputs_embeds=inputs_embeds,
515
+ attention_mask=attention_mask,
516
+ max_new_tokens=max_new_tokens,
517
+ )
518
+ if "pad_token_id" not in kwargs and hasattr(self.llm.config, "pad_token_id"):
519
+ gen_kwargs["pad_token_id"] = self.llm.config.pad_token_id
520
+ gen_kwargs.update(kwargs)
521
+ return self.llm.generate(**gen_kwargs)
522
+
523
+
524
+ def load_atlas_model(
525
+ llm_model_name: str = "lmsys/vicuna-7b-v1.5",
526
+ visual_hidden_size: int = 256,
527
+ num_queries: int = 256,
528
+ num_map_queries: int = 256,
529
+ load_in_4bit: bool = False,
530
+ use_flash_attention: bool = True,
531
+ use_lora: bool = False,
532
+ lora_r: int = 64,
533
+ lora_alpha: int = 64,
534
+ lora_dropout: float = 0.1,
535
+ lora_target_modules: List[str] = None,
536
+ ) -> AtlasForCausalLM:
537
+ return AtlasForCausalLM(
538
+ llm_model_name=llm_model_name,
539
+ visual_hidden_size=visual_hidden_size,
540
+ num_queries=num_queries,
541
+ num_map_queries=num_map_queries,
542
+ load_in_4bit=load_in_4bit,
543
+ use_flash_attention=use_flash_attention,
544
+ use_lora=use_lora,
545
+ lora_r=lora_r,
546
+ lora_alpha=lora_alpha,
547
+ lora_dropout=lora_dropout,
548
+ lora_target_modules=lora_target_modules,
549
+ )
src/model/streampetr_adapter.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ StreamPETR -> Atlas detection token adapter WITHOUT modifying StreamPETR source code.
3
+
4
+ Rationale:
5
+ - StreamPETRHead updates internal memory with Top-K proposals (topk_proposals, typically 256).
6
+ - After a forward pass, the head stores:
7
+ - self.memory_embedding: [B, memory_len + topk_proposals, D] (prepend new Top-K)
8
+ - self.memory_reference_point: [B, memory_len + topk_proposals, 3]
9
+ (exact ordering: new Top-K are concatenated in front)
10
+
11
+ IMPORTANT: after post_update_memory, memory_reference_point is in the **GLOBAL**
12
+ coordinate frame (ego_pose applied). We must invert the ego_pose to bring
13
+ ref_points back to ego frame before normalizing with pc_range.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any, Dict, Optional, Tuple
19
+
20
+ import torch
21
+
22
+ PC_RANGE = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0)
23
+
24
+
25
+ def _normalize_ref_points(
26
+ ref: torch.Tensor,
27
+ pc_range: Tuple[float, float, float, float, float, float] = PC_RANGE,
28
+ ) -> torch.Tensor:
29
+ pc_min = ref.new_tensor(pc_range[:3])
30
+ pc_max = ref.new_tensor(pc_range[3:])
31
+ denom = (pc_max - pc_min).clamp(min=1e-6)
32
+ return ((ref - pc_min) / denom).clamp(0.0, 1.0)
33
+
34
+
35
+ def _global_to_ego(ref: torch.Tensor, ego_pose: torch.Tensor) -> torch.Tensor:
36
+ """Transform reference points from global frame back to ego frame.
37
+
38
+ Args:
39
+ ref: [B, N, 3] in global coordinates
40
+ ego_pose: [B, 4, 4] ego-to-global transform
41
+ Returns:
42
+ [B, N, 3] in ego coordinates
43
+ """
44
+ B, N, _ = ref.shape
45
+ ones = torch.ones(B, N, 1, device=ref.device, dtype=ref.dtype)
46
+ ref_homo = torch.cat([ref, ones], dim=-1) # [B, N, 4]
47
+ ego_pose_inv = torch.inverse(ego_pose) # [B, 4, 4]
48
+ ref_ego = (ego_pose_inv.unsqueeze(1) @ ref_homo.unsqueeze(-1)).squeeze(-1)[..., :3]
49
+ return ref_ego
50
+
51
+
52
+ def _nuscenes_ego_to_paper(ref: torch.Tensor) -> torch.Tensor:
53
+ """Convert nuScenes ego coords to Atlas paper frame.
54
+
55
+ nuScenes ego uses x=forward, y=left. Atlas detection QA uses
56
+ x=right, y=forward, so (x_p, y_p) = (-y_n, x_n).
57
+ """
58
+ ref_paper = ref.clone()
59
+ ref_paper[..., 0] = -ref[..., 1]
60
+ ref_paper[..., 1] = ref[..., 0]
61
+ return ref_paper
62
+
63
+
64
+ @torch.no_grad()
65
+ def extract_streampetr_topk_tokens(
66
+ pts_bbox_head: Any,
67
+ topk: int = 256,
68
+ pc_range: Tuple[float, float, float, float, float, float] = PC_RANGE,
69
+ ego_pose: Optional[torch.Tensor] = None,
70
+ ) -> Dict[str, torch.Tensor]:
71
+ """
72
+ Args:
73
+ pts_bbox_head: the StreamPETRHead instance (model.pts_bbox_head)
74
+ topk: number of tokens to export; should match pts_bbox_head.topk_proposals
75
+ pc_range: point cloud range used by StreamPETR, for normalizing ref_points
76
+ ego_pose: [B, 4, 4] ego-to-global transform. If provided, ref_points are
77
+ transformed back from global to ego frame before normalization.
78
+ Returns:
79
+ dict:
80
+ - detection: [B, topk, D]
81
+ - detection_ref_points: [B, topk, 3] (normalized to [0, 1];
82
+ if ego_pose is provided, aligned to Atlas paper frame)
83
+ """
84
+ if not hasattr(pts_bbox_head, "memory_embedding") or not hasattr(pts_bbox_head, "memory_reference_point"):
85
+ raise RuntimeError("pts_bbox_head missing memory buffers; ensure you have run a forward pass first.")
86
+
87
+ mem = pts_bbox_head.memory_embedding
88
+ ref = pts_bbox_head.memory_reference_point
89
+ if mem is None or ref is None:
90
+ raise RuntimeError("pts_bbox_head memory is None; ensure you have run a forward pass and prev_exists is set.")
91
+
92
+ if mem.ndim != 3 or ref.ndim != 3 or ref.shape[-1] != 3:
93
+ raise RuntimeError(f"unexpected shapes: memory_embedding={getattr(mem,'shape',None)} memory_reference_point={getattr(ref,'shape',None)}")
94
+
95
+ B = mem.shape[0]
96
+ if mem.shape[1] < topk or ref.shape[1] < topk:
97
+ raise RuntimeError(f"memory length too small: mem_len={mem.shape[1]} ref_len={ref.shape[1]} topk={topk}")
98
+
99
+ det = mem[:, :topk, :].contiguous()
100
+ det_ref = ref[:, :topk, :].contiguous()
101
+
102
+ # post_update_memory transforms ref_points to global frame via ego_pose.
103
+ # We invert this to get ego-frame coordinates, then rotate to Atlas paper
104
+ # frame so projector_rp sees the same XY semantics as detection QA/GT.
105
+ if ego_pose is not None:
106
+ det_ref = _global_to_ego(det_ref, ego_pose)
107
+ det_ref = _nuscenes_ego_to_paper(det_ref)
108
+
109
+ det_ref = _normalize_ref_points(det_ref, pc_range)
110
+ return {"detection": det, "detection_ref_points": det_ref}
src/model/topomlp_adapter.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TopoMLP -> Atlas map token adapter.
2
+
3
+ Paper-aligned: Top-K selection from TopoMLP decoder outputs,
4
+ followed by a single linear projector (handled by AtlasUnifiedProjector).
5
+ No Perceiver resampler -- queries and ref_points are passed through directly.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Dict, Optional, Tuple
11
+
12
+ import torch
13
+
14
+
15
+ def _lane_control_points_to_center_xyz(lane_preds: torch.Tensor) -> torch.Tensor:
16
+ if lane_preds.ndim != 3 or lane_preds.shape[-1] % 3 != 0:
17
+ raise ValueError(f"lane_preds expected [B,Q,3*K], got {tuple(lane_preds.shape)}")
18
+ B, Q, D = lane_preds.shape
19
+ K = D // 3
20
+ pts = lane_preds.view(B, Q, K, 3)
21
+ return pts.mean(dim=2)
22
+
23
+
24
+ def _normalize_xyz(xyz: torch.Tensor, xyz_min: torch.Tensor, xyz_max: torch.Tensor) -> torch.Tensor:
25
+ denom = (xyz_max - xyz_min).clamp(min=1e-6)
26
+ out = (xyz - xyz_min) / denom
27
+ return out.clamp(0.0, 1.0)
28
+
29
+
30
+ class TopoMLPToAtlasMapTokens(torch.nn.Module):
31
+ """Select top-K lane queries from TopoMLP and return them with reference points.
32
+
33
+ Aligned with Atlas paper Section 3.1:
34
+ "these queries are streamlined through a single linear layer"
35
+ The linear projection itself is in AtlasUnifiedProjector.projector_map.
36
+ This module only does Top-K selection + ref_point computation.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ num_map_tokens: int = 256,
42
+ hidden_size: int = 256,
43
+ bev_range: Tuple[float, float, float, float, float, float] = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0),
44
+ **kwargs,
45
+ ):
46
+ super().__init__()
47
+ self.num_map_tokens = int(num_map_tokens)
48
+ self.hidden_size = int(hidden_size)
49
+ self.bev_range = tuple(float(x) for x in bev_range)
50
+
51
+ xyz_min = torch.tensor(self.bev_range[:3], dtype=torch.float32)
52
+ xyz_max = torch.tensor(self.bev_range[3:], dtype=torch.float32)
53
+ self.register_buffer("_xyz_min", xyz_min, persistent=False)
54
+ self.register_buffer("_xyz_max", xyz_max, persistent=False)
55
+
56
+ @torch.no_grad()
57
+ def infer_lane_centers_from_outs(self, outs: Dict) -> torch.Tensor:
58
+ one2one_preds = outs["all_lc_preds_list"][-1]
59
+ return _lane_control_points_to_center_xyz(one2one_preds)
60
+
61
+ def forward(self, outs: Dict) -> Dict[str, torch.Tensor]:
62
+ lane_tokens = outs["lc_outs_dec_list"][-1]
63
+ lane_scores = outs["all_lc_cls_scores_list"][-1].squeeze(-1)
64
+
65
+ lane_centers = self.infer_lane_centers_from_outs(outs)
66
+ lane_ref_norm = _normalize_xyz(lane_centers, self._xyz_min, self._xyz_max)
67
+
68
+ B, N, D = lane_tokens.shape
69
+ if N == 0:
70
+ return {
71
+ "map": torch.zeros(B, self.num_map_tokens, D, dtype=lane_tokens.dtype, device=lane_tokens.device),
72
+ "map_ref_points": torch.zeros(B, self.num_map_tokens, 3, dtype=lane_ref_norm.dtype, device=lane_ref_norm.device),
73
+ }
74
+
75
+ k = min(self.num_map_tokens, N)
76
+ topk_idx = torch.topk(lane_scores, k=k, dim=1, largest=True, sorted=True).indices
77
+ tok_idx = topk_idx.unsqueeze(-1).expand(-1, -1, D)
78
+ ref_idx = topk_idx.unsqueeze(-1).expand(-1, -1, 3)
79
+ map_tokens = lane_tokens.gather(dim=1, index=tok_idx)
80
+ map_ref = lane_ref_norm.gather(dim=1, index=ref_idx)
81
+
82
+ if k < self.num_map_tokens:
83
+ pad_t = torch.zeros(B, self.num_map_tokens - k, D, dtype=map_tokens.dtype, device=map_tokens.device)
84
+ pad_r = torch.zeros(B, self.num_map_tokens - k, 3, dtype=map_ref.dtype, device=map_ref.device)
85
+ map_tokens = torch.cat([map_tokens, pad_t], dim=1)
86
+ map_ref = torch.cat([map_ref, pad_r], dim=1)
87
+
88
+ return {"map": map_tokens, "map_ref_points": map_ref}
src/prompting.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+ from typing import Literal, Optional, Tuple
4
+
5
+ from src.audit.audit_utils import audit_enabled, h1b_record_prompt
6
+
7
+
8
+ PLANNING_TABLE3_MODES = (
9
+ "atlas_base",
10
+ "atlas_high_level",
11
+ "atlas_high_level_ego",
12
+ )
13
+
14
+ _PLANNING_COMMAND_RE = re.compile(
15
+ r"The ego car will (?:turn left|turn right|go straight) in future\.\s*"
16
+ )
17
+ _PLANNING_STATE_RE = re.compile(
18
+ r"The current speed value of the ego car is \[\d+,\s*\d+\]\.\s*"
19
+ r"The current acceleration value of the ego car is \[\d+,\s*\d+\]\.\s*"
20
+ )
21
+ _PLANNING_ACCEL_SENTENCE = (
22
+ "The acceleration of the vehicle is defined as "
23
+ "[acceleration along the x-axis, acceleration along the y-axis]."
24
+ )
25
+
26
+
27
+ # Paper Table 6 prompts 1-3, plus two repository paraphrases.
28
+ DETECTION_PROMPTS = [
29
+ # Paper Table 6, prompt 1.
30
+ (
31
+ "There are six images captured by the surround view cameras in driving vehicle. "
32
+ "They are uniformly represented as queries embeddings<query>. "
33
+ "Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
34
+ "Please complete the visual detection task under the Bird's Eye View (BEV) perspective. "
35
+ "Ensure that the detection range does not exceed 50 meters."
36
+ ),
37
+ # Paper Table 6, prompt 2.
38
+ (
39
+ "There are six images captured by the surround view cameras in driving vehicle. "
40
+ "They are uniformly represented as queries embeddings<query>. "
41
+ "Establish the positive y-axis as the frontward direction and the positive x-axis as the rightward direction. "
42
+ "Kindly execute the visual detection task within the Bird's Eye View (BEV) framework. "
43
+ "Be mindful not to exceed a detection range of 50 meters."
44
+ ),
45
+ # Paper Table 6, prompt 3.
46
+ (
47
+ "There are six images captured by the surround view cameras in driving vehicle. "
48
+ "They are uniformly represented as queries embeddings<query>. "
49
+ "Set the forward direction as the positive y-axis and the right direction as the positive x-axis. "
50
+ "Please carry out the visual detection task within the Bird's Eye View (BEV) context. "
51
+ "Ensure that the detection range remains within 50 meters."
52
+ ),
53
+ # Additional paraphrase variant 4.
54
+ (
55
+ "There are six images captured by the surround view cameras in driving vehicle. "
56
+ "They are uniformly represented as queries embeddings<query>. "
57
+ "Let the positive y-axis denote the forward direction and the positive x-axis denote the right direction. "
58
+ "Please perform the visual detection task from the Bird's Eye View (BEV) perspective. "
59
+ "Keep the detection range within 50 meters."
60
+ ),
61
+ # Additional paraphrase variant 5.
62
+ (
63
+ "There are six images captured by the surround view cameras in driving vehicle. "
64
+ "They are uniformly represented as queries embeddings<query>. "
65
+ "Take the positive y-axis to be the forward direction and the positive x-axis to be the right direction. "
66
+ "Kindly carry out the visual detection task under the Bird's Eye View (BEV) perspective. "
67
+ "Ensure that all detections stay within 50 meters."
68
+ ),
69
+ ]
70
+
71
+ # Paper Table 7 prompts 1-3, plus two repository paraphrases.
72
+ LANE_PROMPTS = [
73
+ # Paper Table 7, prompt 1.
74
+ (
75
+ "There are six images captured by the surround view cameras in driving vehicle. "
76
+ "They are uniformly represented as queries embeddings<query>. "
77
+ "Please complete the centerline detection task under the Bird's Eye View (BEV) perspective. "
78
+ "Ensure that the detection range does not exceed 50 meters."
79
+ ),
80
+ # Paper Table 7, prompt 2. The published paper text appears truncated and is kept verbatim.
81
+ (
82
+ "There are six images captured by the surround view cameras in driving vehicle. "
83
+ "They are uniformly represented as queries embeddings<query>. "
84
+ "Kindly execute the centerline detection task within the Bird's Eye View (BEV) framework. "
85
+ "Be mindful not to exceed a detection range of 50 meters."
86
+ ),
87
+ # Paper Table 7, prompt 3.
88
+ (
89
+ "There are six images captured by the surround view cameras in driving vehicle. "
90
+ "They are uniformly represented as queries embeddings<query>. "
91
+ "Could you complete the task of detecting the centerline from the Bird's Eye View (BEV) perspective? "
92
+ "Ensure that the detection range remains within 50 meters."
93
+ ),
94
+ # Additional paraphrase variant 4.
95
+ (
96
+ "There are six images captured by the surround view cameras in driving vehicle. "
97
+ "They are uniformly represented as queries embeddings<query>. "
98
+ "Kindly execute the centerline detection task within the Bird's Eye View (BEV) framework. "
99
+ "Be mindful not to exceed a detection range of 50 meters."
100
+ ),
101
+ # Additional paraphrase variant 5.
102
+ (
103
+ "There are six images captured by the surround view cameras in driving vehicle. "
104
+ "They are uniformly represented as queries embeddings<query>. "
105
+ "Please carry out the task of detecting the centerline from the Bird's Eye View (BEV) perspective. "
106
+ "Ensure that the detection range remains within 50 meters."
107
+ ),
108
+ ]
109
+
110
+ # Paper Table 9 prompts 1-3, plus two repository paraphrases.
111
+ PLANNING_PROMPTS = [
112
+ # Paper Table 9, prompt 1. The paper uses fixed maneuver words; {command} keeps the same slot dynamic.
113
+ (
114
+ "The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
115
+ "and map query embeddings<query>. "
116
+ "Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
117
+ "The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
118
+ "The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
119
+ "The ego car will {command} in future. "
120
+ "Kindly furnish suitable waypoints for the vehicle's trajectory based on the provided particulars. "
121
+ "Waypoints ought to adhere to the [x, y] format, with each waypoint spaced at 0.5-second intervals "
122
+ "within a continuous 3.0-second timeframe. "
123
+ "For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
124
+ ),
125
+ # Paper Table 9, prompt 2. The paper uses fixed maneuver words; {command} keeps the same slot dynamic.
126
+ (
127
+ "The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
128
+ "and map query embeddings<query>. "
129
+ "Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
130
+ "The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
131
+ "The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
132
+ "The ego car will {command} in future. "
133
+ "We request your provision of pertinent waypoints for the vehicle's route in accordance with the given information. "
134
+ "Waypoints should conform to the format [x, y], with spacing set at 0.5-second intervals "
135
+ "over a continuous duration of 3.0 seconds. "
136
+ "For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
137
+ ),
138
+ # Paper Table 9, prompt 3. The paper uses fixed maneuver words; {command} keeps the same slot dynamic.
139
+ (
140
+ "The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
141
+ "and map query embeddings<query>. "
142
+ "Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
143
+ "The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
144
+ "The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
145
+ "The ego car will {command} in future. "
146
+ "Please submit fitting waypoints for the vehicle's course based on the supplied data. "
147
+ "Ensure waypoints are structured as [x, y] and spaced at intervals of 0.5 seconds across a continuous 3.0-second period. "
148
+ "For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
149
+ ),
150
+ # Additional paraphrase variant 4.
151
+ (
152
+ "The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
153
+ "and map query embeddings<query>. "
154
+ "Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
155
+ "The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
156
+ "The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
157
+ "The ego car will {command} in future. "
158
+ "Please provide suitable waypoints for the ego car in [x, y] format at 0.5-second intervals "
159
+ "over a continuous 3.0-second period. "
160
+ "For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
161
+ ),
162
+ # Additional paraphrase variant 5.
163
+ (
164
+ "The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
165
+ "and map query embeddings<query>. "
166
+ "Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
167
+ "The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
168
+ "The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
169
+ "The ego car will {command} in future. "
170
+ "Could you generate appropriate waypoints for the vehicle's trajectory in [x, y] format, "
171
+ "with each waypoint separated by 0.5 seconds across the next 3.0 seconds? "
172
+ "For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
173
+ ),
174
+ ]
175
+
176
+ # Figure 5-style single-view caption prompt, parameterized by camera_name.
177
+ CAPTION_PROMPTS = [
178
+ (
179
+ "There are six images captured by the surround view cameras in driving vehicle. "
180
+ "They are uniformly represented as queries embeddings<query>. "
181
+ "Communicate a narrative of the setting within {camera_name} view image."
182
+ ),
183
+ ]
184
+
185
+
186
+ _TASK_POOLS = {
187
+ "detection": DETECTION_PROMPTS,
188
+ "lane": LANE_PROMPTS,
189
+ "planning": PLANNING_PROMPTS,
190
+ "caption": CAPTION_PROMPTS,
191
+ }
192
+
193
+
194
+ def get_prompt_pool(task: str):
195
+ return _TASK_POOLS.get(task, DETECTION_PROMPTS)
196
+
197
+
198
+ def sample_prompt(task: str, **kwargs) -> str:
199
+ pool = get_prompt_pool(task)
200
+ template = random.choice(pool)
201
+ if kwargs:
202
+ try:
203
+ return template.format(**kwargs)
204
+ except KeyError:
205
+ return template
206
+ return template
207
+
208
+
209
+ def rewrite_planning_prompt_for_table3(
210
+ prompt_text: str,
211
+ mode: str,
212
+ command: Optional[str] = None,
213
+ velocity_bins: Optional[Tuple[int, int]] = None,
214
+ acceleration_bins: Optional[Tuple[int, int]] = None,
215
+ ) -> str:
216
+ """Rewrite planning prompts to match Atlas Table 3 variants.
217
+
218
+ Modes:
219
+ - atlas_base: no high-level command, no explicit ego-state values
220
+ - atlas_high_level: keep high-level command only
221
+ - atlas_high_level_ego: keep high-level command and inject ego-state bins
222
+
223
+ In this repository, the top-level route command is a UniAD-style
224
+ future-GT-derived coarse planning command, not a raw nuScenes field.
225
+ """
226
+ if mode not in PLANNING_TABLE3_MODES:
227
+ raise ValueError(f"Unsupported planning mode: {mode}")
228
+
229
+ prompt = " ".join(str(prompt_text).split())
230
+ prompt = _PLANNING_STATE_RE.sub("", prompt)
231
+ prompt = _PLANNING_COMMAND_RE.sub("", prompt)
232
+
233
+ if mode == "atlas_base":
234
+ return " ".join(prompt.split())
235
+
236
+ if not command:
237
+ raise ValueError(
238
+ f"{mode} requires an explicit top-level route command field"
239
+ )
240
+
241
+ command_sentence = f"The ego car will {command} in future."
242
+
243
+ if mode == "atlas_high_level":
244
+ if _PLANNING_ACCEL_SENTENCE in prompt:
245
+ prompt = prompt.replace(
246
+ _PLANNING_ACCEL_SENTENCE,
247
+ f"{_PLANNING_ACCEL_SENTENCE} {command_sentence}",
248
+ 1,
249
+ )
250
+ return " ".join(prompt.split())
251
+ return " ".join(f"{command_sentence} {prompt}".split())
252
+
253
+ if velocity_bins is None or acceleration_bins is None:
254
+ raise ValueError(
255
+ "atlas_high_level_ego requires velocity_bins and acceleration_bins"
256
+ )
257
+
258
+ state_sentence = (
259
+ f"The current speed value of the ego car is [{int(velocity_bins[0])}, {int(velocity_bins[1])}]. "
260
+ f"The current acceleration value of the ego car is [{int(acceleration_bins[0])}, {int(acceleration_bins[1])}]."
261
+ )
262
+ if _PLANNING_ACCEL_SENTENCE in prompt:
263
+ prompt = prompt.replace(
264
+ _PLANNING_ACCEL_SENTENCE,
265
+ f"{_PLANNING_ACCEL_SENTENCE} {state_sentence} {command_sentence}",
266
+ 1,
267
+ )
268
+ return " ".join(prompt.split())
269
+
270
+ return " ".join(f"{state_sentence} {command_sentence} {prompt}".split())
271
+
272
+
273
+ def build_prompt(prompt_text: str, mode: Literal["train", "infer"]) -> str:
274
+ s = f"USER: {prompt_text}\nASSISTANT:"
275
+ if audit_enabled():
276
+ h1b_record_prompt(mode, s)
277
+ return s
train_atlas.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import math
6
+ import time
7
+ import json
8
+ import logging
9
+ from datetime import timedelta
10
+ from pathlib import Path
11
+ from typing import Dict, Optional, List
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch.utils.data import DataLoader
16
+
17
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
18
+
19
+ from src.model.modeling_atlas import AtlasForCausalLM
20
+ from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens
21
+ from src.model.streampetr_adapter import extract_streampetr_topk_tokens
22
+ from src.dataset.atlas_dataset import (
23
+ AtlasDataset, make_atlas_collate_fn, load_tokenizer,
24
+ )
25
+ from src.dataset.scene_sampler import SceneSequentialSampler
26
+ from src.prompting import PLANNING_TABLE3_MODES
27
+
28
+ logger = logging.getLogger("train_atlas")
29
+
30
+
31
+ def parse_args():
32
+ p = argparse.ArgumentParser()
33
+ p.add_argument("--llm_model", default="lmsys/vicuna-7b-v1.5")
34
+ p.add_argument("--visual_hidden_size", type=int, default=256)
35
+ p.add_argument("--num_det_queries", type=int, default=256)
36
+ p.add_argument("--num_map_queries", type=int, default=256)
37
+ p.add_argument("--streampetr_config", default=None)
38
+ p.add_argument("--streampetr_ckpt", default=None)
39
+ p.add_argument("--topomlp_config", default=None)
40
+ p.add_argument("--topomlp_ckpt", default=None)
41
+ p.add_argument("--data_json", required=True)
42
+ p.add_argument("--data_root", default="/mnt/data/nuscenes")
43
+ p.add_argument("--max_length", type=int, default=4096)
44
+ p.add_argument("--output_dir", default="work_dirs/atlas")
45
+ p.add_argument("--lr", type=float, default=2e-5)
46
+ p.add_argument("--weight_decay", type=float, default=1e-4)
47
+ p.add_argument("--batch_size", type=int, default=1)
48
+ p.add_argument("--epochs", type=int, default=8)
49
+ p.add_argument("--warmup_ratio", type=float, default=0.03)
50
+ p.add_argument("--gradient_accumulation_steps", type=int, default=2)
51
+ p.add_argument("--max_grad_norm", type=float, default=1.0)
52
+ p.add_argument("--use_lora", action="store_true")
53
+ p.add_argument("--lora_r", type=int, default=64)
54
+ p.add_argument("--lora_alpha", type=int, default=64)
55
+ p.add_argument("--lora_dropout", type=float, default=0.1)
56
+ p.add_argument("--load_in_4bit", action="store_true")
57
+ p.add_argument("--save_steps", type=int, default=0)
58
+ p.add_argument("--save_epochs", type=int, default=1)
59
+ p.add_argument("--log_steps", type=int, default=10)
60
+ p.add_argument("--seed", type=int, default=42)
61
+ p.add_argument("--num_workers", type=int, default=4)
62
+ p.add_argument("--resume", default=None)
63
+ p.add_argument("--local_rank", "--local-rank", type=int, default=int(os.environ.get("LOCAL_RANK", -1)))
64
+ p.add_argument("--fp16", action="store_true")
65
+ p.add_argument("--bf16", action="store_true")
66
+ p.add_argument("--image_path_remap", default=None,
67
+ help="old=new path remap, e.g. /mnt/data=/local/data")
68
+ p.add_argument("--precomputed_det_tokens", default=None,
69
+ help="[offline only] Dir with precomputed det tokens (.pt files)")
70
+ p.add_argument("--precomputed_map_tokens", default=None,
71
+ help="[offline only] Dir with precomputed TopoMLP map tokens (.pt files)")
72
+ p.add_argument("--visual_token_mode", choices=("online", "offline"), default="online",
73
+ help="Visual token source: online=live frozen encoders (default), offline=read *_offline dirs")
74
+ p.add_argument("--deepspeed", default=None,
75
+ help="Path to DeepSpeed config JSON (enables ZeRO)")
76
+ p.add_argument("--keep_last_n_ckpts", type=int, default=0,
77
+ help="Keep only the N most recent epoch checkpoints (0=keep all)")
78
+ p.add_argument(
79
+ "--planning_table3_mode",
80
+ choices=PLANNING_TABLE3_MODES,
81
+ default="atlas_base",
82
+ help=(
83
+ "Planning prompt variant matching Atlas Table 3: "
84
+ "atlas_base=no command/no explicit ego state; "
85
+ "atlas_high_level=requires top-level route_command "
86
+ "(this repo uses a UniAD-style future-GT-derived command); "
87
+ "atlas_high_level_ego=requires top-level route_command plus "
88
+ "velocity/acceleration bins."
89
+ ),
90
+ )
91
+ return p.parse_args()
92
+
93
+
94
+ def _validate_visual_token_mode(args):
95
+ """Enforce mode-specific constraints. Fail hard, never silently degrade."""
96
+ if args.visual_token_mode == "online":
97
+ if args.precomputed_det_tokens or args.precomputed_map_tokens:
98
+ raise RuntimeError(
99
+ "visual_token_mode=online forbids --precomputed_det_tokens / "
100
+ "--precomputed_map_tokens. Use --visual_token_mode offline to "
101
+ "read offline token directories."
102
+ )
103
+ missing = []
104
+ if not args.streampetr_config or not args.streampetr_ckpt:
105
+ missing.append("--streampetr_config/--streampetr_ckpt")
106
+ if not args.topomlp_config or not args.topomlp_ckpt:
107
+ missing.append("--topomlp_config/--topomlp_ckpt")
108
+ if missing:
109
+ raise RuntimeError(
110
+ "visual_token_mode=online requires live encoder configs and "
111
+ "checkpoints. Missing: " + ", ".join(missing)
112
+ )
113
+ for p in (args.streampetr_config, args.streampetr_ckpt, args.topomlp_config, args.topomlp_ckpt):
114
+ if not os.path.exists(p):
115
+ raise RuntimeError(f"Required online asset does not exist: {p}")
116
+ if args.batch_size != 1:
117
+ raise RuntimeError(
118
+ "visual_token_mode=online with temporal memory requires "
119
+ "--batch_size 1 (paper-aligned). Got: %d" % args.batch_size
120
+ )
121
+ else:
122
+ if not args.precomputed_det_tokens and not args.precomputed_map_tokens:
123
+ raise RuntimeError(
124
+ "visual_token_mode=offline requires at least one "
125
+ "--precomputed_*_tokens directory."
126
+ )
127
+ for p in (args.precomputed_det_tokens, args.precomputed_map_tokens):
128
+ if p and not os.path.isdir(p):
129
+ raise RuntimeError(f"Offline token directory does not exist: {p}")
130
+
131
+
132
+ def set_seed(seed):
133
+ import random
134
+ import numpy as np
135
+ random.seed(seed)
136
+ np.random.seed(seed)
137
+ torch.manual_seed(seed)
138
+ if torch.cuda.is_available():
139
+ torch.cuda.manual_seed_all(seed)
140
+
141
+
142
+ def setup_distributed(local_rank):
143
+ if local_rank == -1:
144
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
145
+ return device, False, 0, 1
146
+ dist.init_process_group(backend="nccl", timeout=timedelta(seconds=1800))
147
+ torch.cuda.set_device(local_rank)
148
+ device = torch.device("cuda", local_rank)
149
+ rank = dist.get_rank()
150
+ world_size = dist.get_world_size()
151
+ return device, True, rank, world_size
152
+
153
+
154
+ def is_main_process(distributed, rank):
155
+ return (not distributed) or (rank == 0)
156
+
157
+
158
+ def load_frozen_encoder(config_path, ckpt_path, model_type, device):
159
+ if config_path is None or ckpt_path is None:
160
+ return None
161
+ try:
162
+ from mmcv import Config
163
+ from mmdet3d.models import build_model
164
+ from mmcv.runner import load_checkpoint
165
+ except ImportError:
166
+ raise RuntimeError(
167
+ f"mmcv/mmdet3d not installed but --{model_type}_config and "
168
+ f"--{model_type}_ckpt were explicitly provided. "
169
+ f"Install mmcv/mmdet3d or remove these arguments to train without {model_type}."
170
+ )
171
+
172
+ if model_type == "streampetr":
173
+ sp_root = str(Path(__file__).resolve().parent / "external" / "StreamPETR")
174
+ if sp_root not in sys.path:
175
+ sys.path.insert(0, sp_root)
176
+ try:
177
+ import projects.mmdet3d_plugin # noqa: F401
178
+ except ImportError:
179
+ raise RuntimeError(
180
+ f"StreamPETR plugin not found under {sp_root}/projects/mmdet3d_plugin. "
181
+ f"Ensure the submodule is checked out, or remove --streampetr_config/--streampetr_ckpt."
182
+ )
183
+ elif model_type == "topomlp":
184
+ tp_root = str(Path(__file__).resolve().parent / "external" / "TopoMLP_Repo")
185
+ if tp_root not in sys.path:
186
+ sys.path.insert(0, tp_root)
187
+ try:
188
+ os.environ["ATLAS_TOPOMLP_MODELS_ONLY"] = "1"
189
+ from mmcv.utils import registry as _reg
190
+ _orig = _reg.Registry._register_module
191
+ def _tolerant_register(self, module, module_name=None, force=False):
192
+ return _orig(self, module, module_name=module_name, force=True)
193
+ _reg.Registry._register_module = _tolerant_register
194
+ import projects.topomlp # noqa: F401
195
+ _reg.Registry._register_module = _orig
196
+ except ImportError:
197
+ raise RuntimeError(
198
+ f"TopoMLP plugin not found under {tp_root}/projects/topomlp. "
199
+ f"Ensure the submodule is checked out, or remove --topomlp_config/--topomlp_ckpt."
200
+ )
201
+
202
+ cfg = Config.fromfile(config_path)
203
+ model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
204
+ load_checkpoint(model, ckpt_path, map_location="cpu")
205
+ model.eval()
206
+ model.to(device)
207
+ for param in model.parameters():
208
+ param.requires_grad_(False)
209
+ logger.info("Loaded frozen %s from %s", model_type, ckpt_path)
210
+ return model
211
+
212
+
213
+ def build_img_metas_streampetr(batch, device, idx):
214
+ N = batch["pixel_values_det"].shape[1]
215
+ fH, fW = 800, 1600
216
+ scene_ids = batch.get("scene_id", ["__atlas__"] * (idx + 1))
217
+ meta = {
218
+ "pad_shape": [(fH, fW, 3)] * N,
219
+ "img_shape": [(fH, fW, 3)] * N,
220
+ "scene_token": scene_ids[idx] if idx < len(scene_ids) else "__atlas__",
221
+ }
222
+ if "lidar2img_det" in batch:
223
+ meta["lidar2img"] = batch["lidar2img_det"][idx].cpu().numpy()
224
+ return meta
225
+
226
+
227
+ def build_img_metas_topomlp(batch, device, idx):
228
+ meta = {}
229
+ if "lidar2img_map" in batch:
230
+ meta["lidar2img"] = batch["lidar2img_map"][idx].cpu().numpy()
231
+ tH, tW = 800, 1600
232
+ N = batch["pixel_values_map"].shape[1]
233
+ meta["img_shape"] = tuple([(tH, tW, 3)] * N)
234
+ meta["pad_shape"] = tuple([(tH, tW, 3)] * N)
235
+ meta["scale_factor"] = 1.0
236
+ meta["te_yolov8"] = None
237
+ return meta
238
+
239
+
240
+ @torch.no_grad()
241
+ def run_streampetr_forward(model, imgs, img_metas, batch, device, prev_exists=None):
242
+ B, N = imgs.shape[:2]
243
+
244
+ img_feats = model.extract_img_feat(imgs, 1)
245
+
246
+ data = {
247
+ "img": imgs,
248
+ "img_feats": img_feats,
249
+ "prev_exists": prev_exists if prev_exists is not None else imgs.new_zeros(B),
250
+ }
251
+
252
+ if "intrinsics_det" in batch:
253
+ K3 = batch["intrinsics_det"].to(device)
254
+ K4 = torch.zeros(B, N, 4, 4, device=device, dtype=K3.dtype)
255
+ K4[:, :, :3, :3] = K3
256
+ K4[:, :, 3, 3] = 1.0
257
+ data["intrinsics"] = K4
258
+ else:
259
+ data["intrinsics"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
260
+
261
+ if "lidar2img_det" in batch:
262
+ data["lidar2img"] = batch["lidar2img_det"].to(device)
263
+ else:
264
+ data["lidar2img"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
265
+
266
+ if "ego_pose" in batch and batch["ego_pose"] is not None:
267
+ data["ego_pose"] = batch["ego_pose"].to(device)
268
+ else:
269
+ data["ego_pose"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
270
+
271
+ if "ego_pose_inv" in batch and batch["ego_pose_inv"] is not None:
272
+ data["ego_pose_inv"] = batch["ego_pose_inv"].to(device)
273
+ else:
274
+ data["ego_pose_inv"] = torch.inverse(data["ego_pose"])
275
+
276
+ if "timestamp" in batch and batch["timestamp"] is not None:
277
+ data["timestamp"] = batch["timestamp"].to(device)
278
+ else:
279
+ data["timestamp"] = torch.zeros(B, device=device)
280
+
281
+ location = model.prepare_location(img_metas, **data)
282
+ outs_roi = model.forward_roi_head(location, **data)
283
+ topk_indexes = outs_roi["topk_indexes"]
284
+
285
+ outs = model.pts_bbox_head(location, img_metas, topk_indexes, **data)
286
+ return outs
287
+
288
+
289
+ @torch.no_grad()
290
+ def run_topomlp_forward(model, imgs, img_metas):
291
+ return model.simple_forward(imgs, img_metas)
292
+
293
+
294
+ def _reconstruct_topomlp_outs(saved: dict, device, dtype):
295
+ """Convert precomputed .pt dict back to the format adapter.forward() expects."""
296
+ def _restore(t):
297
+ return t.to(device=device, dtype=dtype).unsqueeze(0)
298
+ return {
299
+ "lc_outs_dec_list": [_restore(saved["lc_outs_dec"])],
300
+ "all_lc_cls_scores_list": [_restore(saved["lc_cls_scores"])],
301
+ "all_lc_preds_list": [_restore(saved["lc_preds"])],
302
+ "lc_outs_dec_one2many_list": [_restore(saved["lc_outs_dec_o2m"])],
303
+ "all_lc_cls_scores_one2many_list": [_restore(saved["lc_cls_scores_o2m"])],
304
+ "all_lc_preds_one2many_list": [_restore(saved["lc_preds_o2m"])],
305
+ }
306
+
307
+
308
+ def extract_visual_tokens(
309
+ streampetr_model,
310
+ topomlp_model,
311
+ topomlp_adapter,
312
+ batch,
313
+ device,
314
+ num_det_queries=256,
315
+ visual_hidden_size=256,
316
+ query_token_id=None,
317
+ visual_token_mode="online",
318
+ streaming_state=None,
319
+ ):
320
+ """Extract det + map visual tokens.
321
+
322
+ In online mode with streaming_state, StreamPETR temporal memory is managed
323
+ per-scene and duplicate physical frames are protected: if the current
324
+ sample_id equals the previous one, we reuse cached det tokens and skip the
325
+ StreamPETR forward to avoid pushing the same frame into memory twice.
326
+ """
327
+ B = batch["pixel_values_det"].shape[0]
328
+ vis: Dict[str, torch.Tensor] = {}
329
+
330
+ needs_map = False
331
+ if query_token_id is not None and "input_ids" in batch:
332
+ n_queries = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item())
333
+ needs_map = n_queries > num_det_queries
334
+
335
+ # ---- Detection tokens ----
336
+ if visual_token_mode == "offline" and "precomputed_det" in batch and "precomputed_det_ref" in batch:
337
+ vis["detection"] = batch["precomputed_det"].to(device)
338
+ vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device)
339
+ elif visual_token_mode == "offline":
340
+ raise RuntimeError(
341
+ "visual_token_mode=offline but detection precomputed tokens are missing "
342
+ "for the current batch. Refusing to zero-fill."
343
+ )
344
+ elif streampetr_model is not None:
345
+ if B != 1 and streaming_state is not None:
346
+ raise RuntimeError("online temporal det requires batch_size=1")
347
+
348
+ current_sample_id = batch.get("sample_id", [None])[0]
349
+ current_scene = batch.get("scene_id", ["__atlas__"])[0]
350
+ reuse_cache = False
351
+
352
+ if streaming_state is not None:
353
+ prev_scene = streaming_state.get("prev_scene_token")
354
+ prev_sample_id = streaming_state.get("prev_sample_id")
355
+ ts_tensor = batch.get("timestamp")
356
+ current_ts = float(ts_tensor[0].item()) if ts_tensor is not None else None
357
+ prev_ts = streaming_state.get("prev_timestamp")
358
+
359
+ is_new_segment = (
360
+ prev_scene is None
361
+ or current_scene != prev_scene
362
+ or (current_ts is not None and prev_ts is not None and current_ts <= prev_ts)
363
+ )
364
+
365
+ if current_sample_id is not None and current_sample_id == prev_sample_id:
366
+ cached = streaming_state.get("cached_det")
367
+ if cached is not None:
368
+ reuse_cache = True
369
+ vis["detection"] = cached["detection"]
370
+ vis["detection_ref_points"] = cached["detection_ref_points"]
371
+
372
+ if not reuse_cache:
373
+ if is_new_segment:
374
+ streampetr_model.pts_bbox_head.reset_memory()
375
+ prev_exists_val = 0.0 if is_new_segment else 1.0
376
+ imgs_det = batch["pixel_values_det"].to(device)
377
+ prev_exists = imgs_det.new_full((B,), prev_exists_val)
378
+
379
+ img_metas = [build_img_metas_streampetr(batch, device, b) for b in range(B)]
380
+ run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device, prev_exists=prev_exists)
381
+ ego_pose_for_ref = batch.get("ego_pose")
382
+ if ego_pose_for_ref is not None:
383
+ ego_pose_for_ref = ego_pose_for_ref.to(device)
384
+ det_out = extract_streampetr_topk_tokens(
385
+ streampetr_model.pts_bbox_head,
386
+ topk=num_det_queries,
387
+ ego_pose=ego_pose_for_ref,
388
+ )
389
+ vis["detection"] = det_out["detection"]
390
+ vis["detection_ref_points"] = det_out["detection_ref_points"]
391
+
392
+ streaming_state["cached_det"] = {
393
+ "detection": vis["detection"],
394
+ "detection_ref_points": vis["detection_ref_points"],
395
+ }
396
+
397
+ streaming_state["prev_scene_token"] = current_scene
398
+ streaming_state["prev_sample_id"] = current_sample_id
399
+ if batch.get("timestamp") is not None:
400
+ streaming_state["prev_timestamp"] = float(batch["timestamp"][0].item())
401
+ else:
402
+ imgs_det = batch["pixel_values_det"].to(device)
403
+ img_metas = [build_img_metas_streampetr(batch, device, b) for b in range(B)]
404
+ run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device)
405
+ ego_pose_for_ref = batch.get("ego_pose")
406
+ if ego_pose_for_ref is not None:
407
+ ego_pose_for_ref = ego_pose_for_ref.to(device)
408
+ det_out = extract_streampetr_topk_tokens(
409
+ streampetr_model.pts_bbox_head,
410
+ topk=num_det_queries,
411
+ ego_pose=ego_pose_for_ref,
412
+ )
413
+ vis["detection"] = det_out["detection"]
414
+ vis["detection_ref_points"] = det_out["detection_ref_points"]
415
+ elif visual_token_mode == "online":
416
+ raise RuntimeError(
417
+ "visual_token_mode=online but StreamPETR model is None. "
418
+ "Provide --streampetr_config and --streampetr_ckpt."
419
+ )
420
+ else:
421
+ vis["detection"] = torch.zeros(B, num_det_queries, visual_hidden_size, device=device)
422
+ vis["detection_ref_points"] = torch.zeros(B, num_det_queries, 3, device=device)
423
+
424
+ # ---- Map tokens ----
425
+ num_map_queries = num_det_queries
426
+ if topomlp_adapter is not None:
427
+ num_map_queries = topomlp_adapter.num_map_tokens
428
+
429
+ if topomlp_adapter is not None:
430
+ _params = list(topomlp_adapter.parameters())
431
+ _bufs = list(topomlp_adapter.buffers())
432
+ adapter_dtype = _params[0].dtype if _params else (_bufs[0].dtype if _bufs else torch.float32)
433
+
434
+ map_filled = False
435
+ if visual_token_mode == "offline" and needs_map and "precomputed_map" in batch:
436
+ if B == 1:
437
+ outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype)
438
+ else:
439
+ per_sample = [_reconstruct_topomlp_outs(batch["precomputed_map"][b], device, adapter_dtype) for b in range(B)]
440
+ outs = {}
441
+ for k in per_sample[0]:
442
+ outs[k] = [torch.cat([s[k][i] for s in per_sample], dim=0) for i in range(len(per_sample[0][k]))]
443
+ map_out = topomlp_adapter(outs)
444
+ vis["map"] = map_out["map"]
445
+ vis["map_ref_points"] = map_out["map_ref_points"]
446
+ map_filled = True
447
+ elif visual_token_mode == "offline" and needs_map:
448
+ raise RuntimeError(
449
+ "visual_token_mode=offline but map precomputed tokens are missing "
450
+ "for a batch that requires map queries. Refusing to zero-fill."
451
+ )
452
+ elif needs_map and topomlp_model is not None:
453
+ imgs_map = batch["pixel_values_map"].to(device)
454
+ img_metas = [build_img_metas_topomlp(batch, device, b) for b in range(B)]
455
+ outs = run_topomlp_forward(topomlp_model, imgs_map, img_metas)
456
+ for k, v in outs.items():
457
+ if isinstance(v, torch.Tensor):
458
+ outs[k] = v.to(adapter_dtype)
459
+ elif isinstance(v, list):
460
+ outs[k] = [x.to(adapter_dtype) if isinstance(x, torch.Tensor) else x for x in v]
461
+ map_out = topomlp_adapter(outs)
462
+ vis["map"] = map_out["map"]
463
+ vis["map_ref_points"] = map_out["map_ref_points"]
464
+ map_filled = True
465
+ elif needs_map and visual_token_mode == "online":
466
+ raise RuntimeError(
467
+ "visual_token_mode=online but TopoMLP model is None. "
468
+ "Provide --topomlp_config and --topomlp_ckpt."
469
+ )
470
+
471
+ if not map_filled:
472
+ vis["map"] = torch.zeros(B, num_map_queries, visual_hidden_size, device=device)
473
+ vis["map_ref_points"] = torch.zeros(B, num_map_queries, 3, device=device)
474
+
475
+ return vis
476
+
477
+
478
+ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.0):
479
+ def lr_lambda(step):
480
+ if step < num_warmup_steps:
481
+ return float(step) / float(max(1, num_warmup_steps))
482
+ progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
483
+ return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))
484
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
485
+
486
+
487
+ def _optimizer_steps_per_epoch(num_batches: int, grad_accum_steps: int) -> int:
488
+ if num_batches <= 0:
489
+ return 0
490
+ return int(math.ceil(float(num_batches) / float(max(1, grad_accum_steps))))
491
+
492
+
493
+ def _accum_window_size_for_batch(
494
+ batch_idx: int,
495
+ num_batches: int,
496
+ grad_accum_steps: int,
497
+ ) -> int:
498
+ """Return the effective accumulation window size for this batch.
499
+
500
+ Full windows use `grad_accum_steps`. The final partial window uses the
501
+ remainder so that tail batches are not under-scaled when we flush them.
502
+ """
503
+ grad_accum_steps = max(1, int(grad_accum_steps))
504
+ remainder = int(num_batches % grad_accum_steps)
505
+ tail_start = int(num_batches - remainder)
506
+ if remainder > 0 and batch_idx >= tail_start:
507
+ return remainder
508
+ return grad_accum_steps
509
+
510
+
511
+ def _is_optimizer_step_batch(
512
+ batch_idx: int,
513
+ num_batches: int,
514
+ grad_accum_steps: int,
515
+ ) -> bool:
516
+ grad_accum_steps = max(1, int(grad_accum_steps))
517
+ natural_boundary = ((batch_idx + 1) % grad_accum_steps) == 0
518
+ is_last_batch = (batch_idx + 1) == num_batches
519
+ return natural_boundary or is_last_batch
520
+
521
+
522
+ def save_checkpoint(path, atlas, adapter, optimizer, scheduler, global_step, epoch, args):
523
+ save_dict = {
524
+ "global_step": global_step,
525
+ "epoch": epoch,
526
+ "args": vars(args),
527
+ "atlas_state_dict": {k: v.cpu() for k, v in atlas.state_dict().items()},
528
+ "optimizer": optimizer.state_dict(),
529
+ "scheduler": scheduler.state_dict() if scheduler is not None else None,
530
+ }
531
+ if adapter is not None:
532
+ save_dict["adapter_state_dict"] = {k: v.cpu() for k, v in adapter.state_dict().items()}
533
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
534
+ torch.save(save_dict, path)
535
+
536
+
537
+ def cleanup_old_checkpoints(output_dir: Path, keep_n: int):
538
+ """Delete old epoch-* checkpoint dirs, keeping only the most recent *keep_n*."""
539
+ if keep_n <= 0:
540
+ return
541
+ import shutil
542
+ epoch_dirs = sorted(
543
+ [d for d in output_dir.iterdir() if d.is_dir() and d.name.startswith("epoch-")],
544
+ key=lambda d: int(d.name.split("-")[1]),
545
+ )
546
+ while len(epoch_dirs) > keep_n:
547
+ old = epoch_dirs.pop(0)
548
+ shutil.rmtree(old, ignore_errors=True)
549
+ logger.info("Deleted old checkpoint: %s", old)
550
+
551
+
552
+ def main():
553
+ args = parse_args()
554
+ class _FlushHandler(logging.StreamHandler):
555
+ def emit(self, record):
556
+ super().emit(record)
557
+ self.flush()
558
+ logging.root.handlers.clear()
559
+ _h = _FlushHandler(sys.stderr)
560
+ _fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
561
+ _h.setFormatter(_fmt)
562
+ logging.root.addHandler(_h)
563
+ logging.root.setLevel(logging.INFO)
564
+
565
+ _validate_visual_token_mode(args)
566
+
567
+ device, distributed, rank, world_size = setup_distributed(args.local_rank)
568
+ set_seed(args.seed + rank)
569
+ _main = is_main_process(distributed, rank)
570
+
571
+ is_online = args.visual_token_mode == "online"
572
+
573
+ output_dir = Path(args.output_dir)
574
+ if _main:
575
+ output_dir.mkdir(parents=True, exist_ok=True)
576
+ _fh = logging.FileHandler(str(output_dir / "train.log"), mode="a")
577
+ _fh.setFormatter(_fmt)
578
+ logging.root.addHandler(_fh)
579
+ with open(output_dir / "args.json", "w") as f:
580
+ json.dump(vars(args), f, indent=2)
581
+
582
+ if _main:
583
+ logger.info("Loading tokenizer: %s", args.llm_model)
584
+ tokenizer = load_tokenizer(args.llm_model)
585
+ if "<query>" not in tokenizer.get_vocab():
586
+ tokenizer.add_tokens(["<query>"])
587
+
588
+ _precomp_det = args.precomputed_det_tokens if not is_online else None
589
+ _precomp_map = args.precomputed_map_tokens if not is_online else None
590
+ dataset = AtlasDataset(
591
+ json_file=args.data_json,
592
+ image_root=args.data_root,
593
+ tokenizer=tokenizer,
594
+ max_length=args.max_length,
595
+ is_training=True,
596
+ planning_table3_mode=args.planning_table3_mode,
597
+ image_path_remap=args.image_path_remap,
598
+ precomputed_det_tokens=_precomp_det,
599
+ precomputed_map_tokens=_precomp_map,
600
+ )
601
+
602
+ if is_online:
603
+ scene_groups = dataset.get_scene_groups()
604
+ sampler = SceneSequentialSampler(
605
+ scene_groups,
606
+ num_replicas=world_size,
607
+ rank=rank,
608
+ seed=args.seed,
609
+ pad_to_multiple=args.gradient_accumulation_steps,
610
+ )
611
+ if _main:
612
+ logger.info("Online mode: SceneSequentialSampler (%d scenes, %d samples, world=%d)",
613
+ len(scene_groups), len(dataset), world_size)
614
+ else:
615
+ from torch.utils.data import DistributedSampler
616
+ sampler = DistributedSampler(dataset, shuffle=True) if distributed else None
617
+
618
+ collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id)
619
+ dataloader = DataLoader(
620
+ dataset,
621
+ batch_size=args.batch_size,
622
+ shuffle=(not is_online and sampler is None),
623
+ sampler=sampler,
624
+ num_workers=args.num_workers,
625
+ collate_fn=collate_fn,
626
+ pin_memory=True,
627
+ drop_last=not is_online,
628
+ )
629
+
630
+ streampetr_model = load_frozen_encoder(
631
+ args.streampetr_config, args.streampetr_ckpt, "streampetr", device
632
+ )
633
+ topomlp_model = load_frozen_encoder(
634
+ args.topomlp_config, args.topomlp_ckpt, "topomlp", device
635
+ )
636
+
637
+ topomlp_adapter = None
638
+ if topomlp_model is not None or _precomp_map:
639
+ _tp_bev_range = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0)
640
+ if args.topomlp_config:
641
+ try:
642
+ from mmcv import Config as _Cfg
643
+ _tp_cfg = _Cfg.fromfile(args.topomlp_config)
644
+ if hasattr(_tp_cfg, "point_cloud_range"):
645
+ _tp_bev_range = tuple(float(v) for v in _tp_cfg.point_cloud_range)
646
+ logger.info("TopoMLP bev_range from config: %s", _tp_bev_range)
647
+ except Exception as e:
648
+ logger.warning("Failed to read point_cloud_range from TopoMLP config: %s. Using default: %s", e, _tp_bev_range)
649
+ topomlp_adapter = TopoMLPToAtlasMapTokens(
650
+ num_map_tokens=args.num_map_queries,
651
+ hidden_size=args.visual_hidden_size,
652
+ bev_range=_tp_bev_range,
653
+ ).to(device)
654
+
655
+ dtype = torch.float32
656
+ if args.bf16:
657
+ dtype = torch.bfloat16
658
+ elif args.fp16:
659
+ dtype = torch.float16
660
+
661
+ if args.load_in_4bit:
662
+ dm = {"": device} if distributed else "auto"
663
+ else:
664
+ dm = None
665
+
666
+ _ds_bf16 = False
667
+ _ds_fp16 = False
668
+ if args.deepspeed:
669
+ with open(args.deepspeed) as _f:
670
+ _ds_cfg_peek = json.load(_f)
671
+ _ds_bf16 = _ds_cfg_peek.get("bf16", {}).get("enabled", False)
672
+ _ds_fp16 = _ds_cfg_peek.get("fp16", {}).get("enabled", False)
673
+ _use_half = args.bf16 or args.fp16 or _ds_bf16 or _ds_fp16
674
+ if _use_half and dtype == torch.float32:
675
+ dtype = torch.bfloat16 if (args.bf16 or _ds_bf16) else torch.float16
676
+
677
+ atlas = AtlasForCausalLM(
678
+ llm_model_name=args.llm_model,
679
+ visual_hidden_size=args.visual_hidden_size,
680
+ num_queries=args.num_det_queries,
681
+ num_map_queries=args.num_map_queries,
682
+ load_in_4bit=args.load_in_4bit,
683
+ use_flash_attention=_use_half,
684
+ device_map=dm,
685
+ torch_dtype=dtype,
686
+ use_lora=args.use_lora,
687
+ lora_r=args.lora_r,
688
+ lora_alpha=args.lora_alpha,
689
+ lora_dropout=args.lora_dropout,
690
+ )
691
+ atlas.resize_token_embeddings(len(tokenizer))
692
+ query_token_id = tokenizer.convert_tokens_to_ids("<query>")
693
+ atlas.set_query_token_id(query_token_id)
694
+ if topomlp_adapter is not None:
695
+ atlas.topomlp_adapter = topomlp_adapter
696
+ if dm is None and args.deepspeed is None:
697
+ atlas = atlas.to(device)
698
+ atlas.gradient_checkpointing_enable()
699
+
700
+ num_batches_per_epoch = len(dataloader)
701
+ steps_per_epoch = _optimizer_steps_per_epoch(
702
+ num_batches_per_epoch, args.gradient_accumulation_steps
703
+ )
704
+ total_steps = steps_per_epoch * args.epochs
705
+ warmup_steps = int(total_steps * args.warmup_ratio)
706
+
707
+ global_step = 0
708
+ start_epoch = 0
709
+
710
+ _resume_ckpt = None
711
+ if args.resume:
712
+ _resume_ckpt = torch.load(args.resume, map_location="cpu")
713
+ if "atlas_state_dict" not in _resume_ckpt:
714
+ raise RuntimeError(f"Checkpoint missing 'atlas_state_dict'. Keys: {list(_resume_ckpt.keys())}")
715
+ missing, _ = atlas.load_state_dict(_resume_ckpt["atlas_state_dict"], strict=False)
716
+ if _main and missing:
717
+ logger.warning("Resume: %d missing keys (first 10): %s", len(missing), missing[:10])
718
+ if topomlp_adapter is not None and "adapter_state_dict" in _resume_ckpt:
719
+ _m, _u = topomlp_adapter.load_state_dict(_resume_ckpt["adapter_state_dict"], strict=False)
720
+ if _main and _u:
721
+ logger.info("Adapter resume: ignored %d legacy keys: %s", len(_u), _u[:5])
722
+ global_step = _resume_ckpt.get("global_step", 0)
723
+ start_epoch = _resume_ckpt.get("epoch", 0)
724
+ if _main:
725
+ logger.info("Resumed from %s (step=%d, epoch=%d)", args.resume, global_step, start_epoch)
726
+
727
+ use_deepspeed = args.deepspeed is not None
728
+ if use_deepspeed:
729
+ import deepspeed
730
+ ds_config = json.load(open(args.deepspeed))
731
+ ds_config["optimizer"] = {
732
+ "type": "Adam",
733
+ "params": {
734
+ "lr": args.lr, "weight_decay": args.weight_decay,
735
+ "betas": [0.9, 0.999], "torch_adam": True, "adam_w_mode": True,
736
+ },
737
+ }
738
+ ds_config["scheduler"] = {
739
+ "type": "WarmupCosineLR",
740
+ "params": {
741
+ "total_num_steps": total_steps,
742
+ "warmup_num_steps": warmup_steps,
743
+ "warmup_type": "linear",
744
+ },
745
+ }
746
+ ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
747
+ ds_config["train_micro_batch_size_per_gpu"] = args.batch_size
748
+ ds_config["train_batch_size"] = args.batch_size * args.gradient_accumulation_steps * world_size
749
+
750
+ ds_bf16 = ds_config.get("bf16", {}).get("enabled", False)
751
+ ds_fp16 = ds_config.get("fp16", {}).get("enabled", False)
752
+ if ds_bf16:
753
+ atlas.to(device=device, dtype=torch.bfloat16)
754
+ elif ds_fp16:
755
+ atlas.to(device=device, dtype=torch.float16)
756
+ else:
757
+ atlas.to(device)
758
+
759
+ all_params = atlas.get_trainable_param_groups(args.lr, weight_decay=args.weight_decay)
760
+ if topomlp_adapter is not None:
761
+ _adapter_trainable = [p for p in topomlp_adapter.parameters() if p.requires_grad]
762
+ if _adapter_trainable:
763
+ all_params.append({"params": _adapter_trainable, "lr": args.lr, "weight_decay": 0.0})
764
+
765
+ atlas_ddp, optimizer, _, scheduler = deepspeed.initialize(
766
+ model=atlas, model_parameters=all_params,
767
+ config=ds_config, dist_init_required=False,
768
+ )
769
+
770
+ if _resume_ckpt is not None and "optimizer" in _resume_ckpt:
771
+ try:
772
+ optimizer.load_state_dict(_resume_ckpt["optimizer"])
773
+ if _main:
774
+ logger.info("Restored DeepSpeed optimizer state from checkpoint")
775
+ except Exception as e:
776
+ if _main:
777
+ logger.warning("Failed to restore DeepSpeed optimizer state: %s", e)
778
+
779
+ if global_step > 0 and scheduler is not None:
780
+ for _ in range(global_step):
781
+ scheduler.step()
782
+ _ff_lr = scheduler.get_lr()
783
+ if _main:
784
+ logger.info(
785
+ "Fast-forwarded DeepSpeed LR scheduler to step %d (lr=%s)",
786
+ global_step,
787
+ [f"{x:.6e}" for x in _ff_lr] if isinstance(_ff_lr, (list, tuple)) else f"{_ff_lr:.6e}",
788
+ )
789
+ else:
790
+ param_groups = atlas.get_trainable_param_groups(args.lr, weight_decay=args.weight_decay)
791
+ if topomlp_adapter is not None:
792
+ _adapter_trainable = [p for p in topomlp_adapter.parameters() if p.requires_grad]
793
+ if _adapter_trainable:
794
+ param_groups.append({"params": _adapter_trainable, "lr": args.lr, "weight_decay": 0.0})
795
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
796
+ scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
797
+ if distributed:
798
+ atlas_ddp = torch.nn.parallel.DistributedDataParallel(
799
+ atlas, device_ids=[args.local_rank], find_unused_parameters=True,
800
+ )
801
+ else:
802
+ atlas_ddp = atlas
803
+
804
+ if _resume_ckpt is not None and not use_deepspeed:
805
+ if "optimizer" in _resume_ckpt:
806
+ try:
807
+ optimizer.load_state_dict(_resume_ckpt["optimizer"])
808
+ except Exception as e:
809
+ if _main:
810
+ logger.warning("Failed to restore optimizer state: %s", e)
811
+ if "scheduler" in _resume_ckpt and _resume_ckpt["scheduler"] is not None:
812
+ try:
813
+ scheduler.load_state_dict(_resume_ckpt["scheduler"])
814
+ except Exception as e:
815
+ if _main:
816
+ logger.warning("Failed to restore scheduler state: %s", e)
817
+ _resume_ckpt = None
818
+
819
+ atlas_ddp.train()
820
+ if topomlp_adapter is not None:
821
+ topomlp_adapter.train()
822
+
823
+ if _main:
824
+ logger.info("=== Training Config ===")
825
+ logger.info(" epochs: %d, lr: %s, batch: %d, accum: %d",
826
+ args.epochs, args.lr, args.batch_size, args.gradient_accumulation_steps)
827
+ logger.info(" total_steps: %d, warmup_steps: %d", total_steps, warmup_steps)
828
+ logger.info(" use_lora: %s, load_in_4bit: %s, fp16: %s, bf16: %s, deepspeed: %s",
829
+ args.use_lora, args.load_in_4bit, args.fp16, args.bf16, use_deepspeed)
830
+ n_trainable = sum(p.numel() for p in atlas.parameters() if p.requires_grad)
831
+ if topomlp_adapter is not None:
832
+ n_trainable += sum(p.numel() for p in topomlp_adapter.parameters() if p.requires_grad)
833
+ logger.info(" trainable params: %s", f"{n_trainable:,}")
834
+ logger.info(" visual_token_mode: %s", args.visual_token_mode)
835
+ logger.info(" streampetr: %s", "online-temporal" if (is_online and streampetr_model) else ("loaded" if streampetr_model else ("precomputed" if _precomp_det else "NONE (should not happen)")))
836
+ logger.info(" topomlp: %s", "online" if (is_online and topomlp_model) else ("loaded" if topomlp_model else ("precomputed" if _precomp_map else "NONE (should not happen)")))
837
+ logger.info("=======================")
838
+
839
+ streaming_state = {} if is_online else None
840
+
841
+ for epoch in range(start_epoch, args.epochs):
842
+ if sampler is not None:
843
+ sampler.set_epoch(epoch)
844
+
845
+ if streaming_state is not None:
846
+ streaming_state.clear()
847
+ if streampetr_model is not None:
848
+ streampetr_model.pts_bbox_head.reset_memory()
849
+
850
+ epoch_loss = 0.0
851
+ num_batches = 0
852
+ t0 = time.time()
853
+
854
+ if not use_deepspeed:
855
+ optimizer.zero_grad()
856
+
857
+ for batch_idx, batch in enumerate(dataloader):
858
+ do_step = _is_optimizer_step_batch(
859
+ batch_idx, num_batches_per_epoch, args.gradient_accumulation_steps
860
+ )
861
+ accum_window_size = _accum_window_size_for_batch(
862
+ batch_idx, num_batches_per_epoch, args.gradient_accumulation_steps
863
+ )
864
+ scaled_loss = None
865
+ if use_deepspeed:
866
+ if not hasattr(atlas_ddp, "set_gradient_accumulation_boundary"):
867
+ raise RuntimeError(
868
+ "DeepSpeed engine is missing set_gradient_accumulation_boundary(); "
869
+ "cannot enforce epoch-tail flush semantics."
870
+ )
871
+ atlas_ddp.set_gradient_accumulation_boundary(do_step)
872
+
873
+ if _main and batch_idx < 5:
874
+ _has_map = "precomputed_map" in batch
875
+ _nq = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item()) if query_token_id else 0
876
+ logger.info("[DBG] batch_idx=%d nq=%d has_map=%s sid=%s mode=%s",
877
+ batch_idx, _nq, _has_map,
878
+ batch.get("sample_id", ["?"])[0][:20],
879
+ args.visual_token_mode)
880
+ for _handler in logging.root.handlers:
881
+ _handler.flush()
882
+
883
+ input_ids = batch["input_ids"].to(device)
884
+ attention_mask = batch["attention_mask"].to(device)
885
+ labels = batch["labels"].to(device)
886
+
887
+ visual_features = extract_visual_tokens(
888
+ streampetr_model, topomlp_model, topomlp_adapter,
889
+ batch, device, args.num_det_queries, args.visual_hidden_size,
890
+ query_token_id=query_token_id,
891
+ visual_token_mode=args.visual_token_mode,
892
+ streaming_state=streaming_state,
893
+ )
894
+
895
+ if _main and batch_idx < 5:
896
+ logger.info("[DBG] vis_keys=%s", list(visual_features.keys()))
897
+ for _handler in logging.root.handlers:
898
+ _handler.flush()
899
+
900
+ if _main and batch_idx < 5:
901
+ logger.info("[DBG] pre-forward batch_idx=%d seqlen=%d", batch_idx, input_ids.shape[1])
902
+ for _handler in logging.root.handlers:
903
+ _handler.flush()
904
+
905
+ outputs = atlas_ddp(
906
+ input_ids=input_ids,
907
+ attention_mask=attention_mask,
908
+ visual_features=visual_features,
909
+ labels=labels,
910
+ )
911
+ loss = outputs.loss
912
+
913
+ if _main and batch_idx < 5:
914
+ logger.info("[DBG] post-forward batch_idx=%d loss=%.4f", batch_idx, loss.item())
915
+ for _handler in logging.root.handlers:
916
+ _handler.flush()
917
+
918
+
919
+
920
+ if _main and batch_idx < 5:
921
+ logger.info("[DBG] pre-backward batch_idx=%d", batch_idx)
922
+ for _handler in logging.root.handlers:
923
+ _handler.flush()
924
+
925
+ if use_deepspeed:
926
+ scaled_loss = loss / accum_window_size
927
+ atlas_ddp.backward(scaled_loss, scale_wrt_gas=False)
928
+
929
+ if _main and batch_idx < 5:
930
+ logger.info("[DBG] pre-step batch_idx=%d", batch_idx)
931
+ for _handler in logging.root.handlers:
932
+ _handler.flush()
933
+
934
+ atlas_ddp.step()
935
+ else:
936
+ scaled_loss = loss / accum_window_size
937
+ scaled_loss.backward()
938
+
939
+ if not use_deepspeed and distributed and topomlp_adapter is not None:
940
+ for p in topomlp_adapter.parameters():
941
+ if p.requires_grad and p.grad is not None:
942
+ dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
943
+ p.grad.div_(world_size)
944
+
945
+ epoch_loss += loss.item()
946
+ num_batches += 1
947
+ if _main and num_batches <= 3:
948
+ logger.info("batch=%d loss=%.4f", num_batches, loss.item())
949
+ for _handler in logging.root.handlers:
950
+ _handler.flush()
951
+
952
+ if do_step:
953
+ if not use_deepspeed:
954
+ all_params = list(atlas.parameters()) + (
955
+ list(topomlp_adapter.parameters()) if topomlp_adapter is not None else []
956
+ )
957
+ trainable = [p for p in all_params if p.requires_grad]
958
+ torch.nn.utils.clip_grad_norm_(trainable, args.max_grad_norm)
959
+ optimizer.step()
960
+ scheduler.step()
961
+ optimizer.zero_grad()
962
+ global_step += 1
963
+
964
+ if _main and global_step % args.log_steps == 0:
965
+ # DeepSpeed LR scheduler may not expose get_last_lr() before first scheduler.step().
966
+ if use_deepspeed and hasattr(atlas_ddp, "get_lr"):
967
+ try:
968
+ _lrs = atlas_ddp.get_lr()
969
+ if isinstance(_lrs, (list, tuple)) and len(_lrs) > 0:
970
+ lr_now = float(_lrs[0])
971
+ else:
972
+ lr_now = float(_lrs)
973
+ except Exception:
974
+ lr_now = optimizer.param_groups[0]["lr"] if getattr(optimizer, "param_groups", None) else args.lr
975
+ elif hasattr(scheduler, "get_last_lr"):
976
+ try:
977
+ lr_now = scheduler.get_last_lr()[0]
978
+ except Exception:
979
+ lr_now = optimizer.param_groups[0]["lr"] if getattr(optimizer, "param_groups", None) else args.lr
980
+ else:
981
+ lr_now = args.lr
982
+ elapsed = time.time() - t0
983
+ samples_sec = num_batches * args.batch_size / max(elapsed, 1e-6)
984
+ avg_loss = epoch_loss / max(num_batches, 1)
985
+ logger.info(
986
+ "epoch=%d step=%d loss=%.4f lr=%.2e samples/s=%.1f",
987
+ epoch, global_step, avg_loss, lr_now, samples_sec,
988
+ )
989
+ for _handler in logging.root.handlers:
990
+ _handler.flush()
991
+
992
+ if _main and args.save_steps > 0 and global_step % args.save_steps == 0:
993
+ ckpt_path = output_dir / f"checkpoint-{global_step}" / "checkpoint.pt"
994
+ save_checkpoint(ckpt_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, epoch, args)
995
+ logger.info("Saved step checkpoint: %s", ckpt_path)
996
+
997
+ avg_loss = epoch_loss / max(num_batches, 1)
998
+ if _main:
999
+ logger.info("Epoch %d done — avg_loss=%.4f (%.1f min)", epoch, avg_loss, (time.time() - t0) / 60)
1000
+
1001
+ if _main and (epoch + 1) % args.save_epochs == 0:
1002
+ ckpt_path = output_dir / f"epoch-{epoch}" / "checkpoint.pt"
1003
+ save_checkpoint(ckpt_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, epoch + 1, args)
1004
+ logger.info("Saved epoch checkpoint: %s", ckpt_path)
1005
+ if args.keep_last_n_ckpts > 0:
1006
+ cleanup_old_checkpoints(output_dir, args.keep_last_n_ckpts)
1007
+
1008
+ if _main:
1009
+ final_path = output_dir / "final" / "checkpoint.pt"
1010
+ save_checkpoint(final_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, args.epochs, args)
1011
+ logger.info("Training complete. Final checkpoint: %s", final_path)
1012
+
1013
+ if distributed:
1014
+ dist.destroy_process_group()
1015
+
1016
+
1017
+ if __name__ == "__main__":
1018
+ main()