Spaces:
Runtime error
Runtime error
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Operations for image patches.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import tensorflow.compat.v1 as tf | |
| def get_patch_mask(y, x, patch_size, image_shape): | |
| """Creates a 2D mask array for a square patch of a given size and location. | |
| The mask is created with its center at the y and x coordinates, which must be | |
| within the image. While the mask center must be within the image, the mask | |
| itself can be partially outside of it. If patch_size is an even number, then | |
| the mask is created with lower-valued coordinates first (top and left). | |
| Args: | |
| y: An integer or scalar int32 tensor. The vertical coordinate of the | |
| patch mask center. Must be within the range [0, image_height). | |
| x: An integer or scalar int32 tensor. The horizontal coordinate of the | |
| patch mask center. Must be within the range [0, image_width). | |
| patch_size: An integer or scalar int32 tensor. The square size of the | |
| patch mask. Must be at least 1. | |
| image_shape: A list or 1D int32 tensor representing the shape of the image | |
| to which the mask will correspond, with the first two values being image | |
| height and width. For example, [image_height, image_width] or | |
| [image_height, image_width, image_channels]. | |
| Returns: | |
| Boolean mask tensor of shape [image_height, image_width] with True values | |
| for the patch. | |
| Raises: | |
| tf.errors.InvalidArgumentError: if x is not in the range [0, image_width), y | |
| is not in the range [0, image_height), or patch_size is not at least 1. | |
| """ | |
| image_hw = image_shape[:2] | |
| mask_center_yx = tf.stack([y, x]) | |
| with tf.control_dependencies([ | |
| tf.debugging.assert_greater_equal( | |
| patch_size, 1, | |
| message='Patch size must be >= 1'), | |
| tf.debugging.assert_greater_equal( | |
| mask_center_yx, 0, | |
| message='Patch center (y, x) must be >= (0, 0)'), | |
| tf.debugging.assert_less( | |
| mask_center_yx, image_hw, | |
| message='Patch center (y, x) must be < image (h, w)') | |
| ]): | |
| mask_center_yx = tf.identity(mask_center_yx) | |
| half_patch_size = tf.cast(patch_size, dtype=tf.float32) / 2 | |
| start_yx = mask_center_yx - tf.cast(tf.floor(half_patch_size), dtype=tf.int32) | |
| end_yx = mask_center_yx + tf.cast(tf.ceil(half_patch_size), dtype=tf.int32) | |
| start_yx = tf.maximum(start_yx, 0) | |
| end_yx = tf.minimum(end_yx, image_hw) | |
| start_y = start_yx[0] | |
| start_x = start_yx[1] | |
| end_y = end_yx[0] | |
| end_x = end_yx[1] | |
| lower_pad = image_hw[0] - end_y | |
| upper_pad = start_y | |
| left_pad = start_x | |
| right_pad = image_hw[1] - end_x | |
| mask = tf.ones([end_y - start_y, end_x - start_x], dtype=tf.bool) | |
| return tf.pad(mask, [[upper_pad, lower_pad], [left_pad, right_pad]]) | |