Spaces:
Sleeping
Sleeping
| # Import standard python modules | |
| import tensorflow as tf | |
| import numpy as np | |
| # Import custom modules | |
| from . import layer_util | |
| tf.random.set_seed(489154) | |
| class unet3plus: | |
| """ | |
| Class for building a U-Net3+ model. | |
| """ | |
| def __init__(self, | |
| inputs, | |
| filters = [32,64,128,256,512], | |
| rank = 2, | |
| out_channels = 1, | |
| kernel_initializer=tf.keras.initializers.HeNormal(seed=0), | |
| bias_initializer=tf.keras.initializers.Zeros(), | |
| kernel_regularizer=None, | |
| bias_regularizer=None, | |
| add_dropout = False, | |
| padding = 'same', | |
| dropout_rate = 0.5, | |
| kernel_size = 3, | |
| out_kernel_size = 3, | |
| pool_size = 2, | |
| encoder_block_depth = 2, | |
| decoder_block_depth = 1, | |
| batch_norm = True, | |
| activation = 'relu', | |
| out_activation = None, | |
| skip_batch_norm = True, | |
| skip_type = 'encoder', | |
| CGM = False, | |
| deep_supervision = True): | |
| """ | |
| Initialise the U-Net3+ model. | |
| Args: | |
| inputs: Input tensor. | |
| filters: List of filter sizes for each UNet level. | |
| rank: Number of dimensions (2D or 3D). | |
| out_channels: Number of output channels (for segmentation this shall be the number of distinct masks). | |
| kernel_initializer: Initialiser for the convolutional layers. | |
| bias_initializer: Initialiser for the bias terms. | |
| kernel_regularizer: Regulariser for the convolutional layers. | |
| bias_regularizer: Regulariser for the bias terms in convolutional layers. | |
| add_dropout: Whether to add dropout layers. | |
| padding: Padding type for the convolutional layers. | |
| dropout_rate: Dropout rate. | |
| kernel_size: Kernel size for the convolutional layers. | |
| out_kernel_size: Kernel size for the final convolutional layers of the network. | |
| pool_size: Pooling size for the max pooling layers. This can be a tuple specifing the max pool size for each dimension of the input, or a single integer specifying the same size for all dimensions. | |
| encoder_block_depth: Number of convolutional blocks in each level of the encoding arm. | |
| decoder_block_depth: Number of convolutional blocks in each level of the decoding arm. | |
| batch_norm: Whether to use batch normalization. | |
| activation: Activation function for the convolutional layers. | |
| out_activation: Activation function for the output layer. For binary segmentation this shall be 'sigmoid' or 'softmax'. | |
| skip_batch_norm: Whether to use batch normalization in the skip connections. | |
| skip_type: Type of skip connections to use in the model ('encoder', 'decoder', or 'standard_unet'). | |
| CGM: Whether to use CGM in the model for segmentation (Classification Guided Module). | |
| deep_supervision: Whether to use deep supervision. | |
| """ | |
| # Assign parameters | |
| self.inputs = inputs | |
| self.filters = filters | |
| self.levels = len(filters) | |
| self.rank = rank | |
| self.out_channels = out_channels | |
| self.encoder_block_depth = encoder_block_depth | |
| self.decoder_block_depth = decoder_block_depth | |
| self.kernel_size = kernel_size | |
| self.add_dropout = add_dropout | |
| self.dropout_rate = dropout_rate | |
| self.skip_type = skip_type | |
| self.skip_batch_norm = skip_batch_norm | |
| self.batch_norm = batch_norm | |
| self.activation = activation | |
| self.out_activation = out_activation | |
| self.CGM = CGM | |
| self.deep_supervision = deep_supervision | |
| # Assign pool size based on given tuple, or if single integer is provided, assign the same value to all dimensions using the rank as a guide for the number of dimensions | |
| if isinstance(pool_size,tuple): | |
| self.pool_size = pool_size | |
| else: | |
| self.pool_size = tuple([pool_size for _ in range(rank)]) | |
| # Assign kernel sizes based on given tuple, or if single integer is provided, assign the same value to all dimensions using the rank as a guide for the number of dimensions | |
| if isinstance(kernel_size,tuple): | |
| self.kernel_size = kernel_size | |
| else: | |
| self.kernel_size = tuple([kernel_size for _ in range(rank)]) | |
| if isinstance(out_kernel_size,tuple): | |
| self.out_kernel_size = out_kernel_size | |
| else: | |
| self.out_kernel_size = tuple([out_kernel_size for _ in range(rank)]) | |
| # Create the conv and out conv config dictionaries for the conv and out conv layers | |
| self.conv_config = dict(kernel_size = self.kernel_size, | |
| padding = padding, | |
| kernel_initializer = kernel_initializer, | |
| bias_initializer = bias_initializer, | |
| kernel_regularizer = kernel_regularizer, | |
| bias_regularizer = bias_regularizer) | |
| self.out_conv_config = dict(kernel_size = out_kernel_size, | |
| padding = padding, | |
| kernel_initializer = kernel_initializer, | |
| bias_initializer = bias_initializer, | |
| kernel_regularizer = kernel_regularizer, | |
| bias_regularizer = bias_regularizer) | |
| def aggregate_and_decode(self, input_list, level): | |
| """ | |
| Aggregates the inputs for the decoder levels and applies convolution to get the output of the decoder level. | |
| Args: | |
| input_list: List of inputs to the decoder to be aggregated. | |
| level: Current decoder level. | |
| """ | |
| X = layer_util.ResizeAndConcatenate(name = f'D{level}_input', axis = -1)(input_list) # Takes the various inputs to a decoder level, resizes them to the 1st input size in the list and the concatenates them all. | |
| X = self.conv_block(X, self.filters[level-1], block_depth = self.decoder_block_depth, conv_block_purpose = 'Decoder', level=level) # Performs a decoder block convolution of the concatenated input (i.e. the concatenated list of filters) | |
| return X | |
| def deep_sup(self, inputs, level): | |
| """ | |
| If deep supervision is used, then the network will output a prediction at each level of the decoder. | |
| This function upsamples the output of a decoder level, convolves it and then applies the output activation function (i.e. to get to the final output). | |
| If deep supervision is not used, then the network will only output a prediction at the final level of the decoder. | |
| Args: | |
| inputs: Input tensor. | |
| level: Current decoder level. | |
| """ | |
| conv = layer_util.get_nd_layer('Conv', self.rank) # gets a convolutional layer of the specified rank (2D or 3D) | |
| upsamp = layer_util.get_nd_layer('UpSampling', self.rank) # gets an upsampling layer of the specified rank (2D or 3D) | |
| size = tuple(np.array(self.pool_size)** (abs(level-1))) # This specifies the amount of upsampling needed to get to the correct final output size. It is the maxpool size to the power of the current decoder level minus one. | |
| if self.rank == 2: | |
| upsamp_config = dict(size=size, interpolation='bilinear') # use bilinear interpolation for 2D upsampling | |
| else: | |
| upsamp_config = dict(size=size) # for 3D upsampling, you cannot do bilinear interpolation, so this just uses the default upsampling method. | |
| X = inputs | |
| X = conv(self.out_channels, activation = None, **self.out_conv_config, name = f'deepsup_conv_{level}_1')(X) # Convolves the input to get the correct number of output channels | |
| if level != 1: | |
| X = upsamp(**upsamp_config, name = f'deepsup_upsamp_{level}')(X) # Upsamples the convolved input to the correct size for the final output | |
| X = conv(self.out_channels, activation = None, **self.out_conv_config, name = f'deepsup_conv_{level}_2')(X) # Convolves the upsampled input to get the correct number of output channels (e.g. to correct artifacts due to upsampling) | |
| if self.out_activation: | |
| X = tf.keras.layers.Activation(activation = self.out_activation, name = f'deepsup_activation_{level}')(X) # Applies the output activation function to get the final output | |
| return X | |
| def skip_connection(self, inputs, to_level, from_level): | |
| """ | |
| This function takes an input tensor and processes it as a skip connection to the decoder level. | |
| Args: | |
| inputs: Input tensor. | |
| to_level: Current decoder level. | |
| from_level: Level of UNet the input tensor is from. | |
| """ | |
| conv = layer_util.get_nd_layer('Conv', self.rank) # gets a convolutional layer of the specified rank (2D or 3D) | |
| level_diff = from_level - to_level # difference between level of decoder and level of UNet the input tensor is from | |
| size = tuple(np.array(self.pool_size)** (abs(level_diff))) # This specifies the amount of upsampling needed to get to the correct size for decoder level. It is the maxpool size to the power of the level difference. | |
| maxpool = layer_util.get_nd_layer('MaxPool', self.rank) # gets a maxpool layer of the specified rank (2D or 3D) | |
| upsamp = layer_util.get_nd_layer('UpSampling', self.rank) # gets an upsampling layer of the specified rank (2D or 3D) | |
| if self.rank == 2: | |
| upsamp_config = dict(size=size, interpolation='bilinear') # use bilinear interpolation for 2D upsampling | |
| else: | |
| upsamp_config = dict(size=size) # for 3D upsampling, you cannot do bilinear interpolation, so this just uses the default upsampling method. | |
| X = inputs | |
| if to_level < from_level: # If coming from a deeper level of the UNet, then we need to upsample the input tensor to the correct size for the decoder level | |
| X = upsamp(**upsamp_config, name = f'Skip_Upsample_{from_level}_{to_level}')(X) | |
| elif to_level > from_level: # If coming from a shallower level of the UNet, then we need to maxpool the input tensor to the correct size for the decoder level | |
| X = maxpool(pool_size = size, name = f'Skip_Maxpool_{from_level}_{to_level}')(X) | |
| if self.skip_batch_norm: # If using batch normalization in the skip connections, then apply it within the conv block | |
| X = self.conv_block(X, self.filters[to_level-1], block_depth = self.decoder_block_depth, conv_block_purpose ='Skip', level = f'{from_level}_{to_level}') # applies conv block to the upsampled/maxpooled input tensor (with batch normalization) | |
| else: | |
| X = conv(self.filters[to_level-1],**self.conv_config, name = f'Skip_Conv_{from_level}_{to_level}')(X) # applies conv layer to the upsampled/maxpooled input tensor (without batch normalization) | |
| return X # note: returns the output of a single skip connection, but does not yet concatenate the output to the other skip outputs or existing decoder level filters. | |
| def conv_block(self, inputs, filters, block_depth, conv_block_purpose, level): | |
| """ | |
| This function creates a convolutional block with the specified number of stacks and filters. | |
| Args: | |
| inputs: Input tensor. | |
| filters: Number of filters for the convolutional layers. | |
| block_depth: Number of convolutional stacks in the block. | |
| conv_block_purpose: Type of conv block (Encoder, Decoder, Skip). | |
| level: Current level level. | |
| """ | |
| conv = layer_util.get_nd_layer('Conv', self.rank) # gets a convolutional layer of the specified rank (2D or 3D) | |
| X = inputs | |
| for i in range(block_depth): # replicate the conv block, depth number of times | |
| X = conv(filters, **self.conv_config, name = f'{conv_block_purpose}{level}_Conv_{i+1}')(X) # applies conv layer to the input tensor | |
| if self.batch_norm: # If using batch normalization, then apply it after the conv layer | |
| X = tf.keras.layers.BatchNormalization(axis=-1, name = f'{conv_block_purpose}{level}_BN_{i+1}')(X) | |
| if self.activation: # If using an activation function, then apply it after the conv layer | |
| X = tf.keras.layers.Activation(activation = self.activation, name = f'{conv_block_purpose}{level}_Activation_{i+1}')(X) | |
| return X | |
| def encode(self, inputs, level, block_depth): | |
| """ | |
| Creates the encoding block of the U-Net3+ model. | |
| Args: | |
| inputs: Input tensor. | |
| level: Current level level. | |
| block_depth: Number of convolutional stacks in the block. | |
| """ | |
| maxpool = layer_util.get_nd_layer('MaxPool', self.rank) # gets a maxpool layer of the specified rank (2D or 3D) | |
| level -= 1 # python indexing | |
| filters = self.filters[level] # get the number of filters for the current level | |
| X = inputs | |
| if level != 0: # 0 is the input level, so we do not need to maxpool it | |
| X = maxpool(pool_size=self.pool_size, name = f'encoding_{level}_maxpool')(X) # maxpool the input tensor to the correct size for the next level | |
| X = self.conv_block(X, filters, block_depth, conv_block_purpose = 'Encoder', level = level+1) # applies conv block to the maxpooled input tensor | |
| if level == (self.levels-1) and self.add_dropout: # Check if level is the bottom level of the UNet, and if so, apply dropout if specified | |
| X = tf.keras.layers.Dropout(rate = self.dropout_rate, name = f'Encoder{level+1}_dropout')(X) | |
| return X | |
| def outputs(self): | |
| """ | |
| This is the build function for the U-Net3+ model. | |
| """ | |
| XE = [self.inputs] # This is a list of encoder level outputs, starting with the input tensor | |
| for i in range(self.levels): # for each level of the UNet, we apply an encoding block to the output of the previous level | |
| XE.append(self.encode(XE[i], level = i+1, block_depth = self.encoder_block_depth)) | |
| XD = [XE[-1]] # This is a list of decoder level outputs, starting with the output of the last encoder level | |
| if self.skip_type == 'encoder': | |
| # If using encoder-type skip connections, then we apply skip connections from every encoder level to the current decoder level - except the encoder level one deeper. For this level, we use the output of the last decoder level. | |
| for decoder_level in range(self.levels-1,0,-1): # build the decoder levels in reverse order | |
| input_contributions = [] | |
| for unet_level in range(1,self.levels+1): | |
| if unet_level == decoder_level+1: # If the unet level is one deeper than the decoder level, then we get a skip connection from the output of the last decoder level | |
| input_contributions.append(self.skip_connection(XD[-1], decoder_level, unet_level)) | |
| else: # Otherwise we get a skip connection from the output of the encoder level | |
| input_contributions.append(self.skip_connection(XE[unet_level], decoder_level, unet_level)) | |
| XD.append(self.aggregate_and_decode(input_contributions,decoder_level)) # aggregate and conv the skip connections to the current decoder level. This gives the output of the decoder level. Append this to the list of decoder level outputs. | |
| elif self.skip_type == 'decoder': | |
| # If using decoder-type skip connections, then | |
| for decoder_level in range(self.levels-1,0,-1): | |
| skip_contributions = [] | |
| # Append skips from encoder | |
| for encoder_level in range(1,decoder_level+1): # All encoders shallower or equal to the decoder level contribute a skip connection | |
| skip_contributions.append(self.skip_connection(XE[encoder_level], decoder_level, encoder_level)) | |
| # Append skips from decoder | |
| for i in range(len(XD)-1,-1,-1): # All decoders deeper than the current decoder level contribute a skip connection (note: XD is build iteratively in a loop from the deepest level upwards. Therefore at each stage of the loop, XD grows and deeper decoder levels contribute skip connections to the current decoder level) | |
| skip_contributions.append(self.skip_connection(XD[i], decoder_level, (self.levels-i))) | |
| XD.append(self.aggregate_and_decode(skip_contributions,decoder_level)) # aggregate and conv the skip connections to the current decoder level. This gives the output of the decoder level. Append this to the list of decoder level outputs. | |
| elif self.skip_type == 'standard_unet': | |
| # If standard_unet type skips, then at each decoder level, we get a skip connection from the corresponding encoder level | |
| for decoder_level in range(self.levels-1,0,-1): | |
| skip_contributions = [XE[decoder_level],self.skip_connection(XD[-1],decoder_level,decoder_level+1)] | |
| XD.append(self.aggregate_and_decode(skip_contributions,decoder_level)) # aggregate and conv the skip connections to the current decoder level. | |
| else: | |
| raise ValueError(f"Invalid skip_type") | |
| if self.deep_supervision == True: | |
| XD = [self.deep_sup(xd, self.levels-i) for i,xd in enumerate(XD)] # If deep supervision is used, then we apply deep supervision to each decoder level output | |
| return XD | |
| else: | |
| XD[-1] = self.deep_sup(XD[-1],1) # If deep supervision is not used, then we only apply deep supervision to the final decoder level output | |
| return XD[-1] |