Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Integration tests for the full RASP -> transformer compilation.""" | |
| from absl.testing import absltest | |
| from absl.testing import parameterized | |
| import jax | |
| import numpy as np | |
| from tracr.compiler import compiling | |
| from tracr.compiler import lib | |
| from tracr.compiler import test_cases | |
| from tracr.craft import tests_common | |
| from tracr.rasp import rasp | |
| _COMPILER_BOS = "rasp_to_transformer_integration_test_BOS" | |
| _COMPILER_PAD = "rasp_to_transformer_integration_test_PAD" | |
| # Force float32 precision on TPU, which otherwise defaults to float16. | |
| jax.config.update("jax_default_matmul_precision", "float32") | |
| class CompilerIntegrationTest(tests_common.VectorFnTestCase): | |
| def assertSequenceEqualWhenExpectedIsNotNone(self, actual_seq, expected_seq): | |
| for actual, expected in zip(actual_seq, expected_seq): | |
| if expected is not None and actual != expected: | |
| self.fail(f"{actual_seq} does not match (ignoring Nones) " | |
| f"{expected_seq=}") | |
| def test_rasp_program_and_transformer_produce_same_output(self, program): | |
| vocab = {0, 1, 2} | |
| max_seq_len = 3 | |
| assembled_model = compiling.compile_rasp_to_model( | |
| program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS) | |
| test_outputs = {} | |
| rasp_outputs = {} | |
| for val in vocab: | |
| test_outputs[val] = assembled_model.apply([_COMPILER_BOS, val]).decoded[1] | |
| rasp_outputs[val] = program([val])[0] | |
| with self.subTest(val=0): | |
| self.assertEqual(test_outputs[0], rasp_outputs[0]) | |
| with self.subTest(val=1): | |
| self.assertEqual(test_outputs[1], rasp_outputs[1]) | |
| with self.subTest(val=2): | |
| self.assertEqual(test_outputs[2], rasp_outputs[2]) | |
| def test_compiled_models_produce_expected_output(self, program, vocab, | |
| test_input, expected_output, | |
| max_seq_len, **kwargs): | |
| del kwargs | |
| assembled_model = compiling.compile_rasp_to_model( | |
| program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS) | |
| test_output = assembled_model.apply([_COMPILER_BOS] + test_input) | |
| if isinstance(expected_output[0], (int, float)): | |
| np.testing.assert_allclose( | |
| test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005) | |
| else: | |
| self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:], | |
| expected_output) | |
| def test_compiled_causal_models_produce_expected_output( | |
| self, program, vocab, test_input, expected_output, max_seq_len, **kwargs): | |
| del kwargs | |
| assembled_model = compiling.compile_rasp_to_model( | |
| program, | |
| vocab, | |
| max_seq_len, | |
| causal=True, | |
| compiler_bos=_COMPILER_BOS, | |
| compiler_pad=_COMPILER_PAD) | |
| test_output = assembled_model.apply([_COMPILER_BOS] + test_input) | |
| if isinstance(expected_output[0], (int, float)): | |
| np.testing.assert_allclose( | |
| test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005) | |
| else: | |
| self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:], | |
| expected_output) | |
| def test_compiled_models_produce_expected_output_with_padding( | |
| self, program, vocab, test_input, expected_output, max_seq_len, **kwargs): | |
| del kwargs | |
| assembled_model = compiling.compile_rasp_to_model( | |
| program, | |
| vocab, | |
| max_seq_len, | |
| compiler_bos=_COMPILER_BOS, | |
| compiler_pad=_COMPILER_PAD) | |
| pad_len = (max_seq_len - len(test_input)) | |
| test_input = test_input + [_COMPILER_PAD] * pad_len | |
| test_input = [_COMPILER_BOS] + test_input | |
| test_output = assembled_model.apply(test_input) | |
| output = test_output.decoded | |
| output_len = len(output) | |
| output_stripped = test_output.decoded[1:output_len - pad_len] | |
| self.assertEqual(output[0], _COMPILER_BOS) | |
| if isinstance(expected_output[0], (int, float)): | |
| np.testing.assert_allclose( | |
| output_stripped, expected_output, atol=1e-7, rtol=0.005) | |
| else: | |
| self.assertEqual(output_stripped, expected_output) | |
| if __name__ == "__main__": | |
| absltest.main() | |