| | from typing import * |
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class RotaryPositionEmbedder(nn.Module): |
| | def __init__( |
| | self, |
| | head_dim: int, |
| | dim: int = 3, |
| | rope_freq: Tuple[float, float] = (1.0, 10000.0) |
| | ): |
| | super().__init__() |
| | assert head_dim % 2 == 0, "Head dim must be divisible by 2" |
| | self.head_dim = head_dim |
| | self.dim = dim |
| | self.rope_freq = rope_freq |
| | self.freq_dim = head_dim // 2 // dim |
| | self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim |
| | self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) |
| | |
| | def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: |
| | self.freqs = self.freqs.to(indices.device) |
| | phases = torch.outer(indices, self.freqs) |
| | phases = torch.polar(torch.ones_like(phases), phases) |
| | return phases |
| | |
| | @staticmethod |
| | def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: |
| | x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) |
| | x_rotated = x_complex * phases.unsqueeze(-2) |
| | x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) |
| | return x_embed |
| | |
| | def forward(self, indices: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | indices (torch.Tensor): [..., N, C] tensor of spatial positions |
| | """ |
| | assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" |
| | phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) |
| | if phases.shape[-1] < self.head_dim // 2: |
| | padn = self.head_dim // 2 - phases.shape[-1] |
| | phases = torch.cat([phases, torch.polar( |
| | torch.ones(*phases.shape[:-1], padn, device=phases.device), |
| | torch.zeros(*phases.shape[:-1], padn, device=phases.device) |
| | )], dim=-1) |
| | return phases |