File size: 8,680 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""检测 / 自车未来轨迹的目标构建。

依据 Cosmos-Drive-Dreams 数据集 README:
    all_object_info JSON 中以 ``tracking_id`` 为 key,存储
    ``{object_to_world: 4x4, object_lwh: [l,w,h], object_is_moving: bool, object_type: str}``。

构建步骤:
1. 把每个对象的 ``object_to_world`` 转到 t 时刻自车系:
       object_to_self = inv(vehicle_pose_t) @ object_to_world
2. 距离 ``≤ max_distance_m`` 过滤;
3. 投影中心点到当前帧像素,要求落在视锥内;
4. 用 LIDAR 深度对比做遮挡剔除(粗粒度);
5. 对动态目标,从 t+1..t+24 帧逐帧获取其 ``object_to_world``,转到 t 自车系,
   提取 (dx, dy, dyaw) 并做 symlog 归一作为未来轨迹 GT;缺帧时 ``valid=0``。

为方便与 head 输出对齐,最终输出格式:
    {"labels": [N], "boxes": [N, 7], "is_dynamic": [N],
     "future_traj": [N, 24, 3], "future_valid": [N, 24]}
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
import torch

from ..modules.normalization import symlog
from ..modules.rays import FThetaCamera
from .ftheta_proj import project_points_ftheta
from .se3 import invert_se3


@dataclass
class ObjectTrackInfo:
    """单个对象在某帧的简化记录。"""

    tracking_id: str
    object_to_world: torch.Tensor   # [4, 4]
    lwh: torch.Tensor               # [3]
    is_moving: bool
    object_type: str


def _yaw_from_rotation_matrix(R: torch.Tensor) -> torch.Tensor:
    """从 3x3 旋转矩阵提取自车系下绕 z 轴的 yaw 角。

    使用 ``atan2(R[1,0], R[0,0])``。
    """
    return torch.atan2(R[..., 1, 0], R[..., 0, 0])


def _make_class_index(object_type: str, dynamic_classes: list[str], structured_classes: list[str], background_idx: int = 0) -> tuple[int, int]:
    """根据 object_type 字符串映射到 (class_index, is_dynamic)。"""
    if object_type in dynamic_classes:
        return dynamic_classes.index(object_type) + 1, 1  # +1 为 background 留 idx 0
    if object_type in structured_classes:
        return len(dynamic_classes) + structured_classes.index(object_type) + 1, 0
    return background_idx, 0  # 未知类型当 background


