owlv2 / scenic /model_lib /base_models /tests /test_regression_model.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for regression_model.py."""
from absl.testing import absltest
from flax import jax_utils
import jax
import jax.numpy as jnp
import ml_collections
from scenic.model_lib.base_models import regression_model
class FakeRegressionModel(regression_model.RegressionModel):
"""A dummy regression model for testing purposes."""
def __init__(self):
dataset_meta_data = {}
super().__init__(ml_collections.ConfigDict(), dataset_meta_data)
def build_flax_model(self):
pass
def default_flax_model_config(self):
pass
def get_fake_batch_and_predictions():
"""Generates a fake `batch`."""
targets = jnp.array(
[[2.0, 1.0, 0.0, 1.0],
[2.0, 1.0, 0.0, 1.0],
[5.0, 7.0, 0.0, 1.0]])
predictions = jnp.array(
[[2.0, 0.0, 0.0, 1.0],
[2.0, 1.0, 0.0, 1.0],
[4.0, 10.0, 0.0, 1.0]])
fake_batch = {
'inputs': None,
'targets': targets
}
return fake_batch, predictions
class TestRegressionModel(absltest.TestCase):
"""Tests for the a fake regression model."""
def test_loss_function(self):
"""Tests loss_function by checking its output's validity."""
model = FakeRegressionModel()
batch, predictions = get_fake_batch_and_predictions()
batch_replicated, predictions_replicated = (
jax_utils.replicate(batch), jax_utils.replicate(predictions))
# Test loss function in the pmapped setup:
loss_function_pmapped = jax.pmap(model.loss_function, axis_name='batch')
total_loss = loss_function_pmapped(predictions_replicated, batch_replicated)
total_loss = jax_utils.unreplicate(total_loss)
# Loss = 1/3 * (|[0, 1, 0, 0]|^2 + |[0, 0, 0, 0|^2 + |[1, 3, 0, 0]|^2)
self.assertAlmostEqual(total_loss, 11 / 3)
def test_metric_function(self):
"""Tests metric_function by checking its output's format and validity."""
model = FakeRegressionModel()
batch, predictions = get_fake_batch_and_predictions()
batch_replicated, predictions_replicated = (
jax_utils.replicate(batch), jax_utils.replicate(predictions))
metrics_fn_pmapped = jax.pmap(model.get_metrics_fn(), axis_name='batch')
all_metrics = metrics_fn_pmapped(predictions_replicated, batch_replicated)
expected_metrics_keys = ['mean_squared_error']
self.assertSameElements(expected_metrics_keys, all_metrics.keys())
all_metrics = jax_utils.unreplicate(all_metrics)
self.assertLen(all_metrics, 1)
mse_sum_count = all_metrics['mean_squared_error']
# (|[0, 1, 0, 0]|^2 + |[0, 0, 0, 0|^2 + |[1, 3, 0, 0]|^2) = 11
self.assertAlmostEqual(mse_sum_count[0], 11.0)
self.assertEqual(mse_sum_count[1], 3)
if __name__ == '__main__':
absltest.main()