| import os |
| import sys |
| from typing import Optional, Tuple |
|
|
| import numpy as np |
| import scipy.spatial |
| import torch |
| import torch.nn.functional as F |
| from scipy.special import logsumexp |
| from scipy.spatial import cKDTree |
| from torch import Tensor |
|
|
| try: |
| import segmentator |
| except ImportError: |
| segmentator = None |
|
|
|
|
| parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| sys.path.insert(0, parent_dir) |
|
|
|
|
| |
| |
| |
| def num_to_natural_numpy(group_ids, void_number=-1): |
| """ |
| Convert group IDs to contiguous natural numbers, preserving the void label. |
| |
| Args: |
| group_ids (array-like): Input group labels, e.g., [-1, 0, 3, 4, 0, 6]. |
| void_number (int): The label for 'void' class (only supports -1 or 0). |
| |
| Returns: |
| array_ids (np.ndarray): Mapped IDs with contiguous natural numbers and voids untouched. |
| """ |
| group_ids = np.asarray(group_ids, dtype=int) |
|
|
| if void_number not in (-1, 0): |
| raise ValueError("void_number must be either -1 or 0.") |
|
|
| void_mask = group_ids == void_number |
| valid_ids = group_ids[~void_mask] |
|
|
| if valid_ids.size == 0: |
| return group_ids.copy() |
|
|
| unique_ids = np.unique(valid_ids) |
| remap = np.zeros(unique_ids.max() + 1, dtype=int) |
|
|
| if void_number == -1: |
| remap = np.full(unique_ids.max() + 1, -1, dtype=int) |
| remap[unique_ids] = np.arange(len(unique_ids)) |
| result = remap[group_ids.clip(min=0)] |
| result[void_mask] = -1 |
| else: |
| remap[unique_ids] = np.arange(1, len(unique_ids) + 1) |
| result = remap[group_ids] |
|
|
| return result |
|
|
|
|
| def num_to_natural_torch(group_ids, void_number=-1): |
| group_ids_tensor = group_ids.long() |
| device = group_ids_tensor.device |
|
|
| if void_number == -1: |
| if torch.all(group_ids_tensor == -1): |
| return group_ids_tensor |
| array_ids = group_ids_tensor.clone() |
|
|
| unique_values = torch.unique(array_ids[array_ids != -1]) |
| mapping = torch.full( |
| (torch.max(unique_values) + 2,), -1, dtype=torch.long, device=device |
| ) |
| mapping[unique_values + 1] = torch.arange( |
| len(unique_values), dtype=torch.long, device=device |
| ) |
| array_ids = mapping[array_ids + 1] |
|
|
| elif void_number == 0: |
| if torch.all(group_ids_tensor == 0): |
| return group_ids_tensor |
| array_ids = group_ids_tensor.clone() |
|
|
| unique_values = torch.unique(array_ids[array_ids != 0]) |
| mapping = torch.full( |
| (torch.max(unique_values) + 2,), 0, dtype=torch.long, device=device |
| ) |
| mapping[unique_values] = ( |
| torch.arange(len(unique_values), dtype=torch.long, device=device) + 1 |
| ) |
| array_ids = mapping[array_ids] |
| else: |
| raise ValueError("void_number must be -1 or 0") |
|
|
| return array_ids |
|
|
|
|
| |
| |
| |
| def gen_superpoints(points, normals, k=50, kThresh=0.01, segMinVerts=20): |
| from torch_cluster import knn_graph |
| if segmentator is None or knn_graph is None: |
| raise ImportError("segmentator and torch_cluster are required for gen_superpoints().") |
| edges = knn_graph(points, k=k).T |
| superpoint = segmentator.segment_point(points, normals, edges, kThresh, segMinVerts) |
| return superpoint |
|
|
|
|
|
|
|
|
| def _normalize_xyz_np(xyz: np.ndarray) -> np.ndarray: |
| xyz = np.asarray(xyz, dtype=np.float32) |
| center = xyz.mean(axis=0, keepdims=True) |
| xyz0 = xyz - center |
| bbmin = xyz0.min(axis=0) |
| bbmax = xyz0.max(axis=0) |
| diag = np.linalg.norm(bbmax - bbmin) |
| if diag < 1e-12: |
| diag = 1.0 |
| return xyz0 / diag |
|
|
|
|
| def _normalize_normals_np(normals: np.ndarray) -> np.ndarray: |
| normals = np.asarray(normals, dtype=np.float32) |
| nrm = np.linalg.norm(normals, axis=1, keepdims=True) |
| nrm = np.clip(nrm, 1e-12, None) |
| return normals / nrm |
|
|
|
|
| def _build_knn_np(xyz: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: |
| n = xyz.shape[0] |
| if n <= 1: |
| return ( |
| np.empty((n, 0), dtype=np.float32), |
| np.empty((n, 0), dtype=np.int64), |
| ) |
|
|
| k_eff = min(k + 1, n) |
| tree = cKDTree(xyz) |
| dists, inds = tree.query(xyz, k=k_eff, workers=-1) |
| return dists[:, 1:], inds[:, 1:] |
|
|
|
|
| def _local_geom_features_chunked_np( |
| xyz: np.ndarray, |
| k_feat: int = 10, |
| chunk_size: int = 8192, |
| ) -> np.ndarray: |
| """ |
| SPG-style local geometric features: |
| linearity, planarity, scattering, verticality, elevation |
| """ |
| _, nbrs = _build_knn_np(xyz, k_feat) |
| n = xyz.shape[0] |
| k_eff = nbrs.shape[1] |
|
|
| feat = np.empty((n, 5), dtype=np.float32) |
|
|
| z = xyz[:, 2] |
| zmin, zmax = z.min(), z.max() |
| if zmax - zmin < 1e-12: |
| elevation = np.zeros(n, dtype=np.float32) |
| else: |
| elevation = ((z - zmin) / (zmax - zmin)).astype(np.float32) |
|
|
| eps = 1e-12 |
|
|
| if k_eff == 0: |
| feat[:, 0] = 0.0 |
| feat[:, 1] = 0.0 |
| feat[:, 2] = 1.0 |
| feat[:, 3] = 1.0 |
| feat[:, 4] = elevation |
| return feat |
|
|
| for s in range(0, n, chunk_size): |
| e = min(s + chunk_size, n) |
| pts = xyz[nbrs[s:e]] |
|
|
| mu = pts.mean(axis=1, keepdims=True) |
| X = pts - mu |
|
|
| cov = np.matmul(X.transpose(0, 2, 1), X) / float(max(k_eff, 1)) |
| evals, evecs = np.linalg.eigh(cov.astype(np.float64)) |
| evals = np.clip(evals, eps, None) |
|
|
| l3 = evals[:, 0] |
| l2 = evals[:, 1] |
| l1 = evals[:, 2] |
| denom = np.maximum(l1, eps) |
|
|
| linearity = (l1 - l2) / denom |
| planarity = (l2 - l3) / denom |
| scattering = l3 / denom |
|
|
| n_local = evecs[:, :, 0] |
| verticality = 1.0 - np.abs(n_local[:, 2]) |
|
|
| feat[s:e, 0] = linearity.astype(np.float32) |
| feat[s:e, 1] = planarity.astype(np.float32) |
| feat[s:e, 2] = scattering.astype(np.float32) |
| feat[s:e, 3] = verticality.astype(np.float32) |
| feat[s:e, 4] = elevation[s:e] |
|
|
| return feat |
|
|
|
|
| def _build_adj_graph_np( |
| xyz: np.ndarray, |
| k_adj: int = 10, |
| mutual: bool = False, |
| undirected: bool = True, |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| _, nbrs = _build_knn_np(xyz, k_adj) |
| n, k = nbrs.shape |
|
|
| if k == 0: |
| return np.empty((0,), dtype=np.uint32), np.empty((0,), dtype=np.uint32) |
|
|
| src = np.repeat(np.arange(n, dtype=np.uint32), k) |
| dst = nbrs.reshape(-1).astype(np.uint32) |
|
|
| keep = src != dst |
| src = src[keep] |
| dst = dst[keep] |
|
|
| if mutual: |
| code = src.astype(np.uint64) * np.uint64(n) + dst.astype(np.uint64) |
| rev_code = dst.astype(np.uint64) * np.uint64(n) + src.astype(np.uint64) |
| keep = np.isin(code, rev_code, assume_unique=False) |
| src = src[keep] |
| dst = dst[keep] |
|
|
| if undirected: |
| src0, dst0 = src, dst |
| src = np.concatenate([src0, dst0], axis=0) |
| dst = np.concatenate([dst0, src0], axis=0) |
|
|
| return src.astype(np.uint32, copy=False), dst.astype(np.uint32, copy=False) |
|
|
|
|
| def _edge_weights_chunked_np( |
| Y: np.ndarray, |
| src: np.ndarray, |
| dst: np.ndarray, |
| lam: float = 5.0, |
| sigma: float = 0.5, |
| chunk_size: int = 1_000_000, |
| ) -> np.ndarray: |
| sigma2 = max(sigma * sigma, 1e-12) |
| num_edges = src.shape[0] |
| ew = np.empty(num_edges, dtype=np.float32) |
| f = Y.T |
|
|
| for s in range(0, num_edges, chunk_size): |
| e = min(s + chunk_size, num_edges) |
| diff = f[src[s:e]] - f[dst[s:e]] |
| dist2 = np.sum(diff * diff, axis=1) |
| ew[s:e] = lam * np.exp(-dist2 / sigma2) |
|
|
| return ew |
|
|
|
|
| def _edges_to_forward_star( |
| n: int, |
| src: np.ndarray, |
| dst: np.ndarray, |
| ew: np.ndarray, |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| if src.size == 0: |
| first_edge = np.zeros(n + 1, dtype=np.uint32) |
| adj_vertices = np.empty((0,), dtype=np.uint32) |
| edge_weights = np.empty((0,), dtype=np.float32) |
| return first_edge, adj_vertices, edge_weights |
|
|
| order = np.argsort(src, kind="stable") |
| src = src[order] |
| dst = dst[order] |
| ew = ew[order] |
|
|
| counts = np.bincount(src.astype(np.int64), minlength=n) |
| first_edge = np.zeros(n + 1, dtype=np.uint32) |
| first_edge[1:] = np.cumsum(counts, dtype=np.uint32) |
|
|
| return ( |
| first_edge, |
| dst.astype(np.uint32, copy=False), |
| ew.astype(np.float32, copy=False), |
| ) |
|
|
|
|
| def _relabel_contiguous_np(labels: np.ndarray) -> np.ndarray: |
| _, inv = np.unique(labels, return_inverse=True) |
| return inv.astype(np.int32) |
|
|
|
|
| |
| |
| |
| def superpoint_fps( |
| xyz: Tensor, |
| labels: Tensor, |
| k: int, |
| min_segment_points: int = 10, |
| gamma: float = 0.7, |
| eps: float = 1e-4, |
| base_scale: float = 1.0, |
| deterministic_start: bool = False, |
| return_probs: bool = False, |
| ): |
| B, N, _ = xyz.shape |
| if labels.shape != (B, N): |
| raise ValueError(f"labels shape must be (B, N) = {(B, N)}, got {labels.shape}") |
| if (labels < 0).any(): |
| raise ValueError("labels must be non-negative") |
| if not (0 < k <= N): |
| raise ValueError(f"k must be in 1…N={N}, got {k}") |
|
|
| device, dtype = xyz.device, xyz.dtype |
| long_labels = labels.long() |
|
|
| max_id_val = long_labels.max().item() if N else 0 |
| counts = torch.zeros(B, max_id_val + 1, device=device, dtype=torch.float) |
| counts.scatter_add_(1, long_labels, torch.ones_like(long_labels, dtype=torch.float)) |
| counts = counts + eps |
|
|
| seg_large_enough = counts >= float(min_segment_points) |
| eligible_mask = seg_large_enough.gather(1, long_labels) |
|
|
| if (eligible_mask.sum(dim=1) < k).any(): |
| raise ValueError( |
| "At least one batch item has fewer than k eligible points " |
| f"(threshold={min_segment_points})." |
| ) |
|
|
| label_counts = counts.gather(1, long_labels) |
| inv_probs = torch.where( |
| eligible_mask, |
| 1.0 / label_counts, |
| torch.zeros_like(label_counts), |
| ) |
|
|
| inv_sum = inv_probs.sum(dim=1, keepdim=True).clamp(min=eps) |
| probs = inv_probs / inv_sum |
|
|
| present_mask = counts > (eps + 1e-9) |
| counts_for_min = torch.where( |
| present_mask, counts, torch.full_like(counts, float("inf")) |
| ) |
| min_counts_b = torch.where( |
| torch.isinf(counts_for_min.min(1, keepdim=True).values), |
| torch.ones_like(counts_for_min.min(1, keepdim=True).values), |
| counts_for_min.min(1, keepdim=True).values, |
| ) |
| counts_for_max = torch.where(present_mask, counts, torch.zeros_like(counts)) |
| max_counts_b = counts_for_max.max(1, keepdim=True).values.clamp(min=eps) |
| size_ratio_b = (min_counts_b / max_counts_b).clamp(1e-3, 1.0) |
| min_weight_b = (base_scale * size_ratio_b).clamp(min=1e-5) |
|
|
| score_bias_weights = inv_probs ** gamma |
| max_bias = score_bias_weights.max(1, keepdim=True)[0].clamp(min=eps) |
| score_bias_weights = score_bias_weights / max_bias |
| prob_weights = (1.0 - min_weight_b) * score_bias_weights + min_weight_b |
| prob_weights = torch.where(eligible_mask, prob_weights, torch.zeros_like(prob_weights)) |
|
|
| dist = torch.full((B, N), float("inf"), device=device, dtype=dtype) |
| dist = torch.where(eligible_mask, dist, torch.zeros_like(dist)) |
| idx = torch.zeros(B, k, dtype=torch.long, device=device) |
| batch_idx = torch.arange(B, device=device) |
|
|
| if deterministic_start: |
| idx[:, 0] = probs.argmax(1) |
| else: |
| safe_probs = probs + eps |
| safe_probs = safe_probs / safe_probs.sum(1, keepdim=True) |
| idx[:, 0] = torch.multinomial(safe_probs, 1).squeeze(1) |
|
|
| selected_xyz = torch.zeros(B, k, 3, device=device, dtype=dtype) |
| selected_xyz[batch_idx, 0] = xyz[batch_idx, idx[:, 0]] |
| dist = torch.minimum( |
| dist, |
| ((xyz - selected_xyz[:, 0:1]).pow(2)).sum(-1), |
| ) |
|
|
| for i in range(1, k): |
| scores = dist * prob_weights |
| idx[:, i] = scores.argmax(1) |
| selected_xyz[batch_idx, i] = xyz[batch_idx, idx[:, i]] |
| dist = torch.minimum( |
| dist, |
| ((xyz - selected_xyz[:, i:i + 1]).pow(2)).sum(-1), |
| ) |
| dist = torch.where(eligible_mask, dist, torch.zeros_like(dist)) |
|
|
| return (idx, probs) if return_probs else idx |
|
|
|
|
| |
| |
| |
| def get_spt_centers( |
| x: torch.Tensor, |
| spts: torch.Tensor, |
| reduce: str = "mean", |
| ): |
| B, N, C = x.shape |
| dev = x.device |
| spts = spts.long() |
|
|
| sp_counts = spts.amax(dim=1) + 1 |
| offsets = torch.cat( |
| [torch.zeros(1, dtype=sp_counts.dtype, device=dev), sp_counts.cumsum(0)] |
| )[:-1] |
| spts_global = spts + offsets[:, None] |
|
|
| x_flat = x.reshape(-1, C) |
| spts_flat = spts_global.reshape(-1) |
| tot_sp = int((offsets[-1] + sp_counts[-1]).item()) |
|
|
| pooled = torch.zeros(tot_sp, C, dtype=x.dtype, device=dev) |
| counts = torch.zeros(tot_sp, 1, dtype=x.dtype, device=dev) |
|
|
| if reduce == "mean": |
| pooled.index_add_(0, spts_flat, x_flat) |
| counts.index_add_(0, spts_flat, torch.ones_like(x_flat[:, :1])) |
| pooled = pooled / counts.clamp(min=1e-6) |
| elif reduce == "max": |
| pooled.fill_(-float("inf")) |
| pooled = pooled.index_reduce(0, spts_flat, x_flat, reduce="amax") |
| counts.index_add_(0, spts_flat, torch.ones_like(x_flat[:, :1])) |
| else: |
| raise ValueError(f"Unsupported reduction: {reduce}") |
|
|
| K = int(sp_counts.max().item()) |
| row_ids = torch.arange(K, device=dev).unsqueeze(0).expand(B, -1) |
| valid_mask = row_ids < sp_counts.unsqueeze(1) |
| gather_ix = torch.where( |
| valid_mask, |
| offsets[:, None] + row_ids, |
| torch.full_like(row_ids, fill_value=tot_sp), |
| ) |
|
|
| pooled_ext = torch.cat([pooled, torch.zeros(1, C, device=dev, dtype=x.dtype)], dim=0) |
| counts_ext = torch.cat([counts, torch.zeros(1, 1, device=dev, dtype=x.dtype)], dim=0) |
|
|
| pooled_BKC = pooled_ext[gather_ix] |
| counts_BK = counts_ext[gather_ix].squeeze(-1) |
| mask = counts_BK > 0 |
| pooled_BKC = pooled_BKC * mask.unsqueeze(-1).to(pooled_BKC.dtype) |
|
|
| sp_ids_local = torch.where( |
| valid_mask, |
| row_ids, |
| torch.full_like(row_ids, -1, dtype=torch.long), |
| ) |
|
|
| return pooled_BKC, mask, sp_ids_local, counts_BK |
|
|
|
|
| |
| |
| |
| def masked_pairwise_distance(pc1, pc2, mask1, mask2, invalid_val=1e6): |
| pc1_sq = (pc1 ** 2).sum(dim=2, keepdim=True) |
| pc2_sq = (pc2 ** 2).sum(dim=2).unsqueeze(1) |
| inner = torch.bmm(pc1, pc2.transpose(1, 2)) |
| dists = pc1_sq - 2 * inner + pc2_sq |
| dists = torch.clamp(dists, min=0).sqrt() |
|
|
| mask1_expand = mask1.unsqueeze(2) |
| mask2_expand = mask2.unsqueeze(1) |
| valid_mask = mask1_expand * mask2_expand |
|
|
| dists = torch.where(valid_mask.bool(), dists, torch.full_like(dists, invalid_val)) |
| return dists |
|
|
|
|
| def coverage_ratio(samples: Tensor, labels: Tensor) -> float: |
| unique_sampled = torch.gather(labels, 1, samples).unique().numel() |
| unique_total = labels.unique().numel() |
| return unique_sampled / unique_total |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def fps(xyz: torch.Tensor, k: int) -> torch.LongTensor: |
| N, dev = xyz.size(0), xyz.device |
| sel = torch.empty(k, dtype=torch.long, device=dev) |
| sel[0] = torch.randint(0, N, (1,), device=dev) |
| dist2 = torch.full((N,), 1e9, device=dev) |
| for i in range(1, k): |
| d = ((xyz - xyz[sel[i - 1]]) ** 2).sum(1) |
| dist2 = torch.minimum(dist2, d) |
| sel[i] = torch.argmax(dist2) |
| return sel |
|
|
|
|
| def fps_np(point, npoint): |
| N, D = point.shape |
| xyz = point[:, :3] |
| centroids = np.zeros((npoint,)) |
| distance = np.ones((N,)) * 1e10 |
| farthest = np.random.randint(0, N) |
| for i in range(npoint): |
| centroids[i] = farthest |
| centroid = xyz[farthest, :] |
| dist = np.sum((xyz - centroid) ** 2, -1) |
| mask = dist < distance |
| distance[mask] = dist[mask] |
| farthest = np.argmax(distance, -1) |
| ids = centroids.astype(np.int32) |
| point = point[ids] |
| return point, ids |
|
|
|
|
| |
| |
| |
| def calc_ppf_np(points, point_normals, patches, patch_normals): |
| N, nsamples, _ = patches.shape |
|
|
| points_expanded = np.expand_dims(points, axis=1) |
| points_expanded = np.repeat(points_expanded, nsamples, axis=1) |
|
|
| point_normals_expanded = np.expand_dims(point_normals, axis=1) |
| point_normals_expanded = np.repeat(point_normals_expanded, nsamples, axis=1) |
|
|
| vec_d = patches - points_expanded |
| d = np.linalg.norm(vec_d, axis=-1, keepdims=True) |
|
|
| dot1 = np.sum(point_normals_expanded * vec_d, axis=-1, keepdims=True) |
| cross1 = np.cross(point_normals_expanded, vec_d) |
| norm_cross1 = np.linalg.norm(cross1, axis=-1, keepdims=True) |
| angle1 = np.arctan2(norm_cross1, dot1) / np.pi |
|
|
| dot2 = np.sum(patch_normals * vec_d, axis=-1, keepdims=True) |
| cross2 = np.cross(patch_normals, vec_d) |
| norm_cross2 = np.linalg.norm(cross2, axis=-1, keepdims=True) |
| angle2 = np.arctan2(norm_cross2, dot2) / np.pi |
|
|
| dot3 = np.sum(point_normals_expanded * patch_normals, axis=-1, keepdims=True) |
| cross3 = np.cross(point_normals_expanded, patch_normals) |
| norm_cross3 = np.linalg.norm(cross3, axis=-1, keepdims=True) |
| angle3 = np.arctan2(norm_cross3, dot3) / np.pi |
|
|
| ppf = np.concatenate([d, angle1, angle2, angle3], axis=-1) |
| return ppf |
|
|
|
|
| def calc_ppf_gpu(points, point_normals, patches, patch_normals): |
| points = torch.unsqueeze(points, dim=1).expand(-1, patches.shape[1], -1) |
| point_normals = torch.unsqueeze(point_normals, dim=1).expand(-1, patches.shape[1], -1) |
| vec_d = patches - points |
| d = torch.sqrt(torch.sum(vec_d ** 2, dim=-1, keepdim=True)) |
|
|
| y = torch.sum(point_normals * vec_d, dim=-1, keepdim=True) |
| x = torch.cross(point_normals, vec_d, dim=-1) |
| x = torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True)) |
| angle1 = torch.atan2(x, y) / np.pi |
|
|
| y = torch.sum(patch_normals * vec_d, dim=-1, keepdim=True) |
| x = torch.cross(patch_normals, vec_d, dim=-1) |
| x = torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True)) |
| angle2 = torch.atan2(x, y) / np.pi |
|
|
| y = torch.sum(point_normals * patch_normals, dim=-1, keepdim=True) |
| x = torch.cross(point_normals, patch_normals, dim=-1) |
| x = torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True)) |
| angle3 = torch.atan2(x, y) / np.pi |
|
|
| ppf = torch.cat([d, angle1, angle2, angle3], dim=-1) |
| return ppf |
|
|
|
|
| def calc_ppf_batch(points, point_normals, patches, patch_normals): |
| B, N, S, _ = patches.shape |
|
|
| points_exp = points.unsqueeze(2).expand(-1, -1, S, -1) |
| normals_exp = point_normals.unsqueeze(2).expand(-1, -1, S, -1) |
|
|
| vec_d = patches - points_exp |
| d = torch.norm(vec_d, dim=-1, keepdim=True) |
|
|
| y1 = torch.sum(normals_exp * vec_d, dim=-1, keepdim=True) |
| x1 = torch.norm(torch.cross(normals_exp, vec_d, dim=-1), dim=-1, keepdim=True) |
| angle1 = torch.atan2(x1, y1) / np.pi |
|
|
| y2 = torch.sum(patch_normals * vec_d, dim=-1, keepdim=True) |
| x2 = torch.norm(torch.cross(patch_normals, vec_d, dim=-1), dim=-1, keepdim=True) |
| angle2 = torch.atan2(x2, y2) / np.pi |
|
|
| y3 = torch.sum(normals_exp * patch_normals, dim=-1, keepdim=True) |
| x3 = torch.norm(torch.cross(normals_exp, patch_normals, dim=-1), dim=-1, keepdim=True) |
| angle3 = torch.atan2(x3, y3) / np.pi |
|
|
| ppf = torch.cat([d, angle1, angle2, angle3], dim=-1) |
| return ppf |
|
|
|
|
| |
| |
| |
| def calc_patch_scale(xyz): |
| dist = torch.cdist(xyz, xyz) |
| B, N, _ = dist.shape |
| mask = torch.eye(N, device=xyz.device).unsqueeze(0).bool() |
| dist.masked_fill_(mask, float("inf")) |
| nn_dist, _ = dist.min(dim=-1) |
| scales = torch.max(nn_dist, dim=-1)[0] |
| return scales |
|
|
|
|
| def query_ball_point(radius, nsample, xyz, new_xyz): |
| device = xyz.device |
| B, N, _ = xyz.shape |
| _, S, _ = new_xyz.shape |
| all_idx = torch.arange(N, dtype=torch.long, device=device) |
|
|
| if torch.is_tensor(radius): |
| assert radius.shape[0] == B, "radius must be (B,)" |
| radius_val = radius.to(device).view(B, 1, 1) |
| else: |
| radius_val = float(radius) |
|
|
| dists = torch.cdist(new_xyz, xyz) |
| group_idx = all_idx.view(1, 1, N).expand(B, S, N).clone() |
| group_idx[dists > radius_val] = N |
|
|
| group_idx = group_idx.sort(dim=-1)[0][..., :nsample] |
| first_valid = group_idx[..., 0].unsqueeze(-1).expand(-1, -1, nsample) |
| mask = group_idx == N |
| group_idx[mask] = first_valid[mask] |
|
|
| return group_idx |
|
|
|
|
| |
| |
| |
| def compute_geo_feats_np( |
| P: np.ndarray, |
| n_d_xyz: int = 2048, |
| r: float = 0.1, |
| eps: float = 1e-4, |
| chunk_p: int = 64_000, |
| chunk_q: int = 256_000, |
| ) -> np.ndarray: |
| P = np.asarray(P, dtype=np.float32) |
| Q = fps_np(P, n_d_xyz)[0] |
| Q = np.asarray(Q, dtype=np.float32) |
|
|
| N = P.shape[0] |
| eye3 = np.eye(3, dtype=P.dtype) |
| r2 = r * r |
|
|
| feats_all = [] |
|
|
| for p0 in range(0, N, chunk_p): |
| p1 = min(p0 + chunk_p, N) |
| P_blk = P[p0:p1] |
| p = P_blk.shape[0] |
|
|
| cnt = np.zeros((p, 1), dtype=P.dtype) |
| sum1 = np.zeros((p, 3), dtype=P.dtype) |
| sum2 = np.zeros((p, 3, 3), dtype=P.dtype) |
|
|
| for q0 in range(0, Q.shape[0], chunk_q): |
| q1 = min(q0 + chunk_q, Q.shape[0]) |
| Q_blk = Q[q0:q1] |
|
|
| diff = P_blk[:, None, :] - Q_blk[None, :, :] |
| d2 = np.sum(diff * diff, axis=2) |
|
|
| mask = d2 < r2 |
| if not mask.any(): |
| continue |
|
|
| w = mask.astype(P.dtype) |
| cnt += np.sum(w, axis=1, keepdims=True) |
| sum1 += w @ Q_blk |
|
|
| Q_outer = Q_blk[:, :, None] * Q_blk[:, None, :] |
| sum2 += np.einsum("pq,qkl->pkl", w, Q_outer) |
|
|
| cnt[cnt == 0] = 1.0 |
|
|
| C = ( |
| sum2 |
| - P_blk[:, :, None] * sum1[:, None, :] |
| - sum1[:, :, None] * P_blk[:, None, :] |
| + cnt[:, :, None] * (P_blk[:, :, None] * P_blk[:, None, :]) |
| ) / cnt[:, :, None] |
|
|
| C += eps * eye3 |
| lamb = np.flip(np.linalg.eigvalsh(C), axis=1) |
|
|
| lamb1, lamb2, lamb3 = lamb[:, 0], lamb[:, 1], lamb[:, 2] |
| feats_blk = np.stack( |
| [(lamb1 - lamb2) / lamb1, (lamb2 - lamb3) / lamb1, lamb3 / lamb1], |
| axis=1, |
| ) |
|
|
| feats_all.append(feats_blk) |
|
|
| return np.concatenate(feats_all, axis=0).astype(P.dtype) |
|
|
|
|
| @torch.no_grad() |
| def compute_geo_feats( |
| P: torch.Tensor, |
| Q: torch.Tensor, |
| r: float = 0.1, |
| eps: float = 1e-4, |
| chunk_p: int = 64_000, |
| chunk_q: int = 256_000, |
| ) -> torch.Tensor: |
| device = P.device |
| dtype = P.dtype |
| N = P.size(0) |
| eye3 = torch.eye(3, device=device, dtype=dtype) |
|
|
| feats_chunks = [] |
|
|
| for p0 in range(0, N, chunk_p): |
| p1 = min(p0 + chunk_p, N) |
| P_blk = P[p0:p1] |
| p = p1 - p0 |
|
|
| cnt = torch.zeros(p, 1, device=device, dtype=dtype) |
| sum1 = torch.zeros(p, 3, device=device, dtype=dtype) |
| sum2 = torch.zeros(p, 3, 3, device=device, dtype=dtype) |
|
|
| for q0 in range(0, Q.size(0), chunk_q): |
| q1 = min(q0 + chunk_q, Q.size(0)) |
| Q_blk = Q[q0:q1] |
|
|
| d2 = torch.cdist(P_blk, Q_blk, p=2) |
| mask = d2 < r |
|
|
| if not mask.any(): |
| continue |
|
|
| w = mask.float() |
| cnt += w.sum(1, keepdim=True) |
| sum1 += w @ Q_blk |
|
|
| Q_outer = Q_blk.unsqueeze(2) * Q_blk.unsqueeze(1) |
| sum2 += torch.einsum("ij,jkl->ikl", w, Q_outer) |
|
|
| cnt.clamp_(min=1.0) |
|
|
| C = ( |
| sum2 |
| - P_blk.unsqueeze(2) * sum1.unsqueeze(1) |
| - sum1.unsqueeze(2) * P_blk.unsqueeze(1) |
| + cnt.unsqueeze(2) * (P_blk.unsqueeze(2) * P_blk.unsqueeze(1)) |
| ) |
|
|
| C = C / cnt.unsqueeze(2) + eps * eye3 |
|
|
| lamb = torch.flip( |
| torch.linalg.eigvalsh(C).clamp_min(1e-9), |
| dims=[1], |
| ) |
|
|
| l1, l2, l3 = lamb[:, 0], lamb[:, 1], lamb[:, 2] |
|
|
| blk_feat = torch.stack( |
| [(l1 - l2) / l1, (l2 - l3) / l1, l3 / l1], |
| dim=1, |
| ) |
| feats_chunks.append(blk_feat) |
|
|
| return torch.cat(feats_chunks, dim=0) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def sinkhorn_log_torch(scores, xi=0.05, iters=3, eps=1e-12): |
| B, K = scores.shape |
| log_q = (scores / xi).T |
| log_q -= log_q.max() |
| log_q -= torch.logsumexp(log_q, dim=(0, 1), keepdim=True) |
|
|
| for _ in range(iters): |
| log_q -= torch.logsumexp(log_q, dim=1, keepdim=True) + torch.log( |
| torch.tensor(float(K), device=scores.device, dtype=scores.dtype) |
| ) |
| log_q -= torch.logsumexp(log_q, dim=0, keepdim=True) + torch.log( |
| torch.tensor(float(B), device=scores.device, dtype=scores.dtype) |
| ) |
| log_q += torch.log(torch.tensor(float(B), device=scores.device, dtype=scores.dtype)) |
|
|
| prob = torch.exp(log_q.T) |
| prob = torch.clamp(prob, 1e-6, 1.0) |
| prob = prob / (prob.sum(1, keepdim=True) + eps) |
| return prob |
|
|
|
|
| @torch.no_grad() |
| def sinkhorn_chunked_log( |
| rows_fn, |
| N: int, |
| K: int, |
| tau: float = 0.05, |
| iters: int = 60, |
| chunk_p: int = 131_072, |
| eps: float = 1e-9, |
| dtype=torch.float32, |
| device: torch.device = torch.device("cuda"), |
| ): |
| log_r = -torch.log(torch.tensor(float(N), dtype=dtype, device=device)) |
| log_c = -torch.log(torch.tensor(float(K), dtype=dtype, device=device)) |
| log_v = torch.zeros(K, dtype=dtype, device=device) |
|
|
| for _ in range(iters): |
| log_col_sum = torch.full((K,), -float("inf"), dtype=dtype, device=device) |
| for i0 in range(0, N, chunk_p): |
| sims = rows_fn(i0, min(i0 + chunk_p, N)).clamp_min_(eps) |
| log_K = sims / tau |
| log_u = log_r - torch.logsumexp(log_K + log_v, dim=1) |
| blk = log_u[:, None] + log_K + log_v |
| log_col_sum = torch.logaddexp(log_col_sum, torch.logsumexp(blk, dim=0)) |
| log_v = log_c - log_col_sum |
|
|
| labels = torch.empty(N, dtype=torch.long, device=device) |
| mass = torch.zeros(K, dtype=dtype, device=device) |
|
|
| for i0 in range(0, N, chunk_p): |
| sims = rows_fn(i0, min(i0 + chunk_p, N)).clamp_min_(eps) |
| log_K = sims / tau |
| log_u = log_r - torch.logsumexp(log_K + log_v, dim=1) |
| log_Y = log_u[:, None] + log_K + log_v |
| Y = torch.exp(log_Y) |
| labels[i0: i0 + Y.size(0)] = torch.argmax(Y, dim=1) |
| mass += Y.sum(0) |
|
|
| mass += eps |
| return labels, mass |
|
|
|
|
| def sinkhorn_chunked_log_np( |
| rows_fn, |
| N: int, |
| K: int, |
| tau: float = 0.05, |
| iters: int = 60, |
| chunk_p: int = 131_072, |
| eps: float = 1e-9, |
| dtype=np.float32, |
| ): |
| log_r = -np.log(float(N)).astype(dtype) |
| log_c = -np.log(float(K)).astype(dtype) |
| log_v = np.zeros((K,), dtype=dtype) |
|
|
| for _ in range(iters): |
| log_col_sum = np.full((K,), -np.inf, dtype=dtype) |
| for i0 in range(0, N, chunk_p): |
| sims = rows_fn(i0, min(i0 + chunk_p, N)).astype(dtype) |
| sims = np.clip(sims, eps, None) |
|
|
| log_K = sims / tau |
| log_u = log_r - logsumexp(log_K + log_v, axis=1) |
| blk = log_u[:, None] + log_K + log_v |
|
|
| log_col_sum = np.logaddexp(log_col_sum, logsumexp(blk, axis=0)) |
| log_v = log_c - log_col_sum |
|
|
| labels = np.empty((N,), dtype=np.int64) |
| mass = np.zeros((K,), dtype=dtype) |
|
|
| for i0 in range(0, N, chunk_p): |
| sims = rows_fn(i0, min(i0 + chunk_p, N)).astype(dtype) |
| sims = np.clip(sims, eps, None) |
| log_K = sims / tau |
| log_u = log_r - logsumexp(log_K + log_v, axis=1) |
| log_Y = log_u[:, None] + log_K + log_v |
| Y = np.exp(log_Y) |
|
|
| labels[i0: i0 + Y.shape[0]] = np.argmax(Y, axis=1) |
| mass += Y.sum(axis=0) |
|
|
| mass += eps |
| return labels, mass |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def ot_fps_cluster_large( |
| xyz: torch.Tensor, |
| feats: torch.Tensor, |
| K: int = 64, |
| r_spatial: float = 0.05, |
| tau: float = 0.05, |
| sinkhorn_iters: int = 60, |
| outer_iters: int = 3, |
| chunk_P: int = 128, |
| ): |
| device, N = xyz.device, xyz.size(0) |
| feats = F.normalize(feats, dim=1) |
|
|
| sigma2 = ((xyz - xyz.mean(0, keepdim=True)) ** 2).sum(1).mean() |
| EPS_TIE = 1e-6 |
|
|
| centre_idx = fps(xyz, K) |
| c_xyz = xyz[centre_idx].clone() |
| c_fea = feats[centre_idx].clone() |
|
|
| for _ in range(outer_iters): |
| def rows_fn(i0: int, i1: int) -> torch.Tensor: |
| pts = xyz[i0:i1] |
| fea = feats[i0:i1] |
| d2 = torch.cdist(pts, c_xyz) |
| mask = d2 <= r_spatial |
| empty = mask.logical_not().all(dim=1) |
| if empty.any(): |
| nearest = torch.argmin(d2[empty], dim=1) |
| mask[empty, nearest] = True |
| s_geo = torch.exp(-d2 / sigma2) * mask |
| s_fea = torch.abs(torch.einsum("id,kd->ik", fea, c_fea)) |
| return s_geo * s_fea + EPS_TIE * torch.arange(K, device=device) |
|
|
| labels, mass = sinkhorn_chunked_log( |
| rows_fn, |
| N, |
| K, |
| tau=tau, |
| iters=sinkhorn_iters, |
| chunk_p=chunk_P, |
| dtype=xyz.dtype, |
| device=device, |
| ) |
|
|
| dead = mass < 1e-4 |
| if dead.any(): |
| live_xyz = c_xyz[~dead] |
| dist2_live = torch.cdist(xyz, live_xyz) |
| far_idx = torch.topk(dist2_live.min(1).values, dead.sum().item()).indices |
| c_xyz[dead] = xyz[far_idx] |
| c_fea[dead] = feats[far_idx] |
| mass[dead] = 1.0 |
|
|
| d2_to_c = torch.norm(xyz - c_xyz[labels], dim=1) |
| w = (d2_to_c <= r_spatial).float() |
|
|
| c_xyz.zero_() |
| c_fea.zero_() |
| c_xyz.index_add_(0, labels, xyz * w.unsqueeze(1)) |
| c_fea.index_add_(0, labels, feats * w.unsqueeze(1)) |
| mass = torch.zeros_like(mass).index_add_(0, labels, w) |
|
|
| live = mass > 0 |
| c_xyz[live] = c_xyz[live] / mass[live, None] |
| c_fea[live] = F.normalize(c_fea[live] / mass[live, None], dim=1) |
|
|
| return labels |
|
|
|
|
| def ot_fps_cluster_large_np( |
| xyz: np.ndarray, |
| feats: np.ndarray, |
| K: int = 64, |
| r_spatial: float = 0.05, |
| tau: float = 0.05, |
| sinkhorn_iters: int = 10, |
| outer_iters: int = 3, |
| chunk_P: int = 128, |
| use_scipy_cdist: bool = True, |
| ): |
| N = xyz.shape[0] |
| feats = feats / (np.linalg.norm(feats, axis=1, keepdims=True) + 1e-9) |
|
|
| sigma2 = np.mean(np.linalg.norm(xyz - xyz.mean(0), axis=1) ** 2) |
| EPS_TIE = 1e-6 |
|
|
| c_xyz, c_idx = fps_np(xyz, K) |
| c_fea = feats[c_idx].copy() |
|
|
| for _ in range(outer_iters): |
| def rows_fn(i0, i1): |
| pts = xyz[i0:i1] |
| fea = feats[i0:i1] |
|
|
| if use_scipy_cdist: |
| d2 = scipy.spatial.distance.cdist(pts, c_xyz, metric="sqeuclidean") |
| else: |
| d2 = ((pts[:, None, :] - c_xyz[None, :, :]) ** 2).sum(-1) |
|
|
| mask = d2 <= r_spatial ** 2 |
| empty = np.all(~mask, axis=1) |
| if np.any(empty): |
| nearest = np.argmin(d2[empty], axis=1) |
| mask[empty, nearest] = True |
|
|
| s_geo = np.exp(-d2 / sigma2) * mask |
| s_fea = np.abs(fea @ c_fea.T) |
| return s_geo * s_fea + EPS_TIE * np.arange(K) |
|
|
| labels, mass = sinkhorn_chunked_log_np( |
| rows_fn, |
| N, |
| K, |
| tau=tau, |
| iters=sinkhorn_iters, |
| chunk_p=chunk_P, |
| dtype=xyz.dtype, |
| ) |
|
|
| dead = mass < 1e-4 |
| if np.any(dead): |
| live_xyz = c_xyz[~dead] |
| dist2_live = ((xyz[:, None, :] - live_xyz[None, :, :]) ** 2).sum(-1) |
| far_idx = np.argpartition(dist2_live.min(axis=1), -dead.sum())[-dead.sum():] |
| c_xyz[dead] = xyz[far_idx] |
| c_fea[dead] = feats[far_idx] |
| mass[dead] = 1.0 |
|
|
| d2_to_c = np.linalg.norm(xyz - c_xyz[labels], axis=1) |
| w = (d2_to_c <= r_spatial).astype(xyz.dtype) |
|
|
| c_xyz[:] = 0 |
| c_fea[:] = 0 |
| for k in range(K): |
| sel = labels == k |
| if sel.any(): |
| c_xyz[k] = (xyz[sel] * w[sel, None]).sum(0) |
| c_fea[k] = (feats[sel] * w[sel, None]).sum(0) |
| mass[k] = w[sel].sum() |
|
|
| live = mass > 0 |
| c_xyz[live] /= mass[live, None] |
| c_fea[live] = c_fea[live] / (np.linalg.norm(c_fea[live], axis=1, keepdims=True) + 1e-9) |
|
|
| return labels.astype(np.int64) |
|
|
|
|
| |
| |
| |
| def compute_superpoints( |
| xyz, |
| feats=None, |
| n_d_xyz=2048, |
| n_clus=64, |
| r_geo=0.1, |
| r_clus=0.05, |
| device=torch.device("cuda"), |
| method="ot", |
| normals=None, |
| pycut_kwargs=None, |
| ): |
| """ |
| method: |
| 'ot' -> original OT/FPS clustering |
| 'pycut' -> L0-cut-pursuit superpoints via libcp |
| """ |
|
|
| if method == "pycut": |
| kw = dict( |
| k_feat=10, k_adj=10, chunk_size=8192, |
| use_input_normals=True, use_xyz=False, |
| xyz_scale=0.10, normal_scale=0.25, |
| lam=0.03, sigma=0.5, |
| mutual=False, undirected=True, |
| min_comp_weight=20, weight_decay=0.7, |
| verbose=False, |
| ) |
| if pycut_kwargs is not None: |
| kw.update(pycut_kwargs) |
|
|
| xyz_f = np.asarray(xyz, dtype=np.float32) |
| xyz_norm = _normalize_xyz_np(xyz_f) |
|
|
| geom_feat = _local_geom_features_chunked_np( |
| xyz_norm, k_feat=kw["k_feat"], chunk_size=kw["chunk_size"], |
| ) |
| feat_parts = [geom_feat] |
| if kw["use_input_normals"] and normals is not None: |
| nn = _normalize_normals_np(np.asarray(normals, dtype=np.float32)) |
| feat_parts.append(nn * kw["normal_scale"]) |
| if kw["use_xyz"]: |
| feat_parts.append(xyz_norm * kw["xyz_scale"]) |
| Y = np.hstack(feat_parts).astype(np.float32) |
|
|
| src, dst = _build_adj_graph_np( |
| xyz_norm, k_adj=kw["k_adj"], |
| mutual=kw["mutual"], undirected=kw["undirected"], |
| ) |
| ew = _edge_weights_chunked_np(Y.T, src, dst, lam=1.0, sigma=kw["sigma"]) |
|
|
| try: |
| import libcp |
| except ImportError: |
| libcp_dir = os.path.join( |
| os.path.dirname(os.path.abspath(__file__)), |
| "_cut_pursuit", "build", "src", |
| ) |
| sys.path.insert(0, libcp_dir) |
| import libcp |
|
|
| components, in_component = libcp.cutpursuit( |
| Y, |
| src.astype(np.uint32), |
| dst.astype(np.uint32), |
| ew.astype(np.float32), |
| float(kw["lam"]), |
| int(kw["min_comp_weight"]), |
| 0, |
| float(kw["weight_decay"]), |
| ) |
| lbl_np = _relabel_contiguous_np(np.asarray(in_component, dtype=np.int32)) |
| lbl = torch.from_numpy(lbl_np.astype(np.int64)).to(device) |
| feats_out = torch.from_numpy(Y).to(device) |
| return lbl, feats_out |
|
|
| |
| P = torch.from_numpy(xyz).to(device) |
| if feats is None: |
| Q = P[fps(P, n_d_xyz)] |
| feats = compute_geo_feats(P, Q, r=r_geo, chunk_p=4096, chunk_q=256) |
| lbl = ot_fps_cluster_large(P, feats, K=n_clus, r_spatial=r_clus) |
| return lbl, feats |
|
|
|
|
| |
| |
| |
| def sp_constrained_fps(xyz: torch.Tensor, sp: torch.Tensor, L: int): |
| device, N = xyz.device, xyz.size(0) |
| K = sp.max().item() + 1 |
| assert L >= K |
|
|
| perm = torch.randperm(N, device=device) |
| anchors = torch.full((K,), N, dtype=torch.long, device=device) |
| anchors.scatter_reduce_(0, sp[perm], perm, reduce="amin") |
|
|
| idx = torch.empty(L, dtype=torch.long, device=device) |
| idx[:K] = anchors |
|
|
| dist2 = torch.cdist(xyz, xyz[anchors]).pow(2).min(dim=1).values |
|
|
| for i in range(K, L): |
| nxt = torch.argmax(dist2) |
| idx[i] = nxt |
| dist2 = torch.minimum(dist2, ((xyz - xyz[nxt]) ** 2).sum(-1)) |
|
|
| return idx |
|
|
|
|
| def batch_random_anchor(sp: torch.Tensor): |
| device = sp.device |
| B, N = sp.shape |
| K = sp.max().item() + 1 |
|
|
| perm = torch.argsort(torch.rand(B, N, device=device), dim=1) |
| labels_perm = sp.gather(1, perm) |
|
|
| anchors = torch.full((B, K), N, dtype=torch.long, device=device) |
| anchors.scatter_reduce_(1, labels_perm, perm, reduce="amin") |
|
|
| return anchors |
|
|
|
|
| def sp_constrained_fps_batch( |
| xyz: torch.Tensor, |
| sp: torch.Tensor, |
| L: int, |
| ) -> torch.LongTensor: |
| device = xyz.device |
| B, N, _ = xyz.shape |
| K = sp.max().item() + 1 |
| assert L >= K |
|
|
| perm = torch.argsort(torch.rand(B, N, device=device), dim=1) |
| anchors = torch.full((B, K), N, dtype=torch.long, device=device) |
| anchors.scatter_reduce_(1, sp.gather(1, perm), perm, reduce="amin") |
| idx = torch.empty(B, L, dtype=torch.long, device=device) |
| idx[:, :K] = anchors |
|
|
| c_xyz = xyz.gather(1, anchors[..., None].expand(-1, -1, 3)) |
| dist2 = torch.cdist(xyz, c_xyz, p=2).min(dim=2).values |
|
|
| b_ids = torch.arange(B, device=device) |
|
|
| for i in range(K, L): |
| farthest = dist2.max(dim=1).indices |
| idx[:, i] = farthest |
| new_c = xyz[b_ids, farthest] |
| dist2 = torch.minimum(dist2, ((xyz - new_c[:, None, :]) ** 2).sum(-1)) |
|
|
| return idx |
|
|
|
|
| def memory_efficient_fps_prob( |
| xyz: Tensor, |
| prob: Tensor, |
| k: int, |
| gamma: float = 0.5, |
| eps: float = 1e-12, |
| chunk_size: int = 32, |
| ) -> Tensor: |
| B, N, _ = xyz.shape |
| device = xyz.device |
|
|
| dist = torch.full((B, N), float("inf"), device=device) |
| idx = torch.zeros(B, k, dtype=torch.long, device=device) |
|
|
| idx[:, 0] = prob.argmax(dim=1) |
|
|
| for i in range(1, k): |
| last_xyz = xyz.gather(1, idx[:, i - 1:i].unsqueeze(-1).expand(-1, -1, 3)) |
| min_dist = torch.full((B, N), float("inf"), device=device) |
|
|
| for chunk_start in range(0, N, chunk_size): |
| chunk_end = min(chunk_start + chunk_size, N) |
| chunk_xyz = xyz[:, chunk_start:chunk_end, :] |
| chunk_dist = ((chunk_xyz - last_xyz) ** 2).sum(-1) |
| min_dist[:, chunk_start:chunk_end] = torch.minimum( |
| dist[:, chunk_start:chunk_end], |
| chunk_dist, |
| ) |
|
|
| dist = min_dist |
| score = dist * (prob + eps) ** -gamma |
| idx[:, i] = score.argmax(dim=1) |
|
|
| return idx |
|
|
|
|
| |
| |
| |
| def _superpoint_pool(feat: torch.Tensor, spts: torch.Tensor) -> torch.Tensor: |
| B, N, D = feat.shape |
| feat_flat = feat.reshape(B * N, D) |
|
|
| K = int(spts.max().item()) + 1 |
| offsets = (torch.arange(B, device=feat.device) * K).view(B, 1) |
| sp_offset = (spts + offsets).reshape(-1) |
| tot_sp = B * K |
|
|
| feat_sum = torch.zeros(tot_sp, D, device=feat.device, dtype=feat.dtype) |
| feat_sum.index_add_(0, sp_offset, feat_flat) |
|
|
| cnt_sum = torch.zeros(tot_sp, 1, device=feat.device, dtype=feat.dtype) |
| cnt_sum.index_add_(0, sp_offset, torch.ones(B * N, 1, device=feat.device, dtype=feat.dtype)) |
|
|
| feat_avg = feat_sum / (cnt_sum + 1e-6) |
| feat_denoised = feat_avg[sp_offset].view(B, N, D) |
| return feat_denoised |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| from lib_vis import create_labeled_point_cloud |
| import open3d as o3d |
| pcd = o3d.io.read_point_cloud("s3dis.ply") |
| print(pcd) |
| xyz = np.array(pcd.points) |
|
|
| normals = None |
| normals_path = "../normal.npy" |
| if os.path.exists(normals_path): |
| normals = np.load(normals_path).astype(np.float32) |
|
|
| torch.manual_seed(0) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| |
| lbl, feats = compute_superpoints( |
| xyz, |
| normals=normals, |
| device=device, |
| method="pycut", |
| pycut_kwargs=dict( |
| k_feat=10, |
| k_adj=10, |
| chunk_size=8192, |
| use_input_normals=(normals is not None), |
| use_xyz=False, |
| xyz_scale=0.10, |
| normal_scale=0.25, |
| lam=5.0, |
| sigma=0.5, |
| mutual=False, |
| undirected=True, |
| cp_it_max=10, |
| split_iter_num=2, |
| split_damp_ratio=0.7, |
| kmpp_init_num=3, |
| kmpp_iter_num=3, |
| K=2, |
| min_comp_weight=20.0, |
| verbose=False, |
| ), |
| ) |
|
|
| pcd = torch.from_numpy(xyz).to(device) |
| idx = sp_constrained_fps(pcd, lbl, 2048).tolist() |
| down_pcd = pcd[idx] |
| down_lbl = lbl[idx] |
|
|
| create_labeled_point_cloud(xyz[idx], lbl[idx].tolist(), name="cls_pycut") |
|
|
| down_idx = superpoint_fps( |
| down_pcd.unsqueeze(0), |
| down_lbl.unsqueeze(0), |
| k=128, |
| gamma=0.1, |
| base_scale=1.5, |
| )[0].tolist() |
|
|
| create_labeled_point_cloud( |
| down_pcd[down_idx].cpu().numpy(), |
| down_lbl[down_idx].tolist(), |
| name="fps_pycut", |
| ) |
| print(torch.unique(down_lbl[down_idx], return_counts=True)) |