| """3D RoPE(仅作用于视觉 token)。 |
| |
| 12 头按 4+4+4 拆为三组: |
| - 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。 |
| - 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。 |
| - 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。 |
| |
| 为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else), |
| 零频段头自然变为恒等映射。 |
| |
| 将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按 |
| 3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def _split_head_dim_for_components(half: int, num_components: int) -> list[int]: |
| """把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。 |
| |
| 返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。 |
| 若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后 |
| 一项视为"零频段"——它的频率会被显式置为 0。 |
| """ |
| if num_components == 0: |
| return [0, half] |
| base = half // num_components |
| splits = [base] * num_components |
| splits[-1] += half - base * num_components |
| return splits |
|
|
|
|
| def build_rope_freqs( |
| rays: torch.Tensor, |
| hwt_grid: torch.Tensor, |
| num_heads: int = 12, |
| head_dim: int = 64, |
| rope_theta: float = 10000.0, |
| device: torch.device | None = None, |
| dtype: torch.dtype = torch.float32, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """构造 3D RoPE 的 cos / sin 表。 |
| |
| 参数 |
| ---- |
| rays : Tensor, shape ``[B, N_v, 3]`` |
| 每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。 |
| hwt_grid : Tensor, shape ``[B, N_v, 3]`` |
| 归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。 |
| num_heads : int |
| 总头数(默认 12)。 |
| head_dim : int |
| 每头维度(默认 64,必须为偶数)。 |
| |
| 返回 |
| ---- |
| cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]`` |
| 每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。 |
| """ |
| assert head_dim % 2 == 0, "head_dim 必须为偶数" |
| assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组" |
|
|
| half = head_dim // 2 |
| heads_per_group = num_heads // 3 |
| bsz, n_v, _ = rays.shape |
| if device is None: |
| device = rays.device |
|
|
| |
| |
| |
| |
| splits_g0 = _split_head_dim_for_components(half, 3) |
| splits_g1 = _split_head_dim_for_components(half, 3) |
| splits_g2 = _split_head_dim_for_components(half, 0) |
|
|
| |
| |
| |
| def _freqs(num_pairs: int) -> torch.Tensor: |
| |
| idx = torch.arange(num_pairs, device=device, dtype=dtype) |
| freqs = rope_theta ** (-2.0 * idx / head_dim) |
| return freqs |
|
|
| |
| angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype) |
|
|
| |
| base_offset = 0 |
| h0_start = 0 |
| h0_end = h0_start + heads_per_group |
| cursor = 0 |
| for c in range(3): |
| n_pairs = splits_g0[c] |
| if n_pairs > 0: |
| f = _freqs(n_pairs) |
| comp_val = rays[..., c : c + 1] |
| ang = comp_val * f |
| angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2) |
| cursor += n_pairs |
| |
|
|
| |
| h1_start = heads_per_group |
| h1_end = h1_start + heads_per_group |
| cursor = 0 |
| for c in range(3): |
| n_pairs = splits_g1[c] |
| if n_pairs > 0: |
| f = _freqs(n_pairs) |
| comp_val = hwt_grid[..., c : c + 1] |
| ang = comp_val * f |
| angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2) |
| cursor += n_pairs |
|
|
| |
| |
|
|
| cos = torch.cos(angles) |
| sin = torch.sin(angles) |
| return cos, sin |
|
|
|
|
| def apply_rope( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。 |
| |
| 所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。 |
| |
| 参数 |
| ---- |
| q, k : Tensor, shape ``[B, H, N_v, head_dim]`` |
| cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]`` |
| |
| 返回 |
| ---- |
| 旋转后的 q, k,形状不变。 |
| """ |
| |
| cos_e = cos.permute(0, 2, 1, 3) |
| sin_e = sin.permute(0, 2, 1, 3) |
|
|
| |
| q_even = q[..., 0::2] |
| q_odd = q[..., 1::2] |
| k_even = k[..., 0::2] |
| k_odd = k[..., 1::2] |
|
|
| q_rot_even = q_even * cos_e - q_odd * sin_e |
| q_rot_odd = q_even * sin_e + q_odd * cos_e |
| k_rot_even = k_even * cos_e - k_odd * sin_e |
| k_rot_odd = k_even * sin_e + k_odd * cos_e |
|
|
| q_out = torch.empty_like(q) |
| k_out = torch.empty_like(k) |
| q_out[..., 0::2] = q_rot_even |
| q_out[..., 1::2] = q_rot_odd |
| k_out[..., 0::2] = k_rot_even |
| k_out[..., 1::2] = k_rot_odd |
| return q_out, k_out |
|
|
|
|
| class RoPE3D(nn.Module): |
| """3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。 |
| |
| 使用方式: |
| rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32) |
| cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3] |
| q, k = apply_rope(q_visual_only, k_visual_only, cos, sin) |
| """ |
|
|
| def __init__( |
| self, |
| num_heads: int = 12, |
| head_dim: int = 64, |
| time_size: int = 4, |
| height_size: int = 12, |
| width_size: int = 32, |
| rope_theta: float = 10000.0, |
| ) -> None: |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = head_dim |
| self.rope_theta = rope_theta |
| self.T = time_size |
| self.H = height_size |
| self.W = width_size |
|
|
| |
| t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1) |
| h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1) |
| w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1) |
| |
| T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij") |
| hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) |
| self.register_buffer("hwt_grid", hwt, persistent=False) |
|
|
| @property |
| def num_visual_tokens(self) -> int: |
| return self.T * self.H * self.W |
|
|
| def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """根据每 token 的射线方向计算 cos/sin。 |
| |
| ``rays`` shape: ``[B, N_v, 3]``。 |
| """ |
| bsz = rays.shape[0] |
| hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) |
| return build_rope_freqs( |
| rays=rays, |
| hwt_grid=hwt, |
| num_heads=self.num_heads, |
| head_dim=self.head_dim, |
| rope_theta=self.rope_theta, |
| dtype=rays.dtype, |
| ) |
|
|