Spaces:
Running on Zero
Running on Zero
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Mapping, Optional, Sequence, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import yaml | |
| from instruct_particulate.utils.articulation_utils import ( | |
| plucker_to_axis_point, | |
| transform_plucker, | |
| ) | |
| from instruct_particulate.utils.data_utils import ( | |
| sample_points as sample_mesh_points, | |
| sample_points_per_face as sample_mesh_points_per_face, | |
| ) | |
| def write_json(path: Path, payload: Any) -> None: | |
| """Writes a JSON payload with stable indentation.""" | |
| with path.open("w", encoding="utf-8") as fh: | |
| json.dump(payload, fh, indent=2) | |
| def axis_point_to_plucker_torch(axis: torch.Tensor, point: torch.Tensor) -> torch.Tensor: | |
| """Converts an axis-point line representation to Plucker coordinates.""" | |
| axis = F.normalize(axis, dim=-1, eps=1e-8) | |
| moment = torch.linalg.cross(axis, point, dim=-1) | |
| return torch.cat((axis, moment), dim=-1) | |
| def fit_axis_to_closest_points_torch( | |
| query_points: torch.Tensor, | |
| closest_axis_points: torch.Tensor, | |
| *, | |
| direction_hint: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Fits an axis from closest-point targets using a closed-form projection objective. | |
| The fitted line minimizes | |
| sum_i ||c_i - proj_L(q_i)||^2 | |
| where `q_i` are `query_points`, `c_i` are `closest_axis_points`, and `L` | |
| is the recovered axis. This enforces both that the predicted closest points | |
| lie on the axis and that `(q_i - c_i)` is perpendicular to the axis. | |
| Returns `(axis_direction, axis_point)`, where `axis_point` is the point on | |
| the recovered line closest to the origin. | |
| """ | |
| if query_points.shape != closest_axis_points.shape: | |
| raise ValueError( | |
| "query_points and closest_axis_points must share the same shape, " | |
| f"got {tuple(query_points.shape)} and {tuple(closest_axis_points.shape)}" | |
| ) | |
| if query_points.ndim != 2 or query_points.shape[-1] != 3: | |
| raise ValueError( | |
| f"Expected query_points and closest_axis_points to have shape (N, 3), got {tuple(query_points.shape)}" | |
| ) | |
| if query_points.numel() == 0: | |
| zero = query_points.new_zeros(3) | |
| return zero, zero | |
| result_dtype = ( | |
| query_points.dtype | |
| if query_points.dtype in (torch.float32, torch.float64) | |
| else torch.float32 | |
| ) | |
| solve_dtype = torch.float32 if query_points.device.type == "mps" else torch.float64 | |
| query_points = query_points.to(dtype=solve_dtype) | |
| closest_axis_points = closest_axis_points.to(dtype=solve_dtype) | |
| if direction_hint is None: | |
| direction_hint = query_points.new_zeros(3) | |
| else: | |
| direction_hint = direction_hint.to( | |
| device=query_points.device, | |
| dtype=solve_dtype, | |
| ) | |
| closest_axis_mean = closest_axis_points.mean(dim=0) | |
| centered_closest_axis_points = closest_axis_points - closest_axis_mean.unsqueeze(0) | |
| query_residuals = query_points - closest_axis_points | |
| objective_matrix = ( | |
| query_residuals.transpose(0, 1) @ query_residuals | |
| - centered_closest_axis_points.transpose(0, 1) @ centered_closest_axis_points | |
| ) | |
| direction_hint_norm = torch.linalg.norm(direction_hint) | |
| normalized_direction_hint = direction_hint / direction_hint_norm.clamp_min(1e-12) | |
| eigenvalues, eigenvectors = torch.linalg.eigh(objective_matrix) | |
| axis_direction = eigenvectors[:, 0] | |
| if direction_hint_norm > 1e-8: | |
| eigenspace_scale = torch.maximum( | |
| eigenvalues.abs().max(), | |
| objective_matrix.abs().max(), | |
| ).clamp_min(1.0) | |
| eigenspace_tol = eigenspace_scale * ( | |
| 1e-8 if solve_dtype == torch.float64 else 1e-5 | |
| ) | |
| smallest_eigenspace = eigenvectors[ | |
| :, | |
| eigenvalues <= (eigenvalues[0] + eigenspace_tol), | |
| ] | |
| projected_hint = smallest_eigenspace @ ( | |
| smallest_eigenspace.transpose(0, 1) @ normalized_direction_hint | |
| ) | |
| if torch.linalg.norm(projected_hint) > 1e-8: | |
| axis_direction = projected_hint | |
| if direction_hint_norm > 1e-8 and torch.dot(axis_direction, normalized_direction_hint) < 0: | |
| axis_direction = -axis_direction | |
| axis_direction = F.normalize(axis_direction, dim=-1, eps=1e-8) | |
| fallback_direction = torch.where( | |
| direction_hint_norm > 1e-8, | |
| normalized_direction_hint, | |
| query_points.new_tensor([1.0, 0.0, 0.0]), | |
| ) | |
| if not torch.isfinite(axis_direction).all() or torch.linalg.norm(axis_direction) <= 1e-8: | |
| axis_direction = fallback_direction | |
| axis_point = closest_axis_mean - axis_direction * torch.dot(axis_direction, closest_axis_mean) | |
| if not torch.isfinite(axis_point).all(): | |
| axis_point = closest_axis_mean | |
| return axis_direction.to(dtype=result_dtype), axis_point.to(dtype=result_dtype) | |
| def estimate_prismatic_limit_torch( | |
| current_points: torch.Tensor, | |
| target_points: torch.Tensor, | |
| axis_direction: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Fits a shared prismatic displacement along `axis_direction`.""" | |
| if current_points.numel() == 0: | |
| return current_points.new_zeros(()) | |
| axis_direction = F.normalize(axis_direction, dim=-1, eps=1e-8) | |
| offsets = target_points - current_points | |
| return (offsets * axis_direction.unsqueeze(0)).sum(dim=-1).mean() | |
| def estimate_revolute_limit_torch( | |
| current_points: torch.Tensor, | |
| target_points: torch.Tensor, | |
| axis_direction: torch.Tensor, | |
| axis_point: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Fits a shared revolute angle around the given axis via a global least-squares solve.""" | |
| if current_points.numel() == 0: | |
| return current_points.new_zeros(()) | |
| axis_direction = F.normalize(axis_direction, dim=-1, eps=1e-8) | |
| def _project_perpendicular(points: torch.Tensor) -> torch.Tensor: | |
| offsets = points - axis_point.unsqueeze(0) | |
| parallel = (offsets * axis_direction.unsqueeze(0)).sum(dim=-1, keepdim=True) | |
| return offsets - parallel * axis_direction.unsqueeze(0) | |
| current_perp = _project_perpendicular(current_points) | |
| target_perp = _project_perpendicular(target_points) | |
| cosine_term = (current_perp * target_perp).sum(dim=-1).sum() | |
| sine_term = torch.linalg.cross(current_perp, target_perp, dim=-1) | |
| sine_term = (sine_term * axis_direction.unsqueeze(0)).sum(dim=-1).sum() | |
| return torch.atan2(sine_term, cosine_term) | |
| def strip_text_model_state(state_dict: Mapping[str, Any]) -> dict[str, Any]: | |
| return { | |
| key: value | |
| for key, value in state_dict.items() | |
| if not key.startswith("encoder.text_model.") | |
| } | |
| _ALLOWED_MISSING_CHECKPOINT_KEYS = frozenset( | |
| { | |
| "encoder.no_text_conditioning_embedding", | |
| } | |
| ) | |
| def load_run_config(run_dir: Path) -> dict[str, Any]: | |
| config_path = run_dir / "config.resolved.yaml" | |
| if not config_path.exists(): | |
| config_path = run_dir / "config.source.yaml" | |
| if not config_path.exists(): | |
| raise FileNotFoundError( | |
| f"Could not find config.resolved.yaml or config.source.yaml under {run_dir}" | |
| ) | |
| with config_path.open("r", encoding="utf-8") as fh: | |
| config = yaml.safe_load(fh) | |
| if not isinstance(config, dict): | |
| raise ValueError(f"Expected a mapping in {config_path}, got {type(config).__name__}") | |
| return config | |
| def configure_runtime_environment(config: Mapping[str, Any]) -> None: | |
| runtime_config = config.get("runtime", {}) | |
| if not isinstance(runtime_config, Mapping): | |
| return | |
| configured_hf_cache_dir = Path( | |
| runtime_config.get("hf_cache_dir", ".cache/huggingface") | |
| ).expanduser() | |
| fallback_hf_cache_dir = (Path.cwd() / ".cache" / "huggingface").resolve() | |
| candidate_hf_cache_dirs = [ | |
| configured_hf_cache_dir.resolve(), | |
| fallback_hf_cache_dir, | |
| ] | |
| hf_cache_dir = None | |
| hf_hub_cache = None | |
| for candidate_hf_cache_dir in candidate_hf_cache_dirs: | |
| candidate_hf_hub_cache = candidate_hf_cache_dir / "hub" | |
| try: | |
| candidate_hf_hub_cache.mkdir(parents=True, exist_ok=True) | |
| except OSError: | |
| continue | |
| hf_cache_dir = candidate_hf_cache_dir | |
| hf_hub_cache = candidate_hf_hub_cache | |
| break | |
| if hf_cache_dir is None or hf_hub_cache is None: | |
| raise OSError( | |
| "Could not create a writable Hugging Face cache directory. " | |
| f"Tried: {candidate_hf_cache_dirs}" | |
| ) | |
| os.environ["HF_HOME"] = str(hf_cache_dir) | |
| os.environ["HF_HUB_CACHE"] = str(hf_hub_cache) | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| def load_model_checkpoint_for_inference( | |
| model: torch.nn.Module, | |
| checkpoint_path: Path, | |
| *, | |
| device: torch.device, | |
| ) -> dict[str, Any]: | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
| missing_keys, unexpected_keys = model.load_state_dict( | |
| strip_text_model_state(checkpoint["model"]), | |
| strict=False, | |
| ) | |
| non_text_missing_keys = [ | |
| key | |
| for key in missing_keys | |
| if not key.startswith("encoder.text_model.") | |
| and key not in _ALLOWED_MISSING_CHECKPOINT_KEYS | |
| ] | |
| if non_text_missing_keys or unexpected_keys: | |
| raise RuntimeError( | |
| "Checkpoint/model state mismatch. " | |
| f"Missing keys: {non_text_missing_keys}; unexpected keys: {unexpected_keys}" | |
| ) | |
| model.to(device) | |
| model.eval() | |
| return checkpoint | |
| def resolve_inference_sampling_config( | |
| config: Mapping[str, Any], | |
| ) -> tuple[int, int | None, float]: | |
| datasets_config = config.get("datasets", {}) | |
| if not isinstance(datasets_config, Mapping): | |
| raise ValueError("Config is missing a datasets mapping") | |
| shared_config = datasets_config.get("shared") | |
| if shared_config is not None and not isinstance(shared_config, Mapping): | |
| raise ValueError( | |
| "datasets.shared must be a mapping when provided, " | |
| f"got {type(shared_config).__name__}" | |
| ) | |
| dataset_entries: list[Mapping[str, Any]] = [] | |
| for split_name in ("train", "val"): | |
| split_entries = datasets_config.get(split_name, []) | |
| if isinstance(split_entries, Sequence) and not isinstance(split_entries, (str, bytes)): | |
| dataset_entries.extend( | |
| entry for entry in split_entries if isinstance(entry, Mapping) | |
| ) | |
| if not dataset_entries: | |
| raise ValueError("Config does not contain any dataset entries under datasets.train or datasets.val") | |
| if shared_config is not None: | |
| missing_keys = [ | |
| key for key in ("num_shape_points", "num_query_points") if key not in shared_config | |
| ] | |
| if missing_keys: | |
| raise KeyError( | |
| "datasets.shared must define num_shape_points and num_query_points, " | |
| f"missing {missing_keys}" | |
| ) | |
| num_shape_points = int(shared_config["num_shape_points"]) | |
| default_num_query_points = int(shared_config["num_query_points"]) | |
| else: | |
| num_shape_points_values = { | |
| int(entry["num_shape_points"]) | |
| for entry in dataset_entries | |
| if "num_shape_points" in entry | |
| } | |
| if len(num_shape_points_values) != 1: | |
| raise ValueError( | |
| "Inference requires a single num_shape_points value across saved dataset configs, " | |
| f"got {sorted(num_shape_points_values)}" | |
| ) | |
| num_query_points_values = { | |
| int(entry["num_query_points"]) | |
| for entry in dataset_entries | |
| if "num_query_points" in entry | |
| } | |
| default_num_query_points = None | |
| if len(num_query_points_values) == 1: | |
| default_num_query_points = next(iter(num_query_points_values)) | |
| num_shape_points = next(iter(num_shape_points_values)) | |
| sharp_point_ratio_values = { | |
| float(entry.get("sharp_point_ratio", 0.5)) | |
| for entry in dataset_entries | |
| } | |
| if len(sharp_point_ratio_values) != 1: | |
| raise ValueError( | |
| "Inference requires a single sharp_point_ratio across saved dataset configs, " | |
| f"got {sorted(sharp_point_ratio_values)}" | |
| ) | |
| return ( | |
| num_shape_points, | |
| default_num_query_points, | |
| next(iter(sharp_point_ratio_values)), | |
| ) | |
| def validate_link_names( | |
| link_names: Sequence[str], | |
| *, | |
| require_unique: bool = True, | |
| ) -> list[str]: | |
| cleaned_link_names = [str(link_name).strip() for link_name in link_names] | |
| if not cleaned_link_names: | |
| raise ValueError("At least one link name is required") | |
| if any(not link_name for link_name in cleaned_link_names): | |
| raise ValueError("Link names must be non-empty strings") | |
| if require_unique and len(set(cleaned_link_names)) != len(cleaned_link_names): | |
| raise ValueError(f"Link names must be unique, got {cleaned_link_names}") | |
| return cleaned_link_names | |
| def build_joint_tensors( | |
| num_links: int, | |
| joint_specs: Sequence[tuple[int, int, str]], | |
| *, | |
| device: torch.device | None = None, | |
| ) -> dict[str, torch.Tensor]: | |
| if num_links <= 0: | |
| raise ValueError(f"num_links must be positive, got {num_links}") | |
| parent_by_child: dict[int, int] = {} | |
| for parent_link_id, child_link_id, joint_type in joint_specs: | |
| if not 0 <= parent_link_id < num_links: | |
| raise ValueError( | |
| f"Parent link ID {parent_link_id} is outside the valid range [0, {num_links - 1}]" | |
| ) | |
| if not 0 <= child_link_id < num_links: | |
| raise ValueError( | |
| f"Child link ID {child_link_id} is outside the valid range [0, {num_links - 1}]" | |
| ) | |
| if parent_link_id == child_link_id: | |
| raise ValueError("Joint parent and child link IDs must be different") | |
| if child_link_id in parent_by_child: | |
| raise ValueError(f"Child link {child_link_id} appears in more than one joint") | |
| parent_by_child[child_link_id] = parent_link_id | |
| for start_link_id in range(num_links): | |
| seen: set[int] = set() | |
| current_link_id = start_link_id | |
| while current_link_id in parent_by_child: | |
| if current_link_id in seen: | |
| raise ValueError("Joint specs must form an acyclic forest") | |
| seen.add(current_link_id) | |
| current_link_id = parent_by_child[current_link_id] | |
| num_joints = len(joint_specs) | |
| joint_connections = torch.full( | |
| (1, num_joints, 2), | |
| fill_value=-1, | |
| dtype=torch.long, | |
| device=device, | |
| ) | |
| joint_valid_flag = torch.zeros((1, num_joints), dtype=torch.bool, device=device) | |
| is_revolute = torch.zeros((1, num_joints), dtype=torch.bool, device=device) | |
| is_prismatic = torch.zeros((1, num_joints), dtype=torch.bool, device=device) | |
| for joint_idx, (parent_link_id, child_link_id, joint_type) in enumerate(joint_specs): | |
| joint_connections[0, joint_idx] = torch.tensor( | |
| [parent_link_id, child_link_id], | |
| dtype=torch.long, | |
| device=device, | |
| ) | |
| joint_valid_flag[0, joint_idx] = True | |
| if joint_type == "revolute": | |
| is_revolute[0, joint_idx] = True | |
| else: | |
| is_prismatic[0, joint_idx] = True | |
| return { | |
| "joint_connections": joint_connections, | |
| "joint_valid_flag": joint_valid_flag, | |
| "is_revolute": is_revolute, | |
| "is_prismatic": is_prismatic, | |
| } | |
| def _to_batched_float_tensor(array: np.ndarray, *, device: torch.device) -> torch.Tensor: | |
| return torch.from_numpy(array).float().unsqueeze(0).to(device) | |
| def _build_inference_batch( | |
| *, | |
| shape_points: np.ndarray, | |
| shape_normals: np.ndarray, | |
| query_points: np.ndarray, | |
| query_normals: np.ndarray, | |
| num_links: int, | |
| joint_connections: torch.Tensor, | |
| joint_valid_flag: torch.Tensor, | |
| is_revolute: torch.Tensor, | |
| is_prismatic: torch.Tensor, | |
| device: torch.device, | |
| link_point_prompts: torch.Tensor | None = None, | |
| link_point_prompt_normals: torch.Tensor | None = None, | |
| link_point_prompt_dropout_eligible: torch.Tensor | None = None, | |
| link_text_prompts: Sequence[str] | None = None, | |
| link_text_embeddings: torch.Tensor | None = None, | |
| ) -> dict[str, Any]: | |
| batch = { | |
| "shape_points": _to_batched_float_tensor(shape_points, device=device), | |
| "shape_point_normals": _to_batched_float_tensor(shape_normals, device=device), | |
| "query_points": _to_batched_float_tensor(query_points, device=device), | |
| "query_point_normals": _to_batched_float_tensor(query_normals, device=device), | |
| "link_point_prompts": None, | |
| "link_point_prompt_normals": None, | |
| "link_valid_flag": torch.ones((1, num_links), dtype=torch.bool, device=device), | |
| "link_text_prompts": None if link_text_prompts is None else [list(link_text_prompts)], | |
| "link_text_embeddings": None, | |
| "link_ids": None, | |
| "joint_connections": joint_connections.unsqueeze(0).to(device), | |
| "joint_valid_flag": joint_valid_flag.unsqueeze(0).to(device), | |
| "is_revolute": is_revolute.unsqueeze(0).to(device), | |
| "is_prismatic": is_prismatic.unsqueeze(0).to(device), | |
| } | |
| if link_point_prompts is not None: | |
| batch["link_point_prompts"] = link_point_prompts.unsqueeze(0).to(device) | |
| if link_point_prompt_normals is not None: | |
| batch["link_point_prompt_normals"] = ( | |
| link_point_prompt_normals.unsqueeze(0).to(device) | |
| ) | |
| if link_point_prompt_dropout_eligible is not None: | |
| batch["link_point_prompt_dropout_eligible"] = ( | |
| link_point_prompt_dropout_eligible.unsqueeze(0).to(device) | |
| ) | |
| if link_text_embeddings is not None: | |
| batch["link_text_embeddings"] = link_text_embeddings.unsqueeze(0).to(device) | |
| return batch | |
| def _link_point_prompt_dropout_eligibility_from_link_names( | |
| link_names: Sequence[str], | |
| ) -> torch.Tensor: | |
| """Marks links whose text prompt is unique within the sample.""" | |
| cleaned_link_names = validate_link_names(link_names, require_unique=False) | |
| name_counts: dict[str, int] = {} | |
| for link_name in cleaned_link_names: | |
| name_counts[link_name] = name_counts.get(link_name, 0) + 1 | |
| return torch.tensor( | |
| [name_counts[link_name] == 1 for link_name in cleaned_link_names], | |
| dtype=torch.bool, | |
| ) | |
| def prepare_inference_batch_from_mesh_with_prompts( | |
| mesh, | |
| *, | |
| num_shape_points: int, | |
| num_query_points: int, | |
| num_query_points_per_face_for_seg: int | None = None, | |
| sharp_point_ratio: float, | |
| link_names: Sequence[str], | |
| joint_specs: Sequence[tuple[int, int, str]], | |
| device: torch.device, | |
| link_point_prompts: torch.Tensor | None = None, | |
| link_point_prompt_normals: torch.Tensor | None = None, | |
| require_unique_link_names: bool = True, | |
| ) -> tuple[dict[str, Any], np.ndarray]: | |
| """Builds a mesh inference batch, optionally with explicit link point prompts.""" | |
| shape_points, shape_normals, _, _ = sample_mesh_points( | |
| mesh, | |
| num_points=num_shape_points, | |
| sharp_point_ratio=sharp_point_ratio, | |
| ) | |
| if num_query_points_per_face_for_seg is None: | |
| query_points, query_normals, _, query_face_indices = sample_mesh_points( | |
| mesh, | |
| num_points=num_query_points, | |
| sharp_point_ratio=0.0, | |
| ) | |
| else: | |
| query_points, query_normals, _, query_face_indices = sample_mesh_points_per_face( | |
| mesh, | |
| num_points_per_face=int(num_query_points_per_face_for_seg), | |
| ) | |
| cleaned_link_names = validate_link_names( | |
| link_names, | |
| require_unique=require_unique_link_names, | |
| ) | |
| if (link_point_prompts is None) != (link_point_prompt_normals is None): | |
| raise ValueError( | |
| "link_point_prompts and link_point_prompt_normals must both be provided or both be None" | |
| ) | |
| if link_point_prompts is not None: | |
| if link_point_prompts.ndim != 2 or link_point_prompts.shape != ( | |
| len(cleaned_link_names), | |
| 3, | |
| ): | |
| raise ValueError( | |
| "link_point_prompts must have shape (num_links, 3), " | |
| f"got {tuple(link_point_prompts.shape)} for {len(cleaned_link_names)} links" | |
| ) | |
| assert link_point_prompt_normals is not None | |
| if link_point_prompt_normals.ndim != 2 or link_point_prompt_normals.shape != ( | |
| len(cleaned_link_names), | |
| 3, | |
| ): | |
| raise ValueError( | |
| "link_point_prompt_normals must have shape (num_links, 3), " | |
| f"got {tuple(link_point_prompt_normals.shape)} for {len(cleaned_link_names)} links" | |
| ) | |
| joint_tensors = build_joint_tensors(len(cleaned_link_names), joint_specs, device=device) | |
| link_point_prompt_dropout_eligible = _link_point_prompt_dropout_eligibility_from_link_names( | |
| cleaned_link_names | |
| ) | |
| batch = _build_inference_batch( | |
| shape_points=shape_points, | |
| shape_normals=shape_normals, | |
| query_points=query_points, | |
| query_normals=query_normals, | |
| num_links=len(cleaned_link_names), | |
| joint_connections=joint_tensors["joint_connections"][0], | |
| joint_valid_flag=joint_tensors["joint_valid_flag"][0], | |
| is_revolute=joint_tensors["is_revolute"][0], | |
| is_prismatic=joint_tensors["is_prismatic"][0], | |
| device=device, | |
| link_point_prompts=link_point_prompts, | |
| link_point_prompt_normals=link_point_prompt_normals, | |
| link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible, | |
| link_text_prompts=cleaned_link_names, | |
| ) | |
| return batch, query_face_indices.astype(np.int64, copy=False) | |
| def prepare_inference_batch_from_mesh( | |
| mesh, | |
| *, | |
| num_shape_points: int, | |
| num_query_points: int, | |
| num_query_points_per_face_for_seg: int | None = None, | |
| sharp_point_ratio: float, | |
| link_names: Sequence[str], | |
| joint_specs: Sequence[tuple[int, int, str]], | |
| device: torch.device, | |
| ) -> tuple[dict[str, Any], np.ndarray]: | |
| return prepare_inference_batch_from_mesh_with_prompts( | |
| mesh, | |
| num_shape_points=num_shape_points, | |
| num_query_points=num_query_points, | |
| num_query_points_per_face_for_seg=num_query_points_per_face_for_seg, | |
| sharp_point_ratio=sharp_point_ratio, | |
| link_names=link_names, | |
| joint_specs=joint_specs, | |
| device=device, | |
| link_point_prompts=None, | |
| link_point_prompt_normals=None, | |
| require_unique_link_names=True, | |
| ) | |
| def _sample_points_on_triangles(triangle_vertices: np.ndarray) -> np.ndarray: | |
| """Samples one point uniformly from each triangle in `triangle_vertices`. | |
| Args: | |
| triangle_vertices: Array with shape `(N, 3, 3)`. | |
| Returns: | |
| Sampled points with shape `(N, 3)`. | |
| """ | |
| if triangle_vertices.ndim != 3 or triangle_vertices.shape[1:] != (3, 3): | |
| raise ValueError( | |
| "triangle_vertices must have shape (N, 3, 3), " | |
| f"got {triangle_vertices.shape}" | |
| ) | |
| if triangle_vertices.shape[0] == 0: | |
| return np.zeros((0, 3), dtype=np.float32) | |
| r1 = np.random.random((triangle_vertices.shape[0], 1)) | |
| r2 = np.random.random((triangle_vertices.shape[0], 1)) | |
| sqrt_r1 = np.sqrt(r1) | |
| barycentric = np.concatenate( | |
| ( | |
| 1.0 - sqrt_r1, | |
| sqrt_r1 * (1.0 - r2), | |
| sqrt_r1 * r2, | |
| ), | |
| axis=1, | |
| ).astype(np.float32, copy=False) | |
| return (triangle_vertices * barycentric[:, :, None]).sum(axis=1).astype( | |
| np.float32, | |
| copy=False, | |
| ) | |
| def sample_balanced_query_points_from_face_seg( | |
| mesh, | |
| *, | |
| face_part_ids: np.ndarray, | |
| num_points: int, | |
| ) -> dict[str, np.ndarray]: | |
| """Samples query points with near-equal support per predicted part. | |
| The sampler draws points uniformly from triangle area within each predicted | |
| part, then balances the total sample budget across the remaining part IDs. | |
| Args: | |
| mesh: Mesh whose face ordering matches `face_part_ids`. | |
| face_part_ids: Predicted part ID for each face. | |
| num_points: Total number of query points to sample. | |
| Returns: | |
| Dictionary containing sampled query points, normals, face indices, | |
| assigned part IDs, and per-part query counts. | |
| """ | |
| if num_points <= 0: | |
| raise ValueError(f"num_points must be positive, got {num_points}") | |
| mesh_faces = np.asarray(mesh.faces, dtype=np.int64) | |
| face_part_ids = np.asarray(face_part_ids, dtype=np.int64) | |
| if face_part_ids.shape != (mesh_faces.shape[0],): | |
| raise ValueError( | |
| "face_part_ids must have one entry per face, " | |
| f"got {face_part_ids.shape} for {mesh_faces.shape[0]} faces" | |
| ) | |
| unique_part_ids = np.unique(face_part_ids[face_part_ids >= 0]).astype( | |
| np.int64, | |
| copy=False, | |
| ) | |
| if unique_part_ids.size == 0: | |
| raise ValueError("face_part_ids must contain at least one non-negative part ID") | |
| samples_per_part = np.full( | |
| unique_part_ids.shape, | |
| fill_value=num_points // unique_part_ids.size, | |
| dtype=np.int64, | |
| ) | |
| samples_per_part[: num_points % unique_part_ids.size] += 1 | |
| mesh_vertices = np.asarray(mesh.vertices, dtype=np.float32) | |
| mesh_face_normals = np.asarray(mesh.face_normals, dtype=np.float32) | |
| mesh_face_areas = np.asarray(mesh.area_faces, dtype=np.float64) | |
| point_chunks: list[np.ndarray] = [] | |
| normal_chunks: list[np.ndarray] = [] | |
| face_index_chunks: list[np.ndarray] = [] | |
| part_id_chunks: list[np.ndarray] = [] | |
| for part_id, part_query_count in zip(unique_part_ids, samples_per_part, strict=True): | |
| if int(part_query_count) <= 0: | |
| continue | |
| candidate_face_indices = np.flatnonzero(face_part_ids == int(part_id)).astype( | |
| np.int64, | |
| copy=False, | |
| ) | |
| if candidate_face_indices.size == 0: | |
| raise ValueError( | |
| f"Cannot sample query points for part_id={int(part_id)} because it has no faces" | |
| ) | |
| candidate_face_areas = mesh_face_areas[candidate_face_indices] | |
| if np.any(candidate_face_areas > 0.0): | |
| face_probabilities = candidate_face_areas / candidate_face_areas.sum() | |
| sampled_face_indices = np.random.choice( | |
| candidate_face_indices, | |
| size=int(part_query_count), | |
| replace=True, | |
| p=face_probabilities, | |
| ).astype(np.int64, copy=False) | |
| else: | |
| sampled_face_indices = np.random.choice( | |
| candidate_face_indices, | |
| size=int(part_query_count), | |
| replace=True, | |
| ).astype(np.int64, copy=False) | |
| sampled_triangles = mesh_vertices[mesh_faces[sampled_face_indices]] | |
| sampled_points = _sample_points_on_triangles(sampled_triangles) | |
| point_chunks.append(sampled_points) | |
| normal_chunks.append( | |
| mesh_face_normals[sampled_face_indices].astype(np.float32, copy=False) | |
| ) | |
| face_index_chunks.append(sampled_face_indices) | |
| part_id_chunks.append( | |
| np.full((int(part_query_count),), int(part_id), dtype=np.int64) | |
| ) | |
| sampled_points = np.concatenate(point_chunks, axis=0) | |
| sampled_normals = np.concatenate(normal_chunks, axis=0) | |
| sampled_face_indices = np.concatenate(face_index_chunks, axis=0) | |
| sampled_part_ids = np.concatenate(part_id_chunks, axis=0) | |
| permutation = np.random.permutation(len(sampled_points)).astype(np.int64, copy=False) | |
| sampled_points = sampled_points[permutation] | |
| sampled_normals = sampled_normals[permutation] | |
| sampled_face_indices = sampled_face_indices[permutation] | |
| sampled_part_ids = sampled_part_ids[permutation] | |
| return { | |
| "query_points": sampled_points.astype(np.float32, copy=False), | |
| "query_normals": sampled_normals.astype(np.float32, copy=False), | |
| "face_indices": sampled_face_indices.astype(np.int64, copy=False), | |
| "link_ids": sampled_part_ids.astype(np.int64, copy=False), | |
| "unique_part_ids": unique_part_ids.astype(np.int64, copy=False), | |
| "query_counts": samples_per_part.astype(np.int64, copy=False), | |
| } | |
| def build_joint_refit_batch( | |
| batch: Mapping[str, Any], | |
| *, | |
| query_points: np.ndarray, | |
| query_normals: np.ndarray, | |
| assigned_link_ids: np.ndarray, | |
| ) -> dict[str, Any]: | |
| """Clones an inference batch with new balanced query samples and fixed link IDs.""" | |
| device = batch["shape_points"].device | |
| refit_batch = dict(batch) | |
| refit_batch["query_points"] = _to_batched_float_tensor( | |
| np.asarray(query_points, dtype=np.float32), | |
| device=device, | |
| ) | |
| refit_batch["query_point_normals"] = _to_batched_float_tensor( | |
| np.asarray(query_normals, dtype=np.float32), | |
| device=device, | |
| ) | |
| refit_batch["link_ids"] = ( | |
| torch.from_numpy(np.asarray(assigned_link_ids, dtype=np.int64)) | |
| .long() | |
| .unsqueeze(0) | |
| .to(device) | |
| ) | |
| return refit_batch | |
| def run_joint_refit_from_face_seg( | |
| model: Any, | |
| *, | |
| batch: Mapping[str, Any], | |
| mesh: Any, | |
| face_part_ids: np.ndarray, | |
| num_query_points: int, | |
| query_batch_size: int = 8192, | |
| no_point_prompt_for_unique_text: bool = False, | |
| ) -> tuple[dict[str, Any], dict[str, np.ndarray]]: | |
| """Runs the second inference pass from balanced query samples. | |
| Stage 1 segmentation stays unchanged. This helper performs a second forward | |
| pass that resamples query points evenly across the refined face | |
| segmentation, then decodes joint parameters while also measuring the | |
| segmentation confidence of the assigned part ID for each sampled query. | |
| """ | |
| refit_sampling = sample_balanced_query_points_from_face_seg( | |
| mesh, | |
| face_part_ids=face_part_ids, | |
| num_points=int(num_query_points), | |
| ) | |
| refit_batch = build_joint_refit_batch( | |
| batch, | |
| query_points=refit_sampling["query_points"], | |
| query_normals=refit_sampling["query_normals"], | |
| assigned_link_ids=refit_sampling["link_ids"], | |
| ) | |
| refit_output = run_batched_model_inference( | |
| model, | |
| query_batch_size=int(query_batch_size), | |
| no_point_prompt_for_unique_text=bool(no_point_prompt_for_unique_text), | |
| **refit_batch, | |
| ) | |
| return refit_output, refit_sampling | |
| def run_batched_model_inference( | |
| model: Any, | |
| *, | |
| shape_points: torch.Tensor, | |
| shape_point_normals: torch.Tensor, | |
| query_points: torch.Tensor, | |
| query_point_normals: torch.Tensor, | |
| link_point_prompts: torch.Tensor | None, | |
| link_point_prompt_normals: torch.Tensor | None, | |
| link_valid_flag: torch.Tensor, | |
| joint_connections: torch.Tensor, | |
| joint_valid_flag: torch.Tensor, | |
| is_revolute: torch.Tensor, | |
| is_prismatic: torch.Tensor, | |
| link_point_prompt_dropout_eligible: torch.Tensor | None = None, | |
| link_text_prompts: Sequence[Sequence[str]] | None = None, | |
| link_text_embeddings: torch.Tensor | None = None, | |
| link_ids: torch.Tensor | None = None, | |
| query_batch_size: int = 8192, | |
| no_point_prompt_for_unique_text: bool = False, | |
| decode_joint_parameters: bool = True, | |
| ) -> dict[str, Any]: | |
| if query_batch_size <= 0: | |
| raise ValueError(f"query_batch_size must be positive, got {query_batch_size}") | |
| forced_no_point_prompt_mask: torch.Tensor | None = None | |
| if no_point_prompt_for_unique_text and link_point_prompts is not None: | |
| if link_point_prompt_dropout_eligible is None: | |
| raise ValueError( | |
| "no_point_prompt_for_unique_text requires link_point_prompt_dropout_eligible " | |
| "when link_point_prompts are provided" | |
| ) | |
| if link_point_prompt_dropout_eligible.shape != link_valid_flag.shape: | |
| raise ValueError( | |
| "link_point_prompt_dropout_eligible must match link_valid_flag shape, " | |
| f"got {tuple(link_point_prompt_dropout_eligible.shape)} and " | |
| f"{tuple(link_valid_flag.shape)}" | |
| ) | |
| forced_no_point_prompt_mask = ( | |
| link_point_prompt_dropout_eligible.to( | |
| device=link_valid_flag.device, | |
| dtype=torch.bool, | |
| ) | |
| & link_valid_flag | |
| ) | |
| with torch.inference_mode(): | |
| ( | |
| shape_pretrained_features, | |
| query_pretrained_features, | |
| link_point_prompt_pretrained_features, | |
| ) = model.encoder._compute_pretrained_point_features( | |
| shape_points=shape_points, | |
| query_points=query_points, | |
| link_point_prompts=link_point_prompts, | |
| ) | |
| shape_latents = model.encoder.encode_shape( | |
| shape_points, | |
| shape_point_normals, | |
| pretrained_features=shape_pretrained_features, | |
| ) | |
| initial_link_latents = model.encoder.encode_links( | |
| link_point_prompts=link_point_prompts, | |
| link_point_prompt_normals=link_point_prompt_normals, | |
| link_valid_flag=link_valid_flag, | |
| link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible, | |
| forced_no_point_prompt_mask=forced_no_point_prompt_mask, | |
| link_point_prompt_pretrained_features=link_point_prompt_pretrained_features, | |
| link_text_prompts=link_text_prompts, | |
| link_text_embeddings=link_text_embeddings, | |
| ) | |
| link_block_outputs: list[torch.Tensor] = [] | |
| running_link_latents = initial_link_latents | |
| for link_to_shape_cross_attn, link_self_attn in zip( | |
| model.encoder.link_to_shape_cross_attn, | |
| model.encoder.link_self_attn, | |
| strict=True, | |
| ): | |
| running_link_latents = link_to_shape_cross_attn( | |
| running_link_latents, | |
| shape_latents, | |
| query_mask=link_valid_flag, | |
| ) | |
| running_link_latents = link_self_attn(running_link_latents, mask=link_valid_flag) | |
| link_block_outputs.append(running_link_latents) | |
| link_latents = model.encoder.link_output_norm(running_link_latents) | |
| link_latents = link_latents.masked_fill(~link_valid_flag.unsqueeze(-1), 0) | |
| segmentation_link_latents = model.build_segmentation_link_latents( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| ) | |
| all_query_latents: list[torch.Tensor] | None = ( | |
| [] if decode_joint_parameters else None | |
| ) | |
| all_segmentation_logits: list[torch.Tensor] = [] | |
| all_joint_decoding_link_ids: list[torch.Tensor] = [] | |
| all_joint_decoding_confidences: list[torch.Tensor] = [] | |
| all_revolute_motion_points: list[torch.Tensor] = [] | |
| all_prismatic_motion_points: list[torch.Tensor] = [] | |
| num_queries = query_points.shape[1] | |
| for start_idx in range(0, num_queries, query_batch_size): | |
| end_idx = min(start_idx + query_batch_size, num_queries) | |
| query_points_chunk = query_points[:, start_idx:end_idx] | |
| query_point_normals_chunk = query_point_normals[:, start_idx:end_idx] | |
| query_latents_chunk = model.encoder._embed_point_tokens( | |
| query_points_chunk, | |
| query_point_normals_chunk, | |
| pretrained_features=( | |
| None | |
| if query_pretrained_features is None | |
| else query_pretrained_features[:, start_idx:end_idx] | |
| ), | |
| ) | |
| for block_link_latents, query_to_shape_cross_attn, query_to_link_cross_attn in zip( | |
| link_block_outputs, | |
| model.encoder.query_to_shape_cross_attn, | |
| model.encoder.query_to_link_cross_attn, | |
| strict=True, | |
| ): | |
| query_latents_chunk = query_to_shape_cross_attn(query_latents_chunk, shape_latents) | |
| query_latents_chunk = query_to_link_cross_attn( | |
| query_latents_chunk, | |
| block_link_latents, | |
| context_mask=link_valid_flag, | |
| ) | |
| query_latents_chunk = model.encoder.query_output_norm(query_latents_chunk) | |
| segmentation_logits_chunk = model.decode_segmentation( | |
| query_latents=query_latents_chunk, | |
| link_latents=segmentation_link_latents, | |
| link_valid_flag=link_valid_flag, | |
| ) | |
| if all_query_latents is not None: | |
| all_query_latents.append(query_latents_chunk) | |
| all_segmentation_logits.append(segmentation_logits_chunk) | |
| joint_decoding_link_ids_chunk: torch.Tensor | None = None | |
| joint_decoding_confidences_chunk: torch.Tensor | None = None | |
| if link_ids is not None or ( | |
| decode_joint_parameters | |
| and model.joint_decode_type in { | |
| "overparametrized", | |
| "overparam+dir", | |
| "overparam+singledir", | |
| } | |
| ): | |
| segmentation_probabilities_chunk = segmentation_logits_chunk.softmax(dim=-1) | |
| if link_ids is not None: | |
| joint_decoding_link_ids_chunk = link_ids[:, start_idx:end_idx] | |
| joint_decoding_confidences_chunk = segmentation_probabilities_chunk.gather( | |
| dim=-1, | |
| index=joint_decoding_link_ids_chunk.clamp_min(0).unsqueeze(-1), | |
| ).squeeze(-1) | |
| joint_decoding_confidences_chunk = joint_decoding_confidences_chunk.masked_fill( | |
| joint_decoding_link_ids_chunk < 0, | |
| 0.0, | |
| ) | |
| else: | |
| ( | |
| joint_decoding_confidences_chunk, | |
| joint_decoding_link_ids_chunk, | |
| ) = segmentation_probabilities_chunk.max(dim=-1) | |
| all_joint_decoding_link_ids.append(joint_decoding_link_ids_chunk) | |
| all_joint_decoding_confidences.append(joint_decoding_confidences_chunk) | |
| if ( | |
| decode_joint_parameters | |
| and model.joint_decode_type in { | |
| "overparametrized", | |
| "overparam+dir", | |
| "overparam+singledir", | |
| } | |
| ): | |
| assert joint_decoding_link_ids_chunk is not None | |
| revolute_motion_points_chunk, prismatic_motion_points_chunk = ( | |
| model._decode_joint_motion_points( | |
| query_latents=query_latents_chunk, | |
| link_latents=link_latents, | |
| assigned_link_ids=joint_decoding_link_ids_chunk, | |
| joint_connections=joint_connections, | |
| ) | |
| ) | |
| all_revolute_motion_points.append(revolute_motion_points_chunk) | |
| all_prismatic_motion_points.append(prismatic_motion_points_chunk) | |
| query_latents = ( | |
| torch.cat(all_query_latents, dim=1) | |
| if all_query_latents is not None | |
| else None | |
| ) | |
| segmentation_logits = torch.cat(all_segmentation_logits, dim=1) | |
| output: dict[str, Any] = { | |
| "shape_latents": shape_latents, | |
| "query_latents": query_latents, | |
| "query_points": query_points, | |
| "link_latents": link_latents, | |
| "segmentation_logits": segmentation_logits, | |
| "joint_connections": joint_connections, | |
| "joint_valid_flag": joint_valid_flag, | |
| "is_revolute": is_revolute, | |
| "is_prismatic": is_prismatic, | |
| } | |
| if not decode_joint_parameters: | |
| output.update( | |
| { | |
| "revolute_axis": None, | |
| "prismatic_axis": None, | |
| "revolute_range": None, | |
| "prismatic_range": None, | |
| "revolute_closest_axis_points": None, | |
| "revolute_closest_axis_points_decoder": None, | |
| "revolute_low_points": None, | |
| "revolute_high_points": None, | |
| "revolute_axis_directions": None, | |
| "prismatic_closest_axis_points": None, | |
| "prismatic_closest_axis_points_decoder": None, | |
| "prismatic_low_points": None, | |
| "prismatic_high_points": None, | |
| "prismatic_axis_directions": None, | |
| "joint_decoding_link_ids": None, | |
| "joint_decoding_confidences": None, | |
| "joint_query_counts": None, | |
| "revolute_parameter_valid": None, | |
| "prismatic_parameter_valid": None, | |
| } | |
| ) | |
| return output | |
| joint_decoding_link_ids = None | |
| joint_decoding_confidences = None | |
| joint_query_counts = None | |
| if all_joint_decoding_link_ids: | |
| joint_decoding_link_ids = torch.cat(all_joint_decoding_link_ids, dim=1) | |
| joint_decoding_confidences = torch.cat(all_joint_decoding_confidences, dim=1) | |
| joint_query_counts = _count_assigned_queries_per_joint( | |
| assigned_link_ids=joint_decoding_link_ids, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| ) | |
| if model.joint_decode_type in {"plain", "plain+fm"}: | |
| if joint_query_counts is None: | |
| revolute_parameter_valid = joint_valid_flag & is_revolute | |
| prismatic_parameter_valid = joint_valid_flag & is_prismatic | |
| else: | |
| revolute_parameter_valid = ( | |
| joint_valid_flag & is_revolute & (joint_query_counts > 0) | |
| ) | |
| prismatic_parameter_valid = ( | |
| joint_valid_flag & is_prismatic & (joint_query_counts > 0) | |
| ) | |
| revolute_axis, prismatic_axis, revolute_range, prismatic_range = ( | |
| model.decode_joint_parameters( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| ) | |
| ) | |
| output.update( | |
| { | |
| "revolute_axis": revolute_axis, | |
| "prismatic_axis": prismatic_axis, | |
| "revolute_range": revolute_range, | |
| "prismatic_range": prismatic_range, | |
| "revolute_closest_axis_points": None, | |
| "revolute_low_points": None, | |
| "revolute_high_points": None, | |
| "revolute_axis_directions": None, | |
| "prismatic_closest_axis_points": None, | |
| "prismatic_low_points": None, | |
| "prismatic_high_points": None, | |
| "prismatic_axis_directions": None, | |
| "joint_decoding_link_ids": joint_decoding_link_ids, | |
| "joint_decoding_confidences": joint_decoding_confidences, | |
| "joint_query_counts": joint_query_counts, | |
| "revolute_parameter_valid": revolute_parameter_valid, | |
| "prismatic_parameter_valid": prismatic_parameter_valid, | |
| } | |
| ) | |
| return output | |
| assert joint_decoding_link_ids is not None | |
| assert joint_query_counts is not None | |
| revolute_parameter_valid = ( | |
| joint_valid_flag & is_revolute & (joint_query_counts > 0) | |
| ) | |
| prismatic_parameter_valid = ( | |
| joint_valid_flag & is_prismatic & (joint_query_counts > 0) | |
| ) | |
| revolute_motion_points_decoder_raw = torch.cat(all_revolute_motion_points, dim=1) | |
| prismatic_motion_points_decoder_raw = torch.cat(all_prismatic_motion_points, dim=1) | |
| revolute_joint_axis_directions = None | |
| prismatic_joint_axis_directions = None | |
| if model.joint_decode_type == "overparam+singledir": | |
| ( | |
| revolute_joint_axis_directions, | |
| prismatic_joint_axis_directions, | |
| ) = model._decode_joint_axis_directions( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| ) | |
| revolute_motion_points_decoder = revolute_motion_points_decoder_raw | |
| prismatic_motion_points_decoder = prismatic_motion_points_decoder_raw | |
| revolute_motion_points = model._convert_overparam_motion_points_to_world_coordinates( | |
| motion_points=revolute_motion_points_decoder, | |
| query_points=query_points, | |
| assigned_link_ids=joint_decoding_link_ids, | |
| ) | |
| prismatic_motion_points = model._convert_overparam_motion_points_to_world_coordinates( | |
| motion_points=prismatic_motion_points_decoder, | |
| query_points=query_points, | |
| assigned_link_ids=joint_decoding_link_ids, | |
| ) | |
| revolute_axis, prismatic_axis, revolute_range, prismatic_range = ( | |
| model.decode_joint_parameters( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| query_points=query_points, | |
| assigned_link_ids=joint_decoding_link_ids, | |
| decoded_motion_points=(revolute_motion_points, prismatic_motion_points), | |
| decoded_axis_directions=( | |
| None | |
| if revolute_joint_axis_directions is None or prismatic_joint_axis_directions is None | |
| else ( | |
| revolute_joint_axis_directions, | |
| prismatic_joint_axis_directions, | |
| ) | |
| ), | |
| decoded_motion_points_are_world=True, | |
| ) | |
| ) | |
| output.update( | |
| { | |
| "revolute_axis": revolute_axis, | |
| "prismatic_axis": prismatic_axis, | |
| "revolute_range": revolute_range, | |
| "prismatic_range": prismatic_range, | |
| "revolute_closest_axis_points": revolute_motion_points[..., :3], | |
| "revolute_closest_axis_points_decoder": revolute_motion_points_decoder_raw[..., :3], | |
| "revolute_low_points": revolute_motion_points[..., 3:6], | |
| "revolute_high_points": revolute_motion_points[..., 6:9], | |
| "revolute_axis_directions": ( | |
| revolute_motion_points[..., 9:12] | |
| if model.joint_decode_type == "overparam+dir" | |
| else revolute_joint_axis_directions | |
| if model.joint_decode_type == "overparam+singledir" | |
| else None | |
| ), | |
| "prismatic_closest_axis_points": prismatic_motion_points[..., :3], | |
| "prismatic_closest_axis_points_decoder": prismatic_motion_points_decoder_raw[..., :3], | |
| "prismatic_low_points": prismatic_motion_points[..., 3:6], | |
| "prismatic_high_points": prismatic_motion_points[..., 6:9], | |
| "prismatic_axis_directions": ( | |
| prismatic_motion_points[..., 9:12] | |
| if model.joint_decode_type == "overparam+dir" | |
| else prismatic_joint_axis_directions | |
| if model.joint_decode_type == "overparam+singledir" | |
| else None | |
| ), | |
| "joint_decoding_link_ids": joint_decoding_link_ids, | |
| "joint_decoding_confidences": joint_decoding_confidences, | |
| "joint_query_counts": joint_query_counts, | |
| "revolute_parameter_valid": revolute_parameter_valid, | |
| "prismatic_parameter_valid": prismatic_parameter_valid, | |
| } | |
| ) | |
| return output | |
| def _count_assigned_queries_per_joint( | |
| *, | |
| assigned_link_ids: torch.Tensor, | |
| joint_connections: torch.Tensor, | |
| joint_valid_flag: torch.Tensor, | |
| ) -> torch.Tensor: | |
| child_link_ids = joint_connections[..., 1] | |
| query_counts = (assigned_link_ids.unsqueeze(-1) == child_link_ids.unsqueeze(1)).sum(dim=1) | |
| return query_counts.masked_fill(~joint_valid_flag, 0) | |
| def motion_arrays_from_model_output( | |
| output: Mapping[str, Any], | |
| *, | |
| num_links: int, | |
| ) -> dict[str, np.ndarray]: | |
| revolute_parameter_valid = np.zeros(num_links, dtype=np.bool_) | |
| prismatic_parameter_valid = np.zeros(num_links, dtype=np.bool_) | |
| revolute_plucker = np.zeros((num_links, 6), dtype=np.float32) | |
| prismatic_plucker = np.zeros((num_links, 6), dtype=np.float32) | |
| revolute_range = np.zeros((num_links, 2), dtype=np.float32) | |
| prismatic_range = np.zeros((num_links, 2), dtype=np.float32) | |
| joint_connections = output["joint_connections"][0].detach().cpu().numpy() | |
| joint_valid_flag = output["joint_valid_flag"][0].detach().cpu().numpy().astype(bool, copy=False) | |
| is_revolute = output["is_revolute"][0].detach().cpu().numpy().astype(bool, copy=False) | |
| is_prismatic = output["is_prismatic"][0].detach().cpu().numpy().astype(bool, copy=False) | |
| joint_revolute_parameter_valid = output.get("revolute_parameter_valid") | |
| if joint_revolute_parameter_valid is None: | |
| joint_revolute_parameter_valid = output["joint_valid_flag"] & output["is_revolute"] | |
| joint_revolute_parameter_valid = ( | |
| joint_revolute_parameter_valid[0].detach().cpu().numpy().astype(bool, copy=False) | |
| ) | |
| joint_prismatic_parameter_valid = output.get("prismatic_parameter_valid") | |
| if joint_prismatic_parameter_valid is None: | |
| joint_prismatic_parameter_valid = output["joint_valid_flag"] & output["is_prismatic"] | |
| joint_prismatic_parameter_valid = ( | |
| joint_prismatic_parameter_valid[0].detach().cpu().numpy().astype(bool, copy=False) | |
| ) | |
| revolute_axis_tensor = output["revolute_axis"][0].detach() | |
| prismatic_axis_tensor = output["prismatic_axis"][0].detach() | |
| revolute_limits_tensor = output["revolute_range"][0].detach() | |
| prismatic_limits_tensor = output["prismatic_range"][0].detach() | |
| revolute_axis = revolute_axis_tensor.cpu().numpy() | |
| prismatic_axis = prismatic_axis_tensor.cpu().numpy() | |
| revolute_limits = revolute_limits_tensor.cpu().numpy() | |
| prismatic_limits = prismatic_limits_tensor.cpu().numpy() | |
| for joint_idx in np.flatnonzero(joint_valid_flag): | |
| child_link_id = int(joint_connections[joint_idx, 1]) | |
| if is_revolute[joint_idx] and joint_revolute_parameter_valid[joint_idx]: | |
| revolute_parameter_valid[child_link_id] = True | |
| revolute_plucker[child_link_id] = revolute_axis[joint_idx].astype(np.float32, copy=False) | |
| revolute_range[child_link_id] = revolute_limits[joint_idx].astype(np.float32, copy=False) | |
| if is_prismatic[joint_idx] and joint_prismatic_parameter_valid[joint_idx]: | |
| prismatic_parameter_valid[child_link_id] = True | |
| prismatic_plucker[child_link_id] = prismatic_axis[joint_idx].astype(np.float32, copy=False) | |
| prismatic_range[child_link_id] = prismatic_limits[joint_idx].astype(np.float32, copy=False) | |
| return { | |
| "is_part_revolute": revolute_parameter_valid.copy(), | |
| "is_part_prismatic": prismatic_parameter_valid.copy(), | |
| "revolute_parameter_valid": revolute_parameter_valid, | |
| "prismatic_parameter_valid": prismatic_parameter_valid, | |
| "revolute_plucker": revolute_plucker, | |
| "prismatic_plucker": prismatic_plucker, | |
| "revolute_range": revolute_range, | |
| "prismatic_range": prismatic_range, | |
| } | |
| def denormalize_motion_parameters( | |
| revolute_plucker: np.ndarray, | |
| prismatic_plucker: np.ndarray, | |
| revolute_range: np.ndarray, | |
| prismatic_range: np.ndarray, | |
| *, | |
| center: np.ndarray, | |
| scale: float, | |
| ) -> dict[str, np.ndarray]: | |
| inverse_transform = np.eye(4, dtype=np.float32) | |
| inverse_transform[:3, :3] *= np.float32(1.0 / scale) | |
| inverse_transform[:3, 3] = np.asarray(center, dtype=np.float32) | |
| denormalized_revolute = np.zeros_like(revolute_plucker, dtype=np.float32) | |
| denormalized_prismatic = np.zeros_like(prismatic_plucker, dtype=np.float32) | |
| for link_idx in range(revolute_plucker.shape[0]): | |
| if np.linalg.norm(revolute_plucker[link_idx, :3]) > 0.0: | |
| denormalized_revolute[link_idx] = transform_plucker( | |
| revolute_plucker[link_idx], | |
| inverse_transform, | |
| ) | |
| if np.linalg.norm(prismatic_plucker[link_idx, :3]) > 0.0: | |
| denormalized_prismatic[link_idx] = transform_plucker( | |
| prismatic_plucker[link_idx], | |
| inverse_transform, | |
| ) | |
| return { | |
| "revolute_plucker": denormalized_revolute, | |
| "prismatic_plucker": denormalized_prismatic, | |
| "revolute_range": np.asarray(revolute_range, dtype=np.float32).copy(), | |
| "prismatic_range": np.asarray(prismatic_range, dtype=np.float32) / np.float32(scale), | |
| } | |
| def prismatic_directions_from_plucker(prismatic_plucker: np.ndarray) -> np.ndarray: | |
| prismatic_directions = np.zeros((prismatic_plucker.shape[0], 3), dtype=np.float32) | |
| for link_idx in range(prismatic_plucker.shape[0]): | |
| if np.linalg.norm(prismatic_plucker[link_idx, :3]) <= 0.0: | |
| continue | |
| axis_direction, _ = plucker_to_axis_point(prismatic_plucker[link_idx]) | |
| prismatic_directions[link_idx] = axis_direction.astype(np.float32, copy=False) | |
| return prismatic_directions | |
| def build_predicted_kinematic_records( | |
| link_names: Sequence[str], | |
| joint_specs: Sequence[tuple[int, int, str]], | |
| *, | |
| revolute_plucker: np.ndarray, | |
| prismatic_plucker: np.ndarray, | |
| prismatic_axis: np.ndarray | None = None, | |
| revolute_range: np.ndarray, | |
| prismatic_range: np.ndarray, | |
| revolute_parameter_valid: np.ndarray | None = None, | |
| prismatic_parameter_valid: np.ndarray | None = None, | |
| ) -> dict[str, Any]: | |
| links = [ | |
| {"link_id": int(link_id), "name": str(link_name)} | |
| for link_id, link_name in enumerate(link_names) | |
| ] | |
| joints: list[dict[str, Any]] = [] | |
| for joint_id, (parent_link_id, child_link_id, joint_type) in enumerate(joint_specs): | |
| joints.append( | |
| { | |
| "joint_id": int(joint_id), | |
| "parent_link_id": int(parent_link_id), | |
| "child_link_id": int(child_link_id), | |
| "joint_description": ( | |
| f"{link_names[parent_link_id]} -> {link_names[child_link_id]} ({joint_type})" | |
| ), | |
| "is_revolute": joint_type == "revolute", | |
| "is_prismatic": joint_type == "prismatic", | |
| } | |
| ) | |
| return build_predicted_kinematic_records_from_links_and_joints( | |
| links, | |
| joints, | |
| revolute_plucker=revolute_plucker, | |
| prismatic_plucker=prismatic_plucker, | |
| prismatic_axis=prismatic_axis, | |
| revolute_range=revolute_range, | |
| prismatic_range=prismatic_range, | |
| revolute_parameter_valid=revolute_parameter_valid, | |
| prismatic_parameter_valid=prismatic_parameter_valid, | |
| ) | |
| def _resolve_link_parameter_valid_mask( | |
| parameter_valid: np.ndarray | None, | |
| *, | |
| num_links: int, | |
| ) -> np.ndarray: | |
| if parameter_valid is None: | |
| return np.ones(num_links, dtype=np.bool_) | |
| parameter_valid = np.asarray(parameter_valid, dtype=np.bool_) | |
| if parameter_valid.shape != (num_links,): | |
| raise ValueError( | |
| f"Expected parameter-valid mask to have shape ({num_links},), " | |
| f"got {parameter_valid.shape}" | |
| ) | |
| return parameter_valid | |
| def build_predicted_joint_records_from_links_and_joints( | |
| links: Sequence[Mapping[str, Any]], | |
| joints: Sequence[Mapping[str, Any]], | |
| *, | |
| revolute_plucker: np.ndarray, | |
| prismatic_plucker: np.ndarray, | |
| prismatic_axis: np.ndarray | None = None, | |
| revolute_range: np.ndarray, | |
| prismatic_range: np.ndarray, | |
| revolute_parameter_valid: np.ndarray | None = None, | |
| prismatic_parameter_valid: np.ndarray | None = None, | |
| ) -> dict[str, Any]: | |
| """Builds predicted kinematic records while preserving stored topology. | |
| Prismatic joints are exported as direction-only whenever possible because | |
| their line anchor is not physically meaningful. Callers may pass | |
| ``prismatic_axis`` explicitly; otherwise this derives directions from | |
| ``prismatic_plucker`` for backward compatibility. | |
| """ | |
| predicted_links = [ | |
| { | |
| "link_id": int(link["link_id"]), | |
| "name": str(link.get("name", f"link_{int(link['link_id'])}")), | |
| } | |
| for link in sorted(links, key=lambda item: int(item["link_id"])) | |
| ] | |
| num_links = len(predicted_links) | |
| revolute_parameter_valid_mask = _resolve_link_parameter_valid_mask( | |
| revolute_parameter_valid, | |
| num_links=num_links, | |
| ) | |
| prismatic_parameter_valid_mask = _resolve_link_parameter_valid_mask( | |
| prismatic_parameter_valid, | |
| num_links=num_links, | |
| ) | |
| if prismatic_axis is None: | |
| prismatic_axis = prismatic_directions_from_plucker(prismatic_plucker) | |
| else: | |
| prismatic_axis = np.asarray(prismatic_axis, dtype=np.float32) | |
| if prismatic_axis.shape != (num_links, 3): | |
| raise ValueError( | |
| f"Expected prismatic_axis to have shape ({num_links}, 3), " | |
| f"got {prismatic_axis.shape}" | |
| ) | |
| predicted_joints: list[dict[str, Any]] = [] | |
| for joint_idx, joint in enumerate(sorted(joints, key=lambda item: int(item["joint_id"]))): | |
| child_link_id = int(joint["child_link_id"]) | |
| is_revolute = bool(joint.get("is_revolute", False)) | |
| is_prismatic = bool(joint.get("is_prismatic", False)) | |
| has_revolute_parameters = is_revolute and revolute_parameter_valid_mask[child_link_id] | |
| has_prismatic_parameters = is_prismatic and prismatic_parameter_valid_mask[child_link_id] | |
| predicted_joints.append( | |
| { | |
| "joint_id": int(joint.get("joint_id", joint_idx)), | |
| "parent_link_id": int(joint["parent_link_id"]), | |
| "child_link_id": child_link_id, | |
| "joint_description": str(joint.get("joint_description", "")), | |
| "is_revolute": is_revolute, | |
| "is_prismatic": is_prismatic, | |
| "revolute_axis": ( | |
| np.asarray(revolute_plucker[child_link_id], dtype=np.float32).copy() | |
| if has_revolute_parameters | |
| else None | |
| ), | |
| "prismatic_axis": ( | |
| np.asarray(prismatic_axis[child_link_id], dtype=np.float32).copy() | |
| if has_prismatic_parameters | |
| else None | |
| ), | |
| "revolute_range": ( | |
| np.asarray(revolute_range[child_link_id], dtype=np.float32).copy() | |
| if has_revolute_parameters | |
| else None | |
| ), | |
| "prismatic_range": ( | |
| np.asarray(prismatic_range[child_link_id], dtype=np.float32).copy() | |
| if has_prismatic_parameters | |
| else None | |
| ), | |
| } | |
| ) | |
| return {"links": predicted_links, "joints": predicted_joints} | |
| def build_predicted_kinematic_records_from_links_and_joints( | |
| links: Sequence[Mapping[str, Any]], | |
| joints: Sequence[Mapping[str, Any]], | |
| *, | |
| revolute_plucker: np.ndarray, | |
| prismatic_plucker: np.ndarray, | |
| prismatic_axis: np.ndarray | None = None, | |
| revolute_range: np.ndarray, | |
| prismatic_range: np.ndarray, | |
| revolute_parameter_valid: np.ndarray | None = None, | |
| prismatic_parameter_valid: np.ndarray | None = None, | |
| ) -> dict[str, Any]: | |
| """Builds JSON-serializable predicted kinematic metadata.""" | |
| predicted_records = build_predicted_joint_records_from_links_and_joints( | |
| links, | |
| joints, | |
| revolute_plucker=revolute_plucker, | |
| prismatic_plucker=prismatic_plucker, | |
| prismatic_axis=prismatic_axis, | |
| revolute_range=revolute_range, | |
| prismatic_range=prismatic_range, | |
| revolute_parameter_valid=revolute_parameter_valid, | |
| prismatic_parameter_valid=prismatic_parameter_valid, | |
| ) | |
| predicted_joints: list[dict[str, Any]] = [] | |
| for joint in predicted_records["joints"]: | |
| predicted_joint = dict(joint) | |
| for key in ( | |
| "revolute_axis", | |
| "prismatic_axis", | |
| "revolute_range", | |
| "prismatic_range", | |
| ): | |
| value = predicted_joint.get(key) | |
| if value is not None: | |
| predicted_joint[key] = np.asarray(value, dtype=np.float32).tolist() | |
| predicted_joints.append(predicted_joint) | |
| return { | |
| "links": predicted_records["links"], | |
| "joints": predicted_joints, | |
| } | |