Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import mist_cf.common as common | |
| class IntFeaturizer(nn.Module): | |
| """ | |
| Base class for mapping integers to a vector representation (primarily to be used as a "richer" embedding for NNs | |
| processing integers). | |
| Subclasses should define `self.int_to_feat_matrix`, a matrix where each row is the vector representation for that | |
| integer, i.e. to get a vector representation for `5`, one could call `self.int_to_feat_matrix[5]`. | |
| Note that this class takes care of creating a fixed number (`self.NUM_EXTRA_EMBEDDINGS` to be precise) of extra | |
| "learned" embeddings these will be concatenated after the integer embeddings in the forward pass, | |
| be learned, and be used for extra non-integer tokens such as the "to be confirmed token" (i.e., pad) token. | |
| They are indexed starting from `self.MAX_COUNT_INT`. | |
| """ | |
| MAX_COUNT_INT = 255 # the maximum number of integers that we are going to see as a "count", i.e. 0 to MAX_COUNT_INT-1 | |
| NUM_EXTRA_EMBEDDINGS = 1 # Number of extra embeddings to learn -- one for the "to be confirmed" embedding. | |
| def __init__(self, embedding_dim): | |
| super().__init__() | |
| weights = torch.zeros(self.NUM_EXTRA_EMBEDDINGS, embedding_dim) | |
| self._extra_embeddings = nn.Parameter(weights, requires_grad=True) | |
| nn.init.normal_(self._extra_embeddings, 0.0, 1.0) | |
| self.embedding_dim = embedding_dim | |
| def forward(self, tensor): | |
| """ | |
| Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension. | |
| """ | |
| # todo(jab): copied this code from the original in-built binarizer embedder in built into the class. | |
| # very similar to F.embedding but we want to put the embedding into the final dimension -- could ask Sam | |
| # why... | |
| orig_shape = tensor.shape | |
| out_tensor = torch.empty( | |
| (*orig_shape, self.embedding_dim), device=tensor.device | |
| ) | |
| extra_embed = tensor >= self.MAX_COUNT_INT | |
| tensor = tensor.long() | |
| norm_embeds = self.int_to_feat_matrix[tensor[~extra_embed]] | |
| extra_embeds = self._extra_embeddings[tensor[extra_embed] - self.MAX_COUNT_INT] | |
| out_tensor[~extra_embed] = norm_embeds | |
| out_tensor[extra_embed] = extra_embeds | |
| temp_out = out_tensor.reshape(*orig_shape[:-1], -1) | |
| return temp_out | |
| def num_dim(self): | |
| return self.int_to_feat_matrix.shape[1] | |
| def full_dim(self): | |
| return self.num_dim * common.NORM_VEC.shape[0] | |
| class Binarizer(IntFeaturizer): | |
| def __init__(self): | |
| super().__init__(embedding_dim=len(common.num_to_binary(0))) | |
| int_to_binary_repr = np.vstack( | |
| [common.num_to_binary(i) for i in range(self.MAX_COUNT_INT)] | |
| ) | |
| int_to_binary_repr = torch.from_numpy(int_to_binary_repr) | |
| self.int_to_feat_matrix = nn.Parameter(int_to_binary_repr.float()) | |
| self.int_to_feat_matrix.requires_grad = False | |
| class FourierFeaturizer(IntFeaturizer): | |
| """ | |
| Inspired by: | |
| Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., | |
| Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional | |
| Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739. | |
| Some notes: | |
| * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the | |
| Binarizer quite closely but be a bit smoother. | |
| """ | |
| def __init__(self): | |
| num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2 | |
| # ^ need at least this many to ensure that the whole input range can be represented on the half circle. | |
| freqs = 0.5 ** torch.arange(num_freqs, dtype=torch.float32) | |
| freqs_time_2pi = 2 * np.pi * freqs | |
| super().__init__( | |
| embedding_dim=2 * freqs_time_2pi.shape[0] | |
| ) # 2 for cosine and sine | |
| # we will define the features at this frequency up front (as we only will ever see a fixed number of counts): | |
| combo_of_sinusoid_args = ( | |
| torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None] | |
| * freqs_time_2pi[None, :] | |
| ) | |
| all_features = torch.cat( | |
| [torch.cos(combo_of_sinusoid_args), torch.sin(combo_of_sinusoid_args)], | |
| dim=1, | |
| ) | |
| # ^ shape: MAX_COUNT_INT x 2 * num_freqs | |
| self.int_to_feat_matrix = nn.Parameter(all_features.float()) | |
| self.int_to_feat_matrix.requires_grad = False | |
| class FourierFeaturizerSines(IntFeaturizer): | |
| """ | |
| Like other fourier feats but sines only | |
| Inspired by: | |
| Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., | |
| Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional | |
| Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739. | |
| Some notes: | |
| * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the | |
| Binarizer quite closely but be a bit smoother. | |
| """ | |
| def __init__(self): | |
| num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2 | |
| # ^ need at least this many to ensure that the whole input range can be represented on the half circle. | |
| freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:] | |
| freqs_time_2pi = 2 * np.pi * freqs | |
| super().__init__(embedding_dim=freqs_time_2pi.shape[0]) | |
| # we will define the features at this frequency up front (as we only will ever see a fixed number of counts): | |
| combo_of_sinusoid_args = ( | |
| torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None] | |
| * freqs_time_2pi[None, :] | |
| ) | |
| # ^ shape: MAX_COUNT_INT x 2 * num_freqs | |
| self.int_to_feat_matrix = nn.Parameter( | |
| torch.sin(combo_of_sinusoid_args).float() | |
| ) | |
| self.int_to_feat_matrix.requires_grad = False | |
| class FourierFeaturizerAbsoluteSines(IntFeaturizer): | |
| """ | |
| Like other fourier feats but sines only and absoluted. | |
| Inspired by: | |
| Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R., | |
| Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional | |
| Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739. | |
| Some notes: | |
| * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the | |
| Binarizer quite closely but be a bit smoother. | |
| """ | |
| def __init__(self): | |
| num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2 | |
| freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:] | |
| freqs_time_2pi = 2 * np.pi * freqs | |
| super().__init__(embedding_dim=freqs_time_2pi.shape[0]) | |
| # we will define the features at this frequency up front (as we only will ever see a fixed number of counts): | |
| combo_of_sinusoid_args = ( | |
| torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None] | |
| * freqs_time_2pi[None, :] | |
| ) | |
| # ^ shape: MAX_COUNT_INT x 2 * num_freqs | |
| self.int_to_feat_matrix = nn.Parameter( | |
| torch.abs(torch.sin(combo_of_sinusoid_args)).float() | |
| ) | |
| self.int_to_feat_matrix.requires_grad = False | |
| class RBFFeaturizer(IntFeaturizer): | |
| """ | |
| A featurizer that puts radial basis functions evenly between 0 and max_count-1. These will have a width of | |
| (max_count-1) / (num_funcs) to decay to about 0.6 of its original height at reaching the next func. | |
| """ | |
| def __init__(self, num_funcs=32): | |
| """ | |
| :param num_funcs: number of radial basis functions to use: their width will automatically be chosen -- see class | |
| docstring. | |
| """ | |
| super().__init__(embedding_dim=num_funcs) | |
| width = (self.MAX_COUNT_INT - 1) / num_funcs | |
| centers = torch.linspace(0, self.MAX_COUNT_INT - 1, num_funcs) | |
| pre_exponential_terms = ( | |
| -0.5 | |
| * ((torch.arange(self.MAX_COUNT_INT)[:, None] - centers[None, :]) / width) | |
| ** 2 | |
| ) | |
| # ^ shape: MAX_COUNT_INT x num_funcs | |
| feats = torch.exp(pre_exponential_terms) | |
| self.int_to_feat_matrix = nn.Parameter(feats.float()) | |
| self.int_to_feat_matrix.requires_grad = False | |
| class OneHotFeaturizer(IntFeaturizer): | |
| """ | |
| A featurizer that turns integers into their one hot encoding. | |
| Represents: | |
| - 0 as 1000000000... | |
| - 1 as 0100000000... | |
| - 2 as 0010000000... | |
| and so on. | |
| """ | |
| def __init__(self): | |
| super().__init__(embedding_dim=self.MAX_COUNT_INT) | |
| feats = torch.eye(self.MAX_COUNT_INT) | |
| self.int_to_feat_matrix = nn.Parameter(feats.float()) | |
| self.int_to_feat_matrix.requires_grad = False | |
| class LearnedFeaturizer(IntFeaturizer): | |
| """ | |
| Learns the features for the different integers. | |
| Pretty much `nn.Embedding` but we get to use the forward of the superclass which behaves a bit differently. | |
| """ | |
| def __init__(self, feature_dim=32): | |
| super().__init__(embedding_dim=feature_dim) | |
| weights = torch.zeros(self.MAX_COUNT_INT, feature_dim) | |
| self.int_to_feat_matrix = nn.Parameter(weights, requires_grad=True) | |
| nn.init.normal_(self.int_to_feat_matrix, 0.0, 1.0) | |
| class FloatFeaturizer(IntFeaturizer): | |
| """ | |
| Norms the features | |
| """ | |
| def __init__(self): | |
| # Norm vec | |
| # Placeholder.. | |
| super().__init__(embedding_dim=1) | |
| self.norm_vec = torch.from_numpy(common.NORM_VEC).float() | |
| self.norm_vec = nn.Parameter(self.norm_vec) | |
| self.norm_vec.requires_grad = False | |
| def forward(self, tensor): | |
| """ | |
| Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension. | |
| """ | |
| tens_shape = tensor.shape | |
| out_shape = [1] * (len(tens_shape) - 1) + [-1] | |
| return tensor / self.norm_vec.reshape(*out_shape) | |
| def num_dim(self): | |
| return 1 | |
| def get_embedder(embedder): | |
| if embedder == "binary": | |
| embedder = Binarizer() | |
| elif embedder == "fourier": | |
| embedder = FourierFeaturizer() | |
| elif embedder == "rbf": | |
| embedder = RBFFeaturizer() | |
| elif embedder == "one-hot": | |
| embedder = OneHotFeaturizer() | |
| elif embedder == "learnt": | |
| embedder = LearnedFeaturizer() | |
| elif embedder == "float": | |
| embedder = FloatFeaturizer() | |
| elif embedder == "fourier-sines": | |
| embedder = FourierFeaturizerSines() | |
| elif embedder == "abs-sines": | |
| embedder = FourierFeaturizerAbsoluteSines() | |
| else: | |
| raise NotImplementedError | |
| return embedder | |