| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Tests for mixer.py.""" |
|
|
| import functools |
|
|
| from absl.testing import absltest |
| from absl.testing import parameterized |
| import flax |
| from jax import random |
| from jax.flatten_util import ravel_pytree |
| import jax.numpy as jnp |
| import jax.tree_util |
| import numpy as np |
| from scenic.projects.baselines import mixer |
|
|
| NUM_OUTPUTS = 5 |
| INPUT_SHAPE = (10, 32, 32, 3) |
|
|
|
|
| class MixerTest(parameterized.TestCase): |
| """Tests for modules in mixer.py.""" |
|
|
| @parameterized.named_parameters( |
| ('without_dropout_without_stochastic_depth', 0.0, 0.0), |
| ('with_dropout_without_stochastic_depth', 0.1, 0.0), |
| ('without_dropout_with_stochastic_depth', 0.0, 0.1), |
| ('with_dropout_with_stochastic_depth', 0.1, 0.1), |
| ('with_dropout_stochastic_depth_layer_scale', 0.1, 0.1, 0.1) |
| ) |
| def test_mixer_block(self, dropout_rate, stochastic_depth, layer_scale=None): |
| """Tests MixerBlock.""" |
| rng = random.PRNGKey(0) |
| x = jnp.ones((4, 16, 32)) |
| mixer_block = functools.partial( |
| mixer.MixerBlock, |
| channels_mlp_dim=32, |
| sequence_mlp_dim=32, |
| dropout_rate=dropout_rate, |
| stochastic_depth=stochastic_depth, |
| layer_scale=layer_scale) |
| mixer_block_vars = mixer_block().init(rng, x, deterministic=True) |
| y = mixer_block().apply(mixer_block_vars, x, deterministic=True) |
| |
| self.assertEqual(y.shape, x.shape) |
|
|
| def test_mixer_models(self): |
| """Test forward pass of the mixer classification model.""" |
| rng = jax.random.PRNGKey(0) |
| model = mixer.MixerMultiLabelClassificationModel( |
| config=None, |
| dataset_meta_data={ |
| 'num_classes': NUM_OUTPUTS, |
| 'target_is_onehot': False, |
| }) |
|
|
| xs = jnp.array(np.random.normal(loc=0.0, scale=10.0, |
| size=INPUT_SHAPE)).astype(jnp.float32) |
|
|
| rng, init_rng = jax.random.split(rng) |
| dummy_input = jnp.zeros(INPUT_SHAPE, jnp.float32) |
| init_model_state, init_params = flax.core.pop(model.flax_model.init( |
| init_rng, dummy_input, train=False, debug=False), 'params') |
|
|
| |
| rng, dropout_rng = jax.random.split(rng) |
| variables = {'params': init_params, **init_model_state} |
| outputs, new_model_state = model.flax_model.apply( |
| variables, |
| xs, |
| mutable=['batch_stats'], |
| train=True, |
| rngs={'dropout': dropout_rng}, |
| debug=False) |
| self.assertEqual(outputs.shape, (INPUT_SHAPE[0], NUM_OUTPUTS)) |
|
|
| |
| if init_model_state: |
| bflat, _ = ravel_pytree(init_model_state) |
| new_bflat, _ = ravel_pytree(new_model_state) |
| self.assertFalse(jnp.array_equal(bflat, new_bflat)) |
|
|
| |
| outputs = model.flax_model.apply( |
| variables, xs, mutable=False, train=False, debug=False) |
| self.assertEqual(outputs.shape, (INPUT_SHAPE[0], NUM_OUTPUTS)) |
|
|
|
|
| if __name__ == '__main__': |
| absltest.main() |
|
|