|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| r"""Downloads and converts a particular dataset.
|
|
|
| Usage:
|
| ```shell
|
|
|
| $ python download_and_convert_data.py \
|
| --dataset_name=flowers \
|
| --dataset_dir=/tmp/flowers
|
|
|
| $ python download_and_convert_data.py \
|
| --dataset_name=cifar10 \
|
| --dataset_dir=/tmp/cifar10
|
|
|
| $ python download_and_convert_data.py \
|
| --dataset_name=mnist \
|
| --dataset_dir=/tmp/mnist
|
|
|
| $ python download_and_convert_data.py \
|
| --dataset_name=visualwakewords \
|
| --dataset_dir=/tmp/visualwakewords
|
|
|
| ```
|
| """
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import tensorflow.compat.v1 as tf
|
|
|
| from datasets import download_and_convert_cifar10
|
| from datasets import download_and_convert_flowers
|
| from datasets import download_and_convert_mnist
|
| from datasets import download_and_convert_visualwakewords
|
|
|
| FLAGS = tf.app.flags.FLAGS
|
|
|
| tf.app.flags.DEFINE_string(
|
| 'dataset_name',
|
| None,
|
| 'The name of the dataset to convert, one of "flowers", "cifar10", "mnist", "visualwakewords"'
|
| )
|
|
|
| tf.app.flags.DEFINE_string(
|
| 'dataset_dir',
|
| None,
|
| 'The directory where the output TFRecords and temporary files are saved.')
|
|
|
| tf.flags.DEFINE_float(
|
| 'small_object_area_threshold', 0.005,
|
| 'For --dataset_name=visualwakewords only. Threshold of fraction of image '
|
| 'area below which small objects are filtered')
|
|
|
| tf.flags.DEFINE_string(
|
| 'foreground_class_of_interest', 'person',
|
| 'For --dataset_name=visualwakewords only. Build a binary classifier based '
|
| 'on the presence or absence of this object in the image.')
|
|
|
|
|
| def main(_):
|
| if not FLAGS.dataset_name:
|
| raise ValueError('You must supply the dataset name with --dataset_name')
|
| if not FLAGS.dataset_dir:
|
| raise ValueError('You must supply the dataset directory with --dataset_dir')
|
|
|
| if FLAGS.dataset_name == 'flowers':
|
| download_and_convert_flowers.run(FLAGS.dataset_dir)
|
| elif FLAGS.dataset_name == 'cifar10':
|
| download_and_convert_cifar10.run(FLAGS.dataset_dir)
|
| elif FLAGS.dataset_name == 'mnist':
|
| download_and_convert_mnist.run(FLAGS.dataset_dir)
|
| elif FLAGS.dataset_name == 'visualwakewords':
|
| download_and_convert_visualwakewords.run(
|
| FLAGS.dataset_dir, FLAGS.small_object_area_threshold,
|
| FLAGS.foreground_class_of_interest)
|
| else:
|
| raise ValueError(
|
| 'dataset_name [%s] was not recognized.' % FLAGS.dataset_name)
|
|
|
| if __name__ == '__main__':
|
| tf.app.run()
|
|
|