| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| from nnunet.network_architecture.custom_modules.conv_blocks import StackedConvLayers |
| from nnunet.network_architecture.generic_UNet import Upsample |
| from nnunet.network_architecture.neural_network import SegmentationNetwork |
| from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss |
| from torch import nn |
| import numpy as np |
| from torch.optim import SGD |
|
|
| """ |
| The idea behind this modular U-net ist that we decouple encoder and decoder and thus make things a) a lot more easy to |
| combine and b) enable easy swapping between segmentation or classification mode of the same architecture |
| """ |
|
|
|
|
| def get_default_network_config(dim=2, dropout_p=None, nonlin="LeakyReLU", norm_type="bn"): |
| """ |
| returns a dictionary that contains pointers to conv, nonlin and norm ops and the default kwargs I like to use |
| :return: |
| """ |
| props = {} |
| if dim == 2: |
| props['conv_op'] = nn.Conv2d |
| props['dropout_op'] = nn.Dropout2d |
| elif dim == 3: |
| props['conv_op'] = nn.Conv3d |
| props['dropout_op'] = nn.Dropout3d |
| else: |
| raise NotImplementedError |
|
|
| if norm_type == "bn": |
| if dim == 2: |
| props['norm_op'] = nn.BatchNorm2d |
| elif dim == 3: |
| props['norm_op'] = nn.BatchNorm3d |
| props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True} |
| elif norm_type == "in": |
| if dim == 2: |
| props['norm_op'] = nn.InstanceNorm2d |
| elif dim == 3: |
| props['norm_op'] = nn.InstanceNorm3d |
| props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True} |
| else: |
| raise NotImplementedError |
|
|
| if dropout_p is None: |
| props['dropout_op'] = None |
| props['dropout_op_kwargs'] = {'p': 0, 'inplace': True} |
| else: |
| props['dropout_op_kwargs'] = {'p': dropout_p, 'inplace': True} |
|
|
| props['conv_op_kwargs'] = {'stride': 1, 'dilation': 1, 'bias': True} |
|
|
| if nonlin == "LeakyReLU": |
| props['nonlin'] = nn.LeakyReLU |
| props['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True} |
| elif nonlin == "ReLU": |
| props['nonlin'] = nn.ReLU |
| props['nonlin_kwargs'] = {'inplace': True} |
| else: |
| raise ValueError |
|
|
| return props |
|
|
|
|
|
|
| class PlainConvUNetEncoder(nn.Module): |
| def __init__(self, input_channels, base_num_features, num_blocks_per_stage, feat_map_mul_on_downscale, |
| pool_op_kernel_sizes, conv_kernel_sizes, props, default_return_skips=True, |
| max_num_features=480): |
| """ |
| Following UNet building blocks can be added by utilizing the properties this class exposes (TODO) |
| |
| this one includes the bottleneck layer! |
| |
| :param input_channels: |
| :param base_num_features: |
| :param num_blocks_per_stage: |
| :param feat_map_mul_on_downscale: |
| :param pool_op_kernel_sizes: |
| :param conv_kernel_sizes: |
| :param props: |
| """ |
| super(PlainConvUNetEncoder, self).__init__() |
|
|
| self.default_return_skips = default_return_skips |
| self.props = props |
|
|
| self.stages = [] |
| self.stage_output_features = [] |
| self.stage_pool_kernel_size = [] |
| self.stage_conv_op_kernel_size = [] |
|
|
| assert len(pool_op_kernel_sizes) == len(conv_kernel_sizes) |
|
|
| num_stages = len(conv_kernel_sizes) |
|
|
| if not isinstance(num_blocks_per_stage, (list, tuple)): |
| num_blocks_per_stage = [num_blocks_per_stage] * num_stages |
| else: |
| assert len(num_blocks_per_stage) == num_stages |
|
|
| self.num_blocks_per_stage = num_blocks_per_stage |
|
|
| current_input_features = input_channels |
| for stage in range(num_stages): |
| current_output_features = min(int(base_num_features * feat_map_mul_on_downscale ** stage), max_num_features) |
| current_kernel_size = conv_kernel_sizes[stage] |
| current_pool_kernel_size = pool_op_kernel_sizes[stage] |
|
|
| current_stage = StackedConvLayers(current_input_features, current_output_features, current_kernel_size, |
| props, num_blocks_per_stage[stage], current_pool_kernel_size) |
|
|
| self.stages.append(current_stage) |
| self.stage_output_features.append(current_output_features) |
| self.stage_conv_op_kernel_size.append(current_kernel_size) |
| self.stage_pool_kernel_size.append(current_pool_kernel_size) |
|
|
| |
| current_input_features = current_output_features |
|
|
| self.stages = nn.ModuleList(self.stages) |
| self.output_features = current_output_features |
|
|
| def forward(self, x, return_skips=None): |
| """ |
| |
| :param x: |
| :param return_skips: if none then self.default_return_skips is used |
| :return: |
| """ |
| skips = [] |
|
|
| for s in self.stages: |
| x = s(x) |
| if self.default_return_skips: |
| skips.append(x) |
|
|
| if return_skips is None: |
| return_skips = self.default_return_skips |
|
|
| if return_skips: |
| return skips |
| else: |
| return x |
|
|
| @staticmethod |
| def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
| num_modalities, pool_op_kernel_sizes, num_blocks_per_stage_encoder, |
| feat_map_mul_on_downscale, batch_size): |
| npool = len(pool_op_kernel_sizes) - 1 |
|
|
| current_shape = np.array(patch_size) |
|
|
| tmp = num_blocks_per_stage_encoder[0] * np.prod(current_shape) * base_num_features \ |
| + num_modalities * np.prod(current_shape) |
|
|
| num_feat = base_num_features |
|
|
| for p in range(1, npool + 1): |
| current_shape = current_shape / np.array(pool_op_kernel_sizes[p]) |
| num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features) |
| num_convs = num_blocks_per_stage_encoder[p] |
| print(p, num_feat, num_convs, current_shape) |
| tmp += num_convs * np.prod(current_shape) * num_feat |
| return tmp * batch_size |
|
|
|
|
| class PlainConvUNetDecoder(nn.Module): |
| def __init__(self, previous, num_classes, num_blocks_per_stage=None, network_props=None, deep_supervision=False, |
| upscale_logits=False): |
| super(PlainConvUNetDecoder, self).__init__() |
| self.num_classes = num_classes |
| self.deep_supervision = deep_supervision |
| """ |
| We assume the bottleneck is part of the encoder, so we can start with upsample -> concat here |
| """ |
| previous_stages = previous.stages |
| previous_stage_output_features = previous.stage_output_features |
| previous_stage_pool_kernel_size = previous.stage_pool_kernel_size |
| previous_stage_conv_op_kernel_size = previous.stage_conv_op_kernel_size |
|
|
| if network_props is None: |
| self.props = previous.props |
| else: |
| self.props = network_props |
|
|
| if self.props['conv_op'] == nn.Conv2d: |
| transpconv = nn.ConvTranspose2d |
| upsample_mode = "bilinear" |
| elif self.props['conv_op'] == nn.Conv3d: |
| transpconv = nn.ConvTranspose3d |
| upsample_mode = "trilinear" |
| else: |
| raise ValueError("unknown convolution dimensionality, conv op: %s" % str(self.props['conv_op'])) |
|
|
| if num_blocks_per_stage is None: |
| num_blocks_per_stage = previous.num_blocks_per_stage[:-1][::-1] |
|
|
| assert len(num_blocks_per_stage) == len(previous.num_blocks_per_stage) - 1 |
|
|
| self.stage_pool_kernel_size = previous_stage_pool_kernel_size |
| self.stage_output_features = previous_stage_output_features |
| self.stage_conv_op_kernel_size = previous_stage_conv_op_kernel_size |
|
|
| num_stages = len(previous_stages) - 1 |
| |
|
|
| self.tus = [] |
| self.stages = [] |
| self.deep_supervision_outputs = [] |
|
|
| |
| cum_upsample = np.cumprod(np.vstack(self.stage_pool_kernel_size), axis=0).astype(int) |
|
|
| for i, s in enumerate(np.arange(num_stages)[::-1]): |
| features_below = previous_stage_output_features[s + 1] |
| features_skip = previous_stage_output_features[s] |
|
|
| self.tus.append(transpconv(features_below, features_skip, previous_stage_pool_kernel_size[s + 1], |
| previous_stage_pool_kernel_size[s + 1], bias=False)) |
| |
| self.stages.append(StackedConvLayers(2 * features_skip, features_skip, |
| previous_stage_conv_op_kernel_size[s], self.props, |
| num_blocks_per_stage[i])) |
|
|
| if deep_supervision and s != 0: |
| seg_layer = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False) |
| if upscale_logits: |
| upsample = Upsample(scale_factor=cum_upsample[s], mode=upsample_mode) |
| self.deep_supervision_outputs.append(nn.Sequential(seg_layer, upsample)) |
| else: |
| self.deep_supervision_outputs.append(seg_layer) |
|
|
| self.segmentation_output = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False) |
|
|
| self.tus = nn.ModuleList(self.tus) |
| self.stages = nn.ModuleList(self.stages) |
| self.deep_supervision_outputs = nn.ModuleList(self.deep_supervision_outputs) |
|
|
| def forward(self, skips, gt=None, loss=None): |
| |
| |
| |
| skips = skips[::-1] |
| seg_outputs = [] |
|
|
| x = skips[0] |
|
|
| for i in range(len(self.tus)): |
| x = self.tus[i](x) |
| x = torch.cat((x, skips[i + 1]), dim=1) |
| x = self.stages[i](x) |
| if self.deep_supervision and (i != len(self.tus) - 1): |
| tmp = self.deep_supervision_outputs[i](x) |
| if gt is not None: |
| tmp = loss(tmp, gt) |
| seg_outputs.append(tmp) |
|
|
| segmentation = self.segmentation_output(x) |
|
|
| if self.deep_supervision: |
| tmp = segmentation |
| if gt is not None: |
| tmp = loss(tmp, gt) |
| seg_outputs.append(tmp) |
| return seg_outputs[::-1] |
| |
| else: |
| return segmentation |
|
|
| @staticmethod |
| def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
| num_classes, pool_op_kernel_sizes, num_blocks_per_stage_decoder, |
| feat_map_mul_on_downscale, batch_size): |
| """ |
| This only applies for num_blocks_per_stage and convolutional_upsampling=True |
| not real vram consumption. just a constant term to which the vram consumption will be approx proportional |
| (+ offset for parameter storage) |
| :param patch_size: |
| :param num_pool_per_axis: |
| :param base_num_features: |
| :param max_num_features: |
| :return: |
| """ |
| npool = len(pool_op_kernel_sizes) - 1 |
|
|
| current_shape = np.array(patch_size) |
| tmp = (num_blocks_per_stage_decoder[-1] + 1) * np.prod(current_shape) * base_num_features + num_classes * np.prod(current_shape) |
|
|
| num_feat = base_num_features |
|
|
| for p in range(1, npool): |
| current_shape = current_shape / np.array(pool_op_kernel_sizes[p]) |
| num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features) |
| num_convs = num_blocks_per_stage_decoder[-(p+1)] + 1 |
| print(p, num_feat, num_convs, current_shape) |
| tmp += num_convs * np.prod(current_shape) * num_feat |
|
|
| return tmp * batch_size |
|
|
|
|
| class PlainConvUNet(SegmentationNetwork): |
| use_this_for_batch_size_computation_2D = 1167982592.0 |
| use_this_for_batch_size_computation_3D = 1152286720.0 |
|
|
| def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale, |
| pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder, |
| deep_supervision=False, upscale_logits=False, max_features=512, initializer=None): |
| super(PlainConvUNet, self).__init__() |
| self.conv_op = props['conv_op'] |
| self.num_classes = num_classes |
|
|
| self.encoder = PlainConvUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder, |
| feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes, |
| props, default_return_skips=True, max_num_features=max_features) |
| self.decoder = PlainConvUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props, |
| deep_supervision, upscale_logits) |
| if initializer is not None: |
| self.apply(initializer) |
|
|
| def forward(self, x): |
| skips = self.encoder(x) |
| return self.decoder(skips) |
|
|
| @staticmethod |
| def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
| num_modalities, num_classes, pool_op_kernel_sizes, num_blocks_per_stage_encoder, |
| num_blocks_per_stage_decoder, feat_map_mul_on_downscale, batch_size): |
| enc = PlainConvUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
| num_modalities, pool_op_kernel_sizes, |
| num_blocks_per_stage_encoder, |
| feat_map_mul_on_downscale, batch_size) |
| dec = PlainConvUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features, |
| num_classes, pool_op_kernel_sizes, |
| num_blocks_per_stage_decoder, |
| feat_map_mul_on_downscale, batch_size) |
|
|
| return enc + dec |
|
|
| @staticmethod |
| def compute_reference_for_vram_consumption_3d(): |
| patch_size = (160, 128, 128) |
| pool_op_kernel_sizes = ((1, 1, 1), |
| (2, 2, 2), |
| (2, 2, 2), |
| (2, 2, 2), |
| (2, 2, 2), |
| (2, 2, 2)) |
| conv_per_stage_encoder = (2, 2, 2, 2, 2, 2) |
| conv_per_stage_decoder = (2, 2, 2, 2, 2) |
|
|
| return PlainConvUNet.compute_approx_vram_consumption(patch_size, 32, 512, 4, 3, pool_op_kernel_sizes, |
| conv_per_stage_encoder, conv_per_stage_decoder, 2, 2) |
|
|
| @staticmethod |
| def compute_reference_for_vram_consumption_2d(): |
| patch_size = (256, 256) |
| pool_op_kernel_sizes = ( |
| (1, 1), |
| (2, 2), |
| (2, 2), |
| (2, 2), |
| (2, 2), |
| (2, 2), |
| (2, 2) |
| ) |
| conv_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2) |
| conv_per_stage_decoder = (2, 2, 2, 2, 2, 2) |
|
|
| return PlainConvUNet.compute_approx_vram_consumption(patch_size, 32, 512, 4, 3, pool_op_kernel_sizes, |
| conv_per_stage_encoder, conv_per_stage_decoder, 2, 56) |
|
|
|
|
| if __name__ == "__main__": |
| conv_op_kernel_sizes = ((3, 3), |
| (3, 3), |
| (3, 3), |
| (3, 3), |
| (3, 3), |
| (3, 3), |
| (3, 3)) |
| pool_op_kernel_sizes = ((1, 1), |
| (2, 2), |
| (2, 2), |
| (2, 2), |
| (2, 2), |
| (2, 2), |
| (2, 2)) |
| patch_size = (256, 256) |
| batch_size = 56 |
| unet = PlainConvUNet(4, 32, (2, 2, 2, 2, 2, 2, 2), 2, pool_op_kernel_sizes, conv_op_kernel_sizes, |
| get_default_network_config(2, dropout_p=None), 4, (2, 2, 2, 2, 2, 2), False, False, max_features=512).cuda() |
| optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95) |
|
|
| unet.compute_reference_for_vram_consumption_3d() |
| unet.compute_reference_for_vram_consumption_2d() |
|
|
| dummy_input = torch.rand((batch_size, 4, *patch_size)).cuda() |
| dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * 4).round().clamp_(0, 3).cuda().long() |
|
|
| optimizer.zero_grad() |
| skips = unet.encoder(dummy_input) |
| print([i.shape for i in skips]) |
| loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'smooth_in_nom': True, |
| 'do_bg': False, 'rebalance_weights': None, 'background_weight': 1}, {}) |
| output = unet.decoder(skips) |
|
|
| l = loss(output, dummy_gt) |
| l.backward() |
|
|
| optimizer.step() |
|
|
| import hiddenlayer as hl |
| g = hl.build_graph(unet, dummy_input) |
| g.save("/home/fabian/test.pdf") |
|
|
| """conv_op_kernel_sizes = ((3, 3, 3), |
| (3, 3, 3), |
| (3, 3, 3), |
| (3, 3, 3), |
| (3, 3, 3), |
| (3, 3, 3)) |
| pool_op_kernel_sizes = ((1, 1, 1), |
| (2, 2, 2), |
| (2, 2, 2), |
| (2, 2, 2), |
| (2, 2, 2), |
| (2, 2, 2)) |
| patch_size = (160, 128, 128) |
| unet = PlainConvUNet(4, 32, (2, 2, 2, 2, 2, 2), 2, pool_op_kernel_sizes, conv_op_kernel_sizes, |
| get_default_network_config(3, dropout_p=None), 4, (2, 2, 2, 2, 2), False, False, max_features=512).cuda() |
| optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95) |
| |
| unet.compute_reference_for_vram_consumption_3d() |
| unet.compute_reference_for_vram_consumption_2d() |
| |
| dummy_input = torch.rand((2, 4, *patch_size)).cuda() |
| dummy_gt = (torch.rand((2, 1, *patch_size)) * 4).round().clamp_(0, 3).cuda().long() |
| |
| optimizer.zero_grad() |
| skips = unet.encoder(dummy_input) |
| print([i.shape for i in skips]) |
| loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'smooth_in_nom': True, |
| 'do_bg': False, 'rebalance_weights': None, 'background_weight': 1}, {}) |
| output = unet.decoder(skips) |
| |
| l = loss(output, dummy_gt) |
| l.backward() |
| |
| optimizer.step() |
| |
| import hiddenlayer as hl |
| g = hl.build_graph(unet, dummy_input) |
| g.save("/home/fabian/test.pdf")""" |
|
|