File size: 3,759 Bytes
cbe6208 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
"""Basic layers for composite models"""
import warnings
import torch
from torch import _VF, nn
class ImageEmbedding(nn.Module):
"""A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs."""
def __init__(self, num_embeddings, sequence_length, image_size_pixels, **kwargs):
"""A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs.
The embedding is a single 2D image and is appended at each step in the 1st dimension
(assumed to be time).
Args:
num_embeddings: Size of the dictionary of embeddings
sequence_length: The time sequence length of the data.
image_size_pixels: The spatial size of the image. Assumed square.
**kwargs: See `torch.nn.Embedding` for more possible arguments.
"""
super().__init__()
self.image_size_pixels = image_size_pixels
self.sequence_length = sequence_length
self._embed = nn.Embedding(
num_embeddings=num_embeddings,
embedding_dim=image_size_pixels * image_size_pixels,
**kwargs,
)
def forward(self, x, id):
"""Append ID embedding to image"""
emb = self._embed(id)
emb = emb.reshape((-1, 1, 1, self.image_size_pixels, self.image_size_pixels))
emb = emb.repeat(1, 1, self.sequence_length, 1, 1)
x = torch.cat((x, emb), dim=1)
return x
class CompleteDropoutNd(nn.Module):
"""A layer used to completely drop out all elements of a N-dimensional sample.
Each sample will be zeroed out independently on every forward call with probability `p` using
samples from a Bernoulli distribution.
"""
__constants__ = ["p", "inplace", "n_dim"]
p: float
inplace: bool
n_dim: int
def __init__(self, n_dim, p=0.5, inplace=False):
"""A layer used to completely drop out all elements of a N-dimensional sample.
Args:
n_dim: Number of dimensions of each sample not including channels. E.g. a sample with
shape (channel, time, height, width) would use `n_dim=3`.
p: probability of a channel to be zeroed. Default: 0.5
training: apply dropout if is `True`. Default: `True`
inplace: If set to `True`, will do this operation in-place. Default: `False`
"""
super().__init__()
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
)
self.p = p
self.inplace = inplace
self.n_dim = n_dim
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Run dropout"""
p = self.p
inp_dim = input.dim()
if inp_dim not in (self.n_dim + 1, self.n_dim + 2):
warn_msg = (
f"CompleteDropoutNd: Received a {inp_dim}-D input. Expected either a single sample"
f" with {self.n_dim+1} dimensions, or a batch of samples with {self.n_dim+2}"
" dimensions."
)
warnings.warn(warn_msg)
is_batched = inp_dim == self.n_dim + 2
if not is_batched:
input = input.unsqueeze_(0) if self.inplace else input.unsqueeze(0)
input = input.unsqueeze_(1) if self.inplace else input.unsqueeze(1)
result = (
_VF.feature_dropout_(input, p, self.training)
if self.inplace
else _VF.feature_dropout(input, p, self.training)
)
result = result.squeeze_(1) if self.inplace else result.squeeze(1)
if not is_batched:
result = result.squeeze_(0) if self.inplace else result.squeeze(0)
return result
|