# Copyright 2021 University College London. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Layer utilities.""" import tensorflow as tf from .array_ops import resize_with_crop_or_pad def get_nd_layer(name, rank): """Get an N-D layer object. Args: name: A `str`. The name of the requested layer. rank: An `int`. The rank of the requested layer. Returns: A `tf.keras.layers.Layer` object. Raises: ValueError: If the requested layer is unknown to TFMRI. """ try: return _ND_LAYERS[(name, rank)] except KeyError as err: raise ValueError( f"Could not find a layer with name '{name}' and rank {rank}.") from err _ND_LAYERS = { ('AveragePooling', 1): tf.keras.layers.AveragePooling1D, ('AveragePooling', 2): tf.keras.layers.AveragePooling2D, ('AveragePooling', 3): tf.keras.layers.AveragePooling3D, ('Conv', 1): tf.keras.layers.Conv1D, ('Conv', 2): tf.keras.layers.Conv2D, ('Conv', 3): tf.keras.layers.Conv3D, ('ConvLSTM', 1): tf.keras.layers.ConvLSTM1D, ('ConvLSTM', 2): tf.keras.layers.ConvLSTM2D, ('ConvLSTM', 3): tf.keras.layers.ConvLSTM3D, ('ConvTranspose', 1): tf.keras.layers.Conv1DTranspose, ('ConvTranspose', 2): tf.keras.layers.Conv2DTranspose, ('ConvTranspose', 3): tf.keras.layers.Conv3DTranspose, ('Cropping', 1): tf.keras.layers.Cropping1D, ('Cropping', 2): tf.keras.layers.Cropping2D, ('Cropping', 3): tf.keras.layers.Cropping3D, ('DepthwiseConv', 1): tf.keras.layers.DepthwiseConv1D, ('DepthwiseConv', 2): tf.keras.layers.DepthwiseConv2D, ('GlobalAveragePooling', 1): tf.keras.layers.GlobalAveragePooling1D, ('GlobalAveragePooling', 2): tf.keras.layers.GlobalAveragePooling2D, ('GlobalAveragePooling', 3): tf.keras.layers.GlobalAveragePooling3D, ('GlobalMaxPool', 1): tf.keras.layers.GlobalMaxPool1D, ('GlobalMaxPool', 2): tf.keras.layers.GlobalMaxPool2D, ('GlobalMaxPool', 3): tf.keras.layers.GlobalMaxPool3D, ('MaxPool', 1): tf.keras.layers.MaxPool1D, ('MaxPool', 2): tf.keras.layers.MaxPool2D, ('MaxPool', 3): tf.keras.layers.MaxPool3D, ('SeparableConv', 1): tf.keras.layers.SeparableConv1D, ('SeparableConv', 2): tf.keras.layers.SeparableConv2D, ('SpatialDropout', 1): tf.keras.layers.SpatialDropout1D, ('SpatialDropout', 2): tf.keras.layers.SpatialDropout2D, ('SpatialDropout', 3): tf.keras.layers.SpatialDropout3D, ('UpSampling', 1): tf.keras.layers.UpSampling1D, ('UpSampling', 2): tf.keras.layers.UpSampling2D, ('UpSampling', 3): tf.keras.layers.UpSampling3D, ('ZeroPadding', 1): tf.keras.layers.ZeroPadding1D, ('ZeroPadding', 2): tf.keras.layers.ZeroPadding2D, ('ZeroPadding', 3): tf.keras.layers.ZeroPadding3D } class ResizeAndConcatenate(tf.keras.layers.Layer): """Resizes and concatenates a list of inputs. Similar to `tf.keras.layers.Concatenate`, but if the inputs have different shapes, they are resized to match the shape of the first input. Args: axis: Axis along which to concatenate. """ def __init__(self, axis=-1, **kwargs): super().__init__(**kwargs) self.axis = axis def get_config(self): config = super().get_config() config.update({ "axis": self.axis, }) return config def call(self, inputs): if not isinstance(inputs, (list, tuple)): raise ValueError( f"Layer {self.__class__.__name__} expects a list of inputs. " f"Received: {inputs}") rank = inputs[0].shape.rank if rank is None: raise ValueError( f"Layer {self.__class__.__name__} expects inputs with known rank. " f"Received: {inputs}") if self.axis >= rank or self.axis < -rank: raise ValueError( f"Layer {self.__class__.__name__} expects `axis` to be in the range " f"[-{rank}, {rank}) for an input of rank {rank}. " f"Received: {self.axis}") # Canonical axis (always positive). axis = self.axis % rank # Resize inputs. shape = tf.tensor_scatter_nd_update(tf.shape(inputs[0]), [[axis]], [-1]) resized = [resize_with_crop_or_pad(tensor, shape) for tensor in inputs[1:]] # Set the static shape for each resized tensor. for i, tensor in enumerate(resized): static_shape = inputs[0].shape.as_list() static_shape[axis] = inputs[i + 1].shape.as_list()[axis] static_shape = tf.TensorShape(static_shape) resized[i] = tf.ensure_shape(tensor, static_shape) return tf.concat(inputs[:1] + resized, axis=self.axis) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter