|
|
import numpy as np |
|
|
import torch |
|
|
from torch import nn |
|
|
from typing import Union, List, Tuple |
|
|
from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks |
|
|
from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp |
|
|
from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder |
|
|
from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder |
|
|
|
|
|
|
|
|
class UNetDecoder(nn.Module): |
|
|
def __init__(self, |
|
|
encoder: Union[PlainConvEncoder, ResidualEncoder], |
|
|
num_classes: int, |
|
|
n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], |
|
|
deep_supervision, nonlin_first: bool = False): |
|
|
""" |
|
|
This class needs the skips of the encoder as input in its forward. |
|
|
|
|
|
the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder |
|
|
are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck |
|
|
features and the lowest skip as inputs |
|
|
the decoder has two (three) parts in each stage: |
|
|
1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage) |
|
|
2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge |
|
|
3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits? |
|
|
:param encoder: |
|
|
:param num_classes: |
|
|
:param n_conv_per_stage: |
|
|
:param deep_supervision: |
|
|
""" |
|
|
super().__init__() |
|
|
self.deep_supervision = deep_supervision |
|
|
self.encoder = encoder |
|
|
self.num_classes = num_classes |
|
|
n_stages_encoder = len(encoder.output_channels) |
|
|
if isinstance(n_conv_per_stage, int): |
|
|
n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) |
|
|
assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ |
|
|
"resolution stages - 1 (n_stages in encoder - 1), " \ |
|
|
"here: %d" % n_stages_encoder |
|
|
|
|
|
transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) |
|
|
|
|
|
|
|
|
stages = [] |
|
|
transpconvs = [] |
|
|
seg_layers = [] |
|
|
for s in range(1, n_stages_encoder): |
|
|
input_features_below = encoder.output_channels[-s] |
|
|
input_features_skip = encoder.output_channels[-(s + 1)] |
|
|
stride_for_transpconv = encoder.strides[-s] |
|
|
transpconvs.append(transpconv_op( |
|
|
input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, |
|
|
bias=encoder.conv_bias |
|
|
)) |
|
|
|
|
|
stages.append(StackedConvBlocks( |
|
|
n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip, |
|
|
encoder.kernel_sizes[-(s + 1)], 1, encoder.conv_bias, encoder.norm_op, encoder.norm_op_kwargs, |
|
|
encoder.dropout_op, encoder.dropout_op_kwargs, encoder.nonlin, encoder.nonlin_kwargs, nonlin_first |
|
|
)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) |
|
|
|
|
|
self.stages = nn.ModuleList(stages) |
|
|
self.transpconvs = nn.ModuleList(transpconvs) |
|
|
self.seg_layers = nn.ModuleList(seg_layers) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, skips): |
|
|
lres_input = skips[-1] |
|
|
all_feature_maps = [] |
|
|
|
|
|
for s in range(len(self.stages)): |
|
|
x = self.transpconvs[s](lres_input) |
|
|
x = torch.cat((x, skips[-(s+2)]), 1) |
|
|
x = self.stages[s](x) |
|
|
all_feature_maps.append(x) |
|
|
if s == (len(self.stages) - 1): |
|
|
seg_output = self.seg_layers[-1](x) |
|
|
all_feature_maps.append(seg_output) |
|
|
lres_input = x |
|
|
return all_feature_maps |
|
|
|
|
|
def compute_conv_feature_map_size(self, input_size): |
|
|
""" |
|
|
IMPORTANT: input_size is the input_size of the encoder! |
|
|
:param input_size: |
|
|
:return: |
|
|
""" |
|
|
|
|
|
|
|
|
skip_sizes = [] |
|
|
for s in range(len(self.encoder.strides) - 1): |
|
|
skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])]) |
|
|
input_size = skip_sizes[-1] |
|
|
|
|
|
|
|
|
assert len(skip_sizes) == len(self.stages) |
|
|
|
|
|
|
|
|
output = np.int64(0) |
|
|
for s in range(len(self.stages)): |
|
|
|
|
|
|
|
|
output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)]) |
|
|
|
|
|
output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64) |
|
|
|
|
|
if self.deep_supervision or (s == (len(self.stages) - 1)): |
|
|
output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64) |
|
|
return output |