WJAD / src /wjad /data /targets.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""检测 / 自车未来轨迹的目标构建。
依据 Cosmos-Drive-Dreams 数据集 README:
all_object_info JSON 中以 ``tracking_id`` 为 key,存储
``{object_to_world: 4x4, object_lwh: [l,w,h], object_is_moving: bool, object_type: str}``。
构建步骤:
1. 把每个对象的 ``object_to_world`` 转到 t 时刻自车系:
object_to_self = inv(vehicle_pose_t) @ object_to_world
2. 距离 ``≤ max_distance_m`` 过滤;
3. 投影中心点到当前帧像素,要求落在视锥内;
4. 用 LIDAR 深度对比做遮挡剔除(粗粒度);
5. 对动态目标,从 t+1..t+24 帧逐帧获取其 ``object_to_world``,转到 t 自车系,
提取 (dx, dy, dyaw) 并做 symlog 归一作为未来轨迹 GT;缺帧时 ``valid=0``。
为方便与 head 输出对齐,最终输出格式:
{"labels": [N], "boxes": [N, 7], "is_dynamic": [N],
"future_traj": [N, 24, 3], "future_valid": [N, 24]}
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import torch
from ..modules.normalization import symlog
from ..modules.rays import FThetaCamera
from .ftheta_proj import project_points_ftheta
from .se3 import invert_se3
@dataclass
class ObjectTrackInfo:
"""单个对象在某帧的简化记录。"""
tracking_id: str
object_to_world: torch.Tensor # [4, 4]
lwh: torch.Tensor # [3]
is_moving: bool
object_type: str
def _yaw_from_rotation_matrix(R: torch.Tensor) -> torch.Tensor:
"""从 3x3 旋转矩阵提取自车系下绕 z 轴的 yaw 角。
使用 ``atan2(R[1,0], R[0,0])``。
"""
return torch.atan2(R[..., 1, 0], R[..., 0, 0])
def _make_class_index(object_type: str, dynamic_classes: list[str], structured_classes: list[str], background_idx: int = 0) -> tuple[int, int]:
"""根据 object_type 字符串映射到 (class_index, is_dynamic)。"""
if object_type in dynamic_classes:
return dynamic_classes.index(object_type) + 1, 1 # +1 为 background 留 idx 0
if object_type in structured_classes:
return len(dynamic_classes) + structured_classes.index(object_type) + 1, 0
return background_idx, 0 # 未知类型当 background
def build_detection_targets(
objects_t: list[ObjectTrackInfo],
objects_future: list[list[ObjectTrackInfo]], # len = future_horizon,每帧一个对象列表
vehicle_pose_t: torch.Tensor, # [4, 4],vehicle to world
vehicle_pose_future: list[torch.Tensor], # 每帧一个 4x4
cam_intrinsic: FThetaCamera,
cam2vehicle: torch.Tensor, # [4, 4]
image_h: int,
image_w: int,
max_distance_m: float = 48.0,
occlusion_depth_tolerance: float = 0.5,
lidar_points_self: torch.Tensor | None = None, # [P, 3] in self frame,做粗遮挡
dynamic_classes: list[str] | None = None,
structured_classes: list[str] | None = None,
future_horizon: int = 24,
) -> dict:
"""构建一个样本的检测+未来轨迹标签。"""
if dynamic_classes is None:
dynamic_classes = []
if structured_classes is None:
structured_classes = []
inv_pose_t = invert_se3(vehicle_pose_t)
vehicle2cam = invert_se3(cam2vehicle)
labels: list[int] = []
boxes: list[list[float]] = []
is_dynamic: list[int] = []
future_traj: list[list[list[float]]] = []
future_valid: list[list[int]] = []
for obj in objects_t:
T_obj_self = inv_pose_t @ obj.object_to_world # [4,4]
center_self = T_obj_self[:3, 3]
dist = float(center_self.norm().item())
if dist > max_distance_m:
continue
# 视锥裁剪:把中心投影到相机系再投影到像素
center_cam = (vehicle2cam @ torch.cat([center_self, torch.ones(1)])[None].T).squeeze(-1)[:3]
if center_cam[2].item() <= 0:
continue
uv, depth = project_points_ftheta(center_cam.unsqueeze(0), cam_intrinsic)
u, v = uv[0, 0].item(), uv[0, 1].item()
if not (0 <= u < image_w and 0 <= v < image_h):
continue
# LIDAR 遮挡:找到 LIDAR 中靠近当前射线方向的最近点深度,与对象深度对比
if lidar_points_self is not None and lidar_points_self.numel() > 0:
ray = center_self / (center_self.norm() + 1e-6)
proj = lidar_points_self @ ray # [P]
# 选取沿射线方向投影距离接近 dist 的点(容差 1m,水平角 5°)
cosang = (lidar_points_self / (lidar_points_self.norm(dim=-1, keepdim=True) + 1e-6)) @ ray
mask = (cosang > 0.996) & (proj > 0)
if mask.any():
lidar_depth = proj[mask].min().item()
if lidar_depth + occlusion_depth_tolerance < dist:
# LIDAR 击中前方更近物体 -> 当前对象被遮挡
continue
# 类别映射
cls_idx, is_dyn = _make_class_index(obj.object_type, dynamic_classes, structured_classes)
if cls_idx == 0:
continue
labels.append(cls_idx)
is_dynamic.append(is_dyn)
yaw = _yaw_from_rotation_matrix(T_obj_self[:3, :3]).item()
l, w, h = obj.lwh.tolist()
# box 坐标 symlog 归一
x_n, y_n, z_n = (
float(symlog(center_self[0]).item()),
float(symlog(center_self[1]).item()),
float(symlog(center_self[2]).item()),
)
l_n = float(symlog(torch.tensor(l)).item())
w_n = float(symlog(torch.tensor(w)).item())
h_n = float(symlog(torch.tensor(h)).item())
boxes.append([x_n, y_n, z_n, l_n, w_n, h_n, yaw])
# 未来轨迹:在当前 self 系下用 (dx, dy, dyaw),相对 t 时刻对象自身
# 先取 t 时刻对象在 self 系下的 (x_t, y_t, yaw_t)
x0, y0, yaw0 = center_self[0].item(), center_self[1].item(), yaw
future_3 = []
future_v = []
for k in range(future_horizon):
if k >= len(objects_future) or k >= len(vehicle_pose_future):
future_3.append([0.0, 0.0, 0.0])
future_v.append(0)
continue
# 找对象在 t+k+1 帧
future_objs = objects_future[k]
match = next((o for o in future_objs if o.tracking_id == obj.tracking_id), None)
if match is None:
future_3.append([0.0, 0.0, 0.0])
future_v.append(0)
continue
T_obj_self_future = invert_se3(vehicle_pose_t) @ match.object_to_world
xf = T_obj_self_future[0, 3].item()
yf = T_obj_self_future[1, 3].item()
yawf = _yaw_from_rotation_matrix(T_obj_self_future[:3, :3]).item()
dx = xf - x0
dy = yf - y0
dyaw = yawf - yaw0
# 角度归到 (-pi, pi]
dyaw = (dyaw + np.pi) % (2 * np.pi) - np.pi
future_3.append([
float(symlog(torch.tensor(dx)).item()),
float(symlog(torch.tensor(dy)).item()),
float(dyaw),
])
future_v.append(1)
future_traj.append(future_3)
future_valid.append(future_v)
if not labels:
return {
"labels": torch.zeros(0, dtype=torch.long),
"boxes": torch.zeros(0, 7),
"is_dynamic": torch.zeros(0, dtype=torch.long),
"future_traj": torch.zeros(0, future_horizon, 3),
"future_valid": torch.zeros(0, future_horizon, dtype=torch.bool),
}
return {
"labels": torch.tensor(labels, dtype=torch.long),
"boxes": torch.tensor(boxes, dtype=torch.float32),
"is_dynamic": torch.tensor(is_dynamic, dtype=torch.long),
"future_traj": torch.tensor(future_traj, dtype=torch.float32),
"future_valid": torch.tensor(future_valid, dtype=torch.bool),
}
def build_ego_future_target(
vehicle_pose_t: torch.Tensor,
vehicle_pose_future: list[torch.Tensor],
horizon: int = 24,
) -> tuple[torch.Tensor, torch.Tensor]:
"""自车未来 24 帧轨迹(在 t 自车系下,``(x, y, yaw)`` 已 symlog 归一)。"""
inv_t = invert_se3(vehicle_pose_t)
out = torch.zeros(horizon, 3)
valid = torch.zeros(horizon, dtype=torch.bool)
for k in range(horizon):
if k >= len(vehicle_pose_future):
break
rel = inv_t @ vehicle_pose_future[k]
x, y = rel[0, 3].item(), rel[1, 3].item()
yaw = _yaw_from_rotation_matrix(rel[:3, :3]).item()
out[k, 0] = symlog(torch.tensor(x))
out[k, 1] = symlog(torch.tensor(y))
out[k, 2] = yaw
valid[k] = True
return out, valid