|
|
import math |
|
|
|
|
|
import torch |
|
|
from torch import Tensor, nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RopePositionEmbedding(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int, |
|
|
*, |
|
|
num_heads: int, |
|
|
patch_size: int = 256, |
|
|
base: float | None = 100.0, |
|
|
min_period: float | None = None, |
|
|
max_period: float | None = None, |
|
|
dtype: torch.dtype | None = None, |
|
|
device: torch.device | None = None, |
|
|
): |
|
|
super().__init__() |
|
|
assert embed_dim % (4 * num_heads) == 0 |
|
|
both_periods = min_period is not None and max_period is not None |
|
|
if (base is None and not both_periods) or (base is not None and both_periods): |
|
|
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.") |
|
|
|
|
|
D_head = embed_dim // num_heads |
|
|
self.base = base |
|
|
self.min_period = min_period |
|
|
self.max_period = max_period |
|
|
self.D_head = D_head |
|
|
|
|
|
self.patch_size = int(patch_size) |
|
|
|
|
|
|
|
|
self.dtype = dtype |
|
|
self.register_buffer( |
|
|
"periods", |
|
|
torch.empty(D_head // 4, device=device, dtype=dtype), |
|
|
persistent=True, |
|
|
) |
|
|
self._init_weights() |
|
|
|
|
|
def forward(self, *, coords: Tensor) -> tuple[Tensor, Tensor]: |
|
|
"""Compute RoPE values for given coordinates. |
|
|
|
|
|
Args: |
|
|
coords: Tensor of shape [B, N, 2] representing (h, w) pixel coordinates. |
|
|
Converted to patch indices using (coord + patch_size//2) / patch_size. |
|
|
Returns: |
|
|
Tuple (sin, cos): |
|
|
- Outputs are [B, 1, N, D_head] to broadcast across heads. |
|
|
""" |
|
|
device = self.periods.device |
|
|
dtype = self.dtype |
|
|
|
|
|
if coords.device != device: |
|
|
coords = coords.to(device) |
|
|
if dtype is not None and coords.dtype != dtype: |
|
|
coords = coords.to(dtype) |
|
|
|
|
|
assert coords.ndim == 3 and coords.shape[-1] == 2, f"coords must be [B, N, 2], got shape {tuple(coords.shape)}" |
|
|
|
|
|
|
|
|
|
|
|
patch_size_tensor = torch.tensor(self.patch_size, device=device, dtype=dtype) |
|
|
center_offset = torch.tensor(self.patch_size // 2, device=device, dtype=dtype) |
|
|
coords_norm = (coords + center_offset) / patch_size_tensor |
|
|
|
|
|
|
|
|
angles = 2 * math.pi * coords_norm[:, :, :, None] / self.periods[None, None, None, :] |
|
|
angles = angles.flatten(2, 3) |
|
|
angles = angles.tile((1, 1, 2)) |
|
|
cos = torch.cos(angles) |
|
|
sin = torch.sin(angles) |
|
|
|
|
|
return (sin.unsqueeze(1), cos.unsqueeze(1)) |
|
|
|
|
|
def _init_weights(self): |
|
|
device = self.periods.device |
|
|
dtype = self.dtype |
|
|
if self.base is not None: |
|
|
periods = self.base ** ( |
|
|
2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2) |
|
|
) |
|
|
else: |
|
|
base = self.max_period / self.min_period |
|
|
exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) |
|
|
periods = base**exponents |
|
|
periods = periods / base |
|
|
periods = periods * self.max_period |
|
|
self.periods.data = periods |