Add files using upload-large-folder tool
Browse files- README.md +333 -0
- configs/REPRODUCTION.md +200 -0
- configs/ds_zero2.json +17 -0
- eval_atlas.py +1175 -0
- extract_streampetr_tokens.py +568 -0
- extract_topomlp_tokens.py +381 -0
- scripts/eval_checkpoint_offline.sh +44 -0
- scripts/gen_atlas_caption_dashscope.py +272 -0
- scripts/gen_atlas_caption_qa.py +274 -0
- scripts/gen_atlas_openlane_subsetB_lane_qa.py +251 -0
- scripts/gen_atlas_planning_qa.py +491 -0
- scripts/run_val_extraction.sh +56 -0
- scripts/train_no_caption_baseline.sh +50 -0
- scripts/train_no_caption_baseline_offline.sh +48 -0
- scripts/train_with_caption_balanced.sh +48 -0
- scripts/vis_atlas_lane_gt_pred.py +500 -0
- scripts/vis_atlas_planning_qualitative.py +800 -0
- scripts/vis_traffic_violation.py +516 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/prompting.cpython-310.pyc +0 -0
- src/__pycache__/prompting.cpython-38.pyc +0 -0
- src/audit/__pycache__/__init__.cpython-310.pyc +0 -0
- src/audit/__pycache__/audit_utils.cpython-310.pyc +0 -0
- src/dataset/__pycache__/__init__.cpython-310.pyc +0 -0
- src/dataset/__pycache__/atlas_dataset.cpython-310.pyc +0 -0
- src/dataset/__pycache__/atlas_dataset.cpython-38.pyc +0 -0
- src/dataset/__pycache__/scene_sampler.cpython-310.pyc +0 -0
- src/dataset/__pycache__/scene_sampler.cpython-38.pyc +0 -0
- src/dataset/atlas_dataset.py +1416 -0
- src/dataset/scene_sampler.py +111 -0
- src/eval/__pycache__/__init__.cpython-310.pyc +0 -0
- src/eval/__pycache__/__init__.cpython-38.pyc +0 -0
- src/eval/__pycache__/metrics.cpython-310.pyc +0 -0
- src/eval/__pycache__/metrics.cpython-38.pyc +0 -0
- src/eval/metrics.py +852 -0
- src/model/__init__.py +28 -0
- src/model/__pycache__/__init__.cpython-310.pyc +0 -0
- src/model/__pycache__/__init__.cpython-38.pyc +0 -0
- src/model/__pycache__/configuration_atlas.cpython-310.pyc +0 -0
- src/model/__pycache__/modeling_atlas.cpython-310.pyc +0 -0
- src/model/__pycache__/modeling_atlas.cpython-38.pyc +0 -0
- src/model/__pycache__/streampetr_adapter.cpython-310.pyc +0 -0
- src/model/__pycache__/streampetr_adapter.cpython-38.pyc +0 -0
- src/model/__pycache__/topomlp_adapter.cpython-310.pyc +0 -0
- src/model/__pycache__/topomlp_adapter.cpython-38.pyc +0 -0
- src/model/modeling_atlas.py +549 -0
- src/model/streampetr_adapter.py +110 -0
- src/model/topomlp_adapter.py +88 -0
- src/prompting.py +277 -0
- train_atlas.py +1018 -0
README.md
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- autonomous-driving
|
| 5 |
+
- 3d-detection
|
| 6 |
+
- lane-detection
|
| 7 |
+
- planning
|
| 8 |
+
- multimodal
|
| 9 |
+
- vicuna
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Atlas — 3D-Tokenized LLM for Autonomous Driving
|
| 13 |
+
|
| 14 |
+
基于 [Atlas 论文](https://arxiv.org/abs/2405.18361) 的多模态自动驾驶大语言模型实现。将 **StreamPETR**(3D 目标检测)和 **TopoMLP**(车道线检测)提取的 3D visual tokens 注入 **Vicuna-7B** LLM,实现检测、车道线、规划等多任务统一生成。
|
| 15 |
+
|
| 16 |
+
## 项目结构
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
3dtokenizer-atlas/
|
| 20 |
+
├── train_atlas.py # Atlas LLM 训练入口
|
| 21 |
+
├── eval_atlas.py # Atlas 评估入口
|
| 22 |
+
├── extract_streampetr_tokens.py # 预提取 StreamPETR detection tokens
|
| 23 |
+
├── extract_topomlp_tokens.py # 预提取 TopoMLP lane tokens
|
| 24 |
+
├── train_streampetr.sh # StreamPETR 预训练启动脚本
|
| 25 |
+
├── train_topomlp.sh # TopoMLP 预训练启动脚本
|
| 26 |
+
│
|
| 27 |
+
├── configs/
|
| 28 |
+
│ ├── streampetr_atlas_aligned.py # StreamPETR 配置 (EVA-02 ViT-L, 800x1600)
|
| 29 |
+
│ ├── topomlp_atlas_aligned.py # TopoMLP 配置 (EVA-02 ViT-L, 800x1600)
|
| 30 |
+
│ ├── ds_zero2.json # DeepSpeed ZeRO-2 配置
|
| 31 |
+
│ └── REPRODUCTION.md # 复现文档
|
| 32 |
+
│
|
| 33 |
+
├── src/
|
| 34 |
+
│ ├── model/
|
| 35 |
+
│ │ ├── modeling_atlas.py # AtlasForCausalLM 主模型
|
| 36 |
+
│ │ ├── streampetr_adapter.py # StreamPETR → 检测 token 适配器
|
| 37 |
+
│ │ └── topomlp_adapter.py # TopoMLP → 地图 token 适配器 (Top-K selection)
|
| 38 |
+
│ ├── dataset/
|
| 39 |
+
│ │ ├── atlas_dataset.py # AtlasDataset + Collate
|
| 40 |
+
│ │ └── scene_sampler.py # SceneSequentialSampler (时序采样)
|
| 41 |
+
│ ├── eval/
|
| 42 |
+
│ │ └── metrics.py # 评估指标 (F1/Chamfer/L2/Collision)
|
| 43 |
+
│ └── prompting.py # 多任务 Prompt 模板
|
| 44 |
+
│
|
| 45 |
+
├── scripts/
|
| 46 |
+
│ ├── gen_atlas_full_data.py # nuScenes → 检测 QA JSON
|
| 47 |
+
│ ├── gen_atlas_openlane_subsetB_lane_qa.py # OpenLane-V2 → 车道线 QA JSON
|
| 48 |
+
│ ├── gen_atlas_planning_qa.py # nuScenes → 规划 QA JSON
|
| 49 |
+
│ ├── train_no_caption_baseline.sh # 无 caption 训练脚本
|
| 50 |
+
│ └── train_with_caption_balanced.sh # 含 caption 训练脚本
|
| 51 |
+
│
|
| 52 |
+
├── data/ # 训练/验证数据 (JSON)
|
| 53 |
+
│ ├── atlas_nuscenes_train.json # 检测 (28,130 样本)
|
| 54 |
+
│ ├── atlas_nuscenes_val.json # 检测验证 (6,019 样本)
|
| 55 |
+
│ ├── openlane_subsetB_lane_train_4pt.json # 车道线 (27,968 样本, 4 点/lane)
|
| 56 |
+
│ ├── openlane_subsetB_lane_val_4pt.json # 车道线验证 (6,019 样本)
|
| 57 |
+
│ ├── atlas_planning_train_uniad_command.json # 规划 (23,541 样本, UniAD-style command)
|
| 58 |
+
│ ├── atlas_planning_val_uniad_command.json # 规划验证 (5,037 样本, UniAD-style command)
|
| 59 |
+
│ ├── atlas_caption_train.json # 环境描述 caption
|
| 60 |
+
│ └── atlas_caption_val.json # 环境描述 caption 验证
|
| 61 |
+
│
|
| 62 |
+
├── pretrained/ # 预训练权重
|
| 63 |
+
│ ├── vicuna-7b-v1.5/ # Vicuna-7B-v1.5 LLM
|
| 64 |
+
│ ├── eva02_L_coco_det_sys_o365_remapped_fixed.pth
|
| 65 |
+
│ └── streampetr/
|
| 66 |
+
│ └── streampetr_eva02_ep24.pth
|
| 67 |
+
│
|
| 68 |
+
├── work_dirs/
|
| 69 |
+
│ ├── _quick_eval_cpu.py # 快速检测评估 (CPU, micro-avg F1)
|
| 70 |
+
│ ├── _quick_eval_gpu.py # 快速检测评估 (GPU)
|
| 71 |
+
│ ├── _quick_eval_lane_gpu.py # 快速车道线评估 (GPU)
|
| 72 |
+
│ ├── _quick_eval_plan_gpu.py # 快速规划评估 (GPU, scene-sequential)
|
| 73 |
+
│ ├── precomputed_det_tokens_offline/ # 预提取的 StreamPETR tokens (offline 备选)
|
| 74 |
+
│ │ ├── train/ # 56,099 个 .pt 文件 (det + lane,planning 与 det 共享 ID)
|
| 75 |
+
│ │ └── val/ # 12,039 个 .pt 文件
|
| 76 |
+
│ ├── precomputed_map_tokens_offline/ # 预提取的 TopoMLP tokens (offline 备选)
|
| 77 |
+
│ │ ├── train/ # 51,510 个 .pt 文件 (lane + planning)
|
| 78 |
+
│ │ └── val/ # 11,057 个 .pt 文件
|
| 79 |
+
│ └── topomlp_atlas_aligned/ # TopoMLP 预训练权重
|
| 80 |
+
│ └── epoch_24.pth
|
| 81 |
+
│
|
| 82 |
+
└── external/ # 外部依赖
|
| 83 |
+
├── StreamPETR/
|
| 84 |
+
├── TopoMLP_Repo/
|
| 85 |
+
└── nuscenes-devkit/
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
## 模型架构
|
| 89 |
+
|
| 90 |
+
```
|
| 91 |
+
┌─────────────────────────────────────┐
|
| 92 |
+
6x 环视相机图片 → │ StreamPETR (frozen, EVA-02 ViT-L) │→ det tokens [B, 256, 256]
|
| 93 |
+
│ TopoMLP (frozen, EVA-02 ViT-L) │→ one-to-one lane queries (300) → Top-K → map tokens [B, 256, 256]
|
| 94 |
+
└─────────────────────────────────────┘
|
| 95 |
+
↓
|
| 96 |
+
AtlasUnifiedProjector
|
| 97 |
+
┌────────────────────────────────┐
|
| 98 |
+
│ projector_det: Linear(256→4096) │ ← 单层线性投影
|
| 99 |
+
│ projector_map: Linear(256→4096) │
|
| 100 |
+
│ projector_rp: Linear(3→256) │ ← Reference Point, zero-init
|
| 101 |
+
│ features += projector_rp(ref) │
|
| 102 |
+
└────────────────────────────────┘
|
| 103 |
+
↓
|
| 104 |
+
注入到 <query> token 位置 (256 det + 256 map)
|
| 105 |
+
↓
|
| 106 |
+
┌───────────────────────────────────────┐
|
| 107 |
+
│ Vicuna-7B (当前运行: 全参数微调) │
|
| 108 |
+
│ Causal Language Modeling Loss │
|
| 109 |
+
└───────────────────────────────────────┘
|
| 110 |
+
↓
|
| 111 |
+
多任务文本输出
|
| 112 |
+
(3D 检测 / 车道线 / 规划轨迹)
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## 训练配置(来源与优先级)
|
| 116 |
+
|
| 117 |
+
为避免“论文描述 / 脚本默认 / 实际运行”混淆,本仓库统一按以下优先级解释配置:
|
| 118 |
+
|
| 119 |
+
1. **论文描述**:以 Atlas 原论文(arXiv:2405.18361)正文 Section 3.2 与 Appendix B.2 为准。
|
| 120 |
+
2. **代码默认**:以 `train_atlas.py` 的 argparse 默认值为准。
|
| 121 |
+
3. **实际运行**:以 `work_dirs/<exp>/args.json` + `work_dirs/*train*.log` 为准(最高优先级)。
|
| 122 |
+
|
| 123 |
+
若三者冲突,请按第 3 条解释实验结果,不以 README 中示例命令覆盖真实运行参数。
|
| 124 |
+
|
| 125 |
+
### 论文原文(Atlas LLM,Section 3.2 + Appendix B.2)
|
| 126 |
+
|
| 127 |
+
| 项目 | 论文描述 |
|
| 128 |
+
|------|----------|
|
| 129 |
+
| Batch Size | 1 per GPU |
|
| 130 |
+
| Learning Rate | 2e-5 |
|
| 131 |
+
| Optimizer | AdamW (weight_decay=1e-4) |
|
| 132 |
+
| LR Schedule | 3% linear warmup + cosine decay |
|
| 133 |
+
| Max Length | 4096 |
|
| 134 |
+
| 硬件/时长 | 8x A100,约 100 小时 |
|
| 135 |
+
| LoRA | 论文附录未显式给出 LoRA 开关 |
|
| 136 |
+
|
| 137 |
+
### 当前实际运行(示例:`work_dirs/atlas_no_caption_v3_linear_warmup`)
|
| 138 |
+
|
| 139 |
+
来源:`work_dirs/atlas_no_caption_v3_linear_warmup/args.json` 与 `work_dirs/train_no_caption_v3_linear_warmup.log`
|
| 140 |
+
|
| 141 |
+
| 参数 | 实际值 |
|
| 142 |
+
|------|--------|
|
| 143 |
+
| LLM | Vicuna-7B-v1.5 |
|
| 144 |
+
| 微调方式 | **全参数微调** (`use_lora=false`) |
|
| 145 |
+
| 可训练参数 | 6,740,530,176 |
|
| 146 |
+
| Learning Rate | 2e-5 |
|
| 147 |
+
| Optimizer | AdamW (`weight_decay=1e-4`, `torch_adam`, `adam_w_mode`) |
|
| 148 |
+
| LR Schedule | WarmupCosineLR(warmup ratio=3%) |
|
| 149 |
+
| Epochs | 10 |
|
| 150 |
+
| Batch Size | 1 per GPU |
|
| 151 |
+
| Gradient Accumulation | 2 |
|
| 152 |
+
| Effective Batch Size | 8 (4 GPU x 1 x 2 accum) |
|
| 153 |
+
| Total Steps | 99,550 |
|
| 154 |
+
| Warmup Steps | 2,986 |
|
| 155 |
+
| Max Sequence Length | 4096 |
|
| 156 |
+
| 分布式 | DeepSpeed ZeRO-2 |
|
| 157 |
+
| GPU | 4x NVIDIA H100 80GB |
|
| 158 |
+
| 精度 | BF16(由 `configs/ds_zero2.json` 启用) |
|
| 159 |
+
| Visual Tokens | **在线** (live frozen StreamPETR + TopoMLP, temporal memory);离线预提取仅作为 fallback |
|
| 160 |
+
|
| 161 |
+
### 训练数据
|
| 162 |
+
|
| 163 |
+
| 任务 | 数据文件 | 样本数 |
|
| 164 |
+
|------|---------|--------|
|
| 165 |
+
| 3D 目标检测 | `atlas_nuscenes_train.json` | 28,130 |
|
| 166 |
+
| 3D 车道线检测 | `openlane_subsetB_lane_train_4pt.json` | 27,968 |
|
| 167 |
+
| 轨迹规划 | `atlas_planning_train_uniad_command.json` | 23,541 |
|
| 168 |
+
| 环境描述 (可选) | `atlas_caption_train.json` | — |
|
| 169 |
+
| **总计 (无 caption)** | | **79,639** |
|
| 170 |
+
|
| 171 |
+
车道线数据使用 4 个采样点/lane(论文 Appendix A.2 要求 four lane points,本仓库实现为均匀采样),不设 lane 数量上限(论文未指定上限),按 BEV 距离近→远排序。实际平均约 25 条 lane/样本,最多约 80 条。所有坐标使用 1000-bin 离散化。规划数据包含 `gt_boxes_3d_per_timestep` 字段用于 ST-P3 对齐的 per-timestep 碰撞评测。
|
| 172 |
+
|
| 173 |
+
三类主任务的 question pool 统一采用“前 3 条按论文 Table 6 / 7 / 9 原话整理,后 2 条为仓库补充的同风格扩展模板”的策略;其中车道线 Table 7 的第 2 条按论文现有文本原样保留。
|
| 174 |
+
|
| 175 |
+
为避免运行时再依赖 prompt 文本猜任务,四类样本的磁盘 JSON 统一显式写入 `task` 字段:`detection` / `lane` / `planning` / `caption`。这是仓库层面的工程化 schema,对论文中的 question-answer 文本格式不做额外解释。
|
| 176 |
+
|
| 177 |
+
caption 数据按论文 Appendix A.3 的单视角设定生成:Table 8 作为 GPT-4V 标注 prompt,human prompt 采用 Figure 5 风格的单模板,并注入具体 `camera_name`。
|
| 178 |
+
|
| 179 |
+
当前仓库不再向 prompt 追加 bins-format hint;detection / lane / caption 默认以磁盘 JSON 中的 `human` 文本作为 prompt 主体语义来源。planning ��务运行时仍会按 `planning_table3_mode` 对磁盘 `human` prompt 做轻量重写,只负责插入/剥离 command 句和 ego-state 句,再统一做 `<query>` 展开、空白归一化和 `USER: ... / ASSISTANT:` 包装。
|
| 180 |
+
|
| 181 |
+
当前 detection 的 canonical answer 格式为:`category: [x_bin, y_bin, z_bin], [x_bin, y_bin, z_bin]; category: [x_bin, y_bin, z_bin].`。当前 lane/map 的 canonical answer 格式为:`Lane: [x_bin, y_bin, z_bin], [x_bin, y_bin, z_bin], ...; [x_bin, y_bin, z_bin], ... .`。旧的 detection flat 文本和 `lane_centerline(id=...)` legacy 文本不再作为受支持协议。
|
| 182 |
+
|
| 183 |
+
planning 的 answer/output protocol 采用 Figure 5 风格表述,但保持论文 Table 9 的二维语义:`Ego car speed value:[vx_bin, vy_bin]. Ego car acceleration value:[ax_bin, ay_bin]. Based on the ego car speed and acceleration you predicted, requeset the ego car planning waypoint in 3-seconds: [x_bin, y_bin], ...`。当前实现不为 planning 引入第三维,也不使用固定 `z=500` 占位。
|
| 184 |
+
|
| 185 |
+
当前 planning JSON 的顶层 `route_command` 采用 **UniAD-style future-GT-derived command**:根据 future planning GT / future waypoints 的最后一个有效 timestep 的横向位移离散为 `turn left` / `turn right` / `go straight`。它不是 raw nuScenes 原生字段,也不是独立导航命令;因此 `atlas_high_level*` 在本仓库中的含义更接近 UniAD 风格条件输入,而不是 Atlas 论文 Table 3 严格意义上的独立 route command。
|
| 186 |
+
|
| 187 |
+
### 3D Tokenizer 预训练 (已完成)
|
| 188 |
+
|
| 189 |
+
| 参数 | StreamPETR | TopoMLP |
|
| 190 |
+
|------|-----------|---------|
|
| 191 |
+
| Backbone | EVA-02 ViT-L (embed_dim=1024) | EVA-02 ViT-L (embed_dim=1024) |
|
| 192 |
+
| Resolution | 800x1600 | 800x1600 |
|
| 193 |
+
| Queries | 256 (detection) | 256 (map, Top-K from 300 one-to-one queries) |
|
| 194 |
+
| Control Points | - | 4 per lane |
|
| 195 |
+
| Epochs | 24 | 24 |
|
| 196 |
+
| 数据集 | nuScenes trainval | OpenLane-V2 subset-B |
|
| 197 |
+
|
| 198 |
+
## 快速开始
|
| 199 |
+
|
| 200 |
+
### 1. 环境
|
| 201 |
+
|
| 202 |
+
```bash
|
| 203 |
+
conda activate streampetr
|
| 204 |
+
# 主要依赖: PyTorch 2.0+, transformers, peft, flash-attn, mmcv 1.7, mmdet3d 1.0
|
| 205 |
+
# DeepSpeed (ZeRO-2): pip install deepspeed
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
### 2. 数据准备
|
| 209 |
+
|
| 210 |
+
```bash
|
| 211 |
+
# nuScenes 数据根目录 (含 v1.0-trainval/ 和 samples/)
|
| 212 |
+
export DATA_ROOT=/path/to/nuscenes
|
| 213 |
+
|
| 214 |
+
# OpenLane-V2 subset-B
|
| 215 |
+
export OPENLANE_ROOT=/path/to/OpenLane-V2/subset_B
|
| 216 |
+
|
| 217 |
+
# 生成检测 QA 数据 (按类别分组, 论文 Figure 5 格式)
|
| 218 |
+
python scripts/gen_atlas_full_data.py \
|
| 219 |
+
--data-root $DATA_ROOT --split train \
|
| 220 |
+
--output data/atlas_nuscenes_train.json
|
| 221 |
+
python scripts/gen_atlas_full_data.py \
|
| 222 |
+
--data-root $DATA_ROOT --split val \
|
| 223 |
+
--output data/atlas_nuscenes_val.json
|
| 224 |
+
|
| 225 |
+
# 生成车道线 QA 数据 (4 点/lane, 无 lane 数量上限, BEV 距离排序)
|
| 226 |
+
python scripts/gen_atlas_openlane_subsetB_lane_qa.py \
|
| 227 |
+
--openlane_root $OPENLANE_ROOT \
|
| 228 |
+
--split train --out_json data/openlane_subsetB_lane_train_4pt.json
|
| 229 |
+
|
| 230 |
+
python scripts/gen_atlas_openlane_subsetB_lane_qa.py \
|
| 231 |
+
--openlane_root $OPENLANE_ROOT \
|
| 232 |
+
--split val --out_json data/openlane_subsetB_lane_val_4pt.json
|
| 233 |
+
|
| 234 |
+
# 生成规划 QA 数据
|
| 235 |
+
# 默认输出:
|
| 236 |
+
# data/atlas_planning_train_uniad_command.json
|
| 237 |
+
# data/atlas_planning_val_uniad_command.json
|
| 238 |
+
# 默认写顶层 route_command(UniAD-style future-GT-derived command)
|
| 239 |
+
# 默认 materialize atlas_high_level human prompt;运行时仍可通过
|
| 240 |
+
# --planning_table3_mode 改写为 atlas_base / atlas_high_level_ego
|
| 241 |
+
python scripts/gen_atlas_planning_qa.py \
|
| 242 |
+
--data-root $DATA_ROOT --split train
|
| 243 |
+
python scripts/gen_atlas_planning_qa.py \
|
| 244 |
+
--data-root $DATA_ROOT --split val
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### 3. 训练
|
| 248 |
+
|
| 249 |
+
默认使用 **在线模式**(`--visual_token_mode online`),训练时 live 前向 frozen StreamPETR(含 temporal memory)和 TopoMLP,无需预提取 token。
|
| 250 |
+
|
| 251 |
+
```bash
|
| 252 |
+
# ===== 推荐:在线模式训练(默认)=====
|
| 253 |
+
# 无 caption: det + planning + lane
|
| 254 |
+
bash scripts/train_no_caption_baseline.sh
|
| 255 |
+
|
| 256 |
+
# 含 caption: det + planning + lane + caption
|
| 257 |
+
bash scripts/train_with_caption_balanced.sh
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
等效手动命令(以无 caption 为例):
|
| 261 |
+
|
| 262 |
+
```bash
|
| 263 |
+
deepspeed --num_gpus 4 train_atlas.py \
|
| 264 |
+
--llm_model pretrained/vicuna-7b-v1.5 \
|
| 265 |
+
--data_json data/atlas_nuscenes_train.json,data/atlas_planning_train_uniad_command.json,data/openlane_subsetB_lane_train_4pt.json \
|
| 266 |
+
--data_root $DATA_ROOT \
|
| 267 |
+
--visual_token_mode online \
|
| 268 |
+
--streampetr_config configs/streampetr_atlas_aligned.py \
|
| 269 |
+
--streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
|
| 270 |
+
--topomlp_config configs/topomlp_atlas_aligned.py \
|
| 271 |
+
--topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
|
| 272 |
+
--deepspeed configs/ds_zero2.json \
|
| 273 |
+
--output_dir work_dirs/atlas_no_caption_online \
|
| 274 |
+
--lr 2e-5 --weight_decay 1e-4 \
|
| 275 |
+
--batch_size 1 --epochs 10 --gradient_accumulation_steps 2 \
|
| 276 |
+
--warmup_ratio 0.03 --max_grad_norm 1.0 \
|
| 277 |
+
--save_epochs 1 --log_steps 100 \
|
| 278 |
+
--seed 42 --num_workers 4
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
> **离线 fallback**:如需使用预提取 token 训练(速度更快,但 det 无 temporal memory),
|
| 282 |
+
> 先运行预提取脚本,再使用 `bash scripts/train_no_caption_baseline_offline.sh`。
|
| 283 |
+
> 预提取 token 存放于 `work_dirs/precomputed_*_tokens_offline/`。
|
| 284 |
+
|
| 285 |
+
### 4. 评估
|
| 286 |
+
|
| 287 |
+
`eval_atlas.py` 支持两种模式:**在线模式**(默认,live frozen encoder + temporal memory)和离线 fallback。
|
| 288 |
+
|
| 289 |
+
```bash
|
| 290 |
+
# ===== 推荐:在线模式评估(默认)=====
|
| 291 |
+
# 检测
|
| 292 |
+
bash scripts/eval_checkpoint.sh <checkpoint> data/atlas_nuscenes_val.json
|
| 293 |
+
|
| 294 |
+
# 车道线
|
| 295 |
+
bash scripts/eval_checkpoint.sh <checkpoint> data/openlane_subsetB_lane_val_4pt.json
|
| 296 |
+
|
| 297 |
+
# 规划
|
| 298 |
+
bash scripts/eval_checkpoint.sh <checkpoint> data/atlas_planning_val_uniad_command.json
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
等效手动命令:
|
| 302 |
+
|
| 303 |
+
```bash
|
| 304 |
+
python eval_atlas.py \
|
| 305 |
+
--checkpoint work_dirs/atlas_no_caption_online/final/checkpoint.pt \
|
| 306 |
+
--llm_model pretrained/vicuna-7b-v1.5 \
|
| 307 |
+
--visual_token_mode online \
|
| 308 |
+
--streampetr_config configs/streampetr_atlas_aligned.py \
|
| 309 |
+
--streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
|
| 310 |
+
--topomlp_config configs/topomlp_atlas_aligned.py \
|
| 311 |
+
--topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
|
| 312 |
+
--data_json data/atlas_nuscenes_val.json \
|
| 313 |
+
--data_root $DATA_ROOT \
|
| 314 |
+
--batch_size 1 --max_new_tokens 2700 --bf16
|
| 315 |
+
```
|
| 316 |
+
|
| 317 |
+
> **离线 fallback**:使用预提取 token 评估(速度更快,但 det 无 temporal memory):
|
| 318 |
+
> `bash scripts/eval_checkpoint_offline.sh <checkpoint> <data_json>`
|
| 319 |
+
>
|
| 320 |
+
> **快速验证脚本**(`work_dirs/_quick_eval_*.py`)使用离线 token,仅用于开发调试,**不等价于主在线评测,不应用于正式结果**。其口径与主评测存在差异:planning 解析更宽松、不走 live encoders/temporal memory、不自动检测 LoRA。
|
| 321 |
+
|
| 322 |
+
> **评测协议说明**:
|
| 323 |
+
> - **检测**:micro-averaged F1 @ 0.5/1.0/2.0/4.0m,BEV 中心距离匹配。
|
| 324 |
+
> - **车道线**:使用 OpenLane-V2 官方 F-Score 评测器(`openlanev2` 为必需依赖,缺失时直接报错,不再退化为 Chamfer)。
|
| 325 |
+
> - **规划**:L2 误差 + 碰撞率。规划数据含 `gt_boxes_3d_per_timestep` 字段时使用 ST-P3 对齐的 per-timestep 碰撞检测;旧数据自动退化为静态 box 检测。
|
| 326 |
+
> - 在线主评测(`eval_atlas.py`)需要 `mmcv`、`mmdet3d`、`openlanev2` 三个关键依赖,缺失时会在启动前报错。
|
| 327 |
+
|
| 328 |
+
## 参考
|
| 329 |
+
|
| 330 |
+
- [Atlas: Is a 3D-Tokenized LLM the Key to Reliable Autonomous Driving?](https://arxiv.org/abs/2405.18361)
|
| 331 |
+
- [StreamPETR](https://github.com/exiawsh/StreamPETR)
|
| 332 |
+
- [TopoMLP](https://github.com/wudongming97/TopoMLP)
|
| 333 |
+
- [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/)
|
configs/REPRODUCTION.md
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# StreamPETR Atlas-Aligned 复现指南
|
| 2 |
+
|
| 3 |
+
本文档记录了复现 Atlas 论文中 3D Tokenizer (StreamPETR) 所需的完整环境和数据准备信息。
|
| 4 |
+
|
| 5 |
+
## 1. 版本矩阵(官方 StreamPETR 要求)
|
| 6 |
+
|
| 7 |
+
| 依赖 | 版本 | 备注 |
|
| 8 |
+
|------|------|------|
|
| 9 |
+
| Python | >= 3.8 | 推荐 3.8 |
|
| 10 |
+
| CUDA | 11.2 | 或兼容版本 |
|
| 11 |
+
| PyTorch | 1.9.0 | `pip install torch==1.9.0+cu111` |
|
| 12 |
+
| torchvision | 0.10.0 | |
|
| 13 |
+
| mmcv-full | 1.6.0 | `pip install mmcv-full==1.6.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html` |
|
| 14 |
+
| mmdet | 2.28.2 | |
|
| 15 |
+
| mmsegmentation | 0.30.0 | |
|
| 16 |
+
| mmdetection3d | 1.0.0rc6 | `git checkout v1.0.0rc6` |
|
| 17 |
+
| flash-attn | 0.2.2 | 可选,但高分辨率训练需要 |
|
| 18 |
+
|
| 19 |
+
### 注意事项
|
| 20 |
+
- 如果使用 PyTorch >= 1.13,需要对应 flash-attn 0.2.8
|
| 21 |
+
- Tesla V100 可能不兼容 flash-attn,需要注释相关代码
|
| 22 |
+
|
| 23 |
+
## 2. StreamPETR 代码版本
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
# 克隆 StreamPETR
|
| 27 |
+
git clone https://github.com/exiawsh/StreamPETR
|
| 28 |
+
|
| 29 |
+
# 克隆 mmdetection3d 并切换到指定版本
|
| 30 |
+
cd StreamPETR
|
| 31 |
+
git clone https://github.com/open-mmlab/mmdetection3d.git
|
| 32 |
+
cd mmdetection3d
|
| 33 |
+
git checkout v1.0.0rc6
|
| 34 |
+
pip install -e .
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
**当前仓库状态**:`external/StreamPETR` 是直接复制的代码,非 git 仓库,无法校验具体 commit。
|
| 38 |
+
|
| 39 |
+
## 3. EVA-02 预训练权重
|
| 40 |
+
|
| 41 |
+
| 文件 | 来源 | MD5 |
|
| 42 |
+
|------|------|-----|
|
| 43 |
+
| `eva02_L_coco_det_sys_o365_remapped.pth` | [GitHub Release](https://github.com/exiawsh/storage/releases/download/v1.0/eva02_L_coco_det_sys_o365_remapped.pth) | `15c389fe4e987275c3d08405ca9eeb92` |
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# 下载权重
|
| 47 |
+
mkdir -p /root/autodl-tmp
|
| 48 |
+
wget -O /root/autodl-tmp/eva02_L_coco_det_sys_o365_remapped.pth \
|
| 49 |
+
https://github.com/exiawsh/storage/releases/download/v1.0/eva02_L_coco_det_sys_o365_remapped.pth
|
| 50 |
+
|
| 51 |
+
# 验证 MD5
|
| 52 |
+
md5sum /root/autodl-tmp/eva02_L_coco_det_sys_o365_remapped.pth
|
| 53 |
+
# 应该输出: 15c389fe4e987275c3d08405ca9eeb92
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## 4. nuScenes 数据准备
|
| 57 |
+
|
| 58 |
+
### 4.1 下载数据集
|
| 59 |
+
|
| 60 |
+
从 [nuScenes 官网](https://www.nuscenes.org/download) 下载:
|
| 61 |
+
- Full dataset (v1.0-trainval): ~330GB
|
| 62 |
+
- Mini dataset (v1.0-mini): ~4GB(调试用)
|
| 63 |
+
|
| 64 |
+
### 4.2 生成 temporal infos pkl
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
# 在本仓库中请使用 external/StreamPETR 目录
|
| 68 |
+
cd external/StreamPETR
|
| 69 |
+
|
| 70 |
+
# 生成 nuscenes2d_temporal_infos_{train,val}.pkl
|
| 71 |
+
python tools/create_data_nusc.py \
|
| 72 |
+
--root-path ./data/nuscenes \
|
| 73 |
+
--out-dir ./data/nuscenes \
|
| 74 |
+
--extra-tag nuscenes2d \
|
| 75 |
+
--version v1.0
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### 4.3 或下载预处理好的 pkl
|
| 79 |
+
|
| 80 |
+
| 文件 | 下载链接 |
|
| 81 |
+
|------|----------|
|
| 82 |
+
| train.pkl | [nuscenes2d_temporal_infos_train.pkl](https://github.com/exiawsh/storage/releases/download/v1.0/nuscenes2d_temporal_infos_train.pkl) |
|
| 83 |
+
| val.pkl | [nuscenes2d_temporal_infos_val.pkl](https://github.com/exiawsh/storage/releases/download/v1.0/nuscenes2d_temporal_infos_val.pkl) |
|
| 84 |
+
| test.pkl | [nuscenes2d_temporal_infos_test.pkl](https://github.com/exiawsh/storage/releases/download/v1.0/nuscenes2d_temporal_infos_test.pkl) |
|
| 85 |
+
|
| 86 |
+
### 4.4 目录结构(官方 StreamPETR 期望)
|
| 87 |
+
|
| 88 |
+
```
|
| 89 |
+
external/StreamPETR/data/nuscenes/
|
| 90 |
+
├── maps/
|
| 91 |
+
├── samples/
|
| 92 |
+
├── sweeps/
|
| 93 |
+
├── v1.0-trainval/
|
| 94 |
+
├── nuscenes2d_temporal_infos_train.pkl
|
| 95 |
+
├── nuscenes2d_temporal_infos_val.pkl
|
| 96 |
+
└── nuscenes2d_temporal_infos_test.pkl
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## 5. 训练配置
|
| 100 |
+
|
| 101 |
+
### 5.1 3D Tokenizer 预训练设置(Appendix B.1)
|
| 102 |
+
|
| 103 |
+
| 参数 | 设置 | 备注 |
|
| 104 |
+
|------|-----------|------|
|
| 105 |
+
| Backbone | EVA-02 ViT-L | 1024 embed_dim |
|
| 106 |
+
| Resolution | 800×1600 | 高分辨率 |
|
| 107 |
+
| Detection queries | 256 | |
|
| 108 |
+
| Epochs | 24 | |
|
| 109 |
+
| Batch size | 1 per GPU | 高分辨率需要小 batch |
|
| 110 |
+
| GPUs | 8× A100 | 论文使用 |
|
| 111 |
+
| Base LR | 4e-4 | bs=16 (8 GPU × 2) |
|
| 112 |
+
| Backbone LR | 0.1× base | |
|
| 113 |
+
|
| 114 |
+
### 5.2 Atlas LLM 配置(论文原文 Appendix B.1)
|
| 115 |
+
|
| 116 |
+
| 参数 | 论文描述 |
|
| 117 |
+
|------|----------|
|
| 118 |
+
| Batch size | 1 per GPU |
|
| 119 |
+
| Learning rate | 2e-5 |
|
| 120 |
+
| Optimizer | AdamW (weight_decay=1e-4) |
|
| 121 |
+
| LR schedule | 3% linear warmup + cosine |
|
| 122 |
+
| Max length | 4096 |
|
| 123 |
+
| 训练硬件/时长 | 8x A100,约 100 小时 |
|
| 124 |
+
|
| 125 |
+
> 说明:论文附录未显式给出 LoRA 开关。
|
| 126 |
+
> 如需解释某次实验结果,请以该实验的 `work_dirs/<exp>/args.json` 与训练日志为准。
|
| 127 |
+
|
| 128 |
+
### 5.3 当前配置(A100 环境)
|
| 129 |
+
|
| 130 |
+
当前配置使用线性缩放学习率(官方注释: "bs 8: 2e-4 || bs 16: 4e-4"):
|
| 131 |
+
- **8× A100 GPU**
|
| 132 |
+
- **batch_size = 1 per GPU**
|
| 133 |
+
- **Base LR = 2e-4** (effective bs=8,按官方注释使用 2e-4)
|
| 134 |
+
- **配置文件**:`configs/streampetr_atlas_aligned.py`
|
| 135 |
+
|
| 136 |
+
### 5.4 有效 batch size 计算
|
| 137 |
+
|
| 138 |
+
```
|
| 139 |
+
有效 batch size = num_gpus × batch_size_per_gpu × gradient_accumulation
|
| 140 |
+
= 8 × 1 × 1 = 8
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
官方学习率参考: "bs 8: 2e-4 || bs 16: 4e-4"
|
| 144 |
+
当前 effective bs=8,所以使用 lr=2e-4。
|
| 145 |
+
|
| 146 |
+
如需匹配官方 bs=16,可设置 `gradient_accumulation_steps=2`,并将 lr 改为 4e-4。
|
| 147 |
+
|
| 148 |
+
### 5.5 其他硬件适配
|
| 149 |
+
|
| 150 |
+
如果 GPU 数量不同,需要线性缩放学习率:
|
| 151 |
+
```
|
| 152 |
+
调整后 LR = 2e-4 × (实际 effective batch size / 8)
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
例如 6×GPU、bs=1:`LR = 2e-4 × 6/8 = 1.5e-4`
|
| 156 |
+
|
| 157 |
+
## 6. 训练命令(官方 StreamPETR)
|
| 158 |
+
|
| 159 |
+
```bash
|
| 160 |
+
# 推荐:从本仓库根目录启动(已强制使用官方 StreamPETR 配置)
|
| 161 |
+
bash scripts/run_train_streampetr.sh
|
| 162 |
+
|
| 163 |
+
# 或直接使用官方 StreamPETR 工具
|
| 164 |
+
cd external/StreamPETR
|
| 165 |
+
bash tools/dist_train.sh projects/configs/Atlas/atlas_streampetr_eva02_800x1600_256q_24e.py 8
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
> 注意:不要使用本仓库内的 `scripts/train_streampetr.py`,它是非官方实现,仅用于对比与调试。
|
| 169 |
+
|
| 170 |
+
## 7. 验证环境
|
| 171 |
+
|
| 172 |
+
```bash
|
| 173 |
+
# 检查 CUDA
|
| 174 |
+
python -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.version.cuda}')"
|
| 175 |
+
|
| 176 |
+
# 检查 mmdet3d
|
| 177 |
+
python -c "import mmdet3d; print(f'mmdet3d: {mmdet3d.__version__}')"
|
| 178 |
+
|
| 179 |
+
# 检查 flash-attn
|
| 180 |
+
python -c "from flash_attn import flash_attn_func; print('flash-attn OK')"
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
## 8. 参考资料
|
| 184 |
+
|
| 185 |
+
- [StreamPETR 官方仓库](https://github.com/exiawsh/StreamPETR)
|
| 186 |
+
- [Atlas 论文](https://arxiv.org/abs/2405.18361)
|
| 187 |
+
- [nuScenes 数据集](https://www.nuscenes.org/)
|
| 188 |
+
- [EVA-02 论文](https://arxiv.org/abs/2303.11926)
|
| 189 |
+
|
| 190 |
+
---
|
| 191 |
+
|
| 192 |
+
## 缺失项清单(需补齐,否则无法严格复现)
|
| 193 |
+
|
| 194 |
+
1. `external/StreamPETR` 的**准确 commit hash**(当前不是 git 仓库,无法校验版本)。
|
| 195 |
+
2. Atlas 论文中 StreamPETR 预训练的**实际梯度累积与有效 batch size**(官方配置注释为 bs=16,但未显式开启 accum)。
|
| 196 |
+
3. 论文实验所用的**完整依赖版本锁定**(除了版本矩阵以外,未提供 lockfile/requirements 固定依赖)。
|
| 197 |
+
|
| 198 |
+
---
|
| 199 |
+
|
| 200 |
+
**最后更新**: 2026-01-29
|
configs/ds_zero2.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {"enabled": true},
|
| 3 |
+
"zero_optimization": {
|
| 4 |
+
"stage": 2,
|
| 5 |
+
"allgather_partitions": true,
|
| 6 |
+
"allgather_bucket_size": 50000000,
|
| 7 |
+
"overlap_comm": false,
|
| 8 |
+
"reduce_scatter": true,
|
| 9 |
+
"reduce_bucket_size": 50000000,
|
| 10 |
+
"contiguous_gradients": true
|
| 11 |
+
},
|
| 12 |
+
"gradient_accumulation_steps": 2,
|
| 13 |
+
"gradient_clipping": 1.0,
|
| 14 |
+
"train_batch_size": 8,
|
| 15 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 16 |
+
"wall_clock_breakdown": false
|
| 17 |
+
}
|
eval_atlas.py
ADDED
|
@@ -0,0 +1,1175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import re
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Optional
|
| 10 |
+
from collections import Counter, defaultdict
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
| 16 |
+
|
| 17 |
+
from src.model.modeling_atlas import AtlasForCausalLM
|
| 18 |
+
from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens
|
| 19 |
+
from src.model.streampetr_adapter import extract_streampetr_topk_tokens
|
| 20 |
+
from src.dataset.atlas_dataset import AtlasDataset, infer_task_type, make_atlas_collate_fn, load_tokenizer
|
| 21 |
+
from src.dataset.scene_sampler import SceneSequentialSampler
|
| 22 |
+
from src.prompting import PLANNING_TABLE3_MODES
|
| 23 |
+
from src.eval.metrics import (
|
| 24 |
+
parse_atlas_output,
|
| 25 |
+
parse_planning_output,
|
| 26 |
+
normalize_ground_truths,
|
| 27 |
+
calculate_detection_f1,
|
| 28 |
+
calculate_multi_threshold_detection_f1,
|
| 29 |
+
calculate_lane_detection_metrics,
|
| 30 |
+
calculate_planning_metrics,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger("eval_atlas")
|
| 34 |
+
|
| 35 |
+
_DET_POINT_BLOCK_RE = re.compile(
|
| 36 |
+
r"\[\s*[-+]?\d+(?:\.\d+)?\s*,\s*[-+]?\d+(?:\.\d+)?\s*,\s*[-+]?\d+(?:\.\d+)?\s*\]"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def summarize_detection_parse(gen_text: str, det_preds: List[Dict]) -> Dict[str, object]:
|
| 41 |
+
stripped = gen_text.strip().rstrip(". \t\n")
|
| 42 |
+
parsed_count = int(len(det_preds))
|
| 43 |
+
raw_point_count = int(len(_DET_POINT_BLOCK_RE.findall(gen_text)))
|
| 44 |
+
is_empty_negative = stripped.lower() == "no objects detected within range"
|
| 45 |
+
partial_parse_suspected = bool(parsed_count > 0 and raw_point_count > parsed_count)
|
| 46 |
+
|
| 47 |
+
parse_failed = False
|
| 48 |
+
failure_reason = ""
|
| 49 |
+
if parsed_count == 0 and not is_empty_negative:
|
| 50 |
+
parse_failed = True
|
| 51 |
+
failure_reason = (
|
| 52 |
+
"coordinates_present_but_unparsed"
|
| 53 |
+
if raw_point_count > 0 else
|
| 54 |
+
"no_detection_pattern"
|
| 55 |
+
)
|
| 56 |
+
elif partial_parse_suspected:
|
| 57 |
+
failure_reason = "partial_parse_suspected"
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
"detection_parse_failed": parse_failed,
|
| 61 |
+
"detection_parse_failure_reason": failure_reason,
|
| 62 |
+
"detection_partial_parse_suspected": partial_parse_suspected,
|
| 63 |
+
"detection_raw_point_count": raw_point_count,
|
| 64 |
+
"detection_parsed_count": parsed_count,
|
| 65 |
+
"detection_is_empty_negative": is_empty_negative,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def summarize_lane_parse(gen_text: str, lane_preds: List[Dict]) -> Dict[str, object]:
|
| 70 |
+
stripped = gen_text.strip().rstrip(". \t\n")
|
| 71 |
+
raw_point_count = int(len(_DET_POINT_BLOCK_RE.findall(gen_text)))
|
| 72 |
+
parsed_lane_count = int(len(lane_preds))
|
| 73 |
+
parsed_point_count = int(sum(len(l.get("points", [])) for l in lane_preds))
|
| 74 |
+
is_empty_negative = stripped.lower() == "no lane centerlines detected within range"
|
| 75 |
+
partial_parse_suspected = bool(
|
| 76 |
+
parsed_lane_count > 0 and raw_point_count > parsed_point_count
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
parse_failed = False
|
| 80 |
+
failure_reason = ""
|
| 81 |
+
lowered = stripped.lower()
|
| 82 |
+
if parsed_lane_count == 0 and not is_empty_negative:
|
| 83 |
+
parse_failed = True
|
| 84 |
+
if lowered.startswith("lane_centerline("):
|
| 85 |
+
failure_reason = "legacy_lane_format"
|
| 86 |
+
elif raw_point_count > 0:
|
| 87 |
+
failure_reason = "coordinates_present_but_unparsed"
|
| 88 |
+
elif "lane:" in lowered:
|
| 89 |
+
failure_reason = "lane_prefix_without_valid_points"
|
| 90 |
+
else:
|
| 91 |
+
failure_reason = "no_lane_pattern"
|
| 92 |
+
elif partial_parse_suspected:
|
| 93 |
+
failure_reason = "partial_parse_suspected"
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
"lane_parse_failed": parse_failed,
|
| 97 |
+
"lane_parse_failure_reason": failure_reason,
|
| 98 |
+
"lane_partial_parse_suspected": partial_parse_suspected,
|
| 99 |
+
"lane_raw_point_count": raw_point_count,
|
| 100 |
+
"lane_parsed_lane_count": parsed_lane_count,
|
| 101 |
+
"lane_parsed_point_count": parsed_point_count,
|
| 102 |
+
"lane_is_empty_negative": is_empty_negative,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _audit_lane_gt_point_counts(dataset, max_samples: int = 100):
|
| 107 |
+
"""Spot-check lane GT point counts; warn if median != 4 (non-canonical format)."""
|
| 108 |
+
lane_indices = [i for i, t in enumerate(dataset._task_types) if t == "lane"]
|
| 109 |
+
if not lane_indices:
|
| 110 |
+
return
|
| 111 |
+
check = lane_indices[:max_samples]
|
| 112 |
+
pt_counts = []
|
| 113 |
+
for idx in check:
|
| 114 |
+
item = dataset.data[idx]
|
| 115 |
+
conv = item.get("conversations", [])
|
| 116 |
+
answer = ""
|
| 117 |
+
for turn in conv:
|
| 118 |
+
if turn.get("from") in ("gpt", "assistant"):
|
| 119 |
+
answer = turn.get("value", "")
|
| 120 |
+
break
|
| 121 |
+
lanes = parse_atlas_output(answer)
|
| 122 |
+
for lane in lanes:
|
| 123 |
+
pts = lane.get("points", [])
|
| 124 |
+
if pts:
|
| 125 |
+
pt_counts.append(len(pts))
|
| 126 |
+
if not pt_counts:
|
| 127 |
+
return
|
| 128 |
+
median_pts = int(sorted(pt_counts)[len(pt_counts) // 2])
|
| 129 |
+
if median_pts != 4:
|
| 130 |
+
logger.warning(
|
| 131 |
+
"Lane GT point-count median is %d (expected 4 for canonical _4pt format). "
|
| 132 |
+
"Sampled %d lanes from %d samples. "
|
| 133 |
+
"Check that --data_json points to the correct *_4pt.json files.",
|
| 134 |
+
median_pts, len(pt_counts), len(check),
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
logger.info(
|
| 138 |
+
"Lane GT point-count check OK: median=%d, sampled %d lanes from %d samples",
|
| 139 |
+
median_pts, len(pt_counts), len(check),
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def summarize_lane_gt_parse(gt_answer: str, gt_lanes: List[Dict]) -> Dict[str, object]:
|
| 144 |
+
"""Mirror of summarize_lane_parse() but for the GT side."""
|
| 145 |
+
stripped = gt_answer.strip().rstrip(". \t\n")
|
| 146 |
+
raw_point_count = int(len(_DET_POINT_BLOCK_RE.findall(gt_answer)))
|
| 147 |
+
parsed_lane_count = int(len(gt_lanes))
|
| 148 |
+
parsed_point_count = int(sum(len(l.get("points", [])) for l in gt_lanes))
|
| 149 |
+
is_empty_negative = stripped.lower() == "no lane centerlines detected within range"
|
| 150 |
+
partial_parse_suspected = bool(
|
| 151 |
+
parsed_lane_count > 0 and raw_point_count > parsed_point_count
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
parse_failed = False
|
| 155 |
+
failure_reason = ""
|
| 156 |
+
lowered = stripped.lower()
|
| 157 |
+
if parsed_lane_count == 0 and not is_empty_negative:
|
| 158 |
+
parse_failed = True
|
| 159 |
+
if lowered.startswith("lane_centerline("):
|
| 160 |
+
failure_reason = "legacy_lane_format"
|
| 161 |
+
elif raw_point_count > 0:
|
| 162 |
+
failure_reason = "coordinates_present_but_unparsed"
|
| 163 |
+
elif "lane:" in lowered:
|
| 164 |
+
failure_reason = "lane_prefix_without_valid_points"
|
| 165 |
+
else:
|
| 166 |
+
failure_reason = "no_lane_pattern"
|
| 167 |
+
elif partial_parse_suspected:
|
| 168 |
+
failure_reason = "partial_parse_suspected"
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
"gt_lane_parse_failed": parse_failed,
|
| 172 |
+
"gt_lane_parse_failure_reason": failure_reason,
|
| 173 |
+
"gt_lane_partial_parse_suspected": partial_parse_suspected,
|
| 174 |
+
"gt_lane_raw_point_count": raw_point_count,
|
| 175 |
+
"gt_lane_parsed_lane_count": parsed_lane_count,
|
| 176 |
+
"gt_lane_parsed_point_count": parsed_point_count,
|
| 177 |
+
"gt_lane_is_empty_negative": is_empty_negative,
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def parse_args():
|
| 182 |
+
p = argparse.ArgumentParser()
|
| 183 |
+
p.add_argument("--checkpoint", required=True)
|
| 184 |
+
p.add_argument("--llm_model", default="lmsys/vicuna-7b-v1.5")
|
| 185 |
+
p.add_argument("--visual_hidden_size", type=int, default=256)
|
| 186 |
+
p.add_argument("--num_det_queries", type=int, default=256)
|
| 187 |
+
p.add_argument("--num_map_queries", type=int, default=256)
|
| 188 |
+
p.add_argument("--streampetr_config", default=None)
|
| 189 |
+
p.add_argument("--streampetr_ckpt", default=None)
|
| 190 |
+
p.add_argument("--topomlp_config", default=None)
|
| 191 |
+
p.add_argument("--topomlp_ckpt", default=None)
|
| 192 |
+
p.add_argument("--data_json", required=True)
|
| 193 |
+
p.add_argument("--data_root", default="/mnt/data/nuscenes")
|
| 194 |
+
p.add_argument("--max_length", type=int, default=4096)
|
| 195 |
+
p.add_argument("--max_new_tokens", type=int, default=2700)
|
| 196 |
+
p.add_argument("--batch_size", type=int, default=1)
|
| 197 |
+
p.add_argument("--num_workers", type=int, default=4)
|
| 198 |
+
p.add_argument("--use_lora", action="store_true")
|
| 199 |
+
p.add_argument("--lora_r", type=int, default=64)
|
| 200 |
+
p.add_argument("--lora_alpha", type=int, default=64)
|
| 201 |
+
p.add_argument("--load_in_4bit", action="store_true")
|
| 202 |
+
p.add_argument("--output_json", default=None)
|
| 203 |
+
p.add_argument("--max_samples", type=int, default=0)
|
| 204 |
+
p.add_argument("--fp16", action="store_true")
|
| 205 |
+
p.add_argument("--bf16", action="store_true")
|
| 206 |
+
p.add_argument("--precomputed_det_tokens", default=None,
|
| 207 |
+
help="[offline only] Dir with precomputed det tokens (.pt files)")
|
| 208 |
+
p.add_argument("--precomputed_map_tokens", default=None,
|
| 209 |
+
help="[offline only] Dir with precomputed TopoMLP map tokens (.pt files)")
|
| 210 |
+
p.add_argument("--visual_token_mode", choices=("online", "offline"), default="online",
|
| 211 |
+
help="Visual token source: online=live frozen encoders (default), offline=read *_offline dirs")
|
| 212 |
+
p.add_argument(
|
| 213 |
+
"--planning_table3_mode",
|
| 214 |
+
choices=PLANNING_TABLE3_MODES,
|
| 215 |
+
default="atlas_base",
|
| 216 |
+
help=(
|
| 217 |
+
"Planning prompt variant matching Atlas Table 3: "
|
| 218 |
+
"atlas_base=no command/no explicit ego state; "
|
| 219 |
+
"atlas_high_level=requires top-level route_command "
|
| 220 |
+
"(this repo uses a UniAD-style future-GT-derived command); "
|
| 221 |
+
"atlas_high_level_ego=requires top-level route_command plus "
|
| 222 |
+
"velocity/acceleration bins."
|
| 223 |
+
),
|
| 224 |
+
)
|
| 225 |
+
return p.parse_args()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def infer_task(item: Dict) -> str:
|
| 229 |
+
return infer_task_type(item)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def load_frozen_encoder(config_path, ckpt_path, model_type, device):
|
| 233 |
+
if config_path is None or ckpt_path is None:
|
| 234 |
+
return None
|
| 235 |
+
try:
|
| 236 |
+
from mmcv import Config
|
| 237 |
+
from mmdet3d.models import build_model
|
| 238 |
+
from mmcv.runner import load_checkpoint
|
| 239 |
+
except ImportError:
|
| 240 |
+
raise RuntimeError(
|
| 241 |
+
f"mmcv/mmdet3d not installed but --{model_type}_config and "
|
| 242 |
+
f"--{model_type}_ckpt were explicitly provided. "
|
| 243 |
+
f"Install mmcv/mmdet3d or remove these arguments."
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
project_root = Path(__file__).resolve().parent
|
| 247 |
+
if model_type == "streampetr":
|
| 248 |
+
sp_root = str(project_root / "external" / "StreamPETR")
|
| 249 |
+
if sp_root not in sys.path:
|
| 250 |
+
sys.path.insert(0, sp_root)
|
| 251 |
+
try:
|
| 252 |
+
import projects.mmdet3d_plugin # noqa: F401
|
| 253 |
+
except ImportError:
|
| 254 |
+
raise RuntimeError(
|
| 255 |
+
f"StreamPETR plugin not found under {sp_root}/projects/mmdet3d_plugin. "
|
| 256 |
+
f"Ensure the submodule is checked out, or remove --streampetr_config/--streampetr_ckpt."
|
| 257 |
+
)
|
| 258 |
+
elif model_type == "topomlp":
|
| 259 |
+
tp_root = str(project_root / "external" / "TopoMLP_Repo")
|
| 260 |
+
if tp_root not in sys.path:
|
| 261 |
+
sys.path.insert(0, tp_root)
|
| 262 |
+
try:
|
| 263 |
+
os.environ["ATLAS_TOPOMLP_MODELS_ONLY"] = "1"
|
| 264 |
+
from mmcv.utils import registry as _reg
|
| 265 |
+
_orig = _reg.Registry._register_module
|
| 266 |
+
def _tolerant_register(self, module, module_name=None, force=False):
|
| 267 |
+
return _orig(self, module, module_name=module_name, force=True)
|
| 268 |
+
_reg.Registry._register_module = _tolerant_register
|
| 269 |
+
import projects.topomlp # noqa: F401
|
| 270 |
+
_reg.Registry._register_module = _orig
|
| 271 |
+
except ImportError:
|
| 272 |
+
raise RuntimeError(
|
| 273 |
+
f"TopoMLP plugin not found under {tp_root}/projects/topomlp. "
|
| 274 |
+
f"Ensure the submodule is checked out, or remove --topomlp_config/--topomlp_ckpt."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
cfg = Config.fromfile(config_path)
|
| 278 |
+
model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
|
| 279 |
+
load_checkpoint(model, ckpt_path, map_location="cpu")
|
| 280 |
+
model.eval()
|
| 281 |
+
model.to(device)
|
| 282 |
+
for param in model.parameters():
|
| 283 |
+
param.requires_grad_(False)
|
| 284 |
+
logger.info("Loaded frozen %s from %s", model_type, ckpt_path)
|
| 285 |
+
return model
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _run_streampetr_forward(model, imgs, img_metas, batch, device, prev_exists=None):
|
| 289 |
+
"""Run frozen StreamPETR forward. prev_exists controls temporal memory."""
|
| 290 |
+
B, N = imgs.shape[:2]
|
| 291 |
+
|
| 292 |
+
img_feats = model.extract_img_feat(imgs, 1)
|
| 293 |
+
|
| 294 |
+
data = {
|
| 295 |
+
"img": imgs,
|
| 296 |
+
"img_feats": img_feats,
|
| 297 |
+
"prev_exists": prev_exists if prev_exists is not None else imgs.new_zeros(B),
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
if "intrinsics_det" in batch:
|
| 301 |
+
K3 = batch["intrinsics_det"].to(device)
|
| 302 |
+
K4 = torch.zeros(B, N, 4, 4, device=device, dtype=K3.dtype)
|
| 303 |
+
K4[:, :, :3, :3] = K3
|
| 304 |
+
K4[:, :, 3, 3] = 1.0
|
| 305 |
+
data["intrinsics"] = K4
|
| 306 |
+
else:
|
| 307 |
+
data["intrinsics"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
|
| 308 |
+
|
| 309 |
+
if "lidar2img_det" in batch:
|
| 310 |
+
data["lidar2img"] = batch["lidar2img_det"].to(device)
|
| 311 |
+
else:
|
| 312 |
+
data["lidar2img"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
|
| 313 |
+
|
| 314 |
+
if "ego_pose" in batch and batch["ego_pose"] is not None:
|
| 315 |
+
data["ego_pose"] = batch["ego_pose"].to(device)
|
| 316 |
+
else:
|
| 317 |
+
data["ego_pose"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
|
| 318 |
+
|
| 319 |
+
if "ego_pose_inv" in batch and batch["ego_pose_inv"] is not None:
|
| 320 |
+
data["ego_pose_inv"] = batch["ego_pose_inv"].to(device)
|
| 321 |
+
else:
|
| 322 |
+
data["ego_pose_inv"] = torch.inverse(data["ego_pose"])
|
| 323 |
+
|
| 324 |
+
if "timestamp" in batch and batch["timestamp"] is not None:
|
| 325 |
+
data["timestamp"] = batch["timestamp"].to(device)
|
| 326 |
+
else:
|
| 327 |
+
data["timestamp"] = torch.zeros(B, device=device)
|
| 328 |
+
|
| 329 |
+
location = model.prepare_location(img_metas, **data)
|
| 330 |
+
outs_roi = model.forward_roi_head(location, **data)
|
| 331 |
+
topk_indexes = outs_roi["topk_indexes"]
|
| 332 |
+
|
| 333 |
+
outs = model.pts_bbox_head(location, img_metas, topk_indexes, **data)
|
| 334 |
+
return outs
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _reconstruct_topomlp_outs(saved: dict, device, dtype):
|
| 338 |
+
"""Convert precomputed .pt dict back to the format adapter.forward() expects."""
|
| 339 |
+
def _restore(t):
|
| 340 |
+
return t.to(device=device, dtype=dtype).unsqueeze(0)
|
| 341 |
+
return {
|
| 342 |
+
"lc_outs_dec_list": [_restore(saved["lc_outs_dec"])],
|
| 343 |
+
"all_lc_cls_scores_list": [_restore(saved["lc_cls_scores"])],
|
| 344 |
+
"all_lc_preds_list": [_restore(saved["lc_preds"])],
|
| 345 |
+
"lc_outs_dec_one2many_list": [_restore(saved["lc_outs_dec_o2m"])],
|
| 346 |
+
"all_lc_cls_scores_one2many_list": [_restore(saved["lc_cls_scores_o2m"])],
|
| 347 |
+
"all_lc_preds_one2many_list": [_restore(saved["lc_preds_o2m"])],
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def extract_visual_tokens(
|
| 352 |
+
streampetr_model, topomlp_model, topomlp_adapter,
|
| 353 |
+
batch, device, num_det_queries, visual_hidden_size,
|
| 354 |
+
visual_token_mode="online",
|
| 355 |
+
streaming_state=None,
|
| 356 |
+
query_token_id=None,
|
| 357 |
+
):
|
| 358 |
+
"""Extract det + map visual tokens (eval version).
|
| 359 |
+
|
| 360 |
+
Mirrors train_atlas.extract_visual_tokens: in online mode with
|
| 361 |
+
streaming_state, StreamPETR temporal memory is scene-aware and
|
| 362 |
+
duplicate physical frames are protected. The needs_map gating
|
| 363 |
+
skips TopoMLP when the current sample only requires det queries.
|
| 364 |
+
"""
|
| 365 |
+
B = batch["pixel_values_det"].shape[0]
|
| 366 |
+
N = batch["pixel_values_det"].shape[1]
|
| 367 |
+
vis: Dict[str, torch.Tensor] = {}
|
| 368 |
+
|
| 369 |
+
needs_map = False
|
| 370 |
+
if query_token_id is not None and "input_ids" in batch:
|
| 371 |
+
n_queries = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item())
|
| 372 |
+
needs_map = n_queries > num_det_queries
|
| 373 |
+
|
| 374 |
+
# ---- Detection tokens ----
|
| 375 |
+
if visual_token_mode == "offline" and "precomputed_det" in batch and "precomputed_det_ref" in batch:
|
| 376 |
+
vis["detection"] = batch["precomputed_det"].to(device)
|
| 377 |
+
vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device)
|
| 378 |
+
elif visual_token_mode == "offline":
|
| 379 |
+
raise RuntimeError(
|
| 380 |
+
"visual_token_mode=offline but detection precomputed tokens are missing "
|
| 381 |
+
"for the current batch. Refusing to zero-fill."
|
| 382 |
+
)
|
| 383 |
+
elif streampetr_model is not None:
|
| 384 |
+
current_sample_id = batch.get("sample_id", [None])[0]
|
| 385 |
+
current_scene = batch.get("scene_id", ["__atlas_eval__"])[0]
|
| 386 |
+
reuse_cache = False
|
| 387 |
+
|
| 388 |
+
if streaming_state is not None:
|
| 389 |
+
prev_scene = streaming_state.get("prev_scene_token")
|
| 390 |
+
prev_sample_id = streaming_state.get("prev_sample_id")
|
| 391 |
+
ts_tensor = batch.get("timestamp")
|
| 392 |
+
current_ts = float(ts_tensor[0].item()) if ts_tensor is not None else None
|
| 393 |
+
prev_ts = streaming_state.get("prev_timestamp")
|
| 394 |
+
|
| 395 |
+
is_new_segment = (
|
| 396 |
+
prev_scene is None
|
| 397 |
+
or current_scene != prev_scene
|
| 398 |
+
or (current_ts is not None and prev_ts is not None and current_ts <= prev_ts)
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if current_sample_id is not None and current_sample_id == prev_sample_id:
|
| 402 |
+
cached = streaming_state.get("cached_det")
|
| 403 |
+
if cached is not None:
|
| 404 |
+
reuse_cache = True
|
| 405 |
+
vis["detection"] = cached["detection"]
|
| 406 |
+
vis["detection_ref_points"] = cached["detection_ref_points"]
|
| 407 |
+
|
| 408 |
+
if not reuse_cache:
|
| 409 |
+
if is_new_segment:
|
| 410 |
+
streampetr_model.pts_bbox_head.reset_memory()
|
| 411 |
+
prev_exists_val = 0.0 if is_new_segment else 1.0
|
| 412 |
+
prev_exists = batch["pixel_values_det"].new_full((B,), prev_exists_val)
|
| 413 |
+
|
| 414 |
+
imgs_det = batch["pixel_values_det"].to(device)
|
| 415 |
+
fH, fW = 800, 1600
|
| 416 |
+
img_metas = [{
|
| 417 |
+
"pad_shape": [(fH, fW, 3)] * N,
|
| 418 |
+
"img_shape": [(fH, fW, 3)] * N,
|
| 419 |
+
"scene_token": current_scene,
|
| 420 |
+
} for _ in range(B)]
|
| 421 |
+
if "lidar2img_det" in batch:
|
| 422 |
+
for b in range(B):
|
| 423 |
+
img_metas[b]["lidar2img"] = batch["lidar2img_det"][b].cpu().numpy()
|
| 424 |
+
with torch.no_grad():
|
| 425 |
+
_run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device, prev_exists=prev_exists)
|
| 426 |
+
ego_pose_for_ref = batch.get("ego_pose")
|
| 427 |
+
if ego_pose_for_ref is not None:
|
| 428 |
+
ego_pose_for_ref = ego_pose_for_ref.to(device)
|
| 429 |
+
det_out = extract_streampetr_topk_tokens(
|
| 430 |
+
streampetr_model.pts_bbox_head, topk=num_det_queries,
|
| 431 |
+
ego_pose=ego_pose_for_ref,
|
| 432 |
+
)
|
| 433 |
+
vis["detection"] = det_out["detection"]
|
| 434 |
+
vis["detection_ref_points"] = det_out["detection_ref_points"]
|
| 435 |
+
|
| 436 |
+
streaming_state["cached_det"] = {
|
| 437 |
+
"detection": vis["detection"],
|
| 438 |
+
"detection_ref_points": vis["detection_ref_points"],
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
streaming_state["prev_scene_token"] = current_scene
|
| 442 |
+
streaming_state["prev_sample_id"] = current_sample_id
|
| 443 |
+
if batch.get("timestamp") is not None:
|
| 444 |
+
streaming_state["prev_timestamp"] = float(batch["timestamp"][0].item())
|
| 445 |
+
else:
|
| 446 |
+
imgs_det = batch["pixel_values_det"].to(device)
|
| 447 |
+
fH, fW = 800, 1600
|
| 448 |
+
img_metas = [{
|
| 449 |
+
"pad_shape": [(fH, fW, 3)] * N,
|
| 450 |
+
"img_shape": [(fH, fW, 3)] * N,
|
| 451 |
+
"scene_token": current_scene,
|
| 452 |
+
} for _ in range(B)]
|
| 453 |
+
if "lidar2img_det" in batch:
|
| 454 |
+
for b in range(B):
|
| 455 |
+
img_metas[b]["lidar2img"] = batch["lidar2img_det"][b].cpu().numpy()
|
| 456 |
+
with torch.no_grad():
|
| 457 |
+
streampetr_model.pts_bbox_head.reset_memory()
|
| 458 |
+
_run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device)
|
| 459 |
+
ego_pose_for_ref = batch.get("ego_pose")
|
| 460 |
+
if ego_pose_for_ref is not None:
|
| 461 |
+
ego_pose_for_ref = ego_pose_for_ref.to(device)
|
| 462 |
+
det_out = extract_streampetr_topk_tokens(
|
| 463 |
+
streampetr_model.pts_bbox_head, topk=num_det_queries,
|
| 464 |
+
ego_pose=ego_pose_for_ref,
|
| 465 |
+
)
|
| 466 |
+
vis["detection"] = det_out["detection"]
|
| 467 |
+
vis["detection_ref_points"] = det_out["detection_ref_points"]
|
| 468 |
+
elif visual_token_mode == "online":
|
| 469 |
+
raise RuntimeError(
|
| 470 |
+
"visual_token_mode=online but StreamPETR model is None. "
|
| 471 |
+
"Provide --streampetr_config and --streampetr_ckpt."
|
| 472 |
+
)
|
| 473 |
+
else:
|
| 474 |
+
vis["detection"] = torch.zeros(B, num_det_queries, visual_hidden_size, device=device)
|
| 475 |
+
vis["detection_ref_points"] = torch.zeros(B, num_det_queries, 3, device=device)
|
| 476 |
+
|
| 477 |
+
# ---- Map tokens ----
|
| 478 |
+
num_map_queries = num_det_queries
|
| 479 |
+
if topomlp_adapter is not None:
|
| 480 |
+
num_map_queries = topomlp_adapter.num_map_tokens
|
| 481 |
+
|
| 482 |
+
map_filled = False
|
| 483 |
+
if visual_token_mode == "offline" and topomlp_adapter is not None and "precomputed_map" in batch:
|
| 484 |
+
_params = list(topomlp_adapter.parameters())
|
| 485 |
+
_bufs = list(topomlp_adapter.buffers())
|
| 486 |
+
adapter_dtype = _params[0].dtype if _params else (_bufs[0].dtype if _bufs else torch.float32)
|
| 487 |
+
if B == 1:
|
| 488 |
+
outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype)
|
| 489 |
+
else:
|
| 490 |
+
per_sample = [_reconstruct_topomlp_outs(batch["precomputed_map"][b], device, adapter_dtype) for b in range(B)]
|
| 491 |
+
outs = {}
|
| 492 |
+
for k in per_sample[0]:
|
| 493 |
+
outs[k] = [torch.cat([s[k][i] for s in per_sample], dim=0) for i in range(len(per_sample[0][k]))]
|
| 494 |
+
with torch.no_grad():
|
| 495 |
+
map_out = topomlp_adapter(outs)
|
| 496 |
+
vis["map"] = map_out["map"]
|
| 497 |
+
vis["map_ref_points"] = map_out["map_ref_points"]
|
| 498 |
+
map_filled = True
|
| 499 |
+
elif visual_token_mode == "offline" and topomlp_adapter is not None:
|
| 500 |
+
raise RuntimeError(
|
| 501 |
+
"visual_token_mode=offline but map precomputed tokens are missing "
|
| 502 |
+
"for the current batch. Refusing to zero-fill."
|
| 503 |
+
)
|
| 504 |
+
elif needs_map and topomlp_model is not None and topomlp_adapter is not None:
|
| 505 |
+
imgs_map = batch["pixel_values_map"].to(device)
|
| 506 |
+
img_metas = []
|
| 507 |
+
for b in range(B):
|
| 508 |
+
meta = {"img_shape": [(800, 1600, 3)] * N, "pad_shape": [(800, 1600, 3)] * N}
|
| 509 |
+
meta["scale_factor"] = 1.0
|
| 510 |
+
meta["te_yolov8"] = None
|
| 511 |
+
if "lidar2img_map" in batch:
|
| 512 |
+
meta["lidar2img"] = batch["lidar2img_map"][b].cpu().numpy()
|
| 513 |
+
img_metas.append(meta)
|
| 514 |
+
with torch.no_grad():
|
| 515 |
+
outs = topomlp_model.simple_forward(imgs_map, img_metas)
|
| 516 |
+
map_out = topomlp_adapter(outs)
|
| 517 |
+
vis["map"] = map_out["map"]
|
| 518 |
+
vis["map_ref_points"] = map_out["map_ref_points"]
|
| 519 |
+
map_filled = True
|
| 520 |
+
|
| 521 |
+
if topomlp_adapter is not None and not map_filled:
|
| 522 |
+
vis["map"] = torch.zeros(B, num_map_queries, visual_hidden_size, device=device)
|
| 523 |
+
vis["map_ref_points"] = torch.zeros(B, num_map_queries, 3, device=device)
|
| 524 |
+
|
| 525 |
+
return vis
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def parse_gt_from_item(item: Dict, task: str) -> Dict:
|
| 529 |
+
gt = {}
|
| 530 |
+
if task == "detection":
|
| 531 |
+
annotations = item.get("gt_boxes_3d", item.get("annotations", []))
|
| 532 |
+
gt_dets = []
|
| 533 |
+
for ann in annotations:
|
| 534 |
+
if isinstance(ann, dict):
|
| 535 |
+
cat = ann.get("category_name", ann.get("category", "unknown"))
|
| 536 |
+
if "box" in ann:
|
| 537 |
+
coords = ann["box"][:3]
|
| 538 |
+
elif "translation" in ann:
|
| 539 |
+
coords = ann["translation"][:3]
|
| 540 |
+
else:
|
| 541 |
+
continue
|
| 542 |
+
gt_dets.append({
|
| 543 |
+
"category": cat,
|
| 544 |
+
"world_coords": list(coords),
|
| 545 |
+
})
|
| 546 |
+
gt["detections"] = normalize_ground_truths(gt_dets)
|
| 547 |
+
elif task == "lane":
|
| 548 |
+
conv = item.get("conversations", [])
|
| 549 |
+
answer = ""
|
| 550 |
+
for turn in conv:
|
| 551 |
+
if turn.get("from") in ("gpt", "assistant"):
|
| 552 |
+
answer = turn.get("value", "")
|
| 553 |
+
break
|
| 554 |
+
gt["lanes"] = parse_atlas_output(answer)
|
| 555 |
+
elif task == "planning":
|
| 556 |
+
ego = item.get("ego_motion", {})
|
| 557 |
+
gt["waypoints"] = ego.get("waypoints", [])
|
| 558 |
+
gt["gt_boxes"] = item.get("gt_boxes_3d", [])
|
| 559 |
+
if "gt_boxes_3d_per_timestep" in item:
|
| 560 |
+
gt["gt_boxes_per_timestep"] = item["gt_boxes_3d_per_timestep"]
|
| 561 |
+
return gt
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def _check_task_dependencies(args, *, has_lane: bool, is_online: bool):
|
| 565 |
+
"""Check dependencies based on actual task distribution in the dataset."""
|
| 566 |
+
missing = []
|
| 567 |
+
if is_online:
|
| 568 |
+
for mod in ("mmcv", "mmdet3d"):
|
| 569 |
+
try:
|
| 570 |
+
__import__(mod)
|
| 571 |
+
except ImportError:
|
| 572 |
+
missing.append(mod)
|
| 573 |
+
if has_lane:
|
| 574 |
+
try:
|
| 575 |
+
__import__("openlanev2")
|
| 576 |
+
except ImportError:
|
| 577 |
+
missing.append("openlanev2 (needed for lane F-Score)")
|
| 578 |
+
if missing:
|
| 579 |
+
raise RuntimeError(
|
| 580 |
+
f"Missing dependencies for this eval run: {', '.join(missing)}. "
|
| 581 |
+
f"Install them before running eval_atlas.py."
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def main():
|
| 586 |
+
args = parse_args()
|
| 587 |
+
logging.basicConfig(
|
| 588 |
+
level=logging.INFO,
|
| 589 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 593 |
+
|
| 594 |
+
tokenizer = load_tokenizer(args.llm_model)
|
| 595 |
+
if "<query>" not in tokenizer.get_vocab():
|
| 596 |
+
tokenizer.add_tokens(["<query>"])
|
| 597 |
+
|
| 598 |
+
dtype = torch.float32
|
| 599 |
+
if args.bf16:
|
| 600 |
+
dtype = torch.bfloat16
|
| 601 |
+
elif args.fp16:
|
| 602 |
+
dtype = torch.float16
|
| 603 |
+
|
| 604 |
+
dm = "auto" if args.load_in_4bit else None
|
| 605 |
+
atlas = AtlasForCausalLM(
|
| 606 |
+
llm_model_name=args.llm_model,
|
| 607 |
+
visual_hidden_size=args.visual_hidden_size,
|
| 608 |
+
num_queries=args.num_det_queries,
|
| 609 |
+
num_map_queries=args.num_map_queries,
|
| 610 |
+
load_in_4bit=args.load_in_4bit,
|
| 611 |
+
use_flash_attention=True,
|
| 612 |
+
device_map=dm,
|
| 613 |
+
torch_dtype=dtype,
|
| 614 |
+
use_lora=args.use_lora,
|
| 615 |
+
lora_r=args.lora_r,
|
| 616 |
+
lora_alpha=args.lora_alpha,
|
| 617 |
+
)
|
| 618 |
+
atlas.resize_token_embeddings(len(tokenizer))
|
| 619 |
+
atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("<query>"))
|
| 620 |
+
if dm is None:
|
| 621 |
+
atlas = atlas.to(device)
|
| 622 |
+
|
| 623 |
+
topomlp_adapter = None
|
| 624 |
+
|
| 625 |
+
ckpt = torch.load(args.checkpoint, map_location="cpu")
|
| 626 |
+
|
| 627 |
+
if "atlas_state_dict" not in ckpt:
|
| 628 |
+
raise RuntimeError(
|
| 629 |
+
f"Checkpoint missing 'atlas_state_dict'. "
|
| 630 |
+
f"Top-level keys: {sorted(ckpt.keys()) if isinstance(ckpt, dict) else type(ckpt).__name__}. "
|
| 631 |
+
f"Make sure --checkpoint points to an Atlas training checkpoint (checkpoint.pt), "
|
| 632 |
+
f"not a frozen encoder weight (.pth)."
|
| 633 |
+
)
|
| 634 |
+
atlas_sd = ckpt["atlas_state_dict"]
|
| 635 |
+
if not isinstance(atlas_sd, dict) or len(atlas_sd) == 0:
|
| 636 |
+
raise RuntimeError(
|
| 637 |
+
f"'atlas_state_dict' is empty or not a dict (type={type(atlas_sd).__name__}, "
|
| 638 |
+
f"len={len(atlas_sd) if isinstance(atlas_sd, dict) else 'N/A'}). "
|
| 639 |
+
f"Checkpoint is likely corrupted."
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
# Auto-detect LoRA: if checkpoint has LoRA keys but model doesn't have LoRA,
|
| 643 |
+
# rebuild the model with LoRA enabled to prevent silent degradation.
|
| 644 |
+
has_lora_keys = any("lora_" in k for k in atlas_sd)
|
| 645 |
+
if has_lora_keys and not args.use_lora:
|
| 646 |
+
logger.warning(
|
| 647 |
+
"Checkpoint contains LoRA weights but --use_lora was not set. "
|
| 648 |
+
"Auto-enabling LoRA to prevent silent degradation."
|
| 649 |
+
)
|
| 650 |
+
args.use_lora = True
|
| 651 |
+
atlas = AtlasForCausalLM(
|
| 652 |
+
llm_model_name=args.llm_model,
|
| 653 |
+
visual_hidden_size=args.visual_hidden_size,
|
| 654 |
+
num_queries=args.num_det_queries,
|
| 655 |
+
num_map_queries=args.num_map_queries,
|
| 656 |
+
load_in_4bit=args.load_in_4bit,
|
| 657 |
+
use_flash_attention=True,
|
| 658 |
+
device_map=dm,
|
| 659 |
+
torch_dtype=dtype,
|
| 660 |
+
use_lora=True,
|
| 661 |
+
lora_r=args.lora_r,
|
| 662 |
+
lora_alpha=args.lora_alpha,
|
| 663 |
+
)
|
| 664 |
+
atlas.resize_token_embeddings(len(tokenizer))
|
| 665 |
+
atlas.set_query_token_id(tokenizer.convert_tokens_to_ids("<query>"))
|
| 666 |
+
if dm is None:
|
| 667 |
+
atlas = atlas.to(device)
|
| 668 |
+
|
| 669 |
+
missing, unexpected = atlas.load_state_dict(atlas_sd, strict=False)
|
| 670 |
+
if missing:
|
| 671 |
+
raise RuntimeError(
|
| 672 |
+
f"Atlas checkpoint is incomplete: {len(missing)} missing keys "
|
| 673 |
+
f"(first 10: {missing[:10]}). This means the checkpoint does not "
|
| 674 |
+
f"match the current model architecture. Refusing to evaluate with "
|
| 675 |
+
f"partially-initialized weights."
|
| 676 |
+
)
|
| 677 |
+
if unexpected:
|
| 678 |
+
logger.warning("Unexpected keys in checkpoint (possibly ignored): %d keys, first 5: %s",
|
| 679 |
+
len(unexpected), unexpected[:5])
|
| 680 |
+
logger.info("Loaded Atlas weights from %s (%d keys, 0 missing)", args.checkpoint, len(atlas_sd))
|
| 681 |
+
_tp_bev_range = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0)
|
| 682 |
+
if args.topomlp_config:
|
| 683 |
+
try:
|
| 684 |
+
from mmcv import Config as _Cfg
|
| 685 |
+
_tp_cfg = _Cfg.fromfile(args.topomlp_config)
|
| 686 |
+
if hasattr(_tp_cfg, "point_cloud_range"):
|
| 687 |
+
_tp_bev_range = tuple(float(v) for v in _tp_cfg.point_cloud_range)
|
| 688 |
+
logger.info("TopoMLP bev_range from config: %s", _tp_bev_range)
|
| 689 |
+
except Exception as e:
|
| 690 |
+
logger.warning("Failed to read point_cloud_range from TopoMLP config: %s. Using default: %s", e, _tp_bev_range)
|
| 691 |
+
|
| 692 |
+
if args.topomlp_config or args.topomlp_ckpt or args.precomputed_map_tokens:
|
| 693 |
+
topomlp_adapter = TopoMLPToAtlasMapTokens(
|
| 694 |
+
num_map_tokens=args.num_map_queries,
|
| 695 |
+
hidden_size=args.visual_hidden_size,
|
| 696 |
+
bev_range=_tp_bev_range,
|
| 697 |
+
).to(device)
|
| 698 |
+
if "adapter_state_dict" in ckpt:
|
| 699 |
+
topomlp_adapter.load_state_dict(ckpt["adapter_state_dict"], strict=False)
|
| 700 |
+
topomlp_adapter.eval()
|
| 701 |
+
|
| 702 |
+
atlas.eval()
|
| 703 |
+
|
| 704 |
+
is_online = args.visual_token_mode == "online"
|
| 705 |
+
_precomp_det = args.precomputed_det_tokens if not is_online else None
|
| 706 |
+
_precomp_map = args.precomputed_map_tokens if not is_online else None
|
| 707 |
+
|
| 708 |
+
dataset = AtlasDataset(
|
| 709 |
+
json_file=args.data_json,
|
| 710 |
+
image_root=args.data_root,
|
| 711 |
+
tokenizer=tokenizer,
|
| 712 |
+
max_length=args.max_length,
|
| 713 |
+
is_training=False,
|
| 714 |
+
planning_table3_mode=args.planning_table3_mode,
|
| 715 |
+
precomputed_det_tokens=_precomp_det,
|
| 716 |
+
precomputed_map_tokens=_precomp_map,
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
_task_counts = Counter(dataset._task_types)
|
| 720 |
+
has_lane = "lane" in _task_counts
|
| 721 |
+
has_planning = "planning" in _task_counts
|
| 722 |
+
needs_map_tokens = has_lane or has_planning
|
| 723 |
+
logger.info("Task needs: has_lane=%s, has_planning=%s, needs_map=%s",
|
| 724 |
+
has_lane, has_planning, needs_map_tokens)
|
| 725 |
+
|
| 726 |
+
if has_lane:
|
| 727 |
+
_audit_lane_gt_point_counts(dataset, max_samples=100)
|
| 728 |
+
|
| 729 |
+
_check_task_dependencies(args, has_lane=has_lane, is_online=is_online)
|
| 730 |
+
|
| 731 |
+
if is_online:
|
| 732 |
+
if args.precomputed_det_tokens or args.precomputed_map_tokens:
|
| 733 |
+
raise RuntimeError(
|
| 734 |
+
"visual_token_mode=online forbids --precomputed_* arguments."
|
| 735 |
+
)
|
| 736 |
+
if args.batch_size != 1:
|
| 737 |
+
raise RuntimeError(
|
| 738 |
+
"visual_token_mode=online with temporal memory requires "
|
| 739 |
+
"--batch_size 1. Got: %d" % args.batch_size
|
| 740 |
+
)
|
| 741 |
+
if not args.streampetr_config or not args.streampetr_ckpt:
|
| 742 |
+
raise RuntimeError(
|
| 743 |
+
"online mode requires --streampetr_config and --streampetr_ckpt"
|
| 744 |
+
)
|
| 745 |
+
for p in (args.streampetr_config, args.streampetr_ckpt):
|
| 746 |
+
if not os.path.exists(p):
|
| 747 |
+
raise RuntimeError(f"Required online asset does not exist: {p}")
|
| 748 |
+
if needs_map_tokens:
|
| 749 |
+
if not args.topomlp_config or not args.topomlp_ckpt:
|
| 750 |
+
raise RuntimeError(
|
| 751 |
+
"online mode with lane/planning tasks requires "
|
| 752 |
+
"--topomlp_config and --topomlp_ckpt"
|
| 753 |
+
)
|
| 754 |
+
for p in (args.topomlp_config, args.topomlp_ckpt):
|
| 755 |
+
if not os.path.exists(p):
|
| 756 |
+
raise RuntimeError(f"Required online asset does not exist: {p}")
|
| 757 |
+
else:
|
| 758 |
+
if not _precomp_det:
|
| 759 |
+
raise RuntimeError(
|
| 760 |
+
"offline mode requires --precomputed_det_tokens"
|
| 761 |
+
)
|
| 762 |
+
if needs_map_tokens and not _precomp_map:
|
| 763 |
+
raise RuntimeError(
|
| 764 |
+
"offline mode with lane/planning tasks requires "
|
| 765 |
+
"--precomputed_map_tokens"
|
| 766 |
+
)
|
| 767 |
+
for p in (_precomp_det, _precomp_map):
|
| 768 |
+
if p and not os.path.isdir(p):
|
| 769 |
+
raise RuntimeError(f"Offline token directory does not exist: {p}")
|
| 770 |
+
|
| 771 |
+
streampetr_model = load_frozen_encoder(
|
| 772 |
+
args.streampetr_config, args.streampetr_ckpt, "streampetr", device,
|
| 773 |
+
)
|
| 774 |
+
topomlp_model = load_frozen_encoder(
|
| 775 |
+
args.topomlp_config, args.topomlp_ckpt, "topomlp", device,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
if is_online:
|
| 779 |
+
scene_groups = dataset.get_scene_groups()
|
| 780 |
+
sampler = SceneSequentialSampler(scene_groups, shuffle_scenes=False)
|
| 781 |
+
logger.info("Online eval: scene-sequential sampler (%d scenes)", len(scene_groups))
|
| 782 |
+
else:
|
| 783 |
+
sampler = None
|
| 784 |
+
|
| 785 |
+
collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id)
|
| 786 |
+
dataloader = torch.utils.data.DataLoader(
|
| 787 |
+
dataset,
|
| 788 |
+
batch_size=1 if is_online else args.batch_size,
|
| 789 |
+
shuffle=False,
|
| 790 |
+
sampler=sampler,
|
| 791 |
+
num_workers=args.num_workers,
|
| 792 |
+
collate_fn=collate_fn,
|
| 793 |
+
pin_memory=True,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
streaming_state = {} if is_online else None
|
| 797 |
+
|
| 798 |
+
task_preds: Dict[str, List] = defaultdict(list)
|
| 799 |
+
task_gts: Dict[str, List] = defaultdict(list)
|
| 800 |
+
all_outputs: List[Dict] = []
|
| 801 |
+
sample_count = 0
|
| 802 |
+
|
| 803 |
+
logger.info("Starting evaluation on %d samples (mode=%s)...", len(dataset), args.visual_token_mode)
|
| 804 |
+
|
| 805 |
+
if is_online and streampetr_model is not None:
|
| 806 |
+
streampetr_model.pts_bbox_head.reset_memory()
|
| 807 |
+
|
| 808 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 809 |
+
if args.max_samples > 0 and sample_count >= args.max_samples:
|
| 810 |
+
break
|
| 811 |
+
|
| 812 |
+
B = batch["input_ids"].shape[0]
|
| 813 |
+
input_ids = batch["input_ids"].to(device)
|
| 814 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 815 |
+
|
| 816 |
+
visual_features = extract_visual_tokens(
|
| 817 |
+
streampetr_model, topomlp_model, topomlp_adapter,
|
| 818 |
+
batch, device, args.num_det_queries, args.visual_hidden_size,
|
| 819 |
+
visual_token_mode=args.visual_token_mode,
|
| 820 |
+
streaming_state=streaming_state,
|
| 821 |
+
query_token_id=atlas.query_token_id,
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
with torch.no_grad():
|
| 825 |
+
generated_ids = atlas.generate(
|
| 826 |
+
input_ids=input_ids,
|
| 827 |
+
attention_mask=attention_mask,
|
| 828 |
+
visual_features=visual_features,
|
| 829 |
+
max_new_tokens=args.max_new_tokens,
|
| 830 |
+
do_sample=False,
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
for b in range(B):
|
| 834 |
+
if args.max_samples > 0 and sample_count >= args.max_samples:
|
| 835 |
+
break
|
| 836 |
+
|
| 837 |
+
gen_text_full = tokenizer.decode(generated_ids[b], skip_special_tokens=True)
|
| 838 |
+
# Extract assistant response: take content after the last "ASSISTANT:" tag
|
| 839 |
+
_split_tag = "ASSISTANT:"
|
| 840 |
+
if _split_tag in gen_text_full:
|
| 841 |
+
gen_text = gen_text_full.split(_split_tag)[-1].strip()
|
| 842 |
+
else:
|
| 843 |
+
gen_text = gen_text_full.strip()
|
| 844 |
+
sample_id = batch["sample_id"][b] if "sample_id" in batch else str(sample_count)
|
| 845 |
+
item_idx = int(batch["dataset_idx"][b].item()) if "dataset_idx" in batch else (batch_idx * B + b)
|
| 846 |
+
item = dataset.data[item_idx]
|
| 847 |
+
task = infer_task(item)
|
| 848 |
+
gt = parse_gt_from_item(item, task)
|
| 849 |
+
|
| 850 |
+
record = {
|
| 851 |
+
"sample_id": sample_id,
|
| 852 |
+
"task": task,
|
| 853 |
+
"generated_text": gen_text,
|
| 854 |
+
}
|
| 855 |
+
|
| 856 |
+
if task == "detection":
|
| 857 |
+
preds = parse_atlas_output(gen_text)
|
| 858 |
+
det_preds = [p for p in preds if p.get("type") == "detection"]
|
| 859 |
+
gt_dets = gt.get("detections", [])
|
| 860 |
+
task_preds["detection"].append(det_preds)
|
| 861 |
+
task_gts["detection"].append(gt_dets)
|
| 862 |
+
record.update(summarize_detection_parse(gen_text, det_preds))
|
| 863 |
+
record["num_preds"] = len(det_preds)
|
| 864 |
+
record["num_gt"] = len(gt_dets)
|
| 865 |
+
|
| 866 |
+
elif task == "lane":
|
| 867 |
+
preds = parse_atlas_output(gen_text)
|
| 868 |
+
lane_preds = [p for p in preds if p.get("type") == "lane"]
|
| 869 |
+
gt_lanes = gt.get("lanes", [])
|
| 870 |
+
task_preds["lane"].append(lane_preds)
|
| 871 |
+
task_gts["lane"].append(gt_lanes)
|
| 872 |
+
record.update(summarize_lane_parse(gen_text, lane_preds))
|
| 873 |
+
record["num_preds"] = len(lane_preds)
|
| 874 |
+
record["num_gt"] = len(gt_lanes)
|
| 875 |
+
|
| 876 |
+
gt_answer_text = ""
|
| 877 |
+
for _turn in item.get("conversations", []):
|
| 878 |
+
if _turn.get("from") in ("gpt", "assistant"):
|
| 879 |
+
gt_answer_text = _turn.get("value", "")
|
| 880 |
+
break
|
| 881 |
+
record.update(summarize_lane_gt_parse(gt_answer_text, gt_lanes))
|
| 882 |
+
|
| 883 |
+
elif task == "planning":
|
| 884 |
+
loose_plan_pred = parse_planning_output(gen_text, require_full_vap=False)
|
| 885 |
+
strict_plan_pred = parse_planning_output(gen_text, require_full_vap=True)
|
| 886 |
+
gt_wps = gt.get("waypoints", [])
|
| 887 |
+
gt_boxes = gt.get("gt_boxes", [])
|
| 888 |
+
has_waypoints = bool(
|
| 889 |
+
loose_plan_pred is not None and "waypoints" in loose_plan_pred
|
| 890 |
+
)
|
| 891 |
+
has_velocity = bool(
|
| 892 |
+
loose_plan_pred is not None and "velocity_bins" in loose_plan_pred
|
| 893 |
+
)
|
| 894 |
+
has_acceleration = bool(
|
| 895 |
+
loose_plan_pred is not None and "acceleration_bins" in loose_plan_pred
|
| 896 |
+
)
|
| 897 |
+
vap_complete = bool(strict_plan_pred is not None)
|
| 898 |
+
|
| 899 |
+
require_strict = args.planning_table3_mode == "atlas_high_level_ego"
|
| 900 |
+
accepted_pred = strict_plan_pred if require_strict else loose_plan_pred
|
| 901 |
+
parse_ok = accepted_pred is not None and "waypoints" in accepted_pred
|
| 902 |
+
|
| 903 |
+
if parse_ok:
|
| 904 |
+
plan_pred = accepted_pred
|
| 905 |
+
record["planning_parse_failed"] = False
|
| 906 |
+
record["planning_vap_complete"] = vap_complete
|
| 907 |
+
record["planning_has_velocity"] = has_velocity
|
| 908 |
+
record["planning_has_acceleration"] = has_acceleration
|
| 909 |
+
record["planning_has_waypoints"] = True
|
| 910 |
+
if require_strict:
|
| 911 |
+
record["planning_parse_failure_reason"] = ""
|
| 912 |
+
else:
|
| 913 |
+
record["planning_parse_failure_reason"] = (
|
| 914 |
+
"" if vap_complete else "missing_velocity_or_acceleration"
|
| 915 |
+
)
|
| 916 |
+
else:
|
| 917 |
+
plan_pred = {"waypoints": [[0.0, 0.0]] * max(len(gt_wps), 6)}
|
| 918 |
+
record["planning_parse_failed"] = True
|
| 919 |
+
record["planning_vap_complete"] = False
|
| 920 |
+
record["planning_has_velocity"] = has_velocity
|
| 921 |
+
record["planning_has_acceleration"] = has_acceleration
|
| 922 |
+
record["planning_has_waypoints"] = has_waypoints
|
| 923 |
+
if require_strict and has_waypoints and not vap_complete:
|
| 924 |
+
record["planning_parse_failure_reason"] = (
|
| 925 |
+
"strict_mode_missing_velocity_or_acceleration"
|
| 926 |
+
)
|
| 927 |
+
elif has_waypoints:
|
| 928 |
+
record["planning_parse_failure_reason"] = "missing_velocity_or_acceleration"
|
| 929 |
+
else:
|
| 930 |
+
record["planning_parse_failure_reason"] = "unparseable_waypoints"
|
| 931 |
+
task_preds["planning"].append(plan_pred)
|
| 932 |
+
plan_gt_entry = {
|
| 933 |
+
"waypoints": gt_wps,
|
| 934 |
+
"gt_boxes": gt_boxes,
|
| 935 |
+
}
|
| 936 |
+
if "gt_boxes_per_timestep" in gt:
|
| 937 |
+
plan_gt_entry["gt_boxes_per_timestep"] = gt["gt_boxes_per_timestep"]
|
| 938 |
+
task_gts["planning"].append(plan_gt_entry)
|
| 939 |
+
record["has_plan"] = has_waypoints
|
| 940 |
+
|
| 941 |
+
elif task == "caption":
|
| 942 |
+
record["skipped"] = True
|
| 943 |
+
|
| 944 |
+
else:
|
| 945 |
+
logger.warning("Unknown task %r for sample %s — skipping metrics", task, sample_id)
|
| 946 |
+
record["skipped"] = True
|
| 947 |
+
|
| 948 |
+
all_outputs.append(record)
|
| 949 |
+
sample_count += 1
|
| 950 |
+
|
| 951 |
+
if (batch_idx + 1) % 50 == 0:
|
| 952 |
+
logger.info("Processed %d / %d samples", sample_count, len(dataset))
|
| 953 |
+
|
| 954 |
+
logger.info("Evaluation complete. Total samples: %d", sample_count)
|
| 955 |
+
|
| 956 |
+
_skipped = sum(1 for r in all_outputs if r.get("skipped"))
|
| 957 |
+
if _skipped:
|
| 958 |
+
logger.warning(
|
| 959 |
+
"%d samples were not scored (caption or unknown task). "
|
| 960 |
+
"These consumed GPU time but produced no metrics.",
|
| 961 |
+
_skipped,
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
results = {}
|
| 965 |
+
|
| 966 |
+
if task_preds["detection"] and task_gts["detection"]:
|
| 967 |
+
thresholds = (0.5, 1.0, 2.0, 4.0)
|
| 968 |
+
global_counts = {t: {"tp": 0, "fp": 0, "fn": 0} for t in thresholds}
|
| 969 |
+
for s_preds, s_gts in zip(task_preds["detection"], task_gts["detection"]):
|
| 970 |
+
for t in thresholds:
|
| 971 |
+
m = calculate_detection_f1(s_preds, s_gts, threshold=t)
|
| 972 |
+
global_counts[t]["tp"] += m["tp"]
|
| 973 |
+
global_counts[t]["fp"] += m["fp"]
|
| 974 |
+
global_counts[t]["fn"] += m["fn"]
|
| 975 |
+
det_results = {}
|
| 976 |
+
for t in thresholds:
|
| 977 |
+
tp, fp, fn = global_counts[t]["tp"], global_counts[t]["fp"], global_counts[t]["fn"]
|
| 978 |
+
p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 979 |
+
r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 980 |
+
f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
|
| 981 |
+
det_results[f"F1@{t}m"] = round(f1, 4)
|
| 982 |
+
det_results[f"P@{t}m"] = round(p, 4)
|
| 983 |
+
det_results[f"R@{t}m"] = round(r, 4)
|
| 984 |
+
n_det_total = len(task_preds["detection"])
|
| 985 |
+
n_det_failed = sum(
|
| 986 |
+
1
|
| 987 |
+
for r in all_outputs
|
| 988 |
+
if r.get("task") == "detection" and r.get("detection_parse_failed", False)
|
| 989 |
+
)
|
| 990 |
+
n_det_partial = sum(
|
| 991 |
+
1
|
| 992 |
+
for r in all_outputs
|
| 993 |
+
if r.get("task") == "detection" and r.get("detection_partial_parse_suspected", False)
|
| 994 |
+
)
|
| 995 |
+
n_det_empty_negative = sum(
|
| 996 |
+
1
|
| 997 |
+
for r in all_outputs
|
| 998 |
+
if r.get("task") == "detection" and r.get("detection_is_empty_negative", False)
|
| 999 |
+
)
|
| 1000 |
+
det_results["num_samples"] = n_det_total
|
| 1001 |
+
det_results["parse_fail_count"] = n_det_failed
|
| 1002 |
+
det_results["parse_fail_rate"] = n_det_failed / max(n_det_total, 1)
|
| 1003 |
+
det_results["partial_parse_count"] = n_det_partial
|
| 1004 |
+
det_results["partial_parse_rate"] = n_det_partial / max(n_det_total, 1)
|
| 1005 |
+
det_results["empty_negative_count"] = n_det_empty_negative
|
| 1006 |
+
det_results["empty_negative_rate"] = n_det_empty_negative / max(n_det_total, 1)
|
| 1007 |
+
results["detection"] = det_results
|
| 1008 |
+
logger.info("Detection results (micro-averaged):")
|
| 1009 |
+
for k, v in sorted(results["detection"].items()):
|
| 1010 |
+
if isinstance(v, float):
|
| 1011 |
+
logger.info(" %s: %.4f", k, v)
|
| 1012 |
+
|
| 1013 |
+
if task_preds["lane"] and task_gts["lane"]:
|
| 1014 |
+
try:
|
| 1015 |
+
from openlanev2.evaluation.f_score import LaneEval
|
| 1016 |
+
except ImportError:
|
| 1017 |
+
raise RuntimeError(
|
| 1018 |
+
"openlanev2 is required for lane evaluation but could not be imported. "
|
| 1019 |
+
"Install it with: pip install openlanev2"
|
| 1020 |
+
)
|
| 1021 |
+
_lane_evaluator = LaneEval()
|
| 1022 |
+
|
| 1023 |
+
def _lanes_to_ndarray_list(lanes):
|
| 1024 |
+
out = []
|
| 1025 |
+
for lane in lanes:
|
| 1026 |
+
pts = lane.get("points", [])
|
| 1027 |
+
if not pts:
|
| 1028 |
+
continue
|
| 1029 |
+
rows = []
|
| 1030 |
+
for pt in pts:
|
| 1031 |
+
if isinstance(pt, dict):
|
| 1032 |
+
rows.append(pt.get("world_coords", [0, 0, 0])[:3])
|
| 1033 |
+
else:
|
| 1034 |
+
rows.append(list(pt)[:3])
|
| 1035 |
+
arr = np.array(rows, dtype=np.float64)
|
| 1036 |
+
if arr.shape[0] >= 2:
|
| 1037 |
+
out.append(arr)
|
| 1038 |
+
return out
|
| 1039 |
+
|
| 1040 |
+
stats = []
|
| 1041 |
+
for pl, gl in zip(task_preds["lane"], task_gts["lane"]):
|
| 1042 |
+
pa = _lanes_to_ndarray_list(pl)
|
| 1043 |
+
ga = _lanes_to_ndarray_list(gl)
|
| 1044 |
+
pc = [np.int8(1)] * len(pa)
|
| 1045 |
+
gc = [np.int8(1)] * len(ga)
|
| 1046 |
+
r, p, c, ng, np_, mn = _lane_evaluator.bench(pa, pc, ga, gc)
|
| 1047 |
+
stats.append(np.array([r, p, c, ng, np_, mn]))
|
| 1048 |
+
if stats:
|
| 1049 |
+
s = np.array(stats)
|
| 1050 |
+
tg = np.sum(s[:, 3])
|
| 1051 |
+
tp_sum = np.sum(s[:, 4])
|
| 1052 |
+
lane_r = float(np.sum(s[:, 0]) / max(tg, 1e-6))
|
| 1053 |
+
lane_p = float(np.sum(s[:, 1]) / max(tp_sum, 1e-6))
|
| 1054 |
+
lane_f1 = 2 * lane_p * lane_r / (lane_p + lane_r) if (lane_p + lane_r) > 0 else 0.0
|
| 1055 |
+
else:
|
| 1056 |
+
lane_p = lane_r = lane_f1 = 0.0
|
| 1057 |
+
results["lane"] = {
|
| 1058 |
+
"lane_precision": round(lane_p, 4),
|
| 1059 |
+
"lane_recall": round(lane_r, 4),
|
| 1060 |
+
"lane_f1": round(lane_f1, 4),
|
| 1061 |
+
"method": "openlanev2_f_score",
|
| 1062 |
+
}
|
| 1063 |
+
n_lane_total = len(task_preds["lane"])
|
| 1064 |
+
n_lane_failed = sum(
|
| 1065 |
+
1
|
| 1066 |
+
for r in all_outputs
|
| 1067 |
+
if r.get("task") == "lane" and r.get("lane_parse_failed", False)
|
| 1068 |
+
)
|
| 1069 |
+
n_lane_partial = sum(
|
| 1070 |
+
1
|
| 1071 |
+
for r in all_outputs
|
| 1072 |
+
if r.get("task") == "lane" and r.get("lane_partial_parse_suspected", False)
|
| 1073 |
+
)
|
| 1074 |
+
n_lane_empty_negative = sum(
|
| 1075 |
+
1
|
| 1076 |
+
for r in all_outputs
|
| 1077 |
+
if r.get("task") == "lane" and r.get("lane_is_empty_negative", False)
|
| 1078 |
+
)
|
| 1079 |
+
results["lane"]["num_samples"] = n_lane_total
|
| 1080 |
+
results["lane"]["parse_fail_count"] = n_lane_failed
|
| 1081 |
+
results["lane"]["parse_fail_rate"] = n_lane_failed / max(n_lane_total, 1)
|
| 1082 |
+
results["lane"]["partial_parse_count"] = n_lane_partial
|
| 1083 |
+
results["lane"]["partial_parse_rate"] = n_lane_partial / max(n_lane_total, 1)
|
| 1084 |
+
results["lane"]["empty_negative_count"] = n_lane_empty_negative
|
| 1085 |
+
results["lane"]["empty_negative_rate"] = n_lane_empty_negative / max(n_lane_total, 1)
|
| 1086 |
+
|
| 1087 |
+
n_gt_lane_failed = sum(
|
| 1088 |
+
1
|
| 1089 |
+
for r in all_outputs
|
| 1090 |
+
if r.get("task") == "lane" and r.get("gt_lane_parse_failed", False)
|
| 1091 |
+
)
|
| 1092 |
+
n_gt_lane_partial = sum(
|
| 1093 |
+
1
|
| 1094 |
+
for r in all_outputs
|
| 1095 |
+
if r.get("task") == "lane" and r.get("gt_lane_partial_parse_suspected", False)
|
| 1096 |
+
)
|
| 1097 |
+
n_gt_lane_empty_negative = sum(
|
| 1098 |
+
1
|
| 1099 |
+
for r in all_outputs
|
| 1100 |
+
if r.get("task") == "lane" and r.get("gt_lane_is_empty_negative", False)
|
| 1101 |
+
)
|
| 1102 |
+
results["lane"]["gt_parse_fail_count"] = n_gt_lane_failed
|
| 1103 |
+
results["lane"]["gt_parse_fail_rate"] = n_gt_lane_failed / max(n_lane_total, 1)
|
| 1104 |
+
results["lane"]["gt_partial_parse_count"] = n_gt_lane_partial
|
| 1105 |
+
results["lane"]["gt_partial_parse_rate"] = n_gt_lane_partial / max(n_lane_total, 1)
|
| 1106 |
+
results["lane"]["gt_empty_negative_count"] = n_gt_lane_empty_negative
|
| 1107 |
+
results["lane"]["gt_empty_negative_rate"] = n_gt_lane_empty_negative / max(n_lane_total, 1)
|
| 1108 |
+
|
| 1109 |
+
logger.info("Lane results (OpenLane-V2 official F-Score):")
|
| 1110 |
+
for k, v in sorted(results["lane"].items()):
|
| 1111 |
+
if isinstance(v, float):
|
| 1112 |
+
logger.info(" %s: %.4f", k, v)
|
| 1113 |
+
else:
|
| 1114 |
+
logger.info(" %s: %s", k, v)
|
| 1115 |
+
|
| 1116 |
+
if task_preds["planning"] and task_gts["planning"]:
|
| 1117 |
+
results["planning"] = calculate_planning_metrics(
|
| 1118 |
+
task_preds["planning"], task_gts["planning"],
|
| 1119 |
+
)
|
| 1120 |
+
n_plan_total = len(task_preds["planning"])
|
| 1121 |
+
n_plan_failed = sum(
|
| 1122 |
+
1
|
| 1123 |
+
for r in all_outputs
|
| 1124 |
+
if r.get("task") == "planning" and r.get("planning_parse_failed", False)
|
| 1125 |
+
)
|
| 1126 |
+
n_plan_vap_complete = sum(
|
| 1127 |
+
1
|
| 1128 |
+
for r in all_outputs
|
| 1129 |
+
if r.get("task") == "planning" and r.get("planning_vap_complete", False)
|
| 1130 |
+
)
|
| 1131 |
+
n_plan_missing_va = sum(
|
| 1132 |
+
1
|
| 1133 |
+
for r in all_outputs
|
| 1134 |
+
if r.get("task") == "planning"
|
| 1135 |
+
and (not r.get("planning_parse_failed", False))
|
| 1136 |
+
and (not r.get("planning_vap_complete", True))
|
| 1137 |
+
)
|
| 1138 |
+
results["planning"]["num_samples"] = n_plan_total
|
| 1139 |
+
results["planning"]["parse_fail_count"] = n_plan_failed
|
| 1140 |
+
results["planning"]["parse_fail_rate"] = (
|
| 1141 |
+
n_plan_failed / max(n_plan_total, 1)
|
| 1142 |
+
)
|
| 1143 |
+
results["planning"]["vap_complete_count"] = n_plan_vap_complete
|
| 1144 |
+
results["planning"]["vap_complete_rate"] = (
|
| 1145 |
+
n_plan_vap_complete / max(n_plan_total, 1)
|
| 1146 |
+
)
|
| 1147 |
+
results["planning"]["missing_velocity_or_acceleration_count"] = n_plan_missing_va
|
| 1148 |
+
results["planning"]["missing_velocity_or_acceleration_rate"] = (
|
| 1149 |
+
n_plan_missing_va / max(n_plan_total, 1)
|
| 1150 |
+
)
|
| 1151 |
+
logger.info("Planning results:")
|
| 1152 |
+
for k, v in sorted(results["planning"].items()):
|
| 1153 |
+
if isinstance(v, float):
|
| 1154 |
+
logger.info(" %s: %.4f", k, v)
|
| 1155 |
+
else:
|
| 1156 |
+
logger.info(" %s: %s", k, v)
|
| 1157 |
+
|
| 1158 |
+
output_path = args.output_json
|
| 1159 |
+
if output_path is None:
|
| 1160 |
+
ckpt_dir = Path(args.checkpoint).parent
|
| 1161 |
+
output_path = str(ckpt_dir / "eval_results.json")
|
| 1162 |
+
|
| 1163 |
+
with open(output_path, "w") as f:
|
| 1164 |
+
json.dump({
|
| 1165 |
+
"metrics": results,
|
| 1166 |
+
"num_samples": sample_count,
|
| 1167 |
+
"args": vars(args),
|
| 1168 |
+
"predictions": all_outputs[:100],
|
| 1169 |
+
}, f, indent=2, ensure_ascii=False)
|
| 1170 |
+
logger.info("Results saved to %s", output_path)
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
if __name__ == "__main__":
|
| 1174 |
+
main()
|
| 1175 |
+
|
extract_streampetr_tokens.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Pre-extract frozen StreamPETR detection tokens for offline Atlas training.
|
| 3 |
+
|
| 4 |
+
Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
|
| 5 |
+
For online training (default), use: bash scripts/train_no_caption_baseline.sh
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"):
|
| 12 |
+
print(
|
| 13 |
+
"ERROR: This is an OFFLINE token extraction script.\n"
|
| 14 |
+
"It is isolated by default to prevent accidental use.\n"
|
| 15 |
+
"If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n"
|
| 16 |
+
"For online training use: bash scripts/train_no_caption_baseline.sh",
|
| 17 |
+
file=sys.stderr,
|
| 18 |
+
)
|
| 19 |
+
sys.exit(1)
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import re
|
| 24 |
+
import time
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
| 32 |
+
|
| 33 |
+
from src.model.streampetr_adapter import extract_streampetr_topk_tokens
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def parse_args():
|
| 37 |
+
p = argparse.ArgumentParser()
|
| 38 |
+
p.add_argument("--streampetr_config", required=True)
|
| 39 |
+
p.add_argument("--streampetr_ckpt", required=True)
|
| 40 |
+
p.add_argument("--data_json", required=True)
|
| 41 |
+
p.add_argument("--data_root", default="/home/guoyuanbo/autodl-tmp/data/nuscenes")
|
| 42 |
+
p.add_argument("--output_dir", required=True)
|
| 43 |
+
p.add_argument("--topk", type=int, default=256)
|
| 44 |
+
p.add_argument("--image_path_remap", default=None)
|
| 45 |
+
p.add_argument("--shard_id", type=int, default=0)
|
| 46 |
+
p.add_argument("--num_shards", type=int, default=1)
|
| 47 |
+
return p.parse_args()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_streampetr(config_path, ckpt_path, device):
|
| 51 |
+
sp_root = str(Path(__file__).resolve().parent / "external" / "StreamPETR")
|
| 52 |
+
if sp_root not in sys.path:
|
| 53 |
+
sys.path.insert(0, sp_root)
|
| 54 |
+
import projects.mmdet3d_plugin # noqa: F401
|
| 55 |
+
from mmcv import Config
|
| 56 |
+
from mmdet3d.models import build_model
|
| 57 |
+
from mmcv.runner import load_checkpoint
|
| 58 |
+
|
| 59 |
+
cfg = Config.fromfile(config_path)
|
| 60 |
+
model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
|
| 61 |
+
load_checkpoint(model, ckpt_path, map_location="cpu")
|
| 62 |
+
model.eval()
|
| 63 |
+
model.to(device)
|
| 64 |
+
for param in model.parameters():
|
| 65 |
+
param.requires_grad_(False)
|
| 66 |
+
return model
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _parse_openlane_scene_timestamp(item):
|
| 70 |
+
sample_id = str(item.get("id", ""))
|
| 71 |
+
m = re.match(r"openlane_subsetB_(?:train|val)_(.+?)_(\d+)$", sample_id)
|
| 72 |
+
if m is not None:
|
| 73 |
+
return f"openlane_{m.group(1)}", int(m.group(2))
|
| 74 |
+
|
| 75 |
+
image_paths = item.get("image_paths", [])
|
| 76 |
+
if image_paths:
|
| 77 |
+
p0 = str(image_paths[0]).replace("\\", "/")
|
| 78 |
+
m2 = re.search(r"/(?:train|val)/([^/]+)/image/[^/]+/(\d+)\.(?:jpg|jpeg|png)$", p0, flags=re.IGNORECASE)
|
| 79 |
+
if m2 is not None:
|
| 80 |
+
return f"openlane_{m2.group(1)}", int(m2.group(2))
|
| 81 |
+
|
| 82 |
+
return None, None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def build_scene_order(data_items, data_root):
|
| 86 |
+
nuscenes_root = Path(data_root)
|
| 87 |
+
sample_file = None
|
| 88 |
+
for v in ["v1.0-trainval", "v1.0-mini", "v1.0-test"]:
|
| 89 |
+
sf = nuscenes_root / v / "sample.json"
|
| 90 |
+
if sf.exists():
|
| 91 |
+
sample_file = sf
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
samples_meta = {}
|
| 95 |
+
if sample_file is None:
|
| 96 |
+
print("WARNING: sample.json not found — using OpenLane/id-based scene ordering where possible")
|
| 97 |
+
else:
|
| 98 |
+
with open(sample_file) as f:
|
| 99 |
+
samples_meta = {s["token"]: s for s in json.load(f)}
|
| 100 |
+
|
| 101 |
+
n_nuscenes = 0
|
| 102 |
+
n_openlane = 0
|
| 103 |
+
n_unknown = 0
|
| 104 |
+
|
| 105 |
+
scene_map = defaultdict(list)
|
| 106 |
+
for idx, item in enumerate(data_items):
|
| 107 |
+
sample_token = str(item.get("id", ""))
|
| 108 |
+
meta = samples_meta.get(sample_token, None)
|
| 109 |
+
if meta is not None:
|
| 110 |
+
scene_token = meta.get("scene_token", f"_nus_unknown_{idx}")
|
| 111 |
+
timestamp = int(meta.get("timestamp", 0))
|
| 112 |
+
scene_map[scene_token].append((timestamp, idx))
|
| 113 |
+
n_nuscenes += 1
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
scene_token, timestamp = _parse_openlane_scene_timestamp(item)
|
| 117 |
+
if scene_token is not None:
|
| 118 |
+
scene_map[scene_token].append((int(timestamp), idx))
|
| 119 |
+
n_openlane += 1
|
| 120 |
+
else:
|
| 121 |
+
scene_map[f"_unknown_{idx}"].append((0, idx))
|
| 122 |
+
n_unknown += 1
|
| 123 |
+
|
| 124 |
+
scenes = []
|
| 125 |
+
for scene_token in sorted(scene_map.keys()):
|
| 126 |
+
frames = sorted(scene_map[scene_token], key=lambda x: x[0])
|
| 127 |
+
scenes.append([idx for _, idx in frames])
|
| 128 |
+
|
| 129 |
+
total = sum(len(s) for s in scenes)
|
| 130 |
+
print(
|
| 131 |
+
"Scene grouping: %d scenes, %d samples (nuScenes=%d, OpenLane=%d, unknown=%d)"
|
| 132 |
+
% (len(scenes), total, n_nuscenes, n_openlane, n_unknown)
|
| 133 |
+
)
|
| 134 |
+
return scenes
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def load_and_preprocess_images(item, data_root, image_path_remap, streampetr_conf, image_transform):
|
| 138 |
+
from PIL import Image
|
| 139 |
+
import torchvision.transforms as transforms
|
| 140 |
+
|
| 141 |
+
fH, fW = streampetr_conf["final_dim"]
|
| 142 |
+
images = []
|
| 143 |
+
intrinsics_list = []
|
| 144 |
+
extrinsics_list = []
|
| 145 |
+
lidar2img_list = []
|
| 146 |
+
|
| 147 |
+
calibration = streampetr_conf.get("_calibration")
|
| 148 |
+
cam_names = [
|
| 149 |
+
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
|
| 150 |
+
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT',
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
ego_pose_out = None
|
| 154 |
+
ego_pose_inv_out = None
|
| 155 |
+
timestamp_out = None
|
| 156 |
+
|
| 157 |
+
for i, img_path in enumerate(item["image_paths"]):
|
| 158 |
+
camera_name = cam_names[i] if i < len(cam_names) else f"CAM_{i}"
|
| 159 |
+
for cam in sorted(cam_names, key=len, reverse=True):
|
| 160 |
+
if cam in str(img_path):
|
| 161 |
+
camera_name = cam
|
| 162 |
+
break
|
| 163 |
+
|
| 164 |
+
remapped = img_path
|
| 165 |
+
for old_prefix, new_prefix in image_path_remap.items():
|
| 166 |
+
if remapped.startswith(old_prefix):
|
| 167 |
+
remapped = new_prefix + remapped[len(old_prefix):]
|
| 168 |
+
break
|
| 169 |
+
|
| 170 |
+
if os.path.isabs(remapped):
|
| 171 |
+
full_path = remapped
|
| 172 |
+
else:
|
| 173 |
+
full_path = os.path.normpath(os.path.join(data_root, remapped))
|
| 174 |
+
|
| 175 |
+
img = Image.open(full_path).convert("RGB")
|
| 176 |
+
W_orig, H_orig = img.size
|
| 177 |
+
|
| 178 |
+
resize = max(fH / H_orig, fW / W_orig)
|
| 179 |
+
rW, rH = int(W_orig * resize), int(H_orig * resize)
|
| 180 |
+
crop_h = rH - fH
|
| 181 |
+
crop_w = max(0, rW - fW) // 2
|
| 182 |
+
if resize != 1.0:
|
| 183 |
+
img = img.resize((rW, rH), Image.BILINEAR)
|
| 184 |
+
img = img.crop((crop_w, crop_h, crop_w + fW, crop_h + fH))
|
| 185 |
+
|
| 186 |
+
K = None
|
| 187 |
+
E = None
|
| 188 |
+
ep_rec = None
|
| 189 |
+
sd_rec = None
|
| 190 |
+
|
| 191 |
+
if calibration is not None:
|
| 192 |
+
norm_path = str(img_path).replace("\\", "/").lstrip("./")
|
| 193 |
+
sd_rec = None
|
| 194 |
+
for cand in [img_path, norm_path]:
|
| 195 |
+
sd_rec = calibration["sample_data_by_filename"].get(cand)
|
| 196 |
+
if sd_rec:
|
| 197 |
+
break
|
| 198 |
+
if sd_rec is None:
|
| 199 |
+
for key in ("samples/", "sweeps/"):
|
| 200 |
+
if key in norm_path:
|
| 201 |
+
sd_rec = calibration["sample_data_by_filename"].get(norm_path[norm_path.index(key):])
|
| 202 |
+
if sd_rec:
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
if sd_rec is not None:
|
| 206 |
+
cs = calibration["calibrated_sensor_by_token"].get(sd_rec.get("calibrated_sensor_token"))
|
| 207 |
+
ep_rec = calibration["ego_pose_by_token"].get(sd_rec.get("ego_pose_token"))
|
| 208 |
+
if cs is not None:
|
| 209 |
+
K = np.array(cs["camera_intrinsic"], dtype=np.float32)
|
| 210 |
+
q = cs["rotation"]
|
| 211 |
+
t = cs["translation"]
|
| 212 |
+
w, x, y, z = q
|
| 213 |
+
R = np.array([
|
| 214 |
+
[1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
|
| 215 |
+
[2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w)],
|
| 216 |
+
[2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y)],
|
| 217 |
+
], dtype=np.float32)
|
| 218 |
+
E = np.eye(4, dtype=np.float32)
|
| 219 |
+
E[:3, :3] = R
|
| 220 |
+
E[:3, 3] = np.array(t, dtype=np.float32)
|
| 221 |
+
|
| 222 |
+
if ego_pose_out is None and ep_rec is not None:
|
| 223 |
+
q_ep = ep_rec.get("rotation")
|
| 224 |
+
t_ep = ep_rec.get("translation")
|
| 225 |
+
if q_ep is not None and t_ep is not None:
|
| 226 |
+
w, x, y, z = q_ep
|
| 227 |
+
R_ep = np.array([
|
| 228 |
+
[1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
|
| 229 |
+
[2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w)],
|
| 230 |
+
[2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y)],
|
| 231 |
+
], dtype=np.float32)
|
| 232 |
+
ego_m = np.eye(4, dtype=np.float32)
|
| 233 |
+
ego_m[:3, :3] = R_ep
|
| 234 |
+
ego_m[:3, 3] = np.array(t_ep, dtype=np.float32)
|
| 235 |
+
ego_pose_out = torch.tensor(ego_m, dtype=torch.float32)
|
| 236 |
+
try:
|
| 237 |
+
ego_pose_inv_out = torch.tensor(np.linalg.inv(ego_m), dtype=torch.float32)
|
| 238 |
+
except Exception:
|
| 239 |
+
ego_pose_inv_out = None
|
| 240 |
+
|
| 241 |
+
if timestamp_out is None and sd_rec is not None:
|
| 242 |
+
ts = sd_rec.get("timestamp")
|
| 243 |
+
if ts is not None:
|
| 244 |
+
timestamp_out = torch.tensor(float(ts) * 1e-6, dtype=torch.float32)
|
| 245 |
+
|
| 246 |
+
if K is None or E is None:
|
| 247 |
+
sensor = (item or {}).get("sensor", None) if isinstance(item, dict) else None
|
| 248 |
+
if isinstance(sensor, dict) and camera_name in sensor:
|
| 249 |
+
cam_s = sensor[camera_name]
|
| 250 |
+
K = np.array(cam_s["intrinsic"]["K"], dtype=np.float32)
|
| 251 |
+
R = np.array(cam_s["extrinsic"]["rotation"], dtype=np.float32)
|
| 252 |
+
t = np.array(cam_s["extrinsic"]["translation"], dtype=np.float32)
|
| 253 |
+
E = np.eye(4, dtype=np.float32)
|
| 254 |
+
E[:3, :3] = R
|
| 255 |
+
E[:3, 3] = t
|
| 256 |
+
|
| 257 |
+
if K is None or E is None:
|
| 258 |
+
raise RuntimeError(f"no camera params for {img_path} (camera={camera_name})")
|
| 259 |
+
|
| 260 |
+
K_adj = K.copy()
|
| 261 |
+
K_adj[0, 0] *= resize
|
| 262 |
+
K_adj[1, 1] *= resize
|
| 263 |
+
K_adj[0, 2] = K_adj[0, 2] * resize - crop_w
|
| 264 |
+
K_adj[1, 2] = K_adj[1, 2] * resize - crop_h
|
| 265 |
+
|
| 266 |
+
images.append(image_transform(img))
|
| 267 |
+
intrinsics_list.append(torch.tensor(K_adj, dtype=torch.float32))
|
| 268 |
+
extrinsics_list.append(torch.tensor(E, dtype=torch.float32))
|
| 269 |
+
|
| 270 |
+
cam2ego = E.astype(np.float32)
|
| 271 |
+
ego2cam = np.linalg.inv(cam2ego)
|
| 272 |
+
K4 = np.eye(4, dtype=np.float32)
|
| 273 |
+
K4[:3, :3] = K_adj.astype(np.float32)
|
| 274 |
+
|
| 275 |
+
def _qR(qwxyz):
|
| 276 |
+
ww, xx, yy, zz = qwxyz
|
| 277 |
+
return np.array([
|
| 278 |
+
[1-2*(yy*yy+zz*zz), 2*(xx*yy-zz*ww), 2*(xx*zz+yy*ww)],
|
| 279 |
+
[2*(xx*yy+zz*ww), 1-2*(xx*xx+zz*zz), 2*(yy*zz-xx*ww)],
|
| 280 |
+
[2*(xx*zz-yy*ww), 2*(yy*zz+xx*ww), 1-2*(xx*xx+yy*yy)],
|
| 281 |
+
], dtype=np.float32)
|
| 282 |
+
|
| 283 |
+
def _T(Rm, tv):
|
| 284 |
+
T = np.eye(4, dtype=np.float32)
|
| 285 |
+
T[:3, :3] = Rm
|
| 286 |
+
T[:3, 3] = np.array(tv, dtype=np.float32)
|
| 287 |
+
return T
|
| 288 |
+
|
| 289 |
+
lidar2img_mat = K4 @ ego2cam
|
| 290 |
+
if calibration is not None and sd_rec is not None and ep_rec is not None:
|
| 291 |
+
sample_tk = sd_rec.get("sample_token")
|
| 292 |
+
if sample_tk:
|
| 293 |
+
ego2global_c = _T(_qR(ep_rec["rotation"]), ep_rec["translation"])
|
| 294 |
+
global2ego_c = np.linalg.inv(ego2global_c)
|
| 295 |
+
lidar_sd = calibration.get("lidar_sd_by_sample_token", {}).get(str(sample_tk))
|
| 296 |
+
if lidar_sd is not None:
|
| 297 |
+
lidar_cs = calibration["calibrated_sensor_by_token"].get(lidar_sd.get("calibrated_sensor_token"))
|
| 298 |
+
lidar_ep = calibration["ego_pose_by_token"].get(lidar_sd.get("ego_pose_token"))
|
| 299 |
+
if lidar_cs is not None and lidar_ep is not None:
|
| 300 |
+
lidar2ego = _T(_qR(lidar_cs["rotation"]), lidar_cs["translation"])
|
| 301 |
+
ego2global_l = _T(_qR(lidar_ep["rotation"]), lidar_ep["translation"])
|
| 302 |
+
lidar2cam = ego2cam @ global2ego_c @ ego2global_l @ lidar2ego
|
| 303 |
+
lidar2img_mat = K4 @ lidar2cam
|
| 304 |
+
|
| 305 |
+
lidar2img_list.append(torch.tensor(lidar2img_mat, dtype=torch.float32))
|
| 306 |
+
|
| 307 |
+
# Fallback: if nuScenes calibration lookup failed (e.g. OpenLane samples),
|
| 308 |
+
# recover ego_pose from item["pose"] and timestamp from item["timestamp"].
|
| 309 |
+
if ego_pose_out is None and isinstance(item, dict):
|
| 310 |
+
pose_data = item.get("pose", None)
|
| 311 |
+
if isinstance(pose_data, dict):
|
| 312 |
+
try:
|
| 313 |
+
rot_raw = pose_data.get("rotation", None)
|
| 314 |
+
t_p = pose_data.get("translation", None)
|
| 315 |
+
if rot_raw is not None and t_p is not None:
|
| 316 |
+
arr = np.array(rot_raw, dtype=np.float32)
|
| 317 |
+
if arr.shape == (3, 3):
|
| 318 |
+
R_p = arr
|
| 319 |
+
elif arr.shape == (4,):
|
| 320 |
+
w, x, y, z = arr
|
| 321 |
+
R_p = np.array([
|
| 322 |
+
[1-2*(y*y+z*z), 2*(x*y-z*w), 2*(x*z+y*w)],
|
| 323 |
+
[2*(x*y+z*w), 1-2*(x*x+z*z), 2*(y*z-x*w)],
|
| 324 |
+
[2*(x*z-y*w), 2*(y*z+x*w), 1-2*(x*x+y*y)],
|
| 325 |
+
], dtype=np.float32)
|
| 326 |
+
else:
|
| 327 |
+
raise ValueError(f"Unsupported rotation shape: {arr.shape}")
|
| 328 |
+
T_p = np.eye(4, dtype=np.float32)
|
| 329 |
+
T_p[:3, :3] = R_p
|
| 330 |
+
T_p[:3, 3] = np.array(t_p, dtype=np.float32)
|
| 331 |
+
ego_pose_out = torch.tensor(T_p, dtype=torch.float32)
|
| 332 |
+
try:
|
| 333 |
+
ego_pose_inv_out = torch.tensor(np.linalg.inv(T_p), dtype=torch.float32)
|
| 334 |
+
except Exception:
|
| 335 |
+
ego_pose_inv_out = None
|
| 336 |
+
except Exception as e:
|
| 337 |
+
print(f"WARNING: Failed to parse item['pose']: {e}")
|
| 338 |
+
|
| 339 |
+
if timestamp_out is None and isinstance(item, dict):
|
| 340 |
+
ts_raw = item.get("timestamp", None)
|
| 341 |
+
if ts_raw is not None:
|
| 342 |
+
try:
|
| 343 |
+
timestamp_out = torch.tensor(float(ts_raw) * 1e-6, dtype=torch.float32)
|
| 344 |
+
except Exception:
|
| 345 |
+
pass
|
| 346 |
+
|
| 347 |
+
return {
|
| 348 |
+
"pixel_values_det": torch.stack(images).unsqueeze(0),
|
| 349 |
+
"intrinsics_det": torch.stack(intrinsics_list).unsqueeze(0),
|
| 350 |
+
"lidar2img_det": torch.stack(lidar2img_list).unsqueeze(0),
|
| 351 |
+
"ego_pose": ego_pose_out.unsqueeze(0) if ego_pose_out is not None else None,
|
| 352 |
+
"ego_pose_inv": ego_pose_inv_out.unsqueeze(0) if ego_pose_inv_out is not None else None,
|
| 353 |
+
"timestamp": timestamp_out.unsqueeze(0) if timestamp_out is not None else None,
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def load_nuscenes_calibration(data_root):
|
| 358 |
+
nuscenes_root = Path(data_root)
|
| 359 |
+
version_dir = None
|
| 360 |
+
for v in ["v1.0-trainval", "v1.0-mini", "v1.0-test"]:
|
| 361 |
+
if (nuscenes_root / v).exists():
|
| 362 |
+
version_dir = nuscenes_root / v
|
| 363 |
+
break
|
| 364 |
+
if version_dir is None:
|
| 365 |
+
return None
|
| 366 |
+
|
| 367 |
+
needed = ["sample_data.json", "calibrated_sensor.json", "ego_pose.json", "sample.json"]
|
| 368 |
+
for n in needed:
|
| 369 |
+
if not (version_dir / n).exists():
|
| 370 |
+
return None
|
| 371 |
+
|
| 372 |
+
with open(version_dir / "sample_data.json") as f:
|
| 373 |
+
sample_data = json.load(f)
|
| 374 |
+
with open(version_dir / "calibrated_sensor.json") as f:
|
| 375 |
+
calibrated_sensor = json.load(f)
|
| 376 |
+
with open(version_dir / "ego_pose.json") as f:
|
| 377 |
+
ego_pose = json.load(f)
|
| 378 |
+
|
| 379 |
+
sd_by_fn = {r["filename"]: r for r in sample_data if "filename" in r}
|
| 380 |
+
cs_by_tok = {r["token"]: r for r in calibrated_sensor}
|
| 381 |
+
ep_by_tok = {r["token"]: r for r in ego_pose}
|
| 382 |
+
|
| 383 |
+
lidar_sd_by_sample = {}
|
| 384 |
+
for r in sample_data:
|
| 385 |
+
fn = str(r.get("filename", "")).replace("\\", "/")
|
| 386 |
+
if "/LIDAR_TOP/" in fn and fn.startswith("samples/") and r.get("is_key_frame"):
|
| 387 |
+
st = r.get("sample_token")
|
| 388 |
+
if st:
|
| 389 |
+
lidar_sd_by_sample.setdefault(str(st), r)
|
| 390 |
+
|
| 391 |
+
print(f"Calibration loaded: {len(sd_by_fn)} sample_data, {len(cs_by_tok)} cal_sensor, {len(ep_by_tok)} ego_pose, {len(lidar_sd_by_sample)} lidar_kf")
|
| 392 |
+
return {
|
| 393 |
+
"sample_data_by_filename": sd_by_fn,
|
| 394 |
+
"calibrated_sensor_by_token": cs_by_tok,
|
| 395 |
+
"ego_pose_by_token": ep_by_tok,
|
| 396 |
+
"lidar_sd_by_sample_token": lidar_sd_by_sample,
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
@torch.no_grad()
|
| 401 |
+
def run_streampetr_forward_temporal(model, batch, device, prev_exists_val, scene_token="__extract__"):
|
| 402 |
+
imgs = batch["pixel_values_det"].to(device)
|
| 403 |
+
B, N = imgs.shape[:2]
|
| 404 |
+
fH, fW = 800, 1600
|
| 405 |
+
|
| 406 |
+
img_metas = [{
|
| 407 |
+
"pad_shape": [(fH, fW, 3)] * N,
|
| 408 |
+
"img_shape": [(fH, fW, 3)] * N,
|
| 409 |
+
"scene_token": str(scene_token),
|
| 410 |
+
} for _ in range(B)]
|
| 411 |
+
if batch.get("lidar2img_det") is not None:
|
| 412 |
+
for b in range(B):
|
| 413 |
+
img_metas[b]["lidar2img"] = batch["lidar2img_det"][b].cpu().numpy()
|
| 414 |
+
|
| 415 |
+
img_feats = model.extract_img_feat(imgs, 1)
|
| 416 |
+
|
| 417 |
+
data = {
|
| 418 |
+
"img": imgs,
|
| 419 |
+
"img_feats": img_feats,
|
| 420 |
+
"prev_exists": imgs.new_tensor([prev_exists_val]),
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
if batch.get("intrinsics_det") is not None:
|
| 424 |
+
K3 = batch["intrinsics_det"].to(device)
|
| 425 |
+
K4 = torch.zeros(B, N, 4, 4, device=device, dtype=K3.dtype)
|
| 426 |
+
K4[:, :, :3, :3] = K3
|
| 427 |
+
K4[:, :, 3, 3] = 1.0
|
| 428 |
+
data["intrinsics"] = K4
|
| 429 |
+
else:
|
| 430 |
+
data["intrinsics"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
|
| 431 |
+
|
| 432 |
+
if batch.get("lidar2img_det") is not None:
|
| 433 |
+
data["lidar2img"] = batch["lidar2img_det"].to(device)
|
| 434 |
+
else:
|
| 435 |
+
data["lidar2img"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
|
| 436 |
+
|
| 437 |
+
if batch.get("ego_pose") is not None:
|
| 438 |
+
data["ego_pose"] = batch["ego_pose"].to(device)
|
| 439 |
+
else:
|
| 440 |
+
data["ego_pose"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
|
| 441 |
+
|
| 442 |
+
if batch.get("ego_pose_inv") is not None:
|
| 443 |
+
data["ego_pose_inv"] = batch["ego_pose_inv"].to(device)
|
| 444 |
+
else:
|
| 445 |
+
data["ego_pose_inv"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
|
| 446 |
+
|
| 447 |
+
if batch.get("timestamp") is not None:
|
| 448 |
+
data["timestamp"] = batch["timestamp"].to(device)
|
| 449 |
+
else:
|
| 450 |
+
data["timestamp"] = torch.zeros(B, device=device)
|
| 451 |
+
|
| 452 |
+
location = model.prepare_location(img_metas, **data)
|
| 453 |
+
outs_roi = model.forward_roi_head(location, **data)
|
| 454 |
+
topk_indexes = outs_roi["topk_indexes"]
|
| 455 |
+
|
| 456 |
+
model.pts_bbox_head(location, img_metas, topk_indexes, **data)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def main():
|
| 460 |
+
args = parse_args()
|
| 461 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 462 |
+
|
| 463 |
+
image_path_remap = {}
|
| 464 |
+
if args.image_path_remap:
|
| 465 |
+
for pair in args.image_path_remap.split(","):
|
| 466 |
+
if "=" in pair:
|
| 467 |
+
old, new = pair.split("=", 1)
|
| 468 |
+
image_path_remap[old] = new
|
| 469 |
+
|
| 470 |
+
paths = [p.strip() for p in args.data_json.split(",") if p.strip()]
|
| 471 |
+
data_items_raw = []
|
| 472 |
+
for p in paths:
|
| 473 |
+
with open(p) as f:
|
| 474 |
+
chunk = json.load(f)
|
| 475 |
+
data_items_raw.extend(chunk)
|
| 476 |
+
|
| 477 |
+
data_items = []
|
| 478 |
+
seen_ids = set()
|
| 479 |
+
for i, item in enumerate(data_items_raw):
|
| 480 |
+
sid = str(item.get("id", f"__idx_{i}"))
|
| 481 |
+
if sid in seen_ids:
|
| 482 |
+
continue
|
| 483 |
+
seen_ids.add(sid)
|
| 484 |
+
data_items.append(item)
|
| 485 |
+
print(f"Loaded {len(data_items)} unique samples from {len(paths)} file(s) (raw={len(data_items_raw)})")
|
| 486 |
+
|
| 487 |
+
calibration = load_nuscenes_calibration(args.data_root)
|
| 488 |
+
all_scenes = build_scene_order(data_items, args.data_root)
|
| 489 |
+
if args.num_shards > 1:
|
| 490 |
+
scenes = [s for i, s in enumerate(all_scenes) if i % args.num_shards == args.shard_id]
|
| 491 |
+
print(f"Shard {args.shard_id}/{args.num_shards}: {len(scenes)}/{len(all_scenes)} scenes")
|
| 492 |
+
else:
|
| 493 |
+
scenes = all_scenes
|
| 494 |
+
model = load_streampetr(args.streampetr_config, args.streampetr_ckpt, device)
|
| 495 |
+
|
| 496 |
+
import torchvision.transforms as transforms
|
| 497 |
+
image_transform = transforms.Compose([
|
| 498 |
+
transforms.ToTensor(),
|
| 499 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 500 |
+
])
|
| 501 |
+
sp_conf = {"final_dim": (800, 1600), "_calibration": calibration}
|
| 502 |
+
|
| 503 |
+
output_dir = Path(args.output_dir)
|
| 504 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 505 |
+
|
| 506 |
+
total_samples = sum(len(s) for s in scenes)
|
| 507 |
+
num_saved = 0
|
| 508 |
+
num_existed = 0
|
| 509 |
+
t0 = time.time()
|
| 510 |
+
|
| 511 |
+
for scene_idx, scene_indices in enumerate(scenes):
|
| 512 |
+
model.pts_bbox_head.reset_memory()
|
| 513 |
+
scene_token = f"scene_{scene_idx}"
|
| 514 |
+
|
| 515 |
+
for frame_idx, data_idx in enumerate(scene_indices):
|
| 516 |
+
item = data_items[data_idx]
|
| 517 |
+
sample_id = str(item.get("id", data_idx))
|
| 518 |
+
out_path = output_dir / f"{sample_id}.pt"
|
| 519 |
+
|
| 520 |
+
already_done = out_path.exists()
|
| 521 |
+
|
| 522 |
+
batch = load_and_preprocess_images(
|
| 523 |
+
item, args.data_root, image_path_remap, sp_conf, image_transform
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
prev_exists_val = 0.0 if frame_idx == 0 else 1.0
|
| 527 |
+
run_streampetr_forward_temporal(
|
| 528 |
+
model, batch, device, prev_exists_val, scene_token=scene_token
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
if not already_done:
|
| 532 |
+
ego_pose_for_ref = batch.get("ego_pose")
|
| 533 |
+
if ego_pose_for_ref is not None:
|
| 534 |
+
ego_pose_for_ref = ego_pose_for_ref.to(device)
|
| 535 |
+
det_out = extract_streampetr_topk_tokens(
|
| 536 |
+
model.pts_bbox_head, topk=args.topk, ego_pose=ego_pose_for_ref,
|
| 537 |
+
)
|
| 538 |
+
torch.save({
|
| 539 |
+
"detection": det_out["detection"][0].cpu().half(),
|
| 540 |
+
"detection_ref_points": det_out["detection_ref_points"][0].cpu().half(),
|
| 541 |
+
}, out_path)
|
| 542 |
+
num_saved += 1
|
| 543 |
+
else:
|
| 544 |
+
num_existed += 1
|
| 545 |
+
|
| 546 |
+
done = num_saved + num_existed
|
| 547 |
+
if done % 200 == 0:
|
| 548 |
+
elapsed = time.time() - t0
|
| 549 |
+
rate = done / max(elapsed, 1)
|
| 550 |
+
eta = (total_samples - done) / max(rate, 0.01)
|
| 551 |
+
print(f" [{done}/{total_samples}] saved={num_saved} existed={num_existed} "
|
| 552 |
+
f"{elapsed:.0f}s elapsed, ETA {eta:.0f}s")
|
| 553 |
+
|
| 554 |
+
elapsed = time.time() - t0
|
| 555 |
+
print(f"Done. saved={num_saved}, existed={num_existed}, total={total_samples}, time={elapsed:.0f}s")
|
| 556 |
+
print(f"Output: {output_dir}")
|
| 557 |
+
|
| 558 |
+
index = {}
|
| 559 |
+
for pt in output_dir.glob("*.pt"):
|
| 560 |
+
index[pt.stem] = pt.name
|
| 561 |
+
with open(output_dir / "index.json", "w") as f:
|
| 562 |
+
json.dump(index, f)
|
| 563 |
+
print(f"Index written: {len(index)} entries")
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
if __name__ == "__main__":
|
| 567 |
+
main()
|
| 568 |
+
|
extract_topomlp_tokens.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Pre-extract frozen TopoMLP raw outputs for offline Atlas training.
|
| 3 |
+
|
| 4 |
+
Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
|
| 5 |
+
For online training (default), use: bash scripts/train_no_caption_baseline.sh
|
| 6 |
+
|
| 7 |
+
Saves the 6 tensors that TopoMLPToAtlasMapTokens.forward() needs from the
|
| 8 |
+
last decoder layer. The adapter performs Top-K selection online during
|
| 9 |
+
training; only the frozen TopoMLP forward pass is pre-computed here.
|
| 10 |
+
|
| 11 |
+
Usage (4-GPU parallel, requires ATLAS_ALLOW_OFFLINE=1):
|
| 12 |
+
for i in 0 1 2 3; do
|
| 13 |
+
ATLAS_ALLOW_OFFLINE=1 CUDA_VISIBLE_DEVICES=$i python extract_topomlp_tokens.py \
|
| 14 |
+
--topomlp_config configs/topomlp_atlas_aligned.py \
|
| 15 |
+
--topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
|
| 16 |
+
--data_json data/openlane_subsetB_lane_train_4pt.json \
|
| 17 |
+
--output_dir work_dirs/precomputed_map_tokens_offline/train \
|
| 18 |
+
--shard_id $i --num_shards 4 &
|
| 19 |
+
done; wait
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
if os.environ.get("ATLAS_ALLOW_OFFLINE", "").lower() not in ("1", "true", "yes"):
|
| 26 |
+
print(
|
| 27 |
+
"ERROR: This is an OFFLINE token extraction script.\n"
|
| 28 |
+
"It is isolated by default to prevent accidental use.\n"
|
| 29 |
+
"If you really need it, set: ATLAS_ALLOW_OFFLINE=1\n"
|
| 30 |
+
"For online training use: bash scripts/train_no_caption_baseline.sh",
|
| 31 |
+
file=sys.stderr,
|
| 32 |
+
)
|
| 33 |
+
sys.exit(1)
|
| 34 |
+
|
| 35 |
+
import argparse
|
| 36 |
+
import json
|
| 37 |
+
import time
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
|
| 40 |
+
import numpy as np
|
| 41 |
+
import torch
|
| 42 |
+
from PIL import Image
|
| 43 |
+
import torchvision.transforms as transforms
|
| 44 |
+
|
| 45 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _load_nuscenes_calibration(data_root):
|
| 49 |
+
nuscenes_root = Path(data_root)
|
| 50 |
+
version_dir = None
|
| 51 |
+
for v in ["v1.0-trainval", "v1.0-mini", "v1.0-test"]:
|
| 52 |
+
if (nuscenes_root / v).exists():
|
| 53 |
+
version_dir = nuscenes_root / v
|
| 54 |
+
break
|
| 55 |
+
if version_dir is None:
|
| 56 |
+
return None
|
| 57 |
+
needed = ["sample_data.json", "calibrated_sensor.json", "ego_pose.json"]
|
| 58 |
+
for n in needed:
|
| 59 |
+
if not (version_dir / n).exists():
|
| 60 |
+
return None
|
| 61 |
+
with open(version_dir / "sample_data.json") as f:
|
| 62 |
+
sample_data = json.load(f)
|
| 63 |
+
with open(version_dir / "calibrated_sensor.json") as f:
|
| 64 |
+
calibrated_sensor = json.load(f)
|
| 65 |
+
with open(version_dir / "ego_pose.json") as f:
|
| 66 |
+
ego_pose_data = json.load(f)
|
| 67 |
+
sd_by_fn = {r["filename"]: r for r in sample_data if "filename" in r}
|
| 68 |
+
cs_by_tok = {r["token"]: r for r in calibrated_sensor}
|
| 69 |
+
ep_by_tok = {r["token"]: r for r in ego_pose_data}
|
| 70 |
+
lidar_sd_by_sample = {}
|
| 71 |
+
for r in sample_data:
|
| 72 |
+
fn = str(r.get("filename", "")).replace("\\", "/")
|
| 73 |
+
if "/LIDAR_TOP/" in fn and fn.startswith("samples/") and r.get("is_key_frame"):
|
| 74 |
+
st = r.get("sample_token")
|
| 75 |
+
if st:
|
| 76 |
+
lidar_sd_by_sample.setdefault(str(st), r)
|
| 77 |
+
print(f"nuScenes calibration: {len(sd_by_fn)} sample_data, {len(cs_by_tok)} cal_sensor, "
|
| 78 |
+
f"{len(ep_by_tok)} ego_pose, {len(lidar_sd_by_sample)} lidar_kf")
|
| 79 |
+
return {
|
| 80 |
+
"sample_data_by_filename": sd_by_fn,
|
| 81 |
+
"calibrated_sensor_by_token": cs_by_tok,
|
| 82 |
+
"ego_pose_by_token": ep_by_tok,
|
| 83 |
+
"lidar_sd_by_sample_token": lidar_sd_by_sample,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def parse_args():
|
| 88 |
+
p = argparse.ArgumentParser()
|
| 89 |
+
p.add_argument("--topomlp_config", required=True)
|
| 90 |
+
p.add_argument("--topomlp_ckpt", required=True)
|
| 91 |
+
p.add_argument("--data_json", required=True)
|
| 92 |
+
p.add_argument("--data_root", default="")
|
| 93 |
+
p.add_argument("--output_dir", required=True)
|
| 94 |
+
p.add_argument("--image_path_remap", default=None)
|
| 95 |
+
p.add_argument("--shard_id", type=int, default=0)
|
| 96 |
+
p.add_argument("--num_shards", type=int, default=1)
|
| 97 |
+
return p.parse_args()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def load_topomlp(config_path, ckpt_path, device):
|
| 101 |
+
tp_root = str(Path(__file__).resolve().parent / "external" / "TopoMLP_Repo")
|
| 102 |
+
if tp_root not in sys.path:
|
| 103 |
+
sys.path.insert(0, tp_root)
|
| 104 |
+
|
| 105 |
+
sp_root = str(Path(__file__).resolve().parent / "external" / "StreamPETR")
|
| 106 |
+
if sp_root not in sys.path:
|
| 107 |
+
sys.path.insert(0, sp_root)
|
| 108 |
+
try:
|
| 109 |
+
import projects.mmdet3d_plugin # noqa: F401
|
| 110 |
+
except ImportError:
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
os.environ["ATLAS_TOPOMLP_MODELS_ONLY"] = "1"
|
| 114 |
+
from mmcv.utils import registry as _reg
|
| 115 |
+
_orig = _reg.Registry._register_module
|
| 116 |
+
def _tolerant_register(self, module, module_name=None, force=False):
|
| 117 |
+
return _orig(self, module, module_name=module_name, force=True)
|
| 118 |
+
_reg.Registry._register_module = _tolerant_register
|
| 119 |
+
import projects.topomlp # noqa: F401
|
| 120 |
+
_reg.Registry._register_module = _orig
|
| 121 |
+
|
| 122 |
+
from mmcv import Config
|
| 123 |
+
from mmdet3d.models import build_model
|
| 124 |
+
from mmcv.runner import load_checkpoint
|
| 125 |
+
|
| 126 |
+
cfg = Config.fromfile(config_path)
|
| 127 |
+
model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
|
| 128 |
+
load_checkpoint(model, ckpt_path, map_location="cpu")
|
| 129 |
+
model.eval()
|
| 130 |
+
model.to(device)
|
| 131 |
+
for param in model.parameters():
|
| 132 |
+
param.requires_grad_(False)
|
| 133 |
+
print(f"Loaded frozen TopoMLP from {ckpt_path}")
|
| 134 |
+
return model
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _quat_to_rot(q):
|
| 138 |
+
w, x, y, z = q
|
| 139 |
+
return np.array([
|
| 140 |
+
[1 - 2*(y*y + z*z), 2*(x*y - z*w), 2*(x*z + y*w)],
|
| 141 |
+
[2*(x*y + z*w), 1 - 2*(x*x + z*z), 2*(y*z - x*w)],
|
| 142 |
+
[2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x*x + y*y)],
|
| 143 |
+
], dtype=np.float32)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def load_and_preprocess_images(item, data_root, image_path_remap, image_transform,
|
| 147 |
+
calibration=None):
|
| 148 |
+
tW, tH = 1600, 800
|
| 149 |
+
cam_names = [
|
| 150 |
+
"CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT",
|
| 151 |
+
"CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT",
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
images = []
|
| 155 |
+
lidar2img_list = []
|
| 156 |
+
sensor = item.get("sensor", {})
|
| 157 |
+
|
| 158 |
+
for i, img_path in enumerate(item["image_paths"]):
|
| 159 |
+
camera_name = cam_names[i] if i < len(cam_names) else f"CAM_{i}"
|
| 160 |
+
for cam in sorted(cam_names, key=len, reverse=True):
|
| 161 |
+
if cam in str(img_path):
|
| 162 |
+
camera_name = cam
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
remapped = str(img_path)
|
| 166 |
+
if image_path_remap:
|
| 167 |
+
for old_prefix, new_prefix in image_path_remap.items():
|
| 168 |
+
if remapped.startswith(old_prefix):
|
| 169 |
+
remapped = new_prefix + remapped[len(old_prefix):]
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
if os.path.isabs(remapped):
|
| 173 |
+
full_path = remapped
|
| 174 |
+
elif data_root:
|
| 175 |
+
full_path = os.path.normpath(os.path.join(data_root, remapped))
|
| 176 |
+
else:
|
| 177 |
+
full_path = remapped
|
| 178 |
+
|
| 179 |
+
img = Image.open(full_path).convert("RGB")
|
| 180 |
+
W_orig, H_orig = img.size
|
| 181 |
+
w_scale = tW / W_orig
|
| 182 |
+
h_scale = tH / H_orig
|
| 183 |
+
img = img.resize((tW, tH), Image.BILINEAR)
|
| 184 |
+
images.append(image_transform(img))
|
| 185 |
+
|
| 186 |
+
K, E = None, None
|
| 187 |
+
if isinstance(sensor, dict) and camera_name in sensor:
|
| 188 |
+
cam_s = sensor[camera_name]
|
| 189 |
+
K = np.array(cam_s["intrinsic"]["K"], dtype=np.float32)
|
| 190 |
+
R = np.array(cam_s["extrinsic"]["rotation"], dtype=np.float32)
|
| 191 |
+
t = np.array(cam_s["extrinsic"]["translation"], dtype=np.float32)
|
| 192 |
+
E = np.eye(4, dtype=np.float32)
|
| 193 |
+
E[:3, :3] = R
|
| 194 |
+
E[:3, 3] = t
|
| 195 |
+
|
| 196 |
+
sd_rec = None
|
| 197 |
+
ep_rec = None
|
| 198 |
+
if K is None and calibration is not None:
|
| 199 |
+
norm_path = str(img_path).replace("\\", "/").lstrip("./")
|
| 200 |
+
for cand in [img_path, norm_path]:
|
| 201 |
+
sd_rec = calibration["sample_data_by_filename"].get(cand)
|
| 202 |
+
if sd_rec:
|
| 203 |
+
break
|
| 204 |
+
if sd_rec is None:
|
| 205 |
+
for prefix in ("samples/", "sweeps/"):
|
| 206 |
+
if prefix in norm_path:
|
| 207 |
+
sd_rec = calibration["sample_data_by_filename"].get(
|
| 208 |
+
norm_path[norm_path.index(prefix):]
|
| 209 |
+
)
|
| 210 |
+
if sd_rec:
|
| 211 |
+
break
|
| 212 |
+
if sd_rec is not None:
|
| 213 |
+
cs = calibration["calibrated_sensor_by_token"].get(
|
| 214 |
+
sd_rec.get("calibrated_sensor_token")
|
| 215 |
+
)
|
| 216 |
+
ep_rec = calibration["ego_pose_by_token"].get(
|
| 217 |
+
sd_rec.get("ego_pose_token")
|
| 218 |
+
)
|
| 219 |
+
if cs is not None:
|
| 220 |
+
K = np.array(cs["camera_intrinsic"], dtype=np.float32)
|
| 221 |
+
E = np.eye(4, dtype=np.float32)
|
| 222 |
+
E[:3, :3] = _quat_to_rot(cs["rotation"])
|
| 223 |
+
E[:3, 3] = np.array(cs["translation"], dtype=np.float32)
|
| 224 |
+
|
| 225 |
+
if K is not None and E is not None:
|
| 226 |
+
K_adj = K.copy()
|
| 227 |
+
K_adj[0, 0] *= w_scale
|
| 228 |
+
K_adj[0, 2] *= w_scale
|
| 229 |
+
K_adj[1, 1] *= h_scale
|
| 230 |
+
K_adj[1, 2] *= h_scale
|
| 231 |
+
ego2cam = np.linalg.inv(E.astype(np.float32))
|
| 232 |
+
K4 = np.eye(4, dtype=np.float32)
|
| 233 |
+
K4[:3, :3] = K_adj
|
| 234 |
+
|
| 235 |
+
lidar2img_mat = K4 @ ego2cam
|
| 236 |
+
if calibration is not None and sd_rec is not None and ep_rec is not None:
|
| 237 |
+
sample_tk = sd_rec.get("sample_token")
|
| 238 |
+
if sample_tk:
|
| 239 |
+
ego2global_c = np.eye(4, dtype=np.float32)
|
| 240 |
+
ego2global_c[:3, :3] = _quat_to_rot(ep_rec["rotation"])
|
| 241 |
+
ego2global_c[:3, 3] = np.array(ep_rec["translation"], dtype=np.float32)
|
| 242 |
+
global2ego_c = np.linalg.inv(ego2global_c)
|
| 243 |
+
|
| 244 |
+
lidar_sd = calibration.get("lidar_sd_by_sample_token", {}).get(str(sample_tk))
|
| 245 |
+
if lidar_sd is not None:
|
| 246 |
+
lidar_cs = calibration["calibrated_sensor_by_token"].get(
|
| 247 |
+
lidar_sd.get("calibrated_sensor_token"))
|
| 248 |
+
lidar_ep = calibration["ego_pose_by_token"].get(
|
| 249 |
+
lidar_sd.get("ego_pose_token"))
|
| 250 |
+
if lidar_cs is not None and lidar_ep is not None:
|
| 251 |
+
lidar2ego = np.eye(4, dtype=np.float32)
|
| 252 |
+
lidar2ego[:3, :3] = _quat_to_rot(lidar_cs["rotation"])
|
| 253 |
+
lidar2ego[:3, 3] = np.array(lidar_cs["translation"], dtype=np.float32)
|
| 254 |
+
ego2global_l = np.eye(4, dtype=np.float32)
|
| 255 |
+
ego2global_l[:3, :3] = _quat_to_rot(lidar_ep["rotation"])
|
| 256 |
+
ego2global_l[:3, 3] = np.array(lidar_ep["translation"], dtype=np.float32)
|
| 257 |
+
lidar2cam = ego2cam @ global2ego_c @ ego2global_l @ lidar2ego
|
| 258 |
+
lidar2img_mat = K4 @ lidar2cam
|
| 259 |
+
|
| 260 |
+
lidar2img_list.append(torch.tensor(lidar2img_mat, dtype=torch.float32))
|
| 261 |
+
else:
|
| 262 |
+
lidar2img_list.append(torch.eye(4, dtype=torch.float32))
|
| 263 |
+
|
| 264 |
+
return torch.stack(images).unsqueeze(0), torch.stack(lidar2img_list).unsqueeze(0)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
@torch.no_grad()
|
| 268 |
+
def run_topomlp_forward(model, pixel_values, lidar2img, device):
|
| 269 |
+
imgs = pixel_values.to(device)
|
| 270 |
+
B, N = imgs.shape[:2]
|
| 271 |
+
tH, tW = 800, 1600
|
| 272 |
+
|
| 273 |
+
img_metas = [{
|
| 274 |
+
"img_shape": tuple([(tH, tW, 3)] * N),
|
| 275 |
+
"pad_shape": tuple([(tH, tW, 3)] * N),
|
| 276 |
+
"scale_factor": 1.0,
|
| 277 |
+
"te_yolov8": None,
|
| 278 |
+
} for _ in range(B)]
|
| 279 |
+
if lidar2img is not None:
|
| 280 |
+
for b in range(B):
|
| 281 |
+
img_metas[b]["lidar2img"] = lidar2img[b].cpu().numpy()
|
| 282 |
+
|
| 283 |
+
return model.simple_forward(imgs, img_metas)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def extract_adapter_inputs(outs):
|
| 287 |
+
return {
|
| 288 |
+
"lc_outs_dec": outs["lc_outs_dec_list"][-1][0].cpu().half(),
|
| 289 |
+
"lc_cls_scores": outs["all_lc_cls_scores_list"][-1][0].cpu().half(),
|
| 290 |
+
"lc_preds": outs["all_lc_preds_list"][-1][0].cpu().half(),
|
| 291 |
+
"lc_outs_dec_o2m": outs["lc_outs_dec_one2many_list"][-1][0].cpu().half(),
|
| 292 |
+
"lc_cls_scores_o2m": outs["all_lc_cls_scores_one2many_list"][-1][0].cpu().half(),
|
| 293 |
+
"lc_preds_o2m": outs["all_lc_preds_one2many_list"][-1][0].cpu().half(),
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def main():
|
| 298 |
+
args = parse_args()
|
| 299 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 300 |
+
|
| 301 |
+
image_path_remap = {}
|
| 302 |
+
if args.image_path_remap:
|
| 303 |
+
for pair in args.image_path_remap.split(","):
|
| 304 |
+
if "=" in pair:
|
| 305 |
+
old, new = pair.split("=", 1)
|
| 306 |
+
image_path_remap[old] = new
|
| 307 |
+
|
| 308 |
+
paths = [p.strip() for p in args.data_json.split(",") if p.strip()]
|
| 309 |
+
data_items_raw = []
|
| 310 |
+
for p in paths:
|
| 311 |
+
with open(p) as f:
|
| 312 |
+
data_items_raw.extend(json.load(f))
|
| 313 |
+
|
| 314 |
+
data_items = []
|
| 315 |
+
seen_ids = set()
|
| 316 |
+
for i, item in enumerate(data_items_raw):
|
| 317 |
+
sid = str(item.get("id", f"__idx_{i}"))
|
| 318 |
+
if sid in seen_ids:
|
| 319 |
+
continue
|
| 320 |
+
seen_ids.add(sid)
|
| 321 |
+
data_items.append(item)
|
| 322 |
+
print(f"Loaded {len(data_items)} unique samples from {len(paths)} file(s)")
|
| 323 |
+
|
| 324 |
+
if args.num_shards > 1:
|
| 325 |
+
data_items = [item for i, item in enumerate(data_items) if i % args.num_shards == args.shard_id]
|
| 326 |
+
print(f"Shard {args.shard_id}/{args.num_shards}: {len(data_items)} samples")
|
| 327 |
+
|
| 328 |
+
model = load_topomlp(args.topomlp_config, args.topomlp_ckpt, device)
|
| 329 |
+
|
| 330 |
+
calibration = _load_nuscenes_calibration(args.data_root) if args.data_root else None
|
| 331 |
+
|
| 332 |
+
image_transform = transforms.Compose([
|
| 333 |
+
transforms.ToTensor(),
|
| 334 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 335 |
+
])
|
| 336 |
+
|
| 337 |
+
output_dir = Path(args.output_dir)
|
| 338 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 339 |
+
|
| 340 |
+
total = len(data_items)
|
| 341 |
+
num_saved = 0
|
| 342 |
+
num_existed = 0
|
| 343 |
+
t0 = time.time()
|
| 344 |
+
|
| 345 |
+
for idx, item in enumerate(data_items):
|
| 346 |
+
sample_id = str(item.get("id", idx))
|
| 347 |
+
out_path = output_dir / f"{sample_id}.pt"
|
| 348 |
+
|
| 349 |
+
if out_path.exists():
|
| 350 |
+
num_existed += 1
|
| 351 |
+
else:
|
| 352 |
+
pixel_values, lidar2img = load_and_preprocess_images(
|
| 353 |
+
item, args.data_root, image_path_remap, image_transform,
|
| 354 |
+
calibration=calibration,
|
| 355 |
+
)
|
| 356 |
+
outs = run_topomlp_forward(model, pixel_values, lidar2img, device)
|
| 357 |
+
torch.save(extract_adapter_inputs(outs), out_path)
|
| 358 |
+
num_saved += 1
|
| 359 |
+
|
| 360 |
+
done = num_saved + num_existed
|
| 361 |
+
if done % 200 == 0:
|
| 362 |
+
elapsed = time.time() - t0
|
| 363 |
+
rate = done / max(elapsed, 1)
|
| 364 |
+
eta = (total - done) / max(rate, 0.01)
|
| 365 |
+
print(f" [{done}/{total}] saved={num_saved} existed={num_existed} "
|
| 366 |
+
f"{elapsed:.0f}s elapsed, ETA {eta:.0f}s")
|
| 367 |
+
|
| 368 |
+
elapsed = time.time() - t0
|
| 369 |
+
print(f"Done. saved={num_saved}, existed={num_existed}, total={total}, time={elapsed:.0f}s")
|
| 370 |
+
print(f"Output: {output_dir}")
|
| 371 |
+
|
| 372 |
+
index = {}
|
| 373 |
+
for pt in output_dir.glob("*.pt"):
|
| 374 |
+
index[pt.stem] = pt.name
|
| 375 |
+
with open(output_dir / "index.json", "w") as f:
|
| 376 |
+
json.dump(index, f)
|
| 377 |
+
print(f"Index written: {len(index)} entries")
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
if __name__ == "__main__":
|
| 381 |
+
main()
|
scripts/eval_checkpoint_offline.sh
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# [OFFLINE fallback] Evaluate using precomputed *_offline visual tokens.
|
| 3 |
+
# This is NOT the default. Use scripts/eval_checkpoint.sh for online mode.
|
| 4 |
+
# Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
if [ "${ATLAS_ALLOW_OFFLINE}" != "1" ]; then
|
| 8 |
+
echo "ERROR: This is an OFFLINE fallback script, not the primary online evaluation." >&2
|
| 9 |
+
echo "It is isolated by default to prevent accidental use in experiments." >&2
|
| 10 |
+
echo "If you really need it, set: ATLAS_ALLOW_OFFLINE=1" >&2
|
| 11 |
+
echo "For production evaluation use: bash scripts/eval_checkpoint.sh" >&2
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 16 |
+
cd "$PROJECT_ROOT"
|
| 17 |
+
|
| 18 |
+
CHECKPOINT="${1:?Usage: $0 <checkpoint> <data_json> [output_json] [max_samples]}"
|
| 19 |
+
DATA_JSON="${2:?Usage: $0 <checkpoint> <data_json> [output_json] [max_samples]}"
|
| 20 |
+
OUTPUT_JSON="${3:-}"
|
| 21 |
+
MAX_SAMPLES="${4:-0}"
|
| 22 |
+
PLANNING_TABLE3_MODE="${PLANNING_TABLE3_MODE:-atlas_base}"
|
| 23 |
+
|
| 24 |
+
EXTRA_ARGS=""
|
| 25 |
+
if [ -n "$OUTPUT_JSON" ]; then
|
| 26 |
+
EXTRA_ARGS="$EXTRA_ARGS --output_json $OUTPUT_JSON"
|
| 27 |
+
fi
|
| 28 |
+
if [ "$MAX_SAMPLES" -gt 0 ] 2>/dev/null; then
|
| 29 |
+
EXTRA_ARGS="$EXTRA_ARGS --max_samples $MAX_SAMPLES"
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
python eval_atlas.py \
|
| 33 |
+
--checkpoint "$CHECKPOINT" \
|
| 34 |
+
--llm_model pretrained/vicuna-7b-v1.5 \
|
| 35 |
+
--data_json "$DATA_JSON" \
|
| 36 |
+
--data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
|
| 37 |
+
--visual_token_mode offline \
|
| 38 |
+
--planning_table3_mode "$PLANNING_TABLE3_MODE" \
|
| 39 |
+
--precomputed_det_tokens work_dirs/precomputed_det_tokens_offline/val \
|
| 40 |
+
--precomputed_map_tokens work_dirs/precomputed_map_tokens_offline/val \
|
| 41 |
+
--bf16 \
|
| 42 |
+
--batch_size 1 \
|
| 43 |
+
--num_workers 2 \
|
| 44 |
+
$EXTRA_ARGS
|
scripts/gen_atlas_caption_dashscope.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Atlas Caption 数据生成脚本 - Dashscope 版
|
| 3 |
+
|
| 4 |
+
与 gen_atlas_caption_qa.py 完全相同的输出格式,
|
| 5 |
+
支持 --start/--end 指定 keyframe 范围,写入独立文件,最终合并。
|
| 6 |
+
|
| 7 |
+
模型: qwen-vl-max-latest (Dashscope)
|
| 8 |
+
"""
|
| 9 |
+
import asyncio
|
| 10 |
+
import json
|
| 11 |
+
import base64
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import time
|
| 15 |
+
import signal
|
| 16 |
+
from io import BytesIO
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import httpx
|
| 21 |
+
from PIL import Image
|
| 22 |
+
except ImportError:
|
| 23 |
+
print("pip install httpx Pillow")
|
| 24 |
+
sys.exit(1)
|
| 25 |
+
|
| 26 |
+
NUSCENES_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes"
|
| 27 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 28 |
+
|
| 29 |
+
CAMERAS = [
|
| 30 |
+
"CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT",
|
| 31 |
+
"CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
GPT4V_PROMPT = (
|
| 35 |
+
"Describe the current traffic conditions. "
|
| 36 |
+
"If there are traffic lights in the image, describe the status of all the traffic lights, "
|
| 37 |
+
"including any countdowns; if there are none, please do not respond. "
|
| 38 |
+
"If there are traffic signs in the picture, identify and explain each one; "
|
| 39 |
+
"if there are none, no explanation is necessary. "
|
| 40 |
+
"If there are other vehicles in the picture, describe them in more detail. "
|
| 41 |
+
"Please ensure the answer does not exceed 600 words. Answers must be in English."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
TRAIN_PROMPTS = [
|
| 45 |
+
(
|
| 46 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 47 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 48 |
+
"Communicate a narrative of the setting within {camera_name} view image."
|
| 49 |
+
),
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
API_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
|
| 53 |
+
MODEL = "qwen-vl-max-latest"
|
| 54 |
+
|
| 55 |
+
MAX_CONCURRENCY = 50
|
| 56 |
+
MAX_RETRIES = 3
|
| 57 |
+
RETRY_DELAY = 3
|
| 58 |
+
TIMEOUT = 60
|
| 59 |
+
CHECKPOINT_INTERVAL = 100
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def image_to_base64(path):
|
| 63 |
+
img = Image.open(path)
|
| 64 |
+
buf = BytesIO()
|
| 65 |
+
img.save(buf, format="JPEG", quality=80)
|
| 66 |
+
return base64.b64encode(buf.getvalue()).decode()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
async def call_api(client, api_key, image_b64, camera_name):
|
| 70 |
+
content = [
|
| 71 |
+
{"type": "text", "text": f"[{camera_name}] {GPT4V_PROMPT}"},
|
| 72 |
+
{"type": "image_url", "image_url": {
|
| 73 |
+
"url": f"data:image/jpeg;base64,{image_b64}",
|
| 74 |
+
}},
|
| 75 |
+
]
|
| 76 |
+
payload = {
|
| 77 |
+
"model": MODEL,
|
| 78 |
+
"messages": [{"role": "user", "content": content}],
|
| 79 |
+
"max_tokens": 800,
|
| 80 |
+
"temperature": 0.3,
|
| 81 |
+
}
|
| 82 |
+
headers = {
|
| 83 |
+
"Authorization": f"Bearer {api_key}",
|
| 84 |
+
"Content-Type": "application/json",
|
| 85 |
+
}
|
| 86 |
+
resp = await client.post(API_URL, json=payload, headers=headers, timeout=TIMEOUT)
|
| 87 |
+
resp.raise_for_status()
|
| 88 |
+
data = resp.json()
|
| 89 |
+
msg = data["choices"][0]["message"]["content"].strip()
|
| 90 |
+
usage = data.get("usage", {})
|
| 91 |
+
return msg, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
async def process_one_view(client, api_key, sample, cam_idx, sem, stats):
|
| 95 |
+
cam = CAMERAS[cam_idx]
|
| 96 |
+
img_path = os.path.join(NUSCENES_ROOT, sample["image_paths"][cam_idx])
|
| 97 |
+
if not os.path.exists(img_path):
|
| 98 |
+
stats["skipped"] += 1
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
img_b64 = image_to_base64(img_path)
|
| 102 |
+
train_prompt = TRAIN_PROMPTS[0].format(camera_name=cam)
|
| 103 |
+
|
| 104 |
+
for attempt in range(MAX_RETRIES):
|
| 105 |
+
async with sem:
|
| 106 |
+
try:
|
| 107 |
+
caption, in_tok, out_tok = await call_api(client, api_key, img_b64, cam)
|
| 108 |
+
stats["success"] += 1
|
| 109 |
+
stats["total_in"] += in_tok
|
| 110 |
+
stats["total_out"] += out_tok
|
| 111 |
+
return {
|
| 112 |
+
"id": sample["id"],
|
| 113 |
+
"image_paths": sample["image_paths"],
|
| 114 |
+
"num_map_queries": 0,
|
| 115 |
+
"task": "caption",
|
| 116 |
+
"camera": cam,
|
| 117 |
+
"conversations": [
|
| 118 |
+
{"from": "human", "value": train_prompt},
|
| 119 |
+
{"from": "gpt", "value": caption},
|
| 120 |
+
],
|
| 121 |
+
}
|
| 122 |
+
except httpx.TimeoutException:
|
| 123 |
+
stats["retries"] += 1
|
| 124 |
+
if attempt < MAX_RETRIES - 1:
|
| 125 |
+
await asyncio.sleep(RETRY_DELAY * (attempt + 1))
|
| 126 |
+
except httpx.HTTPStatusError as e:
|
| 127 |
+
stats["retries"] += 1
|
| 128 |
+
if e.response.status_code == 429:
|
| 129 |
+
await asyncio.sleep(RETRY_DELAY * (attempt + 2))
|
| 130 |
+
elif attempt < MAX_RETRIES - 1:
|
| 131 |
+
await asyncio.sleep(RETRY_DELAY)
|
| 132 |
+
else:
|
| 133 |
+
stats["failed"] += 1
|
| 134 |
+
return None
|
| 135 |
+
except Exception:
|
| 136 |
+
stats["retries"] += 1
|
| 137 |
+
if attempt < MAX_RETRIES - 1:
|
| 138 |
+
await asyncio.sleep(RETRY_DELAY)
|
| 139 |
+
else:
|
| 140 |
+
stats["failed"] += 1
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
stats["failed"] += 1
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def make_ckpt_key(sample_id, cam_idx):
|
| 148 |
+
return f"{sample_id}_{cam_idx}"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def load_checkpoint(path):
|
| 152 |
+
if os.path.exists(path):
|
| 153 |
+
with open(path) as f:
|
| 154 |
+
return set(json.load(f))
|
| 155 |
+
return set()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def save_checkpoint(path, done_keys):
|
| 159 |
+
with open(path, "w") as f:
|
| 160 |
+
json.dump(sorted(done_keys), f)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
async def run(split, start, end, dry_run=False, tag="dashscope"):
|
| 164 |
+
api_key = os.environ.get("DASHSCOPE_KEY", "")
|
| 165 |
+
if not api_key:
|
| 166 |
+
print("ERROR: set DASHSCOPE_KEY env var", flush=True)
|
| 167 |
+
sys.exit(1)
|
| 168 |
+
|
| 169 |
+
data_file = PROJECT_ROOT / f"data/atlas_nuscenes_{split}.json"
|
| 170 |
+
out_file = PROJECT_ROOT / f"data/atlas_caption_{split}_{tag}.json"
|
| 171 |
+
ckpt_file = PROJECT_ROOT / f"data/.caption_{split}_{tag}_checkpoint.json"
|
| 172 |
+
|
| 173 |
+
with open(data_file) as f:
|
| 174 |
+
all_samples = json.load(f)
|
| 175 |
+
|
| 176 |
+
all_samples = all_samples[start:end]
|
| 177 |
+
print(f"Range: [{start}:{end}] = {len(all_samples)} keyframes", flush=True)
|
| 178 |
+
|
| 179 |
+
done_keys = load_checkpoint(ckpt_file)
|
| 180 |
+
existing_results = []
|
| 181 |
+
if os.path.exists(out_file) and done_keys:
|
| 182 |
+
with open(out_file) as f:
|
| 183 |
+
existing_results = json.load(f)
|
| 184 |
+
|
| 185 |
+
todo = []
|
| 186 |
+
for s in all_samples:
|
| 187 |
+
for cam_idx in range(6):
|
| 188 |
+
key = make_ckpt_key(s["id"], cam_idx)
|
| 189 |
+
if key not in done_keys:
|
| 190 |
+
todo.append((s, cam_idx))
|
| 191 |
+
|
| 192 |
+
total = len(todo)
|
| 193 |
+
print(f"Split: {split}, Tag: {tag}", flush=True)
|
| 194 |
+
print(f"Total keyframes: {len(all_samples)}", flush=True)
|
| 195 |
+
print(f"Total views to caption: {len(all_samples) * 6}", flush=True)
|
| 196 |
+
print(f"Already done: {len(done_keys)}", flush=True)
|
| 197 |
+
print(f"To process: {total}", flush=True)
|
| 198 |
+
print(f"Model: {MODEL}", flush=True)
|
| 199 |
+
print(f"Concurrency: {MAX_CONCURRENCY}", flush=True)
|
| 200 |
+
if dry_run:
|
| 201 |
+
print("DRY RUN", flush=True)
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
stats = {"success": 0, "failed": 0, "skipped": 0, "retries": 0,
|
| 205 |
+
"total_in": 0, "total_out": 0}
|
| 206 |
+
results = list(existing_results)
|
| 207 |
+
sem = asyncio.Semaphore(MAX_CONCURRENCY)
|
| 208 |
+
client = httpx.AsyncClient()
|
| 209 |
+
|
| 210 |
+
shutdown = False
|
| 211 |
+
def handle_signal(sig, frame):
|
| 212 |
+
nonlocal shutdown
|
| 213 |
+
shutdown = True
|
| 214 |
+
print("\nGraceful shutdown...", flush=True)
|
| 215 |
+
signal.signal(signal.SIGINT, handle_signal)
|
| 216 |
+
|
| 217 |
+
t0 = time.time()
|
| 218 |
+
batch_size = CHECKPOINT_INTERVAL
|
| 219 |
+
for batch_start in range(0, total, batch_size):
|
| 220 |
+
if shutdown:
|
| 221 |
+
break
|
| 222 |
+
batch = todo[batch_start:batch_start + batch_size]
|
| 223 |
+
tasks = [process_one_view(client, api_key, s, ci, sem, stats) for s, ci in batch]
|
| 224 |
+
batch_results = await asyncio.gather(*tasks)
|
| 225 |
+
|
| 226 |
+
for (s, ci), r in zip(batch, batch_results):
|
| 227 |
+
if r is not None:
|
| 228 |
+
results.append(r)
|
| 229 |
+
done_keys.add(make_ckpt_key(s["id"], ci))
|
| 230 |
+
|
| 231 |
+
with open(out_file, "w") as f:
|
| 232 |
+
json.dump(results, f, ensure_ascii=False)
|
| 233 |
+
save_checkpoint(ckpt_file, done_keys)
|
| 234 |
+
|
| 235 |
+
elapsed = time.time() - t0
|
| 236 |
+
done_n = batch_start + len(batch)
|
| 237 |
+
rps = stats["success"] / elapsed if elapsed > 0 else 0
|
| 238 |
+
eta = (total - done_n) / rps / 3600 if rps > 0 else 0
|
| 239 |
+
pct = done_n / total * 100
|
| 240 |
+
|
| 241 |
+
print(
|
| 242 |
+
f" [{pct:5.1f}%] {done_n}/{total} | "
|
| 243 |
+
f"ok={stats['success']} fail={stats['failed']} retry={stats['retries']} | "
|
| 244 |
+
f"{rps:.2f} rps | ETA {eta:.1f}h | "
|
| 245 |
+
f"tok: {stats['total_in']/1e6:.1f}M in + {stats['total_out']/1e6:.1f}M out",
|
| 246 |
+
flush=True,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
await client.aclose()
|
| 250 |
+
|
| 251 |
+
elapsed = time.time() - t0
|
| 252 |
+
print(f"\nDone in {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)
|
| 253 |
+
print(f"Results: {len(results)} captions saved to {out_file}", flush=True)
|
| 254 |
+
print(f"Stats: {json.dumps(stats)}", flush=True)
|
| 255 |
+
total_tok = stats["total_in"] + stats["total_out"]
|
| 256 |
+
cost_in = stats["total_in"] / 1000 * 0.003
|
| 257 |
+
cost_out = stats["total_out"] / 1000 * 0.009
|
| 258 |
+
print(f"Total tokens: {total_tok:,} | Cost: ¥{cost_in + cost_out:.1f}", flush=True)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
import argparse
|
| 263 |
+
parser = argparse.ArgumentParser()
|
| 264 |
+
parser.add_argument("--split", default="train", choices=["train", "val"])
|
| 265 |
+
parser.add_argument("--start", type=int, required=True, help="Start keyframe index (inclusive)")
|
| 266 |
+
parser.add_argument("--end", type=int, required=True, help="End keyframe index (exclusive)")
|
| 267 |
+
parser.add_argument("--tag", default="dashscope", help="Output file tag")
|
| 268 |
+
parser.add_argument("--dry-run", action="store_true")
|
| 269 |
+
parser.add_argument("--concurrency", type=int, default=50)
|
| 270 |
+
args = parser.parse_args()
|
| 271 |
+
MAX_CONCURRENCY = args.concurrency
|
| 272 |
+
asyncio.run(run(args.split, args.start, args.end, args.dry_run, args.tag))
|
scripts/gen_atlas_caption_qa.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Atlas Caption 数据生成脚本 (论文对齐版)
|
| 3 |
+
|
| 4 |
+
与论文 Appendix A.3 完全对齐:
|
| 5 |
+
- 每个 keyframe 的 6 个摄像头各自独立生成 caption
|
| 6 |
+
- 使用论文 Table 8 中的 GPT-4V prompt
|
| 7 |
+
- human prompt 使用论文 Figure 5 风格的单视角模板
|
| 8 |
+
- 输出样本显式写入 `task="caption"` 与 `camera`
|
| 9 |
+
- 每个 keyframe 产出 6 条 QA,总计 ~204K 条 (34K x 6)
|
| 10 |
+
- 训练 prompt 与 src/prompting.py 中的 CAPTION_PROMPTS 保持一致
|
| 11 |
+
|
| 12 |
+
支持: 异步并发、断点续传、自动重试
|
| 13 |
+
"""
|
| 14 |
+
import asyncio
|
| 15 |
+
import json
|
| 16 |
+
import base64
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import time
|
| 20 |
+
import signal
|
| 21 |
+
from io import BytesIO
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import httpx
|
| 26 |
+
from PIL import Image
|
| 27 |
+
except ImportError:
|
| 28 |
+
print("pip install httpx Pillow")
|
| 29 |
+
sys.exit(1)
|
| 30 |
+
|
| 31 |
+
NUSCENES_ROOT = "/home/guoyuanbo/autodl-tmp/data/nuscenes"
|
| 32 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 33 |
+
|
| 34 |
+
CAMERAS = [
|
| 35 |
+
"CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_FRONT_LEFT",
|
| 36 |
+
"CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
GPT4V_PROMPT = (
|
| 40 |
+
"Describe the current traffic conditions. "
|
| 41 |
+
"If there are traffic lights in the image, describe the status of all the traffic lights, "
|
| 42 |
+
"including any countdowns; if there are none, please do not respond. "
|
| 43 |
+
"If there are traffic signs in the picture, identify and explain each one; "
|
| 44 |
+
"if there are none, no explanation is necessary. "
|
| 45 |
+
"If there are other vehicles in the picture, describe them in more detail. "
|
| 46 |
+
"Please ensure the answer does not exceed 600 words. Answers must be in English."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
TRAIN_PROMPTS = [
|
| 50 |
+
(
|
| 51 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 52 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 53 |
+
"Communicate a narrative of the setting within {camera_name} view image."
|
| 54 |
+
),
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
API_URL = "https://openrouter.fans/v1/chat/completions"
|
| 58 |
+
MODEL = "Qwen/Qwen3-VL-235B-A22B-Instruct"
|
| 59 |
+
|
| 60 |
+
MAX_CONCURRENCY = 30
|
| 61 |
+
MAX_RETRIES = 3
|
| 62 |
+
RETRY_DELAY = 5
|
| 63 |
+
TIMEOUT = 90
|
| 64 |
+
CHECKPOINT_INTERVAL = 100
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def image_to_base64(path):
|
| 68 |
+
img = Image.open(path)
|
| 69 |
+
buf = BytesIO()
|
| 70 |
+
img.save(buf, format="JPEG", quality=80)
|
| 71 |
+
return base64.b64encode(buf.getvalue()).decode()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
async def call_api(client, api_key, image_b64, camera_name):
|
| 75 |
+
content = [
|
| 76 |
+
{"type": "text", "text": f"[{camera_name}] {GPT4V_PROMPT}"},
|
| 77 |
+
{"type": "image_url", "image_url": {
|
| 78 |
+
"url": f"data:image/jpeg;base64,{image_b64}",
|
| 79 |
+
}},
|
| 80 |
+
]
|
| 81 |
+
payload = {
|
| 82 |
+
"model": MODEL,
|
| 83 |
+
"messages": [{"role": "user", "content": content}],
|
| 84 |
+
"max_tokens": 800,
|
| 85 |
+
"temperature": 0.3,
|
| 86 |
+
}
|
| 87 |
+
headers = {
|
| 88 |
+
"Authorization": f"Bearer {api_key}",
|
| 89 |
+
"Content-Type": "application/json",
|
| 90 |
+
}
|
| 91 |
+
resp = await client.post(API_URL, json=payload, headers=headers, timeout=TIMEOUT)
|
| 92 |
+
resp.raise_for_status()
|
| 93 |
+
data = resp.json()
|
| 94 |
+
msg = data["choices"][0]["message"]["content"].strip()
|
| 95 |
+
usage = data.get("usage", {})
|
| 96 |
+
return msg, usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
async def process_one_view(client, api_key, sample, cam_idx, sem, stats):
|
| 100 |
+
cam = CAMERAS[cam_idx]
|
| 101 |
+
img_path = os.path.join(NUSCENES_ROOT, sample["image_paths"][cam_idx])
|
| 102 |
+
if not os.path.exists(img_path):
|
| 103 |
+
stats["skipped"] += 1
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
img_b64 = image_to_base64(img_path)
|
| 107 |
+
train_prompt = TRAIN_PROMPTS[0].format(camera_name=cam)
|
| 108 |
+
|
| 109 |
+
for attempt in range(MAX_RETRIES):
|
| 110 |
+
async with sem:
|
| 111 |
+
try:
|
| 112 |
+
caption, in_tok, out_tok = await call_api(client, api_key, img_b64, cam)
|
| 113 |
+
stats["success"] += 1
|
| 114 |
+
stats["total_in"] += in_tok
|
| 115 |
+
stats["total_out"] += out_tok
|
| 116 |
+
return {
|
| 117 |
+
"id": sample["id"],
|
| 118 |
+
"image_paths": sample["image_paths"],
|
| 119 |
+
"num_map_queries": 0,
|
| 120 |
+
"task": "caption",
|
| 121 |
+
"camera": cam,
|
| 122 |
+
"segment_id": sample.get("segment_id", ""),
|
| 123 |
+
"timestamp": sample.get("timestamp", None),
|
| 124 |
+
"conversations": [
|
| 125 |
+
{"from": "human", "value": train_prompt},
|
| 126 |
+
{"from": "gpt", "value": caption},
|
| 127 |
+
],
|
| 128 |
+
}
|
| 129 |
+
except httpx.TimeoutException:
|
| 130 |
+
stats["retries"] += 1
|
| 131 |
+
if attempt < MAX_RETRIES - 1:
|
| 132 |
+
await asyncio.sleep(RETRY_DELAY * (attempt + 1))
|
| 133 |
+
except httpx.HTTPStatusError as e:
|
| 134 |
+
stats["retries"] += 1
|
| 135 |
+
if e.response.status_code == 429:
|
| 136 |
+
await asyncio.sleep(RETRY_DELAY * (attempt + 2))
|
| 137 |
+
elif attempt < MAX_RETRIES - 1:
|
| 138 |
+
await asyncio.sleep(RETRY_DELAY)
|
| 139 |
+
else:
|
| 140 |
+
stats["failed"] += 1
|
| 141 |
+
return None
|
| 142 |
+
except Exception:
|
| 143 |
+
stats["retries"] += 1
|
| 144 |
+
if attempt < MAX_RETRIES - 1:
|
| 145 |
+
await asyncio.sleep(RETRY_DELAY)
|
| 146 |
+
else:
|
| 147 |
+
stats["failed"] += 1
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
stats["failed"] += 1
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def make_ckpt_key(sample_id, cam_idx):
|
| 155 |
+
return f"{sample_id}_{cam_idx}"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def load_checkpoint(path):
|
| 159 |
+
if os.path.exists(path):
|
| 160 |
+
with open(path) as f:
|
| 161 |
+
return set(json.load(f))
|
| 162 |
+
return set()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def save_checkpoint(path, done_keys):
|
| 166 |
+
with open(path, "w") as f:
|
| 167 |
+
json.dump(sorted(done_keys), f)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
async def run(split, dry_run=False, limit=None):
|
| 171 |
+
api_key = os.environ.get("OPENROUTER_KEY", "")
|
| 172 |
+
if not api_key:
|
| 173 |
+
print("ERROR: set OPENROUTER_KEY env var", flush=True)
|
| 174 |
+
sys.exit(1)
|
| 175 |
+
|
| 176 |
+
data_file = PROJECT_ROOT / f"data/atlas_nuscenes_{split}.json"
|
| 177 |
+
out_file = PROJECT_ROOT / f"data/atlas_caption_{split}.json"
|
| 178 |
+
ckpt_file = PROJECT_ROOT / f"data/.caption_{split}_checkpoint.json"
|
| 179 |
+
|
| 180 |
+
with open(data_file) as f:
|
| 181 |
+
all_samples = json.load(f)
|
| 182 |
+
|
| 183 |
+
if limit:
|
| 184 |
+
all_samples = all_samples[:limit]
|
| 185 |
+
|
| 186 |
+
done_keys = load_checkpoint(ckpt_file)
|
| 187 |
+
existing_results = []
|
| 188 |
+
if os.path.exists(out_file) and done_keys:
|
| 189 |
+
with open(out_file) as f:
|
| 190 |
+
existing_results = json.load(f)
|
| 191 |
+
|
| 192 |
+
todo = []
|
| 193 |
+
for s in all_samples:
|
| 194 |
+
for cam_idx in range(6):
|
| 195 |
+
key = make_ckpt_key(s["id"], cam_idx)
|
| 196 |
+
if key not in done_keys:
|
| 197 |
+
todo.append((s, cam_idx))
|
| 198 |
+
|
| 199 |
+
total = len(todo)
|
| 200 |
+
print(f"Split: {split}", flush=True)
|
| 201 |
+
print(f"Total keyframes: {len(all_samples)}", flush=True)
|
| 202 |
+
print(f"Total views to caption: {len(all_samples) * 6}", flush=True)
|
| 203 |
+
print(f"Already done: {len(done_keys)}", flush=True)
|
| 204 |
+
print(f"To process: {total}", flush=True)
|
| 205 |
+
if dry_run:
|
| 206 |
+
print("DRY RUN", flush=True)
|
| 207 |
+
return
|
| 208 |
+
|
| 209 |
+
stats = {"success": 0, "failed": 0, "skipped": 0, "retries": 0,
|
| 210 |
+
"total_in": 0, "total_out": 0}
|
| 211 |
+
results = list(existing_results)
|
| 212 |
+
sem = asyncio.Semaphore(MAX_CONCURRENCY)
|
| 213 |
+
client = httpx.AsyncClient()
|
| 214 |
+
|
| 215 |
+
shutdown = False
|
| 216 |
+
def handle_signal(sig, frame):
|
| 217 |
+
nonlocal shutdown
|
| 218 |
+
shutdown = True
|
| 219 |
+
print("\nGraceful shutdown...", flush=True)
|
| 220 |
+
signal.signal(signal.SIGINT, handle_signal)
|
| 221 |
+
|
| 222 |
+
t0 = time.time()
|
| 223 |
+
batch_size = CHECKPOINT_INTERVAL
|
| 224 |
+
for batch_start in range(0, total, batch_size):
|
| 225 |
+
if shutdown:
|
| 226 |
+
break
|
| 227 |
+
batch = todo[batch_start:batch_start + batch_size]
|
| 228 |
+
tasks = [process_one_view(client, api_key, s, ci, sem, stats) for s, ci in batch]
|
| 229 |
+
batch_results = await asyncio.gather(*tasks)
|
| 230 |
+
|
| 231 |
+
for (s, ci), r in zip(batch, batch_results):
|
| 232 |
+
if r is not None:
|
| 233 |
+
results.append(r)
|
| 234 |
+
done_keys.add(make_ckpt_key(s["id"], ci))
|
| 235 |
+
|
| 236 |
+
with open(out_file, "w") as f:
|
| 237 |
+
json.dump(results, f, ensure_ascii=False)
|
| 238 |
+
save_checkpoint(ckpt_file, done_keys)
|
| 239 |
+
|
| 240 |
+
elapsed = time.time() - t0
|
| 241 |
+
done_n = batch_start + len(batch)
|
| 242 |
+
rps = stats["success"] / elapsed if elapsed > 0 else 0
|
| 243 |
+
eta = (total - done_n) / rps / 3600 if rps > 0 else 0
|
| 244 |
+
pct = done_n / total * 100
|
| 245 |
+
|
| 246 |
+
print(
|
| 247 |
+
f" [{pct:5.1f}%] {done_n}/{total} | "
|
| 248 |
+
f"ok={stats['success']} fail={stats['failed']} retry={stats['retries']} | "
|
| 249 |
+
f"{rps:.2f} rps | ETA {eta:.1f}h | "
|
| 250 |
+
f"tok: {stats['total_in']/1e6:.1f}M in + {stats['total_out']/1e6:.1f}M out",
|
| 251 |
+
flush=True,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
await client.aclose()
|
| 255 |
+
|
| 256 |
+
elapsed = time.time() - t0
|
| 257 |
+
print(f"\nDone in {elapsed:.0f}s ({elapsed/60:.1f}min)", flush=True)
|
| 258 |
+
print(f"Results: {len(results)} captions saved to {out_file}", flush=True)
|
| 259 |
+
print(f"Stats: {json.dumps(stats)}", flush=True)
|
| 260 |
+
total_tok = stats["total_in"] + stats["total_out"]
|
| 261 |
+
cost_rmb = total_tok / 50e6 * 40
|
| 262 |
+
print(f"Total tokens: {total_tok:,} | Est cost: ¥{cost_rmb:.1f}", flush=True)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
import argparse
|
| 267 |
+
parser = argparse.ArgumentParser()
|
| 268 |
+
parser.add_argument("--split", default="train", choices=["train", "val"])
|
| 269 |
+
parser.add_argument("--dry-run", action="store_true")
|
| 270 |
+
parser.add_argument("--limit", type=int, default=None)
|
| 271 |
+
parser.add_argument("--concurrency", type=int, default=30)
|
| 272 |
+
args = parser.parse_args()
|
| 273 |
+
MAX_CONCURRENCY = args.concurrency
|
| 274 |
+
asyncio.run(run(args.split, args.dry_run, args.limit))
|
scripts/gen_atlas_openlane_subsetB_lane_qa.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Convert OpenLane-V2 subset-B data into Atlas-style lane QA JSON."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, List, Tuple, Optional, Iterable
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
CAM_ORDER = [
|
| 14 |
+
"CAM_FRONT",
|
| 15 |
+
"CAM_FRONT_RIGHT",
|
| 16 |
+
"CAM_FRONT_LEFT",
|
| 17 |
+
"CAM_BACK",
|
| 18 |
+
"CAM_BACK_LEFT",
|
| 19 |
+
"CAM_BACK_RIGHT",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _val_to_bin(value: float, min_val: float, max_val: float, num_bins: int = 1000) -> int:
|
| 24 |
+
v = float(value)
|
| 25 |
+
if v < min_val:
|
| 26 |
+
v = min_val
|
| 27 |
+
if v > max_val:
|
| 28 |
+
v = max_val
|
| 29 |
+
t = (v - min_val) / (max_val - min_val)
|
| 30 |
+
idx = int(round(t * (num_bins - 1)))
|
| 31 |
+
if idx < 0:
|
| 32 |
+
idx = 0
|
| 33 |
+
if idx > (num_bins - 1):
|
| 34 |
+
idx = num_bins - 1
|
| 35 |
+
return idx
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _openlane_to_paper_xy(x_fwd: float, y_left: float) -> Tuple[float, float]:
|
| 39 |
+
return (-float(y_left), float(x_fwd))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _uniform_sample_points(pts: List[List[float]], k: int) -> List[List[float]]:
|
| 43 |
+
if k <= 0 or len(pts) <= k:
|
| 44 |
+
return pts
|
| 45 |
+
if k == 1:
|
| 46 |
+
# 避免 k-1 = 0 导致的除零错误,返回中点
|
| 47 |
+
mid_idx = len(pts) // 2
|
| 48 |
+
return [pts[mid_idx]]
|
| 49 |
+
n = len(pts)
|
| 50 |
+
out = []
|
| 51 |
+
for i in range(k):
|
| 52 |
+
j = int(round(i * (n - 1) / float(k - 1)))
|
| 53 |
+
out.append(pts[j])
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _lane_bev_distance(lane: Dict) -> float:
|
| 58 |
+
"""Compute BEV distance from ego (origin) to lane centroid for sorting."""
|
| 59 |
+
pts = lane.get("points", [])
|
| 60 |
+
if not pts:
|
| 61 |
+
return float('inf')
|
| 62 |
+
xs, ys = [], []
|
| 63 |
+
for p in pts:
|
| 64 |
+
if isinstance(p, (list, tuple)) and len(p) >= 2:
|
| 65 |
+
xs.append(float(p[0]))
|
| 66 |
+
ys.append(float(p[1]))
|
| 67 |
+
if not xs:
|
| 68 |
+
return float('inf')
|
| 69 |
+
cx, cy = sum(xs) / len(xs), sum(ys) / len(ys)
|
| 70 |
+
return cx * cx + cy * cy
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _lane_answer_from_centerlines(
|
| 74 |
+
lane_centerline: List[Dict],
|
| 75 |
+
max_lanes: int,
|
| 76 |
+
points_per_lane: int,
|
| 77 |
+
xy_range_m: float = 51.2,
|
| 78 |
+
z_min: float = -5.0,
|
| 79 |
+
z_max: float = 3.0,
|
| 80 |
+
) -> str:
|
| 81 |
+
lanes = list(lane_centerline or [])
|
| 82 |
+
lanes.sort(key=_lane_bev_distance)
|
| 83 |
+
if max_lanes > 0:
|
| 84 |
+
lanes = lanes[:max_lanes]
|
| 85 |
+
|
| 86 |
+
parts: List[str] = []
|
| 87 |
+
for ln in lanes:
|
| 88 |
+
pts = ln.get("points", [])
|
| 89 |
+
if not isinstance(pts, list) or not pts:
|
| 90 |
+
continue
|
| 91 |
+
pts = [p for p in pts if isinstance(p, (list, tuple)) and len(p) >= 3]
|
| 92 |
+
if not pts:
|
| 93 |
+
continue
|
| 94 |
+
pts = _uniform_sample_points([list(map(float, p[:3])) for p in pts], points_per_lane)
|
| 95 |
+
|
| 96 |
+
bins: List[str] = []
|
| 97 |
+
for x, y, z in pts:
|
| 98 |
+
x_p, y_p = _openlane_to_paper_xy(x, y)
|
| 99 |
+
if abs(x_p) > xy_range_m or abs(y_p) > xy_range_m:
|
| 100 |
+
continue
|
| 101 |
+
if z < z_min or z > z_max:
|
| 102 |
+
continue
|
| 103 |
+
xb = _val_to_bin(x_p, -xy_range_m, xy_range_m, 1000)
|
| 104 |
+
yb = _val_to_bin(y_p, -xy_range_m, xy_range_m, 1000)
|
| 105 |
+
zb = _val_to_bin(z, z_min, z_max, 1000)
|
| 106 |
+
bins.append(f"[{xb}, {yb}, {zb}]")
|
| 107 |
+
|
| 108 |
+
if bins:
|
| 109 |
+
parts.append(", ".join(bins))
|
| 110 |
+
|
| 111 |
+
if not parts:
|
| 112 |
+
return "No lane centerlines detected within range."
|
| 113 |
+
return "Lane: " + "; ".join(parts) + "."
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _prompt_lane_qa() -> str:
|
| 117 |
+
import sys, os
|
| 118 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 119 |
+
try:
|
| 120 |
+
from src.prompting import sample_prompt
|
| 121 |
+
return sample_prompt("lane")
|
| 122 |
+
except ImportError:
|
| 123 |
+
return (
|
| 124 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 125 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 126 |
+
"Please complete the centerline detection task under the Bird's Eye View (BEV) perspective. "
|
| 127 |
+
"Ensure that the detection range does not exceed 50 meters."
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def iter_info_jsons(root: Path, split: str) -> Iterable[Path]:
|
| 132 |
+
base = root / split
|
| 133 |
+
if not base.exists():
|
| 134 |
+
return []
|
| 135 |
+
return base.glob("*/info/*.json")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def write_json_list_stream(out_path: Path, items: Iterable[Dict]) -> None:
|
| 139 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 141 |
+
f.write("[\n")
|
| 142 |
+
first = True
|
| 143 |
+
for it in items:
|
| 144 |
+
if not first:
|
| 145 |
+
f.write(",\n")
|
| 146 |
+
first = False
|
| 147 |
+
json.dump(it, f, ensure_ascii=False)
|
| 148 |
+
f.write("\n]\n")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def main() -> None:
|
| 152 |
+
ap = argparse.ArgumentParser()
|
| 153 |
+
ap.add_argument("--openlane_root", type=str, default="/home/guoyuanbo/autodl-tmp/OpenLane-V2")
|
| 154 |
+
ap.add_argument(
|
| 155 |
+
"--no_absolute_image_paths",
|
| 156 |
+
action="store_true",
|
| 157 |
+
default=False,
|
| 158 |
+
)
|
| 159 |
+
ap.add_argument("--split", type=str, default="train", choices=["train", "val", "test"])
|
| 160 |
+
ap.add_argument("--out_json", type=str, required=True)
|
| 161 |
+
ap.add_argument("--max_samples", type=int, default=0, help="0 means all")
|
| 162 |
+
ap.add_argument("--max_lanes", type=int, default=0,
|
| 163 |
+
help="Max lanes per sample (0=no limit, paper does not specify a cap)")
|
| 164 |
+
ap.add_argument("--points_per_lane", type=int, default=4)
|
| 165 |
+
ap.add_argument("--include_raw_lane", action="store_true", default=False)
|
| 166 |
+
ap.add_argument("--num_map_queries", type=int, default=256)
|
| 167 |
+
args = ap.parse_args()
|
| 168 |
+
|
| 169 |
+
root = Path(args.openlane_root)
|
| 170 |
+
out_path = Path(args.out_json)
|
| 171 |
+
|
| 172 |
+
def _gen():
|
| 173 |
+
n = 0
|
| 174 |
+
for p in iter_info_jsons(root, args.split):
|
| 175 |
+
if args.max_samples and n >= int(args.max_samples):
|
| 176 |
+
break
|
| 177 |
+
try:
|
| 178 |
+
d = json.loads(p.read_text(encoding="utf-8", errors="replace"))
|
| 179 |
+
except Exception:
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
sensor = d.get("sensor", None)
|
| 183 |
+
ann = d.get("annotation", {})
|
| 184 |
+
lane_centerline = ann.get("lane_centerline", [])
|
| 185 |
+
|
| 186 |
+
image_paths: List[str] = []
|
| 187 |
+
use_absolute = not args.no_absolute_image_paths
|
| 188 |
+
if isinstance(sensor, dict) and all(cam in sensor for cam in CAM_ORDER):
|
| 189 |
+
for cam in CAM_ORDER:
|
| 190 |
+
rp = str(sensor[cam].get("image_path"))
|
| 191 |
+
if use_absolute:
|
| 192 |
+
image_paths.append(str((root / rp).resolve()))
|
| 193 |
+
else:
|
| 194 |
+
image_paths.append(rp)
|
| 195 |
+
else:
|
| 196 |
+
seq_dir = p.parent.parent
|
| 197 |
+
ts = p.stem
|
| 198 |
+
for cam in CAM_ORDER:
|
| 199 |
+
rp = str((Path(args.split) / seq_dir.name / "image" / cam / f"{ts}.jpg").as_posix())
|
| 200 |
+
if use_absolute:
|
| 201 |
+
image_paths.append(str((root / rp).resolve()))
|
| 202 |
+
else:
|
| 203 |
+
image_paths.append(rp)
|
| 204 |
+
|
| 205 |
+
if use_absolute:
|
| 206 |
+
missing = [ip for ip in image_paths if not Path(ip).exists()]
|
| 207 |
+
else:
|
| 208 |
+
missing = [ip for ip in image_paths if not (root / ip).exists()]
|
| 209 |
+
if missing:
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
answer = _lane_answer_from_centerlines(
|
| 213 |
+
lane_centerline=lane_centerline,
|
| 214 |
+
max_lanes=int(args.max_lanes),
|
| 215 |
+
points_per_lane=int(args.points_per_lane),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
prompt = _prompt_lane_qa()
|
| 219 |
+
|
| 220 |
+
sample_id = f"openlane_subsetB_{args.split}_{d.get('segment_id','seg')}_{d.get('timestamp','ts')}"
|
| 221 |
+
it: Dict = {
|
| 222 |
+
"id": sample_id,
|
| 223 |
+
"image_paths": image_paths,
|
| 224 |
+
"num_map_queries": int(args.num_map_queries),
|
| 225 |
+
"task": "lane",
|
| 226 |
+
"sensor": sensor,
|
| 227 |
+
"pose": d.get("pose", None),
|
| 228 |
+
"timestamp": d.get("timestamp", None),
|
| 229 |
+
"segment_id": d.get("segment_id", None),
|
| 230 |
+
"meta_data": d.get("meta_data", None),
|
| 231 |
+
"conversations": [
|
| 232 |
+
{"from": "human", "value": prompt},
|
| 233 |
+
{"from": "gpt", "value": answer},
|
| 234 |
+
],
|
| 235 |
+
}
|
| 236 |
+
if args.include_raw_lane:
|
| 237 |
+
it["openlane_lane_centerline"] = lane_centerline
|
| 238 |
+
n += 1
|
| 239 |
+
if n % 1000 == 0:
|
| 240 |
+
print(f"[progress] wrote_samples={n}")
|
| 241 |
+
yield it
|
| 242 |
+
|
| 243 |
+
print(f"[start] openlane_root={root} split={args.split} out_json={out_path}")
|
| 244 |
+
write_json_list_stream(out_path, _gen())
|
| 245 |
+
print("[done]")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
main()
|
| 250 |
+
|
| 251 |
+
|
scripts/gen_atlas_planning_qa.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import math
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from collections import Counter
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 15 |
+
from src.prompting import PLANNING_TABLE3_MODES, rewrite_planning_prompt_for_table3
|
| 16 |
+
|
| 17 |
+
Z_MIN, Z_MAX = -5.0, 3.0
|
| 18 |
+
VEL_ACC_RANGE = (-50.0, 50.0)
|
| 19 |
+
XY_RANGE = (-51.2, 51.2)
|
| 20 |
+
NUM_BINS = 1000
|
| 21 |
+
WAYPOINT_DT = 0.5
|
| 22 |
+
NUM_WAYPOINTS = 6
|
| 23 |
+
# Official UniAD get_sdc_planning_label() uses the terminal lateral offset
|
| 24 |
+
# (RIGHT if x >= 2, LEFT if x <= -2, else FORWARD). Our waypoints are already
|
| 25 |
+
# in Atlas paper frame, where x is lateral-right and y is forward.
|
| 26 |
+
UNIAD_COMMAND_X_THRESHOLD = 2.0
|
| 27 |
+
|
| 28 |
+
CAMERA_NAMES = [
|
| 29 |
+
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
|
| 30 |
+
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _val_to_bin(value: float, min_val: float, max_val: float, num_bins: int = NUM_BINS) -> int:
|
| 35 |
+
v = float(np.clip(value, min_val, max_val))
|
| 36 |
+
t = (v - min_val) / (max_val - min_val)
|
| 37 |
+
idx = int(round(t * (num_bins - 1)))
|
| 38 |
+
return int(np.clip(idx, 0, num_bins - 1))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _nuscenes_to_paper_xy(x_fwd: float, y_left: float) -> Tuple[float, float]:
|
| 42 |
+
return (-float(y_left), float(x_fwd))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _derive_uniad_style_command(
|
| 46 |
+
waypoints: List[List[float]],
|
| 47 |
+
lateral_threshold: float = UNIAD_COMMAND_X_THRESHOLD,
|
| 48 |
+
) -> str:
|
| 49 |
+
"""Derive a 3-way planning command from future GT waypoints.
|
| 50 |
+
|
| 51 |
+
This intentionally matches the semantics of UniAD's
|
| 52 |
+
`get_sdc_planning_label()`: the final valid future position determines a
|
| 53 |
+
coarse RIGHT / LEFT / FORWARD command based on lateral displacement.
|
| 54 |
+
"""
|
| 55 |
+
valid_waypoints: List[Tuple[float, float]] = []
|
| 56 |
+
for wp in waypoints:
|
| 57 |
+
if not isinstance(wp, (list, tuple)) or len(wp) < 2:
|
| 58 |
+
continue
|
| 59 |
+
x = float(wp[0])
|
| 60 |
+
y = float(wp[1])
|
| 61 |
+
if np.isfinite(x) and np.isfinite(y):
|
| 62 |
+
valid_waypoints.append((x, y))
|
| 63 |
+
|
| 64 |
+
if not valid_waypoints:
|
| 65 |
+
return "go straight"
|
| 66 |
+
|
| 67 |
+
target_x = float(valid_waypoints[-1][0])
|
| 68 |
+
if target_x >= lateral_threshold:
|
| 69 |
+
return "turn right"
|
| 70 |
+
if target_x <= -lateral_threshold:
|
| 71 |
+
return "turn left"
|
| 72 |
+
return "go straight"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _compute_velocity(nusc, sample) -> Tuple[float, float]:
|
| 76 |
+
try:
|
| 77 |
+
lidar_token = sample['data']['LIDAR_TOP']
|
| 78 |
+
lidar_data = nusc.get('sample_data', lidar_token)
|
| 79 |
+
from pyquaternion import Quaternion
|
| 80 |
+
|
| 81 |
+
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
|
| 82 |
+
ego_t = np.array(ego_pose['translation'])
|
| 83 |
+
ego_q = Quaternion(ego_pose['rotation'])
|
| 84 |
+
|
| 85 |
+
prev_token = lidar_data.get('prev', '')
|
| 86 |
+
if prev_token:
|
| 87 |
+
prev_data = nusc.get('sample_data', prev_token)
|
| 88 |
+
prev_ego = nusc.get('ego_pose', prev_data['ego_pose_token'])
|
| 89 |
+
prev_t = np.array(prev_ego['translation'])
|
| 90 |
+
dt = (lidar_data['timestamp'] - prev_data['timestamp']) * 1e-6
|
| 91 |
+
if dt > 0:
|
| 92 |
+
vel_global = (ego_t - prev_t) / dt
|
| 93 |
+
vel_ego = ego_q.inverse.rotate(vel_global)
|
| 94 |
+
return float(vel_ego[0]), float(vel_ego[1])
|
| 95 |
+
except Exception:
|
| 96 |
+
pass
|
| 97 |
+
return 0.0, 0.0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _compute_acceleration(nusc, sample) -> Tuple[float, float]:
|
| 101 |
+
try:
|
| 102 |
+
lidar_token = sample['data']['LIDAR_TOP']
|
| 103 |
+
lidar_data = nusc.get('sample_data', lidar_token)
|
| 104 |
+
from pyquaternion import Quaternion
|
| 105 |
+
|
| 106 |
+
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
|
| 107 |
+
ego_q = Quaternion(ego_pose['rotation'])
|
| 108 |
+
|
| 109 |
+
prev_token = lidar_data.get('prev', '')
|
| 110 |
+
if not prev_token:
|
| 111 |
+
return 0.0, 0.0
|
| 112 |
+
prev_data = nusc.get('sample_data', prev_token)
|
| 113 |
+
dt1 = (lidar_data['timestamp'] - prev_data['timestamp']) * 1e-6
|
| 114 |
+
if dt1 <= 0:
|
| 115 |
+
return 0.0, 0.0
|
| 116 |
+
|
| 117 |
+
prev2_token = prev_data.get('prev', '')
|
| 118 |
+
if not prev2_token:
|
| 119 |
+
return 0.0, 0.0
|
| 120 |
+
prev2_data = nusc.get('sample_data', prev2_token)
|
| 121 |
+
dt2 = (prev_data['timestamp'] - prev2_data['timestamp']) * 1e-6
|
| 122 |
+
if dt2 <= 0:
|
| 123 |
+
return 0.0, 0.0
|
| 124 |
+
|
| 125 |
+
def _ego_vel(sd1, sd2, dt_val):
|
| 126 |
+
e1 = nusc.get('ego_pose', sd1['ego_pose_token'])
|
| 127 |
+
e2 = nusc.get('ego_pose', sd2['ego_pose_token'])
|
| 128 |
+
t1 = np.array(e1['translation'])
|
| 129 |
+
t2 = np.array(e2['translation'])
|
| 130 |
+
return (t1 - t2) / dt_val
|
| 131 |
+
|
| 132 |
+
v1_global = _ego_vel(lidar_data, prev_data, dt1)
|
| 133 |
+
v0_global = _ego_vel(prev_data, prev2_data, dt2)
|
| 134 |
+
acc_global = (v1_global - v0_global) / dt1
|
| 135 |
+
acc_ego = ego_q.inverse.rotate(acc_global)
|
| 136 |
+
return float(acc_ego[0]), float(acc_ego[1])
|
| 137 |
+
except Exception:
|
| 138 |
+
return 0.0, 0.0
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _get_future_waypoints(nusc, sample) -> Optional[List[List[float]]]:
|
| 142 |
+
try:
|
| 143 |
+
from pyquaternion import Quaternion
|
| 144 |
+
|
| 145 |
+
lidar_token = sample['data']['LIDAR_TOP']
|
| 146 |
+
lidar_data = nusc.get('sample_data', lidar_token)
|
| 147 |
+
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
|
| 148 |
+
ego_t = np.array(ego_pose['translation'])
|
| 149 |
+
ego_q = Quaternion(ego_pose['rotation'])
|
| 150 |
+
|
| 151 |
+
current_ts = lidar_data['timestamp']
|
| 152 |
+
target_times = [current_ts + int(WAYPOINT_DT * (i + 1) * 1e6) for i in range(NUM_WAYPOINTS)]
|
| 153 |
+
|
| 154 |
+
all_sd = []
|
| 155 |
+
sd_token = lidar_token
|
| 156 |
+
while sd_token:
|
| 157 |
+
sd = nusc.get('sample_data', sd_token)
|
| 158 |
+
all_sd.append(sd)
|
| 159 |
+
sd_token = sd.get('next', '')
|
| 160 |
+
if not sd_token:
|
| 161 |
+
break
|
| 162 |
+
if sd['timestamp'] > target_times[-1] + 1e6:
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
if len(all_sd) < 2:
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
timestamps = np.array([s['timestamp'] for s in all_sd])
|
| 169 |
+
poses = []
|
| 170 |
+
for s in all_sd:
|
| 171 |
+
ep = nusc.get('ego_pose', s['ego_pose_token'])
|
| 172 |
+
poses.append(np.array(ep['translation']))
|
| 173 |
+
poses = np.array(poses)
|
| 174 |
+
|
| 175 |
+
waypoints = []
|
| 176 |
+
for tt in target_times:
|
| 177 |
+
if tt > timestamps[-1] or tt < timestamps[0]:
|
| 178 |
+
return None
|
| 179 |
+
idx = np.searchsorted(timestamps, tt, side='right') - 1
|
| 180 |
+
idx = max(0, min(idx, len(timestamps) - 2))
|
| 181 |
+
dt_seg = timestamps[idx + 1] - timestamps[idx]
|
| 182 |
+
if dt_seg <= 0:
|
| 183 |
+
return None
|
| 184 |
+
alpha = (tt - timestamps[idx]) / dt_seg
|
| 185 |
+
pos_global = poses[idx] * (1 - alpha) + poses[idx + 1] * alpha
|
| 186 |
+
pos_ego = ego_q.inverse.rotate(pos_global - ego_t)
|
| 187 |
+
x_p, y_p = _nuscenes_to_paper_xy(pos_ego[0], pos_ego[1])
|
| 188 |
+
waypoints.append([float(x_p), float(y_p)])
|
| 189 |
+
|
| 190 |
+
return waypoints
|
| 191 |
+
except Exception:
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _format_planning_answer(
|
| 196 |
+
vx: float, vy: float, ax: float, ay: float,
|
| 197 |
+
waypoints: List[List[float]],
|
| 198 |
+
) -> str:
|
| 199 |
+
vx_bin = _val_to_bin(vx, *VEL_ACC_RANGE)
|
| 200 |
+
vy_bin = _val_to_bin(vy, *VEL_ACC_RANGE)
|
| 201 |
+
ax_bin = _val_to_bin(ax, *VEL_ACC_RANGE)
|
| 202 |
+
ay_bin = _val_to_bin(ay, *VEL_ACC_RANGE)
|
| 203 |
+
|
| 204 |
+
wp_strs = []
|
| 205 |
+
for wp in waypoints:
|
| 206 |
+
xb = _val_to_bin(wp[0], *XY_RANGE)
|
| 207 |
+
yb = _val_to_bin(wp[1], *XY_RANGE)
|
| 208 |
+
wp_strs.append(f"[{xb}, {yb}]")
|
| 209 |
+
return (
|
| 210 |
+
f"Ego car speed value:[{vx_bin}, {vy_bin}]. "
|
| 211 |
+
f"Ego car acceleration value:[{ax_bin}, {ay_bin}]. "
|
| 212 |
+
"Based on the ego car speed and acceleration you predicted, "
|
| 213 |
+
f"request the ego car planning waypoint in 3-seconds: {', '.join(wp_strs)}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _collect_gt_boxes_ego(nusc, sample) -> List[Dict]:
|
| 218 |
+
from pyquaternion import Quaternion
|
| 219 |
+
|
| 220 |
+
lidar_token = sample['data']['LIDAR_TOP']
|
| 221 |
+
lidar_data = nusc.get('sample_data', lidar_token)
|
| 222 |
+
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
|
| 223 |
+
ego_t = np.array(ego_pose['translation'])
|
| 224 |
+
ego_q = Quaternion(ego_pose['rotation'])
|
| 225 |
+
|
| 226 |
+
cs_record = nusc.get('calibrated_sensor', lidar_data['calibrated_sensor_token'])
|
| 227 |
+
cs_t = np.array(cs_record['translation'])
|
| 228 |
+
cs_q = Quaternion(cs_record['rotation'])
|
| 229 |
+
|
| 230 |
+
boxes = []
|
| 231 |
+
for ann_token in sample['anns']:
|
| 232 |
+
ann = nusc.get('sample_annotation', ann_token)
|
| 233 |
+
center_global = np.array(ann['translation'])
|
| 234 |
+
center_ego = ego_q.inverse.rotate(center_global - ego_t)
|
| 235 |
+
x_p, y_p = _nuscenes_to_paper_xy(center_ego[0], center_ego[1])
|
| 236 |
+
yaw_global = Quaternion(ann['rotation'])
|
| 237 |
+
yaw_ego = ego_q.inverse * yaw_global
|
| 238 |
+
# _nuscenes_to_paper_xy applies a 90° CCW rotation:
|
| 239 |
+
# x_paper = -y_ego, y_paper = x_ego
|
| 240 |
+
# Yaw must be rotated by the same +π/2 to stay consistent.
|
| 241 |
+
yaw_angle = float(yaw_ego.yaw_pitch_roll[0]) + math.pi / 2.0
|
| 242 |
+
w = float(ann['size'][0])
|
| 243 |
+
l = float(ann['size'][1])
|
| 244 |
+
h = float(ann['size'][2])
|
| 245 |
+
boxes.append({
|
| 246 |
+
"world_coords": [float(x_p), float(y_p), float(center_ego[2])],
|
| 247 |
+
"w": w,
|
| 248 |
+
"l": l,
|
| 249 |
+
"h": h,
|
| 250 |
+
"yaw": yaw_angle,
|
| 251 |
+
"category": ann['category_name'],
|
| 252 |
+
})
|
| 253 |
+
return boxes
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _collect_gt_boxes_per_timestep(nusc, sample, num_timesteps=NUM_WAYPOINTS) -> List[List[Dict]]:
|
| 257 |
+
"""Collect GT boxes for each future keyframe, transformed to current ego frame.
|
| 258 |
+
|
| 259 |
+
ST-P3 protocol: at each future timestep t, collision is checked against
|
| 260 |
+
the actual positions of other agents at time t, not their positions at t=0.
|
| 261 |
+
nuScenes keyframes are ~0.5s apart, matching the waypoint interval.
|
| 262 |
+
"""
|
| 263 |
+
from pyquaternion import Quaternion
|
| 264 |
+
|
| 265 |
+
lidar_token = sample['data']['LIDAR_TOP']
|
| 266 |
+
lidar_data = nusc.get('sample_data', lidar_token)
|
| 267 |
+
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
|
| 268 |
+
ref_ego_t = np.array(ego_pose['translation'])
|
| 269 |
+
ref_ego_q = Quaternion(ego_pose['rotation'])
|
| 270 |
+
|
| 271 |
+
per_timestep_boxes: List[List[Dict]] = []
|
| 272 |
+
cur_sample = sample
|
| 273 |
+
for _ in range(num_timesteps):
|
| 274 |
+
next_token = cur_sample.get('next', '')
|
| 275 |
+
if not next_token:
|
| 276 |
+
per_timestep_boxes.append(per_timestep_boxes[-1] if per_timestep_boxes else [])
|
| 277 |
+
continue
|
| 278 |
+
|
| 279 |
+
cur_sample = nusc.get('sample', next_token)
|
| 280 |
+
boxes = []
|
| 281 |
+
for ann_token in cur_sample['anns']:
|
| 282 |
+
ann = nusc.get('sample_annotation', ann_token)
|
| 283 |
+
center_global = np.array(ann['translation'])
|
| 284 |
+
center_ego = ref_ego_q.inverse.rotate(center_global - ref_ego_t)
|
| 285 |
+
x_p, y_p = _nuscenes_to_paper_xy(center_ego[0], center_ego[1])
|
| 286 |
+
|
| 287 |
+
yaw_global = Quaternion(ann['rotation'])
|
| 288 |
+
yaw_ego = ref_ego_q.inverse * yaw_global
|
| 289 |
+
yaw_angle = float(yaw_ego.yaw_pitch_roll[0]) + math.pi / 2.0
|
| 290 |
+
|
| 291 |
+
w = float(ann['size'][0])
|
| 292 |
+
l = float(ann['size'][1])
|
| 293 |
+
h = float(ann['size'][2])
|
| 294 |
+
boxes.append({
|
| 295 |
+
"world_coords": [float(x_p), float(y_p), float(center_ego[2])],
|
| 296 |
+
"w": w, "l": l, "h": h,
|
| 297 |
+
"yaw": yaw_angle,
|
| 298 |
+
"category": ann['category_name'],
|
| 299 |
+
})
|
| 300 |
+
per_timestep_boxes.append(boxes)
|
| 301 |
+
|
| 302 |
+
return per_timestep_boxes
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def process_sample(
|
| 306 |
+
nusc,
|
| 307 |
+
sample_token: str,
|
| 308 |
+
data_root: Path,
|
| 309 |
+
planning_table3_mode: str,
|
| 310 |
+
) -> Optional[Dict]:
|
| 311 |
+
try:
|
| 312 |
+
from pyquaternion import Quaternion
|
| 313 |
+
from src.prompting import sample_prompt
|
| 314 |
+
|
| 315 |
+
sample = nusc.get('sample', sample_token)
|
| 316 |
+
|
| 317 |
+
image_paths = []
|
| 318 |
+
for cam_name in CAMERA_NAMES:
|
| 319 |
+
if cam_name in sample['data']:
|
| 320 |
+
cam_token = sample['data'][cam_name]
|
| 321 |
+
cam_data = nusc.get('sample_data', cam_token)
|
| 322 |
+
image_paths.append(cam_data['filename'].replace('\\', '/'))
|
| 323 |
+
|
| 324 |
+
if len(image_paths) != 6:
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
+
vx_n, vy_n = _compute_velocity(nusc, sample)
|
| 328 |
+
ax_n, ay_n = _compute_acceleration(nusc, sample)
|
| 329 |
+
|
| 330 |
+
vx_p, vy_p = _nuscenes_to_paper_xy(vx_n, vy_n)
|
| 331 |
+
ax_p, ay_p = _nuscenes_to_paper_xy(ax_n, ay_n)
|
| 332 |
+
|
| 333 |
+
waypoints = _get_future_waypoints(nusc, sample)
|
| 334 |
+
if waypoints is None:
|
| 335 |
+
return None
|
| 336 |
+
|
| 337 |
+
vx_bin = _val_to_bin(vx_p, *VEL_ACC_RANGE)
|
| 338 |
+
vy_bin = _val_to_bin(vy_p, *VEL_ACC_RANGE)
|
| 339 |
+
ax_bin = _val_to_bin(ax_p, *VEL_ACC_RANGE)
|
| 340 |
+
ay_bin = _val_to_bin(ay_p, *VEL_ACC_RANGE)
|
| 341 |
+
route_command = _derive_uniad_style_command(waypoints)
|
| 342 |
+
|
| 343 |
+
prompt = sample_prompt(
|
| 344 |
+
"planning",
|
| 345 |
+
vx_bin=vx_bin, vy_bin=vy_bin,
|
| 346 |
+
ax_bin=ax_bin, ay_bin=ay_bin,
|
| 347 |
+
command=route_command,
|
| 348 |
+
)
|
| 349 |
+
prompt = rewrite_planning_prompt_for_table3(
|
| 350 |
+
prompt,
|
| 351 |
+
mode=planning_table3_mode,
|
| 352 |
+
command=route_command,
|
| 353 |
+
velocity_bins=(vx_bin, vy_bin),
|
| 354 |
+
acceleration_bins=(ax_bin, ay_bin),
|
| 355 |
+
)
|
| 356 |
+
answer = _format_planning_answer(vx_p, vy_p, ax_p, ay_p, waypoints)
|
| 357 |
+
|
| 358 |
+
gt_boxes = _collect_gt_boxes_ego(nusc, sample)
|
| 359 |
+
gt_boxes_per_ts = _collect_gt_boxes_per_timestep(nusc, sample)
|
| 360 |
+
|
| 361 |
+
item = {
|
| 362 |
+
"id": sample_token,
|
| 363 |
+
"image_paths": image_paths,
|
| 364 |
+
"num_map_queries": 256,
|
| 365 |
+
"task": "planning",
|
| 366 |
+
"segment_id": sample.get("scene_token", ""),
|
| 367 |
+
"timestamp": sample.get("timestamp", None),
|
| 368 |
+
"ego_motion": {
|
| 369 |
+
"velocity": [vx_p, vy_p],
|
| 370 |
+
"acceleration": [ax_p, ay_p],
|
| 371 |
+
"waypoints": waypoints,
|
| 372 |
+
},
|
| 373 |
+
"gt_boxes_3d": gt_boxes,
|
| 374 |
+
"gt_boxes_3d_per_timestep": gt_boxes_per_ts,
|
| 375 |
+
"conversations": [
|
| 376 |
+
{"from": "human", "value": prompt},
|
| 377 |
+
{"from": "gpt", "value": answer},
|
| 378 |
+
],
|
| 379 |
+
"route_command": route_command,
|
| 380 |
+
}
|
| 381 |
+
return item
|
| 382 |
+
except Exception:
|
| 383 |
+
return None
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _audit_results(results: List[Dict], planning_table3_mode: str) -> None:
|
| 387 |
+
total = int(len(results))
|
| 388 |
+
if total == 0:
|
| 389 |
+
print("[AUDIT] No planning samples were generated.")
|
| 390 |
+
return
|
| 391 |
+
|
| 392 |
+
route_commands = [item.get("route_command") for item in results]
|
| 393 |
+
route_command_coverage = sum(isinstance(cmd, str) and bool(cmd) for cmd in route_commands)
|
| 394 |
+
route_command_dist = Counter(route_commands)
|
| 395 |
+
legacy_ego_motion_command = sum(
|
| 396 |
+
1
|
| 397 |
+
for item in results
|
| 398 |
+
if isinstance(item.get("ego_motion"), dict) and "command" in item["ego_motion"]
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
prompt_with_command = 0
|
| 402 |
+
prompt_with_state = 0
|
| 403 |
+
for item in results:
|
| 404 |
+
conv = item.get("conversations", [])
|
| 405 |
+
if not conv:
|
| 406 |
+
continue
|
| 407 |
+
prompt_text = str(conv[0].get("value", ""))
|
| 408 |
+
if "The ego car will " in prompt_text:
|
| 409 |
+
prompt_with_command += 1
|
| 410 |
+
if "The current speed value of the ego car is [" in prompt_text:
|
| 411 |
+
prompt_with_state += 1
|
| 412 |
+
|
| 413 |
+
print(
|
| 414 |
+
"[AUDIT] planning route_command "
|
| 415 |
+
f"mode={planning_table3_mode} "
|
| 416 |
+
f"coverage={route_command_coverage}/{total} "
|
| 417 |
+
f"legacy_ego_motion_command={legacy_ego_motion_command}/{total} "
|
| 418 |
+
f"prompt_with_command={prompt_with_command}/{total} "
|
| 419 |
+
f"prompt_with_state={prompt_with_state}/{total}"
|
| 420 |
+
)
|
| 421 |
+
print(f"[AUDIT] planning route_command distribution={dict(route_command_dist)}")
|
| 422 |
+
print(
|
| 423 |
+
"[AUDIT] route_command semantics: UniAD-style future-GT-derived "
|
| 424 |
+
f"(terminal lateral x threshold={UNIAD_COMMAND_X_THRESHOLD:.1f}m)."
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def main():
|
| 429 |
+
parser = argparse.ArgumentParser()
|
| 430 |
+
parser.add_argument('--version', type=str, default='v1.0-trainval')
|
| 431 |
+
parser.add_argument('--split', type=str, default='train', choices=['train', 'val'])
|
| 432 |
+
parser.add_argument('--data-root', type=str, default='/mnt/data/nuscenes')
|
| 433 |
+
parser.add_argument('--output', type=str, default=None)
|
| 434 |
+
parser.add_argument(
|
| 435 |
+
'--planning-table3-mode',
|
| 436 |
+
type=str,
|
| 437 |
+
choices=PLANNING_TABLE3_MODES,
|
| 438 |
+
default='atlas_high_level',
|
| 439 |
+
help=(
|
| 440 |
+
'Human prompt variant to materialize in the generated JSON. '
|
| 441 |
+
'route_command is always written as a top-level UniAD-style '
|
| 442 |
+
'future-GT-derived command.'
|
| 443 |
+
),
|
| 444 |
+
)
|
| 445 |
+
args = parser.parse_args()
|
| 446 |
+
|
| 447 |
+
data_root = Path(args.data_root)
|
| 448 |
+
script_dir = Path(__file__).parent.absolute()
|
| 449 |
+
project_root = script_dir.parent
|
| 450 |
+
|
| 451 |
+
if args.output:
|
| 452 |
+
output_file = Path(args.output)
|
| 453 |
+
else:
|
| 454 |
+
output_file = project_root / "data" / f"atlas_planning_{args.split}_uniad_command.json"
|
| 455 |
+
|
| 456 |
+
from nuscenes.nuscenes import NuScenes
|
| 457 |
+
from nuscenes.utils.splits import create_splits_scenes
|
| 458 |
+
|
| 459 |
+
nusc = NuScenes(version=args.version, dataroot=str(data_root), verbose=True)
|
| 460 |
+
|
| 461 |
+
splits = create_splits_scenes()
|
| 462 |
+
split_scenes = set(splits[args.split])
|
| 463 |
+
scene_tokens = set()
|
| 464 |
+
for scene in nusc.scene:
|
| 465 |
+
if scene['name'] in split_scenes:
|
| 466 |
+
scene_tokens.add(scene['token'])
|
| 467 |
+
samples_to_process = [s for s in nusc.sample if s['scene_token'] in scene_tokens]
|
| 468 |
+
|
| 469 |
+
print(f"Processing {len(samples_to_process)} samples for planning...")
|
| 470 |
+
results = []
|
| 471 |
+
for sample in tqdm(samples_to_process):
|
| 472 |
+
item = process_sample(
|
| 473 |
+
nusc,
|
| 474 |
+
sample['token'],
|
| 475 |
+
data_root,
|
| 476 |
+
planning_table3_mode=args.planning_table3_mode,
|
| 477 |
+
)
|
| 478 |
+
if item is not None:
|
| 479 |
+
results.append(item)
|
| 480 |
+
|
| 481 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 482 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 483 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 484 |
+
|
| 485 |
+
_audit_results(results, args.planning_table3_mode)
|
| 486 |
+
print(f"Saved {len(results)} planning samples to {output_file}")
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
if __name__ == "__main__":
|
| 490 |
+
main()
|
| 491 |
+
|
scripts/run_val_extraction.sh
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Offline token extraction orchestrator.
|
| 3 |
+
# Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
if [ "${ATLAS_ALLOW_OFFLINE}" != "1" ]; then
|
| 7 |
+
echo "ERROR: This is an OFFLINE extraction orchestrator." >&2
|
| 8 |
+
echo "It is isolated by default to prevent accidental use." >&2
|
| 9 |
+
echo "If you really need it, set: ATLAS_ALLOW_OFFLINE=1" >&2
|
| 10 |
+
echo "For online training use: bash scripts/train_no_caption_baseline.sh" >&2
|
| 11 |
+
exit 1
|
| 12 |
+
fi
|
| 13 |
+
|
| 14 |
+
cd /home/guoyuanbo/3dtokenizer-atlas
|
| 15 |
+
export LD_LIBRARY_PATH=/home/guoyuanbo/3dtokenizer/envs/streampetr/lib:${LD_LIBRARY_PATH:-}
|
| 16 |
+
PY=/home/guoyuanbo/3dtokenizer/envs/streampetr/bin/python
|
| 17 |
+
|
| 18 |
+
LOG_DIR=work_dirs
|
| 19 |
+
DET_OUT=work_dirs/precomputed_det_tokens_offline/val
|
| 20 |
+
MAP_OUT=work_dirs/precomputed_map_tokens_offline/val
|
| 21 |
+
|
| 22 |
+
echo "[$(date)] === Waiting for Phase 1 (StreamPETR det) to finish ==="
|
| 23 |
+
|
| 24 |
+
# Wait for all extract_streampetr_tokens processes to finish
|
| 25 |
+
while pgrep -f "extract_streampetr_tokens.py" > /dev/null 2>&1; do
|
| 26 |
+
DET_COUNT=$(find "$DET_OUT" -name "*.pt" 2>/dev/null | wc -l)
|
| 27 |
+
echo "[$(date)] Phase 1 running... det files: $DET_COUNT / ~12038"
|
| 28 |
+
sleep 300
|
| 29 |
+
done
|
| 30 |
+
|
| 31 |
+
DET_FINAL=$(find "$DET_OUT" -name "*.pt" 2>/dev/null | wc -l)
|
| 32 |
+
echo "[$(date)] Phase 1 DONE. Total det files: $DET_FINAL"
|
| 33 |
+
|
| 34 |
+
echo "[$(date)] === Starting Phase 2 (TopoMLP map) ==="
|
| 35 |
+
mkdir -p "$MAP_OUT"
|
| 36 |
+
|
| 37 |
+
for i in 0 1 2 3; do
|
| 38 |
+
CUDA_VISIBLE_DEVICES=$i $PY extract_topomlp_tokens.py \
|
| 39 |
+
--topomlp_config configs/topomlp_atlas_aligned.py \
|
| 40 |
+
--topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
|
| 41 |
+
--data_json "data/atlas_planning_val_uniad_command.json,data/openlane_subsetB_lane_val_4pt.json" \
|
| 42 |
+
--data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
|
| 43 |
+
--output_dir "$MAP_OUT" \
|
| 44 |
+
--shard_id $i --num_shards 4 \
|
| 45 |
+
> "$LOG_DIR/extract_map_val_shard_${i}.log" 2>&1 &
|
| 46 |
+
echo "[$(date)] Phase 2 shard $i launched (PID=$!)"
|
| 47 |
+
done
|
| 48 |
+
|
| 49 |
+
echo "[$(date)] Waiting for Phase 2 to complete..."
|
| 50 |
+
wait
|
| 51 |
+
|
| 52 |
+
MAP_FINAL=$(find "$MAP_OUT" -name "*.pt" 2>/dev/null | wc -l)
|
| 53 |
+
echo "[$(date)] Phase 2 DONE. Total map files: $MAP_FINAL"
|
| 54 |
+
echo "[$(date)] === All extraction complete ==="
|
| 55 |
+
echo " Det tokens: $DET_FINAL files in $DET_OUT"
|
| 56 |
+
echo " Map tokens: $MAP_FINAL files in $MAP_OUT"
|
scripts/train_no_caption_baseline.sh
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Atlas online training: detection + planning + lane (no caption).
|
| 3 |
+
# Default: visual_token_mode=online, live frozen StreamPETR + TopoMLP.
|
| 4 |
+
# Data: 28k det + 24k plan + 28k lane = ~80k samples/epoch
|
| 5 |
+
#
|
| 6 |
+
# Usage:
|
| 7 |
+
# bash scripts/train_no_caption_baseline.sh
|
| 8 |
+
# RESUME_CKPT=work_dirs/atlas_no_caption/epoch-4/checkpoint.pt bash scripts/train_no_caption_baseline.sh
|
| 9 |
+
# NUM_GPUS=8 bash scripts/train_no_caption_baseline.sh
|
| 10 |
+
#
|
| 11 |
+
# For offline mode (read precomputed *_offline token dirs):
|
| 12 |
+
# bash scripts/train_no_caption_baseline_offline.sh
|
| 13 |
+
set -e
|
| 14 |
+
|
| 15 |
+
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 16 |
+
cd "$PROJECT_ROOT"
|
| 17 |
+
|
| 18 |
+
NUM_GPUS=${NUM_GPUS:-4}
|
| 19 |
+
PLANNING_TABLE3_MODE=${PLANNING_TABLE3_MODE:-atlas_high_level_ego}
|
| 20 |
+
|
| 21 |
+
EXTRA_ARGS=""
|
| 22 |
+
if [ -n "$RESUME_CKPT" ]; then
|
| 23 |
+
EXTRA_ARGS="--resume $RESUME_CKPT"
|
| 24 |
+
fi
|
| 25 |
+
|
| 26 |
+
deepspeed --num_gpus "$NUM_GPUS" train_atlas.py \
|
| 27 |
+
--llm_model pretrained/vicuna-7b-v1.5 \
|
| 28 |
+
--data_json data/atlas_nuscenes_train.json,data/atlas_planning_train_uniad_command.json,data/openlane_subsetB_lane_train_4pt.json \
|
| 29 |
+
--data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
|
| 30 |
+
--visual_token_mode online \
|
| 31 |
+
--planning_table3_mode "$PLANNING_TABLE3_MODE" \
|
| 32 |
+
--streampetr_config configs/streampetr_atlas_aligned.py \
|
| 33 |
+
--streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
|
| 34 |
+
--topomlp_config configs/topomlp_atlas_aligned.py \
|
| 35 |
+
--topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
|
| 36 |
+
--deepspeed configs/ds_zero2.json \
|
| 37 |
+
--output_dir work_dirs/atlas_no_caption_online \
|
| 38 |
+
--epochs 10 \
|
| 39 |
+
--lr 2e-5 \
|
| 40 |
+
--weight_decay 1e-4 \
|
| 41 |
+
--batch_size 1 \
|
| 42 |
+
--gradient_accumulation_steps 2 \
|
| 43 |
+
--warmup_ratio 0.03 \
|
| 44 |
+
--max_grad_norm 1.0 \
|
| 45 |
+
--log_steps 100 \
|
| 46 |
+
--save_epochs 1 \
|
| 47 |
+
--keep_last_n_ckpts 3 \
|
| 48 |
+
--seed 42 \
|
| 49 |
+
--num_workers 4 \
|
| 50 |
+
$EXTRA_ARGS
|
scripts/train_no_caption_baseline_offline.sh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# [OFFLINE fallback] Atlas training using precomputed *_offline visual tokens.
|
| 3 |
+
# This is NOT the default. Use scripts/train_no_caption_baseline.sh for online mode.
|
| 4 |
+
# Isolated by default. Set ATLAS_ALLOW_OFFLINE=1 to run.
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
if [ "${ATLAS_ALLOW_OFFLINE}" != "1" ]; then
|
| 8 |
+
echo "ERROR: This is an OFFLINE fallback script, not the primary online training." >&2
|
| 9 |
+
echo "It is isolated by default to prevent accidental use in experiments." >&2
|
| 10 |
+
echo "If you really need it, set: ATLAS_ALLOW_OFFLINE=1" >&2
|
| 11 |
+
echo "For production training use: bash scripts/train_no_caption_baseline.sh" >&2
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 16 |
+
cd "$PROJECT_ROOT"
|
| 17 |
+
|
| 18 |
+
NUM_GPUS=${NUM_GPUS:-4}
|
| 19 |
+
PLANNING_TABLE3_MODE=${PLANNING_TABLE3_MODE:-atlas_base}
|
| 20 |
+
|
| 21 |
+
EXTRA_ARGS=""
|
| 22 |
+
if [ -n "$RESUME_CKPT" ]; then
|
| 23 |
+
EXTRA_ARGS="--resume $RESUME_CKPT"
|
| 24 |
+
fi
|
| 25 |
+
|
| 26 |
+
deepspeed --num_gpus "$NUM_GPUS" train_atlas.py \
|
| 27 |
+
--llm_model pretrained/vicuna-7b-v1.5 \
|
| 28 |
+
--data_json data/atlas_nuscenes_train.json,data/atlas_planning_train_uniad_command.json,data/openlane_subsetB_lane_train_4pt.json \
|
| 29 |
+
--data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
|
| 30 |
+
--visual_token_mode offline \
|
| 31 |
+
--planning_table3_mode "$PLANNING_TABLE3_MODE" \
|
| 32 |
+
--precomputed_det_tokens work_dirs/precomputed_det_tokens_offline/train \
|
| 33 |
+
--precomputed_map_tokens work_dirs/precomputed_map_tokens_offline/train \
|
| 34 |
+
--deepspeed configs/ds_zero2.json \
|
| 35 |
+
--output_dir work_dirs/atlas_no_caption_offline \
|
| 36 |
+
--epochs 10 \
|
| 37 |
+
--lr 2e-5 \
|
| 38 |
+
--weight_decay 1e-4 \
|
| 39 |
+
--batch_size 1 \
|
| 40 |
+
--gradient_accumulation_steps 2 \
|
| 41 |
+
--warmup_ratio 0.03 \
|
| 42 |
+
--max_grad_norm 1.0 \
|
| 43 |
+
--log_steps 100 \
|
| 44 |
+
--save_epochs 1 \
|
| 45 |
+
--keep_last_n_ckpts 3 \
|
| 46 |
+
--seed 42 \
|
| 47 |
+
--num_workers 4 \
|
| 48 |
+
$EXTRA_ARGS
|
scripts/train_with_caption_balanced.sh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Atlas online training: detection + planning + lane + caption.
|
| 3 |
+
# Default: visual_token_mode=online, live frozen StreamPETR + TopoMLP.
|
| 4 |
+
# Data: 28k det + 24k plan + 28k lane + 169k caption = ~249k samples/epoch
|
| 5 |
+
# WARNING: caption data is 6x larger than other tasks — consider downsampling.
|
| 6 |
+
#
|
| 7 |
+
# Usage:
|
| 8 |
+
# bash scripts/train_with_caption_balanced.sh
|
| 9 |
+
# RESUME_CKPT=work_dirs/atlas_with_caption/epoch-4/checkpoint.pt bash scripts/train_with_caption_balanced.sh
|
| 10 |
+
# NUM_GPUS=8 bash scripts/train_with_caption_balanced.sh
|
| 11 |
+
set -e
|
| 12 |
+
|
| 13 |
+
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 14 |
+
cd "$PROJECT_ROOT"
|
| 15 |
+
|
| 16 |
+
NUM_GPUS=${NUM_GPUS:-4}
|
| 17 |
+
PLANNING_TABLE3_MODE=${PLANNING_TABLE3_MODE:-atlas_high_level_ego}
|
| 18 |
+
|
| 19 |
+
EXTRA_ARGS=""
|
| 20 |
+
if [ -n "$RESUME_CKPT" ]; then
|
| 21 |
+
EXTRA_ARGS="--resume $RESUME_CKPT"
|
| 22 |
+
fi
|
| 23 |
+
|
| 24 |
+
deepspeed --num_gpus "$NUM_GPUS" train_atlas.py \
|
| 25 |
+
--llm_model pretrained/vicuna-7b-v1.5 \
|
| 26 |
+
--data_json data/atlas_nuscenes_train.json,data/atlas_planning_train_uniad_command.json,data/openlane_subsetB_lane_train_4pt.json,data/atlas_caption_train.json \
|
| 27 |
+
--data_root /home/guoyuanbo/autodl-tmp/data/nuscenes \
|
| 28 |
+
--visual_token_mode online \
|
| 29 |
+
--planning_table3_mode "$PLANNING_TABLE3_MODE" \
|
| 30 |
+
--streampetr_config configs/streampetr_atlas_aligned.py \
|
| 31 |
+
--streampetr_ckpt pretrained/streampetr/streampetr_eva02_ep24.pth \
|
| 32 |
+
--topomlp_config configs/topomlp_atlas_aligned.py \
|
| 33 |
+
--topomlp_ckpt work_dirs/topomlp_atlas_aligned/epoch_24.pth \
|
| 34 |
+
--deepspeed configs/ds_zero2.json \
|
| 35 |
+
--output_dir work_dirs/atlas_with_caption_online \
|
| 36 |
+
--epochs 8 \
|
| 37 |
+
--lr 2e-5 \
|
| 38 |
+
--weight_decay 1e-4 \
|
| 39 |
+
--batch_size 1 \
|
| 40 |
+
--gradient_accumulation_steps 2 \
|
| 41 |
+
--warmup_ratio 0.03 \
|
| 42 |
+
--max_grad_norm 1.0 \
|
| 43 |
+
--log_steps 100 \
|
| 44 |
+
--save_epochs 1 \
|
| 45 |
+
--keep_last_n_ckpts 3 \
|
| 46 |
+
--seed 42 \
|
| 47 |
+
--num_workers 4 \
|
| 48 |
+
$EXTRA_ARGS
|
scripts/vis_atlas_lane_gt_pred.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Visualize one Atlas lane sample with denser diagnostics.
|
| 4 |
+
|
| 5 |
+
Compared with the earlier plot, this version:
|
| 6 |
+
- keeps the original 4-point supervision visible,
|
| 7 |
+
- densifies those 4 points into easier-to-read curves,
|
| 8 |
+
- optionally overlays raw OpenLane GT centerlines when available,
|
| 9 |
+
- keeps all BEV views on the same axes for fair visual comparison.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from scipy.interpolate import interp1d
|
| 24 |
+
|
| 25 |
+
SCIPY_AVAILABLE = True
|
| 26 |
+
except Exception:
|
| 27 |
+
interp1d = None
|
| 28 |
+
SCIPY_AVAILABLE = False
|
| 29 |
+
|
| 30 |
+
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 31 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 32 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _load_json(path: Path):
|
| 36 |
+
with path.open("r", encoding="utf-8") as f:
|
| 37 |
+
return json.load(f)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _find_eval_record(eval_obj: Dict, sample_id: str) -> Dict:
|
| 41 |
+
for rec in eval_obj.get("predictions", []):
|
| 42 |
+
if rec.get("sample_id") == sample_id:
|
| 43 |
+
return rec
|
| 44 |
+
raise KeyError(f"sample_id not found in eval_json: {sample_id}")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _find_data_item(data_list: List[Dict], sample_id: str) -> Dict:
|
| 48 |
+
for it in data_list:
|
| 49 |
+
if it.get("id") == sample_id:
|
| 50 |
+
return it
|
| 51 |
+
raise KeyError(f"sample_id not found in data_json: {sample_id}")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _extract_gt_answer(item: Dict) -> str:
|
| 55 |
+
for turn in item.get("conversations", []) or []:
|
| 56 |
+
if isinstance(turn, dict) and turn.get("from") in ("gpt", "assistant"):
|
| 57 |
+
return str(turn.get("value", ""))
|
| 58 |
+
return ""
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _openlane_to_paper_xy(x_fwd: float, y_left: float) -> Tuple[float, float]:
|
| 62 |
+
return (-float(y_left), float(x_fwd))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _extract_xy_rows(points: Sequence) -> np.ndarray:
|
| 66 |
+
rows: List[List[float]] = []
|
| 67 |
+
for pt in points or []:
|
| 68 |
+
if isinstance(pt, dict):
|
| 69 |
+
wc = pt.get("world_coords", None)
|
| 70 |
+
if wc is None:
|
| 71 |
+
continue
|
| 72 |
+
rows.append([float(wc[0]), float(wc[1])])
|
| 73 |
+
elif isinstance(pt, (list, tuple)) and len(pt) >= 2:
|
| 74 |
+
rows.append([float(pt[0]), float(pt[1])])
|
| 75 |
+
if not rows:
|
| 76 |
+
return np.zeros((0, 2), dtype=np.float64)
|
| 77 |
+
return np.asarray(rows, dtype=np.float64)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _lanes_as_arrays(parsed: List[Dict]) -> List[np.ndarray]:
|
| 81 |
+
lanes = []
|
| 82 |
+
for obj in parsed:
|
| 83 |
+
if obj.get("type") != "lane":
|
| 84 |
+
continue
|
| 85 |
+
arr = _extract_xy_rows(obj.get("points", []) or [])
|
| 86 |
+
if len(arr) >= 1:
|
| 87 |
+
lanes.append(arr)
|
| 88 |
+
return lanes
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _bounds_xy(lanes: Iterable[np.ndarray]) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
| 92 |
+
lanes = [ln for ln in lanes if len(ln) > 0]
|
| 93 |
+
if not lanes:
|
| 94 |
+
return None
|
| 95 |
+
allp = np.concatenate(lanes, axis=0)
|
| 96 |
+
return allp.min(axis=0), allp.max(axis=0)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _lane_bev_distance(lane: np.ndarray) -> float:
|
| 100 |
+
if len(lane) == 0:
|
| 101 |
+
return float("inf")
|
| 102 |
+
centroid = lane.mean(axis=0)
|
| 103 |
+
return float(centroid[0] ** 2 + centroid[1] ** 2)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _select_closest_lanes(lanes: List[np.ndarray], max_lanes: int) -> List[np.ndarray]:
|
| 107 |
+
if max_lanes <= 0 or len(lanes) <= max_lanes:
|
| 108 |
+
return list(lanes)
|
| 109 |
+
order = sorted(range(len(lanes)), key=lambda i: _lane_bev_distance(lanes[i]))
|
| 110 |
+
return [lanes[i] for i in order[:max_lanes]]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _candidate_openlane_paths(openlane_root: Path, item: Dict) -> List[Path]:
|
| 114 |
+
segment_id = str(item.get("segment_id", "")).strip()
|
| 115 |
+
timestamp = str(item.get("timestamp", "")).strip()
|
| 116 |
+
if not segment_id or not timestamp:
|
| 117 |
+
return []
|
| 118 |
+
rel = Path("val") / segment_id / "info" / f"{timestamp}.json"
|
| 119 |
+
return [
|
| 120 |
+
openlane_root / "subset_B" / rel,
|
| 121 |
+
openlane_root / rel,
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _load_raw_openlane_gt_lanes(openlane_root: Path, item: Dict) -> Tuple[List[np.ndarray], Optional[Path]]:
|
| 126 |
+
for info_path in _candidate_openlane_paths(openlane_root, item):
|
| 127 |
+
if not info_path.exists():
|
| 128 |
+
continue
|
| 129 |
+
obj = _load_json(info_path)
|
| 130 |
+
lanes = []
|
| 131 |
+
ann = obj.get("annotation", {}) or {}
|
| 132 |
+
for lane in ann.get("lane_centerline", []) or []:
|
| 133 |
+
rows = []
|
| 134 |
+
for pt in lane.get("points", []) or []:
|
| 135 |
+
if not isinstance(pt, (list, tuple)) or len(pt) < 2:
|
| 136 |
+
continue
|
| 137 |
+
rows.append(list(_openlane_to_paper_xy(float(pt[0]), float(pt[1]))))
|
| 138 |
+
if len(rows) >= 2:
|
| 139 |
+
lanes.append(np.asarray(rows, dtype=np.float64))
|
| 140 |
+
return lanes, info_path
|
| 141 |
+
return [], None
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _resample_lane(lane: np.ndarray, num: int = 41) -> np.ndarray:
|
| 145 |
+
if len(lane) <= 1 or num <= len(lane):
|
| 146 |
+
return lane.copy()
|
| 147 |
+
|
| 148 |
+
diffs = np.diff(lane, axis=0)
|
| 149 |
+
seg_len = np.linalg.norm(diffs, axis=1)
|
| 150 |
+
keep = np.concatenate(([True], seg_len > 1e-8))
|
| 151 |
+
lane = lane[keep]
|
| 152 |
+
if len(lane) <= 1:
|
| 153 |
+
return lane.copy()
|
| 154 |
+
|
| 155 |
+
diffs = np.diff(lane, axis=0)
|
| 156 |
+
seg_len = np.linalg.norm(diffs, axis=1)
|
| 157 |
+
arc = np.concatenate(([0.0], np.cumsum(seg_len)))
|
| 158 |
+
total = float(arc[-1])
|
| 159 |
+
if total <= 1e-8:
|
| 160 |
+
return lane.copy()
|
| 161 |
+
|
| 162 |
+
target = np.linspace(0.0, total, num=num)
|
| 163 |
+
if SCIPY_AVAILABLE and len(lane) >= 4:
|
| 164 |
+
try:
|
| 165 |
+
fx = interp1d(arc, lane[:, 0], kind="cubic")
|
| 166 |
+
fy = interp1d(arc, lane[:, 1], kind="cubic")
|
| 167 |
+
dense = np.stack([fx(target), fy(target)], axis=1)
|
| 168 |
+
return dense.astype(np.float64)
|
| 169 |
+
except Exception:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
x = np.interp(target, arc, lane[:, 0])
|
| 173 |
+
y = np.interp(target, arc, lane[:, 1])
|
| 174 |
+
return np.stack([x, y], axis=1).astype(np.float64)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _draw_ego(ax, *, color: str = "k"):
|
| 178 |
+
import matplotlib.patches as patches
|
| 179 |
+
|
| 180 |
+
w, l = 1.85, 4.084
|
| 181 |
+
rect = patches.Rectangle(
|
| 182 |
+
(-w / 2.0, -l / 2.0),
|
| 183 |
+
w,
|
| 184 |
+
l,
|
| 185 |
+
linewidth=1.2,
|
| 186 |
+
edgecolor=color,
|
| 187 |
+
facecolor="none",
|
| 188 |
+
zorder=8,
|
| 189 |
+
)
|
| 190 |
+
ax.add_patch(rect)
|
| 191 |
+
ax.arrow(0.0, 0.0, 0.0, l * 0.7, head_width=0.6, head_length=0.8, fc=color, ec=color, zorder=9)
|
| 192 |
+
ax.scatter([0.0], [0.0], s=18, c=color, zorder=10)
|
| 193 |
+
ax.text(0.2, -0.5, "EGO", color=color, fontsize=8, zorder=10)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _set_bev_axes(ax, *, title: str, xlim, ylim):
|
| 197 |
+
ax.set_title(title, fontsize=10)
|
| 198 |
+
ax.set_xlabel("X (m) — right", fontsize=8)
|
| 199 |
+
ax.set_ylabel("Y (m) — forward", fontsize=8)
|
| 200 |
+
ax.grid(True, linewidth=0.4, alpha=0.4)
|
| 201 |
+
ax.set_aspect("equal", adjustable="box")
|
| 202 |
+
ax.set_xlim(xlim)
|
| 203 |
+
ax.set_ylim(ylim)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _plot_colorful_panel(
|
| 207 |
+
ax,
|
| 208 |
+
lanes: List[np.ndarray],
|
| 209 |
+
*,
|
| 210 |
+
title: str,
|
| 211 |
+
xlim,
|
| 212 |
+
ylim,
|
| 213 |
+
raw_reference_lanes: Optional[List[np.ndarray]] = None,
|
| 214 |
+
show_control_points: bool = True,
|
| 215 |
+
dense_points: int = 41,
|
| 216 |
+
):
|
| 217 |
+
import matplotlib.pyplot as plt
|
| 218 |
+
from matplotlib.lines import Line2D
|
| 219 |
+
|
| 220 |
+
_set_bev_axes(ax, title=title, xlim=xlim, ylim=ylim)
|
| 221 |
+
cmap = plt.get_cmap("tab20")
|
| 222 |
+
|
| 223 |
+
if raw_reference_lanes:
|
| 224 |
+
for ln in raw_reference_lanes:
|
| 225 |
+
ax.plot(ln[:, 0], ln[:, 1], "--", linewidth=1.0, color="0.60", alpha=0.50, zorder=1)
|
| 226 |
+
|
| 227 |
+
for i, lane in enumerate(lanes):
|
| 228 |
+
color = cmap(i % 20)
|
| 229 |
+
dense = _resample_lane(lane, num=dense_points)
|
| 230 |
+
ax.plot(dense[:, 0], dense[:, 1], "-", linewidth=2.0, color=color, alpha=0.95, zorder=3)
|
| 231 |
+
if show_control_points:
|
| 232 |
+
ax.scatter(
|
| 233 |
+
lane[:, 0],
|
| 234 |
+
lane[:, 1],
|
| 235 |
+
s=20,
|
| 236 |
+
facecolors="white",
|
| 237 |
+
edgecolors=[color],
|
| 238 |
+
linewidths=0.9,
|
| 239 |
+
zorder=4,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
handles = []
|
| 243 |
+
if raw_reference_lanes:
|
| 244 |
+
handles.append(Line2D([0], [0], color="0.60", linestyle="--", linewidth=1.2, label="Raw GT"))
|
| 245 |
+
handles.append(Line2D([0], [0], color="black", linewidth=2.0, label="Densified 4pt"))
|
| 246 |
+
if show_control_points:
|
| 247 |
+
handles.append(
|
| 248 |
+
Line2D(
|
| 249 |
+
[0],
|
| 250 |
+
[0],
|
| 251 |
+
marker="o",
|
| 252 |
+
color="black",
|
| 253 |
+
markerfacecolor="white",
|
| 254 |
+
markersize=5,
|
| 255 |
+
linewidth=0.0,
|
| 256 |
+
label="Control points",
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
ax.legend(handles=handles, loc="upper left", fontsize=7, frameon=True)
|
| 260 |
+
_draw_ego(ax, color="k")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _plot_overlay_panel(
|
| 264 |
+
ax,
|
| 265 |
+
*,
|
| 266 |
+
gt_lanes: List[np.ndarray],
|
| 267 |
+
pred_lanes: List[np.ndarray],
|
| 268 |
+
raw_gt_lanes: Optional[List[np.ndarray]],
|
| 269 |
+
xlim,
|
| 270 |
+
ylim,
|
| 271 |
+
dense_points: int = 41,
|
| 272 |
+
):
|
| 273 |
+
from matplotlib.lines import Line2D
|
| 274 |
+
|
| 275 |
+
_set_bev_axes(
|
| 276 |
+
ax,
|
| 277 |
+
title="Overlay: raw GT(gray), 4pt GT(green), Pred(red)",
|
| 278 |
+
xlim=xlim,
|
| 279 |
+
ylim=ylim,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if raw_gt_lanes:
|
| 283 |
+
for lane in raw_gt_lanes:
|
| 284 |
+
ax.plot(lane[:, 0], lane[:, 1], "--", linewidth=1.1, color="0.55", alpha=0.45, zorder=1)
|
| 285 |
+
|
| 286 |
+
for lane in gt_lanes:
|
| 287 |
+
dense = _resample_lane(lane, num=dense_points)
|
| 288 |
+
ax.plot(dense[:, 0], dense[:, 1], "-", linewidth=2.0, color="green", alpha=0.78, zorder=3)
|
| 289 |
+
ax.scatter(lane[:, 0], lane[:, 1], s=18, facecolors="white", edgecolors="green", linewidths=0.8, zorder=4)
|
| 290 |
+
|
| 291 |
+
for lane in pred_lanes:
|
| 292 |
+
dense = _resample_lane(lane, num=dense_points)
|
| 293 |
+
ax.plot(dense[:, 0], dense[:, 1], "-", linewidth=2.0, color="red", alpha=0.78, zorder=5)
|
| 294 |
+
ax.scatter(lane[:, 0], lane[:, 1], s=18, marker="x", color="red", linewidths=0.9, zorder=6)
|
| 295 |
+
|
| 296 |
+
handles = [
|
| 297 |
+
Line2D([0], [0], color="0.55", linestyle="--", linewidth=1.1, label="Raw GT"),
|
| 298 |
+
Line2D([0], [0], color="green", linewidth=2.0, label="Densified GT (4pt)"),
|
| 299 |
+
Line2D([0], [0], marker="o", color="green", markerfacecolor="white", markersize=5, linewidth=0.0, label="GT control"),
|
| 300 |
+
Line2D([0], [0], color="red", linewidth=2.0, label="Densified Pred (4pt)"),
|
| 301 |
+
Line2D([0], [0], marker="x", color="red", markersize=5, linewidth=0.0, label="Pred control"),
|
| 302 |
+
]
|
| 303 |
+
ax.legend(handles=handles, loc="upper left", fontsize=7, frameon=True)
|
| 304 |
+
_draw_ego(ax, color="k")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _load_images(data_root: Path, image_paths: List[str]):
|
| 308 |
+
from PIL import Image
|
| 309 |
+
|
| 310 |
+
imgs = []
|
| 311 |
+
for rp in image_paths:
|
| 312 |
+
p = Path(rp)
|
| 313 |
+
if not p.is_absolute():
|
| 314 |
+
p = data_root / rp
|
| 315 |
+
imgs.append(Image.open(p).convert("RGB"))
|
| 316 |
+
return imgs
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def parse_args():
|
| 320 |
+
ap = argparse.ArgumentParser()
|
| 321 |
+
ap.add_argument("--eval_json", type=str, default="work_dirs/eval_nocap_ep2_lane50.json")
|
| 322 |
+
ap.add_argument("--data_json", type=str, default="data/openlane_subsetB_lane_val_4pt.json")
|
| 323 |
+
ap.add_argument("--data_root", type=str, default="/home/guoyuanbo/autodl-tmp/data/nuscenes")
|
| 324 |
+
ap.add_argument("--openlane_root", type=str, default="/home/guoyuanbo/autodl-tmp/OpenLane-V2")
|
| 325 |
+
ap.add_argument("--sample_id", type=str, required=True)
|
| 326 |
+
ap.add_argument("--out_png", type=str, default="work_dirs/atlas_lane_real_prediction_dense.png")
|
| 327 |
+
ap.add_argument("--xlim", type=float, nargs=2, default=[-30.0, 30.0])
|
| 328 |
+
ap.add_argument("--ylim", type=float, nargs=2, default=[-55.0, 55.0])
|
| 329 |
+
ap.add_argument("--dpi", type=int, default=170)
|
| 330 |
+
ap.add_argument("--dense_points", type=int, default=41)
|
| 331 |
+
return ap.parse_args()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def main():
|
| 335 |
+
args = parse_args()
|
| 336 |
+
repo = Path(__file__).resolve().parent.parent
|
| 337 |
+
eval_path = (repo / args.eval_json).resolve()
|
| 338 |
+
data_path = (repo / args.data_json).resolve()
|
| 339 |
+
data_root = Path(args.data_root).resolve()
|
| 340 |
+
openlane_root = Path(args.openlane_root).resolve()
|
| 341 |
+
out_png = (repo / args.out_png).resolve()
|
| 342 |
+
|
| 343 |
+
eval_obj = _load_json(eval_path)
|
| 344 |
+
data_list = _load_json(data_path)
|
| 345 |
+
|
| 346 |
+
rec = _find_eval_record(eval_obj, args.sample_id)
|
| 347 |
+
item = _find_data_item(data_list, args.sample_id)
|
| 348 |
+
|
| 349 |
+
pred_text = str(rec.get("generated_text", ""))
|
| 350 |
+
gt_text = _extract_gt_answer(item)
|
| 351 |
+
|
| 352 |
+
from src.eval.metrics import parse_atlas_output
|
| 353 |
+
|
| 354 |
+
pred_parsed = parse_atlas_output(pred_text)
|
| 355 |
+
gt_parsed = parse_atlas_output(gt_text)
|
| 356 |
+
|
| 357 |
+
pred_lanes = _lanes_as_arrays(pred_parsed)
|
| 358 |
+
gt_lanes_all = _lanes_as_arrays(gt_parsed)
|
| 359 |
+
raw_gt_lanes_all, raw_gt_path = _load_raw_openlane_gt_lanes(openlane_root, item)
|
| 360 |
+
|
| 361 |
+
expected_gt_count = rec.get("num_gt", 0)
|
| 362 |
+
try:
|
| 363 |
+
expected_gt_count = int(expected_gt_count)
|
| 364 |
+
except Exception:
|
| 365 |
+
expected_gt_count = 0
|
| 366 |
+
gt_lanes = _select_closest_lanes(gt_lanes_all, expected_gt_count)
|
| 367 |
+
raw_gt_lanes = _select_closest_lanes(raw_gt_lanes_all, expected_gt_count)
|
| 368 |
+
|
| 369 |
+
xlim = (float(args.xlim[0]), float(args.xlim[1]))
|
| 370 |
+
ylim = (float(args.ylim[0]), float(args.ylim[1]))
|
| 371 |
+
|
| 372 |
+
delta = None
|
| 373 |
+
if pred_lanes and gt_lanes:
|
| 374 |
+
p_all = np.concatenate(pred_lanes, axis=0)
|
| 375 |
+
g_all = np.concatenate(gt_lanes, axis=0)
|
| 376 |
+
delta = (p_all.mean(axis=0) - g_all.mean(axis=0)).tolist()
|
| 377 |
+
|
| 378 |
+
imgs = _load_images(data_root, item.get("image_paths", []) or [])
|
| 379 |
+
if len(imgs) != 6:
|
| 380 |
+
raise RuntimeError(f"expected 6 images, got {len(imgs)}")
|
| 381 |
+
|
| 382 |
+
order = [2, 0, 1, 4, 3, 5]
|
| 383 |
+
cam_titles = ["FRONT_LEFT", "FRONT", "FRONT_RIGHT", "BACK_LEFT", "BACK", "BACK_RIGHT"]
|
| 384 |
+
imgs = [imgs[i] for i in order]
|
| 385 |
+
|
| 386 |
+
import matplotlib.pyplot as plt
|
| 387 |
+
from matplotlib.gridspec import GridSpec
|
| 388 |
+
|
| 389 |
+
fig = plt.figure(figsize=(15.5, 9.4), dpi=args.dpi)
|
| 390 |
+
fig.suptitle("Atlas Lane Diagnostics — Control Points, Densified Curves, and Raw GT", fontsize=14, y=0.985)
|
| 391 |
+
|
| 392 |
+
gs = GridSpec(
|
| 393 |
+
3,
|
| 394 |
+
4,
|
| 395 |
+
figure=fig,
|
| 396 |
+
width_ratios=[1.0, 1.0, 1.0, 1.10],
|
| 397 |
+
height_ratios=[1.0, 1.0, 1.08],
|
| 398 |
+
wspace=0.24,
|
| 399 |
+
hspace=0.33,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
for i in range(6):
|
| 403 |
+
r = 0 if i < 3 else 1
|
| 404 |
+
c = i % 3
|
| 405 |
+
ax = fig.add_subplot(gs[r, c])
|
| 406 |
+
ax.imshow(imgs[i])
|
| 407 |
+
ax.set_title(cam_titles[i], fontsize=9)
|
| 408 |
+
ax.axis("off")
|
| 409 |
+
|
| 410 |
+
ax_gt = fig.add_subplot(gs[0, 3])
|
| 411 |
+
_plot_colorful_panel(
|
| 412 |
+
ax_gt,
|
| 413 |
+
gt_lanes,
|
| 414 |
+
title=f"GT: 4pt densified ({len(gt_lanes)} lanes)",
|
| 415 |
+
xlim=xlim,
|
| 416 |
+
ylim=ylim,
|
| 417 |
+
raw_reference_lanes=raw_gt_lanes or None,
|
| 418 |
+
dense_points=int(args.dense_points),
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
ax_pr = fig.add_subplot(gs[1, 3])
|
| 422 |
+
_plot_colorful_panel(
|
| 423 |
+
ax_pr,
|
| 424 |
+
pred_lanes,
|
| 425 |
+
title=f"Pred: 4pt densified ({len(pred_lanes)} lanes)",
|
| 426 |
+
xlim=xlim,
|
| 427 |
+
ylim=ylim,
|
| 428 |
+
raw_reference_lanes=None,
|
| 429 |
+
dense_points=int(args.dense_points),
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
ax_ov = fig.add_subplot(gs[2, 0:3])
|
| 433 |
+
_plot_overlay_panel(
|
| 434 |
+
ax_ov,
|
| 435 |
+
gt_lanes=gt_lanes,
|
| 436 |
+
pred_lanes=pred_lanes,
|
| 437 |
+
raw_gt_lanes=raw_gt_lanes or None,
|
| 438 |
+
xlim=xlim,
|
| 439 |
+
ylim=ylim,
|
| 440 |
+
dense_points=int(args.dense_points),
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
ax_txt = fig.add_subplot(gs[2, 3])
|
| 444 |
+
ax_txt.axis("off")
|
| 445 |
+
|
| 446 |
+
lane_metrics = eval_obj.get("metrics", {}).get("lane", {})
|
| 447 |
+
model_name = Path(str(eval_obj.get("args", {}).get("checkpoint", ""))).parent.name or "atlas"
|
| 448 |
+
lines = [
|
| 449 |
+
"Atlas Lane Diagnostics",
|
| 450 |
+
f"Model: {model_name}",
|
| 451 |
+
f"Sample: {args.sample_id}",
|
| 452 |
+
"",
|
| 453 |
+
f"GT 4pt lanes: {len(gt_lanes)} vis / {len(gt_lanes_all)} full, {sum(len(x) for x in gt_lanes)} ctrl pts",
|
| 454 |
+
f"Pred 4pt lanes: {len(pred_lanes)} lanes, {sum(len(x) for x in pred_lanes)} ctrl pts",
|
| 455 |
+
f"Raw OpenLane GT: {len(raw_gt_lanes)} vis / {len(raw_gt_lanes_all)} full, {sum(len(x) for x in raw_gt_lanes)} pts",
|
| 456 |
+
f"Curve densifier: {'cubic spline' if SCIPY_AVAILABLE else 'linear'} ({args.dense_points} pts/lane)",
|
| 457 |
+
"",
|
| 458 |
+
"Interpretation:",
|
| 459 |
+
" - white circles: GT 4pt control points",
|
| 460 |
+
" - red x: pred 4pt control points",
|
| 461 |
+
" - gray dashed: raw OpenLane GT centerlines",
|
| 462 |
+
" - green/red solid: densified 4pt curves",
|
| 463 |
+
f" - vis GT subset matches eval num_gt={expected_gt_count}" if expected_gt_count > 0 else " - vis GT subset uses full GT",
|
| 464 |
+
"",
|
| 465 |
+
"Eval metrics (from eval_json):",
|
| 466 |
+
]
|
| 467 |
+
for k in ("lane_f1", "method", "num_samples"):
|
| 468 |
+
if k in lane_metrics:
|
| 469 |
+
lines.append(f" {k}: {lane_metrics[k]}")
|
| 470 |
+
|
| 471 |
+
lines.append("")
|
| 472 |
+
if delta is not None:
|
| 473 |
+
lines.append("Mean shift (Pred - GT 4pt):")
|
| 474 |
+
lines.append(f" dx={delta[0]:+.3f} m, dy={delta[1]:+.3f} m")
|
| 475 |
+
lines.append("")
|
| 476 |
+
|
| 477 |
+
gb = _bounds_xy(gt_lanes)
|
| 478 |
+
pb = _bounds_xy(pred_lanes)
|
| 479 |
+
rb = _bounds_xy(raw_gt_lanes)
|
| 480 |
+
if gb is not None:
|
| 481 |
+
lines.append(f"GT4 bounds: x[{gb[0][0]:+.1f},{gb[1][0]:+.1f}] y[{gb[0][1]:+.1f},{gb[1][1]:+.1f}]")
|
| 482 |
+
if pb is not None:
|
| 483 |
+
lines.append(f"Pred bounds: x[{pb[0][0]:+.1f},{pb[1][0]:+.1f}] y[{pb[0][1]:+.1f},{pb[1][1]:+.1f}]")
|
| 484 |
+
if rb is not None:
|
| 485 |
+
lines.append(f"Raw bounds: x[{rb[0][0]:+.1f},{rb[1][0]:+.1f}] y[{rb[0][1]:+.1f},{rb[1][1]:+.1f}]")
|
| 486 |
+
if raw_gt_path is not None:
|
| 487 |
+
lines.append("")
|
| 488 |
+
lines.append(f"Raw GT source: {raw_gt_path}")
|
| 489 |
+
|
| 490 |
+
ax_txt.text(0.0, 1.0, "\n".join(lines), va="top", ha="left", fontsize=8, family="monospace")
|
| 491 |
+
|
| 492 |
+
out_png.parent.mkdir(parents=True, exist_ok=True)
|
| 493 |
+
fig.savefig(out_png, bbox_inches="tight")
|
| 494 |
+
plt.close(fig)
|
| 495 |
+
|
| 496 |
+
print(f"[saved] {out_png}")
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
if __name__ == "__main__":
|
| 500 |
+
main()
|
scripts/vis_atlas_planning_qualitative.py
ADDED
|
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Paper-style qualitative planning visualization (Figure-4-like):
|
| 4 |
+
|
| 5 |
+
- Left: 6-camera mosaic (2x3)
|
| 6 |
+
- Right: BEV panel with planned trajectory
|
| 7 |
+
|
| 8 |
+
Optionally overlays the planned trajectory onto CAM_FRONT using *fixed* nuScenes
|
| 9 |
+
camera calibration (loaded from v1.0-trainval sensor/calibrated_sensor tables).
|
| 10 |
+
This avoids scanning the 1.3GB sample_data.json.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
import math
|
| 18 |
+
import mmap
|
| 19 |
+
import pickle
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Dict, Iterable, List, Optional, Tuple
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 27 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 28 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
CAM_ORDER_DATAJSON = [
|
| 32 |
+
"CAM_FRONT",
|
| 33 |
+
"CAM_FRONT_RIGHT",
|
| 34 |
+
"CAM_FRONT_LEFT",
|
| 35 |
+
"CAM_BACK",
|
| 36 |
+
"CAM_BACK_LEFT",
|
| 37 |
+
"CAM_BACK_RIGHT",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
CAM_ORDER_PAPER = [
|
| 41 |
+
"CAM_FRONT_LEFT",
|
| 42 |
+
"CAM_FRONT",
|
| 43 |
+
"CAM_FRONT_RIGHT",
|
| 44 |
+
"CAM_BACK_LEFT",
|
| 45 |
+
"CAM_BACK",
|
| 46 |
+
"CAM_BACK_RIGHT",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
_IDX_REORDER = [2, 0, 1, 4, 3, 5] # data_json -> paper layout
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _load_json(path: Path):
|
| 53 |
+
with path.open("r", encoding="utf-8") as f:
|
| 54 |
+
return json.load(f)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _load_pickle(path: Path):
|
| 58 |
+
with path.open("rb") as f:
|
| 59 |
+
return pickle.load(f)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _extract_results_list_for_token(mm: "mmap.mmap", token: str) -> List[Dict]:
|
| 63 |
+
"""
|
| 64 |
+
Extract `results[token]` list from a nuScenes detection results JSON mmap.
|
| 65 |
+
|
| 66 |
+
The file is a minified JSON like:
|
| 67 |
+
{"meta": {...}, "results": {"<token>": [ {...}, ... ], ...}}
|
| 68 |
+
"""
|
| 69 |
+
pat = (f"\"{token}\":").encode("utf-8")
|
| 70 |
+
idx = mm.find(pat)
|
| 71 |
+
if idx < 0:
|
| 72 |
+
return []
|
| 73 |
+
j = idx + len(pat)
|
| 74 |
+
# Skip whitespace
|
| 75 |
+
while j < len(mm) and mm[j] in b" \t\r\n":
|
| 76 |
+
j += 1
|
| 77 |
+
if j >= len(mm) or mm[j : j + 1] != b"[":
|
| 78 |
+
return []
|
| 79 |
+
|
| 80 |
+
start = j
|
| 81 |
+
depth = 0
|
| 82 |
+
in_str = False
|
| 83 |
+
esc = False
|
| 84 |
+
k = start
|
| 85 |
+
end = None
|
| 86 |
+
while k < len(mm):
|
| 87 |
+
c = mm[k]
|
| 88 |
+
if in_str:
|
| 89 |
+
if esc:
|
| 90 |
+
esc = False
|
| 91 |
+
elif c == 0x5C: # backslash
|
| 92 |
+
esc = True
|
| 93 |
+
elif c == 0x22: # quote
|
| 94 |
+
in_str = False
|
| 95 |
+
else:
|
| 96 |
+
if c == 0x22:
|
| 97 |
+
in_str = True
|
| 98 |
+
elif c == 0x5B: # [
|
| 99 |
+
depth += 1
|
| 100 |
+
elif c == 0x5D: # ]
|
| 101 |
+
depth -= 1
|
| 102 |
+
if depth == 0:
|
| 103 |
+
end = k + 1
|
| 104 |
+
break
|
| 105 |
+
k += 1
|
| 106 |
+
|
| 107 |
+
if end is None:
|
| 108 |
+
return []
|
| 109 |
+
try:
|
| 110 |
+
return json.loads(mm[start:end].decode("utf-8"))
|
| 111 |
+
except Exception:
|
| 112 |
+
return []
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _load_ego_pose_map(nuscenes_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
|
| 116 |
+
"""
|
| 117 |
+
Map: sample_token -> (R_ego2global(3x3), t_ego2global(3,))
|
| 118 |
+
"""
|
| 119 |
+
pkl = nuscenes_root / "nuscenes_infos_val.pkl"
|
| 120 |
+
if not pkl.exists():
|
| 121 |
+
raise FileNotFoundError(f"Missing {pkl} (required to transform detector global boxes to ego frame).")
|
| 122 |
+
obj = _load_pickle(pkl)
|
| 123 |
+
infos = obj["infos"] if isinstance(obj, dict) and "infos" in obj else obj
|
| 124 |
+
out: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
|
| 125 |
+
for it in infos:
|
| 126 |
+
tok = str(it.get("token", ""))
|
| 127 |
+
if not tok:
|
| 128 |
+
continue
|
| 129 |
+
t = np.asarray(it.get("ego2global_translation", [0.0, 0.0, 0.0]), dtype=np.float64).reshape(3)
|
| 130 |
+
q = it.get("ego2global_rotation", [1.0, 0.0, 0.0, 0.0])
|
| 131 |
+
if not (isinstance(q, (list, tuple)) and len(q) == 4):
|
| 132 |
+
continue
|
| 133 |
+
R = _quat_to_rotmat(float(q[0]), float(q[1]), float(q[2]), float(q[3]))
|
| 134 |
+
out[tok] = (R, t)
|
| 135 |
+
return out
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _quat_to_rotmat(qw: float, qx: float, qy: float, qz: float) -> np.ndarray:
|
| 139 |
+
# nuScenes quaternions are in (w, x, y, z)
|
| 140 |
+
n = math.sqrt(qw * qw + qx * qx + qy * qy + qz * qz)
|
| 141 |
+
if n < 1e-12:
|
| 142 |
+
return np.eye(3, dtype=np.float64)
|
| 143 |
+
qw, qx, qy, qz = qw / n, qx / n, qy / n, qz / n
|
| 144 |
+
# Standard quaternion -> rotation matrix
|
| 145 |
+
xx, yy, zz = qx * qx, qy * qy, qz * qz
|
| 146 |
+
xy, xz, yz = qx * qy, qx * qz, qy * qz
|
| 147 |
+
wx, wy, wz = qw * qx, qw * qy, qw * qz
|
| 148 |
+
return np.array(
|
| 149 |
+
[
|
| 150 |
+
[1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)],
|
| 151 |
+
[2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)],
|
| 152 |
+
[2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)],
|
| 153 |
+
],
|
| 154 |
+
dtype=np.float64,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _load_fixed_cam_front_calib(nuscenes_root: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 159 |
+
"""
|
| 160 |
+
Returns (R_c2e, t_c2e, K) for CAM_FRONT.
|
| 161 |
+
Uses the first calibrated_sensor record matching CAM_FRONT.
|
| 162 |
+
"""
|
| 163 |
+
meta_root = nuscenes_root / "v1.0-trainval"
|
| 164 |
+
sensor = _load_json(meta_root / "sensor.json")
|
| 165 |
+
calib = _load_json(meta_root / "calibrated_sensor.json")
|
| 166 |
+
|
| 167 |
+
# sensor.json is a list of dicts
|
| 168 |
+
sensor_token = None
|
| 169 |
+
for rec in sensor:
|
| 170 |
+
if rec.get("channel") == "CAM_FRONT":
|
| 171 |
+
sensor_token = rec.get("token")
|
| 172 |
+
break
|
| 173 |
+
if sensor_token is None:
|
| 174 |
+
raise RuntimeError("CAM_FRONT not found in sensor.json")
|
| 175 |
+
|
| 176 |
+
calib_rec = None
|
| 177 |
+
for rec in calib:
|
| 178 |
+
if rec.get("sensor_token") == sensor_token:
|
| 179 |
+
calib_rec = rec
|
| 180 |
+
break
|
| 181 |
+
if calib_rec is None:
|
| 182 |
+
raise RuntimeError("No calibrated_sensor record found for CAM_FRONT")
|
| 183 |
+
|
| 184 |
+
t = np.asarray(calib_rec.get("translation", [0.0, 0.0, 0.0]), dtype=np.float64).reshape(3)
|
| 185 |
+
q = calib_rec.get("rotation", [1.0, 0.0, 0.0, 0.0])
|
| 186 |
+
if not (isinstance(q, (list, tuple)) and len(q) == 4):
|
| 187 |
+
raise RuntimeError(f"Unexpected CAM_FRONT rotation quaternion: {q}")
|
| 188 |
+
R = _quat_to_rotmat(float(q[0]), float(q[1]), float(q[2]), float(q[3]))
|
| 189 |
+
|
| 190 |
+
K = np.asarray(calib_rec.get("camera_intrinsic", np.eye(3).tolist()), dtype=np.float64)
|
| 191 |
+
if K.shape != (3, 3):
|
| 192 |
+
raise RuntimeError(f"Unexpected CAM_FRONT camera_intrinsic shape: {K.shape}")
|
| 193 |
+
return R, t, K
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _load_fixed_cam_calibs(nuscenes_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
| 197 |
+
"""
|
| 198 |
+
Load fixed (R_c2e, t_c2e, K) for all nuScenes cameras by channel name.
|
| 199 |
+
"""
|
| 200 |
+
meta_root = nuscenes_root / "v1.0-trainval"
|
| 201 |
+
sensor = _load_json(meta_root / "sensor.json")
|
| 202 |
+
calib = _load_json(meta_root / "calibrated_sensor.json")
|
| 203 |
+
|
| 204 |
+
sensor_token_by_channel: Dict[str, str] = {}
|
| 205 |
+
for rec in sensor:
|
| 206 |
+
ch = rec.get("channel")
|
| 207 |
+
tok = rec.get("token")
|
| 208 |
+
if isinstance(ch, str) and isinstance(tok, str):
|
| 209 |
+
sensor_token_by_channel[ch] = tok
|
| 210 |
+
|
| 211 |
+
calib_by_sensor_token: Dict[str, Dict] = {}
|
| 212 |
+
for rec in calib:
|
| 213 |
+
st = rec.get("sensor_token")
|
| 214 |
+
if isinstance(st, str):
|
| 215 |
+
calib_by_sensor_token[st] = rec
|
| 216 |
+
|
| 217 |
+
out: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
|
| 218 |
+
for ch, st in sensor_token_by_channel.items():
|
| 219 |
+
rec = calib_by_sensor_token.get(st)
|
| 220 |
+
if not rec or "camera_intrinsic" not in rec:
|
| 221 |
+
continue
|
| 222 |
+
t = np.asarray(rec.get("translation", [0.0, 0.0, 0.0]), dtype=np.float64).reshape(3)
|
| 223 |
+
q = rec.get("rotation", [1.0, 0.0, 0.0, 0.0])
|
| 224 |
+
if not (isinstance(q, (list, tuple)) and len(q) == 4):
|
| 225 |
+
continue
|
| 226 |
+
R = _quat_to_rotmat(float(q[0]), float(q[1]), float(q[2]), float(q[3]))
|
| 227 |
+
K = np.asarray(rec.get("camera_intrinsic", np.eye(3).tolist()), dtype=np.float64)
|
| 228 |
+
if K.shape != (3, 3):
|
| 229 |
+
continue
|
| 230 |
+
out[ch] = (R, t, K)
|
| 231 |
+
return out
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _paper_xy_to_nuscenes_ego_xyz(x_right: float, y_fwd: float, z_up: float = 0.0) -> np.ndarray:
|
| 235 |
+
# nuScenes ego: x forward, y left, z up
|
| 236 |
+
x_fwd = float(y_fwd)
|
| 237 |
+
y_left = float(-x_right)
|
| 238 |
+
return np.array([x_fwd, y_left, float(z_up)], dtype=np.float64)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _project_ego_points_to_cam(
|
| 242 |
+
pts_ego_xyz: np.ndarray, # (N,3) in nuScenes ego
|
| 243 |
+
R_c2e: np.ndarray,
|
| 244 |
+
t_c2e: np.ndarray,
|
| 245 |
+
K: np.ndarray,
|
| 246 |
+
) -> np.ndarray:
|
| 247 |
+
"""
|
| 248 |
+
Project nuScenes ego-frame 3D points to pixel coords in CAM_FRONT.
|
| 249 |
+
Returns (M,2) pixels for points with z_cam > 0.
|
| 250 |
+
"""
|
| 251 |
+
# ego -> cam: p_cam = R_c2e^T (p_ego - t_c2e)
|
| 252 |
+
R_e2c = R_c2e.T
|
| 253 |
+
pts_cam = (R_e2c @ (pts_ego_xyz - t_c2e[None, :]).T).T # (N,3)
|
| 254 |
+
z = pts_cam[:, 2]
|
| 255 |
+
keep = z > 1e-3
|
| 256 |
+
pts_cam = pts_cam[keep]
|
| 257 |
+
if pts_cam.shape[0] == 0:
|
| 258 |
+
return np.zeros((0, 2), dtype=np.float64)
|
| 259 |
+
x = pts_cam[:, 0] / pts_cam[:, 2]
|
| 260 |
+
y = pts_cam[:, 1] / pts_cam[:, 2]
|
| 261 |
+
fx, fy = float(K[0, 0]), float(K[1, 1])
|
| 262 |
+
cx, cy = float(K[0, 2]), float(K[1, 2])
|
| 263 |
+
u = fx * x + cx
|
| 264 |
+
v = fy * y + cy
|
| 265 |
+
return np.stack([u, v], axis=1)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _project_one_ego_point_to_cam(
|
| 269 |
+
pt_ego_xyz: np.ndarray, # (3,)
|
| 270 |
+
R_c2e: np.ndarray,
|
| 271 |
+
t_c2e: np.ndarray,
|
| 272 |
+
K: np.ndarray,
|
| 273 |
+
) -> Optional[Tuple[float, float, float]]:
|
| 274 |
+
"""Return (u, v, z_cam) or None if behind camera."""
|
| 275 |
+
pt_ego_xyz = np.asarray(pt_ego_xyz, dtype=np.float64).reshape(3)
|
| 276 |
+
R_e2c = R_c2e.T
|
| 277 |
+
pt_cam = R_e2c @ (pt_ego_xyz - t_c2e)
|
| 278 |
+
zc = float(pt_cam[2])
|
| 279 |
+
if zc <= 1e-3:
|
| 280 |
+
return None
|
| 281 |
+
fx, fy = float(K[0, 0]), float(K[1, 1])
|
| 282 |
+
cx, cy = float(K[0, 2]), float(K[1, 2])
|
| 283 |
+
u = fx * float(pt_cam[0] / zc) + cx
|
| 284 |
+
v = fy * float(pt_cam[1] / zc) + cy
|
| 285 |
+
return float(u), float(v), zc
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _draw_ego_bev(ax, *, color: str = "k"):
|
| 289 |
+
import matplotlib.patches as patches
|
| 290 |
+
|
| 291 |
+
w, l = 1.85, 4.084
|
| 292 |
+
rect = patches.Rectangle(
|
| 293 |
+
(-w / 2.0, -l / 2.0),
|
| 294 |
+
w,
|
| 295 |
+
l,
|
| 296 |
+
linewidth=1.2,
|
| 297 |
+
edgecolor=color,
|
| 298 |
+
facecolor="none",
|
| 299 |
+
zorder=10,
|
| 300 |
+
)
|
| 301 |
+
ax.add_patch(rect)
|
| 302 |
+
ax.arrow(0.0, 0.0, 0.0, l * 0.8, head_width=0.6, head_length=0.8, fc=color, ec=color, zorder=11)
|
| 303 |
+
ax.text(0.2, -0.6, "EGO", color=color, fontsize=8, zorder=12)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _box_corners_xy(cx: float, cy: float, w: float, l: float, yaw: float) -> np.ndarray:
|
| 307 |
+
"""
|
| 308 |
+
Oriented rectangle corners in XY (paper coords).
|
| 309 |
+
yaw is treated as radians in the same XY frame (best-effort).
|
| 310 |
+
|
| 311 |
+
IMPORTANT: In our planning eval JSON, `yaw` behaves like a standard heading
|
| 312 |
+
angle measured from +X (right) axis (yaw-from-x). So:
|
| 313 |
+
- yaw = 0 => vehicle length points to +X (right)
|
| 314 |
+
- yaw = +pi/2 => vehicle length points to +Y (forward)
|
| 315 |
+
"""
|
| 316 |
+
c, s = math.cos(yaw), math.sin(yaw)
|
| 317 |
+
center = np.array([cx, cy], dtype=np.float64)
|
| 318 |
+
|
| 319 |
+
# Length axis (heading) and width axis (perpendicular), yaw-from-x
|
| 320 |
+
d_len = np.array([c, s], dtype=np.float64) * (l / 2.0)
|
| 321 |
+
d_wid = np.array([-s, c], dtype=np.float64) * (w / 2.0)
|
| 322 |
+
|
| 323 |
+
corners = np.stack(
|
| 324 |
+
[
|
| 325 |
+
center + d_len + d_wid,
|
| 326 |
+
center + d_len - d_wid,
|
| 327 |
+
center - d_len - d_wid,
|
| 328 |
+
center - d_len + d_wid,
|
| 329 |
+
],
|
| 330 |
+
axis=0,
|
| 331 |
+
)
|
| 332 |
+
return corners
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _title_case_command(cmd: str) -> str:
|
| 336 |
+
c = (cmd or "").strip().lower()
|
| 337 |
+
if c == "turn left":
|
| 338 |
+
return "Turn Left"
|
| 339 |
+
if c == "turn right":
|
| 340 |
+
return "Turn Right"
|
| 341 |
+
if c == "go straight":
|
| 342 |
+
return "Go Straight"
|
| 343 |
+
return cmd
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _short_cat(cat: str) -> str:
|
| 347 |
+
c = (cat or "").strip()
|
| 348 |
+
if not c:
|
| 349 |
+
return "obj"
|
| 350 |
+
tail = c.split(".")[-1]
|
| 351 |
+
mapping = {
|
| 352 |
+
"car": "car",
|
| 353 |
+
"truck": "truck",
|
| 354 |
+
"trailer": "trailer",
|
| 355 |
+
"bus": "bus",
|
| 356 |
+
"construction": "cveh",
|
| 357 |
+
"construction_vehicle": "cveh",
|
| 358 |
+
"pedestrian": "ped",
|
| 359 |
+
"trafficcone": "cone",
|
| 360 |
+
"traffic_cone": "cone",
|
| 361 |
+
"barrier": "barrier",
|
| 362 |
+
"motorcycle": "moto",
|
| 363 |
+
"bicycle": "bike",
|
| 364 |
+
}
|
| 365 |
+
return mapping.get(tail, tail[:10])
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _select_default_samples(data_items: List[Dict]) -> List[str]:
|
| 369 |
+
# prefer (turn right, go straight) like the paper figure
|
| 370 |
+
by_cmd: Dict[str, str] = {}
|
| 371 |
+
for it in data_items:
|
| 372 |
+
cmd = (it.get("ego_motion", {}) or {}).get("command")
|
| 373 |
+
if cmd and cmd not in by_cmd and it.get("id"):
|
| 374 |
+
by_cmd[str(cmd)] = str(it["id"])
|
| 375 |
+
out = []
|
| 376 |
+
if "turn right" in by_cmd:
|
| 377 |
+
out.append(by_cmd["turn right"])
|
| 378 |
+
if "go straight" in by_cmd:
|
| 379 |
+
out.append(by_cmd["go straight"])
|
| 380 |
+
if not out:
|
| 381 |
+
out = [str(it.get("id")) for it in data_items[:2] if it.get("id")]
|
| 382 |
+
return out[:2]
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def parse_args():
|
| 386 |
+
ap = argparse.ArgumentParser()
|
| 387 |
+
ap.add_argument("--eval_json", type=str, default="work_dirs/eval_nocap_ep2_plan50.json")
|
| 388 |
+
ap.add_argument("--data_json", type=str, default="data/_eval_ep2_plan50.json")
|
| 389 |
+
ap.add_argument("--data_root", type=str, default="/home/guoyuanbo/autodl-tmp/data/nuscenes")
|
| 390 |
+
ap.add_argument("--out_png", type=str, default="work_dirs/atlas_planning_qualitative.png")
|
| 391 |
+
ap.add_argument("--sample_ids", type=str, nargs="*", default=None, help="Provide 1-2 sample ids; default picks turn right & go straight.")
|
| 392 |
+
ap.add_argument("--bev_xlim", type=float, nargs=2, default=[-30.0, 30.0])
|
| 393 |
+
ap.add_argument("--bev_ylim", type=float, nargs=2, default=[-10.0, 55.0])
|
| 394 |
+
ap.add_argument("--dpi", type=int, default=200)
|
| 395 |
+
ap.add_argument("--draw_gt_boxes", action="store_true", default=True)
|
| 396 |
+
ap.add_argument("--no_draw_gt_boxes", action="store_false", dest="draw_gt_boxes")
|
| 397 |
+
ap.add_argument("--overlay_on_front_cam", action="store_true", default=True)
|
| 398 |
+
ap.add_argument("--no_overlay_on_front_cam", action="store_false", dest="overlay_on_front_cam")
|
| 399 |
+
ap.add_argument("--highlight_front_visible_boxes", action="store_true", default=True)
|
| 400 |
+
ap.add_argument("--no_highlight_front_visible_boxes", action="store_false", dest="highlight_front_visible_boxes")
|
| 401 |
+
ap.add_argument("--max_front_labels", type=int, default=8, help="Max GT box labels shown on CAM_FRONT and BEV.")
|
| 402 |
+
ap.add_argument("--max_cam_labels", type=int, default=4, help="Max GT labels per camera view (all 6 views).")
|
| 403 |
+
ap.add_argument("--bev_only_visible", action="store_true", default=True, help="If set, BEV draws only FRONT* visible GT boxes.")
|
| 404 |
+
ap.add_argument("--bev_show_all", action="store_false", dest="bev_only_visible", help="Draw all GT boxes faintly in BEV (plus highlighted).")
|
| 405 |
+
ap.add_argument(
|
| 406 |
+
"--det_results_json",
|
| 407 |
+
type=str,
|
| 408 |
+
default="external/StreamPETR/val/work_dirs/streampetr_atlas_aligned/Tue_Feb_10_23_13_58_2026/pts_bbox/results_nusc.json",
|
| 409 |
+
help="nuScenes detection results JSON (StreamPETR).",
|
| 410 |
+
)
|
| 411 |
+
ap.add_argument("--draw_det_pred_boxes", action="store_true", default=True)
|
| 412 |
+
ap.add_argument("--no_draw_det_pred_boxes", action="store_false", dest="draw_det_pred_boxes")
|
| 413 |
+
ap.add_argument("--det_score_thresh", type=float, default=0.3)
|
| 414 |
+
ap.add_argument("--det_max_boxes", type=int, default=30)
|
| 415 |
+
ap.add_argument("--det_label_topk", type=int, default=6, help="Label top-k detector boxes in BEV.")
|
| 416 |
+
return ap.parse_args()
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def main():
|
| 420 |
+
args = parse_args()
|
| 421 |
+
repo = _REPO_ROOT
|
| 422 |
+
eval_path = (repo / args.eval_json).resolve()
|
| 423 |
+
data_path = (repo / args.data_json).resolve()
|
| 424 |
+
nuscenes_root = Path(args.data_root).resolve()
|
| 425 |
+
out_png = (repo / args.out_png).resolve()
|
| 426 |
+
|
| 427 |
+
eval_obj = _load_json(eval_path)
|
| 428 |
+
data_items = _load_json(data_path)
|
| 429 |
+
|
| 430 |
+
# Build lookup maps
|
| 431 |
+
pred_text_by_id: Dict[str, str] = {}
|
| 432 |
+
for rec in eval_obj.get("predictions", []):
|
| 433 |
+
sid = str(rec.get("sample_id", ""))
|
| 434 |
+
if not sid:
|
| 435 |
+
continue
|
| 436 |
+
pred_text_by_id[sid] = str(rec.get("generated_text", ""))
|
| 437 |
+
|
| 438 |
+
item_by_id: Dict[str, Dict] = {str(it.get("id")): it for it in data_items if it.get("id")}
|
| 439 |
+
|
| 440 |
+
sample_ids = list(args.sample_ids) if args.sample_ids else _select_default_samples(data_items)
|
| 441 |
+
sample_ids = [sid for sid in sample_ids if sid in item_by_id]
|
| 442 |
+
if not sample_ids:
|
| 443 |
+
raise RuntimeError("No valid sample_ids found. Provide --sample_ids explicitly.")
|
| 444 |
+
if len(sample_ids) > 2:
|
| 445 |
+
sample_ids = sample_ids[:2]
|
| 446 |
+
|
| 447 |
+
# Calibration for projecting onto CAM_FRONT
|
| 448 |
+
cam_calibs: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
|
| 449 |
+
if args.overlay_on_front_cam:
|
| 450 |
+
cam_calibs = _load_fixed_cam_calibs(nuscenes_root)
|
| 451 |
+
|
| 452 |
+
from src.eval.metrics import parse_planning_output
|
| 453 |
+
|
| 454 |
+
# Detector predictions (StreamPETR nuScenes results) are in GLOBAL coords.
|
| 455 |
+
# We need ego pose to transform them into ego/paper coords for BEV plotting.
|
| 456 |
+
det_mm = None
|
| 457 |
+
det_f = None
|
| 458 |
+
ego_pose_map = None
|
| 459 |
+
det_results_path = (repo / args.det_results_json).resolve()
|
| 460 |
+
if args.draw_det_pred_boxes and det_results_path.exists():
|
| 461 |
+
try:
|
| 462 |
+
ego_pose_map = _load_ego_pose_map(nuscenes_root)
|
| 463 |
+
det_f = det_results_path.open("rb")
|
| 464 |
+
det_mm = mmap.mmap(det_f.fileno(), 0, access=mmap.ACCESS_READ)
|
| 465 |
+
except Exception:
|
| 466 |
+
det_mm = None
|
| 467 |
+
if det_f is not None:
|
| 468 |
+
det_f.close()
|
| 469 |
+
det_f = None
|
| 470 |
+
|
| 471 |
+
import matplotlib.pyplot as plt
|
| 472 |
+
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
|
| 473 |
+
|
| 474 |
+
nrows = len(sample_ids)
|
| 475 |
+
fig = plt.figure(figsize=(12.5, 6.5 * nrows), dpi=args.dpi)
|
| 476 |
+
fig.suptitle("Atlas Planning — 6-Camera + BEV (Paper-style)", fontsize=16, y=0.99)
|
| 477 |
+
|
| 478 |
+
outer = GridSpec(nrows, 2, figure=fig, width_ratios=[3.2, 1.4], wspace=0.10, hspace=0.18)
|
| 479 |
+
|
| 480 |
+
from PIL import Image
|
| 481 |
+
|
| 482 |
+
for r, sid in enumerate(sample_ids):
|
| 483 |
+
item = item_by_id[sid]
|
| 484 |
+
pred_text = pred_text_by_id.get(sid, "")
|
| 485 |
+
plan = parse_planning_output(pred_text) if pred_text else None
|
| 486 |
+
pred_wps = (plan or {}).get("waypoints", []) if plan else []
|
| 487 |
+
gt_wps = (item.get("ego_motion", {}) or {}).get("waypoints", []) or []
|
| 488 |
+
cmd = (item.get("ego_motion", {}) or {}).get("command", "")
|
| 489 |
+
|
| 490 |
+
# Load images
|
| 491 |
+
rel_paths: List[str] = list(item.get("image_paths", []) or [])
|
| 492 |
+
if len(rel_paths) != 6:
|
| 493 |
+
raise RuntimeError(f"sample {sid}: expected 6 image_paths, got {len(rel_paths)}")
|
| 494 |
+
imgs = []
|
| 495 |
+
for rp in rel_paths:
|
| 496 |
+
p = Path(rp)
|
| 497 |
+
if not p.is_absolute():
|
| 498 |
+
p = nuscenes_root / rp
|
| 499 |
+
imgs.append(Image.open(p).convert("RGB"))
|
| 500 |
+
imgs = [imgs[i] for i in _IDX_REORDER]
|
| 501 |
+
front_w, front_h = imgs[1].size # CAM_FRONT in paper order
|
| 502 |
+
|
| 503 |
+
# Left: 2x3 image mosaic
|
| 504 |
+
left = GridSpecFromSubplotSpec(2, 3, subplot_spec=outer[r, 0], wspace=0.02, hspace=0.06)
|
| 505 |
+
ax_imgs = []
|
| 506 |
+
for i in range(6):
|
| 507 |
+
ax = fig.add_subplot(left[i // 3, i % 3])
|
| 508 |
+
ax.imshow(imgs[i])
|
| 509 |
+
# Lock image axis limits so later overlays don't autoscale-shrink the image.
|
| 510 |
+
w_i, h_i = imgs[i].size
|
| 511 |
+
ax.set_xlim(0, w_i)
|
| 512 |
+
ax.set_ylim(h_i, 0)
|
| 513 |
+
ax.set_title(CAM_ORDER_PAPER[i], fontsize=9)
|
| 514 |
+
ax.axis("off")
|
| 515 |
+
ax_imgs.append(ax)
|
| 516 |
+
|
| 517 |
+
# Collect GT boxes that are visible in each FRONT* camera (by projected centers).
|
| 518 |
+
boxes = item.get("gt_boxes_3d", []) or []
|
| 519 |
+
|
| 520 |
+
def _visible_for_channel(channel: str, img_w: int, img_h: int):
|
| 521 |
+
if not (args.highlight_front_visible_boxes and args.overlay_on_front_cam):
|
| 522 |
+
return []
|
| 523 |
+
if channel not in cam_calibs:
|
| 524 |
+
return []
|
| 525 |
+
R_c2e, t_c2e, K = cam_calibs[channel]
|
| 526 |
+
out = []
|
| 527 |
+
for bi, b in enumerate(boxes):
|
| 528 |
+
if not isinstance(b, dict) or "world_coords" not in b:
|
| 529 |
+
continue
|
| 530 |
+
wc = b.get("world_coords", [0.0, 0.0, 0.0])
|
| 531 |
+
if not (isinstance(wc, (list, tuple)) and len(wc) >= 3):
|
| 532 |
+
continue
|
| 533 |
+
x_r, y_f, z_u = float(wc[0]), float(wc[1]), float(wc[2])
|
| 534 |
+
pt_ego = _paper_xy_to_nuscenes_ego_xyz(x_r, y_f, z_u)
|
| 535 |
+
uv = _project_one_ego_point_to_cam(pt_ego, R_c2e, t_c2e, K)
|
| 536 |
+
if uv is None:
|
| 537 |
+
continue
|
| 538 |
+
u, v, _zc = uv
|
| 539 |
+
if 0.0 <= u <= float(img_w) and 0.0 <= v <= float(img_h):
|
| 540 |
+
d = float(math.hypot(x_r, y_f))
|
| 541 |
+
out.append((d, bi, str(b.get("category", "")), x_r, y_f, z_u, u, v))
|
| 542 |
+
out.sort(key=lambda t: t[0])
|
| 543 |
+
return out
|
| 544 |
+
|
| 545 |
+
vis_by_ch = {
|
| 546 |
+
"CAM_FRONT_LEFT": _visible_for_channel("CAM_FRONT_LEFT", imgs[0].size[0], imgs[0].size[1]),
|
| 547 |
+
"CAM_FRONT": _visible_for_channel("CAM_FRONT", imgs[1].size[0], imgs[1].size[1]),
|
| 548 |
+
"CAM_FRONT_RIGHT": _visible_for_channel("CAM_FRONT_RIGHT", imgs[2].size[0], imgs[2].size[1]),
|
| 549 |
+
"CAM_BACK_LEFT": _visible_for_channel("CAM_BACK_LEFT", imgs[3].size[0], imgs[3].size[1]),
|
| 550 |
+
"CAM_BACK": _visible_for_channel("CAM_BACK", imgs[4].size[0], imgs[4].size[1]),
|
| 551 |
+
"CAM_BACK_RIGHT": _visible_for_channel("CAM_BACK_RIGHT", imgs[5].size[0], imgs[5].size[1]),
|
| 552 |
+
}
|
| 553 |
+
visible_union_all = []
|
| 554 |
+
if any(vis_by_ch.get(ch) for ch in vis_by_ch.keys()):
|
| 555 |
+
seen = set()
|
| 556 |
+
for ch in ("CAM_FRONT_LEFT", "CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_BACK_LEFT", "CAM_BACK", "CAM_BACK_RIGHT"):
|
| 557 |
+
for tup in vis_by_ch.get(ch, []):
|
| 558 |
+
bi = tup[1]
|
| 559 |
+
if bi in seen:
|
| 560 |
+
continue
|
| 561 |
+
seen.add(bi)
|
| 562 |
+
visible_union_all.append(tup)
|
| 563 |
+
visible_union_all.sort(key=lambda t: t[0])
|
| 564 |
+
visible_union = visible_union_all[: max(int(args.max_front_labels), 0)]
|
| 565 |
+
|
| 566 |
+
# Overlay trajectory on CAM_FRONT (middle of top row in paper order)
|
| 567 |
+
if args.overlay_on_front_cam and pred_wps and "CAM_FRONT" in cam_calibs:
|
| 568 |
+
R_c2e, t_c2e, K = cam_calibs["CAM_FRONT"]
|
| 569 |
+
ax_front = ax_imgs[1]
|
| 570 |
+
# Pred
|
| 571 |
+
pts_pred = np.array([_paper_xy_to_nuscenes_ego_xyz(x, y, 0.0) for x, y in pred_wps], dtype=np.float64)
|
| 572 |
+
uv_pred = _project_ego_points_to_cam(pts_pred, R_c2e, t_c2e, K)
|
| 573 |
+
if uv_pred.shape[0] >= 2:
|
| 574 |
+
ax_front.plot(uv_pred[:, 0], uv_pred[:, 1], "-o", color="#00C2FF", linewidth=2.0, markersize=4.0, alpha=0.95)
|
| 575 |
+
# GT (dashed)
|
| 576 |
+
if gt_wps:
|
| 577 |
+
pts_gt = np.array([_paper_xy_to_nuscenes_ego_xyz(x, y, 0.0) for x, y in gt_wps], dtype=np.float64)
|
| 578 |
+
uv_gt = _project_ego_points_to_cam(pts_gt, R_c2e, t_c2e, K)
|
| 579 |
+
if uv_gt.shape[0] >= 2:
|
| 580 |
+
ax_front.plot(uv_gt[:, 0], uv_gt[:, 1], "--o", color="white", linewidth=2.0, markersize=3.5, alpha=0.85)
|
| 581 |
+
|
| 582 |
+
# Highlight GT centers on the three FRONT* cameras (helps validate BEV↔camera consistency).
|
| 583 |
+
if args.highlight_front_visible_boxes and args.overlay_on_front_cam:
|
| 584 |
+
cam_axes = {
|
| 585 |
+
"CAM_FRONT_LEFT": ax_imgs[0],
|
| 586 |
+
"CAM_FRONT": ax_imgs[1],
|
| 587 |
+
"CAM_FRONT_RIGHT": ax_imgs[2],
|
| 588 |
+
"CAM_BACK_LEFT": ax_imgs[3],
|
| 589 |
+
"CAM_BACK": ax_imgs[4],
|
| 590 |
+
"CAM_BACK_RIGHT": ax_imgs[5],
|
| 591 |
+
}
|
| 592 |
+
for ch, ax in cam_axes.items():
|
| 593 |
+
vis = (vis_by_ch.get(ch, []) or [])[: max(int(args.max_cam_labels), 0)]
|
| 594 |
+
if not vis:
|
| 595 |
+
continue
|
| 596 |
+
for _d, _bi, cat, _x, _y, _z, u, v in vis:
|
| 597 |
+
ax.scatter([u], [v], s=26, c="#FF4D4D", edgecolors="black", linewidths=0.6, zorder=20)
|
| 598 |
+
ax.text(
|
| 599 |
+
u + 6.0,
|
| 600 |
+
v - 6.0,
|
| 601 |
+
_short_cat(cat),
|
| 602 |
+
color="white",
|
| 603 |
+
fontsize=8,
|
| 604 |
+
fontweight="bold",
|
| 605 |
+
ha="left",
|
| 606 |
+
va="bottom",
|
| 607 |
+
zorder=21,
|
| 608 |
+
bbox=dict(boxstyle="round,pad=0.15", facecolor="black", edgecolor="none", alpha=0.6),
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# Right: BEV
|
| 612 |
+
ax_bev = fig.add_subplot(outer[r, 1])
|
| 613 |
+
ax_bev.set_title("BEV", fontsize=12)
|
| 614 |
+
ax_bev.set_xlabel("X (m) — right", fontsize=9)
|
| 615 |
+
ax_bev.set_ylabel("Y (m) — forward", fontsize=9)
|
| 616 |
+
ax_bev.set_aspect("equal", adjustable="box")
|
| 617 |
+
ax_bev.grid(True, linewidth=0.4, alpha=0.35)
|
| 618 |
+
ax_bev.set_xlim(float(args.bev_xlim[0]), float(args.bev_xlim[1]))
|
| 619 |
+
ax_bev.set_ylim(float(args.bev_ylim[0]), float(args.bev_ylim[1]))
|
| 620 |
+
|
| 621 |
+
_draw_ego_bev(ax_bev, color="black")
|
| 622 |
+
|
| 623 |
+
# Draw GT boxes (paper-style context)
|
| 624 |
+
if args.draw_gt_boxes:
|
| 625 |
+
boxes = item.get("gt_boxes_3d", []) or []
|
| 626 |
+
|
| 627 |
+
# When debugging camera↔BEV consistency, draw only the boxes that are visible
|
| 628 |
+
# in the shown camera views to avoid confusing off-screen objects.
|
| 629 |
+
if args.bev_only_visible and visible_union_all:
|
| 630 |
+
boxes_to_draw = [boxes[bi] for _d, bi, _cat, *_rest in visible_union_all if 0 <= bi < len(boxes)]
|
| 631 |
+
else:
|
| 632 |
+
boxes_to_draw = boxes
|
| 633 |
+
|
| 634 |
+
for b in boxes_to_draw:
|
| 635 |
+
if not isinstance(b, dict) or "world_coords" not in b:
|
| 636 |
+
continue
|
| 637 |
+
wc = b.get("world_coords", [0.0, 0.0, 0.0])
|
| 638 |
+
cx, cy = float(wc[0]), float(wc[1])
|
| 639 |
+
w = float(b.get("w", 1.8))
|
| 640 |
+
l = float(b.get("l", 4.0))
|
| 641 |
+
yaw = float(b.get("yaw", 0.0))
|
| 642 |
+
corners = _box_corners_xy(cx, cy, w, l, yaw)
|
| 643 |
+
poly = np.vstack([corners, corners[0:1]])
|
| 644 |
+
ax_bev.plot(poly[:, 0], poly[:, 1], color="#FF5A5A", linewidth=0.8, alpha=0.50)
|
| 645 |
+
|
| 646 |
+
# Re-draw FRONT* visible boxes thicker + label them in BEV.
|
| 647 |
+
if args.highlight_front_visible_boxes and visible_union:
|
| 648 |
+
for _d, bi, cat, x_r, y_f, _z, _u, _v in visible_union:
|
| 649 |
+
if bi >= len(boxes):
|
| 650 |
+
continue
|
| 651 |
+
b = boxes[bi]
|
| 652 |
+
wc = b.get("world_coords", [0.0, 0.0, 0.0])
|
| 653 |
+
cx, cy = float(wc[0]), float(wc[1])
|
| 654 |
+
w = float(b.get("w", 1.8))
|
| 655 |
+
l = float(b.get("l", 4.0))
|
| 656 |
+
yaw = float(b.get("yaw", 0.0))
|
| 657 |
+
corners = _box_corners_xy(cx, cy, w, l, yaw)
|
| 658 |
+
poly = np.vstack([corners, corners[0:1]])
|
| 659 |
+
ax_bev.plot(poly[:, 0], poly[:, 1], color="#FF2E2E", linewidth=1.8, alpha=0.95)
|
| 660 |
+
ax_bev.text(
|
| 661 |
+
cx,
|
| 662 |
+
cy,
|
| 663 |
+
_short_cat(cat),
|
| 664 |
+
fontsize=7.5,
|
| 665 |
+
color="#FF2E2E",
|
| 666 |
+
ha="center",
|
| 667 |
+
va="center",
|
| 668 |
+
zorder=30,
|
| 669 |
+
bbox=dict(boxstyle="round,pad=0.12", facecolor="white", edgecolor="none", alpha=0.55),
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# Draw detector predicted boxes (StreamPETR) in BEV.
|
| 673 |
+
det_handle = None
|
| 674 |
+
if args.draw_det_pred_boxes and det_mm is not None and ego_pose_map is not None:
|
| 675 |
+
R_e2g_t = None
|
| 676 |
+
t_e2g_t = None
|
| 677 |
+
if sid in ego_pose_map:
|
| 678 |
+
R_e2g_t, t_e2g_t = ego_pose_map[sid]
|
| 679 |
+
if R_e2g_t is not None and t_e2g_t is not None:
|
| 680 |
+
dets = _extract_results_list_for_token(det_mm, sid)
|
| 681 |
+
# Filter & transform
|
| 682 |
+
det_plot = []
|
| 683 |
+
xlo, xhi = float(args.bev_xlim[0]), float(args.bev_xlim[1])
|
| 684 |
+
ylo, yhi = float(args.bev_ylim[0]), float(args.bev_ylim[1])
|
| 685 |
+
for d in dets:
|
| 686 |
+
try:
|
| 687 |
+
score = float(d.get("detection_score", 0.0))
|
| 688 |
+
except Exception:
|
| 689 |
+
score = 0.0
|
| 690 |
+
if score < float(args.det_score_thresh):
|
| 691 |
+
continue
|
| 692 |
+
tr = d.get("translation", None)
|
| 693 |
+
sz = d.get("size", None)
|
| 694 |
+
rot = d.get("rotation", None)
|
| 695 |
+
if not (isinstance(tr, (list, tuple)) and len(tr) >= 3):
|
| 696 |
+
continue
|
| 697 |
+
if not (isinstance(sz, (list, tuple)) and len(sz) >= 3):
|
| 698 |
+
continue
|
| 699 |
+
if not (isinstance(rot, (list, tuple)) and len(rot) == 4):
|
| 700 |
+
continue
|
| 701 |
+
|
| 702 |
+
p_g = np.asarray([float(tr[0]), float(tr[1]), float(tr[2])], dtype=np.float64)
|
| 703 |
+
# global -> ego
|
| 704 |
+
p_e = (R_e2g_t.T @ (p_g - t_e2g_t)).reshape(3)
|
| 705 |
+
# ego -> paper
|
| 706 |
+
x_p = float(-p_e[1])
|
| 707 |
+
y_p = float(p_e[0])
|
| 708 |
+
z_p = float(p_e[2])
|
| 709 |
+
if not (xlo <= x_p <= xhi and ylo <= y_p <= yhi):
|
| 710 |
+
continue
|
| 711 |
+
|
| 712 |
+
# orientation: global -> ego -> paper yaw
|
| 713 |
+
R_box_g = _quat_to_rotmat(float(rot[0]), float(rot[1]), float(rot[2]), float(rot[3]))
|
| 714 |
+
R_box_e = R_e2g_t.T @ R_box_g
|
| 715 |
+
yaw_e = float(math.atan2(R_box_e[1, 0], R_box_e[0, 0])) # yaw-from-x in ego
|
| 716 |
+
yaw_p = yaw_e + math.pi / 2.0 # convert to paper yaw-from-x
|
| 717 |
+
|
| 718 |
+
w = float(sz[0])
|
| 719 |
+
l = float(sz[1])
|
| 720 |
+
cat = str(d.get("detection_name", "obj"))
|
| 721 |
+
det_plot.append((score, cat, x_p, y_p, z_p, w, l, yaw_p))
|
| 722 |
+
|
| 723 |
+
det_plot.sort(key=lambda t: -t[0])
|
| 724 |
+
det_plot = det_plot[: max(int(args.det_max_boxes), 0)]
|
| 725 |
+
|
| 726 |
+
det_color = "#00A65A" # green
|
| 727 |
+
for score, cat, x_p, y_p, _z, w, l, yaw_p in det_plot:
|
| 728 |
+
corners = _box_corners_xy(x_p, y_p, w, l, yaw_p)
|
| 729 |
+
poly = np.vstack([corners, corners[0:1]])
|
| 730 |
+
ax_bev.plot(poly[:, 0], poly[:, 1], color=det_color, linewidth=1.1, alpha=0.65, linestyle="--")
|
| 731 |
+
|
| 732 |
+
# Label top-k predictions to show "front vehicles"
|
| 733 |
+
for score, cat, x_p, y_p, _z, w, l, yaw_p in det_plot[: max(int(args.det_label_topk), 0)]:
|
| 734 |
+
ax_bev.text(
|
| 735 |
+
x_p,
|
| 736 |
+
y_p,
|
| 737 |
+
_short_cat(cat),
|
| 738 |
+
fontsize=7.2,
|
| 739 |
+
color=det_color,
|
| 740 |
+
ha="center",
|
| 741 |
+
va="center",
|
| 742 |
+
zorder=31,
|
| 743 |
+
bbox=dict(boxstyle="round,pad=0.12", facecolor="white", edgecolor="none", alpha=0.55),
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
# Legend handle for detector boxes
|
| 747 |
+
from matplotlib.lines import Line2D
|
| 748 |
+
det_handle = Line2D([0], [0], color=det_color, lw=1.5, linestyle="--", label="Det Pred boxes")
|
| 749 |
+
|
| 750 |
+
# Plot trajectories
|
| 751 |
+
if gt_wps:
|
| 752 |
+
gt_arr = np.asarray(gt_wps, dtype=np.float64)
|
| 753 |
+
ax_bev.plot(gt_arr[:, 0], gt_arr[:, 1], "--o", color="black", linewidth=1.8, markersize=4.0, alpha=0.85, label="GT traj")
|
| 754 |
+
if pred_wps:
|
| 755 |
+
pr_arr = np.asarray(pred_wps, dtype=np.float64)
|
| 756 |
+
ax_bev.plot(pr_arr[:, 0], pr_arr[:, 1], "-o", color="#00C2FF", linewidth=2.4, markersize=4.5, alpha=0.95, label="Pred traj")
|
| 757 |
+
|
| 758 |
+
# Legend: add GT boxes handle to avoid confusion (boxes are NOT predicted here).
|
| 759 |
+
handles, _labels = ax_bev.get_legend_handles_labels()
|
| 760 |
+
if args.draw_gt_boxes:
|
| 761 |
+
import matplotlib.patches as mpatches
|
| 762 |
+
handles.append(mpatches.Patch(facecolor="none", edgecolor="#FF5A5A", linewidth=1.2, label="GT boxes"))
|
| 763 |
+
if det_handle is not None:
|
| 764 |
+
handles.append(det_handle)
|
| 765 |
+
ax_bev.legend(handles=handles, loc="upper left", fontsize=9, frameon=True)
|
| 766 |
+
|
| 767 |
+
# Command label (like paper: bottom-left)
|
| 768 |
+
ax_bev.text(
|
| 769 |
+
0.02,
|
| 770 |
+
0.02,
|
| 771 |
+
_title_case_command(str(cmd)),
|
| 772 |
+
transform=ax_bev.transAxes,
|
| 773 |
+
fontsize=12,
|
| 774 |
+
fontweight="bold",
|
| 775 |
+
ha="left",
|
| 776 |
+
va="bottom",
|
| 777 |
+
color="black",
|
| 778 |
+
bbox=dict(boxstyle="round,pad=0.25", facecolor="white", edgecolor="none", alpha=0.65),
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
out_png.parent.mkdir(parents=True, exist_ok=True)
|
| 782 |
+
fig.savefig(out_png, bbox_inches="tight")
|
| 783 |
+
plt.close(fig)
|
| 784 |
+
print(f"[saved] {out_png}")
|
| 785 |
+
|
| 786 |
+
if det_mm is not None:
|
| 787 |
+
try:
|
| 788 |
+
det_mm.close()
|
| 789 |
+
except Exception:
|
| 790 |
+
pass
|
| 791 |
+
if det_f is not None:
|
| 792 |
+
try:
|
| 793 |
+
det_f.close()
|
| 794 |
+
except Exception:
|
| 795 |
+
pass
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
if __name__ == "__main__":
|
| 799 |
+
main()
|
| 800 |
+
|
scripts/vis_traffic_violation.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Figure 11 — Violation of traffic regulations.
|
| 4 |
+
|
| 5 |
+
This script focuses on the construction-blocking case shown in the paper:
|
| 6 |
+
the road ahead is fenced by barriers / traffic cones, but Atlas still outputs
|
| 7 |
+
a "go straight" trajectory that cuts through the blocked area.
|
| 8 |
+
|
| 9 |
+
Style:
|
| 10 |
+
- Left: 6 camera views, only CAM_FRONT overlays the trajectory
|
| 11 |
+
- Right: clean BEV with black road boundaries, red construction boxes,
|
| 12 |
+
green/blue trajectory, and "Go Straight" label
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import math
|
| 20 |
+
import pickle
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Dict, List, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
_REPO = Path(__file__).resolve().parent.parent
|
| 28 |
+
if str(_REPO) not in sys.path:
|
| 29 |
+
sys.path.insert(0, str(_REPO))
|
| 30 |
+
|
| 31 |
+
CAM_ORDER_PAPER = [
|
| 32 |
+
"CAM_FRONT_LEFT",
|
| 33 |
+
"CAM_FRONT",
|
| 34 |
+
"CAM_FRONT_RIGHT",
|
| 35 |
+
"CAM_BACK_LEFT",
|
| 36 |
+
"CAM_BACK",
|
| 37 |
+
"CAM_BACK_RIGHT",
|
| 38 |
+
]
|
| 39 |
+
_IDX_REORDER = [2, 0, 1, 4, 3, 5]
|
| 40 |
+
|
| 41 |
+
LOCATION_TO_MAP = {
|
| 42 |
+
"singapore-onenorth": "53992ee3023e5494b90c316c183be829.png",
|
| 43 |
+
"boston-seaport": "36092f0b03a857c6a3403e25b4b7aab3.png",
|
| 44 |
+
"singapore-queenstown": "93406b464a165eaba6d9de76ca09f5da.png",
|
| 45 |
+
"singapore-hollandvillage": "37819e65e09e5547b8a3ceaefba56bb2.png",
|
| 46 |
+
}
|
| 47 |
+
MAP_RES = 0.1 # meters per pixel
|
| 48 |
+
DEFAULT_FIG11_SAMPLE = "856ccc626a4a4c0aaac1e62335050ac0"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _load_json(path: Path):
|
| 52 |
+
with path.open("r", encoding="utf-8") as f:
|
| 53 |
+
return json.load(f)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _load_pickle(path: Path):
|
| 57 |
+
with path.open("rb") as f:
|
| 58 |
+
return pickle.load(f)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _quat_to_rotmat(qw, qx, qy, qz):
|
| 62 |
+
n = math.sqrt(qw * qw + qx * qx + qy * qy + qz * qz)
|
| 63 |
+
if n < 1e-12:
|
| 64 |
+
return np.eye(3, dtype=np.float64)
|
| 65 |
+
qw, qx, qy, qz = qw / n, qx / n, qy / n, qz / n
|
| 66 |
+
xx, yy, zz = qx * qx, qy * qy, qz * qz
|
| 67 |
+
xy, xz, yz = qx * qy, qx * qz, qy * qz
|
| 68 |
+
wx, wy, wz = qw * qx, qw * qy, qw * qz
|
| 69 |
+
return np.array(
|
| 70 |
+
[
|
| 71 |
+
[1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)],
|
| 72 |
+
[2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)],
|
| 73 |
+
[2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)],
|
| 74 |
+
],
|
| 75 |
+
dtype=np.float64,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _quat_to_yaw(q):
|
| 80 |
+
w, x, y, z = [float(v) for v in q]
|
| 81 |
+
return math.atan2(2 * (w * z + x * y), 1 - 2 * (y * y + z * z))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _paper_xy_to_ego(x_right: float, y_fwd: float, z_up: float = 0.0) -> np.ndarray:
|
| 85 |
+
return np.array([float(y_fwd), float(-x_right), float(z_up)], dtype=np.float64)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _project_batch(pts_ego: np.ndarray, R_c2e: np.ndarray, t_c2e: np.ndarray, K: np.ndarray) -> np.ndarray:
|
| 89 |
+
R_e2c = R_c2e.T
|
| 90 |
+
pts_cam = (R_e2c @ (pts_ego - t_c2e[None, :]).T).T
|
| 91 |
+
z = pts_cam[:, 2]
|
| 92 |
+
keep = z > 1e-3
|
| 93 |
+
pts_cam = pts_cam[keep]
|
| 94 |
+
if pts_cam.shape[0] == 0:
|
| 95 |
+
return np.zeros((0, 2), dtype=np.float64)
|
| 96 |
+
x = pts_cam[:, 0] / pts_cam[:, 2]
|
| 97 |
+
y = pts_cam[:, 1] / pts_cam[:, 2]
|
| 98 |
+
fx, fy = float(K[0, 0]), float(K[1, 1])
|
| 99 |
+
cx, cy = float(K[0, 2]), float(K[1, 2])
|
| 100 |
+
return np.stack([fx * x + cx, fy * y + cy], axis=1)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _load_cam_calibs(nusc_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
| 104 |
+
meta = nusc_root / "v1.0-trainval"
|
| 105 |
+
sensor = _load_json(meta / "sensor.json")
|
| 106 |
+
calib = _load_json(meta / "calibrated_sensor.json")
|
| 107 |
+
|
| 108 |
+
sensor_token_by_channel: Dict[str, str] = {}
|
| 109 |
+
for rec in sensor:
|
| 110 |
+
ch = rec.get("channel")
|
| 111 |
+
tok = rec.get("token")
|
| 112 |
+
if isinstance(ch, str) and isinstance(tok, str):
|
| 113 |
+
sensor_token_by_channel[ch] = tok
|
| 114 |
+
|
| 115 |
+
calib_by_sensor_token: Dict[str, Dict] = {}
|
| 116 |
+
for rec in calib:
|
| 117 |
+
st = rec.get("sensor_token")
|
| 118 |
+
if isinstance(st, str):
|
| 119 |
+
calib_by_sensor_token[st] = rec
|
| 120 |
+
|
| 121 |
+
out: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
|
| 122 |
+
for ch, st in sensor_token_by_channel.items():
|
| 123 |
+
rec = calib_by_sensor_token.get(st)
|
| 124 |
+
if not rec or "camera_intrinsic" not in rec:
|
| 125 |
+
continue
|
| 126 |
+
t = np.asarray(rec.get("translation", [0, 0, 0]), dtype=np.float64).reshape(3)
|
| 127 |
+
q = rec.get("rotation", [1, 0, 0, 0])
|
| 128 |
+
K = np.asarray(rec.get("camera_intrinsic", np.eye(3).tolist()), dtype=np.float64)
|
| 129 |
+
if not (isinstance(q, (list, tuple)) and len(q) == 4):
|
| 130 |
+
continue
|
| 131 |
+
if K.shape != (3, 3):
|
| 132 |
+
continue
|
| 133 |
+
R = _quat_to_rotmat(*[float(x) for x in q])
|
| 134 |
+
out[ch] = (R, t, K)
|
| 135 |
+
return out
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _load_ego_poses(nusc_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, float]]:
|
| 139 |
+
pkl = nusc_root / "nuscenes_infos_val.pkl"
|
| 140 |
+
obj = _load_pickle(pkl)
|
| 141 |
+
infos = obj["infos"] if isinstance(obj, dict) and "infos" in obj else obj
|
| 142 |
+
out = {}
|
| 143 |
+
for it in infos:
|
| 144 |
+
tok = str(it.get("token", ""))
|
| 145 |
+
if not tok:
|
| 146 |
+
continue
|
| 147 |
+
t = np.asarray(it.get("ego2global_translation", [0, 0, 0]), dtype=np.float64).reshape(3)
|
| 148 |
+
q = it.get("ego2global_rotation", [1, 0, 0, 0])
|
| 149 |
+
if isinstance(q, (list, tuple)) and len(q) == 4:
|
| 150 |
+
out[tok] = (_quat_to_rotmat(*[float(x) for x in q]), t, _quat_to_yaw(q))
|
| 151 |
+
return out
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _get_sample_location(nusc_root: Path, sample_token: str) -> str:
|
| 155 |
+
samples = _load_json(nusc_root / "v1.0-trainval" / "sample.json")
|
| 156 |
+
scenes = _load_json(nusc_root / "v1.0-trainval" / "scene.json")
|
| 157 |
+
logs = _load_json(nusc_root / "v1.0-trainval" / "log.json")
|
| 158 |
+
scene_map = {s["token"]: s for s in scenes}
|
| 159 |
+
log_map = {l["token"]: l for l in logs}
|
| 160 |
+
for s in samples:
|
| 161 |
+
if s["token"] == sample_token:
|
| 162 |
+
sc = scene_map.get(s.get("scene_token", ""), {})
|
| 163 |
+
lg = log_map.get(sc.get("log_token", ""), {})
|
| 164 |
+
return str(lg.get("location", ""))
|
| 165 |
+
return ""
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _smooth_map(bev_map: np.ndarray) -> np.ndarray:
|
| 169 |
+
acc = bev_map.copy()
|
| 170 |
+
acc += np.roll(bev_map, 1, axis=0)
|
| 171 |
+
acc += np.roll(bev_map, -1, axis=0)
|
| 172 |
+
acc += np.roll(bev_map, 1, axis=1)
|
| 173 |
+
acc += np.roll(bev_map, -1, axis=1)
|
| 174 |
+
acc += np.roll(np.roll(bev_map, 1, axis=0), 1, axis=1)
|
| 175 |
+
acc += np.roll(np.roll(bev_map, 1, axis=0), -1, axis=1)
|
| 176 |
+
acc += np.roll(np.roll(bev_map, -1, axis=0), 1, axis=1)
|
| 177 |
+
acc += np.roll(np.roll(bev_map, -1, axis=0), -1, axis=1)
|
| 178 |
+
return acc / 9.0
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _build_bev_map(
|
| 182 |
+
nusc_root: Path,
|
| 183 |
+
location: str,
|
| 184 |
+
ego_xy: np.ndarray,
|
| 185 |
+
ego_yaw: float,
|
| 186 |
+
bev_xlim: Tuple[float, float],
|
| 187 |
+
bev_ylim: Tuple[float, float],
|
| 188 |
+
bev_res: float = 0.1,
|
| 189 |
+
) -> Optional[np.ndarray]:
|
| 190 |
+
import PIL.Image
|
| 191 |
+
PIL.Image.MAX_IMAGE_PIXELS = None
|
| 192 |
+
from PIL import Image
|
| 193 |
+
|
| 194 |
+
map_fn = LOCATION_TO_MAP.get(location)
|
| 195 |
+
if not map_fn:
|
| 196 |
+
return None
|
| 197 |
+
map_path = nusc_root / "maps" / map_fn
|
| 198 |
+
if not map_path.exists():
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
map_img = Image.open(map_path)
|
| 202 |
+
mw, mh = map_img.size
|
| 203 |
+
map_max_y = mh * MAP_RES
|
| 204 |
+
map_arr = np.asarray(map_img, dtype=np.float32) / 255.0
|
| 205 |
+
|
| 206 |
+
ex, ey = float(ego_xy[0]), float(ego_xy[1])
|
| 207 |
+
c_yaw, s_yaw = math.cos(ego_yaw), math.sin(ego_yaw)
|
| 208 |
+
|
| 209 |
+
x0, x1 = bev_xlim
|
| 210 |
+
y0, y1 = bev_ylim
|
| 211 |
+
nx = int((x1 - x0) / bev_res)
|
| 212 |
+
ny = int((y1 - y0) / bev_res)
|
| 213 |
+
bev = np.zeros((ny, nx), dtype=np.float32)
|
| 214 |
+
|
| 215 |
+
px_arr = np.linspace(x0, x1, nx)
|
| 216 |
+
py_arr = np.linspace(y1, y0, ny)
|
| 217 |
+
PX, PY = np.meshgrid(px_arr, py_arr)
|
| 218 |
+
|
| 219 |
+
GX = ex + PY * c_yaw + PX * s_yaw
|
| 220 |
+
GY = ey + PY * s_yaw - PX * c_yaw
|
| 221 |
+
|
| 222 |
+
MX = (GX / MAP_RES).astype(np.int32)
|
| 223 |
+
MY = ((map_max_y - GY) / MAP_RES).astype(np.int32)
|
| 224 |
+
valid = (MX >= 0) & (MX < mw) & (MY >= 0) & (MY < mh)
|
| 225 |
+
bev[valid] = map_arr[MY[valid], MX[valid]]
|
| 226 |
+
return _smooth_map(bev)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _box_corners(cx: float, cy: float, w: float, l: float, yaw: float) -> np.ndarray:
|
| 230 |
+
c, s = math.cos(yaw), math.sin(yaw)
|
| 231 |
+
center = np.array([cx, cy], dtype=np.float64)
|
| 232 |
+
d_len = np.array([c, s], dtype=np.float64) * (l / 2.0)
|
| 233 |
+
d_wid = np.array([-s, c], dtype=np.float64) * (w / 2.0)
|
| 234 |
+
return np.stack(
|
| 235 |
+
[
|
| 236 |
+
center + d_len + d_wid,
|
| 237 |
+
center + d_len - d_wid,
|
| 238 |
+
center - d_len - d_wid,
|
| 239 |
+
center - d_len + d_wid,
|
| 240 |
+
],
|
| 241 |
+
axis=0,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _is_barrier_like(cat: str) -> bool:
|
| 246 |
+
return ("barrier" in cat) or ("traffic_cone" in cat) or ("trafficcone" in cat)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _is_context_like(cat: str) -> bool:
|
| 250 |
+
return ("construction" in cat) or ("pedestrian" in cat) or ("vehicle.car" in cat)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _title_cmd(cmd: str) -> str:
|
| 254 |
+
c = (cmd or "").strip().lower()
|
| 255 |
+
return {
|
| 256 |
+
"turn left": "Turn Left",
|
| 257 |
+
"turn right": "Turn Right",
|
| 258 |
+
"go straight": "Go Straight",
|
| 259 |
+
}.get(c, cmd)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _blocked_score(item: Dict) -> float:
|
| 263 |
+
route_command = str(item.get("route_command", "")).strip().lower()
|
| 264 |
+
if route_command != "go straight":
|
| 265 |
+
return -1e9
|
| 266 |
+
boxes = item.get("gt_boxes_3d", []) or []
|
| 267 |
+
wps = (item.get("ego_motion", {}) or {}).get("waypoints", []) or []
|
| 268 |
+
n_block_front = 0
|
| 269 |
+
n_block_center = 0
|
| 270 |
+
n_barrier = 0
|
| 271 |
+
n_cone = 0
|
| 272 |
+
for b in boxes:
|
| 273 |
+
cat = str(b.get("category", ""))
|
| 274 |
+
wc = b.get("world_coords", [0, 0, 0])
|
| 275 |
+
if not (isinstance(wc, (list, tuple)) and len(wc) >= 2):
|
| 276 |
+
continue
|
| 277 |
+
x, y = float(wc[0]), float(wc[1])
|
| 278 |
+
if _is_barrier_like(cat):
|
| 279 |
+
if "barrier" in cat:
|
| 280 |
+
n_barrier += 1
|
| 281 |
+
else:
|
| 282 |
+
n_cone += 1
|
| 283 |
+
if 0 < y < 25 and abs(x) < 10:
|
| 284 |
+
n_block_front += 1
|
| 285 |
+
if 2 < y < 20 and abs(x) < 4:
|
| 286 |
+
n_block_center += 1
|
| 287 |
+
through_center = 0
|
| 288 |
+
for x, y in wps:
|
| 289 |
+
if 2 < float(y) < 20 and abs(float(x)) < 4:
|
| 290 |
+
through_center += 1
|
| 291 |
+
return n_block_center * 8 + n_block_front * 3 + through_center * 4 + n_barrier * 1.5 + n_cone
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def parse_args():
|
| 295 |
+
ap = argparse.ArgumentParser()
|
| 296 |
+
ap.add_argument("--eval_json", default="work_dirs/eval_final_plan100.json")
|
| 297 |
+
ap.add_argument("--data_json", default="data/atlas_planning_val_uniad_command.json")
|
| 298 |
+
ap.add_argument("--data_root", default="/home/guoyuanbo/autodl-tmp/data/nuscenes")
|
| 299 |
+
ap.add_argument("--sample_id", default=None)
|
| 300 |
+
ap.add_argument("--out_png", default="work_dirs/atlas_traffic_violation.png")
|
| 301 |
+
ap.add_argument("--dpi", type=int, default=200)
|
| 302 |
+
ap.add_argument("--bev_xlim", type=float, nargs=2, default=[-7.5, 10.5])
|
| 303 |
+
ap.add_argument("--bev_ylim", type=float, nargs=2, default=[-1.5, 30.5])
|
| 304 |
+
return ap.parse_args()
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def main():
|
| 308 |
+
args = parse_args()
|
| 309 |
+
repo = _REPO
|
| 310 |
+
nusc_root = Path(args.data_root).resolve()
|
| 311 |
+
out_png = (repo / args.out_png).resolve()
|
| 312 |
+
|
| 313 |
+
eval_obj = _load_json((repo / args.eval_json).resolve())
|
| 314 |
+
data_items = _load_json((repo / args.data_json).resolve())
|
| 315 |
+
pred_by_id = {
|
| 316 |
+
str(r.get("sample_id", "")): str(r.get("generated_text", ""))
|
| 317 |
+
for r in eval_obj.get("predictions", [])
|
| 318 |
+
if r.get("sample_id")
|
| 319 |
+
}
|
| 320 |
+
item_by_id = {str(it["id"]): it for it in data_items if it.get("id")}
|
| 321 |
+
|
| 322 |
+
from src.eval.metrics import parse_planning_output
|
| 323 |
+
|
| 324 |
+
sid = args.sample_id
|
| 325 |
+
if not sid:
|
| 326 |
+
if DEFAULT_FIG11_SAMPLE in item_by_id and parse_planning_output(pred_by_id.get(DEFAULT_FIG11_SAMPLE, "")):
|
| 327 |
+
sid = DEFAULT_FIG11_SAMPLE
|
| 328 |
+
else:
|
| 329 |
+
candidates = []
|
| 330 |
+
for item in data_items:
|
| 331 |
+
item_sid = str(item.get("id", ""))
|
| 332 |
+
pred_text = pred_by_id.get(item_sid, "")
|
| 333 |
+
if not pred_text:
|
| 334 |
+
continue
|
| 335 |
+
plan = parse_planning_output(pred_text)
|
| 336 |
+
if not plan or not plan.get("waypoints"):
|
| 337 |
+
continue
|
| 338 |
+
candidates.append((_blocked_score(item), item_sid))
|
| 339 |
+
candidates.sort(reverse=True)
|
| 340 |
+
if not candidates:
|
| 341 |
+
raise RuntimeError("No valid construction-blocked sample found.")
|
| 342 |
+
sid = candidates[0][1]
|
| 343 |
+
if sid not in item_by_id:
|
| 344 |
+
raise RuntimeError(f"sample_id {sid} not found in data_json")
|
| 345 |
+
|
| 346 |
+
item = item_by_id[sid]
|
| 347 |
+
pred_text = pred_by_id.get(sid, "")
|
| 348 |
+
plan = parse_planning_output(pred_text) if pred_text else None
|
| 349 |
+
if not plan or not plan.get("waypoints"):
|
| 350 |
+
raise RuntimeError(f"sample_id {sid} has no parseable planning output in {args.eval_json}")
|
| 351 |
+
|
| 352 |
+
pred_wps = np.asarray(plan["waypoints"], dtype=np.float64)
|
| 353 |
+
gt_wps = np.asarray((item.get("ego_motion", {}) or {}).get("waypoints", []) or [], dtype=np.float64)
|
| 354 |
+
cmd = str(item.get("route_command", ""))
|
| 355 |
+
boxes = item.get("gt_boxes_3d", []) or []
|
| 356 |
+
location = _get_sample_location(nusc_root, sid)
|
| 357 |
+
ego_poses = _load_ego_poses(nusc_root)
|
| 358 |
+
ego_info = ego_poses.get(sid)
|
| 359 |
+
if ego_info is None:
|
| 360 |
+
raise RuntimeError(f"Missing ego pose for sample {sid}")
|
| 361 |
+
ego_xy = ego_info[1][:2]
|
| 362 |
+
ego_yaw = ego_info[2]
|
| 363 |
+
|
| 364 |
+
print(f"[sample] {sid}")
|
| 365 |
+
print(f" location: {location}")
|
| 366 |
+
print(f" pred: {pred_text}")
|
| 367 |
+
|
| 368 |
+
bev_map = _build_bev_map(
|
| 369 |
+
nusc_root,
|
| 370 |
+
location,
|
| 371 |
+
ego_xy,
|
| 372 |
+
ego_yaw,
|
| 373 |
+
tuple(args.bev_xlim),
|
| 374 |
+
tuple(args.bev_ylim),
|
| 375 |
+
bev_res=0.1,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
from PIL import Image
|
| 379 |
+
|
| 380 |
+
rel_paths = list(item.get("image_paths", []) or [])
|
| 381 |
+
if len(rel_paths) != 6:
|
| 382 |
+
raise RuntimeError(f"Expected 6 images, got {len(rel_paths)}")
|
| 383 |
+
imgs = []
|
| 384 |
+
for rp in rel_paths:
|
| 385 |
+
p = Path(rp)
|
| 386 |
+
if not p.is_absolute():
|
| 387 |
+
p = nusc_root / rp
|
| 388 |
+
imgs.append(Image.open(p).convert("RGB"))
|
| 389 |
+
imgs = [imgs[i] for i in _IDX_REORDER]
|
| 390 |
+
|
| 391 |
+
cam_calibs = _load_cam_calibs(nusc_root)
|
| 392 |
+
|
| 393 |
+
import matplotlib
|
| 394 |
+
matplotlib.use("Agg")
|
| 395 |
+
import matplotlib.pyplot as plt
|
| 396 |
+
import matplotlib.patches as patches
|
| 397 |
+
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
|
| 398 |
+
|
| 399 |
+
fig = plt.figure(figsize=(14.8, 4.1), dpi=args.dpi)
|
| 400 |
+
gs = GridSpec(1, 2, figure=fig, width_ratios=[3.0, 1.1], wspace=0.03)
|
| 401 |
+
gs_cam = GridSpecFromSubplotSpec(2, 3, subplot_spec=gs[0, 0], wspace=0.01, hspace=0.01)
|
| 402 |
+
|
| 403 |
+
ax_imgs = []
|
| 404 |
+
for i in range(6):
|
| 405 |
+
ax = fig.add_subplot(gs_cam[i // 3, i % 3])
|
| 406 |
+
ax.imshow(imgs[i])
|
| 407 |
+
w_i, h_i = imgs[i].size
|
| 408 |
+
ax.set_xlim(0, w_i)
|
| 409 |
+
ax.set_ylim(h_i, 0)
|
| 410 |
+
ax.axis("off")
|
| 411 |
+
ax.text(
|
| 412 |
+
6,
|
| 413 |
+
14,
|
| 414 |
+
CAM_ORDER_PAPER[i],
|
| 415 |
+
color="white",
|
| 416 |
+
fontsize=7,
|
| 417 |
+
ha="left",
|
| 418 |
+
va="top",
|
| 419 |
+
bbox=dict(boxstyle="square,pad=0.12", facecolor="black", edgecolor="none", alpha=0.55),
|
| 420 |
+
)
|
| 421 |
+
ax_imgs.append(ax)
|
| 422 |
+
|
| 423 |
+
if "CAM_FRONT" in cam_calibs:
|
| 424 |
+
R_c2e, t_c2e, K = cam_calibs["CAM_FRONT"]
|
| 425 |
+
ax_front = ax_imgs[1]
|
| 426 |
+
if gt_wps.shape[0] >= 2:
|
| 427 |
+
uv_gt = _project_batch(
|
| 428 |
+
np.array([_paper_xy_to_ego(x, y) for x, y in gt_wps], dtype=np.float64),
|
| 429 |
+
R_c2e,
|
| 430 |
+
t_c2e,
|
| 431 |
+
K,
|
| 432 |
+
)
|
| 433 |
+
if uv_gt.shape[0] >= 2:
|
| 434 |
+
ax_front.plot(uv_gt[:, 0], uv_gt[:, 1], color="#34c759", linewidth=4.0, alpha=0.95, zorder=18)
|
| 435 |
+
uv_pred = _project_batch(
|
| 436 |
+
np.array([_paper_xy_to_ego(x, y) for x, y in pred_wps], dtype=np.float64),
|
| 437 |
+
R_c2e,
|
| 438 |
+
t_c2e,
|
| 439 |
+
K,
|
| 440 |
+
)
|
| 441 |
+
if uv_pred.shape[0] >= 2:
|
| 442 |
+
ax_front.plot(uv_pred[:, 0], uv_pred[:, 1], color="#1f5cff", linewidth=2.2, alpha=0.98, zorder=20)
|
| 443 |
+
ax_front.scatter(uv_pred[:, 0], uv_pred[:, 1], color="#1f5cff", s=8, zorder=21)
|
| 444 |
+
|
| 445 |
+
ax_bev = fig.add_subplot(gs[0, 1])
|
| 446 |
+
ax_bev.set_facecolor("white")
|
| 447 |
+
ax_bev.set_xlim(*args.bev_xlim)
|
| 448 |
+
ax_bev.set_ylim(*args.bev_ylim)
|
| 449 |
+
ax_bev.set_aspect("equal", adjustable="box")
|
| 450 |
+
ax_bev.set_xticks([])
|
| 451 |
+
ax_bev.set_yticks([])
|
| 452 |
+
for spine in ax_bev.spines.values():
|
| 453 |
+
spine.set_linewidth(1.3)
|
| 454 |
+
spine.set_color("black")
|
| 455 |
+
|
| 456 |
+
if bev_map is not None:
|
| 457 |
+
ny, nx = bev_map.shape
|
| 458 |
+
xs = np.linspace(args.bev_xlim[0], args.bev_xlim[1], nx)
|
| 459 |
+
ys = np.linspace(args.bev_ylim[0], args.bev_ylim[1], ny)
|
| 460 |
+
ax_bev.contour(xs, ys, bev_map, levels=[0.5], colors="black", linewidths=0.8, zorder=1)
|
| 461 |
+
|
| 462 |
+
construction_boxes = []
|
| 463 |
+
for b in boxes:
|
| 464 |
+
cat = str(b.get("category", ""))
|
| 465 |
+
wc = b.get("world_coords", [0, 0, 0])
|
| 466 |
+
if not (isinstance(wc, (list, tuple)) and len(wc) >= 2):
|
| 467 |
+
continue
|
| 468 |
+
cx, cy = float(wc[0]), float(wc[1])
|
| 469 |
+
if not (args.bev_xlim[0] - 2 <= cx <= args.bev_xlim[1] + 2 and args.bev_ylim[0] - 2 <= cy <= args.bev_ylim[1] + 2):
|
| 470 |
+
continue
|
| 471 |
+
box_item = (cx, cy, float(b.get("w", 1.8)), float(b.get("l", 4.0)), float(b.get("yaw", 0.0)))
|
| 472 |
+
if _is_barrier_like(cat):
|
| 473 |
+
construction_boxes.append(box_item)
|
| 474 |
+
|
| 475 |
+
for cx, cy, w, l, yaw in construction_boxes:
|
| 476 |
+
poly = _box_corners(cx, cy, w, l, yaw)
|
| 477 |
+
poly = np.vstack([poly, poly[0:1]])
|
| 478 |
+
ax_bev.plot(poly[:, 0], poly[:, 1], color="#cf7a6b", linewidth=0.9, alpha=0.95, zorder=3)
|
| 479 |
+
|
| 480 |
+
if gt_wps.shape[0] >= 2:
|
| 481 |
+
ax_bev.plot(gt_wps[:, 0], gt_wps[:, 1], color="#34c759", linewidth=3.4, alpha=0.95, zorder=5)
|
| 482 |
+
ax_bev.plot(pred_wps[:, 0], pred_wps[:, 1], color="#1f5cff", linewidth=1.8, alpha=0.98, zorder=6)
|
| 483 |
+
|
| 484 |
+
ego_rect = patches.Rectangle(
|
| 485 |
+
(-0.45, -0.45),
|
| 486 |
+
0.9,
|
| 487 |
+
0.9,
|
| 488 |
+
linewidth=1.4,
|
| 489 |
+
edgecolor="#34c759",
|
| 490 |
+
facecolor="none",
|
| 491 |
+
zorder=7,
|
| 492 |
+
)
|
| 493 |
+
ax_bev.add_patch(ego_rect)
|
| 494 |
+
|
| 495 |
+
ax_bev.text(0.03, 0.98, "BEV", transform=ax_bev.transAxes, ha="left", va="top", fontsize=9, fontweight="bold")
|
| 496 |
+
ax_bev.text(
|
| 497 |
+
0.03,
|
| 498 |
+
0.03,
|
| 499 |
+
_title_cmd(cmd),
|
| 500 |
+
transform=ax_bev.transAxes,
|
| 501 |
+
ha="left",
|
| 502 |
+
va="bottom",
|
| 503 |
+
fontsize=9,
|
| 504 |
+
fontweight="bold",
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
out_png.parent.mkdir(parents=True, exist_ok=True)
|
| 508 |
+
fig.savefig(out_png, bbox_inches="tight", facecolor="white", pad_inches=0.02)
|
| 509 |
+
plt.close(fig)
|
| 510 |
+
|
| 511 |
+
print(f"[saved] {out_png}")
|
| 512 |
+
print(f" construction boxes: {len(construction_boxes)}")
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
if __name__ == "__main__":
|
| 516 |
+
main()
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (177 Bytes). View file
|
|
|
src/__pycache__/prompting.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
src/__pycache__/prompting.cpython-38.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
src/audit/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (240 Bytes). View file
|
|
|
src/audit/__pycache__/audit_utils.cpython-310.pyc
ADDED
|
Binary file (998 Bytes). View file
|
|
|
src/dataset/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (316 Bytes). View file
|
|
|
src/dataset/__pycache__/atlas_dataset.cpython-310.pyc
ADDED
|
Binary file (40 kB). View file
|
|
|
src/dataset/__pycache__/atlas_dataset.cpython-38.pyc
ADDED
|
Binary file (40.6 kB). View file
|
|
|
src/dataset/__pycache__/scene_sampler.cpython-310.pyc
ADDED
|
Binary file (3.88 kB). View file
|
|
|
src/dataset/__pycache__/scene_sampler.cpython-38.pyc
ADDED
|
Binary file (3.89 kB). View file
|
|
|
src/dataset/atlas_dataset.py
ADDED
|
@@ -0,0 +1,1416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import Dict, List, Optional, Tuple
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
try:
|
| 12 |
+
import torchvision.transforms.v2 as _v2
|
| 13 |
+
_HAS_V2 = hasattr(_v2, "ToImage") and hasattr(_v2, "ToDtype")
|
| 14 |
+
except (ImportError, AttributeError):
|
| 15 |
+
_HAS_V2 = False
|
| 16 |
+
if _HAS_V2:
|
| 17 |
+
transforms = _v2
|
| 18 |
+
|
| 19 |
+
from src.prompting import (
|
| 20 |
+
PLANNING_TABLE3_MODES,
|
| 21 |
+
build_prompt,
|
| 22 |
+
rewrite_planning_prompt_for_table3,
|
| 23 |
+
)
|
| 24 |
+
from src.audit.audit_utils import audit_enabled, audit_check
|
| 25 |
+
|
| 26 |
+
NUM_DETECTION_QUERIES = 256
|
| 27 |
+
NUM_MAP_QUERIES = 256
|
| 28 |
+
PLANNING_STATE_RANGE = (-50.0, 50.0)
|
| 29 |
+
PLANNING_NUM_BINS = 1000
|
| 30 |
+
|
| 31 |
+
AVAILABLE_COMMANDS = ["turn left", "turn right", "go straight"]
|
| 32 |
+
|
| 33 |
+
# Z range aligned with StreamPETR point_cloud_range [-5, 3]
|
| 34 |
+
Z_MIN, Z_MAX = -5.0, 3.0
|
| 35 |
+
|
| 36 |
+
# nuScenes 10-class detection 类别映射
|
| 37 |
+
NUSCENES_CLASSES = [
|
| 38 |
+
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
|
| 39 |
+
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
# 完整的 nuScenes 类别名映射到基础类别
|
| 43 |
+
NUSCENES_CATEGORY_MAP = {
|
| 44 |
+
# 基础类别名
|
| 45 |
+
'car': 0, 'truck': 1, 'construction_vehicle': 2, 'bus': 3, 'trailer': 4,
|
| 46 |
+
'barrier': 5, 'motorcycle': 6, 'bicycle': 7, 'pedestrian': 8, 'traffic_cone': 9,
|
| 47 |
+
# 完整 nuScenes 类别名 - 车辆
|
| 48 |
+
'vehicle.car': 0, 'vehicle.truck': 1, 'vehicle.construction': 2,
|
| 49 |
+
'vehicle.bus.bendy': 3, 'vehicle.bus.rigid': 3, 'vehicle.trailer': 4,
|
| 50 |
+
'vehicle.motorcycle': 6, 'vehicle.bicycle': 7,
|
| 51 |
+
# 完整 nuScenes 类别名 - 行人
|
| 52 |
+
'human.pedestrian.adult': 8, 'human.pedestrian.child': 8,
|
| 53 |
+
'human.pedestrian.construction_worker': 8, 'human.pedestrian.police_officer': 8,
|
| 54 |
+
'human.pedestrian.wheelchair': 8, 'human.pedestrian.stroller': 8,
|
| 55 |
+
'human.pedestrian.personal_mobility': 8,
|
| 56 |
+
# 完整 nuScenes 类别名 - 可移动物体
|
| 57 |
+
'movable_object.barrier': 5, 'movable_object.trafficcone': 9,
|
| 58 |
+
'movable_object.traffic_cone': 9,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def nuscenes_to_paper_coords(x_nuscenes: float, y_nuscenes: float) -> Tuple[float, float]:
|
| 63 |
+
return -y_nuscenes, x_nuscenes
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def planning_state_to_bin(
|
| 67 |
+
value: float,
|
| 68 |
+
min_val: float = PLANNING_STATE_RANGE[0],
|
| 69 |
+
max_val: float = PLANNING_STATE_RANGE[1],
|
| 70 |
+
num_bins: int = PLANNING_NUM_BINS,
|
| 71 |
+
) -> int:
|
| 72 |
+
v = float(np.clip(value, min_val, max_val))
|
| 73 |
+
t = (v - min_val) / (max_val - min_val)
|
| 74 |
+
idx = int(round(t * (num_bins - 1)))
|
| 75 |
+
return int(np.clip(idx, 0, num_bins - 1))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def normalize_route_command(command: object) -> Optional[str]:
|
| 79 |
+
if not isinstance(command, str):
|
| 80 |
+
return None
|
| 81 |
+
cmd = command.strip().lower()
|
| 82 |
+
mapping = {
|
| 83 |
+
"turn left": "turn left",
|
| 84 |
+
"left": "turn left",
|
| 85 |
+
"turn right": "turn right",
|
| 86 |
+
"right": "turn right",
|
| 87 |
+
"go straight": "go straight",
|
| 88 |
+
"straight": "go straight",
|
| 89 |
+
"keep straight": "go straight",
|
| 90 |
+
"forward": "go straight",
|
| 91 |
+
}
|
| 92 |
+
if cmd in mapping:
|
| 93 |
+
return mapping[cmd]
|
| 94 |
+
raise ValueError(f"Unsupported route command: {command!r}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
CAMERA_NAMES = [
|
| 98 |
+
'CAM_FRONT',
|
| 99 |
+
'CAM_FRONT_RIGHT',
|
| 100 |
+
'CAM_FRONT_LEFT',
|
| 101 |
+
'CAM_BACK',
|
| 102 |
+
'CAM_BACK_LEFT',
|
| 103 |
+
'CAM_BACK_RIGHT',
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
VALID_TASK_TYPES = {"detection", "lane", "planning", "caption"}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _normalize_task_type(task: object) -> Optional[str]:
|
| 111 |
+
if not isinstance(task, str):
|
| 112 |
+
return None
|
| 113 |
+
task_name = task.strip().lower()
|
| 114 |
+
if task_name in VALID_TASK_TYPES:
|
| 115 |
+
return task_name
|
| 116 |
+
raise ValueError(f"Unsupported task type: {task!r}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _infer_task_type_from_structure(item: dict) -> Optional[str]:
|
| 120 |
+
if not isinstance(item, dict):
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
num_map_queries = int(item.get("num_map_queries", -1))
|
| 125 |
+
except Exception:
|
| 126 |
+
num_map_queries = -1
|
| 127 |
+
|
| 128 |
+
if "ego_motion" in item or "gt_boxes_3d_per_timestep" in item:
|
| 129 |
+
return "planning"
|
| 130 |
+
if isinstance(item.get("camera"), str) and item.get("camera") and num_map_queries == 0:
|
| 131 |
+
return "caption"
|
| 132 |
+
if item.get("sensor") is not None or "openlane_lane_centerline" in item:
|
| 133 |
+
return "lane"
|
| 134 |
+
if "gt_boxes_3d" in item or num_map_queries == 0:
|
| 135 |
+
return "detection"
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def infer_task_type(item: dict) -> str:
|
| 140 |
+
task = _normalize_task_type(item.get("task"))
|
| 141 |
+
if task:
|
| 142 |
+
return task
|
| 143 |
+
inferred = _infer_task_type_from_structure(item)
|
| 144 |
+
if inferred:
|
| 145 |
+
return inferred
|
| 146 |
+
sample_id = item.get("id", "<unknown>")
|
| 147 |
+
raise ValueError(
|
| 148 |
+
f"Unable to infer task type for sample {sample_id!r}. "
|
| 149 |
+
"Please add an explicit 'task' field to the dataset JSON."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def load_tokenizer(model_name: str = "lmsys/vicuna-7b-v1.5"):
|
| 155 |
+
from transformers import AutoTokenizer
|
| 156 |
+
|
| 157 |
+
print(f"Loading tokenizer: {model_name}")
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 161 |
+
model_name,
|
| 162 |
+
use_fast=False,
|
| 163 |
+
trust_remote_code=True,
|
| 164 |
+
)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"Failed to load from {model_name}: {e}")
|
| 167 |
+
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
| 168 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 169 |
+
model_name,
|
| 170 |
+
use_fast=False,
|
| 171 |
+
trust_remote_code=True,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if tokenizer.pad_token is None:
|
| 175 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 176 |
+
|
| 177 |
+
return tokenizer
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class AtlasDataset(Dataset):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
json_file: str,
|
| 184 |
+
image_root: str,
|
| 185 |
+
tokenizer,
|
| 186 |
+
max_length: int = 4096,
|
| 187 |
+
image_size: Tuple[int, int] = (800, 1600),
|
| 188 |
+
use_nuscenes_calibration: bool = True,
|
| 189 |
+
is_training: bool = True,
|
| 190 |
+
num_detection_queries: int = NUM_DETECTION_QUERIES,
|
| 191 |
+
num_map_queries: int = NUM_MAP_QUERIES,
|
| 192 |
+
planning_table3_mode: str = "atlas_base",
|
| 193 |
+
image_path_remap: Optional[str] = None,
|
| 194 |
+
precomputed_det_tokens: Optional[str] = None,
|
| 195 |
+
precomputed_map_tokens: Optional[str] = None,
|
| 196 |
+
):
|
| 197 |
+
self.json_file = json_file
|
| 198 |
+
self.image_root = image_root
|
| 199 |
+
self.image_path_remap = None
|
| 200 |
+
if image_path_remap:
|
| 201 |
+
old, new = image_path_remap.split("=", 1)
|
| 202 |
+
self.image_path_remap = (old, new)
|
| 203 |
+
self.precomputed_det_dir = precomputed_det_tokens
|
| 204 |
+
self.precomputed_map_dir = precomputed_map_tokens
|
| 205 |
+
self.tokenizer = tokenizer
|
| 206 |
+
self.max_length = max_length
|
| 207 |
+
self.is_training = is_training
|
| 208 |
+
self.num_detection_queries = int(num_detection_queries)
|
| 209 |
+
self.num_map_queries = int(num_map_queries)
|
| 210 |
+
if planning_table3_mode not in PLANNING_TABLE3_MODES:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"Unsupported planning_table3_mode: {planning_table3_mode}. "
|
| 213 |
+
f"Expected one of {PLANNING_TABLE3_MODES}."
|
| 214 |
+
)
|
| 215 |
+
self.planning_table3_mode = planning_table3_mode
|
| 216 |
+
self.use_nuscenes_calibration = bool(use_nuscenes_calibration)
|
| 217 |
+
|
| 218 |
+
if isinstance(image_size, int):
|
| 219 |
+
self.image_size = (image_size, image_size)
|
| 220 |
+
else:
|
| 221 |
+
self.image_size = image_size
|
| 222 |
+
|
| 223 |
+
paths = [p.strip() for p in str(json_file).split(",") if p.strip()]
|
| 224 |
+
if len(paths) == 0:
|
| 225 |
+
raise RuntimeError("json_file is empty")
|
| 226 |
+
self.data = []
|
| 227 |
+
for p in paths:
|
| 228 |
+
with open(p, "r", encoding="utf-8") as f:
|
| 229 |
+
chunk = json.load(f)
|
| 230 |
+
if not isinstance(chunk, list):
|
| 231 |
+
raise RuntimeError(f"JSON must be a list: {p}")
|
| 232 |
+
self.data.extend(chunk)
|
| 233 |
+
|
| 234 |
+
if len(paths) == 1:
|
| 235 |
+
print(f"Loaded {len(self.data)} samples from {paths[0]}")
|
| 236 |
+
else:
|
| 237 |
+
print(f"Loaded {len(self.data)} samples from {len(paths)} json files")
|
| 238 |
+
print(f"Image size: {self.image_size[0]}x{self.image_size[1]} (HxW)")
|
| 239 |
+
|
| 240 |
+
self._task_types = [infer_task_type(item) for item in self.data]
|
| 241 |
+
from collections import Counter
|
| 242 |
+
print(f"Task distribution: {dict(Counter(self._task_types))}")
|
| 243 |
+
self._audit_planning_route_command_schema()
|
| 244 |
+
|
| 245 |
+
if _HAS_V2:
|
| 246 |
+
self.image_transform = transforms.Compose([
|
| 247 |
+
transforms.ToImage(),
|
| 248 |
+
transforms.ToDtype(torch.float32, scale=True),
|
| 249 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 250 |
+
])
|
| 251 |
+
else:
|
| 252 |
+
self.image_transform = transforms.Compose([
|
| 253 |
+
transforms.ToTensor(),
|
| 254 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 255 |
+
])
|
| 256 |
+
|
| 257 |
+
self.streampetr_conf = {
|
| 258 |
+
"H": 900, "W": 1600,
|
| 259 |
+
"final_dim": (800, 1600),
|
| 260 |
+
"resize_lim": (1.0, 1.2),
|
| 261 |
+
"bot_pct_lim": (0.0, 0.0),
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
self.topomlp_conf = {
|
| 265 |
+
"target_size": (1600, 800),
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
self.calibration = None
|
| 269 |
+
if self.use_nuscenes_calibration:
|
| 270 |
+
self.calibration = self._load_nuscenes_calibration()
|
| 271 |
+
if self.calibration is None:
|
| 272 |
+
print("[WARN] nuScenes calibration metadata not found, will use per-item sensor field.")
|
| 273 |
+
|
| 274 |
+
self.query_token = "<query>"
|
| 275 |
+
vocab = tokenizer.get_vocab()
|
| 276 |
+
if self.query_token not in vocab:
|
| 277 |
+
raise RuntimeError(f"Tokenizer missing required special token: {self.query_token}")
|
| 278 |
+
if tokenizer.pad_token_id is None:
|
| 279 |
+
raise RuntimeError("tokenizer.pad_token_id is None")
|
| 280 |
+
|
| 281 |
+
self.query_token_id = tokenizer.convert_tokens_to_ids(self.query_token)
|
| 282 |
+
print(f"<query> token ID: {self.query_token_id}")
|
| 283 |
+
|
| 284 |
+
def _load_nuscenes_calibration(self) -> Optional[Dict]:
|
| 285 |
+
try:
|
| 286 |
+
nuscenes_root = Path(self.image_root)
|
| 287 |
+
version_dir = None
|
| 288 |
+
for v in ["v1.0-trainval", "v1.0-mini", "v1.0-test"]:
|
| 289 |
+
if (nuscenes_root / v).exists():
|
| 290 |
+
version_dir = nuscenes_root / v
|
| 291 |
+
break
|
| 292 |
+
if version_dir is None:
|
| 293 |
+
print("nuScenes metadata folder not found under image_root")
|
| 294 |
+
return None
|
| 295 |
+
|
| 296 |
+
sample_data_file = version_dir / "sample_data.json"
|
| 297 |
+
calibrated_sensor_file = version_dir / "calibrated_sensor.json"
|
| 298 |
+
ego_pose_file = version_dir / "ego_pose.json"
|
| 299 |
+
sample_file = version_dir / "sample.json"
|
| 300 |
+
if (not sample_data_file.exists()) or (not calibrated_sensor_file.exists()) or (not ego_pose_file.exists()) or (not sample_file.exists()):
|
| 301 |
+
print(f"nuScenes metadata missing under {version_dir}")
|
| 302 |
+
return None
|
| 303 |
+
|
| 304 |
+
with open(sample_data_file, "r") as f:
|
| 305 |
+
sample_data = json.load(f)
|
| 306 |
+
with open(calibrated_sensor_file, "r") as f:
|
| 307 |
+
calibrated_sensor = json.load(f)
|
| 308 |
+
with open(ego_pose_file, "r") as f:
|
| 309 |
+
ego_pose = json.load(f)
|
| 310 |
+
with open(sample_file, "r") as f:
|
| 311 |
+
sample = json.load(f)
|
| 312 |
+
|
| 313 |
+
sample_data_by_filename = {rec["filename"]: rec for rec in sample_data if "filename" in rec}
|
| 314 |
+
calibrated_sensor_by_token = {rec["token"]: rec for rec in calibrated_sensor if "token" in rec}
|
| 315 |
+
ego_pose_by_token = {rec["token"]: rec for rec in ego_pose if "token" in rec}
|
| 316 |
+
sample_by_token = {rec["token"]: rec for rec in sample if "token" in rec}
|
| 317 |
+
|
| 318 |
+
lidar_sd_by_sample_token: Dict[str, Dict] = {}
|
| 319 |
+
for rec in sample_data:
|
| 320 |
+
if not isinstance(rec, dict):
|
| 321 |
+
continue
|
| 322 |
+
fn = str(rec.get("filename", "")).replace("\\", "/")
|
| 323 |
+
if "/LIDAR_TOP/" not in fn:
|
| 324 |
+
continue
|
| 325 |
+
if not fn.startswith("samples/"):
|
| 326 |
+
continue
|
| 327 |
+
if not bool(rec.get("is_key_frame", False)):
|
| 328 |
+
continue
|
| 329 |
+
st = rec.get("sample_token", None)
|
| 330 |
+
if st is None:
|
| 331 |
+
continue
|
| 332 |
+
lidar_sd_by_sample_token.setdefault(str(st), rec)
|
| 333 |
+
|
| 334 |
+
print(f"Loaded nuScenes metadata from {version_dir.name}:")
|
| 335 |
+
print(f" sample_data: {len(sample_data_by_filename)}")
|
| 336 |
+
print(f" calibrated_sensor: {len(calibrated_sensor_by_token)}")
|
| 337 |
+
print(f" ego_pose: {len(ego_pose_by_token)}")
|
| 338 |
+
print(f" sample: {len(sample_by_token)}")
|
| 339 |
+
print(f" lidar_keyframes: {len(lidar_sd_by_sample_token)}")
|
| 340 |
+
|
| 341 |
+
return {
|
| 342 |
+
"sample_data_by_filename": sample_data_by_filename,
|
| 343 |
+
"calibrated_sensor_by_token": calibrated_sensor_by_token,
|
| 344 |
+
"ego_pose_by_token": ego_pose_by_token,
|
| 345 |
+
"sample_by_token": sample_by_token,
|
| 346 |
+
"lidar_sd_by_sample_token": lidar_sd_by_sample_token,
|
| 347 |
+
}
|
| 348 |
+
except Exception as e:
|
| 349 |
+
print(f"Failed to load nuScenes calibration: {e}")
|
| 350 |
+
return None
|
| 351 |
+
|
| 352 |
+
def __len__(self) -> int:
|
| 353 |
+
return len(self.data)
|
| 354 |
+
|
| 355 |
+
def _audit_planning_route_command_schema(self) -> None:
|
| 356 |
+
planning_indices = [i for i, task in enumerate(self._task_types) if task == "planning"]
|
| 357 |
+
if not planning_indices:
|
| 358 |
+
return
|
| 359 |
+
|
| 360 |
+
total = len(planning_indices)
|
| 361 |
+
top_level_count = 0
|
| 362 |
+
legacy_ego_motion_command = 0
|
| 363 |
+
route_command_dist: Dict[str, int] = {}
|
| 364 |
+
|
| 365 |
+
for idx in planning_indices:
|
| 366 |
+
item = self.data[idx]
|
| 367 |
+
try:
|
| 368 |
+
command = self._resolve_route_command(item)
|
| 369 |
+
except Exception as exc:
|
| 370 |
+
sample_id = item.get("id", f"planning_idx_{idx}")
|
| 371 |
+
print(f"[WARN] invalid planning route_command for sample {sample_id}: {exc}")
|
| 372 |
+
command = None
|
| 373 |
+
if command is not None:
|
| 374 |
+
top_level_count += 1
|
| 375 |
+
route_command_dist[command] = route_command_dist.get(command, 0) + 1
|
| 376 |
+
ego = item.get("ego_motion")
|
| 377 |
+
if isinstance(ego, dict) and "command" in ego:
|
| 378 |
+
legacy_ego_motion_command += 1
|
| 379 |
+
|
| 380 |
+
print(
|
| 381 |
+
"[Planning route_command audit] "
|
| 382 |
+
f"mode={self.planning_table3_mode} "
|
| 383 |
+
f"top_level_coverage={top_level_count}/{total} "
|
| 384 |
+
f"legacy_ego_motion_command={legacy_ego_motion_command}/{total} "
|
| 385 |
+
f"distribution={route_command_dist}"
|
| 386 |
+
)
|
| 387 |
+
if self.planning_table3_mode != "atlas_base" and top_level_count < total:
|
| 388 |
+
print(
|
| 389 |
+
"[WARN] planning high-level mode requested but top-level "
|
| 390 |
+
"route_command coverage is incomplete."
|
| 391 |
+
)
|
| 392 |
+
if legacy_ego_motion_command > 0:
|
| 393 |
+
print(
|
| 394 |
+
"[WARN] legacy planning schema detected: ego_motion.command "
|
| 395 |
+
"is still present in the loaded JSON."
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 399 |
+
item = self.data[idx]
|
| 400 |
+
task_type = self._task_types[idx]
|
| 401 |
+
|
| 402 |
+
img_paths = item["image_paths"]
|
| 403 |
+
if self.image_path_remap:
|
| 404 |
+
old, new = self.image_path_remap
|
| 405 |
+
img_paths = [p.replace(old, new) for p in img_paths]
|
| 406 |
+
|
| 407 |
+
cam_out = self._load_images_with_cameras(
|
| 408 |
+
img_paths,
|
| 409 |
+
item=item,
|
| 410 |
+
)
|
| 411 |
+
pixel_values = cam_out["pixel_values"]
|
| 412 |
+
intrinsics = cam_out["intrinsics"]
|
| 413 |
+
extrinsics = cam_out["extrinsics"]
|
| 414 |
+
lidar2img = cam_out["lidar2img"]
|
| 415 |
+
ego_pose = cam_out.get("ego_pose")
|
| 416 |
+
ego_pose_inv = cam_out.get("ego_pose_inv")
|
| 417 |
+
timestamp = cam_out.get("timestamp")
|
| 418 |
+
|
| 419 |
+
prompt_raw, answer_raw = self._extract_conversation(item)
|
| 420 |
+
if task_type == "planning":
|
| 421 |
+
prompt_raw = self._rewrite_planning_prompt(prompt_raw, item)
|
| 422 |
+
expected_num_queries = self._infer_expected_query_count(prompt_raw, item=item)
|
| 423 |
+
prompt_text = self._expand_query_placeholders(
|
| 424 |
+
prompt_raw,
|
| 425 |
+
expected_num_queries,
|
| 426 |
+
item=item,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
prompt_str = build_prompt(prompt_text, mode="train" if self.is_training else "infer")
|
| 430 |
+
answer_str = f" {answer_raw}"
|
| 431 |
+
|
| 432 |
+
prompt_ids = self.tokenizer(prompt_str, add_special_tokens=False)["input_ids"]
|
| 433 |
+
answer_ids = self.tokenizer(answer_str, add_special_tokens=False)["input_ids"]
|
| 434 |
+
|
| 435 |
+
bos = [self.tokenizer.bos_token_id] if self.tokenizer.bos_token_id is not None else []
|
| 436 |
+
eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
|
| 437 |
+
|
| 438 |
+
if self.is_training:
|
| 439 |
+
input_ids_full = bos + prompt_ids + answer_ids + eos
|
| 440 |
+
else:
|
| 441 |
+
input_ids_full = bos + prompt_ids
|
| 442 |
+
|
| 443 |
+
input_ids = input_ids_full
|
| 444 |
+
if len(input_ids) > self.max_length:
|
| 445 |
+
input_ids = input_ids[: self.max_length]
|
| 446 |
+
attention_mask = [1] * len(input_ids)
|
| 447 |
+
|
| 448 |
+
num_query_tokens = sum(1 for t in input_ids if t == self.query_token_id)
|
| 449 |
+
if num_query_tokens != expected_num_queries:
|
| 450 |
+
raise ValueError(f"<query> mismatch: expected {expected_num_queries}, got {num_query_tokens}")
|
| 451 |
+
|
| 452 |
+
if self.is_training:
|
| 453 |
+
labels = input_ids.copy()
|
| 454 |
+
prompt_len = len(bos) + len(prompt_ids)
|
| 455 |
+
labels[:prompt_len] = [-100] * prompt_len
|
| 456 |
+
|
| 457 |
+
first_nonmasked_index = -1
|
| 458 |
+
for i, t in enumerate(labels):
|
| 459 |
+
if t != -100:
|
| 460 |
+
first_nonmasked_index = i
|
| 461 |
+
break
|
| 462 |
+
labels_nonmasked_count = sum(1 for t in labels if t != -100)
|
| 463 |
+
|
| 464 |
+
assert len(labels) == len(input_ids), "labels/input_ids length mismatch"
|
| 465 |
+
if labels_nonmasked_count > 0:
|
| 466 |
+
assert first_nonmasked_index == prompt_len, (
|
| 467 |
+
f"first_nonmasked_index={first_nonmasked_index} != prompt_len={prompt_len}"
|
| 468 |
+
)
|
| 469 |
+
assert labels_nonmasked_count > 0, (
|
| 470 |
+
f"all labels are -100: len_full={len(input_ids_full)} len_trunc={len(input_ids)} max_length={self.max_length}"
|
| 471 |
+
)
|
| 472 |
+
if audit_enabled():
|
| 473 |
+
if not hasattr(self, "_audit_trunc_total"):
|
| 474 |
+
self._audit_trunc_total = 0
|
| 475 |
+
self._audit_trunc_hits = 0
|
| 476 |
+
self._audit_trunc_total += 1
|
| 477 |
+
truncated = int(len(input_ids_full) > self.max_length)
|
| 478 |
+
self._audit_trunc_hits += truncated
|
| 479 |
+
trunc_rate = float(self._audit_trunc_hits) / float(self._audit_trunc_total)
|
| 480 |
+
min_ans = int(os.getenv("ATLAS_MIN_ANSWER_TOKENS", "16"))
|
| 481 |
+
ok = (labels_nonmasked_count >= min_ans)
|
| 482 |
+
audit_check(
|
| 483 |
+
"A6",
|
| 484 |
+
ok,
|
| 485 |
+
once=False,
|
| 486 |
+
truncated=truncated,
|
| 487 |
+
trunc_rate=trunc_rate,
|
| 488 |
+
labels_nonmasked_count=labels_nonmasked_count,
|
| 489 |
+
min_answer_tokens=min_ans,
|
| 490 |
+
)
|
| 491 |
+
else:
|
| 492 |
+
labels = [-100] * len(input_ids)
|
| 493 |
+
prompt_len = len(input_ids)
|
| 494 |
+
first_nonmasked_index = -1
|
| 495 |
+
labels_nonmasked_count = 0
|
| 496 |
+
scene_id = self._get_scene_id(item)
|
| 497 |
+
sample_id = str(item.get("id", idx))
|
| 498 |
+
|
| 499 |
+
gt_boxes, gt_labels = self._load_gt_boxes(item)
|
| 500 |
+
|
| 501 |
+
result = {
|
| 502 |
+
"pixel_values": pixel_values,
|
| 503 |
+
"pixel_values_det": cam_out["pixel_values_det"],
|
| 504 |
+
"pixel_values_map": cam_out["pixel_values_map"],
|
| 505 |
+
"intrinsics": intrinsics,
|
| 506 |
+
"intrinsics_det": cam_out["intrinsics_det"],
|
| 507 |
+
"intrinsics_map": cam_out["intrinsics_map"],
|
| 508 |
+
"extrinsics": extrinsics,
|
| 509 |
+
"lidar2img": lidar2img,
|
| 510 |
+
"lidar2img_det": cam_out["lidar2img_det"],
|
| 511 |
+
"lidar2img_map": cam_out["lidar2img_map"],
|
| 512 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 513 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
| 514 |
+
"labels": torch.tensor(labels, dtype=torch.long),
|
| 515 |
+
"scene_id": scene_id,
|
| 516 |
+
"sample_id": sample_id,
|
| 517 |
+
"dataset_idx": torch.tensor(idx, dtype=torch.long),
|
| 518 |
+
"task_type": task_type,
|
| 519 |
+
"audit_prompt_len": torch.tensor(prompt_len, dtype=torch.long),
|
| 520 |
+
"audit_answer_len": torch.tensor(len(answer_ids), dtype=torch.long),
|
| 521 |
+
"audit_labels_nonmasked_count": torch.tensor(labels_nonmasked_count, dtype=torch.long),
|
| 522 |
+
"audit_first_nonmasked_index": torch.tensor(first_nonmasked_index, dtype=torch.long),
|
| 523 |
+
"audit_num_query_tokens_in_input_ids": torch.tensor(num_query_tokens, dtype=torch.long),
|
| 524 |
+
"audit_expected_num_queries": torch.tensor(expected_num_queries, dtype=torch.long),
|
| 525 |
+
"audit_truncated": torch.tensor(int(len(input_ids_full) > self.max_length), dtype=torch.long),
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
if ego_pose is not None:
|
| 529 |
+
result["ego_pose"] = ego_pose
|
| 530 |
+
if ego_pose_inv is not None:
|
| 531 |
+
result["ego_pose_inv"] = ego_pose_inv
|
| 532 |
+
if timestamp is not None:
|
| 533 |
+
result["timestamp"] = timestamp
|
| 534 |
+
try:
|
| 535 |
+
ego_motion = self._get_ego_motion_data(item)
|
| 536 |
+
if isinstance(ego_motion, dict) and "velocity" in ego_motion:
|
| 537 |
+
result["velocity"] = torch.tensor(ego_motion["velocity"], dtype=torch.float32)
|
| 538 |
+
except Exception:
|
| 539 |
+
pass
|
| 540 |
+
|
| 541 |
+
if gt_boxes is not None:
|
| 542 |
+
result["gt_boxes"] = gt_boxes
|
| 543 |
+
result["gt_labels"] = gt_labels
|
| 544 |
+
|
| 545 |
+
if self.precomputed_det_dir:
|
| 546 |
+
pt = self._load_precomputed(self.precomputed_det_dir, item)
|
| 547 |
+
if pt is not None:
|
| 548 |
+
result["precomputed_det"] = pt["detection"]
|
| 549 |
+
result["precomputed_det_ref"] = pt["detection_ref_points"]
|
| 550 |
+
|
| 551 |
+
if self.precomputed_map_dir:
|
| 552 |
+
mpt = self._load_precomputed(self.precomputed_map_dir, item)
|
| 553 |
+
if mpt is not None:
|
| 554 |
+
result["precomputed_map"] = mpt
|
| 555 |
+
|
| 556 |
+
if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
|
| 557 |
+
max_samples = int(os.getenv("ATLAS_AUDIT_MAX_SAMPLES", "1"))
|
| 558 |
+
if idx < max_samples:
|
| 559 |
+
print(
|
| 560 |
+
"[ATLAS_AUDIT][A1/A3] "
|
| 561 |
+
f"idx={idx} "
|
| 562 |
+
f"prompt_len={prompt_len} answer_len={len(answer_ids)} "
|
| 563 |
+
f"first_nonmasked_index={first_nonmasked_index} labels_nonmasked_count={labels_nonmasked_count} "
|
| 564 |
+
f"num_query_tokens_in_input_ids={num_query_tokens} expected_num_queries={expected_num_queries} "
|
| 565 |
+
f"seq_len={len(input_ids)} truncated={int(len(input_ids_full) > self.max_length)}"
|
| 566 |
+
)
|
| 567 |
+
return result
|
| 568 |
+
|
| 569 |
+
def _resolve_category_id(self, label) -> int:
|
| 570 |
+
"""将类别标签转换为类别 ID,支持整数、浮点数和字符串类别名"""
|
| 571 |
+
if isinstance(label, (int, float)):
|
| 572 |
+
return int(label)
|
| 573 |
+
if isinstance(label, str):
|
| 574 |
+
label_lower = label.lower().strip()
|
| 575 |
+
if label_lower in NUSCENES_CATEGORY_MAP:
|
| 576 |
+
return NUSCENES_CATEGORY_MAP[label_lower]
|
| 577 |
+
# 对 human.pedestrian.* 子类使用前缀匹配
|
| 578 |
+
if label_lower.startswith('human.pedestrian.'):
|
| 579 |
+
return 8 # pedestrian
|
| 580 |
+
# 对 vehicle.bus.* 子类使用前缀匹配
|
| 581 |
+
if label_lower.startswith('vehicle.bus.'):
|
| 582 |
+
return 3 # bus
|
| 583 |
+
return 0
|
| 584 |
+
|
| 585 |
+
def _load_gt_boxes(self, item: Dict) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 586 |
+
annotations = item.get("annotations", item.get("gt_boxes_3d", None))
|
| 587 |
+
|
| 588 |
+
if annotations is None:
|
| 589 |
+
return None, None
|
| 590 |
+
|
| 591 |
+
if not annotations:
|
| 592 |
+
return torch.zeros(0, 7), torch.zeros(0, dtype=torch.long)
|
| 593 |
+
|
| 594 |
+
boxes_list = []
|
| 595 |
+
labels_list = []
|
| 596 |
+
bev_range = 51.2
|
| 597 |
+
|
| 598 |
+
for ann in annotations:
|
| 599 |
+
if isinstance(ann, dict):
|
| 600 |
+
if "translation" in ann:
|
| 601 |
+
pos_x, pos_y, pos_z = ann["translation"][:3]
|
| 602 |
+
size_w, size_l, size_h = ann.get("size", [1, 1, 1])[:3]
|
| 603 |
+
yaw = ann.get("rotation", 0)
|
| 604 |
+
if isinstance(yaw, (list, tuple)):
|
| 605 |
+
if len(yaw) == 1:
|
| 606 |
+
yaw = float(yaw[0])
|
| 607 |
+
elif len(yaw) == 4:
|
| 608 |
+
# Quaternion (w, x, y, z) -> yaw
|
| 609 |
+
qw, qx, qy, qz = [float(v) for v in yaw]
|
| 610 |
+
t0 = 2.0 * (qw * qz + qx * qy)
|
| 611 |
+
t1 = 1.0 - 2.0 * (qy * qy + qz * qz)
|
| 612 |
+
yaw = math.atan2(t0, t1)
|
| 613 |
+
else:
|
| 614 |
+
yaw = 0.0
|
| 615 |
+
# 支持多种类别字段: category_id, category, category_name
|
| 616 |
+
label = ann.get("category_id", ann.get("category", ann.get("category_name", 0)))
|
| 617 |
+
x, y, z = pos_x, pos_y, pos_z
|
| 618 |
+
w, l, h = size_w, size_l, size_h
|
| 619 |
+
|
| 620 |
+
x_norm = (x + bev_range) / (2 * bev_range)
|
| 621 |
+
y_norm = (y + bev_range) / (2 * bev_range)
|
| 622 |
+
z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
|
| 623 |
+
|
| 624 |
+
boxes_list.append([x_norm, y_norm, z_norm, w, l, h, yaw])
|
| 625 |
+
labels_list.append(self._resolve_category_id(label))
|
| 626 |
+
|
| 627 |
+
elif "box" in ann:
|
| 628 |
+
box = ann["box"]
|
| 629 |
+
x, y, z = box[:3]
|
| 630 |
+
w, l, h = box[3:6] if len(box) >= 6 else (1, 1, 1)
|
| 631 |
+
yaw = box[6] if len(box) >= 7 else 0
|
| 632 |
+
label = ann.get("category_id", ann.get("label", ann.get("category_name", 0)))
|
| 633 |
+
|
| 634 |
+
x_norm = (x + bev_range) / (2 * bev_range)
|
| 635 |
+
y_norm = (y + bev_range) / (2 * bev_range)
|
| 636 |
+
z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
|
| 637 |
+
|
| 638 |
+
boxes_list.append([x_norm, y_norm, z_norm, w, l, h, yaw])
|
| 639 |
+
labels_list.append(self._resolve_category_id(label))
|
| 640 |
+
|
| 641 |
+
elif "world_coords" in ann:
|
| 642 |
+
wc = ann["world_coords"]
|
| 643 |
+
x, y, z = wc[0], wc[1], wc[2] if len(wc) > 2 else 0.0
|
| 644 |
+
w = ann.get("w", 1.0)
|
| 645 |
+
l = ann.get("l", 1.0)
|
| 646 |
+
h = ann.get("h", 1.0)
|
| 647 |
+
yaw = ann.get("yaw", 0.0)
|
| 648 |
+
label = ann.get("category_id", ann.get("category", ann.get("category_name", 0)))
|
| 649 |
+
|
| 650 |
+
x_norm = (x + bev_range) / (2 * bev_range)
|
| 651 |
+
y_norm = (y + bev_range) / (2 * bev_range)
|
| 652 |
+
z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
|
| 653 |
+
|
| 654 |
+
boxes_list.append([x_norm, y_norm, z_norm, w, l, h, yaw])
|
| 655 |
+
labels_list.append(self._resolve_category_id(label))
|
| 656 |
+
|
| 657 |
+
elif isinstance(ann, (list, tuple)) and len(ann) >= 3:
|
| 658 |
+
x, y, z = ann[:3]
|
| 659 |
+
w, l, h = ann[3:6] if len(ann) >= 6 else (1, 1, 1)
|
| 660 |
+
yaw = ann[6] if len(ann) >= 7 else 0
|
| 661 |
+
|
| 662 |
+
x_norm = (x + bev_range) / (2 * bev_range)
|
| 663 |
+
y_norm = (y + bev_range) / (2 * bev_range)
|
| 664 |
+
z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
|
| 665 |
+
|
| 666 |
+
boxes_list.append([x_norm, y_norm, z_norm, w, l, h, yaw])
|
| 667 |
+
labels_list.append(0)
|
| 668 |
+
|
| 669 |
+
if not boxes_list:
|
| 670 |
+
return torch.zeros(0, 7), torch.zeros(0, dtype=torch.long)
|
| 671 |
+
|
| 672 |
+
if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
|
| 673 |
+
if not hasattr(self, "_audit_gt_calls"):
|
| 674 |
+
self._audit_gt_calls = 0
|
| 675 |
+
max_calls = int(os.getenv("ATLAS_AUDIT_MAX_GT", "1"))
|
| 676 |
+
if self._audit_gt_calls < max_calls:
|
| 677 |
+
yaws = [float(b[6]) for b in boxes_list if len(b) >= 7]
|
| 678 |
+
if yaws:
|
| 679 |
+
y_min = float(min(yaws))
|
| 680 |
+
y_max = float(max(yaws))
|
| 681 |
+
y_abs = float(max(abs(y) for y in yaws))
|
| 682 |
+
print(f"[ATLAS_AUDIT][E2/E3] gt_yaw_min={y_min:.3e} gt_yaw_max={y_max:.3e} gt_yaw_absmax={y_abs:.3e}")
|
| 683 |
+
if y_abs > 10.0:
|
| 684 |
+
print("[ATLAS_AUDIT][E2] yaw_absmax>10 (possible degrees instead of radians)")
|
| 685 |
+
self._audit_gt_calls += 1
|
| 686 |
+
|
| 687 |
+
gt_boxes = torch.tensor(boxes_list, dtype=torch.float32)
|
| 688 |
+
gt_labels = torch.tensor(labels_list, dtype=torch.long)
|
| 689 |
+
|
| 690 |
+
return gt_boxes, gt_labels
|
| 691 |
+
|
| 692 |
+
def _rewrite_planning_prompt(self, prompt_text: str, item: Dict) -> str:
|
| 693 |
+
ego_motion = item.get("ego_motion", {})
|
| 694 |
+
if not isinstance(ego_motion, dict):
|
| 695 |
+
ego_motion = {}
|
| 696 |
+
|
| 697 |
+
route_command = self._resolve_route_command(item)
|
| 698 |
+
velocity = ego_motion.get("velocity")
|
| 699 |
+
acceleration = ego_motion.get("acceleration")
|
| 700 |
+
|
| 701 |
+
velocity_bins = None
|
| 702 |
+
acceleration_bins = None
|
| 703 |
+
if velocity is not None:
|
| 704 |
+
if not isinstance(velocity, (list, tuple)) or len(velocity) < 2:
|
| 705 |
+
raise RuntimeError("planning ego_motion.velocity must be a 2D vector")
|
| 706 |
+
velocity_bins = (
|
| 707 |
+
planning_state_to_bin(float(velocity[0])),
|
| 708 |
+
planning_state_to_bin(float(velocity[1])),
|
| 709 |
+
)
|
| 710 |
+
if acceleration is not None:
|
| 711 |
+
if not isinstance(acceleration, (list, tuple)) or len(acceleration) < 2:
|
| 712 |
+
raise RuntimeError("planning ego_motion.acceleration must be a 2D vector")
|
| 713 |
+
acceleration_bins = (
|
| 714 |
+
planning_state_to_bin(float(acceleration[0])),
|
| 715 |
+
planning_state_to_bin(float(acceleration[1])),
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
return rewrite_planning_prompt_for_table3(
|
| 719 |
+
prompt_text,
|
| 720 |
+
mode=self.planning_table3_mode,
|
| 721 |
+
command=route_command,
|
| 722 |
+
velocity_bins=velocity_bins,
|
| 723 |
+
acceleration_bins=acceleration_bins,
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
def _resolve_route_command(self, item: Dict) -> Optional[str]:
|
| 727 |
+
candidates = [
|
| 728 |
+
item.get("route_command"),
|
| 729 |
+
item.get("nav_command"),
|
| 730 |
+
item.get("high_level_command"),
|
| 731 |
+
item.get("navigation_command"),
|
| 732 |
+
]
|
| 733 |
+
meta = item.get("meta_data")
|
| 734 |
+
if isinstance(meta, dict):
|
| 735 |
+
candidates.extend([
|
| 736 |
+
meta.get("route_command"),
|
| 737 |
+
meta.get("nav_command"),
|
| 738 |
+
meta.get("high_level_command"),
|
| 739 |
+
meta.get("navigation_command"),
|
| 740 |
+
])
|
| 741 |
+
for candidate in candidates:
|
| 742 |
+
normalized = normalize_route_command(candidate)
|
| 743 |
+
if normalized is not None:
|
| 744 |
+
return normalized
|
| 745 |
+
return None
|
| 746 |
+
|
| 747 |
+
def _extract_conversation(self, item: Dict) -> Tuple[str, str]:
|
| 748 |
+
conv = item.get("conversations", None)
|
| 749 |
+
if not conv or not isinstance(conv, list):
|
| 750 |
+
raise RuntimeError("missing conversations field")
|
| 751 |
+
prompt = None
|
| 752 |
+
answer = None
|
| 753 |
+
for turn in conv:
|
| 754 |
+
if not isinstance(turn, dict):
|
| 755 |
+
continue
|
| 756 |
+
if turn.get("from") in ("human", "user") and prompt is None:
|
| 757 |
+
prompt = turn.get("value")
|
| 758 |
+
if turn.get("from") in ("gpt", "assistant") and answer is None:
|
| 759 |
+
answer = turn.get("value")
|
| 760 |
+
if prompt is None or answer is None:
|
| 761 |
+
raise RuntimeError("conversations missing human/gpt pair")
|
| 762 |
+
|
| 763 |
+
return str(prompt), str(answer)
|
| 764 |
+
|
| 765 |
+
def _infer_expected_query_count(self, prompt_text: str, item: Optional[Dict] = None) -> int:
|
| 766 |
+
if isinstance(item, dict):
|
| 767 |
+
if "num_map_queries" in item:
|
| 768 |
+
try:
|
| 769 |
+
num_map = int(item.get("num_map_queries", 0))
|
| 770 |
+
if num_map < 0:
|
| 771 |
+
num_map = 0
|
| 772 |
+
return self.num_detection_queries + num_map
|
| 773 |
+
except Exception:
|
| 774 |
+
pass
|
| 775 |
+
use_map = item.get("use_map_queries", None)
|
| 776 |
+
if use_map is not None:
|
| 777 |
+
return self.num_detection_queries + (self.num_map_queries if bool(use_map) else 0)
|
| 778 |
+
p = prompt_text.lower()
|
| 779 |
+
if "map query" in p:
|
| 780 |
+
return self.num_detection_queries + self.num_map_queries
|
| 781 |
+
return self.num_detection_queries
|
| 782 |
+
|
| 783 |
+
def _expand_query_placeholders(
|
| 784 |
+
self,
|
| 785 |
+
prompt_text: str,
|
| 786 |
+
expected_num_queries: int,
|
| 787 |
+
item: Optional[Dict] = None,
|
| 788 |
+
) -> str:
|
| 789 |
+
cnt = prompt_text.count(self.query_token)
|
| 790 |
+
if cnt == 1:
|
| 791 |
+
placeholder = " ".join([self.query_token] * expected_num_queries)
|
| 792 |
+
out = prompt_text.replace(self.query_token, f" {placeholder} ")
|
| 793 |
+
return " ".join(out.split())
|
| 794 |
+
if cnt == 2 and expected_num_queries > self.num_detection_queries:
|
| 795 |
+
num_map = expected_num_queries - self.num_detection_queries
|
| 796 |
+
if isinstance(item, dict) and "num_map_queries" in item:
|
| 797 |
+
try:
|
| 798 |
+
num_map = int(item.get("num_map_queries", num_map))
|
| 799 |
+
except Exception:
|
| 800 |
+
pass
|
| 801 |
+
num_map = max(0, min(num_map, expected_num_queries))
|
| 802 |
+
num_det = expected_num_queries - num_map
|
| 803 |
+
parts = prompt_text.split(self.query_token)
|
| 804 |
+
if len(parts) == 3 and num_det > 0 and num_map > 0:
|
| 805 |
+
# Latest planning prompts use two placeholders: the first stands
|
| 806 |
+
# for detection query slots and the second for map query slots.
|
| 807 |
+
det_placeholder = " ".join([self.query_token] * num_det)
|
| 808 |
+
map_placeholder = " ".join([self.query_token] * num_map)
|
| 809 |
+
out = (
|
| 810 |
+
f"{parts[0]} {det_placeholder} "
|
| 811 |
+
f"{parts[1]} {map_placeholder} "
|
| 812 |
+
f"{parts[2]}"
|
| 813 |
+
)
|
| 814 |
+
return " ".join(out.split())
|
| 815 |
+
if cnt == expected_num_queries:
|
| 816 |
+
return " ".join(prompt_text.split())
|
| 817 |
+
if expected_num_queries > self.num_detection_queries:
|
| 818 |
+
raise ValueError(
|
| 819 |
+
f"<query> count mismatch: got {cnt}, expected 1, 2, or {expected_num_queries}"
|
| 820 |
+
)
|
| 821 |
+
raise ValueError(f"<query> count mismatch: got {cnt}, expected 1 or {expected_num_queries}")
|
| 822 |
+
|
| 823 |
+
def _load_precomputed(self, directory: str, item: Dict) -> Optional[Dict]:
|
| 824 |
+
item_id = str(item.get("id", ""))
|
| 825 |
+
pt_path = os.path.join(directory, f"{item_id}.pt")
|
| 826 |
+
if os.path.isfile(pt_path):
|
| 827 |
+
try:
|
| 828 |
+
return torch.load(pt_path, map_location="cpu")
|
| 829 |
+
except Exception:
|
| 830 |
+
return None
|
| 831 |
+
meta = item.get("meta_data", {})
|
| 832 |
+
if isinstance(meta, dict):
|
| 833 |
+
source_id = meta.get("source_id")
|
| 834 |
+
if source_id:
|
| 835 |
+
pt_path2 = os.path.join(directory, f"{source_id}.pt")
|
| 836 |
+
if os.path.isfile(pt_path2):
|
| 837 |
+
try:
|
| 838 |
+
return torch.load(pt_path2, map_location="cpu")
|
| 839 |
+
except Exception:
|
| 840 |
+
pass
|
| 841 |
+
return None
|
| 842 |
+
|
| 843 |
+
def _get_scene_id(self, item: Dict) -> str:
|
| 844 |
+
if "segment_id" in item and item["segment_id"]:
|
| 845 |
+
return str(item["segment_id"])
|
| 846 |
+
try:
|
| 847 |
+
p0 = item.get("image_paths", [None])[0]
|
| 848 |
+
if not p0:
|
| 849 |
+
return "unknown"
|
| 850 |
+
fname = os.path.basename(p0)
|
| 851 |
+
parts = fname.split("__")
|
| 852 |
+
if len(parts) > 1:
|
| 853 |
+
return parts[0]
|
| 854 |
+
return "unknown"
|
| 855 |
+
except Exception:
|
| 856 |
+
return "unknown"
|
| 857 |
+
|
| 858 |
+
def _get_timestamp(self, item: Dict) -> int:
|
| 859 |
+
if "timestamp" in item and item["timestamp"]:
|
| 860 |
+
return int(item["timestamp"])
|
| 861 |
+
try:
|
| 862 |
+
p0 = item.get("image_paths", [None])[0]
|
| 863 |
+
fname = os.path.basename(p0)
|
| 864 |
+
parts = fname.replace(".jpg", "").split("__")
|
| 865 |
+
return int(parts[-1]) if parts else 0
|
| 866 |
+
except Exception:
|
| 867 |
+
return 0
|
| 868 |
+
|
| 869 |
+
def get_scene_groups(self) -> Dict[str, List[int]]:
|
| 870 |
+
groups: Dict[str, List[int]] = {}
|
| 871 |
+
for idx, item in enumerate(self.data):
|
| 872 |
+
sid = self._get_scene_id(item)
|
| 873 |
+
groups.setdefault(sid, []).append(idx)
|
| 874 |
+
for sid in groups:
|
| 875 |
+
groups[sid].sort(key=lambda i: self._get_timestamp(self.data[i]))
|
| 876 |
+
return groups
|
| 877 |
+
|
| 878 |
+
def _get_ego_motion_data(self, item: Dict) -> Dict:
|
| 879 |
+
route_cmd = self._resolve_route_command(item) or "go straight"
|
| 880 |
+
|
| 881 |
+
if 'ego_motion' in item:
|
| 882 |
+
ego = item['ego_motion']
|
| 883 |
+
|
| 884 |
+
# JSON already stores paper-frame values (gen_atlas_planning_qa.py
|
| 885 |
+
# applies _nuscenes_to_paper_xy before writing). Do NOT transform again.
|
| 886 |
+
vel = ego.get('velocity', [0.0, 0.0])
|
| 887 |
+
acc = ego.get('acceleration', [0.0, 0.0])
|
| 888 |
+
vx, vy = float(vel[0]), float(vel[1])
|
| 889 |
+
ax, ay = float(acc[0]), float(acc[1])
|
| 890 |
+
|
| 891 |
+
if os.getenv("ATLAS_AUDIT", "0") not in ("", "0", "false", "False"):
|
| 892 |
+
if not hasattr(self, "_audit_ego_calls"):
|
| 893 |
+
self._audit_ego_calls = 0
|
| 894 |
+
max_calls = int(os.getenv("ATLAS_AUDIT_MAX_EGO", "1"))
|
| 895 |
+
if self._audit_ego_calls < max_calls:
|
| 896 |
+
print(f"[ATLAS_AUDIT][ego] vel=({vx:.3e},{vy:.3e}) acc=({ax:.3e},{ay:.3e}) [paper-frame, no transform]")
|
| 897 |
+
self._audit_ego_calls += 1
|
| 898 |
+
|
| 899 |
+
if 'waypoints' in ego:
|
| 900 |
+
waypoints_raw = ego['waypoints']
|
| 901 |
+
waypoints = [[float(wp[0]), float(wp[1])] for wp in waypoints_raw]
|
| 902 |
+
else:
|
| 903 |
+
waypoints = self._generate_waypoints(route_cmd, [vx, vy], [ax, ay])
|
| 904 |
+
|
| 905 |
+
return {
|
| 906 |
+
'velocity': [vx, vy],
|
| 907 |
+
'acceleration': [ax, ay],
|
| 908 |
+
'waypoints': waypoints,
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
if not hasattr(self, "_warned_missing_ego_motion"):
|
| 912 |
+
self._warned_missing_ego_motion = True
|
| 913 |
+
print("[WARN] ego_motion missing, using stationary default (velocity=[0,0])")
|
| 914 |
+
|
| 915 |
+
velocity = [0.0, 0.0]
|
| 916 |
+
acceleration = [0.0, 0.0]
|
| 917 |
+
waypoints = self._generate_waypoints(route_cmd, velocity, acceleration)
|
| 918 |
+
|
| 919 |
+
return {
|
| 920 |
+
'velocity': velocity,
|
| 921 |
+
'acceleration': acceleration,
|
| 922 |
+
'waypoints': waypoints,
|
| 923 |
+
}
|
| 924 |
+
|
| 925 |
+
def _generate_waypoints(
|
| 926 |
+
self,
|
| 927 |
+
command: str,
|
| 928 |
+
velocity: List[float] = None,
|
| 929 |
+
acceleration: List[float] = None,
|
| 930 |
+
) -> List[List[float]]:
|
| 931 |
+
if velocity is None:
|
| 932 |
+
velocity = [0.0, 5.0]
|
| 933 |
+
if acceleration is None:
|
| 934 |
+
acceleration = [0.0, 0.0]
|
| 935 |
+
|
| 936 |
+
vx, vy = velocity
|
| 937 |
+
ax, ay = acceleration
|
| 938 |
+
waypoints = []
|
| 939 |
+
|
| 940 |
+
for i in range(1, 7):
|
| 941 |
+
t = i * 0.5
|
| 942 |
+
|
| 943 |
+
x = vx * t + 0.5 * ax * t * t
|
| 944 |
+
y = vy * t + 0.5 * ay * t * t
|
| 945 |
+
|
| 946 |
+
if command == "turn left":
|
| 947 |
+
curvature = -0.3 * t * t
|
| 948 |
+
x += curvature
|
| 949 |
+
elif command == "turn right":
|
| 950 |
+
curvature = 0.3 * t * t
|
| 951 |
+
x += curvature
|
| 952 |
+
|
| 953 |
+
waypoints.append([round(x, 2), round(y, 2)])
|
| 954 |
+
|
| 955 |
+
return waypoints
|
| 956 |
+
|
| 957 |
+
def _preprocess_streampetr(self, img_pil: Image.Image, K: np.ndarray):
|
| 958 |
+
W, H = img_pil.size
|
| 959 |
+
conf = self.streampetr_conf
|
| 960 |
+
fH, fW = conf["final_dim"]
|
| 961 |
+
resize = max(fH / H, fW / W)
|
| 962 |
+
rW, rH = int(W * resize), int(H * resize)
|
| 963 |
+
crop_h = int(rH) - fH
|
| 964 |
+
crop_w = max(0, rW - fW) // 2
|
| 965 |
+
|
| 966 |
+
if resize != 1.0:
|
| 967 |
+
img_pil = img_pil.resize((rW, rH), Image.BILINEAR)
|
| 968 |
+
img_pil = img_pil.crop((crop_w, crop_h, crop_w + fW, crop_h + fH))
|
| 969 |
+
|
| 970 |
+
K_new = K.copy()
|
| 971 |
+
K_new[0, 0] *= resize
|
| 972 |
+
K_new[1, 1] *= resize
|
| 973 |
+
K_new[0, 2] = K_new[0, 2] * resize - crop_w
|
| 974 |
+
K_new[1, 2] = K_new[1, 2] * resize - crop_h
|
| 975 |
+
return img_pil, K_new
|
| 976 |
+
|
| 977 |
+
def _preprocess_topomlp(self, img_pil: Image.Image, K: np.ndarray):
|
| 978 |
+
W, H = img_pil.size
|
| 979 |
+
tW, tH = self.topomlp_conf["target_size"]
|
| 980 |
+
w_scale = tW / W
|
| 981 |
+
h_scale = tH / H
|
| 982 |
+
|
| 983 |
+
img_pil = img_pil.resize((tW, tH), Image.BILINEAR)
|
| 984 |
+
|
| 985 |
+
K_new = K.copy()
|
| 986 |
+
K_new[0, 0] *= w_scale
|
| 987 |
+
K_new[0, 2] *= w_scale
|
| 988 |
+
K_new[1, 1] *= h_scale
|
| 989 |
+
K_new[1, 2] *= h_scale
|
| 990 |
+
return img_pil, K_new
|
| 991 |
+
|
| 992 |
+
def _load_images_with_cameras(
|
| 993 |
+
self,
|
| 994 |
+
image_paths: List[str],
|
| 995 |
+
item: Optional[Dict] = None,
|
| 996 |
+
) -> Dict:
|
| 997 |
+
images_det = []
|
| 998 |
+
images_map = []
|
| 999 |
+
intrinsics_det_list = []
|
| 1000 |
+
intrinsics_map_list = []
|
| 1001 |
+
extrinsics_list = []
|
| 1002 |
+
lidar2img_det_list = []
|
| 1003 |
+
lidar2img_map_list = []
|
| 1004 |
+
ego_pose_out = None
|
| 1005 |
+
ego_pose_inv_out = None
|
| 1006 |
+
timestamp_out = None
|
| 1007 |
+
|
| 1008 |
+
for i, img_path in enumerate(image_paths):
|
| 1009 |
+
camera_name = CAMERA_NAMES[i] if i < len(CAMERA_NAMES) else f"CAM_{i}"
|
| 1010 |
+
for cam in sorted(CAMERA_NAMES, key=len, reverse=True):
|
| 1011 |
+
if cam in img_path:
|
| 1012 |
+
camera_name = cam
|
| 1013 |
+
break
|
| 1014 |
+
|
| 1015 |
+
def _normalize_path(p: str) -> str:
|
| 1016 |
+
return str(p).replace("\\", "/").lstrip("./")
|
| 1017 |
+
|
| 1018 |
+
def _lookup_sample_data(path: str) -> Optional[Dict]:
|
| 1019 |
+
if self.calibration is None:
|
| 1020 |
+
return None
|
| 1021 |
+
by_name = self.calibration["sample_data_by_filename"]
|
| 1022 |
+
candidates = []
|
| 1023 |
+
raw = str(path)
|
| 1024 |
+
norm = _normalize_path(raw)
|
| 1025 |
+
candidates.append(raw)
|
| 1026 |
+
candidates.append(norm)
|
| 1027 |
+
for key in ("samples/", "sweeps/"):
|
| 1028 |
+
if key in norm:
|
| 1029 |
+
candidates.append(norm[norm.index(key):])
|
| 1030 |
+
for cand in candidates:
|
| 1031 |
+
rec = by_name.get(cand, None)
|
| 1032 |
+
if rec is not None:
|
| 1033 |
+
return rec
|
| 1034 |
+
return None
|
| 1035 |
+
|
| 1036 |
+
full_path = os.path.normpath(os.path.join(self.image_root, img_path))
|
| 1037 |
+
if not os.path.isabs(img_path):
|
| 1038 |
+
full_path = os.path.normpath(os.path.join(self.image_root, img_path))
|
| 1039 |
+
else:
|
| 1040 |
+
full_path = img_path
|
| 1041 |
+
try:
|
| 1042 |
+
img = Image.open(full_path).convert("RGB")
|
| 1043 |
+
except Exception as e:
|
| 1044 |
+
sample_id = "unknown"
|
| 1045 |
+
task_type = "unknown"
|
| 1046 |
+
if isinstance(item, dict):
|
| 1047 |
+
sample_id = str(item.get("id", item.get("sample_id", "unknown")))
|
| 1048 |
+
try:
|
| 1049 |
+
task_type = infer_task_type(item)
|
| 1050 |
+
except Exception:
|
| 1051 |
+
task_type = str(item.get("task_type", item.get("task", "unknown")))
|
| 1052 |
+
raise RuntimeError(
|
| 1053 |
+
"failed to load image for AtlasDataset: "
|
| 1054 |
+
f"sample_id={sample_id} task_type={task_type} "
|
| 1055 |
+
f"camera_name={camera_name} image_path={img_path} "
|
| 1056 |
+
f"full_path={full_path} image_root={self.image_root}"
|
| 1057 |
+
) from e
|
| 1058 |
+
|
| 1059 |
+
K = None
|
| 1060 |
+
E = None
|
| 1061 |
+
ep = None
|
| 1062 |
+
sd = None
|
| 1063 |
+
|
| 1064 |
+
if self.calibration is not None:
|
| 1065 |
+
sd = _lookup_sample_data(img_path)
|
| 1066 |
+
if sd is not None:
|
| 1067 |
+
cs_token = sd.get("calibrated_sensor_token", None)
|
| 1068 |
+
if cs_token is None:
|
| 1069 |
+
raise RuntimeError(f"sample_data missing calibrated_sensor_token: {img_path}")
|
| 1070 |
+
cs = self.calibration["calibrated_sensor_by_token"].get(cs_token, None)
|
| 1071 |
+
if cs is None:
|
| 1072 |
+
raise RuntimeError(f"calibrated_sensor not found: {cs_token}")
|
| 1073 |
+
|
| 1074 |
+
ep_token = sd.get("ego_pose_token", None)
|
| 1075 |
+
if ep_token is None:
|
| 1076 |
+
raise RuntimeError(f"sample_data missing ego_pose_token: {img_path}")
|
| 1077 |
+
ep = self.calibration["ego_pose_by_token"].get(ep_token, None)
|
| 1078 |
+
if ep is None:
|
| 1079 |
+
raise RuntimeError(f"ego_pose not found: {ep_token}")
|
| 1080 |
+
|
| 1081 |
+
K = np.array(cs["camera_intrinsic"], dtype=np.float32)
|
| 1082 |
+
q = cs["rotation"]
|
| 1083 |
+
t = cs["translation"]
|
| 1084 |
+
w, x, y, z = q
|
| 1085 |
+
R = np.array(
|
| 1086 |
+
[
|
| 1087 |
+
[1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
|
| 1088 |
+
[2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
|
| 1089 |
+
[2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
|
| 1090 |
+
],
|
| 1091 |
+
dtype=np.float32,
|
| 1092 |
+
)
|
| 1093 |
+
E = np.eye(4, dtype=np.float32)
|
| 1094 |
+
E[:3, :3] = R
|
| 1095 |
+
E[:3, 3] = np.array(t, dtype=np.float32)
|
| 1096 |
+
|
| 1097 |
+
if ego_pose_out is None and self.calibration is not None:
|
| 1098 |
+
_item_id = str(item.get("id", "")) if isinstance(item, dict) else ""
|
| 1099 |
+
_lidar_sd = self.calibration.get("lidar_sd_by_sample_token", {}).get(_item_id)
|
| 1100 |
+
if _lidar_sd is not None:
|
| 1101 |
+
_lidar_cs = self.calibration["calibrated_sensor_by_token"].get(
|
| 1102 |
+
_lidar_sd.get("calibrated_sensor_token"), None)
|
| 1103 |
+
_lidar_ep = self.calibration["ego_pose_by_token"].get(
|
| 1104 |
+
_lidar_sd.get("ego_pose_token"), None)
|
| 1105 |
+
if _lidar_cs is not None and _lidar_ep is not None:
|
| 1106 |
+
def _q2R(q):
|
| 1107 |
+
ww, xx, yy, zz = q
|
| 1108 |
+
return np.array([
|
| 1109 |
+
[1-2*(yy*yy+zz*zz), 2*(xx*yy-zz*ww), 2*(xx*zz+yy*ww)],
|
| 1110 |
+
[2*(xx*yy+zz*ww), 1-2*(xx*xx+zz*zz), 2*(yy*zz-xx*ww)],
|
| 1111 |
+
[2*(xx*zz-yy*ww), 2*(yy*zz+xx*ww), 1-2*(xx*xx+yy*yy)],
|
| 1112 |
+
], dtype=np.float32)
|
| 1113 |
+
_l2e = np.eye(4, dtype=np.float32)
|
| 1114 |
+
_l2e[:3, :3] = _q2R(_lidar_cs["rotation"])
|
| 1115 |
+
_l2e[:3, 3] = np.array(_lidar_cs["translation"], dtype=np.float32)
|
| 1116 |
+
_e2g = np.eye(4, dtype=np.float32)
|
| 1117 |
+
_e2g[:3, :3] = _q2R(_lidar_ep["rotation"])
|
| 1118 |
+
_e2g[:3, 3] = np.array(_lidar_ep["translation"], dtype=np.float32)
|
| 1119 |
+
_lidar2global = (_e2g @ _l2e).astype(np.float32)
|
| 1120 |
+
ego_pose_out = torch.tensor(_lidar2global, dtype=torch.float32)
|
| 1121 |
+
try:
|
| 1122 |
+
ego_pose_inv_out = torch.tensor(np.linalg.inv(_lidar2global), dtype=torch.float32)
|
| 1123 |
+
except Exception:
|
| 1124 |
+
ego_pose_inv_out = None
|
| 1125 |
+
_lidar_ts = _lidar_sd.get("timestamp", None)
|
| 1126 |
+
if _lidar_ts is not None:
|
| 1127 |
+
timestamp_out = torch.tensor(float(_lidar_ts) * 1e-6, dtype=torch.float32)
|
| 1128 |
+
if ego_pose_out is None and ep is not None:
|
| 1129 |
+
q_ep = ep.get("rotation", None)
|
| 1130 |
+
t_ep = ep.get("translation", None)
|
| 1131 |
+
if q_ep is not None and t_ep is not None:
|
| 1132 |
+
w, x, y, z = q_ep
|
| 1133 |
+
R_ep = np.array(
|
| 1134 |
+
[
|
| 1135 |
+
[1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
|
| 1136 |
+
[2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
|
| 1137 |
+
[2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
|
| 1138 |
+
],
|
| 1139 |
+
dtype=np.float32,
|
| 1140 |
+
)
|
| 1141 |
+
ego_pose_m = np.eye(4, dtype=np.float32)
|
| 1142 |
+
ego_pose_m[:3, :3] = R_ep
|
| 1143 |
+
ego_pose_m[:3, 3] = np.array(t_ep, dtype=np.float32)
|
| 1144 |
+
ego_pose_out = torch.tensor(ego_pose_m, dtype=torch.float32)
|
| 1145 |
+
try:
|
| 1146 |
+
ego_pose_inv_out = torch.tensor(np.linalg.inv(ego_pose_m), dtype=torch.float32)
|
| 1147 |
+
except Exception:
|
| 1148 |
+
ego_pose_inv_out = None
|
| 1149 |
+
if timestamp_out is None and sd is not None:
|
| 1150 |
+
ts = sd.get("timestamp", None)
|
| 1151 |
+
if ts is not None:
|
| 1152 |
+
timestamp_out = torch.tensor(float(ts) * 1e-6, dtype=torch.float32)
|
| 1153 |
+
|
| 1154 |
+
if K is None or E is None:
|
| 1155 |
+
sensor = (item or {}).get("sensor", None) if isinstance(item, dict) else None
|
| 1156 |
+
if not isinstance(sensor, dict):
|
| 1157 |
+
raise RuntimeError(f"no camera params for {img_path}")
|
| 1158 |
+
if camera_name not in sensor:
|
| 1159 |
+
raise RuntimeError(f"sensor missing camera {camera_name}")
|
| 1160 |
+
cam_s = sensor[camera_name]
|
| 1161 |
+
try:
|
| 1162 |
+
K = np.array(cam_s["intrinsic"]["K"], dtype=np.float32)
|
| 1163 |
+
R = np.array(cam_s["extrinsic"]["rotation"], dtype=np.float32)
|
| 1164 |
+
t = np.array(cam_s["extrinsic"]["translation"], dtype=np.float32)
|
| 1165 |
+
E = np.eye(4, dtype=np.float32)
|
| 1166 |
+
E[:3, :3] = R
|
| 1167 |
+
E[:3, 3] = t
|
| 1168 |
+
except Exception as e:
|
| 1169 |
+
raise RuntimeError(f"failed to parse sensor for {camera_name}: {e}")
|
| 1170 |
+
|
| 1171 |
+
img_det, K_det = self._preprocess_streampetr(img.copy(), K.copy())
|
| 1172 |
+
img_map, K_map = self._preprocess_topomlp(img.copy(), K.copy())
|
| 1173 |
+
|
| 1174 |
+
images_det.append(self.image_transform(img_det))
|
| 1175 |
+
images_map.append(self.image_transform(img_map))
|
| 1176 |
+
|
| 1177 |
+
intrinsics_det_list.append(torch.tensor(K_det, dtype=torch.float32))
|
| 1178 |
+
intrinsics_map_list.append(torch.tensor(K_map, dtype=torch.float32))
|
| 1179 |
+
extrinsics_list.append(torch.tensor(E, dtype=torch.float32))
|
| 1180 |
+
|
| 1181 |
+
def _quat_wxyz_to_R(qwxyz):
|
| 1182 |
+
ww, xx, yy, zz = qwxyz
|
| 1183 |
+
return np.array(
|
| 1184 |
+
[
|
| 1185 |
+
[1 - 2 * (yy * yy + zz * zz), 2 * (xx * yy - zz * ww), 2 * (xx * zz + yy * ww)],
|
| 1186 |
+
[2 * (xx * yy + zz * ww), 1 - 2 * (xx * xx + zz * zz), 2 * (yy * zz - xx * ww)],
|
| 1187 |
+
[2 * (xx * zz - yy * ww), 2 * (yy * zz + xx * ww), 1 - 2 * (xx * xx + yy * yy)],
|
| 1188 |
+
],
|
| 1189 |
+
dtype=np.float32,
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
def _T_from_Rt(Rm, tv):
|
| 1193 |
+
T = np.eye(4, dtype=np.float32)
|
| 1194 |
+
T[:3, :3] = Rm
|
| 1195 |
+
T[:3, 3] = np.array(tv, dtype=np.float32)
|
| 1196 |
+
return T
|
| 1197 |
+
|
| 1198 |
+
def _compute_lidar2img(K_adj, E_mat, sd_rec, ep_rec):
|
| 1199 |
+
cam2ego = E_mat.astype(np.float32)
|
| 1200 |
+
ego2cam = np.linalg.inv(cam2ego)
|
| 1201 |
+
K4 = np.eye(4, dtype=np.float32)
|
| 1202 |
+
K4[:3, :3] = K_adj.astype(np.float32)
|
| 1203 |
+
|
| 1204 |
+
if sd_rec is None or ep_rec is None:
|
| 1205 |
+
return K4 @ ego2cam
|
| 1206 |
+
|
| 1207 |
+
sample_tk = sd_rec.get("sample_token", None)
|
| 1208 |
+
if sample_tk is None:
|
| 1209 |
+
return K4 @ ego2cam
|
| 1210 |
+
|
| 1211 |
+
ego2global_c = _T_from_Rt(_quat_wxyz_to_R(ep_rec["rotation"]), ep_rec["translation"])
|
| 1212 |
+
global2ego_c = np.linalg.inv(ego2global_c)
|
| 1213 |
+
|
| 1214 |
+
lidar_sd_rec = self.calibration.get("lidar_sd_by_sample_token", {}).get(str(sample_tk), None)
|
| 1215 |
+
if lidar_sd_rec is None:
|
| 1216 |
+
return K4 @ ego2cam
|
| 1217 |
+
|
| 1218 |
+
lidar_cs_rec = self.calibration["calibrated_sensor_by_token"].get(
|
| 1219 |
+
lidar_sd_rec.get("calibrated_sensor_token"), None)
|
| 1220 |
+
lidar_ep_rec = self.calibration["ego_pose_by_token"].get(
|
| 1221 |
+
lidar_sd_rec.get("ego_pose_token"), None)
|
| 1222 |
+
if lidar_cs_rec is None or lidar_ep_rec is None:
|
| 1223 |
+
return K4 @ ego2cam
|
| 1224 |
+
|
| 1225 |
+
lidar2ego = _T_from_Rt(_quat_wxyz_to_R(lidar_cs_rec["rotation"]), lidar_cs_rec["translation"])
|
| 1226 |
+
ego2global_lidar = _T_from_Rt(_quat_wxyz_to_R(lidar_ep_rec["rotation"]), lidar_ep_rec["translation"])
|
| 1227 |
+
lidar2cam = ego2cam @ global2ego_c @ ego2global_lidar @ lidar2ego
|
| 1228 |
+
return K4 @ lidar2cam
|
| 1229 |
+
|
| 1230 |
+
lidar2img_det_list.append(
|
| 1231 |
+
torch.tensor(_compute_lidar2img(K_det, E, sd, ep), dtype=torch.float32))
|
| 1232 |
+
lidar2img_map_list.append(
|
| 1233 |
+
torch.tensor(_compute_lidar2img(K_map, E, sd, ep), dtype=torch.float32))
|
| 1234 |
+
|
| 1235 |
+
# Fallback: if nuScenes calibration lookup failed (e.g. OpenLane samples),
|
| 1236 |
+
# try to recover ego_pose from item["pose"] and timestamp from item["timestamp"].
|
| 1237 |
+
if ego_pose_out is None and isinstance(item, dict):
|
| 1238 |
+
pose_data = item.get("pose", None)
|
| 1239 |
+
if isinstance(pose_data, dict):
|
| 1240 |
+
try:
|
| 1241 |
+
rot_raw = pose_data.get("rotation", None)
|
| 1242 |
+
t_p = pose_data.get("translation", None)
|
| 1243 |
+
if rot_raw is not None and t_p is not None:
|
| 1244 |
+
arr = np.array(rot_raw, dtype=np.float32)
|
| 1245 |
+
if arr.shape == (3, 3):
|
| 1246 |
+
R_p = arr
|
| 1247 |
+
elif arr.shape == (4,):
|
| 1248 |
+
w, x, y, z = arr
|
| 1249 |
+
R_p = np.array([
|
| 1250 |
+
[1-2*(y*y+z*z), 2*(x*y-z*w), 2*(x*z+y*w)],
|
| 1251 |
+
[2*(x*y+z*w), 1-2*(x*x+z*z), 2*(y*z-x*w)],
|
| 1252 |
+
[2*(x*z-y*w), 2*(y*z+x*w), 1-2*(x*x+y*y)],
|
| 1253 |
+
], dtype=np.float32)
|
| 1254 |
+
else:
|
| 1255 |
+
raise ValueError(f"Unsupported rotation shape: {arr.shape}")
|
| 1256 |
+
T_p = np.eye(4, dtype=np.float32)
|
| 1257 |
+
T_p[:3, :3] = R_p
|
| 1258 |
+
T_p[:3, 3] = np.array(t_p, dtype=np.float32)
|
| 1259 |
+
ego_pose_out = torch.tensor(T_p, dtype=torch.float32)
|
| 1260 |
+
try:
|
| 1261 |
+
ego_pose_inv_out = torch.tensor(np.linalg.inv(T_p), dtype=torch.float32)
|
| 1262 |
+
except Exception:
|
| 1263 |
+
ego_pose_inv_out = None
|
| 1264 |
+
except Exception as e:
|
| 1265 |
+
print(f"WARNING: Failed to parse item['pose']: {e}")
|
| 1266 |
+
|
| 1267 |
+
if timestamp_out is None and isinstance(item, dict):
|
| 1268 |
+
ts_raw = item.get("timestamp", None)
|
| 1269 |
+
if ts_raw is not None:
|
| 1270 |
+
try:
|
| 1271 |
+
timestamp_out = torch.tensor(float(ts_raw) * 1e-6, dtype=torch.float32)
|
| 1272 |
+
except Exception:
|
| 1273 |
+
pass
|
| 1274 |
+
|
| 1275 |
+
result = {
|
| 1276 |
+
"pixel_values_det": torch.stack(images_det, dim=0),
|
| 1277 |
+
"pixel_values_map": torch.stack(images_map, dim=0),
|
| 1278 |
+
"intrinsics_det": torch.stack(intrinsics_det_list, dim=0),
|
| 1279 |
+
"intrinsics_map": torch.stack(intrinsics_map_list, dim=0),
|
| 1280 |
+
"extrinsics": torch.stack(extrinsics_list, dim=0),
|
| 1281 |
+
"lidar2img_det": torch.stack(lidar2img_det_list, dim=0),
|
| 1282 |
+
"lidar2img_map": torch.stack(lidar2img_map_list, dim=0),
|
| 1283 |
+
"ego_pose": ego_pose_out,
|
| 1284 |
+
"ego_pose_inv": ego_pose_inv_out,
|
| 1285 |
+
"timestamp": timestamp_out,
|
| 1286 |
+
}
|
| 1287 |
+
result["pixel_values"] = result["pixel_values_det"]
|
| 1288 |
+
result["intrinsics"] = result["intrinsics_det"]
|
| 1289 |
+
result["lidar2img"] = result["lidar2img_det"]
|
| 1290 |
+
return result
|
| 1291 |
+
|
| 1292 |
+
|
| 1293 |
+
def atlas_collate_fn(
|
| 1294 |
+
batch: List[Dict[str, torch.Tensor]],
|
| 1295 |
+
pad_token_id: Optional[int] = None,
|
| 1296 |
+
) -> Dict[str, torch.Tensor]:
|
| 1297 |
+
pixel_values = torch.stack([item['pixel_values'] for item in batch])
|
| 1298 |
+
pixel_values_det = torch.stack([item['pixel_values_det'] for item in batch])
|
| 1299 |
+
pixel_values_map = torch.stack([item['pixel_values_map'] for item in batch])
|
| 1300 |
+
intrinsics = torch.stack([item['intrinsics'] for item in batch])
|
| 1301 |
+
intrinsics_det = torch.stack([item['intrinsics_det'] for item in batch])
|
| 1302 |
+
intrinsics_map = torch.stack([item['intrinsics_map'] for item in batch])
|
| 1303 |
+
extrinsics = torch.stack([item['extrinsics'] for item in batch])
|
| 1304 |
+
|
| 1305 |
+
def _try_stack(key):
|
| 1306 |
+
if all(key in item for item in batch):
|
| 1307 |
+
try:
|
| 1308 |
+
return torch.stack([item[key] for item in batch])
|
| 1309 |
+
except Exception:
|
| 1310 |
+
pass
|
| 1311 |
+
return None
|
| 1312 |
+
|
| 1313 |
+
lidar2img = _try_stack("lidar2img")
|
| 1314 |
+
lidar2img_det = _try_stack("lidar2img_det")
|
| 1315 |
+
lidar2img_map = _try_stack("lidar2img_map")
|
| 1316 |
+
ego_pose = _try_stack("ego_pose")
|
| 1317 |
+
ego_pose_inv = _try_stack("ego_pose_inv")
|
| 1318 |
+
timestamp = _try_stack("timestamp")
|
| 1319 |
+
|
| 1320 |
+
max_length = max(len(item['input_ids']) for item in batch)
|
| 1321 |
+
batch_size = len(batch)
|
| 1322 |
+
|
| 1323 |
+
if pad_token_id is None:
|
| 1324 |
+
raise RuntimeError("atlas_collate_fn requires explicit pad_token_id")
|
| 1325 |
+
|
| 1326 |
+
input_ids = torch.full((batch_size, max_length), fill_value=pad_token_id, dtype=torch.long)
|
| 1327 |
+
attention_mask = torch.full((batch_size, max_length), fill_value=0, dtype=torch.long)
|
| 1328 |
+
labels = torch.full((batch_size, max_length), fill_value=-100, dtype=torch.long)
|
| 1329 |
+
|
| 1330 |
+
for i, item in enumerate(batch):
|
| 1331 |
+
seq_len = len(item['input_ids'])
|
| 1332 |
+
input_ids[i, :seq_len] = item['input_ids']
|
| 1333 |
+
attention_mask[i, :seq_len] = item['attention_mask']
|
| 1334 |
+
labels[i, :seq_len] = item['labels']
|
| 1335 |
+
|
| 1336 |
+
pad_mask = attention_mask == 0
|
| 1337 |
+
if pad_mask.any():
|
| 1338 |
+
if not torch.all(input_ids[pad_mask] == pad_token_id):
|
| 1339 |
+
raise RuntimeError("padding inconsistent: input_ids")
|
| 1340 |
+
if not torch.all(labels[pad_mask] == -100):
|
| 1341 |
+
raise RuntimeError("padding inconsistent: labels")
|
| 1342 |
+
|
| 1343 |
+
result = {
|
| 1344 |
+
'pixel_values': pixel_values,
|
| 1345 |
+
'pixel_values_det': pixel_values_det,
|
| 1346 |
+
'pixel_values_map': pixel_values_map,
|
| 1347 |
+
'intrinsics': intrinsics,
|
| 1348 |
+
'intrinsics_det': intrinsics_det,
|
| 1349 |
+
'intrinsics_map': intrinsics_map,
|
| 1350 |
+
'extrinsics': extrinsics,
|
| 1351 |
+
'input_ids': input_ids,
|
| 1352 |
+
'attention_mask': attention_mask,
|
| 1353 |
+
'labels': labels,
|
| 1354 |
+
'scene_id': [item.get('scene_id', 'unknown') for item in batch],
|
| 1355 |
+
'sample_id': [item.get('sample_id', 'unknown') for item in batch],
|
| 1356 |
+
'dataset_idx': torch.stack([item['dataset_idx'] for item in batch]),
|
| 1357 |
+
'task_type': [item.get('task_type', 'detection') for item in batch],
|
| 1358 |
+
}
|
| 1359 |
+
for k, v in [("lidar2img", lidar2img), ("lidar2img_det", lidar2img_det),
|
| 1360 |
+
("lidar2img_map", lidar2img_map), ("ego_pose", ego_pose),
|
| 1361 |
+
("ego_pose_inv", ego_pose_inv), ("timestamp", timestamp)]:
|
| 1362 |
+
if v is not None:
|
| 1363 |
+
result[k] = v
|
| 1364 |
+
|
| 1365 |
+
velocity = None
|
| 1366 |
+
if all("velocity" in item for item in batch):
|
| 1367 |
+
try:
|
| 1368 |
+
velocity = torch.stack([item["velocity"] for item in batch])
|
| 1369 |
+
except Exception:
|
| 1370 |
+
velocity = None
|
| 1371 |
+
if velocity is not None:
|
| 1372 |
+
result["velocity"] = velocity
|
| 1373 |
+
|
| 1374 |
+
if any("precomputed_det" in item for item in batch):
|
| 1375 |
+
det_shape = next(item["precomputed_det"].shape for item in batch if "precomputed_det" in item)
|
| 1376 |
+
ref_shape = next(item["precomputed_det_ref"].shape for item in batch if "precomputed_det" in item)
|
| 1377 |
+
dets, refs = [], []
|
| 1378 |
+
for item in batch:
|
| 1379 |
+
if "precomputed_det" in item:
|
| 1380 |
+
dets.append(item["precomputed_det"])
|
| 1381 |
+
refs.append(item["precomputed_det_ref"])
|
| 1382 |
+
else:
|
| 1383 |
+
dets.append(torch.zeros(det_shape))
|
| 1384 |
+
refs.append(torch.zeros(ref_shape))
|
| 1385 |
+
result["precomputed_det"] = torch.stack(dets)
|
| 1386 |
+
result["precomputed_det_ref"] = torch.stack(refs)
|
| 1387 |
+
|
| 1388 |
+
if any("precomputed_map" in item for item in batch):
|
| 1389 |
+
result["precomputed_map"] = [item.get("precomputed_map") for item in batch]
|
| 1390 |
+
|
| 1391 |
+
audit_keys = [
|
| 1392 |
+
"audit_prompt_len",
|
| 1393 |
+
"audit_answer_len",
|
| 1394 |
+
"audit_labels_nonmasked_count",
|
| 1395 |
+
"audit_first_nonmasked_index",
|
| 1396 |
+
"audit_num_query_tokens_in_input_ids",
|
| 1397 |
+
"audit_expected_num_queries",
|
| 1398 |
+
"audit_truncated",
|
| 1399 |
+
]
|
| 1400 |
+
for k in audit_keys:
|
| 1401 |
+
if all(k in item for item in batch):
|
| 1402 |
+
try:
|
| 1403 |
+
result[k] = torch.stack([item[k] for item in batch])
|
| 1404 |
+
except Exception:
|
| 1405 |
+
pass
|
| 1406 |
+
has_gt = all('gt_boxes' in item for item in batch)
|
| 1407 |
+
if has_gt:
|
| 1408 |
+
result['gt_boxes'] = [item['gt_boxes'] for item in batch]
|
| 1409 |
+
result['gt_labels'] = [item['gt_labels'] for item in batch]
|
| 1410 |
+
|
| 1411 |
+
return result
|
| 1412 |
+
|
| 1413 |
+
|
| 1414 |
+
def make_atlas_collate_fn(pad_token_id: int):
|
| 1415 |
+
from functools import partial
|
| 1416 |
+
return partial(atlas_collate_fn, pad_token_id=pad_token_id)
|
src/dataset/scene_sampler.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DDP-safe scene-sequential sampler for online temporal training.
|
| 2 |
+
|
| 3 |
+
Guarantees:
|
| 4 |
+
1. Within each scene, frames are yielded in strict timestamp order.
|
| 5 |
+
2. Scene order is shuffled per-epoch for training diversity.
|
| 6 |
+
3. All ranks yield exactly the same number of micro-steps per epoch
|
| 7 |
+
(balanced by greedy scene assignment + deterministic replay padding).
|
| 8 |
+
4. Epoch boundaries and replay-scene starts are detectable by the caller
|
| 9 |
+
via timestamp regression, so StreamPETR memory can be reset correctly.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import random
|
| 13 |
+
from typing import Dict, Iterator, List, Sequence
|
| 14 |
+
|
| 15 |
+
from torch.utils.data import Sampler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SceneSequentialSampler(Sampler[int]):
|
| 19 |
+
"""Distributed temporal sampler with equal-step guarantee."""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
scene_groups: Dict[str, List[int]],
|
| 24 |
+
num_replicas: int = 1,
|
| 25 |
+
rank: int = 0,
|
| 26 |
+
seed: int = 0,
|
| 27 |
+
shuffle_scenes: bool = True,
|
| 28 |
+
pad_to_multiple: int = 1,
|
| 29 |
+
):
|
| 30 |
+
if num_replicas < 1:
|
| 31 |
+
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
|
| 32 |
+
if rank < 0 or rank >= num_replicas:
|
| 33 |
+
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
|
| 34 |
+
if not scene_groups:
|
| 35 |
+
raise ValueError("scene_groups must not be empty")
|
| 36 |
+
|
| 37 |
+
self.scene_groups = scene_groups
|
| 38 |
+
self.scene_ids = sorted(scene_groups.keys())
|
| 39 |
+
self.num_replicas = int(num_replicas)
|
| 40 |
+
self.rank = int(rank)
|
| 41 |
+
self.seed = int(seed)
|
| 42 |
+
self.shuffle_scenes = bool(shuffle_scenes)
|
| 43 |
+
self.pad_to_multiple = max(1, int(pad_to_multiple))
|
| 44 |
+
self.epoch = 0
|
| 45 |
+
self._cached: List[int] = []
|
| 46 |
+
self._target_len = 0
|
| 47 |
+
|
| 48 |
+
def set_epoch(self, epoch: int) -> None:
|
| 49 |
+
self.epoch = int(epoch)
|
| 50 |
+
self._cached = []
|
| 51 |
+
self._target_len = 0
|
| 52 |
+
|
| 53 |
+
def _scene_len(self, sid: str) -> int:
|
| 54 |
+
return len(self.scene_groups[sid])
|
| 55 |
+
|
| 56 |
+
def _build_indices(self) -> List[int]:
|
| 57 |
+
rng = random.Random(self.seed + self.epoch)
|
| 58 |
+
scene_order = list(self.scene_ids)
|
| 59 |
+
if self.shuffle_scenes:
|
| 60 |
+
rng.shuffle(scene_order)
|
| 61 |
+
|
| 62 |
+
per_rank_scenes: List[List[str]] = [[] for _ in range(self.num_replicas)]
|
| 63 |
+
per_rank_counts = [0] * self.num_replicas
|
| 64 |
+
|
| 65 |
+
for sid in sorted(scene_order, key=self._scene_len, reverse=True):
|
| 66 |
+
target = min(range(self.num_replicas), key=lambda r: per_rank_counts[r])
|
| 67 |
+
per_rank_scenes[target].append(sid)
|
| 68 |
+
per_rank_counts[target] += self._scene_len(sid)
|
| 69 |
+
|
| 70 |
+
for rid in range(self.num_replicas):
|
| 71 |
+
if not per_rank_scenes[rid]:
|
| 72 |
+
fallback = scene_order[rid % len(scene_order)]
|
| 73 |
+
per_rank_scenes[rid].append(fallback)
|
| 74 |
+
per_rank_counts[rid] += self._scene_len(fallback)
|
| 75 |
+
|
| 76 |
+
target_count = max(per_rank_counts)
|
| 77 |
+
if target_count % self.pad_to_multiple != 0:
|
| 78 |
+
target_count = (
|
| 79 |
+
(target_count + self.pad_to_multiple - 1)
|
| 80 |
+
// self.pad_to_multiple
|
| 81 |
+
) * self.pad_to_multiple
|
| 82 |
+
self._target_len = target_count
|
| 83 |
+
|
| 84 |
+
my_scenes = per_rank_scenes[self.rank]
|
| 85 |
+
if self.shuffle_scenes:
|
| 86 |
+
rng2 = random.Random(self.seed + self.epoch + self.rank)
|
| 87 |
+
rng2.shuffle(my_scenes)
|
| 88 |
+
|
| 89 |
+
indices: List[int] = []
|
| 90 |
+
for sid in my_scenes:
|
| 91 |
+
indices.extend(self.scene_groups[sid])
|
| 92 |
+
|
| 93 |
+
if len(indices) < target_count:
|
| 94 |
+
replay_pool = list(my_scenes)
|
| 95 |
+
cursor = 0
|
| 96 |
+
while len(indices) < target_count:
|
| 97 |
+
sid = replay_pool[cursor % len(replay_pool)]
|
| 98 |
+
indices.extend(self.scene_groups[sid])
|
| 99 |
+
cursor += 1
|
| 100 |
+
indices = indices[:target_count]
|
| 101 |
+
|
| 102 |
+
return indices
|
| 103 |
+
|
| 104 |
+
def __iter__(self) -> Iterator[int]:
|
| 105 |
+
self._cached = self._build_indices()
|
| 106 |
+
return iter(self._cached)
|
| 107 |
+
|
| 108 |
+
def __len__(self) -> int:
|
| 109 |
+
if self._target_len == 0:
|
| 110 |
+
self._cached = self._build_indices()
|
| 111 |
+
return self._target_len
|
src/eval/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (403 Bytes). View file
|
|
|
src/eval/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (451 Bytes). View file
|
|
|
src/eval/__pycache__/metrics.cpython-310.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
src/eval/__pycache__/metrics.cpython-38.pyc
ADDED
|
Binary file (23.9 kB). View file
|
|
|
src/eval/metrics.py
ADDED
|
@@ -0,0 +1,852 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Atlas evaluation metrics."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import List, Dict, Tuple, Optional
|
| 6 |
+
|
| 7 |
+
# scipy only affects match_lanes() / calculate_lane_detection_metrics(),
|
| 8 |
+
# which are NOT used in the main eval path (eval_atlas.py).
|
| 9 |
+
# Main eval uses: greedy matching for detection, OpenLane-V2 LaneEval.bench() for lanes.
|
| 10 |
+
try:
|
| 11 |
+
from scipy.optimize import linear_sum_assignment
|
| 12 |
+
SCIPY_AVAILABLE = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
SCIPY_AVAILABLE = False
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
NUSCENES_CLASS_MAP = {
|
| 18 |
+
# Base class names
|
| 19 |
+
'car': 'car',
|
| 20 |
+
'truck': 'truck',
|
| 21 |
+
'construction_vehicle': 'construction_vehicle',
|
| 22 |
+
'bus': 'bus',
|
| 23 |
+
'trailer': 'trailer',
|
| 24 |
+
'barrier': 'barrier',
|
| 25 |
+
'motorcycle': 'motorcycle',
|
| 26 |
+
'bicycle': 'bicycle',
|
| 27 |
+
'pedestrian': 'pedestrian',
|
| 28 |
+
'traffic_cone': 'traffic_cone',
|
| 29 |
+
# Full nuScenes category names - vehicles
|
| 30 |
+
'vehicle.car': 'car',
|
| 31 |
+
'vehicle.truck': 'truck',
|
| 32 |
+
'vehicle.construction': 'construction_vehicle',
|
| 33 |
+
'vehicle.bus.bendy': 'bus',
|
| 34 |
+
'vehicle.bus.rigid': 'bus',
|
| 35 |
+
'vehicle.trailer': 'trailer',
|
| 36 |
+
'vehicle.motorcycle': 'motorcycle',
|
| 37 |
+
'vehicle.bicycle': 'bicycle',
|
| 38 |
+
# Full nuScenes category names - pedestrians (all subtypes)
|
| 39 |
+
'human.pedestrian.adult': 'pedestrian',
|
| 40 |
+
'human.pedestrian.child': 'pedestrian',
|
| 41 |
+
'human.pedestrian.construction_worker': 'pedestrian',
|
| 42 |
+
'human.pedestrian.police_officer': 'pedestrian',
|
| 43 |
+
'human.pedestrian.wheelchair': 'pedestrian',
|
| 44 |
+
'human.pedestrian.stroller': 'pedestrian',
|
| 45 |
+
'human.pedestrian.personal_mobility': 'pedestrian',
|
| 46 |
+
# Full nuScenes category names - movable objects
|
| 47 |
+
'movable_object.barrier': 'barrier',
|
| 48 |
+
'movable_object.trafficcone': 'traffic_cone',
|
| 49 |
+
'movable_object.traffic_cone': 'traffic_cone',
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def normalize_category(category: str) -> str:
|
| 54 |
+
"""Normalize nuScenes category names to base class names."""
|
| 55 |
+
cat_lower = category.lower().strip()
|
| 56 |
+
if cat_lower in NUSCENES_CLASS_MAP:
|
| 57 |
+
return NUSCENES_CLASS_MAP[cat_lower]
|
| 58 |
+
for key, val in NUSCENES_CLASS_MAP.items():
|
| 59 |
+
if key in cat_lower or cat_lower in key:
|
| 60 |
+
return val
|
| 61 |
+
return cat_lower
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def normalize_ground_truths(ground_truths: List[Dict]) -> List[Dict]:
|
| 65 |
+
"""Normalize category names and ensure world_coords in ground truth list.
|
| 66 |
+
|
| 67 |
+
Handles multiple GT formats:
|
| 68 |
+
- {"translation": [x, y, z], "category_name": ...} (from regenerate_atlas_with_gt.py)
|
| 69 |
+
- {"box": [x, y, z, w, l, h, yaw], "category_name": ...} (from gen_atlas_full_data.py)
|
| 70 |
+
- {"world_coords": [x, y, z], "category": ...} (already normalized)
|
| 71 |
+
"""
|
| 72 |
+
normalized = []
|
| 73 |
+
for gt in ground_truths:
|
| 74 |
+
gt_copy = dict(gt)
|
| 75 |
+
# Normalize category
|
| 76 |
+
if 'category' in gt_copy:
|
| 77 |
+
gt_copy['category_raw'] = gt_copy['category']
|
| 78 |
+
gt_copy['category'] = normalize_category(gt_copy['category'])
|
| 79 |
+
elif 'category_name' in gt_copy:
|
| 80 |
+
gt_copy['category_raw'] = gt_copy['category_name']
|
| 81 |
+
gt_copy['category'] = normalize_category(gt_copy['category_name'])
|
| 82 |
+
|
| 83 |
+
# Ensure world_coords exists
|
| 84 |
+
if 'world_coords' not in gt_copy:
|
| 85 |
+
if 'translation' in gt_copy:
|
| 86 |
+
gt_copy['world_coords'] = list(gt_copy['translation'][:3])
|
| 87 |
+
elif 'box' in gt_copy:
|
| 88 |
+
gt_copy['world_coords'] = list(gt_copy['box'][:3])
|
| 89 |
+
|
| 90 |
+
normalized.append(gt_copy)
|
| 91 |
+
return normalized
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def bin_to_meters(bin_val: int, bin_range: Tuple[float, float] = (-51.2, 51.2), num_bins: int = 1000) -> float:
|
| 95 |
+
min_val, max_val = bin_range
|
| 96 |
+
normalized = bin_val / (num_bins - 1)
|
| 97 |
+
meters = min_val + normalized * (max_val - min_val)
|
| 98 |
+
return meters
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def meters_to_bin(meters: float, bin_range: Tuple[float, float] = (-51.2, 51.2), num_bins: int = 1000) -> int:
|
| 102 |
+
min_val, max_val = bin_range
|
| 103 |
+
meters = np.clip(meters, min_val, max_val)
|
| 104 |
+
normalized = (meters - min_val) / (max_val - min_val)
|
| 105 |
+
bin_val = round(normalized * (num_bins - 1))
|
| 106 |
+
bin_val = int(np.clip(bin_val, 0, num_bins - 1))
|
| 107 |
+
return bin_val
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _parse_lane_points(points_str: str) -> List[Dict]:
|
| 111 |
+
"""Parse a sequence of [x, y, z] bins into lane point dicts."""
|
| 112 |
+
point_pattern = r'\[(\d+),\s*(\d+),\s*(\d+)\]'
|
| 113 |
+
points = re.findall(point_pattern, points_str)
|
| 114 |
+
lane_points = []
|
| 115 |
+
for x_bin, y_bin, z_bin in points:
|
| 116 |
+
x_bin, y_bin, z_bin = int(x_bin), int(y_bin), int(z_bin)
|
| 117 |
+
x_meters = bin_to_meters(x_bin, bin_range=(-51.2, 51.2))
|
| 118 |
+
y_meters = bin_to_meters(y_bin, bin_range=(-51.2, 51.2))
|
| 119 |
+
z_meters = bin_to_meters(z_bin, bin_range=(-5.0, 3.0))
|
| 120 |
+
lane_points.append({
|
| 121 |
+
'bin_coords': [x_bin, y_bin, z_bin],
|
| 122 |
+
'world_coords': [x_meters, y_meters, z_meters]
|
| 123 |
+
})
|
| 124 |
+
return lane_points
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def parse_atlas_output(text: str) -> List[Dict]:
|
| 128 |
+
"""
|
| 129 |
+
Parse Atlas model output. Supports two canonical formats (checked in order):
|
| 130 |
+
1. Paper lane: Lane: [x, y, z], [x, y, z]; [x, y, z], [x, y, z]; ...
|
| 131 |
+
2. Detection: category: [x, y, z], [x, y, z]; category: [x, y, z].
|
| 132 |
+
"""
|
| 133 |
+
results = []
|
| 134 |
+
|
| 135 |
+
# --- 1. Paper lane format: "Lane: [pts], [pts]; [pts], [pts]; ..." ---
|
| 136 |
+
paper_lane_match = re.search(r'Lane:\s*(.*)', text, re.DOTALL)
|
| 137 |
+
if paper_lane_match:
|
| 138 |
+
content = paper_lane_match.group(1).rstrip('. \t\n')
|
| 139 |
+
lane_strs = content.split(';')
|
| 140 |
+
for lane_idx, lane_str in enumerate(lane_strs):
|
| 141 |
+
lane_str = lane_str.strip()
|
| 142 |
+
if not lane_str:
|
| 143 |
+
continue
|
| 144 |
+
lane_points = _parse_lane_points(lane_str)
|
| 145 |
+
if lane_points:
|
| 146 |
+
results.append({
|
| 147 |
+
'type': 'lane',
|
| 148 |
+
'lane_id': str(lane_idx),
|
| 149 |
+
'points': lane_points,
|
| 150 |
+
})
|
| 151 |
+
if results:
|
| 152 |
+
return results
|
| 153 |
+
|
| 154 |
+
# --- 2. Detection grouped format ---
|
| 155 |
+
# Canonical: "car: [pt1], [pt2]; truck: [pt3]."
|
| 156 |
+
|
| 157 |
+
def _make_det(category: str, x_b: int, y_b: int, z_b: int) -> Dict:
|
| 158 |
+
return {
|
| 159 |
+
'type': 'detection',
|
| 160 |
+
'category': normalize_category(category),
|
| 161 |
+
'category_raw': category,
|
| 162 |
+
'bin_coords': [x_b, y_b, z_b],
|
| 163 |
+
'world_coords': [
|
| 164 |
+
bin_to_meters(x_b, bin_range=(-51.2, 51.2)),
|
| 165 |
+
bin_to_meters(y_b, bin_range=(-51.2, 51.2)),
|
| 166 |
+
bin_to_meters(z_b, bin_range=(-5.0, 3.0)),
|
| 167 |
+
],
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
point_re = re.compile(r'\[(\d+),\s*(\d+),\s*(\d+)\]')
|
| 171 |
+
group_re = re.compile(r'(\S+)\s*:\s*((?:\[\d+,\s*\d+,\s*\d+\][\s,]*)+)')
|
| 172 |
+
|
| 173 |
+
stripped = text.strip().rstrip('.')
|
| 174 |
+
|
| 175 |
+
if stripped.startswith('lane_centerline('):
|
| 176 |
+
return []
|
| 177 |
+
|
| 178 |
+
if ';' in stripped:
|
| 179 |
+
for seg in stripped.split(';'):
|
| 180 |
+
seg = seg.strip()
|
| 181 |
+
if not seg:
|
| 182 |
+
continue
|
| 183 |
+
gm = group_re.match(seg)
|
| 184 |
+
if gm:
|
| 185 |
+
for x_b, y_b, z_b in point_re.findall(gm.group(2)):
|
| 186 |
+
results.append(_make_det(gm.group(1), int(x_b), int(y_b), int(z_b)))
|
| 187 |
+
|
| 188 |
+
if not results:
|
| 189 |
+
gm = group_re.match(stripped)
|
| 190 |
+
if gm:
|
| 191 |
+
pts_in_group = point_re.findall(gm.group(2))
|
| 192 |
+
pts_in_text = point_re.findall(stripped)
|
| 193 |
+
if len(pts_in_group) == len(pts_in_text):
|
| 194 |
+
for x_b, y_b, z_b in pts_in_group:
|
| 195 |
+
results.append(_make_det(gm.group(1), int(x_b), int(y_b), int(z_b)))
|
| 196 |
+
|
| 197 |
+
return results
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def calculate_distance(
|
| 201 |
+
pred_coord: List[float],
|
| 202 |
+
gt_coord: List[float],
|
| 203 |
+
use_2d: bool = False,
|
| 204 |
+
) -> float:
|
| 205 |
+
"""
|
| 206 |
+
计算预测坐标和真实坐标之间的距离
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
pred_coord: 预测坐标 [x, y, z]
|
| 210 |
+
gt_coord: 真实坐标 [x, y, z]
|
| 211 |
+
use_2d: 如果为 True,只使用 XY 平面距离(BEV 距离),忽略 Z 轴
|
| 212 |
+
这是 BEV 3D 检测中更常用的匹配方式
|
| 213 |
+
"""
|
| 214 |
+
pred = np.array(pred_coord)
|
| 215 |
+
gt = np.array(gt_coord)
|
| 216 |
+
|
| 217 |
+
if use_2d:
|
| 218 |
+
# 只使用 XY 平面距离(BEV 距离)
|
| 219 |
+
distance = np.linalg.norm(pred[:2] - gt[:2])
|
| 220 |
+
else:
|
| 221 |
+
# 3D 欧式距离
|
| 222 |
+
distance = np.linalg.norm(pred - gt)
|
| 223 |
+
|
| 224 |
+
return float(distance)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def match_detections(
|
| 228 |
+
predictions: List[Dict],
|
| 229 |
+
ground_truths: List[Dict],
|
| 230 |
+
threshold: float = 2.0,
|
| 231 |
+
use_2d_distance: bool = True,
|
| 232 |
+
use_hungarian: bool = False,
|
| 233 |
+
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
|
| 234 |
+
"""
|
| 235 |
+
匹配预测和真实检测框
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
predictions: 预测检测结果列表
|
| 239 |
+
ground_truths: 真实检测结果列表
|
| 240 |
+
threshold: 匹配距离阈值(米)
|
| 241 |
+
use_2d_distance: 如果为 True,使用 2D BEV 距离(XY 平面),这是 BEV 检测的标准做法
|
| 242 |
+
use_hungarian: 如果为 True,使用匈牙利算法进行最优匹配(需要 scipy);
|
| 243 |
+
默认 False,使用贪婪匹配(nuScenes 标准)
|
| 244 |
+
"""
|
| 245 |
+
if len(predictions) == 0:
|
| 246 |
+
return [], [], list(range(len(ground_truths)))
|
| 247 |
+
|
| 248 |
+
if len(ground_truths) == 0:
|
| 249 |
+
return [], list(range(len(predictions))), []
|
| 250 |
+
|
| 251 |
+
# 按类别分组进行匹配
|
| 252 |
+
all_categories = set(p['category'] for p in predictions) | set(g['category'] for g in ground_truths)
|
| 253 |
+
|
| 254 |
+
matched_preds = set()
|
| 255 |
+
matched_gts = set()
|
| 256 |
+
matches = []
|
| 257 |
+
|
| 258 |
+
for category in all_categories:
|
| 259 |
+
cat_preds = [(i, p) for i, p in enumerate(predictions) if p['category'] == category]
|
| 260 |
+
cat_gts = [(i, g) for i, g in enumerate(ground_truths) if g['category'] == category]
|
| 261 |
+
|
| 262 |
+
if not cat_preds or not cat_gts:
|
| 263 |
+
continue
|
| 264 |
+
|
| 265 |
+
# 构建距离矩阵
|
| 266 |
+
n_preds = len(cat_preds)
|
| 267 |
+
n_gts = len(cat_gts)
|
| 268 |
+
cost_matrix = np.full((n_preds, n_gts), float('inf'))
|
| 269 |
+
|
| 270 |
+
for pi, (pred_idx, pred) in enumerate(cat_preds):
|
| 271 |
+
for gi, (gt_idx, gt) in enumerate(cat_gts):
|
| 272 |
+
dist = calculate_distance(pred['world_coords'], gt['world_coords'], use_2d=use_2d_distance)
|
| 273 |
+
if dist < threshold:
|
| 274 |
+
cost_matrix[pi, gi] = dist
|
| 275 |
+
|
| 276 |
+
# 使用匈牙利算法或贪婪匹配
|
| 277 |
+
if use_hungarian and SCIPY_AVAILABLE and n_preds > 0 and n_gts > 0:
|
| 278 |
+
# 匈牙利算法最优匹配
|
| 279 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
| 280 |
+
for pi, gi in zip(row_ind, col_ind):
|
| 281 |
+
if cost_matrix[pi, gi] < threshold:
|
| 282 |
+
pred_idx = cat_preds[pi][0]
|
| 283 |
+
gt_idx = cat_gts[gi][0]
|
| 284 |
+
matches.append((pred_idx, gt_idx))
|
| 285 |
+
matched_preds.add(pred_idx)
|
| 286 |
+
matched_gts.add(gt_idx)
|
| 287 |
+
else:
|
| 288 |
+
# 贪婪匹配(按距离排序)
|
| 289 |
+
distances = []
|
| 290 |
+
for pi, (pred_idx, pred) in enumerate(cat_preds):
|
| 291 |
+
for gi, (gt_idx, gt) in enumerate(cat_gts):
|
| 292 |
+
dist = cost_matrix[pi, gi]
|
| 293 |
+
if dist < threshold:
|
| 294 |
+
distances.append((dist, pred_idx, gt_idx))
|
| 295 |
+
|
| 296 |
+
distances.sort(key=lambda x: x[0])
|
| 297 |
+
for dist, pred_idx, gt_idx in distances:
|
| 298 |
+
if pred_idx not in matched_preds and gt_idx not in matched_gts:
|
| 299 |
+
matches.append((pred_idx, gt_idx))
|
| 300 |
+
matched_preds.add(pred_idx)
|
| 301 |
+
matched_gts.add(gt_idx)
|
| 302 |
+
|
| 303 |
+
false_positives = [i for i in range(len(predictions)) if i not in matched_preds]
|
| 304 |
+
false_negatives = [i for i in range(len(ground_truths)) if i not in matched_gts]
|
| 305 |
+
|
| 306 |
+
return matches, false_positives, false_negatives
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def calculate_detection_f1(
|
| 310 |
+
predictions: List[Dict],
|
| 311 |
+
ground_truths: List[Dict],
|
| 312 |
+
threshold: float = 2.0,
|
| 313 |
+
) -> Dict[str, float]:
|
| 314 |
+
matches, false_positives, false_negatives = match_detections(
|
| 315 |
+
predictions, ground_truths, threshold
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
tp = len(matches)
|
| 319 |
+
fp = len(false_positives)
|
| 320 |
+
fn = len(false_negatives)
|
| 321 |
+
|
| 322 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 323 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 324 |
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 325 |
+
|
| 326 |
+
metrics = {
|
| 327 |
+
'precision': precision,
|
| 328 |
+
'recall': recall,
|
| 329 |
+
'f1': f1,
|
| 330 |
+
'tp': tp,
|
| 331 |
+
'fp': fp,
|
| 332 |
+
'fn': fn,
|
| 333 |
+
'num_predictions': len(predictions),
|
| 334 |
+
'num_ground_truths': len(ground_truths),
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
return metrics
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def denormalize_ref_points_01(
|
| 341 |
+
ref_points_01: np.ndarray,
|
| 342 |
+
pc_range: Tuple[float, float, float, float, float, float] = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
|
| 343 |
+
) -> np.ndarray:
|
| 344 |
+
"""Convert normalized ref points in [0,1] back to meters.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
ref_points_01: array-like [..., 3] in [0, 1]
|
| 348 |
+
pc_range: (x_min, y_min, z_min, x_max, y_max, z_max)
|
| 349 |
+
Returns:
|
| 350 |
+
np.ndarray [..., 3] in meters
|
| 351 |
+
"""
|
| 352 |
+
ref = np.asarray(ref_points_01, dtype=np.float64)
|
| 353 |
+
pc_min = np.array(pc_range[:3], dtype=np.float64)
|
| 354 |
+
pc_max = np.array(pc_range[3:], dtype=np.float64)
|
| 355 |
+
denom = np.clip(pc_max - pc_min, 1e-6, None)
|
| 356 |
+
ref01 = np.clip(ref, 0.0, 1.0)
|
| 357 |
+
return pc_min + ref01 * denom
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def snap_detections_to_ref_points(
|
| 361 |
+
predictions: List[Dict],
|
| 362 |
+
ref_points_01: np.ndarray,
|
| 363 |
+
pc_range: Tuple[float, float, float, float, float, float] = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
|
| 364 |
+
keep_z: bool = True,
|
| 365 |
+
) -> List[Dict]:
|
| 366 |
+
"""Snap predicted detection centers to nearest reference points (BEV XY).
|
| 367 |
+
|
| 368 |
+
This is a post-processing step that constrains predictions to lie on the
|
| 369 |
+
StreamPETR proposal set (ref points). It can significantly reduce small
|
| 370 |
+
metric thresholds (0.5m/1m) sensitivity to free-form numeric drift.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
predictions: list of detection dicts with 'world_coords' in meters
|
| 374 |
+
ref_points_01: [Q,3] or [B,Q,3] normalized ref points in [0,1]
|
| 375 |
+
pc_range: point cloud range for denormalization
|
| 376 |
+
keep_z: if True, keep each prediction's original z; else use ref z
|
| 377 |
+
Returns:
|
| 378 |
+
New list of predictions (deep-copied dicts) with snapped 'world_coords'
|
| 379 |
+
"""
|
| 380 |
+
if not predictions:
|
| 381 |
+
return []
|
| 382 |
+
|
| 383 |
+
ref = np.asarray(ref_points_01, dtype=np.float64)
|
| 384 |
+
if ref.ndim == 3:
|
| 385 |
+
ref = ref[0]
|
| 386 |
+
if ref.ndim != 2 or ref.shape[1] != 3 or ref.shape[0] == 0:
|
| 387 |
+
return list(predictions)
|
| 388 |
+
|
| 389 |
+
ref_m = denormalize_ref_points_01(ref, pc_range=pc_range)
|
| 390 |
+
ref_xy = ref_m[:, :2]
|
| 391 |
+
|
| 392 |
+
pred_xy = np.array([p.get("world_coords", [0.0, 0.0, 0.0])[:2] for p in predictions], dtype=np.float64)
|
| 393 |
+
if pred_xy.ndim != 2 or pred_xy.shape[0] == 0:
|
| 394 |
+
return list(predictions)
|
| 395 |
+
|
| 396 |
+
d = ((pred_xy[:, None, :] - ref_xy[None, :, :]) ** 2).sum(-1)
|
| 397 |
+
nn = d.argmin(axis=1)
|
| 398 |
+
|
| 399 |
+
snapped = []
|
| 400 |
+
for i, p in enumerate(predictions):
|
| 401 |
+
p2 = dict(p)
|
| 402 |
+
wc = list(p2.get("world_coords", [0.0, 0.0, 0.0]))
|
| 403 |
+
j = int(nn[i])
|
| 404 |
+
new_xyz = ref_m[j].tolist()
|
| 405 |
+
if keep_z and len(wc) >= 3:
|
| 406 |
+
new_xyz[2] = float(wc[2])
|
| 407 |
+
p2["world_coords"] = [float(new_xyz[0]), float(new_xyz[1]), float(new_xyz[2])]
|
| 408 |
+
snapped.append(p2)
|
| 409 |
+
return snapped
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def calculate_per_class_metrics(
|
| 413 |
+
predictions: List[Dict],
|
| 414 |
+
ground_truths: List[Dict],
|
| 415 |
+
threshold: float = 2.0,
|
| 416 |
+
) -> Dict[str, Dict[str, float]]:
|
| 417 |
+
pred_categories = set(pred['category'] for pred in predictions)
|
| 418 |
+
gt_categories = set(gt['category'] for gt in ground_truths)
|
| 419 |
+
all_categories = pred_categories | gt_categories
|
| 420 |
+
|
| 421 |
+
per_class_metrics = {}
|
| 422 |
+
|
| 423 |
+
for category in all_categories:
|
| 424 |
+
cat_preds = [pred for pred in predictions if pred['category'] == category]
|
| 425 |
+
cat_gts = [gt for gt in ground_truths if gt['category'] == category]
|
| 426 |
+
metrics = calculate_detection_f1(cat_preds, cat_gts, threshold)
|
| 427 |
+
per_class_metrics[category] = metrics
|
| 428 |
+
|
| 429 |
+
return per_class_metrics
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def parse_planning_output(text: str, require_full_vap: bool = False) -> Optional[Dict]:
|
| 433 |
+
result = {}
|
| 434 |
+
vel_pattern = r'ego car speed value:\s*\[(\d+),\s*(\d+)\]\.?'
|
| 435 |
+
acc_pattern = r'ego car acceleration value:\s*\[(\d+),\s*(\d+)\]\.?'
|
| 436 |
+
wp_pattern = (
|
| 437 |
+
r'(?:based on the ego car speed and acceleration you predicted,\s*)?'
|
| 438 |
+
r'(?:requeset|request)\s+the ego car planning waypoint(?:s)? in 3-seconds:\s*'
|
| 439 |
+
r'((?:\[\d+,\s*\d+\](?:,\s*)?)+)\.?'
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
vel_m = re.search(vel_pattern, text, flags=re.IGNORECASE)
|
| 443 |
+
if vel_m:
|
| 444 |
+
result['velocity_bins'] = [int(vel_m.group(1)), int(vel_m.group(2))]
|
| 445 |
+
|
| 446 |
+
acc_m = re.search(acc_pattern, text, flags=re.IGNORECASE)
|
| 447 |
+
if acc_m:
|
| 448 |
+
result['acceleration_bins'] = [int(acc_m.group(1)), int(acc_m.group(2))]
|
| 449 |
+
|
| 450 |
+
wp_m = re.search(wp_pattern, text, flags=re.IGNORECASE)
|
| 451 |
+
if wp_m:
|
| 452 |
+
point_pattern = r'\[(\d+),\s*(\d+)\]'
|
| 453 |
+
points = re.findall(point_pattern, wp_m.group(1))
|
| 454 |
+
wps = []
|
| 455 |
+
for xb, yb in points:
|
| 456 |
+
x = bin_to_meters(int(xb), bin_range=(-51.2, 51.2))
|
| 457 |
+
y = bin_to_meters(int(yb), bin_range=(-51.2, 51.2))
|
| 458 |
+
wps.append([x, y])
|
| 459 |
+
result['waypoints'] = wps
|
| 460 |
+
|
| 461 |
+
if 'waypoints' not in result or len(result['waypoints']) == 0:
|
| 462 |
+
return None
|
| 463 |
+
|
| 464 |
+
# Planning answers use a Figure 5-style chained speed + acceleration +
|
| 465 |
+
# waypoint protocol. The main evaluation path can require all three fields.
|
| 466 |
+
if require_full_vap and (
|
| 467 |
+
'velocity_bins' not in result or 'acceleration_bins' not in result
|
| 468 |
+
):
|
| 469 |
+
return None
|
| 470 |
+
return result
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def _pad_waypoints(waypoints: List[List[float]], target_n: int = 6) -> List[List[float]]:
|
| 474 |
+
"""Pad waypoint list to target_n by repeating last waypoint.
|
| 475 |
+
|
| 476 |
+
This prevents short model outputs from gaming the L2 / collision metrics.
|
| 477 |
+
"""
|
| 478 |
+
if len(waypoints) >= target_n:
|
| 479 |
+
return waypoints[:target_n]
|
| 480 |
+
if len(waypoints) == 0:
|
| 481 |
+
return [[0.0, 0.0]] * target_n
|
| 482 |
+
last = list(waypoints[-1])
|
| 483 |
+
return list(waypoints) + [list(last)] * (target_n - len(waypoints))
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def calculate_planning_l2(
|
| 487 |
+
pred_waypoints: List[List[float]],
|
| 488 |
+
gt_waypoints: List[List[float]],
|
| 489 |
+
timestamps: List[float] = None,
|
| 490 |
+
) -> Dict[str, float]:
|
| 491 |
+
n_gt = len(gt_waypoints)
|
| 492 |
+
if timestamps is None:
|
| 493 |
+
timestamps = [0.5 * (i + 1) for i in range(n_gt)]
|
| 494 |
+
|
| 495 |
+
# Pad predictions to match GT length to prevent short-output bias
|
| 496 |
+
pred_padded = _pad_waypoints(pred_waypoints, target_n=n_gt)
|
| 497 |
+
|
| 498 |
+
errors = {}
|
| 499 |
+
all_l2 = []
|
| 500 |
+
for i in range(n_gt):
|
| 501 |
+
pred = np.array(pred_padded[i][:2])
|
| 502 |
+
gt = np.array(gt_waypoints[i][:2])
|
| 503 |
+
l2 = float(np.linalg.norm(pred - gt))
|
| 504 |
+
all_l2.append(l2)
|
| 505 |
+
t = timestamps[i] if i < len(timestamps) else 0.5 * (i + 1)
|
| 506 |
+
if abs(t - 1.0) < 0.01:
|
| 507 |
+
errors['L2_1s'] = l2
|
| 508 |
+
if abs(t - 2.0) < 0.01:
|
| 509 |
+
errors['L2_2s'] = l2
|
| 510 |
+
if abs(t - 3.0) < 0.01:
|
| 511 |
+
errors['L2_3s'] = l2
|
| 512 |
+
|
| 513 |
+
key_steps = [v for k, v in errors.items() if k in ('L2_1s', 'L2_2s', 'L2_3s')]
|
| 514 |
+
errors['L2_avg'] = float(np.mean(key_steps)) if key_steps else (float(np.mean(all_l2)) if all_l2 else 0.0)
|
| 515 |
+
|
| 516 |
+
return errors
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def _box_corners_2d(cx: float, cy: float, w: float, l: float, yaw: float) -> np.ndarray:
|
| 520 |
+
"""Build oriented box corners for yaw-from-x headings.
|
| 521 |
+
|
| 522 |
+
In planning eval JSON, yaw is measured from +X (right) axis:
|
| 523 |
+
- yaw = 0 -> vehicle length points to +X
|
| 524 |
+
- yaw = +pi/2 -> vehicle length points to +Y
|
| 525 |
+
This matches the qualitative visualization helper.
|
| 526 |
+
"""
|
| 527 |
+
c = np.cos(yaw)
|
| 528 |
+
s = np.sin(yaw)
|
| 529 |
+
center = np.array([cx, cy], dtype=np.float64)
|
| 530 |
+
|
| 531 |
+
# Heading axis follows the vehicle length, with width perpendicular to it.
|
| 532 |
+
d_len = np.array([c, s], dtype=np.float64) * (l / 2.0)
|
| 533 |
+
d_wid = np.array([-s, c], dtype=np.float64) * (w / 2.0)
|
| 534 |
+
|
| 535 |
+
corners = np.stack([
|
| 536 |
+
center + d_len + d_wid,
|
| 537 |
+
center + d_len - d_wid,
|
| 538 |
+
center - d_len - d_wid,
|
| 539 |
+
center - d_len + d_wid,
|
| 540 |
+
], axis=0)
|
| 541 |
+
return corners
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def _boxes_overlap(box1_corners: np.ndarray, box2_corners: np.ndarray) -> bool:
|
| 545 |
+
for box in [box1_corners, box2_corners]:
|
| 546 |
+
for i in range(4):
|
| 547 |
+
j = (i + 1) % 4
|
| 548 |
+
edge = box[j] - box[i]
|
| 549 |
+
normal = np.array([-edge[1], edge[0]])
|
| 550 |
+
proj1 = box1_corners @ normal
|
| 551 |
+
proj2 = box2_corners @ normal
|
| 552 |
+
if proj1.max() < proj2.min() or proj2.max() < proj1.min():
|
| 553 |
+
return False
|
| 554 |
+
return True
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def _check_collision_at_waypoints(
|
| 558 |
+
waypoints: List[List[float]],
|
| 559 |
+
gt_boxes: List[Dict],
|
| 560 |
+
ego_w: float,
|
| 561 |
+
ego_l: float,
|
| 562 |
+
gt_boxes_per_timestep: Optional[List[List[Dict]]] = None,
|
| 563 |
+
) -> List[bool]:
|
| 564 |
+
"""Check collision between ego at each waypoint and GT boxes.
|
| 565 |
+
|
| 566 |
+
When *gt_boxes_per_timestep* is provided (ST-P3 aligned), each waypoint
|
| 567 |
+
is checked against the boxes at the corresponding future timestep.
|
| 568 |
+
Otherwise falls back to using the same static *gt_boxes* for all waypoints.
|
| 569 |
+
"""
|
| 570 |
+
collisions = []
|
| 571 |
+
for i, wp in enumerate(waypoints):
|
| 572 |
+
if i + 1 < len(waypoints):
|
| 573 |
+
dx = waypoints[i + 1][0] - wp[0]
|
| 574 |
+
dy = waypoints[i + 1][1] - wp[1]
|
| 575 |
+
ego_yaw = float(np.arctan2(dy, dx)) if (abs(dx) + abs(dy)) > 1e-4 else 0.0
|
| 576 |
+
elif i > 0:
|
| 577 |
+
dx = wp[0] - waypoints[i - 1][0]
|
| 578 |
+
dy = wp[1] - waypoints[i - 1][1]
|
| 579 |
+
ego_yaw = float(np.arctan2(dy, dx)) if (abs(dx) + abs(dy)) > 1e-4 else 0.0
|
| 580 |
+
else:
|
| 581 |
+
ego_yaw = 0.0
|
| 582 |
+
ego_corners = _box_corners_2d(wp[0], wp[1], ego_w, ego_l, ego_yaw)
|
| 583 |
+
|
| 584 |
+
boxes_at_t = gt_boxes
|
| 585 |
+
if gt_boxes_per_timestep is not None and i < len(gt_boxes_per_timestep):
|
| 586 |
+
boxes_at_t = gt_boxes_per_timestep[i]
|
| 587 |
+
|
| 588 |
+
collided = False
|
| 589 |
+
for box in boxes_at_t:
|
| 590 |
+
if 'world_coords' not in box:
|
| 591 |
+
continue
|
| 592 |
+
bx, by = box['world_coords'][0], box['world_coords'][1]
|
| 593 |
+
bw = box.get('w', 2.0)
|
| 594 |
+
bl = box.get('l', 4.0)
|
| 595 |
+
byaw = box.get('yaw', 0.0)
|
| 596 |
+
obj_corners = _box_corners_2d(bx, by, bw, bl, byaw)
|
| 597 |
+
if _boxes_overlap(ego_corners, obj_corners):
|
| 598 |
+
collided = True
|
| 599 |
+
break
|
| 600 |
+
collisions.append(collided)
|
| 601 |
+
return collisions
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def calculate_collision_rate(
|
| 605 |
+
pred_waypoints: List[List[float]],
|
| 606 |
+
gt_boxes: List[Dict],
|
| 607 |
+
ego_w: float = 1.85,
|
| 608 |
+
ego_l: float = 4.084,
|
| 609 |
+
timestamps: List[float] = None,
|
| 610 |
+
num_waypoints: int = 6,
|
| 611 |
+
gt_waypoints: Optional[List[List[float]]] = None,
|
| 612 |
+
gt_boxes_per_timestep: Optional[List[List[Dict]]] = None,
|
| 613 |
+
) -> Dict[str, float]:
|
| 614 |
+
pred_padded = _pad_waypoints(pred_waypoints, target_n=num_waypoints)
|
| 615 |
+
if timestamps is None:
|
| 616 |
+
timestamps = [0.5 * (i + 1) for i in range(num_waypoints)]
|
| 617 |
+
|
| 618 |
+
# ST-P3 aligned: exclude timesteps where the GT trajectory itself collides
|
| 619 |
+
gt_collides = [False] * num_waypoints
|
| 620 |
+
if gt_waypoints is not None:
|
| 621 |
+
gt_padded = _pad_waypoints(gt_waypoints, target_n=num_waypoints)
|
| 622 |
+
gt_collides = _check_collision_at_waypoints(
|
| 623 |
+
gt_padded, gt_boxes, ego_w, ego_l,
|
| 624 |
+
gt_boxes_per_timestep=gt_boxes_per_timestep,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
pred_collides = _check_collision_at_waypoints(
|
| 628 |
+
pred_padded, gt_boxes, ego_w, ego_l,
|
| 629 |
+
gt_boxes_per_timestep=gt_boxes_per_timestep,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
collisions_at_t = {}
|
| 633 |
+
for i in range(num_waypoints):
|
| 634 |
+
t = timestamps[i] if i < len(timestamps) else 0.5 * (i + 1)
|
| 635 |
+
if gt_collides[i]:
|
| 636 |
+
collisions_at_t[t] = False
|
| 637 |
+
else:
|
| 638 |
+
collisions_at_t[t] = pred_collides[i]
|
| 639 |
+
|
| 640 |
+
results = {}
|
| 641 |
+
for target_t, key in [(1.0, 'collision_1s'), (2.0, 'collision_2s'), (3.0, 'collision_3s')]:
|
| 642 |
+
matched = [v for t, v in collisions_at_t.items() if abs(t - target_t) < 0.01]
|
| 643 |
+
if matched:
|
| 644 |
+
results[key] = float(matched[0])
|
| 645 |
+
|
| 646 |
+
key_cols = [v for k, v in results.items() if k in ('collision_1s', 'collision_2s', 'collision_3s')]
|
| 647 |
+
results['collision_avg'] = float(np.mean(key_cols)) if key_cols else 0.0
|
| 648 |
+
|
| 649 |
+
return results
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def calculate_planning_metrics(
|
| 653 |
+
predictions: List[Dict],
|
| 654 |
+
ground_truths: List[Dict],
|
| 655 |
+
) -> Dict[str, float]:
|
| 656 |
+
all_l2 = {'L2_1s': [], 'L2_2s': [], 'L2_3s': [], 'L2_avg': []}
|
| 657 |
+
all_col = {'collision_1s': [], 'collision_2s': [], 'collision_3s': [], 'collision_avg': []}
|
| 658 |
+
|
| 659 |
+
for pred, gt in zip(predictions, ground_truths):
|
| 660 |
+
pred_wps = pred.get('waypoints', [])
|
| 661 |
+
gt_wps = gt.get('waypoints', [])
|
| 662 |
+
if pred_wps and gt_wps:
|
| 663 |
+
l2 = calculate_planning_l2(pred_wps, gt_wps)
|
| 664 |
+
for k, v in l2.items():
|
| 665 |
+
if k in all_l2:
|
| 666 |
+
all_l2[k].append(v)
|
| 667 |
+
|
| 668 |
+
gt_boxes = gt.get('gt_boxes', [])
|
| 669 |
+
gt_boxes_per_ts = gt.get('gt_boxes_per_timestep', None)
|
| 670 |
+
if pred_wps and (gt_boxes or gt_boxes_per_ts):
|
| 671 |
+
col = calculate_collision_rate(
|
| 672 |
+
pred_wps, gt_boxes, gt_waypoints=gt_wps,
|
| 673 |
+
gt_boxes_per_timestep=gt_boxes_per_ts,
|
| 674 |
+
)
|
| 675 |
+
for k, v in col.items():
|
| 676 |
+
if k in all_col:
|
| 677 |
+
all_col[k].append(v)
|
| 678 |
+
|
| 679 |
+
results = {}
|
| 680 |
+
for k, vals in all_l2.items():
|
| 681 |
+
results[k] = float(np.mean(vals)) if vals else 0.0
|
| 682 |
+
for k, vals in all_col.items():
|
| 683 |
+
results[k] = float(np.mean(vals)) if vals else 0.0
|
| 684 |
+
|
| 685 |
+
return results
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
VEL_ACC_RANGE = (-50.0, 50.0)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def vel_acc_bin_to_meters(bin_val: int, num_bins: int = 1000) -> float:
|
| 692 |
+
return bin_to_meters(bin_val, bin_range=VEL_ACC_RANGE, num_bins=num_bins)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def chamfer_distance_polyline(
|
| 696 |
+
pred_pts: np.ndarray,
|
| 697 |
+
gt_pts: np.ndarray,
|
| 698 |
+
) -> float:
|
| 699 |
+
if len(pred_pts) == 0 or len(gt_pts) == 0:
|
| 700 |
+
return float('inf')
|
| 701 |
+
pred_pts = np.asarray(pred_pts, dtype=np.float64)
|
| 702 |
+
gt_pts = np.asarray(gt_pts, dtype=np.float64)
|
| 703 |
+
d_p2g = 0.0
|
| 704 |
+
for p in pred_pts:
|
| 705 |
+
d_p2g += np.linalg.norm(gt_pts - p[None, :], axis=1).min()
|
| 706 |
+
d_p2g /= len(pred_pts)
|
| 707 |
+
d_g2p = 0.0
|
| 708 |
+
for g in gt_pts:
|
| 709 |
+
d_g2p += np.linalg.norm(pred_pts - g[None, :], axis=1).min()
|
| 710 |
+
d_g2p /= len(gt_pts)
|
| 711 |
+
return 0.5 * (d_p2g + d_g2p)
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def _lane_points_array(lane) -> np.ndarray:
|
| 715 |
+
pts = lane.get('points', [])
|
| 716 |
+
if not pts:
|
| 717 |
+
return np.zeros((0, 3))
|
| 718 |
+
rows = []
|
| 719 |
+
for pt in pts:
|
| 720 |
+
if isinstance(pt, dict):
|
| 721 |
+
rows.append(pt.get('world_coords', [0, 0, 0])[:3])
|
| 722 |
+
else:
|
| 723 |
+
rows.append(list(pt)[:3])
|
| 724 |
+
return np.array(rows, dtype=np.float64)
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
def match_lanes(
|
| 728 |
+
pred_lanes: List[Dict],
|
| 729 |
+
gt_lanes: List[Dict],
|
| 730 |
+
threshold: float = 1.5,
|
| 731 |
+
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
|
| 732 |
+
if not pred_lanes:
|
| 733 |
+
return [], [], list(range(len(gt_lanes)))
|
| 734 |
+
if not gt_lanes:
|
| 735 |
+
return [], list(range(len(pred_lanes))), []
|
| 736 |
+
|
| 737 |
+
n_p = len(pred_lanes)
|
| 738 |
+
n_g = len(gt_lanes)
|
| 739 |
+
cost = np.full((n_p, n_g), float('inf'))
|
| 740 |
+
|
| 741 |
+
for i, pl in enumerate(pred_lanes):
|
| 742 |
+
p_pts = _lane_points_array(pl)
|
| 743 |
+
if len(p_pts) == 0:
|
| 744 |
+
continue
|
| 745 |
+
for j, gl in enumerate(gt_lanes):
|
| 746 |
+
g_pts = _lane_points_array(gl)
|
| 747 |
+
if len(g_pts) == 0:
|
| 748 |
+
continue
|
| 749 |
+
cd = chamfer_distance_polyline(p_pts, g_pts)
|
| 750 |
+
if cd < threshold:
|
| 751 |
+
cost[i, j] = cd
|
| 752 |
+
|
| 753 |
+
matches = []
|
| 754 |
+
matched_p = set()
|
| 755 |
+
matched_g = set()
|
| 756 |
+
|
| 757 |
+
if SCIPY_AVAILABLE and n_p > 0 and n_g > 0 and np.isfinite(cost).any():
|
| 758 |
+
try:
|
| 759 |
+
row_ind, col_ind = linear_sum_assignment(cost)
|
| 760 |
+
except ValueError:
|
| 761 |
+
row_ind, col_ind = [], []
|
| 762 |
+
for pi, gi in zip(row_ind, col_ind):
|
| 763 |
+
if cost[pi, gi] < threshold:
|
| 764 |
+
matches.append((pi, gi))
|
| 765 |
+
matched_p.add(pi)
|
| 766 |
+
matched_g.add(gi)
|
| 767 |
+
else:
|
| 768 |
+
pairs = []
|
| 769 |
+
for i in range(n_p):
|
| 770 |
+
for j in range(n_g):
|
| 771 |
+
if cost[i, j] < threshold:
|
| 772 |
+
pairs.append((cost[i, j], i, j))
|
| 773 |
+
pairs.sort()
|
| 774 |
+
for _, i, j in pairs:
|
| 775 |
+
if i not in matched_p and j not in matched_g:
|
| 776 |
+
matches.append((i, j))
|
| 777 |
+
matched_p.add(i)
|
| 778 |
+
matched_g.add(j)
|
| 779 |
+
|
| 780 |
+
fp = [i for i in range(n_p) if i not in matched_p]
|
| 781 |
+
fn = [j for j in range(n_g) if j not in matched_g]
|
| 782 |
+
return matches, fp, fn
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def calculate_lane_detection_metrics(
|
| 786 |
+
pred_lanes: List[Dict],
|
| 787 |
+
gt_lanes: List[Dict],
|
| 788 |
+
threshold: float = 1.5,
|
| 789 |
+
) -> Dict[str, float]:
|
| 790 |
+
matches, fp_list, fn_list = match_lanes(pred_lanes, gt_lanes, threshold)
|
| 791 |
+
tp = len(matches)
|
| 792 |
+
fp = len(fp_list)
|
| 793 |
+
fn = len(fn_list)
|
| 794 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 795 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 796 |
+
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
| 797 |
+
return {
|
| 798 |
+
'lane_precision': precision,
|
| 799 |
+
'lane_recall': recall,
|
| 800 |
+
'lane_f1': f1,
|
| 801 |
+
'lane_tp': tp,
|
| 802 |
+
'lane_fp': fp,
|
| 803 |
+
'lane_fn': fn,
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def calculate_multi_threshold_detection_f1(
|
| 808 |
+
predictions: List[Dict],
|
| 809 |
+
ground_truths: List[Dict],
|
| 810 |
+
thresholds: Tuple[float, ...] = (0.5, 1.0, 2.0, 4.0),
|
| 811 |
+
) -> Dict[str, float]:
|
| 812 |
+
results = {}
|
| 813 |
+
f1_vals = []
|
| 814 |
+
for t in thresholds:
|
| 815 |
+
m = calculate_detection_f1(predictions, ground_truths, threshold=t)
|
| 816 |
+
results[f'P@{t}m'] = m['precision']
|
| 817 |
+
results[f'R@{t}m'] = m['recall']
|
| 818 |
+
results[f'F1@{t}m'] = m['f1']
|
| 819 |
+
f1_vals.append(m['f1'])
|
| 820 |
+
results['F1_avg'] = float(np.mean(f1_vals)) if f1_vals else 0.0
|
| 821 |
+
return results
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
def evaluate_all(
|
| 825 |
+
task_predictions: Dict[str, List],
|
| 826 |
+
task_ground_truths: Dict[str, List],
|
| 827 |
+
) -> Dict[str, Dict[str, float]]:
|
| 828 |
+
results = {}
|
| 829 |
+
|
| 830 |
+
if 'detection' in task_predictions and 'detection' in task_ground_truths:
|
| 831 |
+
results['detection'] = calculate_multi_threshold_detection_f1(
|
| 832 |
+
task_predictions['detection'],
|
| 833 |
+
task_ground_truths['detection'],
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
if 'lane' in task_predictions and 'lane' in task_ground_truths:
|
| 837 |
+
agg = {'lane_precision': [], 'lane_recall': [], 'lane_f1': []}
|
| 838 |
+
for pred_set, gt_set in zip(task_predictions['lane'], task_ground_truths['lane']):
|
| 839 |
+
p_list = pred_set if isinstance(pred_set, list) else [pred_set]
|
| 840 |
+
g_list = gt_set if isinstance(gt_set, list) else [gt_set]
|
| 841 |
+
m = calculate_lane_detection_metrics(p_list, g_list)
|
| 842 |
+
for k in agg:
|
| 843 |
+
agg[k].append(m[k])
|
| 844 |
+
results['lane'] = {k: float(np.mean(v)) for k, v in agg.items() if v}
|
| 845 |
+
|
| 846 |
+
if 'planning' in task_predictions and 'planning' in task_ground_truths:
|
| 847 |
+
results['planning'] = calculate_planning_metrics(
|
| 848 |
+
task_predictions['planning'],
|
| 849 |
+
task_ground_truths['planning'],
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
return results
|
src/model/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Atlas model module."""
|
| 2 |
+
|
| 3 |
+
from .topomlp_adapter import TopoMLPToAtlasMapTokens
|
| 4 |
+
from .streampetr_adapter import extract_streampetr_topk_tokens
|
| 5 |
+
|
| 6 |
+
# Atlas (LLM) side depends on `transformers`. StreamPETR/TopoMLP pretraining does not.
|
| 7 |
+
try:
|
| 8 |
+
from .configuration_atlas import AtlasConfig
|
| 9 |
+
from .modeling_atlas import AtlasProjector, AtlasForCausalLM
|
| 10 |
+
_ATLAS_AVAILABLE = True
|
| 11 |
+
except Exception:
|
| 12 |
+
AtlasConfig = None # type: ignore[assignment]
|
| 13 |
+
AtlasProjector = None # type: ignore[assignment]
|
| 14 |
+
AtlasForCausalLM = None # type: ignore[assignment]
|
| 15 |
+
_ATLAS_AVAILABLE = False
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"TopoMLPToAtlasMapTokens",
|
| 19 |
+
"extract_streampetr_topk_tokens",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
if _ATLAS_AVAILABLE:
|
| 23 |
+
__all__ += [
|
| 24 |
+
"AtlasConfig",
|
| 25 |
+
"AtlasProjector",
|
| 26 |
+
"AtlasForCausalLM",
|
| 27 |
+
]
|
| 28 |
+
|
src/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (622 Bytes). View file
|
|
|
src/model/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (622 Bytes). View file
|
|
|
src/model/__pycache__/configuration_atlas.cpython-310.pyc
ADDED
|
Binary file (852 Bytes). View file
|
|
|
src/model/__pycache__/modeling_atlas.cpython-310.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
src/model/__pycache__/modeling_atlas.cpython-38.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
src/model/__pycache__/streampetr_adapter.cpython-310.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
src/model/__pycache__/streampetr_adapter.cpython-38.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
src/model/__pycache__/topomlp_adapter.cpython-310.pyc
ADDED
|
Binary file (3.91 kB). View file
|
|
|
src/model/__pycache__/topomlp_adapter.cpython-38.pyc
ADDED
|
Binary file (3.81 kB). View file
|
|
|
src/model/modeling_atlas.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Atlas model (LLM + visual token injection)."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from typing import Optional, Tuple, List, Union, Dict
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForCausalLM,
|
| 9 |
+
AutoConfig,
|
| 10 |
+
BitsAndBytesConfig,
|
| 11 |
+
)
|
| 12 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from peft import (
|
| 16 |
+
LoraConfig,
|
| 17 |
+
get_peft_model,
|
| 18 |
+
prepare_model_for_kbit_training,
|
| 19 |
+
TaskType,
|
| 20 |
+
)
|
| 21 |
+
PEFT_AVAILABLE = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
PEFT_AVAILABLE = False
|
| 24 |
+
from src.audit.audit_utils import audit_enabled, audit_check
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_quantization_config(
|
| 28 |
+
load_in_4bit: bool = True,
|
| 29 |
+
bnb_4bit_compute_dtype: torch.dtype = torch.bfloat16,
|
| 30 |
+
bnb_4bit_quant_type: str = "nf4",
|
| 31 |
+
bnb_4bit_use_double_quant: bool = True,
|
| 32 |
+
) -> Optional[BitsAndBytesConfig]:
|
| 33 |
+
"""
|
| 34 |
+
Create BitsAndBytes quantization config for 4-bit loading.
|
| 35 |
+
"""
|
| 36 |
+
if not load_in_4bit:
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
return BitsAndBytesConfig(
|
| 40 |
+
load_in_4bit=True,
|
| 41 |
+
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
| 42 |
+
bnb_4bit_quant_type=bnb_4bit_quant_type,
|
| 43 |
+
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ReferencePointProjector(nn.Module):
|
| 48 |
+
"""Reference points [B,Q,3] -> embeddings [B,Q,D]. Zero-initialized per paper."""
|
| 49 |
+
|
| 50 |
+
def __init__(self, visual_hidden_size: int = 256):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.visual_hidden_size = visual_hidden_size
|
| 53 |
+
self.projector_rp = nn.Linear(3, visual_hidden_size)
|
| 54 |
+
nn.init.zeros_(self.projector_rp.weight)
|
| 55 |
+
if self.projector_rp.bias is not None:
|
| 56 |
+
nn.init.zeros_(self.projector_rp.bias)
|
| 57 |
+
|
| 58 |
+
def forward(self, ref_points: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
return self.projector_rp(ref_points.to(self.projector_rp.weight.dtype))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class AtlasProjector(nn.Module):
|
| 63 |
+
"""Project visual features [B,Q,Dv] -> LLM hidden [B,Q,H]."""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
visual_hidden_size: int = 256,
|
| 68 |
+
llm_hidden_size: int = 4096,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.visual_hidden_size = visual_hidden_size
|
| 72 |
+
self.llm_hidden_size = llm_hidden_size
|
| 73 |
+
self.projector = nn.Linear(visual_hidden_size, llm_hidden_size)
|
| 74 |
+
nn.init.xavier_uniform_(self.projector.weight)
|
| 75 |
+
if self.projector.bias is not None:
|
| 76 |
+
nn.init.zeros_(self.projector.bias)
|
| 77 |
+
|
| 78 |
+
def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
return self.projector(visual_features.to(self.projector.weight.dtype))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class AtlasUnifiedProjector(nn.Module):
|
| 83 |
+
def __init__(self, visual_hidden_size: int, llm_hidden_size: int):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.projector_det = AtlasProjector(visual_hidden_size, llm_hidden_size)
|
| 86 |
+
self.projector_map = AtlasProjector(visual_hidden_size, llm_hidden_size)
|
| 87 |
+
self.projector_rp = ReferencePointProjector(visual_hidden_size)
|
| 88 |
+
|
| 89 |
+
def forward(
|
| 90 |
+
self,
|
| 91 |
+
detection_features: torch.Tensor,
|
| 92 |
+
map_features: Optional[torch.Tensor] = None,
|
| 93 |
+
detection_ref_points: Optional[torch.Tensor] = None,
|
| 94 |
+
map_ref_points: Optional[torch.Tensor] = None,
|
| 95 |
+
) -> Dict[str, torch.Tensor]:
|
| 96 |
+
if detection_ref_points is not None:
|
| 97 |
+
detection_features = detection_features + self.projector_rp(detection_ref_points)
|
| 98 |
+
det_embeds = self.projector_det(detection_features)
|
| 99 |
+
|
| 100 |
+
out: Dict[str, torch.Tensor] = {"detection": det_embeds}
|
| 101 |
+
|
| 102 |
+
if map_features is not None:
|
| 103 |
+
if map_ref_points is not None:
|
| 104 |
+
map_features = map_features + self.projector_rp(map_ref_points)
|
| 105 |
+
out["map"] = self.projector_map(map_features)
|
| 106 |
+
|
| 107 |
+
return out
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class AtlasForCausalLM(nn.Module):
|
| 111 |
+
"""Atlas model."""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
llm_model_name: str = "lmsys/vicuna-7b-v1.5",
|
| 116 |
+
visual_hidden_size: int = 256,
|
| 117 |
+
num_queries: int = 256,
|
| 118 |
+
num_map_queries: int = 256,
|
| 119 |
+
load_in_4bit: bool = False,
|
| 120 |
+
use_flash_attention: bool = True,
|
| 121 |
+
device_map: Optional[str] = None,
|
| 122 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 123 |
+
use_lora: bool = False,
|
| 124 |
+
lora_r: int = 64,
|
| 125 |
+
lora_alpha: int = 64,
|
| 126 |
+
lora_dropout: float = 0.1,
|
| 127 |
+
lora_target_modules: List[str] = None,
|
| 128 |
+
):
|
| 129 |
+
super().__init__()
|
| 130 |
+
|
| 131 |
+
self.llm_model_name = llm_model_name
|
| 132 |
+
self.visual_hidden_size = visual_hidden_size
|
| 133 |
+
self.num_queries = num_queries
|
| 134 |
+
self.num_map_queries = num_map_queries
|
| 135 |
+
self.query_token_id = None
|
| 136 |
+
|
| 137 |
+
if load_in_4bit and not use_lora:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
"load_in_4bit=True requires use_lora=True (4-bit weights are not trainable)."
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if use_lora and not PEFT_AVAILABLE:
|
| 143 |
+
raise ImportError("LoRA requires peft library")
|
| 144 |
+
self.llm_config = AutoConfig.from_pretrained(llm_model_name)
|
| 145 |
+
self.llm_hidden_size = self.llm_config.hidden_size
|
| 146 |
+
|
| 147 |
+
quantization_config = None
|
| 148 |
+
if load_in_4bit:
|
| 149 |
+
quantization_config = get_quantization_config(
|
| 150 |
+
load_in_4bit=True,
|
| 151 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
attn_implementation = None
|
| 155 |
+
if use_flash_attention:
|
| 156 |
+
try:
|
| 157 |
+
import flash_attn # noqa: F401
|
| 158 |
+
attn_implementation = "flash_attention_2"
|
| 159 |
+
except ImportError:
|
| 160 |
+
raise ImportError(
|
| 161 |
+
"use_flash_attention=True but flash_attn is not installed. "
|
| 162 |
+
"Install with: pip install flash-attn --no-build-isolation"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.llm = AutoModelForCausalLM.from_pretrained(
|
| 166 |
+
llm_model_name,
|
| 167 |
+
quantization_config=quantization_config,
|
| 168 |
+
device_map=device_map,
|
| 169 |
+
torch_dtype=torch_dtype,
|
| 170 |
+
attn_implementation=attn_implementation,
|
| 171 |
+
trust_remote_code=True,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if use_lora:
|
| 175 |
+
if load_in_4bit:
|
| 176 |
+
self.llm = prepare_model_for_kbit_training(self.llm)
|
| 177 |
+
if lora_target_modules is None:
|
| 178 |
+
lora_target_modules = [
|
| 179 |
+
"q_proj", "k_proj", "v_proj", "o_proj", # Attention
|
| 180 |
+
"gate_proj", "up_proj", "down_proj", # MLP
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
lora_config = LoraConfig(
|
| 184 |
+
r=lora_r,
|
| 185 |
+
lora_alpha=lora_alpha,
|
| 186 |
+
lora_dropout=lora_dropout,
|
| 187 |
+
target_modules=lora_target_modules,
|
| 188 |
+
bias="none",
|
| 189 |
+
task_type=TaskType.CAUSAL_LM,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.llm = get_peft_model(self.llm, lora_config)
|
| 193 |
+
|
| 194 |
+
self.projector = AtlasUnifiedProjector(
|
| 195 |
+
visual_hidden_size=visual_hidden_size,
|
| 196 |
+
llm_hidden_size=self.llm_hidden_size,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
llm_device = next(self.llm.parameters()).device
|
| 201 |
+
self.projector = self.projector.to(llm_device)
|
| 202 |
+
except StopIteration:
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
def set_query_token_id(self, query_token_id: int):
|
| 206 |
+
self.query_token_id = query_token_id
|
| 207 |
+
|
| 208 |
+
def get_input_embeddings(self):
|
| 209 |
+
return self.llm.get_input_embeddings()
|
| 210 |
+
|
| 211 |
+
def resize_token_embeddings(self, new_num_tokens: int):
|
| 212 |
+
self.llm.resize_token_embeddings(new_num_tokens)
|
| 213 |
+
|
| 214 |
+
def gradient_checkpointing_enable(self):
|
| 215 |
+
if hasattr(self.llm, 'gradient_checkpointing_enable'):
|
| 216 |
+
self.llm.gradient_checkpointing_enable()
|
| 217 |
+
|
| 218 |
+
_NO_DECAY_KEYWORDS = {"bias", "LayerNorm.weight", "layernorm.weight",
|
| 219 |
+
"layer_norm.weight", "norm.weight"}
|
| 220 |
+
|
| 221 |
+
def get_trainable_param_groups(self, lr: float, weight_decay: float = 0.0) -> List[Dict]:
|
| 222 |
+
decay, no_decay = [], []
|
| 223 |
+
for module in [self.projector, self.llm]:
|
| 224 |
+
for name, param in module.named_parameters():
|
| 225 |
+
if not param.requires_grad:
|
| 226 |
+
continue
|
| 227 |
+
if any(nd in name for nd in self._NO_DECAY_KEYWORDS):
|
| 228 |
+
no_decay.append(param)
|
| 229 |
+
else:
|
| 230 |
+
decay.append(param)
|
| 231 |
+
|
| 232 |
+
if not any(p.requires_grad for p in self.projector.parameters()):
|
| 233 |
+
raise RuntimeError("projector has no trainable parameters (requires_grad=False).")
|
| 234 |
+
|
| 235 |
+
groups: List[Dict] = []
|
| 236 |
+
if decay:
|
| 237 |
+
groups.append({"params": decay, "lr": lr, "weight_decay": weight_decay})
|
| 238 |
+
if no_decay:
|
| 239 |
+
groups.append({"params": no_decay, "lr": lr, "weight_decay": 0.0})
|
| 240 |
+
return groups
|
| 241 |
+
|
| 242 |
+
def get_expected_trainable_param_ids(self, lr: float) -> set:
|
| 243 |
+
"""Convenience helper for 'optimizer coverage' hard checks."""
|
| 244 |
+
param_ids: set = set()
|
| 245 |
+
for g in self.get_trainable_param_groups(lr):
|
| 246 |
+
for p in g["params"]:
|
| 247 |
+
param_ids.add(id(p))
|
| 248 |
+
return param_ids
|
| 249 |
+
|
| 250 |
+
def parameters(self, recurse: bool = True):
|
| 251 |
+
for param in self.projector.parameters(recurse):
|
| 252 |
+
yield param
|
| 253 |
+
for param in self.llm.parameters(recurse):
|
| 254 |
+
yield param
|
| 255 |
+
|
| 256 |
+
def forward(
|
| 257 |
+
self,
|
| 258 |
+
input_ids: torch.LongTensor,
|
| 259 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 260 |
+
visual_features: Optional[Union[torch.Tensor, Dict]] = None,
|
| 261 |
+
labels: Optional[torch.LongTensor] = None,
|
| 262 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 263 |
+
use_cache: Optional[bool] = None,
|
| 264 |
+
output_attentions: Optional[bool] = None,
|
| 265 |
+
output_hidden_states: Optional[bool] = None,
|
| 266 |
+
return_dict: Optional[bool] = True,
|
| 267 |
+
query_token_id: Optional[int] = None,
|
| 268 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 269 |
+
inputs_embeds, attention_mask, position_ids = self._prepare_llm_inputs(
|
| 270 |
+
input_ids=input_ids,
|
| 271 |
+
attention_mask=attention_mask,
|
| 272 |
+
visual_features=visual_features,
|
| 273 |
+
labels=labels,
|
| 274 |
+
query_token_id=query_token_id,
|
| 275 |
+
)
|
| 276 |
+
if self.training:
|
| 277 |
+
use_cache = False
|
| 278 |
+
|
| 279 |
+
# Forward through LLM
|
| 280 |
+
outputs = self.llm(
|
| 281 |
+
input_ids=None,
|
| 282 |
+
attention_mask=attention_mask,
|
| 283 |
+
inputs_embeds=inputs_embeds,
|
| 284 |
+
position_ids=position_ids,
|
| 285 |
+
labels=labels,
|
| 286 |
+
past_key_values=past_key_values,
|
| 287 |
+
use_cache=use_cache,
|
| 288 |
+
output_attentions=output_attentions,
|
| 289 |
+
output_hidden_states=output_hidden_states,
|
| 290 |
+
return_dict=return_dict,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return outputs
|
| 294 |
+
|
| 295 |
+
def _prepare_llm_inputs(
|
| 296 |
+
self,
|
| 297 |
+
input_ids: torch.LongTensor,
|
| 298 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 299 |
+
visual_features: Optional[Union[torch.Tensor, Dict]] = None,
|
| 300 |
+
labels: Optional[torch.LongTensor] = None,
|
| 301 |
+
query_token_id: Optional[int] = None,
|
| 302 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
|
| 303 |
+
"""
|
| 304 |
+
Shared path for forward/generate:
|
| 305 |
+
- build inputs_embeds
|
| 306 |
+
- inject visual tokens into <query> positions
|
| 307 |
+
- ensure attention_mask / position_ids are consistent
|
| 308 |
+
"""
|
| 309 |
+
inputs_embeds = self.llm.get_input_embeddings()(input_ids).clone()
|
| 310 |
+
|
| 311 |
+
_attention_mask_before_inject = None
|
| 312 |
+
if attention_mask is not None:
|
| 313 |
+
attention_mask = attention_mask.clone()
|
| 314 |
+
_attention_mask_before_inject = attention_mask.clone()
|
| 315 |
+
|
| 316 |
+
assert inputs_embeds.shape[:2] == input_ids.shape[:2], (
|
| 317 |
+
f"shape mismatch: inputs_embeds={inputs_embeds.shape[:2]} != input_ids={input_ids.shape[:2]}"
|
| 318 |
+
)
|
| 319 |
+
if attention_mask is not None:
|
| 320 |
+
assert attention_mask.shape[:2] == input_ids.shape[:2], (
|
| 321 |
+
f"shape mismatch: attention_mask={attention_mask.shape[:2]} != input_ids={input_ids.shape[:2]}"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
qid = query_token_id if query_token_id is not None else self.query_token_id
|
| 325 |
+
|
| 326 |
+
if visual_features is not None and qid is not None:
|
| 327 |
+
batch_size = input_ids.shape[0]
|
| 328 |
+
device = inputs_embeds.device
|
| 329 |
+
dtype = inputs_embeds.dtype
|
| 330 |
+
|
| 331 |
+
if isinstance(visual_features, dict):
|
| 332 |
+
detection_features = visual_features["detection"]
|
| 333 |
+
map_features = visual_features.get("map", None)
|
| 334 |
+
det_ref_points = visual_features.get("detection_ref_points", None)
|
| 335 |
+
map_ref_points = visual_features.get("map_ref_points", None)
|
| 336 |
+
|
| 337 |
+
projected = self.projector(
|
| 338 |
+
detection_features=detection_features,
|
| 339 |
+
map_features=map_features,
|
| 340 |
+
detection_ref_points=det_ref_points,
|
| 341 |
+
map_ref_points=map_ref_points,
|
| 342 |
+
)
|
| 343 |
+
detection_embeds = projected["detection"].to(dtype=dtype, device=device)
|
| 344 |
+
map_embeds = projected.get("map", None)
|
| 345 |
+
if map_embeds is not None:
|
| 346 |
+
map_embeds = map_embeds.to(dtype=dtype, device=device)
|
| 347 |
+
num_det = detection_embeds.shape[1]
|
| 348 |
+
num_map = map_embeds.shape[1] if map_embeds is not None else 0
|
| 349 |
+
else:
|
| 350 |
+
detection_embeds = self.projector.projector_det(visual_features).to(dtype=dtype, device=device)
|
| 351 |
+
map_embeds = None
|
| 352 |
+
num_det = detection_embeds.shape[1]
|
| 353 |
+
num_map = 0
|
| 354 |
+
|
| 355 |
+
for b in range(batch_size):
|
| 356 |
+
query_positions = torch.where(input_ids[b] == qid)[0]
|
| 357 |
+
num_query_slots = int(query_positions.numel())
|
| 358 |
+
if num_query_slots == 0:
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
if labels is not None:
|
| 362 |
+
ans_pos = (labels[b] != -100).nonzero(as_tuple=True)[0]
|
| 363 |
+
if ans_pos.numel() > 0 and audit_enabled():
|
| 364 |
+
first_answer_token_pos = int(ans_pos[0].item())
|
| 365 |
+
q_min = int(query_positions.min().item())
|
| 366 |
+
q_max = int(query_positions.max().item())
|
| 367 |
+
audit_check(
|
| 368 |
+
"B4",
|
| 369 |
+
q_max < first_answer_token_pos,
|
| 370 |
+
once=True,
|
| 371 |
+
min_query_pos=q_min,
|
| 372 |
+
max_query_pos=q_max,
|
| 373 |
+
first_answer_token_pos=first_answer_token_pos,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
if map_embeds is not None and num_query_slots == (num_det + num_map):
|
| 377 |
+
for i in range(num_det):
|
| 378 |
+
inputs_embeds[b, int(query_positions[i].item())] = detection_embeds[b, i]
|
| 379 |
+
for i in range(num_map):
|
| 380 |
+
inputs_embeds[b, int(query_positions[int(num_det) + i].item())] = map_embeds[b, i]
|
| 381 |
+
num_injected = int(num_det + num_map)
|
| 382 |
+
inj_det = int(num_det)
|
| 383 |
+
inj_map = int(num_map)
|
| 384 |
+
elif num_query_slots == num_det:
|
| 385 |
+
for i in range(num_det):
|
| 386 |
+
inputs_embeds[b, int(query_positions[i].item())] = detection_embeds[b, i]
|
| 387 |
+
num_injected = int(num_det)
|
| 388 |
+
inj_det = int(num_det)
|
| 389 |
+
inj_map = 0
|
| 390 |
+
else:
|
| 391 |
+
raise ValueError(
|
| 392 |
+
f"<query> slot mismatch: slots={num_query_slots}, "
|
| 393 |
+
f"det={num_det}, map={num_map if map_embeds is not None else 0}. "
|
| 394 |
+
f"Ensure visual_features provides the correct number of tokens "
|
| 395 |
+
f"matching the prompt's <query> placeholders."
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
if attention_mask is not None and num_injected > 0:
|
| 399 |
+
attention_mask[b, query_positions[:num_injected]] = 1
|
| 400 |
+
|
| 401 |
+
if audit_enabled():
|
| 402 |
+
if not hasattr(self, "_audit_forward_calls"):
|
| 403 |
+
self._audit_forward_calls = 0
|
| 404 |
+
max_calls = int(os.getenv("ATLAS_AUDIT_MAX_FWD", "1"))
|
| 405 |
+
if self._audit_forward_calls < max_calls and b == 0:
|
| 406 |
+
total_injected = int(num_injected)
|
| 407 |
+
print(
|
| 408 |
+
"[ATLAS_AUDIT][A3] "
|
| 409 |
+
f"num_query_tokens_in_input_ids={num_query_slots} "
|
| 410 |
+
f"num_detection_tokens_injected={inj_det} "
|
| 411 |
+
f"num_map_tokens_injected={inj_map} "
|
| 412 |
+
f"total_injected={total_injected}"
|
| 413 |
+
)
|
| 414 |
+
seq_len_input_ids = int(input_ids.shape[1])
|
| 415 |
+
seq_len_embeds = int(inputs_embeds.shape[1])
|
| 416 |
+
seq_len_mask = int(attention_mask.shape[1]) if attention_mask is not None else -1
|
| 417 |
+
print(f"[ATLAS_AUDIT][B1] seq_len_input_ids={seq_len_input_ids} seq_len_embeds={seq_len_embeds} seq_len_mask={seq_len_mask}")
|
| 418 |
+
n_show = min(20, num_query_slots)
|
| 419 |
+
pos_first = query_positions[:n_show].tolist()
|
| 420 |
+
src_first = []
|
| 421 |
+
for i in range(n_show):
|
| 422 |
+
src_first.append("DET" if i < int(num_det) else "MAP")
|
| 423 |
+
packed = ",".join([f"{s}@{p}" for s, p in zip(src_first, pos_first)])
|
| 424 |
+
print(f"[ATLAS_AUDIT][A4] first20={packed}")
|
| 425 |
+
if attention_mask is not None and n_show > 0 and _attention_mask_before_inject is not None:
|
| 426 |
+
q_mask = _attention_mask_before_inject[b, query_positions[:num_injected]]
|
| 427 |
+
bad = int((q_mask == 0).sum().item())
|
| 428 |
+
q_min2 = int(q_mask.min().item()) if q_mask.numel() else -1
|
| 429 |
+
q_max2 = int(q_mask.max().item()) if q_mask.numel() else -1
|
| 430 |
+
print(f"[ATLAS_AUDIT][A5/B2] query_mask_bad_count={bad} query_mask_min={q_min2} query_mask_max={q_max2}")
|
| 431 |
+
assert bad == 0, "<query> positions have attention_mask==0"
|
| 432 |
+
n = min(8, num_query_slots, num_injected)
|
| 433 |
+
if n > 0:
|
| 434 |
+
diffs = []
|
| 435 |
+
for i in range(n):
|
| 436 |
+
pos = int(query_positions[i].item())
|
| 437 |
+
if i < int(num_det):
|
| 438 |
+
ref = detection_embeds[b, i]
|
| 439 |
+
else:
|
| 440 |
+
j = i - int(num_det)
|
| 441 |
+
ref = map_embeds[b, j] if map_embeds is not None else detection_embeds[b, min(i, int(num_det) - 1)]
|
| 442 |
+
diffs.append(float((inputs_embeds[b, pos] - ref).abs().max().item()))
|
| 443 |
+
max_diff = max(diffs) if diffs else 0.0
|
| 444 |
+
print(f"[ATLAS_AUDIT][C1] sampled_max_diff={max_diff:.3e} (n={n})")
|
| 445 |
+
assert max_diff < 1e-5, f"injected embed diff too large: {max_diff}"
|
| 446 |
+
# C3 audit (post-inject)
|
| 447 |
+
if attention_mask is not None:
|
| 448 |
+
text_pos2 = (attention_mask[b] == 1) & (input_ids[b] != qid)
|
| 449 |
+
else:
|
| 450 |
+
text_pos2 = (input_ids[b] != qid)
|
| 451 |
+
text_vec2 = inputs_embeds[b, text_pos2].float()
|
| 452 |
+
vis_vec2 = inputs_embeds[b, query_positions[:num_injected]].float()
|
| 453 |
+
if text_vec2.numel() > 0 and vis_vec2.numel() > 0:
|
| 454 |
+
text_norm = text_vec2.norm(dim=-1)
|
| 455 |
+
vis_norm = vis_vec2.norm(dim=-1)
|
| 456 |
+
t_mean = float(text_norm.mean().item())
|
| 457 |
+
t_std = float(text_norm.std(unbiased=False).item())
|
| 458 |
+
v_mean = float(vis_norm.mean().item())
|
| 459 |
+
v_std = float(vis_norm.std(unbiased=False).item())
|
| 460 |
+
ratio = v_mean / (t_mean + 1e-8)
|
| 461 |
+
rmin = float(os.getenv("ATLAS_VIS_TEXT_NORM_RATIO_MIN", "0.1"))
|
| 462 |
+
rmax = float(os.getenv("ATLAS_VIS_TEXT_NORM_RATIO_MAX", "10.0"))
|
| 463 |
+
audit_check(
|
| 464 |
+
"C3",
|
| 465 |
+
(ratio >= rmin and ratio <= rmax),
|
| 466 |
+
once=True,
|
| 467 |
+
ratio=ratio,
|
| 468 |
+
ratio_min=rmin,
|
| 469 |
+
ratio_max=rmax,
|
| 470 |
+
text_norm_mean=t_mean,
|
| 471 |
+
text_norm_std=t_std,
|
| 472 |
+
vis_norm_mean=v_mean,
|
| 473 |
+
vis_norm_std=v_std,
|
| 474 |
+
dtype=str(inputs_embeds.dtype),
|
| 475 |
+
)
|
| 476 |
+
self._audit_forward_calls += 1
|
| 477 |
+
|
| 478 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
| 479 |
+
device = inputs_embeds.device
|
| 480 |
+
if attention_mask is not None:
|
| 481 |
+
has_token = (attention_mask.sum(dim=1) > 0)
|
| 482 |
+
left_padded = bool(((attention_mask[:, 0] == 0) & has_token).any().item() and ((attention_mask[:, -1] == 1) & has_token).any().item())
|
| 483 |
+
if left_padded:
|
| 484 |
+
position_ids = attention_mask.long().cumsum(dim=1) - 1
|
| 485 |
+
position_ids.masked_fill_(attention_mask == 0, 0)
|
| 486 |
+
else:
|
| 487 |
+
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
|
| 488 |
+
else:
|
| 489 |
+
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
|
| 490 |
+
|
| 491 |
+
return inputs_embeds, attention_mask, position_ids
|
| 492 |
+
|
| 493 |
+
@torch.no_grad()
|
| 494 |
+
def generate(
|
| 495 |
+
self,
|
| 496 |
+
input_ids: torch.LongTensor,
|
| 497 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 498 |
+
visual_features: Optional[Union[torch.Tensor, Dict]] = None,
|
| 499 |
+
max_new_tokens: int = 256,
|
| 500 |
+
query_token_id: Optional[int] = None,
|
| 501 |
+
**kwargs,
|
| 502 |
+
) -> torch.LongTensor:
|
| 503 |
+
"""
|
| 504 |
+
Greedy/standard generation using the SAME projector+injection+pos/mask path as forward.
|
| 505 |
+
"""
|
| 506 |
+
inputs_embeds, attention_mask, position_ids = self._prepare_llm_inputs(
|
| 507 |
+
input_ids=input_ids,
|
| 508 |
+
attention_mask=attention_mask,
|
| 509 |
+
visual_features=visual_features,
|
| 510 |
+
labels=None,
|
| 511 |
+
query_token_id=query_token_id,
|
| 512 |
+
)
|
| 513 |
+
gen_kwargs = dict(
|
| 514 |
+
inputs_embeds=inputs_embeds,
|
| 515 |
+
attention_mask=attention_mask,
|
| 516 |
+
max_new_tokens=max_new_tokens,
|
| 517 |
+
)
|
| 518 |
+
if "pad_token_id" not in kwargs and hasattr(self.llm.config, "pad_token_id"):
|
| 519 |
+
gen_kwargs["pad_token_id"] = self.llm.config.pad_token_id
|
| 520 |
+
gen_kwargs.update(kwargs)
|
| 521 |
+
return self.llm.generate(**gen_kwargs)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def load_atlas_model(
|
| 525 |
+
llm_model_name: str = "lmsys/vicuna-7b-v1.5",
|
| 526 |
+
visual_hidden_size: int = 256,
|
| 527 |
+
num_queries: int = 256,
|
| 528 |
+
num_map_queries: int = 256,
|
| 529 |
+
load_in_4bit: bool = False,
|
| 530 |
+
use_flash_attention: bool = True,
|
| 531 |
+
use_lora: bool = False,
|
| 532 |
+
lora_r: int = 64,
|
| 533 |
+
lora_alpha: int = 64,
|
| 534 |
+
lora_dropout: float = 0.1,
|
| 535 |
+
lora_target_modules: List[str] = None,
|
| 536 |
+
) -> AtlasForCausalLM:
|
| 537 |
+
return AtlasForCausalLM(
|
| 538 |
+
llm_model_name=llm_model_name,
|
| 539 |
+
visual_hidden_size=visual_hidden_size,
|
| 540 |
+
num_queries=num_queries,
|
| 541 |
+
num_map_queries=num_map_queries,
|
| 542 |
+
load_in_4bit=load_in_4bit,
|
| 543 |
+
use_flash_attention=use_flash_attention,
|
| 544 |
+
use_lora=use_lora,
|
| 545 |
+
lora_r=lora_r,
|
| 546 |
+
lora_alpha=lora_alpha,
|
| 547 |
+
lora_dropout=lora_dropout,
|
| 548 |
+
lora_target_modules=lora_target_modules,
|
| 549 |
+
)
|
src/model/streampetr_adapter.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
StreamPETR -> Atlas detection token adapter WITHOUT modifying StreamPETR source code.
|
| 3 |
+
|
| 4 |
+
Rationale:
|
| 5 |
+
- StreamPETRHead updates internal memory with Top-K proposals (topk_proposals, typically 256).
|
| 6 |
+
- After a forward pass, the head stores:
|
| 7 |
+
- self.memory_embedding: [B, memory_len + topk_proposals, D] (prepend new Top-K)
|
| 8 |
+
- self.memory_reference_point: [B, memory_len + topk_proposals, 3]
|
| 9 |
+
(exact ordering: new Top-K are concatenated in front)
|
| 10 |
+
|
| 11 |
+
IMPORTANT: after post_update_memory, memory_reference_point is in the **GLOBAL**
|
| 12 |
+
coordinate frame (ego_pose applied). We must invert the ego_pose to bring
|
| 13 |
+
ref_points back to ego frame before normalizing with pc_range.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from typing import Any, Dict, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
PC_RANGE = (-51.2, -51.2, -5.0, 51.2, 51.2, 3.0)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _normalize_ref_points(
|
| 26 |
+
ref: torch.Tensor,
|
| 27 |
+
pc_range: Tuple[float, float, float, float, float, float] = PC_RANGE,
|
| 28 |
+
) -> torch.Tensor:
|
| 29 |
+
pc_min = ref.new_tensor(pc_range[:3])
|
| 30 |
+
pc_max = ref.new_tensor(pc_range[3:])
|
| 31 |
+
denom = (pc_max - pc_min).clamp(min=1e-6)
|
| 32 |
+
return ((ref - pc_min) / denom).clamp(0.0, 1.0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _global_to_ego(ref: torch.Tensor, ego_pose: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
"""Transform reference points from global frame back to ego frame.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
ref: [B, N, 3] in global coordinates
|
| 40 |
+
ego_pose: [B, 4, 4] ego-to-global transform
|
| 41 |
+
Returns:
|
| 42 |
+
[B, N, 3] in ego coordinates
|
| 43 |
+
"""
|
| 44 |
+
B, N, _ = ref.shape
|
| 45 |
+
ones = torch.ones(B, N, 1, device=ref.device, dtype=ref.dtype)
|
| 46 |
+
ref_homo = torch.cat([ref, ones], dim=-1) # [B, N, 4]
|
| 47 |
+
ego_pose_inv = torch.inverse(ego_pose) # [B, 4, 4]
|
| 48 |
+
ref_ego = (ego_pose_inv.unsqueeze(1) @ ref_homo.unsqueeze(-1)).squeeze(-1)[..., :3]
|
| 49 |
+
return ref_ego
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _nuscenes_ego_to_paper(ref: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
"""Convert nuScenes ego coords to Atlas paper frame.
|
| 54 |
+
|
| 55 |
+
nuScenes ego uses x=forward, y=left. Atlas detection QA uses
|
| 56 |
+
x=right, y=forward, so (x_p, y_p) = (-y_n, x_n).
|
| 57 |
+
"""
|
| 58 |
+
ref_paper = ref.clone()
|
| 59 |
+
ref_paper[..., 0] = -ref[..., 1]
|
| 60 |
+
ref_paper[..., 1] = ref[..., 0]
|
| 61 |
+
return ref_paper
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@torch.no_grad()
|
| 65 |
+
def extract_streampetr_topk_tokens(
|
| 66 |
+
pts_bbox_head: Any,
|
| 67 |
+
topk: int = 256,
|
| 68 |
+
pc_range: Tuple[float, float, float, float, float, float] = PC_RANGE,
|
| 69 |
+
ego_pose: Optional[torch.Tensor] = None,
|
| 70 |
+
) -> Dict[str, torch.Tensor]:
|
| 71 |
+
"""
|
| 72 |
+
Args:
|
| 73 |
+
pts_bbox_head: the StreamPETRHead instance (model.pts_bbox_head)
|
| 74 |
+
topk: number of tokens to export; should match pts_bbox_head.topk_proposals
|
| 75 |
+
pc_range: point cloud range used by StreamPETR, for normalizing ref_points
|
| 76 |
+
ego_pose: [B, 4, 4] ego-to-global transform. If provided, ref_points are
|
| 77 |
+
transformed back from global to ego frame before normalization.
|
| 78 |
+
Returns:
|
| 79 |
+
dict:
|
| 80 |
+
- detection: [B, topk, D]
|
| 81 |
+
- detection_ref_points: [B, topk, 3] (normalized to [0, 1];
|
| 82 |
+
if ego_pose is provided, aligned to Atlas paper frame)
|
| 83 |
+
"""
|
| 84 |
+
if not hasattr(pts_bbox_head, "memory_embedding") or not hasattr(pts_bbox_head, "memory_reference_point"):
|
| 85 |
+
raise RuntimeError("pts_bbox_head missing memory buffers; ensure you have run a forward pass first.")
|
| 86 |
+
|
| 87 |
+
mem = pts_bbox_head.memory_embedding
|
| 88 |
+
ref = pts_bbox_head.memory_reference_point
|
| 89 |
+
if mem is None or ref is None:
|
| 90 |
+
raise RuntimeError("pts_bbox_head memory is None; ensure you have run a forward pass and prev_exists is set.")
|
| 91 |
+
|
| 92 |
+
if mem.ndim != 3 or ref.ndim != 3 or ref.shape[-1] != 3:
|
| 93 |
+
raise RuntimeError(f"unexpected shapes: memory_embedding={getattr(mem,'shape',None)} memory_reference_point={getattr(ref,'shape',None)}")
|
| 94 |
+
|
| 95 |
+
B = mem.shape[0]
|
| 96 |
+
if mem.shape[1] < topk or ref.shape[1] < topk:
|
| 97 |
+
raise RuntimeError(f"memory length too small: mem_len={mem.shape[1]} ref_len={ref.shape[1]} topk={topk}")
|
| 98 |
+
|
| 99 |
+
det = mem[:, :topk, :].contiguous()
|
| 100 |
+
det_ref = ref[:, :topk, :].contiguous()
|
| 101 |
+
|
| 102 |
+
# post_update_memory transforms ref_points to global frame via ego_pose.
|
| 103 |
+
# We invert this to get ego-frame coordinates, then rotate to Atlas paper
|
| 104 |
+
# frame so projector_rp sees the same XY semantics as detection QA/GT.
|
| 105 |
+
if ego_pose is not None:
|
| 106 |
+
det_ref = _global_to_ego(det_ref, ego_pose)
|
| 107 |
+
det_ref = _nuscenes_ego_to_paper(det_ref)
|
| 108 |
+
|
| 109 |
+
det_ref = _normalize_ref_points(det_ref, pc_range)
|
| 110 |
+
return {"detection": det, "detection_ref_points": det_ref}
|
src/model/topomlp_adapter.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TopoMLP -> Atlas map token adapter.
|
| 2 |
+
|
| 3 |
+
Paper-aligned: Top-K selection from TopoMLP decoder outputs,
|
| 4 |
+
followed by a single linear projector (handled by AtlasUnifiedProjector).
|
| 5 |
+
No Perceiver resampler -- queries and ref_points are passed through directly.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Dict, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _lane_control_points_to_center_xyz(lane_preds: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
if lane_preds.ndim != 3 or lane_preds.shape[-1] % 3 != 0:
|
| 17 |
+
raise ValueError(f"lane_preds expected [B,Q,3*K], got {tuple(lane_preds.shape)}")
|
| 18 |
+
B, Q, D = lane_preds.shape
|
| 19 |
+
K = D // 3
|
| 20 |
+
pts = lane_preds.view(B, Q, K, 3)
|
| 21 |
+
return pts.mean(dim=2)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _normalize_xyz(xyz: torch.Tensor, xyz_min: torch.Tensor, xyz_max: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
denom = (xyz_max - xyz_min).clamp(min=1e-6)
|
| 26 |
+
out = (xyz - xyz_min) / denom
|
| 27 |
+
return out.clamp(0.0, 1.0)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TopoMLPToAtlasMapTokens(torch.nn.Module):
|
| 31 |
+
"""Select top-K lane queries from TopoMLP and return them with reference points.
|
| 32 |
+
|
| 33 |
+
Aligned with Atlas paper Section 3.1:
|
| 34 |
+
"these queries are streamlined through a single linear layer"
|
| 35 |
+
The linear projection itself is in AtlasUnifiedProjector.projector_map.
|
| 36 |
+
This module only does Top-K selection + ref_point computation.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
num_map_tokens: int = 256,
|
| 42 |
+
hidden_size: int = 256,
|
| 43 |
+
bev_range: Tuple[float, float, float, float, float, float] = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0),
|
| 44 |
+
**kwargs,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_map_tokens = int(num_map_tokens)
|
| 48 |
+
self.hidden_size = int(hidden_size)
|
| 49 |
+
self.bev_range = tuple(float(x) for x in bev_range)
|
| 50 |
+
|
| 51 |
+
xyz_min = torch.tensor(self.bev_range[:3], dtype=torch.float32)
|
| 52 |
+
xyz_max = torch.tensor(self.bev_range[3:], dtype=torch.float32)
|
| 53 |
+
self.register_buffer("_xyz_min", xyz_min, persistent=False)
|
| 54 |
+
self.register_buffer("_xyz_max", xyz_max, persistent=False)
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def infer_lane_centers_from_outs(self, outs: Dict) -> torch.Tensor:
|
| 58 |
+
one2one_preds = outs["all_lc_preds_list"][-1]
|
| 59 |
+
return _lane_control_points_to_center_xyz(one2one_preds)
|
| 60 |
+
|
| 61 |
+
def forward(self, outs: Dict) -> Dict[str, torch.Tensor]:
|
| 62 |
+
lane_tokens = outs["lc_outs_dec_list"][-1]
|
| 63 |
+
lane_scores = outs["all_lc_cls_scores_list"][-1].squeeze(-1)
|
| 64 |
+
|
| 65 |
+
lane_centers = self.infer_lane_centers_from_outs(outs)
|
| 66 |
+
lane_ref_norm = _normalize_xyz(lane_centers, self._xyz_min, self._xyz_max)
|
| 67 |
+
|
| 68 |
+
B, N, D = lane_tokens.shape
|
| 69 |
+
if N == 0:
|
| 70 |
+
return {
|
| 71 |
+
"map": torch.zeros(B, self.num_map_tokens, D, dtype=lane_tokens.dtype, device=lane_tokens.device),
|
| 72 |
+
"map_ref_points": torch.zeros(B, self.num_map_tokens, 3, dtype=lane_ref_norm.dtype, device=lane_ref_norm.device),
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
k = min(self.num_map_tokens, N)
|
| 76 |
+
topk_idx = torch.topk(lane_scores, k=k, dim=1, largest=True, sorted=True).indices
|
| 77 |
+
tok_idx = topk_idx.unsqueeze(-1).expand(-1, -1, D)
|
| 78 |
+
ref_idx = topk_idx.unsqueeze(-1).expand(-1, -1, 3)
|
| 79 |
+
map_tokens = lane_tokens.gather(dim=1, index=tok_idx)
|
| 80 |
+
map_ref = lane_ref_norm.gather(dim=1, index=ref_idx)
|
| 81 |
+
|
| 82 |
+
if k < self.num_map_tokens:
|
| 83 |
+
pad_t = torch.zeros(B, self.num_map_tokens - k, D, dtype=map_tokens.dtype, device=map_tokens.device)
|
| 84 |
+
pad_r = torch.zeros(B, self.num_map_tokens - k, 3, dtype=map_ref.dtype, device=map_ref.device)
|
| 85 |
+
map_tokens = torch.cat([map_tokens, pad_t], dim=1)
|
| 86 |
+
map_ref = torch.cat([map_ref, pad_r], dim=1)
|
| 87 |
+
|
| 88 |
+
return {"map": map_tokens, "map_ref_points": map_ref}
|
src/prompting.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import re
|
| 3 |
+
from typing import Literal, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from src.audit.audit_utils import audit_enabled, h1b_record_prompt
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
PLANNING_TABLE3_MODES = (
|
| 9 |
+
"atlas_base",
|
| 10 |
+
"atlas_high_level",
|
| 11 |
+
"atlas_high_level_ego",
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
_PLANNING_COMMAND_RE = re.compile(
|
| 15 |
+
r"The ego car will (?:turn left|turn right|go straight) in future\.\s*"
|
| 16 |
+
)
|
| 17 |
+
_PLANNING_STATE_RE = re.compile(
|
| 18 |
+
r"The current speed value of the ego car is \[\d+,\s*\d+\]\.\s*"
|
| 19 |
+
r"The current acceleration value of the ego car is \[\d+,\s*\d+\]\.\s*"
|
| 20 |
+
)
|
| 21 |
+
_PLANNING_ACCEL_SENTENCE = (
|
| 22 |
+
"The acceleration of the vehicle is defined as "
|
| 23 |
+
"[acceleration along the x-axis, acceleration along the y-axis]."
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Paper Table 6 prompts 1-3, plus two repository paraphrases.
|
| 28 |
+
DETECTION_PROMPTS = [
|
| 29 |
+
# Paper Table 6, prompt 1.
|
| 30 |
+
(
|
| 31 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 32 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 33 |
+
"Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
|
| 34 |
+
"Please complete the visual detection task under the Bird's Eye View (BEV) perspective. "
|
| 35 |
+
"Ensure that the detection range does not exceed 50 meters."
|
| 36 |
+
),
|
| 37 |
+
# Paper Table 6, prompt 2.
|
| 38 |
+
(
|
| 39 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 40 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 41 |
+
"Establish the positive y-axis as the frontward direction and the positive x-axis as the rightward direction. "
|
| 42 |
+
"Kindly execute the visual detection task within the Bird's Eye View (BEV) framework. "
|
| 43 |
+
"Be mindful not to exceed a detection range of 50 meters."
|
| 44 |
+
),
|
| 45 |
+
# Paper Table 6, prompt 3.
|
| 46 |
+
(
|
| 47 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 48 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 49 |
+
"Set the forward direction as the positive y-axis and the right direction as the positive x-axis. "
|
| 50 |
+
"Please carry out the visual detection task within the Bird's Eye View (BEV) context. "
|
| 51 |
+
"Ensure that the detection range remains within 50 meters."
|
| 52 |
+
),
|
| 53 |
+
# Additional paraphrase variant 4.
|
| 54 |
+
(
|
| 55 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 56 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 57 |
+
"Let the positive y-axis denote the forward direction and the positive x-axis denote the right direction. "
|
| 58 |
+
"Please perform the visual detection task from the Bird's Eye View (BEV) perspective. "
|
| 59 |
+
"Keep the detection range within 50 meters."
|
| 60 |
+
),
|
| 61 |
+
# Additional paraphrase variant 5.
|
| 62 |
+
(
|
| 63 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 64 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 65 |
+
"Take the positive y-axis to be the forward direction and the positive x-axis to be the right direction. "
|
| 66 |
+
"Kindly carry out the visual detection task under the Bird's Eye View (BEV) perspective. "
|
| 67 |
+
"Ensure that all detections stay within 50 meters."
|
| 68 |
+
),
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
# Paper Table 7 prompts 1-3, plus two repository paraphrases.
|
| 72 |
+
LANE_PROMPTS = [
|
| 73 |
+
# Paper Table 7, prompt 1.
|
| 74 |
+
(
|
| 75 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 76 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 77 |
+
"Please complete the centerline detection task under the Bird's Eye View (BEV) perspective. "
|
| 78 |
+
"Ensure that the detection range does not exceed 50 meters."
|
| 79 |
+
),
|
| 80 |
+
# Paper Table 7, prompt 2. The published paper text appears truncated and is kept verbatim.
|
| 81 |
+
(
|
| 82 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 83 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 84 |
+
"Kindly execute the centerline detection task within the Bird's Eye View (BEV) framework. "
|
| 85 |
+
"Be mindful not to exceed a detection range of 50 meters."
|
| 86 |
+
),
|
| 87 |
+
# Paper Table 7, prompt 3.
|
| 88 |
+
(
|
| 89 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 90 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 91 |
+
"Could you complete the task of detecting the centerline from the Bird's Eye View (BEV) perspective? "
|
| 92 |
+
"Ensure that the detection range remains within 50 meters."
|
| 93 |
+
),
|
| 94 |
+
# Additional paraphrase variant 4.
|
| 95 |
+
(
|
| 96 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 97 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 98 |
+
"Kindly execute the centerline detection task within the Bird's Eye View (BEV) framework. "
|
| 99 |
+
"Be mindful not to exceed a detection range of 50 meters."
|
| 100 |
+
),
|
| 101 |
+
# Additional paraphrase variant 5.
|
| 102 |
+
(
|
| 103 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 104 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 105 |
+
"Please carry out the task of detecting the centerline from the Bird's Eye View (BEV) perspective. "
|
| 106 |
+
"Ensure that the detection range remains within 50 meters."
|
| 107 |
+
),
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
# Paper Table 9 prompts 1-3, plus two repository paraphrases.
|
| 111 |
+
PLANNING_PROMPTS = [
|
| 112 |
+
# Paper Table 9, prompt 1. The paper uses fixed maneuver words; {command} keeps the same slot dynamic.
|
| 113 |
+
(
|
| 114 |
+
"The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
|
| 115 |
+
"and map query embeddings<query>. "
|
| 116 |
+
"Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
|
| 117 |
+
"The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
|
| 118 |
+
"The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
|
| 119 |
+
"The ego car will {command} in future. "
|
| 120 |
+
"Kindly furnish suitable waypoints for the vehicle's trajectory based on the provided particulars. "
|
| 121 |
+
"Waypoints ought to adhere to the [x, y] format, with each waypoint spaced at 0.5-second intervals "
|
| 122 |
+
"within a continuous 3.0-second timeframe. "
|
| 123 |
+
"For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
|
| 124 |
+
),
|
| 125 |
+
# Paper Table 9, prompt 2. The paper uses fixed maneuver words; {command} keeps the same slot dynamic.
|
| 126 |
+
(
|
| 127 |
+
"The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
|
| 128 |
+
"and map query embeddings<query>. "
|
| 129 |
+
"Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
|
| 130 |
+
"The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
|
| 131 |
+
"The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
|
| 132 |
+
"The ego car will {command} in future. "
|
| 133 |
+
"We request your provision of pertinent waypoints for the vehicle's route in accordance with the given information. "
|
| 134 |
+
"Waypoints should conform to the format [x, y], with spacing set at 0.5-second intervals "
|
| 135 |
+
"over a continuous duration of 3.0 seconds. "
|
| 136 |
+
"For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
|
| 137 |
+
),
|
| 138 |
+
# Paper Table 9, prompt 3. The paper uses fixed maneuver words; {command} keeps the same slot dynamic.
|
| 139 |
+
(
|
| 140 |
+
"The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
|
| 141 |
+
"and map query embeddings<query>. "
|
| 142 |
+
"Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
|
| 143 |
+
"The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
|
| 144 |
+
"The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
|
| 145 |
+
"The ego car will {command} in future. "
|
| 146 |
+
"Please submit fitting waypoints for the vehicle's course based on the supplied data. "
|
| 147 |
+
"Ensure waypoints are structured as [x, y] and spaced at intervals of 0.5 seconds across a continuous 3.0-second period. "
|
| 148 |
+
"For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
|
| 149 |
+
),
|
| 150 |
+
# Additional paraphrase variant 4.
|
| 151 |
+
(
|
| 152 |
+
"The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
|
| 153 |
+
"and map query embeddings<query>. "
|
| 154 |
+
"Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
|
| 155 |
+
"The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
|
| 156 |
+
"The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
|
| 157 |
+
"The ego car will {command} in future. "
|
| 158 |
+
"Please provide suitable waypoints for the ego car in [x, y] format at 0.5-second intervals "
|
| 159 |
+
"over a continuous 3.0-second period. "
|
| 160 |
+
"For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
|
| 161 |
+
),
|
| 162 |
+
# Additional paraphrase variant 5.
|
| 163 |
+
(
|
| 164 |
+
"The six images include objects that are uniformly represented as 3D detection query embeddings<query> "
|
| 165 |
+
"and map query embeddings<query>. "
|
| 166 |
+
"Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
|
| 167 |
+
"The speed of the vehicle is defined as [velocity along the x-axis, velocity along the y-axis]. "
|
| 168 |
+
"The acceleration of the vehicle is defined as [acceleration along the x-axis, acceleration along the y-axis]. "
|
| 169 |
+
"The ego car will {command} in future. "
|
| 170 |
+
"Could you generate appropriate waypoints for the vehicle's trajectory in [x, y] format, "
|
| 171 |
+
"with each waypoint separated by 0.5 seconds across the next 3.0 seconds? "
|
| 172 |
+
"For planning tasks, please pay attention to driving safety and avoid vehicle collisions during driving in continous time."
|
| 173 |
+
),
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
# Figure 5-style single-view caption prompt, parameterized by camera_name.
|
| 177 |
+
CAPTION_PROMPTS = [
|
| 178 |
+
(
|
| 179 |
+
"There are six images captured by the surround view cameras in driving vehicle. "
|
| 180 |
+
"They are uniformly represented as queries embeddings<query>. "
|
| 181 |
+
"Communicate a narrative of the setting within {camera_name} view image."
|
| 182 |
+
),
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
_TASK_POOLS = {
|
| 187 |
+
"detection": DETECTION_PROMPTS,
|
| 188 |
+
"lane": LANE_PROMPTS,
|
| 189 |
+
"planning": PLANNING_PROMPTS,
|
| 190 |
+
"caption": CAPTION_PROMPTS,
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_prompt_pool(task: str):
|
| 195 |
+
return _TASK_POOLS.get(task, DETECTION_PROMPTS)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def sample_prompt(task: str, **kwargs) -> str:
|
| 199 |
+
pool = get_prompt_pool(task)
|
| 200 |
+
template = random.choice(pool)
|
| 201 |
+
if kwargs:
|
| 202 |
+
try:
|
| 203 |
+
return template.format(**kwargs)
|
| 204 |
+
except KeyError:
|
| 205 |
+
return template
|
| 206 |
+
return template
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def rewrite_planning_prompt_for_table3(
|
| 210 |
+
prompt_text: str,
|
| 211 |
+
mode: str,
|
| 212 |
+
command: Optional[str] = None,
|
| 213 |
+
velocity_bins: Optional[Tuple[int, int]] = None,
|
| 214 |
+
acceleration_bins: Optional[Tuple[int, int]] = None,
|
| 215 |
+
) -> str:
|
| 216 |
+
"""Rewrite planning prompts to match Atlas Table 3 variants.
|
| 217 |
+
|
| 218 |
+
Modes:
|
| 219 |
+
- atlas_base: no high-level command, no explicit ego-state values
|
| 220 |
+
- atlas_high_level: keep high-level command only
|
| 221 |
+
- atlas_high_level_ego: keep high-level command and inject ego-state bins
|
| 222 |
+
|
| 223 |
+
In this repository, the top-level route command is a UniAD-style
|
| 224 |
+
future-GT-derived coarse planning command, not a raw nuScenes field.
|
| 225 |
+
"""
|
| 226 |
+
if mode not in PLANNING_TABLE3_MODES:
|
| 227 |
+
raise ValueError(f"Unsupported planning mode: {mode}")
|
| 228 |
+
|
| 229 |
+
prompt = " ".join(str(prompt_text).split())
|
| 230 |
+
prompt = _PLANNING_STATE_RE.sub("", prompt)
|
| 231 |
+
prompt = _PLANNING_COMMAND_RE.sub("", prompt)
|
| 232 |
+
|
| 233 |
+
if mode == "atlas_base":
|
| 234 |
+
return " ".join(prompt.split())
|
| 235 |
+
|
| 236 |
+
if not command:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"{mode} requires an explicit top-level route command field"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
command_sentence = f"The ego car will {command} in future."
|
| 242 |
+
|
| 243 |
+
if mode == "atlas_high_level":
|
| 244 |
+
if _PLANNING_ACCEL_SENTENCE in prompt:
|
| 245 |
+
prompt = prompt.replace(
|
| 246 |
+
_PLANNING_ACCEL_SENTENCE,
|
| 247 |
+
f"{_PLANNING_ACCEL_SENTENCE} {command_sentence}",
|
| 248 |
+
1,
|
| 249 |
+
)
|
| 250 |
+
return " ".join(prompt.split())
|
| 251 |
+
return " ".join(f"{command_sentence} {prompt}".split())
|
| 252 |
+
|
| 253 |
+
if velocity_bins is None or acceleration_bins is None:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"atlas_high_level_ego requires velocity_bins and acceleration_bins"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
state_sentence = (
|
| 259 |
+
f"The current speed value of the ego car is [{int(velocity_bins[0])}, {int(velocity_bins[1])}]. "
|
| 260 |
+
f"The current acceleration value of the ego car is [{int(acceleration_bins[0])}, {int(acceleration_bins[1])}]."
|
| 261 |
+
)
|
| 262 |
+
if _PLANNING_ACCEL_SENTENCE in prompt:
|
| 263 |
+
prompt = prompt.replace(
|
| 264 |
+
_PLANNING_ACCEL_SENTENCE,
|
| 265 |
+
f"{_PLANNING_ACCEL_SENTENCE} {state_sentence} {command_sentence}",
|
| 266 |
+
1,
|
| 267 |
+
)
|
| 268 |
+
return " ".join(prompt.split())
|
| 269 |
+
|
| 270 |
+
return " ".join(f"{state_sentence} {command_sentence} {prompt}".split())
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def build_prompt(prompt_text: str, mode: Literal["train", "infer"]) -> str:
|
| 274 |
+
s = f"USER: {prompt_text}\nASSISTANT:"
|
| 275 |
+
if audit_enabled():
|
| 276 |
+
h1b_record_prompt(mode, s)
|
| 277 |
+
return s
|
train_atlas.py
ADDED
|
@@ -0,0 +1,1018 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import math
|
| 6 |
+
import time
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
from datetime import timedelta
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Optional, List
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
| 18 |
+
|
| 19 |
+
from src.model.modeling_atlas import AtlasForCausalLM
|
| 20 |
+
from src.model.topomlp_adapter import TopoMLPToAtlasMapTokens
|
| 21 |
+
from src.model.streampetr_adapter import extract_streampetr_topk_tokens
|
| 22 |
+
from src.dataset.atlas_dataset import (
|
| 23 |
+
AtlasDataset, make_atlas_collate_fn, load_tokenizer,
|
| 24 |
+
)
|
| 25 |
+
from src.dataset.scene_sampler import SceneSequentialSampler
|
| 26 |
+
from src.prompting import PLANNING_TABLE3_MODES
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger("train_atlas")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def parse_args():
|
| 32 |
+
p = argparse.ArgumentParser()
|
| 33 |
+
p.add_argument("--llm_model", default="lmsys/vicuna-7b-v1.5")
|
| 34 |
+
p.add_argument("--visual_hidden_size", type=int, default=256)
|
| 35 |
+
p.add_argument("--num_det_queries", type=int, default=256)
|
| 36 |
+
p.add_argument("--num_map_queries", type=int, default=256)
|
| 37 |
+
p.add_argument("--streampetr_config", default=None)
|
| 38 |
+
p.add_argument("--streampetr_ckpt", default=None)
|
| 39 |
+
p.add_argument("--topomlp_config", default=None)
|
| 40 |
+
p.add_argument("--topomlp_ckpt", default=None)
|
| 41 |
+
p.add_argument("--data_json", required=True)
|
| 42 |
+
p.add_argument("--data_root", default="/mnt/data/nuscenes")
|
| 43 |
+
p.add_argument("--max_length", type=int, default=4096)
|
| 44 |
+
p.add_argument("--output_dir", default="work_dirs/atlas")
|
| 45 |
+
p.add_argument("--lr", type=float, default=2e-5)
|
| 46 |
+
p.add_argument("--weight_decay", type=float, default=1e-4)
|
| 47 |
+
p.add_argument("--batch_size", type=int, default=1)
|
| 48 |
+
p.add_argument("--epochs", type=int, default=8)
|
| 49 |
+
p.add_argument("--warmup_ratio", type=float, default=0.03)
|
| 50 |
+
p.add_argument("--gradient_accumulation_steps", type=int, default=2)
|
| 51 |
+
p.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 52 |
+
p.add_argument("--use_lora", action="store_true")
|
| 53 |
+
p.add_argument("--lora_r", type=int, default=64)
|
| 54 |
+
p.add_argument("--lora_alpha", type=int, default=64)
|
| 55 |
+
p.add_argument("--lora_dropout", type=float, default=0.1)
|
| 56 |
+
p.add_argument("--load_in_4bit", action="store_true")
|
| 57 |
+
p.add_argument("--save_steps", type=int, default=0)
|
| 58 |
+
p.add_argument("--save_epochs", type=int, default=1)
|
| 59 |
+
p.add_argument("--log_steps", type=int, default=10)
|
| 60 |
+
p.add_argument("--seed", type=int, default=42)
|
| 61 |
+
p.add_argument("--num_workers", type=int, default=4)
|
| 62 |
+
p.add_argument("--resume", default=None)
|
| 63 |
+
p.add_argument("--local_rank", "--local-rank", type=int, default=int(os.environ.get("LOCAL_RANK", -1)))
|
| 64 |
+
p.add_argument("--fp16", action="store_true")
|
| 65 |
+
p.add_argument("--bf16", action="store_true")
|
| 66 |
+
p.add_argument("--image_path_remap", default=None,
|
| 67 |
+
help="old=new path remap, e.g. /mnt/data=/local/data")
|
| 68 |
+
p.add_argument("--precomputed_det_tokens", default=None,
|
| 69 |
+
help="[offline only] Dir with precomputed det tokens (.pt files)")
|
| 70 |
+
p.add_argument("--precomputed_map_tokens", default=None,
|
| 71 |
+
help="[offline only] Dir with precomputed TopoMLP map tokens (.pt files)")
|
| 72 |
+
p.add_argument("--visual_token_mode", choices=("online", "offline"), default="online",
|
| 73 |
+
help="Visual token source: online=live frozen encoders (default), offline=read *_offline dirs")
|
| 74 |
+
p.add_argument("--deepspeed", default=None,
|
| 75 |
+
help="Path to DeepSpeed config JSON (enables ZeRO)")
|
| 76 |
+
p.add_argument("--keep_last_n_ckpts", type=int, default=0,
|
| 77 |
+
help="Keep only the N most recent epoch checkpoints (0=keep all)")
|
| 78 |
+
p.add_argument(
|
| 79 |
+
"--planning_table3_mode",
|
| 80 |
+
choices=PLANNING_TABLE3_MODES,
|
| 81 |
+
default="atlas_base",
|
| 82 |
+
help=(
|
| 83 |
+
"Planning prompt variant matching Atlas Table 3: "
|
| 84 |
+
"atlas_base=no command/no explicit ego state; "
|
| 85 |
+
"atlas_high_level=requires top-level route_command "
|
| 86 |
+
"(this repo uses a UniAD-style future-GT-derived command); "
|
| 87 |
+
"atlas_high_level_ego=requires top-level route_command plus "
|
| 88 |
+
"velocity/acceleration bins."
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
return p.parse_args()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _validate_visual_token_mode(args):
|
| 95 |
+
"""Enforce mode-specific constraints. Fail hard, never silently degrade."""
|
| 96 |
+
if args.visual_token_mode == "online":
|
| 97 |
+
if args.precomputed_det_tokens or args.precomputed_map_tokens:
|
| 98 |
+
raise RuntimeError(
|
| 99 |
+
"visual_token_mode=online forbids --precomputed_det_tokens / "
|
| 100 |
+
"--precomputed_map_tokens. Use --visual_token_mode offline to "
|
| 101 |
+
"read offline token directories."
|
| 102 |
+
)
|
| 103 |
+
missing = []
|
| 104 |
+
if not args.streampetr_config or not args.streampetr_ckpt:
|
| 105 |
+
missing.append("--streampetr_config/--streampetr_ckpt")
|
| 106 |
+
if not args.topomlp_config or not args.topomlp_ckpt:
|
| 107 |
+
missing.append("--topomlp_config/--topomlp_ckpt")
|
| 108 |
+
if missing:
|
| 109 |
+
raise RuntimeError(
|
| 110 |
+
"visual_token_mode=online requires live encoder configs and "
|
| 111 |
+
"checkpoints. Missing: " + ", ".join(missing)
|
| 112 |
+
)
|
| 113 |
+
for p in (args.streampetr_config, args.streampetr_ckpt, args.topomlp_config, args.topomlp_ckpt):
|
| 114 |
+
if not os.path.exists(p):
|
| 115 |
+
raise RuntimeError(f"Required online asset does not exist: {p}")
|
| 116 |
+
if args.batch_size != 1:
|
| 117 |
+
raise RuntimeError(
|
| 118 |
+
"visual_token_mode=online with temporal memory requires "
|
| 119 |
+
"--batch_size 1 (paper-aligned). Got: %d" % args.batch_size
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
if not args.precomputed_det_tokens and not args.precomputed_map_tokens:
|
| 123 |
+
raise RuntimeError(
|
| 124 |
+
"visual_token_mode=offline requires at least one "
|
| 125 |
+
"--precomputed_*_tokens directory."
|
| 126 |
+
)
|
| 127 |
+
for p in (args.precomputed_det_tokens, args.precomputed_map_tokens):
|
| 128 |
+
if p and not os.path.isdir(p):
|
| 129 |
+
raise RuntimeError(f"Offline token directory does not exist: {p}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def set_seed(seed):
|
| 133 |
+
import random
|
| 134 |
+
import numpy as np
|
| 135 |
+
random.seed(seed)
|
| 136 |
+
np.random.seed(seed)
|
| 137 |
+
torch.manual_seed(seed)
|
| 138 |
+
if torch.cuda.is_available():
|
| 139 |
+
torch.cuda.manual_seed_all(seed)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def setup_distributed(local_rank):
|
| 143 |
+
if local_rank == -1:
|
| 144 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 145 |
+
return device, False, 0, 1
|
| 146 |
+
dist.init_process_group(backend="nccl", timeout=timedelta(seconds=1800))
|
| 147 |
+
torch.cuda.set_device(local_rank)
|
| 148 |
+
device = torch.device("cuda", local_rank)
|
| 149 |
+
rank = dist.get_rank()
|
| 150 |
+
world_size = dist.get_world_size()
|
| 151 |
+
return device, True, rank, world_size
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def is_main_process(distributed, rank):
|
| 155 |
+
return (not distributed) or (rank == 0)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def load_frozen_encoder(config_path, ckpt_path, model_type, device):
|
| 159 |
+
if config_path is None or ckpt_path is None:
|
| 160 |
+
return None
|
| 161 |
+
try:
|
| 162 |
+
from mmcv import Config
|
| 163 |
+
from mmdet3d.models import build_model
|
| 164 |
+
from mmcv.runner import load_checkpoint
|
| 165 |
+
except ImportError:
|
| 166 |
+
raise RuntimeError(
|
| 167 |
+
f"mmcv/mmdet3d not installed but --{model_type}_config and "
|
| 168 |
+
f"--{model_type}_ckpt were explicitly provided. "
|
| 169 |
+
f"Install mmcv/mmdet3d or remove these arguments to train without {model_type}."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if model_type == "streampetr":
|
| 173 |
+
sp_root = str(Path(__file__).resolve().parent / "external" / "StreamPETR")
|
| 174 |
+
if sp_root not in sys.path:
|
| 175 |
+
sys.path.insert(0, sp_root)
|
| 176 |
+
try:
|
| 177 |
+
import projects.mmdet3d_plugin # noqa: F401
|
| 178 |
+
except ImportError:
|
| 179 |
+
raise RuntimeError(
|
| 180 |
+
f"StreamPETR plugin not found under {sp_root}/projects/mmdet3d_plugin. "
|
| 181 |
+
f"Ensure the submodule is checked out, or remove --streampetr_config/--streampetr_ckpt."
|
| 182 |
+
)
|
| 183 |
+
elif model_type == "topomlp":
|
| 184 |
+
tp_root = str(Path(__file__).resolve().parent / "external" / "TopoMLP_Repo")
|
| 185 |
+
if tp_root not in sys.path:
|
| 186 |
+
sys.path.insert(0, tp_root)
|
| 187 |
+
try:
|
| 188 |
+
os.environ["ATLAS_TOPOMLP_MODELS_ONLY"] = "1"
|
| 189 |
+
from mmcv.utils import registry as _reg
|
| 190 |
+
_orig = _reg.Registry._register_module
|
| 191 |
+
def _tolerant_register(self, module, module_name=None, force=False):
|
| 192 |
+
return _orig(self, module, module_name=module_name, force=True)
|
| 193 |
+
_reg.Registry._register_module = _tolerant_register
|
| 194 |
+
import projects.topomlp # noqa: F401
|
| 195 |
+
_reg.Registry._register_module = _orig
|
| 196 |
+
except ImportError:
|
| 197 |
+
raise RuntimeError(
|
| 198 |
+
f"TopoMLP plugin not found under {tp_root}/projects/topomlp. "
|
| 199 |
+
f"Ensure the submodule is checked out, or remove --topomlp_config/--topomlp_ckpt."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
cfg = Config.fromfile(config_path)
|
| 203 |
+
model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
|
| 204 |
+
load_checkpoint(model, ckpt_path, map_location="cpu")
|
| 205 |
+
model.eval()
|
| 206 |
+
model.to(device)
|
| 207 |
+
for param in model.parameters():
|
| 208 |
+
param.requires_grad_(False)
|
| 209 |
+
logger.info("Loaded frozen %s from %s", model_type, ckpt_path)
|
| 210 |
+
return model
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def build_img_metas_streampetr(batch, device, idx):
|
| 214 |
+
N = batch["pixel_values_det"].shape[1]
|
| 215 |
+
fH, fW = 800, 1600
|
| 216 |
+
scene_ids = batch.get("scene_id", ["__atlas__"] * (idx + 1))
|
| 217 |
+
meta = {
|
| 218 |
+
"pad_shape": [(fH, fW, 3)] * N,
|
| 219 |
+
"img_shape": [(fH, fW, 3)] * N,
|
| 220 |
+
"scene_token": scene_ids[idx] if idx < len(scene_ids) else "__atlas__",
|
| 221 |
+
}
|
| 222 |
+
if "lidar2img_det" in batch:
|
| 223 |
+
meta["lidar2img"] = batch["lidar2img_det"][idx].cpu().numpy()
|
| 224 |
+
return meta
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def build_img_metas_topomlp(batch, device, idx):
|
| 228 |
+
meta = {}
|
| 229 |
+
if "lidar2img_map" in batch:
|
| 230 |
+
meta["lidar2img"] = batch["lidar2img_map"][idx].cpu().numpy()
|
| 231 |
+
tH, tW = 800, 1600
|
| 232 |
+
N = batch["pixel_values_map"].shape[1]
|
| 233 |
+
meta["img_shape"] = tuple([(tH, tW, 3)] * N)
|
| 234 |
+
meta["pad_shape"] = tuple([(tH, tW, 3)] * N)
|
| 235 |
+
meta["scale_factor"] = 1.0
|
| 236 |
+
meta["te_yolov8"] = None
|
| 237 |
+
return meta
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@torch.no_grad()
|
| 241 |
+
def run_streampetr_forward(model, imgs, img_metas, batch, device, prev_exists=None):
|
| 242 |
+
B, N = imgs.shape[:2]
|
| 243 |
+
|
| 244 |
+
img_feats = model.extract_img_feat(imgs, 1)
|
| 245 |
+
|
| 246 |
+
data = {
|
| 247 |
+
"img": imgs,
|
| 248 |
+
"img_feats": img_feats,
|
| 249 |
+
"prev_exists": prev_exists if prev_exists is not None else imgs.new_zeros(B),
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
if "intrinsics_det" in batch:
|
| 253 |
+
K3 = batch["intrinsics_det"].to(device)
|
| 254 |
+
K4 = torch.zeros(B, N, 4, 4, device=device, dtype=K3.dtype)
|
| 255 |
+
K4[:, :, :3, :3] = K3
|
| 256 |
+
K4[:, :, 3, 3] = 1.0
|
| 257 |
+
data["intrinsics"] = K4
|
| 258 |
+
else:
|
| 259 |
+
data["intrinsics"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
|
| 260 |
+
|
| 261 |
+
if "lidar2img_det" in batch:
|
| 262 |
+
data["lidar2img"] = batch["lidar2img_det"].to(device)
|
| 263 |
+
else:
|
| 264 |
+
data["lidar2img"] = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).expand(B, N, -1, -1).contiguous()
|
| 265 |
+
|
| 266 |
+
if "ego_pose" in batch and batch["ego_pose"] is not None:
|
| 267 |
+
data["ego_pose"] = batch["ego_pose"].to(device)
|
| 268 |
+
else:
|
| 269 |
+
data["ego_pose"] = torch.eye(4, device=device).unsqueeze(0).expand(B, -1, -1).contiguous()
|
| 270 |
+
|
| 271 |
+
if "ego_pose_inv" in batch and batch["ego_pose_inv"] is not None:
|
| 272 |
+
data["ego_pose_inv"] = batch["ego_pose_inv"].to(device)
|
| 273 |
+
else:
|
| 274 |
+
data["ego_pose_inv"] = torch.inverse(data["ego_pose"])
|
| 275 |
+
|
| 276 |
+
if "timestamp" in batch and batch["timestamp"] is not None:
|
| 277 |
+
data["timestamp"] = batch["timestamp"].to(device)
|
| 278 |
+
else:
|
| 279 |
+
data["timestamp"] = torch.zeros(B, device=device)
|
| 280 |
+
|
| 281 |
+
location = model.prepare_location(img_metas, **data)
|
| 282 |
+
outs_roi = model.forward_roi_head(location, **data)
|
| 283 |
+
topk_indexes = outs_roi["topk_indexes"]
|
| 284 |
+
|
| 285 |
+
outs = model.pts_bbox_head(location, img_metas, topk_indexes, **data)
|
| 286 |
+
return outs
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@torch.no_grad()
|
| 290 |
+
def run_topomlp_forward(model, imgs, img_metas):
|
| 291 |
+
return model.simple_forward(imgs, img_metas)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _reconstruct_topomlp_outs(saved: dict, device, dtype):
|
| 295 |
+
"""Convert precomputed .pt dict back to the format adapter.forward() expects."""
|
| 296 |
+
def _restore(t):
|
| 297 |
+
return t.to(device=device, dtype=dtype).unsqueeze(0)
|
| 298 |
+
return {
|
| 299 |
+
"lc_outs_dec_list": [_restore(saved["lc_outs_dec"])],
|
| 300 |
+
"all_lc_cls_scores_list": [_restore(saved["lc_cls_scores"])],
|
| 301 |
+
"all_lc_preds_list": [_restore(saved["lc_preds"])],
|
| 302 |
+
"lc_outs_dec_one2many_list": [_restore(saved["lc_outs_dec_o2m"])],
|
| 303 |
+
"all_lc_cls_scores_one2many_list": [_restore(saved["lc_cls_scores_o2m"])],
|
| 304 |
+
"all_lc_preds_one2many_list": [_restore(saved["lc_preds_o2m"])],
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def extract_visual_tokens(
|
| 309 |
+
streampetr_model,
|
| 310 |
+
topomlp_model,
|
| 311 |
+
topomlp_adapter,
|
| 312 |
+
batch,
|
| 313 |
+
device,
|
| 314 |
+
num_det_queries=256,
|
| 315 |
+
visual_hidden_size=256,
|
| 316 |
+
query_token_id=None,
|
| 317 |
+
visual_token_mode="online",
|
| 318 |
+
streaming_state=None,
|
| 319 |
+
):
|
| 320 |
+
"""Extract det + map visual tokens.
|
| 321 |
+
|
| 322 |
+
In online mode with streaming_state, StreamPETR temporal memory is managed
|
| 323 |
+
per-scene and duplicate physical frames are protected: if the current
|
| 324 |
+
sample_id equals the previous one, we reuse cached det tokens and skip the
|
| 325 |
+
StreamPETR forward to avoid pushing the same frame into memory twice.
|
| 326 |
+
"""
|
| 327 |
+
B = batch["pixel_values_det"].shape[0]
|
| 328 |
+
vis: Dict[str, torch.Tensor] = {}
|
| 329 |
+
|
| 330 |
+
needs_map = False
|
| 331 |
+
if query_token_id is not None and "input_ids" in batch:
|
| 332 |
+
n_queries = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item())
|
| 333 |
+
needs_map = n_queries > num_det_queries
|
| 334 |
+
|
| 335 |
+
# ---- Detection tokens ----
|
| 336 |
+
if visual_token_mode == "offline" and "precomputed_det" in batch and "precomputed_det_ref" in batch:
|
| 337 |
+
vis["detection"] = batch["precomputed_det"].to(device)
|
| 338 |
+
vis["detection_ref_points"] = batch["precomputed_det_ref"].to(device)
|
| 339 |
+
elif visual_token_mode == "offline":
|
| 340 |
+
raise RuntimeError(
|
| 341 |
+
"visual_token_mode=offline but detection precomputed tokens are missing "
|
| 342 |
+
"for the current batch. Refusing to zero-fill."
|
| 343 |
+
)
|
| 344 |
+
elif streampetr_model is not None:
|
| 345 |
+
if B != 1 and streaming_state is not None:
|
| 346 |
+
raise RuntimeError("online temporal det requires batch_size=1")
|
| 347 |
+
|
| 348 |
+
current_sample_id = batch.get("sample_id", [None])[0]
|
| 349 |
+
current_scene = batch.get("scene_id", ["__atlas__"])[0]
|
| 350 |
+
reuse_cache = False
|
| 351 |
+
|
| 352 |
+
if streaming_state is not None:
|
| 353 |
+
prev_scene = streaming_state.get("prev_scene_token")
|
| 354 |
+
prev_sample_id = streaming_state.get("prev_sample_id")
|
| 355 |
+
ts_tensor = batch.get("timestamp")
|
| 356 |
+
current_ts = float(ts_tensor[0].item()) if ts_tensor is not None else None
|
| 357 |
+
prev_ts = streaming_state.get("prev_timestamp")
|
| 358 |
+
|
| 359 |
+
is_new_segment = (
|
| 360 |
+
prev_scene is None
|
| 361 |
+
or current_scene != prev_scene
|
| 362 |
+
or (current_ts is not None and prev_ts is not None and current_ts <= prev_ts)
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
if current_sample_id is not None and current_sample_id == prev_sample_id:
|
| 366 |
+
cached = streaming_state.get("cached_det")
|
| 367 |
+
if cached is not None:
|
| 368 |
+
reuse_cache = True
|
| 369 |
+
vis["detection"] = cached["detection"]
|
| 370 |
+
vis["detection_ref_points"] = cached["detection_ref_points"]
|
| 371 |
+
|
| 372 |
+
if not reuse_cache:
|
| 373 |
+
if is_new_segment:
|
| 374 |
+
streampetr_model.pts_bbox_head.reset_memory()
|
| 375 |
+
prev_exists_val = 0.0 if is_new_segment else 1.0
|
| 376 |
+
imgs_det = batch["pixel_values_det"].to(device)
|
| 377 |
+
prev_exists = imgs_det.new_full((B,), prev_exists_val)
|
| 378 |
+
|
| 379 |
+
img_metas = [build_img_metas_streampetr(batch, device, b) for b in range(B)]
|
| 380 |
+
run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device, prev_exists=prev_exists)
|
| 381 |
+
ego_pose_for_ref = batch.get("ego_pose")
|
| 382 |
+
if ego_pose_for_ref is not None:
|
| 383 |
+
ego_pose_for_ref = ego_pose_for_ref.to(device)
|
| 384 |
+
det_out = extract_streampetr_topk_tokens(
|
| 385 |
+
streampetr_model.pts_bbox_head,
|
| 386 |
+
topk=num_det_queries,
|
| 387 |
+
ego_pose=ego_pose_for_ref,
|
| 388 |
+
)
|
| 389 |
+
vis["detection"] = det_out["detection"]
|
| 390 |
+
vis["detection_ref_points"] = det_out["detection_ref_points"]
|
| 391 |
+
|
| 392 |
+
streaming_state["cached_det"] = {
|
| 393 |
+
"detection": vis["detection"],
|
| 394 |
+
"detection_ref_points": vis["detection_ref_points"],
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
streaming_state["prev_scene_token"] = current_scene
|
| 398 |
+
streaming_state["prev_sample_id"] = current_sample_id
|
| 399 |
+
if batch.get("timestamp") is not None:
|
| 400 |
+
streaming_state["prev_timestamp"] = float(batch["timestamp"][0].item())
|
| 401 |
+
else:
|
| 402 |
+
imgs_det = batch["pixel_values_det"].to(device)
|
| 403 |
+
img_metas = [build_img_metas_streampetr(batch, device, b) for b in range(B)]
|
| 404 |
+
run_streampetr_forward(streampetr_model, imgs_det, img_metas, batch, device)
|
| 405 |
+
ego_pose_for_ref = batch.get("ego_pose")
|
| 406 |
+
if ego_pose_for_ref is not None:
|
| 407 |
+
ego_pose_for_ref = ego_pose_for_ref.to(device)
|
| 408 |
+
det_out = extract_streampetr_topk_tokens(
|
| 409 |
+
streampetr_model.pts_bbox_head,
|
| 410 |
+
topk=num_det_queries,
|
| 411 |
+
ego_pose=ego_pose_for_ref,
|
| 412 |
+
)
|
| 413 |
+
vis["detection"] = det_out["detection"]
|
| 414 |
+
vis["detection_ref_points"] = det_out["detection_ref_points"]
|
| 415 |
+
elif visual_token_mode == "online":
|
| 416 |
+
raise RuntimeError(
|
| 417 |
+
"visual_token_mode=online but StreamPETR model is None. "
|
| 418 |
+
"Provide --streampetr_config and --streampetr_ckpt."
|
| 419 |
+
)
|
| 420 |
+
else:
|
| 421 |
+
vis["detection"] = torch.zeros(B, num_det_queries, visual_hidden_size, device=device)
|
| 422 |
+
vis["detection_ref_points"] = torch.zeros(B, num_det_queries, 3, device=device)
|
| 423 |
+
|
| 424 |
+
# ---- Map tokens ----
|
| 425 |
+
num_map_queries = num_det_queries
|
| 426 |
+
if topomlp_adapter is not None:
|
| 427 |
+
num_map_queries = topomlp_adapter.num_map_tokens
|
| 428 |
+
|
| 429 |
+
if topomlp_adapter is not None:
|
| 430 |
+
_params = list(topomlp_adapter.parameters())
|
| 431 |
+
_bufs = list(topomlp_adapter.buffers())
|
| 432 |
+
adapter_dtype = _params[0].dtype if _params else (_bufs[0].dtype if _bufs else torch.float32)
|
| 433 |
+
|
| 434 |
+
map_filled = False
|
| 435 |
+
if visual_token_mode == "offline" and needs_map and "precomputed_map" in batch:
|
| 436 |
+
if B == 1:
|
| 437 |
+
outs = _reconstruct_topomlp_outs(batch["precomputed_map"][0], device, adapter_dtype)
|
| 438 |
+
else:
|
| 439 |
+
per_sample = [_reconstruct_topomlp_outs(batch["precomputed_map"][b], device, adapter_dtype) for b in range(B)]
|
| 440 |
+
outs = {}
|
| 441 |
+
for k in per_sample[0]:
|
| 442 |
+
outs[k] = [torch.cat([s[k][i] for s in per_sample], dim=0) for i in range(len(per_sample[0][k]))]
|
| 443 |
+
map_out = topomlp_adapter(outs)
|
| 444 |
+
vis["map"] = map_out["map"]
|
| 445 |
+
vis["map_ref_points"] = map_out["map_ref_points"]
|
| 446 |
+
map_filled = True
|
| 447 |
+
elif visual_token_mode == "offline" and needs_map:
|
| 448 |
+
raise RuntimeError(
|
| 449 |
+
"visual_token_mode=offline but map precomputed tokens are missing "
|
| 450 |
+
"for a batch that requires map queries. Refusing to zero-fill."
|
| 451 |
+
)
|
| 452 |
+
elif needs_map and topomlp_model is not None:
|
| 453 |
+
imgs_map = batch["pixel_values_map"].to(device)
|
| 454 |
+
img_metas = [build_img_metas_topomlp(batch, device, b) for b in range(B)]
|
| 455 |
+
outs = run_topomlp_forward(topomlp_model, imgs_map, img_metas)
|
| 456 |
+
for k, v in outs.items():
|
| 457 |
+
if isinstance(v, torch.Tensor):
|
| 458 |
+
outs[k] = v.to(adapter_dtype)
|
| 459 |
+
elif isinstance(v, list):
|
| 460 |
+
outs[k] = [x.to(adapter_dtype) if isinstance(x, torch.Tensor) else x for x in v]
|
| 461 |
+
map_out = topomlp_adapter(outs)
|
| 462 |
+
vis["map"] = map_out["map"]
|
| 463 |
+
vis["map_ref_points"] = map_out["map_ref_points"]
|
| 464 |
+
map_filled = True
|
| 465 |
+
elif needs_map and visual_token_mode == "online":
|
| 466 |
+
raise RuntimeError(
|
| 467 |
+
"visual_token_mode=online but TopoMLP model is None. "
|
| 468 |
+
"Provide --topomlp_config and --topomlp_ckpt."
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
if not map_filled:
|
| 472 |
+
vis["map"] = torch.zeros(B, num_map_queries, visual_hidden_size, device=device)
|
| 473 |
+
vis["map_ref_points"] = torch.zeros(B, num_map_queries, 3, device=device)
|
| 474 |
+
|
| 475 |
+
return vis
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.0):
|
| 479 |
+
def lr_lambda(step):
|
| 480 |
+
if step < num_warmup_steps:
|
| 481 |
+
return float(step) / float(max(1, num_warmup_steps))
|
| 482 |
+
progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 483 |
+
return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| 484 |
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def _optimizer_steps_per_epoch(num_batches: int, grad_accum_steps: int) -> int:
|
| 488 |
+
if num_batches <= 0:
|
| 489 |
+
return 0
|
| 490 |
+
return int(math.ceil(float(num_batches) / float(max(1, grad_accum_steps))))
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def _accum_window_size_for_batch(
|
| 494 |
+
batch_idx: int,
|
| 495 |
+
num_batches: int,
|
| 496 |
+
grad_accum_steps: int,
|
| 497 |
+
) -> int:
|
| 498 |
+
"""Return the effective accumulation window size for this batch.
|
| 499 |
+
|
| 500 |
+
Full windows use `grad_accum_steps`. The final partial window uses the
|
| 501 |
+
remainder so that tail batches are not under-scaled when we flush them.
|
| 502 |
+
"""
|
| 503 |
+
grad_accum_steps = max(1, int(grad_accum_steps))
|
| 504 |
+
remainder = int(num_batches % grad_accum_steps)
|
| 505 |
+
tail_start = int(num_batches - remainder)
|
| 506 |
+
if remainder > 0 and batch_idx >= tail_start:
|
| 507 |
+
return remainder
|
| 508 |
+
return grad_accum_steps
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def _is_optimizer_step_batch(
|
| 512 |
+
batch_idx: int,
|
| 513 |
+
num_batches: int,
|
| 514 |
+
grad_accum_steps: int,
|
| 515 |
+
) -> bool:
|
| 516 |
+
grad_accum_steps = max(1, int(grad_accum_steps))
|
| 517 |
+
natural_boundary = ((batch_idx + 1) % grad_accum_steps) == 0
|
| 518 |
+
is_last_batch = (batch_idx + 1) == num_batches
|
| 519 |
+
return natural_boundary or is_last_batch
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def save_checkpoint(path, atlas, adapter, optimizer, scheduler, global_step, epoch, args):
|
| 523 |
+
save_dict = {
|
| 524 |
+
"global_step": global_step,
|
| 525 |
+
"epoch": epoch,
|
| 526 |
+
"args": vars(args),
|
| 527 |
+
"atlas_state_dict": {k: v.cpu() for k, v in atlas.state_dict().items()},
|
| 528 |
+
"optimizer": optimizer.state_dict(),
|
| 529 |
+
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
| 530 |
+
}
|
| 531 |
+
if adapter is not None:
|
| 532 |
+
save_dict["adapter_state_dict"] = {k: v.cpu() for k, v in adapter.state_dict().items()}
|
| 533 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 534 |
+
torch.save(save_dict, path)
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def cleanup_old_checkpoints(output_dir: Path, keep_n: int):
|
| 538 |
+
"""Delete old epoch-* checkpoint dirs, keeping only the most recent *keep_n*."""
|
| 539 |
+
if keep_n <= 0:
|
| 540 |
+
return
|
| 541 |
+
import shutil
|
| 542 |
+
epoch_dirs = sorted(
|
| 543 |
+
[d for d in output_dir.iterdir() if d.is_dir() and d.name.startswith("epoch-")],
|
| 544 |
+
key=lambda d: int(d.name.split("-")[1]),
|
| 545 |
+
)
|
| 546 |
+
while len(epoch_dirs) > keep_n:
|
| 547 |
+
old = epoch_dirs.pop(0)
|
| 548 |
+
shutil.rmtree(old, ignore_errors=True)
|
| 549 |
+
logger.info("Deleted old checkpoint: %s", old)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def main():
|
| 553 |
+
args = parse_args()
|
| 554 |
+
class _FlushHandler(logging.StreamHandler):
|
| 555 |
+
def emit(self, record):
|
| 556 |
+
super().emit(record)
|
| 557 |
+
self.flush()
|
| 558 |
+
logging.root.handlers.clear()
|
| 559 |
+
_h = _FlushHandler(sys.stderr)
|
| 560 |
+
_fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
| 561 |
+
_h.setFormatter(_fmt)
|
| 562 |
+
logging.root.addHandler(_h)
|
| 563 |
+
logging.root.setLevel(logging.INFO)
|
| 564 |
+
|
| 565 |
+
_validate_visual_token_mode(args)
|
| 566 |
+
|
| 567 |
+
device, distributed, rank, world_size = setup_distributed(args.local_rank)
|
| 568 |
+
set_seed(args.seed + rank)
|
| 569 |
+
_main = is_main_process(distributed, rank)
|
| 570 |
+
|
| 571 |
+
is_online = args.visual_token_mode == "online"
|
| 572 |
+
|
| 573 |
+
output_dir = Path(args.output_dir)
|
| 574 |
+
if _main:
|
| 575 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 576 |
+
_fh = logging.FileHandler(str(output_dir / "train.log"), mode="a")
|
| 577 |
+
_fh.setFormatter(_fmt)
|
| 578 |
+
logging.root.addHandler(_fh)
|
| 579 |
+
with open(output_dir / "args.json", "w") as f:
|
| 580 |
+
json.dump(vars(args), f, indent=2)
|
| 581 |
+
|
| 582 |
+
if _main:
|
| 583 |
+
logger.info("Loading tokenizer: %s", args.llm_model)
|
| 584 |
+
tokenizer = load_tokenizer(args.llm_model)
|
| 585 |
+
if "<query>" not in tokenizer.get_vocab():
|
| 586 |
+
tokenizer.add_tokens(["<query>"])
|
| 587 |
+
|
| 588 |
+
_precomp_det = args.precomputed_det_tokens if not is_online else None
|
| 589 |
+
_precomp_map = args.precomputed_map_tokens if not is_online else None
|
| 590 |
+
dataset = AtlasDataset(
|
| 591 |
+
json_file=args.data_json,
|
| 592 |
+
image_root=args.data_root,
|
| 593 |
+
tokenizer=tokenizer,
|
| 594 |
+
max_length=args.max_length,
|
| 595 |
+
is_training=True,
|
| 596 |
+
planning_table3_mode=args.planning_table3_mode,
|
| 597 |
+
image_path_remap=args.image_path_remap,
|
| 598 |
+
precomputed_det_tokens=_precomp_det,
|
| 599 |
+
precomputed_map_tokens=_precomp_map,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
if is_online:
|
| 603 |
+
scene_groups = dataset.get_scene_groups()
|
| 604 |
+
sampler = SceneSequentialSampler(
|
| 605 |
+
scene_groups,
|
| 606 |
+
num_replicas=world_size,
|
| 607 |
+
rank=rank,
|
| 608 |
+
seed=args.seed,
|
| 609 |
+
pad_to_multiple=args.gradient_accumulation_steps,
|
| 610 |
+
)
|
| 611 |
+
if _main:
|
| 612 |
+
logger.info("Online mode: SceneSequentialSampler (%d scenes, %d samples, world=%d)",
|
| 613 |
+
len(scene_groups), len(dataset), world_size)
|
| 614 |
+
else:
|
| 615 |
+
from torch.utils.data import DistributedSampler
|
| 616 |
+
sampler = DistributedSampler(dataset, shuffle=True) if distributed else None
|
| 617 |
+
|
| 618 |
+
collate_fn = make_atlas_collate_fn(tokenizer.pad_token_id)
|
| 619 |
+
dataloader = DataLoader(
|
| 620 |
+
dataset,
|
| 621 |
+
batch_size=args.batch_size,
|
| 622 |
+
shuffle=(not is_online and sampler is None),
|
| 623 |
+
sampler=sampler,
|
| 624 |
+
num_workers=args.num_workers,
|
| 625 |
+
collate_fn=collate_fn,
|
| 626 |
+
pin_memory=True,
|
| 627 |
+
drop_last=not is_online,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
streampetr_model = load_frozen_encoder(
|
| 631 |
+
args.streampetr_config, args.streampetr_ckpt, "streampetr", device
|
| 632 |
+
)
|
| 633 |
+
topomlp_model = load_frozen_encoder(
|
| 634 |
+
args.topomlp_config, args.topomlp_ckpt, "topomlp", device
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
topomlp_adapter = None
|
| 638 |
+
if topomlp_model is not None or _precomp_map:
|
| 639 |
+
_tp_bev_range = (-51.2, -25.6, -8.0, 51.2, 25.6, 4.0)
|
| 640 |
+
if args.topomlp_config:
|
| 641 |
+
try:
|
| 642 |
+
from mmcv import Config as _Cfg
|
| 643 |
+
_tp_cfg = _Cfg.fromfile(args.topomlp_config)
|
| 644 |
+
if hasattr(_tp_cfg, "point_cloud_range"):
|
| 645 |
+
_tp_bev_range = tuple(float(v) for v in _tp_cfg.point_cloud_range)
|
| 646 |
+
logger.info("TopoMLP bev_range from config: %s", _tp_bev_range)
|
| 647 |
+
except Exception as e:
|
| 648 |
+
logger.warning("Failed to read point_cloud_range from TopoMLP config: %s. Using default: %s", e, _tp_bev_range)
|
| 649 |
+
topomlp_adapter = TopoMLPToAtlasMapTokens(
|
| 650 |
+
num_map_tokens=args.num_map_queries,
|
| 651 |
+
hidden_size=args.visual_hidden_size,
|
| 652 |
+
bev_range=_tp_bev_range,
|
| 653 |
+
).to(device)
|
| 654 |
+
|
| 655 |
+
dtype = torch.float32
|
| 656 |
+
if args.bf16:
|
| 657 |
+
dtype = torch.bfloat16
|
| 658 |
+
elif args.fp16:
|
| 659 |
+
dtype = torch.float16
|
| 660 |
+
|
| 661 |
+
if args.load_in_4bit:
|
| 662 |
+
dm = {"": device} if distributed else "auto"
|
| 663 |
+
else:
|
| 664 |
+
dm = None
|
| 665 |
+
|
| 666 |
+
_ds_bf16 = False
|
| 667 |
+
_ds_fp16 = False
|
| 668 |
+
if args.deepspeed:
|
| 669 |
+
with open(args.deepspeed) as _f:
|
| 670 |
+
_ds_cfg_peek = json.load(_f)
|
| 671 |
+
_ds_bf16 = _ds_cfg_peek.get("bf16", {}).get("enabled", False)
|
| 672 |
+
_ds_fp16 = _ds_cfg_peek.get("fp16", {}).get("enabled", False)
|
| 673 |
+
_use_half = args.bf16 or args.fp16 or _ds_bf16 or _ds_fp16
|
| 674 |
+
if _use_half and dtype == torch.float32:
|
| 675 |
+
dtype = torch.bfloat16 if (args.bf16 or _ds_bf16) else torch.float16
|
| 676 |
+
|
| 677 |
+
atlas = AtlasForCausalLM(
|
| 678 |
+
llm_model_name=args.llm_model,
|
| 679 |
+
visual_hidden_size=args.visual_hidden_size,
|
| 680 |
+
num_queries=args.num_det_queries,
|
| 681 |
+
num_map_queries=args.num_map_queries,
|
| 682 |
+
load_in_4bit=args.load_in_4bit,
|
| 683 |
+
use_flash_attention=_use_half,
|
| 684 |
+
device_map=dm,
|
| 685 |
+
torch_dtype=dtype,
|
| 686 |
+
use_lora=args.use_lora,
|
| 687 |
+
lora_r=args.lora_r,
|
| 688 |
+
lora_alpha=args.lora_alpha,
|
| 689 |
+
lora_dropout=args.lora_dropout,
|
| 690 |
+
)
|
| 691 |
+
atlas.resize_token_embeddings(len(tokenizer))
|
| 692 |
+
query_token_id = tokenizer.convert_tokens_to_ids("<query>")
|
| 693 |
+
atlas.set_query_token_id(query_token_id)
|
| 694 |
+
if topomlp_adapter is not None:
|
| 695 |
+
atlas.topomlp_adapter = topomlp_adapter
|
| 696 |
+
if dm is None and args.deepspeed is None:
|
| 697 |
+
atlas = atlas.to(device)
|
| 698 |
+
atlas.gradient_checkpointing_enable()
|
| 699 |
+
|
| 700 |
+
num_batches_per_epoch = len(dataloader)
|
| 701 |
+
steps_per_epoch = _optimizer_steps_per_epoch(
|
| 702 |
+
num_batches_per_epoch, args.gradient_accumulation_steps
|
| 703 |
+
)
|
| 704 |
+
total_steps = steps_per_epoch * args.epochs
|
| 705 |
+
warmup_steps = int(total_steps * args.warmup_ratio)
|
| 706 |
+
|
| 707 |
+
global_step = 0
|
| 708 |
+
start_epoch = 0
|
| 709 |
+
|
| 710 |
+
_resume_ckpt = None
|
| 711 |
+
if args.resume:
|
| 712 |
+
_resume_ckpt = torch.load(args.resume, map_location="cpu")
|
| 713 |
+
if "atlas_state_dict" not in _resume_ckpt:
|
| 714 |
+
raise RuntimeError(f"Checkpoint missing 'atlas_state_dict'. Keys: {list(_resume_ckpt.keys())}")
|
| 715 |
+
missing, _ = atlas.load_state_dict(_resume_ckpt["atlas_state_dict"], strict=False)
|
| 716 |
+
if _main and missing:
|
| 717 |
+
logger.warning("Resume: %d missing keys (first 10): %s", len(missing), missing[:10])
|
| 718 |
+
if topomlp_adapter is not None and "adapter_state_dict" in _resume_ckpt:
|
| 719 |
+
_m, _u = topomlp_adapter.load_state_dict(_resume_ckpt["adapter_state_dict"], strict=False)
|
| 720 |
+
if _main and _u:
|
| 721 |
+
logger.info("Adapter resume: ignored %d legacy keys: %s", len(_u), _u[:5])
|
| 722 |
+
global_step = _resume_ckpt.get("global_step", 0)
|
| 723 |
+
start_epoch = _resume_ckpt.get("epoch", 0)
|
| 724 |
+
if _main:
|
| 725 |
+
logger.info("Resumed from %s (step=%d, epoch=%d)", args.resume, global_step, start_epoch)
|
| 726 |
+
|
| 727 |
+
use_deepspeed = args.deepspeed is not None
|
| 728 |
+
if use_deepspeed:
|
| 729 |
+
import deepspeed
|
| 730 |
+
ds_config = json.load(open(args.deepspeed))
|
| 731 |
+
ds_config["optimizer"] = {
|
| 732 |
+
"type": "Adam",
|
| 733 |
+
"params": {
|
| 734 |
+
"lr": args.lr, "weight_decay": args.weight_decay,
|
| 735 |
+
"betas": [0.9, 0.999], "torch_adam": True, "adam_w_mode": True,
|
| 736 |
+
},
|
| 737 |
+
}
|
| 738 |
+
ds_config["scheduler"] = {
|
| 739 |
+
"type": "WarmupCosineLR",
|
| 740 |
+
"params": {
|
| 741 |
+
"total_num_steps": total_steps,
|
| 742 |
+
"warmup_num_steps": warmup_steps,
|
| 743 |
+
"warmup_type": "linear",
|
| 744 |
+
},
|
| 745 |
+
}
|
| 746 |
+
ds_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
|
| 747 |
+
ds_config["train_micro_batch_size_per_gpu"] = args.batch_size
|
| 748 |
+
ds_config["train_batch_size"] = args.batch_size * args.gradient_accumulation_steps * world_size
|
| 749 |
+
|
| 750 |
+
ds_bf16 = ds_config.get("bf16", {}).get("enabled", False)
|
| 751 |
+
ds_fp16 = ds_config.get("fp16", {}).get("enabled", False)
|
| 752 |
+
if ds_bf16:
|
| 753 |
+
atlas.to(device=device, dtype=torch.bfloat16)
|
| 754 |
+
elif ds_fp16:
|
| 755 |
+
atlas.to(device=device, dtype=torch.float16)
|
| 756 |
+
else:
|
| 757 |
+
atlas.to(device)
|
| 758 |
+
|
| 759 |
+
all_params = atlas.get_trainable_param_groups(args.lr, weight_decay=args.weight_decay)
|
| 760 |
+
if topomlp_adapter is not None:
|
| 761 |
+
_adapter_trainable = [p for p in topomlp_adapter.parameters() if p.requires_grad]
|
| 762 |
+
if _adapter_trainable:
|
| 763 |
+
all_params.append({"params": _adapter_trainable, "lr": args.lr, "weight_decay": 0.0})
|
| 764 |
+
|
| 765 |
+
atlas_ddp, optimizer, _, scheduler = deepspeed.initialize(
|
| 766 |
+
model=atlas, model_parameters=all_params,
|
| 767 |
+
config=ds_config, dist_init_required=False,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
if _resume_ckpt is not None and "optimizer" in _resume_ckpt:
|
| 771 |
+
try:
|
| 772 |
+
optimizer.load_state_dict(_resume_ckpt["optimizer"])
|
| 773 |
+
if _main:
|
| 774 |
+
logger.info("Restored DeepSpeed optimizer state from checkpoint")
|
| 775 |
+
except Exception as e:
|
| 776 |
+
if _main:
|
| 777 |
+
logger.warning("Failed to restore DeepSpeed optimizer state: %s", e)
|
| 778 |
+
|
| 779 |
+
if global_step > 0 and scheduler is not None:
|
| 780 |
+
for _ in range(global_step):
|
| 781 |
+
scheduler.step()
|
| 782 |
+
_ff_lr = scheduler.get_lr()
|
| 783 |
+
if _main:
|
| 784 |
+
logger.info(
|
| 785 |
+
"Fast-forwarded DeepSpeed LR scheduler to step %d (lr=%s)",
|
| 786 |
+
global_step,
|
| 787 |
+
[f"{x:.6e}" for x in _ff_lr] if isinstance(_ff_lr, (list, tuple)) else f"{_ff_lr:.6e}",
|
| 788 |
+
)
|
| 789 |
+
else:
|
| 790 |
+
param_groups = atlas.get_trainable_param_groups(args.lr, weight_decay=args.weight_decay)
|
| 791 |
+
if topomlp_adapter is not None:
|
| 792 |
+
_adapter_trainable = [p for p in topomlp_adapter.parameters() if p.requires_grad]
|
| 793 |
+
if _adapter_trainable:
|
| 794 |
+
param_groups.append({"params": _adapter_trainable, "lr": args.lr, "weight_decay": 0.0})
|
| 795 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, weight_decay=args.weight_decay)
|
| 796 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
|
| 797 |
+
if distributed:
|
| 798 |
+
atlas_ddp = torch.nn.parallel.DistributedDataParallel(
|
| 799 |
+
atlas, device_ids=[args.local_rank], find_unused_parameters=True,
|
| 800 |
+
)
|
| 801 |
+
else:
|
| 802 |
+
atlas_ddp = atlas
|
| 803 |
+
|
| 804 |
+
if _resume_ckpt is not None and not use_deepspeed:
|
| 805 |
+
if "optimizer" in _resume_ckpt:
|
| 806 |
+
try:
|
| 807 |
+
optimizer.load_state_dict(_resume_ckpt["optimizer"])
|
| 808 |
+
except Exception as e:
|
| 809 |
+
if _main:
|
| 810 |
+
logger.warning("Failed to restore optimizer state: %s", e)
|
| 811 |
+
if "scheduler" in _resume_ckpt and _resume_ckpt["scheduler"] is not None:
|
| 812 |
+
try:
|
| 813 |
+
scheduler.load_state_dict(_resume_ckpt["scheduler"])
|
| 814 |
+
except Exception as e:
|
| 815 |
+
if _main:
|
| 816 |
+
logger.warning("Failed to restore scheduler state: %s", e)
|
| 817 |
+
_resume_ckpt = None
|
| 818 |
+
|
| 819 |
+
atlas_ddp.train()
|
| 820 |
+
if topomlp_adapter is not None:
|
| 821 |
+
topomlp_adapter.train()
|
| 822 |
+
|
| 823 |
+
if _main:
|
| 824 |
+
logger.info("=== Training Config ===")
|
| 825 |
+
logger.info(" epochs: %d, lr: %s, batch: %d, accum: %d",
|
| 826 |
+
args.epochs, args.lr, args.batch_size, args.gradient_accumulation_steps)
|
| 827 |
+
logger.info(" total_steps: %d, warmup_steps: %d", total_steps, warmup_steps)
|
| 828 |
+
logger.info(" use_lora: %s, load_in_4bit: %s, fp16: %s, bf16: %s, deepspeed: %s",
|
| 829 |
+
args.use_lora, args.load_in_4bit, args.fp16, args.bf16, use_deepspeed)
|
| 830 |
+
n_trainable = sum(p.numel() for p in atlas.parameters() if p.requires_grad)
|
| 831 |
+
if topomlp_adapter is not None:
|
| 832 |
+
n_trainable += sum(p.numel() for p in topomlp_adapter.parameters() if p.requires_grad)
|
| 833 |
+
logger.info(" trainable params: %s", f"{n_trainable:,}")
|
| 834 |
+
logger.info(" visual_token_mode: %s", args.visual_token_mode)
|
| 835 |
+
logger.info(" streampetr: %s", "online-temporal" if (is_online and streampetr_model) else ("loaded" if streampetr_model else ("precomputed" if _precomp_det else "NONE (should not happen)")))
|
| 836 |
+
logger.info(" topomlp: %s", "online" if (is_online and topomlp_model) else ("loaded" if topomlp_model else ("precomputed" if _precomp_map else "NONE (should not happen)")))
|
| 837 |
+
logger.info("=======================")
|
| 838 |
+
|
| 839 |
+
streaming_state = {} if is_online else None
|
| 840 |
+
|
| 841 |
+
for epoch in range(start_epoch, args.epochs):
|
| 842 |
+
if sampler is not None:
|
| 843 |
+
sampler.set_epoch(epoch)
|
| 844 |
+
|
| 845 |
+
if streaming_state is not None:
|
| 846 |
+
streaming_state.clear()
|
| 847 |
+
if streampetr_model is not None:
|
| 848 |
+
streampetr_model.pts_bbox_head.reset_memory()
|
| 849 |
+
|
| 850 |
+
epoch_loss = 0.0
|
| 851 |
+
num_batches = 0
|
| 852 |
+
t0 = time.time()
|
| 853 |
+
|
| 854 |
+
if not use_deepspeed:
|
| 855 |
+
optimizer.zero_grad()
|
| 856 |
+
|
| 857 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 858 |
+
do_step = _is_optimizer_step_batch(
|
| 859 |
+
batch_idx, num_batches_per_epoch, args.gradient_accumulation_steps
|
| 860 |
+
)
|
| 861 |
+
accum_window_size = _accum_window_size_for_batch(
|
| 862 |
+
batch_idx, num_batches_per_epoch, args.gradient_accumulation_steps
|
| 863 |
+
)
|
| 864 |
+
scaled_loss = None
|
| 865 |
+
if use_deepspeed:
|
| 866 |
+
if not hasattr(atlas_ddp, "set_gradient_accumulation_boundary"):
|
| 867 |
+
raise RuntimeError(
|
| 868 |
+
"DeepSpeed engine is missing set_gradient_accumulation_boundary(); "
|
| 869 |
+
"cannot enforce epoch-tail flush semantics."
|
| 870 |
+
)
|
| 871 |
+
atlas_ddp.set_gradient_accumulation_boundary(do_step)
|
| 872 |
+
|
| 873 |
+
if _main and batch_idx < 5:
|
| 874 |
+
_has_map = "precomputed_map" in batch
|
| 875 |
+
_nq = int((batch["input_ids"] == query_token_id).sum(dim=-1).max().item()) if query_token_id else 0
|
| 876 |
+
logger.info("[DBG] batch_idx=%d nq=%d has_map=%s sid=%s mode=%s",
|
| 877 |
+
batch_idx, _nq, _has_map,
|
| 878 |
+
batch.get("sample_id", ["?"])[0][:20],
|
| 879 |
+
args.visual_token_mode)
|
| 880 |
+
for _handler in logging.root.handlers:
|
| 881 |
+
_handler.flush()
|
| 882 |
+
|
| 883 |
+
input_ids = batch["input_ids"].to(device)
|
| 884 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 885 |
+
labels = batch["labels"].to(device)
|
| 886 |
+
|
| 887 |
+
visual_features = extract_visual_tokens(
|
| 888 |
+
streampetr_model, topomlp_model, topomlp_adapter,
|
| 889 |
+
batch, device, args.num_det_queries, args.visual_hidden_size,
|
| 890 |
+
query_token_id=query_token_id,
|
| 891 |
+
visual_token_mode=args.visual_token_mode,
|
| 892 |
+
streaming_state=streaming_state,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
if _main and batch_idx < 5:
|
| 896 |
+
logger.info("[DBG] vis_keys=%s", list(visual_features.keys()))
|
| 897 |
+
for _handler in logging.root.handlers:
|
| 898 |
+
_handler.flush()
|
| 899 |
+
|
| 900 |
+
if _main and batch_idx < 5:
|
| 901 |
+
logger.info("[DBG] pre-forward batch_idx=%d seqlen=%d", batch_idx, input_ids.shape[1])
|
| 902 |
+
for _handler in logging.root.handlers:
|
| 903 |
+
_handler.flush()
|
| 904 |
+
|
| 905 |
+
outputs = atlas_ddp(
|
| 906 |
+
input_ids=input_ids,
|
| 907 |
+
attention_mask=attention_mask,
|
| 908 |
+
visual_features=visual_features,
|
| 909 |
+
labels=labels,
|
| 910 |
+
)
|
| 911 |
+
loss = outputs.loss
|
| 912 |
+
|
| 913 |
+
if _main and batch_idx < 5:
|
| 914 |
+
logger.info("[DBG] post-forward batch_idx=%d loss=%.4f", batch_idx, loss.item())
|
| 915 |
+
for _handler in logging.root.handlers:
|
| 916 |
+
_handler.flush()
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
if _main and batch_idx < 5:
|
| 921 |
+
logger.info("[DBG] pre-backward batch_idx=%d", batch_idx)
|
| 922 |
+
for _handler in logging.root.handlers:
|
| 923 |
+
_handler.flush()
|
| 924 |
+
|
| 925 |
+
if use_deepspeed:
|
| 926 |
+
scaled_loss = loss / accum_window_size
|
| 927 |
+
atlas_ddp.backward(scaled_loss, scale_wrt_gas=False)
|
| 928 |
+
|
| 929 |
+
if _main and batch_idx < 5:
|
| 930 |
+
logger.info("[DBG] pre-step batch_idx=%d", batch_idx)
|
| 931 |
+
for _handler in logging.root.handlers:
|
| 932 |
+
_handler.flush()
|
| 933 |
+
|
| 934 |
+
atlas_ddp.step()
|
| 935 |
+
else:
|
| 936 |
+
scaled_loss = loss / accum_window_size
|
| 937 |
+
scaled_loss.backward()
|
| 938 |
+
|
| 939 |
+
if not use_deepspeed and distributed and topomlp_adapter is not None:
|
| 940 |
+
for p in topomlp_adapter.parameters():
|
| 941 |
+
if p.requires_grad and p.grad is not None:
|
| 942 |
+
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
|
| 943 |
+
p.grad.div_(world_size)
|
| 944 |
+
|
| 945 |
+
epoch_loss += loss.item()
|
| 946 |
+
num_batches += 1
|
| 947 |
+
if _main and num_batches <= 3:
|
| 948 |
+
logger.info("batch=%d loss=%.4f", num_batches, loss.item())
|
| 949 |
+
for _handler in logging.root.handlers:
|
| 950 |
+
_handler.flush()
|
| 951 |
+
|
| 952 |
+
if do_step:
|
| 953 |
+
if not use_deepspeed:
|
| 954 |
+
all_params = list(atlas.parameters()) + (
|
| 955 |
+
list(topomlp_adapter.parameters()) if topomlp_adapter is not None else []
|
| 956 |
+
)
|
| 957 |
+
trainable = [p for p in all_params if p.requires_grad]
|
| 958 |
+
torch.nn.utils.clip_grad_norm_(trainable, args.max_grad_norm)
|
| 959 |
+
optimizer.step()
|
| 960 |
+
scheduler.step()
|
| 961 |
+
optimizer.zero_grad()
|
| 962 |
+
global_step += 1
|
| 963 |
+
|
| 964 |
+
if _main and global_step % args.log_steps == 0:
|
| 965 |
+
# DeepSpeed LR scheduler may not expose get_last_lr() before first scheduler.step().
|
| 966 |
+
if use_deepspeed and hasattr(atlas_ddp, "get_lr"):
|
| 967 |
+
try:
|
| 968 |
+
_lrs = atlas_ddp.get_lr()
|
| 969 |
+
if isinstance(_lrs, (list, tuple)) and len(_lrs) > 0:
|
| 970 |
+
lr_now = float(_lrs[0])
|
| 971 |
+
else:
|
| 972 |
+
lr_now = float(_lrs)
|
| 973 |
+
except Exception:
|
| 974 |
+
lr_now = optimizer.param_groups[0]["lr"] if getattr(optimizer, "param_groups", None) else args.lr
|
| 975 |
+
elif hasattr(scheduler, "get_last_lr"):
|
| 976 |
+
try:
|
| 977 |
+
lr_now = scheduler.get_last_lr()[0]
|
| 978 |
+
except Exception:
|
| 979 |
+
lr_now = optimizer.param_groups[0]["lr"] if getattr(optimizer, "param_groups", None) else args.lr
|
| 980 |
+
else:
|
| 981 |
+
lr_now = args.lr
|
| 982 |
+
elapsed = time.time() - t0
|
| 983 |
+
samples_sec = num_batches * args.batch_size / max(elapsed, 1e-6)
|
| 984 |
+
avg_loss = epoch_loss / max(num_batches, 1)
|
| 985 |
+
logger.info(
|
| 986 |
+
"epoch=%d step=%d loss=%.4f lr=%.2e samples/s=%.1f",
|
| 987 |
+
epoch, global_step, avg_loss, lr_now, samples_sec,
|
| 988 |
+
)
|
| 989 |
+
for _handler in logging.root.handlers:
|
| 990 |
+
_handler.flush()
|
| 991 |
+
|
| 992 |
+
if _main and args.save_steps > 0 and global_step % args.save_steps == 0:
|
| 993 |
+
ckpt_path = output_dir / f"checkpoint-{global_step}" / "checkpoint.pt"
|
| 994 |
+
save_checkpoint(ckpt_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, epoch, args)
|
| 995 |
+
logger.info("Saved step checkpoint: %s", ckpt_path)
|
| 996 |
+
|
| 997 |
+
avg_loss = epoch_loss / max(num_batches, 1)
|
| 998 |
+
if _main:
|
| 999 |
+
logger.info("Epoch %d done — avg_loss=%.4f (%.1f min)", epoch, avg_loss, (time.time() - t0) / 60)
|
| 1000 |
+
|
| 1001 |
+
if _main and (epoch + 1) % args.save_epochs == 0:
|
| 1002 |
+
ckpt_path = output_dir / f"epoch-{epoch}" / "checkpoint.pt"
|
| 1003 |
+
save_checkpoint(ckpt_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, epoch + 1, args)
|
| 1004 |
+
logger.info("Saved epoch checkpoint: %s", ckpt_path)
|
| 1005 |
+
if args.keep_last_n_ckpts > 0:
|
| 1006 |
+
cleanup_old_checkpoints(output_dir, args.keep_last_n_ckpts)
|
| 1007 |
+
|
| 1008 |
+
if _main:
|
| 1009 |
+
final_path = output_dir / "final" / "checkpoint.pt"
|
| 1010 |
+
save_checkpoint(final_path, atlas, topomlp_adapter, optimizer, scheduler, global_step, args.epochs, args)
|
| 1011 |
+
logger.info("Training complete. Final checkpoint: %s", final_path)
|
| 1012 |
+
|
| 1013 |
+
if distributed:
|
| 1014 |
+
dist.destroy_process_group()
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
if __name__ == "__main__":
|
| 1018 |
+
main()
|