Spaces:
Running
Running
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Unit tests for `losses.py`.""" | |
| from typing import Generator | |
| from absl.testing import absltest | |
| from absl.testing import parameterized | |
| from clrs._src import dataset | |
| from clrs._src import losses | |
| from clrs._src import probing | |
| from clrs._src import samplers | |
| from clrs._src import specs | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| _Array = np.ndarray | |
| _Location = specs.Location | |
| def _make_sampler(algo: str, nb_nodes: int) -> samplers.Sampler: | |
| sampler, _ = samplers.build_sampler( | |
| algo, | |
| seed=samplers.CLRS30['val']['seed'], | |
| num_samples=samplers.CLRS30['val']['num_samples'], | |
| length=nb_nodes, | |
| ) | |
| return sampler | |
| def _make_iterable_sampler( | |
| algo: str, batch_size: int, | |
| nb_nodes: int) -> Generator[samplers.Feedback, None, None]: | |
| sampler = _make_sampler(algo, nb_nodes) | |
| while True: | |
| yield sampler.next(batch_size) | |
| def _as_pred_data(x, nb_nodes, seed, batch_axis): | |
| """Fake a prediction from a data point.""" | |
| # Permute along batch axis to make the prediction different. | |
| key = jax.random.PRNGKey(seed) | |
| data = jax.random.permutation(key, x.data, axis=batch_axis) | |
| # Extend to one-hot for pointer types. | |
| if x.type_ == specs.Type.POINTER: | |
| return jax.nn.one_hot(data, nb_nodes) | |
| return data | |
| def _mask_datapoint(x, seed, t_axis=None): | |
| """Add some masking to data.""" | |
| key = jax.random.PRNGKey(seed) | |
| data = x.data | |
| if x.type_ == specs.Type.MASK: | |
| # mask some data at random | |
| mask_shape = list(data.shape) | |
| if t_axis is not None: | |
| mask_shape[t_axis] = 1 | |
| mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2 | |
| data = jnp.where(mask, specs.OutputClass.MASKED, data) | |
| elif x.type_ in [specs.Type.CATEGORICAL, specs.Type.MASK_ONE]: | |
| # mask some data at random (all categories together) | |
| mask_shape = list(data.shape)[:-1] | |
| if t_axis is not None: | |
| mask_shape[t_axis] = 1 | |
| mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2 | |
| data = jnp.where(mask[..., None], specs.OutputClass.MASKED, data) | |
| return probing.DataPoint(name=x.name, location=x.location, type_=x.type_, | |
| data=data) | |
| def _rand_diff(seed, shape): | |
| return 2.0 * jax.random.uniform(jax.random.PRNGKey(seed), shape) - 1.0 | |
| def _rand_mask(seed, shape, p=0.5): | |
| return (jax.random.uniform(jax.random.PRNGKey(seed), shape) > p).astype(float) | |
| def invert(d): | |
| """Dict of lists -> list of dicts.""" | |
| if d: | |
| return [dict(zip(d, i)) for i in zip(*d.values())] | |
| def _create_data(algo, nb_nodes): | |
| batch_size = 8 | |
| ds = _make_iterable_sampler(algo, batch_size, nb_nodes) | |
| full_sample = next(ds) | |
| chunk_length = full_sample.features.lengths[0].astype(int) | |
| chunked_ds = dataset.chunkify( | |
| _make_iterable_sampler(algo, batch_size, nb_nodes), | |
| chunk_length) | |
| chunk_sample = next(chunked_ds) | |
| return full_sample, chunk_sample | |
| class FullVsChunkLossesTest(parameterized.TestCase): | |
| """Test that the full and chunked versions of the losses match.""" | |
| # Test two algorithms with fixed-length, covering all data types | |
| def test_output_loss(self, algo): | |
| nb_nodes = 16 | |
| full_sample, chunk_sample = _create_data(algo, nb_nodes) | |
| # Calculate output loss. | |
| for truth_full, truth_chunked in zip(full_sample.outputs, | |
| chunk_sample.outputs): | |
| chunk_output_loss = losses.output_loss_chunked( | |
| truth=_mask_datapoint(truth_chunked, seed=0), | |
| pred=_as_pred_data(truth_chunked, nb_nodes, 0, 1), | |
| is_last=chunk_sample.features.is_last, | |
| nb_nodes=nb_nodes, | |
| ) | |
| full_output_loss = losses.output_loss( | |
| truth=_mask_datapoint(truth_full, seed=0), | |
| pred=_as_pred_data(truth_full, nb_nodes, 0, 0), | |
| nb_nodes=nb_nodes, | |
| ) | |
| np.testing.assert_allclose(chunk_output_loss, full_output_loss, rtol=1e-4) | |
| def test_hint_loss(self, algo): | |
| nb_nodes = 16 | |
| full_sample, chunk_sample = _create_data(algo, nb_nodes) | |
| for truth_full, truth_chunked in zip(full_sample.features.hints, | |
| chunk_sample.features.hints): | |
| np.testing.assert_array_equal(truth_full.data, truth_chunked.data) | |
| pred = _as_pred_data(truth_chunked, nb_nodes, 0, 1) | |
| chunk_hint_loss = losses.hint_loss_chunked( | |
| truth=_mask_datapoint(truth_chunked, seed=1, t_axis=0), | |
| pred=pred, | |
| is_first=chunk_sample.features.is_first, | |
| nb_nodes=nb_nodes, | |
| ) | |
| full_preds = pred[1:] | |
| full_hint_loss = losses.hint_loss( | |
| truth=_mask_datapoint(truth_full, 1, t_axis=0), | |
| preds=full_preds, | |
| lengths=full_sample.features.lengths, | |
| nb_nodes=nb_nodes, | |
| ) | |
| np.testing.assert_allclose(chunk_hint_loss, full_hint_loss, rtol=1e-4) | |
| if __name__ == '__main__': | |
| absltest.main() | |