Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Tests for transformer.assemble.""" | |
| from absl.testing import absltest | |
| from absl.testing import parameterized | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from tracr.compiler import assemble | |
| from tracr.craft import bases | |
| class AssembleTest(parameterized.TestCase): | |
| def test_token_embedding_produces_correct_embedding(self): | |
| # Token embeddings should be one-hot embeddings of the input integers | |
| # into the token subspace of residual_space | |
| input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) | |
| indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) | |
| output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) | |
| residual_space = bases.join_vector_spaces(input_space, indices_space, | |
| output_space) | |
| def token_pos_embed(tokens): | |
| embed_modules = assemble._make_embedding_modules( | |
| residual_space=residual_space, | |
| tokens_space=input_space, | |
| indices_space=indices_space, | |
| output_space=output_space) | |
| return embed_modules.token_embed(tokens) | |
| tokens = jnp.array([0, 0, 1]) | |
| expected_token_embeddings = jnp.array([[1, 0, 0, 0, 0, 0, 0], | |
| [1, 0, 0, 0, 0, 0, 0], | |
| [0, 1, 0, 0, 0, 0, 0]]) | |
| params = token_pos_embed.init(jax.random.PRNGKey(0), tokens) | |
| embeddings = token_pos_embed.apply(params, tokens) | |
| np.testing.assert_allclose(embeddings, expected_token_embeddings) | |
| def test_position_embedding_produces_correct_embedding(self): | |
| # Position embeddings should be one-hot embeddings of the input integers | |
| # (representing indices) into the indices subspace of residual_space | |
| input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) | |
| indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) | |
| output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) | |
| residual_space = bases.join_vector_spaces(input_space, indices_space, | |
| output_space) | |
| def token_pos_embed(tokens): | |
| embed_modules = assemble._make_embedding_modules( | |
| residual_space=residual_space, | |
| tokens_space=input_space, | |
| indices_space=indices_space, | |
| output_space=output_space) | |
| return embed_modules.pos_embed(jnp.indices(tokens.shape)[-1]) | |
| tokens = jnp.array([3, 0, 0, 1]) | |
| expected_pos_embeddings = jnp.array([[0, 0, 0, 0, 0, 0, 0], | |
| [0, 0, 1, 0, 0, 0, 0], | |
| [0, 0, 0, 1, 0, 0, 0], | |
| [0, 0, 0, 0, 1, 0, 0]]) | |
| params = token_pos_embed.init(jax.random.PRNGKey(0), tokens) | |
| embeddings = token_pos_embed.apply(params, tokens) | |
| np.testing.assert_allclose(embeddings, expected_pos_embeddings) | |
| def test_unembedding(self): | |
| # Prepend numbers to preserve basis order [input, index, output] | |
| input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2)) | |
| indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3)) | |
| output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2)) | |
| residual_space = bases.join_vector_spaces(input_space, indices_space, | |
| output_space) | |
| def unembed(embeddings): | |
| embed_modules = assemble._make_embedding_modules( | |
| residual_space=residual_space, | |
| tokens_space=input_space, | |
| indices_space=indices_space, | |
| output_space=output_space) | |
| return embed_modules.unembed(embeddings, use_unembed_argmax=True) | |
| embeddings = jnp.array([ | |
| # pylint: disable=g-no-space-after-comment | |
| #inp| indices| out | < spaces | |
| #0 1 0 1 2 0 1 < values in spaces | |
| [0, 0, 0, 0, 0, 0, 1], | |
| [0, 0, 0, 0, 0, 1, 0], | |
| [0, 0, 0, 0, 0, 0, 1] | |
| ]) | |
| expected_tokens = jnp.array([1, 0, 1]) | |
| params = unembed.init(jax.random.PRNGKey(0), embeddings) | |
| tokens = unembed.apply(params, embeddings) | |
| np.testing.assert_allclose(tokens, expected_tokens) | |
| if __name__ == "__main__": | |
| absltest.main() | |