ViT-Patch-PCA-Visualisation / hf_src /layers /rope_position_encoding.py
Tenbatsu24
add: missing files
a10ce46
Raw
History Blame Contribute Delete
7.33 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import math
from typing import Literal
import numpy as np
import torch
from torch import Tensor, nn
# RoPE positional embedding with no mixing of coordinates (axial) and no learnable weights
# Supports two parametrizations of the rope parameters: either using `base` or `min_period` and `max_period`.
class RopePositionEmbedding(nn.Module):
def __init__(
self,
embed_dim: int,
*,
num_heads: int,
base: float | None = 100.0,
min_period: float | None = None,
max_period: float | None = None,
normalize_coords: Literal["min", "max", "separate"] = "separate",
shift_coords: float | None = None,
jitter_coords: float | None = None,
rescale_coords: float | None = None,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
):
super().__init__()
assert embed_dim % (4 * num_heads) == 0
both_periods = min_period is not None and max_period is not None
if (base is None and not both_periods) or (base is not None and both_periods):
raise ValueError(
"Either `base` or `min_period`+`max_period` must be provided."
)
D_head = embed_dim // num_heads
self.base = base
self.min_period = min_period
self.max_period = max_period
self.D_head = D_head
self.normalize_coords = normalize_coords
self.shift_coords = shift_coords
self.jitter_coords = jitter_coords
self.rescale_coords = rescale_coords
# Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
self.dtype = dtype # Don't rely on self.periods.dtype
self.register_buffer(
"periods",
torch.empty(D_head // 4, device=device, dtype=dtype),
persistent=True,
)
self._init_weights()
def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
device = self.periods.device
dtype = self.dtype
dd = {"device": device, "dtype": dtype}
# Prepare coords in range [-1, +1]
if self.normalize_coords == "max":
max_HW = max(H, W)
coords_h = torch.arange(0.5, H, **dd) / max_HW # [H]
coords_w = torch.arange(0.5, W, **dd) / max_HW # [W]
elif self.normalize_coords == "min":
min_HW = min(H, W)
coords_h = torch.arange(0.5, H, **dd) / min_HW # [H]
coords_w = torch.arange(0.5, W, **dd) / min_HW # [W]
elif self.normalize_coords == "separate":
coords_h = torch.arange(0.5, H, **dd) / H # [H]
coords_w = torch.arange(0.5, W, **dd) / W # [W]
else:
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
coords = torch.stack(
torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1
) # [H, W, 2]
coords = coords.flatten(0, 1) # [HW, 2]
coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1]
# Shift coords by adding a uniform value in [-shift, shift]
if self.training and self.shift_coords is not None:
shift_hw = torch.empty(2, **dd).uniform_(
-self.shift_coords, self.shift_coords
)
coords += shift_hw[None, :]
# Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
if self.training and self.jitter_coords is not None:
jitter_max = np.log(self.jitter_coords)
jitter_min = -jitter_max
jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
coords *= jitter_hw[None, :]
# Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
if self.training and self.rescale_coords is not None:
rescale_max = np.log(self.rescale_coords)
rescale_min = -rescale_max
rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
coords *= rescale_hw
# Prepare angles and sin/cos
angles = (
2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
) # [HW, 2, D//4]
angles = angles.flatten(1, 2) # [HW, D//2]
angles = angles.tile(2) # [HW, D]
cos = torch.cos(angles) # [HW, D]
sin = torch.sin(angles) # [HW, D]
return sin, cos # 2 * [HW, D]
def _init_weights(self):
device = self.periods.device
dtype = self.dtype
if self.base is not None:
periods = self.base ** (
2
* torch.arange(self.D_head // 4, device=device, dtype=dtype)
/ (self.D_head // 2)
) # [D//4]
else:
base = self.max_period / self.min_period
exponents = torch.linspace(
0, 1, self.D_head // 4, device=device, dtype=dtype
) # [D//4] range [0, 1]
periods = base**exponents # range [1, max_period / min_period]
periods = periods / base # range [min_period / max_period, 1]
periods = periods * self.max_period # range [min_period, max_period]
self.periods.data = periods
if __name__ == "__main__":
import torch
import numpy as np
import matplotlib.pyplot as plt
def get_rope_values(H, W, embed_dim, num_heads, base):
# Setup parameters similar to Repo 1
D_head = embed_dim // num_heads
print(D_head // 4, D_head // 2, (D_head // 4) / (D_head // 2))
# We'll pick the first period (the "fastest" one)
period = base ** (2 * torch.arange(D_head // 4) / (D_head // 2))
period = period[3] # First period
# Normalized coordinates as per Repo 1
coords_h = torch.arange(0.5, H) / H
coords_w = torch.arange(0.5, W) / W
grid_h, grid_w = torch.meshgrid(coords_h, coords_w, indexing="ij")
# Convert to [-1, 1]
grid_h = 2.0 * grid_h - 1.0
grid_w = 2.0 * grid_w - 1.0
# Calculate Sine value (using H-coordinate for visualization)
# Formula: sin(2 * pi * coord / period)
vals = torch.sin(2 * np.pi * grid_h / period)
return vals.numpy()
# Settings
embed_dim = 768
num_heads = 12
bases = [100, 10000]
sizes = [14, 28]
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
for i, base in enumerate(bases):
for j, size in enumerate(sizes):
vals = get_rope_values(size, size, embed_dim, num_heads, base)
ax = axes[i, j]
im = ax.imshow(vals, cmap="RdBu", extent=[-1, 1, -1, 1])
ax.set_title(f"Base: {base} | Grid: {size}x{size}")
ax.set_xlabel("Width (Normalized)")
ax.set_ylabel("Height (Normalized)")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()