WJAD / src /wjad /modules /pos_encoding.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""3D RoPE(仅作用于视觉 token)。
12 头按 4+4+4 拆为三组:
- 头 0-3:射线 RoPE,编码自车系下的单位射线方向 ``(dx, dy, dz)``。
- 头 4-7:H/W/T RoPE,编码归一化的空间-时间索引 ``(h_norm, w_norm, t_norm)``。
- 头 8-11:零频段 RoPE,cos=1 / sin=0 → 旋转矩阵恒为 I(identity)。
为减少分支与显存通信,全部 12 头统一走同一份 RoPE 算子(不写 if/else),
零频段头自然变为恒等映射。
将 ``head_dim=64`` 切成 32 个 (cos, sin) 对(两两一组旋转)。每组头内部再按
3 个分量(dx,dy,dz 或 h,w,t)平均分配 32/3 ≈ 10 对(最后 2 对补 0 频)。
"""
from __future__ import annotations
import torch
import torch.nn as nn
def _split_head_dim_for_components(half: int, num_components: int) -> list[int]:
"""把 head_dim/2 个旋转对均匀分给若干个分量;剩余补 0 频。
返回每个分量分到的旋转对数,最后一项是 ``half - sum(其它)``。
若 ``num_components == 0``(零频段头),则返回 ``[0, 0, ..., half]``,最后
一项视为"零频段"——它的频率会被显式置为 0。
"""
if num_components == 0:
return [0, half]
base = half // num_components
splits = [base] * num_components
splits[-1] += half - base * num_components # 余数全归到最后一个分量
return splits
def build_rope_freqs(
rays: torch.Tensor,
hwt_grid: torch.Tensor,
num_heads: int = 12,
head_dim: int = 64,
rope_theta: float = 10000.0,
device: torch.device | None = None,
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""构造 3D RoPE 的 cos / sin 表。
参数
----
rays : Tensor, shape ``[B, N_v, 3]``
每个视觉 token 在自车系下的单位射线方向 ``(dx, dy, dz)``。
hwt_grid : Tensor, shape ``[B, N_v, 3]``
归一化的空间-时间坐标 ``(h_norm, w_norm, t_norm)`` ∈ [-1, 1]。
num_heads : int
总头数(默认 12)。
head_dim : int
每头维度(默认 64,必须为偶数)。
返回
----
cos, sin : Tensor, shape ``[B, N_v, num_heads, head_dim // 2]``
每个旋转对的 cos / sin 值,已就绪可送入 ``apply_rope``。
"""
assert head_dim % 2 == 0, "head_dim 必须为偶数"
assert num_heads % 3 == 0, "num_heads 需被 3 整除以便 4+4+4 分组"
half = head_dim // 2
heads_per_group = num_heads // 3
bsz, n_v, _ = rays.shape
if device is None:
device = rays.device
# === 三组分量值 ===
# group 0: rays (3 components)
# group 1: hwt (3 components)
# group 2: zero (0 components -> 全部 half 视为零频段)
splits_g0 = _split_head_dim_for_components(half, 3) # 用于 rays
splits_g1 = _split_head_dim_for_components(half, 3) # 用于 hwt
splits_g2 = _split_head_dim_for_components(half, 0) # [0, half]
# === 频率向量(沿 head_dim 半轴)===
# 经典 RoPE: theta_i = base ^ (-2i / d)
# 这里我们对每个分量独立排布频率
def _freqs(num_pairs: int) -> torch.Tensor:
# 前 num_pairs 个用 RoPE 频率,剩余补 0
idx = torch.arange(num_pairs, device=device, dtype=dtype)
freqs = rope_theta ** (-2.0 * idx / head_dim)
return freqs # [num_pairs]
# 把分量值与频率张量逐头展开为 [B, N_v, num_heads, half]
angles = torch.zeros(bsz, n_v, num_heads, half, device=device, dtype=dtype)
# ---- 第 0 组(4 头):射线 ----
base_offset = 0
h0_start = 0
h0_end = h0_start + heads_per_group
cursor = 0
for c in range(3): # dx, dy, dz
n_pairs = splits_g0[c]
if n_pairs > 0:
f = _freqs(n_pairs) # [n_pairs]
comp_val = rays[..., c : c + 1] # [B, N_v, 1]
ang = comp_val * f # 广播 -> [B, N_v, n_pairs]
angles[:, :, h0_start:h0_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
cursor += n_pairs
# 余数(splits_g0 最后一项的"补足"部分由 _split 已并入最后分量),无需置 0
# ---- 第 1 组(4 头):HWT ----
h1_start = heads_per_group
h1_end = h1_start + heads_per_group
cursor = 0
for c in range(3): # h, w, t
n_pairs = splits_g1[c]
if n_pairs > 0:
f = _freqs(n_pairs)
comp_val = hwt_grid[..., c : c + 1]
ang = comp_val * f
angles[:, :, h1_start:h1_end, cursor : cursor + n_pairs] = ang.unsqueeze(2)
cursor += n_pairs
# ---- 第 2 组(4 头):零频段 ----
# 角度恒为 0 → cos=1, sin=0 → 等价 identity;不需要再赋值(已是零)
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""对 ``q`` ``k`` 的视觉 token 部分应用 3D RoPE。
所有 12 头一视同仁地走同一段代码(零频段头 cos=1/sin=0 → identity)。
参数
----
q, k : Tensor, shape ``[B, H, N_v, head_dim]``
cos, sin : Tensor, shape ``[B, N_v, H, head_dim // 2]``
返回
----
旋转后的 q, k,形状不变。
"""
# 把 cos/sin 转成 [B, H, N_v, half]
cos_e = cos.permute(0, 2, 1, 3)
sin_e = sin.permute(0, 2, 1, 3)
# 把 head_dim 维度按 (even, odd) 拆开成 [..., half]
q_even = q[..., 0::2]
q_odd = q[..., 1::2]
k_even = k[..., 0::2]
k_odd = k[..., 1::2]
q_rot_even = q_even * cos_e - q_odd * sin_e
q_rot_odd = q_even * sin_e + q_odd * cos_e
k_rot_even = k_even * cos_e - k_odd * sin_e
k_rot_odd = k_even * sin_e + k_odd * cos_e
q_out = torch.empty_like(q)
k_out = torch.empty_like(k)
q_out[..., 0::2] = q_rot_even
q_out[..., 1::2] = q_rot_odd
k_out[..., 0::2] = k_rot_even
k_out[..., 1::2] = k_rot_odd
return q_out, k_out
class RoPE3D(nn.Module):
"""3D RoPE 工具模块:缓存 hwt_grid(视觉 token 网格上不变),动态计算 rays。
使用方式:
rope = RoPE3D(num_heads=12, head_dim=64, T=4, H=12, W=32)
cos, sin = rope.compute_freqs(rays) # rays: [B, N_v, 3]
q, k = apply_rope(q_visual_only, k_visual_only, cos, sin)
"""
def __init__(
self,
num_heads: int = 12,
head_dim: int = 64,
time_size: int = 4,
height_size: int = 12,
width_size: int = 32,
rope_theta: float = 10000.0,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.rope_theta = rope_theta
self.T = time_size
self.H = height_size
self.W = width_size
# 预计算并缓存归一化 H/W/T 网格 [N_v, 3],N_v = T*H*W
t = torch.linspace(-1.0, 1.0, steps=time_size) if time_size > 1 else torch.zeros(1)
h = torch.linspace(-1.0, 1.0, steps=height_size) if height_size > 1 else torch.zeros(1)
w = torch.linspace(-1.0, 1.0, steps=width_size) if width_size > 1 else torch.zeros(1)
# 顺序:t -> h -> w(与 Conv3D 输出展平顺序一致)
T_, H_, W_ = torch.meshgrid(t, h, w, indexing="ij")
hwt = torch.stack([H_, W_, T_], dim=-1).reshape(-1, 3) # [N_v, 3]
self.register_buffer("hwt_grid", hwt, persistent=False)
@property
def num_visual_tokens(self) -> int:
return self.T * self.H * self.W
def compute_freqs(self, rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""根据每 token 的射线方向计算 cos/sin。
``rays`` shape: ``[B, N_v, 3]``。
"""
bsz = rays.shape[0]
hwt = self.hwt_grid.unsqueeze(0).expand(bsz, -1, -1) # [B, N_v, 3]
return build_rope_freqs(
rays=rays,
hwt_grid=hwt,
num_heads=self.num_heads,
head_dim=self.head_dim,
rope_theta=self.rope_theta,
dtype=rays.dtype,
)