| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Array manipulation operations.""" |
| |
|
| | import numpy as np |
| | import tensorflow as tf |
| | import tensorflow.experimental.numpy as tnp |
| | from tensorflow.python.ops.numpy_ops import np_array_ops |
| |
|
| |
|
| | def broadcast_static_shapes(*shapes): |
| | """Computes the shape of a broadcast given known shapes. |
| | |
| | Like `tf.broadcast_static_shape`, but accepts any number of shapes. |
| | |
| | Args: |
| | *shapes: Two or more `TensorShapes`. |
| | |
| | Returns: |
| | A `TensorShape` representing the broadcasted shape. |
| | """ |
| | bcast_shape = shapes[0] |
| | for shape in shapes[1:]: |
| | bcast_shape = tf.broadcast_static_shape(bcast_shape, shape) |
| | return bcast_shape |
| |
|
| |
|
| | def broadcast_dynamic_shapes(*shapes): |
| | """Computes the shape of a broadcast given symbolic shapes. |
| | |
| | Like `tf.broadcast_dynamic_shape`, but accepts any number of shapes. |
| | |
| | Args: |
| | shapes: Two or more rank-1 integer `Tensors` representing the input shapes. |
| | |
| | Returns: |
| | A rank-1 integer `Tensor` representing the broadcasted shape. |
| | """ |
| | bcast_shape = shapes[0] |
| | for shape in shapes[1:]: |
| | bcast_shape = tf.broadcast_dynamic_shape(bcast_shape, shape) |
| | return bcast_shape |
| |
|
| |
|
| | def cartesian_product(*args): |
| | """Cartesian product of input tensors. |
| | |
| | Args: |
| | *args: `Tensors` with rank 1. |
| | |
| | Returns: |
| | A `Tensor` of shape `[M, N]`, where `N` is the number of tensors in `args` |
| | and `M` is the product of the sizes of all the tensors in `args`. |
| | """ |
| | return tf.reshape(meshgrid(*args), [-1, len(args)]) |
| |
|
| |
|
| | def meshgrid(*args): |
| | """Return coordinate matrices from coordinate vectors. |
| | |
| | Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector |
| | fields over N-D grids, given one-dimensional coordinate arrays |
| | `x1, x2, ..., xn`. |
| | |
| | .. note:: |
| | Similar to `tf.meshgrid`, but uses matrix indexing and returns a stacked |
| | tensor (along axis -1) instead of a list of tensors. |
| | |
| | Args: |
| | *args: `Tensors` with rank 1. |
| | |
| | Returns: |
| | A `Tensor` of shape `[M1, M2, ..., Mn, N]`, where `N` is the number of |
| | tensors in `args` and `Mi = tf.size(args[i])`. |
| | """ |
| | return tf.stack(tf.meshgrid(*args, indexing='ij'), axis=-1) |
| |
|
| |
|
| | def ravel_multi_index(multi_indices, dims): |
| | """Converts an array of multi-indices into an array of flat indices. |
| | |
| | Args: |
| | multi_indices: A `Tensor` of shape `[..., N]` containing multi-indices into |
| | an `N`-dimensional tensor. |
| | dims: A `Tensor` of shape `[N]`. The shape of the tensor that |
| | `multi_indices` indexes into. |
| | |
| | Returns: |
| | A `Tensor` of shape `[...]` containing flat indices equivalent to |
| | `multi_indices`. |
| | """ |
| | strides = tf.math.cumprod(dims, exclusive=True, reverse=True) |
| | return tf.math.reduce_sum(multi_indices * strides, axis=-1) |
| |
|
| |
|
| | def unravel_index(indices, dims): |
| | """Converts an array of flat indices into an array of multi-indices. |
| | |
| | Args: |
| | indices: A `Tensor` of shape `[...]` containing flat indices into an |
| | `N`-dimensional tensor. |
| | dims: A `Tensor` of shape `[N]`. The shape of the tensor that |
| | `indices` indexes into. |
| | |
| | Returns: |
| | A `Tensor` of shape `[..., N]` containing multi-indices equivalent to flat |
| | indices. |
| | """ |
| | return tf.transpose(tf.unravel_index(indices, dims)) |
| |
|
| |
|
| | def central_crop(tensor, shape): |
| | """Crop the central region of a tensor. |
| | |
| | Args: |
| | tensor: A `Tensor`. |
| | shape: A `Tensor`. The shape of the region to crop. The length of `shape` |
| | must be equal to or less than the rank of `tensor`. If the length of |
| | `shape` is less than the rank of tensor, the operation is applied along |
| | the last `len(shape)` dimensions of `tensor`. Any component of `shape` can |
| | be set to the special value -1 to leave the corresponding dimension |
| | unchanged. |
| | |
| | Returns: |
| | A `Tensor`. Has the same type as `tensor`. The centrally cropped tensor. |
| | |
| | Raises: |
| | ValueError: If `shape` has a rank other than 1. |
| | """ |
| | tensor = tf.convert_to_tensor(tensor) |
| | input_shape_tensor = tf.shape(tensor) |
| | target_shape_tensor = tf.convert_to_tensor(shape) |
| |
|
| | |
| | if target_shape_tensor.shape.rank != 1: |
| | raise ValueError(f"`shape` must have rank 1. Received: {shape}") |
| |
|
| | |
| | |
| | if not isinstance(shape, tf.Tensor): |
| | shape = [-1] * (tensor.shape.rank - len(shape)) + list(shape) |
| | target_shape_tensor = tf.concat([ |
| | tf.tile([-1], [tf.rank(tensor) - tf.size(target_shape_tensor)]), |
| | target_shape_tensor], 0) |
| |
|
| | |
| | checks = [ |
| | tf.debugging.assert_greater_equal(tf.rank(tensor), tf.size(shape)), |
| | tf.debugging.assert_less_equal( |
| | target_shape_tensor, tf.shape(tensor), message=( |
| | "Target shape cannot be greater than input shape.")) |
| | ] |
| | with tf.control_dependencies(checks): |
| | tensor = tf.identity(tensor) |
| |
|
| | |
| | slice_begin = tf.where( |
| | target_shape_tensor >= 0, |
| | tf.math.maximum(input_shape_tensor - target_shape_tensor, 0) // 2, |
| | 0) |
| | slice_size = tf.where( |
| | target_shape_tensor >= 0, |
| | tf.math.minimum(input_shape_tensor, target_shape_tensor), |
| | -1) |
| | tensor = tf.slice(tensor, slice_begin, slice_size) |
| |
|
| | |
| | static_shape = _compute_static_output_shape(tensor.shape, shape) |
| | if static_shape is not None: |
| | tensor = tf.ensure_shape(tensor, static_shape) |
| |
|
| | return tensor |
| |
|
| |
|
| | def resize_with_crop_or_pad(tensor, shape, padding_mode='constant'): |
| | """Crops and/or pads a tensor to a target shape. |
| | |
| | Pads symmetrically or crops centrally the input tensor as necessary to achieve |
| | the requested shape. |
| | |
| | Args: |
| | tensor: A `Tensor`. |
| | shape: A `Tensor`. The shape of the output tensor. The length of `shape` |
| | must be equal to or less than the rank of `tensor`. If the length of |
| | `shape` is less than the rank of tensor, the operation is applied along |
| | the last `len(shape)` dimensions of `tensor`. Any component of `shape` can |
| | be set to the special value -1 to leave the corresponding dimension |
| | unchanged. |
| | padding_mode: A `str`. Must be one of `'constant'`, `'reflect'` or |
| | `'symmetric'`. |
| | |
| | Returns: |
| | A `Tensor`. Has the same type as `tensor`. The symmetrically padded/cropped |
| | tensor. |
| | """ |
| | tensor = tf.convert_to_tensor(tensor) |
| | input_shape = tensor.shape |
| | input_shape_tensor = tf.shape(tensor) |
| | target_shape = shape |
| | target_shape_tensor = tf.convert_to_tensor(shape) |
| |
|
| | |
| | |
| | if not isinstance(target_shape, tf.Tensor): |
| | target_shape = [-1] * (input_shape.rank - len(shape)) + list(shape) |
| | target_shape_tensor = tf.concat([ |
| | tf.tile([-1], [tf.rank(tensor) - tf.size(shape)]), |
| | target_shape_tensor], 0) |
| |
|
| | |
| | checks = [ |
| | tf.debugging.assert_greater_equal(tf.rank(tensor), |
| | tf.size(target_shape_tensor)), |
| | ] |
| | with tf.control_dependencies(checks): |
| | tensor = tf.identity(tensor) |
| |
|
| | |
| | pad_left = tf.where( |
| | target_shape_tensor >= 0, |
| | tf.math.maximum(target_shape_tensor - input_shape_tensor, 0) // 2, |
| | 0) |
| | pad_right = tf.where( |
| | target_shape_tensor >= 0, |
| | (tf.math.maximum(target_shape_tensor - input_shape_tensor, 0) + 1) // 2, |
| | 0) |
| |
|
| | tensor = tf.pad(tensor, tf.transpose(tf.stack([pad_left, pad_right])), |
| | mode=padding_mode) |
| |
|
| | |
| | tensor = central_crop(tensor, target_shape) |
| |
|
| | static_shape = _compute_static_output_shape(input_shape, target_shape) |
| | if static_shape is not None: |
| | tensor = tf.ensure_shape(tensor, static_shape) |
| |
|
| | return tensor |
| |
|
| |
|
| | def _compute_static_output_shape(input_shape, target_shape): |
| | """Compute the static output shape of a resize operation. |
| | |
| | Args: |
| | input_shape: The static shape of the input tensor. |
| | target_shape: The target shape. |
| | |
| | Returns: |
| | The static output shape. |
| | """ |
| | output_shape = None |
| |
|
| | if isinstance(target_shape, tf.Tensor): |
| | |
| | return None |
| |
|
| | |
| | output_shape = tf.TensorShape( |
| | [s if s >= 0 else None for s in target_shape]) |
| |
|
| | |
| | |
| | output_shape = tf.TensorShape( |
| | [s_target or s_input for (s_target, s_input) in zip( |
| | output_shape.as_list(), input_shape.as_list())]) |
| |
|
| | return output_shape |
| |
|
| |
|
| | def update_tensor(tensor, slices, value): |
| | """Updates the values of a tensor at the specified slices. |
| | |
| | This operator performs slice assignment. |
| | |
| | .. note:: |
| | Equivalent to `tensor[slices] = value`. |
| | |
| | .. warning:: |
| | TensorFlow does not support slice assignment because tensors are immutable. |
| | This operator works around this limitation by creating a new tensor, which |
| | may have performance implications. |
| | |
| | Args: |
| | tensor: A `tf.Tensor`. |
| | slices: The indices or slices. |
| | value: A `tf.Tensor`. |
| | |
| | Returns: |
| | An updated `tf.Tensor` with the same shape and type as `tensor`. |
| | """ |
| | |
| | |
| | return _with_index_update_helper(np_array_ops._UpdateMethod.UPDATE, |
| | tensor, slices, value) |
| |
|
| |
|
| | def _with_index_update_helper(update_method, a, slice_spec, updates): |
| | """Implementation of ndarray._with_index_*.""" |
| | |
| | |
| | if (isinstance(slice_spec, bool) or (isinstance(slice_spec, tf.Tensor) and |
| | slice_spec.dtype == tf.dtypes.bool) or |
| | (isinstance(slice_spec, (np.ndarray, tnp.ndarray)) and |
| | slice_spec.dtype == np.bool_)): |
| | slice_spec = tnp.nonzero(slice_spec) |
| |
|
| | if not isinstance(slice_spec, tuple): |
| | slice_spec = np_array_ops._as_spec_tuple(slice_spec) |
| |
|
| | return np_array_ops._slice_helper(a, slice_spec, update_method, updates) |
| |
|
| |
|
| | def map_fn(fn, elems, batch_dims=1, **kwargs): |
| | """Transforms `elems` by applying `fn` to each element. |
| | |
| | .. note:: |
| | Similar to `tf.map_fn`, but it supports unstacking along multiple batch |
| | dimensions. |
| | |
| | For the parameters, see `tf.map_fn`. The only difference is that there is an |
| | additional `batch_dims` keyword argument which allows specifying the number |
| | of batch dimensions. The default is 1, in which case this function is equal |
| | to `tf.map_fn`. |
| | """ |
| | |
| | |
| | |
| | static_batch_dims = tf.get_static_value(batch_dims) |
| |
|
| | |
| | if static_batch_dims is None: |
| | |
| | |
| | static_batch_shapes = tf.nest.map_structure( |
| | lambda _: tf.TensorShape(None), elems) |
| | else: |
| | static_batch_shapes = tf.nest.map_structure( |
| | lambda x: x.shape[:static_batch_dims], elems) |
| | dynamic_batch_shapes = tf.nest.map_structure( |
| | lambda x: tf.shape(x)[:batch_dims], elems) |
| |
|
| | |
| | elems = tf.nest.map_structure( |
| | lambda x: tf.reshape( |
| | x, tf.concat([[-1], tf.shape(x)[batch_dims:]], 0)), elems) |
| |
|
| | |
| | output = tf.map_fn(fn, elems, **kwargs) |
| |
|
| | |
| | output = tf.nest.map_structure( |
| | lambda x, dynamic_batch_shape: tf.reshape( |
| | x, tf.concat([dynamic_batch_shape, tf.shape(x)[1:]], 0)), |
| | output, dynamic_batch_shapes) |
| |
|
| | |
| | if static_batch_dims is not None: |
| | output = tf.nest.map_structure( |
| | lambda x, static_batch_shape: tf.ensure_shape( |
| | x, static_batch_shape.concatenate(x.shape[static_batch_dims:])), |
| | output, static_batch_shapes) |
| |
|
| | return output |
| |
|
| |
|
| | def slice_along_axis(tensor, axis, start, length): |
| | """Slices a tensor along the specified axis.""" |
| | begin = tf.scatter_nd([[axis]], [start], [tensor.shape.rank]) |
| | size = tf.tensor_scatter_nd_update(tf.shape(tensor), [[axis]], [length]) |
| | return tf.slice(tensor, begin, size) |
| |
|