| from __future__ import annotations |
|
|
| import math |
| from collections.abc import Callable |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| class Rope1D(nn.Module): |
| """ |
| Rotary Position Embedding (RoPE) 1D. |
| |
| Based on the reference LLaMA implementation (Hugging Face |
| `modeling_llama.py`), adapted to this codebase without behavior changes. |
| |
| - dim: per-head dimension |
| - max_position_embeddings: length used to precompute cached cos/sin (not required |
| by forward) |
| - base: RoPE base theta |
| |
| Forward expects: |
| - x: (B, H, T, D) |
| - position_ids: (B, T) integer positions |
| Returns: |
| - cos, sin: (B, T, D) |
| """ |
|
|
| inv_freq: torch.Tensor |
| _cos_cached: torch.Tensor |
| _sin_cached: torch.Tensor |
|
|
| def __init__( |
| self, |
| dim: int, |
| max_position_embeddings: int = 2048, |
| base: float = 10000.0, |
| device: torch.device | None = None, |
| scaling_factor: float = 1.0, |
| ) -> None: |
| super().__init__() |
| if dim % 2 != 0: |
| raise AssertionError("head_dim must be even for RoPE") |
| self.scaling_factor: float = float(scaling_factor) |
| self.dim: int = int(dim) |
| self.max_position_embeddings: int = int(max_position_embeddings) |
| self.base: float = float(base) |
| inv_freq = self._build_inv_freq(device=device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| |
| self.max_seq_len_cached: int = self.max_position_embeddings |
| cos_cached, sin_cached = self._build_cached_trig(device=device) |
| self.register_buffer("_cos_cached", cos_cached, persistent=False) |
| self.register_buffer("_sin_cached", sin_cached, persistent=False) |
|
|
| def _build_inv_freq(self, *, device: torch.device | None) -> torch.Tensor: |
| """Return the RoPE inverse-frequency vector in float32.""" |
|
|
| return 1.0 / ( |
| self.base |
| ** ( |
| torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) |
| / float(self.dim) |
| ) |
| ) |
|
|
| def _build_cached_trig( |
| self, *, device: torch.device | None |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Return cached RoPE trig tensors in float32.""" |
|
|
| inv_freq = self._build_inv_freq(device=device) |
| t = torch.arange( |
| self.max_seq_len_cached, |
| device=device, |
| dtype=torch.float32, |
| ) |
| t = t / self.scaling_factor |
| freqs = torch.outer(t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| return emb.cos(), emb.sin() |
|
|
| def _apply( |
| self, |
| fn: Callable[[torch.Tensor], torch.Tensor], |
| recurse: bool = True, |
| ) -> Rope1D: |
| """Apply module moves/casts while preserving fp32 RoPE buffers.""" |
|
|
| out = super()._apply(fn, recurse=recurse) |
| with torch.no_grad(): |
| device = self.inv_freq.device |
| self.inv_freq.data = self._build_inv_freq(device=device) |
| cos_cached, sin_cached = self._build_cached_trig(device=device) |
| self._cos_cached.data = cos_cached |
| self._sin_cached.data = sin_cached |
| return out |
|
|
| @torch.no_grad() |
| def forward( |
| self, x: torch.Tensor, position_ids: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| inv_freq_tensor = self._build_inv_freq(device=x.device) |
| inv_freq_expanded = ( |
| inv_freq_tensor[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
| ) |
| position_ids_expanded = position_ids[:, None, :].float() / self.scaling_factor |
| device_type = x.device.type |
| device_type = ( |
| device_type |
| if isinstance(device_type, str) and device_type != "mps" |
| else "cpu" |
| ) |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = ( |
| inv_freq_expanded.float() @ position_ids_expanded.float() |
| ).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() |
| sin = emb.sin() |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def rotate_half_adjacent(x: torch.Tensor) -> torch.Tensor: |
| """Rotate consecutive pairs in the last dimension. |
| |
| This matches the common EVA-02 / SpeedrunDiT RoPE convention where the last |
| dimension is interpreted as pairs ``(x0, x1), (x2, x3), ...``. |
| """ |
| if x.shape[-1] % 2 != 0: |
| raise ValueError("rotate_half_adjacent requires an even last dimension") |
| x_pairs = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2) |
| x1 = x_pairs[..., 0] |
| x2 = x_pairs[..., 1] |
| return torch.stack((-x2, x1), dim=-1).reshape_as(x) |
|
|
|
|
| def apply_rotary_pos_emb( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| *, |
| unsqueeze_dim: int = 1, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class LearnableRoPE2D(nn.Module): |
| r""" |
| Learnable mixed 2D RoPE with axial RoPE2D-compatible initialization. |
| |
| - Learnable frequency banks for X and Y. |
| - Frequencies can be shared across groups of attention heads (see |
| ``rope_param_dim``). |
| - Angle per pair: theta = x * fx[g, i] + y * fy[g, i] |
| - Initialization matches the axial RoPE2D parameterization used by DiTTrunk |
| for ``ROPE_2D_AXIAL_FREQ_AWARE`` (AxialRoPE2DConfig(base=100, dim_layout=HALF_SPLIT)): |
| - Angle multiplier ``2π``. |
| - Period base ``100`` (DINOv3-style), applied per-axis. |
| Each head group starts identically (deterministic init) so the learnable |
| variant is functionally identical to axial RoPE2D at step 0. |
| - Rotation is implemented with real-valued sin/cos to avoid complex tensors |
| (torch.compile/inductor cannot codegen complex dtypes). |
| |
| Shapes: |
| - Expects q,k of shape (B, H, T, D) with D % 4 == 0. |
| - Positions xy: (T, 2) or (B, T, 2), any real dtype (cast to float32). |
| - Parameter `freqs`: (2, G, D//2) in float32; index 0 = x, 1 = y. |
| |
| Head grouping / parameter budget |
| ------------------------------- |
| ``rope_param_dim`` controls the total number of learned RoPE frequency |
| parameters (scalars) for this module. |
| |
| Let: |
| - ``head_dim = D`` (per-head width) |
| - ``num_heads = H`` |
| - ``rope_param_dim = P`` |
| |
| Then the module uses: |
| - ``num_groups = G = P // D`` |
| - ``heads_per_group = H // G`` |
| |
| This is fail-fast: ``P`` must be divisible by ``D`` and ``H`` must be |
| divisible by ``G``. When ``rope_param_dim`` is None (default), the module |
| uses the classic per-head parameterization with ``P = H * D``. |
| """ |
|
|
| def __init__( |
| self, |
| head_dim: int, |
| *, |
| num_heads: int, |
| rope_param_dim: int | None = None, |
| rope_base: float = 100.0, |
| angle_multiplier: float = 2.0 * float(math.pi), |
| learnable: bool = True, |
| persist_buffers: bool = True, |
| ) -> None: |
| super().__init__() |
| if head_dim % 4 != 0: |
| raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE") |
| self.head_dim: int = int(head_dim) |
| |
| self.half_dim: int = self.head_dim // 2 |
| self.num_heads: int = int(num_heads) |
| effective_param_dim = ( |
| int(rope_param_dim) |
| if rope_param_dim is not None |
| else self.num_heads * self.head_dim |
| ) |
| if effective_param_dim <= 0: |
| raise ValueError("rope_param_dim must be positive for LearnableRoPE2D") |
| self.rope_param_dim: int = int(effective_param_dim) |
| self._learnable: bool = bool(learnable) |
| theta = float(rope_base) |
| mult = float(angle_multiplier) |
| if not math.isfinite(theta) or theta <= 0.0: |
| raise ValueError("rope_base must be finite and > 0 for LearnableRoPE2D") |
| if not math.isfinite(mult) or mult <= 0.0: |
| raise ValueError( |
| "angle_multiplier must be finite and > 0 for LearnableRoPE2D" |
| ) |
|
|
| if self.rope_param_dim % self.head_dim != 0: |
| raise ValueError( |
| "rope_param_dim must be divisible by head_dim for LearnableRoPE2D " |
| f"(got rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})" |
| ) |
| self.num_groups: int = self.rope_param_dim // self.head_dim |
| if self.num_groups <= 0: |
| raise RuntimeError("num_groups must be positive for LearnableRoPE2D") |
| if self.num_heads % self.num_groups != 0: |
| raise ValueError( |
| "num_heads must be divisible by (rope_param_dim / head_dim) for LearnableRoPE2D " |
| f"(got num_heads={self.num_heads}, num_groups={self.num_groups}, " |
| f"rope_param_dim={self.rope_param_dim}, head_dim={self.head_dim})" |
| ) |
| self.heads_per_group: int = self.num_heads // self.num_groups |
| if self.heads_per_group <= 0: |
| raise RuntimeError("heads_per_group must be positive for LearnableRoPE2D") |
|
|
| |
| |
| |
| qtr = self.head_dim // 4 |
| exponents = ( |
| 2.0 |
| * torch.arange(int(qtr), dtype=torch.float32) |
| / float(self.head_dim // 2) |
| ) |
| periods = torch.tensor(theta, dtype=torch.float32) ** exponents |
| axis_freqs = (mult / periods).to(dtype=torch.float32) |
|
|
| zeros = torch.zeros_like(axis_freqs) |
| |
| |
| |
| |
| fx_half = torch.cat((zeros, axis_freqs), dim=0) |
| fy_half = torch.cat((axis_freqs, zeros), dim=0) |
|
|
| freqs_x = fx_half.expand(int(self.num_groups), -1).clone() |
| freqs_y = fy_half.expand(int(self.num_groups), -1).clone() |
| freqs = torch.stack([freqs_x, freqs_y], dim=0) |
| if self._learnable: |
| self.freqs = nn.Parameter(freqs, requires_grad=True) |
| else: |
| self.register_buffer("freqs", freqs, persistent=persist_buffers) |
|
|
| def _apply( |
| self, |
| fn: Callable[[torch.Tensor], torch.Tensor], |
| recurse: bool = True, |
| ) -> LearnableRoPE2D: |
| """Apply module moves/casts while preserving fp32 frequency tensors.""" |
|
|
| out = super()._apply(fn, recurse=recurse) |
| with torch.no_grad(): |
| self.freqs.data = self.freqs.data.to(dtype=torch.float32) |
| return out |
|
|
| def _apply_rotary_from_trig( |
| self, |
| x: torch.Tensor, |
| *, |
| sin: torch.Tensor, |
| cos: torch.Tensor, |
| ) -> torch.Tensor: |
| """Rotate Q/K using precomputed grouped sin/cos buffers (HALF_SPLIT layout). |
| |
| This matches AxialRoPE2DConfig(dim_layout=HALF_SPLIT) rotation and keeps |
| the learnable variant identical at initialization when combined with |
| axial-compatible frequency init. |
| |
| Args: |
| x: Tensor shaped ``(B, H, T, D)``. |
| sin: Sin tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``. |
| cos: Cos tensor shaped ``(G, T, D//2)`` or ``(B, G, T, D//2)``. |
| |
| Returns: |
| Tensor with the same shape/dtype/device as ``x``. |
| """ |
| if x.dim() != 4: |
| raise ValueError("x must be shaped (B, H, T, D)") |
| B, H, T, D = x.shape |
| if self.num_heads != int(H): |
| raise ValueError("num_heads mismatch for LearnableRoPE2D") |
| if self.head_dim != int(D): |
| raise ValueError("head_dim mismatch for LearnableRoPE2D") |
|
|
| if sin.dim() == 3 and cos.dim() == 3: |
| sin = sin.unsqueeze(0) |
| cos = cos.unsqueeze(0) |
| if sin.dim() != 4 or cos.dim() != 4: |
| raise RuntimeError("Unexpected sin/cos rank for LearnableRoPE2D") |
| if int(D) % 2 != 0: |
| raise RuntimeError("LearnableRoPE2D requires even head_dim for HALF_SPLIT") |
| half = int(D) // 2 |
| if int(sin.shape[-1]) != half or int(cos.shape[-1]) != half: |
| raise RuntimeError( |
| "LearnableRoPE2D expected sin/cos last dim == head_dim//2 " |
| f"(got sin={tuple(sin.shape)}, cos={tuple(cos.shape)}, head_dim={int(D)})" |
| ) |
|
|
| sin = sin[:, :, None, :, :] |
| cos = cos[:, :, None, :, :] |
|
|
| grouped = x.reshape( |
| int(B), |
| int(self.num_groups), |
| int(self.heads_per_group), |
| int(T), |
| int(D), |
| ) |
| x1 = grouped[..., :half] |
| x2 = grouped[..., half:] |
| out1 = x1 * cos - x2 * sin |
| out2 = x2 * cos + x1 * sin |
| out = torch.cat((out1, out2), dim=-1).reshape(int(B), int(H), int(T), int(D)) |
| return out.to(dtype=x.dtype) |
|
|
| def _compute_mixed_cis(self, xy: torch.Tensor) -> torch.Tensor: |
| |
| if xy.dim() == 2: |
| |
| t_x = xy[:, 0].to(dtype=torch.float32) |
| t_y = xy[:, 1].to(dtype=torch.float32) |
| with torch.autocast(device_type=t_x.device.type, enabled=False): |
| |
| |
| |
| |
| angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze( |
| 0 |
| ) |
| angles.add_( |
| t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0) |
| ) |
| angles = angles.permute(1, 0, 2) |
| cis = torch.polar( |
| torch.ones((), device=angles.device, dtype=angles.dtype), angles |
| ) |
| return cis |
| elif xy.dim() == 3: |
| |
| t_x = xy[..., 0].to(dtype=torch.float32) |
| t_y = xy[..., 1].to(dtype=torch.float32) |
| with torch.autocast(device_type=t_x.device.type, enabled=False): |
| angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze( |
| 0 |
| ).unsqueeze(0) |
| angles.add_( |
| t_y.unsqueeze(-1).unsqueeze(-1) |
| * self.freqs[1].unsqueeze(0).unsqueeze(0) |
| ) |
| angles = angles.permute(0, 2, 1, 3) |
| cis = torch.polar( |
| torch.ones((), device=angles.device, dtype=angles.dtype), angles |
| ) |
| return cis |
| else: |
| raise ValueError("xy must have shape (T,2) or (B,T,2)") |
|
|
| def _compute_mixed_angles(self, xy: torch.Tensor) -> torch.Tensor: |
| """Return mixed RoPE2D angles without applying cis/polar. |
| |
| Args: |
| xy: XY positions shaped ``(T, 2)`` or ``(B, T, 2)``. |
| |
| Returns: |
| Float tensor of angles shaped ``(G, T, half)`` or ``(B, G, T, half)``. |
| """ |
| if xy.dim() == 2: |
| t_x = xy[:, 0].to(dtype=torch.float32) |
| t_y = xy[:, 1].to(dtype=torch.float32) |
| with torch.autocast(device_type=t_x.device.type, enabled=False): |
| angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze(0) |
| angles.add_( |
| t_y.unsqueeze(-1).unsqueeze(-1) * self.freqs[1].unsqueeze(0) |
| ) |
| return angles.permute(1, 0, 2) |
| if xy.dim() == 3: |
| t_x = xy[..., 0].to(dtype=torch.float32) |
| t_y = xy[..., 1].to(dtype=torch.float32) |
| with torch.autocast(device_type=t_x.device.type, enabled=False): |
| angles = t_x.unsqueeze(-1).unsqueeze(-1) * self.freqs[0].unsqueeze( |
| 0 |
| ).unsqueeze(0) |
| angles.add_( |
| t_y.unsqueeze(-1).unsqueeze(-1) |
| * self.freqs[1].unsqueeze(0).unsqueeze(0) |
| ) |
| return angles.permute(0, 2, 1, 3) |
| raise ValueError("xy must have shape (T,2) or (B,T,2)") |
|
|
| def _cos_sin_half_from_xy( |
| self, |
| xy: torch.Tensor, |
| *, |
| device: torch.device | None = None, |
| out_dtype: torch.dtype | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| cis = self._compute_mixed_cis(xy.to(device=device) if device else xy) |
| |
| if cis.is_complex(): |
| cos_h = cis.real |
| sin_h = cis.imag |
| else: |
| |
| raise RuntimeError("Expected complex cis tensor from polar") |
| if out_dtype is not None: |
| cos_h = cos_h.to(dtype=out_dtype) |
| sin_h = sin_h.to(dtype=out_dtype) |
| return cos_h, sin_h |
|
|
| def _cos_sin_from_xy( |
| self, |
| xy: torch.Tensor, |
| *, |
| device: torch.device | None = None, |
| out_dtype: torch.dtype | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| cos_h, sin_h = self._cos_sin_half_from_xy( |
| xy, device=device, out_dtype=out_dtype |
| ) |
| emb_cos = torch.cat((cos_h, cos_h), dim=-1) |
| emb_sin = torch.cat((sin_h, sin_h), dim=-1) |
| return emb_cos, emb_sin |
|
|
| def rotate_qk( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| xy: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if q.dim() != 4 or k.dim() != 4: |
| raise ValueError("q,k must be shaped (B,H,T,D)") |
| _, H, _, D = q.shape |
| if self.num_heads != H: |
| raise ValueError("num_heads mismatch for LearnableRoPE2D") |
| if self.head_dim != D: |
| raise ValueError("head_dim mismatch for LearnableRoPE2D") |
| if D % 4 != 0: |
| raise AssertionError("head_dim must be divisible by 4 for mixed 2D RoPE") |
|
|
| |
| |
| angles = self._compute_mixed_angles(xy.to(device=q.device)) |
| sin = torch.sin(angles) |
| cos = torch.cos(angles) |
| q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos) |
| k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos) |
| return q_out, k_out |
|
|
| def rotate_qk_with_dilation( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| *, |
| xy: torch.Tensor, |
| scales: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Rotate Q/K using mixed 2D RoPE with per-sample isotropic dilation. |
| |
| This implements dilation by scaling the RoPE angle, i.e. |
| ``theta_dilated = scale * theta_base`` where ``theta_base`` comes from the |
| undilated XY coordinates. |
| |
| Args: |
| q: Query tensor shaped ``(B, H, T, D)``. |
| k: Key tensor shaped ``(B, H, T, D)``. |
| xy: Base XY coordinates shaped ``(T, 2)`` or ``(B, T, 2)``. |
| scales: Per-sample dilation scales shaped ``(B,)``. |
| |
| Raises: |
| ValueError: If shapes are inconsistent or scales are not 1D. |
| """ |
| if q.dim() != 4 or k.dim() != 4: |
| raise ValueError("q,k must be shaped (B,H,T,D)") |
| B, H, T, D = q.shape |
| if self.num_heads != H: |
| raise ValueError("num_heads mismatch for LearnableRoPE2D") |
| if self.head_dim != D: |
| raise ValueError("head_dim mismatch for LearnableRoPE2D") |
| if scales.dim() != 1 or scales.shape[0] != B: |
| raise ValueError("scales must have shape (B,) matching q batch size") |
| if xy.dim() == 2 and xy.shape[0] != T: |
| raise ValueError("xy length must match q sequence length") |
| if xy.dim() == 3 and (xy.shape[0] != B or xy.shape[1] != T): |
| raise ValueError("xy must have shape (B,T,2) matching q batch/sequence") |
| if xy.shape[-1] != 2: |
| raise ValueError("xy must have last dimension 2") |
|
|
| angles = self._compute_mixed_angles(xy.to(device=q.device)) |
| angles = angles * scales.to(device=q.device, dtype=torch.float32).view( |
| B, 1, 1, 1 |
| ) |
| sin = torch.sin(angles) |
| cos = torch.cos(angles) |
| q_out = self._apply_rotary_from_trig(q, sin=sin, cos=cos) |
| k_out = self._apply_rotary_from_trig(k, sin=sin, cos=cos) |
| return q_out, k_out |
|
|