phoebehxf
init
aff3c6f
"""Transformer class."""
# from torch_geometric.nn import GATv2Conv
import math
import torch
from torch import nn
from typing import Tuple
def _pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32):
# Maximum initial frequency is 1
return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0)
# https://github.com/cvg/LightGlue/blob/b1cd942fc4a3a824b6aedff059d84f5c31c297f6/lightglue/lightglue.py#L51
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate pairs of scalars as 2d vectors by pi/2.
Refer to eq 34 in https://arxiv.org/pdf/2104.09864.pdf.
"""
x = x.unflatten(-1, (-1, 2))
x1, x2 = x.unbind(dim=-1)
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
class RotaryPositionalEncoding(nn.Module):
def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)):
"""Rotary positional encoding with given cutoff and number of frequencies for each dimension.
number of dimension is inferred from the length of cutoffs and n_pos.
see
https://arxiv.org/pdf/2104.09864.pdf
"""
super().__init__()
assert len(cutoffs) == len(n_pos)
if not all(n % 2 == 0 for n in n_pos):
raise ValueError("n_pos must be even")
self._n_dim = len(cutoffs)
# theta in RoFormer https://arxiv.org/pdf/2104.09864.pdf
self.freqs = nn.ParameterList([
nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2))
for cutoff, n in zip(cutoffs, n_pos)
])
def get_co_si(self, coords: torch.Tensor):
_B, _N, D = coords.shape
assert D == len(self.freqs)
co = torch.cat(
tuple(
torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
),
axis=-1,
)
si = torch.cat(
tuple(
torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
),
axis=-1,
)
return co, si
def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor):
_B, _N, D = coords.shape
_B, _H, _N, _C = q.shape
if not D == self._n_dim:
raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}")
co, si = self.get_co_si(coords)
co = co.unsqueeze(1).repeat_interleave(2, dim=-1)
si = si.unsqueeze(1).repeat_interleave(2, dim=-1)
q2 = q * co + _rotate_half(q) * si
k2 = k * co + _rotate_half(k) * si
return q2, k2
if __name__ == "__main__":
model = RotaryPositionalEncoding((256, 256), (32, 32))
x = 100 * torch.rand(1, 17, 2)
q = torch.rand(1, 4, 17, 64)
k = torch.rand(1, 4, 17, 64)
q1, k1 = model(q, k, x)
A1 = q1[:, :, 0] @ k1[:, :, 0].transpose(-1, -2)
q2, k2 = model(q, k, x + 10)
A2 = q2[:, :, 0] @ k2[:, :, 0].transpose(-1, -2)