Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from feature_extractor_models.base import modules as md | |
| class DecoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| skip_channels, | |
| out_channels, | |
| use_batchnorm=True, | |
| attention_type=None, | |
| ): | |
| super().__init__() | |
| self.conv1 = md.Conv2dReLU( | |
| in_channels + skip_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| use_batchnorm=use_batchnorm, | |
| ) | |
| self.attention1 = md.Attention( | |
| attention_type, in_channels=in_channels + skip_channels | |
| ) | |
| self.conv2 = md.Conv2dReLU( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| use_batchnorm=use_batchnorm, | |
| ) | |
| self.attention2 = md.Attention(attention_type, in_channels=out_channels) | |
| def forward(self, x, skip=None): | |
| x = F.interpolate(x, scale_factor=2, mode="nearest") | |
| if skip is not None: | |
| x = torch.cat([x, skip], dim=1) | |
| x = self.attention1(x) | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| x = self.attention2(x) | |
| return x | |
| class CenterBlock(nn.Sequential): | |
| def __init__(self, in_channels, out_channels, use_batchnorm=True): | |
| conv1 = md.Conv2dReLU( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| use_batchnorm=use_batchnorm, | |
| ) | |
| conv2 = md.Conv2dReLU( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| use_batchnorm=use_batchnorm, | |
| ) | |
| super().__init__(conv1, conv2) | |
| class UnetPlusPlusDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| encoder_channels, | |
| decoder_channels, | |
| n_blocks=5, | |
| use_batchnorm=True, | |
| attention_type=None, | |
| center=False, | |
| ): | |
| super().__init__() | |
| if n_blocks != len(decoder_channels): | |
| raise ValueError( | |
| "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( | |
| n_blocks, len(decoder_channels) | |
| ) | |
| ) | |
| # remove first skip with same spatial resolution | |
| encoder_channels = encoder_channels[1:] | |
| # reverse channels to start from head of encoder | |
| encoder_channels = encoder_channels[::-1] | |
| # computing blocks input and output channels | |
| head_channels = encoder_channels[0] | |
| self.in_channels = [head_channels] + list(decoder_channels[:-1]) | |
| self.skip_channels = list(encoder_channels[1:]) + [0] | |
| self.out_channels = decoder_channels | |
| if center: | |
| self.center = CenterBlock( | |
| head_channels, head_channels, use_batchnorm=use_batchnorm | |
| ) | |
| else: | |
| self.center = nn.Identity() | |
| # combine decoder keyword arguments | |
| kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) | |
| blocks = {} | |
| for layer_idx in range(len(self.in_channels) - 1): | |
| for depth_idx in range(layer_idx + 1): | |
| if depth_idx == 0: | |
| in_ch = self.in_channels[layer_idx] | |
| skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1) | |
| out_ch = self.out_channels[layer_idx] | |
| else: | |
| out_ch = self.skip_channels[layer_idx] | |
| skip_ch = self.skip_channels[layer_idx] * ( | |
| layer_idx + 1 - depth_idx | |
| ) | |
| in_ch = self.skip_channels[layer_idx - 1] | |
| blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock( | |
| in_ch, skip_ch, out_ch, **kwargs | |
| ) | |
| blocks[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock( | |
| self.in_channels[-1], 0, self.out_channels[-1], **kwargs | |
| ) | |
| self.blocks = nn.ModuleDict(blocks) | |
| self.depth = len(self.in_channels) - 1 | |
| def forward(self, *features): | |
| features = features[1:] # remove first skip with same spatial resolution | |
| features = features[::-1] # reverse channels to start from head of encoder | |
| # start building dense connections | |
| dense_x = {} | |
| for layer_idx in range(len(self.in_channels) - 1): | |
| for depth_idx in range(self.depth - layer_idx): | |
| if layer_idx == 0: | |
| output = self.blocks[f"x_{depth_idx}_{depth_idx}"]( | |
| features[depth_idx], features[depth_idx + 1] | |
| ) | |
| dense_x[f"x_{depth_idx}_{depth_idx}"] = output | |
| else: | |
| dense_l_i = depth_idx + layer_idx | |
| cat_features = [ | |
| dense_x[f"x_{idx}_{dense_l_i}"] | |
| for idx in range(depth_idx + 1, dense_l_i + 1) | |
| ] | |
| cat_features = torch.cat( | |
| cat_features + [features[dense_l_i + 1]], dim=1 | |
| ) | |
| dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[ | |
| f"x_{depth_idx}_{dense_l_i}" | |
| ](dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features) | |
| dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"]( | |
| dense_x[f"x_{0}_{self.depth-1}"] | |
| ) | |
| return dense_x[f"x_{0}_{self.depth}"] | |