Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Utils used to manipulate tensor shapes.""" | |
| import tensorflow as tf, tf_keras | |
| def assert_shape_equal(shape_a, shape_b): | |
| """Asserts that shape_a and shape_b are equal. | |
| If the shapes are static, raises a ValueError when the shapes | |
| mismatch. | |
| If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes | |
| mismatch. | |
| Args: | |
| shape_a: a list containing shape of the first tensor. | |
| shape_b: a list containing shape of the second tensor. | |
| Returns: | |
| Either a tf.no_op() when shapes are all static and a tf.assert_equal() op | |
| when the shapes are dynamic. | |
| Raises: | |
| ValueError: When shapes are both static and unequal. | |
| """ | |
| if (all(isinstance(dim, int) for dim in shape_a) and | |
| all(isinstance(dim, int) for dim in shape_b)): | |
| if shape_a != shape_b: | |
| raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) | |
| else: | |
| return tf.no_op() | |
| else: | |
| return tf.assert_equal(shape_a, shape_b) | |
| def combined_static_and_dynamic_shape(tensor): | |
| """Returns a list containing static and dynamic values for the dimensions. | |
| Returns a list of static and dynamic values for shape dimensions. This is | |
| useful to preserve static shapes when available in reshape operation. | |
| Args: | |
| tensor: A tensor of any type. | |
| Returns: | |
| A list of size tensor.shape.ndims containing integers or a scalar tensor. | |
| """ | |
| static_tensor_shape = tensor.shape.as_list() | |
| dynamic_tensor_shape = tf.shape(input=tensor) | |
| combined_shape = [] | |
| for index, dim in enumerate(static_tensor_shape): | |
| if dim is not None: | |
| combined_shape.append(dim) | |
| else: | |
| combined_shape.append(dynamic_tensor_shape[index]) | |
| return combined_shape | |
| def pad_or_clip_nd(tensor, output_shape): | |
| """Pad or Clip given tensor to the output shape. | |
| Args: | |
| tensor: Input tensor to pad or clip. | |
| output_shape: A list of integers / scalar tensors (or None for dynamic dim) | |
| representing the size to pad or clip each dimension of the input tensor. | |
| Returns: | |
| Input tensor padded and clipped to the output shape. | |
| """ | |
| tensor_shape = tf.shape(input=tensor) | |
| clip_size = [ | |
| tf.where(tensor_shape[i] - shape > 0, shape, -1) | |
| if shape is not None else -1 for i, shape in enumerate(output_shape) | |
| ] | |
| clipped_tensor = tf.slice( | |
| tensor, begin=tf.zeros(len(clip_size), dtype=tf.int32), size=clip_size) | |
| # Pad tensor if the shape of clipped tensor is smaller than the expected | |
| # shape. | |
| clipped_tensor_shape = tf.shape(input=clipped_tensor) | |
| trailing_paddings = [ | |
| shape - clipped_tensor_shape[i] if shape is not None else 0 | |
| for i, shape in enumerate(output_shape) | |
| ] | |
| paddings = tf.stack( | |
| [tf.zeros(len(trailing_paddings), dtype=tf.int32), trailing_paddings], | |
| axis=1) | |
| padded_tensor = tf.pad(tensor=clipped_tensor, paddings=paddings) | |
| output_static_shape = [ | |
| dim if not isinstance(dim, tf.Tensor) else None for dim in output_shape | |
| ] | |
| padded_tensor.set_shape(output_static_shape) | |
| return padded_tensor | |