from __future__ import annotations from pathlib import Path from typing import Any, Mapping, Sequence import numpy as np import trimesh from instruct_particulate.datasets.vis_helper import ( plot_joint_overparam_visualizations, plot_query_visualization, ) from instruct_particulate.utils.export_utils import export_animated_glb_file from instruct_particulate.utils.visualization_utils import ( COLORS, create_motion_axis_meshes, create_textured_mesh_parts, ) def _denormalize_visualization_points( points: np.ndarray | None, *, center: np.ndarray, scale: float, ) -> np.ndarray | None: if points is None: return None return ( np.asarray(points, dtype=np.float32) / np.float32(scale) + np.asarray(center, dtype=np.float32) ).astype(np.float32, copy=False) def _collect_motion_overparam_visualization( *, motion_label: str, query_points: np.ndarray, query_link_ids: np.ndarray, query_confidences: np.ndarray | None, valid_link_ids: np.ndarray, closest_axis_points: np.ndarray | None, low_points: np.ndarray | None, high_points: np.ndarray | None, num_points: int, ) -> dict[str, Any] | None: if ( closest_axis_points is None or low_points is None or high_points is None or len(valid_link_ids) == 0 ): return None motion_valid_mask = np.isin(query_link_ids, valid_link_ids) valid_indices = np.flatnonzero(motion_valid_mask) if len(valid_indices) == 0: return None if len(valid_indices) > num_points: sampled_indices = np.random.choice( valid_indices, size=num_points, replace=False, ).astype(np.int64, copy=False) sampled_indices.sort() else: sampled_indices = valid_indices.astype(np.int64, copy=False) sampled_query_points = np.asarray(query_points, dtype=np.float32)[sampled_indices] sampled_query_link_ids = np.asarray(query_link_ids, dtype=np.int64)[sampled_indices] sampled_closest_axis_points = np.asarray( closest_axis_points, dtype=np.float32, )[sampled_indices] sampled_low_points = np.asarray(low_points, dtype=np.float32)[sampled_indices] sampled_high_points = np.asarray(high_points, dtype=np.float32)[sampled_indices] if len(sampled_query_points) == 0: return None if query_confidences is None: sampled_query_confidences = np.ones((len(sampled_indices),), dtype=np.float32) else: sampled_query_confidences = np.asarray( query_confidences, dtype=np.float32, )[sampled_indices] target_valid_mask = np.ones((len(sampled_indices),), dtype=np.bool_) palette = np.asarray(COLORS, dtype=np.float32) / 255.0 base_point_colors = np.full((len(sampled_indices), 3), 0.7, dtype=np.float32) valid_link_mask = sampled_query_link_ids >= 0 if np.any(valid_link_mask): base_point_colors[valid_link_mask] = palette[ sampled_query_link_ids[valid_link_mask] % len(palette) ] point_confidences = np.clip(sampled_query_confidences, 0.0, 1.0).reshape(-1, 1) point_colors = 1.0 - point_confidences * (1.0 - base_point_colors) return { "motion_label": motion_label, "query_points": sampled_query_points, "query_link_ids": sampled_query_link_ids, "point_colors": point_colors.astype(np.float32, copy=False), "closest_axis_points": sampled_closest_axis_points, "closest_axis_points_valid": target_valid_mask.copy(), "low_points": sampled_low_points, "low_points_valid": target_valid_mask.copy(), "high_points": sampled_high_points, "high_points_valid": target_valid_mask.copy(), "context_query_points": np.asarray(query_points, dtype=np.float32)[~motion_valid_mask], } def _output_array( output: Mapping[str, Any], key: str, *, dtype: np.dtype[Any] | type = np.float32, ) -> np.ndarray | None: value = output.get(key) if value is None: return None if hasattr(value, "detach"): value = value.detach().cpu().numpy() value_array = np.asarray(value) if value_array.ndim > 0 and value_array.shape[0] == 1: value_array = value_array[0] return np.asarray(value_array, dtype=dtype) def print_inference_summary( *, output_dir: Path, checkpoint_path: Path, unique_part_ids: np.ndarray, visualization_path: Path | None = None, overparam_visualization_path: Path | None = None, ) -> None: print(f"Saved inference outputs to {output_dir}") print(f"Checkpoint: {checkpoint_path}") if visualization_path is not None: print(f"Query visualization: {visualization_path}") if overparam_visualization_path is not None: print(f"Overparam visualization: {overparam_visualization_path}") print(f"Segmented parts present: {unique_part_ids.tolist()}") def save_segmented_visualizations( output_dir: Path, mesh, face_part_ids: np.ndarray, *, motion_hierarchy: Sequence[tuple[int, int]], is_part_revolute: np.ndarray, is_part_prismatic: np.ndarray, revolute_plucker: np.ndarray, revolute_range: np.ndarray, prismatic_axis: np.ndarray, prismatic_range: np.ndarray, animation_frames: int, ) -> tuple[list[trimesh.Trimesh], np.ndarray]: unique_part_ids = np.unique(face_part_ids).astype(np.int32, copy=False) mesh_parts_original = [ mesh.submesh([face_part_ids == part_id], append=True) for part_id in unique_part_ids ] mesh_parts_segmented = create_textured_mesh_parts( [mesh_part.copy() for mesh_part in mesh_parts_original], part_ids=unique_part_ids, ) axes = create_motion_axis_meshes( mesh_parts_original, unique_part_ids, is_part_revolute, is_part_prismatic, revolute_plucker, prismatic_axis, ) trimesh.Scene(mesh_parts_segmented + axes).export(output_dir / "mesh_parts_with_axes.glb") export_animated_glb_file( mesh_parts_original, unique_part_ids, list(motion_hierarchy), is_part_revolute, is_part_prismatic, revolute_plucker, revolute_range, prismatic_axis, prismatic_range, animation_frames, str(output_dir / "animated_textured.glb"), include_axes=False, axes_meshes=None, ) return mesh_parts_original, unique_part_ids def _resolve_link_point_prompt_inputs( link_point_prompts: np.ndarray | None, link_point_prompt_ids: np.ndarray | None, ) -> tuple[np.ndarray, np.ndarray]: if link_point_prompts is None: prompt_array = np.zeros((0, 3), dtype=np.float32) else: prompt_array = np.asarray(link_point_prompts, dtype=np.float32) if prompt_array.ndim != 2 or prompt_array.shape[-1] != 3: raise ValueError( f"Expected link_point_prompts to have shape (N, 3), got {prompt_array.shape}" ) if link_point_prompt_ids is None: prompt_ids = np.arange(len(prompt_array), dtype=np.int64) else: prompt_ids = np.asarray(link_point_prompt_ids, dtype=np.int64) if prompt_ids.shape != (len(prompt_array),): raise ValueError( "link_point_prompt_ids must align with link_point_prompts, " f"got {prompt_ids.shape} and {prompt_array.shape}" ) return prompt_array, prompt_ids def _sample_query_visualization_inputs( *, query_points: np.ndarray, query_normals: np.ndarray, query_link_ids: np.ndarray, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: if len(query_points) == 0: raise ValueError("Point visualization requires at least one query point") query_points = np.asarray(query_points, dtype=np.float32) query_normals = np.asarray(query_normals, dtype=np.float32) query_link_ids = np.asarray(query_link_ids, dtype=np.int64) sample_size = min(8192, len(query_points)) sample_indices = np.random.choice(len(query_points), size=sample_size, replace=False) sample_indices = sample_indices.astype(np.int64, copy=False) return ( query_points[sample_indices], query_normals[sample_indices], query_link_ids[sample_indices], sample_indices, ) def save_predicted_point_query_rest_visualization( output_dir: Path, *, query_points: np.ndarray, query_normals: np.ndarray, predicted_part_ids: np.ndarray, link_point_prompts: np.ndarray | None, link_point_prompt_ids: np.ndarray | None = None, links: Sequence[dict[str, Any]], predicted_link_confidences: Mapping[int, float] | None = None, ) -> Path: link_point_prompts, link_point_prompt_ids = _resolve_link_point_prompt_inputs( link_point_prompts, link_point_prompt_ids, ) ( visualization_query_points, visualization_query_normals, visualization_predicted_link_ids, _, ) = _sample_query_visualization_inputs( query_points=query_points, query_normals=query_normals, query_link_ids=predicted_part_ids, ) out_path = output_dir / "query_visualization.png" plot_query_visualization( query_points=visualization_query_points, query_normals=visualization_query_normals, query_link_ids=visualization_predicted_link_ids, link_point_prompts=link_point_prompts, link_point_prompt_ids=link_point_prompt_ids, links=list(links), predicted_link_confidences=predicted_link_confidences, out_path=out_path, normal_limit=0, ) return out_path def select_visualized_link_point_prompts( *, link_point_prompts: np.ndarray, links: Sequence[dict[str, Any]], hide_unique_text_prompts: bool = False, link_point_prompt_dropout_eligible: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray]: link_point_prompts = np.asarray(link_point_prompts, dtype=np.float32) if link_point_prompts.ndim != 2 or link_point_prompts.shape[-1] != 3: raise ValueError( f"Expected link_point_prompts to have shape (N, 3), got {link_point_prompts.shape}" ) if len(link_point_prompts) != len(links): raise ValueError( "link_point_prompts must align with links, " f"got {len(link_point_prompts)} prompts for {len(links)} links" ) link_point_prompt_ids = np.arange(len(links), dtype=np.int64) if not hide_unique_text_prompts: return link_point_prompts, link_point_prompt_ids if link_point_prompt_dropout_eligible is None: raise ValueError( "hide_unique_text_prompts=True requires link_point_prompt_dropout_eligible" ) link_point_prompt_dropout_eligible = np.asarray( link_point_prompt_dropout_eligible, dtype=np.bool_, ) if link_point_prompt_dropout_eligible.shape != (len(links),): raise ValueError( "link_point_prompt_dropout_eligible must align with links, " f"got {link_point_prompt_dropout_eligible.shape} for {len(links)} links" ) visible_prompt_mask = ~link_point_prompt_dropout_eligible return ( link_point_prompts[visible_prompt_mask], link_point_prompt_ids[visible_prompt_mask], ) def save_joint_overparam_visualization( output_dir: Path, *, query_points: np.ndarray, query_link_ids: np.ndarray, query_confidences: np.ndarray | None, links: Sequence[dict[str, Any]], joints: Sequence[dict[str, Any]], revolute_closest_axis_points: np.ndarray | None, revolute_low_points: np.ndarray | None, revolute_high_points: np.ndarray | None, prismatic_closest_axis_points: np.ndarray | None, prismatic_low_points: np.ndarray | None, prismatic_high_points: np.ndarray | None, num_points: int = 256, ) -> Path | None: if num_points <= 0: raise ValueError(f"num_points must be positive, got {num_points}") query_points = np.asarray(query_points, dtype=np.float32) query_link_ids = np.asarray(query_link_ids, dtype=np.int64) if query_points.shape != (len(query_link_ids), 3): raise ValueError( "query_points and query_link_ids must align, " f"got {query_points.shape} and {query_link_ids.shape}" ) revolute_child_link_ids = np.asarray( [ int(joint["child_link_id"]) for joint in joints if bool(joint.get("is_revolute", False)) ], dtype=np.int64, ) prismatic_child_link_ids = np.asarray( [ int(joint["child_link_id"]) for joint in joints if bool(joint.get("is_prismatic", False)) ], dtype=np.int64, ) visualizations: list[dict[str, Any]] = [] revolute_visualization = _collect_motion_overparam_visualization( motion_label="Revolute", query_points=query_points, query_link_ids=query_link_ids, query_confidences=query_confidences, valid_link_ids=revolute_child_link_ids, closest_axis_points=revolute_closest_axis_points, low_points=revolute_low_points, high_points=revolute_high_points, num_points=num_points, ) if revolute_visualization is not None: visualizations.append(revolute_visualization) prismatic_visualization = _collect_motion_overparam_visualization( motion_label="Prismatic", query_points=query_points, query_link_ids=query_link_ids, query_confidences=query_confidences, valid_link_ids=prismatic_child_link_ids, closest_axis_points=prismatic_closest_axis_points, low_points=prismatic_low_points, high_points=prismatic_high_points, num_points=num_points, ) if prismatic_visualization is not None: visualizations.append(prismatic_visualization) if len(visualizations) == 0: return None out_path = output_dir / "overparam.png" plot_joint_overparam_visualizations( visualizations=visualizations, links=list(links), out_path=out_path, ) return out_path def save_joint_overparam_visualization_from_model_output( output_dir: Path, *, output: Mapping[str, Any], query_points: np.ndarray, query_link_ids: np.ndarray, links: Sequence[dict[str, Any]], joints: Sequence[dict[str, Any]], center: np.ndarray | None = None, scale: float | None = None, num_points: int = 256, ) -> Path | None: if (center is None) != (scale is None): raise ValueError("center and scale must be provided together for denormalization") revolute_closest_axis_points = _output_array(output, "revolute_closest_axis_points") revolute_low_points = _output_array(output, "revolute_low_points") revolute_high_points = _output_array(output, "revolute_high_points") prismatic_closest_axis_points = _output_array(output, "prismatic_closest_axis_points") prismatic_low_points = _output_array(output, "prismatic_low_points") prismatic_high_points = _output_array(output, "prismatic_high_points") if center is not None and scale is not None: revolute_closest_axis_points = _denormalize_visualization_points( revolute_closest_axis_points, center=center, scale=scale, ) revolute_low_points = _denormalize_visualization_points( revolute_low_points, center=center, scale=scale, ) revolute_high_points = _denormalize_visualization_points( revolute_high_points, center=center, scale=scale, ) prismatic_closest_axis_points = _denormalize_visualization_points( prismatic_closest_axis_points, center=center, scale=scale, ) prismatic_low_points = _denormalize_visualization_points( prismatic_low_points, center=center, scale=scale, ) prismatic_high_points = _denormalize_visualization_points( prismatic_high_points, center=center, scale=scale, ) return save_joint_overparam_visualization( output_dir, query_points=query_points, query_link_ids=query_link_ids, query_confidences=_output_array( output, "joint_decoding_confidences", dtype=np.float32, ), links=links, joints=joints, revolute_closest_axis_points=revolute_closest_axis_points, revolute_low_points=revolute_low_points, revolute_high_points=revolute_high_points, prismatic_closest_axis_points=prismatic_closest_axis_points, prismatic_low_points=prismatic_low_points, prismatic_high_points=prismatic_high_points, num_points=num_points, )