| |
| """Generate Atlas format detection data from nuScenes.""" |
|
|
| import json |
| import os |
| import sys |
| import argparse |
| from pathlib import Path |
| from typing import List, Dict, Tuple, Optional |
| import numpy as np |
| from tqdm import tqdm |
|
|
| |
| |
| Z_MIN, Z_MAX = -5.0, 3.0 |
|
|
| NUSCENES_CLASSES = [ |
| 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', |
| 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' |
| ] |
|
|
| |
| NUSCENES_CATEGORY_MAP = { |
| |
| 'car': 0, 'truck': 1, 'construction_vehicle': 2, 'bus': 3, 'trailer': 4, |
| 'barrier': 5, 'motorcycle': 6, 'bicycle': 7, 'pedestrian': 8, 'traffic_cone': 9, |
| |
| 'vehicle.car': 0, 'vehicle.truck': 1, 'vehicle.construction': 2, |
| 'vehicle.bus.bendy': 3, 'vehicle.bus.rigid': 3, 'vehicle.trailer': 4, |
| 'vehicle.motorcycle': 6, 'vehicle.bicycle': 7, |
| |
| 'human.pedestrian.adult': 8, 'human.pedestrian.child': 8, |
| 'human.pedestrian.construction_worker': 8, 'human.pedestrian.police_officer': 8, |
| 'human.pedestrian.wheelchair': 8, 'human.pedestrian.stroller': 8, |
| 'human.pedestrian.personal_mobility': 8, |
| |
| 'movable_object.barrier': 5, 'movable_object.trafficcone': 9, |
| 'movable_object.traffic_cone': 9, |
| } |
|
|
|
|
| def get_category_id(category_name: str) -> int: |
| """Map nuScenes category name to class ID. |
| |
| Returns: |
| int: Class ID (0-9) if found, -1 if unknown category. |
| """ |
| name_lower = category_name.lower().strip() |
| |
| |
| if name_lower in NUSCENES_CATEGORY_MAP: |
| return NUSCENES_CATEGORY_MAP[name_lower] |
| |
| |
| if name_lower.startswith('human.pedestrian.'): |
| return 8 |
| |
| |
| if name_lower.startswith('vehicle.bus.'): |
| return 3 |
| |
| |
| |
| return -1 |
|
|
|
|
| def coord_to_bin(value: float, min_val: float = -51.2, max_val: float = 51.2, num_bins: int = 1000) -> int: |
| v = float(value) |
| if v < min_val: |
| v = min_val |
| if v > max_val: |
| v = max_val |
| t = (v - min_val) / (max_val - min_val) |
| idx = int(round(t * (num_bins - 1))) |
| if idx < 0: |
| idx = 0 |
| if idx > (num_bins - 1): |
| idx = num_bins - 1 |
| return idx |
|
|
|
|
| def nuscenes_to_paper_xy(x_fwd: float, y_left: float) -> Tuple[float, float]: |
| return (-float(y_left), float(x_fwd)) |
|
|
|
|
| def _clamp(v: float, lo: float, hi: float) -> float: |
| if v < lo: |
| return lo |
| if v > hi: |
| return hi |
| return v |
|
|
|
|
| def yaw_nuscenes_to_paper(yaw_n: float) -> float: |
| y = float(yaw_n) + (np.pi / 2.0) |
| y = (y + np.pi) % (2.0 * np.pi) - np.pi |
| return float(y) |
|
|
|
|
| def get_prompt_template() -> str: |
| import sys as _sys |
| _sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
| try: |
| from src.prompting import sample_prompt |
| return sample_prompt("detection") |
| except ImportError: |
| return ( |
| "There are six images captured by the surround view cameras in driving vehicle. " |
| "They are uniformly represented as queries embeddings<query>. " |
| "Define the positive y-axis as the forward direction and the positive x-axis as the right direction. " |
| "Please complete the visual detection task under the Bird's Eye View (BEV) perspective. " |
| "Ensure that the detection range does not exceed 50 meters." |
| ) |
|
|
|
|
| def filter_boxes_in_range( |
| boxes: List, |
| max_range: float = 50.0, |
| z_min: float = -5.0, |
| z_max: float = 3.0, |
| ) -> List: |
| """Filter boxes by XY distance and Z range. |
| |
| Args: |
| boxes: List of Box objects with .center attribute |
| max_range: Maximum XY distance from origin (meters) |
| z_min: Minimum Z coordinate (meters), default matches StreamPETR point_cloud_range |
| z_max: Maximum Z coordinate (meters), default matches StreamPETR point_cloud_range |
| """ |
| filtered = [] |
| for box in boxes: |
| x, y, z = box.center |
| xy_distance = np.sqrt(x**2 + y**2) |
| if xy_distance <= max_range and z_min <= z <= z_max: |
| filtered.append(box) |
| return filtered |
|
|
|
|
| def _to_short_name(category: str) -> str: |
| cid = get_category_id(category) |
| if 0 <= cid < len(NUSCENES_CLASSES): |
| return NUSCENES_CLASSES[cid] |
| return category |
|
|
|
|
| def format_detection_answer(boxes_with_labels: List[Tuple[str, List[int]]]) -> str: |
| """Format detections as paper Figure 5: category: [pt], [pt]; category: [pt].""" |
| if not boxes_with_labels: |
| return "No objects detected within range." |
|
|
| from collections import OrderedDict |
| grouped: OrderedDict = OrderedDict() |
| for category, bins in boxes_with_labels: |
| name = _to_short_name(category) |
| grouped.setdefault(name, []).append(bins) |
|
|
| parts = [] |
| for cat, pts_list in grouped.items(): |
| pts_list.sort(key=lambda b: (b[0] - 500) ** 2 + (b[1] - 500) ** 2) |
| pts_str = ", ".join(f"[{b[0]}, {b[1]}, {b[2]}]" for b in pts_list) |
| parts.append(f"{cat}: {pts_str}") |
|
|
| return "; ".join(parts) + "." |
|
|
|
|
| def process_nuscenes_sample(nusc, sample_token: str, data_root: Path) -> Optional[Dict]: |
| try: |
| from nuscenes.utils.data_classes import Box |
| from nuscenes.utils.geometry_utils import transform_matrix |
| from pyquaternion import Quaternion |
| |
| sample = nusc.get('sample', sample_token) |
| |
| camera_names = [ |
| 'CAM_FRONT', |
| 'CAM_FRONT_RIGHT', |
| 'CAM_FRONT_LEFT', |
| 'CAM_BACK', |
| 'CAM_BACK_LEFT', |
| 'CAM_BACK_RIGHT' |
| ] |
| |
| image_paths = [] |
| for cam_name in camera_names: |
| if cam_name in sample['data']: |
| cam_token = sample['data'][cam_name] |
| cam_data = nusc.get('sample_data', cam_token) |
| img_path = cam_data['filename'].replace('\\', '/') |
| image_paths.append(img_path) |
| |
| if len(image_paths) != 6: |
| return None |
| |
| lidar_token = sample['data']['LIDAR_TOP'] |
| lidar_data = nusc.get('sample_data', lidar_token) |
| ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token']) |
| |
| ego_translation = ego_pose['translation'] |
| ego_rotation = Quaternion(ego_pose['rotation']) |
| |
| boxes = [] |
| for ann_token in sample['anns']: |
| ann = nusc.get('sample_annotation', ann_token) |
| |
| box_global = Box( |
| ann['translation'], |
| ann['size'], |
| Quaternion(ann['rotation']), |
| name=ann['category_name'], |
| token=ann['token'] |
| ) |
| |
| box_global.translate(-np.array(ego_translation)) |
| box_global.rotate(ego_rotation.inverse) |
| |
| boxes.append(box_global) |
| |
| boxes_in_range = filter_boxes_in_range(boxes, max_range=50.0) |
| |
| boxes_with_labels = [] |
| gt_boxes_3d = [] |
| for box in boxes_in_range: |
| category = box.name |
| |
| |
| cat_id = get_category_id(category) |
| if cat_id == -1: |
| continue |
| |
| x_n, y_n, z = box.center |
| x, y = nuscenes_to_paper_xy(x_n, y_n) |
| |
| |
| x_bin = coord_to_bin(x, -51.2, 51.2, 1000) |
| y_bin = coord_to_bin(y, -51.2, 51.2, 1000) |
| z_bin = coord_to_bin(z, Z_MIN, Z_MAX, 1000) |
| |
| boxes_with_labels.append((category, [x_bin, y_bin, z_bin])) |
|
|
| try: |
| w, l, h = [float(v) for v in box.wlh] |
| yaw_n = float(box.orientation.yaw_pitch_roll[0]) |
| yaw_p = yaw_nuscenes_to_paper(yaw_n) |
| except Exception: |
| w, l, h = 1.0, 1.0, 1.0 |
| yaw_n, yaw_p = 0.0, 0.0 |
|
|
| gt_boxes_3d.append( |
| { |
| "category_name": category, |
| "category_id": cat_id, |
| "box": [float(x), float(y), float(z), w, l, h, yaw_p], |
| "box_frame": "paper", |
| "box_nuscenes": [float(x_n), float(y_n), float(z), w, l, h, yaw_n], |
| "box_nuscenes_frame": "nuscenes_ego", |
| } |
| ) |
| |
| prompt = get_prompt_template() |
| answer = format_detection_answer(boxes_with_labels) |
| |
| data_item = { |
| "id": sample_token, |
| "image_paths": image_paths, |
| "num_map_queries": 0, |
| "task": "detection", |
| "segment_id": sample.get("scene_token", ""), |
| "timestamp": sample.get("timestamp", None), |
| "gt_boxes_3d": gt_boxes_3d, |
| "conversations": [ |
| { |
| "from": "human", |
| "value": prompt |
| }, |
| { |
| "from": "gpt", |
| "value": answer |
| } |
| ] |
| } |
| |
| return data_item |
| |
| except Exception as e: |
| return None |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate Atlas format data") |
| parser.add_argument('--version', type=str, default='v1.0-trainval', |
| choices=['v1.0-mini', 'v1.0-trainval']) |
| parser.add_argument('--split', type=str, default='train', |
| choices=['train', 'val', 'trainval']) |
| parser.add_argument('--data-root', type=str, default=None) |
| parser.add_argument('--output', type=str, default=None) |
| args = parser.parse_args() |
| |
| script_dir = Path(__file__).parent.absolute() |
| project_root = script_dir.parent |
| |
| if args.data_root: |
| data_root = Path(args.data_root) |
| else: |
| data_root = project_root / "data" / "nuscenes" |
| |
| if args.output: |
| output_file = Path(args.output) |
| else: |
| if args.version == 'v1.0-mini': |
| output_file = project_root / "data" / "atlas_mini_train.json" |
| else: |
| output_file = project_root / "data" / f"atlas_{args.split}.json" |
|
|
| print("=" * 80) |
| print("Atlas Data Generation") |
| print("=" * 80) |
| print(f"\nProject root: {project_root}") |
| print(f"Data root: {data_root}") |
| print(f"Version: {args.version}") |
| print(f"Output: {output_file}") |
|
|
| version_dir = data_root / args.version |
| if not version_dir.exists(): |
| print(f"\nError: nuScenes {args.version} not found at {version_dir}") |
| sys.exit(1) |
|
|
| try: |
| from nuscenes.nuscenes import NuScenes |
| print("\nImported nuscenes successfully") |
| except ImportError as e: |
| print(f"\nFailed to import nuscenes: {e}") |
| sys.exit(1) |
|
|
| print(f"\nLoading nuScenes {args.version}...") |
| try: |
| nusc = NuScenes( |
| version=args.version, |
| dataroot=str(data_root), |
| verbose=True |
| ) |
| print("Dataset loaded") |
| print(f" Scenes: {len(nusc.scene)}") |
| print(f" Samples: {len(nusc.sample)}") |
| except Exception as e: |
| print(f"Failed to load dataset: {e}") |
| sys.exit(1) |
|
|
| if args.split == 'trainval' or args.version == 'v1.0-mini': |
| samples_to_process = nusc.sample |
| else: |
| from nuscenes.utils.splits import create_splits_scenes |
| splits = create_splits_scenes() |
| |
| if args.split == 'train': |
| split_scenes = set(splits['train']) |
| else: |
| split_scenes = set(splits['val']) |
|
|
| scene_tokens = set() |
| for scene in nusc.scene: |
| if scene['name'] in split_scenes: |
| scene_tokens.add(scene['token']) |
|
|
| samples_to_process = [s for s in nusc.sample if s['scene_token'] in scene_tokens] |
| print(f"\n{args.split} samples: {len(samples_to_process)}") |
|
|
| print("\nProcessing samples...") |
| atlas_data = [] |
| failed_count = 0 |
| |
| for sample in tqdm(samples_to_process, desc="Processing samples"): |
| data_item = process_nuscenes_sample(nusc, sample['token'], data_root) |
| if data_item is not None: |
| atlas_data.append(data_item) |
| else: |
| failed_count += 1 |
| |
| print(f"\nProcessed {len(atlas_data)} / {len(samples_to_process)} samples") |
| if failed_count > 0: |
| print(f"Skipped {failed_count} invalid samples") |
|
|
| output_file.parent.mkdir(parents=True, exist_ok=True) |
| |
| print(f"\nSaving to: {output_file}") |
| with open(output_file, 'w', encoding='utf-8') as f: |
| json.dump(atlas_data, f, indent=2, ensure_ascii=False) |
| |
| file_size_mb = output_file.stat().st_size / (1024 * 1024) |
| print(f"Saved ({file_size_mb:.2f} MB)") |
|
|
| print("\n" + "=" * 80) |
| print("Statistics") |
| print("=" * 80) |
| print(f"Total samples: {len(atlas_data)}") |
|
|
| if len(atlas_data) > 0: |
| total_objects = 0 |
| category_counts = {} |
| |
| for item in atlas_data: |
| answer = item['conversations'][1]['value'] |
| if "No objects detected" not in answer: |
| objects = answer.split(", ") |
| total_objects += len(objects) |
| |
| for obj in objects: |
| if ":" in obj: |
| category = obj.split(":")[0].strip() |
| category_counts[category] = category_counts.get(category, 0) + 1 |
| |
| print(f"Total objects: {total_objects:,}") |
| print(f"Avg objects per sample: {total_objects / len(atlas_data):.2f}") |
| print(f"\nCategory distribution (Top 10):") |
| |
| sorted_categories = sorted(category_counts.items(), key=lambda x: x[1], reverse=True)[:10] |
| for category, count in sorted_categories: |
| print(f" {category}: {count:,}") |
| |
| print("\n" + "=" * 80) |
| print("Done") |
| print("=" * 80 + "\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|