Spaces:
Runtime error
Runtime error
| # Lint as: python3 | |
| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """A common dataset reader.""" | |
| from typing import Any, Callable, List, Optional | |
| import tensorflow as tf | |
| import tensorflow_datasets as tfds | |
| from official.modeling.hyperparams import config_definitions as cfg | |
| class InputReader: | |
| """Input reader that returns a tf.data.Dataset instance.""" | |
| def __init__(self, | |
| params: cfg.DataConfig, | |
| shards: Optional[List[str]] = None, | |
| dataset_fn=tf.data.TFRecordDataset, | |
| decoder_fn: Optional[Callable[..., Any]] = None, | |
| parser_fn: Optional[Callable[..., Any]] = None, | |
| dataset_transform_fn: Optional[Callable[[tf.data.Dataset], | |
| tf.data.Dataset]] = None, | |
| postprocess_fn: Optional[Callable[..., Any]] = None): | |
| """Initializes an InputReader instance. | |
| Args: | |
| params: A config_definitions.DataConfig object. | |
| shards: A list of files to be read. If given, read from these files. | |
| Otherwise, read from params.input_path. | |
| dataset_fn: A `tf.data.Dataset` that consumes the input files. For | |
| example, it can be `tf.data.TFRecordDataset`. | |
| decoder_fn: An optional `callable` that takes the serialized data string | |
| and decodes them into the raw tensor dictionary. | |
| parser_fn: An optional `callable` that takes the decoded raw tensors dict | |
| and parse them into a dictionary of tensors that can be consumed by the | |
| model. It will be executed after decoder_fn. | |
| dataset_transform_fn: An optional `callable` that takes a | |
| `tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be | |
| executed after parser_fn. | |
| postprocess_fn: A optional `callable` that processes batched tensors. It | |
| will be executed after batching. | |
| """ | |
| if params.input_path and params.tfds_name: | |
| raise ValueError('At most one of `input_path` and `tfds_name` can be ' | |
| 'specified, but got %s and %s.' % ( | |
| params.input_path, params.tfds_name)) | |
| self._shards = shards | |
| self._tfds_builder = None | |
| if self._shards: | |
| self._num_files = len(self._shards) | |
| elif not params.tfds_name: | |
| self._input_patterns = params.input_path.strip().split(',') | |
| self._num_files = 0 | |
| for input_pattern in self._input_patterns: | |
| input_pattern = input_pattern.strip() | |
| if not input_pattern: | |
| continue | |
| matched_files = tf.io.gfile.glob(input_pattern) | |
| if not matched_files: | |
| raise ValueError('%s does not match any files.' % input_pattern) | |
| else: | |
| self._num_files += len(matched_files) | |
| if self._num_files == 0: | |
| raise ValueError('%s does not match any files.' % params.input_path) | |
| else: | |
| if not params.tfds_split: | |
| raise ValueError( | |
| '`tfds_name` is %s, but `tfds_split` is not specified.' % | |
| params.tfds_name) | |
| self._tfds_builder = tfds.builder( | |
| params.tfds_name, data_dir=params.tfds_data_dir) | |
| self._global_batch_size = params.global_batch_size | |
| self._is_training = params.is_training | |
| self._drop_remainder = params.drop_remainder | |
| self._shuffle_buffer_size = params.shuffle_buffer_size | |
| self._cache = params.cache | |
| self._cycle_length = params.cycle_length | |
| self._block_length = params.block_length | |
| self._sharding = params.sharding | |
| self._examples_consume = params.examples_consume | |
| self._tfds_split = params.tfds_split | |
| self._tfds_download = params.tfds_download | |
| self._tfds_as_supervised = params.tfds_as_supervised | |
| self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature | |
| self._dataset_fn = dataset_fn | |
| self._decoder_fn = decoder_fn | |
| self._parser_fn = parser_fn | |
| self._dataset_transform_fn = dataset_transform_fn | |
| self._postprocess_fn = postprocess_fn | |
| def _read_sharded_files( | |
| self, | |
| input_context: Optional[tf.distribute.InputContext] = None): | |
| """Reads a dataset from sharded files.""" | |
| # Read from `self._shards` if it is provided. | |
| if self._shards: | |
| dataset = tf.data.Dataset.from_tensor_slices(self._shards) | |
| else: | |
| dataset = tf.data.Dataset.list_files( | |
| self._input_patterns, shuffle=self._is_training) | |
| if self._sharding and input_context and ( | |
| input_context.num_input_pipelines > 1): | |
| dataset = dataset.shard(input_context.num_input_pipelines, | |
| input_context.input_pipeline_id) | |
| if self._is_training: | |
| dataset = dataset.repeat() | |
| dataset = dataset.interleave( | |
| map_func=self._dataset_fn, | |
| cycle_length=self._cycle_length, | |
| block_length=self._block_length, | |
| num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
| return dataset | |
| def _read_single_file( | |
| self, | |
| input_context: Optional[tf.distribute.InputContext] = None): | |
| """Reads a dataset from a single file.""" | |
| # Read from `self._shards` if it is provided. | |
| dataset = self._dataset_fn(self._shards or self._input_patterns) | |
| # When `input_file` is a path to a single file, disable auto sharding | |
| # so that same input file is sent to all workers. | |
| options = tf.data.Options() | |
| options.experimental_distribute.auto_shard_policy = ( | |
| tf.data.experimental.AutoShardPolicy.OFF) | |
| dataset = dataset.with_options(options) | |
| if self._sharding and input_context and ( | |
| input_context.num_input_pipelines > 1): | |
| dataset = dataset.shard(input_context.num_input_pipelines, | |
| input_context.input_pipeline_id) | |
| if self._is_training: | |
| dataset = dataset.repeat() | |
| return dataset | |
| def _read_tfds( | |
| self, | |
| input_context: Optional[tf.distribute.InputContext] = None | |
| ) -> tf.data.Dataset: | |
| """Reads a dataset from tfds.""" | |
| if self._tfds_download: | |
| self._tfds_builder.download_and_prepare() | |
| read_config = tfds.ReadConfig( | |
| interleave_cycle_length=self._cycle_length, | |
| interleave_block_length=self._block_length, | |
| input_context=input_context) | |
| decoders = {} | |
| if self._tfds_skip_decoding_feature: | |
| for skip_feature in self._tfds_skip_decoding_feature.split(','): | |
| decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() | |
| dataset = self._tfds_builder.as_dataset( | |
| split=self._tfds_split, | |
| shuffle_files=self._is_training, | |
| as_supervised=self._tfds_as_supervised, | |
| decoders=decoders, | |
| read_config=read_config) | |
| return dataset | |
| def tfds_info(self) -> tfds.core.DatasetInfo: | |
| """Returns TFDS dataset info, if available.""" | |
| if self._tfds_builder: | |
| return self._tfds_builder.info | |
| else: | |
| raise ValueError('tfds_info is not available, because the dataset ' | |
| 'is not loaded from tfds.') | |
| def read( | |
| self, | |
| input_context: Optional[tf.distribute.InputContext] = None | |
| ) -> tf.data.Dataset: | |
| """Generates a tf.data.Dataset object.""" | |
| if self._tfds_builder: | |
| dataset = self._read_tfds(input_context) | |
| elif self._num_files > 1: | |
| dataset = self._read_sharded_files(input_context) | |
| else: | |
| assert self._num_files == 1 | |
| dataset = self._read_single_file(input_context) | |
| if self._cache: | |
| dataset = dataset.cache() | |
| if self._is_training: | |
| dataset = dataset.shuffle(self._shuffle_buffer_size) | |
| if self._examples_consume > 0: | |
| dataset = dataset.take(self._examples_consume) | |
| def maybe_map_fn(dataset, fn): | |
| return dataset if fn is None else dataset.map( | |
| fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
| dataset = maybe_map_fn(dataset, self._decoder_fn) | |
| dataset = maybe_map_fn(dataset, self._parser_fn) | |
| if self._dataset_transform_fn is not None: | |
| dataset = self._dataset_transform_fn(dataset) | |
| per_replica_batch_size = input_context.get_per_replica_batch_size( | |
| self._global_batch_size) if input_context else self._global_batch_size | |
| dataset = dataset.batch( | |
| per_replica_batch_size, drop_remainder=self._drop_remainder) | |
| dataset = maybe_map_fn(dataset, self._postprocess_fn) | |
| return dataset.prefetch(tf.data.experimental.AUTOTUNE) | |