|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class SharedDropout(nn.Module): |
|
|
|
|
|
def __init__(self, p=0.5, batch_first=True): |
|
|
super(SharedDropout, self).__init__() |
|
|
|
|
|
self.p = p |
|
|
self.batch_first = batch_first |
|
|
|
|
|
def extra_repr(self): |
|
|
s = f"p={self.p}" |
|
|
if self.batch_first: |
|
|
s += f", batch_first={self.batch_first}" |
|
|
|
|
|
return s |
|
|
|
|
|
def forward(self, x): |
|
|
if self.training: |
|
|
if self.batch_first: |
|
|
mask = self.get_mask(x[:, 0], self.p) |
|
|
else: |
|
|
mask = self.get_mask(x[0], self.p) |
|
|
x *= mask.unsqueeze(1) if self.batch_first else mask |
|
|
|
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
def get_mask(x, p): |
|
|
mask = x.new_empty(x.shape).bernoulli_(1 - p) |
|
|
mask = mask / (1 - p) |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
class IndependentDropout(nn.Module): |
|
|
|
|
|
def __init__(self, p=0.5): |
|
|
super(IndependentDropout, self).__init__() |
|
|
|
|
|
self.p = p |
|
|
|
|
|
def extra_repr(self): |
|
|
return f"p={self.p}" |
|
|
|
|
|
def forward(self, *items): |
|
|
if self.training: |
|
|
masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) |
|
|
for x in items] |
|
|
total = sum(masks) |
|
|
scale = len(items) / total.max(torch.ones_like(total)) |
|
|
masks = [mask * scale for mask in masks] |
|
|
items = [item * mask.unsqueeze(dim=-1) |
|
|
for item, mask in zip(items, masks)] |
|
|
|
|
|
return items |
|
|
|