Spaces:
Sleeping
Sleeping
| """Transformer class.""" | |
| # from torch_geometric.nn import GATv2Conv | |
| import math | |
| import torch | |
| from torch import nn | |
| from typing import Tuple | |
| def _pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32): | |
| # Maximum initial frequency is 1 | |
| return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0) | |
| # https://github.com/cvg/LightGlue/blob/b1cd942fc4a3a824b6aedff059d84f5c31c297f6/lightglue/lightglue.py#L51 | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| """Rotate pairs of scalars as 2d vectors by pi/2. | |
| Refer to eq 34 in https://arxiv.org/pdf/2104.09864.pdf. | |
| """ | |
| x = x.unflatten(-1, (-1, 2)) | |
| x1, x2 = x.unbind(dim=-1) | |
| return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) | |
| class RotaryPositionalEncoding(nn.Module): | |
| def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)): | |
| """Rotary positional encoding with given cutoff and number of frequencies for each dimension. | |
| number of dimension is inferred from the length of cutoffs and n_pos. | |
| see | |
| https://arxiv.org/pdf/2104.09864.pdf | |
| """ | |
| super().__init__() | |
| assert len(cutoffs) == len(n_pos) | |
| if not all(n % 2 == 0 for n in n_pos): | |
| raise ValueError("n_pos must be even") | |
| self._n_dim = len(cutoffs) | |
| # theta in RoFormer https://arxiv.org/pdf/2104.09864.pdf | |
| self.freqs = nn.ParameterList([ | |
| nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2)) | |
| for cutoff, n in zip(cutoffs, n_pos) | |
| ]) | |
| def get_co_si(self, coords: torch.Tensor): | |
| _B, _N, D = coords.shape | |
| assert D == len(self.freqs) | |
| co = torch.cat( | |
| tuple( | |
| torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) | |
| for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) | |
| ), | |
| axis=-1, | |
| ) | |
| si = torch.cat( | |
| tuple( | |
| torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) | |
| for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) | |
| ), | |
| axis=-1, | |
| ) | |
| return co, si | |
| def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor): | |
| _B, _N, D = coords.shape | |
| _B, _H, _N, _C = q.shape | |
| if not D == self._n_dim: | |
| raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}") | |
| co, si = self.get_co_si(coords) | |
| co = co.unsqueeze(1).repeat_interleave(2, dim=-1) | |
| si = si.unsqueeze(1).repeat_interleave(2, dim=-1) | |
| q2 = q * co + _rotate_half(q) * si | |
| k2 = k * co + _rotate_half(k) * si | |
| return q2, k2 | |
| if __name__ == "__main__": | |
| model = RotaryPositionalEncoding((256, 256), (32, 32)) | |
| x = 100 * torch.rand(1, 17, 2) | |
| q = torch.rand(1, 4, 17, 64) | |
| k = torch.rand(1, 4, 17, 64) | |
| q1, k1 = model(q, k, x) | |
| A1 = q1[:, :, 0] @ k1[:, :, 0].transpose(-1, -2) | |
| q2, k2 = model(q, k, x + 10) | |
| A2 = q2[:, :, 0] @ k2[:, :, 0].transpose(-1, -2) | |