"""检测 / 自车未来轨迹的目标构建。 依据 Cosmos-Drive-Dreams 数据集 README: all_object_info JSON 中以 ``tracking_id`` 为 key,存储 ``{object_to_world: 4x4, object_lwh: [l,w,h], object_is_moving: bool, object_type: str}``。 构建步骤: 1. 把每个对象的 ``object_to_world`` 转到 t 时刻自车系: object_to_self = inv(vehicle_pose_t) @ object_to_world 2. 距离 ``≤ max_distance_m`` 过滤; 3. 投影中心点到当前帧像素,要求落在视锥内; 4. 用 LIDAR 深度对比做遮挡剔除(粗粒度); 5. 对动态目标,从 t+1..t+24 帧逐帧获取其 ``object_to_world``,转到 t 自车系, 提取 (dx, dy, dyaw) 并做 symlog 归一作为未来轨迹 GT;缺帧时 ``valid=0``。 为方便与 head 输出对齐,最终输出格式: {"labels": [N], "boxes": [N, 7], "is_dynamic": [N], "future_traj": [N, 24, 3], "future_valid": [N, 24]} """ from __future__ import annotations from dataclasses import dataclass import numpy as np import torch from ..modules.normalization import symlog from ..modules.rays import FThetaCamera from .ftheta_proj import project_points_ftheta from .se3 import invert_se3 @dataclass class ObjectTrackInfo: """单个对象在某帧的简化记录。""" tracking_id: str object_to_world: torch.Tensor # [4, 4] lwh: torch.Tensor # [3] is_moving: bool object_type: str def _yaw_from_rotation_matrix(R: torch.Tensor) -> torch.Tensor: """从 3x3 旋转矩阵提取自车系下绕 z 轴的 yaw 角。 使用 ``atan2(R[1,0], R[0,0])``。 """ return torch.atan2(R[..., 1, 0], R[..., 0, 0]) def _make_class_index(object_type: str, dynamic_classes: list[str], structured_classes: list[str], background_idx: int = 0) -> tuple[int, int]: """根据 object_type 字符串映射到 (class_index, is_dynamic)。""" if object_type in dynamic_classes: return dynamic_classes.index(object_type) + 1, 1 # +1 为 background 留 idx 0 if object_type in structured_classes: return len(dynamic_classes) + structured_classes.index(object_type) + 1, 0 return background_idx, 0 # 未知类型当 background def build_detection_targets( objects_t: list[ObjectTrackInfo], objects_future: list[list[ObjectTrackInfo]], # len = future_horizon,每帧一个对象列表 vehicle_pose_t: torch.Tensor, # [4, 4],vehicle to world vehicle_pose_future: list[torch.Tensor], # 每帧一个 4x4 cam_intrinsic: FThetaCamera, cam2vehicle: torch.Tensor, # [4, 4] image_h: int, image_w: int, max_distance_m: float = 48.0, occlusion_depth_tolerance: float = 0.5, lidar_points_self: torch.Tensor | None = None, # [P, 3] in self frame,做粗遮挡 dynamic_classes: list[str] | None = None, structured_classes: list[str] | None = None, future_horizon: int = 24, ) -> dict: """构建一个样本的检测+未来轨迹标签。""" if dynamic_classes is None: dynamic_classes = [] if structured_classes is None: structured_classes = [] inv_pose_t = invert_se3(vehicle_pose_t) vehicle2cam = invert_se3(cam2vehicle) labels: list[int] = [] boxes: list[list[float]] = [] is_dynamic: list[int] = [] future_traj: list[list[list[float]]] = [] future_valid: list[list[int]] = [] for obj in objects_t: T_obj_self = inv_pose_t @ obj.object_to_world # [4,4] center_self = T_obj_self[:3, 3] dist = float(center_self.norm().item()) if dist > max_distance_m: continue # 视锥裁剪:把中心投影到相机系再投影到像素 center_cam = (vehicle2cam @ torch.cat([center_self, torch.ones(1)])[None].T).squeeze(-1)[:3] if center_cam[2].item() <= 0: continue uv, depth = project_points_ftheta(center_cam.unsqueeze(0), cam_intrinsic) u, v = uv[0, 0].item(), uv[0, 1].item() if not (0 <= u < image_w and 0 <= v < image_h): continue # LIDAR 遮挡:找到 LIDAR 中靠近当前射线方向的最近点深度,与对象深度对比 if lidar_points_self is not None and lidar_points_self.numel() > 0: ray = center_self / (center_self.norm() + 1e-6) proj = lidar_points_self @ ray # [P] # 选取沿射线方向投影距离接近 dist 的点(容差 1m,水平角 5°) cosang = (lidar_points_self / (lidar_points_self.norm(dim=-1, keepdim=True) + 1e-6)) @ ray mask = (cosang > 0.996) & (proj > 0) if mask.any(): lidar_depth = proj[mask].min().item() if lidar_depth + occlusion_depth_tolerance < dist: # LIDAR 击中前方更近物体 -> 当前对象被遮挡 continue # 类别映射 cls_idx, is_dyn = _make_class_index(obj.object_type, dynamic_classes, structured_classes) if cls_idx == 0: continue labels.append(cls_idx) is_dynamic.append(is_dyn) yaw = _yaw_from_rotation_matrix(T_obj_self[:3, :3]).item() l, w, h = obj.lwh.tolist() # box 坐标 symlog 归一 x_n, y_n, z_n = ( float(symlog(center_self[0]).item()), float(symlog(center_self[1]).item()), float(symlog(center_self[2]).item()), ) l_n = float(symlog(torch.tensor(l)).item()) w_n = float(symlog(torch.tensor(w)).item()) h_n = float(symlog(torch.tensor(h)).item()) boxes.append([x_n, y_n, z_n, l_n, w_n, h_n, yaw]) # 未来轨迹:在当前 self 系下用 (dx, dy, dyaw),相对 t 时刻对象自身 # 先取 t 时刻对象在 self 系下的 (x_t, y_t, yaw_t) x0, y0, yaw0 = center_self[0].item(), center_self[1].item(), yaw future_3 = [] future_v = [] for k in range(future_horizon): if k >= len(objects_future) or k >= len(vehicle_pose_future): future_3.append([0.0, 0.0, 0.0]) future_v.append(0) continue # 找对象在 t+k+1 帧 future_objs = objects_future[k] match = next((o for o in future_objs if o.tracking_id == obj.tracking_id), None) if match is None: future_3.append([0.0, 0.0, 0.0]) future_v.append(0) continue T_obj_self_future = invert_se3(vehicle_pose_t) @ match.object_to_world xf = T_obj_self_future[0, 3].item() yf = T_obj_self_future[1, 3].item() yawf = _yaw_from_rotation_matrix(T_obj_self_future[:3, :3]).item() dx = xf - x0 dy = yf - y0 dyaw = yawf - yaw0 # 角度归到 (-pi, pi] dyaw = (dyaw + np.pi) % (2 * np.pi) - np.pi future_3.append([ float(symlog(torch.tensor(dx)).item()), float(symlog(torch.tensor(dy)).item()), float(dyaw), ]) future_v.append(1) future_traj.append(future_3) future_valid.append(future_v) if not labels: return { "labels": torch.zeros(0, dtype=torch.long), "boxes": torch.zeros(0, 7), "is_dynamic": torch.zeros(0, dtype=torch.long), "future_traj": torch.zeros(0, future_horizon, 3), "future_valid": torch.zeros(0, future_horizon, dtype=torch.bool), } return { "labels": torch.tensor(labels, dtype=torch.long), "boxes": torch.tensor(boxes, dtype=torch.float32), "is_dynamic": torch.tensor(is_dynamic, dtype=torch.long), "future_traj": torch.tensor(future_traj, dtype=torch.float32), "future_valid": torch.tensor(future_valid, dtype=torch.bool), } def build_ego_future_target( vehicle_pose_t: torch.Tensor, vehicle_pose_future: list[torch.Tensor], horizon: int = 24, ) -> tuple[torch.Tensor, torch.Tensor]: """自车未来 24 帧轨迹(在 t 自车系下,``(x, y, yaw)`` 已 symlog 归一)。""" inv_t = invert_se3(vehicle_pose_t) out = torch.zeros(horizon, 3) valid = torch.zeros(horizon, dtype=torch.bool) for k in range(horizon): if k >= len(vehicle_pose_future): break rel = inv_t @ vehicle_pose_future[k] x, y = rel[0, 3].item(), rel[1, 3].item() yaw = _yaw_from_rotation_matrix(rel[:3, :3]).item() out[k, 0] = symlog(torch.tensor(x)) out[k, 1] = symlog(torch.tensor(y)) out[k, 2] = yaw valid[k] = True return out, valid