| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass, replace |
| from enum import Enum |
| from typing import Final, cast |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| __all__ = [ |
| "AxialRoPE2D", |
| "AxialRoPE2DAlphaWarpConfig", |
| "AxialRoPE2DBetaWarpConfig", |
| "AxialRoPE2DConfig", |
| "AxialRoPE2DCoordMode", |
| "AxialRoPE2DDimLayout", |
| "AxialRoPE2DDyPE", |
| "AxialRoPE2DDyPEConfig", |
| "AxialRoPE2DFrequencyAwareConfig", |
| "AxialRoPE2DNormalizeCoords", |
| "DyPERoPEMethod", |
| "build_axial_rope2d_dype", |
| "build_axial_rope2d_inference_warp_with_strength", |
| "build_axial_rope2d_with_lumina_frequency_warp", |
| "lumina_frequency_aware_periods_for_axis", |
| "set_axial_rope2d_dype_noise_time", |
| ] |
|
|
|
|
| class AxialRoPE2DNormalizeCoords(Enum): |
| """Coordinate normalization strategy for axial 2D RoPE (DINOv3-style).""" |
|
|
| MIN = "min" |
| MAX = "max" |
| SEPARATE = "separate" |
|
|
|
|
| class AxialRoPE2DCoordMode(Enum): |
| """Coordinate grid mode for axial 2D RoPE. |
| |
| - ``DINOV3_NORMALIZED``: DINOv3-style normalized patch-centre coordinates in |
| ``[-1, 1]`` (after normalization). |
| - ``PATCH_INDICES``: Standard unnormalized patch-grid coordinates in patch |
| units (e.g., ``x in [0, W-1]``, ``y in [0, H-1]``). |
| """ |
|
|
| DINOV3_NORMALIZED = "dinov3_normalized" |
| PATCH_INDICES = "patch_indices" |
|
|
|
|
| class AxialRoPE2DDimLayout(Enum): |
| """Layout of angles along the head-dimension. |
| |
| The layout must match the rotation convention used when applying RoPE to Q/K. |
| |
| - ``HALF_SPLIT``: LLaMA-style layout compatible with ``common.rope.rotate_half`` |
| (splits last dim into two halves). |
| - ``PAIR_INTERLEAVED``: EVA-02 / SpeedrunDiT-style layout compatible with an |
| adjacent-pair rotate_half (pairs consecutive dims). |
| |
| TODO(refactor): Standardize on ``PAIR_INTERLEAVED`` throughout DiT to reduce |
| complexity and avoid layout mismatches, then delete ``HALF_SPLIT`` and any |
| related branching once the migration is complete. |
| """ |
|
|
| HALF_SPLIT = "half_split" |
| PAIR_INTERLEAVED = "pair_interleaved" |
|
|
|
|
| class DyPERoPEMethod(Enum): |
| """Dynamic position extrapolation method applied to inference RoPE.""" |
|
|
| VISION_YARN = "vision_yarn" |
| DY_YARN = "dy_yarn" |
| DY_NTK = "dy_ntk" |
|
|
|
|
| @dataclass(frozen=True) |
| class AxialRoPE2DDyPEConfig: |
| """Inference-only DyPE controls for axial RoPE. |
| |
| Args: |
| method: Dynamic extrapolation rule to apply. |
| ref_h_tokens: Training/reference token height. |
| ref_w_tokens: Training/reference token width. |
| lambda_s: Dynamic extrapolation magnitude. |
| lambda_t: Dynamic extrapolation noise-time exponent. |
| yarn_beta_0: YaRN first-ramp high rotation threshold. |
| yarn_beta_1: YaRN first-ramp low rotation threshold. |
| yarn_gamma_0: YaRN base-blend high rotation threshold. |
| yarn_gamma_1: YaRN base-blend low rotation threshold. |
| yarn_attention_scale: Apply YaRN's static attention magnitude correction. |
| """ |
|
|
| method: DyPERoPEMethod |
| ref_h_tokens: int |
| ref_w_tokens: int |
| lambda_s: float = 2.0 |
| lambda_t: float = 2.0 |
| yarn_beta_0: float = 1.25 |
| yarn_beta_1: float = 0.75 |
| yarn_gamma_0: float = 16.0 |
| yarn_gamma_1: float = 2.0 |
| yarn_attention_scale: bool = True |
|
|
| def __post_init__(self) -> None: |
| if not isinstance(self.method, DyPERoPEMethod): |
| raise TypeError("method must be a DyPERoPEMethod") |
| if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0: |
| raise ValueError("ref_h_tokens and ref_w_tokens must be positive") |
| for name, value in ( |
| ("lambda_s", self.lambda_s), |
| ("lambda_t", self.lambda_t), |
| ("yarn_beta_0", self.yarn_beta_0), |
| ("yarn_beta_1", self.yarn_beta_1), |
| ("yarn_gamma_0", self.yarn_gamma_0), |
| ("yarn_gamma_1", self.yarn_gamma_1), |
| ): |
| v = float(value) |
| if not math.isfinite(v) or v <= 0.0: |
| raise ValueError(f"{name} must be finite and > 0") |
| if not isinstance(self.yarn_attention_scale, bool): |
| raise TypeError("yarn_attention_scale must be a bool") |
|
|
|
|
| @dataclass(frozen=True) |
| class AxialRoPE2DFrequencyAwareConfig: |
| """Lumina/Next-DiT-style frequency-aware RoPE warping for one token grid. |
| |
| This config implements a per-axis, per-band frequency warp that depends on |
| the input axis length ``L`` relative to a reference length ``L_ref``: |
| |
| - Define the axis scale ``s = L / L_ref``. |
| - RoPE is parameterized by *periods* (wavelengths in tokens) ``period[d]``. |
| In this module's axial parameterization (with patch-index coordinates), |
| the angle for coordinate ``p`` and band ``d`` is: |
| |
| angle(p, d) = 2π * p / period[d] |
| |
| so the wavelength of band ``d`` is exactly ``period[d]`` tokens. |
| |
| - Pick a *boundary wavelength* ``L_boundary`` (in tokens), expressed as a |
| trainable multiplier around the reference length: |
| |
| L_boundary = L_ref * exp(boundary_log_multiplier) |
| |
| The scalar ``boundary_log_multiplier`` is shared across H/W axes (and |
| initialized by this config). |
| |
| - Define a (possibly fractional) boundary band index ``d*`` as the band |
| whose wavelength equals ``L_boundary``: |
| |
| period(d*) = L_boundary |
| |
| In practice we compute ``d*`` by linear interpolation in log-period space |
| (periods are geometric for both supported period parametrizations). |
| |
| - The Lumina/Next-DiT implicit exponent ramp is then: |
| |
| alpha[d] = clamp(d / d*, 0, 1) |
| |
| where: |
| - high-frequency bands (small d) have alpha≈0 (extrapolation-like), |
| - low-frequency bands (large d) have alpha→1 (interpolation-like), |
| - alpha is capped at 1 to ensure we never compress a band more than |
| plain position interpolation would. |
| |
| - Finally, warp the periods per axis: |
| |
| period'[d] = period[d] * s ** alpha[d] |
| |
| Equivalently, angular frequencies warp as: |
| |
| omega'[d] = omega[d] / s ** alpha[d] |
| |
| Notes |
| ----- |
| - This warp is only meaningful for patch-index coordinates |
| (``AxialRoPE2DCoordMode.PATCH_INDICES``). Mixing it with normalized |
| coordinates would create an implicit "gauge switch"; we fail fast. |
| - The boundary multiplier is trainable by construction (it is stored as an |
| nn.Parameter inside AxialRoPE2D when this config is present). |
| """ |
|
|
| ref_h_tokens: int |
| ref_w_tokens: int |
| boundary_log_multiplier_init: float |
|
|
| def __post_init__(self) -> None: |
| if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0: |
| raise ValueError("ref_h_tokens and ref_w_tokens must be positive") |
| init = float(self.boundary_log_multiplier_init) |
| if not math.isfinite(init): |
| raise ValueError("boundary_log_multiplier_init must be finite") |
|
|
|
|
| @dataclass(frozen=True) |
| class AxialRoPE2DBetaWarpConfig: |
| """Trainable beta-curve warping for axial 2D RoPE periods (per token grid). |
| |
| This config defines a per-axis period warp that depends on the runtime axis |
| length ``L`` relative to a reference length ``L_ref``: |
| |
| s = L / L_ref |
| period'[d] = period[d] * s ** beta[d] |
| |
| where the per-band exponent curve beta(d) is parameterized by three |
| trainable u-space scalars (shared across H/W axes): |
| |
| beta_hi = beta_max * tanh(beta_hi_u) (high-frequency endpoint, d=0) |
| beta_lo = beta_max * tanh(beta_lo_u) (low-frequency endpoint, d=qtr-1) |
| beta_bend = beta_max * tanh(beta_bend_u) (mid-band bump amplitude) |
| |
| and the per-band curve is: |
| |
| t = d / (qtr - 1) in [0, 1] |
| beta(t) = lerp(beta_hi, beta_lo, t) + beta_bend * 4*t*(1-t) |
| |
| Interpretation |
| -------------- |
| - ``beta(d) == 0``: identity / "extrapolation-like" (no warping; periods do not |
| change with axis length). |
| - ``beta(d) == 1``: position-interpolation-like for that band |
| (``period'[d] = period[d] * s`` so ``omega'[d] = omega[d] / s``). |
| |
| This parameterization provides strong and smooth control over the effective |
| scaling of each frequency band, including allowing beta<0 (increasing |
| frequencies when s>1), which can be important for unnormalized RoPE bases |
| (e.g. base=10_000) where some very low-frequency bands barely rotate on |
| practical token grids. |
| |
| Notes: |
| - This warp requires patch-index coordinates (coord_mode=PATCH_INDICES). |
| - The u parameters are stored as nn.Parameter inside AxialRoPE2D when this |
| config is present. |
| """ |
|
|
| ref_h_tokens: int |
| ref_w_tokens: int |
| beta_max: float |
| beta_hi_u_init: float |
| beta_lo_u_init: float |
| beta_bend_u_init: float |
|
|
| def __post_init__(self) -> None: |
| if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0: |
| raise ValueError("ref_h_tokens and ref_w_tokens must be positive") |
| bmax = float(self.beta_max) |
| if not math.isfinite(bmax) or bmax <= 0.0: |
| raise ValueError("beta_max must be finite and > 0") |
| for name, value in ( |
| ("beta_hi_u_init", self.beta_hi_u_init), |
| ("beta_lo_u_init", self.beta_lo_u_init), |
| ("beta_bend_u_init", self.beta_bend_u_init), |
| ): |
| v = float(value) |
| if not math.isfinite(v): |
| raise ValueError(f"{name} must be finite") |
|
|
|
|
| @dataclass(frozen=True) |
| class AxialRoPE2DAlphaWarpConfig: |
| """Per-band power-law warping of axial 2D RoPE frequencies (shared across axes). |
| |
| This config warps RoPE frequencies per band using a learned exponent vector |
| ``alpha[d]`` shared across H/W axes: |
| |
| f'[d] = f[d] * s ** alpha[d] where s = L / L_ref |
| |
| Since this module parameterizes angles via periods ``period[d]`` with |
| ``f[d] ∝ 1 / period[d]``, the equivalent period warp implemented in AxialRoPE2D is: |
| |
| period'[d] = period[d] / s ** alpha[d] |
| |
| Notes: |
| - This warp requires patch-index coordinates (coord_mode=PATCH_INDICES). |
| - ``alpha`` is stored as an unconstrained nn.Parameter vector of length Q |
| (bands per axis), initialized to ``alpha_init`` for all bands. |
| """ |
|
|
| ref_h_tokens: int |
| ref_w_tokens: int |
| alpha_init: float |
|
|
| def __post_init__(self) -> None: |
| if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0: |
| raise ValueError("ref_h_tokens and ref_w_tokens must be positive") |
| init = float(self.alpha_init) |
| if not math.isfinite(init): |
| raise ValueError("alpha_init must be finite") |
|
|
|
|
| @dataclass(frozen=True) |
| class AxialRoPE2DConfig: |
| """Configuration for axial 2D RoPE sin/cos generation. |
| |
| This module supports two coordinate conventions via ``coord_mode``: |
| - ``DINOV3_NORMALIZED``: DINOv3-style normalized patch-centre coordinates in |
| ``[-1, 1]`` (after normalization). |
| - ``PATCH_INDICES``: Standard unnormalized patch-grid coordinates in patch |
| units (e.g., ``x in [0, W-1]``). |
| |
| Period parametrization |
| ---------------------- |
| The periods parametrization matches DINOv3: |
| - Provide either `base` (and leave `min_period/max_period` unset), or |
| - Provide both `min_period` and `max_period` (and set `base=None`). |
| """ |
|
|
| base: float | None = 100.0 |
| min_period: float | None = None |
| max_period: float | None = None |
| coord_mode: AxialRoPE2DCoordMode = AxialRoPE2DCoordMode.DINOV3_NORMALIZED |
| normalize_coords: AxialRoPE2DNormalizeCoords = AxialRoPE2DNormalizeCoords.MAX |
| dim_layout: AxialRoPE2DDimLayout = AxialRoPE2DDimLayout.HALF_SPLIT |
| angle_multiplier: float = 2.0 * float(math.pi) |
| coord_offset: float = 0.5 |
| frequency_aware: AxialRoPE2DFrequencyAwareConfig | None = None |
| beta_warp: AxialRoPE2DBetaWarpConfig | None = None |
| alpha_warp: AxialRoPE2DAlphaWarpConfig | None = None |
|
|
| def __post_init__(self) -> None: |
| both_periods = self.min_period is not None and self.max_period is not None |
| if (self.base is None and not both_periods) or ( |
| self.base is not None and both_periods |
| ): |
| raise ValueError( |
| "AxialRoPE2DConfig requires either base!=None, or both min_period and max_period." |
| ) |
| if self.base is not None and float(self.base) <= 0.0: |
| raise ValueError("AxialRoPE2DConfig.base must be positive when provided") |
| if self.min_period is not None and float(self.min_period) <= 0.0: |
| raise ValueError( |
| "AxialRoPE2DConfig.min_period must be positive when provided" |
| ) |
| if self.max_period is not None and float(self.max_period) <= 0.0: |
| raise ValueError( |
| "AxialRoPE2DConfig.max_period must be positive when provided" |
| ) |
| if self.min_period is not None and self.max_period is not None: |
| if float(self.max_period) <= float(self.min_period): |
| raise ValueError("AxialRoPE2DConfig.max_period must be > min_period") |
| if not isinstance(self.coord_mode, AxialRoPE2DCoordMode): |
| raise TypeError( |
| "AxialRoPE2DConfig.coord_mode must be an AxialRoPE2DCoordMode" |
| ) |
| if not isinstance(self.normalize_coords, AxialRoPE2DNormalizeCoords): |
| raise TypeError( |
| "AxialRoPE2DConfig.normalize_coords must be an AxialRoPE2DNormalizeCoords" |
| ) |
| if not isinstance(self.dim_layout, AxialRoPE2DDimLayout): |
| raise TypeError( |
| "AxialRoPE2DConfig.dim_layout must be an AxialRoPE2DDimLayout" |
| ) |
| mult = float(self.angle_multiplier) |
| if not math.isfinite(mult) or mult <= 0.0: |
| raise ValueError( |
| "AxialRoPE2DConfig.angle_multiplier must be finite and > 0" |
| ) |
| off = float(self.coord_offset) |
| if not math.isfinite(off): |
| raise ValueError("AxialRoPE2DConfig.coord_offset must be finite") |
| if self.frequency_aware is not None and not isinstance( |
| self.frequency_aware, AxialRoPE2DFrequencyAwareConfig |
| ): |
| raise TypeError( |
| "AxialRoPE2DConfig.frequency_aware must be an AxialRoPE2DFrequencyAwareConfig" |
| ) |
| if self.beta_warp is not None and not isinstance( |
| self.beta_warp, AxialRoPE2DBetaWarpConfig |
| ): |
| raise TypeError( |
| "AxialRoPE2DConfig.beta_warp must be an AxialRoPE2DBetaWarpConfig" |
| ) |
| if self.alpha_warp is not None and not isinstance( |
| self.alpha_warp, AxialRoPE2DAlphaWarpConfig |
| ): |
| raise TypeError( |
| "AxialRoPE2DConfig.alpha_warp must be an AxialRoPE2DAlphaWarpConfig" |
| ) |
| warp_count = ( |
| int(self.frequency_aware is not None) |
| + int(self.beta_warp is not None) |
| + int(self.alpha_warp is not None) |
| ) |
| if warp_count > 1: |
| raise ValueError( |
| "AxialRoPE2DConfig requires at most one of frequency_aware, beta_warp, or alpha_warp" |
| ) |
| if self.frequency_aware is not None and ( |
| self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES |
| ): |
| raise ValueError( |
| "AxialRoPE2D frequency-aware warping requires coord_mode=PATCH_INDICES" |
| ) |
| if self.beta_warp is not None and ( |
| self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES |
| ): |
| raise ValueError("AxialRoPE2D beta warp requires coord_mode=PATCH_INDICES") |
| if self.alpha_warp is not None and ( |
| self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES |
| ): |
| raise ValueError("AxialRoPE2D alpha warp requires coord_mode=PATCH_INDICES") |
|
|
|
|
| _AXIAL_COORDS_CACHE: dict[ |
| tuple[ |
| int, int, torch.device, AxialRoPE2DCoordMode, AxialRoPE2DNormalizeCoords, float |
| ], |
| Tensor, |
| ] = {} |
|
|
|
|
| def _get_dinov3_normalized_coords( |
| H: int, |
| W: int, |
| *, |
| device: torch.device, |
| normalize: AxialRoPE2DNormalizeCoords, |
| offset: float, |
| ) -> Tensor: |
| """Return DINOv3-style flattened coords in [-1, 1] with shape [HW, 2].""" |
| if H <= 0 or W <= 0: |
| raise ValueError("H and W must be positive for axial RoPE coords") |
| key = ( |
| int(H), |
| int(W), |
| device, |
| AxialRoPE2DCoordMode.DINOV3_NORMALIZED, |
| normalize, |
| float(offset), |
| ) |
| cached = _AXIAL_COORDS_CACHE.get(key) |
| if cached is not None: |
| return cached |
| start = float(offset) |
| end_h = start + float(int(H)) |
| end_w = start + float(int(W)) |
| match normalize: |
| case AxialRoPE2DNormalizeCoords.MAX: |
| denom = float(max(int(H), int(W))) |
| coords_h = ( |
| torch.arange(start, end_h, device=device, dtype=torch.float32) / denom |
| ) |
| coords_w = ( |
| torch.arange(start, end_w, device=device, dtype=torch.float32) / denom |
| ) |
| case AxialRoPE2DNormalizeCoords.MIN: |
| denom = float(min(int(H), int(W))) |
| coords_h = ( |
| torch.arange(start, end_h, device=device, dtype=torch.float32) / denom |
| ) |
| coords_w = ( |
| torch.arange(start, end_w, device=device, dtype=torch.float32) / denom |
| ) |
| case AxialRoPE2DNormalizeCoords.SEPARATE: |
| coords_h = torch.arange( |
| start, end_h, device=device, dtype=torch.float32 |
| ) / float(int(H)) |
| coords_w = torch.arange( |
| start, end_w, device=device, dtype=torch.float32 |
| ) / float(int(W)) |
| case _ as unreachable: |
| raise RuntimeError(f"Unsupported normalize_coords: {unreachable}") |
| coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) |
| coords = coords.flatten(0, 1) |
| coords = 2.0 * coords - 1.0 |
| |
| |
| if torch.compiler.is_compiling(): |
| return coords |
| if torch.is_inference_mode_enabled(): |
| return coords |
| _AXIAL_COORDS_CACHE[key] = coords |
| return coords |
|
|
|
|
| def _get_patch_index_coords( |
| H: int, |
| W: int, |
| *, |
| device: torch.device, |
| offset: float, |
| ) -> Tensor: |
| """Return unnormalized patch-grid coords with shape [HW, 2] and (y, x) columns.""" |
| if H <= 0 or W <= 0: |
| raise ValueError("H and W must be positive for axial RoPE coords") |
| key = ( |
| int(H), |
| int(W), |
| device, |
| AxialRoPE2DCoordMode.PATCH_INDICES, |
| AxialRoPE2DNormalizeCoords.MAX, |
| float(offset), |
| ) |
| cached = _AXIAL_COORDS_CACHE.get(key) |
| if cached is not None: |
| return cached |
| start = float(offset) |
| end_h = start + float(int(H)) |
| end_w = start + float(int(W)) |
| coords_h = torch.arange(start, end_h, device=device, dtype=torch.float32) |
| coords_w = torch.arange(start, end_w, device=device, dtype=torch.float32) |
| coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) |
| coords = coords.flatten(0, 1) |
| if torch.compiler.is_compiling(): |
| return coords |
| if torch.is_inference_mode_enabled(): |
| return coords |
| _AXIAL_COORDS_CACHE[key] = coords |
| return coords |
|
|
|
|
| def _lumina_boundary_band_index( |
| *, |
| periods: Tensor, |
| boundary_wavelength: Tensor, |
| ) -> Tensor: |
| """Return the fractional boundary band index d* for a given boundary wavelength. |
| |
| This implements the Lumina/Next-DiT definition: |
| |
| period(d*) = boundary_wavelength |
| |
| We compute d* by linear interpolation in log-period space. For the supported |
| period parameterizations, periods are geometric and log(period) is linear in |
| band index. |
| |
| Args: |
| periods: 1D float tensor of length Q containing monotonically increasing |
| periods in tokens. |
| boundary_wavelength: Scalar positive float tensor giving the desired |
| boundary wavelength in tokens. |
| |
| Returns: |
| Scalar float32 tensor giving the (possibly fractional) boundary index d*. |
| |
| Raises: |
| ValueError: If periods are invalid or the boundary is outside valid range |
| for a well-defined positive d*. |
| """ |
| if periods.dim() != 1: |
| raise ValueError("periods must be 1D for boundary band index") |
| if int(periods.numel()) < 2: |
| raise ValueError("periods must have length >= 2 for boundary band index") |
| if boundary_wavelength.dim() != 0: |
| raise ValueError("boundary_wavelength must be a scalar tensor") |
| if not torch.isfinite(boundary_wavelength).item(): |
| raise ValueError("boundary_wavelength must be finite") |
| if float(boundary_wavelength.item()) <= 0.0: |
| raise ValueError("boundary_wavelength must be > 0") |
|
|
| periods_f = periods.to(dtype=torch.float32) |
| if not torch.isfinite(periods_f).all().item(): |
| raise ValueError("periods must be finite for boundary band index") |
| if float(periods_f[0].item()) <= 0.0: |
| raise ValueError("periods must be positive for boundary band index") |
| if not (periods_f[1:] > periods_f[:-1]).all().item(): |
| raise ValueError("periods must be strictly increasing for boundary band index") |
|
|
| log_p0 = torch.log(periods_f[0]) |
| log_p1 = torch.log(periods_f[-1]) |
| denom = log_p1 - log_p0 |
| if float(denom.item()) <= 0.0: |
| raise ValueError("Invalid periods range for boundary band index") |
| log_boundary = torch.log(boundary_wavelength.to(dtype=torch.float32)) |
| q = int(periods_f.numel()) |
| d_star = (float(q - 1) * (log_boundary - log_p0)) / denom |
| if not torch.isfinite(d_star).item(): |
| raise ValueError("Computed non-finite boundary band index d*") |
| if float(d_star.item()) <= 0.0: |
| raise ValueError( |
| "Boundary wavelength implies d* <= 0; increase the boundary wavelength " |
| "(or its multiplier) to be >= the wavelength of the first non-zero band." |
| ) |
| return d_star |
|
|
|
|
| def _lumina_alpha_ramp( |
| *, |
| qtr: int, |
| d_star: Tensor, |
| device: torch.device, |
| ) -> Tensor: |
| """Return alpha[d] = clamp(d / d*, 0, 1) for d in [0, qtr). |
| |
| Args: |
| qtr: Number of RoPE bands per axis (Q). |
| d_star: Scalar positive float tensor boundary index d*. |
| device: Device for the returned alpha tensor. |
| |
| Returns: |
| Float32 tensor of shape [Q] with values in [0, 1]. |
| """ |
| if int(qtr) <= 0: |
| raise ValueError("qtr must be positive for alpha ramp") |
| if d_star.dim() != 0: |
| raise ValueError("d_star must be a scalar tensor for alpha ramp") |
| if float(d_star.item()) <= 0.0: |
| raise ValueError("d_star must be > 0 for alpha ramp") |
| d = torch.arange(int(qtr), device=device, dtype=torch.float32) |
| alpha = d / d_star.to(device=device, dtype=torch.float32) |
| return torch.clamp(alpha, min=0.0, max=1.0) |
|
|
|
|
| def lumina_frequency_aware_periods_for_axis( |
| *, |
| periods: Tensor, |
| axis_len: int, |
| ref_axis_len: int, |
| boundary_log_multiplier: Tensor, |
| angle_multiplier: float, |
| ) -> Tensor: |
| """Return Lumina/Next-DiT frequency-aware warped periods for one axis. |
| |
| Implements: |
| s = axis_len / ref_axis_len |
| L_boundary = ref_axis_len * exp(boundary_log_multiplier) |
| d* = boundary band index where period(d*) = L_boundary |
| alpha[d] = clamp(d / d*, 0, 1) |
| period'[d] = period[d] * s**alpha[d] |
| |
| Notes on ``angle_multiplier`` |
| ----------------------------- |
| This module parameterizes angles as: |
| |
| angle(p, d) = angle_multiplier * p / period[d] |
| |
| The *wavelength* (period in tokens) is the delta in ``p`` that increases |
| the angle by ``2π``: |
| |
| wavelength[d] = 2π * period[d] / angle_multiplier |
| |
| Lumina/Next-DiT define the boundary by matching *wavelength* to the |
| reference axis length. We therefore convert the boundary wavelength |
| ``L_boundary`` into a boundary period via: |
| |
| period_boundary = (angle_multiplier / 2π) * L_boundary |
| |
| When ``angle_multiplier == 2π`` (the DINOv3-style parameterization), this |
| reduces to ``period_boundary == L_boundary``. |
| |
| Args: |
| periods: Base periods ``[Q]`` in tokens (wavelengths). |
| axis_len: Input axis length ``L`` in tokens. |
| ref_axis_len: Reference axis length ``L_ref`` in tokens. |
| boundary_log_multiplier: Scalar tensor; shared trainable log-multiplier. |
| angle_multiplier: RoPE angle multiplier used when converting periods to |
| physical wavelengths in tokens. |
| |
| Returns: |
| Warped periods ``[Q]`` as float32. |
| |
| Raises: |
| ValueError: If inputs are malformed or imply an invalid boundary index. |
| """ |
| if int(axis_len) <= 0: |
| raise ValueError("axis_len must be positive for frequency-aware periods") |
| if int(ref_axis_len) <= 0: |
| raise ValueError("ref_axis_len must be positive for frequency-aware periods") |
| if boundary_log_multiplier.dim() != 0: |
| raise ValueError("boundary_log_multiplier must be a scalar tensor") |
| if not torch.isfinite(boundary_log_multiplier).item(): |
| raise ValueError("boundary_log_multiplier must be finite") |
| mult = float(angle_multiplier) |
| if not math.isfinite(mult) or mult <= 0.0: |
| raise ValueError("angle_multiplier must be finite and > 0") |
|
|
| device = periods.device |
| qtr = int(periods.numel()) |
| s = float(int(axis_len)) / float(int(ref_axis_len)) |
| if not math.isfinite(s) or s <= 0.0: |
| raise ValueError("axis_len/ref_axis_len must be finite and > 0") |
| boundary_wavelength = float(int(ref_axis_len)) * torch.exp( |
| boundary_log_multiplier.to(device=device, dtype=torch.float32) |
| ) |
| boundary_period = (mult / (2.0 * float(math.pi))) * boundary_wavelength |
| d_star = _lumina_boundary_band_index( |
| periods=periods, boundary_wavelength=boundary_period |
| ) |
| alpha = _lumina_alpha_ramp(qtr=qtr, d_star=d_star, device=device) |
| scale = torch.pow(torch.tensor(s, device=device, dtype=torch.float32), alpha) |
| return periods.to(device=device, dtype=torch.float32) * scale |
|
|
|
|
| def build_axial_rope2d_with_lumina_frequency_warp( |
| base: AxialRoPE2D, |
| *, |
| ref_h_tokens: int, |
| ref_w_tokens: int, |
| boundary_log_multiplier: float | None, |
| boundary_band_multiplier: float | None, |
| ) -> AxialRoPE2D: |
| """Return an AxialRoPE2D module that applies Lumina-style frequency warping. |
| |
| This helper is intended for inference-time experimentation on checkpoints |
| that were trained without frequency-aware warping (e.g. |
| ``position_encoding=ROPE_2D_AXIAL_UNNORMALIZED``). It constructs a new |
| AxialRoPE2D instance that: |
| - Keeps the base RoPE periods and layout identical to ``base``. |
| - Applies Lumina/Next-DiT per-axis warping based on the runtime token |
| lengths ``H`` and ``W`` relative to reference lengths. |
| - Uses a fixed (non-trainable) scalar boundary multiplier for inference. |
| |
| Args: |
| base: Existing AxialRoPE2D instance from a loaded model. |
| ref_h_tokens: Reference H token length (L_ref,h). |
| ref_w_tokens: Reference W token length (L_ref,w). |
| boundary_log_multiplier: Optional log multiplier applied to reference |
| lengths to define the boundary wavelength. Use 0.0 for "boundary at |
| L_ref". Mutually exclusive with boundary_band_multiplier. |
| boundary_band_multiplier: Optional multiplier that directly selects the |
| boundary band index d* relative to the lowest-frequency band index |
| (qtr-1). Concretely, with qtr bands per axis: |
| |
| d* = boundary_band_multiplier * (qtr - 1) |
| |
| This lets you move the transition point in frequency space: |
| - smaller values => more bands become PI-like (more interpolation) |
| - larger values => fewer bands become PI-like (more extrapolation) |
| |
| When provided, we compute the implied boundary wavelength and store |
| it as boundary_log_multiplier for the module. |
| |
| Returns: |
| New AxialRoPE2D instance on the same device as ``base``. |
| |
| Raises: |
| TypeError: If base is not an AxialRoPE2D. |
| ValueError: If base uses incompatible coordinates for Lumina warping. |
| """ |
| if not isinstance(base, AxialRoPE2D): |
| raise TypeError("base must be an AxialRoPE2D") |
| if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES: |
| raise ValueError( |
| "Lumina frequency-aware warping requires coord_mode=PATCH_INDICES" |
| ) |
| if (boundary_log_multiplier is None) == (boundary_band_multiplier is None): |
| raise ValueError( |
| "Provide exactly one of boundary_log_multiplier or boundary_band_multiplier" |
| ) |
|
|
| resolved_log_multiplier: float |
| if boundary_band_multiplier is not None: |
| if int(ref_h_tokens) != int(ref_w_tokens): |
| raise ValueError( |
| "boundary_band_multiplier requires ref_h_tokens == ref_w_tokens when using a shared scalar boundary" |
| ) |
| mult = float(boundary_band_multiplier) |
| if not math.isfinite(mult) or mult <= 0.0: |
| raise ValueError("boundary_band_multiplier must be finite and > 0") |
| qtr = int(base.periods.numel()) |
| if qtr < 2: |
| raise ValueError( |
| "AxialRoPE2D periods length must be >= 2 for boundary band selection" |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| with torch.no_grad(): |
| periods_f = base.periods.to(dtype=torch.float32, device=torch.device("cpu")) |
| if not (periods_f[1:] > periods_f[:-1]).all().item(): |
| raise ValueError( |
| "base.periods must be strictly increasing for boundary band selection" |
| ) |
| log_p0 = float(torch.log(periods_f[0]).item()) |
| log_p1 = float(torch.log(periods_f[-1]).item()) |
| d_star = mult * float(qtr - 1) |
| log_boundary_period = log_p0 + (d_star / float(qtr - 1)) * (log_p1 - log_p0) |
| boundary_period = math.exp(log_boundary_period) |
| angle_mult = float(base.cfg.angle_multiplier) |
| if not math.isfinite(angle_mult) or angle_mult <= 0.0: |
| raise ValueError("base.cfg.angle_multiplier must be finite and > 0") |
| boundary_wavelength = (2.0 * float(math.pi) / angle_mult) * boundary_period |
| resolved_log_multiplier = math.log( |
| boundary_wavelength / float(int(ref_h_tokens)) |
| ) |
| else: |
| if boundary_log_multiplier is None: |
| raise RuntimeError("boundary_log_multiplier missing despite validation") |
| resolved_log_multiplier = float(boundary_log_multiplier) |
|
|
| freq_cfg = AxialRoPE2DFrequencyAwareConfig( |
| ref_h_tokens=int(ref_h_tokens), |
| ref_w_tokens=int(ref_w_tokens), |
| boundary_log_multiplier_init=resolved_log_multiplier, |
| ) |
| cfg = replace(base.cfg, frequency_aware=freq_cfg, beta_warp=None, alpha_warp=None) |
| device = base.periods.device |
| warped = AxialRoPE2D(head_dim=int(base.head_dim), cfg=cfg).to(device=device) |
| with torch.no_grad(): |
| warped.periods.copy_(base.periods.to(device=device, dtype=torch.float32)) |
| if warped.boundary_log_multiplier is None: |
| raise RuntimeError("Expected boundary_log_multiplier to be initialized") |
| warped.boundary_log_multiplier.copy_( |
| torch.tensor(resolved_log_multiplier, device=device, dtype=torch.float32) |
| ) |
| warped.boundary_log_multiplier.requires_grad_(False) |
| return warped |
|
|
|
|
| def build_axial_rope2d_inference_warp_with_strength( |
| base: AxialRoPE2D, |
| *, |
| ref_h_tokens: int, |
| ref_w_tokens: int, |
| beta_hi_u: float, |
| beta_lo_u: float, |
| beta_bend_u: float, |
| beta_max: float, |
| ) -> AxialRoPE2D: |
| """Build an inference-only RoPE warp parameterized by a 3-knob beta(t) curve. |
| |
| This helper is meant for notebook experimentation on checkpoints trained |
| with patch-index axial RoPE (e.g. ``position_encoding=ROPE_2D_AXIAL_UNNORMALIZED``). |
| |
| We warp per-axis RoPE periods (wavelengths, in tokens) as: |
| |
| period'[d] = period[d] * s ** beta[d] where s = L / L_ref |
| |
| with a smooth exponent curve beta(d) over bands. Unlike a strict |
| interpolation-only exponent (0..1), beta is allowed to be negative or > 1, |
| which is important for unnormalized RoPE (e.g. base=10_000) where some very |
| low-frequency bands are effectively "dead" on practical token grids unless |
| their frequencies can be increased (beta < 0). |
| |
| Knobs (bounded via u-space) |
| --------------------------- |
| We use three unconstrained parameters (u-space) which map to bounded beta |
| values via tanh: |
| |
| beta_hi = beta_max * tanh(beta_hi_u) (high-frequency endpoint, d=0) |
| beta_lo = beta_max * tanh(beta_lo_u) (low-frequency endpoint, d=qtr-1) |
| beta_bend = beta_max * tanh(beta_bend_u) ("bump" amplitude in the middle) |
| |
| Then define the per-band curve over t in [0,1] (high -> low frequency): |
| |
| t = d / (qtr - 1) |
| beta(t) = lerp(beta_hi, beta_lo, t) + beta_bend * 4*t*(1-t) |
| |
| The bump term is 0 at the endpoints and peaks at 1 at t=0.5. |
| |
| Notes |
| ----- |
| - This wrapper is inference-only: it is not saved in checkpoints. |
| - It requires patch-index coordinates (no normalized "gauge"). |
| - It preserves the base module's periods and layout exactly. |
| |
| Args: |
| base: Existing AxialRoPE2D instance from a loaded model. |
| ref_h_tokens: Reference H token length (L_ref,h). |
| ref_w_tokens: Reference W token length (L_ref,w). |
| beta_hi_u: Unconstrained u for beta_hi. |
| beta_lo_u: Unconstrained u for beta_lo. |
| beta_bend_u: Unconstrained u for beta_bend (mid-band bump). |
| beta_max: Maximum absolute beta value (> 0). Higher increases control. |
| |
| Returns: |
| An AxialRoPE2D instance whose forward applies the inference-only warp. |
| """ |
| if not isinstance(base, AxialRoPE2D): |
| raise TypeError("base must be an AxialRoPE2D") |
| if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES: |
| raise ValueError( |
| "Inference freq-warp requires base.cfg.coord_mode=PATCH_INDICES" |
| ) |
| if int(ref_h_tokens) <= 0 or int(ref_w_tokens) <= 0: |
| raise ValueError("ref_h_tokens and ref_w_tokens must be positive") |
| hi_u = float(beta_hi_u) |
| lo_u = float(beta_lo_u) |
| bend_u = float(beta_bend_u) |
| if not math.isfinite(hi_u): |
| raise ValueError("beta_hi_u must be finite") |
| if not math.isfinite(lo_u): |
| raise ValueError("beta_lo_u must be finite") |
| if not math.isfinite(bend_u): |
| raise ValueError("beta_bend_u must be finite") |
| bmax = float(beta_max) |
| if not math.isfinite(bmax) or bmax <= 0.0: |
| raise ValueError("beta_max must be finite and > 0") |
|
|
| class _AxialRoPE2DInferenceWarp(AxialRoPE2D): |
| """Inference-only axial RoPE variant with beta-curve knobs.""" |
|
|
| def __init__(self, *, device: torch.device) -> None: |
| super().__init__(head_dim=int(base.head_dim), cfg=base.cfg) |
| self.ref_h_tokens: Final[int] = int(ref_h_tokens) |
| self.ref_w_tokens: Final[int] = int(ref_w_tokens) |
| |
| self.register_buffer( |
| "beta_hi_u", |
| torch.tensor(float(hi_u), dtype=torch.float32), |
| persistent=False, |
| ) |
| self.register_buffer( |
| "beta_lo_u", |
| torch.tensor(float(lo_u), dtype=torch.float32), |
| persistent=False, |
| ) |
| self.register_buffer( |
| "beta_bend_u", |
| torch.tensor(float(bend_u), dtype=torch.float32), |
| persistent=False, |
| ) |
| self.register_buffer( |
| "beta_max", |
| torch.tensor(float(bmax), dtype=torch.float32), |
| persistent=False, |
| ) |
| self.to(device=device) |
| with torch.no_grad(): |
| self.periods.copy_( |
| base.periods.detach().to(device=device, dtype=torch.float32) |
| ) |
|
|
| def forward( |
| self, |
| *, |
| H: int, |
| W: int, |
| scales: Tensor | None, |
| ) -> tuple[Tensor, Tensor]: |
| if scales is not None: |
| raise ValueError("Inference freq-warp does not support dilation scales") |
| if int(H) <= 0 or int(W) <= 0: |
| raise ValueError("H and W must be positive for axial RoPE") |
| device = self.periods.device |
| offset = float(self.cfg.coord_offset) |
| coords = _get_patch_index_coords( |
| int(H), int(W), device=device, offset=offset |
| ) |
| if coords.dim() != 2 or coords.shape[1] != 2: |
| raise RuntimeError("Axial RoPE coords must have shape [HW, 2]") |
|
|
| qtr = int(self.periods.numel()) |
| if qtr <= 0: |
| raise RuntimeError("Axial RoPE periods length must be positive") |
|
|
| beta_max_t = cast("Tensor", self.beta_max).to( |
| device=device, dtype=torch.float32 |
| ) |
| beta_hi = beta_max_t * torch.tanh( |
| cast("Tensor", self.beta_hi_u).to(device=device, dtype=torch.float32) |
| ) |
| beta_lo = beta_max_t * torch.tanh( |
| cast("Tensor", self.beta_lo_u).to(device=device, dtype=torch.float32) |
| ) |
| beta_bend = beta_max_t * torch.tanh( |
| cast("Tensor", self.beta_bend_u).to(device=device, dtype=torch.float32) |
| ) |
|
|
| if qtr == 1: |
| beta = beta_hi[None] |
| else: |
| t = torch.arange(int(qtr), device=device, dtype=torch.float32) / float( |
| qtr - 1 |
| ) |
| bump = 4.0 * t * (1.0 - t) |
| beta = (1.0 - t) * beta_hi + t * beta_lo + beta_bend * bump |
|
|
| s_h = float(int(H)) / float(int(self.ref_h_tokens)) |
| s_w = float(int(W)) / float(int(self.ref_w_tokens)) |
| if ( |
| not math.isfinite(s_h) |
| or s_h <= 0.0 |
| or not math.isfinite(s_w) |
| or s_w <= 0.0 |
| ): |
| raise ValueError( |
| "H/ref_h_tokens and W/ref_w_tokens must be finite and > 0" |
| ) |
|
|
| periods_h = self.periods * torch.pow( |
| torch.tensor(s_h, device=device, dtype=torch.float32), beta |
| ) |
| periods_w = self.periods * torch.pow( |
| torch.tensor(s_w, device=device, dtype=torch.float32), beta |
| ) |
| axis_periods = torch.stack([periods_h, periods_w], dim=0) |
|
|
| angles = ( |
| float(self.cfg.angle_multiplier) |
| * coords[:, :, None].to(dtype=torch.float32) |
| / axis_periods[None, :, :].to(dtype=torch.float32) |
| ) |
| match self.cfg.dim_layout: |
| case AxialRoPE2DDimLayout.HALF_SPLIT: |
| angles = angles.flatten(1, 2).repeat(1, 2) |
| case AxialRoPE2DDimLayout.PAIR_INTERLEAVED: |
| angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2) |
| case _ as unreachable: |
| raise RuntimeError(f"Unsupported dim_layout: {unreachable}") |
| if angles.shape != (int(H) * int(W), int(self.head_dim)): |
| raise RuntimeError( |
| "Unexpected angles shape in inference freq-warp: " |
| f"{tuple(angles.shape)} for H={int(H)} W={int(W)}" |
| ) |
| return torch.sin(angles), torch.cos(angles) |
|
|
| return _AxialRoPE2DInferenceWarp(device=base.periods.device) |
|
|
|
|
| class AxialRoPE2D(nn.Module): |
| """DINOv3-style axial 2D RoPE sin/cos generator. |
| |
| The base periods are fixed by ``AxialRoPE2DConfig``. Optionally, this module |
| can include learnable scalar parameters when using: |
| - ``frequency_aware`` (boundary_log_multiplier), or |
| - ``beta_warp`` (beta_hi_u/beta_lo_u/beta_bend_u), or |
| - ``alpha_warp`` (alpha per-band exponents). |
| """ |
|
|
| periods: Tensor |
|
|
| def __init__(self, *, head_dim: int, cfg: AxialRoPE2DConfig) -> None: |
| super().__init__() |
| if int(head_dim) <= 0: |
| raise ValueError("head_dim must be positive for AxialRoPE2D") |
| if int(head_dim) % 4 != 0: |
| raise ValueError( |
| "AxialRoPE2D requires head_dim % 4 == 0 (DINOv3 constraint); " |
| f"got head_dim={int(head_dim)}" |
| ) |
| if not isinstance(cfg, AxialRoPE2DConfig): |
| raise TypeError("cfg must be an AxialRoPE2DConfig for AxialRoPE2D") |
| self.head_dim: Final[int] = int(head_dim) |
| self.cfg: Final[AxialRoPE2DConfig] = cfg |
| self._d_head: Final[int] = self.head_dim |
| self.register_buffer( |
| "periods", |
| torch.empty(self._d_head // 4, dtype=torch.float32), |
| persistent=True, |
| ) |
| if cfg.frequency_aware is None: |
| self.register_parameter("boundary_log_multiplier", None) |
| else: |
| init = float(cfg.frequency_aware.boundary_log_multiplier_init) |
| self.boundary_log_multiplier = nn.Parameter( |
| torch.tensor(init, dtype=torch.float32), |
| requires_grad=True, |
| ) |
| if cfg.beta_warp is None: |
| self.register_parameter("beta_hi_u", None) |
| self.register_parameter("beta_lo_u", None) |
| self.register_parameter("beta_bend_u", None) |
| else: |
| beta = cfg.beta_warp |
| self.beta_hi_u = nn.Parameter( |
| torch.tensor(float(beta.beta_hi_u_init), dtype=torch.float32), |
| requires_grad=True, |
| ) |
| self.beta_lo_u = nn.Parameter( |
| torch.tensor(float(beta.beta_lo_u_init), dtype=torch.float32), |
| requires_grad=True, |
| ) |
| self.beta_bend_u = nn.Parameter( |
| torch.tensor(float(beta.beta_bend_u_init), dtype=torch.float32), |
| requires_grad=True, |
| ) |
| if cfg.alpha_warp is None: |
| self.register_parameter("alpha", None) |
| else: |
| qtr = int(self._d_head) // 4 |
| if qtr <= 0: |
| raise RuntimeError("AxialRoPE2D periods length must be positive") |
| init = float(cfg.alpha_warp.alpha_init) |
| if not math.isfinite(init): |
| raise RuntimeError("alpha_init must be finite for alpha-warp RoPE") |
| self.alpha = nn.Parameter( |
| torch.full((int(qtr),), init, dtype=torch.float32), |
| requires_grad=True, |
| ) |
| self._init_periods() |
|
|
| def _apply(self, fn): |
| out = super()._apply(fn) |
| with torch.no_grad(): |
| self.periods.data = self.periods.data.to(dtype=torch.float32) |
| if self.boundary_log_multiplier is not None: |
| self.boundary_log_multiplier.data = ( |
| self.boundary_log_multiplier.data.to(dtype=torch.float32) |
| ) |
| if self.beta_hi_u is not None: |
| self.beta_hi_u.data = self.beta_hi_u.data.to(dtype=torch.float32) |
| if self.beta_lo_u is not None: |
| self.beta_lo_u.data = self.beta_lo_u.data.to(dtype=torch.float32) |
| if self.beta_bend_u is not None: |
| self.beta_bend_u.data = self.beta_bend_u.data.to(dtype=torch.float32) |
| if self.alpha is not None: |
| self.alpha.data = self.alpha.data.to(dtype=torch.float32) |
| return out |
|
|
| def _init_periods(self) -> None: |
| """Initialize per-dimension periods using DINOv3 formulas.""" |
| device: torch.device = self.periods.device |
| dtype: torch.dtype = self.periods.dtype |
| d_head = int(self._d_head) |
| qtr = d_head // 4 |
| if qtr <= 0: |
| raise RuntimeError("AxialRoPE2D periods length must be positive") |
| if self.cfg.base is not None: |
| base = float(self.cfg.base) |
| exponents = ( |
| 2.0 |
| * torch.arange(int(qtr), device=device, dtype=dtype) |
| / float(d_head // 2) |
| ) |
| periods = torch.tensor(base, device=device, dtype=dtype) ** exponents |
| else: |
| if self.cfg.min_period is None or self.cfg.max_period is None: |
| raise RuntimeError( |
| "AxialRoPE2DConfig must provide min_period and max_period when base is None" |
| ) |
| min_p = float(self.cfg.min_period) |
| max_p = float(self.cfg.max_period) |
| base = max_p / min_p |
| exponents = torch.linspace(0.0, 1.0, int(qtr), device=device, dtype=dtype) |
| periods = torch.tensor(base, device=device, dtype=dtype) ** exponents |
| periods = periods / torch.tensor(base, device=device, dtype=dtype) |
| periods = periods * torch.tensor(max_p, device=device, dtype=dtype) |
| self.periods.data = periods |
|
|
| def forward( |
| self, |
| *, |
| H: int, |
| W: int, |
| scales: Tensor | None, |
| ) -> tuple[Tensor, Tensor]: |
| """Return (sin, cos) buffers for axial 2D RoPE. |
| |
| Args: |
| H: Patch-grid height. |
| W: Patch-grid width. |
| scales: Optional per-batch dilation scale (scalar tensor). When |
| None, returns shared sin/cos shaped ``[HW, head_dim]``. When |
| provided, applies the scalar dilation and still returns shared |
| sin/cos shaped ``[HW, head_dim]``. |
| """ |
| if int(H) <= 0 or int(W) <= 0: |
| raise ValueError("H and W must be positive for AxialRoPE2D forward") |
| device = self.periods.device |
| offset = float(self.cfg.coord_offset) |
| coords: Tensor |
| match self.cfg.coord_mode: |
| case AxialRoPE2DCoordMode.DINOV3_NORMALIZED: |
| coords = _get_dinov3_normalized_coords( |
| int(H), |
| int(W), |
| device=device, |
| normalize=self.cfg.normalize_coords, |
| offset=offset, |
| ) |
| case AxialRoPE2DCoordMode.PATCH_INDICES: |
| coords = _get_patch_index_coords( |
| int(H), int(W), device=device, offset=offset |
| ) |
| case _ as unreachable: |
| raise RuntimeError(f"Unsupported coord_mode: {unreachable}") |
| if coords.dim() != 2 or coords.shape[1] != 2: |
| raise RuntimeError("AxialRoPE2D coords must have shape [HW, 2]") |
| if self.cfg.frequency_aware is not None: |
| if scales is not None: |
| raise ValueError( |
| "frequency-aware axial RoPE does not support dilation scales" |
| ) |
| if self.boundary_log_multiplier is None: |
| raise RuntimeError( |
| "boundary_log_multiplier parameter missing for frequency-aware RoPE" |
| ) |
| ref_h = int(self.cfg.frequency_aware.ref_h_tokens) |
| ref_w = int(self.cfg.frequency_aware.ref_w_tokens) |
| periods_h = lumina_frequency_aware_periods_for_axis( |
| periods=self.periods, |
| axis_len=int(H), |
| ref_axis_len=ref_h, |
| boundary_log_multiplier=self.boundary_log_multiplier, |
| angle_multiplier=float(self.cfg.angle_multiplier), |
| ) |
| periods_w = lumina_frequency_aware_periods_for_axis( |
| periods=self.periods, |
| axis_len=int(W), |
| ref_axis_len=ref_w, |
| boundary_log_multiplier=self.boundary_log_multiplier, |
| angle_multiplier=float(self.cfg.angle_multiplier), |
| ) |
| axis_periods = torch.stack([periods_h, periods_w], dim=0) |
| elif self.cfg.beta_warp is not None: |
| if scales is not None: |
| raise ValueError( |
| "beta-warp axial RoPE does not support dilation scales" |
| ) |
| if ( |
| self.beta_hi_u is None |
| or self.beta_lo_u is None |
| or self.beta_bend_u is None |
| ): |
| raise RuntimeError("beta warp parameters missing for beta-warp RoPE") |
| beta_cfg = self.cfg.beta_warp |
| ref_h = int(beta_cfg.ref_h_tokens) |
| ref_w = int(beta_cfg.ref_w_tokens) |
| qtr = int(self.periods.numel()) |
| if qtr <= 0: |
| raise RuntimeError("AxialRoPE2D periods length must be positive") |
| beta_max = float(beta_cfg.beta_max) |
| if not math.isfinite(beta_max) or beta_max <= 0.0: |
| raise RuntimeError("beta_max must be finite and > 0") |
| beta_max_t = torch.tensor(beta_max, device=device, dtype=torch.float32) |
| beta_hi = beta_max_t * torch.tanh(self.beta_hi_u.to(dtype=torch.float32)) |
| beta_lo = beta_max_t * torch.tanh(self.beta_lo_u.to(dtype=torch.float32)) |
| beta_bend = beta_max_t * torch.tanh( |
| self.beta_bend_u.to(dtype=torch.float32) |
| ) |
| if qtr == 1: |
| beta = beta_hi[None] |
| else: |
| t = torch.arange(int(qtr), device=device, dtype=torch.float32) / float( |
| qtr - 1 |
| ) |
| bump = 4.0 * t * (1.0 - t) |
| beta = (1.0 - t) * beta_hi + t * beta_lo + beta_bend * bump |
|
|
| s_h = float(int(H)) / float(ref_h) |
| s_w = float(int(W)) / float(ref_w) |
| if ( |
| not math.isfinite(s_h) |
| or s_h <= 0.0 |
| or not math.isfinite(s_w) |
| or s_w <= 0.0 |
| ): |
| raise RuntimeError( |
| "Computed invalid axis scale factors for beta-warp RoPE" |
| ) |
| periods_h = self.periods.to(dtype=torch.float32) * torch.pow( |
| torch.tensor(s_h, device=device, dtype=torch.float32), beta |
| ) |
| periods_w = self.periods.to(dtype=torch.float32) * torch.pow( |
| torch.tensor(s_w, device=device, dtype=torch.float32), beta |
| ) |
| axis_periods = torch.stack([periods_h, periods_w], dim=0) |
| elif self.cfg.alpha_warp is not None: |
| if scales is not None: |
| raise ValueError( |
| "alpha-warp axial RoPE does not support dilation scales" |
| ) |
| if self.alpha is None: |
| raise RuntimeError("alpha parameter missing for alpha-warp RoPE") |
| alpha_cfg = self.cfg.alpha_warp |
| ref_h = int(alpha_cfg.ref_h_tokens) |
| ref_w = int(alpha_cfg.ref_w_tokens) |
| qtr = int(self.periods.numel()) |
| if int(self.alpha.numel()) != qtr: |
| raise RuntimeError( |
| "alpha length must match RoPE periods length for alpha-warp RoPE" |
| ) |
| s_h = float(int(H)) / float(ref_h) |
| s_w = float(int(W)) / float(ref_w) |
| if ( |
| not math.isfinite(s_h) |
| or s_h <= 0.0 |
| or not math.isfinite(s_w) |
| or s_w <= 0.0 |
| ): |
| raise RuntimeError( |
| "Computed invalid axis scale factors for alpha-warp RoPE" |
| ) |
| alpha = self.alpha.to(device=device, dtype=torch.float32) |
| scale_h = torch.pow( |
| torch.tensor(s_h, device=device, dtype=torch.float32), alpha |
| ) |
| scale_w = torch.pow( |
| torch.tensor(s_w, device=device, dtype=torch.float32), alpha |
| ) |
| periods_h = self.periods.to(dtype=torch.float32) / scale_h |
| periods_w = self.periods.to(dtype=torch.float32) / scale_w |
| axis_periods = torch.stack([periods_h, periods_w], dim=0) |
| else: |
| axis_periods = self.periods[None, :].expand(2, -1).to(dtype=torch.float32) |
|
|
| |
| angles = ( |
| float(self.cfg.angle_multiplier) |
| * coords[:, :, None].to(dtype=torch.float32) |
| / axis_periods[None, :, :].to(dtype=torch.float32) |
| ) |
| match self.cfg.dim_layout: |
| case AxialRoPE2DDimLayout.HALF_SPLIT: |
| angles = angles.flatten(1, 2).repeat(1, 2) |
| case AxialRoPE2DDimLayout.PAIR_INTERLEAVED: |
| angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2) |
| case _ as unreachable: |
| raise RuntimeError(f"Unsupported dim_layout: {unreachable}") |
| if angles.shape != (int(H) * int(W), int(self._d_head)): |
| raise RuntimeError( |
| "Unexpected angles shape in AxialRoPE2D: " |
| f"{tuple(angles.shape)} for H={int(H)} W={int(W)}" |
| ) |
| if scales is not None: |
| if scales.dim() != 0: |
| raise ValueError( |
| "AxialRoPE2D scales must be a scalar tensor for per-batch dilation; " |
| "per-sample dilation is not supported" |
| ) |
| angles = angles * scales.to(device=device, dtype=torch.float32) |
| cos = torch.cos(angles) |
| sin = torch.sin(angles) |
| return sin, cos |
|
|
|
|
| def _dy_ntk_periods_for_axis( |
| *, |
| periods: Tensor, |
| axis_len: int, |
| ref_axis_len: int, |
| noise_time: Tensor, |
| lambda_s: float, |
| lambda_t: float, |
| ) -> Tensor: |
| """Return Dy-NTK periods for one spatial axis. |
| |
| Raises: |
| ValueError: If token lengths or scheduler values are invalid. |
| """ |
|
|
| if int(axis_len) <= 0 or int(ref_axis_len) <= 0: |
| raise ValueError("axis_len and ref_axis_len must be positive for Dy-NTK") |
| qtr = int(periods.numel()) |
| if qtr <= 0: |
| raise ValueError("periods must be non-empty for Dy-NTK") |
| scale = float(int(axis_len)) / float(int(ref_axis_len)) |
| if not math.isfinite(scale) or scale <= 0.0: |
| raise ValueError("Dy-NTK axis scale must be finite and > 0") |
| return _dy_ntk_periods_for_scale( |
| periods=periods, |
| scale=scale, |
| noise_time=noise_time, |
| lambda_s=float(lambda_s), |
| lambda_t=float(lambda_t), |
| ) |
|
|
|
|
| def _dy_ntk_periods_for_scale( |
| *, |
| periods: Tensor, |
| scale: float, |
| noise_time: Tensor, |
| lambda_s: float, |
| lambda_t: float, |
| ) -> Tensor: |
| """Return Dy-NTK periods for a precomputed axis scale.""" |
|
|
| axis_scale = float(scale) |
| if not math.isfinite(axis_scale) or axis_scale <= 0.0: |
| raise ValueError("Dy-NTK scale must be finite and > 0") |
| qtr = int(periods.numel()) |
| if qtr <= 0: |
| raise ValueError("periods must be non-empty for Dy-NTK") |
| if scale <= 1.0: |
| return periods.to(dtype=torch.float32) |
| if qtr == 1: |
| exponent = torch.zeros((1,), device=periods.device, dtype=torch.float32) |
| else: |
| exponent = torch.arange(qtr, device=periods.device, dtype=torch.float32) / ( |
| float(qtr - 1) |
| ) |
| kappa = float(lambda_s) * torch.pow( |
| noise_time.to(device=periods.device, dtype=torch.float32), |
| float(lambda_t), |
| ) |
| return periods.to(dtype=torch.float32) * torch.pow( |
| torch.tensor(axis_scale, device=periods.device, dtype=torch.float32), |
| kappa * exponent, |
| ) |
|
|
|
|
| def _dype_dynamic_exponent( |
| *, noise_time: float, lambda_s: float, lambda_t: float |
| ) -> float: |
| """Return Comfy/DyPE-style dynamic magnitude for normalized noise time.""" |
|
|
| noise = float(noise_time) |
| if not math.isfinite(noise): |
| raise ValueError("DyPE noise_time must be finite") |
| noise = max(0.0, min(1.0, noise)) |
| scale = float(lambda_s) |
| exponent = float(lambda_t) |
| if not math.isfinite(scale) or scale <= 0.0: |
| raise ValueError("DyPE lambda_s must be finite and > 0") |
| if not math.isfinite(exponent) or exponent <= 0.0: |
| raise ValueError("DyPE lambda_t must be finite and > 0") |
| return scale * (noise**exponent) |
|
|
|
|
| def _dype_correction_factor( |
| *, |
| periods: Tensor, |
| rotations: float, |
| ref_axis_len: int, |
| angle_multiplier: float, |
| ) -> float: |
| """Return fractional band index whose wavelength makes ``rotations`` turns.""" |
|
|
| if int(ref_axis_len) <= 0: |
| raise ValueError("ref_axis_len must be positive for DyPE correction") |
| rot = float(rotations) |
| if not math.isfinite(rot) or rot <= 0.0: |
| raise ValueError("rotations must be finite and > 0") |
| mult = float(angle_multiplier) |
| if not math.isfinite(mult) or mult <= 0.0: |
| raise ValueError("angle_multiplier must be finite and > 0") |
| if int(periods.numel()) < 2: |
| return 0.0 |
| periods_cpu = periods.detach().to(device=torch.device("cpu"), dtype=torch.float32) |
| p0 = float(periods_cpu[0].item()) |
| p1 = float(periods_cpu[-1].item()) |
| if p0 <= 0.0 or p1 <= p0: |
| raise ValueError("periods must be positive and strictly increasing for DyPE") |
| boundary_wavelength = float(int(ref_axis_len)) / rot |
| boundary_period = (mult / (2.0 * float(math.pi))) * boundary_wavelength |
| log_p0 = math.log(p0) |
| log_p1 = math.log(p1) |
| return float(periods.numel() - 1) * ( |
| (math.log(boundary_period) - log_p0) / (log_p1 - log_p0) |
| ) |
|
|
|
|
| def _dype_ramp_mask( |
| *, |
| periods: Tensor, |
| threshold_high_rotations: float, |
| threshold_low_rotations: float, |
| ref_axis_len: int, |
| angle_multiplier: float, |
| ) -> Tensor: |
| """Return YaRN's high-to-low band mask for one dynamic threshold pair.""" |
|
|
| qtr = int(periods.numel()) |
| if qtr <= 0: |
| raise ValueError("periods must be non-empty for DyPE ramp mask") |
| device = periods.device |
| if qtr == 1: |
| return torch.ones((1,), device=device, dtype=torch.float32) |
| low = math.floor( |
| _dype_correction_factor( |
| periods=periods, |
| rotations=float(threshold_high_rotations), |
| ref_axis_len=int(ref_axis_len), |
| angle_multiplier=float(angle_multiplier), |
| ) |
| ) |
| high = math.ceil( |
| _dype_correction_factor( |
| periods=periods, |
| rotations=float(threshold_low_rotations), |
| ref_axis_len=int(ref_axis_len), |
| angle_multiplier=float(angle_multiplier), |
| ) |
| ) |
| low = max(0, min(qtr - 1, int(low))) |
| high = max(0, min(qtr, int(high))) |
| if low == high: |
| high = min(qtr, low + 1) |
| band = torch.arange(qtr, device=device, dtype=torch.float32) |
| ramp = (band - float(low)) / float(high - low) |
| return 1.0 - torch.clamp(ramp, min=0.0, max=1.0) |
|
|
|
|
| def _dy_yarn_periods_for_axis( |
| *, |
| periods: Tensor, |
| linear_scale: float, |
| ntk_scale: float, |
| ref_axis_len: int, |
| noise_time: float, |
| lambda_s: float, |
| cfg: AxialRoPE2DDyPEConfig, |
| angle_multiplier: float, |
| ) -> Tensor: |
| """Return Dy-YaRN periods for one spatial axis.""" |
|
|
| if int(ref_axis_len) <= 0: |
| raise ValueError("ref_axis_len must be positive for Dy-YaRN") |
| linear_s = float(linear_scale) |
| ntk_s = float(ntk_scale) |
| if ( |
| not math.isfinite(linear_s) |
| or linear_s <= 0.0 |
| or not math.isfinite(ntk_s) |
| or ntk_s <= 0.0 |
| ): |
| raise ValueError("Dy-YaRN axis scales must be finite and > 0") |
| periods_f = periods.to(dtype=torch.float32) |
| if max(linear_s, ntk_s) <= 1.0: |
| return periods_f |
| kappa = _dype_dynamic_exponent( |
| noise_time=float(noise_time), |
| lambda_s=float(lambda_s), |
| lambda_t=float(cfg.lambda_t), |
| ) |
| if kappa <= 1e-6: |
| return periods_f |
| freq_base = float(angle_multiplier) / periods_f |
| freq_linear = float(angle_multiplier) / (periods_f * max(1.0, linear_s)) |
| periods_ntk = _dy_ntk_periods_for_scale( |
| periods=periods_f, |
| scale=max(1.0, ntk_s), |
| noise_time=torch.ones((), device=periods.device, dtype=torch.float32), |
| lambda_s=1.0, |
| lambda_t=1.0, |
| ) |
| freq_ntk = float(angle_multiplier) / periods_ntk |
|
|
| beta_mask = _dype_ramp_mask( |
| periods=periods_f, |
| threshold_high_rotations=float(cfg.yarn_beta_0) ** kappa, |
| threshold_low_rotations=float(cfg.yarn_beta_1) ** kappa, |
| ref_axis_len=int(ref_axis_len), |
| angle_multiplier=float(angle_multiplier), |
| ) |
| freq = freq_linear * (1.0 - beta_mask) + freq_ntk * beta_mask |
|
|
| gamma_mask = _dype_ramp_mask( |
| periods=periods_f, |
| threshold_high_rotations=float(cfg.yarn_gamma_0) ** kappa, |
| threshold_low_rotations=float(cfg.yarn_gamma_1) ** kappa, |
| ref_axis_len=int(ref_axis_len), |
| angle_multiplier=float(angle_multiplier), |
| ) |
| freq = freq * (1.0 - gamma_mask) + freq_base * gamma_mask |
| return float(angle_multiplier) / freq |
|
|
|
|
| class AxialRoPE2DDyPE(AxialRoPE2D): |
| """Inference-only axial RoPE wrapper using dynamic position extrapolation.""" |
|
|
| dype_cfg: AxialRoPE2DDyPEConfig |
| dype_noise_time: Tensor |
| dype_noise_time_values: list[float] |
|
|
| def __init__(self, *, base: AxialRoPE2D, cfg: AxialRoPE2DDyPEConfig) -> None: |
| if not isinstance(base, AxialRoPE2D): |
| raise TypeError("base must be an AxialRoPE2D") |
| if not isinstance(cfg, AxialRoPE2DDyPEConfig): |
| raise TypeError("cfg must be an AxialRoPE2DDyPEConfig") |
| if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES: |
| raise ValueError("DyPE requires patch-index axial RoPE coordinates") |
| super().__init__(head_dim=int(base.head_dim), cfg=base.cfg) |
| self.dype_cfg = cfg |
| self.register_buffer( |
| "dype_noise_time", |
| torch.tensor(1.0, dtype=torch.float32), |
| persistent=False, |
| ) |
| self.dype_noise_time_values: list[float] = [1.0] |
| with torch.no_grad(): |
| self.periods.copy_(base.periods.detach().to(dtype=torch.float32)) |
|
|
| def set_dype_noise_time(self, noise_time: float) -> None: |
| """Set the current normalized diffusion noise time in ``[0, 1]``.""" |
|
|
| t = float(noise_time) |
| if not math.isfinite(t) or t < 0.0 or t > 1.0: |
| raise ValueError("DyPE noise_time must be finite and within [0, 1]") |
| self.dype_noise_time.fill_(t) |
| self.dype_noise_time_values[0] = t |
|
|
| def _dype_axis_periods( |
| self, |
| *, |
| axis_len: int, |
| ref_axis_len: int, |
| global_scale: float, |
| lambda_s: float, |
| ) -> Tensor: |
| """Return method-specific periods for one spatial axis.""" |
|
|
| cfg = self.dype_cfg |
| axis_scale = float(int(axis_len)) / float(int(ref_axis_len)) |
| shared_scale = float(global_scale) |
| if ( |
| not math.isfinite(axis_scale) |
| or axis_scale <= 0.0 |
| or not math.isfinite(shared_scale) |
| or shared_scale <= 0.0 |
| ): |
| raise ValueError("DyPE axis and global scales must be finite and > 0") |
| match cfg.method: |
| case DyPERoPEMethod.DY_NTK: |
| return _dy_ntk_periods_for_scale( |
| periods=self.periods, |
| scale=shared_scale, |
| noise_time=self.dype_noise_time, |
| lambda_s=float(lambda_s), |
| lambda_t=float(cfg.lambda_t), |
| ) |
| case DyPERoPEMethod.VISION_YARN: |
| return _dy_yarn_periods_for_axis( |
| periods=self.periods, |
| linear_scale=axis_scale, |
| ntk_scale=shared_scale, |
| ref_axis_len=int(ref_axis_len), |
| noise_time=float(self.dype_noise_time_values[0]), |
| lambda_s=float(lambda_s), |
| cfg=cfg, |
| angle_multiplier=float(self.cfg.angle_multiplier), |
| ) |
| case DyPERoPEMethod.DY_YARN: |
| return _dy_yarn_periods_for_axis( |
| periods=self.periods, |
| linear_scale=shared_scale, |
| ntk_scale=shared_scale, |
| ref_axis_len=int(ref_axis_len), |
| noise_time=float(self.dype_noise_time_values[0]), |
| lambda_s=float(lambda_s), |
| cfg=cfg, |
| angle_multiplier=float(self.cfg.angle_multiplier), |
| ) |
| case _ as unreachable: |
| raise RuntimeError(f"Unsupported DyPE method: {unreachable}") |
|
|
| def forward( |
| self, |
| *, |
| H: int, |
| W: int, |
| scales: Tensor | None, |
| ) -> tuple[Tensor, Tensor]: |
| """Return timestep-aware DyPE sin/cos buffers.""" |
|
|
| if scales is not None: |
| raise ValueError("DyPE axial RoPE does not support dilation scales") |
| if int(H) <= 0 or int(W) <= 0: |
| raise ValueError("H and W must be positive for DyPE axial RoPE") |
| device = self.periods.device |
| coords = _get_patch_index_coords( |
| int(H), int(W), device=device, offset=float(self.cfg.coord_offset) |
| ) |
| scale_h = float(int(H)) / float(int(self.dype_cfg.ref_h_tokens)) |
| scale_w = float(int(W)) / float(int(self.dype_cfg.ref_w_tokens)) |
| global_scale = max(scale_h, scale_w) |
| periods_h = self._dype_axis_periods( |
| axis_len=int(H), |
| ref_axis_len=int(self.dype_cfg.ref_h_tokens), |
| global_scale=global_scale, |
| lambda_s=float(self.dype_cfg.lambda_s), |
| ) |
| periods_w = self._dype_axis_periods( |
| axis_len=int(W), |
| ref_axis_len=int(self.dype_cfg.ref_w_tokens), |
| global_scale=global_scale, |
| lambda_s=float(self.dype_cfg.lambda_s), |
| ) |
| axis_periods = torch.stack([periods_h, periods_w], dim=0) |
| angles = ( |
| float(self.cfg.angle_multiplier) |
| * coords[:, :, None].to(dtype=torch.float32) |
| / axis_periods[None, :, :].to(dtype=torch.float32) |
| ) |
| match self.cfg.dim_layout: |
| case AxialRoPE2DDimLayout.HALF_SPLIT: |
| angles = angles.flatten(1, 2).repeat(1, 2) |
| case AxialRoPE2DDimLayout.PAIR_INTERLEAVED: |
| angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2) |
| case _ as unreachable: |
| raise RuntimeError(f"Unsupported dim_layout: {unreachable}") |
| expected_shape = (int(H) * int(W), int(self.head_dim)) |
| if angles.shape != expected_shape: |
| raise RuntimeError( |
| "Unexpected angles shape in DyPE axial RoPE: " |
| f"{tuple(angles.shape)} for expected {expected_shape}" |
| ) |
| sin = torch.sin(angles) |
| cos = torch.cos(angles) |
| if ( |
| self.dype_cfg.method in (DyPERoPEMethod.VISION_YARN, DyPERoPEMethod.DY_YARN) |
| and bool(self.dype_cfg.yarn_attention_scale) |
| and global_scale > 1.0 |
| ): |
| match self.dype_cfg.method: |
| case DyPERoPEMethod.VISION_YARN: |
| mscale_start = 0.1 * math.log(global_scale) + 1.0 |
| kappa = _dype_dynamic_exponent( |
| noise_time=float(self.dype_noise_time_values[0]), |
| lambda_s=1.0, |
| lambda_t=float(self.dype_cfg.lambda_t), |
| ) |
| mscale = 1.0 + (mscale_start - 1.0) * kappa |
| case DyPERoPEMethod.DY_YARN: |
| mscale = 1.0 + 0.1 * math.log(global_scale) / math.sqrt( |
| global_scale |
| ) |
| case _ as unreachable: |
| raise RuntimeError( |
| f"Unsupported YaRN attention scale: {unreachable}" |
| ) |
| if mscale > 1.0: |
| sin = sin * float(mscale) |
| cos = cos * float(mscale) |
| return sin, cos |
|
|
|
|
| def build_axial_rope2d_dype( |
| *, base: AxialRoPE2D, cfg: AxialRoPE2DDyPEConfig |
| ) -> AxialRoPE2DDyPE: |
| """Build an inference-only DyPE wrapper for an existing axial RoPE.""" |
|
|
| return AxialRoPE2DDyPE(base=base, cfg=cfg).to(device=base.periods.device) |
|
|
|
|
| def set_axial_rope2d_dype_noise_time(module: nn.Module, *, noise_time: float) -> bool: |
| """Set DyPE noise time on all axial DyPE modules inside ``module``.""" |
|
|
| updated = False |
| for child in module.modules(): |
| match child: |
| case AxialRoPE2DDyPE() as dype: |
| dype.set_dype_noise_time(float(noise_time)) |
| updated = True |
| case _: |
| pass |
| return updated |
|
|