File size: 17,583 Bytes
b03d3b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# Import standard python modules
import tensorflow as tf
import numpy as np

# Import custom modules
from . import layer_util

tf.random.set_seed(489154)

class unet3plus:
   """
   Class for building a U-Net3+ model.
   """

   def __init__(self,
                inputs,
                filters = [32,64,128,256,512],
                rank = 2,
                out_channels = 1,
                kernel_initializer=tf.keras.initializers.HeNormal(seed=0),
                bias_initializer=tf.keras.initializers.Zeros(),
                kernel_regularizer=None,
                bias_regularizer=None,
                add_dropout = False,
                padding = 'same',
                dropout_rate = 0.5,
                kernel_size = 3,
                out_kernel_size = 3,
                pool_size = 2,
                encoder_block_depth = 2,
                decoder_block_depth = 1,
                batch_norm = True,
                activation = 'relu',
                out_activation = None,
                skip_batch_norm = True,
                skip_type = 'encoder',
                CGM = False,
                deep_supervision = True):
       
       """
       Initialise the U-Net3+ model.
       Args:
           inputs: Input tensor.
           filters: List of filter sizes for each UNet level.
           rank: Number of dimensions (2D or 3D).
           out_channels: Number of output channels (for segmentation this shall be the number of distinct masks).
           kernel_initializer: Initialiser for the convolutional layers.
           bias_initializer: Initialiser for the bias terms.
           kernel_regularizer: Regulariser for the convolutional layers.
           bias_regularizer: Regulariser for the bias terms in convolutional layers.
           add_dropout: Whether to add dropout layers.
           padding: Padding type for the convolutional layers.
           dropout_rate: Dropout rate.
           kernel_size: Kernel size for the convolutional layers.
           out_kernel_size: Kernel size for the final convolutional layers of the network.
           pool_size: Pooling size for the max pooling layers. This can be a tuple specifing the max pool size for each dimension of the input, or a single integer specifying the same size for all dimensions.
           encoder_block_depth: Number of convolutional blocks in each level of the encoding arm.
           decoder_block_depth: Number of convolutional blocks in each level of the decoding arm.
           batch_norm: Whether to use batch normalization.
           activation: Activation function for the convolutional layers.
           out_activation: Activation function for the output layer. For binary segmentation this shall be 'sigmoid' or 'softmax'.
           skip_batch_norm: Whether to use batch normalization in the skip connections.
           skip_type: Type of skip connections to use in the model ('encoder', 'decoder', or 'standard_unet').
           CGM: Whether to use CGM in the model for segmentation (Classification Guided Module).
           deep_supervision: Whether to use deep supervision.
        """
        # Assign parameters
       self.inputs = inputs
       self.filters = filters
       self.levels = len(filters)
       self.rank = rank
       self.out_channels = out_channels
       self.encoder_block_depth = encoder_block_depth
       self.decoder_block_depth = decoder_block_depth
       self.kernel_size = kernel_size
       self.add_dropout = add_dropout
       self.dropout_rate = dropout_rate
       self.skip_type = skip_type  
       self.skip_batch_norm = skip_batch_norm
       self.batch_norm = batch_norm
       self.activation = activation
       self.out_activation = out_activation
       self.CGM = CGM
       self.deep_supervision = deep_supervision
       # Assign pool size based on given tuple, or if single integer is provided, assign the same value to all dimensions using the rank as a guide for the number of dimensions
       if isinstance(pool_size,tuple):
           self.pool_size = pool_size
       else:
           self.pool_size = tuple([pool_size for _ in range(rank)])
        # Assign kernel sizes based on given tuple, or if single integer is provided, assign the same value to all dimensions using the rank as a guide for the number of dimensions
       if isinstance(kernel_size,tuple):
           self.kernel_size = kernel_size
       else:
           self.kernel_size = tuple([kernel_size for _ in range(rank)])
       if isinstance(out_kernel_size,tuple):
           self.out_kernel_size = out_kernel_size
       else:
           self.out_kernel_size = tuple([out_kernel_size for _ in range(rank)])
       # Create the conv and out conv config dictionaries for the conv and out conv layers
       self.conv_config = dict(kernel_size = self.kernel_size,
                          padding = padding,
                          kernel_initializer = kernel_initializer,
                          bias_initializer = bias_initializer,
                          kernel_regularizer = kernel_regularizer,
                          bias_regularizer = bias_regularizer)
       self.out_conv_config = dict(kernel_size = out_kernel_size,
                          padding = padding,
                          kernel_initializer = kernel_initializer,
                          bias_initializer = bias_initializer,
                          kernel_regularizer = kernel_regularizer,
                          bias_regularizer = bias_regularizer)
   
   def aggregate_and_decode(self, input_list, level):
    """
    Aggregates the inputs for the decoder levels and applies convolution to get the output of the decoder level.

    Args:
        input_list: List of inputs to the decoder to be aggregated.
        level: Current decoder level.
    """
    X = layer_util.ResizeAndConcatenate(name = f'D{level}_input', axis = -1)(input_list) # Takes the various inputs to a decoder level, resizes them to the 1st input size in the list and the concatenates them all.
    X = self.conv_block(X, self.filters[level-1], block_depth = self.decoder_block_depth, conv_block_purpose = 'Decoder', level=level) # Performs a decoder block convolution of the concatenated input (i.e. the concatenated list of filters)
    return X
   
   def deep_sup(self, inputs, level):
    """
    If deep supervision is used, then the network will output a prediction at each level of the decoder.
    This function upsamples the output of a decoder level, convolves it and then applies the output activation function (i.e. to get to the final output).
    If deep supervision is not used, then the network will only output a prediction at the final level of the decoder.

    Args:
        inputs: Input tensor.
        level: Current decoder level.
    """
    conv = layer_util.get_nd_layer('Conv', self.rank) # gets a convolutional layer of the specified rank (2D or 3D)
    upsamp = layer_util.get_nd_layer('UpSampling', self.rank) # gets an upsampling layer of the specified rank (2D or 3D)
    size = tuple(np.array(self.pool_size)** (abs(level-1))) # This specifies the amount of upsampling needed to get to the correct final output size. It is the maxpool size to the power of the current decoder level minus one.
    if self.rank == 2:
        upsamp_config = dict(size=size, interpolation='bilinear') # use bilinear interpolation for 2D upsampling
    else:
        upsamp_config = dict(size=size) # for 3D upsampling, you cannot do bilinear interpolation, so this just uses the default upsampling method.
    X = inputs  
    X = conv(self.out_channels, activation = None, **self.out_conv_config, name = f'deepsup_conv_{level}_1')(X) # Convolves the input to get the correct number of output channels
    if level != 1:
        X = upsamp(**upsamp_config, name = f'deepsup_upsamp_{level}')(X) # Upsamples the convolved input to the correct size for the final output
    X = conv(self.out_channels, activation = None, **self.out_conv_config, name = f'deepsup_conv_{level}_2')(X) # Convolves the upsampled input to get the correct number of output channels (e.g. to correct artifacts due to upsampling)
    if self.out_activation:
        X = tf.keras.layers.Activation(activation = self.out_activation, name = f'deepsup_activation_{level}')(X) # Applies the output activation function to get the final output
    return X
       
       
       
   def skip_connection(self, inputs, to_level, from_level):
    """
    This function takes an input tensor and processes it as a skip connection to the decoder level.

    Args:
        inputs: Input tensor.
        to_level: Current decoder level.
        from_level: Level of UNet the input tensor is from.    
    """
    conv = layer_util.get_nd_layer('Conv', self.rank) # gets a convolutional layer of the specified rank (2D or 3D)
    level_diff = from_level - to_level  # difference between level of decoder and level of UNet the input tensor is from
    size = tuple(np.array(self.pool_size)** (abs(level_diff))) # This specifies the amount of upsampling needed to get to the correct size for decoder level. It is the maxpool size to the power of the level difference.
    maxpool = layer_util.get_nd_layer('MaxPool', self.rank) # gets a maxpool layer of the specified rank (2D or 3D)
    upsamp = layer_util.get_nd_layer('UpSampling', self.rank) # gets an upsampling layer of the specified rank (2D or 3D)
    if self.rank == 2:
        upsamp_config = dict(size=size, interpolation='bilinear') # use bilinear interpolation for 2D upsampling
    else:
        upsamp_config = dict(size=size) # for 3D upsampling, you cannot do bilinear interpolation, so this just uses the default upsampling method.
    
    X = inputs        
    if to_level < from_level: # If coming from a deeper level of the UNet, then we need to upsample the input tensor to the correct size for the decoder level
        X = upsamp(**upsamp_config, name = f'Skip_Upsample_{from_level}_{to_level}')(X)
    elif to_level > from_level: # If coming from a shallower level of the UNet, then we need to maxpool the input tensor to the correct size for the decoder level
        X = maxpool(pool_size = size, name = f'Skip_Maxpool_{from_level}_{to_level}')(X)
    
    if self.skip_batch_norm: # If using batch normalization in the skip connections, then apply it within the conv block
        X = self.conv_block(X, self.filters[to_level-1], block_depth = self.decoder_block_depth, conv_block_purpose ='Skip', level = f'{from_level}_{to_level}') # applies conv block to the upsampled/maxpooled input tensor (with batch normalization)
    else:
        X = conv(self.filters[to_level-1],**self.conv_config, name = f'Skip_Conv_{from_level}_{to_level}')(X)  # applies conv layer to the upsampled/maxpooled input tensor (without batch normalization)
        
    return X # note: returns the output of a single skip connection, but does not yet concatenate the output to the other skip outputs or existing decoder level filters. 
       
   def conv_block(self, inputs, filters, block_depth, conv_block_purpose, level):
       """
       This function creates a convolutional block with the specified number of stacks and filters.
         Args:
                inputs: Input tensor.
                filters: Number of filters for the convolutional layers.
                block_depth: Number of convolutional stacks in the block.
                conv_block_purpose: Type of conv block (Encoder, Decoder, Skip).
                level: Current level level.
       """
       conv = layer_util.get_nd_layer('Conv', self.rank) # gets a convolutional layer of the specified rank (2D or 3D)
       X = inputs
       for i in range(block_depth): # replicate the conv block, depth number of times
           X = conv(filters, **self.conv_config, name = f'{conv_block_purpose}{level}_Conv_{i+1}')(X) # applies conv layer to the input tensor
           if self.batch_norm: # If using batch normalization, then apply it after the conv layer
               X = tf.keras.layers.BatchNormalization(axis=-1, name = f'{conv_block_purpose}{level}_BN_{i+1}')(X) 
           if self.activation: # If using an activation function, then apply it after the conv layer
            X = tf.keras.layers.Activation(activation = self.activation, name = f'{conv_block_purpose}{level}_Activation_{i+1}')(X)
       return X
   
   
   def encode(self, inputs, level, block_depth):
       """
       Creates the encoding block of the U-Net3+ model.

         Args:
                inputs: Input tensor.
                level: Current level level.
                block_depth: Number of convolutional stacks in the block.
       """
       maxpool = layer_util.get_nd_layer('MaxPool', self.rank) # gets a maxpool layer of the specified rank (2D or 3D)
       level -= 1 # python indexing
       filters = self.filters[level] # get the number of filters for the current level
       X = inputs
       if level != 0: # 0 is the input level, so we do not need to maxpool it
           X = maxpool(pool_size=self.pool_size, name = f'encoding_{level}_maxpool')(X) # maxpool the input tensor to the correct size for the next level
       X = self.conv_block(X, filters, block_depth, conv_block_purpose = 'Encoder', level = level+1) # applies conv block to the maxpooled input tensor
       if level == (self.levels-1) and self.add_dropout: # Check if level is the bottom level of the UNet, and if so, apply dropout if specified
           X = tf.keras.layers.Dropout(rate = self.dropout_rate, name = f'Encoder{level+1}_dropout')(X)
       return X
       
   def outputs(self):
       """
       This is the build function for the U-Net3+ model. 

       """
       XE  = [self.inputs] # This is a list of encoder level outputs, starting with the input tensor
       for i in range(self.levels): # for each level of the UNet, we apply an encoding block to the output of the previous level
           XE.append(self.encode(XE[i], level = i+1, block_depth = self.encoder_block_depth))
       XD = [XE[-1]] # This is a list of decoder level outputs, starting with the output of the last encoder level
       if self.skip_type == 'encoder': 
           # If using encoder-type skip connections, then we apply skip connections from every encoder level to the current decoder level - except the encoder level one deeper. For this level, we use the output of the last decoder level.
           for decoder_level in range(self.levels-1,0,-1): # build the decoder levels in reverse order
               input_contributions = []
               for unet_level in range(1,self.levels+1):
                   if unet_level == decoder_level+1: # If the unet level is one deeper than the decoder level, then we get a skip connection from the output of the last decoder level
                       input_contributions.append(self.skip_connection(XD[-1], decoder_level, unet_level))
                   else: # Otherwise we get a skip connection from the output of the encoder level
                       input_contributions.append(self.skip_connection(XE[unet_level], decoder_level, unet_level))
               XD.append(self.aggregate_and_decode(input_contributions,decoder_level)) # aggregate and conv the skip connections to the current decoder level. This gives the output of the decoder level. Append this to the list of decoder level outputs.
       elif self.skip_type == 'decoder':
           # If using decoder-type skip connections, then 
           for decoder_level in range(self.levels-1,0,-1):
               skip_contributions = []
               # Append skips from encoder
               for encoder_level in range(1,decoder_level+1): # All encoders shallower or equal to the decoder level contribute a skip connection
                   skip_contributions.append(self.skip_connection(XE[encoder_level], decoder_level, encoder_level))
               # Append skips from decoder
               for i in range(len(XD)-1,-1,-1): # All decoders deeper than the current decoder level contribute a skip connection (note: XD is build iteratively in a loop from the deepest level upwards. Therefore at each stage of the loop, XD grows and deeper decoder levels contribute skip connections to the current decoder level)
                   skip_contributions.append(self.skip_connection(XD[i], decoder_level, (self.levels-i)))
               XD.append(self.aggregate_and_decode(skip_contributions,decoder_level)) # aggregate and conv the skip connections to the current decoder level. This gives the output of the decoder level. Append this to the list of decoder level outputs.
       elif self.skip_type == 'standard_unet':
           # If standard_unet type skips, then at each decoder level, we get a skip connection from the corresponding encoder level 
           for decoder_level in range(self.levels-1,0,-1): 
               skip_contributions = [XE[decoder_level],self.skip_connection(XD[-1],decoder_level,decoder_level+1)]
               XD.append(self.aggregate_and_decode(skip_contributions,decoder_level)) # aggregate and conv the skip connections to the current decoder level.
       else:
           raise ValueError(f"Invalid skip_type")
       if self.deep_supervision == True:
           XD = [self.deep_sup(xd, self.levels-i) for i,xd in enumerate(XD)] # If deep supervision is used, then we apply deep supervision to each decoder level output
           return XD
       else:
           XD[-1] = self.deep_sup(XD[-1],1) # If deep supervision is not used, then we only apply deep supervision to the final decoder level output
           return XD[-1]