FlexiCT-2D / rope_position_encoding.py
ricklisz123's picture
Upload folder using huggingface_hub
fcefac1 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import math
from typing import Literal
import numpy as np
import torch
from torch import Tensor, nn
# RoPE positional embedding with no mixing of coordinates (axial) and no learnable weights
# Supports two parametrizations of the rope parameters: either using `base` or `min_period` and `max_period`.
class RopePositionEmbedding(nn.Module):
def __init__(
self,
embed_dim: int,
*,
num_heads: int,
base: float | None = 100.0,
min_period: float | None = None,
max_period: float | None = None,
normalize_coords: Literal["min", "max", "separate"] = "separate",
shift_coords: float | None = None,
jitter_coords: float | None = None,
rescale_coords: float | None = None,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
):
super().__init__()
assert embed_dim % (4 * num_heads) == 0
both_periods = min_period is not None and max_period is not None
if (base is None and not both_periods) or (base is not None and both_periods):
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
D_head = embed_dim // num_heads
self.base = base
self.min_period = min_period
self.max_period = max_period
self.D_head = D_head
self.normalize_coords = normalize_coords
self.shift_coords = shift_coords
self.jitter_coords = jitter_coords
self.rescale_coords = rescale_coords
# Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
self.dtype = dtype # Don't rely on self.periods.dtype
self.register_buffer(
"periods",
torch.empty(D_head // 4, device=device, dtype=dtype),
persistent=True,
)
self._init_weights()
def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
device = self.periods.device
dtype = self.dtype
dd = {"device": device, "dtype": dtype}
# Prepare coords in range [-1, +1]
if self.normalize_coords == "max":
max_HW = max(H, W)
coords_h = torch.arange(0.5, H, **dd) / max_HW # [H]
coords_w = torch.arange(0.5, W, **dd) / max_HW # [W]
elif self.normalize_coords == "min":
min_HW = min(H, W)
coords_h = torch.arange(0.5, H, **dd) / min_HW # [H]
coords_w = torch.arange(0.5, W, **dd) / min_HW # [W]
elif self.normalize_coords == "separate":
coords_h = torch.arange(0.5, H, **dd) / H # [H]
coords_w = torch.arange(0.5, W, **dd) / W # [W]
else:
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2]
coords = coords.flatten(0, 1) # [HW, 2]
coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1]
# Shift coords by adding a uniform value in [-shift, shift]
if self.training and self.shift_coords is not None:
shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
coords += shift_hw[None, :]
# Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
if self.training and self.jitter_coords is not None:
jitter_max = np.log(self.jitter_coords)
jitter_min = -jitter_max
jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
coords *= jitter_hw[None, :]
# Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
if self.training and self.rescale_coords is not None:
rescale_max = np.log(self.rescale_coords)
rescale_min = -rescale_max
rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
coords *= rescale_hw
# Prepare angles and sin/cos
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] # [HW, 2, D//4]
angles = angles.flatten(1, 2) # [HW, D//2]
angles = angles.tile(2) # [HW, D]
cos = torch.cos(angles) # [HW, D]
sin = torch.sin(angles) # [HW, D]
return (sin, cos) # 2 * [HW, D]
def _init_weights(self):
device = self.periods.device
dtype = self.dtype
if self.base is not None:
periods = self.base ** (
2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
) # [D//4]
else:
base = self.max_period / self.min_period
exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) # [D//4] range [0, 1]
periods = base**exponents # range [1, max_period / min_period]
periods = periods / base # range [min_period / max_period, 1]
periods = periods * self.max_period # range [min_period, max_period]
self.periods.data = periods
class RopePositionEmbedding3D(nn.Module):
"""
RoPE positional embedding for 3D grids with no mixing across axes (axial),
and no learnable weights.
Supports two parametrizations of the rope parameters: either using `base`
or `min_period` + `max_period`.
Returns (sin, cos) each shaped [D*H*W, D_head], suitable for applying to q/k.
"""
def __init__(
self,
embed_dim: int,
*,
num_heads: int,
base: float | None = 100.0,
min_period: float | None = None,
max_period: float | None = None,
normalize_coords: Literal["min", "max", "separate"] = "separate",
shift_coords: float | None = None,
jitter_coords: float | None = None,
rescale_coords: float | None = None,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
):
super().__init__()
# For 3 axes, per-axis rotary block is D_head//6
assert embed_dim % (6 * num_heads) == 0, \
"For 3D RoPE, embed_dim must be divisible by 6 * num_heads."
both_periods = (min_period is not None) and (max_period is not None)
if (base is None and not both_periods) or (base is not None and both_periods):
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
D_head = embed_dim // num_heads
self.base = base
self.min_period = min_period
self.max_period = max_period
self.D_head = D_head
self.normalize_coords = normalize_coords
self.shift_coords = shift_coords
self.jitter_coords = jitter_coords
self.rescale_coords = rescale_coords
# Keep a persistent buffer so teacher.load_state_dict(student.state_dict()) works.
self.dtype = dtype # Don't rely on self.periods.dtype
self.register_buffer(
"periods",
torch.empty(D_head // 6, device=device, dtype=dtype),
persistent=True,
)
self._init_weights()
def _init_weights(self):
device = self.periods.device
dtype = self.dtype
per_axis_block = self.D_head // 6 # 3 axes × per_axis_block = D_head//2
if self.base is not None:
# Classic exponential spectrum
periods = self.base ** (
2 * torch.arange(per_axis_block, device=device, dtype=dtype) / (self.D_head // 2)
) # [per_axis_block]
else:
# Linearly spaced in log-period between min_period and max_period
base = self.max_period / self.min_period
exponents = torch.linspace(0, 1, per_axis_block, device=device, dtype=dtype)
periods = base**exponents # [1, base]
periods = periods / base # [1/base, 1]
periods = periods * self.max_period # [min_period, max_period]
self.periods.data = periods
def forward(self, *, D: int, H: int, W: int) -> tuple[Tensor, Tensor]:
"""
Args:
D, H, W: depth, height, width (integers)
Returns:
(sin, cos): two tensors of shape [D*H*W, D_head]
"""
device = self.periods.device
dtype = self.dtype
dd = {"device": device, "dtype": dtype}
# --- Build normalized coords in [-1, +1] for each axis ---
if self.normalize_coords == "max":
m = max(D, H, W)
coords_d = torch.arange(0.5, D, **dd) / m # [D]
coords_h = torch.arange(0.5, H, **dd) / m # [H]
coords_w = torch.arange(0.5, W, **dd) / m # [W]
elif self.normalize_coords == "min":
m = min(D, H, W)
coords_d = torch.arange(0.5, D, **dd) / m
coords_h = torch.arange(0.5, H, **dd) / m
coords_w = torch.arange(0.5, W, **dd) / m
elif self.normalize_coords == "separate":
coords_d = torch.arange(0.5, D, **dd) / D
coords_h = torch.arange(0.5, H, **dd) / H
coords_w = torch.arange(0.5, W, **dd) / W
else:
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
# Meshgrid in (D, H, W) order; last dim stacks (d, h, w)
coords = torch.stack(
torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"),
dim=-1
) # [D, H, W, 3]
coords = coords.flatten(0, 2) # [DHW, 3]
coords = 2.0 * coords - 1.0 # shift [0,1] -> [-1,1]
# --- Optional data-aug on coords (train-time only) ---
if self.training and self.shift_coords is not None:
shift_dhw = torch.empty(3, **dd).uniform_(-self.shift_coords, self.shift_coords)
coords = coords + shift_dhw[None, :]
if self.training and self.jitter_coords is not None:
jitter_max = np.log(self.jitter_coords)
jitter_min = -jitter_max
jitter_dhw = torch.empty(3, **dd).uniform_(jitter_min, jitter_max).exp()
coords = coords * jitter_dhw[None, :]
if self.training and self.rescale_coords is not None:
rescale_max = np.log(self.rescale_coords)
rescale_min = -rescale_max
rescale = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
coords = coords * rescale
# --- Build rotary angles (axial, no mixing) ---
# periods: [P] where P = D_head//6
# coords: [N, 3], broadcast => [N, 3, P]
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] # [N, 3, P]
angles = angles.flatten(1, 2) # [N, 3*P] == [N, D_head//2]
angles = angles.tile(2) # [N, D_head]
cos = torch.cos(angles)
sin = torch.sin(angles)
return (sin, cos)