Spaces:
Sleeping
Sleeping
| import torch | |
| def square_distance(src, dst): | |
| """ | |
| Calculate Euclid distance between each two points. | |
| src^T * dst = xn * xm + yn * ym + zn * zm; | |
| sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; | |
| sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; | |
| dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 | |
| = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst | |
| Input: | |
| src: source points, [B, N, C] | |
| dst: target points, [B, M, C] | |
| Output: | |
| dist: per-point square distance, [B, N, M] | |
| """ | |
| B, N, _ = src.shape | |
| _, M, _ = dst.shape | |
| dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) | |
| dist += torch.sum(src**2, -1).view(B, N, 1) | |
| dist += torch.sum(dst**2, -1).view(B, 1, M) | |
| return dist | |
| def knn_point(nsample, xyz, new_xyz): | |
| """ | |
| Input: | |
| nsample: max sample number in local region | |
| xyz: all points, [B, N, C] | |
| new_xyz: query points, [B, S, C] | |
| Return: | |
| group_idx: grouped points index, [B, S, nsample] | |
| """ | |
| sqrdists = square_distance(new_xyz, xyz) | |
| _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) | |
| return group_idx | |