File size: 4,867 Bytes
1327f34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | # Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for models.py."""
from absl.testing import absltest
from absl.testing import parameterized
import flax
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.tree_util
import numpy as np
from scenic.model_lib import models
NUM_OUTPUTS = 5
INPUT_SHAPE = (10, 32, 32, 3)
# Automatically test all defined classification models.
CLASSIFICATION_KEYS = [
('test_{}'.format(m), m) for m in models.CLASSIFICATION_MODELS.keys()
]
# Automatically test all defined segmentation models.
SEGMENTATION_KEYS = [
('test_{}'.format(m), m) for m in models.SEGMENTATION_MODELS.keys()
]
class ModelsTest(parameterized.TestCase):
"""Tests for all models."""
@parameterized.named_parameters(*CLASSIFICATION_KEYS)
def test_classification_models(self, model_name):
"""Test forward pass of the classification models."""
model_cls = models.get_model_cls(model_name)
rng = jax.random.PRNGKey(0)
model = model_cls(
config=None,
dataset_meta_data={
'num_classes': NUM_OUTPUTS,
'target_is_onehot': False,
})
model_input_dtype = getattr(
jnp, model.config.get('data_dtype_str', 'float32'))
xs = jnp.array(np.random.normal(loc=0.0, scale=10.0,
size=INPUT_SHAPE)).astype(model_input_dtype)
rng, init_rng = jax.random.split(rng)
dummy_input = jnp.zeros(INPUT_SHAPE, model_input_dtype)
init_model_state, init_params = flax.core.pop(model.flax_model.init(
init_rng, dummy_input, train=False, debug=False), 'params')
# Check that the forward pass works with mutated model_state.
rng, dropout_rng = jax.random.split(rng)
variables = {'params': init_params, **init_model_state}
outputs, new_model_state = model.flax_model.apply(
variables,
xs,
mutable=['batch_stats'],
train=True,
rngs={'dropout': dropout_rng},
debug=False)
self.assertEqual(outputs.shape, (INPUT_SHAPE[0], NUM_OUTPUTS))
# If it's a batch norm model check the batch stats changed.
if init_model_state:
bflat, _ = ravel_pytree(init_model_state)
new_bflat, _ = ravel_pytree(new_model_state)
self.assertFalse(jnp.array_equal(bflat, new_bflat))
# Test batch_norm in inference mode.
outputs = model.flax_model.apply(
variables, xs, mutable=False, train=False, debug=False)
self.assertEqual(outputs.shape, (INPUT_SHAPE[0], NUM_OUTPUTS))
@parameterized.named_parameters(*SEGMENTATION_KEYS)
def test_segmentation_models(self, model_name):
"""Test forward pass of the segmentation models."""
model_cls = models.get_model_cls(model_name)
rng = jax.random.PRNGKey(0)
model = model_cls(
config=None,
dataset_meta_data={
'num_classes': NUM_OUTPUTS,
'target_is_onehot': False,
})
model_input_dtype = model.config.get('default_input_dtype', jnp.float32)
xs = jnp.array(np.random.normal(loc=0.0, scale=10.0,
size=INPUT_SHAPE)).astype(model_input_dtype)
rng, init_rng = jax.random.split(rng)
dummy_input = jnp.zeros(INPUT_SHAPE, model_input_dtype)
init_model_state, init_params = flax.core.pop(model.flax_model.init(
init_rng, dummy_input, train=False, debug=False), 'params')
# Check that the forward pass works with mutated model_state.
rng, dropout_rng = jax.random.split(rng)
variables = {'params': init_params, **init_model_state}
outputs, new_model_state = model.flax_model.apply(
variables,
xs,
mutable=['batch_stats'],
train=True,
rngs={'dropout': dropout_rng},
debug=False)
self.assertEqual(outputs.shape, INPUT_SHAPE[:3] + (NUM_OUTPUTS,))
# If it's a batch norm model check the batch stats changed.
if init_model_state:
bflat, _ = ravel_pytree(init_model_state)
new_bflat, _ = ravel_pytree(new_model_state)
self.assertFalse(jnp.array_equal(bflat, new_bflat))
# Test batch_norm in inference mode.
outputs = model.flax_model.apply(
variables, xs, mutable=False, train=False, debug=False)
self.assertEqual(outputs.shape, INPUT_SHAPE[:3] + (NUM_OUTPUTS,))
if __name__ == '__main__':
absltest.main()
|