| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Unit tests for functions in box_utils.py.""" |
| |
|
| | from absl.testing import absltest |
| | from absl.testing import parameterized |
| | import jax |
| | import jax.numpy as jnp |
| | import numpy as np |
| | from scenic.model_lib.base_models import box_utils |
| | from shapely import geometry |
| |
|
| |
|
| | def sample_cxcywh_bbox(key, batch_shape): |
| | """Samples a bounding box in the [cx, cy, w, h] in [0, 1] range format.""" |
| | frac = 0.8 |
| | sample = jax.random.uniform(key, shape=(*batch_shape, 4)) * frac |
| | cx, cy, w, h = jnp.split(sample, indices_or_sections=4, axis=-1) |
| | |
| | w = jnp.where(cx + w / 2. >= 1., frac * 2. * (1. - cx), w) |
| | h = jnp.where(cy + h / 2. >= 1., frac * 2. * (1. - cy), h) |
| | |
| | w = jnp.where(cx - w / 2. <= 0., frac * 2. * cx, w) |
| | h = jnp.where(cy - h / 2. <= 0., frac * 2. * cy, h) |
| |
|
| | bbox = jnp.concatenate([cx, cy, w, h], axis=-1) |
| | return bbox |
| |
|
| |
|
| | class BoxUtilsTest(parameterized.TestCase): |
| | """Tests all the bounding box related utilities.""" |
| |
|
| | def test_box_cxcywh_to_xyxy(self): |
| | """Test for correctness of the box_cxcywh_to_xyxy operation.""" |
| | cxcywh = jnp.array([[[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.2, 0.4]], |
| | [[0.3, 0.2, 0.1, 0.4], [0.3, 0.2, 0.1, 0.4]]], |
| | dtype=jnp.float32) |
| | expected = jnp.array([[[0.0, 0.1, 0.2, 0.5], [0.0, 0.1, 0.2, 0.5]], |
| | [[0.25, 0.0, 0.35, 0.4], [0.25, 0.0, 0.35, 0.4]]], |
| | dtype=jnp.float32) |
| | output = box_utils.box_cxcywh_to_xyxy(cxcywh) |
| | self.assertSequenceAlmostEqual( |
| | expected.flatten(), output.flatten(), places=5) |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | cxcywh = jnp.array(np.random.uniform(size=(2, 3, 5))) |
| | _ = box_utils.box_cxcywh_to_xyxy(cxcywh) |
| |
|
| | @parameterized.parameters([((3, 1, 4),), ((4, 6, 4),)]) |
| | def test_box_cxcywh_to_xyxy_shape(self, input_shape): |
| | """Test whether the shape is correct for box_cxcywh_to_xyxy.""" |
| | cxcywh = jnp.array(np.random.uniform(size=input_shape)) |
| | xyxy = box_utils.box_cxcywh_to_xyxy(cxcywh) |
| | self.assertEqual(xyxy.shape, cxcywh.shape) |
| |
|
| | @parameterized.parameters([((2, 5, 4),), ((1, 3, 4),)]) |
| | def test_box_cxcy_to_xyxy_box_xyxy_to_cxcy(self, input_shape): |
| | """Test both box conversion functions as they are inverses of each other.""" |
| | cxcywh = jnp.array(np.random.uniform(size=input_shape)) |
| | xyxy = box_utils.box_cxcywh_to_xyxy(cxcywh) |
| | cxcywh_loop = box_utils.box_xyxy_to_cxcywh(xyxy) |
| | self.assertSequenceAlmostEqual( |
| | cxcywh_loop.flatten(), cxcywh.flatten(), places=5) |
| |
|
| |
|
| | def sample_cxcywha(key, batch_shape): |
| | """Sample rotated bounding boxes [cx, cy, w, h, a (radians)].""" |
| | scale = jnp.array([0.3, 0.3, 0.5, 0.5, 1.0]) |
| | offset = jnp.array([0.35, 0.35, 0, 0, 0]) |
| | return jax.random.uniform(key, shape=(*batch_shape, 5)) * scale + offset |
| |
|
| |
|
| | class RBoxUtilsTest(parameterized.TestCase): |
| | """Tests all the rotated bounding box related utilities.""" |
| |
|
| | def test_convert_cxcywha_to_corners(self): |
| | key = jax.random.PRNGKey(0) |
| | cxcywha = sample_cxcywha(key, batch_shape=(300, 200)) |
| | self.assertEqual(cxcywha.shape, (300, 200, 5)) |
| |
|
| | corners = box_utils.cxcywha_to_corners(cxcywha) |
| | self.assertEqual(corners.shape, (300, 200, 4, 2)) |
| | |
| | self.assertTrue(jnp.all(corners >= 0)) |
| | self.assertTrue(jnp.all(corners <= 1)) |
| |
|
| | def test_convert_corners_to_cxcywha(self): |
| | key = jax.random.PRNGKey(0) |
| | cxcywha = sample_cxcywha(key, batch_shape=(3, 2)) |
| | self.assertEqual(cxcywha.shape, (3, 2, 5)) |
| |
|
| | corners = box_utils.cxcywha_to_corners(cxcywha) |
| | cxcywha2 = box_utils.corners_to_cxcywha(corners) |
| | np.testing.assert_allclose(cxcywha2, cxcywha, atol=1e-6) |
| |
|
| | def test_convert_cxcywha_to_corners_single_rotated(self): |
| | cxcywha = jnp.array([1, 1, jnp.sqrt(2), jnp.sqrt(2), 45. * jnp.pi / 180.]) |
| | corners = box_utils.cxcywha_to_corners(cxcywha) |
| | expected_corners = [[1, 0], [2, 1], [1, 2], [0, 1]] |
| | np.testing.assert_allclose(corners, expected_corners, atol=1e-7) |
| |
|
| | def test_intersect_line_segments(self): |
| | """Test for correctness of the intersect_lines operation.""" |
| | key = jax.random.PRNGKey(0) |
| | key, subkey = jax.random.split(key) |
| | lines1 = jax.random.uniform(subkey, (100, 2, 2)) |
| | lines2 = jax.random.uniform(key, (100, 2, 2)) |
| | intersect_line_segments = jax.jit( |
| | jax.vmap(box_utils.intersect_line_segments)) |
| | intersections = intersect_line_segments(lines1, lines2) |
| | self.assertEqual(intersections.shape, (100, 2)) |
| |
|
| | expected_intersections = [] |
| | for i in range(len(lines1)): |
| | line1 = geometry.LineString(lines1[i]) |
| | line2 = geometry.LineString(lines2[i]) |
| | it = line1.intersection(line2) |
| | it_coord = ( |
| | it.coords[0] |
| | if isinstance(it, geometry.Point) else jnp.asarray([jnp.nan] * 2)) |
| | expected_intersections.append(it_coord) |
| |
|
| | np.testing.assert_allclose(intersections, expected_intersections, atol=1e-7) |
| |
|
| | def test_intersect_rbox_edges_same_box(self): |
| | """Test for correctness of the intersect_rbox_edges operation.""" |
| | rbox1 = jnp.array([0.5, 0.5, 1.0, 1.0, 0]) |
| | rbox2 = rbox1 |
| | corners1 = box_utils.cxcywha_to_corners(rbox1) |
| | corners2 = box_utils.cxcywha_to_corners(rbox2) |
| | it_points = box_utils.intersect_rbox_edges(corners1, corners2) |
| | self.assertEqual(it_points.shape, (4, 4, 2)) |
| | it_points = it_points[~jnp.any(jnp.isnan(it_points), -1)] |
| | it_points = sorted([(x, y) for x, y in np.array(it_points)]) |
| | expected_points = sorted([(0, 0), (0, 1), (1, 0), (1, 1)] * 2) |
| | self.assertSequenceEqual(it_points, expected_points) |
| |
|
| | def test_intersect_rbox_edges_rotated_box(self): |
| | """Test rboxe inscribes the other with 45 degree angle.""" |
| | rbox1 = jnp.array([1.0, 1.0, 1.0, 1.0, 0]) |
| | rbox2 = jnp.array([1.0, 1.0, jnp.sqrt(2), jnp.sqrt(2), 45. * np.pi / 180.]) |
| | corners1 = box_utils.cxcywha_to_corners(rbox1) |
| | corners2 = box_utils.cxcywha_to_corners(rbox2) |
| | it_points = box_utils.intersect_rbox_edges(corners1, corners2) |
| | it_points = jnp.round( |
| | it_points[~jnp.any(jnp.isnan(it_points), -1)], decimals=4) |
| | it_points = sorted([(x, y) for x, y in np.array(it_points)]) |
| | |
| | expected_pts = sorted([(1.5, 1.5), (1.5, 0.5), (0.5, 0.5), (0.5, 1.5)] * 2) |
| | self.assertSequenceEqual(it_points, expected_pts) |
| |
|
| |
|
| | class IoUTest(parameterized.TestCase): |
| | """Test box_iou and generalized_box_iou functions.""" |
| |
|
| | def test_box_iou_values(self): |
| | """Tests if 0 <= IoU <= 1 and -1 <= gIoU <=1.""" |
| |
|
| | |
| | key = jax.random.PRNGKey(0) |
| | key, subkey = jax.random.split(key) |
| | pred_bbox = sample_cxcywh_bbox(key, batch_shape=(4, 100)) |
| | tgt_bbox = sample_cxcywh_bbox(subkey, batch_shape=(4, 63)) |
| |
|
| | pred_bbox = box_utils.box_cxcywh_to_xyxy(pred_bbox) |
| | tgt_bbox = box_utils.box_cxcywh_to_xyxy(tgt_bbox) |
| |
|
| | iou, union = box_utils.box_iou(pred_bbox, tgt_bbox, all_pairs=True) |
| | self.assertTrue(jnp.all(iou >= 0)) |
| | self.assertTrue(jnp.all(iou <= 1.)) |
| | self.assertTrue(jnp.all(union >= 0.)) |
| |
|
| | giou = box_utils.generalized_box_iou(pred_bbox, tgt_bbox, all_pairs=True) |
| | self.assertTrue(jnp.all(giou >= -1.)) |
| | self.assertTrue(jnp.all(giou <= 1.)) |
| |
|
| | def test_box_iou(self): |
| | """Test box_iou using hand designed targets.""" |
| | in1 = jnp.array([ |
| | [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.5, 0.8]], |
| | [[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8], [0.0, 0.0, 0.0, 0.0]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]], |
| | ], |
| | dtype=jnp.float32) |
| | in2 = jnp.array([ |
| | [[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.7, 0.8]], |
| | [[0.7, 0.4, 0.8, 0.6], [0.8, 0.6, 0.7, 0.4], [0.1, 0.1, 0.2, 0.2]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]], |
| | ], |
| | dtype=jnp.float32) |
| |
|
| | target = jnp.array( |
| | [[0.0, 0.125, 0.125], [0.0625, 0.0, 0.0], [0.0, 0.0, 1.0]], |
| | dtype=jnp.float32) |
| |
|
| | output, _ = box_utils.box_iou(in1, in2, all_pairs=False) |
| |
|
| | self.assertSequenceAlmostEqual(output.flatten(), target.flatten(), places=3) |
| |
|
| | @classmethod |
| | def _get_method_fn(cls, method): |
| | """Returns method_fn function corresponding to method str.""" |
| | if method == 'iou': |
| | method_fn = lambda x, y, **kwargs: box_utils.box_iou(x, y, **kwargs)[0] |
| | elif method == 'giou': |
| | method_fn = box_utils.generalized_box_iou |
| | else: |
| | raise ValueError(f'Unknown method {method}') |
| | return method_fn |
| |
|
| | @parameterized.parameters('iou', 'giou') |
| | def test_all_pairs_true_false(self, method): |
| | """Use *box_iou(..., all_pairs=False) to test the all_pairs=True case.""" |
| | method_fn = self._get_method_fn(method) |
| |
|
| | in1 = jnp.array( |
| | [ |
| | [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0]], |
| | [[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8]], |
| | ], |
| | dtype=jnp.float32) |
| | in2 = jnp.array([ |
| | [[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8]], |
| | [[0.7, 0.4, 0.8, 0.6], [0.1, 0.5, 0.7, 0.7]], |
| | ], |
| | dtype=jnp.float32) |
| |
|
| | |
| | in1_1 = jnp.array( |
| | [ |
| | [[0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.3, 0.4]], |
| | [[0.6, 0.2, 1.0, 0.8], [0.6, 0.2, 1.0, 1.0]], |
| | ], |
| | dtype=jnp.float32) |
| |
|
| | out = method_fn(in1, in2, all_pairs=False) |
| | out_1 = method_fn(in1_1, in2, all_pairs=False) |
| |
|
| | |
| | out_all = method_fn(in1, in2, all_pairs=True) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | out_all_ = jnp.array([[[out[0, 0], out_1[0, 1]], [out_1[0, 0], out[0, 1]]], |
| | [[out[1, 0], out_1[1, 1]], [out_1[1, 0], out[1, 1]]]], |
| | dtype=jnp.float32) |
| |
|
| | self.assertSequenceAlmostEqual(out_all.flatten(), out_all_.flatten()) |
| |
|
| | def test_generalized_box_iou(self): |
| | """Same as test_box_iou but for generalized_box_iou().""" |
| | in1 = jnp.array([ |
| | [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.5, 0.8]], |
| | [[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8], [0.0, 0.0, 0.0, 0.0]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]], |
| | ], |
| | dtype=jnp.float32) |
| | in2 = jnp.array([ |
| | [[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.7, 0.8]], |
| | [[0.7, 0.4, 0.8, 0.6], [0.4, 0.4, 0.8, 0.6], [0.1, 0.1, 0.2, 0.2]], |
| | [[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]], |
| | ], |
| | dtype=jnp.float32) |
| |
|
| | target_iou = jnp.array( |
| | [[0.0, 0.125, 0.125], [0.0625, 1. / 7., 0.0], [0.0, 0.0, 1.0]], |
| | dtype=jnp.float32) |
| | target_extra = jnp.array( |
| | [[-2. / 3., 0.0, -1. / 9.], [0.0, -2. / 9., -3. / 4.], [0.0, 0.0, 0.0]], |
| | dtype=jnp.float32) |
| | target = target_iou + target_extra |
| |
|
| | output = box_utils.generalized_box_iou(in1, in2, all_pairs=False) |
| |
|
| | self.assertSequenceAlmostEqual(output.flatten(), target.flatten(), places=3) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | @parameterized.parameters('iou', 'giou') |
| | def test_backward(self, method): |
| | """Test whether *box_iou methods have a grad.""" |
| | method_fn = self._get_method_fn(method) |
| |
|
| | def loss_fn(x, y, all_pairs): |
| | return method_fn(x, y, all_pairs=all_pairs).sum() |
| |
|
| | grad_fn = jax.grad(loss_fn) |
| |
|
| | in1 = jnp.array( |
| | [ |
| | [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0]], |
| | [[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8]], |
| | ], |
| | dtype=jnp.float32) |
| | in2 = jnp.array([ |
| | [[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8]], |
| | [[0.7, 0.4, 0.8, 0.6], [0.1, 0.5, 0.7, 0.7]], |
| | ], |
| | dtype=jnp.float32) |
| |
|
| | grad_in1 = grad_fn(in1, in2, all_pairs=True) |
| | self.assertSequenceEqual(grad_in1.shape, in1.shape) |
| |
|
| | grad_in1 = grad_fn(in1, in2, all_pairs=False) |
| | self.assertSequenceEqual(grad_in1.shape, in1.shape) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | absltest.main() |
| |
|