SanDP / utils /nn /modules /sparse.py
shivrajanand's picture
Upload folder using huggingface_hub
a7b3936 verified
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn.functional import embedding
from ..init import assign_tensor
class Embedding(nn.Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
init_embedding (Tensor): If given, the embedding will be initialized with the given tensor.
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
padding_idx (int, optional): If given, pads the output with zeros whenever it encounters the index.
max_norm (float, optional): If given, will renormalize the embeddings to always have a norm lesser than this
norm_type (float, optional): The p of the p-norm to compute for the max_norm option
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the frequency of
the words in the mini-batch.
sparse (boolean, optional): if True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for
more details regarding sparse gradients.
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
Shape:
- Input: LongTensor `(N1, N2, ...,Nm, W)`, N = mini-batch, W = number of indices to extract per mini-batch
- Output: `(N1, N2, ..., Nm, W, embedding_dim)`
Notes:
Keep in mind that only a limited number of optimizers support
sparse gradients: currently it's `optim.SGD` (`cuda` and `cpu`),
and `optim.Adagrad` (`cpu`)
"""
def __init__(self, num_embeddings, embedding_dim, init_embedding=None, freeze=False, padding_idx=None,
max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False):
super(Embedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
self.frozen = freeze
self.sparse = sparse
self.reset_parameters(init_embedding)
def reset_parameters(self, init_embedding):
if init_embedding is None:
scale = np.sqrt(3.0 / self.embedding_dim)
self.weight.data.uniform_(-scale, scale)
else:
assign_tensor(self.weight, init_embedding)
if self.padding_idx is not None:
self.weight.data[self.padding_idx].fill_(0)
if self.frozen:
if init_embedding is None:
raise Warning('Freeze embeddings which are randomly initialized.')
self.weight.requires_grad = False
def freeze(self):
self.weight.requires_grad = False
self.frozen = True
def forward(self, input):
padding_idx = self.padding_idx
if padding_idx is None:
padding_idx = -1
input_size = input.size()
if input.dim() > 2:
num_inputs = int(np.prod(input_size[:-1]))
input = input.view(num_inputs, input_size[-1])
output_size = input_size + (self.embedding_dim,)
return embedding(input,self.weight,padding_idx,self.max_norm,
self.norm_type,self.scale_grad_by_freq,
self.sparse).view(output_size)
#return self._backend.Embedding.apply(
# input, self.weight,
# padding_idx, self.max_norm, self.norm_type,
# self.scale_grad_by_freq, self.sparse).view(output_size)
def __repr__(self):
s = '{name}({num_embeddings}, {embedding_dim}'
if self.padding_idx is not None:
s += ', padding_idx={padding_idx}'
if self.max_norm is not None:
s += ', max_norm={max_norm}'
if self.norm_type != 2:
s += ', norm_type={norm_type}'
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
if self.sparse is not False:
s += ', sparse=True'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)