instruct-particulate / instruct_particulate /utils /postprocessing_utils.py
rayli's picture
Clean unused demo logic
13116e0 verified
Raw
History Blame Contribute Delete
62.4 kB
import os
import time
import heapq
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import trimesh
from numba import njit, prange
from scipy.sparse import coo_matrix, csr_matrix
from scipy.sparse.csgraph import connected_components
from scipy.spatial import cKDTree
DEFAULT_FACE_GROUP_DISTANCE_THRESHOLD = 1e-4
def _get_face_centroids(mesh):
"""Return face centroids in a compact dtype for NN queries."""
return np.asarray(mesh.triangles_center, dtype=np.float32)
def _query_nearest(tree, query_points):
"""
Query nearest neighbors with all available CPU workers.
SciPy exposes ``workers`` in cKDTree.query; using -1 lets it parallelize.
"""
return tree.query(query_points, k=1, workers=-1)
def _mesh_face_connected_components(mesh):
"""Return mesh face connected components as dense int64 arrays."""
components = trimesh.graph.connected_components(
edges=mesh.face_adjacency,
nodes=np.arange(len(mesh.faces)),
min_len=1,
)
return [np.array(list(component), dtype=np.int64) for component in components]
def assign_undefined_faces_to_nearest_defined(mesh, face_part_ids):
"""
Fill undefined (-1) face labels by nearest defined face label.
For scalability this uses a KD-tree over defined face centroids.
"""
face_part_ids_filled = face_part_ids.copy()
undefined_faces = np.flatnonzero(face_part_ids_filled == -1)
if len(undefined_faces) == 0:
return face_part_ids_filled
defined_faces = np.flatnonzero(face_part_ids_filled != -1)
if len(defined_faces) == 0:
return face_part_ids_filled
centroids = _get_face_centroids(mesh)
tree = cKDTree(centroids[defined_faces])
_, nearest_local = _query_nearest(tree, centroids[undefined_faces])
nearest_local = np.atleast_1d(nearest_local)
nearest_defined_faces = defined_faces[nearest_local]
face_part_ids_filled[undefined_faces] = face_part_ids_filled[nearest_defined_faces]
return face_part_ids_filled
def refine_part_ids_strict(mesh, face_part_ids):
"""
Refine face part IDs by treating each connected component (CC) independently.
For each CC:
- If it has any defined labels, all faces are overwritten with the dominant
part ID by surface area.
- If all faces are undefined (-1), assign all faces from the nearest defined CC.
Args:
mesh: trimesh object
face_part_ids: part ID for each face [num_faces]
Returns:
refined_face_part_ids: refined part ID for each face [num_faces]
"""
face_part_ids = np.asarray(face_part_ids, dtype=np.int32).copy()
mesh_components = _mesh_face_connected_components(mesh)
component_dominant_part_id = {}
undefined_components = []
# For each connected component, find the dominant part ID by surface area.
for comp_idx, component in enumerate(mesh_components):
if len(component) == 0:
continue
component_part_ids = face_part_ids[component]
valid_mask = component_part_ids != -1
if not np.any(valid_mask):
undefined_components.append(comp_idx)
continue
valid_part_ids = component_part_ids[valid_mask].astype(np.int64)
valid_face_areas = mesh.area_faces[component[valid_mask]]
unique_part_ids, inverse = np.unique(valid_part_ids, return_inverse=True)
part_area_sums = np.bincount(inverse, weights=valid_face_areas)
dominant_part_id = int(unique_part_ids[np.argmax(part_area_sums)])
component_dominant_part_id[comp_idx] = dominant_part_id
face_part_ids[component] = dominant_part_id
# Components that are entirely undefined are assigned from the nearest
# component that has a defined dominant part label.
if undefined_components and component_dominant_part_id:
centroids = _get_face_centroids(mesh)
face_to_component = np.full(len(mesh.faces), -1, dtype=np.int32)
defined_face_chunks = []
undefined_face_chunks = []
for comp_idx in component_dominant_part_id.keys():
comp_faces = mesh_components[comp_idx]
face_to_component[comp_faces] = comp_idx
defined_face_chunks.append(comp_faces)
for comp_idx in undefined_components:
comp_faces = mesh_components[comp_idx]
face_to_component[comp_faces] = comp_idx
undefined_face_chunks.append(comp_faces)
defined_faces = np.concatenate(defined_face_chunks, axis=0)
undefined_faces = np.concatenate(undefined_face_chunks, axis=0)
tree = cKDTree(centroids[defined_faces])
nearest_dist, nearest_local = _query_nearest(tree, centroids[undefined_faces])
nearest_local = np.atleast_1d(nearest_local)
nearest_dist = np.atleast_1d(nearest_dist)
undefined_face_components = face_to_component[undefined_faces]
nearest_defined_faces = defined_faces[nearest_local]
nearest_defined_components = face_to_component[nearest_defined_faces]
order = np.argsort(undefined_face_components, kind="mergesort")
sorted_undefined_comps = undefined_face_components[order]
sorted_dists = nearest_dist[order]
sorted_nearest_defined_comps = nearest_defined_components[order]
unique_undefined_comps, group_start = np.unique(sorted_undefined_comps, return_index=True)
group_end = np.concatenate([group_start[1:], np.array([len(sorted_undefined_comps)])])
for comp_idx, start, end in zip(unique_undefined_comps, group_start, group_end):
best_local = start + int(np.argmin(sorted_dists[start:end]))
nearest_defined_comp = int(sorted_nearest_defined_comps[best_local])
face_part_ids[mesh_components[int(comp_idx)]] = component_dominant_part_id[nearest_defined_comp]
return face_part_ids
def _majority_vote_face_part_ids(mesh, part_ids, face_indices):
"""Assigns each sampled face the majority label of its query points.
Faces that never received a sampled query point remain `-1`.
Args:
mesh: trimesh object
part_ids: part IDs for each sampled point [num_points]
face_indices: which face each point lies on (-1 means on edge) [num_points]
Returns:
Face labels with unresolved faces marked as `-1`.
"""
num_faces = len(mesh.faces)
face_part_ids = np.full(num_faces, -1, dtype=np.int32)
face_to_points = {}
for point_idx, face_idx in enumerate(face_indices):
if face_idx == -1:
continue
if face_idx not in face_to_points:
face_to_points[face_idx] = []
face_to_points[face_idx].append(part_ids[point_idx])
for face_idx, point_part_ids in face_to_points.items():
counts = np.bincount(point_part_ids)
majority_part_id = np.argmax(counts)
face_part_ids[face_idx] = majority_part_id
return face_part_ids
def find_unrefined_part_ids_for_faces(mesh, part_ids, face_indices):
"""Builds the face labels used before connected-component refinement.
This matches the user's requested "unrefined" representation:
1. majority-vote query labels per sampled face
2. nearest-neighbor fill for unsampled faces
Args:
mesh: trimesh object
part_ids: part IDs for each sampled point [num_points]
face_indices: which face each point lies on (-1 means on edge) [num_points]
Returns:
Face labels after majority vote plus nearest-face fill.
"""
initial_face_part_ids = _majority_vote_face_part_ids(mesh, part_ids, face_indices)
return assign_undefined_faces_to_nearest_defined(mesh, initial_face_part_ids)
def refine_face_part_ids(mesh, face_part_ids, strict=False):
"""Apply the base face-label post-processing stage.
Args:
mesh: trimesh object
face_part_ids: Face labels after the unrefined majority-vote stage.
strict: Whether to use strict refinement. When False, the unrefined
labels are returned unchanged.
Returns:
Base per-face part IDs used for the final segmentation export.
"""
if strict:
return refine_part_ids_strict(mesh, face_part_ids)
return np.asarray(face_part_ids, dtype=np.int32).copy()
def point_to_triangle_distance_batch(points_batch, tri_verts_batch):
"""Compute squared distances from batched points to batched triangles."""
v0 = tri_verts_batch[:, 0, :]
v1 = tri_verts_batch[:, 1, :]
v2 = tri_verts_batch[:, 2, :]
edge0 = v1 - v0
edge1 = v2 - v0
normals = np.cross(edge0, edge1)
normal_norms = np.linalg.norm(normals, axis=1, keepdims=True)
valid_mask = normal_norms[:, 0] >= 1e-10
normals = normals / np.maximum(normal_norms, 1e-10)
to_points = points_batch - v0[:, np.newaxis, :]
dist_to_plane = np.einsum("mnk,mk->mn", to_points, normals)
points_on_plane = (
points_batch - dist_to_plane[:, :, np.newaxis] * normals[:, np.newaxis, :]
)
v = points_on_plane - v0[:, np.newaxis, :]
d00 = np.einsum("mk,mk->m", edge0, edge0)
d01 = np.einsum("mk,mk->m", edge0, edge1)
d11 = np.einsum("mk,mk->m", edge1, edge1)
d20 = np.einsum("mnk,mk->mn", v, edge0)
d21 = np.einsum("mnk,mk->mn", v, edge1)
denom = d00 * d11 - d01 * d01
valid_denom = np.abs(denom) >= 1e-10
denom = np.where(valid_denom, denom, 1.0)[:, np.newaxis]
bary_v = (d11[:, np.newaxis] * d20 - d01[:, np.newaxis] * d21) / denom
bary_w = (d00[:, np.newaxis] * d21 - d01[:, np.newaxis] * d20) / denom
bary_u = 1.0 - bary_v - bary_w
inside_mask = (bary_u >= -1e-10) & (bary_v >= -1e-10) & (bary_w >= -1e-10)
inside_mask = inside_mask & valid_mask[:, np.newaxis] & valid_denom[:, np.newaxis]
distances_sq = dist_to_plane * dist_to_plane
outside_mask = ~inside_mask
if np.any(outside_mask):
edge = edge0
edge_len_sq = d00
ap = points_batch - v0[:, np.newaxis, :]
t = np.clip(
np.einsum("mnk,mk->mn", ap, edge)
/ np.maximum(edge_len_sq[:, np.newaxis], 1e-10),
0,
1,
)
proj = v0[:, np.newaxis, :] + t[:, :, np.newaxis] * edge[:, np.newaxis, :]
diff = points_batch - proj
dist_edge0_sq = np.einsum("mnk,mnk->mn", diff, diff)
edge = v2 - v1
edge_len_sq = np.einsum("mk,mk->m", edge, edge)
ap = points_batch - v1[:, np.newaxis, :]
t = np.clip(
np.einsum("mnk,mk->mn", ap, edge)
/ np.maximum(edge_len_sq[:, np.newaxis], 1e-10),
0,
1,
)
proj = v1[:, np.newaxis, :] + t[:, :, np.newaxis] * edge[:, np.newaxis, :]
diff = points_batch - proj
dist_edge1_sq = np.einsum("mnk,mnk->mn", diff, diff)
edge = v0 - v2
edge_len_sq = np.einsum("mk,mk->m", edge, edge)
ap = points_batch - v2[:, np.newaxis, :]
t = np.clip(
np.einsum("mnk,mk->mn", ap, edge)
/ np.maximum(edge_len_sq[:, np.newaxis], 1e-10),
0,
1,
)
proj = v2[:, np.newaxis, :] + t[:, :, np.newaxis] * edge[:, np.newaxis, :]
diff = points_batch - proj
dist_edge2_sq = np.einsum("mnk,mnk->mn", diff, diff)
min_edge_dist_sq = np.minimum(
dist_edge0_sq,
np.minimum(dist_edge1_sq, dist_edge2_sq),
)
distances_sq = np.where(outside_mask, min_edge_dist_sq, distances_sq)
return distances_sq
def resolve_point_prompt_face_ids(
mesh,
point_prompts,
*,
exact_batch_size=8192,
):
"""Resolve each point prompt to the nearest mesh face in the same coordinate frame.
This is used for backward compatibility when saved prompt-face IDs are unavailable.
The implementation is exact but keeps the work bounded by:
- initializing the best candidate from a face-centroid KD-tree
- pruning exact triangle checks with face AABB lower bounds
"""
point_prompts = np.asarray(point_prompts, dtype=np.float32)
if point_prompts.ndim != 2 or point_prompts.shape[1] != 3:
raise ValueError(
"point_prompts must have shape (num_prompts, 3), "
f"got {point_prompts.shape}"
)
if point_prompts.shape[0] == 0:
return np.zeros((0,), dtype=np.int64)
face_verts = np.asarray(mesh.triangles, dtype=np.float64)
if face_verts.shape[0] == 0:
raise ValueError("cannot resolve point-prompt faces on a mesh with zero faces")
bbox_mins = np.min(face_verts, axis=1)
bbox_maxs = np.max(face_verts, axis=1)
face_centroids = np.asarray(mesh.triangles_center, dtype=np.float64)
centroid_tree = cKDTree(face_centroids)
face_ids = np.zeros((point_prompts.shape[0],), dtype=np.int64)
all_face_ids = np.arange(len(face_verts), dtype=np.int64)
for prompt_idx, point_prompt in enumerate(point_prompts.astype(np.float64, copy=False)):
_, seed_face_local = _query_nearest(centroid_tree, point_prompt[None, :])
seed_face_id = int(np.atleast_1d(seed_face_local)[0])
best_face_id = seed_face_id
best_sq = float(
point_to_triangle_distance_batch(
np.broadcast_to(point_prompt, (1, 1, 3)),
face_verts[[seed_face_id]],
)[0, 0]
)
axis_gap = np.maximum(
np.maximum(bbox_mins - point_prompt[None, :], point_prompt[None, :] - bbox_maxs),
0.0,
)
lower_bound_sq = np.einsum("ij,ij->i", axis_gap, axis_gap)
candidate_face_ids = all_face_ids[lower_bound_sq < best_sq]
if candidate_face_ids.size == 0:
face_ids[prompt_idx] = best_face_id
continue
for start in range(0, len(candidate_face_ids), int(exact_batch_size)):
batch_face_ids = candidate_face_ids[start:start + int(exact_batch_size)]
batch_dist_sq = point_to_triangle_distance_batch(
np.broadcast_to(point_prompt, (len(batch_face_ids), 1, 3)),
face_verts[batch_face_ids],
)[:, 0]
batch_best_local = int(np.argmin(batch_dist_sq))
batch_best_sq = float(batch_dist_sq[batch_best_local])
if batch_best_sq < best_sq:
best_sq = batch_best_sq
best_face_id = int(batch_face_ids[batch_best_local])
face_ids[prompt_idx] = best_face_id
return face_ids
def segment_segment_distance_sq_batch(p1, q1, p2, q2):
"""Compute squared distances between batched 3D line segments."""
eps = 1e-12
u = q1 - p1
v = q2 - p2
w = p1 - p2
a = np.einsum("ij,ij->i", u, u)
b = np.einsum("ij,ij->i", u, v)
c = np.einsum("ij,ij->i", v, v)
d = np.einsum("ij,ij->i", u, w)
e = np.einsum("ij,ij->i", v, w)
det = a * c - b * b
s_n = np.empty_like(det)
t_n = np.empty_like(det)
s_d = np.empty_like(det)
t_d = np.empty_like(det)
parallel_mask = det < eps
non_parallel_mask = ~parallel_mask
s_n[parallel_mask] = 0.0
s_d[parallel_mask] = 1.0
t_n[parallel_mask] = e[parallel_mask]
t_d[parallel_mask] = c[parallel_mask]
s_n[non_parallel_mask] = (
b[non_parallel_mask] * e[non_parallel_mask]
- c[non_parallel_mask] * d[non_parallel_mask]
)
t_n[non_parallel_mask] = (
a[non_parallel_mask] * e[non_parallel_mask]
- b[non_parallel_mask] * d[non_parallel_mask]
)
s_d[non_parallel_mask] = det[non_parallel_mask]
t_d[non_parallel_mask] = det[non_parallel_mask]
mask = non_parallel_mask & (s_n < 0.0)
s_n[mask] = 0.0
t_n[mask] = e[mask]
t_d[mask] = c[mask]
mask = non_parallel_mask & (s_n > s_d)
s_n[mask] = s_d[mask]
t_n[mask] = e[mask] + b[mask]
t_d[mask] = c[mask]
mask = t_n < 0.0
t_n[mask] = 0.0
s_n[mask] = -d[mask]
s_d[mask] = a[mask]
mask2 = mask & (s_n < 0.0)
s_n[mask2] = 0.0
mask2 = mask & (s_n > s_d)
s_n[mask2] = s_d[mask2]
mask = t_n > t_d
t_n[mask] = t_d[mask]
s_n[mask] = -d[mask] + b[mask]
s_d[mask] = a[mask]
mask2 = mask & (s_n < 0.0)
s_n[mask2] = 0.0
mask2 = mask & (s_n > s_d)
s_n[mask2] = s_d[mask2]
sc = np.zeros_like(s_n)
tc = np.zeros_like(t_n)
valid_s = np.abs(s_d) > eps
valid_t = np.abs(t_d) > eps
sc[valid_s] = s_n[valid_s] / s_d[valid_s]
tc[valid_t] = t_n[valid_t] / t_d[valid_t]
delta = w + sc[:, np.newaxis] * u - tc[:, np.newaxis] * v
return np.einsum("ij,ij->i", delta, delta)
def segment_intersects_triangle_batch(seg_start, seg_end, tri_verts_batch, eps=1e-10):
"""Test batched segment-triangle intersections."""
direction = seg_end - seg_start
v0 = tri_verts_batch[:, 0, :]
v1 = tri_verts_batch[:, 1, :]
v2 = tri_verts_batch[:, 2, :]
edge1 = v1 - v0
edge2 = v2 - v0
pvec = np.cross(direction, edge2)
det = np.einsum("ij,ij->i", edge1, pvec)
non_parallel = np.abs(det) > eps
inv_det = np.zeros_like(det)
inv_det[non_parallel] = 1.0 / det[non_parallel]
tvec = seg_start - v0
u = np.einsum("ij,ij->i", tvec, pvec) * inv_det
qvec = np.cross(tvec, edge1)
v = np.einsum("ij,ij->i", direction, qvec) * inv_det
t = np.einsum("ij,ij->i", edge2, qvec) * inv_det
return (
non_parallel
& (u >= -eps)
& (v >= -eps)
& (u + v <= 1.0 + eps)
& (t >= -eps)
& (t <= 1.0 + eps)
)
def triangle_pairs_within_threshold_batch(tri_a_batch, tri_b_batch, threshold_sq):
"""Return mask of triangle pairs with exact distance < threshold."""
num_pairs = len(tri_a_batch)
if num_pairs == 0:
return np.zeros(0, dtype=bool)
adjacent = np.zeros(num_pairs, dtype=bool)
edge_indices = ((0, 1), (1, 2), (2, 0))
min_vv_sq = np.full(num_pairs, np.inf, dtype=tri_a_batch.dtype)
for ia in range(3):
pa = tri_a_batch[:, ia, :]
for ib in range(3):
pb = tri_b_batch[:, ib, :]
diff = pa - pb
vv_sq = np.einsum("ij,ij->i", diff, diff)
min_vv_sq = np.minimum(min_vv_sq, vv_sq)
adjacent |= min_vv_sq < threshold_sq
remaining_mask = ~adjacent
if not np.any(remaining_mask):
return adjacent
remaining_idx = np.flatnonzero(remaining_mask)
tri_a_rem = tri_a_batch[remaining_idx]
tri_b_rem = tri_b_batch[remaining_idx]
d_a_to_b_sq = point_to_triangle_distance_batch(tri_a_rem, tri_b_rem)
d_b_to_a_sq = point_to_triangle_distance_batch(tri_b_rem, tri_a_rem)
min_pt_sq = np.minimum(
np.min(d_a_to_b_sq, axis=1),
np.min(d_b_to_a_sq, axis=1),
)
pt_adjacent = min_pt_sq < threshold_sq
if np.any(pt_adjacent):
adjacent[remaining_idx[pt_adjacent]] = True
remaining_mask = ~adjacent
if not np.any(remaining_mask):
return adjacent
remaining_idx = np.flatnonzero(remaining_mask)
tri_a_rem = tri_a_batch[remaining_idx]
tri_b_rem = tri_b_batch[remaining_idx]
min_edge_sq = np.full(len(remaining_idx), np.inf, dtype=tri_a_rem.dtype)
for a0, a1 in edge_indices:
p1 = tri_a_rem[:, a0, :]
q1 = tri_a_rem[:, a1, :]
for b0, b1 in edge_indices:
p2 = tri_b_rem[:, b0, :]
q2 = tri_b_rem[:, b1, :]
edge_dist_sq = segment_segment_distance_sq_batch(p1, q1, p2, q2)
min_edge_sq = np.minimum(min_edge_sq, edge_dist_sq)
edge_adjacent = min_edge_sq < threshold_sq
if np.any(edge_adjacent):
adjacent[remaining_idx[edge_adjacent]] = True
remaining_mask = ~adjacent
if not np.any(remaining_mask):
return adjacent
remaining_idx = np.flatnonzero(remaining_mask)
tri_a_rem = tri_a_batch[remaining_idx]
tri_b_rem = tri_b_batch[remaining_idx]
intersects = np.zeros(len(remaining_idx), dtype=bool)
for a0, a1 in edge_indices:
intersects |= segment_intersects_triangle_batch(
tri_a_rem[:, a0, :],
tri_a_rem[:, a1, :],
tri_b_rem,
)
intersects |= segment_intersects_triangle_batch(
tri_b_rem[:, a0, :],
tri_b_rem[:, a1, :],
tri_a_rem,
)
if np.any(intersects):
adjacent[remaining_idx[intersects]] = True
return adjacent
def _triangle_pair_distance_sq_batch(tri_a_batch, tri_b_batch):
"""Compute exact triangle-triangle squared distances for a batch."""
num_pairs = len(tri_a_batch)
if num_pairs == 0:
return np.zeros(0, dtype=np.float64)
edge_indices = ((0, 1), (1, 2), (2, 0))
min_sq = np.full(num_pairs, np.inf, dtype=np.float64)
d_a_to_b_sq = point_to_triangle_distance_batch(tri_a_batch, tri_b_batch)
d_b_to_a_sq = point_to_triangle_distance_batch(tri_b_batch, tri_a_batch)
min_sq = np.minimum(min_sq, np.min(d_a_to_b_sq, axis=1))
min_sq = np.minimum(min_sq, np.min(d_b_to_a_sq, axis=1))
for a0, a1 in edge_indices:
p1 = tri_a_batch[:, a0, :]
q1 = tri_a_batch[:, a1, :]
for b0, b1 in edge_indices:
p2 = tri_b_batch[:, b0, :]
q2 = tri_b_batch[:, b1, :]
min_sq = np.minimum(
min_sq,
segment_segment_distance_sq_batch(p1, q1, p2, q2),
)
intersects = np.zeros(num_pairs, dtype=bool)
for a0, a1 in edge_indices:
intersects |= segment_intersects_triangle_batch(
tri_a_batch[:, a0, :],
tri_a_batch[:, a1, :],
tri_b_batch,
)
intersects |= segment_intersects_triangle_batch(
tri_b_batch[:, a0, :],
tri_b_batch[:, a1, :],
tri_a_batch,
)
min_sq[intersects] = 0.0
return min_sq
@njit(cache=True, parallel=True)
def generate_candidate_pairs_sweep_numba(
order,
mins_a,
maxs_a,
mins_b,
maxs_b,
upper_bounds,
distance_threshold,
):
"""Generate exact bbox candidate pairs using the reviewed parallel sweep-line logic."""
n_faces = len(order)
counts = np.zeros(n_faces, dtype=np.int64)
for i in prange(n_faces):
upper_bound = upper_bounds[i]
if upper_bound <= i + 1:
continue
min_ai = mins_a[i]
max_ai = maxs_a[i]
min_bi = mins_b[i]
max_bi = maxs_b[i]
local_count = 0
for j in range(i + 1, upper_bound):
if min_ai - maxs_a[j] >= distance_threshold:
continue
if mins_a[j] - max_ai >= distance_threshold:
continue
if min_bi - maxs_b[j] >= distance_threshold:
continue
if mins_b[j] - max_bi >= distance_threshold:
continue
local_count += 1
counts[i] = local_count
offsets = np.empty(n_faces, dtype=np.int64)
total_count = 0
for i in range(n_faces):
offsets[i] = total_count
total_count += counts[i]
candidate_pairs = np.empty((total_count, 2), dtype=np.int64)
for i in prange(n_faces):
upper_bound = upper_bounds[i]
if upper_bound <= i + 1:
continue
min_ai = mins_a[i]
max_ai = maxs_a[i]
min_bi = mins_b[i]
max_bi = maxs_b[i]
face_i = order[i]
out_idx = offsets[i]
for j in range(i + 1, upper_bound):
if min_ai - maxs_a[j] >= distance_threshold:
continue
if mins_a[j] - max_ai >= distance_threshold:
continue
if min_bi - maxs_b[j] >= distance_threshold:
continue
if mins_b[j] - max_bi >= distance_threshold:
continue
face_j = order[j]
if face_i < face_j:
candidate_pairs[out_idx, 0] = face_i
candidate_pairs[out_idx, 1] = face_j
else:
candidate_pairs[out_idx, 0] = face_j
candidate_pairs[out_idx, 1] = face_i
out_idx += 1
return candidate_pairs
def filter_adjacent_pairs_batch(batch_pairs, verts, faces, threshold_sq):
"""Filter candidate pairs to those with exact triangle distance < threshold."""
if len(batch_pairs) == 0:
return batch_pairs
face_i_indices = batch_pairs[:, 0]
face_j_indices = batch_pairs[:, 1]
face_i_vids = faces[face_i_indices]
face_j_vids = faces[face_j_indices]
adjacent_mask = np.any(
face_i_vids[:, :, np.newaxis] == face_j_vids[:, np.newaxis, :],
axis=(1, 2),
)
remaining_mask = ~adjacent_mask
if np.any(remaining_mask):
remaining_idx = np.flatnonzero(remaining_mask)
face_i_vids_rem = face_i_vids[remaining_idx]
face_j_vids_rem = face_j_vids[remaining_idx]
face_i_verts_rem = verts[face_i_vids_rem]
face_j_verts_rem = verts[face_j_vids_rem]
mins_i = np.min(face_i_verts_rem, axis=1)
maxs_i = np.max(face_i_verts_rem, axis=1)
mins_j = np.min(face_j_verts_rem, axis=1)
maxs_j = np.max(face_j_verts_rem, axis=1)
axis_gap = np.maximum(mins_i - maxs_j, mins_j - maxs_i)
axis_gap = np.maximum(axis_gap, 0.0)
lower_bound_sq = np.einsum("ij,ij->i", axis_gap, axis_gap)
maybe_adjacent = lower_bound_sq < threshold_sq
if np.any(maybe_adjacent):
geom_idx = remaining_idx[maybe_adjacent]
tri_adjacent_mask = triangle_pairs_within_threshold_batch(
face_i_verts_rem[maybe_adjacent],
face_j_verts_rem[maybe_adjacent],
threshold_sq,
)
if np.any(tri_adjacent_mask):
adjacent_mask[geom_idx[tri_adjacent_mask]] = True
return batch_pairs[adjacent_mask]
def build_face_edge_adjacency(mesh):
"""Build adjacency from exact mesh face-edge connectivity."""
face_adjacency = defaultdict(set)
for face_i, face_j in np.asarray(mesh.face_adjacency, dtype=np.int64):
face_adjacency[int(face_i)].add(int(face_j))
face_adjacency[int(face_j)].add(int(face_i))
return face_adjacency
def build_face_distance_adjacency(
verts,
faces,
distance_threshold=DEFAULT_FACE_GROUP_DISTANCE_THRESHOLD,
max_distance_workers=None,
component_labels=None,
cross_component_only=False,
log_prefix="",
):
"""Build face adjacency based on exact triangle-to-triangle distances."""
n_faces = len(faces)
face_verts_all = verts[faces]
bbox_mins = np.min(face_verts_all, axis=1)
bbox_maxs = np.max(face_verts_all, axis=1)
if cross_component_only:
if component_labels is None:
raise ValueError("component_labels is required when cross_component_only=True")
component_labels = np.asarray(component_labels, dtype=np.int32)
if component_labels.shape != (n_faces,):
raise ValueError(
"component_labels must have one entry per face, "
f"got shape {component_labels.shape} for {n_faces} faces"
)
step1_start = time.time()
n_pairs_total = n_faces * (n_faces - 1) // 2
if cross_component_only:
component_sizes = np.bincount(component_labels.astype(np.int64, copy=False))
intra_component_pairs = int(
np.sum(component_sizes * np.maximum(component_sizes - 1, 0) // 2)
)
n_pairs_total -= intra_component_pairs
idx = np.arange(n_faces, dtype=np.int64)
best_axis = 0
best_order = np.argsort(bbox_mins[:, 0], kind="mergesort")
mins_axis = bbox_mins[best_order, 0]
maxs_axis = bbox_maxs[best_order, 0]
best_upper = np.searchsorted(
mins_axis,
maxs_axis + distance_threshold,
side="left",
).astype(np.int64, copy=False)
window = best_upper.copy()
window -= idx
window -= 1
window[window < 0] = 0
best_estimated_checks = int(np.sum(window))
for axis in (1, 2):
order_axis = np.argsort(bbox_mins[:, axis], kind="mergesort")
mins_axis = bbox_mins[order_axis, axis]
maxs_axis = bbox_maxs[order_axis, axis]
upper_axis = np.searchsorted(
mins_axis,
maxs_axis + distance_threshold,
side="left",
).astype(np.int64, copy=False)
window_axis = upper_axis.copy()
window_axis -= idx
window_axis -= 1
window_axis[window_axis < 0] = 0
estimated_checks = int(np.sum(window_axis))
if estimated_checks < best_estimated_checks:
best_estimated_checks = estimated_checks
best_axis = axis
best_order = order_axis
best_upper = upper_axis
other_axes = [axis for axis in (0, 1, 2) if axis != best_axis]
axis_a, axis_b = other_axes
sorted_mins = bbox_mins[best_order]
sorted_maxs = bbox_maxs[best_order]
candidate_pairs = generate_candidate_pairs_sweep_numba(
best_order,
sorted_mins[:, axis_a],
sorted_maxs[:, axis_a],
sorted_mins[:, axis_b],
sorted_maxs[:, axis_b],
best_upper,
distance_threshold,
)
if cross_component_only and len(candidate_pairs) > 0:
cross_component_mask = (
component_labels[candidate_pairs[:, 0]] != component_labels[candidate_pairs[:, 1]]
)
candidate_pairs = candidate_pairs[cross_component_mask]
step1_time = time.time() - step1_start
axis_name = "xyz"[best_axis]
sparsity = len(candidate_pairs) / n_pairs_total if n_pairs_total > 0 else 0.0
prefix = f"{log_prefix} " if log_prefix else ""
print(
f"{prefix}Step 1 (Exact sparse candidate generation): {step1_time:.4f}s - "
f"{best_estimated_checks} axis-{axis_name} sweep checks -> "
f"{len(candidate_pairs):,} candidates ({sparsity:.8f} of all pairs)"
)
step2_start = time.time()
face_adjacency = defaultdict(set)
batch_size = 200_000
threshold_sq = distance_threshold * distance_threshold
if max_distance_workers is None:
cpu_count = os.cpu_count() or 1
max_workers = min(cpu_count, 64)
else:
max_workers = max(1, int(max_distance_workers))
accepted_rows = []
accepted_cols = []
total_distance_batches = 0
if max_workers == 1:
for start in range(0, len(candidate_pairs), batch_size):
end = min(start + batch_size, len(candidate_pairs))
adjacent_pairs = filter_adjacent_pairs_batch(
candidate_pairs[start:end],
verts,
faces,
threshold_sq,
)
total_distance_batches += 1
if len(adjacent_pairs) > 0:
accepted_rows.append(adjacent_pairs[:, 0].astype(np.int32, copy=False))
accepted_cols.append(adjacent_pairs[:, 1].astype(np.int32, copy=False))
else:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for start in range(0, len(candidate_pairs), batch_size):
end = min(start + batch_size, len(candidate_pairs))
futures.append(
executor.submit(
filter_adjacent_pairs_batch,
candidate_pairs[start:end],
verts,
faces,
threshold_sq,
)
)
total_distance_batches += 1
for future in as_completed(futures):
adjacent_pairs = future.result()
if len(adjacent_pairs) > 0:
accepted_rows.append(adjacent_pairs[:, 0].astype(np.int32, copy=False))
accepted_cols.append(adjacent_pairs[:, 1].astype(np.int32, copy=False))
if accepted_rows:
rows = np.concatenate(accepted_rows, axis=0)
cols = np.concatenate(accepted_cols, axis=0)
sym_rows = np.concatenate((rows, cols), axis=0)
sym_cols = np.concatenate((cols, rows), axis=0)
data = np.ones(len(sym_rows), dtype=np.uint8)
adjacency_csr = coo_matrix(
(data, (sym_rows, sym_cols)),
shape=(n_faces, n_faces),
dtype=np.uint8,
).tocsr()
indptr = adjacency_csr.indptr
indices = adjacency_csr.indices
non_empty_rows = np.flatnonzero(np.diff(indptr))
for face_i in non_empty_rows:
start = indptr[face_i]
end = indptr[face_i + 1]
face_adjacency[int(face_i)] = set(indices[start:end].tolist())
step2_time = time.time() - step2_start
print(
f"{prefix}Step 2 (Precise distance computation): {step2_time:.4f}s - "
f"{len(candidate_pairs):,} candidates across {total_distance_batches:,} "
f"distance batches; {len(face_adjacency)} faces have adjacencies"
)
return face_adjacency
def _component_face_ids_from_labels(component_labels, n_components):
"""Group face IDs by connected-component label without scanning once per component."""
component_labels = np.asarray(component_labels, dtype=np.int64)
order = np.argsort(component_labels, kind="mergesort")
counts = np.bincount(component_labels, minlength=int(n_components))
offsets = np.concatenate(
(
np.zeros((1,), dtype=np.int64),
np.cumsum(counts, dtype=np.int64),
)
)
return [
order[offsets[component_id]:offsets[component_id + 1]]
for component_id in range(int(n_components))
]
def _component_bounds_from_face_bounds(
bbox_mins,
bbox_maxs,
component_labels,
n_components,
):
component_bbox_mins = np.full((int(n_components), 3), np.inf, dtype=np.float64)
component_bbox_maxs = np.full((int(n_components), 3), -np.inf, dtype=np.float64)
np.minimum.at(component_bbox_mins, component_labels, bbox_mins)
np.maximum.at(component_bbox_maxs, component_labels, bbox_maxs)
return component_bbox_mins, component_bbox_maxs
def find_face_groups(faces, face_labels, adjacency=None, verts=None):
"""Find connected components of faces with the same link ID."""
num_faces = len(faces)
if adjacency is None:
if verts is None:
raise ValueError("verts must be provided if adjacency is None")
mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
adjacency = defaultdict(set)
for face_i, face_j in mesh.face_adjacency:
adjacency[int(face_i)].add(int(face_j))
adjacency[int(face_j)].add(int(face_i))
visited = np.zeros(num_faces, dtype=bool)
groups = []
for start_face_id in range(num_faces):
if visited[start_face_id]:
continue
link_id = int(face_labels[start_face_id])
component = []
queue = deque([start_face_id])
visited[start_face_id] = True
while queue:
face_id = queue.popleft()
component.append(face_id)
for neighbor_id in adjacency.get(face_id, []):
if not visited[neighbor_id] and int(face_labels[neighbor_id]) == link_id:
visited[neighbor_id] = True
queue.append(neighbor_id)
groups.append((link_id, component))
return groups
def find_connected_components_fast(n_nodes, adjacency):
"""Find connected components in an undirected sparse graph."""
row_indices = []
col_indices = []
for node_idx in range(n_nodes):
for neighbor_idx in adjacency.get(node_idx, set()):
row_indices.append(node_idx)
col_indices.append(neighbor_idx)
if not row_indices:
return n_nodes, np.arange(n_nodes, dtype=np.int32)
adjacency_matrix = csr_matrix(
(np.ones(len(row_indices), dtype=bool), (row_indices, col_indices)),
shape=(n_nodes, n_nodes),
)
n_components, labels = connected_components(adjacency_matrix, directed=False)
return int(n_components), labels.astype(np.int32, copy=False)
def _copy_face_adjacency(adjacency):
return {int(face_id): set(int(neighbor) for neighbor in neighbors) for face_id, neighbors in adjacency.items()}
def _merge_face_adjacency(base_adjacency, extra_adjacency):
"""Merge undirected face adjacency maps."""
merged_adjacency = _copy_face_adjacency(base_adjacency)
for face_id, neighbors in extra_adjacency.items():
merged_adjacency.setdefault(int(face_id), set()).update(
int(neighbor) for neighbor in neighbors
)
return merged_adjacency
def _closest_face_pair_between_components(
face_component_a,
face_component_b,
*,
face_verts,
bbox_mins,
bbox_maxs,
face_centroids,
upper_bound_sq=np.inf,
):
"""Find the exact closest face pair across two disconnected components."""
face_component_a = np.asarray(face_component_a, dtype=np.int64)
face_component_b = np.asarray(face_component_b, dtype=np.int64)
if face_component_a.size == 0 or face_component_b.size == 0:
raise ValueError("component face lists must be non-empty")
if face_component_a.size > face_component_b.size:
best_pair, best_sq = _closest_face_pair_between_components(
face_component_b,
face_component_a,
face_verts=face_verts,
bbox_mins=bbox_mins,
bbox_maxs=bbox_maxs,
face_centroids=face_centroids,
upper_bound_sq=upper_bound_sq,
)
return best_pair[::-1], best_sq
centroid_tree = cKDTree(face_centroids[face_component_a])
centroid_distances, nearest_local = _query_nearest(
centroid_tree,
face_centroids[face_component_b],
)
centroid_distances = np.atleast_1d(centroid_distances)
nearest_local = np.atleast_1d(nearest_local)
best_b_local = int(np.argmin(centroid_distances))
best_face_a = int(face_component_a[nearest_local[best_b_local]])
best_face_b = int(face_component_b[best_b_local])
best_sq = min(
float(upper_bound_sq),
float(
_triangle_pair_distance_sq_batch(
face_verts[[best_face_a]],
face_verts[[best_face_b]],
)[0]
),
)
best_pair = (best_face_a, best_face_b)
if best_sq <= 0.0:
return best_pair, best_sq
block_size = 128
exact_batch_size = 8_192
for start_a in range(0, len(face_component_a), block_size):
face_ids_a = face_component_a[start_a:start_a + block_size]
mins_a = bbox_mins[face_ids_a]
maxs_a = bbox_maxs[face_ids_a]
for start_b in range(0, len(face_component_b), block_size):
face_ids_b = face_component_b[start_b:start_b + block_size]
mins_b = bbox_mins[face_ids_b]
maxs_b = bbox_maxs[face_ids_b]
axis_gap = np.maximum(
mins_a[:, np.newaxis, :] - maxs_b[np.newaxis, :, :],
mins_b[np.newaxis, :, :] - maxs_a[:, np.newaxis, :],
)
axis_gap = np.maximum(axis_gap, 0.0)
lower_bound_sq = np.einsum("ijk,ijk->ij", axis_gap, axis_gap)
candidate_mask = lower_bound_sq < best_sq
if not np.any(candidate_mask):
continue
candidate_a_local, candidate_b_local = np.nonzero(candidate_mask)
candidate_lower_bounds = lower_bound_sq[candidate_a_local, candidate_b_local]
candidate_order = np.argsort(candidate_lower_bounds, kind="mergesort")
candidate_a = face_ids_a[candidate_a_local[candidate_order]]
candidate_b = face_ids_b[candidate_b_local[candidate_order]]
candidate_lower_bounds = candidate_lower_bounds[candidate_order]
for batch_start in range(0, len(candidate_a), exact_batch_size):
batch_end = min(batch_start + exact_batch_size, len(candidate_a))
if candidate_lower_bounds[batch_start] >= best_sq:
break
batch_face_ids_a = candidate_a[batch_start:batch_end]
batch_face_ids_b = candidate_b[batch_start:batch_end]
batch_dist_sq = _triangle_pair_distance_sq_batch(
face_verts[batch_face_ids_a],
face_verts[batch_face_ids_b],
)
batch_best_local = int(np.argmin(batch_dist_sq))
batch_best_sq = float(batch_dist_sq[batch_best_local])
if batch_best_sq < best_sq:
best_sq = batch_best_sq
best_pair = (
int(batch_face_ids_a[batch_best_local]),
int(batch_face_ids_b[batch_best_local]),
)
if best_sq <= 0.0:
return best_pair, best_sq
return best_pair, best_sq
class _UnionFind:
def __init__(self, n_items):
self.parent = np.arange(int(n_items), dtype=np.int32)
self.rank = np.zeros((int(n_items),), dtype=np.int8)
self.n_sets = int(n_items)
def find(self, item):
item = int(item)
parent = self.parent
while int(parent[item]) != item:
parent[item] = parent[int(parent[item])]
item = int(parent[item])
return item
def union(self, item_a, item_b):
root_a = self.find(item_a)
root_b = self.find(item_b)
if root_a == root_b:
return False
if self.rank[root_a] < self.rank[root_b]:
root_a, root_b = root_b, root_a
self.parent[root_b] = root_a
if self.rank[root_a] == self.rank[root_b]:
self.rank[root_a] += 1
self.n_sets -= 1
return True
def _component_bbox_pair_lower_bound_sq(
component_bbox_mins,
component_bbox_maxs,
component_indices_a,
component_indices_b,
):
"""Return exact AABB lower-bound distances for component bbox pairs."""
mins_a = component_bbox_mins[component_indices_a]
maxs_a = component_bbox_maxs[component_indices_a]
mins_b = component_bbox_mins[component_indices_b]
maxs_b = component_bbox_maxs[component_indices_b]
axis_gap = np.maximum(mins_a - maxs_b, mins_b - maxs_a)
axis_gap = np.maximum(axis_gap, 0.0)
return np.einsum("ij,ij->i", axis_gap, axis_gap)
def _sorted_component_bbox_lower_bound_pairs(
component_bbox_mins,
component_bbox_maxs,
):
"""Return all component pairs sorted by their exact AABB distance lower bound."""
n_components = len(component_bbox_mins)
n_pairs = n_components * (n_components - 1) // 2
pair_rows = np.empty((n_pairs,), dtype=np.int32)
pair_cols = np.empty((n_pairs,), dtype=np.int32)
pair_lower_bound_sq = np.empty((n_pairs,), dtype=np.float64)
write_offset = 0
for component_idx in range(n_components - 1):
component_indices_b = np.arange(
component_idx + 1,
n_components,
dtype=np.int32,
)
n_row_pairs = len(component_indices_b)
pair_rows[write_offset:write_offset + n_row_pairs] = component_idx
pair_cols[write_offset:write_offset + n_row_pairs] = component_indices_b
pair_lower_bound_sq[write_offset:write_offset + n_row_pairs] = (
_component_bbox_pair_lower_bound_sq(
component_bbox_mins,
component_bbox_maxs,
np.full((n_row_pairs,), component_idx, dtype=np.int32),
component_indices_b,
)
)
write_offset += n_row_pairs
sort_order = np.lexsort((pair_cols, pair_rows, pair_lower_bound_sq))
return (
pair_rows[sort_order],
pair_cols[sort_order],
pair_lower_bound_sq[sort_order],
)
def _exact_component_bridge_edges(
component_face_ids,
*,
face_verts,
bbox_mins,
bbox_maxs,
face_centroids,
component_bbox_mins,
component_bbox_maxs,
):
"""Return exact MST face-pair bridges over disconnected face components.
Components are connected by the exact closest triangle-to-triangle face pair.
AABB distances are used only as lower bounds for lazy Kruskal ordering.
"""
n_components = len(component_face_ids)
if n_components <= 1:
return []
lower_bound_start = time.time()
pair_rows, pair_cols, pair_lower_bound_sq = _sorted_component_bbox_lower_bound_pairs(
component_bbox_mins,
component_bbox_maxs,
)
lower_bound_time = time.time() - lower_bound_start
union_find = _UnionFind(n_components)
bridge_edges = []
exact_edge_heap = []
pair_cursor = 0
exact_evaluations = 0
skipped_same_set_lower_bound_pairs = 0
discarded_same_set_exact_edges = 0
exact_eval_time = 0.0
n_pairs = len(pair_rows)
while union_find.n_sets > 1:
while (
pair_cursor < n_pairs
and (
not exact_edge_heap
or float(pair_lower_bound_sq[pair_cursor]) < float(exact_edge_heap[0][0])
)
):
component_idx_a = int(pair_rows[pair_cursor])
component_idx_b = int(pair_cols[pair_cursor])
pair_cursor += 1
if union_find.find(component_idx_a) == union_find.find(component_idx_b):
skipped_same_set_lower_bound_pairs += 1
continue
exact_start = time.time()
face_pair, distance_sq = _closest_face_pair_between_components(
component_face_ids[component_idx_a],
component_face_ids[component_idx_b],
face_verts=face_verts,
bbox_mins=bbox_mins,
bbox_maxs=bbox_maxs,
face_centroids=face_centroids,
)
exact_eval_time += time.time() - exact_start
exact_evaluations += 1
heapq.heappush(
exact_edge_heap,
(
float(distance_sq),
component_idx_a,
component_idx_b,
int(face_pair[0]),
int(face_pair[1]),
),
)
if not exact_edge_heap:
raise RuntimeError("failed to connect face-distance adjacency components")
_, component_idx_a, component_idx_b, face_i, face_j = heapq.heappop(
exact_edge_heap
)
if union_find.union(component_idx_a, component_idx_b):
bridge_edges.append((face_i, face_j))
else:
discarded_same_set_exact_edges += 1
print(
"Exact component bridge search: "
f"{lower_bound_time:.4f}s bbox lower bounds for {n_pairs:,} component pairs; "
f"{exact_evaluations:,} exact component-pair evaluations in {exact_eval_time:.4f}s; "
f"accepted {len(bridge_edges):,} bridges; "
f"skipped {skipped_same_set_lower_bound_pairs:,} same-set lower-bound pairs; "
f"discarded {discarded_same_set_exact_edges:,} same-set exact edges"
)
return bridge_edges
def ensure_face_adjacency_is_connected(mesh, face_adjacency):
"""Bridge disconnected face-distance components by component-level nearest links."""
connected_adjacency = _copy_face_adjacency(face_adjacency)
num_faces = len(mesh.faces)
n_components, component_labels = find_connected_components_fast(
num_faces,
connected_adjacency,
)
if n_components <= 1:
return connected_adjacency
bridge_start = time.time()
print(
"Face connectivity graph has "
f"{n_components} connected components after threshold links; "
"adding nearest component bridges"
)
verts = np.asarray(mesh.vertices, dtype=np.float64)
faces = np.asarray(mesh.faces, dtype=np.int64)
face_verts = verts[faces]
bbox_mins = np.min(face_verts, axis=1)
bbox_maxs = np.max(face_verts, axis=1)
face_centroids = np.asarray(mesh.triangles_center, dtype=np.float64)
component_face_ids = _component_face_ids_from_labels(component_labels, n_components)
component_bbox_mins, component_bbox_maxs = _component_bounds_from_face_bounds(
bbox_mins,
bbox_maxs,
component_labels,
n_components,
)
bridge_edges = _exact_component_bridge_edges(
component_face_ids,
face_verts=face_verts,
bbox_mins=bbox_mins,
bbox_maxs=bbox_maxs,
face_centroids=face_centroids,
component_bbox_mins=component_bbox_mins,
component_bbox_maxs=component_bbox_maxs,
)
for best_pair in bridge_edges:
face_i, face_j = best_pair
connected_adjacency.setdefault(face_i, set()).add(face_j)
connected_adjacency.setdefault(face_j, set()).add(face_i)
print(
"Component connectivity bridge step: "
f"{time.time() - bridge_start:.4f}s - added {len(bridge_edges):,} "
"nearest component bridges"
)
return connected_adjacency
def build_face_connectivity_adjacency_for_inference(
mesh,
*,
distance_threshold=DEFAULT_FACE_GROUP_DISTANCE_THRESHOLD,
max_distance_workers=None,
):
"""Build the inference connectivity graph from edge adjacency plus cross-CC distance links."""
base_face_adjacency = build_face_edge_adjacency(mesh)
num_faces = len(mesh.faces)
n_components, component_labels = find_connected_components_fast(
num_faces,
base_face_adjacency,
)
if n_components <= 1:
return base_face_adjacency
print(
"Base face-edge adjacency has "
f"{n_components} connected components across {num_faces} faces; "
"adding face-level cross-component distance links"
)
cross_component_adjacency = build_face_distance_adjacency(
np.asarray(mesh.vertices, dtype=np.float64),
np.asarray(mesh.faces, dtype=np.int64),
distance_threshold=float(distance_threshold),
component_labels=component_labels,
cross_component_only=True,
max_distance_workers=max_distance_workers,
log_prefix="Cross-component",
)
return _merge_face_adjacency(base_face_adjacency, cross_component_adjacency)
def _compute_face_probability_statistics(num_faces, point_part_probabilities, face_indices):
"""Aggregate point softmax probabilities onto faces."""
point_part_probabilities = np.asarray(point_part_probabilities, dtype=np.float32)
face_indices = np.asarray(face_indices, dtype=np.int64)
if point_part_probabilities.ndim != 2:
raise ValueError(
"point_part_probabilities must have shape (num_points, num_parts), "
f"got {point_part_probabilities.shape}"
)
if face_indices.shape != (point_part_probabilities.shape[0],):
raise ValueError(
"face_indices must have one entry per point, "
f"got {face_indices.shape} for {point_part_probabilities.shape[0]} points"
)
num_parts = point_part_probabilities.shape[1]
face_probability_sums = np.zeros((num_faces, num_parts), dtype=np.float64)
face_probability_counts = np.zeros(num_faces, dtype=np.int64)
valid_mask = face_indices >= 0
if np.any(valid_mask):
valid_face_indices = face_indices[valid_mask]
np.add.at(
face_probability_sums,
valid_face_indices,
point_part_probabilities[valid_mask],
)
np.add.at(
face_probability_counts,
valid_face_indices,
np.ones(valid_face_indices.shape[0], dtype=np.int64),
)
return face_probability_sums, face_probability_counts
def _build_filled_face_probability_means(mesh, face_probability_sums, face_probability_counts):
"""Fill unsampled face probabilities from the nearest sampled face."""
num_faces, num_parts = face_probability_sums.shape
filled_face_probability_means = np.zeros((num_faces, num_parts), dtype=np.float32)
defined_mask = face_probability_counts > 0
if np.any(defined_mask):
defined_faces = np.flatnonzero(defined_mask)
filled_face_probability_means[defined_faces] = (
face_probability_sums[defined_faces]
/ face_probability_counts[defined_faces, np.newaxis]
).astype(np.float32, copy=False)
else:
filled_face_probability_means.fill(1.0 / max(num_parts, 1))
return filled_face_probability_means
undefined_faces = np.flatnonzero(~defined_mask)
if undefined_faces.size == 0:
return filled_face_probability_means
centroids = _get_face_centroids(mesh)
tree = cKDTree(centroids[defined_faces])
_, nearest_local = _query_nearest(tree, centroids[undefined_faces])
nearest_local = np.atleast_1d(nearest_local)
nearest_defined_faces = defined_faces[nearest_local]
filled_face_probability_means[undefined_faces] = filled_face_probability_means[
nearest_defined_faces
]
return filled_face_probability_means
def _group_confidence_vector(
group_face_ids,
face_probability_sums,
face_probability_counts,
filled_face_probability_means,
):
"""Aggregate point softmax probabilities for one face group."""
group_face_ids = np.asarray(group_face_ids, dtype=np.int64)
group_point_count = int(face_probability_counts[group_face_ids].sum())
if group_point_count > 0:
return (
face_probability_sums[group_face_ids].sum(axis=0)
/ float(group_point_count)
).astype(np.float32, copy=False)
return filled_face_probability_means[group_face_ids].mean(axis=0).astype(
np.float32,
copy=False,
)
def _adjacent_group_indices_by_part_id_for_group(groups, face_adjacency, face_to_group, group_idx):
adjacent_groups_by_part_id = defaultdict(set)
for face_id in groups[group_idx][1]:
for adjacent_face_id in face_adjacency.get(face_id, set()):
adjacent_group_idx = int(face_to_group[int(adjacent_face_id)])
if adjacent_group_idx < 0 or adjacent_group_idx == group_idx:
continue
adjacent_part_id = int(groups[adjacent_group_idx][0])
if adjacent_part_id >= 0:
adjacent_groups_by_part_id[adjacent_part_id].add(adjacent_group_idx)
return adjacent_groups_by_part_id
def _iterative_single_group_reassignment(
face_part_ids,
*,
face_adjacency,
input_part_ids,
face_probability_sums,
face_probability_counts,
filled_face_probability_means,
):
"""Iteratively enforce one face group per part ID."""
input_part_ids = np.asarray(input_part_ids, dtype=np.int64)
face_part_ids = np.asarray(face_part_ids, dtype=np.int32).copy()
max_iterations = max(8, 4 * max(1, int(input_part_ids.size)))
seen_states = set()
for _ in range(max_iterations):
state_key = face_part_ids.tobytes()
if state_key in seen_states:
return face_part_ids
seen_states.add(state_key)
groups = find_face_groups(
np.empty((face_part_ids.shape[0], 3), dtype=np.int64),
face_part_ids,
adjacency=face_adjacency,
)
groups_by_part_id = defaultdict(list)
group_confidences = []
group_sizes = []
for group_idx, (part_id, group_face_ids) in enumerate(groups):
groups_by_part_id[int(part_id)].append(group_idx)
group_face_ids = np.asarray(group_face_ids, dtype=np.int64)
group_confidences.append(
_group_confidence_vector(
group_face_ids,
face_probability_sums,
face_probability_counts,
filled_face_probability_means,
)
)
group_sizes.append(len(group_face_ids))
duplicate_part_ids = [
int(part_id)
for part_id, group_indices in groups_by_part_id.items()
if len(group_indices) > 1 and part_id >= 0
]
if not duplicate_part_ids:
return face_part_ids
existing_part_ids = {
int(part_id)
for part_id in np.unique(face_part_ids)
if int(part_id) >= 0
}
missing_part_ids = sorted(
int(part_id) for part_id in input_part_ids if int(part_id) not in existing_part_ids
)
updates = {}
groups_to_keep = {}
for part_id in duplicate_part_ids:
groups_to_keep[part_id] = max(
groups_by_part_id[part_id],
key=lambda group_idx: (
group_sizes[group_idx],
float(group_confidences[group_idx][part_id]),
-group_idx,
),
)
face_to_group = np.full(face_part_ids.shape[0], -1, dtype=np.int32)
for group_idx, (_, group_face_ids) in enumerate(groups):
face_to_group[np.asarray(group_face_ids, dtype=np.int64)] = int(group_idx)
available_missing_part_ids = set(missing_part_ids)
for part_id in duplicate_part_ids:
group_to_keep = groups_to_keep[part_id]
for group_to_update in groups_by_part_id[part_id]:
if group_to_update == group_to_keep:
continue
adjacent_groups_by_part_id = _adjacent_group_indices_by_part_id_for_group(
groups,
face_adjacency,
face_to_group,
group_to_update,
)
safe_adjacent_part_ids = set()
for adjacent_part_id, adjacent_group_indices in adjacent_groups_by_part_id.items():
if adjacent_part_id == part_id:
continue
target_group_indices = groups_by_part_id.get(adjacent_part_id, [])
if len(target_group_indices) == 1:
safe_adjacent_part_ids.add(adjacent_part_id)
continue
target_group_to_keep = groups_to_keep.get(adjacent_part_id)
if (
target_group_to_keep is not None
and target_group_to_keep in adjacent_group_indices
):
safe_adjacent_part_ids.add(adjacent_part_id)
replacement_candidates = sorted(
safe_adjacent_part_ids | available_missing_part_ids
)
if not replacement_candidates:
continue
confidence_vector = group_confidences[group_to_update]
best_replacement_part_id = max(
replacement_candidates,
key=lambda candidate_part_id: (
float(confidence_vector[candidate_part_id]),
-int(candidate_part_id),
),
)
updates[group_to_update] = int(best_replacement_part_id)
available_missing_part_ids.discard(int(best_replacement_part_id))
if not updates:
return face_part_ids
updated_face_part_ids = face_part_ids.copy()
for group_idx, replacement_part_id in updates.items():
updated_face_part_ids[np.asarray(groups[group_idx][1], dtype=np.int64)] = replacement_part_id
if np.array_equal(updated_face_part_ids, face_part_ids):
return face_part_ids
face_part_ids = updated_face_part_ids
raise RuntimeError("single-group face post-processing did not converge")
def refine_face_part_ids_for_inference(
mesh,
face_part_ids,
*,
point_part_probabilities=None,
face_indices=None,
input_part_ids=None,
strict=False,
enforce_connectivity_per_part=False,
distance_threshold=DEFAULT_FACE_GROUP_DISTANCE_THRESHOLD,
):
"""Inference-time face post-processing layered on top of the base face pass."""
face_part_ids = np.asarray(face_part_ids, dtype=np.int32)
base_face_part_ids = refine_face_part_ids(
mesh,
face_part_ids,
strict=bool(strict),
)
if not enforce_connectivity_per_part:
return base_face_part_ids
if point_part_probabilities is None:
raise ValueError(
"point_part_probabilities is required when enforce_connectivity_per_part is enabled"
)
if face_indices is None:
raise ValueError("face_indices is required when enforce_connectivity_per_part is enabled")
if input_part_ids is None:
raise ValueError("input_part_ids is required when enforce_connectivity_per_part is enabled")
point_part_probabilities = np.asarray(point_part_probabilities, dtype=np.float32)
face_indices = np.asarray(face_indices, dtype=np.int64)
input_part_ids = np.asarray(input_part_ids, dtype=np.int64)
if point_part_probabilities.ndim != 2:
raise ValueError("point_part_probabilities must have shape [num_points, num_parts]")
if point_part_probabilities.shape[1] <= 0:
raise ValueError("point_part_probabilities must contain at least one part column")
if np.any(input_part_ids < 0):
raise ValueError("input_part_ids must be non-negative")
if np.any(input_part_ids >= point_part_probabilities.shape[1]):
raise ValueError(
"input_part_ids must be within the probability columns, "
f"got max {int(input_part_ids.max())} for {point_part_probabilities.shape[1]} columns"
)
face_connectivity_adjacency = build_face_connectivity_adjacency_for_inference(
mesh,
distance_threshold=float(distance_threshold),
)
face_connectivity_adjacency = ensure_face_adjacency_is_connected(
mesh,
face_connectivity_adjacency,
)
faces = np.asarray(mesh.faces, dtype=np.int64)
face_probability_sums, face_probability_counts = _compute_face_probability_statistics(
len(faces),
point_part_probabilities,
face_indices,
)
filled_face_probability_means = _build_filled_face_probability_means(
mesh,
face_probability_sums,
face_probability_counts,
)
return _iterative_single_group_reassignment(
base_face_part_ids,
face_adjacency=face_connectivity_adjacency,
input_part_ids=input_part_ids,
face_probability_sums=face_probability_sums,
face_probability_counts=face_probability_counts,
filled_face_probability_means=filled_face_probability_means,
)