Spaces:
Running
Running
| import torch.nn as nn | |
| from feature_extractor_models.base import modules | |
| class TransposeX2(nn.Sequential): | |
| def __init__(self, in_channels, out_channels, use_batchnorm=True): | |
| super().__init__() | |
| layers = [ | |
| nn.ConvTranspose2d( | |
| in_channels, out_channels, kernel_size=4, stride=2, padding=1 | |
| ), | |
| nn.ReLU(inplace=True), | |
| ] | |
| if use_batchnorm: | |
| layers.insert(1, nn.BatchNorm2d(out_channels)) | |
| super().__init__(*layers) | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, use_batchnorm=True): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| modules.Conv2dReLU( | |
| in_channels, | |
| in_channels // 4, | |
| kernel_size=1, | |
| use_batchnorm=use_batchnorm, | |
| ), | |
| TransposeX2( | |
| in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm | |
| ), | |
| modules.Conv2dReLU( | |
| in_channels // 4, | |
| out_channels, | |
| kernel_size=1, | |
| use_batchnorm=use_batchnorm, | |
| ), | |
| ) | |
| def forward(self, x, skip=None): | |
| x = self.block(x) | |
| if skip is not None: | |
| x = x + skip | |
| return x | |
| class LinknetDecoder(nn.Module): | |
| def __init__( | |
| self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True | |
| ): | |
| super().__init__() | |
| # remove first skip | |
| encoder_channels = encoder_channels[1:] | |
| # reverse channels to start from head of encoder | |
| encoder_channels = encoder_channels[::-1] | |
| channels = list(encoder_channels) + [prefinal_channels] | |
| self.blocks = nn.ModuleList( | |
| [ | |
| DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) | |
| for i in range(n_blocks) | |
| ] | |
| ) | |
| def forward(self, *features): | |
| features = features[1:] # remove first skip | |
| features = features[::-1] # reverse channels to start from head of encoder | |
| x = features[0] | |
| skips = features[1:] | |
| for i, decoder_block in enumerate(self.blocks): | |
| skip = skips[i] if i < len(skips) else None | |
| x = decoder_block(x, skip) | |
| return x | |