File size: 449 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
__all__ = ['dropout']
def dropout(a, p=0.5, dim=1, inplace=False, to_mean=False):
n = a.shape[dim]
to_drop = torch.where(torch.rand(n, device=a.device).detach() < p)[0]
out = a if inplace else a.clone()
if not to_mean:
out.index_fill_(dim, to_drop, 0)
return out
if dim == 1:
out[:, to_drop] = a.mean(dim=0)[to_drop]
return out
out[to_drop] = a.mean(dim=0)
return out
|