| import torch |
|
|
| def knn(x, k, add_one_to_k=False): |
| if add_one_to_k: k = k + 1 |
| inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x) |
| xx = torch.sum(x**2, dim=1, keepdim=True) |
| pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous() |
| idx = pairwise_distance.topk(k=k, dim=-1)[1] |
| return idx |
|
|
| def pc_normalize(pc): |
| l = pc.shape[0] |
| centroid = np.mean(pc, axis=0) |
| pc = pc - centroid |
| m = np.max(np.sqrt(np.sum(pc**2, axis=1))) |
| pc = pc / m |
| return pc |
|
|
| def square_distance(src, dst): |
| """ |
| Calculate Euclid distance between each two points. |
| src^T * dst = xn * xm + yn * ym + zn * zm; |
| sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; |
| sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; |
| dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 |
| = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst |
| Input: |
| src: source points, [B, N, C] |
| dst: target points, [B, M, C] |
| Output: |
| dist: per-point square distance, [B, N, M] |
| """ |
| B, N, _ = src.shape |
| _, M, _ = dst.shape |
| dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) |
| dist += torch.sum(src ** 2, -1).view(B, N, 1) |
| dist += torch.sum(dst ** 2, -1).view(B, 1, M) |
| return dist |
|
|
| def index_points(points, idx): |
| """ |
| Input: |
| points: input points data, [B, N, C] |
| idx: sample index data, [B, S] |
| Return: |
| new_points:, indexed points data, [B, S, C] |
| """ |
| device = points.device |
| B = points.shape[0] |
| view_shape = list(idx.shape) |
| view_shape[1:] = [1] * (len(view_shape) - 1) |
| repeat_shape = list(idx.shape) |
| repeat_shape[0] = 1 |
| batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) |
| new_points = points[batch_indices, idx, :] |
| return new_points |
|
|
| def farthest_point_sample(xyz, npoint, start_with_first_point=False): |
| """ |
| Input: |
| xyz: pointcloud data, [B, N, C] |
| npoint: number of samples |
| Return: |
| centroids: sampled pointcloud index, [B, npoint] |
| """ |
| device = xyz.device |
| B, N, C = xyz.shape |
| centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) |
| distance = torch.ones(B, N).to(device) * 1e10 |
| if not start_with_first_point: |
| farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) |
| else: |
| farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) * 0 |
| batch_indices = torch.arange(B, dtype=torch.long).to(device) |
| for i in range(npoint): |
| centroids[:, i] = farthest |
| centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) |
| dist = torch.sum((xyz - centroid) ** 2, -1) |
| mask = dist < distance |
| distance[mask] = dist[mask] |
| farthest = torch.max(distance, -1)[1] |
| return centroids |
|
|
| def knn_point(k, pos1, pos2): |
| ''' |
| Input: |
| k: int32, number of k in k-nn search |
| pos1: (batch_size, ndataset, c) float32 array, input points |
| pos2: (batch_size, npoint, c) float32 array, query points |
| Output: |
| val: (batch_size, npoint, k) float32 array, L2 distances |
| idx: (batch_size, npoint, k) int32 array, indices to input points |
| ''' |
| B, N, C = pos1.shape |
| M = pos2.shape[1] |
| pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1) |
| pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1) |
| dist = torch.sum(-(pos1-pos2)**2,-1) |
| val,idx = dist.topk(k=k,dim = -1) |
| return torch.sqrt(-val), idx |
|
|
| def query_ball_point(radius, nsample, xyz, new_xyz, get_cnt=False): |
| """ |
| Input: |
| radius: local region radius |
| nsample: max sample number in local region |
| xyz: all points, [B, N, C] |
| new_xyz: query points, [B, S, C] |
| Return: |
| group_idx: grouped points index, [B, S, nsample] |
| """ |
| device = xyz.device |
| B, N, C = xyz.shape |
| _, S, _ = new_xyz.shape |
| group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) |
| sqrdists = square_distance(new_xyz, xyz) |
| group_idx[sqrdists > radius ** 2] = N |
| |
| if get_cnt: |
| mask = group_idx != N |
| cnt = mask.sum(dim=-1) |
|
|
| group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] |
| group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) |
| mask = group_idx == N |
| group_idx[mask] = group_first[mask] |
| if get_cnt: |
| return group_idx, cnt |
| else: |
| return group_idx |
|
|
| def get_graph_feature(x, k=20, device=None): |
| |
| x = x.view(*x.size()[:3]) |
| idx = knn(x, k=k) |
| batch_size, num_points, _ = idx.size() |
|
|
| if device is None: |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points |
|
|
| idx = idx + idx_base |
|
|
| idx = idx.view(-1) |
|
|
| _, num_dims, _ = x.size() |
|
|
| x = x.transpose(2, 1).contiguous() |
| feature = x.view(batch_size * num_points, -1)[idx, :] |
| feature = feature.view(batch_size, num_points, k, num_dims) |
| x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) |
|
|
| feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2) |
|
|
| return feature |