File size: 8,102 Bytes
0cfefd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | """3D RoPE(仅作用于视觉 token)。
12 头按 4+4+4 拆为三组:
- 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。
- 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。
- 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。
为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else),
零频段头自然变为恒等映射。
将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按
3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。
"""
from __future__ import annotations
import torch
import torch.nn as nn
def _split_head_dim_for_components(half: int, num_components: int) -> list[int]:
"""把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。
返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。
若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后
一项视为"零频段"——它的频率会被显式置为 0。
"""
if num_components == 0:
return [0, half]
base = half // num_components
splits = [base] * num_components
splits[-1] += half - base * num_components # 余数全归到最后一个分量
return splits
def build_rope_freqs(
rays: torch.Tensor,
hwt_grid: torch.Tensor,
num_heads: int = 12,
head_dim: int = 64,
rope_theta: float = 10000.0,
device: torch.device | None = None,
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""构造 3D RoPE 的 cos / sin 表。
参数
----
rays : Tensor, shape ``[B, N_v, 3]``
每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。
hwt_grid : Tensor, shape ``[B, N_v, 3]``
归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。
num_heads : int
总头数(默认 12)。
head_dim : int
每头维度(默认 64,必须为偶数)。
返回
----
cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]``
每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。
"""
assert head_dim % 2 == 0, "head_dim 必须为偶数"
assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组"
half = head_dim // 2
heads_per_group = num_heads // 3
bsz, n_v, _ = rays.shape
if device is None:
device = rays.device
# === 三组分量值 ===
# group 0: rays (3 components)
# group 1: hwt (3 components)
# group 2: zero (0 components -> 全部 half 视为零频段)
splits_g0 = _split_head_dim_for_components(half, 3) # 用于 rays
splits_g1 = _split_head_dim_for_components(half, 3) # 用于 hwt
splits_g2 = _split_head_dim_for_components(half, 0) # [0, half]
# === 频率向量(沿 head_dim 半轴)===
# 经典 RoPE: theta_i = base ^ (-2i / d)
# 这里我们对每个分量独立排布频率
def _freqs(num_pairs: int) -> torch.Tensor:
# 前 num_pairs 个用 RoPE 频率,剩余补 0
idx = torch.arange(num_pairs, device=device, dtype=dtype)
freqs = rope_theta ** (-2.0 * idx / head_dim)
return freqs # [num_pairs]
# 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half]
angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype)
# ---- 第 0 组(4 头):射线 ----
base_offset = 0
h0_start = 0
h0_end = h0_start + heads_per_group
cursor = 0
for c in range(3): # dx, dy, dz
n_pairs = splits_g0[c]
if n_pairs > 0:
f = _freqs(n_pairs) # [n_pairs]
comp_val = rays[..., c : c + 1] # [B, N_v, 1]
ang = comp_val * f # 广播 -> [B, N_v, n_pairs]
angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
cursor += n_pairs
# 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0
# ---- 第 1 组(4 头):HWT ----
h1_start = heads_per_group
h1_end = h1_start + heads_per_group
cursor = 0
for c in range(3): # h, w, t
n_pairs = splits_g1[c]
if n_pairs > 0:
f = _freqs(n_pairs)
comp_val = hwt_grid[..., c : c + 1]
ang = comp_val * f
angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
cursor += n_pairs
# ---- 第 2 组(4 头):零频段 ----
# 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零)
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。
所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。
参数
----
q, k : Tensor, shape ``[B, H, N_v, head_dim]``
cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]``
返回
----
旋转后的 q, k,形状不变。
"""
# 把 cos/sin 转成 [B, H, N_v, half]
cos_e = cos.permute(0, 2, 1, 3)
sin_e = sin.permute(0, 2, 1, 3)
# 把 head_dim 维度按 (even, odd) 拆开成 [..., half]
q_even = q[..., 0::2]
q_odd = q[..., 1::2]
k_even = k[..., 0::2]
k_odd = k[..., 1::2]
q_rot_even = q_even * cos_e - q_odd * sin_e
q_rot_odd = q_even * sin_e + q_odd * cos_e
k_rot_even = k_even * cos_e - k_odd * sin_e
k_rot_odd = k_even * sin_e + k_odd * cos_e
q_out = torch.empty_like(q)
k_out = torch.empty_like(k)
q_out[..., 0::2] = q_rot_even
q_out[..., 1::2] = q_rot_odd
k_out[..., 0::2] = k_rot_even
k_out[..., 1::2] = k_rot_odd
return q_out, k_out
class RoPE3D(nn.Module):
"""3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。
使用方式:
rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32)
cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3]
q, k = apply_rope(q_visual_only, k_visual_only, cos, sin)
"""
def __init__(
self,
num_heads: int = 12,
head_dim: int = 64,
time_size: int = 4,
height_size: int = 12,
width_size: int = 32,
rope_theta: float = 10000.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.rope_theta = rope_theta
self.T = time_size
self.H = height_size
self.W = width_size
# 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W
t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1)
h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1)
w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1)
# 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致)
T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij")
hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) # [N_v, 3]
self.register_buffer("hwt_grid", hwt, persistent=False)
@property
def num_visual_tokens(self) -> int:
return self.T * self.H * self.W
def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""根据每 token 的射线方向计算 cos/sin。
``rays`` shape: ``[B, N_v, 3]``。
"""
bsz = rays.shape[0]
hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) # [B, N_v, 3]
return build_rope_freqs(
rays=rays,
hwt_grid=hwt,
num_heads=self.num_heads,
head_dim=self.head_dim,
rope_theta=self.rope_theta,
dtype=rays.dtype,
)
|