Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Integration tests for the RASP -> craft stages of the compiler.""" | |
| import unittest | |
| from absl.testing import absltest | |
| from absl.testing import parameterized | |
| import numpy as np | |
| from tracr.compiler import basis_inference | |
| from tracr.compiler import craft_graph_to_model | |
| from tracr.compiler import expr_to_craft_graph | |
| from tracr.compiler import nodes | |
| from tracr.compiler import rasp_to_graph | |
| from tracr.compiler import test_cases | |
| from tracr.craft import bases | |
| from tracr.craft import tests_common | |
| from tracr.rasp import rasp | |
| _BOS_DIRECTION = "rasp_to_transformer_integration_test_BOS" | |
| _ONE_DIRECTION = "rasp_to_craft_integration_test_ONE" | |
| def _make_input_space(vocab, max_seq_len): | |
| tokens_space = bases.VectorSpaceWithBasis.from_values("tokens", vocab) | |
| indices_space = bases.VectorSpaceWithBasis.from_values( | |
| "indices", range(max_seq_len)) | |
| one_space = bases.VectorSpaceWithBasis.from_names([_ONE_DIRECTION]) | |
| bos_space = bases.VectorSpaceWithBasis.from_names([_BOS_DIRECTION]) | |
| input_space = bases.join_vector_spaces(tokens_space, indices_space, one_space, | |
| bos_space) | |
| return input_space | |
| def _embed_input(input_seq, input_space): | |
| bos_vec = input_space.vector_from_basis_direction( | |
| bases.BasisDirection(_BOS_DIRECTION)) | |
| one_vec = input_space.vector_from_basis_direction( | |
| bases.BasisDirection(_ONE_DIRECTION)) | |
| embedded_input = [bos_vec + one_vec] | |
| for i, val in enumerate(input_seq): | |
| i_vec = input_space.vector_from_basis_direction( | |
| bases.BasisDirection("indices", i)) | |
| val_vec = input_space.vector_from_basis_direction( | |
| bases.BasisDirection("tokens", val)) | |
| embedded_input.append(i_vec + val_vec + one_vec) | |
| return bases.VectorInBasis.stack(embedded_input) | |
| def _embed_output(output_seq, output_space, categorical_output): | |
| embedded_output = [] | |
| output_label = output_space.basis[0].name | |
| for x in output_seq: | |
| if x is None: | |
| out_vec = output_space.null_vector() | |
| elif categorical_output: | |
| out_vec = output_space.vector_from_basis_direction( | |
| bases.BasisDirection(output_label, x)) | |
| else: | |
| out_vec = x * output_space.vector_from_basis_direction( | |
| output_space.basis[0]) | |
| embedded_output.append(out_vec) | |
| return bases.VectorInBasis.stack(embedded_output) | |
| class CompilerIntegrationTest(tests_common.VectorFnTestCase): | |
| def test_rasp_program_and_craft_model_produce_same_output(self, program): | |
| vocab = {0, 1, 2} | |
| max_seq_len = 3 | |
| extracted = rasp_to_graph.extract_rasp_graph(program) | |
| basis_inference.infer_bases( | |
| extracted.graph, | |
| extracted.sink, | |
| vocab, | |
| max_seq_len=max_seq_len, | |
| ) | |
| expr_to_craft_graph.add_craft_components_to_rasp_graph( | |
| extracted.graph, | |
| bos_dir=bases.BasisDirection(_BOS_DIRECTION), | |
| one_dir=bases.BasisDirection(_ONE_DIRECTION), | |
| ) | |
| model = craft_graph_to_model.craft_graph_to_model(extracted.graph, | |
| extracted.sources) | |
| input_space = _make_input_space(vocab, max_seq_len) | |
| output_space = bases.VectorSpaceWithBasis( | |
| extracted.sink[nodes.OUTPUT_BASIS]) | |
| for val in vocab: | |
| test_input = _embed_input([val], input_space) | |
| rasp_output = program([val]) | |
| expected_output = _embed_output( | |
| output_seq=rasp_output, | |
| output_space=output_space, | |
| categorical_output=True) | |
| test_output = model.apply(test_input).project(output_space) | |
| self.assertVectorAllClose( | |
| tests_common.strip_bos_token(test_output), expected_output) | |
| def test_compiled_models_produce_expected_output(self, program, vocab, | |
| test_input, expected_output, | |
| max_seq_len, **kwargs): | |
| del kwargs | |
| categorical_output = rasp.is_categorical(program) | |
| extracted = rasp_to_graph.extract_rasp_graph(program) | |
| basis_inference.infer_bases( | |
| extracted.graph, | |
| extracted.sink, | |
| vocab, | |
| max_seq_len=max_seq_len, | |
| ) | |
| expr_to_craft_graph.add_craft_components_to_rasp_graph( | |
| extracted.graph, | |
| bos_dir=bases.BasisDirection(_BOS_DIRECTION), | |
| one_dir=bases.BasisDirection(_ONE_DIRECTION), | |
| ) | |
| model = craft_graph_to_model.craft_graph_to_model(extracted.graph, | |
| extracted.sources) | |
| input_space = _make_input_space(vocab, max_seq_len) | |
| output_space = bases.VectorSpaceWithBasis( | |
| extracted.sink[nodes.OUTPUT_BASIS]) | |
| if not categorical_output: | |
| self.assertLen(output_space.basis, 1) | |
| test_input_vector = _embed_input(test_input, input_space) | |
| expected_output_vector = _embed_output( | |
| output_seq=expected_output, | |
| output_space=output_space, | |
| categorical_output=categorical_output) | |
| test_output = model.apply(test_input_vector).project(output_space) | |
| self.assertVectorAllClose( | |
| tests_common.strip_bos_token(test_output), expected_output_vector) | |
| def test_setting_default_values_can_lead_to_wrong_outputs_in_compiled_model( | |
| self, program): | |
| # This is an example program in which setting a default value for aggregate | |
| # writes a value to the bos token position, which interfers with a later | |
| # aggregate operation causing the compiled model to have the wrong output. | |
| vocab = {"a", "b"} | |
| test_input = ["a"] | |
| max_seq_len = 2 | |
| # RASP: [False, True] | |
| # compiled: [False, False, True] | |
| not_a = rasp.Map(lambda x: x != "a", rasp.tokens) | |
| # RASP: | |
| # [[True, False], | |
| # [False, False]] | |
| # compiled: | |
| # [[False,True, False], | |
| # [True, False, False]] | |
| sel1 = rasp.Select(rasp.tokens, rasp.tokens, | |
| lambda k, q: k == "a" and q == "a") | |
| # RASP: [False, True] | |
| # compiled: [True, False, True] | |
| agg1 = rasp.Aggregate(sel1, not_a, default=True) | |
| # RASP: | |
| # [[False, True] | |
| # [True, True]] | |
| # compiled: | |
| # [[True, False, False] | |
| # [True, False, False]] | |
| # because pre-softmax we get | |
| # [[1.5, 1, 1] | |
| # [1.5, 1, 1]] | |
| # instead of | |
| # [[0.5, 1, 1] | |
| # [0.5, 1, 1]] | |
| # Because agg1 = True is stored on the BOS token position | |
| sel2 = rasp.Select(agg1, agg1, lambda k, q: k or q) | |
| # RASP: [1, 0.5] | |
| # compiled | |
| # [1, 1, 1] | |
| program = rasp.numerical( | |
| rasp.Aggregate(sel2, rasp.numerical(not_a), default=1)) | |
| expected_output = [1, 0.5] | |
| # RASP program gives the correct output | |
| program_output = program(test_input) | |
| np.testing.assert_allclose(program_output, expected_output) | |
| extracted = rasp_to_graph.extract_rasp_graph(program) | |
| basis_inference.infer_bases( | |
| extracted.graph, | |
| extracted.sink, | |
| vocab, | |
| max_seq_len=max_seq_len, | |
| ) | |
| expr_to_craft_graph.add_craft_components_to_rasp_graph( | |
| extracted.graph, | |
| bos_dir=bases.BasisDirection(_BOS_DIRECTION), | |
| one_dir=bases.BasisDirection(_ONE_DIRECTION), | |
| ) | |
| model = craft_graph_to_model.craft_graph_to_model(extracted.graph, | |
| extracted.sources) | |
| input_space = _make_input_space(vocab, max_seq_len) | |
| output_space = bases.VectorSpaceWithBasis( | |
| extracted.sink[nodes.OUTPUT_BASIS]) | |
| test_input_vector = _embed_input(test_input, input_space) | |
| expected_output_vector = _embed_output( | |
| output_seq=expected_output, | |
| output_space=output_space, | |
| categorical_output=True) | |
| compiled_model_output = model.apply(test_input_vector).project(output_space) | |
| # Compiled craft model gives correct output | |
| self.assertVectorAllClose( | |
| tests_common.strip_bos_token(compiled_model_output), | |
| expected_output_vector) | |
| if __name__ == "__main__": | |
| absltest.main() | |