English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
raw
history blame contribute delete
449 Bytes
import torch
__all__ = ['dropout']
def dropout(a, p=0.5, dim=1, inplace=False, to_mean=False):
n = a.shape[dim]
to_drop = torch.where(torch.rand(n, device=a.device).detach() < p)[0]
out = a if inplace else a.clone()
if not to_mean:
out.index_fill_(dim, to_drop, 0)
return out
if dim == 1:
out[:, to_drop] = a.mean(dim=0)[to_drop]
return out
out[to_drop] = a.mean(dim=0)
return out