Spaces:
Sleeping
Sleeping
| # Copyright 2021 University College London. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Array manipulation operations.""" | |
| import numpy as np | |
| import tensorflow as tf | |
| import tensorflow.experimental.numpy as tnp | |
| from tensorflow.python.ops.numpy_ops import np_array_ops | |
| def broadcast_static_shapes(*shapes): | |
| """Computes the shape of a broadcast given known shapes. | |
| Like `tf.broadcast_static_shape`, but accepts any number of shapes. | |
| Args: | |
| *shapes: Two or more `TensorShapes`. | |
| Returns: | |
| A `TensorShape` representing the broadcasted shape. | |
| """ | |
| bcast_shape = shapes[0] | |
| for shape in shapes[1:]: | |
| bcast_shape = tf.broadcast_static_shape(bcast_shape, shape) | |
| return bcast_shape | |
| def broadcast_dynamic_shapes(*shapes): | |
| """Computes the shape of a broadcast given symbolic shapes. | |
| Like `tf.broadcast_dynamic_shape`, but accepts any number of shapes. | |
| Args: | |
| shapes: Two or more rank-1 integer `Tensors` representing the input shapes. | |
| Returns: | |
| A rank-1 integer `Tensor` representing the broadcasted shape. | |
| """ | |
| bcast_shape = shapes[0] | |
| for shape in shapes[1:]: | |
| bcast_shape = tf.broadcast_dynamic_shape(bcast_shape, shape) | |
| return bcast_shape | |
| def cartesian_product(*args): | |
| """Cartesian product of input tensors. | |
| Args: | |
| *args: `Tensors` with rank 1. | |
| Returns: | |
| A `Tensor` of shape `[M, N]`, where `N` is the number of tensors in `args` | |
| and `M` is the product of the sizes of all the tensors in `args`. | |
| """ | |
| return tf.reshape(meshgrid(*args), [-1, len(args)]) | |
| def meshgrid(*args): | |
| """Return coordinate matrices from coordinate vectors. | |
| Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector | |
| fields over N-D grids, given one-dimensional coordinate arrays | |
| `x1, x2, ..., xn`. | |
| .. note:: | |
| Similar to `tf.meshgrid`, but uses matrix indexing and returns a stacked | |
| tensor (along axis -1) instead of a list of tensors. | |
| Args: | |
| *args: `Tensors` with rank 1. | |
| Returns: | |
| A `Tensor` of shape `[M1, M2, ..., Mn, N]`, where `N` is the number of | |
| tensors in `args` and `Mi = tf.size(args[i])`. | |
| """ | |
| return tf.stack(tf.meshgrid(*args, indexing='ij'), axis=-1) | |
| def ravel_multi_index(multi_indices, dims): | |
| """Converts an array of multi-indices into an array of flat indices. | |
| Args: | |
| multi_indices: A `Tensor` of shape `[..., N]` containing multi-indices into | |
| an `N`-dimensional tensor. | |
| dims: A `Tensor` of shape `[N]`. The shape of the tensor that | |
| `multi_indices` indexes into. | |
| Returns: | |
| A `Tensor` of shape `[...]` containing flat indices equivalent to | |
| `multi_indices`. | |
| """ | |
| strides = tf.math.cumprod(dims, exclusive=True, reverse=True) # pylint:disable=no-value-for-parameter | |
| return tf.math.reduce_sum(multi_indices * strides, axis=-1) | |
| def unravel_index(indices, dims): | |
| """Converts an array of flat indices into an array of multi-indices. | |
| Args: | |
| indices: A `Tensor` of shape `[...]` containing flat indices into an | |
| `N`-dimensional tensor. | |
| dims: A `Tensor` of shape `[N]`. The shape of the tensor that | |
| `indices` indexes into. | |
| Returns: | |
| A `Tensor` of shape `[..., N]` containing multi-indices equivalent to flat | |
| indices. | |
| """ | |
| return tf.transpose(tf.unravel_index(indices, dims)) | |
| def central_crop(tensor, shape): | |
| """Crop the central region of a tensor. | |
| Args: | |
| tensor: A `Tensor`. | |
| shape: A `Tensor`. The shape of the region to crop. The length of `shape` | |
| must be equal to or less than the rank of `tensor`. If the length of | |
| `shape` is less than the rank of tensor, the operation is applied along | |
| the last `len(shape)` dimensions of `tensor`. Any component of `shape` can | |
| be set to the special value -1 to leave the corresponding dimension | |
| unchanged. | |
| Returns: | |
| A `Tensor`. Has the same type as `tensor`. The centrally cropped tensor. | |
| Raises: | |
| ValueError: If `shape` has a rank other than 1. | |
| """ | |
| tensor = tf.convert_to_tensor(tensor) | |
| input_shape_tensor = tf.shape(tensor) | |
| target_shape_tensor = tf.convert_to_tensor(shape) | |
| # Static checks. | |
| if target_shape_tensor.shape.rank != 1: | |
| raise ValueError(f"`shape` must have rank 1. Received: {shape}") | |
| # Support a target shape with less dimensions than input. In that case, the | |
| # target shape applies to the last dimensions of input. | |
| if not isinstance(shape, tf.Tensor): | |
| shape = [-1] * (tensor.shape.rank - len(shape)) + list(shape) | |
| target_shape_tensor = tf.concat([ | |
| tf.tile([-1], [tf.rank(tensor) - tf.size(target_shape_tensor)]), | |
| target_shape_tensor], 0) | |
| # Dynamic checks. | |
| checks = [ | |
| tf.debugging.assert_greater_equal(tf.rank(tensor), tf.size(shape)), | |
| tf.debugging.assert_less_equal( | |
| target_shape_tensor, tf.shape(tensor), message=( | |
| "Target shape cannot be greater than input shape.")) | |
| ] | |
| with tf.control_dependencies(checks): | |
| tensor = tf.identity(tensor) | |
| # Crop the tensor. | |
| slice_begin = tf.where( | |
| target_shape_tensor >= 0, | |
| tf.math.maximum(input_shape_tensor - target_shape_tensor, 0) // 2, | |
| 0) | |
| slice_size = tf.where( | |
| target_shape_tensor >= 0, | |
| tf.math.minimum(input_shape_tensor, target_shape_tensor), | |
| -1) | |
| tensor = tf.slice(tensor, slice_begin, slice_size) | |
| # Set static shape, if possible. | |
| static_shape = _compute_static_output_shape(tensor.shape, shape) | |
| if static_shape is not None: | |
| tensor = tf.ensure_shape(tensor, static_shape) | |
| return tensor | |
| def resize_with_crop_or_pad(tensor, shape, padding_mode='constant'): | |
| """Crops and/or pads a tensor to a target shape. | |
| Pads symmetrically or crops centrally the input tensor as necessary to achieve | |
| the requested shape. | |
| Args: | |
| tensor: A `Tensor`. | |
| shape: A `Tensor`. The shape of the output tensor. The length of `shape` | |
| must be equal to or less than the rank of `tensor`. If the length of | |
| `shape` is less than the rank of tensor, the operation is applied along | |
| the last `len(shape)` dimensions of `tensor`. Any component of `shape` can | |
| be set to the special value -1 to leave the corresponding dimension | |
| unchanged. | |
| padding_mode: A `str`. Must be one of `'constant'`, `'reflect'` or | |
| `'symmetric'`. | |
| Returns: | |
| A `Tensor`. Has the same type as `tensor`. The symmetrically padded/cropped | |
| tensor. | |
| """ | |
| tensor = tf.convert_to_tensor(tensor) | |
| input_shape = tensor.shape | |
| input_shape_tensor = tf.shape(tensor) | |
| target_shape = shape | |
| target_shape_tensor = tf.convert_to_tensor(shape) | |
| # Support a target shape with less dimensions than input. In that case, the | |
| # target shape applies to the last dimensions of input. | |
| if not isinstance(target_shape, tf.Tensor): | |
| target_shape = [-1] * (input_shape.rank - len(shape)) + list(shape) | |
| target_shape_tensor = tf.concat([ | |
| tf.tile([-1], [tf.rank(tensor) - tf.size(shape)]), | |
| target_shape_tensor], 0) | |
| # Dynamic checks. | |
| checks = [ | |
| tf.debugging.assert_greater_equal(tf.rank(tensor), | |
| tf.size(target_shape_tensor)), | |
| ] | |
| with tf.control_dependencies(checks): | |
| tensor = tf.identity(tensor) | |
| # Pad the tensor. | |
| pad_left = tf.where( | |
| target_shape_tensor >= 0, | |
| tf.math.maximum(target_shape_tensor - input_shape_tensor, 0) // 2, | |
| 0) | |
| pad_right = tf.where( | |
| target_shape_tensor >= 0, | |
| (tf.math.maximum(target_shape_tensor - input_shape_tensor, 0) + 1) // 2, | |
| 0) | |
| tensor = tf.pad(tensor, tf.transpose(tf.stack([pad_left, pad_right])), # pylint: disable=no-value-for-parameter,unexpected-keyword-arg | |
| mode=padding_mode) | |
| # Crop the tensor. | |
| tensor = central_crop(tensor, target_shape) | |
| static_shape = _compute_static_output_shape(input_shape, target_shape) | |
| if static_shape is not None: | |
| tensor = tf.ensure_shape(tensor, static_shape) | |
| return tensor | |
| def _compute_static_output_shape(input_shape, target_shape): | |
| """Compute the static output shape of a resize operation. | |
| Args: | |
| input_shape: The static shape of the input tensor. | |
| target_shape: The target shape. | |
| Returns: | |
| The static output shape. | |
| """ | |
| output_shape = None | |
| if isinstance(target_shape, tf.Tensor): | |
| # If target shape is a tensor, we can't infer the output shape. | |
| return None | |
| # Get static tensor shape, after replacing -1 values by `None`. | |
| output_shape = tf.TensorShape( | |
| [s if s >= 0 else None for s in target_shape]) | |
| # Complete any unspecified target dimensions with those of the | |
| # input tensor, if known. | |
| output_shape = tf.TensorShape( | |
| [s_target or s_input for (s_target, s_input) in zip( | |
| output_shape.as_list(), input_shape.as_list())]) | |
| return output_shape | |
| def update_tensor(tensor, slices, value): | |
| """Updates the values of a tensor at the specified slices. | |
| This operator performs slice assignment. | |
| .. note:: | |
| Equivalent to `tensor[slices] = value`. | |
| .. warning:: | |
| TensorFlow does not support slice assignment because tensors are immutable. | |
| This operator works around this limitation by creating a new tensor, which | |
| may have performance implications. | |
| Args: | |
| tensor: A `tf.Tensor`. | |
| slices: The indices or slices. | |
| value: A `tf.Tensor`. | |
| Returns: | |
| An updated `tf.Tensor` with the same shape and type as `tensor`. | |
| """ | |
| # Using a private implementation in the TensorFlow NumPy API. | |
| # pylint: disable=protected-access | |
| return _with_index_update_helper(np_array_ops._UpdateMethod.UPDATE, | |
| tensor, slices, value) | |
| def _with_index_update_helper(update_method, a, slice_spec, updates): # pylint: disable=missing-param-doc | |
| """Implementation of ndarray._with_index_*.""" | |
| # Adapted from tensorflow/python/ops/numpy_ops/np_array_ops.py. | |
| # pylint: disable=protected-access | |
| if (isinstance(slice_spec, bool) or (isinstance(slice_spec, tf.Tensor) and | |
| slice_spec.dtype == tf.dtypes.bool) or | |
| (isinstance(slice_spec, (np.ndarray, tnp.ndarray)) and | |
| slice_spec.dtype == np.bool_)): | |
| slice_spec = tnp.nonzero(slice_spec) | |
| if not isinstance(slice_spec, tuple): | |
| slice_spec = np_array_ops._as_spec_tuple(slice_spec) | |
| return np_array_ops._slice_helper(a, slice_spec, update_method, updates) | |
| def map_fn(fn, elems, batch_dims=1, **kwargs): | |
| """Transforms `elems` by applying `fn` to each element. | |
| .. note:: | |
| Similar to `tf.map_fn`, but it supports unstacking along multiple batch | |
| dimensions. | |
| For the parameters, see `tf.map_fn`. The only difference is that there is an | |
| additional `batch_dims` keyword argument which allows specifying the number | |
| of batch dimensions. The default is 1, in which case this function is equal | |
| to `tf.map_fn`. | |
| """ | |
| # This function works by reshaping any number of batch dimensions into a | |
| # single batch dimension, calling the original `tf.map_fn`, and then | |
| # restoring the original batch dimensions. | |
| static_batch_dims = tf.get_static_value(batch_dims) | |
| # Get batch shapes. | |
| if static_batch_dims is None: | |
| # We don't know how many batch dimensions there are statically, so we can't | |
| # get the batch shape statically. | |
| static_batch_shapes = tf.nest.map_structure( | |
| lambda _: tf.TensorShape(None), elems) | |
| else: | |
| static_batch_shapes = tf.nest.map_structure( | |
| lambda x: x.shape[:static_batch_dims], elems) | |
| dynamic_batch_shapes = tf.nest.map_structure( | |
| lambda x: tf.shape(x)[:batch_dims], elems) | |
| # Flatten the batch dimensions. | |
| elems = tf.nest.map_structure( | |
| lambda x: tf.reshape( | |
| x, tf.concat([[-1], tf.shape(x)[batch_dims:]], 0)), elems) | |
| # Process each batch. | |
| output = tf.map_fn(fn, elems, **kwargs) | |
| # Unflatten the batch dimensions. | |
| output = tf.nest.map_structure( | |
| lambda x, dynamic_batch_shape: tf.reshape( | |
| x, tf.concat([dynamic_batch_shape, tf.shape(x)[1:]], 0)), | |
| output, dynamic_batch_shapes) | |
| # Set the static batch shapes on the output, if known. | |
| if static_batch_dims is not None: | |
| output = tf.nest.map_structure( | |
| lambda x, static_batch_shape: tf.ensure_shape( | |
| x, static_batch_shape.concatenate(x.shape[static_batch_dims:])), | |
| output, static_batch_shapes) | |
| return output | |
| def slice_along_axis(tensor, axis, start, length): | |
| """Slices a tensor along the specified axis.""" | |
| begin = tf.scatter_nd([[axis]], [start], [tensor.shape.rank]) | |
| size = tf.tensor_scatter_nd_update(tf.shape(tensor), [[axis]], [length]) | |
| return tf.slice(tensor, begin, size) | |