def build_detection_targets(
    objects_t: list[ObjectTrackInfo],
    objects_future: list[list[ObjectTrackInfo]],   # len = future_horizon,每帧一个对象列表
    vehicle_pose_t: torch.Tensor,                  # [4, 4],vehicle to world
    vehicle_pose_future: list[torch.Tensor],       # 每帧一个 4x4
    cam_intrinsic: FThetaCamera,
    cam2vehicle: torch.Tensor,                     # [4, 4]
    image_h: int,
    image_w: int,
    max_distance_m: float = 48.0,
    occlusion_depth_tolerance: float = 0.5,
    lidar_points_self: torch.Tensor | None = None,  # [P, 3] in self frame,做粗遮挡
    dynamic_classes: list[str] | None = None,
    structured_classes: list[str] | None = None,
    future_horizon: int = 24,
) -> dict:
    """构建一个样本的检测+未来轨迹标签。"""
    if dynamic_classes is None:
        dynamic_classes = []
    if structured_classes is None:
        structured_classes = []

    inv_pose_t = invert_se3(vehicle_pose_t)
    vehicle2cam = invert_se3(cam2vehicle)

    labels: list[int] = []
    boxes: list[list[float]] = []
    is_dynamic: list[int] = []
    future_traj: list[list[list[float]]] = []
    future_valid: list[list[int]] = []

    for obj in objects_t:
        T_obj_self = inv_pose_t @ obj.object_to_world  # [4,4]
        center_self = T_obj_self[:3, 3]

        dist = float(center_self.norm().item())
        if dist > max_distance_m:
            continue

        # 视锥裁剪:把中心投影到相机系再投影到像素
        center_cam = (vehicle2cam @ torch.cat([center_self, torch.ones(1)])[None].T).squeeze(-1)[:3]
        if center_cam[2].item() <= 0:
            continue
        uv, depth = project_points_ftheta(center_cam.unsqueeze(0), cam_intrinsic)
        u, v = uv[0, 0].item(), uv[0, 1].item()
        if not (0 <= u < image_w and 0 <= v < image_h):
            continue

        # LIDAR 遮挡:找到 LIDAR 中靠近当前射线方向的最近点深度,与对象深度对比
        if lidar_points_self is not None and lidar_points_self.numel() > 0:
            ray = center_self / (center_self.norm() + 1e-6)
            proj = lidar_points_self @ ray  # [P]
            # 选取沿射线方向投影距离接近 dist 的点(容差 1m,水平角 5°)
            cosang = (lidar_points_self / (lidar_points_self.norm(dim=-1, keepdim=True) + 1e-6)) @ ray
            mask = (cosang > 0.996) & (proj > 0)
            if mask.any():
                lidar_depth = proj[mask].min().item()
                if lidar_depth + occlusion_depth_tolerance < dist:
                    # LIDAR 击中前方更近物体 -> 当前对象被遮挡
                    continue

        # 类别映射
        cls_idx, is_dyn = _make_class_index(obj.object_type, dynamic_classes, structured_classes)
        if cls_idx == 0:
            continue
        labels.append(cls_idx)
        is_dynamic.append(is_dyn)

        yaw = _yaw_from_rotation_matrix(T_obj_self[:3, :3]).item()
        l, w, h = obj.lwh.tolist()
        # box 坐标 symlog 归一
        x_n, y_n, z_n = (
            float(symlog(center_self[0]).item()),
            float(symlog(center_self[1]).item()),
            float(symlog(center_self[2]).item()),
        )
        l_n = float(symlog(torch.tensor(l)).item())
        w_n = float(symlog(torch.tensor(w)).item())
        h_n = float(symlog(torch.tensor(h)).item())
        boxes.append([x_n, y_n, z_n, l_n, w_n, h_n, yaw])

        # 未来轨迹:在当前 self 系下用 (dx, dy, dyaw),相对 t 时刻对象自身
        # 先取 t 时刻对象在 self 系下的 (x_t, y_t, yaw_t)
        x0, y0, yaw0 = center_self[0].item(), center_self[1].item(), yaw
        future_3 = []
        future_v = []
        for k in range(future_horizon):
            if k >= len(objects_future) or k >= len(vehicle_pose_future):
                future_3.append([0.0, 0.0, 0.0])
                future_v.append(0)
                continue
            # 找对象在 t+k+1 帧
            future_objs = objects_future[k]
            match = next((o for o in future_objs if o.tracking_id == obj.tracking_id), None)
            if match is None:
                future_3.append([0.0, 0.0, 0.0])
                future_v.append(0)
                continue
            T_obj_self_future = invert_se3(vehicle_pose_t) @ match.object_to_world
            xf = T_obj_self_future[0, 3].item()
            yf = T_obj_self_future[1, 3].item()
            yawf = _yaw_from_rotation_matrix(T_obj_self_future[:3, :3]).item()
            dx = xf - x0
            dy = yf - y0
            dyaw = yawf - yaw0
            # 角度归到 (-pi, pi]
            dyaw = (dyaw + np.pi) % (2 * np.pi) - np.pi
            future_3.append([
                float(symlog(torch.tensor(dx)).item()),
                float(symlog(torch.tensor(dy)).item()),
                float(dyaw),
            ])
            future_v.append(1)
        future_traj.append(future_3)
        future_valid.append(future_v)

    if not labels:
        return {
            "labels": torch.zeros(0, dtype=torch.long),
            "boxes": torch.zeros(0, 7),
            "is_dynamic": torch.zeros(0, dtype=torch.long),
            "future_traj": torch.zeros(0, future_horizon, 3),
            "future_valid": torch.zeros(0, future_horizon, dtype=torch.bool),
        }
    return {
        "labels": torch.tensor(labels, dtype=torch.long),
        "boxes": torch.tensor(boxes, dtype=torch.float32),
        "is_dynamic": torch.tensor(is_dynamic, dtype=torch.long),
        "future_traj": torch.tensor(future_traj, dtype=torch.float32),
        "future_valid": torch.tensor(future_valid, dtype=torch.bool),
    }


def build_ego_future_target(
    vehicle_pose_t: torch.Tensor,
    vehicle_pose_future: list[torch.Tensor],
    horizon: int = 24,
) -> tuple[torch.Tensor, torch.Tensor]:
    """自车未来 24 帧轨迹(在 t 自车系下,``(x, y, yaw)`` 已 symlog 归一)。"""
    inv_t = invert_se3(vehicle_pose_t)
    out = torch.zeros(horizon, 3)
    valid = torch.zeros(horizon, dtype=torch.bool)
    for k in range(horizon):
        if k >= len(vehicle_pose_future):
            break
        rel = inv_t @ vehicle_pose_future[k]
        x, y = rel[0, 3].item(), rel[1, 3].item()
        yaw = _yaw_from_rotation_matrix(rel[:3, :3]).item()
        out[k, 0] = symlog(torch.tensor(x))
        out[k, 1] = symlog(torch.tensor(y))
        out[k, 2] = yaw
        valid[k] = True
    return out, valid