| """ |
| Module for autoencoder model. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class AutoEncoder(nn.Module): |
| """ |
| Auto encoder model class. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.network = nn.Sequential( |
| |
| nn.Conv3d(in_channels=1, out_channels=16, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=16), |
| nn.LeakyReLU(), |
|
|
| nn.Conv3d(in_channels=16, out_channels=32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=32), |
| nn.LeakyReLU(), |
|
|
| nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=64), |
| nn.LeakyReLU(), |
|
|
| |
| nn.Conv3d(in_channels=64, out_channels=16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=1), |
| nn.GroupNorm(num_groups=8, num_channels=16), |
| nn.LeakyReLU(), |
|
|
| |
| nn.ConvTranspose3d(in_channels=16, out_channels=32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1, output_padding=(1, 1, 1)), |
| nn.GroupNorm(num_groups=8, num_channels=32), |
| nn.LeakyReLU(), |
|
|
| nn.ConvTranspose3d(in_channels=32, out_channels=16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1, output_padding=(1, 1, 1)), |
| nn.GroupNorm(num_groups=8, num_channels=16), |
| nn.LeakyReLU(), |
|
|
| |
| nn.ConvTranspose3d(in_channels=16, out_channels=1, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=1, output_padding=(0, 1, 1)), |
| nn.Tanh(), |
| ) |
|
|
| def forward(self, x): |
| |
| x = x.permute(0, 2, 1, 3, 4) |
| x = self.network(x) |
| x = x.permute(0, 2, 1, 3, 4) |
| return x |
| |
|
|
| if __name__ == "__main__": |
| |
| model = AutoEncoder() |
| x = torch.randn(2, 16, 1, 128, 128) |
|
|
| |
| xd = x.permute(0,2,1,3,4) |
| for layer in model.network: |
| xd = layer(xd) |
| if isinstance(xd, torch.Tensor): |
| print(type(layer).__name__, tuple(xd.shape)) |
|
|
| |
| out = model(x) |
| print("out:", tuple(out.shape)) |
| assert out.shape == x.shape |