import torch def knn(x, k, add_one_to_k=False): if add_one_to_k: k = k + 1 inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x) xx = torch.sum(x**2, dim=1, keepdim=True) pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous() idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) return idx def pc_normalize(pc): l = pc.shape[0] centroid = np.mean(pc, axis=0) pc = pc - centroid m = np.max(np.sqrt(np.sum(pc**2, axis=1))) pc = pc / m return pc def square_distance(src, dst): """ Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst Input: src: source points, [B, N, C] dst: target points, [B, M, C] Output: dist: per-point square distance, [B, N, M] """ B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(dst ** 2, -1).view(B, 1, M) return dist def index_points(points, idx): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S] Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points def farthest_point_sample(xyz, npoint, start_with_first_point=False): """ Input: xyz: pointcloud data, [B, N, C] npoint: number of samples Return: centroids: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) distance = torch.ones(B, N).to(device) * 1e10 if not start_with_first_point: farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) else: farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) * 0 batch_indices = torch.arange(B, dtype=torch.long).to(device) for i in range(npoint): centroids[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) dist = torch.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = torch.max(distance, -1)[1] return centroids def knn_point(k, pos1, pos2): ''' Input: k: int32, number of k in k-nn search pos1: (batch_size, ndataset, c) float32 array, input points pos2: (batch_size, npoint, c) float32 array, query points Output: val: (batch_size, npoint, k) float32 array, L2 distances idx: (batch_size, npoint, k) int32 array, indices to input points ''' B, N, C = pos1.shape M = pos2.shape[1] pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1) pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1) dist = torch.sum(-(pos1-pos2)**2,-1) val,idx = dist.topk(k=k,dim = -1) return torch.sqrt(-val), idx def query_ball_point(radius, nsample, xyz, new_xyz, get_cnt=False): """ Input: radius: local region radius nsample: max sample number in local region xyz: all points, [B, N, C] new_xyz: query points, [B, S, C] Return: group_idx: grouped points index, [B, S, nsample] """ device = xyz.device B, N, C = xyz.shape _, S, _ = new_xyz.shape group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) sqrdists = square_distance(new_xyz, xyz) group_idx[sqrdists > radius ** 2] = N if get_cnt: mask = group_idx != N cnt = mask.sum(dim=-1) group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) mask = group_idx == N group_idx[mask] = group_first[mask] if get_cnt: return group_idx, cnt else: return group_idx def get_graph_feature(x, k=20, device=None): # x = x.squeeze() x = x.view(*x.size()[:3]) idx = knn(x, k=k) # (batch_size, num_points, k) batch_size, num_points, _ = idx.size() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points idx = idx + idx_base idx = idx.view(-1) _, num_dims, _ = x.size() x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) feature = x.view(batch_size * num_points, -1)[idx, :] feature = feature.view(batch_size, num_points, k, num_dims) x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2) return feature