| import torch.nn as nn | |
| """ | |
| downsampling blocks | |
| (first half of the 'U' in UNet) | |
| [ENCODER] | |
| """ | |
| class EncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels=1, | |
| out_channels=64, | |
| n_layers=2, | |
| all_padding=False, | |
| maxpool=True, | |
| ): | |
| super(EncoderLayer, self).__init__() | |
| f_in_channel = lambda layer: in_channels if layer == 0 else out_channels | |
| f_padding = lambda layer: 1 if layer >= 2 or all_padding else 0 | |
| self.layer = nn.Sequential( | |
| *[ | |
| self._conv_relu_layer( | |
| in_channels=f_in_channel(i), | |
| out_channels=out_channels, | |
| padding=f_padding(i), | |
| ) | |
| for i in range(n_layers) | |
| ] | |
| ) | |
| self.maxpool = maxpool | |
| def _conv_relu_layer(self, in_channels, out_channels, padding=0): | |
| return nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| padding=padding, | |
| ), | |
| nn.ReLU(), | |
| nn.BatchNorm2d(out_channels), | |
| ) | |
| def forward(self, x): | |
| return self.layer(x) | |
| class Encoder(nn.Module): | |
| def __init__(self, config): | |
| super(Encoder, self).__init__() | |
| self.encoder = nn.ModuleDict( | |
| { | |
| name: EncoderLayer( | |
| in_channels=block["in_channels"], | |
| out_channels=block["out_channels"], | |
| n_layers=block["n_layers"], | |
| all_padding=block["all_padding"], | |
| maxpool=block["maxpool"], | |
| ) | |
| for name, block in config.items() | |
| } | |
| ) | |
| self.maxpool = nn.MaxPool2d(2) | |
| def forward(self, x): | |
| output = dict() | |
| for i, (block_name, block) in enumerate(self.encoder.items()): | |
| x = block(x) | |
| output[block_name] = x | |
| if block.maxpool: | |
| x = self.maxpool(x) | |
| return x, output | |