# Copyright (c) Meta Platforms, Inc. and affiliates. # # This software may be used and distributed in accordance with # the terms of the DINOv3 License Agreement. import math from typing import Literal import numpy as np import torch from torch import Tensor, nn # RoPE positional embedding with no mixing of coordinates (axial) and no learnable weights # Supports two parametrizations of the rope parameters: either using `base` or `min_period` and `max_period`. class RopePositionEmbedding(nn.Module): def __init__( self, embed_dim: int, *, num_heads: int, base: float | None = 100.0, min_period: float | None = None, max_period: float | None = None, normalize_coords: Literal["min", "max", "separate"] = "separate", shift_coords: float | None = None, jitter_coords: float | None = None, rescale_coords: float | None = None, dtype: torch.dtype | None = None, device: torch.device | None = None, ): super().__init__() assert embed_dim % (4 * num_heads) == 0 both_periods = min_period is not None and max_period is not None if (base is None and not both_periods) or (base is not None and both_periods): raise ValueError( "Either `base` or `min_period`+`max_period` must be provided." ) D_head = embed_dim // num_heads self.base = base self.min_period = min_period self.max_period = max_period self.D_head = D_head self.normalize_coords = normalize_coords self.shift_coords = shift_coords self.jitter_coords = jitter_coords self.rescale_coords = rescale_coords # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher self.dtype = dtype # Don't rely on self.periods.dtype self.register_buffer( "periods", torch.empty(D_head // 4, device=device, dtype=dtype), persistent=True, ) self._init_weights() def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]: device = self.periods.device dtype = self.dtype dd = {"device": device, "dtype": dtype} # Prepare coords in range [-1, +1] if self.normalize_coords == "max": max_HW = max(H, W) coords_h = torch.arange(0.5, H, **dd) / max_HW # [H] coords_w = torch.arange(0.5, W, **dd) / max_HW # [W] elif self.normalize_coords == "min": min_HW = min(H, W) coords_h = torch.arange(0.5, H, **dd) / min_HW # [H] coords_w = torch.arange(0.5, W, **dd) / min_HW # [W] elif self.normalize_coords == "separate": coords_h = torch.arange(0.5, H, **dd) / H # [H] coords_w = torch.arange(0.5, W, **dd) / W # [W] else: raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") coords = torch.stack( torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1 ) # [H, W, 2] coords = coords.flatten(0, 1) # [HW, 2] coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1] # Shift coords by adding a uniform value in [-shift, shift] if self.training and self.shift_coords is not None: shift_hw = torch.empty(2, **dd).uniform_( -self.shift_coords, self.shift_coords ) coords += shift_hw[None, :] # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] if self.training and self.jitter_coords is not None: jitter_max = np.log(self.jitter_coords) jitter_min = -jitter_max jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() coords *= jitter_hw[None, :] # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] if self.training and self.rescale_coords is not None: rescale_max = np.log(self.rescale_coords) rescale_min = -rescale_max rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() coords *= rescale_hw # Prepare angles and sin/cos angles = ( 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] ) # [HW, 2, D//4] angles = angles.flatten(1, 2) # [HW, D//2] angles = angles.tile(2) # [HW, D] cos = torch.cos(angles) # [HW, D] sin = torch.sin(angles) # [HW, D] return sin, cos # 2 * [HW, D] def _init_weights(self): device = self.periods.device dtype = self.dtype if self.base is not None: periods = self.base ** ( 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2) ) # [D//4] else: base = self.max_period / self.min_period exponents = torch.linspace( 0, 1, self.D_head // 4, device=device, dtype=dtype ) # [D//4] range [0, 1] periods = base**exponents # range [1, max_period / min_period] periods = periods / base # range [min_period / max_period, 1] periods = periods * self.max_period # range [min_period, max_period] self.periods.data = periods if __name__ == "__main__": import torch import numpy as np import matplotlib.pyplot as plt def get_rope_values(H, W, embed_dim, num_heads, base): # Setup parameters similar to Repo 1 D_head = embed_dim // num_heads print(D_head // 4, D_head // 2, (D_head // 4) / (D_head // 2)) # We'll pick the first period (the "fastest" one) period = base ** (2 * torch.arange(D_head // 4) / (D_head // 2)) period = period[3] # First period # Normalized coordinates as per Repo 1 coords_h = torch.arange(0.5, H) / H coords_w = torch.arange(0.5, W) / W grid_h, grid_w = torch.meshgrid(coords_h, coords_w, indexing="ij") # Convert to [-1, 1] grid_h = 2.0 * grid_h - 1.0 grid_w = 2.0 * grid_w - 1.0 # Calculate Sine value (using H-coordinate for visualization) # Formula: sin(2 * pi * coord / period) vals = torch.sin(2 * np.pi * grid_h / period) return vals.numpy() # Settings embed_dim = 768 num_heads = 12 bases = [100, 10000] sizes = [14, 28] fig, axes = plt.subplots(2, 2, figsize=(12, 10)) for i, base in enumerate(bases): for j, size in enumerate(sizes): vals = get_rope_values(size, size, embed_dim, num_heads, base) ax = axes[i, j] im = ax.imshow(vals, cmap="RdBu", extent=[-1, 1, -1, 1]) ax.set_title(f"Base: {base} | Grid: {size}x{size}") ax.set_xlabel("Width (Normalized)") ax.set_ylabel("Height (Normalized)") plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) plt.tight_layout() plt.show()