import torch class LockedDropout(torch.nn.Module): """ Implementation of locked (or variational) dropout. Randomly drops out entire parameters in embedding space. :param dropout_rate: represent the fraction of the input unit to be dropped. It will be from 0 to 1. :param batch_first: represent if the drop will perform in an ascending manner :param inplace: """ def __init__(self, dropout_rate=0.5, batch_first=True, inplace=False): super(LockedDropout, self).__init__() self.dropout_rate = dropout_rate self.batch_first = batch_first self.inplace = inplace def forward(self, x): if not self.training or not self.dropout_rate: return x if not self.batch_first: m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate) else: m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate) mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate) mask = mask.expand_as(x) return mask * x def extra_repr(self): inplace_str = ", inplace" if self.inplace else "" return "p={}{}".format(self.dropout_rate, inplace_str) class WordDropout(torch.nn.Module): """ Implementation of word dropout. Randomly drops out entire words (or characters) in embedding space. """ def __init__(self, dropout_rate=0.05, inplace=False): super(WordDropout, self).__init__() self.dropout_rate = dropout_rate self.inplace = inplace def forward(self, x): if not self.training or not self.dropout_rate: return x m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate) mask = torch.autograd.Variable(m, requires_grad=False) return mask * x def extra_repr(self): inplace_str = ", inplace" if self.inplace else "" return "p={}{}".format(self.dropout_rate, inplace_str)