| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Common utilities used in evaluators.""" |
| import math |
| import jax |
| import tensorflow as tf |
| import tensorflow_datasets as tfds |
|
|
|
|
| def get_jax_process_dataset(dataset, split, global_batch_size, pp_fn, |
| dataset_dir=None, cache=True, add_tfds_id=False): |
| """Returns dataset to be processed by current jax host. |
| |
| The dataset is sharded and padded with zeros such that all processes |
| have equal number of batches. The first 2 dimensions of the dataset |
| elements are: [local_device_count, device_batch_size]. |
| |
| Args: |
| dataset: dataset name. |
| split: dataset split. |
| global_batch_size: batch size to be process per iteration on the dataset. |
| pp_fn: preprocessing function to apply per example. |
| dataset_dir: path for tfds to find the prepared data. |
| cache: whether to cache the dataset after batching. |
| add_tfds_id: whether to add the unique `tfds_id` string to each example. |
| """ |
| assert global_batch_size % jax.device_count() == 0 |
| total_examples = tfds.load( |
| dataset, split=split, data_dir=dataset_dir).cardinality() |
| num_batches = math.ceil(total_examples / global_batch_size) |
|
|
| process_split = tfds.even_splits( |
| split, n=jax.process_count(), drop_remainder=False)[jax.process_index()] |
| data = tfds.load( |
| dataset, |
| split=process_split, |
| data_dir=dataset_dir, |
| read_config=tfds.ReadConfig(add_tfds_id=add_tfds_id)).map(pp_fn) |
| pad_data = tf.data.Dataset.from_tensors( |
| jax.tree_map(lambda x: tf.zeros(x.shape, x.dtype), data.element_spec) |
| ).repeat() |
|
|
| data = data.concatenate(pad_data) |
| data = data.batch(global_batch_size // jax.device_count()) |
| data = data.batch(jax.local_device_count()) |
| data = data.take(num_batches) |
| if cache: |
| |
| |
| |
| data = data.cache() |
| return data |
|
|