File size: 11,003 Bytes
c8bfe50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import torch
import torch.nn as nn
import numpy as np

import mist_cf.common as common


class IntFeaturizer(nn.Module):
    """
    Base class for mapping integers to a vector representation (primarily to be used as a "richer" embedding for NNs
    processing integers).

    Subclasses should define `self.int_to_feat_matrix`, a matrix where each row is the vector representation for that
    integer, i.e. to get a vector representation for `5`, one could call `self.int_to_feat_matrix[5]`.

    Note that this class takes care of creating a fixed number (`self.NUM_EXTRA_EMBEDDINGS` to be precise) of extra
    "learned" embeddings these will be concatenated after the integer embeddings in the forward pass,
    be learned, and be used for extra  non-integer tokens such as the "to be confirmed token" (i.e., pad) token.
    They are indexed starting from `self.MAX_COUNT_INT`.
    """

    MAX_COUNT_INT = 255  # the maximum number of integers that we are going to see as a "count", i.e. 0 to MAX_COUNT_INT-1
    NUM_EXTRA_EMBEDDINGS = 1  # Number of extra embeddings to learn -- one for the "to be confirmed" embedding.

    def __init__(self, embedding_dim):
        super().__init__()
        weights = torch.zeros(self.NUM_EXTRA_EMBEDDINGS, embedding_dim)
        self._extra_embeddings = nn.Parameter(weights, requires_grad=True)
        nn.init.normal_(self._extra_embeddings, 0.0, 1.0)
        self.embedding_dim = embedding_dim

    def forward(self, tensor):
        """
        Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension.
        """
        # todo(jab): copied this code from the original in-built binarizer embedder in built into the class.
        # very similar to F.embedding but we want to put the embedding into the final dimension -- could ask Sam
        # why...

        orig_shape = tensor.shape
        out_tensor = torch.empty(
            (*orig_shape, self.embedding_dim), device=tensor.device
        )
        extra_embed = tensor >= self.MAX_COUNT_INT

        tensor = tensor.long()
        norm_embeds = self.int_to_feat_matrix[tensor[~extra_embed]]
        extra_embeds = self._extra_embeddings[tensor[extra_embed] - self.MAX_COUNT_INT]

        out_tensor[~extra_embed] = norm_embeds
        out_tensor[extra_embed] = extra_embeds

        temp_out = out_tensor.reshape(*orig_shape[:-1], -1)
        return temp_out

    @property
    def num_dim(self):
        return self.int_to_feat_matrix.shape[1]

    @property
    def full_dim(self):
        return self.num_dim * common.NORM_VEC.shape[0]


class Binarizer(IntFeaturizer):
    def __init__(self):
        super().__init__(embedding_dim=len(common.num_to_binary(0)))
        int_to_binary_repr = np.vstack(
            [common.num_to_binary(i) for i in range(self.MAX_COUNT_INT)]
        )
        int_to_binary_repr = torch.from_numpy(int_to_binary_repr)
        self.int_to_feat_matrix = nn.Parameter(int_to_binary_repr.float())
        self.int_to_feat_matrix.requires_grad = False


class FourierFeaturizer(IntFeaturizer):
    """
    Inspired by:
    Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
    Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
     Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.

    Some notes:
    * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
        Binarizer quite closely but be a bit smoother.
    """

    def __init__(self):

        num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
        # ^ need at least this many to ensure that the whole input range can be represented on the half circle.

        freqs = 0.5 ** torch.arange(num_freqs, dtype=torch.float32)
        freqs_time_2pi = 2 * np.pi * freqs

        super().__init__(
            embedding_dim=2 * freqs_time_2pi.shape[0]
        )  # 2 for cosine and sine

        # we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
        combo_of_sinusoid_args = (
            torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
            * freqs_time_2pi[None, :]
        )
        all_features = torch.cat(
            [torch.cos(combo_of_sinusoid_args), torch.sin(combo_of_sinusoid_args)],
            dim=1,
        )

        # ^ shape:  MAX_COUNT_INT x 2 * num_freqs
        self.int_to_feat_matrix = nn.Parameter(all_features.float())
        self.int_to_feat_matrix.requires_grad = False


class FourierFeaturizerSines(IntFeaturizer):
    """
    Like other fourier feats but sines only

    Inspired by:
    Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
    Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
     Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.

    Some notes:
    * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
        Binarizer quite closely but be a bit smoother.
    """

    def __init__(self):

        num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
        # ^ need at least this many to ensure that the whole input range can be represented on the half circle.

        freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:]
        freqs_time_2pi = 2 * np.pi * freqs

        super().__init__(embedding_dim=freqs_time_2pi.shape[0])

        # we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
        combo_of_sinusoid_args = (
            torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
            * freqs_time_2pi[None, :]
        )
        # ^ shape:  MAX_COUNT_INT x 2 * num_freqs
        self.int_to_feat_matrix = nn.Parameter(
            torch.sin(combo_of_sinusoid_args).float()
        )
        self.int_to_feat_matrix.requires_grad = False


