"""Transformer class.""" # from torch_geometric.nn import GATv2Conv import math import torch from torch import nn from typing import Tuple def _pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32): # Maximum initial frequency is 1 return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0) # https://github.com/cvg/LightGlue/blob/b1cd942fc4a3a824b6aedff059d84f5c31c297f6/lightglue/lightglue.py#L51 def _rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate pairs of scalars as 2d vectors by pi/2. Refer to eq 34 in https://arxiv.org/pdf/2104.09864.pdf. """ x = x.unflatten(-1, (-1, 2)) x1, x2 = x.unbind(dim=-1) return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) class RotaryPositionalEncoding(nn.Module): def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)): """Rotary positional encoding with given cutoff and number of frequencies for each dimension. number of dimension is inferred from the length of cutoffs and n_pos. see https://arxiv.org/pdf/2104.09864.pdf """ super().__init__() assert len(cutoffs) == len(n_pos) if not all(n % 2 == 0 for n in n_pos): raise ValueError("n_pos must be even") self._n_dim = len(cutoffs) # theta in RoFormer https://arxiv.org/pdf/2104.09864.pdf self.freqs = nn.ParameterList([ nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2)) for cutoff, n in zip(cutoffs, n_pos) ]) def get_co_si(self, coords: torch.Tensor): _B, _N, D = coords.shape assert D == len(self.freqs) co = torch.cat( tuple( torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) ), axis=-1, ) si = torch.cat( tuple( torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) ), axis=-1, ) return co, si def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor): _B, _N, D = coords.shape _B, _H, _N, _C = q.shape if not D == self._n_dim: raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}") co, si = self.get_co_si(coords) co = co.unsqueeze(1).repeat_interleave(2, dim=-1) si = si.unsqueeze(1).repeat_interleave(2, dim=-1) q2 = q * co + _rotate_half(q) * si k2 = k * co + _rotate_half(k) * si return q2, k2 if __name__ == "__main__": model = RotaryPositionalEncoding((256, 256), (32, 32)) x = 100 * torch.rand(1, 17, 2) q = torch.rand(1, 4, 17, 64) k = torch.rand(1, 4, 17, 64) q1, k1 = model(q, k, x) A1 = q1[:, :, 0] @ k1[:, :, 0].transpose(-1, -2) q2, k2 = model(q, k, x + 10) A2 = q2[:, :, 0] @ k2[:, :, 0].transpose(-1, -2)