|
|
|
|
|
|
|
|
|
|
| import math
|
| from typing import Literal
|
|
|
| import numpy as np
|
| import torch
|
| from torch import Tensor, nn
|
|
|
|
|
|
|
|
|
| class RopePositionEmbedding(nn.Module):
|
| def __init__(
|
| self,
|
| embed_dim: int,
|
| *,
|
| num_heads: int,
|
| base: float | None = 100.0,
|
| min_period: float | None = None,
|
| max_period: float | None = None,
|
| normalize_coords: Literal["min", "max", "separate"] = "separate",
|
| shift_coords: float | None = None,
|
| jitter_coords: float | None = None,
|
| rescale_coords: float | None = None,
|
| dtype: torch.dtype | None = None,
|
| device: torch.device | None = None,
|
| ):
|
| super().__init__()
|
| assert embed_dim % (4 * num_heads) == 0
|
| both_periods = min_period is not None and max_period is not None
|
| if (base is None and not both_periods) or (base is not None and both_periods):
|
| raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
|
|
|
| D_head = embed_dim // num_heads
|
| self.base = base
|
| self.min_period = min_period
|
| self.max_period = max_period
|
| self.D_head = D_head
|
| self.normalize_coords = normalize_coords
|
| self.shift_coords = shift_coords
|
| self.jitter_coords = jitter_coords
|
| self.rescale_coords = rescale_coords
|
|
|
|
|
| self.dtype = dtype
|
| self.register_buffer(
|
| "periods",
|
| torch.empty(D_head // 4, device=device, dtype=dtype),
|
| persistent=True,
|
| )
|
| self._init_weights()
|
|
|
| def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
|
| device = self.periods.device
|
| dtype = self.dtype
|
| dd = {"device": device, "dtype": dtype}
|
|
|
|
|
| if self.normalize_coords == "max":
|
| max_HW = max(H, W)
|
| coords_h = torch.arange(0.5, H, **dd) / max_HW
|
| coords_w = torch.arange(0.5, W, **dd) / max_HW
|
| elif self.normalize_coords == "min":
|
| min_HW = min(H, W)
|
| coords_h = torch.arange(0.5, H, **dd) / min_HW
|
| coords_w = torch.arange(0.5, W, **dd) / min_HW
|
| elif self.normalize_coords == "separate":
|
| coords_h = torch.arange(0.5, H, **dd) / H
|
| coords_w = torch.arange(0.5, W, **dd) / W
|
| else:
|
| raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
|
| coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
| coords = coords.flatten(0, 1)
|
| coords = 2.0 * coords - 1.0
|
|
|
|
|
| if self.training and self.shift_coords is not None:
|
| shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
|
| coords += shift_hw[None, :]
|
|
|
|
|
| if self.training and self.jitter_coords is not None:
|
| jitter_max = np.log(self.jitter_coords)
|
| jitter_min = -jitter_max
|
| jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
|
| coords *= jitter_hw[None, :]
|
|
|
|
|
| if self.training and self.rescale_coords is not None:
|
| rescale_max = np.log(self.rescale_coords)
|
| rescale_min = -rescale_max
|
| rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
|
| coords *= rescale_hw
|
|
|
|
|
| angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
|
| angles = angles.flatten(1, 2)
|
| angles = angles.tile(2)
|
| cos = torch.cos(angles)
|
| sin = torch.sin(angles)
|
|
|
| return (sin, cos)
|
|
|
| def _init_weights(self):
|
| device = self.periods.device
|
| dtype = self.dtype
|
| if self.base is not None:
|
| periods = self.base ** (
|
| 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
|
| )
|
| else:
|
| base = self.max_period / self.min_period
|
| exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype)
|
| periods = base**exponents
|
| periods = periods / base
|
| periods = periods * self.max_period
|
| self.periods.data = periods
|
|
|