import torch from torch import Tensor def get_dropout_mask( dropout: float, z: Tensor, training: bool, columnwise: bool = False, ) -> Tensor: """Get the dropout mask. Parameters ---------- dropout : float The dropout rate z : torch.Tensor The tensor to apply dropout to training : bool Whether the model is in training mode columnwise : bool, optional Whether to apply dropout columnwise Returns ------- torch.Tensor The dropout mask """ dropout = dropout * training v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1] d = torch.rand(v.shape, dtype=torch.float32, device=v.device) >= dropout d = d * 1.0 / (1.0 - dropout) return d