| import torch | |
| from torch import Tensor | |
| def get_dropout_mask( | |
| dropout: float, | |
| z: Tensor, | |
| training: bool, | |
| columnwise: bool = False, | |
| ) -> Tensor: | |
| """Get the dropout mask. | |
| Parameters | |
| ---------- | |
| dropout : float | |
| The dropout rate | |
| z : torch.Tensor | |
| The tensor to apply dropout to | |
| training : bool | |
| Whether the model is in training mode | |
| columnwise : bool, optional | |
| Whether to apply dropout columnwise | |
| Returns | |
| ------- | |
| torch.Tensor | |
| The dropout mask | |
| """ | |
| dropout = dropout * training | |
| v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1] | |
| d = torch.rand(v.shape, dtype=torch.float32, device=v.device) >= dropout | |
| d = d * 1.0 / (1.0 - dropout) | |
| return d | |