Atlas-online / scripts /gen_atlas_full_data.py
guoyb0's picture
Add files using upload-large-folder tool
7dfc72e verified
#!/usr/bin/env python3
"""Generate Atlas format detection data from nuScenes."""
import json
import os
import sys
import argparse
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import numpy as np
from tqdm import tqdm
# Z range aligned with StreamPETR point_cloud_range [-5, 3]
# This ensures bin utilization matches the actual detection range
Z_MIN, Z_MAX = -5.0, 3.0
NUSCENES_CLASSES = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
# 完整的 nuScenes 类别名映射到类别 ID
NUSCENES_CATEGORY_MAP = {
# 基础类别名
'car': 0, 'truck': 1, 'construction_vehicle': 2, 'bus': 3, 'trailer': 4,
'barrier': 5, 'motorcycle': 6, 'bicycle': 7, 'pedestrian': 8, 'traffic_cone': 9,
# 完整 nuScenes 类别名 - 车辆
'vehicle.car': 0, 'vehicle.truck': 1, 'vehicle.construction': 2,
'vehicle.bus.bendy': 3, 'vehicle.bus.rigid': 3, 'vehicle.trailer': 4,
'vehicle.motorcycle': 6, 'vehicle.bicycle': 7,
# 完整 nuScenes 类别名 - 行人(所有子类型)
'human.pedestrian.adult': 8, 'human.pedestrian.child': 8,
'human.pedestrian.construction_worker': 8, 'human.pedestrian.police_officer': 8,
'human.pedestrian.wheelchair': 8, 'human.pedestrian.stroller': 8,
'human.pedestrian.personal_mobility': 8,
# 完整 nuScenes 类别名 - 可移动物体
'movable_object.barrier': 5, 'movable_object.trafficcone': 9,
'movable_object.traffic_cone': 9,
}
def get_category_id(category_name: str) -> int:
"""Map nuScenes category name to class ID.
Returns:
int: Class ID (0-9) if found, -1 if unknown category.
"""
name_lower = category_name.lower().strip()
# 首先尝试精确匹配
if name_lower in NUSCENES_CATEGORY_MAP:
return NUSCENES_CATEGORY_MAP[name_lower]
# 对 human.pedestrian.* 子类使用前缀匹配(所有行人子类映射到 pedestrian)
if name_lower.startswith('human.pedestrian.'):
return 8 # pedestrian
# 对 vehicle.bus.* 子类使用前缀匹配(所有巴士子类映射到 bus)
if name_lower.startswith('vehicle.bus.'):
return 3 # bus
# 不使用宽泛的子串匹配,避免 static_object.bicycle_rack 被误匹配到 bicycle
# 未知类别返回 -1,调用方应过滤或处理
return -1
def coord_to_bin(value: float, min_val: float = -51.2, max_val: float = 51.2, num_bins: int = 1000) -> int:
v = float(value)
if v < min_val:
v = min_val
if v > max_val:
v = max_val
t = (v - min_val) / (max_val - min_val)
idx = int(round(t * (num_bins - 1)))
if idx < 0:
idx = 0
if idx > (num_bins - 1):
idx = num_bins - 1
return idx
def nuscenes_to_paper_xy(x_fwd: float, y_left: float) -> Tuple[float, float]:
return (-float(y_left), float(x_fwd))
def _clamp(v: float, lo: float, hi: float) -> float:
if v < lo:
return lo
if v > hi:
return hi
return v
def yaw_nuscenes_to_paper(yaw_n: float) -> float:
y = float(yaw_n) + (np.pi / 2.0)
y = (y + np.pi) % (2.0 * np.pi) - np.pi
return float(y)
def get_prompt_template() -> str:
import sys as _sys
_sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
try:
from src.prompting import sample_prompt
return sample_prompt("detection")
except ImportError:
return (
"There are six images captured by the surround view cameras in driving vehicle. "
"They are uniformly represented as queries embeddings<query>. "
"Define the positive y-axis as the forward direction and the positive x-axis as the right direction. "
"Please complete the visual detection task under the Bird's Eye View (BEV) perspective. "
"Ensure that the detection range does not exceed 50 meters."
)
def filter_boxes_in_range(
boxes: List,
max_range: float = 50.0,
z_min: float = -5.0,
z_max: float = 3.0,
) -> List:
"""Filter boxes by XY distance and Z range.
Args:
boxes: List of Box objects with .center attribute
max_range: Maximum XY distance from origin (meters)
z_min: Minimum Z coordinate (meters), default matches StreamPETR point_cloud_range
z_max: Maximum Z coordinate (meters), default matches StreamPETR point_cloud_range
"""
filtered = []
for box in boxes:
x, y, z = box.center
xy_distance = np.sqrt(x**2 + y**2)
if xy_distance <= max_range and z_min <= z <= z_max:
filtered.append(box)
return filtered
def _to_short_name(category: str) -> str:
cid = get_category_id(category)
if 0 <= cid < len(NUSCENES_CLASSES):
return NUSCENES_CLASSES[cid]
return category
def format_detection_answer(boxes_with_labels: List[Tuple[str, List[int]]]) -> str:
"""Format detections as paper Figure 5: category: [pt], [pt]; category: [pt]."""
if not boxes_with_labels:
return "No objects detected within range."
from collections import OrderedDict
grouped: OrderedDict = OrderedDict()
for category, bins in boxes_with_labels:
name = _to_short_name(category)
grouped.setdefault(name, []).append(bins)
parts = []
for cat, pts_list in grouped.items():
pts_list.sort(key=lambda b: (b[0] - 500) ** 2 + (b[1] - 500) ** 2)
pts_str = ", ".join(f"[{b[0]}, {b[1]}, {b[2]}]" for b in pts_list)
parts.append(f"{cat}: {pts_str}")
return "; ".join(parts) + "."
def process_nuscenes_sample(nusc, sample_token: str, data_root: Path) -> Optional[Dict]:
try:
from nuscenes.utils.data_classes import Box
from nuscenes.utils.geometry_utils import transform_matrix
from pyquaternion import Quaternion
sample = nusc.get('sample', sample_token)
camera_names = [
'CAM_FRONT',
'CAM_FRONT_RIGHT',
'CAM_FRONT_LEFT',
'CAM_BACK',
'CAM_BACK_LEFT',
'CAM_BACK_RIGHT'
]
image_paths = []
for cam_name in camera_names:
if cam_name in sample['data']:
cam_token = sample['data'][cam_name]
cam_data = nusc.get('sample_data', cam_token)
img_path = cam_data['filename'].replace('\\', '/')
image_paths.append(img_path)
if len(image_paths) != 6:
return None
lidar_token = sample['data']['LIDAR_TOP']
lidar_data = nusc.get('sample_data', lidar_token)
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
ego_translation = ego_pose['translation']
ego_rotation = Quaternion(ego_pose['rotation'])
boxes = []
for ann_token in sample['anns']:
ann = nusc.get('sample_annotation', ann_token)
box_global = Box(
ann['translation'],
ann['size'],
Quaternion(ann['rotation']),
name=ann['category_name'],
token=ann['token']
)
box_global.translate(-np.array(ego_translation))
box_global.rotate(ego_rotation.inverse)
boxes.append(box_global)
boxes_in_range = filter_boxes_in_range(boxes, max_range=50.0)
boxes_with_labels = []
gt_boxes_3d = []
for box in boxes_in_range:
category = box.name
# 跳过未知类别(不在 10 类检测任务中的物体,如 animal)
cat_id = get_category_id(category)
if cat_id == -1:
continue
x_n, y_n, z = box.center
x, y = nuscenes_to_paper_xy(x_n, y_n)
# Z is already filtered by filter_boxes_in_range, no clamping needed
x_bin = coord_to_bin(x, -51.2, 51.2, 1000)
y_bin = coord_to_bin(y, -51.2, 51.2, 1000)
z_bin = coord_to_bin(z, Z_MIN, Z_MAX, 1000)
boxes_with_labels.append((category, [x_bin, y_bin, z_bin]))
try:
w, l, h = [float(v) for v in box.wlh]
yaw_n = float(box.orientation.yaw_pitch_roll[0])
yaw_p = yaw_nuscenes_to_paper(yaw_n)
except Exception:
w, l, h = 1.0, 1.0, 1.0
yaw_n, yaw_p = 0.0, 0.0
gt_boxes_3d.append(
{
"category_name": category,
"category_id": cat_id,
"box": [float(x), float(y), float(z), w, l, h, yaw_p],
"box_frame": "paper",
"box_nuscenes": [float(x_n), float(y_n), float(z), w, l, h, yaw_n],
"box_nuscenes_frame": "nuscenes_ego",
}
)
prompt = get_prompt_template()
answer = format_detection_answer(boxes_with_labels)
data_item = {
"id": sample_token,
"image_paths": image_paths,
"num_map_queries": 0,
"task": "detection",
"segment_id": sample.get("scene_token", ""),
"timestamp": sample.get("timestamp", None),
"gt_boxes_3d": gt_boxes_3d,
"conversations": [
{
"from": "human",
"value": prompt
},
{
"from": "gpt",
"value": answer
}
]
}
return data_item
except Exception as e:
return None
def main():
parser = argparse.ArgumentParser(description="Generate Atlas format data")
parser.add_argument('--version', type=str, default='v1.0-trainval',
choices=['v1.0-mini', 'v1.0-trainval'])
parser.add_argument('--split', type=str, default='train',
choices=['train', 'val', 'trainval'])
parser.add_argument('--data-root', type=str, default=None)
parser.add_argument('--output', type=str, default=None)
args = parser.parse_args()
script_dir = Path(__file__).parent.absolute()
project_root = script_dir.parent
if args.data_root:
data_root = Path(args.data_root)
else:
data_root = project_root / "data" / "nuscenes"
if args.output:
output_file = Path(args.output)
else:
if args.version == 'v1.0-mini':
output_file = project_root / "data" / "atlas_mini_train.json"
else:
output_file = project_root / "data" / f"atlas_{args.split}.json"
print("=" * 80)
print("Atlas Data Generation")
print("=" * 80)
print(f"\nProject root: {project_root}")
print(f"Data root: {data_root}")
print(f"Version: {args.version}")
print(f"Output: {output_file}")
version_dir = data_root / args.version
if not version_dir.exists():
print(f"\nError: nuScenes {args.version} not found at {version_dir}")
sys.exit(1)
try:
from nuscenes.nuscenes import NuScenes
print("\nImported nuscenes successfully")
except ImportError as e:
print(f"\nFailed to import nuscenes: {e}")
sys.exit(1)
print(f"\nLoading nuScenes {args.version}...")
try:
nusc = NuScenes(
version=args.version,
dataroot=str(data_root),
verbose=True
)
print("Dataset loaded")
print(f" Scenes: {len(nusc.scene)}")
print(f" Samples: {len(nusc.sample)}")
except Exception as e:
print(f"Failed to load dataset: {e}")
sys.exit(1)
if args.split == 'trainval' or args.version == 'v1.0-mini':
samples_to_process = nusc.sample
else:
from nuscenes.utils.splits import create_splits_scenes
splits = create_splits_scenes()
if args.split == 'train':
split_scenes = set(splits['train'])
else:
split_scenes = set(splits['val'])
scene_tokens = set()
for scene in nusc.scene:
if scene['name'] in split_scenes:
scene_tokens.add(scene['token'])
samples_to_process = [s for s in nusc.sample if s['scene_token'] in scene_tokens]
print(f"\n{args.split} samples: {len(samples_to_process)}")
print("\nProcessing samples...")
atlas_data = []
failed_count = 0
for sample in tqdm(samples_to_process, desc="Processing samples"):
data_item = process_nuscenes_sample(nusc, sample['token'], data_root)
if data_item is not None:
atlas_data.append(data_item)
else:
failed_count += 1
print(f"\nProcessed {len(atlas_data)} / {len(samples_to_process)} samples")
if failed_count > 0:
print(f"Skipped {failed_count} invalid samples")
output_file.parent.mkdir(parents=True, exist_ok=True)
print(f"\nSaving to: {output_file}")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(atlas_data, f, indent=2, ensure_ascii=False)
file_size_mb = output_file.stat().st_size / (1024 * 1024)
print(f"Saved ({file_size_mb:.2f} MB)")
print("\n" + "=" * 80)
print("Statistics")
print("=" * 80)
print(f"Total samples: {len(atlas_data)}")
if len(atlas_data) > 0:
total_objects = 0
category_counts = {}
for item in atlas_data:
answer = item['conversations'][1]['value']
if "No objects detected" not in answer:
objects = answer.split(", ")
total_objects += len(objects)
for obj in objects:
if ":" in obj:
category = obj.split(":")[0].strip()
category_counts[category] = category_counts.get(category, 0) + 1
print(f"Total objects: {total_objects:,}")
print(f"Avg objects per sample: {total_objects / len(atlas_data):.2f}")
print(f"\nCategory distribution (Top 10):")
sorted_categories = sorted(category_counts.items(), key=lambda x: x[1], reverse=True)[:10]
for category, count in sorted_categories:
print(f" {category}: {count:,}")
print("\n" + "=" * 80)
print("Done")
print("=" * 80 + "\n")
if __name__ == "__main__":
main()