nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
raw
history blame contribute delete
763 Bytes
import torch
from torch import Tensor
def get_dropout_mask(
dropout: float,
z: Tensor,
training: bool,
columnwise: bool = False,
) -> Tensor:
"""Get the dropout mask.
Parameters
----------
dropout : float
The dropout rate
z : torch.Tensor
The tensor to apply dropout to
training : bool
Whether the model is in training mode
columnwise : bool, optional
Whether to apply dropout columnwise
Returns
-------
torch.Tensor
The dropout mask
"""
dropout = dropout * training
v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1]
d = torch.rand(v.shape, dtype=torch.float32, device=v.device) >= dropout
d = d * 1.0 / (1.0 - dropout)
return d