|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """A common dataset reader."""
|
| import dataclasses
|
| import random
|
| from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union
|
|
|
| from absl import logging
|
| import tensorflow as tf, tf_keras
|
| import tensorflow_datasets as tfds
|
|
|
| from official.core import config_definitions as cfg
|
|
|
|
|
| def _get_random_integer():
|
| return random.randint(0, (1 << 31) - 1)
|
|
|
|
|
| def _maybe_map_fn(dataset: tf.data.Dataset,
|
| fn: Optional[Callable[..., Any]] = None) -> tf.data.Dataset:
|
| """Calls dataset.map if a valid function is passed in."""
|
| return dataset if fn is None else dataset.map(
|
| fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
|
|
|
|
| def match_files(input_path: Union[Sequence[str], str]) -> List[str]:
|
| """Matches files from an input_path."""
|
| matched_files = []
|
|
|
| usage = ('`input_path` should be either (1) a str indicating a file '
|
| 'path/pattern, or (2) a str indicating multiple file '
|
| 'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
|
| '"a,b,c", or (3) a list of str, each of which is a file '
|
| 'path/pattern or multiple file paths/patterns separated by '
|
| 'comma, but got: %s')
|
| if isinstance(input_path, str):
|
| input_path_list = [input_path]
|
| elif isinstance(input_path, (list, tuple)):
|
| if any(not isinstance(x, str) for x in input_path):
|
| raise ValueError(usage % input_path)
|
| input_path_list = input_path
|
| else:
|
| raise ValueError(usage % input_path)
|
|
|
| for input_path in input_path_list:
|
| input_patterns = input_path.strip().split(',')
|
| for input_pattern in input_patterns:
|
| input_pattern = input_pattern.strip()
|
| if not input_pattern:
|
| continue
|
| if '*' in input_pattern or '?' in input_pattern:
|
| tmp_matched_files = tf.io.gfile.glob(input_pattern)
|
| if not tmp_matched_files:
|
| raise ValueError('%s does not match any files.' % input_pattern)
|
| matched_files.extend(tmp_matched_files)
|
| else:
|
| matched_files.append(input_pattern)
|
|
|
| if not matched_files:
|
| raise ValueError('%s does not match any files.' % input_path)
|
|
|
| return matched_files
|
|
|
|
|
| def _read_files_then_shard(matched_files: List[str],
|
| dataset_fn,
|
| input_context: Optional[
|
| tf.distribute.InputContext] = None,
|
| sharding: bool = False,
|
| repeat: bool = False) -> tf.data.Dataset:
|
| """Sends all data files to every worker and then shard by data."""
|
| dataset = dataset_fn(matched_files)
|
|
|
|
|
|
|
|
|
| options = tf.data.Options()
|
| options.experimental_distribute.auto_shard_policy = (
|
| tf.data.experimental.AutoShardPolicy.OFF)
|
| dataset = dataset.with_options(options)
|
|
|
|
|
| if sharding and input_context and (input_context.num_input_pipelines > 1):
|
| dataset = dataset.shard(input_context.num_input_pipelines,
|
| input_context.input_pipeline_id)
|
|
|
| if repeat:
|
| dataset = dataset.repeat()
|
| return dataset
|
|
|
|
|
| def _shard_files_then_read(matched_files: List[str],
|
| dataset_fn,
|
| input_context: Optional[
|
| tf.distribute.InputContext] = None,
|
| seed: Optional[Union[int, tf.Tensor]] = None,
|
| is_training: bool = False,
|
| sharding: bool = False,
|
| cache: bool = False,
|
| cycle_length: Optional[int] = None,
|
| block_length: Optional[int] = None,
|
| deterministic: bool = False) -> tf.data.Dataset:
|
| """Shards the data files and then sent a split to every worker to read."""
|
| dataset = tf.data.Dataset.from_tensor_slices(matched_files)
|
|
|
|
|
|
|
|
|
| if is_training:
|
|
|
|
|
| if sharding and seed is None:
|
| seed = _get_random_integer()
|
| dataset = dataset.shuffle(
|
| len(matched_files),
|
| seed=seed,
|
| reshuffle_each_iteration=True if not cache else False)
|
|
|
|
|
|
|
| if sharding and input_context and (input_context.num_input_pipelines > 1):
|
| dataset = dataset.shard(input_context.num_input_pipelines,
|
| input_context.input_pipeline_id)
|
|
|
|
|
| if is_training and not cache:
|
| dataset = dataset.repeat()
|
|
|
| dataset = dataset.interleave(
|
| map_func=dataset_fn,
|
| cycle_length=cycle_length,
|
| block_length=block_length,
|
| num_parallel_calls=(cycle_length
|
| if cycle_length else tf.data.experimental.AUTOTUNE),
|
| deterministic=deterministic)
|
| return dataset
|
|
|
|
|
| def _read_tfds(tfds_name: Text,
|
| tfds_data_dir: Text,
|
| tfds_split: Text,
|
| tfds_skip_decoding_feature: Text,
|
| tfds_as_supervised: bool,
|
| input_context: Optional[tf.distribute.InputContext] = None,
|
| seed: Optional[Union[int, tf.Tensor]] = None,
|
| is_training: bool = False,
|
| cache: bool = False,
|
| cycle_length: Optional[int] = None,
|
| block_length: Optional[int] = None) -> tf.data.Dataset:
|
| """Reads a dataset from tfds."""
|
| repeat_filenames = is_training and not cache
|
| read_config = tfds.ReadConfig(
|
| interleave_cycle_length=cycle_length,
|
| interleave_block_length=block_length,
|
| input_context=input_context,
|
| shuffle_seed=seed,
|
| repeat_filenames=repeat_filenames,
|
|
|
| assert_cardinality=not repeat_filenames,
|
| skip_prefetch=True)
|
|
|
| decoders = {}
|
| if tfds_skip_decoding_feature:
|
| for skip_feature in tfds_skip_decoding_feature.split(','):
|
| decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
|
|
|
| if tfds_name.startswith('mldataset.'):
|
| dataset = tfds.load(name=tfds_name,
|
| split=tfds_split,
|
| as_supervised=tfds_as_supervised,
|
| decoders=decoders if decoders else None,
|
| read_config=read_config)
|
| else:
|
| builder = tfds.builder(tfds_name, data_dir=tfds_data_dir)
|
| if builder.info.splits:
|
| num_shards = len(builder.info.splits[tfds_split].file_instructions)
|
| else:
|
|
|
| num_shards = 1
|
| load_kwargs = dict(
|
| name=tfds_name, download=True, split=tfds_split,
|
| shuffle_files=is_training, as_supervised=tfds_as_supervised,
|
| decoders=decoders if decoders else None)
|
| if tfds_data_dir:
|
| load_kwargs.update({'data_dir': tfds_data_dir})
|
|
|
| if input_context and num_shards < input_context.num_input_pipelines:
|
|
|
|
|
|
|
| read_config = dataclasses.replace(read_config, input_context=None)
|
| load_kwargs.update({'read_config': read_config})
|
| dataset = tfds.load(**load_kwargs)
|
| dataset = dataset.shard(input_context.num_input_pipelines,
|
| input_context.input_pipeline_id)
|
| else:
|
| load_kwargs.update({'read_config': read_config})
|
| dataset = tfds.load(**load_kwargs)
|
| return dataset
|
|
|
|
|
| class InputReader:
|
| """Input reader that returns a tf.data.Dataset instance."""
|
|
|
|
|
|
|
| static_randnum = _get_random_integer()
|
|
|
| def __init__(
|
| self,
|
| params: cfg.DataConfig,
|
| dataset_fn=tf.data.TFRecordDataset,
|
| decoder_fn: Optional[Callable[..., Any]] = None,
|
| combine_fn: Optional[Callable[..., Any]] = None,
|
| sample_fn: Optional[Callable[..., Any]] = None,
|
| parser_fn: Optional[Callable[..., Any]] = None,
|
| filter_fn: Optional[Callable[..., tf.Tensor]] = None,
|
| transform_and_batch_fn: Optional[
|
| Callable[
|
| [tf.data.Dataset, Optional[tf.distribute.InputContext]],
|
| tf.data.Dataset,
|
| ]
|
| ] = None,
|
| postprocess_fn: Optional[Callable[..., Any]] = None,
|
| ):
|
| """Initializes an InputReader instance.
|
|
|
| Args:
|
| params: A config_definitions.DataConfig object.
|
| dataset_fn: A `tf.data.Dataset` that consumes the input files. For
|
| example, it can be `tf.data.TFRecordDataset`.
|
| decoder_fn: An optional `callable` that takes the serialized data string
|
| and decodes them into the raw tensor dictionary.
|
| combine_fn: An optional `callable` that takes a dictionarty of
|
| `tf.data.Dataset` objects as input and outputs a combined dataset. It
|
| will be executed after the decoder_fn and before the sample_fn.
|
| sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
|
| input and outputs the transformed dataset. It performs sampling on the
|
| decoded raw tensors dict before the parser_fn.
|
| parser_fn: An optional `callable` that takes the decoded raw tensors dict
|
| and parse them into a dictionary of tensors that can be consumed by the
|
| model. It will be executed after decoder_fn.
|
| filter_fn: An optional `callable` mapping a dataset element to a boolean.
|
| It will be executed after parser_fn.
|
| transform_and_batch_fn: An optional `callable` that takes a
|
| `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
|
| input, and returns a `tf.data.Dataset` object. It will be executed after
|
| `parser_fn` to transform and batch the dataset; if None, after
|
| `parser_fn` is executed, the dataset will be batched into per-replica
|
| batch size.
|
| postprocess_fn: A optional `callable` that processes batched tensors. It
|
| will be executed after batching.
|
| """
|
| if params.input_path and params.tfds_name:
|
| raise ValueError('At most one of `input_path` and `tfds_name` can be '
|
| 'specified, but got %s and %s.' %
|
| (params.input_path, params.tfds_name))
|
|
|
| if (isinstance(params.input_path, cfg.base_config.Config) or
|
| isinstance(params.tfds_name, cfg.base_config.Config)
|
| ) and combine_fn is None:
|
| raise ValueError(
|
| 'A combine_fn is required if `input_path` or `tfds_name` is a dict.')
|
|
|
| self._tfds_name = params.tfds_name
|
| self._tfds_data_dir = params.tfds_data_dir
|
| self._matched_files = None
|
| if not params.input_path:
|
|
|
| if not params.tfds_split:
|
| raise ValueError(
|
| '`tfds_name` is %s, but `tfds_split` is not specified.' %
|
| params.tfds_name)
|
| else:
|
| self._matched_files = self.get_files(params.input_path)
|
|
|
| self._global_batch_size = params.global_batch_size
|
| self._is_training = params.is_training
|
| self._drop_remainder = params.drop_remainder
|
| self._shuffle_buffer_size = params.shuffle_buffer_size
|
| self._cache = params.cache
|
| self._cycle_length = params.cycle_length
|
| self._block_length = params.block_length
|
| self._deterministic = params.deterministic
|
| self._sharding = params.sharding
|
| self._tfds_split = params.tfds_split
|
| self._tfds_as_supervised = params.tfds_as_supervised
|
| self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature
|
|
|
| self._dataset_fn = dataset_fn
|
| self._decoder_fn = decoder_fn
|
| self._combine_fn = combine_fn
|
| self._sample_fn = sample_fn
|
| self._parser_fn = parser_fn
|
| self._transform_and_batch_fn = transform_and_batch_fn
|
| self._postprocess_fn = postprocess_fn
|
| self._filter_fn = filter_fn
|
| self._seed = params.seed
|
| self._prefetch_buffer_size = (
|
| params.prefetch_buffer_size or tf.data.experimental.AUTOTUNE)
|
| self._autotune_algorithm = params.autotune_algorithm
|
|
|
|
|
|
|
|
|
|
|
| if params.enable_tf_data_service:
|
| self._seed = None
|
| self._sharding = False
|
|
|
| self._enable_tf_data_service = (
|
| params.enable_tf_data_service and params.tf_data_service_address)
|
| self._tf_data_service_address = params.tf_data_service_address
|
| self._enable_shared_tf_data_service_between_parallel_trainers = (
|
| params.enable_shared_tf_data_service_between_parallel_trainers)
|
| self._apply_tf_data_service_before_batching = (
|
| params.apply_tf_data_service_before_batching)
|
| self._trainer_id = params.trainer_id
|
| if self._enable_tf_data_service:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| self._tf_data_service_job_name = (
|
| f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
|
| f'{self.static_randnum}')
|
| self._enable_round_robin_tf_data_service = params.get(
|
| 'enable_round_robin_tf_data_service', False)
|
| if self._enable_shared_tf_data_service_between_parallel_trainers:
|
|
|
|
|
|
|
|
|
|
|
|
|
| self._tf_data_service_job_name = (
|
| f'{params.tf_data_service_job_name}_{self.static_randnum}')
|
|
|
| def get_files(self, input_path):
|
| """Gets matched files. Can be overridden by subclasses."""
|
| if not input_path:
|
| return None
|
|
|
| if isinstance(input_path, cfg.base_config.Config):
|
| matched_files = {}
|
| for k, v in input_path.as_dict().items():
|
| matched_files[k] = match_files(v)
|
|
|
| else:
|
| matched_files = match_files(input_path)
|
| return matched_files
|
|
|
| def _read_data_source(
|
| self,
|
| matched_files: Union[Dict[str, List[str]], List[str]],
|
| dataset_fn,
|
| input_context: Optional[tf.distribute.InputContext] = None,
|
| ):
|
| """Reads the data source (files/tfds) to a dataset."""
|
|
|
| def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
|
| if len(files) > 1:
|
| if input_context and (len(files) < input_context.num_input_pipelines):
|
| logging.warn(
|
| (
|
| 'The number of files %d is less than the number of input '
|
| 'pipelines %d. We will send all input files to every worker. '
|
| 'Please consider sharding your data into more files.'
|
| ),
|
| len(files),
|
| input_context.num_input_pipelines,
|
| )
|
| return _read_files_then_shard(
|
| files,
|
| dataset_fn,
|
| input_context,
|
| sharding=self._sharding,
|
| repeat=self._is_training and not self._cache)
|
| else:
|
| return _shard_files_then_read(
|
| files,
|
| dataset_fn,
|
| input_context,
|
| seed=self._seed,
|
| is_training=self._is_training,
|
| sharding=self._sharding,
|
| cache=self._cache,
|
| cycle_length=self._cycle_length,
|
| block_length=self._block_length,
|
| deterministic=self._deterministic)
|
| elif len(files) == 1:
|
| return _read_files_then_shard(
|
| files,
|
| dataset_fn,
|
| input_context,
|
| sharding=self._sharding,
|
| repeat=self._is_training and not self._cache)
|
| else:
|
| raise ValueError('It is unexpected that `tfds_builder` is None and '
|
| 'there is also no `files`.')
|
|
|
| if self._tfds_name:
|
| if isinstance(self._tfds_name, cfg.base_config.Config):
|
| dataset = {}
|
| for k, tfds_name in self._tfds_name.as_dict().items():
|
| dataset[k] = _read_tfds(
|
| tfds_name=tfds_name,
|
| tfds_data_dir=self._tfds_data_dir,
|
| tfds_split=self._tfds_split,
|
| tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
|
| tfds_as_supervised=self._tfds_as_supervised,
|
| input_context=input_context,
|
| seed=self._seed,
|
| is_training=self._is_training,
|
| cache=self._cache,
|
| cycle_length=self._cycle_length,
|
| block_length=self._block_length)
|
| else:
|
| dataset = _read_tfds(
|
| tfds_name=self._tfds_name,
|
| tfds_data_dir=self._tfds_data_dir,
|
| tfds_split=self._tfds_split,
|
| tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
|
| tfds_as_supervised=self._tfds_as_supervised,
|
| input_context=input_context,
|
| seed=self._seed,
|
| is_training=self._is_training,
|
| cache=self._cache,
|
| cycle_length=self._cycle_length,
|
| block_length=self._block_length)
|
| elif isinstance(matched_files, (list, tuple)):
|
| dataset = _files_to_dataset(matched_files)
|
| elif isinstance(matched_files, dict):
|
| dataset = {}
|
| for k, fs in matched_files.items():
|
| dataset[k] = _files_to_dataset(fs)
|
| else:
|
| raise ValueError('`matched_files` should be a list or dict.')
|
|
|
| return dataset
|
|
|
| def _decode_and_parse_dataset(
|
| self,
|
| dataset: Union[tf.data.Dataset, Dict[Text, tf.data.Dataset]],
|
| batch_size: int,
|
| input_context: Optional[tf.distribute.InputContext] = None
|
| ) -> tf.data.Dataset:
|
| """Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
|
|
|
| def _shuffle_and_decode(ds):
|
|
|
| if self._is_training and not self._cache:
|
| ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
|
|
|
| ds = _maybe_map_fn(ds, self._decoder_fn)
|
| return ds
|
|
|
| dataset = tf.nest.map_structure(_shuffle_and_decode, dataset)
|
| if tf.nest.is_nested(dataset):
|
| dataset = self._combine_fn(dataset)
|
|
|
| if self._sample_fn is not None:
|
| dataset = dataset.apply(self._sample_fn)
|
| dataset = _maybe_map_fn(dataset, self._parser_fn)
|
|
|
| if self._filter_fn is not None:
|
| dataset = dataset.filter(self._filter_fn)
|
|
|
| if self._cache:
|
| dataset = dataset.cache()
|
| if self._is_training:
|
| dataset = dataset.repeat()
|
| dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if (self._enable_shared_tf_data_service_between_parallel_trainers and
|
| self._apply_tf_data_service_before_batching):
|
| dataset = self._maybe_apply_data_service(dataset, input_context)
|
|
|
| if self._transform_and_batch_fn is not None:
|
| dataset = self._transform_and_batch_fn(dataset, input_context)
|
| else:
|
| per_replica_batch_size = input_context.get_per_replica_batch_size(
|
| batch_size) if input_context else batch_size
|
| dataset = dataset.batch(
|
| per_replica_batch_size, drop_remainder=self._drop_remainder)
|
|
|
| return dataset
|
|
|
| def _maybe_apply_data_service(
|
| self,
|
| dataset: tf.data.Dataset,
|
| input_context: Optional[tf.distribute.InputContext] = None
|
| ) -> tf.data.Dataset:
|
| """Potentially distributes a dataset."""
|
| if self._enable_tf_data_service and input_context:
|
| if self._enable_round_robin_tf_data_service:
|
| replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
|
| input_context.num_input_pipelines)
|
| base_consumer_index = input_context.input_pipeline_id * (
|
| replicas_per_input_pipeline)
|
| num_consumers = input_context.num_input_pipelines * (
|
| replicas_per_input_pipeline)
|
| range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
|
| tfds_kwargs = {
|
| 'processing_mode': 'parallel_epochs',
|
| 'service': self._tf_data_service_address,
|
| 'job_name': self._tf_data_service_job_name,
|
| 'num_consumers': num_consumers
|
| }
|
| if self._enable_shared_tf_data_service_between_parallel_trainers:
|
| raise ValueError('Shared tf.data service does not support round-robin'
|
| ' tf.data service.')
|
| dataset = range_dataset.map(lambda i: dataset.apply(
|
| tf.data.experimental.service.distribute(
|
| consumer_index=base_consumer_index + i, **tfds_kwargs)))
|
|
|
|
|
| dataset = dataset.interleave(
|
| lambda x: x,
|
| cycle_length=replicas_per_input_pipeline,
|
| num_parallel_calls=replicas_per_input_pipeline,
|
| deterministic=True)
|
| else:
|
| tfds_kwargs = {
|
| 'processing_mode': 'parallel_epochs',
|
| 'service': self._tf_data_service_address,
|
| 'job_name': self._tf_data_service_job_name,
|
| }
|
| if self._enable_shared_tf_data_service_between_parallel_trainers:
|
| tfds_kwargs.update({
|
| 'processing_mode':
|
| tf.data.experimental.service.ShardingPolicy.OFF,
|
| 'cross_trainer_cache':
|
| tf.data.experimental.service.CrossTrainerCache(
|
| trainer_id=self._trainer_id)
|
| })
|
| dataset = dataset.apply(
|
| tf.data.experimental.service.distribute(**tfds_kwargs))
|
| return dataset
|
|
|
| def read(self,
|
| input_context: Optional[tf.distribute.InputContext] = None,
|
| dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
|
| """Generates a tf.data.Dataset object."""
|
| if dataset is None:
|
| dataset = self._read_data_source(self._matched_files, self._dataset_fn,
|
| input_context)
|
| dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
|
| input_context)
|
| dataset = _maybe_map_fn(dataset, self._postprocess_fn)
|
| if not (self._enable_shared_tf_data_service_between_parallel_trainers and
|
| self._apply_tf_data_service_before_batching):
|
| dataset = self._maybe_apply_data_service(dataset, input_context)
|
|
|
| if self._deterministic is not None:
|
| options = tf.data.Options()
|
| options.deterministic = self._deterministic
|
| dataset = dataset.with_options(options)
|
| if self._autotune_algorithm:
|
| options = tf.data.Options()
|
| options.autotune.autotune_algorithm = (
|
| tf.data.experimental.AutotuneAlgorithm[self._autotune_algorithm])
|
| dataset = dataset.with_options(options)
|
| return dataset.prefetch(self._prefetch_buffer_size)
|
|
|