# Copyright 2021 University College London. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Layer utilities.""" import tensorflow as tf # from tensorflow_mri.python.layers import convolutional # from tensorflow_mri.python.layers import signal_layers def get_nd_layer(name, rank): """Get an N-D layer object. Args: name: A `str`. The name of the requested layer. rank: An `int`. The rank of the requested layer. Returns: A `tf.keras.layers.Layer` object. Raises: ValueError: If the requested layer is unknown to TFMRI. """ try: return _ND_LAYERS[(name, rank)] except KeyError as err: raise ValueError( f"Could not find a layer with name '{name}' and rank {rank}.") from err _ND_LAYERS = { ('AveragePooling', 1): tf.keras.layers.AveragePooling1D, ('AveragePooling', 2): tf.keras.layers.AveragePooling2D, ('AveragePooling', 3): tf.keras.layers.AveragePooling3D, ('Conv', 1): tf.keras.layers.Conv1D, ('Conv', 2): tf.keras.layers.Conv2D, ('Conv', 3): tf.keras.layers.Conv3D, ('ConvLSTM', 1): tf.keras.layers.ConvLSTM1D, ('ConvLSTM', 2): tf.keras.layers.ConvLSTM2D, ('ConvLSTM', 3): tf.keras.layers.ConvLSTM3D, ('ConvTranspose', 1): tf.keras.layers.Conv1DTranspose, ('ConvTranspose', 2): tf.keras.layers.Conv2DTranspose, ('ConvTranspose', 3): tf.keras.layers.Conv3DTranspose, ('Cropping', 1): tf.keras.layers.Cropping1D, ('Cropping', 2): tf.keras.layers.Cropping2D, ('Cropping', 3): tf.keras.layers.Cropping3D, ('DepthwiseConv', 1): tf.keras.layers.DepthwiseConv1D, ('DepthwiseConv', 2): tf.keras.layers.DepthwiseConv2D, # ('DWT', 1): signal_layers.DWT1D, # ('DWT', 2): signal_layers.DWT2D, # ('DWT', 3): signal_layers.DWT3D, ('GlobalAveragePooling', 1): tf.keras.layers.GlobalAveragePooling1D, ('GlobalAveragePooling', 2): tf.keras.layers.GlobalAveragePooling2D, ('GlobalAveragePooling', 3): tf.keras.layers.GlobalAveragePooling3D, ('GlobalMaxPool', 1): tf.keras.layers.GlobalMaxPool1D, ('GlobalMaxPool', 2): tf.keras.layers.GlobalMaxPool2D, ('GlobalMaxPool', 3): tf.keras.layers.GlobalMaxPool3D, # ('IDWT', 1): signal_layers.IDWT1D, # ('IDWT', 2): signal_layers.IDWT2D, # ('IDWT', 3): signal_layers.IDWT3D, ('LocallyConnected', 1): tf.keras.layers.LocallyConnected1D, ('LocallyConnected', 2): tf.keras.layers.LocallyConnected2D, ('MaxPool', 1): tf.keras.layers.MaxPool1D, ('MaxPool', 2): tf.keras.layers.MaxPool2D, ('MaxPool', 3): tf.keras.layers.MaxPool3D, ('SeparableConv', 1): tf.keras.layers.SeparableConv1D, ('SeparableConv', 2): tf.keras.layers.SeparableConv2D, ('SpatialDropout', 1): tf.keras.layers.SpatialDropout1D, ('SpatialDropout', 2): tf.keras.layers.SpatialDropout2D, ('SpatialDropout', 3): tf.keras.layers.SpatialDropout3D, ('UpSampling', 1): tf.keras.layers.UpSampling1D, ('UpSampling', 2): tf.keras.layers.UpSampling2D, ('UpSampling', 3): tf.keras.layers.UpSampling3D, ('ZeroPadding', 1): tf.keras.layers.ZeroPadding1D, ('ZeroPadding', 2): tf.keras.layers.ZeroPadding2D, ('ZeroPadding', 3): tf.keras.layers.ZeroPadding3D }