Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Tests for compiler.expr_to_craft_graph.""" | |
| from absl.testing import absltest | |
| from absl.testing import parameterized | |
| from tracr.compiler import basis_inference | |
| from tracr.compiler import expr_to_craft_graph | |
| from tracr.compiler import lib | |
| from tracr.compiler import nodes | |
| from tracr.compiler import rasp_to_graph | |
| from tracr.craft import bases | |
| from tracr.craft import transformers | |
| from tracr.rasp import rasp | |
| class ExprToCraftGraphTest(parameterized.TestCase): | |
| def _check_block_types_are_correct(self, graph): | |
| for _, node in graph.nodes.items(): | |
| expr = node[nodes.EXPR] | |
| if isinstance(expr, rasp.SOp): | |
| block = node[nodes.MODEL_BLOCK] | |
| if isinstance(expr, (rasp.Map, rasp.SequenceMap)): | |
| self.assertIsInstance(block, transformers.MLP) | |
| elif isinstance(expr, rasp.Aggregate): | |
| self.assertIsInstance(block, transformers.AttentionHead) | |
| def _get_input_space_from_node(self, node): | |
| block = node[nodes.MODEL_BLOCK] | |
| if isinstance(block, transformers.MLP): | |
| return block.fst.input_space | |
| elif isinstance(block, transformers.AttentionHead): | |
| return bases.join_vector_spaces(block.w_qk.left_space, | |
| block.w_qk.right_space, | |
| block.w_ov.input_space) | |
| else: | |
| return None | |
| def _check_spaces_are_consistent(self, graph): | |
| """Check that for each edge the output is a subspace of the input.""" | |
| for u, v in graph.edges: | |
| u_node, v_node = graph.nodes[u], graph.nodes[v] | |
| if isinstance(u_node[nodes.EXPR], rasp.SOp) and isinstance( | |
| v_node[nodes.EXPR], rasp.SOp): | |
| u_out_basis = u_node[nodes.OUTPUT_BASIS] | |
| u_out_space = bases.VectorSpaceWithBasis(u_out_basis) | |
| v_in_space = self._get_input_space_from_node(v_node) | |
| self.assertTrue(u_out_space.issubspace(v_in_space)) | |
| def test_compiling_rasp_programs(self, program): | |
| vocab = {0, 1, 2} | |
| extracted = rasp_to_graph.extract_rasp_graph(program) | |
| basis_inference.infer_bases( | |
| extracted.graph, | |
| extracted.sink, | |
| vocab, | |
| max_seq_len=3, | |
| ) | |
| expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) | |
| self._check_block_types_are_correct(extracted.graph) | |
| self._check_spaces_are_consistent(extracted.graph) | |
| def test_add_craft_components_raises_value_error_if_called_before_basis_inference( | |
| self): | |
| program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) | |
| extracted = rasp_to_graph.extract_rasp_graph(program) | |
| with self.assertRaisesRegex( | |
| ValueError, | |
| r"^.*Craft components can only be added after basis inference.*$"): | |
| expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) | |
| def test_add_craft_components_raises_value_error_if_called_twice(self): | |
| vocab = {0, 1, 2} | |
| program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) | |
| extracted = rasp_to_graph.extract_rasp_graph(program) | |
| basis_inference.infer_bases( | |
| extracted.graph, | |
| extracted.sink, | |
| vocab, | |
| max_seq_len=1, | |
| ) | |
| expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) | |
| with self.assertRaisesRegex( | |
| ValueError, r"^.*Input graph cannot have model blocks set already.*$"): | |
| expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) | |
| if __name__ == "__main__": | |
| absltest.main() | |