3DCine / utils /layer_util.py
MarkWrobel's picture
Upload 16 files
f8317f9 verified
# Copyright 2021 University College London. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layer utilities."""
import tensorflow as tf
# from tensorflow_mri.python.layers import convolutional
# from tensorflow_mri.python.layers import signal_layers
def get_nd_layer(name, rank):
"""Get an N-D layer object.
Args:
name: A `str`. The name of the requested layer.
rank: An `int`. The rank of the requested layer.
Returns:
A `tf.keras.layers.Layer` object.
Raises:
ValueError: If the requested layer is unknown to TFMRI.
"""
try:
return _ND_LAYERS[(name, rank)]
except KeyError as err:
raise ValueError(
f"Could not find a layer with name '{name}' and rank {rank}.") from err
_ND_LAYERS = {
('AveragePooling', 1): tf.keras.layers.AveragePooling1D,
('AveragePooling', 2): tf.keras.layers.AveragePooling2D,
('AveragePooling', 3): tf.keras.layers.AveragePooling3D,
('Conv', 1): tf.keras.layers.Conv1D,
('Conv', 2): tf.keras.layers.Conv2D,
('Conv', 3): tf.keras.layers.Conv3D,
('ConvLSTM', 1): tf.keras.layers.ConvLSTM1D,
('ConvLSTM', 2): tf.keras.layers.ConvLSTM2D,
('ConvLSTM', 3): tf.keras.layers.ConvLSTM3D,
('ConvTranspose', 1): tf.keras.layers.Conv1DTranspose,
('ConvTranspose', 2): tf.keras.layers.Conv2DTranspose,
('ConvTranspose', 3): tf.keras.layers.Conv3DTranspose,
('Cropping', 1): tf.keras.layers.Cropping1D,
('Cropping', 2): tf.keras.layers.Cropping2D,
('Cropping', 3): tf.keras.layers.Cropping3D,
('DepthwiseConv', 1): tf.keras.layers.DepthwiseConv1D,
('DepthwiseConv', 2): tf.keras.layers.DepthwiseConv2D,
# ('DWT', 1): signal_layers.DWT1D,
# ('DWT', 2): signal_layers.DWT2D,
# ('DWT', 3): signal_layers.DWT3D,
('GlobalAveragePooling', 1): tf.keras.layers.GlobalAveragePooling1D,
('GlobalAveragePooling', 2): tf.keras.layers.GlobalAveragePooling2D,
('GlobalAveragePooling', 3): tf.keras.layers.GlobalAveragePooling3D,
('GlobalMaxPool', 1): tf.keras.layers.GlobalMaxPool1D,
('GlobalMaxPool', 2): tf.keras.layers.GlobalMaxPool2D,
('GlobalMaxPool', 3): tf.keras.layers.GlobalMaxPool3D,
# ('IDWT', 1): signal_layers.IDWT1D,
# ('IDWT', 2): signal_layers.IDWT2D,
# ('IDWT', 3): signal_layers.IDWT3D,
('LocallyConnected', 1): tf.keras.layers.LocallyConnected1D,
('LocallyConnected', 2): tf.keras.layers.LocallyConnected2D,
('MaxPool', 1): tf.keras.layers.MaxPool1D,
('MaxPool', 2): tf.keras.layers.MaxPool2D,
('MaxPool', 3): tf.keras.layers.MaxPool3D,
('SeparableConv', 1): tf.keras.layers.SeparableConv1D,
('SeparableConv', 2): tf.keras.layers.SeparableConv2D,
('SpatialDropout', 1): tf.keras.layers.SpatialDropout1D,
('SpatialDropout', 2): tf.keras.layers.SpatialDropout2D,
('SpatialDropout', 3): tf.keras.layers.SpatialDropout3D,
('UpSampling', 1): tf.keras.layers.UpSampling1D,
('UpSampling', 2): tf.keras.layers.UpSampling2D,
('UpSampling', 3): tf.keras.layers.UpSampling3D,
('ZeroPadding', 1): tf.keras.layers.ZeroPadding1D,
('ZeroPadding', 2): tf.keras.layers.ZeroPadding2D,
('ZeroPadding', 3): tf.keras.layers.ZeroPadding3D
}