gpu_symbol / engine /backbone /dinov3 /layers /rope_position_encoding.py
himipo's picture
first
11aa70b
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import math
from typing import Literal
import numpy as np
import torch
from torch import Tensor, nn
# RoPE positional embedding with no mixing of coordinates (axial) and no learnable weights
# Supports two parametrizations of the rope parameters: either using `base` or `min_period` and `max_period`.
class RopePositionEmbedding(nn.Module):
def __init__(
self,
embed_dim: int,
*,
num_heads: int,
base: float | None = 100.0,
min_period: float | None = None,
max_period: float | None = None,
normalize_coords: Literal["min", "max", "separate"] = "separate",
shift_coords: float | None = None,
jitter_coords: float | None = None,
rescale_coords: float | None = None,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
):
super().__init__()
assert embed_dim % (4 * num_heads) == 0
both_periods = min_period is not None and max_period is not None
if (base is None and not both_periods) or (base is not None and both_periods):
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
D_head = embed_dim // num_heads
self.base = base
self.min_period = min_period
self.max_period = max_period
self.D_head = D_head
self.normalize_coords = normalize_coords
self.shift_coords = shift_coords
self.jitter_coords = jitter_coords
self.rescale_coords = rescale_coords
# Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
self.dtype = dtype # Don't rely on self.periods.dtype
self.register_buffer(
"periods",
torch.empty(D_head // 4, device=device, dtype=dtype),
persistent=True,
)
self._init_weights()
def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
device = self.periods.device
dtype = self.dtype
dd = {"device": device, "dtype": dtype}
# Prepare coords in range [-1, +1]
if self.normalize_coords == "max":
max_HW = max(H, W)
coords_h = torch.arange(0.5, H, **dd) / max_HW # [H]
coords_w = torch.arange(0.5, W, **dd) / max_HW # [W]
elif self.normalize_coords == "min":
min_HW = min(H, W)
coords_h = torch.arange(0.5, H, **dd) / min_HW # [H]
coords_w = torch.arange(0.5, W, **dd) / min_HW # [W]
elif self.normalize_coords == "separate":
coords_h = torch.arange(0.5, H, **dd) / H # [H]
coords_w = torch.arange(0.5, W, **dd) / W # [W]
else:
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2]
coords = coords.flatten(0, 1) # [HW, 2]
coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1]
# Shift coords by adding a uniform value in [-shift, shift]
if self.training and self.shift_coords is not None:
shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
coords += shift_hw[None, :]
# Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
if self.training and self.jitter_coords is not None:
jitter_max = np.log(self.jitter_coords)
jitter_min = -jitter_max
jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
coords *= jitter_hw[None, :]
# Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
if self.training and self.rescale_coords is not None:
rescale_max = np.log(self.rescale_coords)
rescale_min = -rescale_max
rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
coords *= rescale_hw
# Prepare angles and sin/cos
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] # [HW, 2, D//4]
angles = angles.flatten(1, 2) # [HW, D//2]
angles = angles.tile(2) # [HW, D]
cos = torch.cos(angles) # [HW, D]
sin = torch.sin(angles) # [HW, D]
return (sin, cos) # 2 * [HW, D]
def _init_weights(self):
device = self.periods.device
dtype = self.dtype
if self.base is not None:
periods = self.base ** (
2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
) # [D//4]
else:
base = self.max_period / self.min_period
exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) # [D//4] range [0, 1]
periods = base**exponents # range [1, max_period / min_period]
periods = periods / base # range [min_period / max_period, 1]
periods = periods * self.max_period # range [min_period, max_period]
self.periods.data = periods