| """Basic layers for composite models""" | |
| import warnings | |
| import torch | |
| from torch import _VF, nn | |
| class ImageEmbedding(nn.Module): | |
| """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs.""" | |
| def __init__(self, num_embeddings, sequence_length, image_size_pixels, **kwargs): | |
| """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs. | |
| The embedding is a single 2D image and is appended at each step in the 1st dimension | |
| (assumed to be time). | |
| Args: | |
| num_embeddings: Size of the dictionary of embeddings | |
| sequence_length: The time sequence length of the data. | |
| image_size_pixels: The spatial size of the image. Assumed square. | |
| **kwargs: See `torch.nn.Embedding` for more possible arguments. | |
| """ | |
| super().__init__() | |
| self.image_size_pixels = image_size_pixels | |
| self.sequence_length = sequence_length | |
| self._embed = nn.Embedding( | |
| num_embeddings=num_embeddings, | |
| embedding_dim=image_size_pixels * image_size_pixels, | |
| **kwargs, | |
| ) | |
| def forward(self, x, id): | |
| """Append ID embedding to image""" | |
| emb = self._embed(id) | |
| emb = emb.reshape((-1, 1, 1, self.image_size_pixels, self.image_size_pixels)) | |
| emb = emb.repeat(1, 1, self.sequence_length, 1, 1) | |
| x = torch.cat((x, emb), dim=1) | |
| return x | |
| class CompleteDropoutNd(nn.Module): | |
| """A layer used to completely drop out all elements of a N-dimensional sample. | |
| Each sample will be zeroed out independently on every forward call with probability `p` using | |
| samples from a Bernoulli distribution. | |
| """ | |
| __constants__ = ["p", "inplace", "n_dim"] | |
| p: float | |
| inplace: bool | |
| n_dim: int | |
| def __init__(self, n_dim, p=0.5, inplace=False): | |
| """A layer used to completely drop out all elements of a N-dimensional sample. | |
| Args: | |
| n_dim: Number of dimensions of each sample not including channels. E.g. a sample with | |
| shape (channel, time, height, width) would use `n_dim=3`. | |
| p: probability of a channel to be zeroed. Default: 0.5 | |
| training: apply dropout if is `True`. Default: `True` | |
| inplace: If set to `True`, will do this operation in-place. Default: `False` | |
| """ | |
| super().__init__() | |
| if p < 0 or p > 1: | |
| raise ValueError( | |
| "dropout probability has to be between 0 and 1, " "but got {}".format(p) | |
| ) | |
| self.p = p | |
| self.inplace = inplace | |
| self.n_dim = n_dim | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| """Run dropout""" | |
| p = self.p | |
| inp_dim = input.dim() | |
| if inp_dim not in (self.n_dim + 1, self.n_dim + 2): | |
| warn_msg = ( | |
| f"CompleteDropoutNd: Received a {inp_dim}-D input. Expected either a single sample" | |
| f" with {self.n_dim+1} dimensions, or a batch of samples with {self.n_dim+2}" | |
| " dimensions." | |
| ) | |
| warnings.warn(warn_msg) | |
| is_batched = inp_dim == self.n_dim + 2 | |
| if not is_batched: | |
| input = input.unsqueeze_(0) if self.inplace else input.unsqueeze(0) | |
| input = input.unsqueeze_(1) if self.inplace else input.unsqueeze(1) | |
| result = ( | |
| _VF.feature_dropout_(input, p, self.training) | |
| if self.inplace | |
| else _VF.feature_dropout(input, p, self.training) | |
| ) | |
| result = result.squeeze_(1) if self.inplace else result.squeeze(1) | |
| if not is_batched: | |
| result = result.squeeze_(0) if self.inplace else result.squeeze(0) | |
| return result | |