|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from src.utils import fourier_position_encoder |
|
|
from src.nn.fusion import fusion_factory |
|
|
from src.nn.mlp import FFN |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'CatInjection', 'AdditiveInjection', 'AdditiveMLPInjection', |
|
|
'FourierInjection', 'LearnableFourierInjection'] |
|
|
|
|
|
|
|
|
class BasePositionalInjection(nn.Module): |
|
|
def __init__(self, dim=None, x_dim=None, fusion='additive', **kwargs): |
|
|
"""Base class for positional information injection. Takes care |
|
|
fusion with a potential embedding vector. |
|
|
|
|
|
Child classes are expected to overwrite the `_encode` method. |
|
|
|
|
|
:param dim: int |
|
|
Positional encoding dimension |
|
|
:param x_dim: int |
|
|
If provided, the input feature embeddings will undergo a |
|
|
linear projection ’x_dim -> dim’ before being fused with the |
|
|
positional encodings |
|
|
:param fusion: str |
|
|
Fusion mechanism to merge positional encodings with feature |
|
|
embeddings |
|
|
""" |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.fusion = fusion |
|
|
self.proj = nn.Identity() if x_dim is None or dim is None \ |
|
|
else nn.Linear(x_dim, dim) |
|
|
|
|
|
|
|
|
self.fusion = fusion_factory(fusion) |
|
|
|
|
|
def _encode(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def forward(self, pos, x): |
|
|
if x is not None: |
|
|
x = self.proj(x) |
|
|
return self.fusion(self._encode(pos), x) |
|
|
|
|
|
|
|
|
class CatInjection(BasePositionalInjection): |
|
|
def __init__(self, **kwargs): |
|
|
"""Simple child class of BasePositionalInjection equivalent to |
|
|
a CatFusion. |
|
|
""" |
|
|
super().__init__(dim=None, x_dim=None, fusion='cat') |
|
|
|
|
|
def _encode(self, pos): |
|
|
return pos |
|
|
|
|
|
|
|
|
class AdditiveInjection(BasePositionalInjection): |
|
|
def __init__(self, **kwargs): |
|
|
"""Simple child class of BasePositionalInjection equivalent to |
|
|
an AdditiveFusion. |
|
|
""" |
|
|
super().__init__(dim=None, x_dim=None, fusion='additive') |
|
|
|
|
|
def _encode(self, pos): |
|
|
return pos |
|
|
|
|
|
|
|
|
class AdditiveMLPInjection(BasePositionalInjection): |
|
|
def __init__(self, dim=None, **kwargs): |
|
|
"""Simple child class of BasePositionalInjection equivalent to |
|
|
an MLP followed by AdditiveFusion. |
|
|
""" |
|
|
super().__init__(dim=dim, x_dim=None, fusion='additive') |
|
|
|
|
|
self.ffn = FFN(3, out_dim=self.dim, activation=nn.LeakyReLU()) |
|
|
|
|
|
def _encode(self, pos): |
|
|
return self.ffn(pos) |
|
|
|
|
|
|
|
|
class FourierInjection(BasePositionalInjection): |
|
|
def __init__( |
|
|
self, dim=None, x_dim=None, fusion='additive', f_min=1e-1, |
|
|
f_max=1e1, **kwargs): |
|
|
"""Convert [N, M] M-dimensional positions into [N, dim] encodings |
|
|
using sine and cosine decomposition along each axis. Expects dim |
|
|
to be a multiple of 2*M, for each of the M-dimensions to have |
|
|
access to the same number of encoding dimensions. |
|
|
|
|
|
Input positions are expected to be normalized in [-1, 1] before |
|
|
encoding. This operation is important, since passing positions |
|
|
outside this range will result in ambiguities where two distinct |
|
|
positions have the same encoding. |
|
|
|
|
|
:param dim: positional encoding dimension |
|
|
""" |
|
|
assert dim is not None |
|
|
super().__init__(dim=dim, x_dim=x_dim, fusion=fusion, **kwargs) |
|
|
self.f_min = f_min |
|
|
self.f_max = f_max |
|
|
|
|
|
def _encode(self, pos): |
|
|
return fourier_position_encoder( |
|
|
pos, self.dim, f_min=self.f_min, f_max=self.f_max) |
|
|
|
|
|
|
|
|
class LearnableFourierInjection(BasePositionalInjection): |
|
|
def __init__(self, M: int, F_dim: int, H_dim: int, D: int, gamma: float): |
|
|
"""Learnable Fourier Features from: |
|
|
https://arxiv.org/pdf/2106.02795.pdf (Algorithm 1) |
|
|
|
|
|
Implementation of Algorithm 1: Compute the Fourier feature positional encoding of a multi-dimensional position |
|
|
Computes the positional encoding of a tensor of shape [N, M] |
|
|
:param M: each point has a M-dimensional positional values |
|
|
:param F_dim: depth of the Fourier feature dimension |
|
|
:param H_dim: hidden layer dimension |
|
|
:param D: positional encoding dimension |
|
|
:param gamma: parameter to initialize Wr |
|
|
""" |
|
|
super().__init__() |
|
|
self.M = M |
|
|
self.F_dim = F_dim |
|
|
self.H_dim = H_dim |
|
|
self.D = D |
|
|
self.gamma = gamma |
|
|
|
|
|
|
|
|
self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False) |
|
|
|
|
|
self.ffn = FFN( |
|
|
self.F_dim, hidden_dim=self.H_dim, out_dim=self.D, |
|
|
activation=nn.GELU(), drop=None) |
|
|
self.init_weights() |
|
|
|
|
|
def init_weights(self): |
|
|
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Produce positional encodings from x. |
|
|
|
|
|
:param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M], |
|
|
where G is the positional group and each group has M-dimensional positional values. |
|
|
Positions in different positional groups are independent |
|
|
|
|
|
:return: positional encoding for X |
|
|
""" |
|
|
N, G, M = x.shape |
|
|
|
|
|
projected = self.Wr(x) |
|
|
cosines = torch.cos(projected) |
|
|
sines = torch.sin(projected) |
|
|
F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1) |
|
|
|
|
|
Y = self.ffn(F) |
|
|
|
|
|
PEx = Y.reshape((N, self.D)) |
|
|
return PEx |
|
|
|