English
File size: 449 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch


__all__ = ['dropout']


def dropout(a, p=0.5, dim=1, inplace=False, to_mean=False):
    n = a.shape[dim]
    to_drop = torch.where(torch.rand(n, device=a.device).detach() < p)[0]
    out = a if inplace else a.clone()


    if not to_mean:
        out.index_fill_(dim, to_drop, 0)
        return out

    if dim == 1:
        out[:, to_drop] = a.mean(dim=0)[to_drop]
        return out

    out[to_drop] = a.mean(dim=0)
    return out