Spaces:
Sleeping
Sleeping
| # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Tests for transformer.model.""" | |
| from absl.testing import absltest | |
| from absl.testing import parameterized | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from tracr.transformer import compressed_model | |
| from tracr.transformer import model | |
| class CompressedTransformerTest(parameterized.TestCase): | |
| def _check_layer_naming(self, params): | |
| # Modules should be named for example | |
| # For MLPs: "compressed_transformer/layer_{i}/mlp/linear_1" | |
| # For Attention: "compressed_transformer/layer_{i}/attn/key" | |
| # For Layer Norm: "compressed_transformer/layer_{i}/layer_norm" | |
| for key in params.keys(): | |
| levels = key.split("/") | |
| self.assertEqual(levels[0], "compressed_transformer") | |
| if len(levels) == 1: | |
| self.assertEqual(list(params[key].keys()), ["w_emb"]) | |
| continue | |
| if levels[1].startswith("layer_norm"): | |
| continue # output layer norm | |
| self.assertStartsWith(levels[1], "layer") | |
| if levels[2] == "mlp": | |
| self.assertIn(levels[3], {"linear_1", "linear_2"}) | |
| elif levels[2] == "attn": | |
| self.assertIn(levels[3], {"key", "query", "value", "linear"}) | |
| else: | |
| self.assertStartsWith(levels[2], "layer_norm") | |
| def _zero_mlps(self, params): | |
| for module in params: | |
| if "mlp" in module: | |
| for param in params[module]: | |
| params[module][param] = jnp.zeros_like(params[module][param]) | |
| return params | |
| def test_layer_norm(self, layer_norm): | |
| # input = [1, 1, 1, 1] | |
| # If layer norm is used, this should give all-0 output for a freshly | |
| # initialized model because LN will subtract the mean after each layer. | |
| # Else we expect non-zero outputs. | |
| def forward(emb, mask): | |
| transformer = compressed_model.CompressedTransformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| layer_norm=layer_norm)) | |
| return transformer(emb, mask).output | |
| seq_len = 4 | |
| emb = jnp.ones((1, seq_len, 1)) | |
| mask = jnp.ones((1, seq_len)) | |
| rng = hk.PRNGSequence(1) | |
| params = forward.init(next(rng), emb, mask) | |
| out = forward.apply(params, next(rng), emb, mask) | |
| self._check_layer_naming(params) | |
| if layer_norm: | |
| np.testing.assert_allclose(out, 0) | |
| else: | |
| self.assertFalse(np.allclose(out, 0)) | |
| def test_causal_attention(self, causal): | |
| # input = [0, random, random, random] | |
| # mask = [1, 0, 1, 1] | |
| # For causal attention the second token can only attend to the first one, so | |
| # it should be the same. For non-causal attention all tokens should change. | |
| def forward(emb, mask): | |
| transformer = compressed_model.CompressedTransformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| layer_norm=False, | |
| causal=causal)) | |
| return transformer(emb, mask).output | |
| seq_len = 4 | |
| emb = np.random.random((1, seq_len, 1)) | |
| emb[:, 0, :] = 0 | |
| mask = np.array([[1, 0, 1, 1]]) | |
| emb, mask = jnp.array(emb), jnp.array(mask) | |
| rng = hk.PRNGSequence(1) | |
| params = forward.init(next(rng), emb, mask) | |
| params = self._zero_mlps(params) | |
| out = forward.apply(params, next(rng), emb, mask) | |
| self._check_layer_naming(params) | |
| if causal: | |
| self.assertEqual(0, out[0, 0, 0]) | |
| self.assertEqual(emb[0, 1, 0], out[0, 1, 0]) | |
| else: | |
| self.assertNotEqual(0, out[0, 0, 0]) | |
| self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0]) | |
| self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0]) | |
| self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0]) | |
| def test_setting_activation_function_to_zero(self): | |
| # An activation function that always returns zeros should result in the | |
| # same model output as setting all MLP weights to zero. | |
| def forward_zero(emb, mask): | |
| transformer = compressed_model.CompressedTransformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| causal=False, | |
| layer_norm=False, | |
| activation_function=jnp.zeros_like)) | |
| return transformer(emb, mask).output | |
| def forward(emb, mask): | |
| transformer = compressed_model.CompressedTransformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| causal=False, | |
| layer_norm=False, | |
| activation_function=jax.nn.gelu)) | |
| return transformer(emb, mask).output | |
| seq_len = 4 | |
| emb = np.random.random((1, seq_len, 1)) | |
| mask = np.ones((1, seq_len)) | |
| emb, mask = jnp.array(emb), jnp.array(mask) | |
| rng = hk.PRNGSequence(1) | |
| params = forward.init(next(rng), emb, mask) | |
| params_no_mlps = self._zero_mlps(params) | |
| out_zero_activation = forward_zero.apply(params, next(rng), emb, mask) | |
| out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask) | |
| self._check_layer_naming(params) | |
| np.testing.assert_allclose(out_zero_activation, out_no_mlps) | |
| self.assertFalse(np.allclose(out_zero_activation, 0)) | |
| def test_not_setting_embedding_size_produces_same_output_as_default_model( | |
| self): | |
| config = model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| causal=False, | |
| layer_norm=False) | |
| def forward_model(emb, mask): | |
| return model.Transformer(config)(emb, mask).output | |
| def forward_superposition(emb, mask): | |
| return compressed_model.CompressedTransformer(config)(emb, mask).output | |
| seq_len = 4 | |
| emb = np.random.random((1, seq_len, 1)) | |
| mask = np.ones((1, seq_len)) | |
| emb, mask = jnp.array(emb), jnp.array(mask) | |
| rng = hk.PRNGSequence(1) | |
| params = forward_model.init(next(rng), emb, mask) | |
| params_superposition = { | |
| k.replace("transformer", "compressed_transformer"): v | |
| for k, v in params.items() | |
| } | |
| out_model = forward_model.apply(params, emb, mask) | |
| out_superposition = forward_superposition.apply(params_superposition, emb, | |
| mask) | |
| self._check_layer_naming(params_superposition) | |
| np.testing.assert_allclose(out_model, out_superposition) | |
| def test_embbeding_size_produces_correct_shape_of_residuals_and_layer_outputs( | |
| self, embedding_size, unembed_at_every_layer): | |
| def forward(emb, mask): | |
| transformer = compressed_model.CompressedTransformer( | |
| model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| causal=False, | |
| layer_norm=False)) | |
| return transformer( | |
| emb, | |
| mask, | |
| embedding_size=embedding_size, | |
| unembed_at_every_layer=unembed_at_every_layer, | |
| ) | |
| seq_len = 4 | |
| model_size = 16 | |
| emb = np.random.random((1, seq_len, model_size)) | |
| mask = np.ones((1, seq_len)) | |
| emb, mask = jnp.array(emb), jnp.array(mask) | |
| rng = hk.PRNGSequence(1) | |
| params = forward.init(next(rng), emb, mask) | |
| activations = forward.apply(params, next(rng), emb, mask) | |
| self._check_layer_naming(params) | |
| for residual in activations.residuals: | |
| self.assertEqual(residual.shape, (1, seq_len, embedding_size)) | |
| for layer_output in activations.layer_outputs: | |
| self.assertEqual(layer_output.shape, (1, seq_len, model_size)) | |
| def test_identity_embedding_produces_same_output_as_standard_model( | |
| self, model_size, unembed_at_every_layer): | |
| config = model.TransformerConfig( | |
| num_heads=2, | |
| num_layers=2, | |
| key_size=5, | |
| mlp_hidden_size=64, | |
| dropout_rate=0., | |
| causal=False, | |
| layer_norm=False) | |
| def forward_model(emb, mask): | |
| return model.Transformer(config)(emb, mask).output | |
| def forward_superposition(emb, mask): | |
| return compressed_model.CompressedTransformer(config)( | |
| emb, | |
| mask, | |
| embedding_size=model_size, | |
| unembed_at_every_layer=unembed_at_every_layer).output | |
| seq_len = 4 | |
| emb = np.random.random((1, seq_len, model_size)) | |
| mask = np.ones((1, seq_len)) | |
| emb, mask = jnp.array(emb), jnp.array(mask) | |
| rng = hk.PRNGSequence(1) | |
| params = forward_model.init(next(rng), emb, mask) | |
| params_superposition = { | |
| k.replace("transformer", "compressed_transformer"): v | |
| for k, v in params.items() | |
| } | |
| params_superposition["compressed_transformer"] = { | |
| "w_emb": jnp.identity(model_size) | |
| } | |
| out_model = forward_model.apply(params, emb, mask) | |
| out_superposition = forward_superposition.apply(params_superposition, emb, | |
| mask) | |
| self._check_layer_naming(params_superposition) | |
| np.testing.assert_allclose(out_model, out_superposition) | |
| if __name__ == "__main__": | |
| absltest.main() | |