File size: 5,221 Bytes
b03d3b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright 2021 University College London. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layer utilities."""
import tensorflow as tf
from .array_ops import resize_with_crop_or_pad
def get_nd_layer(name, rank):
"""Get an N-D layer object.
Args:
name: A `str`. The name of the requested layer.
rank: An `int`. The rank of the requested layer.
Returns:
A `tf.keras.layers.Layer` object.
Raises:
ValueError: If the requested layer is unknown to TFMRI.
"""
try:
return _ND_LAYERS[(name, rank)]
except KeyError as err:
raise ValueError(
f"Could not find a layer with name '{name}' and rank {rank}.") from err
_ND_LAYERS = {
('AveragePooling', 1): tf.keras.layers.AveragePooling1D,
('AveragePooling', 2): tf.keras.layers.AveragePooling2D,
('AveragePooling', 3): tf.keras.layers.AveragePooling3D,
('Conv', 1): tf.keras.layers.Conv1D,
('Conv', 2): tf.keras.layers.Conv2D,
('Conv', 3): tf.keras.layers.Conv3D,
('ConvLSTM', 1): tf.keras.layers.ConvLSTM1D,
('ConvLSTM', 2): tf.keras.layers.ConvLSTM2D,
('ConvLSTM', 3): tf.keras.layers.ConvLSTM3D,
('ConvTranspose', 1): tf.keras.layers.Conv1DTranspose,
('ConvTranspose', 2): tf.keras.layers.Conv2DTranspose,
('ConvTranspose', 3): tf.keras.layers.Conv3DTranspose,
('Cropping', 1): tf.keras.layers.Cropping1D,
('Cropping', 2): tf.keras.layers.Cropping2D,
('Cropping', 3): tf.keras.layers.Cropping3D,
('DepthwiseConv', 1): tf.keras.layers.DepthwiseConv1D,
('DepthwiseConv', 2): tf.keras.layers.DepthwiseConv2D,
('GlobalAveragePooling', 1): tf.keras.layers.GlobalAveragePooling1D,
('GlobalAveragePooling', 2): tf.keras.layers.GlobalAveragePooling2D,
('GlobalAveragePooling', 3): tf.keras.layers.GlobalAveragePooling3D,
('GlobalMaxPool', 1): tf.keras.layers.GlobalMaxPool1D,
('GlobalMaxPool', 2): tf.keras.layers.GlobalMaxPool2D,
('GlobalMaxPool', 3): tf.keras.layers.GlobalMaxPool3D,
('MaxPool', 1): tf.keras.layers.MaxPool1D,
('MaxPool', 2): tf.keras.layers.MaxPool2D,
('MaxPool', 3): tf.keras.layers.MaxPool3D,
('SeparableConv', 1): tf.keras.layers.SeparableConv1D,
('SeparableConv', 2): tf.keras.layers.SeparableConv2D,
('SpatialDropout', 1): tf.keras.layers.SpatialDropout1D,
('SpatialDropout', 2): tf.keras.layers.SpatialDropout2D,
('SpatialDropout', 3): tf.keras.layers.SpatialDropout3D,
('UpSampling', 1): tf.keras.layers.UpSampling1D,
('UpSampling', 2): tf.keras.layers.UpSampling2D,
('UpSampling', 3): tf.keras.layers.UpSampling3D,
('ZeroPadding', 1): tf.keras.layers.ZeroPadding1D,
('ZeroPadding', 2): tf.keras.layers.ZeroPadding2D,
('ZeroPadding', 3): tf.keras.layers.ZeroPadding3D
}
class ResizeAndConcatenate(tf.keras.layers.Layer):
"""Resizes and concatenates a list of inputs.
Similar to `tf.keras.layers.Concatenate`, but if the inputs have different
shapes, they are resized to match the shape of the first input.
Args:
axis: Axis along which to concatenate.
"""
def __init__(self, axis=-1, **kwargs):
super().__init__(**kwargs)
self.axis = axis
def get_config(self):
config = super().get_config()
config.update({
"axis": self.axis,
})
return config
def call(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise ValueError(
f"Layer {self.__class__.__name__} expects a list of inputs. "
f"Received: {inputs}")
rank = inputs[0].shape.rank
if rank is None:
raise ValueError(
f"Layer {self.__class__.__name__} expects inputs with known rank. "
f"Received: {inputs}")
if self.axis >= rank or self.axis < -rank:
raise ValueError(
f"Layer {self.__class__.__name__} expects `axis` to be in the range "
f"[-{rank}, {rank}) for an input of rank {rank}. "
f"Received: {self.axis}")
# Canonical axis (always positive).
axis = self.axis % rank
# Resize inputs.
shape = tf.tensor_scatter_nd_update(tf.shape(inputs[0]), [[axis]], [-1])
resized = [resize_with_crop_or_pad(tensor, shape)
for tensor in inputs[1:]]
# Set the static shape for each resized tensor.
for i, tensor in enumerate(resized):
static_shape = inputs[0].shape.as_list()
static_shape[axis] = inputs[i + 1].shape.as_list()[axis]
static_shape = tf.TensorShape(static_shape)
resized[i] = tf.ensure_shape(tensor, static_shape)
return tf.concat(inputs[:1] + resized, axis=self.axis) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
|