#!/usr/bin/env python3 """ Figure 11 — Violation of traffic regulations. This script focuses on the construction-blocking case shown in the paper: the road ahead is fenced by barriers / traffic cones, but Atlas still outputs a "go straight" trajectory that cuts through the blocked area. Style: - Left: 6 camera views, only CAM_FRONT overlays the trajectory - Right: clean BEV with black road boundaries, red construction boxes, green/blue trajectory, and "Go Straight" label """ from __future__ import annotations import argparse import json import math import pickle import sys from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np _REPO = Path(__file__).resolve().parent.parent if str(_REPO) not in sys.path: sys.path.insert(0, str(_REPO)) CAM_ORDER_PAPER = [ "CAM_FRONT_LEFT", "CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_BACK_LEFT", "CAM_BACK", "CAM_BACK_RIGHT", ] _IDX_REORDER = [2, 0, 1, 4, 3, 5] LOCATION_TO_MAP = { "singapore-onenorth": "53992ee3023e5494b90c316c183be829.png", "boston-seaport": "36092f0b03a857c6a3403e25b4b7aab3.png", "singapore-queenstown": "93406b464a165eaba6d9de76ca09f5da.png", "singapore-hollandvillage": "37819e65e09e5547b8a3ceaefba56bb2.png", } MAP_RES = 0.1 # meters per pixel DEFAULT_FIG11_SAMPLE = "856ccc626a4a4c0aaac1e62335050ac0" def _load_json(path: Path): with path.open("r", encoding="utf-8") as f: return json.load(f) def _load_pickle(path: Path): with path.open("rb") as f: return pickle.load(f) def _quat_to_rotmat(qw, qx, qy, qz): n = math.sqrt(qw * qw + qx * qx + qy * qy + qz * qz) if n < 1e-12: return np.eye(3, dtype=np.float64) qw, qx, qy, qz = qw / n, qx / n, qy / n, qz / n xx, yy, zz = qx * qx, qy * qy, qz * qz xy, xz, yz = qx * qy, qx * qz, qy * qz wx, wy, wz = qw * qx, qw * qy, qw * qz return np.array( [ [1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)], [2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)], [2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)], ], dtype=np.float64, ) def _quat_to_yaw(q): w, x, y, z = [float(v) for v in q] return math.atan2(2 * (w * z + x * y), 1 - 2 * (y * y + z * z)) def _paper_xy_to_ego(x_right: float, y_fwd: float, z_up: float = 0.0) -> np.ndarray: return np.array([float(y_fwd), float(-x_right), float(z_up)], dtype=np.float64) def _project_batch(pts_ego: np.ndarray, R_c2e: np.ndarray, t_c2e: np.ndarray, K: np.ndarray) -> np.ndarray: R_e2c = R_c2e.T pts_cam = (R_e2c @ (pts_ego - t_c2e[None, :]).T).T z = pts_cam[:, 2] keep = z > 1e-3 pts_cam = pts_cam[keep] if pts_cam.shape[0] == 0: return np.zeros((0, 2), dtype=np.float64) x = pts_cam[:, 0] / pts_cam[:, 2] y = pts_cam[:, 1] / pts_cam[:, 2] fx, fy = float(K[0, 0]), float(K[1, 1]) cx, cy = float(K[0, 2]), float(K[1, 2]) return np.stack([fx * x + cx, fy * y + cy], axis=1) def _load_cam_calibs(nusc_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]: meta = nusc_root / "v1.0-trainval" sensor = _load_json(meta / "sensor.json") calib = _load_json(meta / "calibrated_sensor.json") sensor_token_by_channel: Dict[str, str] = {} for rec in sensor: ch = rec.get("channel") tok = rec.get("token") if isinstance(ch, str) and isinstance(tok, str): sensor_token_by_channel[ch] = tok calib_by_sensor_token: Dict[str, Dict] = {} for rec in calib: st = rec.get("sensor_token") if isinstance(st, str): calib_by_sensor_token[st] = rec out: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {} for ch, st in sensor_token_by_channel.items(): rec = calib_by_sensor_token.get(st) if not rec or "camera_intrinsic" not in rec: continue t = np.asarray(rec.get("translation", [0, 0, 0]), dtype=np.float64).reshape(3) q = rec.get("rotation", [1, 0, 0, 0]) K = np.asarray(rec.get("camera_intrinsic", np.eye(3).tolist()), dtype=np.float64) if not (isinstance(q, (list, tuple)) and len(q) == 4): continue if K.shape != (3, 3): continue R = _quat_to_rotmat(*[float(x) for x in q]) out[ch] = (R, t, K) return out def _load_ego_poses(nusc_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, float]]: pkl = nusc_root / "nuscenes_infos_val.pkl" obj = _load_pickle(pkl) infos = obj["infos"] if isinstance(obj, dict) and "infos" in obj else obj out = {} for it in infos: tok = str(it.get("token", "")) if not tok: continue t = np.asarray(it.get("ego2global_translation", [0, 0, 0]), dtype=np.float64).reshape(3) q = it.get("ego2global_rotation", [1, 0, 0, 0]) if isinstance(q, (list, tuple)) and len(q) == 4: out[tok] = (_quat_to_rotmat(*[float(x) for x in q]), t, _quat_to_yaw(q)) return out def _get_sample_location(nusc_root: Path, sample_token: str) -> str: samples = _load_json(nusc_root / "v1.0-trainval" / "sample.json") scenes = _load_json(nusc_root / "v1.0-trainval" / "scene.json") logs = _load_json(nusc_root / "v1.0-trainval" / "log.json") scene_map = {s["token"]: s for s in scenes} log_map = {l["token"]: l for l in logs} for s in samples: if s["token"] == sample_token: sc = scene_map.get(s.get("scene_token", ""), {}) lg = log_map.get(sc.get("log_token", ""), {}) return str(lg.get("location", "")) return "" def _smooth_map(bev_map: np.ndarray) -> np.ndarray: acc = bev_map.copy() acc += np.roll(bev_map, 1, axis=0) acc += np.roll(bev_map, -1, axis=0) acc += np.roll(bev_map, 1, axis=1) acc += np.roll(bev_map, -1, axis=1) acc += np.roll(np.roll(bev_map, 1, axis=0), 1, axis=1) acc += np.roll(np.roll(bev_map, 1, axis=0), -1, axis=1) acc += np.roll(np.roll(bev_map, -1, axis=0), 1, axis=1) acc += np.roll(np.roll(bev_map, -1, axis=0), -1, axis=1) return acc / 9.0 def _build_bev_map( nusc_root: Path, location: str, ego_xy: np.ndarray, ego_yaw: float, bev_xlim: Tuple[float, float], bev_ylim: Tuple[float, float], bev_res: float = 0.1, ) -> Optional[np.ndarray]: import PIL.Image PIL.Image.MAX_IMAGE_PIXELS = None from PIL import Image map_fn = LOCATION_TO_MAP.get(location) if not map_fn: return None map_path = nusc_root / "maps" / map_fn if not map_path.exists(): return None map_img = Image.open(map_path) mw, mh = map_img.size map_max_y = mh * MAP_RES map_arr = np.asarray(map_img, dtype=np.float32) / 255.0 ex, ey = float(ego_xy[0]), float(ego_xy[1]) c_yaw, s_yaw = math.cos(ego_yaw), math.sin(ego_yaw) x0, x1 = bev_xlim y0, y1 = bev_ylim nx = int((x1 - x0) / bev_res) ny = int((y1 - y0) / bev_res) bev = np.zeros((ny, nx), dtype=np.float32) px_arr = np.linspace(x0, x1, nx) py_arr = np.linspace(y1, y0, ny) PX, PY = np.meshgrid(px_arr, py_arr) GX = ex + PY * c_yaw + PX * s_yaw GY = ey + PY * s_yaw - PX * c_yaw MX = (GX / MAP_RES).astype(np.int32) MY = ((map_max_y - GY) / MAP_RES).astype(np.int32) valid = (MX >= 0) & (MX < mw) & (MY >= 0) & (MY < mh) bev[valid] = map_arr[MY[valid], MX[valid]] return _smooth_map(bev) def _box_corners(cx: float, cy: float, w: float, l: float, yaw: float) -> np.ndarray: c, s = math.cos(yaw), math.sin(yaw) center = np.array([cx, cy], dtype=np.float64) d_len = np.array([c, s], dtype=np.float64) * (l / 2.0) d_wid = np.array([-s, c], dtype=np.float64) * (w / 2.0) return np.stack( [ center + d_len + d_wid, center + d_len - d_wid, center - d_len - d_wid, center - d_len + d_wid, ], axis=0, ) def _is_barrier_like(cat: str) -> bool: return ("barrier" in cat) or ("traffic_cone" in cat) or ("trafficcone" in cat) def _is_context_like(cat: str) -> bool: return ("construction" in cat) or ("pedestrian" in cat) or ("vehicle.car" in cat) def _title_cmd(cmd: str) -> str: c = (cmd or "").strip().lower() return { "turn left": "Turn Left", "turn right": "Turn Right", "go straight": "Go Straight", }.get(c, cmd) def _blocked_score(item: Dict) -> float: route_command = str(item.get("route_command", "")).strip().lower() if route_command != "go straight": return -1e9 boxes = item.get("gt_boxes_3d", []) or [] wps = (item.get("ego_motion", {}) or {}).get("waypoints", []) or [] n_block_front = 0 n_block_center = 0 n_barrier = 0 n_cone = 0 for b in boxes: cat = str(b.get("category", "")) wc = b.get("world_coords", [0, 0, 0]) if not (isinstance(wc, (list, tuple)) and len(wc) >= 2): continue x, y = float(wc[0]), float(wc[1]) if _is_barrier_like(cat): if "barrier" in cat: n_barrier += 1 else: n_cone += 1 if 0 < y < 25 and abs(x) < 10: n_block_front += 1 if 2 < y < 20 and abs(x) < 4: n_block_center += 1 through_center = 0 for x, y in wps: if 2 < float(y) < 20 and abs(float(x)) < 4: through_center += 1 return n_block_center * 8 + n_block_front * 3 + through_center * 4 + n_barrier * 1.5 + n_cone def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--eval_json", default="work_dirs/eval_final_plan100.json") ap.add_argument("--data_json", default="data/atlas_planning_val_uniad_command.json") ap.add_argument("--data_root", default="/home/guoyuanbo/autodl-tmp/data/nuscenes") ap.add_argument("--sample_id", default=None) ap.add_argument("--out_png", default="work_dirs/atlas_traffic_violation.png") ap.add_argument("--dpi", type=int, default=200) ap.add_argument("--bev_xlim", type=float, nargs=2, default=[-7.5, 10.5]) ap.add_argument("--bev_ylim", type=float, nargs=2, default=[-1.5, 30.5]) return ap.parse_args() def main(): args = parse_args() repo = _REPO nusc_root = Path(args.data_root).resolve() out_png = (repo / args.out_png).resolve() eval_obj = _load_json((repo / args.eval_json).resolve()) data_items = _load_json((repo / args.data_json).resolve()) pred_by_id = { str(r.get("sample_id", "")): str(r.get("generated_text", "")) for r in eval_obj.get("predictions", []) if r.get("sample_id") } item_by_id = {str(it["id"]): it for it in data_items if it.get("id")} from src.eval.metrics import parse_planning_output sid = args.sample_id if not sid: if DEFAULT_FIG11_SAMPLE in item_by_id and parse_planning_output(pred_by_id.get(DEFAULT_FIG11_SAMPLE, "")): sid = DEFAULT_FIG11_SAMPLE else: candidates = [] for item in data_items: item_sid = str(item.get("id", "")) pred_text = pred_by_id.get(item_sid, "") if not pred_text: continue plan = parse_planning_output(pred_text) if not plan or not plan.get("waypoints"): continue candidates.append((_blocked_score(item), item_sid)) candidates.sort(reverse=True) if not candidates: raise RuntimeError("No valid construction-blocked sample found.") sid = candidates[0][1] if sid not in item_by_id: raise RuntimeError(f"sample_id {sid} not found in data_json") item = item_by_id[sid] pred_text = pred_by_id.get(sid, "") plan = parse_planning_output(pred_text) if pred_text else None if not plan or not plan.get("waypoints"): raise RuntimeError(f"sample_id {sid} has no parseable planning output in {args.eval_json}") pred_wps = np.asarray(plan["waypoints"], dtype=np.float64) gt_wps = np.asarray((item.get("ego_motion", {}) or {}).get("waypoints", []) or [], dtype=np.float64) cmd = str(item.get("route_command", "")) boxes = item.get("gt_boxes_3d", []) or [] location = _get_sample_location(nusc_root, sid) ego_poses = _load_ego_poses(nusc_root) ego_info = ego_poses.get(sid) if ego_info is None: raise RuntimeError(f"Missing ego pose for sample {sid}") ego_xy = ego_info[1][:2] ego_yaw = ego_info[2] print(f"[sample] {sid}") print(f" location: {location}") print(f" pred: {pred_text}") bev_map = _build_bev_map( nusc_root, location, ego_xy, ego_yaw, tuple(args.bev_xlim), tuple(args.bev_ylim), bev_res=0.1, ) from PIL import Image rel_paths = list(item.get("image_paths", []) or []) if len(rel_paths) != 6: raise RuntimeError(f"Expected 6 images, got {len(rel_paths)}") imgs = [] for rp in rel_paths: p = Path(rp) if not p.is_absolute(): p = nusc_root / rp imgs.append(Image.open(p).convert("RGB")) imgs = [imgs[i] for i in _IDX_REORDER] cam_calibs = _load_cam_calibs(nusc_root) import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as patches from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec fig = plt.figure(figsize=(14.8, 4.1), dpi=args.dpi) gs = GridSpec(1, 2, figure=fig, width_ratios=[3.0, 1.1], wspace=0.03) gs_cam = GridSpecFromSubplotSpec(2, 3, subplot_spec=gs[0, 0], wspace=0.01, hspace=0.01) ax_imgs = [] for i in range(6): ax = fig.add_subplot(gs_cam[i // 3, i % 3]) ax.imshow(imgs[i]) w_i, h_i = imgs[i].size ax.set_xlim(0, w_i) ax.set_ylim(h_i, 0) ax.axis("off") ax.text( 6, 14, CAM_ORDER_PAPER[i], color="white", fontsize=7, ha="left", va="top", bbox=dict(boxstyle="square,pad=0.12", facecolor="black", edgecolor="none", alpha=0.55), ) ax_imgs.append(ax) if "CAM_FRONT" in cam_calibs: R_c2e, t_c2e, K = cam_calibs["CAM_FRONT"] ax_front = ax_imgs[1] if gt_wps.shape[0] >= 2: uv_gt = _project_batch( np.array([_paper_xy_to_ego(x, y) for x, y in gt_wps], dtype=np.float64), R_c2e, t_c2e, K, ) if uv_gt.shape[0] >= 2: ax_front.plot(uv_gt[:, 0], uv_gt[:, 1], color="#34c759", linewidth=4.0, alpha=0.95, zorder=18) uv_pred = _project_batch( np.array([_paper_xy_to_ego(x, y) for x, y in pred_wps], dtype=np.float64), R_c2e, t_c2e, K, ) if uv_pred.shape[0] >= 2: ax_front.plot(uv_pred[:, 0], uv_pred[:, 1], color="#1f5cff", linewidth=2.2, alpha=0.98, zorder=20) ax_front.scatter(uv_pred[:, 0], uv_pred[:, 1], color="#1f5cff", s=8, zorder=21) ax_bev = fig.add_subplot(gs[0, 1]) ax_bev.set_facecolor("white") ax_bev.set_xlim(*args.bev_xlim) ax_bev.set_ylim(*args.bev_ylim) ax_bev.set_aspect("equal", adjustable="box") ax_bev.set_xticks([]) ax_bev.set_yticks([]) for spine in ax_bev.spines.values(): spine.set_linewidth(1.3) spine.set_color("black") if bev_map is not None: ny, nx = bev_map.shape xs = np.linspace(args.bev_xlim[0], args.bev_xlim[1], nx) ys = np.linspace(args.bev_ylim[0], args.bev_ylim[1], ny) ax_bev.contour(xs, ys, bev_map, levels=[0.5], colors="black", linewidths=0.8, zorder=1) construction_boxes = [] for b in boxes: cat = str(b.get("category", "")) wc = b.get("world_coords", [0, 0, 0]) if not (isinstance(wc, (list, tuple)) and len(wc) >= 2): continue cx, cy = float(wc[0]), float(wc[1]) if not (args.bev_xlim[0] - 2 <= cx <= args.bev_xlim[1] + 2 and args.bev_ylim[0] - 2 <= cy <= args.bev_ylim[1] + 2): continue box_item = (cx, cy, float(b.get("w", 1.8)), float(b.get("l", 4.0)), float(b.get("yaw", 0.0))) if _is_barrier_like(cat): construction_boxes.append(box_item) for cx, cy, w, l, yaw in construction_boxes: poly = _box_corners(cx, cy, w, l, yaw) poly = np.vstack([poly, poly[0:1]]) ax_bev.plot(poly[:, 0], poly[:, 1], color="#cf7a6b", linewidth=0.9, alpha=0.95, zorder=3) if gt_wps.shape[0] >= 2: ax_bev.plot(gt_wps[:, 0], gt_wps[:, 1], color="#34c759", linewidth=3.4, alpha=0.95, zorder=5) ax_bev.plot(pred_wps[:, 0], pred_wps[:, 1], color="#1f5cff", linewidth=1.8, alpha=0.98, zorder=6) ego_rect = patches.Rectangle( (-0.45, -0.45), 0.9, 0.9, linewidth=1.4, edgecolor="#34c759", facecolor="none", zorder=7, ) ax_bev.add_patch(ego_rect) ax_bev.text(0.03, 0.98, "BEV", transform=ax_bev.transAxes, ha="left", va="top", fontsize=9, fontweight="bold") ax_bev.text( 0.03, 0.03, _title_cmd(cmd), transform=ax_bev.transAxes, ha="left", va="bottom", fontsize=9, fontweight="bold", ) out_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_png, bbox_inches="tight", facecolor="white", pad_inches=0.02) plt.close(fig) print(f"[saved] {out_png}") print(f" construction boxes: {len(construction_boxes)}") if __name__ == "__main__": main()