File size: 4,376 Bytes
e101805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import math
import torch
from torch import Tensor, nn
# RoPE positional embedding with no mixing of coordinates (axial) and no learnable weights
# Supports two parametrizations of the rope parameters: either using `base` or `min_period` and `max_period`.
class RopePositionEmbedding(nn.Module):
# NOTE: Modified to index by patch_size instead of dataset-wise max H/W.
# - During __init__, provide patch_size and store as member.
# - In forward(), pass coords tensor with shape [B, N, 2];
# coords are converted to patch indices using (coords + patch_size//2) / patch_size.
def __init__(
self,
embed_dim: int,
*,
num_heads: int,
patch_size: int = 256,
base: float | None = 100.0,
min_period: float | None = None,
max_period: float | None = None,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
):
super().__init__()
assert embed_dim % (4 * num_heads) == 0
both_periods = min_period is not None and max_period is not None
if (base is None and not both_periods) or (base is not None and both_periods):
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
D_head = embed_dim // num_heads
self.base = base
self.min_period = min_period
self.max_period = max_period
self.D_head = D_head
# Store patch size for converting pixel coords to patch indices
self.patch_size = int(patch_size)
# Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
self.dtype = dtype
self.register_buffer(
"periods",
torch.empty(D_head // 4, device=device, dtype=dtype),
persistent=True,
)
self._init_weights()
def forward(self, *, coords: Tensor) -> tuple[Tensor, Tensor]:
"""Compute RoPE values for given coordinates.
Args:
coords: Tensor of shape [B, N, 2] representing (h, w) pixel coordinates.
Converted to patch indices using (coord + patch_size//2) / patch_size.
Returns:
Tuple (sin, cos):
- Outputs are [B, 1, N, D_head] to broadcast across heads.
"""
device = self.periods.device
dtype = self.dtype
if coords.device != device:
coords = coords.to(device)
if dtype is not None and coords.dtype != dtype:
coords = coords.to(dtype)
# Enforce batched coords for consistent behavior with attention and sampling
assert coords.ndim == 3 and coords.shape[-1] == 2, f"coords must be [B, N, 2], got shape {tuple(coords.shape)}"
# Convert pixel coordinates to patch indices centered at patch centers
# index = (coord + patch_size//2) / patch_size
patch_size_tensor = torch.tensor(self.patch_size, device=device, dtype=dtype) # for broadcasting
center_offset = torch.tensor(self.patch_size // 2, device=device, dtype=dtype)
coords_norm = (coords + center_offset) / patch_size_tensor
# Prepare angles and sin/cos for [B, N, 2]
angles = 2 * math.pi * coords_norm[:, :, :, None] / self.periods[None, None, None, :] # [B, N, 2, D//4]
angles = angles.flatten(2, 3) # [B, N, D//2]
angles = angles.tile((1, 1, 2)) # [B, N, D]
cos = torch.cos(angles) # [B, N, D]
sin = torch.sin(angles) # [B, N, D]
# Expand head dimension to broadcast across heads: [B, 1, N, D]
return (sin.unsqueeze(1), cos.unsqueeze(1))
def _init_weights(self):
device = self.periods.device
dtype = self.dtype
if self.base is not None:
periods = self.base ** (
2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
) # [D//4]
else:
base = self.max_period / self.min_period
exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) # [D//4] range [0, 1]
periods = base**exponents # range [1, max_period / min_period]
periods = periods / base # range [min_period / max_period, 1]
periods = periods * self.max_period # range [min_period, max_period]
self.periods.data = periods |