Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Tests for transformers.""" | |
| from absl.testing import absltest | |
| from absl.testing import parameterized | |
| import numpy as np | |
| from tracr.craft import bases | |
| from tracr.craft import tests_common | |
| from tracr.craft import transformers | |
| from tracr.craft import vectorspace_fns as vs_fns | |
| # This makes it easier to use comments to annotate dimensions in arrays | |
| # pylint: disable=g-no-space-after-comment | |
| class AttentionHeadTest(tests_common.VectorFnTestCase): | |
| def test_attention_head(self, with_residual_stream): | |
| i = bases.VectorSpaceWithBasis.from_values("i", [1, 2]) | |
| o = bases.VectorSpaceWithBasis.from_values("o", [1, 2]) | |
| q = bases.VectorSpaceWithBasis.from_values("q", [1, 2]) | |
| k = bases.VectorSpaceWithBasis.from_values("p", [1, 2]) | |
| rs = bases.direct_sum(i, o, q, k) | |
| seq = bases.VectorInBasis( | |
| rs.basis, | |
| np.array([ | |
| #i1 i2 o1 o2 q1 q2 p1 p2 | |
| [1, 0, 0, 0, 1, 0, 1, 0], | |
| [0, 1, 0, 0, 0, 1, 0, 1], | |
| ])) | |
| head = transformers.AttentionHead( | |
| w_qk=vs_fns.ScalarBilinear(q, k, | |
| np.eye(2) * 100), | |
| w_ov=vs_fns.Linear(i, o, np.eye(2)), | |
| residual_space=rs if with_residual_stream else None, | |
| causal=False, | |
| ) | |
| self.assertVectorAllClose( | |
| head.apply(seq), | |
| bases.VectorInBasis( | |
| rs.basis, | |
| np.array([ | |
| #i1 i2 o1 o2 q1 q2 p1 p2 | |
| [0, 0, 1, 0, 0, 0, 0, 0], | |
| [0, 0, 0, 1, 0, 0, 0, 0], | |
| ])), | |
| ) | |
| class MLPTest(tests_common.VectorFnTestCase): | |
| def test_mlp(self, with_residual_stream, same_in_out): | |
| i = bases.VectorSpaceWithBasis.from_values("i", [1, 2]) | |
| if same_in_out: | |
| o, rs = i, i | |
| expected_result = np.array([ | |
| #o1 o2 | |
| [1, 0], | |
| [0, 1], | |
| ]) | |
| else: | |
| o = bases.VectorSpaceWithBasis.from_values("o", [1, 2]) | |
| rs = bases.direct_sum(i, o) | |
| expected_result = np.array([ | |
| #i1 i2 o1 o2 | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1], | |
| ]) | |
| h = bases.VectorSpaceWithBasis.from_values("p", [1, 2]) | |
| seq = bases.VectorInBasis( | |
| i.basis, | |
| np.array([ | |
| #i1 i2 | |
| [1, -1], | |
| [-1, 1], | |
| ])).project(rs) | |
| mlp = transformers.MLP( | |
| fst=vs_fns.Linear(i, h, np.eye(2)), | |
| snd=vs_fns.Linear(h, o, np.eye(2)), | |
| residual_space=rs if with_residual_stream else None, | |
| ) | |
| self.assertEqual( | |
| mlp.apply(seq), | |
| bases.VectorInBasis(rs.basis, expected_result), | |
| ) | |
| def test_combining_mlps(self): | |
| in12 = bases.VectorSpaceWithBasis.from_values("in", [1, 2]) | |
| in34 = bases.VectorSpaceWithBasis.from_values("in", [3, 4]) | |
| out12 = bases.VectorSpaceWithBasis.from_values("out", [1, 2]) | |
| residual_space = bases.join_vector_spaces(in12, in34, out12) | |
| h1 = bases.VectorSpaceWithBasis.from_values("h", [1]) | |
| h2 = bases.VectorSpaceWithBasis.from_values("h", [2]) | |
| # MLP1 maps in2 -> h1 -> out1 | |
| mlp1 = transformers.MLP( | |
| fst=vs_fns.Linear(in12, h1, np.array([[0], [1]])), | |
| snd=vs_fns.Linear(h1, out12, np.array([[1, 0]]))) | |
| # MLP2 maps in3 -> h2 -> out2 | |
| mlp2 = transformers.MLP( | |
| fst=vs_fns.Linear(in34, h2, np.array([[1], [0]])), | |
| snd=vs_fns.Linear(h2, out12, np.array([[0, 1]]))) | |
| mlp = transformers.MLP.combine_in_parallel([mlp1, mlp2]) | |
| seq = bases.VectorInBasis( | |
| bases.direct_sum(in12, in34).basis, | |
| np.array([ | |
| #i1 i2 i3 i4 | |
| [1, 2, 0, 0], | |
| [0, 2, 3, 4], | |
| ])).project(residual_space) | |
| expected_result = bases.VectorInBasis( | |
| out12.basis, | |
| np.array([ | |
| #o1 o2 | |
| [2, 0], | |
| [2, 3], | |
| ])) | |
| self.assertEqual( | |
| mlp.apply(seq).project(out12), | |
| expected_result, | |
| ) | |
| if __name__ == "__main__": | |
| absltest.main() | |