fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of data preprocessing ops.
All preprocessing ops should return a data processing functors. A data
is represented as a dictionary of tensors, where field "image" is reserved
for 3D images (height x width x channels). The functors output dictionary with
field "image" being modified. Potentially, other fields can also be modified
or added.
"""
from typing import Optional, Tuple
import numpy as np
from scenic.dataset_lib.big_transfer.preprocessing import autoaugment
from scenic.dataset_lib.big_transfer.preprocessing import utils
from scenic.dataset_lib.big_transfer.registry import Registry
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from tensorflow_addons import image as image_utils
@Registry.register("preprocess_ops.color_distort", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_color_distortion():
"""Applies random brigthness/saturation/hue/contrast transformations."""
def _color_distortion(image):
image = tf.image.random_brightness(image, max_delta=128. / 255.)
image = tf.image.random_saturation(image, lower=0.1, upper=2.0)
image = tf.image.random_hue(image, max_delta=0.5)
image = tf.image.random_contrast(image, lower=0.1, upper=2.0)
return image
return _color_distortion
@Registry.register("preprocess_ops.random_brightness", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_random_brightness(max_delta=0.1):
"""Applies random brigthness transformations."""
# A random value in [-max_delta, +max_delta] is added to the image values.
# Small max_delta <1.0 assumes that the image values are within [0, 1].
def _random_brightness(image):
return tf.image.random_brightness(image, max_delta)
return _random_brightness
@Registry.register("preprocess_ops.random_saturation", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_random_saturation(lower=0.5, upper=2.0):
"""Applies random saturation transformations."""
# Multiplies saturation channel in HSV (with converting from/to RGB) with a
# random float value in [lower, upper].
def _random_saturation(image):
return tf.image.random_saturation(image, lower=lower, upper=upper)
return _random_saturation
@Registry.register("preprocess_ops.random_hue", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_random_hue(max_delta=0.1):
"""Applies random hue transformations."""
# Adds to hue channel in HSV (with converting from/to RGB) a random offset
# in [-max_delta, +max_delta].
def _random_hue(image):
return tf.image.random_hue(image, max_delta=max_delta)
return _random_hue
@Registry.register("preprocess_ops.random_contrast", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_random_contrast(lower=0.5, upper=2.0):
"""Applies random contrast transformations."""
# Stretches/shrinks value stddev (per channel) by multiplying with a random
# value in [lower, upper].
def _random_contrast(image):
return tf.image.random_contrast(image, lower=lower, upper=upper)
return _random_contrast
@Registry.register("preprocess_ops.decode", "function")
@utils.InKeyOutKey()
def get_decode(channels=3):
"""Decode an encoded image string, see tf.io.decode_image."""
def _decode(image): # pylint: disable=missing-docstring
# tf.io.decode_image does not set the shape correctly, so we use
# tf.io.deocde_jpeg, which also works for png, see
# https://github.com/tensorflow/tensorflow/issues/8551
return tf.io.decode_jpeg(image, channels=channels)
return _decode
@Registry.register("preprocess_ops.decode_grayscale", "function")
@utils.InKeyOutKey()
def get_decode_grayscale(channels=1):
"""Decode an encoded image string, see tf.io.decode_image."""
def _decode_gray(image): # pylint: disable=missing-docstring
# tf.io.decode_image does not set the shape correctly, so we use
# tf.io.deocde_jpeg, which also works for png, see
# https://github.com/tensorflow/tensorflow/issues/8551
return tf.io.decode_jpeg(image, channels=channels)
return _decode_gray
@Registry.register("preprocess_ops.pad", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_pad(pad_size):
"""Pads an image.
Args:
pad_size: either an integer u giving verticle and horizontal pad sizes u, or
a list or tuple [u, v] of integers where u and v are vertical and
horizontal pad sizes.
Returns:
A function for padding an image.
"""
pad_size = utils.maybe_repeat(pad_size, 2)
def _pad(image):
return tf.pad(
image, [[pad_size[0], pad_size[0]], [pad_size[1], pad_size[1]], [0, 0]])
return _pad
@Registry.register("preprocess_ops.resize", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_resize(resize_size, method=tf2.image.ResizeMethod.BILINEAR,
antialias=False):
"""Resizes image to a given size.
Args:
resize_size: either an integer H, where H is both the new height and width
of the resized image, or a list or tuple [H, W] of integers, where H and W
are new image"s height and width respectively.
method: The type of interpolation to apply when resizing.
antialias: Whether to use an anti-aliasing filter when downsampling an
image.
Returns:
A function for resizing an image.
"""
resize_size = utils.maybe_repeat(resize_size, 2)
def _resize(image):
"""Resizes image to a given size."""
# Note: use TF-2 version of tf.image.resize as the version in TF-1 is
# buggy: https://github.com/tensorflow/tensorflow/issues/6720.
# In particular it was not equivariant with rotation and lead to the network
# to learn a shortcut in self-supervised rotation task, if rotation was
# applied after resize.
dtype = image.dtype
image = tf2.image.resize(
images=image, size=resize_size, method=method, antialias=antialias)
return tf.cast(image, dtype)
return _resize
@Registry.register("preprocess_ops.resize_small", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_resize_small(smaller_size, method="area", antialias=True):
"""Resizes the smaller side to `smaller_size` keeping aspect ratio.
Args:
smaller_size: an integer, that represents a new size of the smaller side of
an input image.
method: the resize method. `area` is a meaningful, bwd-compat default.
antialias: See TF's image.resize method.
Returns:
A function, that resizes an image and preserves its aspect ratio.
"""
def _resize_small(image): # pylint: disable=missing-docstring
h, w = tf.shape(image)[0], tf.shape(image)[1]
# Figure out the necessary h/w.
ratio = (
tf.cast(smaller_size, tf.float32) /
tf.cast(tf.minimum(h, w), tf.float32))
h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32)
w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32)
dtype = image.dtype
image = tf2.image.resize(image, (h, w), method, antialias)
return tf.cast(image, dtype)
return _resize_small
@Registry.register("preprocess_ops.inception_crop", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_inception_crop(resize_size=None, area_min=5, area_max=100,
resize_method=tf2.image.ResizeMethod.BILINEAR,
resize_antialias=False):
"""Makes inception-style image crop.
Inception-style crop is a random image crop (its size and aspect ratio are
random) that was used for training Inception models, see
https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf.
Args:
resize_size: Resize image to [resize_size, resize_size] after crop.
area_min: minimal crop area.
area_max: maximal crop area.
resize_method: The type of interpolation to apply when resizing. Valid
values those accepted by tf.image.resize.
resize_antialias: Whether to use an anti-aliasing filter when downsampling
an image.
Returns:
A function, that applies inception crop.
"""
def _inception_crop(image): # pylint: disable=missing-docstring
begin, size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(image),
tf.zeros([0, 0, 4], tf.float32),
area_range=(area_min / 100, area_max / 100),
min_object_covered=0, # Don't enforce a minimum area.
use_image_if_no_bounding_boxes=True)
crop = tf.slice(image, begin, size)
# Unfortunately, the above operation loses the depth-dimension. So we need
# to restore it the manual way.
crop.set_shape([None, None, image.shape[-1]])
if resize_size:
crop = get_resize(
[resize_size, resize_size], resize_method, resize_antialias)(
{"image": crop})["image"]
return crop
return _inception_crop
@Registry.register("preprocess_ops.decode_jpeg_and_inception_crop", "function")
@utils.InKeyOutKey()
def get_decode_jpeg_and_inception_crop(
resize_size=None,
area_min=5,
area_max=100,
aspect_ratio_range=None,
resize_method=tf2.image.ResizeMethod.BILINEAR,
resize_antialias=False):
"""Decode jpeg string and make inception-style image crop.
Inception-style crop is a random image crop (its size and aspect ratio are
random) that was used for training Inception models, see
https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf.
Args:
resize_size: Resize image to [resize_size, resize_size] after crop.
area_min: minimal crop area.
area_max: maximal crop area.
aspect_ratio_range: An optional list of floats. Defaults to [0.75, 1.33].
The cropped area of the image must have an aspect ratio = width / height
within this range.
resize_method: The type of interpolation to apply when resizing. Valid
values those accepted by tf.image.resize.
resize_antialias: Whether to use an anti-aliasing filter when downsampling
an image.
Returns:
A function, that applies inception crop.
"""
def _inception_crop(image_data): # pylint: disable=missing-docstring
shape = tf.image.extract_jpeg_shape(image_data)
begin, size, _ = tf.image.sample_distorted_bounding_box(
shape,
tf.zeros([0, 0, 4], tf.float32),
area_range=(area_min / 100, area_max / 100),
min_object_covered=0, # Don't enforce a minimum area.
aspect_ratio_range=aspect_ratio_range,
use_image_if_no_bounding_boxes=True)
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(begin)
target_height, target_width, _ = tf.unstack(size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3)
if resize_size:
image = get_resize(
[resize_size, resize_size], resize_method, resize_antialias)(
{"image": image})["image"]
return image
return _inception_crop
@Registry.register("preprocess_ops.decode_jpeg_and_center_crop", "function")
@utils.InKeyOutKey()
def get_decode_jpeg_and_center_crop(crop_size=None):
"""Decode jpeg string and make a center image crop.
Args:
crop_size: Crop image to [crop_size, crop_size].
Returns:
A function that applies center crop.
"""
crop_size = utils.maybe_repeat(crop_size, 2)
def _decode_and_center_crop(image_data): # pylint: disable=missing-docstring
shape = tf.image.extract_jpeg_shape(image_data)
target_height, target_width = crop_size
offset_y = (shape[0] - target_height) // 2
offset_x = (shape[1] - target_width) // 2
# Crop the image to the specified bounding box.
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3)
image.set_shape([target_height, target_width, 3])
return image
return _decode_and_center_crop
@Registry.register("preprocess_ops.decode_jpeg_and_random_crop", "function")
@utils.InKeyOutKey()
def get_decode_jpeg_and_random_crop(crop_size=None):
"""Decode jpeg string and make a center image crop.
Args:
crop_size: Crop image to [crop_size, crop_size].
Returns:
A function that applies center crop.
"""
crop_size = utils.maybe_repeat(crop_size, 2)
def _decode_and_random_crop(image_data): # pylint: disable=missing-docstring
shape = tf.image.extract_jpeg_shape(image_data)[:2]
target_height, target_width = crop_size
limit = shape - crop_size + 1
offset = tf.random.uniform([2], 0, tf.int32.max, dtype=tf.int32) % limit
# Crop the image to the specified bounding box.
crop_window = tf.stack([offset[0], offset[1], target_height, target_width])
image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3)
image.set_shape([target_height, target_width, 3])
return image
return _decode_and_random_crop
@Registry.register("preprocess_ops.random_crop", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_random_crop(crop_size):
"""Makes a random crop of a given size.
Args:
crop_size: either an integer H, where H is both the height and width of the
random crop, or a list or tuple [H, W] of integers, where H and W are
height and width of the random crop respectively.
Returns:
A function, that applies random crop.
"""
crop_size = utils.maybe_repeat(crop_size, 2)
def _crop(image):
return tf.random_crop(image, [crop_size[0], crop_size[1], image.shape[-1]])
return _crop
@Registry.register("preprocess_ops.central_crop", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_central_crop(crop_size):
"""Makes central crop of a given size.
Args:
crop_size: either an integer H, where H is both the height and width of the
central crop, or a list or tuple [H, W] of integers, where H and W are
height and width of the central crop respectively.
Returns:
A function, that applies central crop.
"""
crop_size = utils.maybe_repeat(crop_size, 2)
def _crop(image):
h, w = crop_size[0], crop_size[1]
dy = (tf.shape(image)[0] - h) // 2
dx = (tf.shape(image)[1] - w) // 2
return tf.image.crop_to_bounding_box(image, dy, dx, h, w)
return _crop
@Registry.register("preprocess_ops.central_crop_longer", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_central_crop_longer():
"""Center crop the longer side so that the image becomes a square.
Args:
Returns:
A function, that applies central crop.
"""
def _crop(image):
shape = tf.shape(image)
h, w = shape[0], shape[1]
crop_fn = tf.image.crop_to_bounding_box
return tf.cond(
h > w,
lambda: crop_fn(image, h // 2 - w // 2, 0, w, w),
lambda: crop_fn(image, 0, w // 2 - h // 2, h, h))
return _crop
@Registry.register("preprocess_ops.flip_lr", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_random_flip_lr():
"""Flips an image horizontally with probability 50%."""
def _random_flip_lr_pp(image):
return tf.image.random_flip_left_right(image)
return _random_flip_lr_pp
@Registry.register("preprocess_ops.flip_ud", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_random_flip_ud():
"""Flips an image vertically with probability 50%."""
def _random_flip_ud_pp(image):
return tf.image.random_flip_up_down(image)
return _random_flip_ud_pp
@Registry.register("preprocess_ops.random_rotate", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_random_rotation(min_angle=0, max_angle=360):
"""Randomly rotate an image."""
if min_angle > max_angle:
raise ValueError("min_angle (%f) must be lower than max_angle (%f)" %
(min_angle, max_angle))
# Convert to radians.
min_angle = np.radians(min_angle)
max_angle = np.radians(max_angle)
def _random_rotation(image):
"""Rotation function."""
num_dims = len(image.shape)
if num_dims in [3, 4]:
batch_size = tf.shape(image)[0] if num_dims == 4 else 1
else:
raise ValueError("Tensor \"image\" should have 3 or 4 dimensions.")
random_angles = tf.random.uniform(
shape=(batch_size,), minval=min_angle, maxval=max_angle)
return image_utils.rotate(images=image, angles=random_angles)
return _random_rotation
@Registry.register("preprocess_ops.random_rotate90", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_random_rotation90():
"""Randomly rotate an image by multiples of 90 degrees."""
def _random_rotation90(image):
"""Rotation function."""
num_rotations = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
return tf.image.rot90(image, k=num_rotations)
return _random_rotation90
@Registry.register("preprocess_ops.rotate", "function")
def get_rotate(create_labels=None):
"""Returns a function that does 90deg rotations and sets according labels.
Args:
create_labels: create new labels to the default label field in the input
dictionary. It should be set to one of ['rotation', 'supervised', None].
Returns:
A function, that applies rotation preprocess.
"""
def _four_rots(img):
"""Rotates an image four times, with 90 degrees between each rotation."""
return tf.stack([
img,
tf.transpose(tf.reverse_v2(img, [1]), [1, 0, 2]),
tf.reverse_v2(img, [0, 1]),
tf.reverse_v2(tf.transpose(img, [1, 0, 2]), [1]),
])
def _rotate_pp(data):
"""Rotate preprocessing function applied on data dictionary input."""
assert create_labels in [
"rotation", "supervised", None
], ("create_labels:{} must be one of ['rotation', 'supervised', None]."
.format(create_labels))
# Creates labels in the same structure as images.
if create_labels == "rotation":
data["label"] = tf.constant([0, 1, 2, 3])
# Duplicates the original supervised label four times.
elif create_labels == "supervised":
if "label" in data:
data["label"] = tf.stack(tf.tile([data["label"]], [4]))
# Creates rotated images and rot labels.
data["image"] = _four_rots(data["image"])
data["rot_label"] = tf.constant([0, 1, 2, 3])
return data
return _rotate_pp
@Registry.register("preprocess_ops.value_range", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing(output_dtype=tf.float32)
def get_value_range(vmin=-1, vmax=1, in_min=0, in_max=255.0, clip_values=False):
"""Transforms a [in_min,in_max] image to [vmin,vmax] range.
Input ranges in_min/in_max can be equal-size lists to rescale the invidudal
channels independently.
Args:
vmin: A scalar. Output max value.
vmax: A scalar. Output min value.
in_min: A scalar or a list of input min values to scale. If a list, the
length should match to the number of channels in the image.
in_max: A scalar or a list of input max values to scale. If a list, the
length should match to the number of channels in the image.
clip_values: Whether to clip the output values to the provided ranges.
Returns:
A function to rescale the values.
"""
def _value_range(image):
"""Scales values in given range."""
in_min_t = tf.constant(in_min, tf.float32)
in_max_t = tf.constant(in_max, tf.float32)
image = tf.cast(image, tf.float32)
image = (image - in_min_t) / (in_max_t - in_min_t)
image = vmin + image * (vmax - vmin)
if clip_values:
image = tf.clip_by_value(image, vmin, vmax)
return image
return _value_range
@Registry.register("preprocess_ops.value_range_mc", "function")
def get_value_range_mc(vmin, vmax, *args):
"""Independent multi-channel rescaling."""
if len(args) % 2:
raise ValueError("Additional args must be list of even length giving "
"`in_max` and `in_min` concatenated")
num_channels = len(args) // 2
in_min = args[:num_channels]
in_max = args[-num_channels:]
return get_value_range(vmin, vmax, in_min, in_max)
@Registry.register("preprocess_ops.delete_field", "function")
def get_delete_field(key):
def _delete_field(datum):
if key in datum:
del datum[key]
return datum
return _delete_field
@Registry.register("preprocess_ops.replicate", "function")
@utils.InKeyOutKey()
def get_replicate(num_replicas=2):
"""Replicates an image `num_replicas` times along a new batch dimension."""
def _replicate(image):
tiles = [num_replicas] + [1] * len(image.shape)
return tf.tile(image[None], tiles)
return _replicate
@Registry.register("preprocess_ops.standardize", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing(output_dtype=tf.float32)
def get_standardize(mean, std):
"""Standardize an image with the given mean and standard deviation."""
def _standardize(image):
image = tf.cast(image, dtype=tf.float32)
return (image - mean) / std
return _standardize
@Registry.register("preprocess_ops.select_channels", "function")
@utils.InKeyOutKey()
@utils.BatchedImagePreprocessing()
def get_select_channels(channels):
"""Returns function to select specified channels."""
def _select_channels(image):
"""Returns a subset of available channels."""
return tf.gather(image, channels, axis=-1)
return _select_channels
@Registry.register("preprocess_ops.extract_patches", "function")
@utils.InKeyOutKey()
def get_extract_patches(patch_size, stride):
"""Extracts image patches.
Args:
patch_size: patch size.
stride: patches stride.
Returns:
A function for extracting patches.
"""
def _extract_patches(image):
"""Extracts image patches."""
h, w, c = image.get_shape().as_list()
count_h = h // stride
count_w = w // stride
# pyformat: disable
image = tf.extract_image_patches(image[None],
[1, patch_size, patch_size, 1],
[1, stride, stride, 1],
[1, 1, 1, 1],
padding="VALID")
# pyformat: enable
return tf.reshape(image, [count_h * count_w, patch_size, patch_size, c])
return _extract_patches
@Registry.register("preprocess_ops.onehot", "function")
def get_onehot(depth,
key="labels",
key_result=None,
multi=True,
on=1.0,
off=0.0):
"""One-hot encodes the input.
Args:
depth: Length of the one-hot vector (how many classes).
key: Key of the data to be one-hot encoded.
key_result: Key under which to store the result (same as `key` if None).
multi: If there are multiple labels, whether to merge them into the same
"multi-hot" vector (True) or keep them as an extra dimension (False).
on: Value to fill in for the positive label (default: 1).
off: Value to fill in for negative labels (default: 0).
Returns:
Data dictionary.
"""
def _onehot(data):
# When there's more than one label, this is significantly more efficient
# than using tf.one_hot followed by tf.reduce_max; we tested.
labels = data[key]
if labels.shape.rank > 0 and multi:
# Currently, the assertion below is only used for datasets with single
# labels. In a multi-label dataset either `on` or `off` should be computed
# dynamically to yield the correct sum, when the number of labels varies.
x = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(labels)[0]), (depth,))
x = tf.clip_by_value(x, 0, 1) * (on - off) + off
else:
assert np.isclose(on + off * (depth - 1), 1), (
"All on and off values must sum to 1")
x = tf.one_hot(labels, depth, on_value=on, off_value=off)
data[key_result or key] = x
return data
return _onehot
@Registry.register("preprocess_ops.keep", "function")
def get_keep(*keys):
"""Keeps only the given keys."""
def _keep(data):
return {k: v for k, v in data.items() if k in keys}
return _keep
@Registry.register("preprocess_ops.drop", "function")
def get_drop(*keys):
"""Drops the given keys."""
def _drop(data):
return {k: v for k, v in data.items() if k not in keys}
return _drop
@Registry.register("preprocess_ops.copy", "function")
def get_copy(inkey, outkey):
"""Copies value of `inkey` into `outkey`."""
def _copy(data):
data[outkey] = data[inkey]
return data
return _copy
@Registry.register("preprocess_ops.randaug", "function")
@utils.InKeyOutKey()
def get_randaug(num_layers: int = 2, magnitude: int = 10):
"""Creates a function that applies RandAugment.
RandAugment is from the paper https://arxiv.org/abs/1909.13719,
Args:
num_layers: Integer, the number of augmentation transformations to apply
sequentially to an image. Represented as (N) in the paper. Usually best
values will be in the range [1, 3].
magnitude: Integer, shared magnitude across all augmentation operations.
Represented as (M) in the paper. Usually best values are in the range [5,
30].
Returns:
A function that applies RandAugment.
"""
def _randaug(image):
return autoaugment.distort_image_with_randaugment(
image=image,
num_layers=num_layers,
magnitude=magnitude,
)
return _randaug
@Registry.register("preprocess_ops.patchify", "function")
@utils.InKeyOutKey()
def patchify(patch_size: Tuple[int, int], stride: Tuple[int, int]):
"""Patchifies image.
If image is of size (h, w, c), patchify it into (h//p*w//p, p*p*c)
Args:
patch_size: Integer.
stride: Integer.
Returns:
A function that applies RandAugment.
"""
def _extract_patches(image):
"""Extracts image patches."""
h, w, _ = image.get_shape().as_list()
count_h = h // stride[0]
count_w = w // stride[1]
# pyformat: disable
image = tf.extract_image_patches(image[None],
[1, patch_size[0], patch_size[1], 1],
[1, stride[0], stride[1], 1],
[1, 1, 1, 1],
padding="VALID")
# pyformat: enable
return tf.reshape(image, [count_h * count_w, -1])
return _extract_patches