| | |
| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | import torch |
| | from pytorch3d.ops import perspective_n_points |
| | from pytorch3d.transforms import rotation_conversions |
| |
|
| | from .common_testing import TestCaseMixin |
| |
|
| |
|
| | def reproj_error(x_world, y, R, T, weight=None): |
| | |
| | y_hat = torch.matmul(x_world, R) + T[:, None, :] |
| | y_hat = y_hat / y_hat[..., 2:] |
| | if weight is None: |
| | weight = y.new_ones((1, 1)) |
| | return (((weight[:, :, None] * (y - y_hat[..., :2])) ** 2).sum(dim=-1) ** 0.5).mean( |
| | dim=-1 |
| | ) |
| |
|
| |
|
| | class TestPerspectiveNPoints(TestCaseMixin, unittest.TestCase): |
| | def setUp(self) -> None: |
| | super().setUp() |
| | torch.manual_seed(42) |
| |
|
| | @classmethod |
| | def _generate_epnp_test_from_2d(cls, y): |
| | """ |
| | Instantiate random x_world, x_cam, R, T given a set of input |
| | 2D projections y. |
| | """ |
| | batch_size = y.shape[0] |
| | x_cam = torch.cat((y, torch.rand_like(y[:, :, :1]) * 2.0 + 3.5), dim=2) |
| | x_cam[:, :, :2] *= x_cam[:, :, 2:] |
| | R = rotation_conversions.random_rotations(batch_size).to(y) |
| | T = torch.randn_like(R[:, :1, :]) |
| | T[:, :, 2] = (T[:, :, 2] + 3.0).clamp(2.0) |
| | x_world = torch.matmul(x_cam - T, R.transpose(1, 2)) |
| | return x_cam, x_world, R, T |
| |
|
| | def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False): |
| | sol = perspective_n_points.efficient_pnp( |
| | x_world, y.expand_as(x_world[:, :, :2]), skip_quadratic_eq=skip_q |
| | ) |
| |
|
| | err_2d = reproj_error(x_world, y, sol.R, sol.T) |
| | R_est_quat = rotation_conversions.matrix_to_quaternion(sol.R) |
| | R_quat = rotation_conversions.matrix_to_quaternion(R) |
| |
|
| | num_pts = x_world.shape[-2] |
| | if check_output: |
| | assert_msg = ( |
| | f"test_perspective_n_points assertion failure for " |
| | f"n_points={num_pts}, " |
| | f"skip_quadratic={skip_q}, " |
| | f"no noise." |
| | ) |
| |
|
| | self.assertClose(err_2d, sol.err_2d, msg=assert_msg) |
| | self.assertTrue((err_2d < 1e-3).all(), msg=assert_msg) |
| |
|
| | def norm_fn(t): |
| | return t.norm(dim=-1) |
| |
|
| | self.assertNormsClose( |
| | T, sol.T[:, None, :], rtol=4e-3, norm_fn=norm_fn, msg=assert_msg |
| | ) |
| | self.assertNormsClose( |
| | R_quat, R_est_quat, rtol=3e-3, norm_fn=norm_fn, msg=assert_msg |
| | ) |
| |
|
| | if print_stats: |
| | torch.set_printoptions(precision=5, sci_mode=False) |
| | for err_2d, err_3d, R_gt, T_gt in zip( |
| | sol.err_2d, |
| | sol.err_3d, |
| | torch.cat((sol.R, R), dim=-1), |
| | torch.stack((sol.T, T[:, 0, :]), dim=-1), |
| | ): |
| | print("2D Error: %1.4f" % err_2d.item()) |
| | print("3D Error: %1.4f" % err_3d.item()) |
| | print("R_hat | R_gt\n", R_gt) |
| | print("T_hat | T_gt\n", T_gt) |
| |
|
| | def _testcase_from_2d( |
| | self, y, print_stats, benchmark, skip_q=False, skip_check_thresh=5 |
| | ): |
| | """ |
| | In case num_pts < 6, EPnP gets unstable, so we check it doesn't crash |
| | """ |
| | x_cam, x_world, R, T = TestPerspectiveNPoints._generate_epnp_test_from_2d( |
| | y[None].repeat(16, 1, 1) |
| | ) |
| |
|
| | if print_stats: |
| | print("Run without noise") |
| |
|
| | if benchmark: |
| | torch.cuda.synchronize() |
| |
|
| | def result(): |
| | self._run_and_print(x_world, y, R, T, False, skip_q) |
| | torch.cuda.synchronize() |
| |
|
| | return result |
| |
|
| | self._run_and_print( |
| | x_world, |
| | y, |
| | R, |
| | T, |
| | print_stats, |
| | skip_q, |
| | check_output=True if y.shape[1] > skip_check_thresh else False, |
| | ) |
| |
|
| | |
| | if print_stats: |
| | print("Run with noise") |
| | x_world += torch.randn_like(x_world) * 0.1 |
| | self._run_and_print(x_world, y, R, T, print_stats, skip_q) |
| |
|
| | def case_with_gaussian_points( |
| | self, batch_size=10, num_pts=20, print_stats=False, benchmark=True, skip_q=False |
| | ): |
| | return self._testcase_from_2d( |
| | torch.randn((num_pts, 2)).cuda() / 3.0, |
| | print_stats=print_stats, |
| | benchmark=benchmark, |
| | skip_q=skip_q, |
| | ) |
| |
|
| | def test_perspective_n_points(self, print_stats=False): |
| | if print_stats: |
| | print("RUN ON A DENSE GRID") |
| | u = torch.linspace(-1.0, 1.0, 20) |
| | v = torch.linspace(-1.0, 1.0, 15) |
| | for skip_q in [False, True]: |
| | self._testcase_from_2d( |
| | torch.cartesian_prod(u, v).cuda(), print_stats, False, skip_q |
| | ) |
| |
|
| | for num_pts in range(6, 3, -1): |
| | for skip_q in [False, True]: |
| | if print_stats: |
| | print(f"RUN ON {num_pts} points; skip_quadratic: {skip_q}") |
| |
|
| | self.case_with_gaussian_points( |
| | num_pts=num_pts, |
| | print_stats=print_stats, |
| | benchmark=False, |
| | skip_q=skip_q, |
| | ) |
| |
|
| | def test_weighted_perspective_n_points(self, batch_size=16, num_pts=200): |
| | |
| | y = torch.randn((batch_size, num_pts, 2)).cuda() / 3.0 |
| | x_cam, x_world, R, T = TestPerspectiveNPoints._generate_epnp_test_from_2d(y) |
| |
|
| | |
| | weights = (torch.rand_like(x_world[:, :, 0]) > 0.5).float() |
| |
|
| | |
| | weights[:, :6] = 1.0 |
| |
|
| | |
| | |
| | y = y + (1 - weights[:, :, None]) * 100.0 |
| |
|
| | def norm_fn(t): |
| | return t.norm(dim=-1) |
| |
|
| | for skip_quadratic_eq in (True, False): |
| | |
| | sol = perspective_n_points.efficient_pnp( |
| | x_world, y, skip_quadratic_eq=skip_quadratic_eq, weights=weights |
| | ) |
| | sol_R_quat = rotation_conversions.matrix_to_quaternion(sol.R) |
| | sol_T = sol.T |
| |
|
| | |
| | |
| | for i in range(batch_size): |
| | ok = weights[i] > 0 |
| | x_world_ok = x_world[i, ok][None] |
| | y_ok = y[i, ok][None] |
| | sol_ok = perspective_n_points.efficient_pnp( |
| | x_world_ok, y_ok, skip_quadratic_eq=False |
| | ) |
| | R_est_quat_ok = rotation_conversions.matrix_to_quaternion(sol_ok.R) |
| |
|
| | self.assertNormsClose(sol_T[i], sol_ok.T[0], rtol=3e-3, norm_fn=norm_fn) |
| | self.assertNormsClose( |
| | sol_R_quat[i], R_est_quat_ok[0], rtol=3e-4, norm_fn=norm_fn |
| | ) |
| |
|