File size: 8,102 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""3D RoPE(仅作用于视觉 token)。

12 头按 4+4+4 拆为三组:
  - 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。
  - 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。
  - 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。

为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else),
零频段头自然变为恒等映射。

将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按
3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。
"""

from __future__ import annotations

import torch
import torch.nn as nn


def _split_head_dim_for_components(half: int, num_components: int) -> list[int]:
    """把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。

    返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。
    若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后
    一项视为"零频段"——它的频率会被显式置为 0。
    """
    if num_components == 0:
        return [0, half]
    base = half // num_components
    splits = [base] * num_components
    splits[-1] += half - base * num_components  # 余数全归到最后一个分量
    return splits


def build_rope_freqs(
    rays: torch.Tensor,
    hwt_grid: torch.Tensor,
    num_heads: int = 12,
    head_dim: int = 64,
    rope_theta: float = 10000.0,
    device: torch.device | None = None,
    dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
    """构造 3D RoPE 的 cos / sin 表。

    参数
    ----
    rays : Tensor, shape ``[B, N_v, 3]``
        每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。
    hwt_grid : Tensor, shape ``[B, N_v, 3]``
        归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。
    num_heads : int
        总头数(默认 12)。
    head_dim : int
        每头维度(默认 64,必须为偶数)。

    返回
    ----
    cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]``
        每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。
    """
    assert head_dim % 2 == 0, "head_dim 必须为偶数"
    assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组"

    half = head_dim // 2
    heads_per_group = num_heads // 3
    bsz, n_v, _ = rays.shape
    if device is None:
        device = rays.device

    # === 三组分量值 ===
    # group 0: rays (3 components)
    # group 1: hwt  (3 components)
    # group 2: zero (0 components -> 全部 half 视为零频段)
    splits_g0 = _split_head_dim_for_components(half, 3)  # 用于 rays
    splits_g1 = _split_head_dim_for_components(half, 3)  # 用于 hwt
    splits_g2 = _split_head_dim_for_components(half, 0)  # [0, half]

    # === 频率向量(沿 head_dim 半轴)===
    # 经典 RoPE: theta_i = base ^ (-2i / d)
    #   这里我们对每个分量独立排布频率
    def _freqs(num_pairs: int) -> torch.Tensor:
        # 前 num_pairs 个用 RoPE 频率,剩余补 0
        idx = torch.arange(num_pairs, device=device, dtype=dtype)
        freqs = rope_theta ** (-2.0 * idx / head_dim)
        return freqs  # [num_pairs]

    # 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half]
    angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype)

    # ---- 第 0 组(4 头):射线 ----
    base_offset = 0
    h0_start = 0
    h0_end = h0_start + heads_per_group
    cursor = 0
    for c in range(3):  # dx, dy, dz
        n_pairs = splits_g0[c]
        if n_pairs > 0:
            f = _freqs(n_pairs)  # [n_pairs]
            comp_val = rays[..., c : c + 1]  # [B, N_v, 1]
            ang = comp_val * f  # 广播 -> [B, N_v, n_pairs]
            angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
        cursor += n_pairs
    # 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0

    # ---- 第 1 组(4 头):HWT ----
    h1_start = heads_per_group
    h1_end = h1_start + heads_per_group
    cursor = 0
    for c in range(3):  # h, w, t
        n_pairs = splits_g1[c]
        if n_pairs > 0:
            f = _freqs(n_pairs)
            comp_val = hwt_grid[..., c : c + 1]
            ang = comp_val * f
            angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
        cursor += n_pairs

    # ---- 第 2 组(4 头):零频段 ----
    # 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零)

    cos = torch.cos(angles)
    sin = torch.sin(angles)
    return cos, sin


def apply_rope(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。

    所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。

    参数
    ----
    q, k : Tensor, shape ``[B, H, N_v, head_dim]``
    cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]``

    返回
    ----
    旋转后的 q, k,形状不变。
    """
    # 把 cos/sin 转成 [B, H, N_v, half]
    cos_e = cos.permute(0, 2, 1, 3)
    sin_e = sin.permute(0, 2, 1, 3)

    # 把 head_dim 维度按 (even, odd) 拆开成 [..., half]
    q_even = q[..., 0::2]
    q_odd = q[..., 1::2]
    k_even = k[..., 0::2]
    k_odd = k[..., 1::2]

    q_rot_even = q_even * cos_e - q_odd * sin_e
    q_rot_odd = q_even * sin_e + q_odd * cos_e
    k_rot_even = k_even * cos_e - k_odd * sin_e
    k_rot_odd = k_even * sin_e + k_odd * cos_e

    q_out = torch.empty_like(q)
    k_out = torch.empty_like(k)
    q_out[..., 0::2] = q_rot_even
    q_out[..., 1::2] = q_rot_odd
    k_out[..., 0::2] = k_rot_even
    k_out[..., 1::2] = k_rot_odd
    return q_out, k_out


class RoPE3D(nn.Module):
    """3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。

    使用方式:
        rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32)
        cos, sin = rope.compute_freqs(rays)            # rays: [B, N_v, 3]
        q, k = apply_rope(q_visual_only, k_visual_only, cos, sin)
    """

    def __init__(
        self,
        num_heads: int = 12,
        head_dim: int = 64,
        time_size: int = 4,
        height_size: int = 12,
        width_size: int = 32,
        rope_theta: float = 10000.0,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.rope_theta = rope_theta
        self.T = time_size
        self.H = height_size
        self.W = width_size

        # 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W
        t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1)
        h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1)
        w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1)
        # 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致)
        T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij")
        hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3)  # [N_v, 3]
        self.register_buffer("hwt_grid", hwt, persistent=False)

    @property
    def num_visual_tokens(self) -> int:
        return self.T * self.H * self.W

    def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """根据每 token 的射线方向计算 cos/sin。

        ``rays`` shape: ``[B, N_v, 3]``。
        """
        bsz = rays.shape[0]
        hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1)  # [B, N_v, 3]
        return build_rope_freqs(
            rays=rays,
            hwt_grid=hwt,
            num_heads=self.num_heads,
            head_dim=self.head_dim,
            rope_theta=self.rope_theta,
            dtype=rays.dtype,
        )