File size: 2,544 Bytes
c679d56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | """
Module for autoencoder model.
"""
import torch
import torch.nn as nn
class AutoEncoder(nn.Module):
"""
Auto encoder model class.
"""
def __init__(self):
super().__init__()
self.network = nn.Sequential(
# Encoder layers
nn.Conv3d(in_channels=1, out_channels=16, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=1),
nn.GroupNorm(num_groups=8, num_channels=16),
nn.LeakyReLU(),
nn.Conv3d(in_channels=16, out_channels=32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1),
nn.GroupNorm(num_groups=8, num_channels=32),
nn.LeakyReLU(),
nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1),
nn.GroupNorm(num_groups=8, num_channels=64),
nn.LeakyReLU(),
# Bottleneck
nn.Conv3d(in_channels=64, out_channels=16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=1),
nn.GroupNorm(num_groups=8, num_channels=16),
nn.LeakyReLU(),
# Decoder layers
nn.ConvTranspose3d(in_channels=16, out_channels=32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1, output_padding=(1, 1, 1)),
nn.GroupNorm(num_groups=8, num_channels=32),
nn.LeakyReLU(),
nn.ConvTranspose3d(in_channels=32, out_channels=16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1, output_padding=(1, 1, 1)),
nn.GroupNorm(num_groups=8, num_channels=16),
nn.LeakyReLU(),
# Output layer
nn.ConvTranspose3d(in_channels=16, out_channels=1, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1)),
nn.Tanh(),
)
def forward(self, x):
# Permute to match the shape that dataloader gives
x = x.permute(0, 2, 1, 3, 4) # (B,T,C,H,W) -> (B,C,T,H,W)
x = self.network(x)
x = x.permute(0, 2, 1, 3, 4) # backwards: (B,C,T,H,W) -> (B,T,C,H,W)
return x
if __name__ == "__main__":
# Smoke test to assert that shapes are correctly matches
model = AutoEncoder()
x = torch.randn(2, 16, 1, 128, 128)
# debug: seperate variable to permute manually
xd = x.permute(0,2,1,3,4)
for layer in model.network:
xd = layer(xd)
if isinstance(xd, torch.Tensor):
print(type(layer).__name__, tuple(xd.shape))
# real forward prop
out = model(x)
print("out:", tuple(out.shape))
assert out.shape == x.shape |