wanhanisah commited on
Commit
b03d3b9
·
verified ·
1 Parent(s): 321897c

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/layer_util.py +132 -0
  2. utils/unet3plus.py +268 -0
utils/layer_util.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 University College London. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Layer utilities."""
16
+
17
+ import tensorflow as tf
18
+ from .array_ops import resize_with_crop_or_pad
19
+
20
+
21
+ def get_nd_layer(name, rank):
22
+ """Get an N-D layer object.
23
+
24
+ Args:
25
+ name: A `str`. The name of the requested layer.
26
+ rank: An `int`. The rank of the requested layer.
27
+
28
+ Returns:
29
+ A `tf.keras.layers.Layer` object.
30
+
31
+ Raises:
32
+ ValueError: If the requested layer is unknown to TFMRI.
33
+ """
34
+ try:
35
+ return _ND_LAYERS[(name, rank)]
36
+ except KeyError as err:
37
+ raise ValueError(
38
+ f"Could not find a layer with name '{name}' and rank {rank}.") from err
39
+
40
+
41
+ _ND_LAYERS = {
42
+ ('AveragePooling', 1): tf.keras.layers.AveragePooling1D,
43
+ ('AveragePooling', 2): tf.keras.layers.AveragePooling2D,
44
+ ('AveragePooling', 3): tf.keras.layers.AveragePooling3D,
45
+ ('Conv', 1): tf.keras.layers.Conv1D,
46
+ ('Conv', 2): tf.keras.layers.Conv2D,
47
+ ('Conv', 3): tf.keras.layers.Conv3D,
48
+ ('ConvLSTM', 1): tf.keras.layers.ConvLSTM1D,
49
+ ('ConvLSTM', 2): tf.keras.layers.ConvLSTM2D,
50
+ ('ConvLSTM', 3): tf.keras.layers.ConvLSTM3D,
51
+ ('ConvTranspose', 1): tf.keras.layers.Conv1DTranspose,
52
+ ('ConvTranspose', 2): tf.keras.layers.Conv2DTranspose,
53
+ ('ConvTranspose', 3): tf.keras.layers.Conv3DTranspose,
54
+ ('Cropping', 1): tf.keras.layers.Cropping1D,
55
+ ('Cropping', 2): tf.keras.layers.Cropping2D,
56
+ ('Cropping', 3): tf.keras.layers.Cropping3D,
57
+ ('DepthwiseConv', 1): tf.keras.layers.DepthwiseConv1D,
58
+ ('DepthwiseConv', 2): tf.keras.layers.DepthwiseConv2D,
59
+ ('GlobalAveragePooling', 1): tf.keras.layers.GlobalAveragePooling1D,
60
+ ('GlobalAveragePooling', 2): tf.keras.layers.GlobalAveragePooling2D,
61
+ ('GlobalAveragePooling', 3): tf.keras.layers.GlobalAveragePooling3D,
62
+ ('GlobalMaxPool', 1): tf.keras.layers.GlobalMaxPool1D,
63
+ ('GlobalMaxPool', 2): tf.keras.layers.GlobalMaxPool2D,
64
+ ('GlobalMaxPool', 3): tf.keras.layers.GlobalMaxPool3D,
65
+ ('MaxPool', 1): tf.keras.layers.MaxPool1D,
66
+ ('MaxPool', 2): tf.keras.layers.MaxPool2D,
67
+ ('MaxPool', 3): tf.keras.layers.MaxPool3D,
68
+ ('SeparableConv', 1): tf.keras.layers.SeparableConv1D,
69
+ ('SeparableConv', 2): tf.keras.layers.SeparableConv2D,
70
+ ('SpatialDropout', 1): tf.keras.layers.SpatialDropout1D,
71
+ ('SpatialDropout', 2): tf.keras.layers.SpatialDropout2D,
72
+ ('SpatialDropout', 3): tf.keras.layers.SpatialDropout3D,
73
+ ('UpSampling', 1): tf.keras.layers.UpSampling1D,
74
+ ('UpSampling', 2): tf.keras.layers.UpSampling2D,
75
+ ('UpSampling', 3): tf.keras.layers.UpSampling3D,
76
+ ('ZeroPadding', 1): tf.keras.layers.ZeroPadding1D,
77
+ ('ZeroPadding', 2): tf.keras.layers.ZeroPadding2D,
78
+ ('ZeroPadding', 3): tf.keras.layers.ZeroPadding3D
79
+ }
80
+
81
+
82
+ class ResizeAndConcatenate(tf.keras.layers.Layer):
83
+ """Resizes and concatenates a list of inputs.
84
+
85
+ Similar to `tf.keras.layers.Concatenate`, but if the inputs have different
86
+ shapes, they are resized to match the shape of the first input.
87
+
88
+ Args:
89
+ axis: Axis along which to concatenate.
90
+ """
91
+ def __init__(self, axis=-1, **kwargs):
92
+ super().__init__(**kwargs)
93
+ self.axis = axis
94
+
95
+ def get_config(self):
96
+ config = super().get_config()
97
+ config.update({
98
+ "axis": self.axis,
99
+ })
100
+ return config
101
+
102
+ def call(self, inputs):
103
+ if not isinstance(inputs, (list, tuple)):
104
+ raise ValueError(
105
+ f"Layer {self.__class__.__name__} expects a list of inputs. "
106
+ f"Received: {inputs}")
107
+
108
+ rank = inputs[0].shape.rank
109
+ if rank is None:
110
+ raise ValueError(
111
+ f"Layer {self.__class__.__name__} expects inputs with known rank. "
112
+ f"Received: {inputs}")
113
+ if self.axis >= rank or self.axis < -rank:
114
+ raise ValueError(
115
+ f"Layer {self.__class__.__name__} expects `axis` to be in the range "
116
+ f"[-{rank}, {rank}) for an input of rank {rank}. "
117
+ f"Received: {self.axis}")
118
+ # Canonical axis (always positive).
119
+ axis = self.axis % rank
120
+
121
+ # Resize inputs.
122
+ shape = tf.tensor_scatter_nd_update(tf.shape(inputs[0]), [[axis]], [-1])
123
+ resized = [resize_with_crop_or_pad(tensor, shape)
124
+ for tensor in inputs[1:]]
125
+
126
+ # Set the static shape for each resized tensor.
127
+ for i, tensor in enumerate(resized):
128
+ static_shape = inputs[0].shape.as_list()
129
+ static_shape[axis] = inputs[i + 1].shape.as_list()[axis]
130
+ static_shape = tf.TensorShape(static_shape)
131
+ resized[i] = tf.ensure_shape(tensor, static_shape)
132
+ return tf.concat(inputs[:1] + resized, axis=self.axis) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
utils/unet3plus.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import standard python modules
2
+ import tensorflow as tf
3
+ import numpy as np
4
+
5
+ # Import custom modules
6
+ from . import layer_util
7
+
8
+ tf.random.set_seed(489154)
9
+
10
+ class unet3plus:
11
+ """
12
+ Class for building a U-Net3+ model.
13
+ """
14
+
15
+ def __init__(self,
16
+ inputs,
17
+ filters = [32,64,128,256,512],
18
+ rank = 2,
19
+ out_channels = 1,
20
+ kernel_initializer=tf.keras.initializers.HeNormal(seed=0),
21
+ bias_initializer=tf.keras.initializers.Zeros(),
22
+ kernel_regularizer=None,
23
+ bias_regularizer=None,
24
+ add_dropout = False,
25
+ padding = 'same',
26
+ dropout_rate = 0.5,
27
+ kernel_size = 3,
28
+ out_kernel_size = 3,
29
+ pool_size = 2,
30
+ encoder_block_depth = 2,
31
+ decoder_block_depth = 1,
32
+ batch_norm = True,
33
+ activation = 'relu',
34
+ out_activation = None,
35
+ skip_batch_norm = True,
36
+ skip_type = 'encoder',
37
+ CGM = False,
38
+ deep_supervision = True):
39
+
40
+ """
41
+ Initialise the U-Net3+ model.
42
+ Args:
43
+ inputs: Input tensor.
44
+ filters: List of filter sizes for each UNet level.
45
+ rank: Number of dimensions (2D or 3D).
46
+ out_channels: Number of output channels (for segmentation this shall be the number of distinct masks).
47
+ kernel_initializer: Initialiser for the convolutional layers.
48
+ bias_initializer: Initialiser for the bias terms.
49
+ kernel_regularizer: Regulariser for the convolutional layers.
50
+ bias_regularizer: Regulariser for the bias terms in convolutional layers.
51
+ add_dropout: Whether to add dropout layers.
52
+ padding: Padding type for the convolutional layers.
53
+ dropout_rate: Dropout rate.
54
+ kernel_size: Kernel size for the convolutional layers.
55
+ out_kernel_size: Kernel size for the final convolutional layers of the network.
56
+ pool_size: Pooling size for the max pooling layers. This can be a tuple specifing the max pool size for each dimension of the input, or a single integer specifying the same size for all dimensions.
57
+ encoder_block_depth: Number of convolutional blocks in each level of the encoding arm.
58
+ decoder_block_depth: Number of convolutional blocks in each level of the decoding arm.
59
+ batch_norm: Whether to use batch normalization.
60
+ activation: Activation function for the convolutional layers.
61
+ out_activation: Activation function for the output layer. For binary segmentation this shall be 'sigmoid' or 'softmax'.
62
+ skip_batch_norm: Whether to use batch normalization in the skip connections.
63
+ skip_type: Type of skip connections to use in the model ('encoder', 'decoder', or 'standard_unet').
64
+ CGM: Whether to use CGM in the model for segmentation (Classification Guided Module).
65
+ deep_supervision: Whether to use deep supervision.
66
+ """
67
+ # Assign parameters
68
+ self.inputs = inputs
69
+ self.filters = filters
70
+ self.levels = len(filters)
71
+ self.rank = rank
72
+ self.out_channels = out_channels
73
+ self.encoder_block_depth = encoder_block_depth
74
+ self.decoder_block_depth = decoder_block_depth
75
+ self.kernel_size = kernel_size
76
+ self.add_dropout = add_dropout
77
+ self.dropout_rate = dropout_rate
78
+ self.skip_type = skip_type
79
+ self.skip_batch_norm = skip_batch_norm
80
+ self.batch_norm = batch_norm
81
+ self.activation = activation
82
+ self.out_activation = out_activation
83
+ self.CGM = CGM
84
+ self.deep_supervision = deep_supervision
85
+ # Assign pool size based on given tuple, or if single integer is provided, assign the same value to all dimensions using the rank as a guide for the number of dimensions
86
+ if isinstance(pool_size,tuple):
87
+ self.pool_size = pool_size
88
+ else:
89
+ self.pool_size = tuple([pool_size for _ in range(rank)])
90
+ # Assign kernel sizes based on given tuple, or if single integer is provided, assign the same value to all dimensions using the rank as a guide for the number of dimensions
91
+ if isinstance(kernel_size,tuple):
92
+ self.kernel_size = kernel_size
93
+ else:
94
+ self.kernel_size = tuple([kernel_size for _ in range(rank)])
95
+ if isinstance(out_kernel_size,tuple):
96
+ self.out_kernel_size = out_kernel_size
97
+ else:
98
+ self.out_kernel_size = tuple([out_kernel_size for _ in range(rank)])
99
+ # Create the conv and out conv config dictionaries for the conv and out conv layers
100
+ self.conv_config = dict(kernel_size = self.kernel_size,
101
+ padding = padding,
102
+ kernel_initializer = kernel_initializer,
103
+ bias_initializer = bias_initializer,
104
+ kernel_regularizer = kernel_regularizer,
105
+ bias_regularizer = bias_regularizer)
106
+ self.out_conv_config = dict(kernel_size = out_kernel_size,
107
+ padding = padding,
108
+ kernel_initializer = kernel_initializer,
109
+ bias_initializer = bias_initializer,
110
+ kernel_regularizer = kernel_regularizer,
111
+ bias_regularizer = bias_regularizer)
112
+
113
+ def aggregate_and_decode(self, input_list, level):
114
+ """
115
+ Aggregates the inputs for the decoder levels and applies convolution to get the output of the decoder level.
116
+
117
+ Args:
118
+ input_list: List of inputs to the decoder to be aggregated.
119
+ level: Current decoder level.
120
+ """
121
+ X = layer_util.ResizeAndConcatenate(name = f'D{level}_input', axis = -1)(input_list) # Takes the various inputs to a decoder level, resizes them to the 1st input size in the list and the concatenates them all.
122
+ X = self.conv_block(X, self.filters[level-1], block_depth = self.decoder_block_depth, conv_block_purpose = 'Decoder', level=level) # Performs a decoder block convolution of the concatenated input (i.e. the concatenated list of filters)
123
+ return X
124
+
125
+ def deep_sup(self, inputs, level):
126
+ """
127
+ If deep supervision is used, then the network will output a prediction at each level of the decoder.
128
+ This function upsamples the output of a decoder level, convolves it and then applies the output activation function (i.e. to get to the final output).
129
+ If deep supervision is not used, then the network will only output a prediction at the final level of the decoder.
130
+
131
+ Args:
132
+ inputs: Input tensor.
133
+ level: Current decoder level.
134
+ """
135
+ conv = layer_util.get_nd_layer('Conv', self.rank) # gets a convolutional layer of the specified rank (2D or 3D)
136
+ upsamp = layer_util.get_nd_layer('UpSampling', self.rank) # gets an upsampling layer of the specified rank (2D or 3D)
137
+ size = tuple(np.array(self.pool_size)** (abs(level-1))) # This specifies the amount of upsampling needed to get to the correct final output size. It is the maxpool size to the power of the current decoder level minus one.
138
+ if self.rank == 2:
139
+ upsamp_config = dict(size=size, interpolation='bilinear') # use bilinear interpolation for 2D upsampling
140
+ else:
141
+ upsamp_config = dict(size=size) # for 3D upsampling, you cannot do bilinear interpolation, so this just uses the default upsampling method.
142
+ X = inputs
143
+ X = conv(self.out_channels, activation = None, **self.out_conv_config, name = f'deepsup_conv_{level}_1')(X) # Convolves the input to get the correct number of output channels
144
+ if level != 1:
145
+ X = upsamp(**upsamp_config, name = f'deepsup_upsamp_{level}')(X) # Upsamples the convolved input to the correct size for the final output
146
+ X = conv(self.out_channels, activation = None, **self.out_conv_config, name = f'deepsup_conv_{level}_2')(X) # Convolves the upsampled input to get the correct number of output channels (e.g. to correct artifacts due to upsampling)
147
+ if self.out_activation:
148
+ X = tf.keras.layers.Activation(activation = self.out_activation, name = f'deepsup_activation_{level}')(X) # Applies the output activation function to get the final output
149
+ return X
150
+
151
+
152
+
153
+ def skip_connection(self, inputs, to_level, from_level):
154
+ """
155
+ This function takes an input tensor and processes it as a skip connection to the decoder level.
156
+
157
+ Args:
158
+ inputs: Input tensor.
159
+ to_level: Current decoder level.
160
+ from_level: Level of UNet the input tensor is from.
161
+ """
162
+ conv = layer_util.get_nd_layer('Conv', self.rank) # gets a convolutional layer of the specified rank (2D or 3D)
163
+ level_diff = from_level - to_level # difference between level of decoder and level of UNet the input tensor is from
164
+ size = tuple(np.array(self.pool_size)** (abs(level_diff))) # This specifies the amount of upsampling needed to get to the correct size for decoder level. It is the maxpool size to the power of the level difference.
165
+ maxpool = layer_util.get_nd_layer('MaxPool', self.rank) # gets a maxpool layer of the specified rank (2D or 3D)
166
+ upsamp = layer_util.get_nd_layer('UpSampling', self.rank) # gets an upsampling layer of the specified rank (2D or 3D)
167
+ if self.rank == 2:
168
+ upsamp_config = dict(size=size, interpolation='bilinear') # use bilinear interpolation for 2D upsampling
169
+ else:
170
+ upsamp_config = dict(size=size) # for 3D upsampling, you cannot do bilinear interpolation, so this just uses the default upsampling method.
171
+
172
+ X = inputs
173
+ if to_level < from_level: # If coming from a deeper level of the UNet, then we need to upsample the input tensor to the correct size for the decoder level
174
+ X = upsamp(**upsamp_config, name = f'Skip_Upsample_{from_level}_{to_level}')(X)
175
+ elif to_level > from_level: # If coming from a shallower level of the UNet, then we need to maxpool the input tensor to the correct size for the decoder level
176
+ X = maxpool(pool_size = size, name = f'Skip_Maxpool_{from_level}_{to_level}')(X)
177
+
178
+ if self.skip_batch_norm: # If using batch normalization in the skip connections, then apply it within the conv block
179
+ X = self.conv_block(X, self.filters[to_level-1], block_depth = self.decoder_block_depth, conv_block_purpose ='Skip', level = f'{from_level}_{to_level}') # applies conv block to the upsampled/maxpooled input tensor (with batch normalization)
180
+ else:
181
+ X = conv(self.filters[to_level-1],**self.conv_config, name = f'Skip_Conv_{from_level}_{to_level}')(X) # applies conv layer to the upsampled/maxpooled input tensor (without batch normalization)
182
+
183
+ return X # note: returns the output of a single skip connection, but does not yet concatenate the output to the other skip outputs or existing decoder level filters.
184
+
185
+ def conv_block(self, inputs, filters, block_depth, conv_block_purpose, level):
186
+ """
187
+ This function creates a convolutional block with the specified number of stacks and filters.
188
+ Args:
189
+ inputs: Input tensor.
190
+ filters: Number of filters for the convolutional layers.
191
+ block_depth: Number of convolutional stacks in the block.
192
+ conv_block_purpose: Type of conv block (Encoder, Decoder, Skip).
193
+ level: Current level level.
194
+ """
195
+ conv = layer_util.get_nd_layer('Conv', self.rank) # gets a convolutional layer of the specified rank (2D or 3D)
196
+ X = inputs
197
+ for i in range(block_depth): # replicate the conv block, depth number of times
198
+ X = conv(filters, **self.conv_config, name = f'{conv_block_purpose}{level}_Conv_{i+1}')(X) # applies conv layer to the input tensor
199
+ if self.batch_norm: # If using batch normalization, then apply it after the conv layer
200
+ X = tf.keras.layers.BatchNormalization(axis=-1, name = f'{conv_block_purpose}{level}_BN_{i+1}')(X)
201
+ if self.activation: # If using an activation function, then apply it after the conv layer
202
+ X = tf.keras.layers.Activation(activation = self.activation, name = f'{conv_block_purpose}{level}_Activation_{i+1}')(X)
203
+ return X
204
+
205
+
206
+ def encode(self, inputs, level, block_depth):
207
+ """
208
+ Creates the encoding block of the U-Net3+ model.
209
+
210
+ Args:
211
+ inputs: Input tensor.
212
+ level: Current level level.
213
+ block_depth: Number of convolutional stacks in the block.
214
+ """
215
+ maxpool = layer_util.get_nd_layer('MaxPool', self.rank) # gets a maxpool layer of the specified rank (2D or 3D)
216
+ level -= 1 # python indexing
217
+ filters = self.filters[level] # get the number of filters for the current level
218
+ X = inputs
219
+ if level != 0: # 0 is the input level, so we do not need to maxpool it
220
+ X = maxpool(pool_size=self.pool_size, name = f'encoding_{level}_maxpool')(X) # maxpool the input tensor to the correct size for the next level
221
+ X = self.conv_block(X, filters, block_depth, conv_block_purpose = 'Encoder', level = level+1) # applies conv block to the maxpooled input tensor
222
+ if level == (self.levels-1) and self.add_dropout: # Check if level is the bottom level of the UNet, and if so, apply dropout if specified
223
+ X = tf.keras.layers.Dropout(rate = self.dropout_rate, name = f'Encoder{level+1}_dropout')(X)
224
+ return X
225
+
226
+ def outputs(self):
227
+ """
228
+ This is the build function for the U-Net3+ model.
229
+
230
+ """
231
+ XE = [self.inputs] # This is a list of encoder level outputs, starting with the input tensor
232
+ for i in range(self.levels): # for each level of the UNet, we apply an encoding block to the output of the previous level
233
+ XE.append(self.encode(XE[i], level = i+1, block_depth = self.encoder_block_depth))
234
+ XD = [XE[-1]] # This is a list of decoder level outputs, starting with the output of the last encoder level
235
+ if self.skip_type == 'encoder':
236
+ # If using encoder-type skip connections, then we apply skip connections from every encoder level to the current decoder level - except the encoder level one deeper. For this level, we use the output of the last decoder level.
237
+ for decoder_level in range(self.levels-1,0,-1): # build the decoder levels in reverse order
238
+ input_contributions = []
239
+ for unet_level in range(1,self.levels+1):
240
+ if unet_level == decoder_level+1: # If the unet level is one deeper than the decoder level, then we get a skip connection from the output of the last decoder level
241
+ input_contributions.append(self.skip_connection(XD[-1], decoder_level, unet_level))
242
+ else: # Otherwise we get a skip connection from the output of the encoder level
243
+ input_contributions.append(self.skip_connection(XE[unet_level], decoder_level, unet_level))
244
+ XD.append(self.aggregate_and_decode(input_contributions,decoder_level)) # aggregate and conv the skip connections to the current decoder level. This gives the output of the decoder level. Append this to the list of decoder level outputs.
245
+ elif self.skip_type == 'decoder':
246
+ # If using decoder-type skip connections, then
247
+ for decoder_level in range(self.levels-1,0,-1):
248
+ skip_contributions = []
249
+ # Append skips from encoder
250
+ for encoder_level in range(1,decoder_level+1): # All encoders shallower or equal to the decoder level contribute a skip connection
251
+ skip_contributions.append(self.skip_connection(XE[encoder_level], decoder_level, encoder_level))
252
+ # Append skips from decoder
253
+ for i in range(len(XD)-1,-1,-1): # All decoders deeper than the current decoder level contribute a skip connection (note: XD is build iteratively in a loop from the deepest level upwards. Therefore at each stage of the loop, XD grows and deeper decoder levels contribute skip connections to the current decoder level)
254
+ skip_contributions.append(self.skip_connection(XD[i], decoder_level, (self.levels-i)))
255
+ XD.append(self.aggregate_and_decode(skip_contributions,decoder_level)) # aggregate and conv the skip connections to the current decoder level. This gives the output of the decoder level. Append this to the list of decoder level outputs.
256
+ elif self.skip_type == 'standard_unet':
257
+ # If standard_unet type skips, then at each decoder level, we get a skip connection from the corresponding encoder level
258
+ for decoder_level in range(self.levels-1,0,-1):
259
+ skip_contributions = [XE[decoder_level],self.skip_connection(XD[-1],decoder_level,decoder_level+1)]
260
+ XD.append(self.aggregate_and_decode(skip_contributions,decoder_level)) # aggregate and conv the skip connections to the current decoder level.
261
+ else:
262
+ raise ValueError(f"Invalid skip_type")
263
+ if self.deep_supervision == True:
264
+ XD = [self.deep_sup(xd, self.levels-i) for i,xd in enumerate(XD)] # If deep supervision is used, then we apply deep supervision to each decoder level output
265
+ return XD
266
+ else:
267
+ XD[-1] = self.deep_sup(XD[-1],1) # If deep supervision is not used, then we only apply deep supervision to the final decoder level output
268
+ return XD[-1]