Spaces:
Sleeping
Sleeping
| # Copyright 2021 University College London. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Layer utilities.""" | |
| import tensorflow as tf | |
| from .array_ops import resize_with_crop_or_pad | |
| def get_nd_layer(name, rank): | |
| """Get an N-D layer object. | |
| Args: | |
| name: A `str`. The name of the requested layer. | |
| rank: An `int`. The rank of the requested layer. | |
| Returns: | |
| A `tf.keras.layers.Layer` object. | |
| Raises: | |
| ValueError: If the requested layer is unknown to TFMRI. | |
| """ | |
| try: | |
| return _ND_LAYERS[(name, rank)] | |
| except KeyError as err: | |
| raise ValueError( | |
| f"Could not find a layer with name '{name}' and rank {rank}.") from err | |
| _ND_LAYERS = { | |
| ('AveragePooling', 1): tf.keras.layers.AveragePooling1D, | |
| ('AveragePooling', 2): tf.keras.layers.AveragePooling2D, | |
| ('AveragePooling', 3): tf.keras.layers.AveragePooling3D, | |
| ('Conv', 1): tf.keras.layers.Conv1D, | |
| ('Conv', 2): tf.keras.layers.Conv2D, | |
| ('Conv', 3): tf.keras.layers.Conv3D, | |
| ('ConvLSTM', 1): tf.keras.layers.ConvLSTM1D, | |
| ('ConvLSTM', 2): tf.keras.layers.ConvLSTM2D, | |
| ('ConvLSTM', 3): tf.keras.layers.ConvLSTM3D, | |
| ('ConvTranspose', 1): tf.keras.layers.Conv1DTranspose, | |
| ('ConvTranspose', 2): tf.keras.layers.Conv2DTranspose, | |
| ('ConvTranspose', 3): tf.keras.layers.Conv3DTranspose, | |
| ('Cropping', 1): tf.keras.layers.Cropping1D, | |
| ('Cropping', 2): tf.keras.layers.Cropping2D, | |
| ('Cropping', 3): tf.keras.layers.Cropping3D, | |
| ('DepthwiseConv', 1): tf.keras.layers.DepthwiseConv1D, | |
| ('DepthwiseConv', 2): tf.keras.layers.DepthwiseConv2D, | |
| ('GlobalAveragePooling', 1): tf.keras.layers.GlobalAveragePooling1D, | |
| ('GlobalAveragePooling', 2): tf.keras.layers.GlobalAveragePooling2D, | |
| ('GlobalAveragePooling', 3): tf.keras.layers.GlobalAveragePooling3D, | |
| ('GlobalMaxPool', 1): tf.keras.layers.GlobalMaxPool1D, | |
| ('GlobalMaxPool', 2): tf.keras.layers.GlobalMaxPool2D, | |
| ('GlobalMaxPool', 3): tf.keras.layers.GlobalMaxPool3D, | |
| ('MaxPool', 1): tf.keras.layers.MaxPool1D, | |
| ('MaxPool', 2): tf.keras.layers.MaxPool2D, | |
| ('MaxPool', 3): tf.keras.layers.MaxPool3D, | |
| ('SeparableConv', 1): tf.keras.layers.SeparableConv1D, | |
| ('SeparableConv', 2): tf.keras.layers.SeparableConv2D, | |
| ('SpatialDropout', 1): tf.keras.layers.SpatialDropout1D, | |
| ('SpatialDropout', 2): tf.keras.layers.SpatialDropout2D, | |
| ('SpatialDropout', 3): tf.keras.layers.SpatialDropout3D, | |
| ('UpSampling', 1): tf.keras.layers.UpSampling1D, | |
| ('UpSampling', 2): tf.keras.layers.UpSampling2D, | |
| ('UpSampling', 3): tf.keras.layers.UpSampling3D, | |
| ('ZeroPadding', 1): tf.keras.layers.ZeroPadding1D, | |
| ('ZeroPadding', 2): tf.keras.layers.ZeroPadding2D, | |
| ('ZeroPadding', 3): tf.keras.layers.ZeroPadding3D | |
| } | |
| class ResizeAndConcatenate(tf.keras.layers.Layer): | |
| """Resizes and concatenates a list of inputs. | |
| Similar to `tf.keras.layers.Concatenate`, but if the inputs have different | |
| shapes, they are resized to match the shape of the first input. | |
| Args: | |
| axis: Axis along which to concatenate. | |
| """ | |
| def __init__(self, axis=-1, **kwargs): | |
| super().__init__(**kwargs) | |
| self.axis = axis | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({ | |
| "axis": self.axis, | |
| }) | |
| return config | |
| def call(self, inputs): | |
| if not isinstance(inputs, (list, tuple)): | |
| raise ValueError( | |
| f"Layer {self.__class__.__name__} expects a list of inputs. " | |
| f"Received: {inputs}") | |
| rank = inputs[0].shape.rank | |
| if rank is None: | |
| raise ValueError( | |
| f"Layer {self.__class__.__name__} expects inputs with known rank. " | |
| f"Received: {inputs}") | |
| if self.axis >= rank or self.axis < -rank: | |
| raise ValueError( | |
| f"Layer {self.__class__.__name__} expects `axis` to be in the range " | |
| f"[-{rank}, {rank}) for an input of rank {rank}. " | |
| f"Received: {self.axis}") | |
| # Canonical axis (always positive). | |
| axis = self.axis % rank | |
| # Resize inputs. | |
| shape = tf.tensor_scatter_nd_update(tf.shape(inputs[0]), [[axis]], [-1]) | |
| resized = [resize_with_crop_or_pad(tensor, shape) | |
| for tensor in inputs[1:]] | |
| # Set the static shape for each resized tensor. | |
| for i, tensor in enumerate(resized): | |
| static_shape = inputs[0].shape.as_list() | |
| static_shape[axis] = inputs[i + 1].shape.as_list()[axis] | |
| static_shape = tf.TensorShape(static_shape) | |
| resized[i] = tf.ensure_shape(tensor, static_shape) | |
| return tf.concat(inputs[:1] + resized, axis=self.axis) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter | |