CSU-MS2-T2 / nn_utils /form_embedder.py
Tingxie's picture
Upload 10 files
c8bfe50
import torch
import torch.nn as nn
import numpy as np
import mist_cf.common as common
class IntFeaturizer(nn.Module):
"""
Base class for mapping integers to a vector representation (primarily to be used as a "richer" embedding for NNs
processing integers).
Subclasses should define `self.int_to_feat_matrix`, a matrix where each row is the vector representation for that
integer, i.e. to get a vector representation for `5`, one could call `self.int_to_feat_matrix[5]`.
Note that this class takes care of creating a fixed number (`self.NUM_EXTRA_EMBEDDINGS` to be precise) of extra
"learned" embeddings these will be concatenated after the integer embeddings in the forward pass,
be learned, and be used for extra non-integer tokens such as the "to be confirmed token" (i.e., pad) token.
They are indexed starting from `self.MAX_COUNT_INT`.
"""
MAX_COUNT_INT = 255 # the maximum number of integers that we are going to see as a "count", i.e. 0 to MAX_COUNT_INT-1
NUM_EXTRA_EMBEDDINGS = 1 # Number of extra embeddings to learn -- one for the "to be confirmed" embedding.
def __init__(self, embedding_dim):
super().__init__()
weights = torch.zeros(self.NUM_EXTRA_EMBEDDINGS, embedding_dim)
self._extra_embeddings = nn.Parameter(weights, requires_grad=True)
nn.init.normal_(self._extra_embeddings, 0.0, 1.0)
self.embedding_dim = embedding_dim
def forward(self, tensor):
"""
Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension.
"""
# todo(jab): copied this code from the original in-built binarizer embedder in built into the class.
# very similar to F.embedding but we want to put the embedding into the final dimension -- could ask Sam
# why...
orig_shape = tensor.shape
out_tensor = torch.empty(
(*orig_shape, self.embedding_dim), device=tensor.device
)
extra_embed = tensor >= self.MAX_COUNT_INT
tensor = tensor.long()
norm_embeds = self.int_to_feat_matrix[tensor[~extra_embed]]
extra_embeds = self._extra_embeddings[tensor[extra_embed] - self.MAX_COUNT_INT]
out_tensor[~extra_embed] = norm_embeds
out_tensor[extra_embed] = extra_embeds
temp_out = out_tensor.reshape(*orig_shape[:-1], -1)
return temp_out
@property
def num_dim(self):
return self.int_to_feat_matrix.shape[1]
@property
def full_dim(self):
return self.num_dim * common.NORM_VEC.shape[0]
class Binarizer(IntFeaturizer):
def __init__(self):
super().__init__(embedding_dim=len(common.num_to_binary(0)))
int_to_binary_repr = np.vstack(
[common.num_to_binary(i) for i in range(self.MAX_COUNT_INT)]
)
int_to_binary_repr = torch.from_numpy(int_to_binary_repr)
self.int_to_feat_matrix = nn.Parameter(int_to_binary_repr.float())
self.int_to_feat_matrix.requires_grad = False
class FourierFeaturizer(IntFeaturizer):
"""
Inspired by:
Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.
Some notes:
* we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
Binarizer quite closely but be a bit smoother.
"""
def __init__(self):
num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
# ^ need at least this many to ensure that the whole input range can be represented on the half circle.
freqs = 0.5 ** torch.arange(num_freqs, dtype=torch.float32)
freqs_time_2pi = 2 * np.pi * freqs
super().__init__(
embedding_dim=2 * freqs_time_2pi.shape[0]
) # 2 for cosine and sine
# we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
combo_of_sinusoid_args = (
torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
* freqs_time_2pi[None, :]
)
all_features = torch.cat(
[torch.cos(combo_of_sinusoid_args), torch.sin(combo_of_sinusoid_args)],
dim=1,
)
# ^ shape: MAX_COUNT_INT x 2 * num_freqs
self.int_to_feat_matrix = nn.Parameter(all_features.float())
self.int_to_feat_matrix.requires_grad = False
class FourierFeaturizerSines(IntFeaturizer):
"""
Like other fourier feats but sines only
Inspired by:
Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.
Some notes:
* we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
Binarizer quite closely but be a bit smoother.
"""
def __init__(self):
num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
# ^ need at least this many to ensure that the whole input range can be represented on the half circle.
freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:]
freqs_time_2pi = 2 * np.pi * freqs
super().__init__(embedding_dim=freqs_time_2pi.shape[0])
# we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
combo_of_sinusoid_args = (
torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
* freqs_time_2pi[None, :]
)
# ^ shape: MAX_COUNT_INT x 2 * num_freqs
self.int_to_feat_matrix = nn.Parameter(
torch.sin(combo_of_sinusoid_args).float()
)
self.int_to_feat_matrix.requires_grad = False
class FourierFeaturizerAbsoluteSines(IntFeaturizer):
"""
Like other fourier feats but sines only and absoluted.
Inspired by:
Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.
Some notes:
* we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
Binarizer quite closely but be a bit smoother.
"""
def __init__(self):
num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:]
freqs_time_2pi = 2 * np.pi * freqs
super().__init__(embedding_dim=freqs_time_2pi.shape[0])
# we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
combo_of_sinusoid_args = (
torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
* freqs_time_2pi[None, :]
)
# ^ shape: MAX_COUNT_INT x 2 * num_freqs
self.int_to_feat_matrix = nn.Parameter(
torch.abs(torch.sin(combo_of_sinusoid_args)).float()
)
self.int_to_feat_matrix.requires_grad = False
class RBFFeaturizer(IntFeaturizer):
"""
A featurizer that puts radial basis functions evenly between 0 and max_count-1. These will have a width of
(max_count-1) / (num_funcs) to decay to about 0.6 of its original height at reaching the next func.
"""
def __init__(self, num_funcs=32):
"""
:param num_funcs: number of radial basis functions to use: their width will automatically be chosen -- see class
docstring.
"""
super().__init__(embedding_dim=num_funcs)
width = (self.MAX_COUNT_INT - 1) / num_funcs
centers = torch.linspace(0, self.MAX_COUNT_INT - 1, num_funcs)
pre_exponential_terms = (
-0.5
* ((torch.arange(self.MAX_COUNT_INT)[:, None] - centers[None, :]) / width)
** 2
)
# ^ shape: MAX_COUNT_INT x num_funcs
feats = torch.exp(pre_exponential_terms)
self.int_to_feat_matrix = nn.Parameter(feats.float())
self.int_to_feat_matrix.requires_grad = False
class OneHotFeaturizer(IntFeaturizer):
"""
A featurizer that turns integers into their one hot encoding.
Represents:
- 0 as 1000000000...
- 1 as 0100000000...
- 2 as 0010000000...
and so on.
"""
def __init__(self):
super().__init__(embedding_dim=self.MAX_COUNT_INT)
feats = torch.eye(self.MAX_COUNT_INT)
self.int_to_feat_matrix = nn.Parameter(feats.float())
self.int_to_feat_matrix.requires_grad = False
class LearnedFeaturizer(IntFeaturizer):
"""
Learns the features for the different integers.
Pretty much `nn.Embedding` but we get to use the forward of the superclass which behaves a bit differently.
"""
def __init__(self, feature_dim=32):
super().__init__(embedding_dim=feature_dim)
weights = torch.zeros(self.MAX_COUNT_INT, feature_dim)
self.int_to_feat_matrix = nn.Parameter(weights, requires_grad=True)
nn.init.normal_(self.int_to_feat_matrix, 0.0, 1.0)
class FloatFeaturizer(IntFeaturizer):
"""
Norms the features
"""
def __init__(self):
# Norm vec
# Placeholder..
super().__init__(embedding_dim=1)
self.norm_vec = torch.from_numpy(common.NORM_VEC).float()
self.norm_vec = nn.Parameter(self.norm_vec)
self.norm_vec.requires_grad = False
def forward(self, tensor):
"""
Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension.
"""
tens_shape = tensor.shape
out_shape = [1] * (len(tens_shape) - 1) + [-1]
return tensor / self.norm_vec.reshape(*out_shape)
@property
def num_dim(self):
return 1
def get_embedder(embedder):
if embedder == "binary":
embedder = Binarizer()
elif embedder == "fourier":
embedder = FourierFeaturizer()
elif embedder == "rbf":
embedder = RBFFeaturizer()
elif embedder == "one-hot":
embedder = OneHotFeaturizer()
elif embedder == "learnt":
embedder = LearnedFeaturizer()
elif embedder == "float":
embedder = FloatFeaturizer()
elif embedder == "fourier-sines":
embedder = FourierFeaturizerSines()
elif embedder == "abs-sines":
embedder = FourierFeaturizerAbsoluteSines()
else:
raise NotImplementedError
return embedder