Spaces:
Runtime error
Runtime error
| import torch | |
| class LockedDropout(torch.nn.Module): | |
| """ | |
| Implementation of locked (or variational) dropout. | |
| Randomly drops out entire parameters in embedding space. | |
| :param dropout_rate: represent the fraction of the input unit to be dropped. It will be from 0 to 1. | |
| :param batch_first: represent if the drop will perform in an ascending manner | |
| :param inplace: | |
| """ | |
| def __init__(self, dropout_rate=0.5, batch_first=True, inplace=False): | |
| super(LockedDropout, self).__init__() | |
| self.dropout_rate = dropout_rate | |
| self.batch_first = batch_first | |
| self.inplace = inplace | |
| def forward(self, x): | |
| if not self.training or not self.dropout_rate: | |
| return x | |
| if not self.batch_first: | |
| m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout_rate) | |
| else: | |
| m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - self.dropout_rate) | |
| mask = torch.autograd.Variable(m, requires_grad=False) / (1 - self.dropout_rate) | |
| mask = mask.expand_as(x) | |
| return mask * x | |
| def extra_repr(self): | |
| inplace_str = ", inplace" if self.inplace else "" | |
| return "p={}{}".format(self.dropout_rate, inplace_str) | |
| class WordDropout(torch.nn.Module): | |
| """ | |
| Implementation of word dropout. Randomly drops out entire words | |
| (or characters) in embedding space. | |
| """ | |
| def __init__(self, dropout_rate=0.05, inplace=False): | |
| super(WordDropout, self).__init__() | |
| self.dropout_rate = dropout_rate | |
| self.inplace = inplace | |
| def forward(self, x): | |
| if not self.training or not self.dropout_rate: | |
| return x | |
| m = x.data.new(x.size(0), x.size(1), 1).bernoulli_(1 - self.dropout_rate) | |
| mask = torch.autograd.Variable(m, requires_grad=False) | |
| return mask * x | |
| def extra_repr(self): | |
| inplace_str = ", inplace" if self.inplace else "" | |
| return "p={}{}".format(self.dropout_rate, inplace_str) |