owlv2 / scenic /model_lib /base_models /tests /test_model_utils.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for functions in model_utils.py."""
import itertools
from absl.testing import absltest
from absl.testing import parameterized
from flax.training import common_utils
import jax
import jax.numpy as jnp
import numpy as np
from scenic.model_lib.base_models import model_utils
class SimpleGatherTest(parameterized.TestCase):
"""Test simple_gather()."""
def test_simple_gather_ndarray(self):
"""Test against manually specified target when idx is a nd-array."""
x = jnp.array(np.random.normal(size=(2, 3, 5)), dtype=jnp.float32)
idx = jnp.array([[1, 0, 2], [2, 1, 0]], dtype=jnp.int32)
y = model_utils.simple_gather(x, idx)
y_target = jnp.stack([
jnp.stack([x[0, 1], x[0, 0], x[0, 2]]),
jnp.stack([x[1, 2], x[1, 1], x[1, 0]])])
self.assertSequenceAlmostEqual(y.flatten(), y_target.flatten())
class LossTest(parameterized.TestCase):
"""Test various loss functions in model_utils."""
def test_weighted_l1_loss(self):
"""Test weighted_l1_loss against a manually specified target."""
x = jnp.array([[0.1, 0.3], [-1.0, 0.2]], dtype=jnp.float32)
y = jnp.array([[0.5, -1.3], [0.9, 1.2]], dtype=jnp.float32)
out1 = model_utils.weighted_l1_loss(x, y, reduction=None)
out1_target = jnp.array([[0.4, 1.6], [1.9, 1.0]], dtype=jnp.float32)
self.assertSequenceAlmostEqual(
out1.flatten(), out1_target.flatten(), places=5)
out2 = model_utils.weighted_l1_loss(x, y, reduction='mean').item()
out2_target = 4.9 / 4
self.assertAlmostEqual(out2, out2_target, places=5)
def test_weighted_box_l1_loss(self):
"""Test weighted_box_l1_loss against manually specified targets."""
x1 = jnp.array([[0.1, 0.3, 0.9, 0.8]], dtype=jnp.float32)
y1 = jnp.array([[0.5, 0.1, 0.9, 0.7]], dtype=jnp.float32)
out1 = model_utils.weighted_box_l1_loss(x1, y1)
out1_target = jnp.array([[0.4, 0.2, 0, 0.1]], dtype=jnp.float32)
self.assertSequenceAlmostEqual(
out1.flatten(), out1_target.flatten(), places=5)
out2 = model_utils.weighted_box_l1_loss(x1, y1, reduction='mean').item()
out2_target = jnp.mean(out1_target).item()
self.assertAlmostEqual(out2, out2_target, places=5)
out3 = model_utils.weighted_box_l1_loss(x1, y1, tight=False)
out3_target = jnp.array([[0.4, 0.0, 0.0, 0.1]], dtype=jnp.float32)
self.assertSequenceAlmostEqual(
out3.flatten(), out3_target.flatten(), places=5)
def test_weighted_sigmoid_cross_entropy(self):
"""Tests weighted_sigmoid_cross_entropy."""
logits = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32)
labels = jnp.array([[0, 1, 1], [1, 0, 1]], dtype=jnp.float32)
sigmoid = jax.nn.sigmoid
log = jnp.log
loss = model_utils.weighted_sigmoid_cross_entropy(logits, labels)
gt_loss = jnp.array([[
-log(1 - sigmoid(1.)), -log(sigmoid(2.)), -log(sigmoid(3.))
], [-log(sigmoid(4.)), -log(1 - sigmoid(5.)), -log(sigmoid(6.))]
]) / np.prod(labels.shape[:-1])
self.assertSequenceAlmostEqual(
loss.flatten(), gt_loss.sum().flatten(), places=3)
example_weights = jnp.array([1., 0.])
loss = model_utils.weighted_sigmoid_cross_entropy(
logits, labels, weights=example_weights)
gt_loss = jnp.array([[
-log(1 - sigmoid(1.)), -log(sigmoid(2.)), -log(sigmoid(3.))
], [0., 0., 0.]]) / example_weights.sum() + 1e-9
self.assertSequenceAlmostEqual(
loss.flatten(), gt_loss.sum().flatten(), places=3)
label_weights = jnp.array([1., 2., 3.])
loss = model_utils.weighted_sigmoid_cross_entropy(
logits, labels, label_weights=label_weights)
gt_loss = jnp.array([[
-log(1 - sigmoid(1.)), -2 * log(sigmoid(2.)), -3 * log(sigmoid(3.))
], [-log(sigmoid(4.)), -2 * log(1 - sigmoid(5.)), -3 * log(sigmoid(6.))]
]) / np.prod(labels.shape[:-1])
self.assertSequenceAlmostEqual(
loss.flatten(), gt_loss.sum().flatten(), places=3)
loss = model_utils.weighted_sigmoid_cross_entropy(
logits, labels, weights=example_weights, label_weights=label_weights)
gt_loss = jnp.array([[
-log(1 - sigmoid(1.)), -2 * log(sigmoid(2.)), -3 * log(sigmoid(3.))
], [0., 0., 0.]]) / example_weights.sum() + 1e-9
self.assertSequenceAlmostEqual(
loss.flatten(), gt_loss.sum().flatten(), places=3)
# Label weights can actually be any shape that is broadcastable to the
# shape of logits.
label_weights = jnp.array([[1., 2., 3.], [4., 5., 6.]])
loss = model_utils.weighted_sigmoid_cross_entropy(
logits, labels, weights=example_weights, label_weights=label_weights)
gt_loss = jnp.array([[
-log(1 - sigmoid(1.)), -2 * log(sigmoid(2.)), -3 * log(sigmoid(3.))
], [0., 0., 0.]]) / example_weights.sum() + 1e-9
self.assertSequenceAlmostEqual(
loss.flatten(), gt_loss.sum().flatten(), places=3)
with self.assertRaises(ValueError):
label_weights = jnp.array([1., 2., 3., 4.])
loss = model_utils.weighted_sigmoid_cross_entropy(
logits, labels, label_weights=label_weights)
def test_focal_sigmoid_cross_entropy(self):
"""Tests focal_sigmoid_cross_entropy."""
logits = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32)
labels = jnp.array([[0, 1, 1], [1, 0, 1]], dtype=jnp.float32)
sigmoid = jax.nn.sigmoid
log = jnp.log
a = 0.25
g = 2.
loss = model_utils.focal_sigmoid_cross_entropy(
logits, labels, alpha=a, gamma=g)
gt_loss = jnp.array(
[[-log(1 - sigmoid(1.)), -log(sigmoid(2.)), -log(sigmoid(3.))],
[-log(sigmoid(4.)), -log(1 - sigmoid(5.)), -log(sigmoid(6.))]])
focal_factor = jnp.array([[
(1 - a) * sigmoid(1.)**g, a * sigmoid(-2.)**g, a * sigmoid(-3.)**g
], [a * sigmoid(-4.)**g, (1 - a) * sigmoid(5.)**g, a * sigmoid(-6.)**g]])
self.assertSequenceAlmostEqual(
loss.flatten(), (gt_loss * focal_factor).flatten(), places=3)
def test_dice_loss(self):
"""Tests the correctness of the segmentation dice loss."""
# Create test targets:
batch, num_objects, h, w = 1, 2, 128, 128
stride = 2
targets = np.zeros((batch, num_objects, h, w), dtype=np.float32)
targets[0, 0, :64, :64] = 1.0 # Add object in top left of image.
targets[0, 1, 64:, 64:] = 1.0 # Add object in bottom right of image.
input_shape = batch, num_objects, h // stride, w // stride
# Test perfect predictions:
inputs = np.zeros(input_shape, dtype=np.float32)
inputs[0, 0, :64 // stride, :64 // stride] = 1.0
inputs[0, 1, 64 // stride:, 64 // stride:] = 1.0
inputs = (inputs - 0.5) * 1e6 # Inputs will be passed through sigmoid.
loss = model_utils.dice_loss(
jnp.array(inputs), jnp.array(targets), interpolation='nearest')
np.testing.assert_array_almost_equal(loss, [[0.0, 0.0]], decimal=3)
# Test one half-overlapping prediction:
inputs = np.zeros(input_shape, dtype=np.float32)
inputs[0, 0, 32 // stride:(32 + 64) // stride, :64 // stride] = 1.0
inputs[0, 1, 64 // stride:, 64 // stride:] = 1.0
inputs = (inputs - 0.5) * 1e6 # Inputs will be passed through sigmoid.
loss = model_utils.dice_loss(
jnp.array(inputs), jnp.array(targets), interpolation='nearest')
np.testing.assert_array_almost_equal(loss, [[0.5, 0.0]], decimal=3)
# Test one non-overlapping prediction:
inputs = np.zeros(input_shape, dtype=np.float32)
inputs[0, 0, 64 // stride:, 64 // stride:] = 1.0
inputs[0, 1, 64 // stride:, 64 // stride:] = 1.0
inputs = (inputs - 0.5) * 1e6 # Inputs will be passed through sigmoid.
loss = model_utils.dice_loss(
jnp.array(inputs), jnp.array(targets), interpolation='nearest')
np.testing.assert_array_almost_equal(loss, [[1.0, 0.0]], decimal=3)
# Test all-pairs with different instance numbers:
inputs = np.zeros((batch, 3, h // stride, w // stride), dtype=np.float32)
inputs[0, 0, :64 // stride, :64 // stride] = 1.0
inputs[0, 1, 32 // stride:(32 + 64) // stride, :64 // stride] = 1.0
inputs[0, 2, 64 // stride:, 64 // stride:] = 1.0
inputs = (inputs - 0.5) * 1e6 # Inputs will be passed through sigmoid.
loss = model_utils.dice_loss(
jnp.array(inputs), jnp.array(targets), interpolation='nearest',
all_pairs=True)
self.assertTupleEqual(loss.shape, (1, 3, 2)) # [b, n_pred, n_true]
np.testing.assert_array_almost_equal(loss, [[[0.0, 1.0],
[0.5, 1.0],
[1.0, 0.0]]], decimal=3)
def test_weighted_square_error(self):
"""Tests implementation of squared error."""
predictions = jnp.array([
[
[1.0, 3.0, 5.0, 6.0],
[3.0, 5.0, 11.0, 10.0],
[9.0, 10.0, 11.0, 12.0],
[14.0, 13.0, 14.0, 17.0],
],
[
[17.0, 18.0, 21.0, 22.0],
[20.0, 19.0, 24.0, 25.0],
[27.0, 29.0, 30.0, 32.0],
[27.0, 28.0, 33.0, 32.0],
],
])
targets = jnp.arange(1, 33).reshape(2, 4, 4)
# Without specifying axis, this will be over the last two dimensions.
loss = model_utils.weighted_mean_squared_error(predictions, targets)
expected_loss = jnp.mean(jnp.array([38.0, 70.0]))
self.assertAlmostEqual(loss, expected_loss, places=5)
# Test by specifying axes as a tuple. The following are all equivalent to
# the previous test.
loss = model_utils.weighted_mean_squared_error(predictions, targets,
axis=(1, 2))
self.assertAlmostEqual(loss, expected_loss, places=5)
loss = model_utils.weighted_mean_squared_error(predictions, targets,
axis=(-1, -2))
self.assertAlmostEqual(loss, expected_loss, places=5)
loss = model_utils.weighted_mean_squared_error(predictions, targets,
axis=(2, 1))
self.assertAlmostEqual(loss, expected_loss, places=5)
# Test by computing loss over a single axis only.
loss = model_utils.weighted_mean_squared_error(predictions, targets,
axis=-1)
expected_loss = jnp.mean(jnp.array([[9, 25, 0, 4],
[8, 12, 38, 12]]))
self.assertAlmostEqual(loss, expected_loss, places=5)
loss = model_utils.weighted_mean_squared_error(predictions, targets,
axis=2)
self.assertAlmostEqual(loss, expected_loss, places=5)
loss = model_utils.weighted_mean_squared_error(predictions, targets,
axis=1)
expected_loss = jnp.mean(jnp.array([[5, 3, 21, 9],
[9, 22, 18, 21]]))
self.assertAlmostEqual(loss, expected_loss, places=5)
# Test with loss weights.
weights = jnp.array([[1, 1, 1, 0], [0, 1, 1, 0]])
loss = model_utils.weighted_mean_squared_error(predictions, targets,
weights, axis=-1)
expected_loss = jnp.mean(jnp.array([9, 25, 12, 38, 0]))
self.assertAlmostEqual(loss, expected_loss, places=5)
weights = jnp.array([1, 0])
loss = model_utils.weighted_mean_squared_error(predictions, targets,
weights, axis=-1)
expected_loss = jnp.mean(jnp.array([9, 25, 0, 4]))
self.assertAlmostEqual(loss, expected_loss, places=5)
class MetricTest(parameterized.TestCase):
"""Tests the metric computation related utilities."""
def is_valid(self, t, value_name):
"""Helper function to assert that tensor `t` does not have `nan`, `inf`."""
self.assertFalse(
jnp.isnan(t).any(), msg=f'Found nan\'s in {t} for {value_name}')
self.assertFalse(
jnp.isinf(t).any(), msg=f'Found inf\'s in {t} for {value_name}')
def test_weighted_topk_correctly_classified(self):
"""Tests the topk accuracy computation."""
batch_size = 512
num_of_classes = 100
logits = jnp.array(
np.random.normal(size=(batch_size, num_of_classes)), dtype=jnp.float32)
labels = jnp.array(np.random.randint(num_of_classes, size=(batch_size,)))
one_hot_targets = common_utils.onehot(labels, logits.shape[-1])
classification_accuracy = model_utils.weighted_correctly_classified(
logits, one_hot_targets)
top_one_accuracy = model_utils.weighted_topk_correctly_classified(
logits, one_hot_targets, k=1)
self.assertSequenceAlmostEqual(
classification_accuracy.flatten(), top_one_accuracy.flatten())
top_n_accuracy = model_utils.weighted_topk_correctly_classified(
logits, one_hot_targets, k=num_of_classes)
self.assertEqual(jnp.mean(top_n_accuracy), 1)
# computes using numpy
top_5_accuracy = model_utils.weighted_topk_correctly_classified(
logits, one_hot_targets, k=5)
top5_pred = np.argsort(
np.reshape(logits, [-1, num_of_classes]), axis=1)[:, -5:]
y_true = np.array(labels)
top5_pred = np.reshape(top5_pred, [-1, 5])
y_true = np.reshape(y_true, [-1])
np_top_accuracy = np.array(
[y_true[i] in top5_pred[i, :] for i in range(len(y_true))])
self.assertSequenceAlmostEqual(top_5_accuracy.flatten(),
np_top_accuracy.flatten())
def test_weighted_recall(self):
"""Tests the topk recall computation."""
logits = np.array([[[2, 3, 4],
[4, 3, 2],
[4, 2, 3],
[3, 2, 4],
[4, 2, 3],
]])
labels = np.array([[[1, 1, 0],
[1, 1, 0],
[1, 0, 0],
[1, 0, 0],
[0, 0, 0]
]])
batch_size = 8
logits = jnp.tile(logits, [batch_size, 1, 1])
labels = jnp.tile(labels, [batch_size, 1, 1])
recall = model_utils.weighted_recall(logits, labels)
recall_expected = np.array([[1/2, 1., 1., 0., 0.]] * batch_size)
self.assertSequenceAlmostEqual(
recall.flatten(), recall_expected.flatten())
@parameterized.parameters(itertools.product([1., 0.], [1., 0.]))
def test_weighted_top_one_correctly_classified(self, label_multiplier,
weight_multiplier):
"""Tests the top1 correct computation."""
batch_size = 512
num_of_classes = 100
logits = jnp.array(np.random.normal(
size=(batch_size, 50, num_of_classes)), dtype=jnp.float32)
labels = jnp.array(np.random.randint(
0, 2, size=(batch_size, 50, num_of_classes)))
labels *= label_multiplier
weights = jnp.ones(shape=(batch_size,), dtype=jnp.float32)
weights *= weight_multiplier
is_correct_array = model_utils.weighted_top_one_correctly_classified(
logits, labels, weights=weights)
num_correct = jnp.sum(is_correct_array)
is_correct_array_ref = model_utils.weighted_topk_correctly_classified(
logits, labels, weights, k=1)
np.testing.assert_array_almost_equal(
is_correct_array, is_correct_array_ref)
np.testing.assert_equal(np.sum(is_correct_array),
np.sum(is_correct_array_ref))
self.is_valid(num_correct, 'Number of correctly classified')
@parameterized.parameters(itertools.product([1., 0.], [1., 0.]))
def test_weighted_unnormalized_sigmoid_cross_entropy(self, label_multiplier,
weight_multiplier):
"""Tests the unnormalized sigmoid cross entropy computation."""
batch_size = 512
num_of_classes = 100
logits = jnp.array(
np.random.normal(size=(batch_size, num_of_classes)), dtype=jnp.float32)
labels = jnp.array(np.random.randint(0, 2,
size=(batch_size, num_of_classes)))
labels *= label_multiplier
weights = jnp.ones(shape=(batch_size,), dtype=jnp.float32)
weights *= weight_multiplier
loss_array = model_utils.weighted_unnormalized_sigmoid_cross_entropy(
logits, labels, weights=weights)
loss_sum = jnp.sum(loss_array)
self.is_valid(loss_sum, 'Loss value')
@parameterized.parameters(itertools.product([1., 0.], [1., 0.]))
def test_weighted_unnormalized_softmax_cross_entropy(self, label_multiplier,
weight_multiplier):
"""Tests the unnormalized softmax cross entropy computation."""
batch_size = 512
num_of_classes = 100
logits = jnp.array(
np.random.normal(size=(batch_size, num_of_classes)), dtype=jnp.float32)
labels = jnp.array(
np.random.randint(0, 2, size=(batch_size, num_of_classes)))
labels *= label_multiplier
weights = jnp.ones(shape=(batch_size,), dtype=jnp.float32)
weights *= weight_multiplier
loss_array = model_utils.weighted_unnormalized_softmax_cross_entropy(
logits, labels, weights=weights)
loss_sum = jnp.sum(loss_array)
self.is_valid(loss_sum, 'Loss value')
if __name__ == '__main__':
absltest.main()