import torch __all__ = ['dropout'] def dropout(a, p=0.5, dim=1, inplace=False, to_mean=False): n = a.shape[dim] to_drop = torch.where(torch.rand(n, device=a.device).detach() < p)[0] out = a if inplace else a.clone() if not to_mean: out.index_fill_(dim, to_drop, 0) return out if dim == 1: out[:, to_drop] = a.mean(dim=0)[to_drop] return out out[to_drop] = a.mean(dim=0) return out