File size: 3,759 Bytes
cbe6208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Basic layers for composite models"""

import warnings

import torch
from torch import _VF, nn


class ImageEmbedding(nn.Module):
    """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs."""

    def __init__(self, num_embeddings, sequence_length, image_size_pixels, **kwargs):
        """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs.

        The embedding is a single 2D image and is appended at each step in the 1st dimension
        (assumed to be time).

        Args:
            num_embeddings: Size of the dictionary of embeddings
            sequence_length: The time sequence length of the data.
            image_size_pixels: The spatial size of the image. Assumed square.
            **kwargs: See `torch.nn.Embedding` for more possible arguments.
        """
        super().__init__()
        self.image_size_pixels = image_size_pixels
        self.sequence_length = sequence_length
        self._embed = nn.Embedding(
            num_embeddings=num_embeddings,
            embedding_dim=image_size_pixels * image_size_pixels,
            **kwargs,
        )

    def forward(self, x, id):
        """Append ID embedding to image"""
        emb = self._embed(id)
        emb = emb.reshape((-1, 1, 1, self.image_size_pixels, self.image_size_pixels))
        emb = emb.repeat(1, 1, self.sequence_length, 1, 1)
        x = torch.cat((x, emb), dim=1)
        return x


class CompleteDropoutNd(nn.Module):
    """A layer used to completely drop out all elements of a N-dimensional sample.

    Each sample will be zeroed out independently on every forward call with probability `p` using
    samples from a Bernoulli distribution.

    """

    __constants__ = ["p", "inplace", "n_dim"]
    p: float
    inplace: bool
    n_dim: int

    def __init__(self, n_dim, p=0.5, inplace=False):
        """A layer used to completely drop out all elements of a N-dimensional sample.

        Args:
            n_dim: Number of dimensions of each sample not including channels. E.g. a sample with
                shape (channel, time, height, width) would use `n_dim=3`.
            p: probability of a channel to be zeroed. Default: 0.5
            training: apply dropout if is `True`. Default: `True`
            inplace: If set to `True`, will do this operation in-place. Default: `False`
        """
        super().__init__()
        if p < 0 or p > 1:
            raise ValueError(
                "dropout probability has to be between 0 and 1, " "but got {}".format(p)
            )
        self.p = p
        self.inplace = inplace
        self.n_dim = n_dim

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """Run dropout"""
        p = self.p
        inp_dim = input.dim()

        if inp_dim not in (self.n_dim + 1, self.n_dim + 2):
            warn_msg = (
                f"CompleteDropoutNd: Received a {inp_dim}-D input. Expected either a single sample"
                f" with {self.n_dim+1} dimensions, or a batch of samples with {self.n_dim+2}"
                " dimensions."
            )
            warnings.warn(warn_msg)

        is_batched = inp_dim == self.n_dim + 2
        if not is_batched:
            input = input.unsqueeze_(0) if self.inplace else input.unsqueeze(0)

        input = input.unsqueeze_(1) if self.inplace else input.unsqueeze(1)

        result = (
            _VF.feature_dropout_(input, p, self.training)
            if self.inplace
            else _VF.feature_dropout(input, p, self.training)
        )

        result = result.squeeze_(1) if self.inplace else result.squeeze(1)

        if not is_batched:
            result = result.squeeze_(0) if self.inplace else result.squeeze(0)

        return result