File size: 763 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import torch
from torch import Tensor
def get_dropout_mask(
dropout: float,
z: Tensor,
training: bool,
columnwise: bool = False,
) -> Tensor:
"""Get the dropout mask.
Parameters
----------
dropout : float
The dropout rate
z : torch.Tensor
The tensor to apply dropout to
training : bool
Whether the model is in training mode
columnwise : bool, optional
Whether to apply dropout columnwise
Returns
-------
torch.Tensor
The dropout mask
"""
dropout = dropout * training
v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1]
d = torch.rand(v.shape, dtype=torch.float32, device=v.device) >= dropout
d = d * 1.0 / (1.0 - dropout)
return d
|