Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from feature_extractor_models.base import modules as md | |
| class PAB(nn.Module): | |
| def __init__(self, in_channels, out_channels, pab_channels=64): | |
| super(PAB, self).__init__() | |
| # Series of 1x1 conv to generate attention feature maps | |
| self.pab_channels = pab_channels | |
| self.in_channels = in_channels | |
| self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) | |
| self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) | |
| self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | |
| self.map_softmax = nn.Softmax(dim=1) | |
| self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| bsize = x.size()[0] | |
| h = x.size()[2] | |
| w = x.size()[3] | |
| x_top = self.top_conv(x) | |
| x_center = self.center_conv(x) | |
| x_bottom = self.bottom_conv(x) | |
| x_top = x_top.flatten(2) | |
| x_center = x_center.flatten(2).transpose(1, 2) | |
| x_bottom = x_bottom.flatten(2).transpose(1, 2) | |
| sp_map = torch.matmul(x_center, x_top) | |
| sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h * w, h * w) | |
| sp_map = torch.matmul(sp_map, x_bottom) | |
| sp_map = sp_map.reshape(bsize, self.in_channels, h, w) | |
| x = x + sp_map | |
| x = self.out_conv(x) | |
| return x | |
| class MFAB(nn.Module): | |
| def __init__( | |
| self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16 | |
| ): | |
| # MFAB is just a modified version of SE-blocks, one for skip, one for input | |
| super(MFAB, self).__init__() | |
| self.hl_conv = nn.Sequential( | |
| md.Conv2dReLU( | |
| in_channels, | |
| in_channels, | |
| kernel_size=3, | |
| padding=1, | |
| use_batchnorm=use_batchnorm, | |
| ), | |
| md.Conv2dReLU( | |
| in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm | |
| ), | |
| ) | |
| reduced_channels = max(1, skip_channels // reduction) | |
| self.SE_ll = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(skip_channels, reduced_channels, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(reduced_channels, skip_channels, 1), | |
| nn.Sigmoid(), | |
| ) | |
| self.SE_hl = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(skip_channels, reduced_channels, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(reduced_channels, skip_channels, 1), | |
| nn.Sigmoid(), | |
| ) | |
| self.conv1 = md.Conv2dReLU( | |
| skip_channels | |
| + skip_channels, # we transform C-prime form high level to C from skip connection | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| use_batchnorm=use_batchnorm, | |
| ) | |
| self.conv2 = md.Conv2dReLU( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| use_batchnorm=use_batchnorm, | |
| ) | |
| def forward(self, x, skip=None): | |
| x = self.hl_conv(x) | |
| x = F.interpolate(x, scale_factor=2, mode="nearest") | |
| attention_hl = self.SE_hl(x) | |
| if skip is not None: | |
| attention_ll = self.SE_ll(skip) | |
| attention_hl = attention_hl + attention_ll | |
| x = x * attention_hl | |
| x = torch.cat([x, skip], dim=1) | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| return x | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True): | |
| super().__init__() | |
| self.conv1 = md.Conv2dReLU( | |
| in_channels + skip_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| use_batchnorm=use_batchnorm, | |
| ) | |
| self.conv2 = md.Conv2dReLU( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| use_batchnorm=use_batchnorm, | |
| ) | |
| def forward(self, x, skip=None): | |
| x = F.interpolate(x, scale_factor=2, mode="nearest") | |
| if skip is not None: | |
| x = torch.cat([x, skip], dim=1) | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| return x | |
| class MAnetDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| encoder_channels, | |
| decoder_channels, | |
| n_blocks=5, | |
| reduction=16, | |
| use_batchnorm=True, | |
| pab_channels=64, | |
| ): | |
| super().__init__() | |
| if n_blocks != len(decoder_channels): | |
| raise ValueError( | |
| "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( | |
| n_blocks, len(decoder_channels) | |
| ) | |
| ) | |
| # remove first skip with same spatial resolution | |
| encoder_channels = encoder_channels[1:] | |
| # reverse channels to start from head of encoder | |
| encoder_channels = encoder_channels[::-1] | |
| # computing blocks input and output channels | |
| head_channels = encoder_channels[0] | |
| in_channels = [head_channels] + list(decoder_channels[:-1]) | |
| skip_channels = list(encoder_channels[1:]) + [0] | |
| out_channels = decoder_channels | |
| self.center = PAB(head_channels, head_channels, pab_channels=pab_channels) | |
| # combine decoder keyword arguments | |
| kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here | |
| blocks = [ | |
| MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) | |
| if skip_ch > 0 | |
| else DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) | |
| for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) | |
| ] | |
| # for the last we dont have skip connection -> use simple decoder block | |
| self.blocks = nn.ModuleList(blocks) | |
| def forward(self, *features): | |
| features = features[1:] # remove first skip with same spatial resolution | |
| features = features[::-1] # reverse channels to start from head of encoder | |
| head = features[0] | |
| skips = features[1:] | |
| x = self.center(head) | |
| for i, decoder_block in enumerate(self.blocks): | |
| skip = skips[i] if i < len(skips) else None | |
| x = decoder_block(x, skip) | |
| return x | |