Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Tests for transformer.model.""" | |
| from absl.testing import absltest | |
| from absl.testing import parameterized | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from tracr.transformer import model | |
| class TransformerTest(parameterized.TestCase): | |
| def _check_layer_naming(self, params): | |
| # Modules should be named for example | |
| # For MLPs: "transformer/layer_{i}/mlp/linear_1" | |
| # For Attention: "transformer/layer_{i}/attn/key" | |
| # For Layer Norm: "transformer/layer_{i}/layer_norm" | |
| for key in params.keys(): | |
| levels = key.split("/") | |
| self.assertEqual(levels[0], "transformer") | |
| if levels[1].startswith("layer_norm"): | |
| continue # output layer norm | |
| self.assertStartsWith(levels[1], "layer") | |
| if levels[2] == "mlp": | |
| self.assertIn(levels[3], {"linear_1", "linear_2"}) | |
| elif levels[2] == "attn": | |
| self.assertIn(levels[3], {"key", "query", "value", "linear"}) | |
| else: | |
| self.assertStartsWith(levels[2], "layer_norm") | |
| def _zero_mlps(self, params): | |
| for module in params: | |
| if "mlp" in module: | |
| for param in params[module]: | |
| params[module][param] = jnp.zeros_like(params[module][param]) | |
| return params | |
| def test_layer_norm(self, layer_norm): | |
| # input = [1, 1, 1, 1] | |
| # If layer norm is used, this should give all-0 output for a freshly | |
| # initialized model because LN will subtract the mean after each layer. | |
| # Else we expect non-zero outputs. | |
| def forward(emb, mask): | |
| transformer = model.Transformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| layer_norm=layer_norm)) | |
| return transformer(emb, mask).output | |
| seq_len = 4 | |
| emb = jnp.ones((1, seq_len, 1)) | |
| mask = jnp.ones((1, seq_len)) | |
| rng = hk.PRNGSequence(1) | |
| params = forward.init(next(rng), emb, mask) | |
| out = forward.apply(params, next(rng), emb, mask) | |
| self._check_layer_naming(params) | |
| if layer_norm: | |
| np.testing.assert_allclose(out, 0) | |
| else: | |
| self.assertFalse(np.allclose(out, 0)) | |
| def test_causal_attention(self, causal): | |
| # input = [0, random, random, random] | |
| # mask = [1, 0, 1, 1] | |
| # For causal attention the second token can only attend to the first one, so | |
| # it should be the same. For non-causal attention all tokens should change. | |
| def forward(emb, mask): | |
| transformer = model.Transformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| layer_norm=False, | |
| causal=causal)) | |
| return transformer(emb, mask).output | |
| seq_len = 4 | |
| emb = np.random.random((1, seq_len, 1)) | |
| emb[:, 0, :] = 0 | |
| mask = np.array([[1, 0, 1, 1]]) | |
| emb, mask = jnp.array(emb), jnp.array(mask) | |
| rng = hk.PRNGSequence(1) | |
| params = forward.init(next(rng), emb, mask) | |
| params = self._zero_mlps(params) | |
| out = forward.apply(params, next(rng), emb, mask) | |
| self._check_layer_naming(params) | |
| if causal: | |
| self.assertEqual(0, out[0, 0, 0]) | |
| self.assertEqual(emb[0, 1, 0], out[0, 1, 0]) | |
| else: | |
| self.assertNotEqual(0, out[0, 0, 0]) | |
| self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0]) | |
| self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0]) | |
| self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0]) | |
| def test_setting_activation_function_to_zero(self): | |
| # An activation function that always returns zeros should result in the | |
| # same model output as setting all MLP weights to zero. | |
| def forward_zero(emb, mask): | |
| transformer = model.Transformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| causal=False, | |
| layer_norm=False, | |
| activation_function=jnp.zeros_like)) | |
| return transformer(emb, mask).output | |
| def forward(emb, mask): | |
| transformer = model.Transformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| causal=False, | |
| layer_norm=False, | |
| activation_function=jax.nn.gelu)) | |
| return transformer(emb, mask).output | |
| seq_len = 4 | |
| emb = np.random.random((1, seq_len, 1)) | |
| mask = np.ones((1, seq_len)) | |
| emb, mask = jnp.array(emb), jnp.array(mask) | |
| rng = hk.PRNGSequence(1) | |
| params = forward.init(next(rng), emb, mask) | |
| params_no_mlps = self._zero_mlps(params) | |
| out_zero_activation = forward_zero.apply(params, next(rng), emb, mask) | |
| out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask) | |
| self._check_layer_naming(params) | |
| np.testing.assert_allclose(out_zero_activation, out_no_mlps) | |
| self.assertFalse(np.allclose(out_zero_activation, 0)) | |
| class CompiledTransformerModelTest(parameterized.TestCase): | |
| def _get_one_hot_embed_unembed(self, vocab_size, max_seq_len): | |
| # Embeds tokens as one-hot into the first `vocab_size` dimensions | |
| token_embed = hk.Embed( | |
| embedding_matrix=jnp.block( | |
| [jnp.eye(vocab_size), | |
| jnp.zeros((vocab_size, max_seq_len))])) | |
| # Embeds positions as one-hot into the last `max_seq_len` dimensions | |
| position_embed = hk.Embed( | |
| embedding_matrix=jnp.block( | |
| [jnp.zeros((max_seq_len, vocab_size)), | |
| jnp.eye(max_seq_len)])) | |
| class Unembed(hk.Module): | |
| def __call__(self, embeddings): | |
| return jnp.argmax(embeddings[:, :, :vocab_size], axis=-1) | |
| return token_embed, position_embed, Unembed() | |
| def test_embedding_gives_desired_result(self): | |
| tokens = jnp.array([[1, 2, 3]]) | |
| vocab_size, max_seq_len, pad_token = 5, 5, 0 | |
| expected_embeddings = jnp.array([[[0, 1, 0, 0, 0, 1, 0, 0, 0, 0], | |
| [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], | |
| [0, 0, 0, 1, 0, 0, 0, 1, 0, 0]]]) | |
| def embed(tokens): | |
| transformer = model.Transformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| causal=False, | |
| layer_norm=False, | |
| activation_function=jax.nn.gelu)) | |
| token_embed, position_embed, unembed = self._get_one_hot_embed_unembed( | |
| vocab_size, max_seq_len) | |
| compiled_model = model.CompiledTransformerModel( | |
| transformer=transformer, | |
| token_embed=token_embed, | |
| position_embed=position_embed, | |
| unembed=unembed, | |
| use_unembed_argmax=True, | |
| pad_token=pad_token) | |
| return compiled_model.embed(tokens) | |
| rng = hk.PRNGSequence(1) | |
| params = embed.init(next(rng), tokens) | |
| embeddings = embed.apply(params, next(rng), tokens) | |
| np.testing.assert_allclose(embeddings, expected_embeddings) | |
| def test_embedding_then_unembedding_gives_same_tokens(self): | |
| tokens = jnp.array([[1, 2, 3], [4, 5, 6], [3, 2, 4]]) | |
| vocab_size, max_seq_len, pad_token = 10, 5, 0 | |
| def embed_unembed(tokens): | |
| transformer = model.Transformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| causal=False, | |
| layer_norm=False, | |
| activation_function=jax.nn.gelu)) | |
| token_embed, position_embed, unembed = self._get_one_hot_embed_unembed( | |
| vocab_size, max_seq_len) | |
| compiled_model = model.CompiledTransformerModel( | |
| transformer=transformer, | |
| token_embed=token_embed, | |
| position_embed=position_embed, | |
| unembed=unembed, | |
| use_unembed_argmax=True, | |
| pad_token=pad_token) | |
| embeddings = compiled_model.embed(tokens) | |
| unembeddings = compiled_model.unembed(embeddings) | |
| return embeddings, unembeddings | |
| rng = hk.PRNGSequence(1) | |
| params = embed_unembed.init(next(rng), tokens) | |
| embeddings, unembeddings = embed_unembed.apply(params, next(rng), tokens) | |
| self.assertEqual( | |
| embeddings.shape, | |
| (tokens.shape[0], tokens.shape[1], vocab_size + max_seq_len)) | |
| np.testing.assert_allclose(unembeddings, tokens) | |
| if __name__ == "__main__": | |
| absltest.main() | |