fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for functions in box_utils.py."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import numpy as np
from scenic.model_lib.base_models import box_utils
from shapely import geometry
def sample_cxcywh_bbox(key, batch_shape):
"""Samples a bounding box in the [cx, cy, w, h] in [0, 1] range format."""
frac = 0.8
sample = jax.random.uniform(key, shape=(*batch_shape, 4)) * frac
cx, cy, w, h = jnp.split(sample, indices_or_sections=4, axis=-1)
# Make sure the bounding box doesn't cross the right and top image borders
w = jnp.where(cx + w / 2. >= 1., frac * 2. * (1. - cx), w)
h = jnp.where(cy + h / 2. >= 1., frac * 2. * (1. - cy), h)
# Make sure the bounding box doesn't cross the left and bottom image borders
w = jnp.where(cx - w / 2. <= 0., frac * 2. * cx, w)
h = jnp.where(cy - h / 2. <= 0., frac * 2. * cy, h)
bbox = jnp.concatenate([cx, cy, w, h], axis=-1)
return bbox
class BoxUtilsTest(parameterized.TestCase):
"""Tests all the bounding box related utilities."""
def test_box_cxcywh_to_xyxy(self):
"""Test for correctness of the box_cxcywh_to_xyxy operation."""
cxcywh = jnp.array([[[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.2, 0.4]],
[[0.3, 0.2, 0.1, 0.4], [0.3, 0.2, 0.1, 0.4]]],
dtype=jnp.float32)
expected = jnp.array([[[0.0, 0.1, 0.2, 0.5], [0.0, 0.1, 0.2, 0.5]],
[[0.25, 0.0, 0.35, 0.4], [0.25, 0.0, 0.35, 0.4]]],
dtype=jnp.float32)
output = box_utils.box_cxcywh_to_xyxy(cxcywh)
self.assertSequenceAlmostEqual(
expected.flatten(), output.flatten(), places=5)
# also test whether an exception is raised when a non-box input is provided
with self.assertRaises(ValueError):
cxcywh = jnp.array(np.random.uniform(size=(2, 3, 5)))
_ = box_utils.box_cxcywh_to_xyxy(cxcywh)
@parameterized.parameters([((3, 1, 4),), ((4, 6, 4),)])
def test_box_cxcywh_to_xyxy_shape(self, input_shape):
"""Test whether the shape is correct for box_cxcywh_to_xyxy."""
cxcywh = jnp.array(np.random.uniform(size=input_shape))
xyxy = box_utils.box_cxcywh_to_xyxy(cxcywh)
self.assertEqual(xyxy.shape, cxcywh.shape)
@parameterized.parameters([((2, 5, 4),), ((1, 3, 4),)])
def test_box_cxcy_to_xyxy_box_xyxy_to_cxcy(self, input_shape):
"""Test both box conversion functions as they are inverses of each other."""
cxcywh = jnp.array(np.random.uniform(size=input_shape))
xyxy = box_utils.box_cxcywh_to_xyxy(cxcywh)
cxcywh_loop = box_utils.box_xyxy_to_cxcywh(xyxy)
self.assertSequenceAlmostEqual(
cxcywh_loop.flatten(), cxcywh.flatten(), places=5)
def sample_cxcywha(key, batch_shape):
"""Sample rotated bounding boxes [cx, cy, w, h, a (radians)]."""
scale = jnp.array([0.3, 0.3, 0.5, 0.5, 1.0])
offset = jnp.array([0.35, 0.35, 0, 0, 0])
return jax.random.uniform(key, shape=(*batch_shape, 5)) * scale + offset
class RBoxUtilsTest(parameterized.TestCase):
"""Tests all the rotated bounding box related utilities."""
def test_convert_cxcywha_to_corners(self):
key = jax.random.PRNGKey(0)
cxcywha = sample_cxcywha(key, batch_shape=(300, 200))
self.assertEqual(cxcywha.shape, (300, 200, 5))
corners = box_utils.cxcywha_to_corners(cxcywha)
self.assertEqual(corners.shape, (300, 200, 4, 2))
# This criteria depends on sample function sampling within unit square.
self.assertTrue(jnp.all(corners >= 0))
self.assertTrue(jnp.all(corners <= 1))
def test_convert_corners_to_cxcywha(self):
key = jax.random.PRNGKey(0)
cxcywha = sample_cxcywha(key, batch_shape=(3, 2))
self.assertEqual(cxcywha.shape, (3, 2, 5))
corners = box_utils.cxcywha_to_corners(cxcywha)
cxcywha2 = box_utils.corners_to_cxcywha(corners)
np.testing.assert_allclose(cxcywha2, cxcywha, atol=1e-6)
def test_convert_cxcywha_to_corners_single_rotated(self):
cxcywha = jnp.array([1, 1, jnp.sqrt(2), jnp.sqrt(2), 45. * jnp.pi / 180.])
corners = box_utils.cxcywha_to_corners(cxcywha)
expected_corners = [[1, 0], [2, 1], [1, 2], [0, 1]]
np.testing.assert_allclose(corners, expected_corners, atol=1e-7)
def test_intersect_line_segments(self):
"""Test for correctness of the intersect_lines operation."""
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
lines1 = jax.random.uniform(subkey, (100, 2, 2))
lines2 = jax.random.uniform(key, (100, 2, 2))
intersect_line_segments = jax.jit(
jax.vmap(box_utils.intersect_line_segments))
intersections = intersect_line_segments(lines1, lines2)
self.assertEqual(intersections.shape, (100, 2))
expected_intersections = []
for i in range(len(lines1)):
line1 = geometry.LineString(lines1[i])
line2 = geometry.LineString(lines2[i])
it = line1.intersection(line2)
it_coord = (
it.coords[0]
if isinstance(it, geometry.Point) else jnp.asarray([jnp.nan] * 2))
expected_intersections.append(it_coord)
np.testing.assert_allclose(intersections, expected_intersections, atol=1e-7)
def test_intersect_rbox_edges_same_box(self):
"""Test for correctness of the intersect_rbox_edges operation."""
rbox1 = jnp.array([0.5, 0.5, 1.0, 1.0, 0])
rbox2 = rbox1
corners1 = box_utils.cxcywha_to_corners(rbox1)
corners2 = box_utils.cxcywha_to_corners(rbox2)
it_points = box_utils.intersect_rbox_edges(corners1, corners2)
self.assertEqual(it_points.shape, (4, 4, 2))
it_points = it_points[~jnp.any(jnp.isnan(it_points), -1)]
it_points = sorted([(x, y) for x, y in np.array(it_points)])
expected_points = sorted([(0, 0), (0, 1), (1, 0), (1, 1)] * 2)
self.assertSequenceEqual(it_points, expected_points)
def test_intersect_rbox_edges_rotated_box(self):
"""Test rboxe inscribes the other with 45 degree angle."""
rbox1 = jnp.array([1.0, 1.0, 1.0, 1.0, 0])
rbox2 = jnp.array([1.0, 1.0, jnp.sqrt(2), jnp.sqrt(2), 45. * np.pi / 180.])
corners1 = box_utils.cxcywha_to_corners(rbox1)
corners2 = box_utils.cxcywha_to_corners(rbox2)
it_points = box_utils.intersect_rbox_edges(corners1, corners2)
it_points = jnp.round(
it_points[~jnp.any(jnp.isnan(it_points), -1)], decimals=4)
it_points = sorted([(x, y) for x, y in np.array(it_points)])
# Expect intersection at unrotated box vertices.
expected_pts = sorted([(1.5, 1.5), (1.5, 0.5), (0.5, 0.5), (0.5, 1.5)] * 2)
self.assertSequenceEqual(it_points, expected_pts)
class IoUTest(parameterized.TestCase):
"""Test box_iou and generalized_box_iou functions."""
def test_box_iou_values(self):
"""Tests if 0 <= IoU <= 1 and -1 <= gIoU <=1."""
# Create fake predictions and targets
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
pred_bbox = sample_cxcywh_bbox(key, batch_shape=(4, 100))
tgt_bbox = sample_cxcywh_bbox(subkey, batch_shape=(4, 63))
pred_bbox = box_utils.box_cxcywh_to_xyxy(pred_bbox)
tgt_bbox = box_utils.box_cxcywh_to_xyxy(tgt_bbox)
iou, union = box_utils.box_iou(pred_bbox, tgt_bbox, all_pairs=True)
self.assertTrue(jnp.all(iou >= 0))
self.assertTrue(jnp.all(iou <= 1.))
self.assertTrue(jnp.all(union >= 0.))
giou = box_utils.generalized_box_iou(pred_bbox, tgt_bbox, all_pairs=True)
self.assertTrue(jnp.all(giou >= -1.))
self.assertTrue(jnp.all(giou <= 1.))
def test_box_iou(self):
"""Test box_iou using hand designed targets."""
in1 = jnp.array([
[[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.5, 0.8]],
[[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8], [0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]],
],
dtype=jnp.float32)
in2 = jnp.array([
[[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.7, 0.8]],
[[0.7, 0.4, 0.8, 0.6], [0.8, 0.6, 0.7, 0.4], [0.1, 0.1, 0.2, 0.2]],
[[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]],
],
dtype=jnp.float32)
target = jnp.array(
[[0.0, 0.125, 0.125], [0.0625, 0.0, 0.0], [0.0, 0.0, 1.0]],
dtype=jnp.float32)
output, _ = box_utils.box_iou(in1, in2, all_pairs=False)
self.assertSequenceAlmostEqual(output.flatten(), target.flatten(), places=3)
@classmethod
def _get_method_fn(cls, method):
"""Returns method_fn function corresponding to method str."""
if method == 'iou':
method_fn = lambda x, y, **kwargs: box_utils.box_iou(x, y, **kwargs)[0]
elif method == 'giou':
method_fn = box_utils.generalized_box_iou
else:
raise ValueError(f'Unknown method {method}')
return method_fn
@parameterized.parameters('iou', 'giou')
def test_all_pairs_true_false(self, method):
"""Use *box_iou(..., all_pairs=False) to test the all_pairs=True case."""
method_fn = self._get_method_fn(method)
in1 = jnp.array(
[ # [2, 2, 4] tensor.
[[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0]],
[[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8]],
],
dtype=jnp.float32)
in2 = jnp.array([
[[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8]],
[[0.7, 0.4, 0.8, 0.6], [0.1, 0.5, 0.7, 0.7]],
],
dtype=jnp.float32)
# we will simulate all_pairs=True by manually permuting in1
in1_1 = jnp.array(
[ # [2, 2, 4] tensor.
[[0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.3, 0.4]],
[[0.6, 0.2, 1.0, 0.8], [0.6, 0.2, 1.0, 1.0]],
],
dtype=jnp.float32)
out = method_fn(in1, in2, all_pairs=False) # [2, 2]
out_1 = method_fn(in1_1, in2, all_pairs=False) # [2, 2]
# we can compare these against the output of all_pairs=True
out_all = method_fn(in1, in2, all_pairs=True) # [2, 2, 2]
# assemble out_all_ using out and out_1. The comparisons are illustrated
# below:
# out = [[0-0, 1-1], [2-2, 3-3]]
# out_1 = [[1-0, 0-1], [3-2, 2-3]]
# out_all = [[[0-0, 0-1], [1-0, 1-1]], [[2-2, 2-3], [3-2, 3-3]]]
out_all_ = jnp.array([[[out[0, 0], out_1[0, 1]], [out_1[0, 0], out[0, 1]]],
[[out[1, 0], out_1[1, 1]], [out_1[1, 0], out[1, 1]]]],
dtype=jnp.float32)
self.assertSequenceAlmostEqual(out_all.flatten(), out_all_.flatten())
def test_generalized_box_iou(self):
"""Same as test_box_iou but for generalized_box_iou()."""
in1 = jnp.array([
[[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.5, 0.8]],
[[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8], [0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]],
],
dtype=jnp.float32)
in2 = jnp.array([
[[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.7, 0.8]],
[[0.7, 0.4, 0.8, 0.6], [0.4, 0.4, 0.8, 0.6], [0.1, 0.1, 0.2, 0.2]],
[[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]],
],
dtype=jnp.float32)
target_iou = jnp.array(
[[0.0, 0.125, 0.125], [0.0625, 1. / 7., 0.0], [0.0, 0.0, 1.0]],
dtype=jnp.float32)
target_extra = jnp.array(
[[-2. / 3., 0.0, -1. / 9.], [0.0, -2. / 9., -3. / 4.], [0.0, 0.0, 0.0]],
dtype=jnp.float32)
target = target_iou + target_extra
output = box_utils.generalized_box_iou(in1, in2, all_pairs=False)
self.assertSequenceAlmostEqual(output.flatten(), target.flatten(), places=3)
# if the boxes are invalid it should raise an AssertionError
# TODO(b/166344282): uncomment these after enabling the assertions
# in1 = jnp.array([[[0.1, 0.2, 0.3, 0.4],],], dtype=jnp.float32)
# in2 = jnp.array([[[0.3, 0.4, 0.1, 0.2],],], dtype=jnp.float32)
# with self.assertRaises(AssertionError):
# _ = box_utils.generalized_box_iou(in1, in2, all_pairs=False)
@parameterized.parameters('iou', 'giou')
def test_backward(self, method):
"""Test whether *box_iou methods have a grad."""
method_fn = self._get_method_fn(method)
def loss_fn(x, y, all_pairs):
return method_fn(x, y, all_pairs=all_pairs).sum()
grad_fn = jax.grad(loss_fn)
in1 = jnp.array(
[ # [2, 2, 4] tensor.
[[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0]],
[[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8]],
],
dtype=jnp.float32)
in2 = jnp.array([
[[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8]],
[[0.7, 0.4, 0.8, 0.6], [0.1, 0.5, 0.7, 0.7]],
],
dtype=jnp.float32)
grad_in1 = grad_fn(in1, in2, all_pairs=True)
self.assertSequenceEqual(grad_in1.shape, in1.shape)
grad_in1 = grad_fn(in1, in2, all_pairs=False)
self.assertSequenceEqual(grad_in1.shape, in1.shape)
if __name__ == '__main__':
absltest.main()