| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Unit tests for functions in model_utils.py.""" |
| | import itertools |
| |
|
| | from absl.testing import absltest |
| | from absl.testing import parameterized |
| | from flax.training import common_utils |
| | import jax |
| | import jax.numpy as jnp |
| | import numpy as np |
| | from scenic.model_lib.base_models import model_utils |
| |
|
| |
|
| | class SimpleGatherTest(parameterized.TestCase): |
| | """Test simple_gather().""" |
| |
|
| | def test_simple_gather_ndarray(self): |
| | """Test against manually specified target when idx is a nd-array.""" |
| | x = jnp.array(np.random.normal(size=(2, 3, 5)), dtype=jnp.float32) |
| | idx = jnp.array([[1, 0, 2], [2, 1, 0]], dtype=jnp.int32) |
| | y = model_utils.simple_gather(x, idx) |
| | y_target = jnp.stack([ |
| | jnp.stack([x[0, 1], x[0, 0], x[0, 2]]), |
| | jnp.stack([x[1, 2], x[1, 1], x[1, 0]])]) |
| |
|
| | self.assertSequenceAlmostEqual(y.flatten(), y_target.flatten()) |
| |
|
| |
|
| | class LossTest(parameterized.TestCase): |
| | """Test various loss functions in model_utils.""" |
| |
|
| | def test_weighted_l1_loss(self): |
| | """Test weighted_l1_loss against a manually specified target.""" |
| | x = jnp.array([[0.1, 0.3], [-1.0, 0.2]], dtype=jnp.float32) |
| | y = jnp.array([[0.5, -1.3], [0.9, 1.2]], dtype=jnp.float32) |
| |
|
| | out1 = model_utils.weighted_l1_loss(x, y, reduction=None) |
| | out1_target = jnp.array([[0.4, 1.6], [1.9, 1.0]], dtype=jnp.float32) |
| | self.assertSequenceAlmostEqual( |
| | out1.flatten(), out1_target.flatten(), places=5) |
| |
|
| | out2 = model_utils.weighted_l1_loss(x, y, reduction='mean').item() |
| | out2_target = 4.9 / 4 |
| | self.assertAlmostEqual(out2, out2_target, places=5) |
| |
|
| | def test_weighted_box_l1_loss(self): |
| | """Test weighted_box_l1_loss against manually specified targets.""" |
| | x1 = jnp.array([[0.1, 0.3, 0.9, 0.8]], dtype=jnp.float32) |
| | y1 = jnp.array([[0.5, 0.1, 0.9, 0.7]], dtype=jnp.float32) |
| |
|
| | out1 = model_utils.weighted_box_l1_loss(x1, y1) |
| | out1_target = jnp.array([[0.4, 0.2, 0, 0.1]], dtype=jnp.float32) |
| | self.assertSequenceAlmostEqual( |
| | out1.flatten(), out1_target.flatten(), places=5) |
| |
|
| | out2 = model_utils.weighted_box_l1_loss(x1, y1, reduction='mean').item() |
| | out2_target = jnp.mean(out1_target).item() |
| | self.assertAlmostEqual(out2, out2_target, places=5) |
| |
|
| | out3 = model_utils.weighted_box_l1_loss(x1, y1, tight=False) |
| | out3_target = jnp.array([[0.4, 0.0, 0.0, 0.1]], dtype=jnp.float32) |
| | self.assertSequenceAlmostEqual( |
| | out3.flatten(), out3_target.flatten(), places=5) |
| |
|
| | def test_weighted_sigmoid_cross_entropy(self): |
| | """Tests weighted_sigmoid_cross_entropy.""" |
| |
|
| | logits = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32) |
| | labels = jnp.array([[0, 1, 1], [1, 0, 1]], dtype=jnp.float32) |
| | sigmoid = jax.nn.sigmoid |
| | log = jnp.log |
| |
|
| | loss = model_utils.weighted_sigmoid_cross_entropy(logits, labels) |
| | gt_loss = jnp.array([[ |
| | -log(1 - sigmoid(1.)), -log(sigmoid(2.)), -log(sigmoid(3.)) |
| | ], [-log(sigmoid(4.)), -log(1 - sigmoid(5.)), -log(sigmoid(6.))] |
| | ]) / np.prod(labels.shape[:-1]) |
| | self.assertSequenceAlmostEqual( |
| | loss.flatten(), gt_loss.sum().flatten(), places=3) |
| |
|
| | example_weights = jnp.array([1., 0.]) |
| | loss = model_utils.weighted_sigmoid_cross_entropy( |
| | logits, labels, weights=example_weights) |
| | gt_loss = jnp.array([[ |
| | -log(1 - sigmoid(1.)), -log(sigmoid(2.)), -log(sigmoid(3.)) |
| | ], [0., 0., 0.]]) / example_weights.sum() + 1e-9 |
| | self.assertSequenceAlmostEqual( |
| | loss.flatten(), gt_loss.sum().flatten(), places=3) |
| |
|
| | label_weights = jnp.array([1., 2., 3.]) |
| | loss = model_utils.weighted_sigmoid_cross_entropy( |
| | logits, labels, label_weights=label_weights) |
| | gt_loss = jnp.array([[ |
| | -log(1 - sigmoid(1.)), -2 * log(sigmoid(2.)), -3 * log(sigmoid(3.)) |
| | ], [-log(sigmoid(4.)), -2 * log(1 - sigmoid(5.)), -3 * log(sigmoid(6.))] |
| | ]) / np.prod(labels.shape[:-1]) |
| | self.assertSequenceAlmostEqual( |
| | loss.flatten(), gt_loss.sum().flatten(), places=3) |
| |
|
| | loss = model_utils.weighted_sigmoid_cross_entropy( |
| | logits, labels, weights=example_weights, label_weights=label_weights) |
| | gt_loss = jnp.array([[ |
| | -log(1 - sigmoid(1.)), -2 * log(sigmoid(2.)), -3 * log(sigmoid(3.)) |
| | ], [0., 0., 0.]]) / example_weights.sum() + 1e-9 |
| | self.assertSequenceAlmostEqual( |
| | loss.flatten(), gt_loss.sum().flatten(), places=3) |
| |
|
| | |
| | |
| | label_weights = jnp.array([[1., 2., 3.], [4., 5., 6.]]) |
| | loss = model_utils.weighted_sigmoid_cross_entropy( |
| | logits, labels, weights=example_weights, label_weights=label_weights) |
| | gt_loss = jnp.array([[ |
| | -log(1 - sigmoid(1.)), -2 * log(sigmoid(2.)), -3 * log(sigmoid(3.)) |
| | ], [0., 0., 0.]]) / example_weights.sum() + 1e-9 |
| | self.assertSequenceAlmostEqual( |
| | loss.flatten(), gt_loss.sum().flatten(), places=3) |
| |
|
| | with self.assertRaises(ValueError): |
| | label_weights = jnp.array([1., 2., 3., 4.]) |
| | loss = model_utils.weighted_sigmoid_cross_entropy( |
| | logits, labels, label_weights=label_weights) |
| |
|
| | def test_focal_sigmoid_cross_entropy(self): |
| | """Tests focal_sigmoid_cross_entropy.""" |
| | logits = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32) |
| | labels = jnp.array([[0, 1, 1], [1, 0, 1]], dtype=jnp.float32) |
| | sigmoid = jax.nn.sigmoid |
| | log = jnp.log |
| |
|
| | a = 0.25 |
| | g = 2. |
| | loss = model_utils.focal_sigmoid_cross_entropy( |
| | logits, labels, alpha=a, gamma=g) |
| |
|
| | gt_loss = jnp.array( |
| | [[-log(1 - sigmoid(1.)), -log(sigmoid(2.)), -log(sigmoid(3.))], |
| | [-log(sigmoid(4.)), -log(1 - sigmoid(5.)), -log(sigmoid(6.))]]) |
| | focal_factor = jnp.array([[ |
| | (1 - a) * sigmoid(1.)**g, a * sigmoid(-2.)**g, a * sigmoid(-3.)**g |
| | ], [a * sigmoid(-4.)**g, (1 - a) * sigmoid(5.)**g, a * sigmoid(-6.)**g]]) |
| | self.assertSequenceAlmostEqual( |
| | loss.flatten(), (gt_loss * focal_factor).flatten(), places=3) |
| |
|
| | def test_dice_loss(self): |
| | """Tests the correctness of the segmentation dice loss.""" |
| | |
| | batch, num_objects, h, w = 1, 2, 128, 128 |
| | stride = 2 |
| | targets = np.zeros((batch, num_objects, h, w), dtype=np.float32) |
| | targets[0, 0, :64, :64] = 1.0 |
| | targets[0, 1, 64:, 64:] = 1.0 |
| | input_shape = batch, num_objects, h // stride, w // stride |
| |
|
| | |
| | inputs = np.zeros(input_shape, dtype=np.float32) |
| | inputs[0, 0, :64 // stride, :64 // stride] = 1.0 |
| | inputs[0, 1, 64 // stride:, 64 // stride:] = 1.0 |
| | inputs = (inputs - 0.5) * 1e6 |
| | loss = model_utils.dice_loss( |
| | jnp.array(inputs), jnp.array(targets), interpolation='nearest') |
| | np.testing.assert_array_almost_equal(loss, [[0.0, 0.0]], decimal=3) |
| |
|
| | |
| | inputs = np.zeros(input_shape, dtype=np.float32) |
| | inputs[0, 0, 32 // stride:(32 + 64) // stride, :64 // stride] = 1.0 |
| | inputs[0, 1, 64 // stride:, 64 // stride:] = 1.0 |
| | inputs = (inputs - 0.5) * 1e6 |
| | loss = model_utils.dice_loss( |
| | jnp.array(inputs), jnp.array(targets), interpolation='nearest') |
| | np.testing.assert_array_almost_equal(loss, [[0.5, 0.0]], decimal=3) |
| |
|
| | |
| | inputs = np.zeros(input_shape, dtype=np.float32) |
| | inputs[0, 0, 64 // stride:, 64 // stride:] = 1.0 |
| | inputs[0, 1, 64 // stride:, 64 // stride:] = 1.0 |
| | inputs = (inputs - 0.5) * 1e6 |
| | loss = model_utils.dice_loss( |
| | jnp.array(inputs), jnp.array(targets), interpolation='nearest') |
| | np.testing.assert_array_almost_equal(loss, [[1.0, 0.0]], decimal=3) |
| |
|
| | |
| | inputs = np.zeros((batch, 3, h // stride, w // stride), dtype=np.float32) |
| | inputs[0, 0, :64 // stride, :64 // stride] = 1.0 |
| | inputs[0, 1, 32 // stride:(32 + 64) // stride, :64 // stride] = 1.0 |
| | inputs[0, 2, 64 // stride:, 64 // stride:] = 1.0 |
| | inputs = (inputs - 0.5) * 1e6 |
| | loss = model_utils.dice_loss( |
| | jnp.array(inputs), jnp.array(targets), interpolation='nearest', |
| | all_pairs=True) |
| | self.assertTupleEqual(loss.shape, (1, 3, 2)) |
| | np.testing.assert_array_almost_equal(loss, [[[0.0, 1.0], |
| | [0.5, 1.0], |
| | [1.0, 0.0]]], decimal=3) |
| |
|
| | def test_weighted_square_error(self): |
| | """Tests implementation of squared error.""" |
| |
|
| | predictions = jnp.array([ |
| | [ |
| | [1.0, 3.0, 5.0, 6.0], |
| | [3.0, 5.0, 11.0, 10.0], |
| | [9.0, 10.0, 11.0, 12.0], |
| | [14.0, 13.0, 14.0, 17.0], |
| | ], |
| | [ |
| | [17.0, 18.0, 21.0, 22.0], |
| | [20.0, 19.0, 24.0, 25.0], |
| | [27.0, 29.0, 30.0, 32.0], |
| | [27.0, 28.0, 33.0, 32.0], |
| | ], |
| | ]) |
| | targets = jnp.arange(1, 33).reshape(2, 4, 4) |
| |
|
| | |
| | loss = model_utils.weighted_mean_squared_error(predictions, targets) |
| | expected_loss = jnp.mean(jnp.array([38.0, 70.0])) |
| | self.assertAlmostEqual(loss, expected_loss, places=5) |
| |
|
| | |
| | |
| | loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| | axis=(1, 2)) |
| | self.assertAlmostEqual(loss, expected_loss, places=5) |
| |
|
| | loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| | axis=(-1, -2)) |
| | self.assertAlmostEqual(loss, expected_loss, places=5) |
| |
|
| | loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| | axis=(2, 1)) |
| | self.assertAlmostEqual(loss, expected_loss, places=5) |
| |
|
| | |
| | loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| | axis=-1) |
| | expected_loss = jnp.mean(jnp.array([[9, 25, 0, 4], |
| | [8, 12, 38, 12]])) |
| | self.assertAlmostEqual(loss, expected_loss, places=5) |
| |
|
| | loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| | axis=2) |
| | self.assertAlmostEqual(loss, expected_loss, places=5) |
| |
|
| | loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| | axis=1) |
| | expected_loss = jnp.mean(jnp.array([[5, 3, 21, 9], |
| | [9, 22, 18, 21]])) |
| | self.assertAlmostEqual(loss, expected_loss, places=5) |
| |
|
| | |
| | weights = jnp.array([[1, 1, 1, 0], [0, 1, 1, 0]]) |
| | loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| | weights, axis=-1) |
| | expected_loss = jnp.mean(jnp.array([9, 25, 12, 38, 0])) |
| | self.assertAlmostEqual(loss, expected_loss, places=5) |
| |
|
| | weights = jnp.array([1, 0]) |
| | loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| | weights, axis=-1) |
| | expected_loss = jnp.mean(jnp.array([9, 25, 0, 4])) |
| | self.assertAlmostEqual(loss, expected_loss, places=5) |
| |
|
| |
|
| | class MetricTest(parameterized.TestCase): |
| | """Tests the metric computation related utilities.""" |
| |
|
| | def is_valid(self, t, value_name): |
| | """Helper function to assert that tensor `t` does not have `nan`, `inf`.""" |
| | self.assertFalse( |
| | jnp.isnan(t).any(), msg=f'Found nan\'s in {t} for {value_name}') |
| | self.assertFalse( |
| | jnp.isinf(t).any(), msg=f'Found inf\'s in {t} for {value_name}') |
| |
|
| | def test_weighted_topk_correctly_classified(self): |
| | """Tests the topk accuracy computation.""" |
| | batch_size = 512 |
| | num_of_classes = 100 |
| | logits = jnp.array( |
| | np.random.normal(size=(batch_size, num_of_classes)), dtype=jnp.float32) |
| | labels = jnp.array(np.random.randint(num_of_classes, size=(batch_size,))) |
| |
|
| | one_hot_targets = common_utils.onehot(labels, logits.shape[-1]) |
| | classification_accuracy = model_utils.weighted_correctly_classified( |
| | logits, one_hot_targets) |
| | top_one_accuracy = model_utils.weighted_topk_correctly_classified( |
| | logits, one_hot_targets, k=1) |
| | self.assertSequenceAlmostEqual( |
| | classification_accuracy.flatten(), top_one_accuracy.flatten()) |
| |
|
| | top_n_accuracy = model_utils.weighted_topk_correctly_classified( |
| | logits, one_hot_targets, k=num_of_classes) |
| | self.assertEqual(jnp.mean(top_n_accuracy), 1) |
| |
|
| | |
| | top_5_accuracy = model_utils.weighted_topk_correctly_classified( |
| | logits, one_hot_targets, k=5) |
| | top5_pred = np.argsort( |
| | np.reshape(logits, [-1, num_of_classes]), axis=1)[:, -5:] |
| | y_true = np.array(labels) |
| | top5_pred = np.reshape(top5_pred, [-1, 5]) |
| | y_true = np.reshape(y_true, [-1]) |
| | np_top_accuracy = np.array( |
| | [y_true[i] in top5_pred[i, :] for i in range(len(y_true))]) |
| | self.assertSequenceAlmostEqual(top_5_accuracy.flatten(), |
| | np_top_accuracy.flatten()) |
| |
|
| | def test_weighted_recall(self): |
| | """Tests the topk recall computation.""" |
| |
|
| | logits = np.array([[[2, 3, 4], |
| | [4, 3, 2], |
| | [4, 2, 3], |
| | [3, 2, 4], |
| | [4, 2, 3], |
| | ]]) |
| | labels = np.array([[[1, 1, 0], |
| | [1, 1, 0], |
| | [1, 0, 0], |
| | [1, 0, 0], |
| | [0, 0, 0] |
| | ]]) |
| |
|
| | batch_size = 8 |
| | logits = jnp.tile(logits, [batch_size, 1, 1]) |
| | labels = jnp.tile(labels, [batch_size, 1, 1]) |
| |
|
| | recall = model_utils.weighted_recall(logits, labels) |
| | recall_expected = np.array([[1/2, 1., 1., 0., 0.]] * batch_size) |
| | self.assertSequenceAlmostEqual( |
| | recall.flatten(), recall_expected.flatten()) |
| |
|
| | @parameterized.parameters(itertools.product([1., 0.], [1., 0.])) |
| | def test_weighted_top_one_correctly_classified(self, label_multiplier, |
| | weight_multiplier): |
| | """Tests the top1 correct computation.""" |
| | batch_size = 512 |
| | num_of_classes = 100 |
| | logits = jnp.array(np.random.normal( |
| | size=(batch_size, 50, num_of_classes)), dtype=jnp.float32) |
| | labels = jnp.array(np.random.randint( |
| | 0, 2, size=(batch_size, 50, num_of_classes))) |
| | labels *= label_multiplier |
| |
|
| | weights = jnp.ones(shape=(batch_size,), dtype=jnp.float32) |
| | weights *= weight_multiplier |
| |
|
| | is_correct_array = model_utils.weighted_top_one_correctly_classified( |
| | logits, labels, weights=weights) |
| | num_correct = jnp.sum(is_correct_array) |
| | is_correct_array_ref = model_utils.weighted_topk_correctly_classified( |
| | logits, labels, weights, k=1) |
| |
|
| | np.testing.assert_array_almost_equal( |
| | is_correct_array, is_correct_array_ref) |
| | np.testing.assert_equal(np.sum(is_correct_array), |
| | np.sum(is_correct_array_ref)) |
| |
|
| | self.is_valid(num_correct, 'Number of correctly classified') |
| |
|
| | @parameterized.parameters(itertools.product([1., 0.], [1., 0.])) |
| | def test_weighted_unnormalized_sigmoid_cross_entropy(self, label_multiplier, |
| | weight_multiplier): |
| | """Tests the unnormalized sigmoid cross entropy computation.""" |
| | batch_size = 512 |
| | num_of_classes = 100 |
| | logits = jnp.array( |
| | np.random.normal(size=(batch_size, num_of_classes)), dtype=jnp.float32) |
| | labels = jnp.array(np.random.randint(0, 2, |
| | size=(batch_size, num_of_classes))) |
| | labels *= label_multiplier |
| |
|
| | weights = jnp.ones(shape=(batch_size,), dtype=jnp.float32) |
| | weights *= weight_multiplier |
| |
|
| | loss_array = model_utils.weighted_unnormalized_sigmoid_cross_entropy( |
| | logits, labels, weights=weights) |
| | loss_sum = jnp.sum(loss_array) |
| |
|
| | self.is_valid(loss_sum, 'Loss value') |
| |
|
| | @parameterized.parameters(itertools.product([1., 0.], [1., 0.])) |
| | def test_weighted_unnormalized_softmax_cross_entropy(self, label_multiplier, |
| | weight_multiplier): |
| | """Tests the unnormalized softmax cross entropy computation.""" |
| | batch_size = 512 |
| | num_of_classes = 100 |
| | logits = jnp.array( |
| | np.random.normal(size=(batch_size, num_of_classes)), dtype=jnp.float32) |
| | labels = jnp.array( |
| | np.random.randint(0, 2, size=(batch_size, num_of_classes))) |
| | labels *= label_multiplier |
| |
|
| | weights = jnp.ones(shape=(batch_size,), dtype=jnp.float32) |
| | weights *= weight_multiplier |
| |
|
| | loss_array = model_utils.weighted_unnormalized_softmax_cross_entropy( |
| | logits, labels, weights=weights) |
| | loss_sum = jnp.sum(loss_array) |
| |
|
| | self.is_valid(loss_sum, 'Loss value') |
| |
|
| |
|
| | if __name__ == '__main__': |
| | absltest.main() |
| |
|