| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.nn import LazyConv3d, MaxPool3d, BatchNorm3d |
|
|
| from torch.nn.modules import Module |
| from torch.nn.modules import ReLU |
| from torch.nn.modules.dropout import Dropout |
| from torch.nn.modules.instancenorm import InstanceNorm3d |
| from custom_modules import LazyConvDropoutNormNonlinCat, ModularConvLayers, LazyConvBottleneckLayer |
|
|
|
|
| class modular_hdunet_encoder(Module): |
| """HDUnet encoder with modular parameters |
| """ |
|
|
| def __init__(self, base_num_filter, num_blocks_per_stage, num_stages, pool_kernel_sizes, conv_kernel_sizes, |
| padding='same', conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d, |
| dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d, |
| pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU): |
| """Object creation |
| |
| :param base_num_filter: base number of filters (output channels). |
| :param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage). |
| :param num_stages: number of stages. |
| :param pool_kernel_sizes: last conv layer is strided => we use this parameter to set its kernel size and stride |
| (can be different for each stage). |
| Please note that this parameter is retrieved in our modular decoder and used as the scale factor (upsampling). |
| :param conv_kernel_sizes: kernel size (can be different for each stage). |
| :param padding: padding used, default is 'same'. |
| :param conv_type: type of convolution used, default is a lazy convolution using: |
| - dropout; |
| - normalization; |
| - nonlinear activation function; |
| - concatenation. |
| Must be a torch Module (should be a custom Module). |
| :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. |
| :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. |
| :param dropout_rate: dropout rate used by dropout, default is 0. |
| :param expansion_rate: expansion rate used to modify the number of filters, default is 1. |
| :param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module. |
| :param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2). |
| :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. |
| """ |
| super(modular_hdunet_encoder, self).__init__() |
| self.base_num_filter = base_num_filter |
| self.num_blocks_per_stage = num_blocks_per_stage |
| self.num_stages = num_stages |
| self.pool_kernel_sizes = pool_kernel_sizes |
| self.conv_kernel_sizes = conv_kernel_sizes |
| self.padding = padding |
| self.conv_type = conv_type |
| self.norm_type = norm_type |
| self.dropout_type = dropout_type |
| self.dropout_rate = dropout_rate |
| self.nonlin = nonlin |
| self.expansion_rate = expansion_rate |
| self.pooling_type = pooling_type |
| self.pooling_kernel_size = pooling_kernel_size |
|
|
| self.stages = [] |
| self.pooling_stages = [] |
| self.end_stages = [] |
| self.stage_output_features = [] |
| self.stage_pool_kernel_size = [] |
| self.stage_conv_kernel_size = [] |
|
|
| assert len(pool_kernel_sizes) == len(conv_kernel_sizes) == num_stages |
|
|
| if not isinstance(num_blocks_per_stage, (list, tuple)): |
| num_blocks_per_stage = [num_blocks_per_stage] * num_stages |
| else: |
| assert len(num_blocks_per_stage) == num_stages |
|
|
| self.num_blocks_per_stage = num_blocks_per_stage |
|
|
| current_out_channels = 0 |
| |
| for stage in range(num_stages): |
| current_out_channels = np.round((expansion_rate ** stage) * self.base_num_filter) |
| current_num_blocks_per_stage = num_blocks_per_stage[stage] |
| current_pool_kernel_size = pool_kernel_sizes[stage] |
| current_kernel_size = conv_kernel_sizes[stage] |
|
|
| current_stage = ModularConvLayers(output_channels=current_out_channels, |
| num_conv_layers=current_num_blocks_per_stage, |
| kernel_size=current_kernel_size, padding=padding, conv_type=conv_type, |
| norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, |
| nonlin=self.nonlin) |
|
|
| self.pooling_stages.append(pooling_type(kernel_size=current_pool_kernel_size)) |
|
|
| |
| current_end_stage = nn.Sequential( |
| LazyConv3d(out_channels=current_out_channels, kernel_size=current_pool_kernel_size, |
| stride=current_pool_kernel_size, padding=0), nonlin(), BatchNorm3d(current_out_channels) |
| ) |
|
|
| self.stages.append(current_stage) |
| self.end_stages.append(current_end_stage) |
| self.stage_output_features.append(current_out_channels) |
| self.stage_pool_kernel_size.append(current_pool_kernel_size) |
| self.stage_conv_kernel_size.append(current_kernel_size) |
|
|
| self.stages = nn.ModuleList(self.stages) |
| self.pooling_stages = nn.ModuleList(self.pooling_stages) |
| self.end_stages = nn.ModuleList(self.end_stages) |
| self.output_features = current_out_channels |
| |
| def forward(self, x): |
| """Forward inputs through the layer |
| |
| :param x: the input to forward. |
| :return: an array containing the results of the input at each stage of the down-sampling (before concatenation) |
| which will be used in the decoder later on. The last value of the array is the very last value provided by the |
| encoder (after concatenation) and will be used in the bottleneck. Therefore, provided x is the number of stages |
| there are x + 1 values in the array. |
| """ |
| skips = [] |
|
|
| for i, stage in enumerate(self.stages): |
| x = stage(x) |
| buff = self.pooling_stages[i](x) |
| tmp = self.end_stages[i](x) |
| skips.append(x) |
| x = torch.cat([tmp, buff], dim=1) |
| skips.append(x) |
| |
| return skips |
|
|
|
|
| class modular_hdunet_bottleneck(Module): |
| """HDUnet bottleneck with modular parameters |
| """ |
|
|
| def __init__(self, base_num_filter, num_stages, conv_kernel_sizes, padding='same', num_steps_bottleneck=4, |
| conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d, |
| dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU): |
| """Object creation |
| |
| :param base_num_filter: base number of filters (output channels). |
| :param num_stages: number of stages of the encoder. |
| :param conv_kernel_sizes: kernel size (can be different for each stage). |
| :param padding: padding used, default is 'same'. |
| :param num_steps_bottleneck: number of steps in the bottleneck, default is 4. |
| :param conv_type: type of convolution used, default is a lazy convolution using: |
| - dropout; |
| - normalization; |
| - nonlinear activation function. |
| Must be a torch Module (should be a custom Module). |
| :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. |
| :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. |
| :param dropout_rate: dropout rate used by dropout, default is 0. |
| :param expansion_rate: expansion rate used to modify the number of filters, default is 1. |
| :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. |
| """ |
| super(modular_hdunet_bottleneck, self).__init__() |
| self.base_num_filter = base_num_filter |
| self.conv_kernel_sizes = conv_kernel_sizes |
| self.padding = padding |
| self.num_steps_bottleneck = num_steps_bottleneck |
| self.conv_type = conv_type |
| self.norm_type = norm_type |
| self.dropout_type = dropout_type |
| self.dropout_rate = dropout_rate |
| self.expansion_rate = expansion_rate |
| self.nonlin = nonlin |
|
|
| encoder_output_features = (expansion_rate ** num_stages * base_num_filter) |
|
|
| self.stages = [] |
| self.step_conv_kernel_size = [] |
|
|
| assert len(conv_kernel_sizes) == num_steps_bottleneck |
|
|
| |
| for step in range(num_steps_bottleneck): |
| current_kernel_size = conv_kernel_sizes[step] |
| self.stages.append( |
| conv_type(output_channels=encoder_output_features, kernel_size=current_kernel_size, padding=padding, |
| norm_type=norm_type, dropout_type=dropout_type, |
| dropout_rate=dropout_rate, nonlin=self.nonlin) |
| ) |
|
|
| self.stages = nn.ModuleList(self.stages) |
| |
|
|
| def forward(self, x): |
| """Forward inputs through the layer |
| |
| :param x: the input to forward. At each step the input is concatenated with |
| its result in order to produce the input of the next bottleneck layer. |
| :return: the input forwarded through the layer. |
| """ |
| for stage in self.stages: |
| buff = stage(x) |
| x = torch.cat([buff, x], dim=1) |
| return x |
|
|
|
|
| class modular_hdunet_decoder(Module): |
| """HDUnet decoder with modular parameters |
| """ |
|
|
| def __init__(self, previous, base_num_filter, num_blocks_per_stage=None, padding='same', |
| conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d, |
| dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU): |
| """Object creation |
| |
| :param previous: the encoder which was previously used in the model. It is useful to retrieve some information |
| that do not change such as the number of stages or the kernel sizes of each stages per example. |
| :param base_num_filter: base number of filters (output channels). |
| :param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage). |
| If set to None, it will be same than the encoder (reversed). |
| :param padding: padding used, default is 'same'. |
| :param conv_type: type of convolution used, default is a lazy convolution using: |
| - dropout; |
| - normalization; |
| - nonlinear activation function; |
| - concatenation. |
| Must be a torch Module (should be a custom Module). |
| :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. |
| :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. |
| :param dropout_rate: dropout rate used by dropout, default is 0. |
| :param expansion_rate: expansion rate used to modify the number of filters, default is 1. |
| :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. |
| """ |
| super(modular_hdunet_decoder, self).__init__() |
| self.base_num_filter = base_num_filter |
| self.num_blocks_per_stage = num_blocks_per_stage |
| self.padding = padding |
| self.conv_type = conv_type |
| self.norm_type = norm_type |
| self.dropout_type = dropout_type |
| self.dropout_rate = dropout_rate |
| self.expansion_rate = expansion_rate |
| self.nonlin = nonlin |
|
|
| |
| |
| self.skips = [] |
|
|
| |
| |
| previous_stages = previous.stages |
| previous_stage_output_features = previous.stage_output_features |
| previous_stage_pool_kernel_size = previous.stage_pool_kernel_size |
| previous_stage_conv_kernel_size = previous.stage_conv_kernel_size |
|
|
| |
| self.num_stages = len(previous_stages) |
|
|
| |
| if num_blocks_per_stage is None: |
| self.num_blocks_per_stage = previous.num_blocks_per_stage[:][::-1] |
|
|
| if not isinstance(self.num_blocks_per_stage, (list, tuple)): |
| self.num_blocks_per_stage = [self.num_blocks_per_stage] * self.num_stages |
| else: |
| assert len(self.num_blocks_per_stage) == self.num_stages |
|
|
| |
| assert len(self.num_blocks_per_stage) == len(previous.num_blocks_per_stage) |
|
|
| self.stage_output_features = previous_stage_output_features |
| self.stage_pool_kernel_size = previous_stage_pool_kernel_size[::-1] |
| self.stage_conv_kernel_size = previous_stage_conv_kernel_size[::-1] |
|
|
| self.stages = [] |
|
|
| number_half_layer = self.num_stages + 1 |
| |
| for stage in range(self.num_stages): |
| current_out_channels = np.round( |
| (expansion_rate ** (2 * number_half_layer - (stage + number_half_layer) - 1)) * self.base_num_filter) |
| current_num_blocks_per_stage = self.num_blocks_per_stage[stage] |
| current_pool_kernel_size = self.stage_pool_kernel_size[stage] |
| current_kernel_size = self.stage_conv_kernel_size[stage] |
| self.stages.append( |
| ModularConvLayers(output_channels=current_out_channels, kernel_size=current_kernel_size, |
| padding=padding, pool_size=current_pool_kernel_size, conv_type=conv_type, |
| norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, |
| num_conv_layers=current_num_blocks_per_stage, nonlin=self.nonlin, upsampling=True)) |
|
|
| self.stages = nn.ModuleList(self.stages) |
|
|
| def forward(self, x): |
| """Forward inputs through the layer |
| |
| :param x: the input to forward. |
| :return: the input forwarded through the layer. |
| """ |
| for i, stage in enumerate(self.stages): |
| x = stage(x, self.skips[i + 1]) |
| return x |
|
|
| def set_skips(self, skips): |
| self.skips = skips |
|
|
|
|
| |
| |
| |
| |
| |
| |
| class modular_hdunet(Module): |
| """HDUnet model with modular parameters |
| """ |
|
|
| def __init__(self, base_num_filter, num_blocks_per_stage_encoder, num_stages, |
| pool_kernel_sizes, conv_kernel_sizes, conv_bottleneck_kernel_sizes, num_blocks_per_stage_decoder=None, |
| padding='same', num_steps_bottleneck=4, conv_type: Module = LazyConvDropoutNormNonlinCat, |
| bottleneck_conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d, |
| dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d, |
| pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU): |
| """Object creation |
| |
| :param base_num_filter: base number of filters (output channels). |
| :param num_blocks_per_stage_encoder: number of convolutional block per stage for the encoder |
| (can be different for each stage). |
| :param num_stages: number of stages. |
| :param pool_kernel_sizes: last convolutional layer of the encoder is strided => we use this parameter |
| to set its kernel size and stride (can be different for each stage). |
| :param conv_kernel_sizes: kernel size for the encoder and decoder (can be different for each stage). |
| :param conv_bottleneck_kernel_sizes: kernel size for the bottleneck (can be different for each stage). |
| :param padding: padding used, default is 'same'. |
| :param num_blocks_per_stage_decoder: number of convolutional block per stage for the decoder |
| (can be different for each stage). Default is None (it will be the same as the encoder). |
| :param num_steps_bottleneck: number of steps in the bottleneck, default is 4. |
| :param conv_type: type of convolution used, default is a lazy convolution using: |
| - dropout; |
| - normalization; |
| - nonlinear activation function; |
| - concatenation. |
| Must be a torch Module (should be a custom Module). |
| :param bottleneck_conv_type: type of convolution used in the bottleneck, default is a lazy convolution using: |
| - dropout; |
| - normalization; |
| - nonlinear activation function. |
| Must be a torch Module (should be a custom Module). |
| :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. |
| :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. |
| :param dropout_rate: dropout rate used by dropout, default is 0. |
| :param expansion_rate: expansion rate used to modify the number of filters, default is 1. |
| :param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module. |
| :param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2). |
| :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. |
| """ |
| super(modular_hdunet, self).__init__() |
| self.nonlin = nonlin |
| self.encoder = modular_hdunet_encoder(base_num_filter=base_num_filter, |
| num_blocks_per_stage=num_blocks_per_stage_encoder, num_stages=num_stages, |
| pool_kernel_sizes=pool_kernel_sizes, conv_kernel_sizes=conv_kernel_sizes, |
| padding=padding, conv_type=conv_type, norm_type=norm_type, |
| dropout_type=dropout_type, dropout_rate=dropout_rate, |
| expansion_rate=expansion_rate, pooling_type=pooling_type, |
| pooling_kernel_size=pooling_kernel_size, nonlin=self.nonlin) |
|
|
| self.bottleNeck = modular_hdunet_bottleneck(base_num_filter=base_num_filter, num_stages=num_stages, |
| conv_kernel_sizes=conv_bottleneck_kernel_sizes, padding=padding, |
| num_steps_bottleneck=num_steps_bottleneck, |
| conv_type=bottleneck_conv_type, norm_type=norm_type, |
| dropout_type=dropout_type, dropout_rate=dropout_rate, |
| expansion_rate=expansion_rate, nonlin=self.nonlin) |
|
|
| self.decoder = modular_hdunet_decoder(previous=self.encoder, base_num_filter=base_num_filter, |
| num_blocks_per_stage=num_blocks_per_stage_decoder, padding=padding, |
| conv_type=conv_type, norm_type=norm_type, dropout_type=dropout_type, |
| dropout_rate=dropout_rate, expansion_rate=expansion_rate, |
| nonlin=self.nonlin) |
| |
| self.last_block = nn.Sequential( |
| LazyConv3d(out_channels=1, kernel_size=(3, 3, 3), padding='same'), |
| nonlin() |
| ) |
|
|
| def forward(self, x): |
| """Forward inputs through the layer |
| (using the forward functions of the encoder/bottleneck/decoder) |
| |
| :param x: the input to forward. |
| :return: the input forwarded through the layer. |
| """ |
| skips = self.encoder(x) |
| tmp = self.bottleNeck(skips[-1]) |
|
|
| |
| |
| skips = skips[:-1] |
| skips.append(tmp) |
|
|
| |
| skips = skips[::-1] |
| self.decoder.set_skips(skips) |
| x = skips[0] |
| x = self.decoder(x) |
| return self.last_block(x) |
|
|