guoyb0's picture
Add files using upload-large-folder tool
7dfc72e verified
#!/usr/bin/env python3
"""
Figure 11 — Violation of traffic regulations.
This script focuses on the construction-blocking case shown in the paper:
the road ahead is fenced by barriers / traffic cones, but Atlas still outputs
a "go straight" trajectory that cuts through the blocked area.
Style:
- Left: 6 camera views, only CAM_FRONT overlays the trajectory
- Right: clean BEV with black road boundaries, red construction boxes,
green/blue trajectory, and "Go Straight" label
"""
from __future__ import annotations
import argparse
import json
import math
import pickle
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
_REPO = Path(__file__).resolve().parent.parent
if str(_REPO) not in sys.path:
sys.path.insert(0, str(_REPO))
CAM_ORDER_PAPER = [
"CAM_FRONT_LEFT",
"CAM_FRONT",
"CAM_FRONT_RIGHT",
"CAM_BACK_LEFT",
"CAM_BACK",
"CAM_BACK_RIGHT",
]
_IDX_REORDER = [2, 0, 1, 4, 3, 5]
LOCATION_TO_MAP = {
"singapore-onenorth": "53992ee3023e5494b90c316c183be829.png",
"boston-seaport": "36092f0b03a857c6a3403e25b4b7aab3.png",
"singapore-queenstown": "93406b464a165eaba6d9de76ca09f5da.png",
"singapore-hollandvillage": "37819e65e09e5547b8a3ceaefba56bb2.png",
}
MAP_RES = 0.1 # meters per pixel
DEFAULT_FIG11_SAMPLE = "856ccc626a4a4c0aaac1e62335050ac0"
def _load_json(path: Path):
with path.open("r", encoding="utf-8") as f:
return json.load(f)
def _load_pickle(path: Path):
with path.open("rb") as f:
return pickle.load(f)
def _quat_to_rotmat(qw, qx, qy, qz):
n = math.sqrt(qw * qw + qx * qx + qy * qy + qz * qz)
if n < 1e-12:
return np.eye(3, dtype=np.float64)
qw, qx, qy, qz = qw / n, qx / n, qy / n, qz / n
xx, yy, zz = qx * qx, qy * qy, qz * qz
xy, xz, yz = qx * qy, qx * qz, qy * qz
wx, wy, wz = qw * qx, qw * qy, qw * qz
return np.array(
[
[1 - 2 * (yy + zz), 2 * (xy - wz), 2 * (xz + wy)],
[2 * (xy + wz), 1 - 2 * (xx + zz), 2 * (yz - wx)],
[2 * (xz - wy), 2 * (yz + wx), 1 - 2 * (xx + yy)],
],
dtype=np.float64,
)
def _quat_to_yaw(q):
w, x, y, z = [float(v) for v in q]
return math.atan2(2 * (w * z + x * y), 1 - 2 * (y * y + z * z))
def _paper_xy_to_ego(x_right: float, y_fwd: float, z_up: float = 0.0) -> np.ndarray:
return np.array([float(y_fwd), float(-x_right), float(z_up)], dtype=np.float64)
def _project_batch(pts_ego: np.ndarray, R_c2e: np.ndarray, t_c2e: np.ndarray, K: np.ndarray) -> np.ndarray:
R_e2c = R_c2e.T
pts_cam = (R_e2c @ (pts_ego - t_c2e[None, :]).T).T
z = pts_cam[:, 2]
keep = z > 1e-3
pts_cam = pts_cam[keep]
if pts_cam.shape[0] == 0:
return np.zeros((0, 2), dtype=np.float64)
x = pts_cam[:, 0] / pts_cam[:, 2]
y = pts_cam[:, 1] / pts_cam[:, 2]
fx, fy = float(K[0, 0]), float(K[1, 1])
cx, cy = float(K[0, 2]), float(K[1, 2])
return np.stack([fx * x + cx, fy * y + cy], axis=1)
def _load_cam_calibs(nusc_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
meta = nusc_root / "v1.0-trainval"
sensor = _load_json(meta / "sensor.json")
calib = _load_json(meta / "calibrated_sensor.json")
sensor_token_by_channel: Dict[str, str] = {}
for rec in sensor:
ch = rec.get("channel")
tok = rec.get("token")
if isinstance(ch, str) and isinstance(tok, str):
sensor_token_by_channel[ch] = tok
calib_by_sensor_token: Dict[str, Dict] = {}
for rec in calib:
st = rec.get("sensor_token")
if isinstance(st, str):
calib_by_sensor_token[st] = rec
out: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]] = {}
for ch, st in sensor_token_by_channel.items():
rec = calib_by_sensor_token.get(st)
if not rec or "camera_intrinsic" not in rec:
continue
t = np.asarray(rec.get("translation", [0, 0, 0]), dtype=np.float64).reshape(3)
q = rec.get("rotation", [1, 0, 0, 0])
K = np.asarray(rec.get("camera_intrinsic", np.eye(3).tolist()), dtype=np.float64)
if not (isinstance(q, (list, tuple)) and len(q) == 4):
continue
if K.shape != (3, 3):
continue
R = _quat_to_rotmat(*[float(x) for x in q])
out[ch] = (R, t, K)
return out
def _load_ego_poses(nusc_root: Path) -> Dict[str, Tuple[np.ndarray, np.ndarray, float]]:
pkl = nusc_root / "nuscenes_infos_val.pkl"
obj = _load_pickle(pkl)
infos = obj["infos"] if isinstance(obj, dict) and "infos" in obj else obj
out = {}
for it in infos:
tok = str(it.get("token", ""))
if not tok:
continue
t = np.asarray(it.get("ego2global_translation", [0, 0, 0]), dtype=np.float64).reshape(3)
q = it.get("ego2global_rotation", [1, 0, 0, 0])
if isinstance(q, (list, tuple)) and len(q) == 4:
out[tok] = (_quat_to_rotmat(*[float(x) for x in q]), t, _quat_to_yaw(q))
return out
def _get_sample_location(nusc_root: Path, sample_token: str) -> str:
samples = _load_json(nusc_root / "v1.0-trainval" / "sample.json")
scenes = _load_json(nusc_root / "v1.0-trainval" / "scene.json")
logs = _load_json(nusc_root / "v1.0-trainval" / "log.json")
scene_map = {s["token"]: s for s in scenes}
log_map = {l["token"]: l for l in logs}
for s in samples:
if s["token"] == sample_token:
sc = scene_map.get(s.get("scene_token", ""), {})
lg = log_map.get(sc.get("log_token", ""), {})
return str(lg.get("location", ""))
return ""
def _smooth_map(bev_map: np.ndarray) -> np.ndarray:
acc = bev_map.copy()
acc += np.roll(bev_map, 1, axis=0)
acc += np.roll(bev_map, -1, axis=0)
acc += np.roll(bev_map, 1, axis=1)
acc += np.roll(bev_map, -1, axis=1)
acc += np.roll(np.roll(bev_map, 1, axis=0), 1, axis=1)
acc += np.roll(np.roll(bev_map, 1, axis=0), -1, axis=1)
acc += np.roll(np.roll(bev_map, -1, axis=0), 1, axis=1)
acc += np.roll(np.roll(bev_map, -1, axis=0), -1, axis=1)
return acc / 9.0
def _build_bev_map(
nusc_root: Path,
location: str,
ego_xy: np.ndarray,
ego_yaw: float,
bev_xlim: Tuple[float, float],
bev_ylim: Tuple[float, float],
bev_res: float = 0.1,
) -> Optional[np.ndarray]:
import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = None
from PIL import Image
map_fn = LOCATION_TO_MAP.get(location)
if not map_fn:
return None
map_path = nusc_root / "maps" / map_fn
if not map_path.exists():
return None
map_img = Image.open(map_path)
mw, mh = map_img.size
map_max_y = mh * MAP_RES
map_arr = np.asarray(map_img, dtype=np.float32) / 255.0
ex, ey = float(ego_xy[0]), float(ego_xy[1])
c_yaw, s_yaw = math.cos(ego_yaw), math.sin(ego_yaw)
x0, x1 = bev_xlim
y0, y1 = bev_ylim
nx = int((x1 - x0) / bev_res)
ny = int((y1 - y0) / bev_res)
bev = np.zeros((ny, nx), dtype=np.float32)
px_arr = np.linspace(x0, x1, nx)
py_arr = np.linspace(y1, y0, ny)
PX, PY = np.meshgrid(px_arr, py_arr)
GX = ex + PY * c_yaw + PX * s_yaw
GY = ey + PY * s_yaw - PX * c_yaw
MX = (GX / MAP_RES).astype(np.int32)
MY = ((map_max_y - GY) / MAP_RES).astype(np.int32)
valid = (MX >= 0) & (MX < mw) & (MY >= 0) & (MY < mh)
bev[valid] = map_arr[MY[valid], MX[valid]]
return _smooth_map(bev)
def _box_corners(cx: float, cy: float, w: float, l: float, yaw: float) -> np.ndarray:
c, s = math.cos(yaw), math.sin(yaw)
center = np.array([cx, cy], dtype=np.float64)
d_len = np.array([c, s], dtype=np.float64) * (l / 2.0)
d_wid = np.array([-s, c], dtype=np.float64) * (w / 2.0)
return np.stack(
[
center + d_len + d_wid,
center + d_len - d_wid,
center - d_len - d_wid,
center - d_len + d_wid,
],
axis=0,
)
def _is_barrier_like(cat: str) -> bool:
return ("barrier" in cat) or ("traffic_cone" in cat) or ("trafficcone" in cat)
def _is_context_like(cat: str) -> bool:
return ("construction" in cat) or ("pedestrian" in cat) or ("vehicle.car" in cat)
def _title_cmd(cmd: str) -> str:
c = (cmd or "").strip().lower()
return {
"turn left": "Turn Left",
"turn right": "Turn Right",
"go straight": "Go Straight",
}.get(c, cmd)
def _blocked_score(item: Dict) -> float:
route_command = str(item.get("route_command", "")).strip().lower()
if route_command != "go straight":
return -1e9
boxes = item.get("gt_boxes_3d", []) or []
wps = (item.get("ego_motion", {}) or {}).get("waypoints", []) or []
n_block_front = 0
n_block_center = 0
n_barrier = 0
n_cone = 0
for b in boxes:
cat = str(b.get("category", ""))
wc = b.get("world_coords", [0, 0, 0])
if not (isinstance(wc, (list, tuple)) and len(wc) >= 2):
continue
x, y = float(wc[0]), float(wc[1])
if _is_barrier_like(cat):
if "barrier" in cat:
n_barrier += 1
else:
n_cone += 1
if 0 < y < 25 and abs(x) < 10:
n_block_front += 1
if 2 < y < 20 and abs(x) < 4:
n_block_center += 1
through_center = 0
for x, y in wps:
if 2 < float(y) < 20 and abs(float(x)) < 4:
through_center += 1
return n_block_center * 8 + n_block_front * 3 + through_center * 4 + n_barrier * 1.5 + n_cone
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--eval_json", default="work_dirs/eval_final_plan100.json")
ap.add_argument("--data_json", default="data/atlas_planning_val_uniad_command.json")
ap.add_argument("--data_root", default="/home/guoyuanbo/autodl-tmp/data/nuscenes")
ap.add_argument("--sample_id", default=None)
ap.add_argument("--out_png", default="work_dirs/atlas_traffic_violation.png")
ap.add_argument("--dpi", type=int, default=200)
ap.add_argument("--bev_xlim", type=float, nargs=2, default=[-7.5, 10.5])
ap.add_argument("--bev_ylim", type=float, nargs=2, default=[-1.5, 30.5])
return ap.parse_args()
def main():
args = parse_args()
repo = _REPO
nusc_root = Path(args.data_root).resolve()
out_png = (repo / args.out_png).resolve()
eval_obj = _load_json((repo / args.eval_json).resolve())
data_items = _load_json((repo / args.data_json).resolve())
pred_by_id = {
str(r.get("sample_id", "")): str(r.get("generated_text", ""))
for r in eval_obj.get("predictions", [])
if r.get("sample_id")
}
item_by_id = {str(it["id"]): it for it in data_items if it.get("id")}
from src.eval.metrics import parse_planning_output
sid = args.sample_id
if not sid:
if DEFAULT_FIG11_SAMPLE in item_by_id and parse_planning_output(pred_by_id.get(DEFAULT_FIG11_SAMPLE, "")):
sid = DEFAULT_FIG11_SAMPLE
else:
candidates = []
for item in data_items:
item_sid = str(item.get("id", ""))
pred_text = pred_by_id.get(item_sid, "")
if not pred_text:
continue
plan = parse_planning_output(pred_text)
if not plan or not plan.get("waypoints"):
continue
candidates.append((_blocked_score(item), item_sid))
candidates.sort(reverse=True)
if not candidates:
raise RuntimeError("No valid construction-blocked sample found.")
sid = candidates[0][1]
if sid not in item_by_id:
raise RuntimeError(f"sample_id {sid} not found in data_json")
item = item_by_id[sid]
pred_text = pred_by_id.get(sid, "")
plan = parse_planning_output(pred_text) if pred_text else None
if not plan or not plan.get("waypoints"):
raise RuntimeError(f"sample_id {sid} has no parseable planning output in {args.eval_json}")
pred_wps = np.asarray(plan["waypoints"], dtype=np.float64)
gt_wps = np.asarray((item.get("ego_motion", {}) or {}).get("waypoints", []) or [], dtype=np.float64)
cmd = str(item.get("route_command", ""))
boxes = item.get("gt_boxes_3d", []) or []
location = _get_sample_location(nusc_root, sid)
ego_poses = _load_ego_poses(nusc_root)
ego_info = ego_poses.get(sid)
if ego_info is None:
raise RuntimeError(f"Missing ego pose for sample {sid}")
ego_xy = ego_info[1][:2]
ego_yaw = ego_info[2]
print(f"[sample] {sid}")
print(f" location: {location}")
print(f" pred: {pred_text}")
bev_map = _build_bev_map(
nusc_root,
location,
ego_xy,
ego_yaw,
tuple(args.bev_xlim),
tuple(args.bev_ylim),
bev_res=0.1,
)
from PIL import Image
rel_paths = list(item.get("image_paths", []) or [])
if len(rel_paths) != 6:
raise RuntimeError(f"Expected 6 images, got {len(rel_paths)}")
imgs = []
for rp in rel_paths:
p = Path(rp)
if not p.is_absolute():
p = nusc_root / rp
imgs.append(Image.open(p).convert("RGB"))
imgs = [imgs[i] for i in _IDX_REORDER]
cam_calibs = _load_cam_calibs(nusc_root)
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
fig = plt.figure(figsize=(14.8, 4.1), dpi=args.dpi)
gs = GridSpec(1, 2, figure=fig, width_ratios=[3.0, 1.1], wspace=0.03)
gs_cam = GridSpecFromSubplotSpec(2, 3, subplot_spec=gs[0, 0], wspace=0.01, hspace=0.01)
ax_imgs = []
for i in range(6):
ax = fig.add_subplot(gs_cam[i // 3, i % 3])
ax.imshow(imgs[i])
w_i, h_i = imgs[i].size
ax.set_xlim(0, w_i)
ax.set_ylim(h_i, 0)
ax.axis("off")
ax.text(
6,
14,
CAM_ORDER_PAPER[i],
color="white",
fontsize=7,
ha="left",
va="top",
bbox=dict(boxstyle="square,pad=0.12", facecolor="black", edgecolor="none", alpha=0.55),
)
ax_imgs.append(ax)
if "CAM_FRONT" in cam_calibs:
R_c2e, t_c2e, K = cam_calibs["CAM_FRONT"]
ax_front = ax_imgs[1]
if gt_wps.shape[0] >= 2:
uv_gt = _project_batch(
np.array([_paper_xy_to_ego(x, y) for x, y in gt_wps], dtype=np.float64),
R_c2e,
t_c2e,
K,
)
if uv_gt.shape[0] >= 2:
ax_front.plot(uv_gt[:, 0], uv_gt[:, 1], color="#34c759", linewidth=4.0, alpha=0.95, zorder=18)
uv_pred = _project_batch(
np.array([_paper_xy_to_ego(x, y) for x, y in pred_wps], dtype=np.float64),
R_c2e,
t_c2e,
K,
)
if uv_pred.shape[0] >= 2:
ax_front.plot(uv_pred[:, 0], uv_pred[:, 1], color="#1f5cff", linewidth=2.2, alpha=0.98, zorder=20)
ax_front.scatter(uv_pred[:, 0], uv_pred[:, 1], color="#1f5cff", s=8, zorder=21)
ax_bev = fig.add_subplot(gs[0, 1])
ax_bev.set_facecolor("white")
ax_bev.set_xlim(*args.bev_xlim)
ax_bev.set_ylim(*args.bev_ylim)
ax_bev.set_aspect("equal", adjustable="box")
ax_bev.set_xticks([])
ax_bev.set_yticks([])
for spine in ax_bev.spines.values():
spine.set_linewidth(1.3)
spine.set_color("black")
if bev_map is not None:
ny, nx = bev_map.shape
xs = np.linspace(args.bev_xlim[0], args.bev_xlim[1], nx)
ys = np.linspace(args.bev_ylim[0], args.bev_ylim[1], ny)
ax_bev.contour(xs, ys, bev_map, levels=[0.5], colors="black", linewidths=0.8, zorder=1)
construction_boxes = []
for b in boxes:
cat = str(b.get("category", ""))
wc = b.get("world_coords", [0, 0, 0])
if not (isinstance(wc, (list, tuple)) and len(wc) >= 2):
continue
cx, cy = float(wc[0]), float(wc[1])
if not (args.bev_xlim[0] - 2 <= cx <= args.bev_xlim[1] + 2 and args.bev_ylim[0] - 2 <= cy <= args.bev_ylim[1] + 2):
continue
box_item = (cx, cy, float(b.get("w", 1.8)), float(b.get("l", 4.0)), float(b.get("yaw", 0.0)))
if _is_barrier_like(cat):
construction_boxes.append(box_item)
for cx, cy, w, l, yaw in construction_boxes:
poly = _box_corners(cx, cy, w, l, yaw)
poly = np.vstack([poly, poly[0:1]])
ax_bev.plot(poly[:, 0], poly[:, 1], color="#cf7a6b", linewidth=0.9, alpha=0.95, zorder=3)
if gt_wps.shape[0] >= 2:
ax_bev.plot(gt_wps[:, 0], gt_wps[:, 1], color="#34c759", linewidth=3.4, alpha=0.95, zorder=5)
ax_bev.plot(pred_wps[:, 0], pred_wps[:, 1], color="#1f5cff", linewidth=1.8, alpha=0.98, zorder=6)
ego_rect = patches.Rectangle(
(-0.45, -0.45),
0.9,
0.9,
linewidth=1.4,
edgecolor="#34c759",
facecolor="none",
zorder=7,
)
ax_bev.add_patch(ego_rect)
ax_bev.text(0.03, 0.98, "BEV", transform=ax_bev.transAxes, ha="left", va="top", fontsize=9, fontweight="bold")
ax_bev.text(
0.03,
0.03,
_title_cmd(cmd),
transform=ax_bev.transAxes,
ha="left",
va="bottom",
fontsize=9,
fontweight="bold",
)
out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, bbox_inches="tight", facecolor="white", pad_inches=0.02)
plt.close(fig)
print(f"[saved] {out_png}")
print(f" construction boxes: {len(construction_boxes)}")
if __name__ == "__main__":
main()