Spaces:
Running
Running
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """CLRS dataset.""" | |
| import dataclasses | |
| import functools | |
| from typing import Iterator | |
| from clrs._src import probing | |
| from clrs._src import samplers | |
| from clrs._src import specs | |
| import jax | |
| import numpy as np | |
| import tensorflow as tf | |
| import tensorflow_datasets as tfds | |
| def _correct_axis_filtering(tensor, index, name): | |
| if 'hint_' in name: | |
| return tensor[:, index] | |
| else: | |
| return tensor[index] | |
| class CLRSConfig(tfds.core.BuilderConfig): | |
| """Specify the split in the variant because they have different shapes.""" | |
| split: str = '' | |
| DEFAULT_BUILDER_CONFIGS = [] | |
| def _build_default_builder_configs(): | |
| for split in ['train', 'val', 'test']: | |
| for alg in specs.CLRS_30_ALGS: | |
| DEFAULT_BUILDER_CONFIGS.append( | |
| CLRSConfig(name=f'{alg}_{split}', split=split)) | |
| _build_default_builder_configs() | |
| class CLRSDataset(tfds.core.GeneratorBasedBuilder): | |
| """DatasetBuilder for my_dataset dataset.""" | |
| VERSION = tfds.core.Version('1.0.0') | |
| RELEASE_NOTES = { | |
| '1.0.0': 'Initial release.', | |
| } | |
| BUILDER_CONFIGS = DEFAULT_BUILDER_CONFIGS | |
| _instantiated_dataset = None | |
| _instantiated_dataset_name = '' | |
| _instantiated_dataset_split = '' | |
| def _num_samples(self, algorithm_name): | |
| num_samples = samplers.CLRS30[self._builder_config.split]['num_samples'] # pytype: disable=attribute-error # always-use-return-annotations | |
| if self._builder_config.split != 'train': # pytype: disable=attribute-error # always-use-return-annotations | |
| # Generate more samples for those algorithms in which the number of | |
| # signals is small. | |
| num_samples *= specs.CLRS_30_ALGS_SETTINGS[algorithm_name][ | |
| 'num_samples_multiplier'] | |
| return num_samples | |
| def _create_data(self, single_sample): | |
| algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1]) | |
| num_samples = self._num_samples(algorithm_name) | |
| sampler, _ = samplers.build_sampler( | |
| algorithm_name, | |
| seed=samplers.CLRS30[self._builder_config.split]['seed'], # pytype: disable=attribute-error # always-use-return-annotations | |
| num_samples=num_samples, | |
| length=samplers.CLRS30[self._builder_config.split]['length'], # pytype: disable=attribute-error # always-use-return-annotations | |
| ) | |
| sampled_dataset = sampler.next(batch_size=1 if single_sample else None) | |
| data = {'input_' + t.name: t.data for t in sampled_dataset.features.inputs} | |
| # All other data points have input_, hint_, and output_ prefixes, so we | |
| # guarantee that this key is unused. | |
| data['lengths'] = sampled_dataset.features.lengths | |
| data.update({'output_' + t.name: t.data for t in sampled_dataset.outputs}) | |
| data.update({ | |
| 'hint_' + t.name: t.data for t in sampled_dataset.features.hints}) | |
| self._instantiated_dataset = data | |
| def _info(self) -> tfds.core.DatasetInfo: | |
| if tf.io.gfile.exists(self.data_dir): | |
| info = tfds.core.DatasetInfo(builder=self) | |
| info.read_from_directory(self.data_dir) | |
| return info | |
| if (self._instantiated_dataset_name != self._builder_config.name | |
| or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations | |
| self._create_data(single_sample=True) | |
| data = {k: _correct_axis_filtering(v, 0, k) | |
| for k, v in self._instantiated_dataset.items()} | |
| data_info = { | |
| k: tfds.features.Tensor(shape=v.shape, dtype=tf.dtypes.as_dtype( | |
| v.dtype)) for k, v in data.items()} | |
| return tfds.core.DatasetInfo( | |
| builder=self, | |
| features=tfds.features.FeaturesDict(data_info), | |
| ) | |
| def _split_generators(self, dl_manager: tfds.download.DownloadManager): | |
| """Download the data and define splits.""" | |
| if (self._instantiated_dataset_name != self._builder_config.name | |
| or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations | |
| self._create_data(single_sample=False) | |
| self._instantiated_dataset_name = self._builder_config.name | |
| self._instantiated_dataset_split = self._builder_config.split # pytype: disable=attribute-error # always-use-return-annotations | |
| return {self._builder_config.split: self._generate_examples()} # pytype: disable=attribute-error # always-use-return-annotations | |
| def _generate_examples(self): | |
| """Generator of examples for each split.""" | |
| algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1]) | |
| for i in range(self._num_samples(algorithm_name)): | |
| data = {k: _correct_axis_filtering(v, i, k) | |
| for k, v in self._instantiated_dataset.items()} | |
| yield str(i), data | |
| def _get_clrs_file_name(): | |
| return f'CLRS30_v{CLRSDataset.VERSION}.tar.gz' | |
| def get_dataset_gcp_url(): | |
| return f'https://storage.googleapis.com/dm-clrs/{_get_clrs_file_name()}' | |
| def get_clrs_folder(): | |
| return f'CLRS30_v{CLRSDataset.VERSION}' | |
| def _preprocess(data_point, algorithm=None): | |
| """Convert sampled inputs into DataPoints.""" | |
| inputs = [] | |
| outputs = [] | |
| hints = [] | |
| lengths = None | |
| for name, data in data_point.items(): | |
| if name == 'lengths': | |
| lengths = data | |
| continue | |
| data_point_name = name.split('_') | |
| name = '_'.join(data_point_name[1:]) | |
| (stage, location, dp_type) = specs.SPECS[algorithm][name] | |
| assert stage == data_point_name[0] | |
| if stage == specs.Stage.HINT: | |
| data = tf.experimental.numpy.swapaxes(data, 0, 1) | |
| dp = probing.DataPoint(name, location, dp_type, data) | |
| if stage == specs.Stage.INPUT: | |
| inputs.append(dp) | |
| elif stage == specs.Stage.OUTPUT: | |
| outputs.append(dp) | |
| else: | |
| hints.append(dp) | |
| return samplers.Feedback( | |
| samplers.Features(tuple(inputs), tuple(hints), lengths), tuple(outputs)) | |
| def create_dataset(folder, algorithm, split, batch_size): | |
| dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}', | |
| data_dir=folder, split=split) | |
| num_samples = len(dataset) # Must be done here for correct size | |
| dataset = dataset.repeat() | |
| dataset = dataset.batch(batch_size) | |
| return (dataset.map(lambda d: _preprocess(d, algorithm=algorithm)), | |
| num_samples, | |
| specs.SPECS[algorithm]) | |
| def _copy_hint(source, dest, i, start_source, start_dest, to_add): | |
| """Copy from full-sample hint to a hint chunk.""" | |
| assert np.all(dest[start_dest:, i:] == 0) | |
| assert start_dest < dest.shape[0] | |
| assert start_dest + to_add <= dest.shape[0] | |
| assert start_source < source.shape[0] | |
| assert start_source + to_add <= source.shape[0] | |
| dest[start_dest:start_dest+to_add, i] = source[ | |
| start_source:start_source+to_add, i] | |
| return dest | |
| def _copy_io(source, dest, i, start_dest, to_add): | |
| """Copy from an input or output to an input or output chunk.""" | |
| assert np.all(dest[start_dest:, i:] == 0) | |
| dest[start_dest:start_dest+to_add, i] = source[i] | |
| return dest | |
| def chunkify(dataset: Iterator[samplers.Feedback], chunk_length: int): | |
| """Generator of fixed-length chunks from full-trajectory samples. | |
| Args: | |
| dataset: full-sample dataset as numpy iterator. | |
| chunk_length: time length of chunks. | |
| Yields: | |
| Fixed-timelength chunks of data. Each tensor of inputs, hints and outputs | |
| has dimensions chunk_length x batch_size x ... Samples are not time-padded, | |
| after the end of one sample immediately comes the next. Since different | |
| samples can have different time lengths, the beginnings and ends of samples | |
| within a batch do not need to coincide. For this reason, the chunked | |
| dataset features include two chunk_length x batch_size int tensors, | |
| `is_first` and `is_last`, that mark the beginning and end of each sample. | |
| For example, if `chunk_legnth`==6 and `batch_size`==2 and the first | |
| full-sample batch had one sample of length 3 and one of length 5, | |
| we would have a first chunked batch with the following `is_first` and | |
| `is_last` tensors: | |
| is_first = [[1, 1] is_last = [[0, 0] ( sample id [[0 1] | |
| [0, 0] [0, 0] [0 1] | |
| [0, 0] [1, 0] [0 1] | |
| [1, 0] [0, 0] [2 1] | |
| [0, 0] [0, 1] [2 1] | |
| [0, 1]] [0, 0]] [2 3]] ) | |
| while the data in the inputs, outputs and hints tensors would correspond | |
| to samples as identified by the sample_id indicated above for reference. | |
| Notice that, while in the full-sample dataset inputs and outputs have | |
| no time dimension, here they do; the input and output tensors are simply | |
| repeated along each sample's time length. | |
| """ | |
| def _get_batch(): | |
| d = next(dataset) | |
| return (d.features.inputs, d.features.hints, d.outputs, | |
| d.features.lengths.astype(int)) | |
| inputs, hints, outputs, lengths = _get_batch() | |
| for inp in inputs: | |
| if inp.location in [specs.Location.NODE, specs.Location.EDGE]: | |
| batch_size = inp.data.shape[0] | |
| break | |
| io_chunk = lambda x: np.zeros((chunk_length,) + x.shape, dtype=x.dtype) | |
| chunk_inputs = jax.tree_util.tree_map(io_chunk, inputs) | |
| chunk_outputs = jax.tree_util.tree_map(io_chunk, outputs) | |
| hint_chunk = lambda x: np.zeros((chunk_length,) + x.shape[1:], dtype=x.dtype) | |
| chunk_hints = jax.tree_util.tree_map(hint_chunk, hints) | |
| inputs = [inputs] | |
| hints = [hints] | |
| outputs = [outputs] | |
| left = [lengths.copy()] | |
| lengths = [lengths.copy()] | |
| while True: | |
| # Create a new empty chunk | |
| chunk_inputs = jax.tree_util.tree_map(np.zeros_like, chunk_inputs) | |
| chunk_hints = jax.tree_util.tree_map(np.zeros_like, chunk_hints) | |
| chunk_outputs = jax.tree_util.tree_map(np.zeros_like, chunk_outputs) | |
| start_mark = np.zeros((chunk_length, batch_size), dtype=int) | |
| end_mark = np.zeros((chunk_length, batch_size), dtype=int) | |
| # Get enough data batches to fill the new chunk | |
| while np.any(np.sum(left, axis=0) < chunk_length): | |
| inp, hh, out, ll = _get_batch() | |
| inputs.append(inp) | |
| hints.append(hh) | |
| outputs.append(out) | |
| left.append(ll.copy()) | |
| lengths.append(ll.copy()) | |
| # Fill the chunk, one batch element at a time | |
| for i in range(batch_size): | |
| total, idx = 0, 0 | |
| while total < chunk_length: | |
| to_add = min(left[idx][i], chunk_length - total) | |
| if to_add: | |
| start = lengths[idx][i] - left[idx][i] | |
| assert start >= 0 | |
| f_io = functools.partial(_copy_io, i=i, start_dest=total, | |
| to_add=to_add) | |
| chunk_inputs = jax.tree_util.tree_map(f_io, inputs[idx], chunk_inputs) | |
| chunk_outputs = jax.tree_util.tree_map(f_io, outputs[idx], | |
| chunk_outputs) | |
| f_hint = functools.partial(_copy_hint, i=i, start_source=start, | |
| start_dest=total, to_add=to_add) | |
| chunk_hints = jax.tree_util.tree_map(f_hint, hints[idx], chunk_hints) | |
| if start == 0: | |
| start_mark[total, i] = 1 | |
| total += to_add | |
| left[idx][i] -= to_add | |
| assert left[idx][i] >= 0 | |
| if left[idx][i] == 0: | |
| end_mark[total - 1, i] = 1 | |
| idx += 1 | |
| assert total == chunk_length | |
| while left and np.all(left[0] == 0): | |
| inputs.pop(0) | |
| hints.pop(0) | |
| outputs.pop(0) | |
| left.pop(0) | |
| lengths.pop(0) | |
| yield samplers.Feedback( | |
| samplers.FeaturesChunked(chunk_inputs, chunk_hints, | |
| start_mark, end_mark), | |
| chunk_outputs) | |
| def create_chunked_dataset(folder, algorithm, split, batch_size, chunk_length): | |
| dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}', | |
| data_dir=folder, split=split) | |
| dataset = dataset.repeat() | |
| dataset = dataset.batch(batch_size) | |
| dataset = dataset.map(lambda d: _preprocess(d, algorithm=algorithm)) | |
| dataset = dataset.as_numpy_iterator() | |
| return chunkify(dataset, chunk_length), specs.SPECS[algorithm] | |