| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """TensorFlow Datasets as data source for big_vision.""" |
| | import functools |
| |
|
| | import big_vision.datasets.core as ds_core |
| | import jax |
| | import numpy as np |
| | import overrides |
| | import tensorflow as tf |
| | import tensorflow_datasets as tfds |
| |
|
| |
|
| | class DataSource(ds_core.DataSource): |
| | """Use TFDS as a data source.""" |
| |
|
| | def __init__(self, name, split, data_dir=None, skip_decode=("image",)): |
| | self.builder = _get_builder(name, data_dir) |
| | self.split = split |
| | |
| | process_splits = tfds.even_splits(split, jax.process_count()) |
| | self.process_split = process_splits[jax.process_index()] |
| | self.skip_decode = skip_decode |
| |
|
| | @overrides.overrides |
| | def get_tfdata( |
| | self, ordered=False, *, process_split=True, allow_cache=True, **kw): |
| | |
| | |
| | |
| | return (_cached_get_dataset if allow_cache else _get_dataset)( |
| | self.builder, self.skip_decode, |
| | split=self.process_split if process_split else self.split, |
| | shuffle_files=not ordered, |
| | **kw) |
| |
|
| | @property |
| | @overrides.overrides |
| | def total_examples(self): |
| | return self.builder.info.splits[self.split].num_examples |
| |
|
| | @overrides.overrides |
| | def num_examples_per_process(self): |
| | splits = tfds.even_splits(self.split, jax.process_count()) |
| | return [self.builder.info.splits[s].num_examples for s in splits] |
| |
|
| |
|
| | @functools.cache |
| | def _get_builder(dataset, data_dir): |
| | if dataset == "from_data_dir": |
| | return tfds.builder_from_directory(data_dir) |
| | else: |
| | return tfds.builder(dataset, data_dir=data_dir, try_gcs=True) |
| |
|
| |
|
| | |
| | |
| | def _get_dataset(builder, skip_decode, **kw): |
| | """Returns a tf.data to be used.""" |
| | rckw = {k: kw.pop(k) for k in ("shuffle_seed",) if k in kw} |
| | ds = builder.as_dataset( |
| | read_config=tfds.ReadConfig( |
| | skip_prefetch=True, |
| | try_autocache=False, |
| | add_tfds_id=True, |
| | **rckw, |
| | ), |
| | decoders={ |
| | f: tfds.decode.SkipDecoding() |
| | for f in skip_decode if f in builder.info.features |
| | }, |
| | **kw) |
| |
|
| | def _hash_tfds_id(example): |
| | id_ = tf.strings.to_hash_bucket_strong( |
| | example["tfds_id"], |
| | np.iinfo(np.uint32).max, |
| | [3714561454027272724, 8800639020734831960]) |
| | example["_id"] = tf.bitcast(id_, tf.int32)[0] |
| | return example |
| |
|
| | return ds.map(_hash_tfds_id) |
| | _cached_get_dataset = functools.cache(_get_dataset) |
| |
|