from __future__ import annotations import math from dataclasses import dataclass, replace from enum import Enum from typing import Final, cast import torch from torch import Tensor, nn __all__ = [ "AxialRoPE2D", "AxialRoPE2DAlphaWarpConfig", "AxialRoPE2DBetaWarpConfig", "AxialRoPE2DConfig", "AxialRoPE2DCoordMode", "AxialRoPE2DDimLayout", "AxialRoPE2DDyPE", "AxialRoPE2DDyPEConfig", "AxialRoPE2DFrequencyAwareConfig", "AxialRoPE2DNormalizeCoords", "DyPERoPEMethod", "build_axial_rope2d_dype", "build_axial_rope2d_inference_warp_with_strength", "build_axial_rope2d_with_lumina_frequency_warp", "lumina_frequency_aware_periods_for_axis", "set_axial_rope2d_dype_noise_time", ] class AxialRoPE2DNormalizeCoords(Enum): """Coordinate normalization strategy for axial 2D RoPE (DINOv3-style).""" MIN = "min" MAX = "max" SEPARATE = "separate" class AxialRoPE2DCoordMode(Enum): """Coordinate grid mode for axial 2D RoPE. - ``DINOV3_NORMALIZED``: DINOv3-style normalized patch-centre coordinates in ``[-1, 1]`` (after normalization). - ``PATCH_INDICES``: Standard unnormalized patch-grid coordinates in patch units (e.g., ``x in [0, W-1]``, ``y in [0, H-1]``). """ DINOV3_NORMALIZED = "dinov3_normalized" PATCH_INDICES = "patch_indices" class AxialRoPE2DDimLayout(Enum): """Layout of angles along the head-dimension. The layout must match the rotation convention used when applying RoPE to Q/K. - ``HALF_SPLIT``: LLaMA-style layout compatible with ``common.rope.rotate_half`` (splits last dim into two halves). - ``PAIR_INTERLEAVED``: EVA-02 / SpeedrunDiT-style layout compatible with an adjacent-pair rotate_half (pairs consecutive dims). TODO(refactor): Standardize on ``PAIR_INTERLEAVED`` throughout DiT to reduce complexity and avoid layout mismatches, then delete ``HALF_SPLIT`` and any related branching once the migration is complete. """ HALF_SPLIT = "half_split" PAIR_INTERLEAVED = "pair_interleaved" class DyPERoPEMethod(Enum): """Dynamic position extrapolation method applied to inference RoPE.""" VISION_YARN = "vision_yarn" DY_YARN = "dy_yarn" DY_NTK = "dy_ntk" @dataclass(frozen=True) class AxialRoPE2DDyPEConfig: """Inference-only DyPE controls for axial RoPE. Args: method: Dynamic extrapolation rule to apply. ref_h_tokens: Training/reference token height. ref_w_tokens: Training/reference token width. lambda_s: Dynamic extrapolation magnitude. lambda_t: Dynamic extrapolation noise-time exponent. yarn_beta_0: YaRN first-ramp high rotation threshold. yarn_beta_1: YaRN first-ramp low rotation threshold. yarn_gamma_0: YaRN base-blend high rotation threshold. yarn_gamma_1: YaRN base-blend low rotation threshold. yarn_attention_scale: Apply YaRN's static attention magnitude correction. """ method: DyPERoPEMethod ref_h_tokens: int ref_w_tokens: int lambda_s: float = 2.0 lambda_t: float = 2.0 yarn_beta_0: float = 1.25 yarn_beta_1: float = 0.75 yarn_gamma_0: float = 16.0 yarn_gamma_1: float = 2.0 yarn_attention_scale: bool = True def __post_init__(self) -> None: if not isinstance(self.method, DyPERoPEMethod): raise TypeError("method must be a DyPERoPEMethod") if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0: raise ValueError("ref_h_tokens and ref_w_tokens must be positive") for name, value in ( ("lambda_s", self.lambda_s), ("lambda_t", self.lambda_t), ("yarn_beta_0", self.yarn_beta_0), ("yarn_beta_1", self.yarn_beta_1), ("yarn_gamma_0", self.yarn_gamma_0), ("yarn_gamma_1", self.yarn_gamma_1), ): v = float(value) if not math.isfinite(v) or v <= 0.0: raise ValueError(f"{name} must be finite and > 0") if not isinstance(self.yarn_attention_scale, bool): raise TypeError("yarn_attention_scale must be a bool") @dataclass(frozen=True) class AxialRoPE2DFrequencyAwareConfig: """Lumina/Next-DiT-style frequency-aware RoPE warping for one token grid. This config implements a per-axis, per-band frequency warp that depends on the input axis length ``L`` relative to a reference length ``L_ref``: - Define the axis scale ``s = L / L_ref``. - RoPE is parameterized by *periods* (wavelengths in tokens) ``period[d]``. In this module's axial parameterization (with patch-index coordinates), the angle for coordinate ``p`` and band ``d`` is: angle(p, d) = 2π * p / period[d] so the wavelength of band ``d`` is exactly ``period[d]`` tokens. - Pick a *boundary wavelength* ``L_boundary`` (in tokens), expressed as a trainable multiplier around the reference length: L_boundary = L_ref * exp(boundary_log_multiplier) The scalar ``boundary_log_multiplier`` is shared across H/W axes (and initialized by this config). - Define a (possibly fractional) boundary band index ``d*`` as the band whose wavelength equals ``L_boundary``: period(d*) = L_boundary In practice we compute ``d*`` by linear interpolation in log-period space (periods are geometric for both supported period parametrizations). - The Lumina/Next-DiT implicit exponent ramp is then: alpha[d] = clamp(d / d*, 0, 1) where: - high-frequency bands (small d) have alpha≈0 (extrapolation-like), - low-frequency bands (large d) have alpha→1 (interpolation-like), - alpha is capped at 1 to ensure we never compress a band more than plain position interpolation would. - Finally, warp the periods per axis: period'[d] = period[d] * s ** alpha[d] Equivalently, angular frequencies warp as: omega'[d] = omega[d] / s ** alpha[d] Notes ----- - This warp is only meaningful for patch-index coordinates (``AxialRoPE2DCoordMode.PATCH_INDICES``). Mixing it with normalized coordinates would create an implicit "gauge switch"; we fail fast. - The boundary multiplier is trainable by construction (it is stored as an nn.Parameter inside AxialRoPE2D when this config is present). """ ref_h_tokens: int ref_w_tokens: int boundary_log_multiplier_init: float def __post_init__(self) -> None: if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0: raise ValueError("ref_h_tokens and ref_w_tokens must be positive") init = float(self.boundary_log_multiplier_init) if not math.isfinite(init): raise ValueError("boundary_log_multiplier_init must be finite") @dataclass(frozen=True) class AxialRoPE2DBetaWarpConfig: """Trainable beta-curve warping for axial 2D RoPE periods (per token grid). This config defines a per-axis period warp that depends on the runtime axis length ``L`` relative to a reference length ``L_ref``: s = L / L_ref period'[d] = period[d] * s ** beta[d] where the per-band exponent curve beta(d) is parameterized by three trainable u-space scalars (shared across H/W axes): beta_hi = beta_max * tanh(beta_hi_u) (high-frequency endpoint, d=0) beta_lo = beta_max * tanh(beta_lo_u) (low-frequency endpoint, d=qtr-1) beta_bend = beta_max * tanh(beta_bend_u) (mid-band bump amplitude) and the per-band curve is: t = d / (qtr - 1) in [0, 1] beta(t) = lerp(beta_hi, beta_lo, t) + beta_bend * 4*t*(1-t) Interpretation -------------- - ``beta(d) == 0``: identity / "extrapolation-like" (no warping; periods do not change with axis length). - ``beta(d) == 1``: position-interpolation-like for that band (``period'[d] = period[d] * s`` so ``omega'[d] = omega[d] / s``). This parameterization provides strong and smooth control over the effective scaling of each frequency band, including allowing beta<0 (increasing frequencies when s>1), which can be important for unnormalized RoPE bases (e.g. base=10_000) where some very low-frequency bands barely rotate on practical token grids. Notes: - This warp requires patch-index coordinates (coord_mode=PATCH_INDICES). - The u parameters are stored as nn.Parameter inside AxialRoPE2D when this config is present. """ ref_h_tokens: int ref_w_tokens: int beta_max: float beta_hi_u_init: float beta_lo_u_init: float beta_bend_u_init: float def __post_init__(self) -> None: if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0: raise ValueError("ref_h_tokens and ref_w_tokens must be positive") bmax = float(self.beta_max) if not math.isfinite(bmax) or bmax <= 0.0: raise ValueError("beta_max must be finite and > 0") for name, value in ( ("beta_hi_u_init", self.beta_hi_u_init), ("beta_lo_u_init", self.beta_lo_u_init), ("beta_bend_u_init", self.beta_bend_u_init), ): v = float(value) if not math.isfinite(v): raise ValueError(f"{name} must be finite") @dataclass(frozen=True) class AxialRoPE2DAlphaWarpConfig: """Per-band power-law warping of axial 2D RoPE frequencies (shared across axes). This config warps RoPE frequencies per band using a learned exponent vector ``alpha[d]`` shared across H/W axes: f'[d] = f[d] * s ** alpha[d] where s = L / L_ref Since this module parameterizes angles via periods ``period[d]`` with ``f[d] ∝ 1 / period[d]``, the equivalent period warp implemented in AxialRoPE2D is: period'[d] = period[d] / s ** alpha[d] Notes: - This warp requires patch-index coordinates (coord_mode=PATCH_INDICES). - ``alpha`` is stored as an unconstrained nn.Parameter vector of length Q (bands per axis), initialized to ``alpha_init`` for all bands. """ ref_h_tokens: int ref_w_tokens: int alpha_init: float def __post_init__(self) -> None: if int(self.ref_h_tokens) <= 0 or int(self.ref_w_tokens) <= 0: raise ValueError("ref_h_tokens and ref_w_tokens must be positive") init = float(self.alpha_init) if not math.isfinite(init): raise ValueError("alpha_init must be finite") @dataclass(frozen=True) class AxialRoPE2DConfig: """Configuration for axial 2D RoPE sin/cos generation. This module supports two coordinate conventions via ``coord_mode``: - ``DINOV3_NORMALIZED``: DINOv3-style normalized patch-centre coordinates in ``[-1, 1]`` (after normalization). - ``PATCH_INDICES``: Standard unnormalized patch-grid coordinates in patch units (e.g., ``x in [0, W-1]``). Period parametrization ---------------------- The periods parametrization matches DINOv3: - Provide either `base` (and leave `min_period/max_period` unset), or - Provide both `min_period` and `max_period` (and set `base=None`). """ base: float | None = 100.0 min_period: float | None = None max_period: float | None = None coord_mode: AxialRoPE2DCoordMode = AxialRoPE2DCoordMode.DINOV3_NORMALIZED normalize_coords: AxialRoPE2DNormalizeCoords = AxialRoPE2DNormalizeCoords.MAX dim_layout: AxialRoPE2DDimLayout = AxialRoPE2DDimLayout.HALF_SPLIT angle_multiplier: float = 2.0 * float(math.pi) coord_offset: float = 0.5 frequency_aware: AxialRoPE2DFrequencyAwareConfig | None = None beta_warp: AxialRoPE2DBetaWarpConfig | None = None alpha_warp: AxialRoPE2DAlphaWarpConfig | None = None def __post_init__(self) -> None: both_periods = self.min_period is not None and self.max_period is not None if (self.base is None and not both_periods) or ( self.base is not None and both_periods ): raise ValueError( "AxialRoPE2DConfig requires either base!=None, or both min_period and max_period." ) if self.base is not None and float(self.base) <= 0.0: raise ValueError("AxialRoPE2DConfig.base must be positive when provided") if self.min_period is not None and float(self.min_period) <= 0.0: raise ValueError( "AxialRoPE2DConfig.min_period must be positive when provided" ) if self.max_period is not None and float(self.max_period) <= 0.0: raise ValueError( "AxialRoPE2DConfig.max_period must be positive when provided" ) if self.min_period is not None and self.max_period is not None: if float(self.max_period) <= float(self.min_period): raise ValueError("AxialRoPE2DConfig.max_period must be > min_period") if not isinstance(self.coord_mode, AxialRoPE2DCoordMode): raise TypeError( "AxialRoPE2DConfig.coord_mode must be an AxialRoPE2DCoordMode" ) if not isinstance(self.normalize_coords, AxialRoPE2DNormalizeCoords): raise TypeError( "AxialRoPE2DConfig.normalize_coords must be an AxialRoPE2DNormalizeCoords" ) if not isinstance(self.dim_layout, AxialRoPE2DDimLayout): raise TypeError( "AxialRoPE2DConfig.dim_layout must be an AxialRoPE2DDimLayout" ) mult = float(self.angle_multiplier) if not math.isfinite(mult) or mult <= 0.0: raise ValueError( "AxialRoPE2DConfig.angle_multiplier must be finite and > 0" ) off = float(self.coord_offset) if not math.isfinite(off): raise ValueError("AxialRoPE2DConfig.coord_offset must be finite") if self.frequency_aware is not None and not isinstance( self.frequency_aware, AxialRoPE2DFrequencyAwareConfig ): raise TypeError( "AxialRoPE2DConfig.frequency_aware must be an AxialRoPE2DFrequencyAwareConfig" ) if self.beta_warp is not None and not isinstance( self.beta_warp, AxialRoPE2DBetaWarpConfig ): raise TypeError( "AxialRoPE2DConfig.beta_warp must be an AxialRoPE2DBetaWarpConfig" ) if self.alpha_warp is not None and not isinstance( self.alpha_warp, AxialRoPE2DAlphaWarpConfig ): raise TypeError( "AxialRoPE2DConfig.alpha_warp must be an AxialRoPE2DAlphaWarpConfig" ) warp_count = ( int(self.frequency_aware is not None) + int(self.beta_warp is not None) + int(self.alpha_warp is not None) ) if warp_count > 1: raise ValueError( "AxialRoPE2DConfig requires at most one of frequency_aware, beta_warp, or alpha_warp" ) if self.frequency_aware is not None and ( self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES ): raise ValueError( "AxialRoPE2D frequency-aware warping requires coord_mode=PATCH_INDICES" ) if self.beta_warp is not None and ( self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES ): raise ValueError("AxialRoPE2D beta warp requires coord_mode=PATCH_INDICES") if self.alpha_warp is not None and ( self.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES ): raise ValueError("AxialRoPE2D alpha warp requires coord_mode=PATCH_INDICES") _AXIAL_COORDS_CACHE: dict[ tuple[ int, int, torch.device, AxialRoPE2DCoordMode, AxialRoPE2DNormalizeCoords, float ], Tensor, ] = {} def _get_dinov3_normalized_coords( H: int, W: int, *, device: torch.device, normalize: AxialRoPE2DNormalizeCoords, offset: float, ) -> Tensor: """Return DINOv3-style flattened coords in [-1, 1] with shape [HW, 2].""" if H <= 0 or W <= 0: raise ValueError("H and W must be positive for axial RoPE coords") key = ( int(H), int(W), device, AxialRoPE2DCoordMode.DINOV3_NORMALIZED, normalize, float(offset), ) cached = _AXIAL_COORDS_CACHE.get(key) if cached is not None: return cached start = float(offset) end_h = start + float(int(H)) end_w = start + float(int(W)) match normalize: case AxialRoPE2DNormalizeCoords.MAX: denom = float(max(int(H), int(W))) coords_h = ( torch.arange(start, end_h, device=device, dtype=torch.float32) / denom ) coords_w = ( torch.arange(start, end_w, device=device, dtype=torch.float32) / denom ) case AxialRoPE2DNormalizeCoords.MIN: denom = float(min(int(H), int(W))) coords_h = ( torch.arange(start, end_h, device=device, dtype=torch.float32) / denom ) coords_w = ( torch.arange(start, end_w, device=device, dtype=torch.float32) / denom ) case AxialRoPE2DNormalizeCoords.SEPARATE: coords_h = torch.arange( start, end_h, device=device, dtype=torch.float32 ) / float(int(H)) coords_w = torch.arange( start, end_w, device=device, dtype=torch.float32 ) / float(int(W)) case _ as unreachable: # pragma: no cover - defensive raise RuntimeError(f"Unsupported normalize_coords: {unreachable}") coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) coords = coords.flatten(0, 1) coords = 2.0 * coords - 1.0 # torch.compile cannot trace `torch.is_inference_mode_enabled()` and should # not record Python-side cache mutations in the graph. if torch.compiler.is_compiling(): return coords if torch.is_inference_mode_enabled(): return coords _AXIAL_COORDS_CACHE[key] = coords return coords def _get_patch_index_coords( H: int, W: int, *, device: torch.device, offset: float, ) -> Tensor: """Return unnormalized patch-grid coords with shape [HW, 2] and (y, x) columns.""" if H <= 0 or W <= 0: raise ValueError("H and W must be positive for axial RoPE coords") key = ( int(H), int(W), device, AxialRoPE2DCoordMode.PATCH_INDICES, AxialRoPE2DNormalizeCoords.MAX, float(offset), ) cached = _AXIAL_COORDS_CACHE.get(key) if cached is not None: return cached start = float(offset) end_h = start + float(int(H)) end_w = start + float(int(W)) coords_h = torch.arange(start, end_h, device=device, dtype=torch.float32) coords_w = torch.arange(start, end_w, device=device, dtype=torch.float32) coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) coords = coords.flatten(0, 1) if torch.compiler.is_compiling(): return coords if torch.is_inference_mode_enabled(): return coords _AXIAL_COORDS_CACHE[key] = coords return coords def _lumina_boundary_band_index( *, periods: Tensor, boundary_wavelength: Tensor, ) -> Tensor: """Return the fractional boundary band index d* for a given boundary wavelength. This implements the Lumina/Next-DiT definition: period(d*) = boundary_wavelength We compute d* by linear interpolation in log-period space. For the supported period parameterizations, periods are geometric and log(period) is linear in band index. Args: periods: 1D float tensor of length Q containing monotonically increasing periods in tokens. boundary_wavelength: Scalar positive float tensor giving the desired boundary wavelength in tokens. Returns: Scalar float32 tensor giving the (possibly fractional) boundary index d*. Raises: ValueError: If periods are invalid or the boundary is outside valid range for a well-defined positive d*. """ if periods.dim() != 1: raise ValueError("periods must be 1D for boundary band index") if int(periods.numel()) < 2: raise ValueError("periods must have length >= 2 for boundary band index") if boundary_wavelength.dim() != 0: raise ValueError("boundary_wavelength must be a scalar tensor") if not torch.isfinite(boundary_wavelength).item(): raise ValueError("boundary_wavelength must be finite") if float(boundary_wavelength.item()) <= 0.0: raise ValueError("boundary_wavelength must be > 0") periods_f = periods.to(dtype=torch.float32) if not torch.isfinite(periods_f).all().item(): raise ValueError("periods must be finite for boundary band index") if float(periods_f[0].item()) <= 0.0: raise ValueError("periods must be positive for boundary band index") if not (periods_f[1:] > periods_f[:-1]).all().item(): raise ValueError("periods must be strictly increasing for boundary band index") log_p0 = torch.log(periods_f[0]) log_p1 = torch.log(periods_f[-1]) denom = log_p1 - log_p0 if float(denom.item()) <= 0.0: raise ValueError("Invalid periods range for boundary band index") log_boundary = torch.log(boundary_wavelength.to(dtype=torch.float32)) q = int(periods_f.numel()) d_star = (float(q - 1) * (log_boundary - log_p0)) / denom if not torch.isfinite(d_star).item(): raise ValueError("Computed non-finite boundary band index d*") if float(d_star.item()) <= 0.0: raise ValueError( "Boundary wavelength implies d* <= 0; increase the boundary wavelength " "(or its multiplier) to be >= the wavelength of the first non-zero band." ) return d_star def _lumina_alpha_ramp( *, qtr: int, d_star: Tensor, device: torch.device, ) -> Tensor: """Return alpha[d] = clamp(d / d*, 0, 1) for d in [0, qtr). Args: qtr: Number of RoPE bands per axis (Q). d_star: Scalar positive float tensor boundary index d*. device: Device for the returned alpha tensor. Returns: Float32 tensor of shape [Q] with values in [0, 1]. """ if int(qtr) <= 0: raise ValueError("qtr must be positive for alpha ramp") if d_star.dim() != 0: raise ValueError("d_star must be a scalar tensor for alpha ramp") if float(d_star.item()) <= 0.0: raise ValueError("d_star must be > 0 for alpha ramp") d = torch.arange(int(qtr), device=device, dtype=torch.float32) alpha = d / d_star.to(device=device, dtype=torch.float32) return torch.clamp(alpha, min=0.0, max=1.0) def lumina_frequency_aware_periods_for_axis( *, periods: Tensor, axis_len: int, ref_axis_len: int, boundary_log_multiplier: Tensor, angle_multiplier: float, ) -> Tensor: """Return Lumina/Next-DiT frequency-aware warped periods for one axis. Implements: s = axis_len / ref_axis_len L_boundary = ref_axis_len * exp(boundary_log_multiplier) d* = boundary band index where period(d*) = L_boundary alpha[d] = clamp(d / d*, 0, 1) period'[d] = period[d] * s**alpha[d] Notes on ``angle_multiplier`` ----------------------------- This module parameterizes angles as: angle(p, d) = angle_multiplier * p / period[d] The *wavelength* (period in tokens) is the delta in ``p`` that increases the angle by ``2π``: wavelength[d] = 2π * period[d] / angle_multiplier Lumina/Next-DiT define the boundary by matching *wavelength* to the reference axis length. We therefore convert the boundary wavelength ``L_boundary`` into a boundary period via: period_boundary = (angle_multiplier / 2π) * L_boundary When ``angle_multiplier == 2π`` (the DINOv3-style parameterization), this reduces to ``period_boundary == L_boundary``. Args: periods: Base periods ``[Q]`` in tokens (wavelengths). axis_len: Input axis length ``L`` in tokens. ref_axis_len: Reference axis length ``L_ref`` in tokens. boundary_log_multiplier: Scalar tensor; shared trainable log-multiplier. angle_multiplier: RoPE angle multiplier used when converting periods to physical wavelengths in tokens. Returns: Warped periods ``[Q]`` as float32. Raises: ValueError: If inputs are malformed or imply an invalid boundary index. """ if int(axis_len) <= 0: raise ValueError("axis_len must be positive for frequency-aware periods") if int(ref_axis_len) <= 0: raise ValueError("ref_axis_len must be positive for frequency-aware periods") if boundary_log_multiplier.dim() != 0: raise ValueError("boundary_log_multiplier must be a scalar tensor") if not torch.isfinite(boundary_log_multiplier).item(): raise ValueError("boundary_log_multiplier must be finite") mult = float(angle_multiplier) if not math.isfinite(mult) or mult <= 0.0: raise ValueError("angle_multiplier must be finite and > 0") device = periods.device qtr = int(periods.numel()) s = float(int(axis_len)) / float(int(ref_axis_len)) if not math.isfinite(s) or s <= 0.0: raise ValueError("axis_len/ref_axis_len must be finite and > 0") boundary_wavelength = float(int(ref_axis_len)) * torch.exp( boundary_log_multiplier.to(device=device, dtype=torch.float32) ) boundary_period = (mult / (2.0 * float(math.pi))) * boundary_wavelength d_star = _lumina_boundary_band_index( periods=periods, boundary_wavelength=boundary_period ) alpha = _lumina_alpha_ramp(qtr=qtr, d_star=d_star, device=device) scale = torch.pow(torch.tensor(s, device=device, dtype=torch.float32), alpha) return periods.to(device=device, dtype=torch.float32) * scale def build_axial_rope2d_with_lumina_frequency_warp( base: AxialRoPE2D, *, ref_h_tokens: int, ref_w_tokens: int, boundary_log_multiplier: float | None, boundary_band_multiplier: float | None, ) -> AxialRoPE2D: """Return an AxialRoPE2D module that applies Lumina-style frequency warping. This helper is intended for inference-time experimentation on checkpoints that were trained without frequency-aware warping (e.g. ``position_encoding=ROPE_2D_AXIAL_UNNORMALIZED``). It constructs a new AxialRoPE2D instance that: - Keeps the base RoPE periods and layout identical to ``base``. - Applies Lumina/Next-DiT per-axis warping based on the runtime token lengths ``H`` and ``W`` relative to reference lengths. - Uses a fixed (non-trainable) scalar boundary multiplier for inference. Args: base: Existing AxialRoPE2D instance from a loaded model. ref_h_tokens: Reference H token length (L_ref,h). ref_w_tokens: Reference W token length (L_ref,w). boundary_log_multiplier: Optional log multiplier applied to reference lengths to define the boundary wavelength. Use 0.0 for "boundary at L_ref". Mutually exclusive with boundary_band_multiplier. boundary_band_multiplier: Optional multiplier that directly selects the boundary band index d* relative to the lowest-frequency band index (qtr-1). Concretely, with qtr bands per axis: d* = boundary_band_multiplier * (qtr - 1) This lets you move the transition point in frequency space: - smaller values => more bands become PI-like (more interpolation) - larger values => fewer bands become PI-like (more extrapolation) When provided, we compute the implied boundary wavelength and store it as boundary_log_multiplier for the module. Returns: New AxialRoPE2D instance on the same device as ``base``. Raises: TypeError: If base is not an AxialRoPE2D. ValueError: If base uses incompatible coordinates for Lumina warping. """ if not isinstance(base, AxialRoPE2D): raise TypeError("base must be an AxialRoPE2D") if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES: raise ValueError( "Lumina frequency-aware warping requires coord_mode=PATCH_INDICES" ) if (boundary_log_multiplier is None) == (boundary_band_multiplier is None): raise ValueError( "Provide exactly one of boundary_log_multiplier or boundary_band_multiplier" ) resolved_log_multiplier: float if boundary_band_multiplier is not None: if int(ref_h_tokens) != int(ref_w_tokens): raise ValueError( "boundary_band_multiplier requires ref_h_tokens == ref_w_tokens when using a shared scalar boundary" ) mult = float(boundary_band_multiplier) if not math.isfinite(mult) or mult <= 0.0: raise ValueError("boundary_band_multiplier must be finite and > 0") qtr = int(base.periods.numel()) if qtr < 2: raise ValueError( "AxialRoPE2D periods length must be >= 2 for boundary band selection" ) # Solve for the boundary wavelength implied by choosing d* directly. # # We use geometric interpolation in log-period space: # log(period(d*)) = log(period0) + (d*/(qtr-1)) * (log(period_max) - log(period0)) # with: # d* = boundary_band_multiplier * (qtr-1) # # This allows d* outside the trained band range (multiplier > 1), which # corresponds to pushing the transition beyond the lowest-frequency band. with torch.no_grad(): periods_f = base.periods.to(dtype=torch.float32, device=torch.device("cpu")) if not (periods_f[1:] > periods_f[:-1]).all().item(): raise ValueError( "base.periods must be strictly increasing for boundary band selection" ) log_p0 = float(torch.log(periods_f[0]).item()) log_p1 = float(torch.log(periods_f[-1]).item()) d_star = mult * float(qtr - 1) log_boundary_period = log_p0 + (d_star / float(qtr - 1)) * (log_p1 - log_p0) boundary_period = math.exp(log_boundary_period) angle_mult = float(base.cfg.angle_multiplier) if not math.isfinite(angle_mult) or angle_mult <= 0.0: raise ValueError("base.cfg.angle_multiplier must be finite and > 0") boundary_wavelength = (2.0 * float(math.pi) / angle_mult) * boundary_period resolved_log_multiplier = math.log( boundary_wavelength / float(int(ref_h_tokens)) ) else: if boundary_log_multiplier is None: # pragma: no cover - validated above raise RuntimeError("boundary_log_multiplier missing despite validation") resolved_log_multiplier = float(boundary_log_multiplier) freq_cfg = AxialRoPE2DFrequencyAwareConfig( ref_h_tokens=int(ref_h_tokens), ref_w_tokens=int(ref_w_tokens), boundary_log_multiplier_init=resolved_log_multiplier, ) cfg = replace(base.cfg, frequency_aware=freq_cfg, beta_warp=None, alpha_warp=None) device = base.periods.device warped = AxialRoPE2D(head_dim=int(base.head_dim), cfg=cfg).to(device=device) with torch.no_grad(): warped.periods.copy_(base.periods.to(device=device, dtype=torch.float32)) if warped.boundary_log_multiplier is None: # pragma: no cover - defensive raise RuntimeError("Expected boundary_log_multiplier to be initialized") warped.boundary_log_multiplier.copy_( torch.tensor(resolved_log_multiplier, device=device, dtype=torch.float32) ) warped.boundary_log_multiplier.requires_grad_(False) return warped def build_axial_rope2d_inference_warp_with_strength( base: AxialRoPE2D, *, ref_h_tokens: int, ref_w_tokens: int, beta_hi_u: float, beta_lo_u: float, beta_bend_u: float, beta_max: float, ) -> AxialRoPE2D: """Build an inference-only RoPE warp parameterized by a 3-knob beta(t) curve. This helper is meant for notebook experimentation on checkpoints trained with patch-index axial RoPE (e.g. ``position_encoding=ROPE_2D_AXIAL_UNNORMALIZED``). We warp per-axis RoPE periods (wavelengths, in tokens) as: period'[d] = period[d] * s ** beta[d] where s = L / L_ref with a smooth exponent curve beta(d) over bands. Unlike a strict interpolation-only exponent (0..1), beta is allowed to be negative or > 1, which is important for unnormalized RoPE (e.g. base=10_000) where some very low-frequency bands are effectively "dead" on practical token grids unless their frequencies can be increased (beta < 0). Knobs (bounded via u-space) --------------------------- We use three unconstrained parameters (u-space) which map to bounded beta values via tanh: beta_hi = beta_max * tanh(beta_hi_u) (high-frequency endpoint, d=0) beta_lo = beta_max * tanh(beta_lo_u) (low-frequency endpoint, d=qtr-1) beta_bend = beta_max * tanh(beta_bend_u) ("bump" amplitude in the middle) Then define the per-band curve over t in [0,1] (high -> low frequency): t = d / (qtr - 1) beta(t) = lerp(beta_hi, beta_lo, t) + beta_bend * 4*t*(1-t) The bump term is 0 at the endpoints and peaks at 1 at t=0.5. Notes ----- - This wrapper is inference-only: it is not saved in checkpoints. - It requires patch-index coordinates (no normalized "gauge"). - It preserves the base module's periods and layout exactly. Args: base: Existing AxialRoPE2D instance from a loaded model. ref_h_tokens: Reference H token length (L_ref,h). ref_w_tokens: Reference W token length (L_ref,w). beta_hi_u: Unconstrained u for beta_hi. beta_lo_u: Unconstrained u for beta_lo. beta_bend_u: Unconstrained u for beta_bend (mid-band bump). beta_max: Maximum absolute beta value (> 0). Higher increases control. Returns: An AxialRoPE2D instance whose forward applies the inference-only warp. """ if not isinstance(base, AxialRoPE2D): raise TypeError("base must be an AxialRoPE2D") if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES: raise ValueError( "Inference freq-warp requires base.cfg.coord_mode=PATCH_INDICES" ) if int(ref_h_tokens) <= 0 or int(ref_w_tokens) <= 0: raise ValueError("ref_h_tokens and ref_w_tokens must be positive") hi_u = float(beta_hi_u) lo_u = float(beta_lo_u) bend_u = float(beta_bend_u) if not math.isfinite(hi_u): raise ValueError("beta_hi_u must be finite") if not math.isfinite(lo_u): raise ValueError("beta_lo_u must be finite") if not math.isfinite(bend_u): raise ValueError("beta_bend_u must be finite") bmax = float(beta_max) if not math.isfinite(bmax) or bmax <= 0.0: raise ValueError("beta_max must be finite and > 0") class _AxialRoPE2DInferenceWarp(AxialRoPE2D): """Inference-only axial RoPE variant with beta-curve knobs.""" def __init__(self, *, device: torch.device) -> None: super().__init__(head_dim=int(base.head_dim), cfg=base.cfg) self.ref_h_tokens: Final[int] = int(ref_h_tokens) self.ref_w_tokens: Final[int] = int(ref_w_tokens) # Store as buffers so the notebook can mutate by replacing the module. self.register_buffer( "beta_hi_u", torch.tensor(float(hi_u), dtype=torch.float32), persistent=False, ) self.register_buffer( "beta_lo_u", torch.tensor(float(lo_u), dtype=torch.float32), persistent=False, ) self.register_buffer( "beta_bend_u", torch.tensor(float(bend_u), dtype=torch.float32), persistent=False, ) self.register_buffer( "beta_max", torch.tensor(float(bmax), dtype=torch.float32), persistent=False, ) self.to(device=device) with torch.no_grad(): self.periods.copy_( base.periods.detach().to(device=device, dtype=torch.float32) ) def forward( self, *, H: int, W: int, scales: Tensor | None, ) -> tuple[Tensor, Tensor]: if scales is not None: raise ValueError("Inference freq-warp does not support dilation scales") if int(H) <= 0 or int(W) <= 0: raise ValueError("H and W must be positive for axial RoPE") device = self.periods.device offset = float(self.cfg.coord_offset) coords = _get_patch_index_coords( int(H), int(W), device=device, offset=offset ) if coords.dim() != 2 or coords.shape[1] != 2: raise RuntimeError("Axial RoPE coords must have shape [HW, 2]") qtr = int(self.periods.numel()) if qtr <= 0: raise RuntimeError("Axial RoPE periods length must be positive") beta_max_t = cast("Tensor", self.beta_max).to( device=device, dtype=torch.float32 ) beta_hi = beta_max_t * torch.tanh( cast("Tensor", self.beta_hi_u).to(device=device, dtype=torch.float32) ) beta_lo = beta_max_t * torch.tanh( cast("Tensor", self.beta_lo_u).to(device=device, dtype=torch.float32) ) beta_bend = beta_max_t * torch.tanh( cast("Tensor", self.beta_bend_u).to(device=device, dtype=torch.float32) ) if qtr == 1: beta = beta_hi[None] else: t = torch.arange(int(qtr), device=device, dtype=torch.float32) / float( qtr - 1 ) bump = 4.0 * t * (1.0 - t) beta = (1.0 - t) * beta_hi + t * beta_lo + beta_bend * bump s_h = float(int(H)) / float(int(self.ref_h_tokens)) s_w = float(int(W)) / float(int(self.ref_w_tokens)) if ( not math.isfinite(s_h) or s_h <= 0.0 or not math.isfinite(s_w) or s_w <= 0.0 ): raise ValueError( "H/ref_h_tokens and W/ref_w_tokens must be finite and > 0" ) periods_h = self.periods * torch.pow( torch.tensor(s_h, device=device, dtype=torch.float32), beta ) periods_w = self.periods * torch.pow( torch.tensor(s_w, device=device, dtype=torch.float32), beta ) axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q] angles = ( float(self.cfg.angle_multiplier) * coords[:, :, None].to(dtype=torch.float32) / axis_periods[None, :, :].to(dtype=torch.float32) ) match self.cfg.dim_layout: case AxialRoPE2DDimLayout.HALF_SPLIT: angles = angles.flatten(1, 2).repeat(1, 2) case AxialRoPE2DDimLayout.PAIR_INTERLEAVED: angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2) case _ as unreachable: # pragma: no cover - defensive raise RuntimeError(f"Unsupported dim_layout: {unreachable}") if angles.shape != (int(H) * int(W), int(self.head_dim)): raise RuntimeError( "Unexpected angles shape in inference freq-warp: " f"{tuple(angles.shape)} for H={int(H)} W={int(W)}" ) return torch.sin(angles), torch.cos(angles) return _AxialRoPE2DInferenceWarp(device=base.periods.device) class AxialRoPE2D(nn.Module): """DINOv3-style axial 2D RoPE sin/cos generator. The base periods are fixed by ``AxialRoPE2DConfig``. Optionally, this module can include learnable scalar parameters when using: - ``frequency_aware`` (boundary_log_multiplier), or - ``beta_warp`` (beta_hi_u/beta_lo_u/beta_bend_u), or - ``alpha_warp`` (alpha per-band exponents). """ periods: Tensor def __init__(self, *, head_dim: int, cfg: AxialRoPE2DConfig) -> None: super().__init__() if int(head_dim) <= 0: raise ValueError("head_dim must be positive for AxialRoPE2D") if int(head_dim) % 4 != 0: raise ValueError( "AxialRoPE2D requires head_dim % 4 == 0 (DINOv3 constraint); " f"got head_dim={int(head_dim)}" ) if not isinstance(cfg, AxialRoPE2DConfig): raise TypeError("cfg must be an AxialRoPE2DConfig for AxialRoPE2D") self.head_dim: Final[int] = int(head_dim) self.cfg: Final[AxialRoPE2DConfig] = cfg self._d_head: Final[int] = self.head_dim self.register_buffer( "periods", torch.empty(self._d_head // 4, dtype=torch.float32), persistent=True, ) if cfg.frequency_aware is None: self.register_parameter("boundary_log_multiplier", None) else: init = float(cfg.frequency_aware.boundary_log_multiplier_init) self.boundary_log_multiplier = nn.Parameter( torch.tensor(init, dtype=torch.float32), requires_grad=True, ) if cfg.beta_warp is None: self.register_parameter("beta_hi_u", None) self.register_parameter("beta_lo_u", None) self.register_parameter("beta_bend_u", None) else: beta = cfg.beta_warp self.beta_hi_u = nn.Parameter( torch.tensor(float(beta.beta_hi_u_init), dtype=torch.float32), requires_grad=True, ) self.beta_lo_u = nn.Parameter( torch.tensor(float(beta.beta_lo_u_init), dtype=torch.float32), requires_grad=True, ) self.beta_bend_u = nn.Parameter( torch.tensor(float(beta.beta_bend_u_init), dtype=torch.float32), requires_grad=True, ) if cfg.alpha_warp is None: self.register_parameter("alpha", None) else: qtr = int(self._d_head) // 4 if qtr <= 0: # pragma: no cover - defensive raise RuntimeError("AxialRoPE2D periods length must be positive") init = float(cfg.alpha_warp.alpha_init) if not math.isfinite(init): raise RuntimeError("alpha_init must be finite for alpha-warp RoPE") self.alpha = nn.Parameter( torch.full((int(qtr),), init, dtype=torch.float32), requires_grad=True, ) self._init_periods() def _apply(self, fn): # type: ignore[override] out = super()._apply(fn) with torch.no_grad(): self.periods.data = self.periods.data.to(dtype=torch.float32) if self.boundary_log_multiplier is not None: self.boundary_log_multiplier.data = ( self.boundary_log_multiplier.data.to(dtype=torch.float32) ) if self.beta_hi_u is not None: self.beta_hi_u.data = self.beta_hi_u.data.to(dtype=torch.float32) if self.beta_lo_u is not None: self.beta_lo_u.data = self.beta_lo_u.data.to(dtype=torch.float32) if self.beta_bend_u is not None: self.beta_bend_u.data = self.beta_bend_u.data.to(dtype=torch.float32) if self.alpha is not None: self.alpha.data = self.alpha.data.to(dtype=torch.float32) return out def _init_periods(self) -> None: """Initialize per-dimension periods using DINOv3 formulas.""" device: torch.device = self.periods.device dtype: torch.dtype = self.periods.dtype d_head = int(self._d_head) qtr = d_head // 4 if qtr <= 0: raise RuntimeError("AxialRoPE2D periods length must be positive") if self.cfg.base is not None: base = float(self.cfg.base) exponents = ( 2.0 * torch.arange(int(qtr), device=device, dtype=dtype) / float(d_head // 2) ) periods = torch.tensor(base, device=device, dtype=dtype) ** exponents else: if self.cfg.min_period is None or self.cfg.max_period is None: raise RuntimeError( "AxialRoPE2DConfig must provide min_period and max_period when base is None" ) min_p = float(self.cfg.min_period) max_p = float(self.cfg.max_period) base = max_p / min_p exponents = torch.linspace(0.0, 1.0, int(qtr), device=device, dtype=dtype) periods = torch.tensor(base, device=device, dtype=dtype) ** exponents periods = periods / torch.tensor(base, device=device, dtype=dtype) periods = periods * torch.tensor(max_p, device=device, dtype=dtype) self.periods.data = periods def forward( self, *, H: int, W: int, scales: Tensor | None, ) -> tuple[Tensor, Tensor]: """Return (sin, cos) buffers for axial 2D RoPE. Args: H: Patch-grid height. W: Patch-grid width. scales: Optional per-batch dilation scale (scalar tensor). When None, returns shared sin/cos shaped ``[HW, head_dim]``. When provided, applies the scalar dilation and still returns shared sin/cos shaped ``[HW, head_dim]``. """ if int(H) <= 0 or int(W) <= 0: raise ValueError("H and W must be positive for AxialRoPE2D forward") device = self.periods.device offset = float(self.cfg.coord_offset) coords: Tensor match self.cfg.coord_mode: case AxialRoPE2DCoordMode.DINOV3_NORMALIZED: coords = _get_dinov3_normalized_coords( int(H), int(W), device=device, normalize=self.cfg.normalize_coords, offset=offset, ) case AxialRoPE2DCoordMode.PATCH_INDICES: coords = _get_patch_index_coords( int(H), int(W), device=device, offset=offset ) case _ as unreachable: # pragma: no cover - defensive raise RuntimeError(f"Unsupported coord_mode: {unreachable}") if coords.dim() != 2 or coords.shape[1] != 2: raise RuntimeError("AxialRoPE2D coords must have shape [HW, 2]") if self.cfg.frequency_aware is not None: if scales is not None: raise ValueError( "frequency-aware axial RoPE does not support dilation scales" ) if self.boundary_log_multiplier is None: raise RuntimeError( "boundary_log_multiplier parameter missing for frequency-aware RoPE" ) ref_h = int(self.cfg.frequency_aware.ref_h_tokens) ref_w = int(self.cfg.frequency_aware.ref_w_tokens) periods_h = lumina_frequency_aware_periods_for_axis( periods=self.periods, axis_len=int(H), ref_axis_len=ref_h, boundary_log_multiplier=self.boundary_log_multiplier, angle_multiplier=float(self.cfg.angle_multiplier), ) periods_w = lumina_frequency_aware_periods_for_axis( periods=self.periods, axis_len=int(W), ref_axis_len=ref_w, boundary_log_multiplier=self.boundary_log_multiplier, angle_multiplier=float(self.cfg.angle_multiplier), ) axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q] elif self.cfg.beta_warp is not None: if scales is not None: raise ValueError( "beta-warp axial RoPE does not support dilation scales" ) if ( self.beta_hi_u is None or self.beta_lo_u is None or self.beta_bend_u is None ): raise RuntimeError("beta warp parameters missing for beta-warp RoPE") beta_cfg = self.cfg.beta_warp ref_h = int(beta_cfg.ref_h_tokens) ref_w = int(beta_cfg.ref_w_tokens) qtr = int(self.periods.numel()) if qtr <= 0: # pragma: no cover - defensive (checked elsewhere) raise RuntimeError("AxialRoPE2D periods length must be positive") beta_max = float(beta_cfg.beta_max) if not math.isfinite(beta_max) or beta_max <= 0.0: raise RuntimeError("beta_max must be finite and > 0") beta_max_t = torch.tensor(beta_max, device=device, dtype=torch.float32) beta_hi = beta_max_t * torch.tanh(self.beta_hi_u.to(dtype=torch.float32)) beta_lo = beta_max_t * torch.tanh(self.beta_lo_u.to(dtype=torch.float32)) beta_bend = beta_max_t * torch.tanh( self.beta_bend_u.to(dtype=torch.float32) ) if qtr == 1: beta = beta_hi[None] else: t = torch.arange(int(qtr), device=device, dtype=torch.float32) / float( qtr - 1 ) bump = 4.0 * t * (1.0 - t) beta = (1.0 - t) * beta_hi + t * beta_lo + beta_bend * bump s_h = float(int(H)) / float(ref_h) s_w = float(int(W)) / float(ref_w) if ( not math.isfinite(s_h) or s_h <= 0.0 or not math.isfinite(s_w) or s_w <= 0.0 ): raise RuntimeError( "Computed invalid axis scale factors for beta-warp RoPE" ) periods_h = self.periods.to(dtype=torch.float32) * torch.pow( torch.tensor(s_h, device=device, dtype=torch.float32), beta ) periods_w = self.periods.to(dtype=torch.float32) * torch.pow( torch.tensor(s_w, device=device, dtype=torch.float32), beta ) axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q] elif self.cfg.alpha_warp is not None: if scales is not None: raise ValueError( "alpha-warp axial RoPE does not support dilation scales" ) if self.alpha is None: raise RuntimeError("alpha parameter missing for alpha-warp RoPE") alpha_cfg = self.cfg.alpha_warp ref_h = int(alpha_cfg.ref_h_tokens) ref_w = int(alpha_cfg.ref_w_tokens) qtr = int(self.periods.numel()) if int(self.alpha.numel()) != qtr: raise RuntimeError( "alpha length must match RoPE periods length for alpha-warp RoPE" ) s_h = float(int(H)) / float(ref_h) s_w = float(int(W)) / float(ref_w) if ( not math.isfinite(s_h) or s_h <= 0.0 or not math.isfinite(s_w) or s_w <= 0.0 ): raise RuntimeError( "Computed invalid axis scale factors for alpha-warp RoPE" ) alpha = self.alpha.to(device=device, dtype=torch.float32) scale_h = torch.pow( torch.tensor(s_h, device=device, dtype=torch.float32), alpha ) scale_w = torch.pow( torch.tensor(s_w, device=device, dtype=torch.float32), alpha ) periods_h = self.periods.to(dtype=torch.float32) / scale_h periods_w = self.periods.to(dtype=torch.float32) / scale_w axis_periods = torch.stack([periods_h, periods_w], dim=0) # [2, Q] else: axis_periods = self.periods[None, :].expand(2, -1).to(dtype=torch.float32) # Angles: angle_multiplier * coords / periods, flattened and tiled. angles = ( float(self.cfg.angle_multiplier) * coords[:, :, None].to(dtype=torch.float32) / axis_periods[None, :, :].to(dtype=torch.float32) ) match self.cfg.dim_layout: case AxialRoPE2DDimLayout.HALF_SPLIT: angles = angles.flatten(1, 2).repeat(1, 2) case AxialRoPE2DDimLayout.PAIR_INTERLEAVED: angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2) case _ as unreachable: # pragma: no cover - defensive raise RuntimeError(f"Unsupported dim_layout: {unreachable}") if angles.shape != (int(H) * int(W), int(self._d_head)): raise RuntimeError( "Unexpected angles shape in AxialRoPE2D: " f"{tuple(angles.shape)} for H={int(H)} W={int(W)}" ) if scales is not None: if scales.dim() != 0: raise ValueError( "AxialRoPE2D scales must be a scalar tensor for per-batch dilation; " "per-sample dilation is not supported" ) angles = angles * scales.to(device=device, dtype=torch.float32) cos = torch.cos(angles) sin = torch.sin(angles) return sin, cos def _dy_ntk_periods_for_axis( *, periods: Tensor, axis_len: int, ref_axis_len: int, noise_time: Tensor, lambda_s: float, lambda_t: float, ) -> Tensor: """Return Dy-NTK periods for one spatial axis. Raises: ValueError: If token lengths or scheduler values are invalid. """ if int(axis_len) <= 0 or int(ref_axis_len) <= 0: raise ValueError("axis_len and ref_axis_len must be positive for Dy-NTK") qtr = int(periods.numel()) if qtr <= 0: raise ValueError("periods must be non-empty for Dy-NTK") scale = float(int(axis_len)) / float(int(ref_axis_len)) if not math.isfinite(scale) or scale <= 0.0: raise ValueError("Dy-NTK axis scale must be finite and > 0") return _dy_ntk_periods_for_scale( periods=periods, scale=scale, noise_time=noise_time, lambda_s=float(lambda_s), lambda_t=float(lambda_t), ) def _dy_ntk_periods_for_scale( *, periods: Tensor, scale: float, noise_time: Tensor, lambda_s: float, lambda_t: float, ) -> Tensor: """Return Dy-NTK periods for a precomputed axis scale.""" axis_scale = float(scale) if not math.isfinite(axis_scale) or axis_scale <= 0.0: raise ValueError("Dy-NTK scale must be finite and > 0") qtr = int(periods.numel()) if qtr <= 0: raise ValueError("periods must be non-empty for Dy-NTK") if scale <= 1.0: return periods.to(dtype=torch.float32) if qtr == 1: exponent = torch.zeros((1,), device=periods.device, dtype=torch.float32) else: exponent = torch.arange(qtr, device=periods.device, dtype=torch.float32) / ( float(qtr - 1) ) kappa = float(lambda_s) * torch.pow( noise_time.to(device=periods.device, dtype=torch.float32), float(lambda_t), ) return periods.to(dtype=torch.float32) * torch.pow( torch.tensor(axis_scale, device=periods.device, dtype=torch.float32), kappa * exponent, ) def _dype_dynamic_exponent( *, noise_time: float, lambda_s: float, lambda_t: float ) -> float: """Return Comfy/DyPE-style dynamic magnitude for normalized noise time.""" noise = float(noise_time) if not math.isfinite(noise): raise ValueError("DyPE noise_time must be finite") noise = max(0.0, min(1.0, noise)) scale = float(lambda_s) exponent = float(lambda_t) if not math.isfinite(scale) or scale <= 0.0: raise ValueError("DyPE lambda_s must be finite and > 0") if not math.isfinite(exponent) or exponent <= 0.0: raise ValueError("DyPE lambda_t must be finite and > 0") return scale * (noise**exponent) def _dype_correction_factor( *, periods: Tensor, rotations: float, ref_axis_len: int, angle_multiplier: float, ) -> float: """Return fractional band index whose wavelength makes ``rotations`` turns.""" if int(ref_axis_len) <= 0: raise ValueError("ref_axis_len must be positive for DyPE correction") rot = float(rotations) if not math.isfinite(rot) or rot <= 0.0: raise ValueError("rotations must be finite and > 0") mult = float(angle_multiplier) if not math.isfinite(mult) or mult <= 0.0: raise ValueError("angle_multiplier must be finite and > 0") if int(periods.numel()) < 2: return 0.0 periods_cpu = periods.detach().to(device=torch.device("cpu"), dtype=torch.float32) p0 = float(periods_cpu[0].item()) p1 = float(periods_cpu[-1].item()) if p0 <= 0.0 or p1 <= p0: raise ValueError("periods must be positive and strictly increasing for DyPE") boundary_wavelength = float(int(ref_axis_len)) / rot boundary_period = (mult / (2.0 * float(math.pi))) * boundary_wavelength log_p0 = math.log(p0) log_p1 = math.log(p1) return float(periods.numel() - 1) * ( (math.log(boundary_period) - log_p0) / (log_p1 - log_p0) ) def _dype_ramp_mask( *, periods: Tensor, threshold_high_rotations: float, threshold_low_rotations: float, ref_axis_len: int, angle_multiplier: float, ) -> Tensor: """Return YaRN's high-to-low band mask for one dynamic threshold pair.""" qtr = int(periods.numel()) if qtr <= 0: raise ValueError("periods must be non-empty for DyPE ramp mask") device = periods.device if qtr == 1: return torch.ones((1,), device=device, dtype=torch.float32) low = math.floor( _dype_correction_factor( periods=periods, rotations=float(threshold_high_rotations), ref_axis_len=int(ref_axis_len), angle_multiplier=float(angle_multiplier), ) ) high = math.ceil( _dype_correction_factor( periods=periods, rotations=float(threshold_low_rotations), ref_axis_len=int(ref_axis_len), angle_multiplier=float(angle_multiplier), ) ) low = max(0, min(qtr - 1, int(low))) high = max(0, min(qtr, int(high))) if low == high: high = min(qtr, low + 1) band = torch.arange(qtr, device=device, dtype=torch.float32) ramp = (band - float(low)) / float(high - low) return 1.0 - torch.clamp(ramp, min=0.0, max=1.0) def _dy_yarn_periods_for_axis( *, periods: Tensor, linear_scale: float, ntk_scale: float, ref_axis_len: int, noise_time: float, lambda_s: float, cfg: AxialRoPE2DDyPEConfig, angle_multiplier: float, ) -> Tensor: """Return Dy-YaRN periods for one spatial axis.""" if int(ref_axis_len) <= 0: raise ValueError("ref_axis_len must be positive for Dy-YaRN") linear_s = float(linear_scale) ntk_s = float(ntk_scale) if ( not math.isfinite(linear_s) or linear_s <= 0.0 or not math.isfinite(ntk_s) or ntk_s <= 0.0 ): raise ValueError("Dy-YaRN axis scales must be finite and > 0") periods_f = periods.to(dtype=torch.float32) if max(linear_s, ntk_s) <= 1.0: return periods_f kappa = _dype_dynamic_exponent( noise_time=float(noise_time), lambda_s=float(lambda_s), lambda_t=float(cfg.lambda_t), ) if kappa <= 1e-6: return periods_f freq_base = float(angle_multiplier) / periods_f freq_linear = float(angle_multiplier) / (periods_f * max(1.0, linear_s)) periods_ntk = _dy_ntk_periods_for_scale( periods=periods_f, scale=max(1.0, ntk_s), noise_time=torch.ones((), device=periods.device, dtype=torch.float32), lambda_s=1.0, lambda_t=1.0, ) freq_ntk = float(angle_multiplier) / periods_ntk beta_mask = _dype_ramp_mask( periods=periods_f, threshold_high_rotations=float(cfg.yarn_beta_0) ** kappa, threshold_low_rotations=float(cfg.yarn_beta_1) ** kappa, ref_axis_len=int(ref_axis_len), angle_multiplier=float(angle_multiplier), ) freq = freq_linear * (1.0 - beta_mask) + freq_ntk * beta_mask gamma_mask = _dype_ramp_mask( periods=periods_f, threshold_high_rotations=float(cfg.yarn_gamma_0) ** kappa, threshold_low_rotations=float(cfg.yarn_gamma_1) ** kappa, ref_axis_len=int(ref_axis_len), angle_multiplier=float(angle_multiplier), ) freq = freq * (1.0 - gamma_mask) + freq_base * gamma_mask return float(angle_multiplier) / freq class AxialRoPE2DDyPE(AxialRoPE2D): """Inference-only axial RoPE wrapper using dynamic position extrapolation.""" dype_cfg: AxialRoPE2DDyPEConfig dype_noise_time: Tensor dype_noise_time_values: list[float] def __init__(self, *, base: AxialRoPE2D, cfg: AxialRoPE2DDyPEConfig) -> None: if not isinstance(base, AxialRoPE2D): raise TypeError("base must be an AxialRoPE2D") if not isinstance(cfg, AxialRoPE2DDyPEConfig): raise TypeError("cfg must be an AxialRoPE2DDyPEConfig") if base.cfg.coord_mode is not AxialRoPE2DCoordMode.PATCH_INDICES: raise ValueError("DyPE requires patch-index axial RoPE coordinates") super().__init__(head_dim=int(base.head_dim), cfg=base.cfg) self.dype_cfg = cfg # ty: ignore[unresolved-attribute] self.register_buffer( "dype_noise_time", torch.tensor(1.0, dtype=torch.float32), persistent=False, ) self.dype_noise_time_values: list[float] = [1.0] with torch.no_grad(): self.periods.copy_(base.periods.detach().to(dtype=torch.float32)) def set_dype_noise_time(self, noise_time: float) -> None: """Set the current normalized diffusion noise time in ``[0, 1]``.""" t = float(noise_time) if not math.isfinite(t) or t < 0.0 or t > 1.0: raise ValueError("DyPE noise_time must be finite and within [0, 1]") self.dype_noise_time.fill_(t) self.dype_noise_time_values[0] = t def _dype_axis_periods( self, *, axis_len: int, ref_axis_len: int, global_scale: float, lambda_s: float, ) -> Tensor: """Return method-specific periods for one spatial axis.""" cfg = self.dype_cfg axis_scale = float(int(axis_len)) / float(int(ref_axis_len)) shared_scale = float(global_scale) if ( not math.isfinite(axis_scale) or axis_scale <= 0.0 or not math.isfinite(shared_scale) or shared_scale <= 0.0 ): raise ValueError("DyPE axis and global scales must be finite and > 0") match cfg.method: case DyPERoPEMethod.DY_NTK: return _dy_ntk_periods_for_scale( periods=self.periods, scale=shared_scale, noise_time=self.dype_noise_time, lambda_s=float(lambda_s), lambda_t=float(cfg.lambda_t), ) case DyPERoPEMethod.VISION_YARN: return _dy_yarn_periods_for_axis( periods=self.periods, linear_scale=axis_scale, ntk_scale=shared_scale, ref_axis_len=int(ref_axis_len), noise_time=float(self.dype_noise_time_values[0]), lambda_s=float(lambda_s), cfg=cfg, angle_multiplier=float(self.cfg.angle_multiplier), ) case DyPERoPEMethod.DY_YARN: return _dy_yarn_periods_for_axis( periods=self.periods, linear_scale=shared_scale, ntk_scale=shared_scale, ref_axis_len=int(ref_axis_len), noise_time=float(self.dype_noise_time_values[0]), lambda_s=float(lambda_s), cfg=cfg, angle_multiplier=float(self.cfg.angle_multiplier), ) case _ as unreachable: raise RuntimeError(f"Unsupported DyPE method: {unreachable}") def forward( self, *, H: int, W: int, scales: Tensor | None, ) -> tuple[Tensor, Tensor]: """Return timestep-aware DyPE sin/cos buffers.""" if scales is not None: raise ValueError("DyPE axial RoPE does not support dilation scales") if int(H) <= 0 or int(W) <= 0: raise ValueError("H and W must be positive for DyPE axial RoPE") device = self.periods.device coords = _get_patch_index_coords( int(H), int(W), device=device, offset=float(self.cfg.coord_offset) ) scale_h = float(int(H)) / float(int(self.dype_cfg.ref_h_tokens)) scale_w = float(int(W)) / float(int(self.dype_cfg.ref_w_tokens)) global_scale = max(scale_h, scale_w) periods_h = self._dype_axis_periods( axis_len=int(H), ref_axis_len=int(self.dype_cfg.ref_h_tokens), global_scale=global_scale, lambda_s=float(self.dype_cfg.lambda_s), ) periods_w = self._dype_axis_periods( axis_len=int(W), ref_axis_len=int(self.dype_cfg.ref_w_tokens), global_scale=global_scale, lambda_s=float(self.dype_cfg.lambda_s), ) axis_periods = torch.stack([periods_h, periods_w], dim=0) angles = ( float(self.cfg.angle_multiplier) * coords[:, :, None].to(dtype=torch.float32) / axis_periods[None, :, :].to(dtype=torch.float32) ) match self.cfg.dim_layout: case AxialRoPE2DDimLayout.HALF_SPLIT: angles = angles.flatten(1, 2).repeat(1, 2) case AxialRoPE2DDimLayout.PAIR_INTERLEAVED: angles = angles.repeat_interleave(2, dim=-1).flatten(1, 2) case _ as unreachable: raise RuntimeError(f"Unsupported dim_layout: {unreachable}") expected_shape = (int(H) * int(W), int(self.head_dim)) if angles.shape != expected_shape: raise RuntimeError( "Unexpected angles shape in DyPE axial RoPE: " f"{tuple(angles.shape)} for expected {expected_shape}" ) sin = torch.sin(angles) cos = torch.cos(angles) if ( self.dype_cfg.method in (DyPERoPEMethod.VISION_YARN, DyPERoPEMethod.DY_YARN) and bool(self.dype_cfg.yarn_attention_scale) and global_scale > 1.0 ): match self.dype_cfg.method: case DyPERoPEMethod.VISION_YARN: mscale_start = 0.1 * math.log(global_scale) + 1.0 kappa = _dype_dynamic_exponent( noise_time=float(self.dype_noise_time_values[0]), lambda_s=1.0, lambda_t=float(self.dype_cfg.lambda_t), ) mscale = 1.0 + (mscale_start - 1.0) * kappa case DyPERoPEMethod.DY_YARN: mscale = 1.0 + 0.1 * math.log(global_scale) / math.sqrt( global_scale ) case _ as unreachable: # pragma: no cover - guarded above raise RuntimeError( f"Unsupported YaRN attention scale: {unreachable}" ) if mscale > 1.0: sin = sin * float(mscale) cos = cos * float(mscale) return sin, cos def build_axial_rope2d_dype( *, base: AxialRoPE2D, cfg: AxialRoPE2DDyPEConfig ) -> AxialRoPE2DDyPE: """Build an inference-only DyPE wrapper for an existing axial RoPE.""" return AxialRoPE2DDyPE(base=base, cfg=cfg).to(device=base.periods.device) def set_axial_rope2d_dype_noise_time(module: nn.Module, *, noise_time: float) -> bool: """Set DyPE noise time on all axial DyPE modules inside ``module``.""" updated = False for child in module.modules(): match child: case AxialRoPE2DDyPE() as dype: dype.set_dype_noise_time(float(noise_time)) updated = True case _: pass return updated