| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from src.utils import fourier_position_encoder |
| | from src.nn.fusion import fusion_factory |
| | from src.nn.mlp import FFN |
| |
|
| |
|
| | __all__ = [ |
| | 'CatInjection', 'AdditiveInjection', 'AdditiveMLPInjection', |
| | 'FourierInjection', 'LearnableFourierInjection'] |
| |
|
| |
|
| | class BasePositionalInjection(nn.Module): |
| | def __init__(self, dim=None, x_dim=None, fusion='additive', **kwargs): |
| | """Base class for positional information injection. Takes care |
| | fusion with a potential embedding vector. |
| | |
| | Child classes are expected to overwrite the `_encode` method. |
| | |
| | :param dim: int |
| | Positional encoding dimension |
| | :param x_dim: int |
| | If provided, the input feature embeddings will undergo a |
| | linear projection ’x_dim -> dim’ before being fused with the |
| | positional encodings |
| | :param fusion: str |
| | Fusion mechanism to merge positional encodings with feature |
| | embeddings |
| | """ |
| | super().__init__() |
| | self.dim = dim |
| | self.fusion = fusion |
| | self.proj = nn.Identity() if x_dim is None or dim is None \ |
| | else nn.Linear(x_dim, dim) |
| |
|
| | |
| | self.fusion = fusion_factory(fusion) |
| |
|
| | def _encode(self): |
| | raise NotImplementedError |
| |
|
| | def forward(self, pos, x): |
| | if x is not None: |
| | x = self.proj(x) |
| | return self.fusion(self._encode(pos), x) |
| |
|
| |
|
| | class CatInjection(BasePositionalInjection): |
| | def __init__(self, **kwargs): |
| | """Simple child class of BasePositionalInjection equivalent to |
| | a CatFusion. |
| | """ |
| | super().__init__(dim=None, x_dim=None, fusion='cat') |
| |
|
| | def _encode(self, pos): |
| | return pos |
| |
|
| |
|
| | class AdditiveInjection(BasePositionalInjection): |
| | def __init__(self, **kwargs): |
| | """Simple child class of BasePositionalInjection equivalent to |
| | an AdditiveFusion. |
| | """ |
| | super().__init__(dim=None, x_dim=None, fusion='additive') |
| |
|
| | def _encode(self, pos): |
| | return pos |
| |
|
| |
|
| | class AdditiveMLPInjection(BasePositionalInjection): |
| | def __init__(self, dim=None, **kwargs): |
| | """Simple child class of BasePositionalInjection equivalent to |
| | an MLP followed by AdditiveFusion. |
| | """ |
| | super().__init__(dim=dim, x_dim=None, fusion='additive') |
| |
|
| | self.ffn = FFN(3, out_dim=self.dim, activation=nn.LeakyReLU()) |
| |
|
| | def _encode(self, pos): |
| | return self.ffn(pos) |
| |
|
| |
|
| | class FourierInjection(BasePositionalInjection): |
| | def __init__( |
| | self, dim=None, x_dim=None, fusion='additive', f_min=1e-1, |
| | f_max=1e1, **kwargs): |
| | """Convert [N, M] M-dimensional positions into [N, dim] encodings |
| | using sine and cosine decomposition along each axis. Expects dim |
| | to be a multiple of 2*M, for each of the M-dimensions to have |
| | access to the same number of encoding dimensions. |
| | |
| | Input positions are expected to be normalized in [-1, 1] before |
| | encoding. This operation is important, since passing positions |
| | outside this range will result in ambiguities where two distinct |
| | positions have the same encoding. |
| | |
| | :param dim: positional encoding dimension |
| | """ |
| | assert dim is not None |
| | super().__init__(dim=dim, x_dim=x_dim, fusion=fusion, **kwargs) |
| | self.f_min = f_min |
| | self.f_max = f_max |
| |
|
| | def _encode(self, pos): |
| | return fourier_position_encoder( |
| | pos, self.dim, f_min=self.f_min, f_max=self.f_max) |
| |
|
| |
|
| | class LearnableFourierInjection(BasePositionalInjection): |
| | def __init__(self, M: int, F_dim: int, H_dim: int, D: int, gamma: float): |
| | """Learnable Fourier Features from: |
| | https://arxiv.org/pdf/2106.02795.pdf (Algorithm 1) |
| | |
| | Implementation of Algorithm 1: Compute the Fourier feature positional encoding of a multi-dimensional position |
| | Computes the positional encoding of a tensor of shape [N, M] |
| | :param M: each point has a M-dimensional positional values |
| | :param F_dim: depth of the Fourier feature dimension |
| | :param H_dim: hidden layer dimension |
| | :param D: positional encoding dimension |
| | :param gamma: parameter to initialize Wr |
| | """ |
| | super().__init__() |
| | self.M = M |
| | self.F_dim = F_dim |
| | self.H_dim = H_dim |
| | self.D = D |
| | self.gamma = gamma |
| |
|
| | |
| | self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False) |
| | |
| | self.ffn = FFN( |
| | self.F_dim, hidden_dim=self.H_dim, out_dim=self.D, |
| | activation=nn.GELU(), drop=None) |
| | self.init_weights() |
| |
|
| | def init_weights(self): |
| | nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) |
| |
|
| | def forward(self, x): |
| | """Produce positional encodings from x. |
| | |
| | :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M], |
| | where G is the positional group and each group has M-dimensional positional values. |
| | Positions in different positional groups are independent |
| | |
| | :return: positional encoding for X |
| | """ |
| | N, G, M = x.shape |
| | |
| | projected = self.Wr(x) |
| | cosines = torch.cos(projected) |
| | sines = torch.sin(projected) |
| | F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1) |
| | |
| | Y = self.ffn(F) |
| | |
| | PEx = Y.reshape((N, self.D)) |
| | return PEx |
| |
|