| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Tests for models.py.""" |
|
|
| from absl.testing import absltest |
| from absl.testing import parameterized |
| import flax |
| from jax.flatten_util import ravel_pytree |
| import jax.numpy as jnp |
| import jax.tree_util |
| import numpy as np |
| from scenic.model_lib import models |
|
|
| NUM_OUTPUTS = 5 |
| INPUT_SHAPE = (10, 32, 32, 3) |
|
|
|
|
| |
| CLASSIFICATION_KEYS = [ |
| ('test_{}'.format(m), m) for m in models.CLASSIFICATION_MODELS.keys() |
| ] |
|
|
| |
| SEGMENTATION_KEYS = [ |
| ('test_{}'.format(m), m) for m in models.SEGMENTATION_MODELS.keys() |
| ] |
|
|
|
|
| class ModelsTest(parameterized.TestCase): |
| """Tests for all models.""" |
|
|
| @parameterized.named_parameters(*CLASSIFICATION_KEYS) |
| def test_classification_models(self, model_name): |
| """Test forward pass of the classification models.""" |
|
|
| model_cls = models.get_model_cls(model_name) |
| rng = jax.random.PRNGKey(0) |
| model = model_cls( |
| config=None, |
| dataset_meta_data={ |
| 'num_classes': NUM_OUTPUTS, |
| 'target_is_onehot': False, |
| }) |
|
|
| model_input_dtype = getattr( |
| jnp, model.config.get('data_dtype_str', 'float32')) |
|
|
| xs = jnp.array(np.random.normal(loc=0.0, scale=10.0, |
| size=INPUT_SHAPE)).astype(model_input_dtype) |
|
|
| rng, init_rng = jax.random.split(rng) |
| dummy_input = jnp.zeros(INPUT_SHAPE, model_input_dtype) |
| init_model_state, init_params = flax.core.pop(model.flax_model.init( |
| init_rng, dummy_input, train=False, debug=False), 'params') |
|
|
| |
| rng, dropout_rng = jax.random.split(rng) |
| variables = {'params': init_params, **init_model_state} |
| outputs, new_model_state = model.flax_model.apply( |
| variables, |
| xs, |
| mutable=['batch_stats'], |
| train=True, |
| rngs={'dropout': dropout_rng}, |
| debug=False) |
| self.assertEqual(outputs.shape, (INPUT_SHAPE[0], NUM_OUTPUTS)) |
|
|
| |
| if init_model_state: |
| bflat, _ = ravel_pytree(init_model_state) |
| new_bflat, _ = ravel_pytree(new_model_state) |
| self.assertFalse(jnp.array_equal(bflat, new_bflat)) |
|
|
| |
| outputs = model.flax_model.apply( |
| variables, xs, mutable=False, train=False, debug=False) |
| self.assertEqual(outputs.shape, (INPUT_SHAPE[0], NUM_OUTPUTS)) |
|
|
| @parameterized.named_parameters(*SEGMENTATION_KEYS) |
| def test_segmentation_models(self, model_name): |
| """Test forward pass of the segmentation models.""" |
|
|
| model_cls = models.get_model_cls(model_name) |
| rng = jax.random.PRNGKey(0) |
| model = model_cls( |
| config=None, |
| dataset_meta_data={ |
| 'num_classes': NUM_OUTPUTS, |
| 'target_is_onehot': False, |
| }) |
|
|
| model_input_dtype = model.config.get('default_input_dtype', jnp.float32) |
| xs = jnp.array(np.random.normal(loc=0.0, scale=10.0, |
| size=INPUT_SHAPE)).astype(model_input_dtype) |
|
|
| rng, init_rng = jax.random.split(rng) |
| dummy_input = jnp.zeros(INPUT_SHAPE, model_input_dtype) |
| init_model_state, init_params = flax.core.pop(model.flax_model.init( |
| init_rng, dummy_input, train=False, debug=False), 'params') |
|
|
| |
| rng, dropout_rng = jax.random.split(rng) |
| variables = {'params': init_params, **init_model_state} |
| outputs, new_model_state = model.flax_model.apply( |
| variables, |
| xs, |
| mutable=['batch_stats'], |
| train=True, |
| rngs={'dropout': dropout_rng}, |
| debug=False) |
| self.assertEqual(outputs.shape, INPUT_SHAPE[:3] + (NUM_OUTPUTS,)) |
|
|
| |
| if init_model_state: |
| bflat, _ = ravel_pytree(init_model_state) |
| new_bflat, _ = ravel_pytree(new_model_state) |
| self.assertFalse(jnp.array_equal(bflat, new_bflat)) |
|
|
| |
| outputs = model.flax_model.apply( |
| variables, xs, mutable=False, train=False, debug=False) |
| self.assertEqual(outputs.shape, INPUT_SHAPE[:3] + (NUM_OUTPUTS,)) |
|
|
|
|
| if __name__ == '__main__': |
| absltest.main() |
|
|