dinac_ae / common /rope.py
data-archetype's picture
Upload DINAC-AE export package
1b703d5
from __future__ import annotations
import math
from collections.abc import Callable
import torch
from torch import nn
class Rope1D(nn.Module):
"""
Rotary Position Embedding (RoPE) 1D.
Based on the reference LLaMA implementation (Hugging Face
`modeling_llama.py`), adapted to this codebase without behavior changes.
- dim: per-head dimension
- max_position_embeddings: length used to precompute cached cos/sin (not required
by forward)
- base: RoPE base theta
Forward expects:
- x: (B, H, T, D)
- position_ids: (B, T) integer positions
Returns:
- cos, sin: (B, T, D)
"""
inv_freq: torch.Tensor
_cos_cached: torch.Tensor
_sin_cached: torch.Tensor
def __init__(
self,
dim: int,
max_position_embeddings: int = 2048,
base: float = 10000.0,
device: torch.device | None = None,
scaling_factor: float = 1.0,
) -> None:
super().__init__()
if dim % 2 != 0:
raise AssertionError("head_dim must be even for RoPE")
self.scaling_factor: float = float(scaling_factor)
self.dim: int = int(dim)
self.max_position_embeddings: int = int(max_position_embeddings)
self.base: float = float(base)
inv_freq = self._build_inv_freq(device=device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Cached cos/sin (not used in application, but kept for parity with reference)
self.max_seq_len_cached: int = self.max_position_embeddings
cos_cached, sin_cached = self._build_cached_trig(device=device)
self.register_buffer("_cos_cached", cos_cached, persistent=False)
self.register_buffer("_sin_cached", sin_cached, persistent=False)
def _build_inv_freq(self, *, device: torch.device | None) -> torch.Tensor:
"""Return the RoPE inverse-frequency vector in float32."""
return 1.0 / (
self.base
** (
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
/ float(self.dim)
)
)
def _build_cached_trig(
self, *, device: torch.device | None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return cached RoPE trig tensors in float32."""
inv_freq = self._build_inv_freq(device=device)
t = torch.arange(
self.max_seq_len_cached,
device=device,
dtype=torch.float32,
)
t = t / self.scaling_factor
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def _apply(
self,
fn: Callable[[torch.Tensor], torch.Tensor],
recurse: bool = True,
) -> Rope1D:
"""Apply module moves/casts while preserving fp32 RoPE buffers."""
out = super()._apply(fn, recurse=recurse)
with torch.no_grad():
device = self.inv_freq.device
self.inv_freq.data = self._build_inv_freq(device=device)
cos_cached, sin_cached = self._build_cached_trig(device=device)
self._cos_cached.data = cos_cached
self._sin_cached.data = sin_cached
return out
@torch.no_grad()
def forward(
self, x: torch.Tensor, position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
inv_freq_tensor = self._build_inv_freq(device=x.device)
inv_freq_expanded = (
inv_freq_tensor[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float() / self.scaling_factor
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def rotate_half_adjacent(x: torch.Tensor) -> torch.Tensor:
"""Rotate consecutive pairs in the last dimension.
This matches the common EVA-02 / SpeedrunDiT RoPE convention where the last
dimension is interpreted as pairs ``(x0, x1), (x2, x3), ...``.
"""
if x.shape[-1] % 2 != 0:
raise ValueError("rotate_half_adjacent requires an even last dimension")
x_pairs = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
x1 = x_pairs[..., 0]
x2 = x_pairs[..., 1]
return torch.stack((-x2, x1), dim=-1).reshape_as(x)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
*,
unsqueeze_dim: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LearnableRoPE2D(nn.Module):
r"""
Learnable mixed 2D RoPE with axial RoPE2D-compatible initialization.
- Learnable frequency banks for X and Y.
- Frequencies can be shared across groups of attention heads (see
``rope_param_dim``).
- Angle per pair: theta = x * fx[g, i] + y * fy[g, i]
- Initialization matches the axial RoPE2D parameterization used by DiTTrunk
for ``ROPE_2D_AXIAL_FREQ_AWARE`` (AxialRoPE2DConfig(base=100, dim_layout=HALF_SPLIT)):
- Angle multiplier ``2π``.
- Period base ``100`` (DINOv3-style), applied per-axis.
Each head group starts identically (deterministic init) so the learnable
variant is functionally identical to axial RoPE2D at step 0.
- Rotation is implemented with real-valued sin/cos to avoid complex tensors
(torch.compile/inductor cannot codegen complex dtypes).
Shapes:
- Expects q,k of shape (B, H, T, D) with D % 4 == 0.
- Positions xy: (T, 2) or (B, T, 2), any real dtype (cast to float32).
- Parameter `freqs`: (2, G, D//2) in float32; index 0 = x, 1 = y.
Head grouping / parameter budget
-------------------------------
``rope_param_dim`` controls the total number of learned RoPE frequency
parameters (scalars) for this module.
Let:
- ``head_dim = D`` (per-head width)
- ``num_heads = H``
- ``rope_param_dim = P``
Then the module uses:
- ``num_groups = G = P // D``
- ``heads_per_group = H // G``
This is fail-fast: ``P`` must be divisible by ``D`` and ``H`` must be
divisible by ``G``. When ``rope_param_dim`` is None (default), the module
uses the classic per-head parameterization with ``P = H * D``.
"""
def __init__(
self,
head_dim: int,
*,
num_heads: int,
rope_param_dim: int | None = None,
rope_base: float = 100.0,
angle_multiplier: float = 2.0 * float(math.pi),
learnable: bool = True,
persist_buffers: bool = True,
) -> None:
super().__init__()
if head_dim % 4 != 0:
raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE")
self.head_dim: int = int(head_dim)
# Avoid naming collisions with nn.Module.half() (dtype casting helper).
self.half_dim: int = self.head_dim // 2
self.num_heads: int = int(num_heads)
effective_param_dim = (
int(rope_param_dim)
if rope_param_dim is not None
else self.num_heads * self.head_dim
)
if effective_param_dim <= 0:
raise ValueError("rope_param_dim must be positive for LearnableRoPE2D")
self.rope_param_dim: int = int(effective_param_dim)
self._learnable: bool = bool(learnable)
theta = float(rope_base)
mult = float(angle_multiplier)
if not math.isfinite(theta) or theta <= 0.0:
raise ValueError("rope_base must be finite and > 0 for LearnableRoPE2D")
if not math.isfinite(mult) or mult <= 0.0:
raise ValueError(
"angle_multiplier must be finite and > 0 for LearnableRoPE2D"
)
if self.rope_param_dim % self.head_dim != 0:
raise ValueError(
"rope_param_dim must be divisible by head_dim for LearnableRoPE2D "
f"(got rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})"
)
self.num_groups: int = self.rope_param_dim // self.head_dim
if self.num_groups <= 0:
raise RuntimeError("num_groups must be positive for LearnableRoPE2D")
if self.num_heads % self.num_groups != 0:
raise ValueError(
"num_heads must be divisible by (rope_param_dim / head_dim) for LearnableRoPE2D "
f"(got num_heads={self.num_heads}, num_groups={self.num_groups}, "
f"rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})"
)
self.heads_per_group: int = self.num_heads // self.num_groups
if self.heads_per_group <= 0:
raise RuntimeError("heads_per_group must be positive for LearnableRoPE2D")
# Axial-compatible deterministic init:
# - periods match AxialRoPE2DConfig(base=100, dim_layout=HALF_SPLIT)
# - angle = 2π * coord / period
qtr = self.head_dim // 4
exponents = (
2.0
* torch.arange(int(qtr), dtype=torch.float32)
/ float(self.head_dim // 2)
)
periods = torch.tensor(theta, dtype=torch.float32) ** exponents # [qtr]
axis_freqs = (mult / periods).to(dtype=torch.float32) # [qtr]
zeros = torch.zeros_like(axis_freqs)
# Match AxialRoPE2D(HALF_SPLIT) flatten order: [y-axis, x-axis].
# Our xy columns are (x, y), so:
# - x contributes to the second quarter (x-axis part)
# - y contributes to the first quarter (y-axis part)
fx_half = torch.cat((zeros, axis_freqs), dim=0) # [half_dim]
fy_half = torch.cat((axis_freqs, zeros), dim=0) # [half_dim]
freqs_x = fx_half.expand(int(self.num_groups), -1).clone()
freqs_y = fy_half.expand(int(self.num_groups), -1).clone()
freqs = torch.stack([freqs_x, freqs_y], dim=0) # (2, G, half)
if self._learnable:
self.freqs = nn.Parameter(freqs, requires_grad=True)
else:
self.register_buffer("freqs", freqs, persistent=persist_buffers)
def _apply(
self,
fn: Callable[[torch.Tensor], torch.Tensor],
recurse: bool = True,
) -> LearnableRoPE2D:
"""Apply module moves/casts while preserving fp32 frequency tensors."""
out = super()._apply(fn, recurse=recurse)
with torch.no_grad():
self.freqs.data = self.freqs.data.to(dtype=torch.float32)
return out
def _apply_rotary_from_trig(
self,
x: torch.Tensor,
*,
sin: torch.Tensor,
cos: torch.Tensor,
) -> torch.Tensor:
"""Rotate Q/K using precomputed grouped sin/cos buffers (HALF_SPLIT layout).
This matches AxialRoPE2DConfig(dim_layout=HALF_SPLIT) rotation and keeps
the learnable variant identical at initialization when combined with
axial-compatible frequency init.
Args:
x: Tensor shaped ``(B, H, T, D)``.
sin: Sin tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``.
cos: Cos tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``.
Returns:
Tensor with the same shape/dtype/device as ``x``.
"""
if x.dim() != 4:
raise ValueError("x must be shaped (B, H, T, D)")
B, H, T, D = x.shape
if self.num_heads != int(H):
raise ValueError("num_heads mismatch for LearnableRoPE2D")
if self.head_dim != int(D):
raise ValueError("head_dim mismatch for LearnableRoPE2D")
if sin.dim() == 3 and cos.dim() == 3:
sin = sin.unsqueeze(0)
cos = cos.unsqueeze(0)
if sin.dim() != 4 or cos.dim() != 4:
raise RuntimeError("Unexpected sin/cos rank for LearnableRoPE2D")
if int(D) % 2 != 0:
raise RuntimeError("LearnableRoPE2D requires even head_dim for HALF_SPLIT")
half = int(D) // 2
if int(sin.shape[-1]) != half or int(cos.shape[-1]) != half:
raise RuntimeError(
"LearnableRoPE2D expected sin/cos last dim == head_dim//2 "
f"(got sin={tuple(sin.shape)}, cos={tuple(cos.shape)}, head_dim={int(D)})"
)
sin = sin[:, :, None, :, :] # [B, G, 1, T, half]
cos = cos[:, :, None, :, :] # [B, G, 1, T, half]
grouped = x.reshape(
int(B),
int(self.num_groups),
int(self.heads_per_group),
int(T),
int(D),
)
x1 = grouped[..., :half]
x2 = grouped[..., half:]
out1 = x1 * cos - x2 * sin
out2 = x2 * cos + x1 * sin
out = torch.cat((out1, out2), dim=-1).reshape(int(B), int(H), int(T), int(D))
return out.to(dtype=x.dtype)
def _compute_mixed_cis(self, xy: torch.Tensor) -> torch.Tensor:
# Returns complex cis angles with shape (G, T, half) or (B, G, T, half)
if xy.dim() == 2:
# (T, 2) -> (G, T, half)
t_x = xy[:, 0].to(dtype=torch.float32)
t_y = xy[:, 1].to(dtype=torch.float32)
with torch.autocast(device_type=t_x.device.type, enabled=False):
# Memory notes:
# - Avoid materializing both fx and fy; accumulate in-place into angles.
# - Avoid torch.ones_like(angles) (full-size allocation); a scalar
# magnitude broadcasts in torch.polar.
angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
0
) # (T, G, half)
angles.add_(
t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0)
)
angles = angles.permute(1, 0, 2) # (G, T, half)
cis = torch.polar(
torch.ones((), device=angles.device, dtype=angles.dtype), angles
)
return cis
elif xy.dim() == 3:
# (B, T, 2) -> (B, G, T, half)
t_x = xy[..., 0].to(dtype=torch.float32)
t_y = xy[..., 1].to(dtype=torch.float32)
with torch.autocast(device_type=t_x.device.type, enabled=False):
angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
0
).unsqueeze(0)
angles.add_(
t_y.unsqueeze(-1).unsqueeze(-1)
* self.freqs[1].unsqueeze(0).unsqueeze(0)
)
angles = angles.permute(0, 2, 1, 3) # (B, G, T, half)
cis = torch.polar(
torch.ones((), device=angles.device, dtype=angles.dtype), angles
)
return cis
else:
raise ValueError("xy must have shape (T,2) or (B,T,2)")
def _compute_mixed_angles(self, xy: torch.Tensor) -> torch.Tensor:
"""Return mixed RoPE2D angles without applying cis/polar.
Args:
xy: XY positions shaped ``(T, 2)`` or ``(B, T, 2)``.
Returns:
Float tensor of angles shaped ``(G, T, half)`` or ``(B, G, T, half)``.
"""
if xy.dim() == 2:
t_x = xy[:, 0].to(dtype=torch.float32)
t_y = xy[:, 1].to(dtype=torch.float32)
with torch.autocast(device_type=t_x.device.type, enabled=False):
angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(0)
angles.add_(
t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0)
)
return angles.permute(1, 0, 2)
if xy.dim() == 3:
t_x = xy[..., 0].to(dtype=torch.float32)
t_y = xy[..., 1].to(dtype=torch.float32)
with torch.autocast(device_type=t_x.device.type, enabled=False):
angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(
0
).unsqueeze(0)
angles.add_(
t_y.unsqueeze(-1).unsqueeze(-1)
* self.freqs[1].unsqueeze(0).unsqueeze(0)
)
return angles.permute(0, 2, 1, 3)
raise ValueError("xy must have shape (T,2) or (B,T,2)")
def _cos_sin_half_from_xy(
self,
xy: torch.Tensor,
*,
device: torch.device | None = None,
out_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Helper used in tests to build real-valued cos/sin tensors.
cis = self._compute_mixed_cis(xy.to(device=device) if device else xy)
# Convert complex cis to cos/sin (real/imag) with matching shapes
if cis.is_complex():
cos_h = cis.real
sin_h = cis.imag
else:
# Should not happen; torch.polar returns complex64/128
raise RuntimeError("Expected complex cis tensor from polar")
if out_dtype is not None:
cos_h = cos_h.to(dtype=out_dtype)
sin_h = sin_h.to(dtype=out_dtype)
return cos_h, sin_h
def _cos_sin_from_xy(
self,
xy: torch.Tensor,
*,
device: torch.device | None = None,
out_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
cos_h, sin_h = self._cos_sin_half_from_xy(
xy, device=device, out_dtype=out_dtype
)
emb_cos = torch.cat((cos_h, cos_h), dim=-1)
emb_sin = torch.cat((sin_h, sin_h), dim=-1)
return emb_cos, emb_sin
def rotate_qk(
self,
q: torch.Tensor,
k: torch.Tensor,
xy: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if q.dim() != 4 or k.dim() != 4:
raise ValueError("q,k must be shaped (B,H,T,D)")
_, H, _, D = q.shape
if self.num_heads != H:
raise ValueError("num_heads mismatch for LearnableRoPE2D")
if self.head_dim != D:
raise ValueError("head_dim mismatch for LearnableRoPE2D")
if D % 4 != 0:
raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE")
# Use real-valued sin/cos rotation to keep torch.compile/inductor on the
# fast path (inductor cannot codegen complex tensors).
angles = self._compute_mixed_angles(xy.to(device=q.device))
sin = torch.sin(angles)
cos = torch.cos(angles)
q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos)
k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos)
return q_out, k_out
def rotate_qk_with_dilation(
self,
q: torch.Tensor,
k: torch.Tensor,
*,
xy: torch.Tensor,
scales: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Rotate Q/K using mixed 2D RoPE with per-sample isotropic dilation.
This implements dilation by scaling the RoPE angle, i.e.
``theta_dilated = scale * theta_base`` where ``theta_base`` comes from the
undilated XY coordinates.
Args:
q: Query tensor shaped ``(B, H, T, D)``.
k: Key tensor shaped ``(B, H, T, D)``.
xy: Base XY coordinates shaped ``(T, 2)`` or ``(B, T, 2)``.
scales: Per-sample dilation scales shaped ``(B,)``.
Raises:
ValueError: If shapes are inconsistent or scales are not 1D.
"""
if q.dim() != 4 or k.dim() != 4:
raise ValueError("q,k must be shaped (B,H,T,D)")
B, H, T, D = q.shape
if self.num_heads != H:
raise ValueError("num_heads mismatch for LearnableRoPE2D")
if self.head_dim != D:
raise ValueError("head_dim mismatch for LearnableRoPE2D")
if scales.dim() != 1 or scales.shape[0] != B:
raise ValueError("scales must have shape (B,) matching q batch size")
if xy.dim() == 2 and xy.shape[0] != T:
raise ValueError("xy length must match q sequence length")
if xy.dim() == 3 and (xy.shape[0] != B or xy.shape[1] != T):
raise ValueError("xy must have shape (B,T,2) matching q batch/sequence")
if xy.shape[-1] != 2:
raise ValueError("xy must have last dimension 2")
angles = self._compute_mixed_angles(xy.to(device=q.device))
angles = angles * scales.to(device=q.device, dtype=torch.float32).view(
B, 1, 1, 1
)
sin = torch.sin(angles)
cos = torch.cos(angles)
q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos)
k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos)
return q_out, k_out