Spaces:
Running
Running
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Unit tests for `baselines.py`.""" | |
| import copy | |
| import functools | |
| from typing import Generator | |
| from absl.testing import absltest | |
| from absl.testing import parameterized | |
| import chex | |
| from clrs._src import baselines | |
| from clrs._src import dataset | |
| from clrs._src import probing | |
| from clrs._src import processors | |
| from clrs._src import samplers | |
| from clrs._src import specs | |
| import haiku as hk | |
| import jax | |
| import numpy as np | |
| _Array = np.ndarray | |
| def _error(x, y): | |
| return np.sum(np.abs(x-y)) | |
| def _make_sampler(algo: str, length: int) -> samplers.Sampler: | |
| sampler, _ = samplers.build_sampler( | |
| algo, | |
| seed=samplers.CLRS30['val']['seed'], | |
| num_samples=samplers.CLRS30['val']['num_samples'], | |
| length=length, | |
| ) | |
| return sampler | |
| def _without_permutation(feedback): | |
| """Replace should-be permutations with pointers.""" | |
| outputs = [] | |
| for x in feedback.outputs: | |
| if x.type_ != specs.Type.SHOULD_BE_PERMUTATION: | |
| outputs.append(x) | |
| continue | |
| assert x.location == specs.Location.NODE | |
| outputs.append(probing.DataPoint(name=x.name, location=x.location, | |
| type_=specs.Type.POINTER, data=x.data)) | |
| return feedback._replace(outputs=outputs) | |
| def _make_iterable_sampler( | |
| algo: str, batch_size: int, | |
| length: int) -> Generator[samplers.Feedback, None, None]: | |
| sampler = _make_sampler(algo, length) | |
| while True: | |
| yield _without_permutation(sampler.next(batch_size)) | |
| def _remove_permutation_from_spec(spec): | |
| """Modify spec to turn permutation type to pointer.""" | |
| new_spec = {} | |
| for k in spec: | |
| if (spec[k][1] == specs.Location.NODE and | |
| spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION): | |
| new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER) | |
| else: | |
| new_spec[k] = spec[k] | |
| return new_spec | |
| class BaselinesTest(parameterized.TestCase): | |
| def test_full_vs_chunked(self): | |
| """Test that chunking does not affect gradients.""" | |
| batch_size = 4 | |
| length = 8 | |
| algo = 'insertion_sort' | |
| spec = _remove_permutation_from_spec(specs.SPECS[algo]) | |
| rng_key = jax.random.PRNGKey(42) | |
| full_ds = _make_iterable_sampler(algo, batch_size, length) | |
| chunked_ds = dataset.chunkify( | |
| _make_iterable_sampler(algo, batch_size, length), | |
| length) | |
| double_chunked_ds = dataset.chunkify( | |
| _make_iterable_sampler(algo, batch_size, length), | |
| length * 2) | |
| full_batches = [next(full_ds) for _ in range(2)] | |
| chunked_batches = [next(chunked_ds) for _ in range(2)] | |
| double_chunk_batch = next(double_chunked_ds) | |
| with chex.fake_jit(): # jitting makes test longer | |
| processor_factory = processors.get_processor_factory( | |
| 'mpnn', use_ln=False, nb_triplet_fts=0) | |
| common_args = dict(processor_factory=processor_factory, hidden_dim=8, | |
| learning_rate=0.01, | |
| decode_hints=True, encode_hints=True) | |
| b_full = baselines.BaselineModel( | |
| spec, dummy_trajectory=full_batches[0], **common_args) | |
| b_full.init(full_batches[0].features, seed=42) # pytype: disable=wrong-arg-types # jax-ndarray | |
| full_params = b_full.params | |
| full_loss_0 = b_full.feedback(rng_key, full_batches[0]) | |
| b_full.params = full_params | |
| full_loss_1 = b_full.feedback(rng_key, full_batches[1]) | |
| new_full_params = b_full.params | |
| b_chunked = baselines.BaselineModelChunked( | |
| spec, dummy_trajectory=chunked_batches[0], **common_args) | |
| b_chunked.init([[chunked_batches[0].features]], seed=42) # pytype: disable=wrong-arg-types # jax-ndarray | |
| chunked_params = b_chunked.params | |
| jax.tree_util.tree_map(np.testing.assert_array_equal, full_params, | |
| chunked_params) | |
| chunked_loss_0 = b_chunked.feedback(rng_key, chunked_batches[0]) | |
| b_chunked.params = chunked_params | |
| chunked_loss_1 = b_chunked.feedback(rng_key, chunked_batches[1]) | |
| new_chunked_params = b_chunked.params | |
| b_chunked.params = chunked_params | |
| double_chunked_loss = b_chunked.feedback(rng_key, double_chunk_batch) | |
| # Test that losses match | |
| np.testing.assert_allclose(full_loss_0, chunked_loss_0, rtol=1e-4) | |
| np.testing.assert_allclose(full_loss_1, chunked_loss_1, rtol=1e-4) | |
| np.testing.assert_allclose(full_loss_0 + full_loss_1, | |
| 2 * double_chunked_loss, | |
| rtol=1e-4) | |
| # Test that gradients are the same (parameters changed equally). | |
| # First check that gradients were not zero, i.e., parameters have changed. | |
| param_change, _ = jax.tree_util.tree_flatten( | |
| jax.tree_util.tree_map(_error, full_params, new_full_params)) | |
| self.assertGreater(np.mean(param_change), 0.1) | |
| # Now check that full and chunked gradients are the same. | |
| jax.tree_util.tree_map( | |
| functools.partial(np.testing.assert_allclose, rtol=1e-4), | |
| new_full_params, new_chunked_params) | |
| def test_multi_vs_single(self): | |
| """Test that multi = single when we only train one of the algorithms.""" | |
| batch_size = 4 | |
| length = 16 | |
| algos = ['insertion_sort', 'activity_selector', 'bfs'] | |
| spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos] | |
| rng_key = jax.random.PRNGKey(42) | |
| full_ds = [_make_iterable_sampler(algo, batch_size, length) | |
| for algo in algos] | |
| full_batches = [next(ds) for ds in full_ds] | |
| full_batches_2 = [next(ds) for ds in full_ds] | |
| with chex.fake_jit(): # jitting makes test longer | |
| processor_factory = processors.get_processor_factory( | |
| 'mpnn', use_ln=False, nb_triplet_fts=0) | |
| common_args = dict(processor_factory=processor_factory, hidden_dim=8, | |
| learning_rate=0.01, | |
| decode_hints=True, encode_hints=True) | |
| b_single = baselines.BaselineModel( | |
| spec[0], dummy_trajectory=full_batches[0], **common_args) | |
| b_multi = baselines.BaselineModel( | |
| spec, dummy_trajectory=full_batches, **common_args) | |
| b_single.init(full_batches[0].features, seed=0) # pytype: disable=wrong-arg-types # jax-ndarray | |
| b_multi.init([f.features for f in full_batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray | |
| single_params = [] | |
| single_losses = [] | |
| multi_params = [] | |
| multi_losses = [] | |
| single_params.append(copy.deepcopy(b_single.params)) | |
| single_losses.append(b_single.feedback(rng_key, full_batches[0])) | |
| single_params.append(copy.deepcopy(b_single.params)) | |
| single_losses.append(b_single.feedback(rng_key, full_batches_2[0])) | |
| single_params.append(copy.deepcopy(b_single.params)) | |
| multi_params.append(copy.deepcopy(b_multi.params)) | |
| multi_losses.append(b_multi.feedback(rng_key, full_batches[0], | |
| algorithm_index=0)) | |
| multi_params.append(copy.deepcopy(b_multi.params)) | |
| multi_losses.append(b_multi.feedback(rng_key, full_batches_2[0], | |
| algorithm_index=0)) | |
| multi_params.append(copy.deepcopy(b_multi.params)) | |
| # Test that losses match | |
| np.testing.assert_array_equal(single_losses, multi_losses) | |
| # Test that loss decreased | |
| assert single_losses[1] < single_losses[0] | |
| # Test that param changes were the same in single and multi-algorithm | |
| for single, multi in zip(single_params, multi_params): | |
| assert hk.data_structures.is_subset(subset=single, superset=multi) | |
| for module_name, params in single.items(): | |
| jax.tree_util.tree_map(np.testing.assert_array_equal, params, | |
| multi[module_name]) | |
| # Test that params change for the trained algorithm, but not the others | |
| for module_name, params in multi_params[0].items(): | |
| param_changes = jax.tree_util.tree_map(lambda a, b: np.sum(np.abs(a - b)), | |
| params, | |
| multi_params[1][module_name]) | |
| param_change = sum(param_changes.values()) | |
| if module_name in single_params[0]: # params of trained algorithm | |
| assert param_change > 1e-3 | |
| else: # params of non-trained algorithms | |
| assert param_change == 0.0 | |
| def test_multi_algorithm_idx(self, is_chunked): | |
| """Test that algorithm selection works as intended.""" | |
| batch_size = 4 | |
| length = 8 | |
| algos = ['insertion_sort', 'activity_selector', 'bfs'] | |
| spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos] | |
| rng_key = jax.random.PRNGKey(42) | |
| if is_chunked: | |
| ds = [dataset.chunkify(_make_iterable_sampler(algo, batch_size, length), | |
| 2 * length) for algo in algos] | |
| else: | |
| ds = [_make_iterable_sampler(algo, batch_size, length) for algo in algos] | |
| batches = [next(d) for d in ds] | |
| processor_factory = processors.get_processor_factory( | |
| 'mpnn', use_ln=False, nb_triplet_fts=0) | |
| common_args = dict(processor_factory=processor_factory, hidden_dim=8, | |
| learning_rate=0.01, | |
| decode_hints=True, encode_hints=True) | |
| if is_chunked: | |
| baseline = baselines.BaselineModelChunked( | |
| spec, dummy_trajectory=batches, **common_args) | |
| baseline.init([[f.features for f in batches]], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray | |
| else: | |
| baseline = baselines.BaselineModel( | |
| spec, dummy_trajectory=batches, **common_args) | |
| baseline.init([f.features for f in batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray | |
| # Find out what parameters change when we train each algorithm | |
| def _change(x, y): | |
| changes = {} | |
| for module_name, params in x.items(): | |
| changes[module_name] = sum( | |
| jax.tree_util.tree_map( | |
| lambda a, b: np.sum(np.abs(a-b)), params, y[module_name] | |
| ).values()) | |
| return changes | |
| param_changes = [] | |
| for algo_idx in range(len(algos)): | |
| init_params = copy.deepcopy(baseline.params) | |
| _ = baseline.feedback( | |
| rng_key, | |
| batches[algo_idx], | |
| algorithm_index=(0, algo_idx) if is_chunked else algo_idx) | |
| param_changes.append(_change(init_params, baseline.params)) | |
| # Test that non-changing parameters correspond to encoders/decoders | |
| # associated with the non-trained algorithms | |
| unchanged = [[k for k in pc if pc[k] == 0] for pc in param_changes] | |
| def _get_other_algos(algo_idx, modules): | |
| return set([k for k in modules if '_construct_encoders_decoders' in k | |
| and f'algo_{algo_idx}' not in k]) | |
| for algo_idx in range(len(algos)): | |
| expected_unchanged = _get_other_algos(algo_idx, baseline.params.keys()) | |
| self.assertNotEmpty(expected_unchanged) | |
| self.assertSetEqual(expected_unchanged, set(unchanged[algo_idx])) | |
| if __name__ == '__main__': | |
| absltest.main() | |