biptv3 / code /superpoint_ops /lib_geo.py
YYYYYYUUU's picture
Add core reproduction code (binarization layers, PTv3, superpoint ops, min-repro pack)
7b95dc2 verified
Raw
History Blame Contribute Delete
42.8 kB
import os
import sys
from typing import Optional, Tuple
import numpy as np
import scipy.spatial
import torch
import torch.nn.functional as F
from scipy.special import logsumexp
from scipy.spatial import cKDTree
from torch import Tensor
try:
import segmentator
except ImportError:
segmentator = None
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, parent_dir)
# -----------------------------------------------------------------------------
# label utilities
# -----------------------------------------------------------------------------
def num_to_natural_numpy(group_ids, void_number=-1):
"""
Convert group IDs to contiguous natural numbers, preserving the void label.
Args:
group_ids (array-like): Input group labels, e.g., [-1, 0, 3, 4, 0, 6].
void_number (int): The label for 'void' class (only supports -1 or 0).
Returns:
array_ids (np.ndarray): Mapped IDs with contiguous natural numbers and voids untouched.
"""
group_ids = np.asarray(group_ids, dtype=int)
if void_number not in (-1, 0):
raise ValueError("void_number must be either -1 or 0.")
void_mask = group_ids == void_number
valid_ids = group_ids[~void_mask]
if valid_ids.size == 0:
return group_ids.copy() # all voids
unique_ids = np.unique(valid_ids)
remap = np.zeros(unique_ids.max() + 1, dtype=int)
if void_number == -1:
remap = np.full(unique_ids.max() + 1, -1, dtype=int)
remap[unique_ids] = np.arange(len(unique_ids))
result = remap[group_ids.clip(min=0)]
result[void_mask] = -1
else:
remap[unique_ids] = np.arange(1, len(unique_ids) + 1)
result = remap[group_ids]
return result
def num_to_natural_torch(group_ids, void_number=-1):
group_ids_tensor = group_ids.long()
device = group_ids_tensor.device
if void_number == -1:
if torch.all(group_ids_tensor == -1):
return group_ids_tensor
array_ids = group_ids_tensor.clone()
unique_values = torch.unique(array_ids[array_ids != -1])
mapping = torch.full(
(torch.max(unique_values) + 2,), -1, dtype=torch.long, device=device
)
mapping[unique_values + 1] = torch.arange(
len(unique_values), dtype=torch.long, device=device
)
array_ids = mapping[array_ids + 1]
elif void_number == 0:
if torch.all(group_ids_tensor == 0):
return group_ids_tensor
array_ids = group_ids_tensor.clone()
unique_values = torch.unique(array_ids[array_ids != 0])
mapping = torch.full(
(torch.max(unique_values) + 2,), 0, dtype=torch.long, device=device
)
mapping[unique_values] = (
torch.arange(len(unique_values), dtype=torch.long, device=device) + 1
)
array_ids = mapping[array_ids]
else:
raise ValueError("void_number must be -1 or 0")
return array_ids
# -----------------------------------------------------------------------------
# optional legacy segmentator
# -----------------------------------------------------------------------------
def gen_superpoints(points, normals, k=50, kThresh=0.01, segMinVerts=20):
from torch_cluster import knn_graph
if segmentator is None or knn_graph is None:
raise ImportError("segmentator and torch_cluster are required for gen_superpoints().")
edges = knn_graph(points, k=k).T
superpoint = segmentator.segment_point(points, normals, edges, kThresh, segMinVerts)
return superpoint
def _normalize_xyz_np(xyz: np.ndarray) -> np.ndarray:
xyz = np.asarray(xyz, dtype=np.float32)
center = xyz.mean(axis=0, keepdims=True)
xyz0 = xyz - center
bbmin = xyz0.min(axis=0)
bbmax = xyz0.max(axis=0)
diag = np.linalg.norm(bbmax - bbmin)
if diag < 1e-12:
diag = 1.0
return xyz0 / diag
def _normalize_normals_np(normals: np.ndarray) -> np.ndarray:
normals = np.asarray(normals, dtype=np.float32)
nrm = np.linalg.norm(normals, axis=1, keepdims=True)
nrm = np.clip(nrm, 1e-12, None)
return normals / nrm
def _build_knn_np(xyz: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
n = xyz.shape[0]
if n <= 1:
return (
np.empty((n, 0), dtype=np.float32),
np.empty((n, 0), dtype=np.int64),
)
k_eff = min(k + 1, n)
tree = cKDTree(xyz)
dists, inds = tree.query(xyz, k=k_eff, workers=-1)
return dists[:, 1:], inds[:, 1:]
def _local_geom_features_chunked_np(
xyz: np.ndarray,
k_feat: int = 10,
chunk_size: int = 8192,
) -> np.ndarray:
"""
SPG-style local geometric features:
linearity, planarity, scattering, verticality, elevation
"""
_, nbrs = _build_knn_np(xyz, k_feat)
n = xyz.shape[0]
k_eff = nbrs.shape[1]
feat = np.empty((n, 5), dtype=np.float32)
z = xyz[:, 2]
zmin, zmax = z.min(), z.max()
if zmax - zmin < 1e-12:
elevation = np.zeros(n, dtype=np.float32)
else:
elevation = ((z - zmin) / (zmax - zmin)).astype(np.float32)
eps = 1e-12
if k_eff == 0:
feat[:, 0] = 0.0
feat[:, 1] = 0.0
feat[:, 2] = 1.0
feat[:, 3] = 1.0
feat[:, 4] = elevation
return feat
for s in range(0, n, chunk_size):
e = min(s + chunk_size, n)
pts = xyz[nbrs[s:e]] # [B, k, 3]
mu = pts.mean(axis=1, keepdims=True)
X = pts - mu
cov = np.matmul(X.transpose(0, 2, 1), X) / float(max(k_eff, 1))
evals, evecs = np.linalg.eigh(cov.astype(np.float64))
evals = np.clip(evals, eps, None)
l3 = evals[:, 0]
l2 = evals[:, 1]
l1 = evals[:, 2]
denom = np.maximum(l1, eps)
linearity = (l1 - l2) / denom
planarity = (l2 - l3) / denom
scattering = l3 / denom
n_local = evecs[:, :, 0]
verticality = 1.0 - np.abs(n_local[:, 2])
feat[s:e, 0] = linearity.astype(np.float32)
feat[s:e, 1] = planarity.astype(np.float32)
feat[s:e, 2] = scattering.astype(np.float32)
feat[s:e, 3] = verticality.astype(np.float32)
feat[s:e, 4] = elevation[s:e]
return feat
def _build_adj_graph_np(
xyz: np.ndarray,
k_adj: int = 10,
mutual: bool = False,
undirected: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
_, nbrs = _build_knn_np(xyz, k_adj)
n, k = nbrs.shape
if k == 0:
return np.empty((0,), dtype=np.uint32), np.empty((0,), dtype=np.uint32)
src = np.repeat(np.arange(n, dtype=np.uint32), k)
dst = nbrs.reshape(-1).astype(np.uint32)
keep = src != dst
src = src[keep]
dst = dst[keep]
if mutual:
code = src.astype(np.uint64) * np.uint64(n) + dst.astype(np.uint64)
rev_code = dst.astype(np.uint64) * np.uint64(n) + src.astype(np.uint64)
keep = np.isin(code, rev_code, assume_unique=False)
src = src[keep]
dst = dst[keep]
if undirected:
src0, dst0 = src, dst
src = np.concatenate([src0, dst0], axis=0)
dst = np.concatenate([dst0, src0], axis=0)
return src.astype(np.uint32, copy=False), dst.astype(np.uint32, copy=False)
def _edge_weights_chunked_np(
Y: np.ndarray, # [D, N], Fortran order
src: np.ndarray,
dst: np.ndarray,
lam: float = 5.0,
sigma: float = 0.5,
chunk_size: int = 1_000_000,
) -> np.ndarray:
sigma2 = max(sigma * sigma, 1e-12)
num_edges = src.shape[0]
ew = np.empty(num_edges, dtype=np.float32)
f = Y.T # [N, D]
for s in range(0, num_edges, chunk_size):
e = min(s + chunk_size, num_edges)
diff = f[src[s:e]] - f[dst[s:e]]
dist2 = np.sum(diff * diff, axis=1)
ew[s:e] = lam * np.exp(-dist2 / sigma2)
return ew
def _edges_to_forward_star(
n: int,
src: np.ndarray,
dst: np.ndarray,
ew: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if src.size == 0:
first_edge = np.zeros(n + 1, dtype=np.uint32)
adj_vertices = np.empty((0,), dtype=np.uint32)
edge_weights = np.empty((0,), dtype=np.float32)
return first_edge, adj_vertices, edge_weights
order = np.argsort(src, kind="stable")
src = src[order]
dst = dst[order]
ew = ew[order]
counts = np.bincount(src.astype(np.int64), minlength=n)
first_edge = np.zeros(n + 1, dtype=np.uint32)
first_edge[1:] = np.cumsum(counts, dtype=np.uint32)
return (
first_edge,
dst.astype(np.uint32, copy=False),
ew.astype(np.float32, copy=False),
)
def _relabel_contiguous_np(labels: np.ndarray) -> np.ndarray:
_, inv = np.unique(labels, return_inverse=True)
return inv.astype(np.int32)
# -----------------------------------------------------------------------------
# superpoint-aware fps
# -----------------------------------------------------------------------------
def superpoint_fps(
xyz: Tensor,
labels: Tensor,
k: int,
min_segment_points: int = 10,
gamma: float = 0.7,
eps: float = 1e-4,
base_scale: float = 1.0,
deterministic_start: bool = False,
return_probs: bool = False,
):
B, N, _ = xyz.shape
if labels.shape != (B, N):
raise ValueError(f"labels shape must be (B, N) = {(B, N)}, got {labels.shape}")
if (labels < 0).any():
raise ValueError("labels must be non-negative")
if not (0 < k <= N):
raise ValueError(f"k must be in 1…N={N}, got {k}")
device, dtype = xyz.device, xyz.dtype
long_labels = labels.long()
max_id_val = long_labels.max().item() if N else 0
counts = torch.zeros(B, max_id_val + 1, device=device, dtype=torch.float)
counts.scatter_add_(1, long_labels, torch.ones_like(long_labels, dtype=torch.float))
counts = counts + eps
seg_large_enough = counts >= float(min_segment_points)
eligible_mask = seg_large_enough.gather(1, long_labels)
if (eligible_mask.sum(dim=1) < k).any():
raise ValueError(
"At least one batch item has fewer than k eligible points "
f"(threshold={min_segment_points})."
)
label_counts = counts.gather(1, long_labels)
inv_probs = torch.where(
eligible_mask,
1.0 / label_counts,
torch.zeros_like(label_counts),
)
inv_sum = inv_probs.sum(dim=1, keepdim=True).clamp(min=eps)
probs = inv_probs / inv_sum
present_mask = counts > (eps + 1e-9)
counts_for_min = torch.where(
present_mask, counts, torch.full_like(counts, float("inf"))
)
min_counts_b = torch.where(
torch.isinf(counts_for_min.min(1, keepdim=True).values),
torch.ones_like(counts_for_min.min(1, keepdim=True).values),
counts_for_min.min(1, keepdim=True).values,
)
counts_for_max = torch.where(present_mask, counts, torch.zeros_like(counts))
max_counts_b = counts_for_max.max(1, keepdim=True).values.clamp(min=eps)
size_ratio_b = (min_counts_b / max_counts_b).clamp(1e-3, 1.0)
min_weight_b = (base_scale * size_ratio_b).clamp(min=1e-5)
score_bias_weights = inv_probs ** gamma
max_bias = score_bias_weights.max(1, keepdim=True)[0].clamp(min=eps)
score_bias_weights = score_bias_weights / max_bias
prob_weights = (1.0 - min_weight_b) * score_bias_weights + min_weight_b
prob_weights = torch.where(eligible_mask, prob_weights, torch.zeros_like(prob_weights))
dist = torch.full((B, N), float("inf"), device=device, dtype=dtype)
dist = torch.where(eligible_mask, dist, torch.zeros_like(dist))
idx = torch.zeros(B, k, dtype=torch.long, device=device)
batch_idx = torch.arange(B, device=device)
if deterministic_start:
idx[:, 0] = probs.argmax(1)
else:
safe_probs = probs + eps
safe_probs = safe_probs / safe_probs.sum(1, keepdim=True)
idx[:, 0] = torch.multinomial(safe_probs, 1).squeeze(1)
selected_xyz = torch.zeros(B, k, 3, device=device, dtype=dtype)
selected_xyz[batch_idx, 0] = xyz[batch_idx, idx[:, 0]]
dist = torch.minimum(
dist,
((xyz - selected_xyz[:, 0:1]).pow(2)).sum(-1),
)
for i in range(1, k):
scores = dist * prob_weights
idx[:, i] = scores.argmax(1)
selected_xyz[batch_idx, i] = xyz[batch_idx, idx[:, i]]
dist = torch.minimum(
dist,
((xyz - selected_xyz[:, i:i + 1]).pow(2)).sum(-1),
)
dist = torch.where(eligible_mask, dist, torch.zeros_like(dist))
return (idx, probs) if return_probs else idx
# -----------------------------------------------------------------------------
# superpoint pooling
# -----------------------------------------------------------------------------
def get_spt_centers(
x: torch.Tensor,
spts: torch.Tensor,
reduce: str = "mean",
):
B, N, C = x.shape
dev = x.device
spts = spts.long()
sp_counts = spts.amax(dim=1) + 1
offsets = torch.cat(
[torch.zeros(1, dtype=sp_counts.dtype, device=dev), sp_counts.cumsum(0)]
)[:-1]
spts_global = spts + offsets[:, None]
x_flat = x.reshape(-1, C)
spts_flat = spts_global.reshape(-1)
tot_sp = int((offsets[-1] + sp_counts[-1]).item())
pooled = torch.zeros(tot_sp, C, dtype=x.dtype, device=dev)
counts = torch.zeros(tot_sp, 1, dtype=x.dtype, device=dev)
if reduce == "mean":
pooled.index_add_(0, spts_flat, x_flat)
counts.index_add_(0, spts_flat, torch.ones_like(x_flat[:, :1]))
pooled = pooled / counts.clamp(min=1e-6)
elif reduce == "max":
pooled.fill_(-float("inf"))
pooled = pooled.index_reduce(0, spts_flat, x_flat, reduce="amax")
counts.index_add_(0, spts_flat, torch.ones_like(x_flat[:, :1]))
else:
raise ValueError(f"Unsupported reduction: {reduce}")
K = int(sp_counts.max().item())
row_ids = torch.arange(K, device=dev).unsqueeze(0).expand(B, -1)
valid_mask = row_ids < sp_counts.unsqueeze(1)
gather_ix = torch.where(
valid_mask,
offsets[:, None] + row_ids,
torch.full_like(row_ids, fill_value=tot_sp),
)
pooled_ext = torch.cat([pooled, torch.zeros(1, C, device=dev, dtype=x.dtype)], dim=0)
counts_ext = torch.cat([counts, torch.zeros(1, 1, device=dev, dtype=x.dtype)], dim=0)
pooled_BKC = pooled_ext[gather_ix]
counts_BK = counts_ext[gather_ix].squeeze(-1)
mask = counts_BK > 0
pooled_BKC = pooled_BKC * mask.unsqueeze(-1).to(pooled_BKC.dtype)
sp_ids_local = torch.where(
valid_mask,
row_ids,
torch.full_like(row_ids, -1, dtype=torch.long),
)
return pooled_BKC, mask, sp_ids_local, counts_BK
# -----------------------------------------------------------------------------
# distances and coverage
# -----------------------------------------------------------------------------
def masked_pairwise_distance(pc1, pc2, mask1, mask2, invalid_val=1e6):
pc1_sq = (pc1 ** 2).sum(dim=2, keepdim=True)
pc2_sq = (pc2 ** 2).sum(dim=2).unsqueeze(1)
inner = torch.bmm(pc1, pc2.transpose(1, 2))
dists = pc1_sq - 2 * inner + pc2_sq
dists = torch.clamp(dists, min=0).sqrt()
mask1_expand = mask1.unsqueeze(2)
mask2_expand = mask2.unsqueeze(1)
valid_mask = mask1_expand * mask2_expand
dists = torch.where(valid_mask.bool(), dists, torch.full_like(dists, invalid_val))
return dists
def coverage_ratio(samples: Tensor, labels: Tensor) -> float:
unique_sampled = torch.gather(labels, 1, samples).unique().numel()
unique_total = labels.unique().numel()
return unique_sampled / unique_total
# -----------------------------------------------------------------------------
# fps
# -----------------------------------------------------------------------------
@torch.no_grad()
def fps(xyz: torch.Tensor, k: int) -> torch.LongTensor:
N, dev = xyz.size(0), xyz.device
sel = torch.empty(k, dtype=torch.long, device=dev)
sel[0] = torch.randint(0, N, (1,), device=dev)
dist2 = torch.full((N,), 1e9, device=dev)
for i in range(1, k):
d = ((xyz - xyz[sel[i - 1]]) ** 2).sum(1)
dist2 = torch.minimum(dist2, d)
sel[i] = torch.argmax(dist2)
return sel
def fps_np(point, npoint):
N, D = point.shape
xyz = point[:, :3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
ids = centroids.astype(np.int32)
point = point[ids]
return point, ids
# -----------------------------------------------------------------------------
# ppf
# -----------------------------------------------------------------------------
def calc_ppf_np(points, point_normals, patches, patch_normals):
N, nsamples, _ = patches.shape
points_expanded = np.expand_dims(points, axis=1)
points_expanded = np.repeat(points_expanded, nsamples, axis=1)
point_normals_expanded = np.expand_dims(point_normals, axis=1)
point_normals_expanded = np.repeat(point_normals_expanded, nsamples, axis=1)
vec_d = patches - points_expanded
d = np.linalg.norm(vec_d, axis=-1, keepdims=True)
dot1 = np.sum(point_normals_expanded * vec_d, axis=-1, keepdims=True)
cross1 = np.cross(point_normals_expanded, vec_d)
norm_cross1 = np.linalg.norm(cross1, axis=-1, keepdims=True)
angle1 = np.arctan2(norm_cross1, dot1) / np.pi
dot2 = np.sum(patch_normals * vec_d, axis=-1, keepdims=True)
cross2 = np.cross(patch_normals, vec_d)
norm_cross2 = np.linalg.norm(cross2, axis=-1, keepdims=True)
angle2 = np.arctan2(norm_cross2, dot2) / np.pi
dot3 = np.sum(point_normals_expanded * patch_normals, axis=-1, keepdims=True)
cross3 = np.cross(point_normals_expanded, patch_normals)
norm_cross3 = np.linalg.norm(cross3, axis=-1, keepdims=True)
angle3 = np.arctan2(norm_cross3, dot3) / np.pi
ppf = np.concatenate([d, angle1, angle2, angle3], axis=-1)
return ppf
def calc_ppf_gpu(points, point_normals, patches, patch_normals):
points = torch.unsqueeze(points, dim=1).expand(-1, patches.shape[1], -1)
point_normals = torch.unsqueeze(point_normals, dim=1).expand(-1, patches.shape[1], -1)
vec_d = patches - points
d = torch.sqrt(torch.sum(vec_d ** 2, dim=-1, keepdim=True))
y = torch.sum(point_normals * vec_d, dim=-1, keepdim=True)
x = torch.cross(point_normals, vec_d, dim=-1)
x = torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True))
angle1 = torch.atan2(x, y) / np.pi
y = torch.sum(patch_normals * vec_d, dim=-1, keepdim=True)
x = torch.cross(patch_normals, vec_d, dim=-1)
x = torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True))
angle2 = torch.atan2(x, y) / np.pi
y = torch.sum(point_normals * patch_normals, dim=-1, keepdim=True)
x = torch.cross(point_normals, patch_normals, dim=-1)
x = torch.sqrt(torch.sum(x ** 2, dim=-1, keepdim=True))
angle3 = torch.atan2(x, y) / np.pi
ppf = torch.cat([d, angle1, angle2, angle3], dim=-1)
return ppf
def calc_ppf_batch(points, point_normals, patches, patch_normals):
B, N, S, _ = patches.shape
points_exp = points.unsqueeze(2).expand(-1, -1, S, -1)
normals_exp = point_normals.unsqueeze(2).expand(-1, -1, S, -1)
vec_d = patches - points_exp
d = torch.norm(vec_d, dim=-1, keepdim=True)
y1 = torch.sum(normals_exp * vec_d, dim=-1, keepdim=True)
x1 = torch.norm(torch.cross(normals_exp, vec_d, dim=-1), dim=-1, keepdim=True)
angle1 = torch.atan2(x1, y1) / np.pi
y2 = torch.sum(patch_normals * vec_d, dim=-1, keepdim=True)
x2 = torch.norm(torch.cross(patch_normals, vec_d, dim=-1), dim=-1, keepdim=True)
angle2 = torch.atan2(x2, y2) / np.pi
y3 = torch.sum(normals_exp * patch_normals, dim=-1, keepdim=True)
x3 = torch.norm(torch.cross(normals_exp, patch_normals, dim=-1), dim=-1, keepdim=True)
angle3 = torch.atan2(x3, y3) / np.pi
ppf = torch.cat([d, angle1, angle2, angle3], dim=-1)
return ppf
# -----------------------------------------------------------------------------
# misc geometry
# -----------------------------------------------------------------------------
def calc_patch_scale(xyz):
dist = torch.cdist(xyz, xyz)
B, N, _ = dist.shape
mask = torch.eye(N, device=xyz.device).unsqueeze(0).bool()
dist.masked_fill_(mask, float("inf"))
nn_dist, _ = dist.min(dim=-1)
scales = torch.max(nn_dist, dim=-1)[0]
return scales
def query_ball_point(radius, nsample, xyz, new_xyz):
device = xyz.device
B, N, _ = xyz.shape
_, S, _ = new_xyz.shape
all_idx = torch.arange(N, dtype=torch.long, device=device)
if torch.is_tensor(radius):
assert radius.shape[0] == B, "radius must be (B,)"
radius_val = radius.to(device).view(B, 1, 1)
else:
radius_val = float(radius)
dists = torch.cdist(new_xyz, xyz)
group_idx = all_idx.view(1, 1, N).expand(B, S, N).clone()
group_idx[dists > radius_val] = N
group_idx = group_idx.sort(dim=-1)[0][..., :nsample]
first_valid = group_idx[..., 0].unsqueeze(-1).expand(-1, -1, nsample)
mask = group_idx == N
group_idx[mask] = first_valid[mask]
return group_idx
# -----------------------------------------------------------------------------
# geometric features
# -----------------------------------------------------------------------------
def compute_geo_feats_np(
P: np.ndarray,
n_d_xyz: int = 2048,
r: float = 0.1,
eps: float = 1e-4,
chunk_p: int = 64_000,
chunk_q: int = 256_000,
) -> np.ndarray:
P = np.asarray(P, dtype=np.float32)
Q = fps_np(P, n_d_xyz)[0]
Q = np.asarray(Q, dtype=np.float32)
N = P.shape[0]
eye3 = np.eye(3, dtype=P.dtype)
r2 = r * r
feats_all = []
for p0 in range(0, N, chunk_p):
p1 = min(p0 + chunk_p, N)
P_blk = P[p0:p1]
p = P_blk.shape[0]
cnt = np.zeros((p, 1), dtype=P.dtype)
sum1 = np.zeros((p, 3), dtype=P.dtype)
sum2 = np.zeros((p, 3, 3), dtype=P.dtype)
for q0 in range(0, Q.shape[0], chunk_q):
q1 = min(q0 + chunk_q, Q.shape[0])
Q_blk = Q[q0:q1]
diff = P_blk[:, None, :] - Q_blk[None, :, :]
d2 = np.sum(diff * diff, axis=2)
mask = d2 < r2
if not mask.any():
continue
w = mask.astype(P.dtype)
cnt += np.sum(w, axis=1, keepdims=True)
sum1 += w @ Q_blk
Q_outer = Q_blk[:, :, None] * Q_blk[:, None, :]
sum2 += np.einsum("pq,qkl->pkl", w, Q_outer)
cnt[cnt == 0] = 1.0
C = (
sum2
- P_blk[:, :, None] * sum1[:, None, :]
- sum1[:, :, None] * P_blk[:, None, :]
+ cnt[:, :, None] * (P_blk[:, :, None] * P_blk[:, None, :])
) / cnt[:, :, None]
C += eps * eye3
lamb = np.flip(np.linalg.eigvalsh(C), axis=1)
lamb1, lamb2, lamb3 = lamb[:, 0], lamb[:, 1], lamb[:, 2]
feats_blk = np.stack(
[(lamb1 - lamb2) / lamb1, (lamb2 - lamb3) / lamb1, lamb3 / lamb1],
axis=1,
)
feats_all.append(feats_blk)
return np.concatenate(feats_all, axis=0).astype(P.dtype)
@torch.no_grad()
def compute_geo_feats(
P: torch.Tensor,
Q: torch.Tensor,
r: float = 0.1,
eps: float = 1e-4,
chunk_p: int = 64_000,
chunk_q: int = 256_000,
) -> torch.Tensor:
device = P.device
dtype = P.dtype
N = P.size(0)
eye3 = torch.eye(3, device=device, dtype=dtype)
feats_chunks = []
for p0 in range(0, N, chunk_p):
p1 = min(p0 + chunk_p, N)
P_blk = P[p0:p1]
p = p1 - p0
cnt = torch.zeros(p, 1, device=device, dtype=dtype)
sum1 = torch.zeros(p, 3, device=device, dtype=dtype)
sum2 = torch.zeros(p, 3, 3, device=device, dtype=dtype)
for q0 in range(0, Q.size(0), chunk_q):
q1 = min(q0 + chunk_q, Q.size(0))
Q_blk = Q[q0:q1]
d2 = torch.cdist(P_blk, Q_blk, p=2)
mask = d2 < r
if not mask.any():
continue
w = mask.float()
cnt += w.sum(1, keepdim=True)
sum1 += w @ Q_blk
Q_outer = Q_blk.unsqueeze(2) * Q_blk.unsqueeze(1)
sum2 += torch.einsum("ij,jkl->ikl", w, Q_outer)
cnt.clamp_(min=1.0)
C = (
sum2
- P_blk.unsqueeze(2) * sum1.unsqueeze(1)
- sum1.unsqueeze(2) * P_blk.unsqueeze(1)
+ cnt.unsqueeze(2) * (P_blk.unsqueeze(2) * P_blk.unsqueeze(1))
)
C = C / cnt.unsqueeze(2) + eps * eye3
lamb = torch.flip(
torch.linalg.eigvalsh(C).clamp_min(1e-9),
dims=[1],
)
l1, l2, l3 = lamb[:, 0], lamb[:, 1], lamb[:, 2]
blk_feat = torch.stack(
[(l1 - l2) / l1, (l2 - l3) / l1, l3 / l1],
dim=1,
)
feats_chunks.append(blk_feat)
return torch.cat(feats_chunks, dim=0)
# -----------------------------------------------------------------------------
# sinkhorn
# -----------------------------------------------------------------------------
@torch.no_grad()
def sinkhorn_log_torch(scores, xi=0.05, iters=3, eps=1e-12):
B, K = scores.shape
log_q = (scores / xi).T
log_q -= log_q.max()
log_q -= torch.logsumexp(log_q, dim=(0, 1), keepdim=True)
for _ in range(iters):
log_q -= torch.logsumexp(log_q, dim=1, keepdim=True) + torch.log(
torch.tensor(float(K), device=scores.device, dtype=scores.dtype)
)
log_q -= torch.logsumexp(log_q, dim=0, keepdim=True) + torch.log(
torch.tensor(float(B), device=scores.device, dtype=scores.dtype)
)
log_q += torch.log(torch.tensor(float(B), device=scores.device, dtype=scores.dtype))
prob = torch.exp(log_q.T)
prob = torch.clamp(prob, 1e-6, 1.0)
prob = prob / (prob.sum(1, keepdim=True) + eps)
return prob
@torch.no_grad()
def sinkhorn_chunked_log(
rows_fn,
N: int,
K: int,
tau: float = 0.05,
iters: int = 60,
chunk_p: int = 131_072,
eps: float = 1e-9,
dtype=torch.float32,
device: torch.device = torch.device("cuda"),
):
log_r = -torch.log(torch.tensor(float(N), dtype=dtype, device=device))
log_c = -torch.log(torch.tensor(float(K), dtype=dtype, device=device))
log_v = torch.zeros(K, dtype=dtype, device=device)
for _ in range(iters):
log_col_sum = torch.full((K,), -float("inf"), dtype=dtype, device=device)
for i0 in range(0, N, chunk_p):
sims = rows_fn(i0, min(i0 + chunk_p, N)).clamp_min_(eps)
log_K = sims / tau
log_u = log_r - torch.logsumexp(log_K + log_v, dim=1)
blk = log_u[:, None] + log_K + log_v
log_col_sum = torch.logaddexp(log_col_sum, torch.logsumexp(blk, dim=0))
log_v = log_c - log_col_sum
labels = torch.empty(N, dtype=torch.long, device=device)
mass = torch.zeros(K, dtype=dtype, device=device)
for i0 in range(0, N, chunk_p):
sims = rows_fn(i0, min(i0 + chunk_p, N)).clamp_min_(eps)
log_K = sims / tau
log_u = log_r - torch.logsumexp(log_K + log_v, dim=1)
log_Y = log_u[:, None] + log_K + log_v
Y = torch.exp(log_Y)
labels[i0: i0 + Y.size(0)] = torch.argmax(Y, dim=1)
mass += Y.sum(0)
mass += eps
return labels, mass
def sinkhorn_chunked_log_np(
rows_fn,
N: int,
K: int,
tau: float = 0.05,
iters: int = 60,
chunk_p: int = 131_072,
eps: float = 1e-9,
dtype=np.float32,
):
log_r = -np.log(float(N)).astype(dtype)
log_c = -np.log(float(K)).astype(dtype)
log_v = np.zeros((K,), dtype=dtype)
for _ in range(iters):
log_col_sum = np.full((K,), -np.inf, dtype=dtype)
for i0 in range(0, N, chunk_p):
sims = rows_fn(i0, min(i0 + chunk_p, N)).astype(dtype)
sims = np.clip(sims, eps, None)
log_K = sims / tau
log_u = log_r - logsumexp(log_K + log_v, axis=1)
blk = log_u[:, None] + log_K + log_v
log_col_sum = np.logaddexp(log_col_sum, logsumexp(blk, axis=0))
log_v = log_c - log_col_sum
labels = np.empty((N,), dtype=np.int64)
mass = np.zeros((K,), dtype=dtype)
for i0 in range(0, N, chunk_p):
sims = rows_fn(i0, min(i0 + chunk_p, N)).astype(dtype)
sims = np.clip(sims, eps, None)
log_K = sims / tau
log_u = log_r - logsumexp(log_K + log_v, axis=1)
log_Y = log_u[:, None] + log_K + log_v
Y = np.exp(log_Y)
labels[i0: i0 + Y.shape[0]] = np.argmax(Y, axis=1)
mass += Y.sum(axis=0)
mass += eps
return labels, mass
# -----------------------------------------------------------------------------
# ot-fps clustering
# -----------------------------------------------------------------------------
@torch.no_grad()
def ot_fps_cluster_large(
xyz: torch.Tensor,
feats: torch.Tensor,
K: int = 64,
r_spatial: float = 0.05,
tau: float = 0.05,
sinkhorn_iters: int = 60,
outer_iters: int = 3,
chunk_P: int = 128,
):
device, N = xyz.device, xyz.size(0)
feats = F.normalize(feats, dim=1)
sigma2 = ((xyz - xyz.mean(0, keepdim=True)) ** 2).sum(1).mean()
EPS_TIE = 1e-6
centre_idx = fps(xyz, K)
c_xyz = xyz[centre_idx].clone()
c_fea = feats[centre_idx].clone()
for _ in range(outer_iters):
def rows_fn(i0: int, i1: int) -> torch.Tensor:
pts = xyz[i0:i1]
fea = feats[i0:i1]
d2 = torch.cdist(pts, c_xyz)
mask = d2 <= r_spatial
empty = mask.logical_not().all(dim=1)
if empty.any():
nearest = torch.argmin(d2[empty], dim=1)
mask[empty, nearest] = True
s_geo = torch.exp(-d2 / sigma2) * mask
s_fea = torch.abs(torch.einsum("id,kd->ik", fea, c_fea))
return s_geo * s_fea + EPS_TIE * torch.arange(K, device=device)
labels, mass = sinkhorn_chunked_log(
rows_fn,
N,
K,
tau=tau,
iters=sinkhorn_iters,
chunk_p=chunk_P,
dtype=xyz.dtype,
device=device,
)
dead = mass < 1e-4
if dead.any():
live_xyz = c_xyz[~dead]
dist2_live = torch.cdist(xyz, live_xyz)
far_idx = torch.topk(dist2_live.min(1).values, dead.sum().item()).indices
c_xyz[dead] = xyz[far_idx]
c_fea[dead] = feats[far_idx]
mass[dead] = 1.0
d2_to_c = torch.norm(xyz - c_xyz[labels], dim=1)
w = (d2_to_c <= r_spatial).float()
c_xyz.zero_()
c_fea.zero_()
c_xyz.index_add_(0, labels, xyz * w.unsqueeze(1))
c_fea.index_add_(0, labels, feats * w.unsqueeze(1))
mass = torch.zeros_like(mass).index_add_(0, labels, w)
live = mass > 0
c_xyz[live] = c_xyz[live] / mass[live, None]
c_fea[live] = F.normalize(c_fea[live] / mass[live, None], dim=1)
return labels
def ot_fps_cluster_large_np(
xyz: np.ndarray,
feats: np.ndarray,
K: int = 64,
r_spatial: float = 0.05,
tau: float = 0.05,
sinkhorn_iters: int = 10,
outer_iters: int = 3,
chunk_P: int = 128,
use_scipy_cdist: bool = True,
):
N = xyz.shape[0]
feats = feats / (np.linalg.norm(feats, axis=1, keepdims=True) + 1e-9)
sigma2 = np.mean(np.linalg.norm(xyz - xyz.mean(0), axis=1) ** 2)
EPS_TIE = 1e-6
c_xyz, c_idx = fps_np(xyz, K)
c_fea = feats[c_idx].copy()
for _ in range(outer_iters):
def rows_fn(i0, i1):
pts = xyz[i0:i1]
fea = feats[i0:i1]
if use_scipy_cdist:
d2 = scipy.spatial.distance.cdist(pts, c_xyz, metric="sqeuclidean")
else:
d2 = ((pts[:, None, :] - c_xyz[None, :, :]) ** 2).sum(-1)
mask = d2 <= r_spatial ** 2
empty = np.all(~mask, axis=1)
if np.any(empty):
nearest = np.argmin(d2[empty], axis=1)
mask[empty, nearest] = True
s_geo = np.exp(-d2 / sigma2) * mask
s_fea = np.abs(fea @ c_fea.T)
return s_geo * s_fea + EPS_TIE * np.arange(K)
labels, mass = sinkhorn_chunked_log_np(
rows_fn,
N,
K,
tau=tau,
iters=sinkhorn_iters,
chunk_p=chunk_P,
dtype=xyz.dtype,
)
dead = mass < 1e-4
if np.any(dead):
live_xyz = c_xyz[~dead]
dist2_live = ((xyz[:, None, :] - live_xyz[None, :, :]) ** 2).sum(-1)
far_idx = np.argpartition(dist2_live.min(axis=1), -dead.sum())[-dead.sum():]
c_xyz[dead] = xyz[far_idx]
c_fea[dead] = feats[far_idx]
mass[dead] = 1.0
d2_to_c = np.linalg.norm(xyz - c_xyz[labels], axis=1)
w = (d2_to_c <= r_spatial).astype(xyz.dtype)
c_xyz[:] = 0
c_fea[:] = 0
for k in range(K):
sel = labels == k
if sel.any():
c_xyz[k] = (xyz[sel] * w[sel, None]).sum(0)
c_fea[k] = (feats[sel] * w[sel, None]).sum(0)
mass[k] = w[sel].sum()
live = mass > 0
c_xyz[live] /= mass[live, None]
c_fea[live] = c_fea[live] / (np.linalg.norm(c_fea[live], axis=1, keepdims=True) + 1e-9)
return labels.astype(np.int64)
# -----------------------------------------------------------------------------
# high-level superpoint computation
# -----------------------------------------------------------------------------
def compute_superpoints(
xyz,
feats=None,
n_d_xyz=2048,
n_clus=64,
r_geo=0.1,
r_clus=0.05,
device=torch.device("cuda"),
method="ot",
normals=None,
pycut_kwargs=None,
):
"""
method:
'ot' -> original OT/FPS clustering
'pycut' -> L0-cut-pursuit superpoints via libcp
"""
if method == "pycut":
kw = dict(
k_feat=10, k_adj=10, chunk_size=8192,
use_input_normals=True, use_xyz=False,
xyz_scale=0.10, normal_scale=0.25,
lam=0.03, sigma=0.5,
mutual=False, undirected=True,
min_comp_weight=20, weight_decay=0.7,
verbose=False,
)
if pycut_kwargs is not None:
kw.update(pycut_kwargs)
xyz_f = np.asarray(xyz, dtype=np.float32)
xyz_norm = _normalize_xyz_np(xyz_f)
geom_feat = _local_geom_features_chunked_np(
xyz_norm, k_feat=kw["k_feat"], chunk_size=kw["chunk_size"],
)
feat_parts = [geom_feat]
if kw["use_input_normals"] and normals is not None:
nn = _normalize_normals_np(np.asarray(normals, dtype=np.float32))
feat_parts.append(nn * kw["normal_scale"])
if kw["use_xyz"]:
feat_parts.append(xyz_norm * kw["xyz_scale"])
Y = np.hstack(feat_parts).astype(np.float32)
src, dst = _build_adj_graph_np(
xyz_norm, k_adj=kw["k_adj"],
mutual=kw["mutual"], undirected=kw["undirected"],
)
ew = _edge_weights_chunked_np(Y.T, src, dst, lam=1.0, sigma=kw["sigma"])
try:
import libcp
except ImportError:
libcp_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"_cut_pursuit", "build", "src",
)
sys.path.insert(0, libcp_dir)
import libcp
components, in_component = libcp.cutpursuit(
Y,
src.astype(np.uint32),
dst.astype(np.uint32),
ew.astype(np.float32),
float(kw["lam"]),
int(kw["min_comp_weight"]),
0,
float(kw["weight_decay"]),
)
lbl_np = _relabel_contiguous_np(np.asarray(in_component, dtype=np.int32))
lbl = torch.from_numpy(lbl_np.astype(np.int64)).to(device)
feats_out = torch.from_numpy(Y).to(device)
return lbl, feats_out
# default: OT/FPS clustering
P = torch.from_numpy(xyz).to(device)
if feats is None:
Q = P[fps(P, n_d_xyz)]
feats = compute_geo_feats(P, Q, r=r_geo, chunk_p=4096, chunk_q=256)
lbl = ot_fps_cluster_large(P, feats, K=n_clus, r_spatial=r_clus)
return lbl, feats
# -----------------------------------------------------------------------------
# constrained fps
# -----------------------------------------------------------------------------
def sp_constrained_fps(xyz: torch.Tensor, sp: torch.Tensor, L: int):
device, N = xyz.device, xyz.size(0)
K = sp.max().item() + 1
assert L >= K
perm = torch.randperm(N, device=device)
anchors = torch.full((K,), N, dtype=torch.long, device=device)
anchors.scatter_reduce_(0, sp[perm], perm, reduce="amin")
idx = torch.empty(L, dtype=torch.long, device=device)
idx[:K] = anchors
dist2 = torch.cdist(xyz, xyz[anchors]).pow(2).min(dim=1).values
for i in range(K, L):
nxt = torch.argmax(dist2)
idx[i] = nxt
dist2 = torch.minimum(dist2, ((xyz - xyz[nxt]) ** 2).sum(-1))
return idx
def batch_random_anchor(sp: torch.Tensor):
device = sp.device
B, N = sp.shape
K = sp.max().item() + 1
perm = torch.argsort(torch.rand(B, N, device=device), dim=1)
labels_perm = sp.gather(1, perm)
anchors = torch.full((B, K), N, dtype=torch.long, device=device)
anchors.scatter_reduce_(1, labels_perm, perm, reduce="amin")
return anchors
def sp_constrained_fps_batch(
xyz: torch.Tensor,
sp: torch.Tensor,
L: int,
) -> torch.LongTensor:
device = xyz.device
B, N, _ = xyz.shape
K = sp.max().item() + 1
assert L >= K
perm = torch.argsort(torch.rand(B, N, device=device), dim=1)
anchors = torch.full((B, K), N, dtype=torch.long, device=device)
anchors.scatter_reduce_(1, sp.gather(1, perm), perm, reduce="amin")
idx = torch.empty(B, L, dtype=torch.long, device=device)
idx[:, :K] = anchors
c_xyz = xyz.gather(1, anchors[..., None].expand(-1, -1, 3))
dist2 = torch.cdist(xyz, c_xyz, p=2).min(dim=2).values
b_ids = torch.arange(B, device=device)
for i in range(K, L):
farthest = dist2.max(dim=1).indices
idx[:, i] = farthest
new_c = xyz[b_ids, farthest]
dist2 = torch.minimum(dist2, ((xyz - new_c[:, None, :]) ** 2).sum(-1))
return idx
def memory_efficient_fps_prob(
xyz: Tensor,
prob: Tensor,
k: int,
gamma: float = 0.5,
eps: float = 1e-12,
chunk_size: int = 32,
) -> Tensor:
B, N, _ = xyz.shape
device = xyz.device
dist = torch.full((B, N), float("inf"), device=device)
idx = torch.zeros(B, k, dtype=torch.long, device=device)
idx[:, 0] = prob.argmax(dim=1)
for i in range(1, k):
last_xyz = xyz.gather(1, idx[:, i - 1:i].unsqueeze(-1).expand(-1, -1, 3))
min_dist = torch.full((B, N), float("inf"), device=device)
for chunk_start in range(0, N, chunk_size):
chunk_end = min(chunk_start + chunk_size, N)
chunk_xyz = xyz[:, chunk_start:chunk_end, :]
chunk_dist = ((chunk_xyz - last_xyz) ** 2).sum(-1)
min_dist[:, chunk_start:chunk_end] = torch.minimum(
dist[:, chunk_start:chunk_end],
chunk_dist,
)
dist = min_dist
score = dist * (prob + eps) ** -gamma
idx[:, i] = score.argmax(dim=1)
return idx
# -----------------------------------------------------------------------------
# superpoint smoothing
# -----------------------------------------------------------------------------
def _superpoint_pool(feat: torch.Tensor, spts: torch.Tensor) -> torch.Tensor:
B, N, D = feat.shape
feat_flat = feat.reshape(B * N, D)
K = int(spts.max().item()) + 1
offsets = (torch.arange(B, device=feat.device) * K).view(B, 1)
sp_offset = (spts + offsets).reshape(-1)
tot_sp = B * K
feat_sum = torch.zeros(tot_sp, D, device=feat.device, dtype=feat.dtype)
feat_sum.index_add_(0, sp_offset, feat_flat)
cnt_sum = torch.zeros(tot_sp, 1, device=feat.device, dtype=feat.dtype)
cnt_sum.index_add_(0, sp_offset, torch.ones(B * N, 1, device=feat.device, dtype=feat.dtype))
feat_avg = feat_sum / (cnt_sum + 1e-6)
feat_denoised = feat_avg[sp_offset].view(B, N, D)
return feat_denoised
# -----------------------------------------------------------------------------
# example
# -----------------------------------------------------------------------------
if __name__ == "__main__":
from lib_vis import create_labeled_point_cloud
import open3d as o3d
pcd = o3d.io.read_point_cloud("s3dis.ply")
print(pcd)
xyz = np.array(pcd.points)
normals = None
normals_path = "../normal.npy"
if os.path.exists(normals_path):
normals = np.load(normals_path).astype(np.float32)
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------------------------------------------------------
# Option 1: pycut-pursuit superpoints
# ---------------------------------------------------------
lbl, feats = compute_superpoints(
xyz,
normals=normals,
device=device,
method="pycut",
pycut_kwargs=dict(
k_feat=10,
k_adj=10,
chunk_size=8192,
use_input_normals=(normals is not None),
use_xyz=False,
xyz_scale=0.10,
normal_scale=0.25,
lam=5.0,
sigma=0.5,
mutual=False,
undirected=True,
cp_it_max=10,
split_iter_num=2,
split_damp_ratio=0.7,
kmpp_init_num=3,
kmpp_iter_num=3,
K=2,
min_comp_weight=20.0,
verbose=False,
),
)
pcd = torch.from_numpy(xyz).to(device)
idx = sp_constrained_fps(pcd, lbl, 2048).tolist()
down_pcd = pcd[idx]
down_lbl = lbl[idx]
create_labeled_point_cloud(xyz[idx], lbl[idx].tolist(), name="cls_pycut")
down_idx = superpoint_fps(
down_pcd.unsqueeze(0),
down_lbl.unsqueeze(0),
k=128,
gamma=0.1,
base_scale=1.5,
)[0].tolist()
create_labeled_point_cloud(
down_pcd[down_idx].cpu().numpy(),
down_lbl[down_idx].tolist(),
name="fps_pycut",
)
print(torch.unique(down_lbl[down_idx], return_counts=True))