Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """A common dataset reader.""" | |
| import dataclasses | |
| import random | |
| from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union | |
| from absl import logging | |
| import tensorflow as tf, tf_keras | |
| import tensorflow_datasets as tfds | |
| from official.core import config_definitions as cfg | |
| def _get_random_integer(): | |
| return random.randint(0, (1 << 31) - 1) | |
| def _maybe_map_fn(dataset: tf.data.Dataset, | |
| fn: Optional[Callable[..., Any]] = None) -> tf.data.Dataset: | |
| """Calls dataset.map if a valid function is passed in.""" | |
| return dataset if fn is None else dataset.map( | |
| fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
| def match_files(input_path: Union[Sequence[str], str]) -> List[str]: | |
| """Matches files from an input_path.""" | |
| matched_files = [] | |
| # Read dataset from files. | |
| usage = ('`input_path` should be either (1) a str indicating a file ' | |
| 'path/pattern, or (2) a str indicating multiple file ' | |
| 'paths/patterns separated by comma (e.g "a, b, c" or no spaces ' | |
| '"a,b,c", or (3) a list of str, each of which is a file ' | |
| 'path/pattern or multiple file paths/patterns separated by ' | |
| 'comma, but got: %s') | |
| if isinstance(input_path, str): | |
| input_path_list = [input_path] | |
| elif isinstance(input_path, (list, tuple)): | |
| if any(not isinstance(x, str) for x in input_path): | |
| raise ValueError(usage % input_path) | |
| input_path_list = input_path | |
| else: | |
| raise ValueError(usage % input_path) | |
| for input_path in input_path_list: | |
| input_patterns = input_path.strip().split(',') | |
| for input_pattern in input_patterns: | |
| input_pattern = input_pattern.strip() | |
| if not input_pattern: | |
| continue | |
| if '*' in input_pattern or '?' in input_pattern: | |
| tmp_matched_files = tf.io.gfile.glob(input_pattern) | |
| if not tmp_matched_files: | |
| raise ValueError('%s does not match any files.' % input_pattern) | |
| matched_files.extend(tmp_matched_files) | |
| else: | |
| matched_files.append(input_pattern) | |
| if not matched_files: | |
| raise ValueError('%s does not match any files.' % input_path) | |
| return matched_files | |
| def _read_files_then_shard(matched_files: List[str], | |
| dataset_fn, | |
| input_context: Optional[ | |
| tf.distribute.InputContext] = None, | |
| sharding: bool = False, | |
| repeat: bool = False) -> tf.data.Dataset: | |
| """Sends all data files to every worker and then shard by data.""" | |
| dataset = dataset_fn(matched_files) | |
| # When `input_file` is a path to a single file or the number of files is | |
| # less than the number of input pipelines, disable auto sharding | |
| # so that same input file is sent to all workers. | |
| options = tf.data.Options() | |
| options.experimental_distribute.auto_shard_policy = ( | |
| tf.data.experimental.AutoShardPolicy.OFF) | |
| dataset = dataset.with_options(options) | |
| # Do not enable sharding if tf.data service is enabled, as sharding will be | |
| # handled inside tf.data service. | |
| if sharding and input_context and (input_context.num_input_pipelines > 1): | |
| dataset = dataset.shard(input_context.num_input_pipelines, | |
| input_context.input_pipeline_id) | |
| if repeat: | |
| dataset = dataset.repeat() | |
| return dataset | |
| def _shard_files_then_read(matched_files: List[str], | |
| dataset_fn, | |
| input_context: Optional[ | |
| tf.distribute.InputContext] = None, | |
| seed: Optional[Union[int, tf.Tensor]] = None, | |
| is_training: bool = False, | |
| sharding: bool = False, | |
| cache: bool = False, | |
| cycle_length: Optional[int] = None, | |
| block_length: Optional[int] = None, | |
| deterministic: bool = False) -> tf.data.Dataset: | |
| """Shards the data files and then sent a split to every worker to read.""" | |
| dataset = tf.data.Dataset.from_tensor_slices(matched_files) | |
| # Shuffle and repeat at file level. | |
| # If cache is enabled, `reshuffle_each_iteration` is set to False, | |
| # because we will read the same cached data in every iteration anyway. | |
| if is_training: | |
| # We need a seed to shuffle the files so that when each TPU workers gets | |
| # its own shard the files do not overlap. | |
| if sharding and seed is None: | |
| seed = _get_random_integer() | |
| dataset = dataset.shuffle( | |
| len(matched_files), | |
| seed=seed, | |
| reshuffle_each_iteration=True if not cache else False) | |
| # Do not enable sharding if tf.data service is enabled, as sharding will be | |
| # handled inside tf.data service. | |
| if sharding and input_context and (input_context.num_input_pipelines > 1): | |
| dataset = dataset.shard(input_context.num_input_pipelines, | |
| input_context.input_pipeline_id) | |
| # If cache is enabled, we will call `repeat()` later after `cache()`. | |
| if is_training and not cache: | |
| dataset = dataset.repeat() | |
| dataset = dataset.interleave( | |
| map_func=dataset_fn, | |
| cycle_length=cycle_length, | |
| block_length=block_length, | |
| num_parallel_calls=(cycle_length | |
| if cycle_length else tf.data.experimental.AUTOTUNE), | |
| deterministic=deterministic) | |
| return dataset | |
| def _read_tfds(tfds_name: Text, | |
| tfds_data_dir: Text, | |
| tfds_split: Text, | |
| tfds_skip_decoding_feature: Text, | |
| tfds_as_supervised: bool, | |
| input_context: Optional[tf.distribute.InputContext] = None, | |
| seed: Optional[Union[int, tf.Tensor]] = None, | |
| is_training: bool = False, | |
| cache: bool = False, | |
| cycle_length: Optional[int] = None, | |
| block_length: Optional[int] = None) -> tf.data.Dataset: | |
| """Reads a dataset from tfds.""" | |
| repeat_filenames = is_training and not cache | |
| read_config = tfds.ReadConfig( | |
| interleave_cycle_length=cycle_length, | |
| interleave_block_length=block_length, | |
| input_context=input_context, | |
| shuffle_seed=seed, | |
| repeat_filenames=repeat_filenames, | |
| # Only assert cardinality when we have a finite dataset. | |
| assert_cardinality=not repeat_filenames, | |
| skip_prefetch=True) | |
| decoders = {} | |
| if tfds_skip_decoding_feature: | |
| for skip_feature in tfds_skip_decoding_feature.split(','): | |
| decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() | |
| if tfds_name.startswith('mldataset.'): | |
| dataset = tfds.load(name=tfds_name, | |
| split=tfds_split, | |
| as_supervised=tfds_as_supervised, | |
| decoders=decoders if decoders else None, | |
| read_config=read_config) | |
| else: | |
| builder = tfds.builder(tfds_name, data_dir=tfds_data_dir) | |
| if builder.info.splits: | |
| num_shards = len(builder.info.splits[tfds_split].file_instructions) | |
| else: | |
| # The tfds mock path often does not provide splits. | |
| num_shards = 1 | |
| load_kwargs = dict( | |
| name=tfds_name, download=True, split=tfds_split, | |
| shuffle_files=is_training, as_supervised=tfds_as_supervised, | |
| decoders=decoders if decoders else None) | |
| if tfds_data_dir: | |
| load_kwargs.update({'data_dir': tfds_data_dir}) | |
| if input_context and num_shards < input_context.num_input_pipelines: | |
| # The number of files in the dataset split is smaller than the number of | |
| # input pipelines. We read the entire dataset first and then shard in the | |
| # host memory. | |
| read_config = dataclasses.replace(read_config, input_context=None) | |
| load_kwargs.update({'read_config': read_config}) | |
| dataset = tfds.load(**load_kwargs) | |
| dataset = dataset.shard(input_context.num_input_pipelines, | |
| input_context.input_pipeline_id) | |
| else: | |
| load_kwargs.update({'read_config': read_config}) | |
| dataset = tfds.load(**load_kwargs) | |
| return dataset | |
| class InputReader: | |
| """Input reader that returns a tf.data.Dataset instance.""" | |
| # A static random number which is the same across different InputReader | |
| # instances. | |
| static_randnum = _get_random_integer() | |
| def __init__( | |
| self, | |
| params: cfg.DataConfig, | |
| dataset_fn=tf.data.TFRecordDataset, | |
| decoder_fn: Optional[Callable[..., Any]] = None, | |
| combine_fn: Optional[Callable[..., Any]] = None, | |
| sample_fn: Optional[Callable[..., Any]] = None, | |
| parser_fn: Optional[Callable[..., Any]] = None, | |
| filter_fn: Optional[Callable[..., tf.Tensor]] = None, | |
| transform_and_batch_fn: Optional[ | |
| Callable[ | |
| [tf.data.Dataset, Optional[tf.distribute.InputContext]], | |
| tf.data.Dataset, | |
| ] | |
| ] = None, | |
| postprocess_fn: Optional[Callable[..., Any]] = None, | |
| ): | |
| """Initializes an InputReader instance. | |
| Args: | |
| params: A config_definitions.DataConfig object. | |
| dataset_fn: A `tf.data.Dataset` that consumes the input files. For | |
| example, it can be `tf.data.TFRecordDataset`. | |
| decoder_fn: An optional `callable` that takes the serialized data string | |
| and decodes them into the raw tensor dictionary. | |
| combine_fn: An optional `callable` that takes a dictionarty of | |
| `tf.data.Dataset` objects as input and outputs a combined dataset. It | |
| will be executed after the decoder_fn and before the sample_fn. | |
| sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as | |
| input and outputs the transformed dataset. It performs sampling on the | |
| decoded raw tensors dict before the parser_fn. | |
| parser_fn: An optional `callable` that takes the decoded raw tensors dict | |
| and parse them into a dictionary of tensors that can be consumed by the | |
| model. It will be executed after decoder_fn. | |
| filter_fn: An optional `callable` mapping a dataset element to a boolean. | |
| It will be executed after parser_fn. | |
| transform_and_batch_fn: An optional `callable` that takes a | |
| `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as | |
| input, and returns a `tf.data.Dataset` object. It will be executed after | |
| `parser_fn` to transform and batch the dataset; if None, after | |
| `parser_fn` is executed, the dataset will be batched into per-replica | |
| batch size. | |
| postprocess_fn: A optional `callable` that processes batched tensors. It | |
| will be executed after batching. | |
| """ | |
| if params.input_path and params.tfds_name: | |
| raise ValueError('At most one of `input_path` and `tfds_name` can be ' | |
| 'specified, but got %s and %s.' % | |
| (params.input_path, params.tfds_name)) | |
| if (isinstance(params.input_path, cfg.base_config.Config) or | |
| isinstance(params.tfds_name, cfg.base_config.Config) | |
| ) and combine_fn is None: | |
| raise ValueError( | |
| 'A combine_fn is required if `input_path` or `tfds_name` is a dict.') | |
| self._tfds_name = params.tfds_name | |
| self._tfds_data_dir = params.tfds_data_dir | |
| self._matched_files = None | |
| if not params.input_path: | |
| # Read dataset from TFDS. | |
| if not params.tfds_split: | |
| raise ValueError( | |
| '`tfds_name` is %s, but `tfds_split` is not specified.' % | |
| params.tfds_name) | |
| else: | |
| self._matched_files = self.get_files(params.input_path) | |
| self._global_batch_size = params.global_batch_size | |
| self._is_training = params.is_training | |
| self._drop_remainder = params.drop_remainder | |
| self._shuffle_buffer_size = params.shuffle_buffer_size | |
| self._cache = params.cache | |
| self._cycle_length = params.cycle_length | |
| self._block_length = params.block_length | |
| self._deterministic = params.deterministic | |
| self._sharding = params.sharding | |
| self._tfds_split = params.tfds_split | |
| self._tfds_as_supervised = params.tfds_as_supervised | |
| self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature | |
| self._dataset_fn = dataset_fn | |
| self._decoder_fn = decoder_fn | |
| self._combine_fn = combine_fn | |
| self._sample_fn = sample_fn | |
| self._parser_fn = parser_fn | |
| self._transform_and_batch_fn = transform_and_batch_fn | |
| self._postprocess_fn = postprocess_fn | |
| self._filter_fn = filter_fn | |
| self._seed = params.seed | |
| self._prefetch_buffer_size = ( | |
| params.prefetch_buffer_size or tf.data.experimental.AUTOTUNE) | |
| self._autotune_algorithm = params.autotune_algorithm | |
| # When tf.data service is enabled, each data service worker should get | |
| # different random seeds. Thus, we set `seed` to None. | |
| # Sharding should also be disabled because tf data service handles how | |
| # each worker shard data with `processing_mode` in distribute method. | |
| if params.enable_tf_data_service: | |
| self._seed = None | |
| self._sharding = False | |
| self._enable_tf_data_service = ( | |
| params.enable_tf_data_service and params.tf_data_service_address) | |
| self._tf_data_service_address = params.tf_data_service_address | |
| self._enable_shared_tf_data_service_between_parallel_trainers = ( | |
| params.enable_shared_tf_data_service_between_parallel_trainers) | |
| self._apply_tf_data_service_before_batching = ( | |
| params.apply_tf_data_service_before_batching) | |
| self._trainer_id = params.trainer_id | |
| if self._enable_tf_data_service: | |
| # Add a random seed as the tf.data service job name suffix, so tf.data | |
| # service doesn't reuse the previous state if TPU worker gets preempted. | |
| # It's necessary to add global batch size into the tf data service job | |
| # name because when tuning batch size with vizier and tf data service is | |
| # also enable, the tf data servce job name should be different for | |
| # different vizier trials since once batch size is changed, from the | |
| # tf.data perspective, the dataset is a different instance, and a | |
| # different job name should be used for tf data service. Otherwise, the | |
| # model would read tensors from the incorrect tf data service job, which | |
| # would causes dimension mismatch on the batch size dimension. | |
| self._tf_data_service_job_name = ( | |
| f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_' | |
| f'{self.static_randnum}') | |
| self._enable_round_robin_tf_data_service = params.get( | |
| 'enable_round_robin_tf_data_service', False) | |
| if self._enable_shared_tf_data_service_between_parallel_trainers: | |
| # When shared tf.data service is enabled, only a single tf.data service | |
| # instance should be created and shared between parallel trainers. If | |
| # the global batch size is different across trainers, | |
| # params.apply_tf_data_service_before_batching should be set to true | |
| # because tf.data service with different batch sizes will be considered | |
| # separate tf.data service instances. | |
| self._tf_data_service_job_name = ( | |
| f'{params.tf_data_service_job_name}_{self.static_randnum}') | |
| def get_files(self, input_path): | |
| """Gets matched files. Can be overridden by subclasses.""" | |
| if not input_path: | |
| return None | |
| # we want to combine / mix datasets | |
| if isinstance(input_path, cfg.base_config.Config): | |
| matched_files = {} | |
| for k, v in input_path.as_dict().items(): | |
| matched_files[k] = match_files(v) | |
| # single dataset | |
| else: | |
| matched_files = match_files(input_path) | |
| return matched_files | |
| def _read_data_source( | |
| self, | |
| matched_files: Union[Dict[str, List[str]], List[str]], | |
| dataset_fn, | |
| input_context: Optional[tf.distribute.InputContext] = None, | |
| ): | |
| """Reads the data source (files/tfds) to a dataset.""" | |
| def _files_to_dataset(files: List[str]) -> tf.data.Dataset: | |
| if len(files) > 1: | |
| if input_context and (len(files) < input_context.num_input_pipelines): | |
| logging.warn( | |
| ( | |
| 'The number of files %d is less than the number of input ' | |
| 'pipelines %d. We will send all input files to every worker. ' | |
| 'Please consider sharding your data into more files.' | |
| ), | |
| len(files), | |
| input_context.num_input_pipelines, | |
| ) | |
| return _read_files_then_shard( | |
| files, | |
| dataset_fn, | |
| input_context, | |
| sharding=self._sharding, | |
| repeat=self._is_training and not self._cache) | |
| else: | |
| return _shard_files_then_read( | |
| files, | |
| dataset_fn, | |
| input_context, | |
| seed=self._seed, | |
| is_training=self._is_training, | |
| sharding=self._sharding, | |
| cache=self._cache, | |
| cycle_length=self._cycle_length, | |
| block_length=self._block_length, | |
| deterministic=self._deterministic) | |
| elif len(files) == 1: | |
| return _read_files_then_shard( | |
| files, | |
| dataset_fn, | |
| input_context, | |
| sharding=self._sharding, | |
| repeat=self._is_training and not self._cache) | |
| else: | |
| raise ValueError('It is unexpected that `tfds_builder` is None and ' | |
| 'there is also no `files`.') | |
| if self._tfds_name: | |
| if isinstance(self._tfds_name, cfg.base_config.Config): | |
| dataset = {} | |
| for k, tfds_name in self._tfds_name.as_dict().items(): | |
| dataset[k] = _read_tfds( | |
| tfds_name=tfds_name, | |
| tfds_data_dir=self._tfds_data_dir, | |
| tfds_split=self._tfds_split, | |
| tfds_skip_decoding_feature=self._tfds_skip_decoding_feature, | |
| tfds_as_supervised=self._tfds_as_supervised, | |
| input_context=input_context, | |
| seed=self._seed, | |
| is_training=self._is_training, | |
| cache=self._cache, | |
| cycle_length=self._cycle_length, | |
| block_length=self._block_length) | |
| else: | |
| dataset = _read_tfds( | |
| tfds_name=self._tfds_name, | |
| tfds_data_dir=self._tfds_data_dir, | |
| tfds_split=self._tfds_split, | |
| tfds_skip_decoding_feature=self._tfds_skip_decoding_feature, | |
| tfds_as_supervised=self._tfds_as_supervised, | |
| input_context=input_context, | |
| seed=self._seed, | |
| is_training=self._is_training, | |
| cache=self._cache, | |
| cycle_length=self._cycle_length, | |
| block_length=self._block_length) | |
| elif isinstance(matched_files, (list, tuple)): | |
| dataset = _files_to_dataset(matched_files) | |
| elif isinstance(matched_files, dict): | |
| dataset = {} | |
| for k, fs in matched_files.items(): | |
| dataset[k] = _files_to_dataset(fs) | |
| else: | |
| raise ValueError('`matched_files` should be a list or dict.') | |
| return dataset | |
| def _decode_and_parse_dataset( | |
| self, | |
| dataset: Union[tf.data.Dataset, Dict[Text, tf.data.Dataset]], | |
| batch_size: int, | |
| input_context: Optional[tf.distribute.InputContext] = None | |
| ) -> tf.data.Dataset: | |
| """Returns a tf.data.Dataset object after shuffling, decoding, and parsing.""" | |
| def _shuffle_and_decode(ds): | |
| # If cache is enabled, we will call `shuffle()` later after `cache()`. | |
| if self._is_training and not self._cache: | |
| ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed) | |
| # Decode | |
| ds = _maybe_map_fn(ds, self._decoder_fn) | |
| return ds | |
| dataset = tf.nest.map_structure(_shuffle_and_decode, dataset) | |
| if tf.nest.is_nested(dataset): | |
| dataset = self._combine_fn(dataset) | |
| if self._sample_fn is not None: | |
| dataset = dataset.apply(self._sample_fn) | |
| dataset = _maybe_map_fn(dataset, self._parser_fn) | |
| if self._filter_fn is not None: | |
| dataset = dataset.filter(self._filter_fn) | |
| if self._cache: | |
| dataset = dataset.cache() | |
| if self._is_training: | |
| dataset = dataset.repeat() | |
| dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed) | |
| # Applies tf.data service before batching operations. This is useful when | |
| # tf.data service is shared between parallel trainers, and batch size is | |
| # changing between parallel trainers. Then batch size is changing, tf.data | |
| # services will be considered different instances if applied after batching | |
| # operations, which make it difficult to share between parallel trainers. | |
| # However, if there are additional expensive operations in | |
| # self._transform_and_batch_fn and self._postprocess_fn, the entire tf.data | |
| # pipeline could be slowed down. In this case, try to move these dataset | |
| # operations into early stages if possible. | |
| if (self._enable_shared_tf_data_service_between_parallel_trainers and | |
| self._apply_tf_data_service_before_batching): | |
| dataset = self._maybe_apply_data_service(dataset, input_context) | |
| if self._transform_and_batch_fn is not None: | |
| dataset = self._transform_and_batch_fn(dataset, input_context) | |
| else: | |
| per_replica_batch_size = input_context.get_per_replica_batch_size( | |
| batch_size) if input_context else batch_size | |
| dataset = dataset.batch( | |
| per_replica_batch_size, drop_remainder=self._drop_remainder) | |
| return dataset | |
| def _maybe_apply_data_service( | |
| self, | |
| dataset: tf.data.Dataset, | |
| input_context: Optional[tf.distribute.InputContext] = None | |
| ) -> tf.data.Dataset: | |
| """Potentially distributes a dataset.""" | |
| if self._enable_tf_data_service and input_context: | |
| if self._enable_round_robin_tf_data_service: | |
| replicas_per_input_pipeline = input_context.num_replicas_in_sync // ( | |
| input_context.num_input_pipelines) | |
| base_consumer_index = input_context.input_pipeline_id * ( | |
| replicas_per_input_pipeline) | |
| num_consumers = input_context.num_input_pipelines * ( | |
| replicas_per_input_pipeline) | |
| range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline) | |
| tfds_kwargs = { | |
| 'processing_mode': 'parallel_epochs', | |
| 'service': self._tf_data_service_address, | |
| 'job_name': self._tf_data_service_job_name, | |
| 'num_consumers': num_consumers | |
| } | |
| if self._enable_shared_tf_data_service_between_parallel_trainers: | |
| raise ValueError('Shared tf.data service does not support round-robin' | |
| ' tf.data service.') | |
| dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda | |
| tf.data.experimental.service.distribute( | |
| consumer_index=base_consumer_index + i, **tfds_kwargs))) | |
| # Use parallel interleave to read multiple batches from a tf.data | |
| # service worker in parallel. | |
| dataset = dataset.interleave( | |
| lambda x: x, | |
| cycle_length=replicas_per_input_pipeline, | |
| num_parallel_calls=replicas_per_input_pipeline, | |
| deterministic=True) | |
| else: | |
| tfds_kwargs = { | |
| 'processing_mode': 'parallel_epochs', | |
| 'service': self._tf_data_service_address, | |
| 'job_name': self._tf_data_service_job_name, | |
| } | |
| if self._enable_shared_tf_data_service_between_parallel_trainers: | |
| tfds_kwargs.update({ | |
| 'processing_mode': | |
| tf.data.experimental.service.ShardingPolicy.OFF, | |
| 'cross_trainer_cache': | |
| tf.data.experimental.service.CrossTrainerCache( | |
| trainer_id=self._trainer_id) | |
| }) | |
| dataset = dataset.apply( | |
| tf.data.experimental.service.distribute(**tfds_kwargs)) | |
| return dataset | |
| def read(self, | |
| input_context: Optional[tf.distribute.InputContext] = None, | |
| dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset: | |
| """Generates a tf.data.Dataset object.""" | |
| if dataset is None: | |
| dataset = self._read_data_source(self._matched_files, self._dataset_fn, | |
| input_context) | |
| dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size, | |
| input_context) | |
| dataset = _maybe_map_fn(dataset, self._postprocess_fn) | |
| if not (self._enable_shared_tf_data_service_between_parallel_trainers and | |
| self._apply_tf_data_service_before_batching): | |
| dataset = self._maybe_apply_data_service(dataset, input_context) | |
| if self._deterministic is not None: | |
| options = tf.data.Options() | |
| options.deterministic = self._deterministic | |
| dataset = dataset.with_options(options) | |
| if self._autotune_algorithm: | |
| options = tf.data.Options() | |
| options.autotune.autotune_algorithm = ( | |
| tf.data.experimental.AutotuneAlgorithm[self._autotune_algorithm]) | |
| dataset = dataset.with_options(options) | |
| return dataset.prefetch(self._prefetch_buffer_size) | |