Atlas-online / scripts /gen_atlas_planning_qa.py
guoyb0's picture
Add files using upload-large-folder tool
9fe982a verified
#!/usr/bin/env python3
import math
import argparse
import json
import os
import sys
from collections import Counter
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from src.prompting import PLANNING_TABLE3_MODES, rewrite_planning_prompt_for_table3
Z_MIN, Z_MAX = -5.0, 3.0
VEL_ACC_RANGE = (-50.0, 50.0)
XY_RANGE = (-51.2, 51.2)
NUM_BINS = 1000
WAYPOINT_DT = 0.5
NUM_WAYPOINTS = 6
# Official UniAD get_sdc_planning_label() uses the terminal lateral offset
# (RIGHT if x >= 2, LEFT if x <= -2, else FORWARD). Our waypoints are already
# in Atlas paper frame, where x is lateral-right and y is forward.
UNIAD_COMMAND_X_THRESHOLD = 2.0
CAMERA_NAMES = [
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
]
def _val_to_bin(value: float, min_val: float, max_val: float, num_bins: int = NUM_BINS) -> int:
v = float(np.clip(value, min_val, max_val))
t = (v - min_val) / (max_val - min_val)
idx = int(round(t * (num_bins - 1)))
return int(np.clip(idx, 0, num_bins - 1))
def _nuscenes_to_paper_xy(x_fwd: float, y_left: float) -> Tuple[float, float]:
return (-float(y_left), float(x_fwd))
def _derive_uniad_style_command(
waypoints: List[List[float]],
lateral_threshold: float = UNIAD_COMMAND_X_THRESHOLD,
) -> str:
"""Derive a 3-way planning command from future GT waypoints.
This intentionally matches the semantics of UniAD's
`get_sdc_planning_label()`: the final valid future position determines a
coarse RIGHT / LEFT / FORWARD command based on lateral displacement.
"""
valid_waypoints: List[Tuple[float, float]] = []
for wp in waypoints:
if not isinstance(wp, (list, tuple)) or len(wp) < 2:
continue
x = float(wp[0])
y = float(wp[1])
if np.isfinite(x) and np.isfinite(y):
valid_waypoints.append((x, y))
if not valid_waypoints:
return "go straight"
target_x = float(valid_waypoints[-1][0])
if target_x >= lateral_threshold:
return "turn right"
if target_x <= -lateral_threshold:
return "turn left"
return "go straight"
def _compute_velocity(nusc, sample) -> Tuple[float, float]:
try:
lidar_token = sample['data']['LIDAR_TOP']
lidar_data = nusc.get('sample_data', lidar_token)
from pyquaternion import Quaternion
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
ego_t = np.array(ego_pose['translation'])
ego_q = Quaternion(ego_pose['rotation'])
prev_token = lidar_data.get('prev', '')
if prev_token:
prev_data = nusc.get('sample_data', prev_token)
prev_ego = nusc.get('ego_pose', prev_data['ego_pose_token'])
prev_t = np.array(prev_ego['translation'])
dt = (lidar_data['timestamp'] - prev_data['timestamp']) * 1e-6
if dt > 0:
vel_global = (ego_t - prev_t) / dt
vel_ego = ego_q.inverse.rotate(vel_global)
return float(vel_ego[0]), float(vel_ego[1])
except Exception:
pass
return 0.0, 0.0
def _compute_acceleration(nusc, sample) -> Tuple[float, float]:
try:
lidar_token = sample['data']['LIDAR_TOP']
lidar_data = nusc.get('sample_data', lidar_token)
from pyquaternion import Quaternion
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
ego_q = Quaternion(ego_pose['rotation'])
prev_token = lidar_data.get('prev', '')
if not prev_token:
return 0.0, 0.0
prev_data = nusc.get('sample_data', prev_token)
dt1 = (lidar_data['timestamp'] - prev_data['timestamp']) * 1e-6
if dt1 <= 0:
return 0.0, 0.0
prev2_token = prev_data.get('prev', '')
if not prev2_token:
return 0.0, 0.0
prev2_data = nusc.get('sample_data', prev2_token)
dt2 = (prev_data['timestamp'] - prev2_data['timestamp']) * 1e-6
if dt2 <= 0:
return 0.0, 0.0
def _ego_vel(sd1, sd2, dt_val):
e1 = nusc.get('ego_pose', sd1['ego_pose_token'])
e2 = nusc.get('ego_pose', sd2['ego_pose_token'])
t1 = np.array(e1['translation'])
t2 = np.array(e2['translation'])
return (t1 - t2) / dt_val
v1_global = _ego_vel(lidar_data, prev_data, dt1)
v0_global = _ego_vel(prev_data, prev2_data, dt2)
acc_global = (v1_global - v0_global) / dt1
acc_ego = ego_q.inverse.rotate(acc_global)
return float(acc_ego[0]), float(acc_ego[1])
except Exception:
return 0.0, 0.0
def _get_future_waypoints(nusc, sample) -> Optional[List[List[float]]]:
try:
from pyquaternion import Quaternion
lidar_token = sample['data']['LIDAR_TOP']
lidar_data = nusc.get('sample_data', lidar_token)
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
ego_t = np.array(ego_pose['translation'])
ego_q = Quaternion(ego_pose['rotation'])
current_ts = lidar_data['timestamp']
target_times = [current_ts + int(WAYPOINT_DT * (i + 1) * 1e6) for i in range(NUM_WAYPOINTS)]
all_sd = []
sd_token = lidar_token
while sd_token:
sd = nusc.get('sample_data', sd_token)
all_sd.append(sd)
sd_token = sd.get('next', '')
if not sd_token:
break
if sd['timestamp'] > target_times[-1] + 1e6:
break
if len(all_sd) < 2:
return None
timestamps = np.array([s['timestamp'] for s in all_sd])
poses = []
for s in all_sd:
ep = nusc.get('ego_pose', s['ego_pose_token'])
poses.append(np.array(ep['translation']))
poses = np.array(poses)
waypoints = []
for tt in target_times:
if tt > timestamps[-1] or tt < timestamps[0]:
return None
idx = np.searchsorted(timestamps, tt, side='right') - 1
idx = max(0, min(idx, len(timestamps) - 2))
dt_seg = timestamps[idx + 1] - timestamps[idx]
if dt_seg <= 0:
return None
alpha = (tt - timestamps[idx]) / dt_seg
pos_global = poses[idx] * (1 - alpha) + poses[idx + 1] * alpha
pos_ego = ego_q.inverse.rotate(pos_global - ego_t)
x_p, y_p = _nuscenes_to_paper_xy(pos_ego[0], pos_ego[1])
waypoints.append([float(x_p), float(y_p)])
return waypoints
except Exception:
return None
def _format_planning_answer(
vx: float, vy: float, ax: float, ay: float,
waypoints: List[List[float]],
) -> str:
vx_bin = _val_to_bin(vx, *VEL_ACC_RANGE)
vy_bin = _val_to_bin(vy, *VEL_ACC_RANGE)
ax_bin = _val_to_bin(ax, *VEL_ACC_RANGE)
ay_bin = _val_to_bin(ay, *VEL_ACC_RANGE)
wp_strs = []
for wp in waypoints:
xb = _val_to_bin(wp[0], *XY_RANGE)
yb = _val_to_bin(wp[1], *XY_RANGE)
wp_strs.append(f"[{xb}, {yb}]")
return (
f"Ego car speed value:[{vx_bin}, {vy_bin}]. "
f"Ego car acceleration value:[{ax_bin}, {ay_bin}]. "
"Based on the ego car speed and acceleration you predicted, "
f"request the ego car planning waypoint in 3-seconds: {', '.join(wp_strs)}"
)
def _collect_gt_boxes_ego(nusc, sample) -> List[Dict]:
from pyquaternion import Quaternion
lidar_token = sample['data']['LIDAR_TOP']
lidar_data = nusc.get('sample_data', lidar_token)
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
ego_t = np.array(ego_pose['translation'])
ego_q = Quaternion(ego_pose['rotation'])
cs_record = nusc.get('calibrated_sensor', lidar_data['calibrated_sensor_token'])
cs_t = np.array(cs_record['translation'])
cs_q = Quaternion(cs_record['rotation'])
boxes = []
for ann_token in sample['anns']:
ann = nusc.get('sample_annotation', ann_token)
center_global = np.array(ann['translation'])
center_ego = ego_q.inverse.rotate(center_global - ego_t)
x_p, y_p = _nuscenes_to_paper_xy(center_ego[0], center_ego[1])
yaw_global = Quaternion(ann['rotation'])
yaw_ego = ego_q.inverse * yaw_global
# _nuscenes_to_paper_xy applies a 90° CCW rotation:
# x_paper = -y_ego, y_paper = x_ego
# Yaw must be rotated by the same +π/2 to stay consistent.
yaw_angle = float(yaw_ego.yaw_pitch_roll[0]) + math.pi / 2.0
w = float(ann['size'][0])
l = float(ann['size'][1])
h = float(ann['size'][2])
boxes.append({
"world_coords": [float(x_p), float(y_p), float(center_ego[2])],
"w": w,
"l": l,
"h": h,
"yaw": yaw_angle,
"category": ann['category_name'],
})
return boxes
def _collect_gt_boxes_per_timestep(nusc, sample, num_timesteps=NUM_WAYPOINTS) -> List[List[Dict]]:
"""Collect GT boxes for each future keyframe, transformed to current ego frame.
ST-P3 protocol: at each future timestep t, collision is checked against
the actual positions of other agents at time t, not their positions at t=0.
nuScenes keyframes are ~0.5s apart, matching the waypoint interval.
"""
from pyquaternion import Quaternion
lidar_token = sample['data']['LIDAR_TOP']
lidar_data = nusc.get('sample_data', lidar_token)
ego_pose = nusc.get('ego_pose', lidar_data['ego_pose_token'])
ref_ego_t = np.array(ego_pose['translation'])
ref_ego_q = Quaternion(ego_pose['rotation'])
per_timestep_boxes: List[List[Dict]] = []
cur_sample = sample
for _ in range(num_timesteps):
next_token = cur_sample.get('next', '')
if not next_token:
per_timestep_boxes.append(per_timestep_boxes[-1] if per_timestep_boxes else [])
continue
cur_sample = nusc.get('sample', next_token)
boxes = []
for ann_token in cur_sample['anns']:
ann = nusc.get('sample_annotation', ann_token)
center_global = np.array(ann['translation'])
center_ego = ref_ego_q.inverse.rotate(center_global - ref_ego_t)
x_p, y_p = _nuscenes_to_paper_xy(center_ego[0], center_ego[1])
yaw_global = Quaternion(ann['rotation'])
yaw_ego = ref_ego_q.inverse * yaw_global
yaw_angle = float(yaw_ego.yaw_pitch_roll[0]) + math.pi / 2.0
w = float(ann['size'][0])
l = float(ann['size'][1])
h = float(ann['size'][2])
boxes.append({
"world_coords": [float(x_p), float(y_p), float(center_ego[2])],
"w": w, "l": l, "h": h,
"yaw": yaw_angle,
"category": ann['category_name'],
})
per_timestep_boxes.append(boxes)
return per_timestep_boxes
def process_sample(
nusc,
sample_token: str,
data_root: Path,
planning_table3_mode: str,
) -> Optional[Dict]:
try:
from pyquaternion import Quaternion
from src.prompting import sample_prompt
sample = nusc.get('sample', sample_token)
image_paths = []
for cam_name in CAMERA_NAMES:
if cam_name in sample['data']:
cam_token = sample['data'][cam_name]
cam_data = nusc.get('sample_data', cam_token)
image_paths.append(cam_data['filename'].replace('\\', '/'))
if len(image_paths) != 6:
return None
vx_n, vy_n = _compute_velocity(nusc, sample)
ax_n, ay_n = _compute_acceleration(nusc, sample)
vx_p, vy_p = _nuscenes_to_paper_xy(vx_n, vy_n)
ax_p, ay_p = _nuscenes_to_paper_xy(ax_n, ay_n)
waypoints = _get_future_waypoints(nusc, sample)
if waypoints is None:
return None
vx_bin = _val_to_bin(vx_p, *VEL_ACC_RANGE)
vy_bin = _val_to_bin(vy_p, *VEL_ACC_RANGE)
ax_bin = _val_to_bin(ax_p, *VEL_ACC_RANGE)
ay_bin = _val_to_bin(ay_p, *VEL_ACC_RANGE)
route_command = _derive_uniad_style_command(waypoints)
prompt = sample_prompt(
"planning",
vx_bin=vx_bin, vy_bin=vy_bin,
ax_bin=ax_bin, ay_bin=ay_bin,
command=route_command,
)
prompt = rewrite_planning_prompt_for_table3(
prompt,
mode=planning_table3_mode,
command=route_command,
velocity_bins=(vx_bin, vy_bin),
acceleration_bins=(ax_bin, ay_bin),
)
answer = _format_planning_answer(vx_p, vy_p, ax_p, ay_p, waypoints)
gt_boxes = _collect_gt_boxes_ego(nusc, sample)
gt_boxes_per_ts = _collect_gt_boxes_per_timestep(nusc, sample)
item = {
"id": sample_token,
"image_paths": image_paths,
"num_map_queries": 256,
"task": "planning",
"segment_id": sample.get("scene_token", ""),
"timestamp": sample.get("timestamp", None),
"ego_motion": {
"velocity": [vx_p, vy_p],
"acceleration": [ax_p, ay_p],
"waypoints": waypoints,
},
"gt_boxes_3d": gt_boxes,
"gt_boxes_3d_per_timestep": gt_boxes_per_ts,
"conversations": [
{"from": "human", "value": prompt},
{"from": "gpt", "value": answer},
],
"route_command": route_command,
}
return item
except Exception:
return None
def _audit_results(results: List[Dict], planning_table3_mode: str) -> None:
total = int(len(results))
if total == 0:
print("[AUDIT] No planning samples were generated.")
return
route_commands = [item.get("route_command") for item in results]
route_command_coverage = sum(isinstance(cmd, str) and bool(cmd) for cmd in route_commands)
route_command_dist = Counter(route_commands)
legacy_ego_motion_command = sum(
1
for item in results
if isinstance(item.get("ego_motion"), dict) and "command" in item["ego_motion"]
)
prompt_with_command = 0
prompt_with_state = 0
for item in results:
conv = item.get("conversations", [])
if not conv:
continue
prompt_text = str(conv[0].get("value", ""))
if "The ego car will " in prompt_text:
prompt_with_command += 1
if "The current speed value of the ego car is [" in prompt_text:
prompt_with_state += 1
print(
"[AUDIT] planning route_command "
f"mode={planning_table3_mode} "
f"coverage={route_command_coverage}/{total} "
f"legacy_ego_motion_command={legacy_ego_motion_command}/{total} "
f"prompt_with_command={prompt_with_command}/{total} "
f"prompt_with_state={prompt_with_state}/{total}"
)
print(f"[AUDIT] planning route_command distribution={dict(route_command_dist)}")
print(
"[AUDIT] route_command semantics: UniAD-style future-GT-derived "
f"(terminal lateral x threshold={UNIAD_COMMAND_X_THRESHOLD:.1f}m)."
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--version', type=str, default='v1.0-trainval')
parser.add_argument('--split', type=str, default='train', choices=['train', 'val'])
parser.add_argument('--data-root', type=str, default='/mnt/data/nuscenes')
parser.add_argument('--output', type=str, default=None)
parser.add_argument(
'--planning-table3-mode',
type=str,
choices=PLANNING_TABLE3_MODES,
default='atlas_high_level',
help=(
'Human prompt variant to materialize in the generated JSON. '
'route_command is always written as a top-level UniAD-style '
'future-GT-derived command.'
),
)
args = parser.parse_args()
data_root = Path(args.data_root)
script_dir = Path(__file__).parent.absolute()
project_root = script_dir.parent
if args.output:
output_file = Path(args.output)
else:
output_file = project_root / "data" / f"atlas_planning_{args.split}_uniad_command.json"
from nuscenes.nuscenes import NuScenes
from nuscenes.utils.splits import create_splits_scenes
nusc = NuScenes(version=args.version, dataroot=str(data_root), verbose=True)
splits = create_splits_scenes()
split_scenes = set(splits[args.split])
scene_tokens = set()
for scene in nusc.scene:
if scene['name'] in split_scenes:
scene_tokens.add(scene['token'])
samples_to_process = [s for s in nusc.sample if s['scene_token'] in scene_tokens]
print(f"Processing {len(samples_to_process)} samples for planning...")
results = []
for sample in tqdm(samples_to_process):
item = process_sample(
nusc,
sample['token'],
data_root,
planning_table3_mode=args.planning_table3_mode,
)
if item is not None:
results.append(item)
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
_audit_results(results, args.planning_table3_mode)
print(f"Saved {len(results)} planning samples to {output_file}")
if __name__ == "__main__":
main()