File size: 13,472 Bytes
1327f34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 | # Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for functions in box_utils.py."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import numpy as np
from scenic.model_lib.base_models import box_utils
from shapely import geometry
def sample_cxcywh_bbox(key, batch_shape):
"""Samples a bounding box in the [cx, cy, w, h] in [0, 1] range format."""
frac = 0.8
sample = jax.random.uniform(key, shape=(*batch_shape, 4)) * frac
cx, cy, w, h = jnp.split(sample, indices_or_sections=4, axis=-1)
# Make sure the bounding box doesn't cross the right and top image borders
w = jnp.where(cx + w / 2. >= 1., frac * 2. * (1. - cx), w)
h = jnp.where(cy + h / 2. >= 1., frac * 2. * (1. - cy), h)
# Make sure the bounding box doesn't cross the left and bottom image borders
w = jnp.where(cx - w / 2. <= 0., frac * 2. * cx, w)
h = jnp.where(cy - h / 2. <= 0., frac * 2. * cy, h)
bbox = jnp.concatenate([cx, cy, w, h], axis=-1)
return bbox
class BoxUtilsTest(parameterized.TestCase):
"""Tests all the bounding box related utilities."""
def test_box_cxcywh_to_xyxy(self):
"""Test for correctness of the box_cxcywh_to_xyxy operation."""
cxcywh = jnp.array([[[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.2, 0.4]],
[[0.3, 0.2, 0.1, 0.4], [0.3, 0.2, 0.1, 0.4]]],
dtype=jnp.float32)
expected = jnp.array([[[0.0, 0.1, 0.2, 0.5], [0.0, 0.1, 0.2, 0.5]],
[[0.25, 0.0, 0.35, 0.4], [0.25, 0.0, 0.35, 0.4]]],
dtype=jnp.float32)
output = box_utils.box_cxcywh_to_xyxy(cxcywh)
self.assertSequenceAlmostEqual(
expected.flatten(), output.flatten(), places=5)
# also test whether an exception is raised when a non-box input is provided
with self.assertRaises(ValueError):
cxcywh = jnp.array(np.random.uniform(size=(2, 3, 5)))
_ = box_utils.box_cxcywh_to_xyxy(cxcywh)
@parameterized.parameters([((3, 1, 4),), ((4, 6, 4),)])
def test_box_cxcywh_to_xyxy_shape(self, input_shape):
"""Test whether the shape is correct for box_cxcywh_to_xyxy."""
cxcywh = jnp.array(np.random.uniform(size=input_shape))
xyxy = box_utils.box_cxcywh_to_xyxy(cxcywh)
self.assertEqual(xyxy.shape, cxcywh.shape)
@parameterized.parameters([((2, 5, 4),), ((1, 3, 4),)])
def test_box_cxcy_to_xyxy_box_xyxy_to_cxcy(self, input_shape):
"""Test both box conversion functions as they are inverses of each other."""
cxcywh = jnp.array(np.random.uniform(size=input_shape))
xyxy = box_utils.box_cxcywh_to_xyxy(cxcywh)
cxcywh_loop = box_utils.box_xyxy_to_cxcywh(xyxy)
self.assertSequenceAlmostEqual(
cxcywh_loop.flatten(), cxcywh.flatten(), places=5)
def sample_cxcywha(key, batch_shape):
"""Sample rotated bounding boxes [cx, cy, w, h, a (radians)]."""
scale = jnp.array([0.3, 0.3, 0.5, 0.5, 1.0])
offset = jnp.array([0.35, 0.35, 0, 0, 0])
return jax.random.uniform(key, shape=(*batch_shape, 5)) * scale + offset
class RBoxUtilsTest(parameterized.TestCase):
"""Tests all the rotated bounding box related utilities."""
def test_convert_cxcywha_to_corners(self):
key = jax.random.PRNGKey(0)
cxcywha = sample_cxcywha(key, batch_shape=(300, 200))
self.assertEqual(cxcywha.shape, (300, 200, 5))
corners = box_utils.cxcywha_to_corners(cxcywha)
self.assertEqual(corners.shape, (300, 200, 4, 2))
# This criteria depends on sample function sampling within unit square.
self.assertTrue(jnp.all(corners >= 0))
self.assertTrue(jnp.all(corners <= 1))
def test_convert_corners_to_cxcywha(self):
key = jax.random.PRNGKey(0)
cxcywha = sample_cxcywha(key, batch_shape=(3, 2))
self.assertEqual(cxcywha.shape, (3, 2, 5))
corners = box_utils.cxcywha_to_corners(cxcywha)
cxcywha2 = box_utils.corners_to_cxcywha(corners)
np.testing.assert_allclose(cxcywha2, cxcywha, atol=1e-6)
def test_convert_cxcywha_to_corners_single_rotated(self):
cxcywha = jnp.array([1, 1, jnp.sqrt(2), jnp.sqrt(2), 45. * jnp.pi / 180.])
corners = box_utils.cxcywha_to_corners(cxcywha)
expected_corners = [[1, 0], [2, 1], [1, 2], [0, 1]]
np.testing.assert_allclose(corners, expected_corners, atol=1e-7)
def test_intersect_line_segments(self):
"""Test for correctness of the intersect_lines operation."""
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
lines1 = jax.random.uniform(subkey, (100, 2, 2))
lines2 = jax.random.uniform(key, (100, 2, 2))
intersect_line_segments = jax.jit(
jax.vmap(box_utils.intersect_line_segments))
intersections = intersect_line_segments(lines1, lines2)
self.assertEqual(intersections.shape, (100, 2))
expected_intersections = []
for i in range(len(lines1)):
line1 = geometry.LineString(lines1[i])
line2 = geometry.LineString(lines2[i])
it = line1.intersection(line2)
it_coord = (
it.coords[0]
if isinstance(it, geometry.Point) else jnp.asarray([jnp.nan] * 2))
expected_intersections.append(it_coord)
np.testing.assert_allclose(intersections, expected_intersections, atol=1e-7)
def test_intersect_rbox_edges_same_box(self):
"""Test for correctness of the intersect_rbox_edges operation."""
rbox1 = jnp.array([0.5, 0.5, 1.0, 1.0, 0])
rbox2 = rbox1
corners1 = box_utils.cxcywha_to_corners(rbox1)
corners2 = box_utils.cxcywha_to_corners(rbox2)
it_points = box_utils.intersect_rbox_edges(corners1, corners2)
self.assertEqual(it_points.shape, (4, 4, 2))
it_points = it_points[~jnp.any(jnp.isnan(it_points), -1)]
it_points = sorted([(x, y) for x, y in np.array(it_points)])
expected_points = sorted([(0, 0), (0, 1), (1, 0), (1, 1)] * 2)
self.assertSequenceEqual(it_points, expected_points)
def test_intersect_rbox_edges_rotated_box(self):
"""Test rboxe inscribes the other with 45 degree angle."""
rbox1 = jnp.array([1.0, 1.0, 1.0, 1.0, 0])
rbox2 = jnp.array([1.0, 1.0, jnp.sqrt(2), jnp.sqrt(2), 45. * np.pi / 180.])
corners1 = box_utils.cxcywha_to_corners(rbox1)
corners2 = box_utils.cxcywha_to_corners(rbox2)
it_points = box_utils.intersect_rbox_edges(corners1, corners2)
it_points = jnp.round(
it_points[~jnp.any(jnp.isnan(it_points), -1)], decimals=4)
it_points = sorted([(x, y) for x, y in np.array(it_points)])
# Expect intersection at unrotated box vertices.
expected_pts = sorted([(1.5, 1.5), (1.5, 0.5), (0.5, 0.5), (0.5, 1.5)] * 2)
self.assertSequenceEqual(it_points, expected_pts)
class IoUTest(parameterized.TestCase):
"""Test box_iou and generalized_box_iou functions."""
def test_box_iou_values(self):
"""Tests if 0 <= IoU <= 1 and -1 <= gIoU <=1."""
# Create fake predictions and targets
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
pred_bbox = sample_cxcywh_bbox(key, batch_shape=(4, 100))
tgt_bbox = sample_cxcywh_bbox(subkey, batch_shape=(4, 63))
pred_bbox = box_utils.box_cxcywh_to_xyxy(pred_bbox)
tgt_bbox = box_utils.box_cxcywh_to_xyxy(tgt_bbox)
iou, union = box_utils.box_iou(pred_bbox, tgt_bbox, all_pairs=True)
self.assertTrue(jnp.all(iou >= 0))
self.assertTrue(jnp.all(iou <= 1.))
self.assertTrue(jnp.all(union >= 0.))
giou = box_utils.generalized_box_iou(pred_bbox, tgt_bbox, all_pairs=True)
self.assertTrue(jnp.all(giou >= -1.))
self.assertTrue(jnp.all(giou <= 1.))
def test_box_iou(self):
"""Test box_iou using hand designed targets."""
in1 = jnp.array([
[[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.5, 0.8]],
[[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8], [0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]],
],
dtype=jnp.float32)
in2 = jnp.array([
[[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.7, 0.8]],
[[0.7, 0.4, 0.8, 0.6], [0.8, 0.6, 0.7, 0.4], [0.1, 0.1, 0.2, 0.2]],
[[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]],
],
dtype=jnp.float32)
target = jnp.array(
[[0.0, 0.125, 0.125], [0.0625, 0.0, 0.0], [0.0, 0.0, 1.0]],
dtype=jnp.float32)
output, _ = box_utils.box_iou(in1, in2, all_pairs=False)
self.assertSequenceAlmostEqual(output.flatten(), target.flatten(), places=3)
@classmethod
def _get_method_fn(cls, method):
"""Returns method_fn function corresponding to method str."""
if method == 'iou':
method_fn = lambda x, y, **kwargs: box_utils.box_iou(x, y, **kwargs)[0]
elif method == 'giou':
method_fn = box_utils.generalized_box_iou
else:
raise ValueError(f'Unknown method {method}')
return method_fn
@parameterized.parameters('iou', 'giou')
def test_all_pairs_true_false(self, method):
"""Use *box_iou(..., all_pairs=False) to test the all_pairs=True case."""
method_fn = self._get_method_fn(method)
in1 = jnp.array(
[ # [2, 2, 4] tensor.
[[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0]],
[[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8]],
],
dtype=jnp.float32)
in2 = jnp.array([
[[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8]],
[[0.7, 0.4, 0.8, 0.6], [0.1, 0.5, 0.7, 0.7]],
],
dtype=jnp.float32)
# we will simulate all_pairs=True by manually permuting in1
in1_1 = jnp.array(
[ # [2, 2, 4] tensor.
[[0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.3, 0.4]],
[[0.6, 0.2, 1.0, 0.8], [0.6, 0.2, 1.0, 1.0]],
],
dtype=jnp.float32)
out = method_fn(in1, in2, all_pairs=False) # [2, 2]
out_1 = method_fn(in1_1, in2, all_pairs=False) # [2, 2]
# we can compare these against the output of all_pairs=True
out_all = method_fn(in1, in2, all_pairs=True) # [2, 2, 2]
# assemble out_all_ using out and out_1. The comparisons are illustrated
# below:
# out = [[0-0, 1-1], [2-2, 3-3]]
# out_1 = [[1-0, 0-1], [3-2, 2-3]]
# out_all = [[[0-0, 0-1], [1-0, 1-1]], [[2-2, 2-3], [3-2, 3-3]]]
out_all_ = jnp.array([[[out[0, 0], out_1[0, 1]], [out_1[0, 0], out[0, 1]]],
[[out[1, 0], out_1[1, 1]], [out_1[1, 0], out[1, 1]]]],
dtype=jnp.float32)
self.assertSequenceAlmostEqual(out_all.flatten(), out_all_.flatten())
def test_generalized_box_iou(self):
"""Same as test_box_iou but for generalized_box_iou()."""
in1 = jnp.array([
[[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.5, 0.8]],
[[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8], [0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]],
],
dtype=jnp.float32)
in2 = jnp.array([
[[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.7, 0.8]],
[[0.7, 0.4, 0.8, 0.6], [0.4, 0.4, 0.8, 0.6], [0.1, 0.1, 0.2, 0.2]],
[[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]],
],
dtype=jnp.float32)
target_iou = jnp.array(
[[0.0, 0.125, 0.125], [0.0625, 1. / 7., 0.0], [0.0, 0.0, 1.0]],
dtype=jnp.float32)
target_extra = jnp.array(
[[-2. / 3., 0.0, -1. / 9.], [0.0, -2. / 9., -3. / 4.], [0.0, 0.0, 0.0]],
dtype=jnp.float32)
target = target_iou + target_extra
output = box_utils.generalized_box_iou(in1, in2, all_pairs=False)
self.assertSequenceAlmostEqual(output.flatten(), target.flatten(), places=3)
# if the boxes are invalid it should raise an AssertionError
# TODO(b/166344282): uncomment these after enabling the assertions
# in1 = jnp.array([[[0.1, 0.2, 0.3, 0.4],],], dtype=jnp.float32)
# in2 = jnp.array([[[0.3, 0.4, 0.1, 0.2],],], dtype=jnp.float32)
# with self.assertRaises(AssertionError):
# _ = box_utils.generalized_box_iou(in1, in2, all_pairs=False)
@parameterized.parameters('iou', 'giou')
def test_backward(self, method):
"""Test whether *box_iou methods have a grad."""
method_fn = self._get_method_fn(method)
def loss_fn(x, y, all_pairs):
return method_fn(x, y, all_pairs=all_pairs).sum()
grad_fn = jax.grad(loss_fn)
in1 = jnp.array(
[ # [2, 2, 4] tensor.
[[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0]],
[[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8]],
],
dtype=jnp.float32)
in2 = jnp.array([
[[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8]],
[[0.7, 0.4, 0.8, 0.6], [0.1, 0.5, 0.7, 0.7]],
],
dtype=jnp.float32)
grad_in1 = grad_fn(in1, in2, all_pairs=True)
self.assertSequenceEqual(grad_in1.shape, in1.shape)
grad_in1 = grad_fn(in1, in2, all_pairs=False)
self.assertSequenceEqual(grad_in1.shape, in1.shape)
if __name__ == '__main__':
absltest.main()
|