|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Functions to read, decode and pre-process input data for the Model.
|
| """
|
| import collections
|
| import functools
|
| import tensorflow as tf
|
| from tensorflow.contrib import slim
|
|
|
| import inception_preprocessing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| InputEndpoints = collections.namedtuple(
|
| 'InputEndpoints', ['images', 'images_orig', 'labels', 'labels_one_hot'])
|
|
|
|
|
|
|
|
|
|
|
|
|
| ShuffleBatchConfig = collections.namedtuple('ShuffleBatchConfig', [
|
| 'num_batching_threads', 'queue_capacity', 'min_after_dequeue'
|
| ])
|
|
|
| DEFAULT_SHUFFLE_CONFIG = ShuffleBatchConfig(
|
| num_batching_threads=8, queue_capacity=3000, min_after_dequeue=1000)
|
|
|
|
|
| def augment_image(image):
|
| """Augmentation the image with a random modification.
|
|
|
| Args:
|
| image: input Tensor image of rank 3, with the last dimension
|
| of size 3.
|
|
|
| Returns:
|
| Distorted Tensor image of the same shape.
|
| """
|
| with tf.compat.v1.variable_scope('AugmentImage'):
|
| height = image.get_shape().dims[0].value
|
| width = image.get_shape().dims[1].value
|
|
|
|
|
|
|
| bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
|
| image_size=tf.shape(input=image),
|
| bounding_boxes=tf.zeros([0, 0, 4]),
|
| min_object_covered=0.8,
|
| aspect_ratio_range=[0.8, 1.2],
|
| area_range=[0.8, 1.0],
|
| use_image_if_no_bounding_boxes=True)
|
| distorted_image = tf.slice(image, bbox_begin, bbox_size)
|
|
|
|
|
| distorted_image = inception_preprocessing.apply_with_random_selector(
|
| distorted_image,
|
| lambda x, method: tf.image.resize(x, [height, width], method),
|
| num_cases=4)
|
| distorted_image.set_shape([height, width, 3])
|
|
|
|
|
| distorted_image = inception_preprocessing.apply_with_random_selector(
|
| distorted_image,
|
| functools.partial(
|
| inception_preprocessing.distort_color, fast_mode=False),
|
| num_cases=4)
|
| distorted_image = tf.clip_by_value(distorted_image, -1.5, 1.5)
|
|
|
| return distorted_image
|
|
|
|
|
| def central_crop(image, crop_size):
|
| """Returns a central crop for the specified size of an image.
|
|
|
| Args:
|
| image: A tensor with shape [height, width, channels]
|
| crop_size: A tuple (crop_width, crop_height)
|
|
|
| Returns:
|
| A tensor of shape [crop_height, crop_width, channels].
|
| """
|
| with tf.compat.v1.variable_scope('CentralCrop'):
|
| target_width, target_height = crop_size
|
| image_height, image_width = tf.shape(
|
| input=image)[0], tf.shape(input=image)[1]
|
| assert_op1 = tf.Assert(
|
| tf.greater_equal(image_height, target_height),
|
| ['image_height < target_height', image_height, target_height])
|
| assert_op2 = tf.Assert(
|
| tf.greater_equal(image_width, target_width),
|
| ['image_width < target_width', image_width, target_width])
|
| with tf.control_dependencies([assert_op1, assert_op2]):
|
| offset_width = tf.cast((image_width - target_width) / 2, tf.int32)
|
| offset_height = tf.cast((image_height - target_height) / 2, tf.int32)
|
| return tf.image.crop_to_bounding_box(image, offset_height, offset_width,
|
| target_height, target_width)
|
|
|
|
|
| def preprocess_image(image, augment=False, central_crop_size=None,
|
| num_towers=4):
|
| """Normalizes image to have values in a narrow range around zero.
|
|
|
| Args:
|
| image: a [H x W x 3] uint8 tensor.
|
| augment: optional, if True do random image distortion.
|
| central_crop_size: A tuple (crop_width, crop_height).
|
| num_towers: optional, number of shots of the same image in the input image.
|
|
|
| Returns:
|
| A float32 tensor of shape [H x W x 3] with RGB values in the required
|
| range.
|
| """
|
| with tf.compat.v1.variable_scope('PreprocessImage'):
|
| image = tf.image.convert_image_dtype(image, dtype=tf.float32)
|
| if augment or central_crop_size:
|
| if num_towers == 1:
|
| images = [image]
|
| else:
|
| images = tf.split(value=image, num_or_size_splits=num_towers, axis=1)
|
| if central_crop_size:
|
| view_crop_size = (int(central_crop_size[0] / num_towers),
|
| central_crop_size[1])
|
| images = [central_crop(img, view_crop_size) for img in images]
|
| if augment:
|
| images = [augment_image(img) for img in images]
|
| image = tf.concat(images, 1)
|
|
|
| return image
|
|
|
|
|
| def get_data(dataset,
|
| batch_size,
|
| augment=False,
|
| central_crop_size=None,
|
| shuffle_config=None,
|
| shuffle=True):
|
| """Wraps calls to DatasetDataProviders and shuffle_batch.
|
|
|
| For more details about supported Dataset objects refer to datasets/fsns.py.
|
|
|
| Args:
|
| dataset: a slim.data.dataset.Dataset object.
|
| batch_size: number of samples per batch.
|
| augment: optional, if True does random image distortion.
|
| central_crop_size: A CharLogittuple (crop_width, crop_height).
|
| shuffle_config: A namedtuple ShuffleBatchConfig.
|
| shuffle: if True use data shuffling.
|
|
|
| Returns:
|
|
|
| """
|
| if not shuffle_config:
|
| shuffle_config = DEFAULT_SHUFFLE_CONFIG
|
|
|
| provider = slim.dataset_data_provider.DatasetDataProvider(
|
| dataset,
|
| shuffle=shuffle,
|
| common_queue_capacity=2 * batch_size,
|
| common_queue_min=batch_size)
|
| image_orig, label = provider.get(['image', 'label'])
|
|
|
| image = preprocess_image(
|
| image_orig, augment, central_crop_size, num_towers=dataset.num_of_views)
|
| label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes)
|
|
|
| images, images_orig, labels, labels_one_hot = (tf.compat.v1.train.shuffle_batch(
|
| [image, image_orig, label, label_one_hot],
|
| batch_size=batch_size,
|
| num_threads=shuffle_config.num_batching_threads,
|
| capacity=shuffle_config.queue_capacity,
|
| min_after_dequeue=shuffle_config.min_after_dequeue))
|
|
|
| return InputEndpoints(
|
| images=images,
|
| images_orig=images_orig,
|
| labels=labels,
|
| labels_one_hot=labels_one_hot)
|
|
|