varox34's picture
Upload 64 files
366b225 verified
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class SharedDropout(nn.Module):
def __init__(self, p=0.5, batch_first=True):
super(SharedDropout, self).__init__()
self.p = p
self.batch_first = batch_first
def extra_repr(self):
s = f"p={self.p}"
if self.batch_first:
s += f", batch_first={self.batch_first}"
return s
def forward(self, x):
if self.training:
if self.batch_first:
mask = self.get_mask(x[:, 0], self.p)
else:
mask = self.get_mask(x[0], self.p)
x *= mask.unsqueeze(1) if self.batch_first else mask
return x
@staticmethod
def get_mask(x, p):
mask = x.new_empty(x.shape).bernoulli_(1 - p)
mask = mask / (1 - p)
return mask
class IndependentDropout(nn.Module):
def __init__(self, p=0.5):
super(IndependentDropout, self).__init__()
self.p = p
def extra_repr(self):
return f"p={self.p}"
def forward(self, *items):
if self.training:
masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p)
for x in items]
total = sum(masks)
scale = len(items) / total.max(torch.ones_like(total))
masks = [mask * scale for mask in masks]
items = [item * mask.unsqueeze(dim=-1)
for item, mask in zip(items, masks)]
return items