"""3D RoPE(仅作用于视觉 token)。 12 头按 4+4+4 拆为三组: - 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。 - 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。 - 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。 为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else), 零频段头自然变为恒等映射。 将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按 3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。 """ from __future__ import annotations import torch import torch.nn as nn def _split_head_dim_for_components(half: int, num_components: int) -> list[int]: """把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。 返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。 若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后 一项视为"零频段"——它的频率会被显式置为 0。 """ if num_components == 0: return [0, half] base = half // num_components splits = [base] * num_components splits[-1] += half - base * num_components # 余数全归到最后一个分量 return splits def build_rope_freqs( rays: torch.Tensor, hwt_grid: torch.Tensor, num_heads: int = 12, head_dim: int = 64, rope_theta: float = 10000.0, device: torch.device | None = None, dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: """构造 3D RoPE 的 cos / sin 表。 参数 ---- rays : Tensor, shape ``[B, N_v, 3]`` 每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。 hwt_grid : Tensor, shape ``[B, N_v, 3]`` 归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。 num_heads : int 总头数(默认 12)。 head_dim : int 每头维度(默认 64,必须为偶数)。 返回 ---- cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]`` 每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。 """ assert head_dim % 2 == 0, "head_dim 必须为偶数" assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组" half = head_dim // 2 heads_per_group = num_heads // 3 bsz, n_v, _ = rays.shape if device is None: device = rays.device # === 三组分量值 === # group 0: rays (3 components) # group 1: hwt (3 components) # group 2: zero (0 components -> 全部 half 视为零频段) splits_g0 = _split_head_dim_for_components(half, 3) # 用于 rays splits_g1 = _split_head_dim_for_components(half, 3) # 用于 hwt splits_g2 = _split_head_dim_for_components(half, 0) # [0, half] # === 频率向量(沿 head_dim 半轴)=== # 经典 RoPE: theta_i = base ^ (-2i / d) # 这里我们对每个分量独立排布频率 def _freqs(num_pairs: int) -> torch.Tensor: # 前 num_pairs 个用 RoPE 频率,剩余补 0 idx = torch.arange(num_pairs, device=device, dtype=dtype) freqs = rope_theta ** (-2.0 * idx / head_dim) return freqs # [num_pairs] # 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half] angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype) # ---- 第 0 组(4 头):射线 ---- base_offset = 0 h0_start = 0 h0_end = h0_start + heads_per_group cursor = 0 for c in range(3): # dx, dy, dz n_pairs = splits_g0[c] if n_pairs > 0: f = _freqs(n_pairs) # [n_pairs] comp_val = rays[..., c : c + 1] # [B, N_v, 1] ang = comp_val * f # 广播 -> [B, N_v, n_pairs] angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2) cursor += n_pairs # 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0 # ---- 第 1 组(4 头):HWT ---- h1_start = heads_per_group h1_end = h1_start + heads_per_group cursor = 0 for c in range(3): # h, w, t n_pairs = splits_g1[c] if n_pairs > 0: f = _freqs(n_pairs) comp_val = hwt_grid[..., c : c + 1] ang = comp_val * f angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2) cursor += n_pairs # ---- 第 2 组(4 头):零频段 ---- # 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零) cos = torch.cos(angles) sin = torch.sin(angles) return cos, sin def apply_rope( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。 所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。 参数 ---- q, k : Tensor, shape ``[B, H, N_v, head_dim]`` cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]`` 返回 ---- 旋转后的 q, k,形状不变。 """ # 把 cos/sin 转成 [B, H, N_v, half] cos_e = cos.permute(0, 2, 1, 3) sin_e = sin.permute(0, 2, 1, 3) # 把 head_dim 维度按 (even, odd) 拆开成 [..., half] q_even = q[..., 0::2] q_odd = q[..., 1::2] k_even = k[..., 0::2] k_odd = k[..., 1::2] q_rot_even = q_even * cos_e - q_odd * sin_e q_rot_odd = q_even * sin_e + q_odd * cos_e k_rot_even = k_even * cos_e - k_odd * sin_e k_rot_odd = k_even * sin_e + k_odd * cos_e q_out = torch.empty_like(q) k_out = torch.empty_like(k) q_out[..., 0::2] = q_rot_even q_out[..., 1::2] = q_rot_odd k_out[..., 0::2] = k_rot_even k_out[..., 1::2] = k_rot_odd return q_out, k_out class RoPE3D(nn.Module): """3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。 使用方式: rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32) cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3] q, k = apply_rope(q_visual_only, k_visual_only, cos, sin) """ def __init__( self, num_heads: int = 12, head_dim: int = 64, time_size: int = 4, height_size: int = 12, width_size: int = 32, rope_theta: float = 10000.0, ) -> None: super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.rope_theta = rope_theta self.T = time_size self.H = height_size self.W = width_size # 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1) h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1) w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1) # 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致) T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij") hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) # [N_v, 3] self.register_buffer("hwt_grid", hwt, persistent=False) @property def num_visual_tokens(self) -> int: return self.T * self.H * self.W def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """根据每 token 的射线方向计算 cos/sin。 ``rays`` shape: ``[B, N_v, 3]``。 """ bsz = rays.shape[0] hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) # [B, N_v, 3] return build_rope_freqs( rays=rays, hwt_grid=hwt, num_heads=self.num_heads, head_dim=self.head_dim, rope_theta=self.rope_theta, dtype=rays.dtype, )