English
File size: 5,725 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import numpy as np
import torch
import torch.nn as nn
from src.utils import fourier_position_encoder
from src.nn.fusion import fusion_factory
from src.nn.mlp import FFN


__all__ = [
    'CatInjection', 'AdditiveInjection', 'AdditiveMLPInjection',
    'FourierInjection', 'LearnableFourierInjection']


class BasePositionalInjection(nn.Module):
    def __init__(self, dim=None, x_dim=None, fusion='additive', **kwargs):
        """Base class for positional information injection. Takes care
        fusion with a potential embedding vector.

        Child classes are expected to overwrite the `_encode` method.

        :param dim: int
            Positional encoding dimension
        :param x_dim: int
            If provided, the input feature embeddings will undergo a
            linear projection ’x_dim -> dim’ before being fused with the
            positional encodings
        :param fusion: str
            Fusion mechanism to merge positional encodings with feature
            embeddings
        """
        super().__init__()
        self.dim = dim
        self.fusion = fusion
        self.proj = nn.Identity() if x_dim is None or dim is None \
            else nn.Linear(x_dim, dim)

        # Fusion operator
        self.fusion = fusion_factory(fusion)

    def _encode(self):
        raise NotImplementedError

    def forward(self, pos, x):
        if x is not None:
            x = self.proj(x)
        return self.fusion(self._encode(pos), x)


class CatInjection(BasePositionalInjection):
    def __init__(self, **kwargs):
        """Simple child class of BasePositionalInjection equivalent to
        a CatFusion.
        """
        super().__init__(dim=None, x_dim=None, fusion='cat')

    def _encode(self, pos):
        return pos


class AdditiveInjection(BasePositionalInjection):
    def __init__(self, **kwargs):
        """Simple child class of BasePositionalInjection equivalent to
        an AdditiveFusion.
        """
        super().__init__(dim=None, x_dim=None, fusion='additive')

    def _encode(self, pos):
        return pos


class AdditiveMLPInjection(BasePositionalInjection):
    def __init__(self, dim=None, **kwargs):
        """Simple child class of BasePositionalInjection equivalent to
        an MLP followed by AdditiveFusion.
        """
        super().__init__(dim=dim, x_dim=None, fusion='additive')

        self.ffn = FFN(3, out_dim=self.dim, activation=nn.LeakyReLU())

    def _encode(self, pos):
        return self.ffn(pos)


class FourierInjection(BasePositionalInjection):
    def __init__(
            self, dim=None, x_dim=None, fusion='additive', f_min=1e-1,
            f_max=1e1, **kwargs):
        """Convert [N, M] M-dimensional positions into [N, dim] encodings
        using sine and cosine decomposition along each axis. Expects dim
        to be a multiple of 2*M, for each of the M-dimensions to have
        access to the same number of encoding dimensions.

        Input positions are expected to be normalized in [-1, 1] before
        encoding. This operation is important, since passing positions
        outside this range will result in ambiguities where two distinct
        positions have the same encoding.

        :param dim: positional encoding dimension
        """
        assert dim is not None
        super().__init__(dim=dim, x_dim=x_dim, fusion=fusion, **kwargs)
        self.f_min = f_min
        self.f_max = f_max

    def _encode(self, pos):
        return fourier_position_encoder(
            pos, self.dim, f_min=self.f_min, f_max=self.f_max)


class LearnableFourierInjection(BasePositionalInjection):
    def __init__(self, M: int, F_dim: int, H_dim: int, D: int, gamma: float):
        """Learnable Fourier Features from:
            https://arxiv.org/pdf/2106.02795.pdf (Algorithm 1)

        Implementation of Algorithm 1: Compute the Fourier feature positional encoding of a multi-dimensional position
        Computes the positional encoding of a tensor of shape [N, M]
        :param M: each point has a M-dimensional positional values
        :param F_dim: depth of the Fourier feature dimension
        :param H_dim: hidden layer dimension
        :param D: positional encoding dimension
        :param gamma: parameter to initialize Wr
        """
        super().__init__()
        self.M = M
        self.F_dim = F_dim
        self.H_dim = H_dim
        self.D = D
        self.gamma = gamma

        # Projection matrix on learned lines (used in eq. 2)
        self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
        # MLP (GeLU(F @ W1 + B1) @ W2 + B2 (eq. 6)
        self.ffn = FFN(
            self.F_dim, hidden_dim=self.H_dim, out_dim=self.D,
            activation=nn.GELU(), drop=None)
        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)

    def forward(self, x):
        """Produce positional encodings from x.

        :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M],
                  where G is the positional group and each group has M-dimensional positional values.
                  Positions in different positional groups are independent

        :return: positional encoding for X
        """
        N, G, M = x.shape
        # Step 1. Compute Fourier features (eq. 2)
        projected = self.Wr(x)
        cosines = torch.cos(projected)
        sines = torch.sin(projected)
        F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1)
        # Step 2. Compute projected Fourier features (eq. 6)
        Y = self.ffn(F)
        # Step 3. Reshape to x's shape
        PEx = Y.reshape((N, self.D))
        return PEx