rayli's picture
Cleanup demo code paths
2f3ab6d verified
Raw
History Blame Contribute Delete
61.5 kB
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn.functional as F
from instruct_particulate.model import Particulate2ArticulationModel
from instruct_particulate.utils.data_utils import (
load_trimesh,
normalize_mesh,
up_dir_rotation_matrix,
reorient_mesh_to_z_up,
)
from instruct_particulate.utils.export_utils import export_urdf
from instruct_particulate.utils.inference_utils import (
axis_point_to_plucker_torch,
denormalize_motion_parameters,
estimate_prismatic_limit_torch,
estimate_revolute_limit_torch,
fit_axis_to_closest_points_torch,
motion_arrays_from_model_output,
prismatic_directions_from_plucker,
run_joint_refit_from_face_seg,
write_json,
)
from instruct_particulate.utils.inference_visualization_utils import (
print_inference_summary,
save_joint_overparam_visualization_from_model_output,
save_segmented_visualizations,
select_visualized_link_point_prompts,
)
from instruct_particulate.utils.postprocessing_utils import (
find_unrefined_part_ids_for_faces,
refine_face_part_ids_for_inference,
)
@dataclass(frozen=True)
class PreparedMeshGeometry:
original_mesh: Any
normalized_mesh: Any
center: np.ndarray
scale: float
up_dir_rotation: np.ndarray
render_mesh: Any
render_to_model_rotation: np.ndarray
_INFER_OVERPARAM_AXIS_AABB_HALF_EXTENT_MIN = 1e-4
def tensor_to_numpy(tensor: torch.Tensor, *, dtype: np.dtype[Any] | type) -> np.ndarray:
"""Converts a tensor to a CPU NumPy array with the requested dtype."""
return tensor.detach().cpu().numpy().astype(dtype, copy=False)
def denormalize_points(
points: np.ndarray,
*,
center: np.ndarray | tuple[float, float, float],
scale: float,
) -> np.ndarray:
"""Converts normalized model-space points back to mesh-space coordinates."""
return (
np.asarray(points, dtype=np.float32) / np.float32(scale)
+ np.asarray(center, dtype=np.float32)
).astype(np.float32, copy=False)
def _confidence_weighted_mean(
values: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
"""Returns a weighted mean over the leading dimension."""
values = values.float()
weights = weights.float()
if values.shape[0] != weights.shape[0]:
raise ValueError(
"values and weights must agree on the leading dimension, "
f"got {tuple(values.shape)} and {tuple(weights.shape)}"
)
if values.shape[0] == 0:
return values.new_zeros(values.shape[1:])
total_weight = weights.sum()
if float(total_weight.item()) <= 0.0:
return values.mean(dim=0)
view_shape = (weights.shape[0],) + (1,) * (values.ndim - 1)
return (values * weights.view(view_shape)).sum(dim=0) / total_weight
def _weighted_median_1d(
values: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
"""Returns the weighted median of a 1D tensor."""
values = values.float().reshape(-1)
weights = weights.float().reshape(-1)
if values.shape != weights.shape:
raise ValueError(
"values and weights must have matching shapes, "
f"got {tuple(values.shape)} and {tuple(weights.shape)}"
)
if values.numel() == 0:
return values.new_zeros(())
if float(weights.sum().item()) <= 0.0:
return torch.quantile(values, 0.5)
sorted_values, order = torch.sort(values)
sorted_weights = weights[order]
cumulative_weights = torch.cumsum(sorted_weights, dim=0)
cutoff = 0.5 * sorted_weights.sum()
median_index = torch.searchsorted(cumulative_weights, cutoff, right=False)
median_index = median_index.clamp_max(sorted_values.shape[0] - 1)
return sorted_values[median_index]
def _confidence_weighted_coordinatewise_median(
values: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
"""Returns a coordinate-wise weighted median over the leading dimension."""
values = values.float()
weights = weights.float().reshape(-1)
if values.ndim != 2:
raise ValueError(f"Expected values with shape (N, D), got {tuple(values.shape)}")
if values.shape[0] != weights.shape[0]:
raise ValueError(
"values and weights must agree on the leading dimension, "
f"got {tuple(values.shape)} and {tuple(weights.shape)}"
)
if values.shape[0] == 0:
return values.new_zeros((values.shape[-1],))
if float(weights.sum().item()) <= 0.0:
return torch.quantile(values, 0.5, dim=0)
return torch.stack(
[
_weighted_median_1d(values[:, dim_idx], weights)
for dim_idx in range(values.shape[1])
],
dim=0,
)
def _weighted_fit_axis_to_closest_points(
query_points: torch.Tensor,
closest_axis_points: torch.Tensor,
weights: torch.Tensor,
*,
direction_hint: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fits a weighted least-squares axis from per-query closest-point targets."""
if query_points.shape != closest_axis_points.shape:
raise ValueError(
"query_points and closest_axis_points must share the same shape, "
f"got {tuple(query_points.shape)} and {tuple(closest_axis_points.shape)}"
)
if query_points.ndim != 2 or query_points.shape[-1] != 3:
raise ValueError(
f"Expected query_points and closest_axis_points to have shape (N, 3), got {tuple(query_points.shape)}"
)
if weights.ndim != 1 or weights.shape[0] != query_points.shape[0]:
raise ValueError(
"weights must have shape (N,), "
f"got {tuple(weights.shape)} for query_points {tuple(query_points.shape)}"
)
if query_points.numel() == 0:
zero = query_points.new_zeros(3)
return zero, zero
weights = weights.float()
valid_mask = torch.isfinite(weights) & (weights > 0)
if not bool(valid_mask.any().item()):
return fit_axis_to_closest_points_torch(
query_points,
closest_axis_points,
direction_hint=direction_hint,
)
query_points = query_points[valid_mask]
closest_axis_points = closest_axis_points[valid_mask]
weights = weights[valid_mask]
normalized_weights = weights / weights.sum().clamp_min(1e-12)
result_dtype = (
query_points.dtype
if query_points.dtype in (torch.float32, torch.float64)
else torch.float32
)
solve_dtype = torch.float32 if query_points.device.type == "mps" else torch.float64
query_points = query_points.to(dtype=solve_dtype)
closest_axis_points = closest_axis_points.to(dtype=solve_dtype)
normalized_weights = normalized_weights.to(dtype=solve_dtype)
if direction_hint is None:
direction_hint = query_points.new_zeros(3)
else:
direction_hint = direction_hint.to(
device=query_points.device,
dtype=solve_dtype,
)
closest_axis_mean = (closest_axis_points * normalized_weights.unsqueeze(-1)).sum(dim=0)
centered_closest_axis_points = closest_axis_points - closest_axis_mean.unsqueeze(0)
query_residuals = query_points - closest_axis_points
objective_matrix = (
query_residuals.transpose(0, 1)
@ (query_residuals * normalized_weights.unsqueeze(-1))
- centered_closest_axis_points.transpose(0, 1)
@ (centered_closest_axis_points * normalized_weights.unsqueeze(-1))
)
direction_hint_norm = torch.linalg.norm(direction_hint)
normalized_direction_hint = direction_hint / direction_hint_norm.clamp_min(1e-12)
eigenvalues, eigenvectors = torch.linalg.eigh(objective_matrix)
axis_direction = eigenvectors[:, 0]
if direction_hint_norm > 1e-8:
eigenspace_scale = torch.maximum(
eigenvalues.abs().max(),
objective_matrix.abs().max(),
).clamp_min(1.0)
eigenspace_tol = eigenspace_scale * (
1e-8 if solve_dtype == torch.float64 else 1e-5
)
smallest_eigenspace = eigenvectors[
:,
eigenvalues <= (eigenvalues[0] + eigenspace_tol),
]
projected_hint = smallest_eigenspace @ (
smallest_eigenspace.transpose(0, 1) @ normalized_direction_hint
)
if torch.linalg.norm(projected_hint) > 1e-8:
axis_direction = projected_hint
if direction_hint_norm > 1e-8 and torch.dot(axis_direction, normalized_direction_hint) < 0:
axis_direction = -axis_direction
axis_direction = F.normalize(axis_direction, dim=-1, eps=1e-8)
fallback_direction = torch.where(
direction_hint_norm > 1e-8,
normalized_direction_hint,
query_points.new_tensor([1.0, 0.0, 0.0]),
)
if not torch.isfinite(axis_direction).all() or torch.linalg.norm(axis_direction) <= 1e-8:
axis_direction = fallback_direction
axis_point = closest_axis_mean - axis_direction * torch.dot(axis_direction, closest_axis_mean)
if not torch.isfinite(axis_point).all():
axis_point = closest_axis_mean
return axis_direction.to(dtype=result_dtype), axis_point.to(dtype=result_dtype)
def _confidence_weighted_estimate_prismatic_limit(
current_points: torch.Tensor,
target_points: torch.Tensor,
axis_direction: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
"""Fits a weighted shared prismatic displacement."""
if current_points.numel() == 0:
return current_points.new_zeros(())
if weights.ndim != 1 or weights.shape[0] != current_points.shape[0]:
raise ValueError(
"weights must have shape (N,), "
f"got {tuple(weights.shape)} for current_points {tuple(current_points.shape)}"
)
valid_mask = torch.isfinite(weights) & (weights > 0)
if not bool(valid_mask.any().item()):
return estimate_prismatic_limit_torch(
current_points,
target_points,
axis_direction,
)
axis_direction = F.normalize(axis_direction.float(), dim=-1, eps=1e-8)
weights = weights[valid_mask].float()
projections = (
(target_points[valid_mask].float() - current_points[valid_mask].float())
* axis_direction.unsqueeze(0)
).sum(dim=-1)
return (projections * weights).sum() / weights.sum().clamp_min(1e-12)
def _confidence_weighted_estimate_revolute_limit(
current_points: torch.Tensor,
target_points: torch.Tensor,
axis_direction: torch.Tensor,
axis_point: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
"""Fits a weighted shared revolute angle around the given axis."""
if current_points.numel() == 0:
return current_points.new_zeros(())
if weights.ndim != 1 or weights.shape[0] != current_points.shape[0]:
raise ValueError(
"weights must have shape (N,), "
f"got {tuple(weights.shape)} for current_points {tuple(current_points.shape)}"
)
valid_mask = torch.isfinite(weights) & (weights > 0)
if not bool(valid_mask.any().item()):
return estimate_revolute_limit_torch(
current_points,
target_points,
axis_direction,
axis_point,
)
axis_direction = F.normalize(axis_direction.float(), dim=-1, eps=1e-8)
axis_point = axis_point.float()
current_points = current_points[valid_mask].float()
target_points = target_points[valid_mask].float()
weights = weights[valid_mask].float()
def _project_perpendicular(points: torch.Tensor) -> torch.Tensor:
offsets = points - axis_point.unsqueeze(0)
parallel = (offsets * axis_direction.unsqueeze(0)).sum(dim=-1, keepdim=True)
return offsets - parallel * axis_direction.unsqueeze(0)
current_perp = _project_perpendicular(current_points)
target_perp = _project_perpendicular(target_points)
cosine_term = ((current_perp * target_perp).sum(dim=-1) * weights).sum()
sine_term = torch.linalg.cross(current_perp, target_perp, dim=-1)
sine_term = ((sine_term * axis_direction.unsqueeze(0)).sum(dim=-1) * weights).sum()
return torch.atan2(sine_term, cosine_term)
def _confidence_weighted_axis_direction_vote(
predicted_directions: torch.Tensor,
weights: torch.Tensor,
*,
sign_hint: torch.Tensor | None = None,
) -> torch.Tensor:
"""Returns one weighted averaged unit direction while treating flips as equivalent."""
predicted_directions = predicted_directions.float()
weights = weights.float()
if predicted_directions.ndim != 2 or predicted_directions.shape[-1] != 3:
raise ValueError(
f"Expected predicted_directions with shape (N, 3), got {tuple(predicted_directions.shape)}"
)
if weights.ndim != 1 or weights.shape[0] != predicted_directions.shape[0]:
raise ValueError(
"weights must have shape (N,), "
f"got {tuple(weights.shape)} for predicted_directions {tuple(predicted_directions.shape)}"
)
if predicted_directions.numel() == 0:
return predicted_directions.new_zeros(3)
direction_norms = torch.linalg.vector_norm(predicted_directions, dim=-1)
valid_mask = torch.isfinite(weights) & (weights > 0) & (direction_norms > 1e-8)
if not bool(valid_mask.any().item()):
return predicted_directions.new_zeros(3)
unit_directions = predicted_directions[valid_mask] / direction_norms[valid_mask].unsqueeze(-1)
weights = weights[valid_mask]
direction_covariance = unit_directions.transpose(0, 1) @ (
unit_directions * weights.unsqueeze(-1)
)
_, eigenvectors = torch.linalg.eigh(direction_covariance)
anchor_direction = eigenvectors[:, -1]
alignment = torch.sign(unit_directions @ anchor_direction)
alignment = torch.where(
alignment == 0,
torch.ones_like(alignment),
alignment,
)
aligned_mean_direction = (
unit_directions
* alignment.unsqueeze(-1)
* weights.unsqueeze(-1)
).sum(dim=0) / weights.sum().clamp_min(1e-12)
if float(torch.linalg.vector_norm(aligned_mean_direction).item()) <= 1e-8:
axis_direction = F.normalize(anchor_direction, dim=0, eps=1e-8)
else:
axis_direction = F.normalize(aligned_mean_direction, dim=0, eps=1e-8)
if sign_hint is not None:
sign_hint = sign_hint.float()
if float(torch.linalg.vector_norm(sign_hint).item()) > 1e-8:
if float(torch.dot(axis_direction, sign_hint).item()) < 0.0:
axis_direction = -axis_direction
return axis_direction
def _fit_confidence_weighted_revolute_joint_parameters(
query_points: torch.Tensor,
closest_axis_points: torch.Tensor,
low_points: torch.Tensor,
high_points: torch.Tensor,
weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fits one revolute axis and limits from weighted per-query targets."""
query_points = query_points.float()
closest_axis_points = closest_axis_points.float()
low_points = low_points.float()
high_points = high_points.float()
weights = weights.float()
if query_points.numel() == 0:
zero_axis = query_points.new_zeros(6)
zero_range = query_points.new_zeros(2)
return zero_axis, zero_range
direction_terms = (
torch.linalg.cross(
query_points - closest_axis_points,
low_points - closest_axis_points,
dim=-1,
)
+ torch.linalg.cross(
query_points - closest_axis_points,
high_points - closest_axis_points,
dim=-1,
)
+ torch.linalg.cross(
low_points - closest_axis_points,
high_points - closest_axis_points,
dim=-1,
)
)
direction_hint = _confidence_weighted_mean(direction_terms, weights)
axis_direction, axis_point = _weighted_fit_axis_to_closest_points(
query_points,
closest_axis_points,
weights,
direction_hint=direction_hint,
)
revolute_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
low_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
low_points,
axis_direction,
axis_point,
weights,
)
high_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
high_points,
axis_direction,
axis_point,
weights,
)
return revolute_axis, torch.stack((low_limit, high_limit))
def _fit_confidence_weighted_revolute_joint_parameters_with_direction(
query_points: torch.Tensor,
closest_axis_points: torch.Tensor,
low_points: torch.Tensor,
high_points: torch.Tensor,
predicted_axis_directions: torch.Tensor,
weights: torch.Tensor,
*,
direction_vote_query_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fits one revolute axis and limits from weighted per-query directional targets."""
query_points = query_points.float()
closest_axis_points = closest_axis_points.float()
low_points = low_points.float()
high_points = high_points.float()
predicted_axis_directions = predicted_axis_directions.float()
weights = weights.float()
if query_points.numel() == 0:
zero_axis = query_points.new_zeros(6)
zero_range = query_points.new_zeros(2)
return zero_axis, zero_range
direction_sign_terms = (
torch.linalg.cross(
query_points - closest_axis_points,
low_points - closest_axis_points,
dim=-1,
)
+ torch.linalg.cross(
query_points - closest_axis_points,
high_points - closest_axis_points,
dim=-1,
)
+ torch.linalg.cross(
low_points - closest_axis_points,
high_points - closest_axis_points,
dim=-1,
)
)
vote_predicted_axis_directions = predicted_axis_directions
vote_weights = weights
vote_direction_sign_terms = direction_sign_terms
if direction_vote_query_mask is not None:
if bool(direction_vote_query_mask.any().item()):
vote_predicted_axis_directions = predicted_axis_directions[direction_vote_query_mask]
vote_weights = weights[direction_vote_query_mask]
vote_direction_sign_terms = direction_sign_terms[direction_vote_query_mask]
axis_direction = _confidence_weighted_axis_direction_vote(
vote_predicted_axis_directions,
vote_weights,
sign_hint=_confidence_weighted_mean(vote_direction_sign_terms, vote_weights),
)
if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8:
return _fit_confidence_weighted_revolute_joint_parameters(
query_points,
closest_axis_points,
low_points,
high_points,
weights,
)
axis_point = _confidence_weighted_coordinatewise_median(
closest_axis_points,
weights,
)
low_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
low_points,
axis_direction,
axis_point,
weights,
)
high_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
high_points,
axis_direction,
axis_point,
weights,
)
if float(low_limit.item()) > float(high_limit.item()):
axis_direction = -axis_direction
low_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
low_points,
axis_direction,
axis_point,
weights,
)
high_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
high_points,
axis_direction,
axis_point,
weights,
)
revolute_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
return revolute_axis, torch.stack((low_limit, high_limit))
def _fit_confidence_weighted_prismatic_joint_parameters(
query_points: torch.Tensor,
closest_axis_points: torch.Tensor,
low_points: torch.Tensor,
high_points: torch.Tensor,
weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fits one prismatic axis and limits from weighted per-query targets."""
query_points = query_points.float()
closest_axis_points = closest_axis_points.float()
low_points = low_points.float()
high_points = high_points.float()
weights = weights.float()
if query_points.numel() == 0:
zero_axis = query_points.new_zeros(6)
zero_range = query_points.new_zeros(2)
return zero_axis, zero_range
axis_direction, axis_point = _weighted_fit_axis_to_closest_points(
query_points,
closest_axis_points,
weights,
direction_hint=_confidence_weighted_mean(high_points - low_points, weights),
)
prismatic_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
low_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
low_points,
axis_direction,
weights,
)
high_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
high_points,
axis_direction,
weights,
)
return prismatic_axis, torch.stack((low_limit, high_limit))
def _fit_confidence_weighted_prismatic_joint_parameters_with_direction(
query_points: torch.Tensor,
closest_axis_points: torch.Tensor,
low_points: torch.Tensor,
high_points: torch.Tensor,
predicted_axis_directions: torch.Tensor,
weights: torch.Tensor,
*,
direction_vote_query_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fits one prismatic axis and limits from weighted per-query directional targets."""
query_points = query_points.float()
closest_axis_points = closest_axis_points.float()
low_points = low_points.float()
high_points = high_points.float()
predicted_axis_directions = predicted_axis_directions.float()
weights = weights.float()
if query_points.numel() == 0:
zero_axis = query_points.new_zeros(6)
zero_range = query_points.new_zeros(2)
return zero_axis, zero_range
vote_predicted_axis_directions = predicted_axis_directions
vote_weights = weights
vote_sign_hint_terms = high_points - low_points
if direction_vote_query_mask is not None:
if bool(direction_vote_query_mask.any().item()):
vote_predicted_axis_directions = predicted_axis_directions[direction_vote_query_mask]
vote_weights = weights[direction_vote_query_mask]
vote_sign_hint_terms = vote_sign_hint_terms[direction_vote_query_mask]
axis_direction = _confidence_weighted_axis_direction_vote(
vote_predicted_axis_directions,
vote_weights,
sign_hint=_confidence_weighted_mean(vote_sign_hint_terms, vote_weights),
)
if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8:
return _fit_confidence_weighted_prismatic_joint_parameters(
query_points,
closest_axis_points,
low_points,
high_points,
weights,
)
axis_point = _confidence_weighted_coordinatewise_median(
closest_axis_points,
weights,
)
low_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
low_points,
axis_direction,
weights,
)
high_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
high_points,
axis_direction,
weights,
)
if float(low_limit.item()) > float(high_limit.item()):
axis_direction = -axis_direction
low_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
low_points,
axis_direction,
weights,
)
high_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
high_points,
axis_direction,
weights,
)
prismatic_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
return prismatic_axis, torch.stack((low_limit, high_limit))
def _fit_confidence_weighted_revolute_joint_parameters_with_single_direction(
query_points: torch.Tensor,
closest_axis_points: torch.Tensor,
low_points: torch.Tensor,
high_points: torch.Tensor,
predicted_axis_direction: torch.Tensor,
weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fits one revolute axis and limits from weighted query targets plus one direction."""
query_points = query_points.float()
closest_axis_points = closest_axis_points.float()
low_points = low_points.float()
high_points = high_points.float()
predicted_axis_direction = predicted_axis_direction.float()
weights = weights.float()
if query_points.numel() == 0:
zero_axis = query_points.new_zeros(6)
zero_range = query_points.new_zeros(2)
return zero_axis, zero_range
if float(torch.linalg.vector_norm(predicted_axis_direction).item()) <= 1e-8:
return _fit_confidence_weighted_revolute_joint_parameters(
query_points,
closest_axis_points,
low_points,
high_points,
weights,
)
axis_direction = F.normalize(predicted_axis_direction, dim=0, eps=1e-8)
axis_point = _confidence_weighted_coordinatewise_median(
closest_axis_points,
weights,
)
low_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
low_points,
axis_direction,
axis_point,
weights,
)
high_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
high_points,
axis_direction,
axis_point,
weights,
)
if float(low_limit.item()) > float(high_limit.item()):
axis_direction = -axis_direction
low_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
low_points,
axis_direction,
axis_point,
weights,
)
high_limit = _confidence_weighted_estimate_revolute_limit(
query_points,
high_points,
axis_direction,
axis_point,
weights,
)
revolute_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
return revolute_axis, torch.stack((low_limit, high_limit))
def _fit_confidence_weighted_prismatic_joint_parameters_with_single_direction(
query_points: torch.Tensor,
closest_axis_points: torch.Tensor,
low_points: torch.Tensor,
high_points: torch.Tensor,
predicted_axis_direction: torch.Tensor,
weights: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fits one prismatic axis and limits from weighted query targets plus one direction."""
query_points = query_points.float()
closest_axis_points = closest_axis_points.float()
low_points = low_points.float()
high_points = high_points.float()
predicted_axis_direction = predicted_axis_direction.float()
weights = weights.float()
if query_points.numel() == 0:
zero_axis = query_points.new_zeros(6)
zero_range = query_points.new_zeros(2)
return zero_axis, zero_range
if float(torch.linalg.vector_norm(predicted_axis_direction).item()) <= 1e-8:
return _fit_confidence_weighted_prismatic_joint_parameters(
query_points,
closest_axis_points,
low_points,
high_points,
weights,
)
axis_direction = F.normalize(predicted_axis_direction, dim=0, eps=1e-8)
axis_point = _confidence_weighted_coordinatewise_median(
closest_axis_points,
weights,
)
low_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
low_points,
axis_direction,
weights,
)
high_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
high_points,
axis_direction,
weights,
)
if float(low_limit.item()) > float(high_limit.item()):
axis_direction = -axis_direction
low_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
low_points,
axis_direction,
weights,
)
high_limit = _confidence_weighted_estimate_prismatic_limit(
query_points,
high_points,
axis_direction,
weights,
)
prismatic_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
return prismatic_axis, torch.stack((low_limit, high_limit))
def _apply_joint_decoding_confidence_temperature(
joint_decoding_confidences: torch.Tensor,
*,
temperature: float,
) -> torch.Tensor:
"""Applies a temperature transform to segmentation confidences used for joint refitting."""
if temperature <= 0.0:
raise ValueError("joint decoding confidence temperature must be positive")
confidences = joint_decoding_confidences.to(dtype=torch.float32).clamp_min(0.0)
if temperature == 1.0:
return confidences
return confidences.pow(1.0 / float(temperature))
def _resolve_second_pass_part_aabb_bbox_source_link_ids(
*,
assigned_link_ids: torch.Tensor,
segmentation_logits: torch.Tensor,
) -> torch.Tensor:
"""Returns the second-pass query subset that contributes to each part AABB."""
if segmentation_logits.ndim != 3:
raise ValueError(
"segmentation_logits must have shape (B, Q, num_parts), "
f"got {tuple(segmentation_logits.shape)}"
)
if assigned_link_ids.shape != segmentation_logits.shape[:2]:
raise ValueError(
"assigned_link_ids must match segmentation_logits batch/query dims, "
f"got {tuple(assigned_link_ids.shape)} and {tuple(segmentation_logits.shape)}"
)
predicted_link_ids = segmentation_logits.argmax(dim=-1)
bbox_source_link_ids = torch.full_like(assigned_link_ids, -1)
valid_assigned_mask = assigned_link_ids >= 0
agreement_mask = valid_assigned_mask & (predicted_link_ids == assigned_link_ids)
bbox_source_link_ids[agreement_mask] = assigned_link_ids[agreement_mask]
return bbox_source_link_ids
def _compute_query_link_aabb_parameters_from_source_subset(
*,
query_points: torch.Tensor,
assigned_link_ids: torch.Tensor,
bbox_source_link_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Computes per-link AABB parameters from a filtered query subset.
The AABB for one link is estimated from the subset encoded by
`bbox_source_link_ids`. If that subset is empty for a link, this falls back
to all queries assigned to that link. The resulting center / half-extent is
broadcast to every query assigned to the link.
"""
if assigned_link_ids.shape != query_points.shape[:2]:
raise ValueError(
"assigned_link_ids must match query_points batch/query dims, "
f"got {tuple(assigned_link_ids.shape)} and {tuple(query_points.shape)}"
)
if bbox_source_link_ids.shape != assigned_link_ids.shape:
raise ValueError(
"bbox_source_link_ids must match assigned_link_ids, "
f"got {tuple(bbox_source_link_ids.shape)} and {tuple(assigned_link_ids.shape)}"
)
query_points = query_points.float()
centers = query_points.new_zeros(query_points.shape)
half_extents = query_points.new_ones(query_points.shape)
for batch_idx in range(query_points.shape[0]):
batch_assigned_link_ids = assigned_link_ids[batch_idx]
valid_assigned_mask = batch_assigned_link_ids >= 0
if not bool(valid_assigned_mask.any().item()):
continue
unique_link_ids = torch.unique(batch_assigned_link_ids[valid_assigned_mask])
for link_id in unique_link_ids.tolist():
target_mask = batch_assigned_link_ids == int(link_id)
source_mask = bbox_source_link_ids[batch_idx] == int(link_id)
if not bool(source_mask.any().item()):
source_mask = target_mask
source_query_points = query_points[batch_idx][source_mask]
min_corner = source_query_points.min(dim=0).values
max_corner = source_query_points.max(dim=0).values
centers[batch_idx][target_mask] = 0.5 * (min_corner + max_corner)
half_extents[batch_idx][target_mask] = (
0.5 * (max_corner - min_corner)
).clamp_min(_INFER_OVERPARAM_AXIS_AABB_HALF_EXTENT_MIN)
return centers, half_extents
def _denormalize_overparam_axis_points_from_second_pass_intersection(
*,
axis_points: torch.Tensor,
query_points: torch.Tensor,
assigned_link_ids: torch.Tensor,
segmentation_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Denormalizes part-AABB axis points from the first/second-pass query intersection."""
bbox_source_link_ids = _resolve_second_pass_part_aabb_bbox_source_link_ids(
assigned_link_ids=assigned_link_ids,
segmentation_logits=segmentation_logits,
)
centers, half_extents = _compute_query_link_aabb_parameters_from_source_subset(
query_points=query_points,
assigned_link_ids=assigned_link_ids,
bbox_source_link_ids=bbox_source_link_ids,
)
return axis_points.float() * half_extents + centers, bbox_source_link_ids
def _apply_confidence_weighted_overparam_joint_voting(
model: Particulate2ArticulationModel,
motion_output: dict[str, Any],
*,
query_points: torch.Tensor,
confidence_temperature: float = 1.0,
) -> dict[str, Any]:
"""Recomputes overparameterized joint parameters with confidence-weighted voting."""
if model.joint_decode_type not in {
"overparametrized",
"overparam+dir",
"overparam+singledir",
}:
return motion_output
joint_decoding_link_ids = motion_output.get("joint_decoding_link_ids")
joint_decoding_confidences = motion_output.get("joint_decoding_confidences")
revolute_closest_axis_points = motion_output.get("revolute_closest_axis_points")
revolute_low_points = motion_output.get("revolute_low_points")
revolute_high_points = motion_output.get("revolute_high_points")
prismatic_closest_axis_points = motion_output.get("prismatic_closest_axis_points")
prismatic_low_points = motion_output.get("prismatic_low_points")
prismatic_high_points = motion_output.get("prismatic_high_points")
segmentation_logits = motion_output.get("segmentation_logits")
if (
joint_decoding_link_ids is None
or joint_decoding_confidences is None
or revolute_closest_axis_points is None
or revolute_low_points is None
or revolute_high_points is None
or prismatic_closest_axis_points is None
or prismatic_low_points is None
or prismatic_high_points is None
):
return motion_output
joint_connections = motion_output["joint_connections"]
joint_valid_flag = motion_output["joint_valid_flag"]
is_revolute = motion_output["is_revolute"]
is_prismatic = motion_output["is_prismatic"]
query_points = query_points.to(
device=joint_decoding_link_ids.device,
dtype=torch.float32,
)
bbox_source_link_ids: torch.Tensor | None = None
if getattr(model, "overparam_closest_axis_uses_part_aabb", False) is True:
if segmentation_logits is None:
raise ValueError(
"part_aabb inference-time refit requires second-pass segmentation_logits"
)
revolute_closest_axis_points_decoder = motion_output.get(
"revolute_closest_axis_points_decoder"
)
prismatic_closest_axis_points_decoder = motion_output.get(
"prismatic_closest_axis_points_decoder"
)
segmentation_logits = segmentation_logits.to(device=query_points.device)
if revolute_closest_axis_points_decoder is not None:
(
revolute_closest_axis_points,
bbox_source_link_ids,
) = _denormalize_overparam_axis_points_from_second_pass_intersection(
axis_points=revolute_closest_axis_points_decoder.to(
device=query_points.device,
dtype=torch.float32,
),
query_points=query_points,
assigned_link_ids=joint_decoding_link_ids,
segmentation_logits=segmentation_logits,
)
if prismatic_closest_axis_points_decoder is not None:
(
prismatic_closest_axis_points,
bbox_source_link_ids,
) = _denormalize_overparam_axis_points_from_second_pass_intersection(
axis_points=prismatic_closest_axis_points_decoder.to(
device=query_points.device,
dtype=torch.float32,
),
query_points=query_points,
assigned_link_ids=joint_decoding_link_ids,
segmentation_logits=segmentation_logits,
)
joint_decoding_confidences = _apply_joint_decoding_confidence_temperature(
joint_decoding_confidences,
temperature=confidence_temperature,
).to(device=joint_decoding_link_ids.device)
batch_size, max_joints = joint_connections.shape[:2]
revolute_axis = query_points.new_zeros((batch_size, max_joints, 6))
prismatic_axis = query_points.new_zeros((batch_size, max_joints, 6))
revolute_range = query_points.new_zeros((batch_size, max_joints, 2))
prismatic_range = query_points.new_zeros((batch_size, max_joints, 2))
revolute_axis_directions = None
prismatic_axis_directions = None
if model.joint_decode_type in {"overparam+dir", "overparam+singledir"}:
revolute_axis_directions = motion_output["revolute_axis_directions"]
prismatic_axis_directions = motion_output["prismatic_axis_directions"]
if revolute_axis_directions is None or prismatic_axis_directions is None:
if model.joint_decode_type == "overparam+dir":
raise ValueError(
"direction-assisted overparam inference requires per-query axis "
"direction predictions"
)
raise ValueError(
"overparam+singledir inference requires per-joint axis direction predictions"
)
child_link_ids = joint_connections[..., 1]
for batch_idx in range(batch_size):
for joint_idx in range(max_joints):
child_link_id = child_link_ids[batch_idx, joint_idx]
query_mask = joint_decoding_link_ids[batch_idx] == child_link_id
if not bool(query_mask.any().item()):
continue
joint_query_points = query_points[batch_idx][query_mask]
joint_weights = joint_decoding_confidences[batch_idx][query_mask]
joint_direction_vote_query_mask = None
if bbox_source_link_ids is not None:
joint_direction_vote_query_mask = (
bbox_source_link_ids[batch_idx][query_mask] == int(child_link_id)
)
if not bool(joint_direction_vote_query_mask.any().item()):
joint_direction_vote_query_mask = None
if model.joint_decode_type == "overparam+dir":
(
revolute_axis[batch_idx, joint_idx],
revolute_range[batch_idx, joint_idx],
) = _fit_confidence_weighted_revolute_joint_parameters_with_direction(
joint_query_points,
revolute_closest_axis_points[batch_idx][query_mask],
revolute_low_points[batch_idx][query_mask],
revolute_high_points[batch_idx][query_mask],
revolute_axis_directions[batch_idx][query_mask],
joint_weights,
direction_vote_query_mask=joint_direction_vote_query_mask,
)
(
prismatic_axis[batch_idx, joint_idx],
prismatic_range[batch_idx, joint_idx],
) = _fit_confidence_weighted_prismatic_joint_parameters_with_direction(
joint_query_points,
prismatic_closest_axis_points[batch_idx][query_mask],
prismatic_low_points[batch_idx][query_mask],
prismatic_high_points[batch_idx][query_mask],
prismatic_axis_directions[batch_idx][query_mask],
joint_weights,
direction_vote_query_mask=joint_direction_vote_query_mask,
)
elif model.joint_decode_type == "overparam+singledir":
(
revolute_axis[batch_idx, joint_idx],
revolute_range[batch_idx, joint_idx],
) = _fit_confidence_weighted_revolute_joint_parameters_with_single_direction(
joint_query_points,
revolute_closest_axis_points[batch_idx][query_mask],
revolute_low_points[batch_idx][query_mask],
revolute_high_points[batch_idx][query_mask],
revolute_axis_directions[batch_idx, joint_idx],
joint_weights,
)
(
prismatic_axis[batch_idx, joint_idx],
prismatic_range[batch_idx, joint_idx],
) = _fit_confidence_weighted_prismatic_joint_parameters_with_single_direction(
joint_query_points,
prismatic_closest_axis_points[batch_idx][query_mask],
prismatic_low_points[batch_idx][query_mask],
prismatic_high_points[batch_idx][query_mask],
prismatic_axis_directions[batch_idx, joint_idx],
joint_weights,
)
else:
(
revolute_axis[batch_idx, joint_idx],
revolute_range[batch_idx, joint_idx],
) = _fit_confidence_weighted_revolute_joint_parameters(
joint_query_points,
revolute_closest_axis_points[batch_idx][query_mask],
revolute_low_points[batch_idx][query_mask],
revolute_high_points[batch_idx][query_mask],
joint_weights,
)
(
prismatic_axis[batch_idx, joint_idx],
prismatic_range[batch_idx, joint_idx],
) = _fit_confidence_weighted_prismatic_joint_parameters(
joint_query_points,
prismatic_closest_axis_points[batch_idx][query_mask],
prismatic_low_points[batch_idx][query_mask],
prismatic_high_points[batch_idx][query_mask],
joint_weights,
)
revolute_mask = (joint_valid_flag & is_revolute).unsqueeze(-1)
prismatic_mask = (joint_valid_flag & is_prismatic).unsqueeze(-1)
weighted_motion_output = dict(motion_output)
weighted_motion_output.update(
{
"revolute_closest_axis_points": revolute_closest_axis_points,
"prismatic_closest_axis_points": prismatic_closest_axis_points,
"revolute_axis": revolute_axis.masked_fill(~revolute_mask, 0),
"prismatic_axis": prismatic_axis.masked_fill(~prismatic_mask, 0),
"revolute_range": revolute_range.masked_fill(~revolute_mask[..., :1], 0),
"prismatic_range": prismatic_range.masked_fill(~prismatic_mask[..., :1], 0),
}
)
if bbox_source_link_ids is not None:
weighted_motion_output["part_aabb_bbox_source_link_ids"] = bbox_source_link_ids
return weighted_motion_output
def build_joint_refit_metadata(
refit_sampling: dict[str, np.ndarray] | None,
*,
num_links: int,
) -> dict[str, Any]:
"""Returns metadata describing the balanced joint-refit sampling stage."""
if refit_sampling is None:
return {
"joint_refit_applied": False,
}
per_link_query_counts = np.zeros((num_links,), dtype=np.int32)
refit_link_ids = np.asarray(refit_sampling["link_ids"], dtype=np.int64)
unique_link_ids, query_counts = np.unique(refit_link_ids, return_counts=True)
per_link_query_counts[unique_link_ids.astype(np.int64, copy=False)] = query_counts.astype(
np.int32,
copy=False,
)
return {
"joint_refit_applied": True,
"joint_refit_strategy": "balanced_surface_queries_from_refined_face_seg",
"joint_refit_unique_part_ids": np.asarray(
refit_sampling["unique_part_ids"],
dtype=np.int32,
).tolist(),
"joint_refit_query_counts_per_part": np.asarray(
refit_sampling["query_counts"],
dtype=np.int32,
).tolist(),
"joint_refit_query_counts_per_link": per_link_query_counts.tolist(),
}
def decode_face_part_ids(
mesh: Any,
*,
point_part_ids: np.ndarray,
point_part_probabilities: np.ndarray,
query_face_indices: np.ndarray,
input_part_ids: np.ndarray,
strict: bool,
enforce_connectivity_per_part: bool,
) -> tuple[np.ndarray, np.ndarray]:
"""Decodes unrefined and refined face labels from point-level predictions."""
face_part_ids_unrefined = find_unrefined_part_ids_for_faces(
mesh,
point_part_ids,
query_face_indices,
)
face_part_ids = refine_face_part_ids_for_inference(
mesh,
face_part_ids_unrefined,
point_part_probabilities=point_part_probabilities,
face_indices=query_face_indices,
input_part_ids=input_part_ids,
strict=bool(strict),
enforce_connectivity_per_part=bool(enforce_connectivity_per_part),
)
return face_part_ids, face_part_ids_unrefined
def compute_motion_prediction_artifacts(
*,
model: Particulate2ArticulationModel,
batch: dict[str, Any],
normalized_mesh: Any,
face_part_ids: np.ndarray,
joint_refit_num_query_points: int,
num_links: int,
query_batch_size: int,
no_point_prompt: bool,
joint_decoding_confidence_temperature: float,
center: np.ndarray | tuple[float, float, float],
scale: float,
) -> dict[str, Any]:
"""Runs joint decoding/refit and returns normalized/world motion artifacts."""
motion_output, joint_refit_sampling = run_joint_refit_from_face_seg(
model=model,
batch=batch,
mesh=normalized_mesh,
face_part_ids=face_part_ids,
num_query_points=int(joint_refit_num_query_points),
query_batch_size=int(query_batch_size),
no_point_prompt_for_unique_text=bool(no_point_prompt),
)
if model.joint_decode_type in {
"overparametrized",
"overparam+dir",
"overparam+singledir",
}:
motion_output = _apply_confidence_weighted_overparam_joint_voting(
model,
motion_output,
query_points=(
torch.from_numpy(np.asarray(joint_refit_sampling["query_points"], dtype=np.float32))
.unsqueeze(0)
.to(batch["query_points"].device)
),
confidence_temperature=joint_decoding_confidence_temperature,
)
motion_arrays_normalized = motion_arrays_from_model_output(
motion_output,
num_links=int(num_links),
)
if model.joint_decode_type in {"plain", "plain+fm"}:
motion_arrays_normalized = _canonicalize_plain_motion_arrays(
motion_arrays_normalized
)
motion_arrays_world = denormalize_motion_parameters(
motion_arrays_normalized["revolute_plucker"],
motion_arrays_normalized["prismatic_plucker"],
motion_arrays_normalized["revolute_range"],
motion_arrays_normalized["prismatic_range"],
center=center,
scale=scale,
)
return {
"motion_output": motion_output,
"joint_refit_sampling": joint_refit_sampling,
"motion_arrays_normalized": motion_arrays_normalized,
"motion_arrays_world": motion_arrays_world,
"prismatic_axis_normalized": prismatic_directions_from_plucker(
motion_arrays_normalized["prismatic_plucker"]
),
"prismatic_axis_world": prismatic_directions_from_plucker(
motion_arrays_world["prismatic_plucker"]
),
}
def _canonicalize_axis_range_pairs(
axis_parameters: np.ndarray,
ranges: np.ndarray,
*,
valid_mask: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Flips valid axis/range pairs so exported limits use `low <= high`."""
canonical_axis = np.asarray(axis_parameters, dtype=np.float32).copy()
canonical_ranges = np.asarray(ranges, dtype=np.float32).copy()
valid_mask = np.asarray(valid_mask, dtype=np.bool_)
flip_mask = valid_mask & (canonical_ranges[:, 0] > canonical_ranges[:, 1])
canonical_axis[flip_mask] *= np.float32(-1.0)
canonical_ranges[flip_mask] *= np.float32(-1.0)
return canonical_axis, canonical_ranges
def _canonicalize_plain_motion_arrays(
motion_arrays: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
"""Returns copies of plain-decoder motion arrays with a deterministic sign convention."""
canonical_motion_arrays = {
key: value.copy()
for key, value in motion_arrays.items()
}
(
canonical_motion_arrays["revolute_plucker"],
canonical_motion_arrays["revolute_range"],
) = _canonicalize_axis_range_pairs(
canonical_motion_arrays["revolute_plucker"],
canonical_motion_arrays["revolute_range"],
valid_mask=canonical_motion_arrays["revolute_parameter_valid"],
)
(
canonical_motion_arrays["prismatic_plucker"],
canonical_motion_arrays["prismatic_range"],
) = _canonicalize_axis_range_pairs(
canonical_motion_arrays["prismatic_plucker"],
canonical_motion_arrays["prismatic_range"],
valid_mask=canonical_motion_arrays["prismatic_parameter_valid"],
)
return canonical_motion_arrays
def _save_query_predictions(
output_path: Path,
*,
query_points: np.ndarray,
face_indices: np.ndarray,
part_ids: np.ndarray,
center: np.ndarray,
scale: float,
joint_refit_sampling: dict[str, np.ndarray] | None,
query_normals: np.ndarray | None = None,
gt_part_ids: np.ndarray | None = None,
) -> None:
"""Writes first-pass query predictions plus optional joint-refit debug arrays."""
payload: dict[str, np.ndarray] = {
"query_points": denormalize_points(
query_points,
center=center,
scale=scale,
),
"face_indices": np.asarray(face_indices, dtype=np.int32),
"part_ids": np.asarray(part_ids, dtype=np.int32),
}
if query_normals is not None:
payload["query_normals"] = np.asarray(query_normals, dtype=np.float32)
if gt_part_ids is not None:
payload["gt_part_ids"] = np.asarray(gt_part_ids, dtype=np.int32)
if joint_refit_sampling is not None:
payload.update(
{
"joint_refit_query_points": denormalize_points(
joint_refit_sampling["query_points"],
center=center,
scale=scale,
),
"joint_refit_query_normals": np.asarray(
joint_refit_sampling["query_normals"],
dtype=np.float32,
),
"joint_refit_face_indices": np.asarray(
joint_refit_sampling["face_indices"],
dtype=np.int32,
),
"joint_refit_part_ids": np.asarray(
joint_refit_sampling["link_ids"],
dtype=np.int32,
),
}
)
np.savez(output_path, **payload)
def resolve_visualized_batch_link_point_prompts(
*,
batch: dict[str, Any],
links: list[dict[str, Any]],
no_point_prompt: bool,
center: np.ndarray,
scale: float,
) -> tuple[np.ndarray, np.ndarray]:
"""Returns denormalized prompt points filtered for visualization-only display."""
link_point_prompts = batch.get("link_point_prompts")
if link_point_prompts is None:
return np.zeros((0, 3), dtype=np.float32), np.zeros((0,), dtype=np.int64)
link_point_prompts_world = denormalize_points(
tensor_to_numpy(link_point_prompts[0], dtype=np.float32),
center=center,
scale=scale,
)
dropout_eligible = batch.get("link_point_prompt_dropout_eligible")
dropout_eligible_np = (
None
if dropout_eligible is None
else tensor_to_numpy(dropout_eligible[0], dtype=np.bool_)
)
return select_visualized_link_point_prompts(
link_point_prompts=link_point_prompts_world,
links=links,
hide_unique_text_prompts=bool(no_point_prompt),
link_point_prompt_dropout_eligible=dropout_eligible_np,
)
def build_base_metadata(
*,
mode: str,
input_path: Path,
run_dir: Path,
checkpoint_path: Path,
device: torch.device,
num_shape_points: int,
segmentation_num_query_points: int,
joint_refit_num_query_points: int,
num_query_points_per_face_for_seg: int | None,
query_batch_size: int,
no_point_prompt: bool,
enforce_connectivity_per_part: bool,
joint_decoding_confidence_temperature: float,
sharp_point_ratio: float,
) -> dict[str, Any]:
"""Builds the metadata fields common to mesh and meta-root inference."""
return {
"mode": mode,
"input_path": str(input_path),
"run_dir": str(run_dir),
"checkpoint_path": str(checkpoint_path),
"device": str(device),
"num_shape_points": int(num_shape_points),
"num_query_points": int(joint_refit_num_query_points),
"segmentation_num_query_points": int(segmentation_num_query_points),
"joint_refit_num_query_points": int(joint_refit_num_query_points),
"num_query_points_per_face_for_seg": (
None
if num_query_points_per_face_for_seg is None
else int(num_query_points_per_face_for_seg)
),
"query_batch_size": int(query_batch_size),
"no_point_prompt": bool(no_point_prompt),
"enforce_connectivity_per_part": bool(enforce_connectivity_per_part),
"joint_decoding_confidence_temperature": float(
joint_decoding_confidence_temperature
),
"sharp_point_ratio": float(sharp_point_ratio),
}
def write_metadata_and_summary(
*,
output_dir: Path,
metadata: dict[str, Any],
checkpoint_path: Path,
unique_part_ids: np.ndarray,
visualization_path: Path | None = None,
overparam_visualization_path: Path | None = None,
optional_paths: dict[str, Path | None] | None = None,
) -> None:
"""Writes metadata.json and prints the standard end-of-run summary."""
if optional_paths is not None:
for key, path in optional_paths.items():
if path is not None:
metadata[key] = str(path)
write_json(output_dir / "metadata.json", metadata)
print_inference_summary(
output_dir=output_dir,
checkpoint_path=checkpoint_path,
unique_part_ids=unique_part_ids,
visualization_path=visualization_path,
overparam_visualization_path=overparam_visualization_path,
)
def prepare_mesh_geometry(
*,
input_path: Path,
up_dir: str,
) -> PreparedMeshGeometry:
"""Loads one mesh input and prepares both model-space and Blender render-space copies."""
raw_mesh = load_trimesh(input_path)
original_mesh, up_dir_rotation = reorient_mesh_to_z_up(raw_mesh, up_dir)
normalized_mesh, center, scale = normalize_mesh(original_mesh)
render_export_rotation = up_dir_rotation_matrix("+Y", "+Z")
blender_import_rotation = up_dir_rotation_matrix("+Y", "+Z")
render_mesh = original_mesh.copy()
render_transform = np.eye(4, dtype=np.float32)
render_transform[:3, :3] = render_export_rotation
render_mesh.apply_transform(render_transform)
return PreparedMeshGeometry(
original_mesh=original_mesh,
normalized_mesh=normalized_mesh,
center=np.asarray(center, dtype=np.float32),
scale=float(scale),
up_dir_rotation=np.asarray(up_dir_rotation, dtype=np.float32),
render_mesh=render_mesh,
render_to_model_rotation=(blender_import_rotation @ render_export_rotation).T,
)
def write_mesh_like_prediction_files(
output_dir: Path,
*,
face_part_ids: np.ndarray,
query_points: np.ndarray,
query_face_indices: np.ndarray,
point_part_ids: np.ndarray,
center: np.ndarray,
scale: float,
joint_refit_sampling: dict[str, np.ndarray] | None,
query_normals: np.ndarray | None = None,
gt_part_ids: np.ndarray | None = None,
) -> None:
"""Writes the common segmentation/query artifacts for mesh-like inputs."""
np.save(output_dir / "seg.npy", np.asarray(face_part_ids, dtype=np.int32))
_save_query_predictions(
output_dir / "query_predictions.npz",
query_points=query_points,
query_normals=query_normals,
face_indices=query_face_indices,
part_ids=point_part_ids,
gt_part_ids=gt_part_ids,
center=np.asarray(center, dtype=np.float32),
scale=scale,
joint_refit_sampling=joint_refit_sampling,
)
def write_kinematic_and_overparam_visualization(
output_dir: Path,
*,
kinematic_records: dict[str, Any],
visualization_records: dict[str, Any] | None,
motion_output: dict[str, Any],
query_points: np.ndarray,
point_part_ids: np.ndarray,
joint_refit_sampling: dict[str, np.ndarray] | None,
center: np.ndarray,
scale: float,
) -> Path | None:
"""Writes `kinematic.json` and the shared overparameterized joint visualization."""
write_json(output_dir / "kinematic.json", kinematic_records)
records_for_visualization = (
kinematic_records if visualization_records is None else visualization_records
)
if joint_refit_sampling is None:
overparam_visualization_query_points = denormalize_points(
query_points,
center=np.asarray(center, dtype=np.float32),
scale=scale,
)
overparam_visualization_link_ids = np.asarray(point_part_ids, dtype=np.int32)
else:
overparam_visualization_query_points = denormalize_points(
joint_refit_sampling["query_points"],
center=np.asarray(center, dtype=np.float32),
scale=scale,
)
overparam_visualization_link_ids = np.asarray(
joint_refit_sampling["link_ids"],
dtype=np.int32,
)
return save_joint_overparam_visualization_from_model_output(
output_dir,
output=motion_output,
query_points=overparam_visualization_query_points,
query_link_ids=overparam_visualization_link_ids,
links=records_for_visualization["links"],
joints=records_for_visualization["joints"],
center=center,
scale=scale,
)
def save_articulated_mesh_outputs(
*,
output_dir: Path,
original_mesh: Any,
face_part_ids: np.ndarray,
motion_hierarchy: list[tuple[int, int]],
is_part_revolute: np.ndarray,
is_part_prismatic: np.ndarray,
revolute_plucker: np.ndarray,
revolute_range: np.ndarray,
prismatic_axis: np.ndarray,
prismatic_range: np.ndarray,
animation_frames: int,
export_urdf_enabled: bool,
urdf_name: str,
link_names: list[str],
urdf_is_part_revolute: np.ndarray | None = None,
urdf_is_part_prismatic: np.ndarray | None = None,
) -> np.ndarray:
"""Saves the common segmented mesh exports, plus optional URDF."""
mesh_parts_original, unique_part_ids = save_segmented_visualizations(
output_dir,
original_mesh,
face_part_ids,
motion_hierarchy=motion_hierarchy,
is_part_revolute=is_part_revolute,
is_part_prismatic=is_part_prismatic,
revolute_plucker=revolute_plucker,
revolute_range=revolute_range,
prismatic_axis=prismatic_axis,
prismatic_range=prismatic_range,
animation_frames=int(animation_frames),
)
if export_urdf_enabled:
export_urdf(
mesh_parts_original,
unique_part_ids,
motion_hierarchy,
(
np.asarray(is_part_revolute, dtype=np.bool_)
if urdf_is_part_revolute is None
else np.asarray(urdf_is_part_revolute, dtype=np.bool_)
),
(
np.asarray(is_part_prismatic, dtype=np.bool_)
if urdf_is_part_prismatic is None
else np.asarray(urdf_is_part_prismatic, dtype=np.bool_)
),
revolute_plucker,
revolute_range,
prismatic_axis,
prismatic_range,
output_path=str(output_dir / "urdf" / "model.urdf"),
name=urdf_name,
link_names=link_names,
)
return unique_part_ids