File size: 797 Bytes
cba994c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from torch import Tensor


def get_dropout_mask(

    dropout: float,

    z: Tensor,

    training: bool,

    columnwise: bool = False,

) -> Tensor:
    """Get the dropout mask.



    Parameters

    ----------

    dropout : float

        The dropout rate

    z : torch.Tensor

        The tensor to apply dropout to

    training : bool

        Whether the model is in training mode

    columnwise : bool, optional

        Whether to apply dropout columnwise



    Returns

    -------

    torch.Tensor

        The dropout mask



    """
    dropout = dropout * training
    v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1]
    d = torch.rand(v.shape, dtype=torch.float32, device=v.device) >= dropout
    d = d * 1.0 / (1.0 - dropout)
    return d