FelixzeroSun's picture
Upload folder using huggingface_hub
19c1f58 verified
import numpy as np
import torch
from torch import nn
from typing import Union, List, Tuple
from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks
from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp
from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder
from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder
class UNetDecoder(nn.Module):
def __init__(self,
encoder: Union[PlainConvEncoder, ResidualEncoder],
num_classes: int,
n_conv_per_stage: Union[int, Tuple[int, ...], List[int]],
deep_supervision, nonlin_first: bool = False):
"""
This class needs the skips of the encoder as input in its forward.
the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder
are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck
features and the lowest skip as inputs
the decoder has two (three) parts in each stage:
1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage)
2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge
3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits?
:param encoder:
:param num_classes:
:param n_conv_per_stage:
:param deep_supervision:
"""
super().__init__()
self.deep_supervision = deep_supervision
self.encoder = encoder
self.num_classes = num_classes
n_stages_encoder = len(encoder.output_channels)
if isinstance(n_conv_per_stage, int):
n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1)
assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \
"resolution stages - 1 (n_stages in encoder - 1), " \
"here: %d" % n_stages_encoder
transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op)
# we start with the bottleneck and work out way up
stages = []
transpconvs = []
seg_layers = []
for s in range(1, n_stages_encoder):
input_features_below = encoder.output_channels[-s]
input_features_skip = encoder.output_channels[-(s + 1)]
stride_for_transpconv = encoder.strides[-s]
transpconvs.append(transpconv_op(
input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv,
bias=encoder.conv_bias
))
# input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output)
stages.append(StackedConvBlocks(
n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip,
encoder.kernel_sizes[-(s + 1)], 1, encoder.conv_bias, encoder.norm_op, encoder.norm_op_kwargs,
encoder.dropout_op, encoder.dropout_op_kwargs, encoder.nonlin, encoder.nonlin_kwargs, nonlin_first
))
# we always build the deep supervision outputs so that we can always load parameters. If we don't do this
# then a model trained with deep_supervision=True could not easily be loaded at inference time where
# deep supervision is not needed. It's just a convenience thing
seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True))
self.stages = nn.ModuleList(stages)
self.transpconvs = nn.ModuleList(transpconvs)
self.seg_layers = nn.ModuleList(seg_layers)
# def forward(self, skips):
# lres_input = skips[-1]
# seg_outputs = []
# for s in range(len(self.stages)):
# x = self.transpconvs[s](lres_input)
# x = torch.cat((x, skips[-(s+2)]), 1)
# x = self.stages[s](x)
# if self.deep_supervision:
# seg_outputs.append(self.seg_layers[s](x))
# elif s == (len(self.stages) - 1):
# seg_outputs.append(self.seg_layers[-1](x))
# lres_input = x
# # invert seg outputs so that the largest segmentation prediction is returned first
# seg_outputs = seg_outputs[::-1]
# if not self.deep_supervision:
# r = seg_outputs[0]
# else:
# r = seg_outputs
# return r
def forward(self, skips):
lres_input = skips[-1]
all_feature_maps = []
for s in range(len(self.stages)):
x = self.transpconvs[s](lres_input)
x = torch.cat((x, skips[-(s+2)]), 1)
x = self.stages[s](x)
all_feature_maps.append(x)
if s == (len(self.stages) - 1):
seg_output = self.seg_layers[-1](x)
all_feature_maps.append(seg_output)
lres_input = x
return all_feature_maps
def compute_conv_feature_map_size(self, input_size):
"""
IMPORTANT: input_size is the input_size of the encoder!
:param input_size:
:return:
"""
# first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at
# least have the size of the skip above that (therefore -1)
skip_sizes = []
for s in range(len(self.encoder.strides) - 1):
skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])])
input_size = skip_sizes[-1]
# print(skip_sizes)
assert len(skip_sizes) == len(self.stages)
# our ops are the other way around, so let's match things up
output = np.int64(0)
for s in range(len(self.stages)):
# print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)])
# conv blocks
output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)])
# trans conv
output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64)
# segmentation
if self.deep_supervision or (s == (len(self.stages) - 1)):
output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64)
return output