# Copyright 2021 University College London. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Array manipulation operations.""" import numpy as np import tensorflow as tf import tensorflow.experimental.numpy as tnp from tensorflow.python.ops.numpy_ops import np_array_ops def broadcast_static_shapes(*shapes): """Computes the shape of a broadcast given known shapes. Like `tf.broadcast_static_shape`, but accepts any number of shapes. Args: *shapes: Two or more `TensorShapes`. Returns: A `TensorShape` representing the broadcasted shape. """ bcast_shape = shapes[0] for shape in shapes[1:]: bcast_shape = tf.broadcast_static_shape(bcast_shape, shape) return bcast_shape def broadcast_dynamic_shapes(*shapes): """Computes the shape of a broadcast given symbolic shapes. Like `tf.broadcast_dynamic_shape`, but accepts any number of shapes. Args: shapes: Two or more rank-1 integer `Tensors` representing the input shapes. Returns: A rank-1 integer `Tensor` representing the broadcasted shape. """ bcast_shape = shapes[0] for shape in shapes[1:]: bcast_shape = tf.broadcast_dynamic_shape(bcast_shape, shape) return bcast_shape def cartesian_product(*args): """Cartesian product of input tensors. Args: *args: `Tensors` with rank 1. Returns: A `Tensor` of shape `[M, N]`, where `N` is the number of tensors in `args` and `M` is the product of the sizes of all the tensors in `args`. """ return tf.reshape(meshgrid(*args), [-1, len(args)]) def meshgrid(*args): """Return coordinate matrices from coordinate vectors. Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector fields over N-D grids, given one-dimensional coordinate arrays `x1, x2, ..., xn`. .. note:: Similar to `tf.meshgrid`, but uses matrix indexing and returns a stacked tensor (along axis -1) instead of a list of tensors. Args: *args: `Tensors` with rank 1. Returns: A `Tensor` of shape `[M1, M2, ..., Mn, N]`, where `N` is the number of tensors in `args` and `Mi = tf.size(args[i])`. """ return tf.stack(tf.meshgrid(*args, indexing='ij'), axis=-1) def ravel_multi_index(multi_indices, dims): """Converts an array of multi-indices into an array of flat indices. Args: multi_indices: A `Tensor` of shape `[..., N]` containing multi-indices into an `N`-dimensional tensor. dims: A `Tensor` of shape `[N]`. The shape of the tensor that `multi_indices` indexes into. Returns: A `Tensor` of shape `[...]` containing flat indices equivalent to `multi_indices`. """ strides = tf.math.cumprod(dims, exclusive=True, reverse=True) # pylint:disable=no-value-for-parameter return tf.math.reduce_sum(multi_indices * strides, axis=-1) def unravel_index(indices, dims): """Converts an array of flat indices into an array of multi-indices. Args: indices: A `Tensor` of shape `[...]` containing flat indices into an `N`-dimensional tensor. dims: A `Tensor` of shape `[N]`. The shape of the tensor that `indices` indexes into. Returns: A `Tensor` of shape `[..., N]` containing multi-indices equivalent to flat indices. """ return tf.transpose(tf.unravel_index(indices, dims)) def central_crop(tensor, shape): """Crop the central region of a tensor. Args: tensor: A `Tensor`. shape: A `Tensor`. The shape of the region to crop. The length of `shape` must be equal to or less than the rank of `tensor`. If the length of `shape` is less than the rank of tensor, the operation is applied along the last `len(shape)` dimensions of `tensor`. Any component of `shape` can be set to the special value -1 to leave the corresponding dimension unchanged. Returns: A `Tensor`. Has the same type as `tensor`. The centrally cropped tensor. Raises: ValueError: If `shape` has a rank other than 1. """ tensor = tf.convert_to_tensor(tensor) input_shape_tensor = tf.shape(tensor) target_shape_tensor = tf.convert_to_tensor(shape) # Static checks. if target_shape_tensor.shape.rank != 1: raise ValueError(f"`shape` must have rank 1. Received: {shape}") # Support a target shape with less dimensions than input. In that case, the # target shape applies to the last dimensions of input. if not isinstance(shape, tf.Tensor): shape = [-1] * (tensor.shape.rank - len(shape)) + list(shape) target_shape_tensor = tf.concat([ tf.tile([-1], [tf.rank(tensor) - tf.size(target_shape_tensor)]), target_shape_tensor], 0) # Dynamic checks. checks = [ tf.debugging.assert_greater_equal(tf.rank(tensor), tf.size(shape)), tf.debugging.assert_less_equal( target_shape_tensor, tf.shape(tensor), message=( "Target shape cannot be greater than input shape.")) ] with tf.control_dependencies(checks): tensor = tf.identity(tensor) # Crop the tensor. slice_begin = tf.where( target_shape_tensor >= 0, tf.math.maximum(input_shape_tensor - target_shape_tensor, 0) // 2, 0) slice_size = tf.where( target_shape_tensor >= 0, tf.math.minimum(input_shape_tensor, target_shape_tensor), -1) tensor = tf.slice(tensor, slice_begin, slice_size) # Set static shape, if possible. static_shape = _compute_static_output_shape(tensor.shape, shape) if static_shape is not None: tensor = tf.ensure_shape(tensor, static_shape) return tensor def resize_with_crop_or_pad(tensor, shape, padding_mode='constant'): """Crops and/or pads a tensor to a target shape. Pads symmetrically or crops centrally the input tensor as necessary to achieve the requested shape. Args: tensor: A `Tensor`. shape: A `Tensor`. The shape of the output tensor. The length of `shape` must be equal to or less than the rank of `tensor`. If the length of `shape` is less than the rank of tensor, the operation is applied along the last `len(shape)` dimensions of `tensor`. Any component of `shape` can be set to the special value -1 to leave the corresponding dimension unchanged. padding_mode: A `str`. Must be one of `'constant'`, `'reflect'` or `'symmetric'`. Returns: A `Tensor`. Has the same type as `tensor`. The symmetrically padded/cropped tensor. """ tensor = tf.convert_to_tensor(tensor) input_shape = tensor.shape input_shape_tensor = tf.shape(tensor) target_shape = shape target_shape_tensor = tf.convert_to_tensor(shape) # Support a target shape with less dimensions than input. In that case, the # target shape applies to the last dimensions of input. if not isinstance(target_shape, tf.Tensor): target_shape = [-1] * (input_shape.rank - len(shape)) + list(shape) target_shape_tensor = tf.concat([ tf.tile([-1], [tf.rank(tensor) - tf.size(shape)]), target_shape_tensor], 0) # Dynamic checks. checks = [ tf.debugging.assert_greater_equal(tf.rank(tensor), tf.size(target_shape_tensor)), ] with tf.control_dependencies(checks): tensor = tf.identity(tensor) # Pad the tensor. pad_left = tf.where( target_shape_tensor >= 0, tf.math.maximum(target_shape_tensor - input_shape_tensor, 0) // 2, 0) pad_right = tf.where( target_shape_tensor >= 0, (tf.math.maximum(target_shape_tensor - input_shape_tensor, 0) + 1) // 2, 0) tensor = tf.pad(tensor, tf.transpose(tf.stack([pad_left, pad_right])), # pylint: disable=no-value-for-parameter,unexpected-keyword-arg mode=padding_mode) # Crop the tensor. tensor = central_crop(tensor, target_shape) static_shape = _compute_static_output_shape(input_shape, target_shape) if static_shape is not None: tensor = tf.ensure_shape(tensor, static_shape) return tensor def _compute_static_output_shape(input_shape, target_shape): """Compute the static output shape of a resize operation. Args: input_shape: The static shape of the input tensor. target_shape: The target shape. Returns: The static output shape. """ output_shape = None if isinstance(target_shape, tf.Tensor): # If target shape is a tensor, we can't infer the output shape. return None # Get static tensor shape, after replacing -1 values by `None`. output_shape = tf.TensorShape( [s if s >= 0 else None for s in target_shape]) # Complete any unspecified target dimensions with those of the # input tensor, if known. output_shape = tf.TensorShape( [s_target or s_input for (s_target, s_input) in zip( output_shape.as_list(), input_shape.as_list())]) return output_shape def update_tensor(tensor, slices, value): """Updates the values of a tensor at the specified slices. This operator performs slice assignment. .. note:: Equivalent to `tensor[slices] = value`. .. warning:: TensorFlow does not support slice assignment because tensors are immutable. This operator works around this limitation by creating a new tensor, which may have performance implications. Args: tensor: A `tf.Tensor`. slices: The indices or slices. value: A `tf.Tensor`. Returns: An updated `tf.Tensor` with the same shape and type as `tensor`. """ # Using a private implementation in the TensorFlow NumPy API. # pylint: disable=protected-access return _with_index_update_helper(np_array_ops._UpdateMethod.UPDATE, tensor, slices, value) def _with_index_update_helper(update_method, a, slice_spec, updates): # pylint: disable=missing-param-doc """Implementation of ndarray._with_index_*.""" # Adapted from tensorflow/python/ops/numpy_ops/np_array_ops.py. # pylint: disable=protected-access if (isinstance(slice_spec, bool) or (isinstance(slice_spec, tf.Tensor) and slice_spec.dtype == tf.dtypes.bool) or (isinstance(slice_spec, (np.ndarray, tnp.ndarray)) and slice_spec.dtype == np.bool_)): slice_spec = tnp.nonzero(slice_spec) if not isinstance(slice_spec, tuple): slice_spec = np_array_ops._as_spec_tuple(slice_spec) return np_array_ops._slice_helper(a, slice_spec, update_method, updates) def map_fn(fn, elems, batch_dims=1, **kwargs): """Transforms `elems` by applying `fn` to each element. .. note:: Similar to `tf.map_fn`, but it supports unstacking along multiple batch dimensions. For the parameters, see `tf.map_fn`. The only difference is that there is an additional `batch_dims` keyword argument which allows specifying the number of batch dimensions. The default is 1, in which case this function is equal to `tf.map_fn`. """ # This function works by reshaping any number of batch dimensions into a # single batch dimension, calling the original `tf.map_fn`, and then # restoring the original batch dimensions. static_batch_dims = tf.get_static_value(batch_dims) # Get batch shapes. if static_batch_dims is None: # We don't know how many batch dimensions there are statically, so we can't # get the batch shape statically. static_batch_shapes = tf.nest.map_structure( lambda _: tf.TensorShape(None), elems) else: static_batch_shapes = tf.nest.map_structure( lambda x: x.shape[:static_batch_dims], elems) dynamic_batch_shapes = tf.nest.map_structure( lambda x: tf.shape(x)[:batch_dims], elems) # Flatten the batch dimensions. elems = tf.nest.map_structure( lambda x: tf.reshape( x, tf.concat([[-1], tf.shape(x)[batch_dims:]], 0)), elems) # Process each batch. output = tf.map_fn(fn, elems, **kwargs) # Unflatten the batch dimensions. output = tf.nest.map_structure( lambda x, dynamic_batch_shape: tf.reshape( x, tf.concat([dynamic_batch_shape, tf.shape(x)[1:]], 0)), output, dynamic_batch_shapes) # Set the static batch shapes on the output, if known. if static_batch_dims is not None: output = tf.nest.map_structure( lambda x, static_batch_shape: tf.ensure_shape( x, static_batch_shape.concatenate(x.shape[static_batch_dims:])), output, static_batch_shapes) return output def slice_along_axis(tensor, axis, start, length): """Slices a tensor along the specified axis.""" begin = tf.scatter_nd([[axis]], [start], [tensor.shape.rank]) size = tf.tensor_scatter_nd_update(tf.shape(tensor), [[axis]], [length]) return tf.slice(tensor, begin, size)