WJAD / src /wjad /data /hdmap.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""HDMap 3D 标签解析(Cosmos-Drive-Dreams 9 类结构化对象)。
输入:clip 标签目录(``labels/{clip_id_full}/``)。
输出:``list[ObjectTrackInfo]``,每个对象给出 ``object_to_world`` 4x4 + ``lwh``,
``object_type`` 取自 ``HDMAP_SOURCES`` 的 9 类,``is_moving=False``。
形状约定(按 README):
- 3d_lanes / lanes.json
labels[i]['labelData']['shape3d']['polylines3d']['polylines'][0/1]['vertices']
- 3d_lanelines / lanelines.json
labels[i]['labelData']['shape3d']['polyline3d']['vertices']
- 3d_road_boundaries / road_boundaries.json 同 polyline3d
- 3d_wait_lines / wait_lines.json 同 polyline3d
- 3d_crosswalks / crosswalks.json
labels[i]['labelData']['shape3d']['surface']['vertices']
- 3d_road_markings / road_markings.json 同 surface
- 3d_poles / poles.json 同 polyline3d
- 3d_traffic_lights / 3d_traffic_lights.json
labels[i]['labelData']['shape3d']['cuboid3d']['vertices'] # 8 角点
- 3d_traffic_signs / 3d_traffic_signs.json 同 cuboid3d
折线 → 7-DoF box:
PCA 主方向作 yaw,主/副/竖三向 min-max 作 ``l/w/h``;过长 polyline 按累计
弧长切成若干 ``segment_len`` 米的小段,每段一个独立 box(车道线一段太长会
超出 max_distance_m,DETR query 也很难一次拟合一整条 100 m 车道线)。
"""
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import torch
from .targets import ObjectTrackInfo
from .label_paths import resolve_clip_file
# 折线类长度切分阈值(米)
POLYLINE_SEGMENT_LEN = 10.0
LANE_SEGMENT_LEN = 15.0 # lanes 是一对左右 polyline,整体粗一点
MIN_LWH = (0.2, 0.2, 0.05)
# cls_name -> (folder, json_name, kind)
HDMAP_SOURCES = {
"lane": ("3d_lanes", "lanes.json", "lane_pair"),
"laneline": ("3d_lanelines", "lanelines.json", "polyline"),
"road_boundary": ("3d_road_boundaries", "road_boundaries.json", "polyline"),
"wait_line": ("3d_wait_lines", "wait_lines.json", "polyline"),
"crosswalk": ("3d_crosswalks", "crosswalks.json", "surface"),
"road_marking": ("3d_road_markings", "road_markings.json", "surface"),
"pole": ("3d_poles", "poles.json", "polyline_short"),
# 磁盘文件名为 ``{clip_stem}.traffic_lights.json``(非 README 里的 3d_*.json)
"traffic_light": ("3d_traffic_lights", "traffic_lights.json", "cuboid"),
"traffic_sign": ("3d_traffic_signs", "traffic_signs.json", "cuboid"),
}
def _load_json_labels(path: Path) -> list:
"""容错读取:JSON 顶层可能是 ``{labels: ...}`` 或 ``{<filename>: {labels: ...}}``。"""
if not path.exists():
return []
try:
data = json.loads(path.read_text(encoding="utf-8"))
except Exception:
return []
if isinstance(data, dict):
if isinstance(data.get("labels"), list):
return data["labels"]
for v in data.values():
if isinstance(v, dict) and isinstance(v.get("labels"), list):
return v["labels"]
return []
def _verts_to_array(verts) -> np.ndarray:
"""vertices 兼容 ``list[[x,y,z]]`` 与 ``list[{x,y,z}]`` 两种格式。"""
if not verts:
return np.zeros((0, 3), dtype=np.float32)
out: list[list[float]] = []
for v in verts:
if isinstance(v, dict):
out.append([float(v.get("x", 0.0)), float(v.get("y", 0.0)), float(v.get("z", 0.0))])
elif isinstance(v, (list, tuple)) and len(v) >= 3:
out.append([float(v[0]), float(v[1]), float(v[2])])
return np.array(out, dtype=np.float32) if out else np.zeros((0, 3), dtype=np.float32)
def _split_polyline(verts: np.ndarray, seg_len: float) -> list[np.ndarray]:
"""按累计弧长把折线切成若干段。每段顶点数 >=2。"""
if verts.shape[0] < 2:
return []
edges = np.linalg.norm(np.diff(verts, axis=0), axis=1)
cum = np.concatenate([[0.0], np.cumsum(edges)])
total = float(cum[-1])
if total <= seg_len:
return [verts]
n = max(1, int(np.ceil(total / seg_len)))
bounds = np.linspace(0.0, total, n + 1)
chunks: list[np.ndarray] = []
for i in range(n):
lo, hi = bounds[i], bounds[i + 1]
mask = (cum >= lo - 1e-6) & (cum <= hi + 1e-6)
chunk = verts[mask]
if chunk.shape[0] >= 2:
chunks.append(chunk)
return chunks
def _vertices_to_box(verts: np.ndarray) -> tuple[np.ndarray, np.ndarray, float] | None:
"""[N, 3] -> (center, lwh, yaw)。"""
if verts.shape[0] < 2:
return None
center = verts.mean(0)
centered_xy = verts[:, :2] - center[:2]
if np.allclose(centered_xy, 0.0):
yaw = 0.0
else:
cov = centered_xy.T @ centered_xy / max(verts.shape[0] - 1, 1)
_, eigvecs = np.linalg.eigh(cov)
principal = eigvecs[:, -1]
yaw = float(np.arctan2(principal[1], principal[0]))
c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
rot_xy = centered_xy @ np.array([[c, -s], [s, c]], dtype=np.float32).T
l = float(rot_xy[:, 0].max() - rot_xy[:, 0].min())
w = float(rot_xy[:, 1].max() - rot_xy[:, 1].min())
h = float(verts[:, 2].max() - verts[:, 2].min())
l = max(l, MIN_LWH[0])
w = max(w, MIN_LWH[1])
h = max(h, MIN_LWH[2])
return center.astype(np.float32), np.array([l, w, h], dtype=np.float32), yaw
def _cuboid_to_box(corners: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]:
"""8 角点 -> (center, lwh, yaw)。用 corner[0]→corner[1] 估计 yaw。"""
center = corners.mean(0)
edge = corners[1] - corners[0]
yaw = float(np.arctan2(edge[1], edge[0]))
c, s = float(np.cos(-yaw)), float(np.sin(-yaw))
R = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
rot = (corners - center) @ R.T
lwh = (rot.max(0) - rot.min(0)).astype(np.float32)
lwh = np.maximum(lwh, np.array(MIN_LWH, dtype=np.float32))
return center.astype(np.float32), lwh, yaw
def _build_object(
center: np.ndarray,
lwh: np.ndarray,
yaw: float,
cls_name: str,
idx: int,
sub_idx: int = 0,
) -> ObjectTrackInfo:
T = np.eye(4, dtype=np.float32)
c, s = float(np.cos(yaw)), float(np.sin(yaw))
T[:3, :3] = np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
T[:3, 3] = center
return ObjectTrackInfo(
tracking_id=f"hdmap_{cls_name}_{idx}_{sub_idx}",
object_to_world=torch.from_numpy(T),
lwh=torch.from_numpy(lwh),
is_moving=False,
object_type=cls_name,
)
def parse_hdmap_clip(
clip_label_dir: Path,
segment_len: float = POLYLINE_SEGMENT_LEN,
lane_segment_len: float = LANE_SEGMENT_LEN,
) -> list[ObjectTrackInfo]:
"""解析一个 clip 的 9 类 HDMap,展开为 world-frame ``ObjectTrackInfo`` 列表。"""
out: list[ObjectTrackInfo] = []
for cls_name, (subdir, json_name, kind) in HDMAP_SOURCES.items():
try:
path = resolve_clip_file(clip_label_dir, subdir, json_name)
except FileNotFoundError:
continue
labels = _load_json_labels(path)
for i, lbl in enumerate(labels):
if not isinstance(lbl, dict):
continue
shape = lbl.get("labelData", {}).get("shape3d", {})
if not isinstance(shape, dict):
continue
if kind == "cuboid":
verts = shape.get("cuboid3d", {}).get("vertices", [])
arr = _verts_to_array(verts)
if arr.shape[0] != 8:
continue
c, lwh, yaw = _cuboid_to_box(arr)
out.append(_build_object(c, lwh, yaw, cls_name, i))
elif kind == "surface":
verts = shape.get("surface", {}).get("vertices", [])
arr = _verts_to_array(verts)
if arr.shape[0] < 3:
continue
box = _vertices_to_box(arr)
if box is not None:
out.append(_build_object(*box, cls_name, i))
elif kind == "polyline":
verts = shape.get("polyline3d", {}).get("vertices", [])
arr = _verts_to_array(verts)
if arr.shape[0] < 2:
continue
for j, chunk in enumerate(_split_polyline(arr, segment_len)):
box = _vertices_to_box(chunk)
if box is not None:
out.append(_build_object(*box, cls_name, i, j))
elif kind == "polyline_short":
# 杆状物体不切分
verts = shape.get("polyline3d", {}).get("vertices", [])
arr = _verts_to_array(verts)
if arr.shape[0] < 2:
continue
box = _vertices_to_box(arr)
if box is not None:
out.append(_build_object(*box, cls_name, i))
elif kind == "lane_pair":
pl_root = shape.get("polylines3d", {}).get("polylines", [])
if not isinstance(pl_root, list) or len(pl_root) < 2:
continue
left = _verts_to_array(
pl_root[0].get("vertices", []) if isinstance(pl_root[0], dict) else []
)
right = _verts_to_array(
pl_root[1].get("vertices", []) if isinstance(pl_root[1], dict) else []
)
if left.shape[0] == 0 and right.shape[0] == 0:
continue
merged = np.concatenate([a for a in (left, right) if a.shape[0]], axis=0)
if merged.shape[0] < 2:
continue
for j, chunk in enumerate(_split_polyline(merged, lane_segment_len)):
box = _vertices_to_box(chunk)
if box is not None:
out.append(_build_object(*box, cls_name, i, j))
return out