| | |
| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| | from itertools import product |
| |
|
| | import torch |
| | from pytorch3d.ops import sample_points_from_meshes |
| | from pytorch3d.ops.ball_query import ball_query |
| | from pytorch3d.ops.knn import _KNN |
| | from pytorch3d.utils import ico_sphere |
| |
|
| | from .common_testing import get_random_cuda_device, TestCaseMixin |
| |
|
| |
|
| | class TestBallQuery(TestCaseMixin, unittest.TestCase): |
| | def setUp(self) -> None: |
| | super().setUp() |
| | torch.manual_seed(1) |
| |
|
| | @staticmethod |
| | def _ball_query_naive( |
| | p1, p2, lengths1, lengths2, K: int, radius: float |
| | ) -> torch.Tensor: |
| | """ |
| | Naive PyTorch implementation of ball query. |
| | """ |
| | N, P1, D = p1.shape |
| | _N, P2, _D = p2.shape |
| |
|
| | assert N == _N and D == _D |
| |
|
| | if lengths1 is None: |
| | lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device) |
| | if lengths2 is None: |
| | lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device) |
| |
|
| | radius2 = radius * radius |
| | dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device) |
| | idx = torch.full((N, P1, K), fill_value=-1, dtype=torch.int64, device=p1.device) |
| |
|
| | |
| | for n in range(N): |
| | num1 = lengths1[n].item() |
| | num2 = lengths2[n].item() |
| |
|
| | |
| | for i in range(num1): |
| | |
| | count = 0 |
| | for j in range(num2): |
| | dist = p2[n, j] - p1[n, i] |
| | dist2 = (dist * dist).sum() |
| | if dist2 < radius2 and count < K: |
| | dists[n, i, count] = dist2 |
| | idx[n, i, count] = j |
| | count += 1 |
| |
|
| | return _KNN(dists=dists, idx=idx, knn=None) |
| |
|
| | def _ball_query_vs_python_square_helper(self, device): |
| | Ns = [1, 4] |
| | Ds = [3, 5, 8] |
| | P1s = [8, 24] |
| | P2s = [8, 16, 32] |
| | Ks = [1, 5] |
| | Rs = [3, 5] |
| | factors = [Ns, Ds, P1s, P2s, Ks, Rs] |
| | for N, D, P1, P2, K, R in product(*factors): |
| | x = torch.randn(N, P1, D, device=device, requires_grad=True) |
| | x_cuda = x.clone().detach() |
| | x_cuda.requires_grad_(True) |
| | y = torch.randn(N, P2, D, device=device, requires_grad=True) |
| | y_cuda = y.clone().detach() |
| | y_cuda.requires_grad_(True) |
| |
|
| | |
| | out1 = self._ball_query_naive( |
| | x, y, lengths1=None, lengths2=None, K=K, radius=R |
| | ) |
| | out2 = ball_query(x_cuda, y_cuda, K=K, radius=R) |
| |
|
| | |
| | self.assertClose(out1.dists, out2.dists) |
| | |
| | self.assertTrue(torch.all(out1.idx == out2.idx)) |
| |
|
| | |
| | grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) |
| | loss1 = (out1.dists * grad_dist).sum() |
| | loss1.backward() |
| | loss2 = (out2.dists * grad_dist).sum() |
| | loss2.backward() |
| |
|
| | self.assertClose(x_cuda.grad, x.grad, atol=5e-6) |
| | self.assertClose(y_cuda.grad, y.grad, atol=5e-6) |
| |
|
| | def test_ball_query_vs_python_square_cpu(self): |
| | device = torch.device("cpu") |
| | self._ball_query_vs_python_square_helper(device) |
| |
|
| | def test_ball_query_vs_python_square_cuda(self): |
| | device = get_random_cuda_device() |
| | self._ball_query_vs_python_square_helper(device) |
| |
|
| | def _ball_query_vs_python_ragged_helper(self, device): |
| | Ns = [1, 4] |
| | Ds = [3, 5, 8] |
| | P1s = [8, 24] |
| | P2s = [8, 16, 32] |
| | Ks = [2, 3, 10] |
| | Rs = [1.4, 5] |
| | factors = [Ns, Ds, P1s, P2s, Ks, Rs] |
| | for N, D, P1, P2, K, R in product(*factors): |
| | x = torch.rand((N, P1, D), device=device, requires_grad=True) |
| | y = torch.rand((N, P2, D), device=device, requires_grad=True) |
| | lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) |
| | lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device) |
| |
|
| | x_csrc = x.clone().detach() |
| | x_csrc.requires_grad_(True) |
| | y_csrc = y.clone().detach() |
| | y_csrc.requires_grad_(True) |
| |
|
| | |
| | out1 = self._ball_query_naive( |
| | x, y, lengths1=lengths1, lengths2=lengths2, K=K, radius=R |
| | ) |
| | out2 = ball_query( |
| | x_csrc, |
| | y_csrc, |
| | lengths1=lengths1, |
| | lengths2=lengths2, |
| | K=K, |
| | radius=R, |
| | ) |
| |
|
| | self.assertClose(out1.idx, out2.idx) |
| | self.assertClose(out1.dists, out2.dists) |
| |
|
| | |
| | grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) |
| | loss1 = (out1.dists * grad_dist).sum() |
| | loss1.backward() |
| | loss2 = (out2.dists * grad_dist).sum() |
| | loss2.backward() |
| |
|
| | self.assertClose(x_csrc.grad, x.grad, atol=5e-6) |
| | self.assertClose(y_csrc.grad, y.grad, atol=5e-6) |
| |
|
| | def test_ball_query_vs_python_ragged_cpu(self): |
| | device = torch.device("cpu") |
| | self._ball_query_vs_python_ragged_helper(device) |
| |
|
| | def test_ball_query_vs_python_ragged_cuda(self): |
| | device = get_random_cuda_device() |
| | self._ball_query_vs_python_ragged_helper(device) |
| |
|
| | def test_ball_query_output_simple(self): |
| | device = get_random_cuda_device() |
| | N, P1, P2, K = 5, 8, 16, 4 |
| | sphere = ico_sphere(level=2, device=device).extend(N) |
| | points_1 = sample_points_from_meshes(sphere, P1) |
| | points_2 = sample_points_from_meshes(sphere, P2) * 5.0 |
| | radius = 6.0 |
| |
|
| | naive_out = self._ball_query_naive( |
| | points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius |
| | ) |
| | cuda_out = ball_query(points_1, points_2, K=K, radius=radius) |
| |
|
| | |
| | |
| | naive_out_zeros = (naive_out.idx == 0).sum(dim=-1).max() |
| | cuda_out_zeros = (cuda_out.idx == 0).sum(dim=-1).max() |
| | self.assertTrue(naive_out_zeros == 0 or naive_out_zeros == 1) |
| | self.assertTrue(cuda_out_zeros == 0 or cuda_out_zeros == 1) |
| |
|
| | |
| | radius = 0.5 |
| | naive_out = self._ball_query_naive( |
| | points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius |
| | ) |
| | cuda_out = ball_query(points_1, points_2, K=K, radius=radius) |
| | naive_out_allzeros = (naive_out.idx == -1).all() |
| | cuda_out_allzeros = (cuda_out.idx == -1).sum() |
| | self.assertTrue(naive_out_allzeros) |
| | self.assertTrue(cuda_out_allzeros) |
| |
|
| | @staticmethod |
| | def ball_query_square( |
| | N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str |
| | ): |
| | device = torch.device(device) |
| | pts1 = torch.randn(N, P1, D, device=device, requires_grad=True) |
| | pts2 = torch.randn(N, P2, D, device=device, requires_grad=True) |
| | grad_dists = torch.randn(N, P1, K, device=device) |
| | torch.cuda.synchronize() |
| |
|
| | def output(): |
| | out = ball_query(pts1, pts2, K=K, radius=radius) |
| | loss = (out.dists * grad_dists).sum() |
| | loss.backward() |
| | torch.cuda.synchronize() |
| |
|
| | return output |
| |
|
| | @staticmethod |
| | def ball_query_ragged( |
| | N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str |
| | ): |
| | device = torch.device(device) |
| | pts1 = torch.rand((N, P1, D), device=device, requires_grad=True) |
| | pts2 = torch.rand((N, P2, D), device=device, requires_grad=True) |
| | lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) |
| | lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device) |
| | grad_dists = torch.randn(N, P1, K, device=device) |
| | torch.cuda.synchronize() |
| |
|
| | def output(): |
| | out = ball_query( |
| | pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K, radius=radius |
| | ) |
| | loss = (out.dists * grad_dists).sum() |
| | loss.backward() |
| | torch.cuda.synchronize() |
| |
|
| | return output |
| |
|