Other
English
SPT_GridNet-HD_baseline / src /dependencies /FRNN /tests /frnn_validation_2D_simple.py
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
import glob
import csv
import torch
import frnn
num_points_fixed_query = 10000
class TestFRNN:
def __init__(self, num_pcs=1, K=5, r=0.1):
self.pc1 = torch.rand((num_pcs, num_points_fixed_query, 2), dtype=torch.float).cuda()
self.pc2 = torch.rand((num_pcs, num_points_fixed_query, 2), dtype=torch.float).cuda()
self.num_pcs = num_pcs
self.r = r
self.K = K
self.num_points = num_points_fixed_query
lengths1 = torch.ones((num_pcs,), dtype=torch.long) * num_points_fixed_query
lengths2 = torch.ones((num_pcs,), dtype=torch.long) * num_points_fixed_query
self.lengths1_cuda = lengths1.cuda()
self.lengths2_cuda = lengths2.cuda()
def frnn_2d(self):
dists, idxs, nn, grid = frnn.frnn_grid_points(
self.pc1,
self.pc2,
self.lengths1_cuda,
self.lengths2_cuda,
self.K,
self.r,
radius_cell_ratio=1.0
)
sorted_points2 = grid.sorted_points2
sorted_points2_idxs = grid.sorted_points2_idxs[:, :, None].long().expand(-1, -1, 2)
idxs_pc2 = torch.gather(self.pc2, 1, sorted_points2_idxs)
print(torch.allclose(sorted_points2, idxs_pc2))
if __name__ == "__main__":
validator = TestFRNN()
validator.frnn_2d()