|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Wrapper for providing semantic segmentaion data.
|
|
|
| The SegmentationDataset class provides both images and annotations (semantic
|
| segmentation and/or instance segmentation) for TensorFlow. Currently, we
|
| support the following datasets:
|
|
|
| 1. PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/voc2012/).
|
|
|
| PASCAL VOC 2012 semantic segmentation dataset annotates 20 foreground objects
|
| (e.g., bike, person, and so on) and leaves all the other semantic classes as
|
| one background class. The dataset contains 1464, 1449, and 1456 annotated
|
| images for the training, validation and test respectively.
|
|
|
| 2. Cityscapes dataset (https://www.cityscapes-dataset.com)
|
|
|
| The Cityscapes dataset contains 19 semantic labels (such as road, person, car,
|
| and so on) for urban street scenes.
|
|
|
| 3. ADE20K dataset (http://groups.csail.mit.edu/vision/datasets/ADE20K)
|
|
|
| The ADE20K dataset contains 150 semantic labels both urban street scenes and
|
| indoor scenes.
|
|
|
| References:
|
| M. Everingham, S. M. A. Eslami, L. V. Gool, C. K. I. Williams, J. Winn,
|
| and A. Zisserman, The pascal visual object classes challenge a retrospective.
|
| IJCV, 2014.
|
|
|
| M. Cordts, M. Omran, S. Ramos, T. Rehfeld, M. Enzweiler, R. Benenson,
|
| U. Franke, S. Roth, and B. Schiele, "The cityscapes dataset for semantic urban
|
| scene understanding," In Proc. of CVPR, 2016.
|
|
|
| B. Zhou, H. Zhao, X. Puig, S. Fidler, A. Barriuso, A. Torralba, "Scene Parsing
|
| through ADE20K dataset", In Proc. of CVPR, 2017.
|
| """
|
|
|
| import collections
|
| import os
|
| import tensorflow as tf
|
| from deeplab import common
|
| from deeplab import input_preprocess
|
|
|
|
|
| DatasetDescriptor = collections.namedtuple(
|
| 'DatasetDescriptor',
|
| [
|
| 'splits_to_sizes',
|
| 'num_classes',
|
|
|
|
|
|
|
|
|
| 'ignore_label',
|
| ])
|
|
|
| _CITYSCAPES_INFORMATION = DatasetDescriptor(
|
| splits_to_sizes={'train_fine': 2975,
|
| 'train_coarse': 22973,
|
| 'trainval_fine': 3475,
|
| 'trainval_coarse': 23473,
|
| 'val_fine': 500,
|
| 'test_fine': 1525},
|
| num_classes=19,
|
| ignore_label=255,
|
| )
|
|
|
| _PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
|
| splits_to_sizes={
|
| 'train': 1464,
|
| 'train_aug': 10582,
|
| 'trainval': 2913,
|
| 'val': 1449,
|
| },
|
| num_classes=21,
|
| ignore_label=255,
|
| )
|
|
|
| _ADE20K_INFORMATION = DatasetDescriptor(
|
| splits_to_sizes={
|
| 'train': 20210,
|
| 'val': 2000,
|
| },
|
| num_classes=151,
|
| ignore_label=0,
|
| )
|
|
|
| _DATASETS_INFORMATION = {
|
| 'cityscapes': _CITYSCAPES_INFORMATION,
|
| 'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
|
| 'ade20k': _ADE20K_INFORMATION,
|
| }
|
|
|
|
|
| _FILE_PATTERN = '%s-*'
|
|
|
|
|
| def get_cityscapes_dataset_name():
|
| return 'cityscapes'
|
|
|
|
|
| class Dataset(object):
|
| """Represents input dataset for deeplab model."""
|
|
|
| def __init__(self,
|
| dataset_name,
|
| split_name,
|
| dataset_dir,
|
| batch_size,
|
| crop_size,
|
| min_resize_value=None,
|
| max_resize_value=None,
|
| resize_factor=None,
|
| min_scale_factor=1.,
|
| max_scale_factor=1.,
|
| scale_factor_step_size=0,
|
| model_variant=None,
|
| num_readers=1,
|
| is_training=False,
|
| should_shuffle=False,
|
| should_repeat=False):
|
| """Initializes the dataset.
|
|
|
| Args:
|
| dataset_name: Dataset name.
|
| split_name: A train/val Split name.
|
| dataset_dir: The directory of the dataset sources.
|
| batch_size: Batch size.
|
| crop_size: The size used to crop the image and label.
|
| min_resize_value: Desired size of the smaller image side.
|
| max_resize_value: Maximum allowed size of the larger image side.
|
| resize_factor: Resized dimensions are multiple of factor plus one.
|
| min_scale_factor: Minimum scale factor value.
|
| max_scale_factor: Maximum scale factor value.
|
| scale_factor_step_size: The step size from min scale factor to max scale
|
| factor. The input is randomly scaled based on the value of
|
| (min_scale_factor, max_scale_factor, scale_factor_step_size).
|
| model_variant: Model variant (string) for choosing how to mean-subtract
|
| the images. See feature_extractor.network_map for supported model
|
| variants.
|
| num_readers: Number of readers for data provider.
|
| is_training: Boolean, if dataset is for training or not.
|
| should_shuffle: Boolean, if should shuffle the input data.
|
| should_repeat: Boolean, if should repeat the input data.
|
|
|
| Raises:
|
| ValueError: Dataset name and split name are not supported.
|
| """
|
| if dataset_name not in _DATASETS_INFORMATION:
|
| raise ValueError('The specified dataset is not supported yet.')
|
| self.dataset_name = dataset_name
|
|
|
| splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
|
|
|
| if split_name not in splits_to_sizes:
|
| raise ValueError('data split name %s not recognized' % split_name)
|
|
|
| if model_variant is None:
|
| tf.logging.warning('Please specify a model_variant. See '
|
| 'feature_extractor.network_map for supported model '
|
| 'variants.')
|
|
|
| self.split_name = split_name
|
| self.dataset_dir = dataset_dir
|
| self.batch_size = batch_size
|
| self.crop_size = crop_size
|
| self.min_resize_value = min_resize_value
|
| self.max_resize_value = max_resize_value
|
| self.resize_factor = resize_factor
|
| self.min_scale_factor = min_scale_factor
|
| self.max_scale_factor = max_scale_factor
|
| self.scale_factor_step_size = scale_factor_step_size
|
| self.model_variant = model_variant
|
| self.num_readers = num_readers
|
| self.is_training = is_training
|
| self.should_shuffle = should_shuffle
|
| self.should_repeat = should_repeat
|
|
|
| self.num_of_classes = _DATASETS_INFORMATION[self.dataset_name].num_classes
|
| self.ignore_label = _DATASETS_INFORMATION[self.dataset_name].ignore_label
|
|
|
| def _parse_function(self, example_proto):
|
| """Function to parse the example proto.
|
|
|
| Args:
|
| example_proto: Proto in the format of tf.Example.
|
|
|
| Returns:
|
| A dictionary with parsed image, label, height, width and image name.
|
|
|
| Raises:
|
| ValueError: Label is of wrong shape.
|
| """
|
|
|
|
|
|
|
|
|
|
|
| def _decode_image(content, channels):
|
| return tf.cond(
|
| tf.image.is_jpeg(content),
|
| lambda: tf.image.decode_jpeg(content, channels),
|
| lambda: tf.image.decode_png(content, channels))
|
|
|
| features = {
|
| 'image/encoded':
|
| tf.FixedLenFeature((), tf.string, default_value=''),
|
| 'image/filename':
|
| tf.FixedLenFeature((), tf.string, default_value=''),
|
| 'image/format':
|
| tf.FixedLenFeature((), tf.string, default_value='jpeg'),
|
| 'image/height':
|
| tf.FixedLenFeature((), tf.int64, default_value=0),
|
| 'image/width':
|
| tf.FixedLenFeature((), tf.int64, default_value=0),
|
| 'image/segmentation/class/encoded':
|
| tf.FixedLenFeature((), tf.string, default_value=''),
|
| 'image/segmentation/class/format':
|
| tf.FixedLenFeature((), tf.string, default_value='png'),
|
| }
|
|
|
| parsed_features = tf.parse_single_example(example_proto, features)
|
|
|
| image = _decode_image(parsed_features['image/encoded'], channels=3)
|
|
|
| label = None
|
| if self.split_name != common.TEST_SET:
|
| label = _decode_image(
|
| parsed_features['image/segmentation/class/encoded'], channels=1)
|
|
|
| image_name = parsed_features['image/filename']
|
| if image_name is None:
|
| image_name = tf.constant('')
|
|
|
| sample = {
|
| common.IMAGE: image,
|
| common.IMAGE_NAME: image_name,
|
| common.HEIGHT: parsed_features['image/height'],
|
| common.WIDTH: parsed_features['image/width'],
|
| }
|
|
|
| if label is not None:
|
| if label.get_shape().ndims == 2:
|
| label = tf.expand_dims(label, 2)
|
| elif label.get_shape().ndims == 3 and label.shape.dims[2] == 1:
|
| pass
|
| else:
|
| raise ValueError('Input label shape must be [height, width], or '
|
| '[height, width, 1].')
|
|
|
| label.set_shape([None, None, 1])
|
|
|
| sample[common.LABELS_CLASS] = label
|
|
|
| return sample
|
|
|
| def _preprocess_image(self, sample):
|
| """Preprocesses the image and label.
|
|
|
| Args:
|
| sample: A sample containing image and label.
|
|
|
| Returns:
|
| sample: Sample with preprocessed image and label.
|
|
|
| Raises:
|
| ValueError: Ground truth label not provided during training.
|
| """
|
| image = sample[common.IMAGE]
|
| label = sample[common.LABELS_CLASS]
|
|
|
| original_image, image, label = input_preprocess.preprocess_image_and_label(
|
| image=image,
|
| label=label,
|
| crop_height=self.crop_size[0],
|
| crop_width=self.crop_size[1],
|
| min_resize_value=self.min_resize_value,
|
| max_resize_value=self.max_resize_value,
|
| resize_factor=self.resize_factor,
|
| min_scale_factor=self.min_scale_factor,
|
| max_scale_factor=self.max_scale_factor,
|
| scale_factor_step_size=self.scale_factor_step_size,
|
| ignore_label=self.ignore_label,
|
| is_training=self.is_training,
|
| model_variant=self.model_variant)
|
|
|
| sample[common.IMAGE] = image
|
|
|
| if not self.is_training:
|
|
|
| sample[common.ORIGINAL_IMAGE] = original_image
|
|
|
| if label is not None:
|
| sample[common.LABEL] = label
|
|
|
|
|
|
|
| sample.pop(common.LABELS_CLASS, None)
|
|
|
| return sample
|
|
|
| def get_one_shot_iterator(self):
|
| """Gets an iterator that iterates across the dataset once.
|
|
|
| Returns:
|
| An iterator of type tf.data.Iterator.
|
| """
|
|
|
| files = self._get_all_files()
|
|
|
| dataset = (
|
| tf.data.TFRecordDataset(files, num_parallel_reads=self.num_readers)
|
| .map(self._parse_function, num_parallel_calls=self.num_readers)
|
| .map(self._preprocess_image, num_parallel_calls=self.num_readers))
|
|
|
| if self.should_shuffle:
|
| dataset = dataset.shuffle(buffer_size=100)
|
|
|
| if self.should_repeat:
|
| dataset = dataset.repeat()
|
| else:
|
| dataset = dataset.repeat(1)
|
|
|
| dataset = dataset.batch(self.batch_size).prefetch(self.batch_size)
|
| return dataset.make_one_shot_iterator()
|
|
|
| def _get_all_files(self):
|
| """Gets all the files to read data from.
|
|
|
| Returns:
|
| A list of input files.
|
| """
|
| file_pattern = _FILE_PATTERN
|
| file_pattern = os.path.join(self.dataset_dir,
|
| file_pattern % self.split_name)
|
| return tf.gfile.Glob(file_pattern)
|
|
|