class FourierFeaturizerAbsoluteSines(IntFeaturizer):
    """
    Like other fourier feats but sines only and absoluted.

    Inspired by:
    Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
    Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
     Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.

    Some notes:
    * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
        Binarizer quite closely but be a bit smoother.
    """

    def __init__(self):

        num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2

        freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:]
        freqs_time_2pi = 2 * np.pi * freqs

        super().__init__(embedding_dim=freqs_time_2pi.shape[0])

        # we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
        combo_of_sinusoid_args = (
            torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
            * freqs_time_2pi[None, :]
        )
        # ^ shape:  MAX_COUNT_INT x 2 * num_freqs
        self.int_to_feat_matrix = nn.Parameter(
            torch.abs(torch.sin(combo_of_sinusoid_args)).float()
        )
        self.int_to_feat_matrix.requires_grad = False


class RBFFeaturizer(IntFeaturizer):
    """
    A featurizer that puts radial basis functions evenly between 0 and max_count-1. These will have a width of
    (max_count-1) / (num_funcs) to decay to about 0.6 of its original height at reaching the next func.

    """

    def __init__(self, num_funcs=32):
        """
        :param num_funcs: number of radial basis functions to use: their width will automatically be chosen -- see class
                            docstring.
        """
        super().__init__(embedding_dim=num_funcs)
        width = (self.MAX_COUNT_INT - 1) / num_funcs
        centers = torch.linspace(0, self.MAX_COUNT_INT - 1, num_funcs)

        pre_exponential_terms = (
            -0.5
            * ((torch.arange(self.MAX_COUNT_INT)[:, None] - centers[None, :]) / width)
            ** 2
        )
        # ^ shape: MAX_COUNT_INT x num_funcs
        feats = torch.exp(pre_exponential_terms)

        self.int_to_feat_matrix = nn.Parameter(feats.float())
        self.int_to_feat_matrix.requires_grad = False


class OneHotFeaturizer(IntFeaturizer):
    """
    A featurizer that turns integers into their one hot encoding.

    Represents:
     - 0 as 1000000000...
     - 1 as 0100000000...
     - 2 as 0010000000...
     and so on.
    """

    def __init__(self):
        super().__init__(embedding_dim=self.MAX_COUNT_INT)
        feats = torch.eye(self.MAX_COUNT_INT)
        self.int_to_feat_matrix = nn.Parameter(feats.float())
        self.int_to_feat_matrix.requires_grad = False


class LearnedFeaturizer(IntFeaturizer):
    """
    Learns the features for the different integers.

    Pretty much `nn.Embedding` but we get to use the forward of the superclass which behaves a bit differently.
    """

    def __init__(self, feature_dim=32):
        super().__init__(embedding_dim=feature_dim)
        weights = torch.zeros(self.MAX_COUNT_INT, feature_dim)
        self.int_to_feat_matrix = nn.Parameter(weights, requires_grad=True)
        nn.init.normal_(self.int_to_feat_matrix, 0.0, 1.0)


class FloatFeaturizer(IntFeaturizer):
    """
    Norms the features
    """

    def __init__(self):
        # Norm vec
        # Placeholder..
        super().__init__(embedding_dim=1)
        self.norm_vec = torch.from_numpy(common.NORM_VEC).float()
        self.norm_vec = nn.Parameter(self.norm_vec)
        self.norm_vec.requires_grad = False

    def forward(self, tensor):
        """
        Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension.
        """
        tens_shape = tensor.shape
        out_shape = [1] * (len(tens_shape) - 1) + [-1]
        return tensor / self.norm_vec.reshape(*out_shape)

    @property
    def num_dim(self):
        return 1


def get_embedder(embedder):
    if embedder == "binary":
        embedder = Binarizer()
    elif embedder == "fourier":
        embedder = FourierFeaturizer()
    elif embedder == "rbf":
        embedder = RBFFeaturizer()
    elif embedder == "one-hot":
        embedder = OneHotFeaturizer()
    elif embedder == "learnt":
        embedder = LearnedFeaturizer()
    elif embedder == "float":
        embedder = FloatFeaturizer()
    elif embedder == "fourier-sines":
        embedder = FourierFeaturizerSines()
    elif embedder == "abs-sines":
        embedder = FourierFeaturizerAbsoluteSines()
    else:
        raise NotImplementedError
    return embedder