import os import sys from typing import Optional, Tuple import numpy as np import scipy.spatial import torch import torch.nn.functional as F from scipy.special import logsumexp from scipy.spatial import cKDTree from torch import Tensor try: import segmentator except ImportError: segmentator = None parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, parent_dir) # ----------------------------------------------------------------------------- # label utilities # ----------------------------------------------------------------------------- def num_to_natural_numpy(group_ids, void_number=-1): """ Convert group IDs to contiguous natural numbers, preserving the void label. Args: group_ids (array-like): Input group labels, e.g., [-1, 0, 3, 4, 0, 6]. void_number (int): The label for 'void' class (only supports -1 or 0). Returns: array_ids (np.ndarray): Mapped IDs with contiguous natural numbers and voids untouched. """ group_ids = np.asarray(group_ids, dtype=int) if void_number not in (-1, 0): raise ValueError("void_number must be either -1 or 0.") void_mask = group_ids == void_number valid_ids = group_ids[~void_mask] if valid_ids.size == 0: return group_ids.copy() # all voids unique_ids = np.unique(valid_ids) remap = np.zeros(unique_ids.max() + 1, dtype=int) if void_number == -1: remap = np.full(unique_ids.max() + 1, -1, dtype=int) remap[unique_ids] = np.arange(len(unique_ids)) result = remap[group_ids.clip(min=0)] result[void_mask] = -1 else: remap[unique_ids] = np.arange(1, len(unique_ids) + 1) result = remap[group_ids] return result def num_to_natural_torch(group_ids, void_number=-1): group_ids_tensor = group_ids.long() device = group_ids_tensor.device if void_number == -1: if torch.all(group_ids_tensor == -1): return group_ids_tensor array_ids = group_ids_tensor.clone() unique_values = torch.unique(array_ids[array_ids != -1]) mapping = torch.full( (torch.max(unique_values) + 2,), -1, dtype=torch.long, device=device ) mapping[unique_values + 1] = torch.arange( len(unique_values), dtype=torch.long, device=device ) array_ids = mapping[array_ids + 1] elif void_number == 0: if torch.all(group_ids_tensor == 0): return group_ids_tensor array_ids = group_ids_tensor.clone() unique_values = torch.unique(array_ids[array_ids != 0]) mapping = torch.full( (torch.max(unique_values) + 2,), 0, dtype=torch.long, device=device ) mapping[unique_values] = ( torch.arange(len(unique_values), dtype=torch.long, device=device) + 1 ) array_ids = mapping[array_ids] else: raise ValueError("void_number must be -1 or 0") return array_ids # ----------------------------------------------------------------------------- # optional legacy segmentator # ----------------------------------------------------------------------------- def gen_superpoints(points, normals, k=50, kThresh=0.01, segMinVerts=20): from torch_cluster import knn_graph if segmentator is None or knn_graph is None: raise ImportError("segmentator and torch_cluster are required for gen_superpoints().") edges = knn_graph(points, k=k).T superpoint = segmentator.segment_point(points, normals, edges, kThresh, segMinVerts) return superpoint def _normalize_xyz_np(xyz: np.ndarray) -> np.ndarray: xyz = np.asarray(xyz, dtype=np.float32) center = xyz.mean(axis=0, keepdims=True) xyz0 = xyz - center bbmin = xyz0.min(axis=0) bbmax = xyz0.max(axis=0) diag = np.linalg.norm(bbmax - bbmin) if diag < 1e-12: diag = 1.0 return xyz0 / diag def _normalize_normals_np(normals: np.ndarray) -> np.ndarray: normals = np.asarray(normals, dtype=np.float32) nrm = np.linalg.norm(normals, axis=1, keepdims=True) nrm = np.clip(nrm, 1e-12, None) return normals / nrm def _build_knn_np(xyz: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: n = xyz.shape[0] if n <= 1: return ( np.empty((n, 0), dtype=np.float32), np.empty((n, 0), dtype=np.int64), ) k_eff = min(k + 1, n) tree = cKDTree(xyz) dists, inds = tree.query(xyz, k=k_eff, workers=-1) return dists[:, 1:], inds[:, 1:] def _local_geom_features_chunked_np( xyz: np.ndarray, k_feat: int = 10, chunk_size: int = 8192, ) -> np.ndarray: """ SPG-style local geometric features: linearity, planarity, scattering, verticality, elevation """ _, nbrs = _build_knn_np(xyz, k_feat) n = xyz.shape[0] k_eff = nbrs.shape[1] feat = np.empty((n, 5), dtype=np.float32) z = xyz[:, 2] zmin, zmax = z.min(), z.max() if zmax - zmin < 1e-12: elevation = np.zeros(n, dtype=np.float32) else: elevation = ((z - zmin) / (zmax - zmin)).astype(np.float32) eps = 1e-12 if k_eff == 0: feat[:, 0] = 0.0 feat[:, 1] = 0.0 feat[:, 2] = 1.0 feat[:, 3] = 1.0 feat[:, 4] = elevation return feat for s in range(0, n, chunk_size): e = min(s + chunk_size, n) pts = xyz[nbrs[s:e]] # [B, k, 3] mu = pts.mean(axis=1, keepdims=True) X = pts - mu cov = np.matmul(X.transpose(0, 2, 1), X) / float(max(k_eff, 1)) evals, evecs = np.linalg.eigh(cov.astype(np.float64)) evals = np.clip(evals, eps, None) l3 = evals[:, 0] l2 = evals[:, 1] l1 = evals[:, 2] denom = np.maximum(l1, eps) linearity = (l1 - l2) / denom planarity = (l2 - l3) / denom scattering = l3 / denom n_local = evecs[:, :, 0] verticality = 1.0 - np.abs(n_local[:, 2]) feat[s:e, 0] = linearity.astype(np.float32) feat[s:e, 1] = planarity.astype(np.float32) feat[s:e, 2] = scattering.astype(np.float32) feat[s:e, 3] = verticality.astype(np.float32) feat[s:e, 4] = elevation[s:e] return feat def _build_adj_graph_np( xyz: np.ndarray, k_adj: int = 10, mutual: bool = False, undirected: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: _, nbrs = _build_knn_np(xyz, k_adj) n, k = nbrs.shape if k == 0: return np.empty((0,), dtype=np.uint32), np.empty((0,), dtype=np.uint32) src = np.repeat(np.arange(n, dtype=np.uint32), k) dst = nbrs.reshape(-1).astype(np.uint32) keep = src != dst src = src[keep] dst = dst[keep] if mutual: code = src.astype(np.uint64) * np.uint64(n) + dst.astype(np.uint64) rev_code = dst.astype(np.uint64) * np.uint64(n) + src.astype(np.uint64) keep = np.isin(code, rev_code, assume_unique=False) src = src[keep] dst = dst[keep] if undirected: src0, dst0 = src, dst src = np.concatenate([src0, dst0], axis=0) dst = np.concatenate([dst0, src0], axis=0) return src.astype(np.uint32, copy=False), dst.astype(np.uint32, copy=False) def _edge_weights_chunked_np( Y: np.ndarray, # [D, N], Fortran order src: np.ndarray, dst: np.ndarray, lam: float = 5.0, sigma: float = 0.5, chunk_size: int = 1_000_000, ) -> np.ndarray: sigma2 = max(sigma * sigma, 1e-12) num_edges = src.shape[0] ew = np.empty(num_edges, dtype=np.float32) f = Y.T # [N, D] for s in range(0, num_edges, chunk_size): e = min(s + chunk_size, num_edges) diff = f[src[s:e]] - f[dst[s:e]] dist2 = np.sum(diff * diff, axis=1) ew[s:e] = lam * np.exp(-dist2 / sigma2) return ew def _edges_to_forward_star( n: int, src: np.ndarray, dst: np.ndarray, ew: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if src.size == 0: first_edge = np.zeros(n + 1, dtype=np.uint32) adj_vertices = np.empty((0,), dtype=np.uint32) edge_weights = np.empty((0,), dtype=np.float32) return first_edge, adj_vertices, edge_weights order = np.argsort(src, kind="stable") src = src[order] dst = dst[order] ew = ew[order] counts = np.bincount(src.astype(np.int64), minlength=n) first_edge = np.zeros(n + 1, dtype=np.uint32) first_edge[1:] = np.cumsum(counts, dtype=np.uint32) return ( first_edge, dst.astype(np.uint32, copy=False), ew.astype(np.float32, copy=False), ) def _relabel_contiguous_np(labels: np.ndarray) -> np.ndarray: _, inv = np.unique(labels, return_inverse=True) return inv.astype(np.int32) # ----------------------------------------------------------------------------- # superpoint-aware fps # ----------------------------------------------------------------------------- def superpoint_fps( xyz: Tensor, labels: Tensor, k: int, min_segment_points: int = 10, gamma: float = 0.7, eps: float = 1e-4, base_scale: float = 1.0, deterministic_start: bool = False, return_probs: bool = False, ): B, N, _ = xyz.shape if labels.shape != (B, N): raise ValueError(f"labels shape must be (B, N) = {(B, N)}, got {labels.shape}") if (labels < 0).any(): raise ValueError("labels must be non-negative") if not (0 < k <= N): raise ValueError(f"k must be in 1…N={N}, got {k}") device, dtype = xyz.device, xyz.dtype long_labels = labels.long() max_id_val = long_labels.max().item() if N else 0 counts = torch.zeros(B, max_id_val + 1, device=device, dtype=torch.float) counts.scatter_add_(1, long_labels, torch.ones_like(long_labels, dtype=torch.float)) counts = counts + eps seg_large_enough = counts >= float(min_segment_points) eligible_mask = seg_large_enough.gather(1, long_labels) if (eligible_mask.sum(dim=1) < k).any(): raise ValueError( "At least one batch item has fewer than k eligible points " f"(threshold={min_segment_points})." ) label_counts = counts.gather(1, long_labels) inv_probs = torch.where( eligible_mask, 1.0 / label_counts, torch.zeros_like(label_counts), ) inv_sum = inv_probs.sum(dim=1, keepdim=True).clamp(min=eps) probs = inv_probs / inv_sum present_mask = counts > (eps + 1e-9) counts_for_min = torch.where( present_mask, counts, torch.full_like(counts, float("inf")) ) min_counts_b = torch.where( torch.isinf(counts_for_min.min(1, keepdim=True).values), torch.ones_like(counts_for_min.min(1, keepdim=True).values), counts_for_min.min(1, keepdim=True).values, ) counts_for_max = torch.where(present_mask, counts, torch.zeros_like(counts)) max_counts_b = counts_for_max.max(1, keepdim=True).values.clamp(min=eps) size_ratio_b = (min_counts_b / max_counts_b).clamp(1e-3, 1.0) min_weight_b = (base_scale * size_ratio_b).clamp(min=1e-5) score_bias_weights = inv_probs ** gamma max_bias = score_bias_weights.max(1, keepdim=True)[0].clamp(min=eps) score_bias_weights = score_bias_weights / max_bias prob_weights = (1.0 - min_weight_b) * score_bias_weights + min_weight_b prob_weights = torch.where(eligible_mask, prob_weights, torch.zeros_like(prob_weights)) dist = torch.full((B, N), float("inf"), device=device, dtype=dtype) dist = torch.where(eligible_mask, dist, torch.zeros_like(dist)) idx = torch.zeros(B, k, dtype=torch.long, device=device) batch_idx = torch.arange(B, device=device) if deterministic_start: idx[:, 0] = probs.argmax(1) else: safe_probs = probs + eps safe_probs = safe_probs / safe_probs.sum(1, keepdim=True) idx[:, 0] = torch.multinomial(safe_probs, 1).squeeze(1) selected_xyz = torch.zeros(B, k, 3, device=device, dtype=dtype) selected_xyz[batch_idx, 0] = xyz[batch_idx, idx[:, 0]] dist = torch.minimum( dist, ((xyz - selected_xyz[:, 0:1]).pow(2)).sum(-1), ) for i in range(1, k): scores = dist * prob_weights idx[:, i] = scores.argmax(1) selected_xyz[batch_idx, i] = xyz[batch_idx, idx[:, i]] dist = torch.minimum( dist, ((xyz - selected_xyz[:, i:i + 1]).pow(2)).sum(-1), ) dist = torch.where(eligible_mask, dist, torch.zeros_like(dist)) return (idx, probs) if return_probs else idx # ----------------------------------------------------------------------------- # superpoint pooling # ----------------------------------------------------------------------------- def get_spt_centers( x: torch.Tensor, spts: torch.Tensor, reduce: str = "mean", ): B, N, C = x.shape dev = x.device spts = spts.long() sp_counts = spts.amax(dim=1) + 1 offsets = torch.cat( [torch.zeros(1, dtype=sp_counts.dtype, device=dev), sp_counts.cumsum(0)] )[:-1] spts_global = spts + offsets[:, None] x_flat = x.reshape(-1, C) spts_flat = spts_global.reshape(-1) tot_sp = int((offsets[-1] + sp_counts[-1]).item()) pooled = torch.zeros(tot_sp, C, dtype=x.dtype, device=dev) counts = torch.zeros(tot_sp, 1, dtype=x.dtype, device=dev) if reduce == "mean": pooled.index_add_(0, spts_flat, x_flat) counts.index_add_(0, spts_flat, torch.ones_like(x_flat[:, :1])) pooled = pooled / counts.clamp(min=1e-6) elif reduce == "max": pooled.fill_(-float("inf")) pooled = pooled.index_reduce(0, spts_flat, x_flat, reduce="amax") counts.index_add_(0, spts_flat, torch.ones_like(x_flat[:, :1])) else: raise ValueError(f"Unsupported reduction: {reduce}") K = int(sp_counts.max().item()) row_ids = torch.arange(K, device=dev).unsqueeze(0).expand(B, -1) valid_mask = row_ids < sp_counts.unsqueeze(1) gather_ix = torch.where( valid_mask, offsets[:, None] + row_ids, torch.full_like(row_ids, fill_value=tot_sp), ) pooled_ext = torch.cat([pooled, torch.zeros(1, C, device=dev, dtype=x.dtype)], dim=0) counts_ext = torch.cat([counts, torch.zeros(1, 1, device=dev, dtype=x.dtype)], dim=0) pooled_BKC = pooled_ext[gather_ix] counts_BK = counts_ext[gather_ix].squeeze(-1) mask = counts_BK > 0 pooled_BKC = pooled_BKC * mask.unsqueeze(-1).to(pooled_BKC.dtype) sp_ids_local = torch.where( valid_mask, row_ids, torch.full_like(row_ids, -1, dtype=torch.long), ) return pooled_BKC, mask, sp_ids_local, counts_BK # ----------------------------------------------------------------------------- # distances and coverage # ----------------------------------------------------------------------------- def masked_pairwise_distance(pc1, pc2, mask1, mask2, invalid_val=1e6): pc1_sq = (pc1 ** 2).sum(dim=2, keepdim=True) pc2_sq = (pc2 ** 2).sum(dim=2).unsqueeze(1) inner = torch.bmm(pc1, pc2.transpose(1, 2)) dists = pc1_sq - 2 * inner + pc2_sq dists = torch.clamp(dists, min=0).sqrt() mask1_expand = mask1.unsqueeze(2) mask2_expand = mask2.unsqueeze(1) valid_mask = mask1_expand * mask2_expand dists = torch.where(valid_mask.bool(), dists, torch.full_like(dists, invalid_val)) return dists def coverage_ratio(samples: Tensor, labels: Tensor) -> float: unique_sampled = torch.gather(labels, 1, samples).unique().numel() unique_total = labels.unique().numel() return unique_sampled / unique_total # ----------------------------------------------------------------------------- # fps # ----------------------------------------------------------------------------- @torch.no_grad() def fps(xyz: torch.Tensor, k: int) -> torch.LongTensor: N, dev = xyz.size(0), xyz.device sel = torch.empty(k, dtype=torch.long, device=dev) sel[0] = torch.randint(0, N, (1,), device=dev) dist2 = torch.full((N,), 1e9, device=dev) for i in range(1, k): d = ((xyz - xyz[sel[i - 1]]) ** 2).sum(1) dist2 = torch.minimum(dist2, d) sel[i] = torch.argmax(dist2) return sel def fps_np(point, npoint): N, D = point.shape xyz = point[:, :3] centroids = np.zeros((npoint,)) distance = np.ones((N,)) * 1e10 farthest = np.random.randint(0, N) for i in range(npoint): centroids[i] = farthest centroid = xyz[farthest, :] dist = np.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = np.argmax(distance, -1) ids = centroids.astype(np.int32) point = point[ids] return point, ids # ----------------------------------------------------------------------------- # ppf # ----------------------------------------------------------------------------- def calc_ppf_np(points, point_normals, patches, patch_normals): N, nsamples, _ = patches.shape points_expanded = np.expand_dims(points, axis=1) points_expanded = np.repeat(points_expanded, nsamples, axis=1) point_normals_expanded = np.expand_dims(point_normals, axis=1) point_normals_expanded = np.repeat(point_normals_expanded, nsamples, axis=1) vec_d = patches - points_expanded d = np.linalg.norm(vec_d, axis=-1, keepdims=True) dot1 = np.sum(point_normals_expanded * vec_d, axis=-1, keepdims=True) cross1 = np.cross(point_normals_expanded, vec_d) norm_cross1 = np.linalg.norm(cross1, axis=-1, keepdims=True) angle1 = np.arctan2(norm_cross1, dot1) / np.pi dot2 = np.sum(patch_normals * vec_d, axis=-1, keepdims=True) cross2 = np.cross(patch_normals, vec_d) norm_cross2 = np.linalg.norm(cross2, axis=-1, keepdims=True) angle2 = np.arctan2(norm_cross2, dot2) / np.pi dot3 = np.sum(point_normals_expanded * patch_normals, axis=-1, keepdims=True) cross3 = np.cross(point_normals_expanded, patch_normals) norm_cross3 = np.linalg.norm(cross3, axis=-1, keepdims=True) angle3 = np.arctan2(norm_cross3, dot3) / np.pi ppf = np.concatenate([d, angle1, angle2, angle3], axis=-1) return ppf def calc_ppf_gpu(points, point_normals, patches, patch_normals): points = torch.unsqueeze(points, dim=1).expand(-1, patches.shape[1], -1) point_normals = torch.unsqueeze(point_normals, dim=1).expand(-1, patches.shape[1], -1) vec_d = patches - points d = torch.sqrt(torch.sum(vec_d ** 2, dim=-1, keepdim=True)) y = torch.sum(point_normals * vec_d, dim=-1, keepdim=True) x = torch.cross(point_normals, vec_d, dim=-1) x = torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True)) angle1 = torch.atan2(x, y) / np.pi y = torch.sum(patch_normals * vec_d, dim=-1, keepdim=True) x = torch.cross(patch_normals, vec_d, dim=-1) x = torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True)) angle2 = torch.atan2(x, y) / np.pi y = torch.sum(point_normals * patch_normals, dim=-1, keepdim=True) x = torch.cross(point_normals, patch_normals, dim=-1) x = torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True)) angle3 = torch.atan2(x, y) / np.pi ppf = torch.cat([d, angle1, angle2, angle3], dim=-1) return ppf def calc_ppf_batch(points, point_normals, patches, patch_normals): B, N, S, _ = patches.shape points_exp = points.unsqueeze(2).expand(-1, -1, S, -1) normals_exp = point_normals.unsqueeze(2).expand(-1, -1, S, -1) vec_d = patches - points_exp d = torch.norm(vec_d, dim=-1, keepdim=True) y1 = torch.sum(normals_exp * vec_d, dim=-1, keepdim=True) x1 = torch.norm(torch.cross(normals_exp, vec_d, dim=-1), dim=-1, keepdim=True) angle1 = torch.atan2(x1, y1) / np.pi y2 = torch.sum(patch_normals * vec_d, dim=-1, keepdim=True) x2 = torch.norm(torch.cross(patch_normals, vec_d, dim=-1), dim=-1, keepdim=True) angle2 = torch.atan2(x2, y2) / np.pi y3 = torch.sum(normals_exp * patch_normals, dim=-1, keepdim=True) x3 = torch.norm(torch.cross(normals_exp, patch_normals, dim=-1), dim=-1, keepdim=True) angle3 = torch.atan2(x3, y3) / np.pi ppf = torch.cat([d, angle1, angle2, angle3], dim=-1) return ppf # ----------------------------------------------------------------------------- # misc geometry # ----------------------------------------------------------------------------- def calc_patch_scale(xyz): dist = torch.cdist(xyz, xyz) B, N, _ = dist.shape mask = torch.eye(N, device=xyz.device).unsqueeze(0).bool() dist.masked_fill_(mask, float("inf")) nn_dist, _ = dist.min(dim=-1) scales = torch.max(nn_dist, dim=-1)[0] return scales def query_ball_point(radius, nsample, xyz, new_xyz): device = xyz.device B, N, _ = xyz.shape _, S, _ = new_xyz.shape all_idx = torch.arange(N, dtype=torch.long, device=device) if torch.is_tensor(radius): assert radius.shape[0] == B, "radius must be (B,)" radius_val = radius.to(device).view(B, 1, 1) else: radius_val = float(radius) dists = torch.cdist(new_xyz, xyz) group_idx = all_idx.view(1, 1, N).expand(B, S, N).clone() group_idx[dists > radius_val] = N group_idx = group_idx.sort(dim=-1)[0][..., :nsample] first_valid = group_idx[..., 0].unsqueeze(-1).expand(-1, -1, nsample) mask = group_idx == N group_idx[mask] = first_valid[mask] return group_idx # ----------------------------------------------------------------------------- # geometric features # ----------------------------------------------------------------------------- def compute_geo_feats_np( P: np.ndarray, n_d_xyz: int = 2048, r: float = 0.1, eps: float = 1e-4, chunk_p: int = 64_000, chunk_q: int = 256_000, ) -> np.ndarray: P = np.asarray(P, dtype=np.float32) Q = fps_np(P, n_d_xyz)[0] Q = np.asarray(Q, dtype=np.float32) N = P.shape[0] eye3 = np.eye(3, dtype=P.dtype) r2 = r * r feats_all = [] for p0 in range(0, N, chunk_p): p1 = min(p0 + chunk_p, N) P_blk = P[p0:p1] p = P_blk.shape[0] cnt = np.zeros((p, 1), dtype=P.dtype) sum1 = np.zeros((p, 3), dtype=P.dtype) sum2 = np.zeros((p, 3, 3), dtype=P.dtype) for q0 in range(0, Q.shape[0], chunk_q): q1 = min(q0 + chunk_q, Q.shape[0]) Q_blk = Q[q0:q1] diff = P_blk[:, None, :] - Q_blk[None, :, :] d2 = np.sum(diff * diff, axis=2) mask = d2 < r2 if not mask.any(): continue w = mask.astype(P.dtype) cnt += np.sum(w, axis=1, keepdims=True) sum1 += w @ Q_blk Q_outer = Q_blk[:, :, None] * Q_blk[:, None, :] sum2 += np.einsum("pq,qkl->pkl", w, Q_outer) cnt[cnt == 0] = 1.0 C = ( sum2 - P_blk[:, :, None] * sum1[:, None, :] - sum1[:, :, None] * P_blk[:, None, :] + cnt[:, :, None] * (P_blk[:, :, None] * P_blk[:, None, :]) ) / cnt[:, :, None] C += eps * eye3 lamb = np.flip(np.linalg.eigvalsh(C), axis=1) lamb1, lamb2, lamb3 = lamb[:, 0], lamb[:, 1], lamb[:, 2] feats_blk = np.stack( [(lamb1 - lamb2) / lamb1, (lamb2 - lamb3) / lamb1, lamb3 / lamb1], axis=1, ) feats_all.append(feats_blk) return np.concatenate(feats_all, axis=0).astype(P.dtype) @torch.no_grad() def compute_geo_feats( P: torch.Tensor, Q: torch.Tensor, r: float = 0.1, eps: float = 1e-4, chunk_p: int = 64_000, chunk_q: int = 256_000, ) -> torch.Tensor: device = P.device dtype = P.dtype N = P.size(0) eye3 = torch.eye(3, device=device, dtype=dtype) feats_chunks = [] for p0 in range(0, N, chunk_p): p1 = min(p0 + chunk_p, N) P_blk = P[p0:p1] p = p1 - p0 cnt = torch.zeros(p, 1, device=device, dtype=dtype) sum1 = torch.zeros(p, 3, device=device, dtype=dtype) sum2 = torch.zeros(p, 3, 3, device=device, dtype=dtype) for q0 in range(0, Q.size(0), chunk_q): q1 = min(q0 + chunk_q, Q.size(0)) Q_blk = Q[q0:q1] d2 = torch.cdist(P_blk, Q_blk, p=2) mask = d2 < r if not mask.any(): continue w = mask.float() cnt += w.sum(1, keepdim=True) sum1 += w @ Q_blk Q_outer = Q_blk.unsqueeze(2) * Q_blk.unsqueeze(1) sum2 += torch.einsum("ij,jkl->ikl", w, Q_outer) cnt.clamp_(min=1.0) C = ( sum2 - P_blk.unsqueeze(2) * sum1.unsqueeze(1) - sum1.unsqueeze(2) * P_blk.unsqueeze(1) + cnt.unsqueeze(2) * (P_blk.unsqueeze(2) * P_blk.unsqueeze(1)) ) C = C / cnt.unsqueeze(2) + eps * eye3 lamb = torch.flip( torch.linalg.eigvalsh(C).clamp_min(1e-9), dims=[1], ) l1, l2, l3 = lamb[:, 0], lamb[:, 1], lamb[:, 2] blk_feat = torch.stack( [(l1 - l2) / l1, (l2 - l3) / l1, l3 / l1], dim=1, ) feats_chunks.append(blk_feat) return torch.cat(feats_chunks, dim=0) # ----------------------------------------------------------------------------- # sinkhorn # ----------------------------------------------------------------------------- @torch.no_grad() def sinkhorn_log_torch(scores, xi=0.05, iters=3, eps=1e-12): B, K = scores.shape log_q = (scores / xi).T log_q -= log_q.max() log_q -= torch.logsumexp(log_q, dim=(0, 1), keepdim=True) for _ in range(iters): log_q -= torch.logsumexp(log_q, dim=1, keepdim=True) + torch.log( torch.tensor(float(K), device=scores.device, dtype=scores.dtype) ) log_q -= torch.logsumexp(log_q, dim=0, keepdim=True) + torch.log( torch.tensor(float(B), device=scores.device, dtype=scores.dtype) ) log_q += torch.log(torch.tensor(float(B), device=scores.device, dtype=scores.dtype)) prob = torch.exp(log_q.T) prob = torch.clamp(prob, 1e-6, 1.0) prob = prob / (prob.sum(1, keepdim=True) + eps) return prob @torch.no_grad() def sinkhorn_chunked_log( rows_fn, N: int, K: int, tau: float = 0.05, iters: int = 60, chunk_p: int = 131_072, eps: float = 1e-9, dtype=torch.float32, device: torch.device = torch.device("cuda"), ): log_r = -torch.log(torch.tensor(float(N), dtype=dtype, device=device)) log_c = -torch.log(torch.tensor(float(K), dtype=dtype, device=device)) log_v = torch.zeros(K, dtype=dtype, device=device) for _ in range(iters): log_col_sum = torch.full((K,), -float("inf"), dtype=dtype, device=device) for i0 in range(0, N, chunk_p): sims = rows_fn(i0, min(i0 + chunk_p, N)).clamp_min_(eps) log_K = sims / tau log_u = log_r - torch.logsumexp(log_K + log_v, dim=1) blk = log_u[:, None] + log_K + log_v log_col_sum = torch.logaddexp(log_col_sum, torch.logsumexp(blk, dim=0)) log_v = log_c - log_col_sum labels = torch.empty(N, dtype=torch.long, device=device) mass = torch.zeros(K, dtype=dtype, device=device) for i0 in range(0, N, chunk_p): sims = rows_fn(i0, min(i0 + chunk_p, N)).clamp_min_(eps) log_K = sims / tau log_u = log_r - torch.logsumexp(log_K + log_v, dim=1) log_Y = log_u[:, None] + log_K + log_v Y = torch.exp(log_Y) labels[i0: i0 + Y.size(0)] = torch.argmax(Y, dim=1) mass += Y.sum(0) mass += eps return labels, mass def sinkhorn_chunked_log_np( rows_fn, N: int, K: int, tau: float = 0.05, iters: int = 60, chunk_p: int = 131_072, eps: float = 1e-9, dtype=np.float32, ): log_r = -np.log(float(N)).astype(dtype) log_c = -np.log(float(K)).astype(dtype) log_v = np.zeros((K,), dtype=dtype) for _ in range(iters): log_col_sum = np.full((K,), -np.inf, dtype=dtype) for i0 in range(0, N, chunk_p): sims = rows_fn(i0, min(i0 + chunk_p, N)).astype(dtype) sims = np.clip(sims, eps, None) log_K = sims / tau log_u = log_r - logsumexp(log_K + log_v, axis=1) blk = log_u[:, None] + log_K + log_v log_col_sum = np.logaddexp(log_col_sum, logsumexp(blk, axis=0)) log_v = log_c - log_col_sum labels = np.empty((N,), dtype=np.int64) mass = np.zeros((K,), dtype=dtype) for i0 in range(0, N, chunk_p): sims = rows_fn(i0, min(i0 + chunk_p, N)).astype(dtype) sims = np.clip(sims, eps, None) log_K = sims / tau log_u = log_r - logsumexp(log_K + log_v, axis=1) log_Y = log_u[:, None] + log_K + log_v Y = np.exp(log_Y) labels[i0: i0 + Y.shape[0]] = np.argmax(Y, axis=1) mass += Y.sum(axis=0) mass += eps return labels, mass # ----------------------------------------------------------------------------- # ot-fps clustering # ----------------------------------------------------------------------------- @torch.no_grad() def ot_fps_cluster_large( xyz: torch.Tensor, feats: torch.Tensor, K: int = 64, r_spatial: float = 0.05, tau: float = 0.05, sinkhorn_iters: int = 60, outer_iters: int = 3, chunk_P: int = 128, ): device, N = xyz.device, xyz.size(0) feats = F.normalize(feats, dim=1) sigma2 = ((xyz - xyz.mean(0, keepdim=True)) ** 2).sum(1).mean() EPS_TIE = 1e-6 centre_idx = fps(xyz, K) c_xyz = xyz[centre_idx].clone() c_fea = feats[centre_idx].clone() for _ in range(outer_iters): def rows_fn(i0: int, i1: int) -> torch.Tensor: pts = xyz[i0:i1] fea = feats[i0:i1] d2 = torch.cdist(pts, c_xyz) mask = d2 <= r_spatial empty = mask.logical_not().all(dim=1) if empty.any(): nearest = torch.argmin(d2[empty], dim=1) mask[empty, nearest] = True s_geo = torch.exp(-d2 / sigma2) * mask s_fea = torch.abs(torch.einsum("id,kd->ik", fea, c_fea)) return s_geo * s_fea + EPS_TIE * torch.arange(K, device=device) labels, mass = sinkhorn_chunked_log( rows_fn, N, K, tau=tau, iters=sinkhorn_iters, chunk_p=chunk_P, dtype=xyz.dtype, device=device, ) dead = mass < 1e-4 if dead.any(): live_xyz = c_xyz[~dead] dist2_live = torch.cdist(xyz, live_xyz) far_idx = torch.topk(dist2_live.min(1).values, dead.sum().item()).indices c_xyz[dead] = xyz[far_idx] c_fea[dead] = feats[far_idx] mass[dead] = 1.0 d2_to_c = torch.norm(xyz - c_xyz[labels], dim=1) w = (d2_to_c <= r_spatial).float() c_xyz.zero_() c_fea.zero_() c_xyz.index_add_(0, labels, xyz * w.unsqueeze(1)) c_fea.index_add_(0, labels, feats * w.unsqueeze(1)) mass = torch.zeros_like(mass).index_add_(0, labels, w) live = mass > 0 c_xyz[live] = c_xyz[live] / mass[live, None] c_fea[live] = F.normalize(c_fea[live] / mass[live, None], dim=1) return labels def ot_fps_cluster_large_np( xyz: np.ndarray, feats: np.ndarray, K: int = 64, r_spatial: float = 0.05, tau: float = 0.05, sinkhorn_iters: int = 10, outer_iters: int = 3, chunk_P: int = 128, use_scipy_cdist: bool = True, ): N = xyz.shape[0] feats = feats / (np.linalg.norm(feats, axis=1, keepdims=True) + 1e-9) sigma2 = np.mean(np.linalg.norm(xyz - xyz.mean(0), axis=1) ** 2) EPS_TIE = 1e-6 c_xyz, c_idx = fps_np(xyz, K) c_fea = feats[c_idx].copy() for _ in range(outer_iters): def rows_fn(i0, i1): pts = xyz[i0:i1] fea = feats[i0:i1] if use_scipy_cdist: d2 = scipy.spatial.distance.cdist(pts, c_xyz, metric="sqeuclidean") else: d2 = ((pts[:, None, :] - c_xyz[None, :, :]) ** 2).sum(-1) mask = d2 <= r_spatial ** 2 empty = np.all(~mask, axis=1) if np.any(empty): nearest = np.argmin(d2[empty], axis=1) mask[empty, nearest] = True s_geo = np.exp(-d2 / sigma2) * mask s_fea = np.abs(fea @ c_fea.T) return s_geo * s_fea + EPS_TIE * np.arange(K) labels, mass = sinkhorn_chunked_log_np( rows_fn, N, K, tau=tau, iters=sinkhorn_iters, chunk_p=chunk_P, dtype=xyz.dtype, ) dead = mass < 1e-4 if np.any(dead): live_xyz = c_xyz[~dead] dist2_live = ((xyz[:, None, :] - live_xyz[None, :, :]) ** 2).sum(-1) far_idx = np.argpartition(dist2_live.min(axis=1), -dead.sum())[-dead.sum():] c_xyz[dead] = xyz[far_idx] c_fea[dead] = feats[far_idx] mass[dead] = 1.0 d2_to_c = np.linalg.norm(xyz - c_xyz[labels], axis=1) w = (d2_to_c <= r_spatial).astype(xyz.dtype) c_xyz[:] = 0 c_fea[:] = 0 for k in range(K): sel = labels == k if sel.any(): c_xyz[k] = (xyz[sel] * w[sel, None]).sum(0) c_fea[k] = (feats[sel] * w[sel, None]).sum(0) mass[k] = w[sel].sum() live = mass > 0 c_xyz[live] /= mass[live, None] c_fea[live] = c_fea[live] / (np.linalg.norm(c_fea[live], axis=1, keepdims=True) + 1e-9) return labels.astype(np.int64) # ----------------------------------------------------------------------------- # high-level superpoint computation # ----------------------------------------------------------------------------- def compute_superpoints( xyz, feats=None, n_d_xyz=2048, n_clus=64, r_geo=0.1, r_clus=0.05, device=torch.device("cuda"), method="ot", normals=None, pycut_kwargs=None, ): """ method: 'ot' -> original OT/FPS clustering 'pycut' -> L0-cut-pursuit superpoints via libcp """ if method == "pycut": kw = dict( k_feat=10, k_adj=10, chunk_size=8192, use_input_normals=True, use_xyz=False, xyz_scale=0.10, normal_scale=0.25, lam=0.03, sigma=0.5, mutual=False, undirected=True, min_comp_weight=20, weight_decay=0.7, verbose=False, ) if pycut_kwargs is not None: kw.update(pycut_kwargs) xyz_f = np.asarray(xyz, dtype=np.float32) xyz_norm = _normalize_xyz_np(xyz_f) geom_feat = _local_geom_features_chunked_np( xyz_norm, k_feat=kw["k_feat"], chunk_size=kw["chunk_size"], ) feat_parts = [geom_feat] if kw["use_input_normals"] and normals is not None: nn = _normalize_normals_np(np.asarray(normals, dtype=np.float32)) feat_parts.append(nn * kw["normal_scale"]) if kw["use_xyz"]: feat_parts.append(xyz_norm * kw["xyz_scale"]) Y = np.hstack(feat_parts).astype(np.float32) src, dst = _build_adj_graph_np( xyz_norm, k_adj=kw["k_adj"], mutual=kw["mutual"], undirected=kw["undirected"], ) ew = _edge_weights_chunked_np(Y.T, src, dst, lam=1.0, sigma=kw["sigma"]) try: import libcp except ImportError: libcp_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "_cut_pursuit", "build", "src", ) sys.path.insert(0, libcp_dir) import libcp components, in_component = libcp.cutpursuit( Y, src.astype(np.uint32), dst.astype(np.uint32), ew.astype(np.float32), float(kw["lam"]), int(kw["min_comp_weight"]), 0, float(kw["weight_decay"]), ) lbl_np = _relabel_contiguous_np(np.asarray(in_component, dtype=np.int32)) lbl = torch.from_numpy(lbl_np.astype(np.int64)).to(device) feats_out = torch.from_numpy(Y).to(device) return lbl, feats_out # default: OT/FPS clustering P = torch.from_numpy(xyz).to(device) if feats is None: Q = P[fps(P, n_d_xyz)] feats = compute_geo_feats(P, Q, r=r_geo, chunk_p=4096, chunk_q=256) lbl = ot_fps_cluster_large(P, feats, K=n_clus, r_spatial=r_clus) return lbl, feats # ----------------------------------------------------------------------------- # constrained fps # ----------------------------------------------------------------------------- def sp_constrained_fps(xyz: torch.Tensor, sp: torch.Tensor, L: int): device, N = xyz.device, xyz.size(0) K = sp.max().item() + 1 assert L >= K perm = torch.randperm(N, device=device) anchors = torch.full((K,), N, dtype=torch.long, device=device) anchors.scatter_reduce_(0, sp[perm], perm, reduce="amin") idx = torch.empty(L, dtype=torch.long, device=device) idx[:K] = anchors dist2 = torch.cdist(xyz, xyz[anchors]).pow(2).min(dim=1).values for i in range(K, L): nxt = torch.argmax(dist2) idx[i] = nxt dist2 = torch.minimum(dist2, ((xyz - xyz[nxt]) ** 2).sum(-1)) return idx def batch_random_anchor(sp: torch.Tensor): device = sp.device B, N = sp.shape K = sp.max().item() + 1 perm = torch.argsort(torch.rand(B, N, device=device), dim=1) labels_perm = sp.gather(1, perm) anchors = torch.full((B, K), N, dtype=torch.long, device=device) anchors.scatter_reduce_(1, labels_perm, perm, reduce="amin") return anchors def sp_constrained_fps_batch( xyz: torch.Tensor, sp: torch.Tensor, L: int, ) -> torch.LongTensor: device = xyz.device B, N, _ = xyz.shape K = sp.max().item() + 1 assert L >= K perm = torch.argsort(torch.rand(B, N, device=device), dim=1) anchors = torch.full((B, K), N, dtype=torch.long, device=device) anchors.scatter_reduce_(1, sp.gather(1, perm), perm, reduce="amin") idx = torch.empty(B, L, dtype=torch.long, device=device) idx[:, :K] = anchors c_xyz = xyz.gather(1, anchors[..., None].expand(-1, -1, 3)) dist2 = torch.cdist(xyz, c_xyz, p=2).min(dim=2).values b_ids = torch.arange(B, device=device) for i in range(K, L): farthest = dist2.max(dim=1).indices idx[:, i] = farthest new_c = xyz[b_ids, farthest] dist2 = torch.minimum(dist2, ((xyz - new_c[:, None, :]) ** 2).sum(-1)) return idx def memory_efficient_fps_prob( xyz: Tensor, prob: Tensor, k: int, gamma: float = 0.5, eps: float = 1e-12, chunk_size: int = 32, ) -> Tensor: B, N, _ = xyz.shape device = xyz.device dist = torch.full((B, N), float("inf"), device=device) idx = torch.zeros(B, k, dtype=torch.long, device=device) idx[:, 0] = prob.argmax(dim=1) for i in range(1, k): last_xyz = xyz.gather(1, idx[:, i - 1:i].unsqueeze(-1).expand(-1, -1, 3)) min_dist = torch.full((B, N), float("inf"), device=device) for chunk_start in range(0, N, chunk_size): chunk_end = min(chunk_start + chunk_size, N) chunk_xyz = xyz[:, chunk_start:chunk_end, :] chunk_dist = ((chunk_xyz - last_xyz) ** 2).sum(-1) min_dist[:, chunk_start:chunk_end] = torch.minimum( dist[:, chunk_start:chunk_end], chunk_dist, ) dist = min_dist score = dist * (prob + eps) ** -gamma idx[:, i] = score.argmax(dim=1) return idx # ----------------------------------------------------------------------------- # superpoint smoothing # ----------------------------------------------------------------------------- def _superpoint_pool(feat: torch.Tensor, spts: torch.Tensor) -> torch.Tensor: B, N, D = feat.shape feat_flat = feat.reshape(B * N, D) K = int(spts.max().item()) + 1 offsets = (torch.arange(B, device=feat.device) * K).view(B, 1) sp_offset = (spts + offsets).reshape(-1) tot_sp = B * K feat_sum = torch.zeros(tot_sp, D, device=feat.device, dtype=feat.dtype) feat_sum.index_add_(0, sp_offset, feat_flat) cnt_sum = torch.zeros(tot_sp, 1, device=feat.device, dtype=feat.dtype) cnt_sum.index_add_(0, sp_offset, torch.ones(B * N, 1, device=feat.device, dtype=feat.dtype)) feat_avg = feat_sum / (cnt_sum + 1e-6) feat_denoised = feat_avg[sp_offset].view(B, N, D) return feat_denoised # ----------------------------------------------------------------------------- # example # ----------------------------------------------------------------------------- if __name__ == "__main__": from lib_vis import create_labeled_point_cloud import open3d as o3d pcd = o3d.io.read_point_cloud("s3dis.ply") print(pcd) xyz = np.array(pcd.points) normals = None normals_path = "../normal.npy" if os.path.exists(normals_path): normals = np.load(normals_path).astype(np.float32) torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --------------------------------------------------------- # Option 1: pycut-pursuit superpoints # --------------------------------------------------------- lbl, feats = compute_superpoints( xyz, normals=normals, device=device, method="pycut", pycut_kwargs=dict( k_feat=10, k_adj=10, chunk_size=8192, use_input_normals=(normals is not None), use_xyz=False, xyz_scale=0.10, normal_scale=0.25, lam=5.0, sigma=0.5, mutual=False, undirected=True, cp_it_max=10, split_iter_num=2, split_damp_ratio=0.7, kmpp_init_num=3, kmpp_iter_num=3, K=2, min_comp_weight=20.0, verbose=False, ), ) pcd = torch.from_numpy(xyz).to(device) idx = sp_constrained_fps(pcd, lbl, 2048).tolist() down_pcd = pcd[idx] down_lbl = lbl[idx] create_labeled_point_cloud(xyz[idx], lbl[idx].tolist(), name="cls_pycut") down_idx = superpoint_fps( down_pcd.unsqueeze(0), down_lbl.unsqueeze(0), k=128, gamma=0.1, base_scale=1.5, )[0].tolist() create_labeled_point_cloud( down_pcd[down_idx].cpu().numpy(), down_lbl[down_idx].tolist(), name="fps_pycut", ) print(torch.unique(down_lbl[down_idx], return_counts=True))