Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import math | |
| import os | |
| from contextlib import nullcontext | |
| from itertools import chain | |
| from typing import Any, Dict, Optional, Sequence, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor, nn | |
| from instruct_particulate.utils.inference_utils import ( | |
| axis_point_to_plucker_torch, | |
| estimate_prismatic_limit_torch, | |
| estimate_revolute_limit_torch, | |
| fit_axis_to_closest_points_torch, | |
| ) | |
| from instruct_particulate.utils.partfield_feature_utils import ( | |
| PARTFIELD_FEATURE_DIM, | |
| PartFieldFeatureExtractor, | |
| ) | |
| from instruct_particulate.utils.text_embedding_utils import ( | |
| encode_clip_text_prompts, | |
| load_clip_text_encoder, | |
| ) | |
| def _make_silu_mlp( | |
| input_dim: int, | |
| hidden_dim: int, | |
| output_dim: int, | |
| *, | |
| bias: bool = True, | |
| ) -> nn.Sequential: | |
| return nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim, bias=bias), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, output_dim, bias=bias), | |
| ) | |
| _OVERPARAM_AXIS_AABB_HALF_EXTENT_MIN = 1e-4 | |
| _PLAIN_JOINT_DECODE_TYPES = frozenset({"plain", "plain+fm"}) | |
| _OVERPARAM_JOINT_DECODE_TYPES = frozenset( | |
| {"overparametrized", "overparam+dir", "overparam+singledir"} | |
| ) | |
| def _normalize_joint_decode_type(joint_decode_type: str) -> str: | |
| normalized_joint_decode_type = str(joint_decode_type).lower() | |
| if normalized_joint_decode_type in { | |
| "plain+flowmatching", | |
| "plain+flow-matching", | |
| "plain+fm", | |
| }: | |
| return "plain+fm" | |
| if normalized_joint_decode_type in { | |
| "overparameterization", | |
| "overparameterized", | |
| "overparam", | |
| }: | |
| return "overparametrized" | |
| if normalized_joint_decode_type in { | |
| "overparameterization+dir", | |
| "overparameterized+dir", | |
| "overparametrized+dir", | |
| "overparam+dir", | |
| }: | |
| return "overparam+dir" | |
| if normalized_joint_decode_type in { | |
| "overparameterization+singledir", | |
| "overparameterized+singledir", | |
| "overparametrized+singledir", | |
| "overparam+singledir", | |
| "overparameterization+single-dir", | |
| "overparameterized+single-dir", | |
| "overparametrized+single-dir", | |
| "overparam+single-dir", | |
| }: | |
| return "overparam+singledir" | |
| return normalized_joint_decode_type | |
| def _normalize_joint_fm_prediction_type(prediction_type: str) -> str: | |
| normalized_prediction_type = str(prediction_type).lower() | |
| if normalized_prediction_type in {"x", "xpred", "x-pred", "x_pred"}: | |
| return "x" | |
| if normalized_prediction_type in {"v", "vpred", "v-pred", "v_pred"}: | |
| return "v" | |
| return normalized_prediction_type | |
| def _normalize_overparam_closest_axis_space(closest_axis_space: str) -> str: | |
| normalized_closest_axis_space = str(closest_axis_space).lower() | |
| if normalized_closest_axis_space in { | |
| "world", | |
| "sample", | |
| "global", | |
| "global-space", | |
| "world-space", | |
| }: | |
| return "world" | |
| if normalized_closest_axis_space in { | |
| "part_aabb", | |
| "part-aabb", | |
| "aabb", | |
| "local_aabb", | |
| "local-aabb", | |
| }: | |
| return "part_aabb" | |
| return normalized_closest_axis_space | |
| def _coerce_joint_fm_state_stat( | |
| values: Optional[Sequence[float]], | |
| *, | |
| default_value: float, | |
| name: str, | |
| ) -> Tensor: | |
| if values is None: | |
| values = [default_value] * 8 | |
| if len(values) != 8: | |
| raise ValueError(f"{name} must have length 8, got {len(values)}") | |
| tensor = torch.tensor([float(value) for value in values], dtype=torch.float32) | |
| if not torch.isfinite(tensor).all(): | |
| raise ValueError(f"{name} must contain only finite values, got {values!r}") | |
| if name.endswith("_std") and torch.any(tensor <= 0.0): | |
| raise ValueError(f"{name} must contain only positive values, got {values!r}") | |
| return tensor | |
| def _gather_joint_link_latents( | |
| *, | |
| link_latents: Tensor, | |
| joint_connections: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| parent_indices = joint_connections[..., 0].clamp_min(0) | |
| child_indices = joint_connections[..., 1].clamp_min(0) | |
| gather_index = parent_indices.unsqueeze(-1).expand(-1, -1, link_latents.shape[-1]) | |
| parent_latents = link_latents.gather(dim=1, index=gather_index) | |
| gather_index = child_indices.unsqueeze(-1).expand(-1, -1, link_latents.shape[-1]) | |
| child_latents = link_latents.gather(dim=1, index=gather_index) | |
| return parent_latents, child_latents | |
| def _build_joint_motion_condition_inputs( | |
| *, | |
| parent_latents: Tensor, | |
| child_latents: Tensor, | |
| motion_type: str, | |
| revolute_embedding: Tensor, | |
| prismatic_embedding: Tensor, | |
| ) -> Tensor: | |
| if parent_latents.shape != child_latents.shape: | |
| raise ValueError( | |
| "parent_latents and child_latents must share the same shape, " | |
| f"got {tuple(parent_latents.shape)} and {tuple(child_latents.shape)}" | |
| ) | |
| if motion_type == "revolute": | |
| type_embedding = revolute_embedding | |
| elif motion_type == "prismatic": | |
| type_embedding = prismatic_embedding | |
| else: | |
| raise ValueError( | |
| "motion_type must be 'revolute' or 'prismatic', " | |
| f"got {motion_type!r}" | |
| ) | |
| type_embeddings = type_embedding.to( | |
| device=parent_latents.device, | |
| dtype=parent_latents.dtype, | |
| ).view(1, 1, -1).expand_as(parent_latents) | |
| return torch.cat((type_embeddings, parent_latents, child_latents), dim=-1) | |
| class SwiGLUFeedForward(nn.Module): | |
| """Modern gated MLP used after attention blocks.""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: Optional[int] = None, | |
| multiplier: float = 4.0, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| hidden_dim = hidden_dim or int(dim * multiplier) | |
| self.in_proj = nn.Linear(dim, 2 * hidden_dim) | |
| self.out_proj = nn.Linear(hidden_dim, dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: Tensor) -> Tensor: | |
| value, gate = self.in_proj(x).chunk(2, dim=-1) | |
| x = value * F.silu(gate) | |
| x = self.dropout(x) | |
| return self.out_proj(x) | |
| class FrequencyMLPEmbedder(nn.Module): | |
| """Embeds coordinates or normals with Fourier features and a small MLP.""" | |
| def __init__( | |
| self, | |
| output_dim: int, | |
| num_frequencies: int, | |
| input_dim: int = 3, | |
| hidden_dim: Optional[int] = None, | |
| max_frequency: float = 32.0, | |
| include_raw: bool = True, | |
| ): | |
| super().__init__() | |
| if num_frequencies <= 0: | |
| raise ValueError(f"num_frequencies must be positive, got {num_frequencies}") | |
| if input_dim <= 0: | |
| raise ValueError(f"input_dim must be positive, got {input_dim}") | |
| if max_frequency <= 0.0: | |
| raise ValueError(f"max_frequency must be positive, got {max_frequency}") | |
| hidden_dim = hidden_dim or output_dim | |
| self.input_dim = input_dim | |
| self.include_raw = include_raw | |
| if num_frequencies == 1: | |
| frequencies = torch.tensor([math.pi], dtype=torch.float32) | |
| else: | |
| frequencies = torch.exp( | |
| torch.linspace(0.0, math.log(max_frequency), steps=num_frequencies, dtype=torch.float32) | |
| ) * math.pi | |
| self.register_buffer("frequencies", frequencies, persistent=False) | |
| encoded_dim = input_dim * (2 * num_frequencies + int(include_raw)) | |
| self.input_norm = nn.RMSNorm(encoded_dim) | |
| self.mlp = _make_silu_mlp( | |
| input_dim=encoded_dim, | |
| hidden_dim=hidden_dim, | |
| output_dim=output_dim, | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| if x.shape[-1] != self.input_dim: | |
| raise ValueError( | |
| f"expected inputs with last dimension {self.input_dim}, got {tuple(x.shape)}" | |
| ) | |
| x_float = x.float() | |
| angles = x_float.unsqueeze(-1) * self.frequencies | |
| encoded = torch.cat((angles.sin(), angles.cos()), dim=-1).flatten(start_dim=-2) | |
| if self.include_raw: | |
| encoded = torch.cat((encoded, x_float), dim=-1) | |
| encoded = encoded.to(dtype=self.mlp[0].weight.dtype) | |
| return self.mlp(self.input_norm(encoded)) | |
| class TimestepEmbedder(nn.Module): | |
| """Embeds scalar diffusion/flow timesteps with sinusoidal features.""" | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| *, | |
| frequency_embedding_dim: int = 256, | |
| max_period: float = 10_000.0, | |
| ): | |
| super().__init__() | |
| if hidden_dim <= 0: | |
| raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") | |
| if frequency_embedding_dim <= 0: | |
| raise ValueError( | |
| f"frequency_embedding_dim must be positive, got {frequency_embedding_dim}" | |
| ) | |
| if max_period <= 0.0: | |
| raise ValueError(f"max_period must be positive, got {max_period}") | |
| self.frequency_embedding_dim = int(frequency_embedding_dim) | |
| self.max_period = float(max_period) | |
| self.mlp = _make_silu_mlp( | |
| input_dim=self.frequency_embedding_dim, | |
| hidden_dim=hidden_dim, | |
| output_dim=hidden_dim, | |
| ) | |
| def _frequency_embedding(self, t: Tensor) -> Tensor: | |
| half_dim = self.frequency_embedding_dim // 2 | |
| if half_dim == 0: | |
| return t.unsqueeze(-1) | |
| frequency_exponents = torch.arange( | |
| half_dim, | |
| device=t.device, | |
| dtype=torch.float32, | |
| ) | |
| frequency_exponents = frequency_exponents / max(half_dim, 1) | |
| frequencies = torch.exp(-math.log(self.max_period) * frequency_exponents) | |
| angles = t.float().unsqueeze(-1) * frequencies.unsqueeze(0) | |
| embedding = torch.cat((angles.cos(), angles.sin()), dim=-1) | |
| if self.frequency_embedding_dim % 2 != 0: | |
| embedding = torch.cat((embedding, torch.zeros_like(embedding[:, :1])), dim=-1) | |
| return embedding | |
| def forward(self, t: Tensor) -> Tensor: | |
| if t.ndim != 1: | |
| raise ValueError(f"expected a 1D timestep tensor, got shape {tuple(t.shape)}") | |
| frequency_embedding = self._frequency_embedding(t).to( | |
| dtype=self.mlp[0].weight.dtype | |
| ) | |
| return self.mlp(frequency_embedding) | |
| class SDPASelfAttention(nn.Module): | |
| """Self-attention implemented with PyTorch SDPA.""" | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| num_heads: int, | |
| head_dim: Optional[int] = None, | |
| attn_dropout: float = 0.0, | |
| proj_dropout: float = 0.0, | |
| qk_norm: bool = True, | |
| qkv_bias: bool = True, | |
| ): | |
| super().__init__() | |
| if head_dim is None: | |
| if model_dim % num_heads != 0: | |
| raise ValueError( | |
| f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads}) " | |
| "when head_dim is not specified" | |
| ) | |
| head_dim = model_dim // num_heads | |
| self.model_dim = model_dim | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.inner_dim = num_heads * head_dim | |
| self.attn_dropout = attn_dropout | |
| self.qkv_proj = nn.Linear(model_dim, 3 * self.inner_dim, bias=qkv_bias) | |
| self.out_proj = nn.Linear(self.inner_dim, model_dim) | |
| self.out_dropout = nn.Dropout(proj_dropout) | |
| self.q_norm = nn.RMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() | |
| self.k_norm = nn.RMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() | |
| def _reshape_heads(self, x: Tensor) -> Tensor: | |
| batch_size, seq_len, _ = x.shape | |
| x = x.view(batch_size, seq_len, self.num_heads, self.head_dim) | |
| return x.transpose(1, 2) | |
| def forward(self, x: Tensor, *, mask: Tensor | None = None) -> Tensor: | |
| qkv = self.qkv_proj(x) | |
| q, k, v = qkv.chunk(3, dim=-1) | |
| q = self.q_norm(self._reshape_heads(q)) | |
| k = self.k_norm(self._reshape_heads(k)) | |
| v = self._reshape_heads(v) | |
| attn_mask = None | |
| if mask is not None: | |
| attn_mask = mask[:, None, :, None] & mask[:, None, None, :] | |
| attn_output = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| attn_mask=attn_mask, | |
| dropout_p=self.attn_dropout if self.training else 0.0, | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous().view( | |
| x.shape[0], | |
| x.shape[1], | |
| self.inner_dim, | |
| ) | |
| return self.out_dropout(self.out_proj(attn_output)) | |
| class SDPACrossAttention(nn.Module): | |
| """Cross-attention implemented with PyTorch SDPA.""" | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| num_heads: int, | |
| head_dim: Optional[int] = None, | |
| attn_dropout: float = 0.0, | |
| proj_dropout: float = 0.0, | |
| qk_norm: bool = True, | |
| q_bias: bool = True, | |
| kv_bias: bool = True, | |
| ): | |
| super().__init__() | |
| if head_dim is None: | |
| if model_dim % num_heads != 0: | |
| raise ValueError( | |
| f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads}) " | |
| "when head_dim is not specified" | |
| ) | |
| head_dim = model_dim // num_heads | |
| self.model_dim = model_dim | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.inner_dim = num_heads * head_dim | |
| self.attn_dropout = attn_dropout | |
| self.q_proj = nn.Linear(model_dim, self.inner_dim, bias=q_bias) | |
| self.kv_proj = nn.Linear(model_dim, 2 * self.inner_dim, bias=kv_bias) | |
| self.out_proj = nn.Linear(self.inner_dim, model_dim) | |
| self.out_dropout = nn.Dropout(proj_dropout) | |
| self.q_norm = nn.RMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() | |
| self.k_norm = nn.RMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity() | |
| def _reshape_heads(self, x: Tensor) -> Tensor: | |
| batch_size, seq_len, _ = x.shape | |
| x = x.view(batch_size, seq_len, self.num_heads, self.head_dim) | |
| return x.transpose(1, 2) | |
| def forward( | |
| self, | |
| query: Tensor, | |
| context: Tensor, | |
| *, | |
| query_mask: Tensor | None = None, | |
| context_mask: Tensor | None = None, | |
| ) -> Tensor: | |
| q = self.q_norm(self._reshape_heads(self.q_proj(query))) | |
| k, v = self.kv_proj(context).chunk(2, dim=-1) | |
| k = self.k_norm(self._reshape_heads(k)) | |
| v = self._reshape_heads(v) | |
| if query_mask is None and context_mask is None: | |
| attn_mask = None | |
| else: | |
| if query_mask is None: | |
| query_mask = torch.ones(query.shape[:2], device=query.device, dtype=torch.bool) | |
| if context_mask is None: | |
| context_mask = torch.ones( | |
| context.shape[:2], | |
| device=context.device, | |
| dtype=torch.bool, | |
| ) | |
| attn_mask = query_mask[:, None, :, None] & context_mask[:, None, None, :] | |
| attn_output = F.scaled_dot_product_attention( | |
| q, | |
| k, | |
| v, | |
| attn_mask=attn_mask, | |
| dropout_p=self.attn_dropout if self.training else 0.0, | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous().view( | |
| query.shape[0], | |
| query.shape[1], | |
| self.inner_dim, | |
| ) | |
| return self.out_dropout(self.out_proj(attn_output)) | |
| class CrossAttentionBlock(nn.Module): | |
| """Pre-norm cross-attention block with a gated MLP.""" | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| num_heads: int, | |
| head_dim: Optional[int] = None, | |
| attn_dropout: float = 0.0, | |
| proj_dropout: float = 0.0, | |
| ffn_multiplier: float = 4.0, | |
| ffn_dropout: float = 0.0, | |
| qk_norm: bool = True, | |
| norm_eps: float = 1e-6, | |
| ): | |
| super().__init__() | |
| self.query_norm = nn.RMSNorm(model_dim, eps=norm_eps) | |
| self.context_norm = nn.RMSNorm(model_dim, eps=norm_eps) | |
| self.attn = SDPACrossAttention( | |
| model_dim=model_dim, | |
| num_heads=num_heads, | |
| head_dim=head_dim, | |
| attn_dropout=attn_dropout, | |
| proj_dropout=proj_dropout, | |
| qk_norm=qk_norm, | |
| ) | |
| self.ffn_norm = nn.RMSNorm(model_dim, eps=norm_eps) | |
| self.ffn = SwiGLUFeedForward( | |
| dim=model_dim, | |
| multiplier=ffn_multiplier, | |
| dropout=ffn_dropout, | |
| ) | |
| def forward( | |
| self, | |
| query: Tensor, | |
| context: Tensor, | |
| *, | |
| query_mask: Tensor | None = None, | |
| context_mask: Tensor | None = None, | |
| ) -> Tensor: | |
| query = query + self.attn( | |
| self.query_norm(query), | |
| self.context_norm(context), | |
| query_mask=query_mask, | |
| context_mask=context_mask, | |
| ) | |
| if query_mask is not None: | |
| query = query.masked_fill(~query_mask.unsqueeze(-1), 0) | |
| query = query + self.ffn(self.ffn_norm(query)) | |
| if query_mask is not None: | |
| query = query.masked_fill(~query_mask.unsqueeze(-1), 0) | |
| return query | |
| class SelfAttentionBlock(nn.Module): | |
| """Pre-norm self-attention block with a gated MLP.""" | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| num_heads: int, | |
| head_dim: Optional[int] = None, | |
| attn_dropout: float = 0.0, | |
| proj_dropout: float = 0.0, | |
| ffn_multiplier: float = 4.0, | |
| ffn_dropout: float = 0.0, | |
| qk_norm: bool = True, | |
| norm_eps: float = 1e-6, | |
| ): | |
| super().__init__() | |
| self.attn_norm = nn.RMSNorm(model_dim, eps=norm_eps) | |
| self.attn = SDPASelfAttention( | |
| model_dim=model_dim, | |
| num_heads=num_heads, | |
| head_dim=head_dim, | |
| attn_dropout=attn_dropout, | |
| proj_dropout=proj_dropout, | |
| qk_norm=qk_norm, | |
| ) | |
| self.ffn_norm = nn.RMSNorm(model_dim, eps=norm_eps) | |
| self.ffn = SwiGLUFeedForward( | |
| dim=model_dim, | |
| multiplier=ffn_multiplier, | |
| dropout=ffn_dropout, | |
| ) | |
| def forward(self, x: Tensor, *, mask: Tensor | None = None) -> Tensor: | |
| x = x + self.attn(self.attn_norm(x), mask=mask) | |
| if mask is not None: | |
| x = x.masked_fill(~mask.unsqueeze(-1), 0) | |
| x = x + self.ffn(self.ffn_norm(x)) | |
| if mask is not None: | |
| x = x.masked_fill(~mask.unsqueeze(-1), 0) | |
| return x | |
| class ShapeLatentEncoder(nn.Module): | |
| """Encodes a point cloud into a fixed set of shape latents.""" | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| num_shape_latents: int, | |
| num_heads: int, | |
| head_dim: Optional[int] = None, | |
| attn_dropout: float = 0.0, | |
| proj_dropout: float = 0.0, | |
| ffn_multiplier: float = 4.0, | |
| ffn_dropout: float = 0.0, | |
| qk_norm: bool = True, | |
| norm_eps: float = 1e-6, | |
| ): | |
| super().__init__() | |
| self.shape_latents = nn.Parameter(torch.empty(num_shape_latents, model_dim)) | |
| self.block = CrossAttentionBlock( | |
| model_dim=model_dim, | |
| num_heads=num_heads, | |
| head_dim=head_dim, | |
| attn_dropout=attn_dropout, | |
| proj_dropout=proj_dropout, | |
| ffn_multiplier=ffn_multiplier, | |
| ffn_dropout=ffn_dropout, | |
| qk_norm=qk_norm, | |
| norm_eps=norm_eps, | |
| ) | |
| self.output_norm = nn.RMSNorm(model_dim, eps=norm_eps) | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| nn.init.trunc_normal_(self.shape_latents, std=0.02) | |
| def forward(self, point_tokens: Tensor) -> Tensor: | |
| batch_size = point_tokens.shape[0] | |
| latents = self.shape_latents.to(dtype=point_tokens.dtype).unsqueeze(0).expand( | |
| batch_size, | |
| -1, | |
| -1, | |
| ) | |
| latents = self.block(latents, point_tokens) | |
| return self.output_norm(latents) | |
| class SegmentationDecoder(nn.Module): | |
| """Pairwise query-to-link scoring over padded link tokens.""" | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| decode_type: str = "dot", | |
| bias: bool = True, | |
| query_chunk_size: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.model_dim = model_dim | |
| self.decode_type = str(decode_type) | |
| if self.decode_type not in {"dot", "mlp"}: | |
| raise ValueError( | |
| f"decode_type must be 'dot' or 'mlp', got {decode_type!r}" | |
| ) | |
| if query_chunk_size is not None and int(query_chunk_size) <= 0: | |
| raise ValueError(f"query_chunk_size must be positive when set, got {query_chunk_size}") | |
| self.query_chunk_size = None if query_chunk_size is None else int(query_chunk_size) | |
| if self.decode_type == "mlp": | |
| self.decoder = _make_silu_mlp( | |
| input_dim=model_dim * 2, | |
| hidden_dim=model_dim * 4, | |
| output_dim=1, | |
| bias=bias, | |
| ) | |
| def _decode_mlp_logits( | |
| self, | |
| query_latents: Tensor, | |
| link_latents: Tensor, | |
| ) -> Tensor: | |
| expanded_link_latents = link_latents.unsqueeze(1).expand( | |
| -1, | |
| query_latents.shape[1], | |
| -1, | |
| -1, | |
| ) | |
| expanded_query_latents = query_latents.unsqueeze(2).expand( | |
| -1, | |
| -1, | |
| link_latents.shape[1], | |
| -1, | |
| ) | |
| pair_latents = torch.cat( | |
| (expanded_link_latents, expanded_query_latents), | |
| dim=-1, | |
| ) | |
| return self.decoder(pair_latents).squeeze(-1) | |
| def forward( | |
| self, | |
| query_latents: Tensor, | |
| link_latents: Tensor, | |
| link_valid_flag: Tensor, | |
| ) -> Tensor: | |
| if self.decode_type == "dot": | |
| logits = torch.matmul(query_latents, link_latents.transpose(-1, -2)) | |
| else: | |
| if ( | |
| self.query_chunk_size is None | |
| or query_latents.shape[1] <= self.query_chunk_size | |
| ): | |
| logits = self._decode_mlp_logits(query_latents, link_latents) | |
| else: | |
| logits = torch.cat( | |
| [ | |
| self._decode_mlp_logits( | |
| query_latents[:, start : start + self.query_chunk_size], | |
| link_latents, | |
| ) | |
| for start in range(0, query_latents.shape[1], self.query_chunk_size) | |
| ], | |
| dim=1, | |
| ) | |
| return logits.masked_fill(~link_valid_flag.unsqueeze(1), float("-inf")) | |
| class JointDecoderPlain(nn.Module): | |
| """Decodes per-joint axes and ranges from parent/child link latents.""" | |
| motion_state_dim = 8 | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.model_dim = model_dim | |
| self.revolute_embedding = nn.Parameter(torch.empty(model_dim)) | |
| self.prismatic_embedding = nn.Parameter(torch.empty(model_dim)) | |
| self.decoder = _make_silu_mlp( | |
| input_dim=model_dim * 3, | |
| hidden_dim=model_dim * 4, | |
| output_dim=self.motion_state_dim, | |
| bias=bias, | |
| ) | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| nn.init.trunc_normal_(self.revolute_embedding, std=0.02) | |
| nn.init.trunc_normal_(self.prismatic_embedding, std=0.02) | |
| def _predict_motion_parameters( | |
| self, | |
| *, | |
| parent_latents: Tensor, | |
| child_latents: Tensor, | |
| active_mask: Tensor, | |
| motion_type: str, | |
| ) -> Tuple[Tensor, Tensor]: | |
| decoded = self.decoder( | |
| _build_joint_motion_condition_inputs( | |
| parent_latents=parent_latents, | |
| child_latents=child_latents, | |
| motion_type=motion_type, | |
| revolute_embedding=self.revolute_embedding, | |
| prismatic_embedding=self.prismatic_embedding, | |
| ) | |
| ) | |
| decoded = decoded.masked_fill(~active_mask.unsqueeze(-1), 0) | |
| return decoded[..., :6], decoded[..., 6:8] | |
| def predict( | |
| self, | |
| link_latents: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| is_revolute: Tensor, | |
| is_prismatic: Tensor, | |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
| """Decodes padded joint tensors and zeros invalid rows in the outputs.""" | |
| parent_latents, child_latents = _gather_joint_link_latents( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| ) | |
| motion_flags = dict(revolute=is_revolute, prismatic=is_prismatic) | |
| motion_outputs: Dict[str, Tensor] = {} | |
| for motion_type in ("revolute", "prismatic"): | |
| motion_outputs[f"{motion_type}_axis"], motion_outputs[f"{motion_type}_range"] = self._predict_motion_parameters( | |
| parent_latents=parent_latents, | |
| child_latents=child_latents, | |
| active_mask=joint_valid_flag & motion_flags[motion_type], | |
| motion_type=motion_type) | |
| return ( | |
| motion_outputs["revolute_axis"], | |
| motion_outputs["prismatic_axis"], | |
| motion_outputs["revolute_range"], | |
| motion_outputs["prismatic_range"], | |
| ) | |
| class JointDecoderPlainFlowMatching(nn.Module): | |
| """Predicts plain joint FM states with configurable x/v parameterization.""" | |
| motion_state_dim = 8 | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| *, | |
| hidden_dim: Optional[int] = None, | |
| prediction_type: str = "v", | |
| time_embedding_dim: int = 256, | |
| inference_steps: int = 100, | |
| time_scale: float = 1000.0, | |
| sigma_min: float = 0.0, | |
| rescale_t: float = 1.0, | |
| cfg_scale: float = 1.0, | |
| revolute_state_mean: Optional[Sequence[float]] = None, | |
| revolute_state_std: Optional[Sequence[float]] = None, | |
| prismatic_state_mean: Optional[Sequence[float]] = None, | |
| prismatic_state_std: Optional[Sequence[float]] = None, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.model_dim = model_dim | |
| self.hidden_dim = int(hidden_dim or model_dim) | |
| self.inference_steps = int(inference_steps) | |
| if self.inference_steps <= 0: | |
| raise ValueError(f"inference_steps must be positive, got {inference_steps}") | |
| prediction_type = _normalize_joint_fm_prediction_type(prediction_type) | |
| if prediction_type not in {"x", "v"}: | |
| raise ValueError( | |
| "prediction_type must be 'x' or 'v', " | |
| f"got {prediction_type!r}" | |
| ) | |
| self.prediction_type = prediction_type | |
| self.time_scale = float(time_scale) | |
| self.sigma_min = float(sigma_min) | |
| self.rescale_t = float(rescale_t) | |
| self.cfg_scale = float(cfg_scale) | |
| self.register_buffer( | |
| "revolute_state_mean", | |
| _coerce_joint_fm_state_stat( | |
| revolute_state_mean, | |
| default_value=0.0, | |
| name="revolute_state_mean", | |
| ), | |
| persistent=True, | |
| ) | |
| self.register_buffer( | |
| "revolute_state_std", | |
| _coerce_joint_fm_state_stat( | |
| revolute_state_std, | |
| default_value=1.0, | |
| name="revolute_state_std", | |
| ), | |
| persistent=True, | |
| ) | |
| self.register_buffer( | |
| "prismatic_state_mean", | |
| _coerce_joint_fm_state_stat( | |
| prismatic_state_mean, | |
| default_value=0.0, | |
| name="prismatic_state_mean", | |
| ), | |
| persistent=True, | |
| ) | |
| self.register_buffer( | |
| "prismatic_state_std", | |
| _coerce_joint_fm_state_stat( | |
| prismatic_state_std, | |
| default_value=1.0, | |
| name="prismatic_state_std", | |
| ), | |
| persistent=True, | |
| ) | |
| self.revolute_embedding = nn.Parameter(torch.empty(model_dim)) | |
| self.prismatic_embedding = nn.Parameter(torch.empty(model_dim)) | |
| self.condition_projector = _make_silu_mlp( | |
| input_dim=model_dim * 3, | |
| hidden_dim=self.hidden_dim * 2, | |
| output_dim=self.hidden_dim, | |
| bias=bias, | |
| ) | |
| self.state_projector = nn.Linear( | |
| self.motion_state_dim, | |
| self.hidden_dim, | |
| bias=bias, | |
| ) | |
| self.time_embedder = TimestepEmbedder( | |
| self.hidden_dim, | |
| frequency_embedding_dim=time_embedding_dim, | |
| ) | |
| self.input_norm = nn.RMSNorm(self.hidden_dim) | |
| self.residual_block_1 = _make_silu_mlp( | |
| input_dim=self.hidden_dim, | |
| hidden_dim=self.hidden_dim * 4, | |
| output_dim=self.hidden_dim, | |
| bias=bias, | |
| ) | |
| self.residual_block_2 = _make_silu_mlp( | |
| input_dim=self.hidden_dim, | |
| hidden_dim=self.hidden_dim * 4, | |
| output_dim=self.hidden_dim, | |
| bias=bias, | |
| ) | |
| self.output_norm = nn.RMSNorm(self.hidden_dim) | |
| self.output_projector = nn.Linear( | |
| self.hidden_dim, | |
| self.motion_state_dim, | |
| bias=bias, | |
| ) | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| nn.init.trunc_normal_(self.revolute_embedding, std=0.02) | |
| nn.init.trunc_normal_(self.prismatic_embedding, std=0.02) | |
| def _prepare_t_sequence( | |
| self, | |
| *, | |
| device: torch.device, | |
| steps: int | None = None, | |
| dtype: torch.dtype = torch.float32, | |
| ) -> Tensor: | |
| num_steps = self.inference_steps if steps is None else int(steps) | |
| if num_steps <= 0: | |
| raise ValueError(f"steps must be positive, got {num_steps}") | |
| t_sequence = torch.linspace( | |
| 0.0, | |
| 1.0, | |
| num_steps + 1, | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| if self.rescale_t: | |
| t_sequence = t_sequence / ( | |
| 1.0 + (self.rescale_t - 1.0) * (1.0 - t_sequence) | |
| ) | |
| return t_sequence.to(dtype=dtype) | |
| def _motion_state_mean_and_std( | |
| self, | |
| *, | |
| motion_type: str, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> Tuple[Tensor, Tensor]: | |
| if motion_type == "revolute": | |
| mean = self.revolute_state_mean | |
| std = self.revolute_state_std | |
| elif motion_type == "prismatic": | |
| mean = self.prismatic_state_mean | |
| std = self.prismatic_state_std | |
| else: | |
| raise ValueError( | |
| "motion_type must be 'revolute' or 'prismatic', " | |
| f"got {motion_type!r}" | |
| ) | |
| safe_std = std.to(device=device, dtype=dtype).clamp_min(1.0e-6) | |
| return mean.to(device=device, dtype=dtype), safe_std | |
| def _unnormalize_motion_state( | |
| self, | |
| motion_state: Tensor, | |
| *, | |
| motion_type: str, | |
| ) -> Tensor: | |
| mean, std = self._motion_state_mean_and_std( | |
| motion_type=motion_type, | |
| device=motion_state.device, | |
| dtype=motion_state.dtype, | |
| ) | |
| return motion_state * std + mean | |
| def _sample_motion_state( | |
| self, | |
| *, | |
| condition_embeddings: Tensor, | |
| steps: int | None = None, | |
| cfg_scale: float | None = None, | |
| ) -> Tensor: | |
| num_samples = condition_embeddings.shape[0] | |
| if num_samples == 0: | |
| return condition_embeddings.new_zeros((0, self.motion_state_dim)) | |
| guidance_scale = self.cfg_scale if cfg_scale is None else float(cfg_scale) | |
| motion_state = torch.randn( | |
| (num_samples, self.motion_state_dim), | |
| device=condition_embeddings.device, | |
| dtype=condition_embeddings.dtype, | |
| ) | |
| t_sequence = self._prepare_t_sequence( | |
| device=condition_embeddings.device, | |
| steps=steps, | |
| dtype=condition_embeddings.dtype, | |
| ) | |
| for t0, t1 in zip(t_sequence[:-1], t_sequence[1:]): | |
| time_vector = motion_state.new_full((num_samples,), t0) | |
| if guidance_scale == 1.0: | |
| velocity = self.predict_velocity( | |
| x_t=motion_state, | |
| t=time_vector, | |
| condition_embeddings=condition_embeddings, | |
| ) | |
| else: | |
| velocity = self._predict_guided_velocity( | |
| x_t=motion_state, | |
| t=time_vector, | |
| condition_embeddings=condition_embeddings, | |
| cfg_scale=guidance_scale, | |
| ) | |
| motion_state = motion_state + (t1 - t0).to(dtype=motion_state.dtype) * velocity | |
| return motion_state | |
| def _bridge_noise_scale(self, t: Tensor) -> Tensor: | |
| return 1.0 - (1.0 - self.sigma_min) * t | |
| def _predict_model_output( | |
| self, | |
| *, | |
| x_t: Tensor, | |
| t: Tensor, | |
| condition_embeddings: Tensor, | |
| ) -> Tensor: | |
| if x_t.ndim != 2: | |
| raise ValueError(f"x_t must be rank-2, got shape {tuple(x_t.shape)}") | |
| if x_t.shape[-1] != self.motion_state_dim: | |
| raise ValueError( | |
| f"x_t last dim must be {self.motion_state_dim}, got {tuple(x_t.shape)}" | |
| ) | |
| if condition_embeddings.ndim != 2: | |
| raise ValueError( | |
| "condition_embeddings must be rank-2, " | |
| f"got shape {tuple(condition_embeddings.shape)}" | |
| ) | |
| if condition_embeddings.shape[0] != x_t.shape[0]: | |
| raise ValueError( | |
| "condition_embeddings batch dim must match x_t, " | |
| f"got {tuple(condition_embeddings.shape)} and {tuple(x_t.shape)}" | |
| ) | |
| if t.ndim != 1 or t.shape[0] != x_t.shape[0]: | |
| raise ValueError( | |
| "t must be a vector matching x_t batch size, " | |
| f"got {tuple(t.shape)} and {tuple(x_t.shape)}" | |
| ) | |
| x_hidden = self.state_projector(x_t.to(dtype=self.state_projector.weight.dtype)) | |
| time_hidden = self.time_embedder(t * self.time_scale) | |
| condition_embeddings = condition_embeddings.to(dtype=x_hidden.dtype) | |
| hidden = self.input_norm(x_hidden + time_hidden + condition_embeddings) | |
| hidden = hidden + self.residual_block_1(hidden) | |
| hidden = hidden + self.residual_block_2(hidden) | |
| return self.output_projector(self.output_norm(hidden)) | |
| def _predicted_clean_state_to_velocity( | |
| self, | |
| *, | |
| predicted_clean_state: Tensor, | |
| x_t: Tensor, | |
| t: Tensor, | |
| ) -> Tensor: | |
| bridge_noise_scale = self._bridge_noise_scale(t).unsqueeze(-1).to( | |
| dtype=predicted_clean_state.dtype, | |
| ) | |
| bridge_noise_scale = bridge_noise_scale.clamp_min( | |
| torch.finfo(predicted_clean_state.dtype).eps | |
| ) | |
| x_t = x_t.to(dtype=predicted_clean_state.dtype) | |
| return ( | |
| predicted_clean_state - (1.0 - self.sigma_min) * x_t | |
| ) / bridge_noise_scale | |
| def predict_velocity( | |
| self, | |
| *, | |
| x_t: Tensor, | |
| t: Tensor, | |
| condition_embeddings: Tensor, | |
| ) -> Tensor: | |
| model_output = self._predict_model_output( | |
| x_t=x_t, | |
| t=t, | |
| condition_embeddings=condition_embeddings, | |
| ) | |
| if self.prediction_type == "v": | |
| return model_output | |
| return self._predicted_clean_state_to_velocity( | |
| predicted_clean_state=model_output, | |
| x_t=x_t, | |
| t=t, | |
| ) | |
| def _predict_guided_velocity( | |
| self, | |
| *, | |
| x_t: Tensor, | |
| t: Tensor, | |
| condition_embeddings: Tensor, | |
| cfg_scale: float | None = None, | |
| ) -> Tensor: | |
| guidance_scale = self.cfg_scale if cfg_scale is None else float(cfg_scale) | |
| if not math.isfinite(guidance_scale) or guidance_scale < 0.0: | |
| raise ValueError( | |
| "cfg_scale must be a finite non-negative number, " | |
| f"got {guidance_scale!r}" | |
| ) | |
| conditional_velocity = self.predict_velocity( | |
| x_t=x_t, | |
| t=t, | |
| condition_embeddings=condition_embeddings, | |
| ) | |
| unconditional_velocity = self.predict_velocity( | |
| x_t=x_t, | |
| t=t, | |
| condition_embeddings=torch.zeros_like(condition_embeddings), | |
| ) | |
| return unconditional_velocity + guidance_scale * ( | |
| conditional_velocity - unconditional_velocity | |
| ) | |
| def predict( | |
| self, | |
| link_latents: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| is_revolute: Tensor, | |
| is_prismatic: Tensor, | |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
| parent_latents, child_latents = _gather_joint_link_latents( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| ) | |
| batch_size, max_joints = joint_valid_flag.shape | |
| motion_flags = dict(revolute=is_revolute, prismatic=is_prismatic) | |
| motion_outputs: Dict[str, Tensor] = {} | |
| for motion_type in ("revolute", "prismatic"): | |
| motion_outputs[f"{motion_type}_axis"] = link_latents.new_zeros( | |
| (batch_size, max_joints, 6) | |
| ) | |
| motion_outputs[f"{motion_type}_range"] = link_latents.new_zeros( | |
| (batch_size, max_joints, 2) | |
| ) | |
| active_mask = joint_valid_flag & motion_flags[motion_type] | |
| if not torch.any(active_mask): | |
| continue | |
| condition_embeddings = self.condition_projector( | |
| _build_joint_motion_condition_inputs( | |
| parent_latents=parent_latents, | |
| child_latents=child_latents, | |
| motion_type=motion_type, | |
| revolute_embedding=self.revolute_embedding, | |
| prismatic_embedding=self.prismatic_embedding, | |
| ) | |
| )[active_mask] | |
| motion_state = self._sample_motion_state( | |
| condition_embeddings=condition_embeddings, | |
| ) | |
| motion_state = self._unnormalize_motion_state( | |
| motion_state, | |
| motion_type=motion_type, | |
| ) | |
| motion_outputs[f"{motion_type}_axis"][active_mask] = motion_state[..., :6] | |
| motion_outputs[f"{motion_type}_range"][active_mask] = motion_state[..., 6:8] | |
| return ( | |
| motion_outputs["revolute_axis"], | |
| motion_outputs["prismatic_axis"], | |
| motion_outputs["revolute_range"], | |
| motion_outputs["prismatic_range"], | |
| ) | |
| class JointDecoderOverParametrized(nn.Module): | |
| """Decodes per-query over-parameterized joint supervision targets.""" | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| output_dim: int = 9, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.model_dim = model_dim | |
| self.output_dim = int(output_dim) | |
| if self.output_dim not in {9, 12}: | |
| raise ValueError( | |
| f"JointDecoderOverParametrized output_dim must be 9 or 12, got {output_dim}" | |
| ) | |
| self.revolute_embedding = nn.Parameter(torch.empty(model_dim)) | |
| self.prismatic_embedding = nn.Parameter(torch.empty(model_dim)) | |
| self.joint_decoder = _make_silu_mlp( | |
| input_dim=model_dim * 4, | |
| hidden_dim=model_dim * 4, | |
| output_dim=self.output_dim, | |
| bias=bias, | |
| ) | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| nn.init.trunc_normal_(self.revolute_embedding, std=0.02) | |
| nn.init.trunc_normal_(self.prismatic_embedding, std=0.02) | |
| def forward( | |
| self, | |
| query_latents: Tensor, | |
| link_latents: Tensor, | |
| assigned_link_ids: Tensor, | |
| joint_connections: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """Decodes `(closest_axis_point, low_pose_point, high_pose_point)` per query.""" | |
| batch_size, num_links = link_latents.shape[:2] | |
| parent_link_ids = torch.arange( | |
| num_links, | |
| device=assigned_link_ids.device, | |
| dtype=assigned_link_ids.dtype, | |
| ).view(1, -1).expand(batch_size, -1).clone() # [batch_size, num_links] | |
| valid_joint_mask = joint_connections[..., 1] >= 0 | |
| child_link_ids = joint_connections[..., 1].clamp_min(0) | |
| batch_indices = torch.arange( | |
| batch_size, | |
| device=joint_connections.device, | |
| ).unsqueeze(1).expand_as(child_link_ids) | |
| parent_link_ids[ | |
| batch_indices[valid_joint_mask], | |
| child_link_ids[valid_joint_mask], | |
| ] = joint_connections[..., 0][valid_joint_mask] | |
| assigned_link_latents = link_latents.gather( | |
| dim=1, | |
| index=assigned_link_ids.unsqueeze(-1).expand(-1, -1, link_latents.shape[-1]), | |
| ) | |
| assigned_parent_link_ids = parent_link_ids.gather(dim=1, index=assigned_link_ids) | |
| assigned_parent_link_latents = link_latents.gather( | |
| dim=1, | |
| index=assigned_parent_link_ids.unsqueeze(-1).expand(-1, -1, link_latents.shape[-1]), | |
| ) | |
| revolute_type_embeddings = self.revolute_embedding.to( | |
| device=assigned_link_latents.device, | |
| dtype=assigned_link_latents.dtype, | |
| ).view(1, 1, -1).expand_as(assigned_link_latents) | |
| prismatic_type_embeddings = self.prismatic_embedding.to( | |
| device=assigned_link_latents.device, | |
| dtype=assigned_link_latents.dtype, | |
| ).view(1, 1, -1).expand_as(assigned_link_latents) | |
| return ( | |
| self.joint_decoder( | |
| torch.cat( | |
| ( | |
| revolute_type_embeddings, | |
| query_latents, | |
| assigned_parent_link_latents, | |
| assigned_link_latents, | |
| ), | |
| dim=-1, | |
| ) | |
| ), | |
| self.joint_decoder( | |
| torch.cat( | |
| ( | |
| prismatic_type_embeddings, | |
| query_latents, | |
| assigned_parent_link_latents, | |
| assigned_link_latents, | |
| ), | |
| dim=-1, | |
| ) | |
| ), | |
| ) | |
| class JointDecoderSingleDirection(nn.Module): | |
| """Decodes one axis direction per joint from parent/child link latents.""" | |
| direction_dim = 3 | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.model_dim = model_dim | |
| self.revolute_embedding = nn.Parameter(torch.empty(model_dim)) | |
| self.prismatic_embedding = nn.Parameter(torch.empty(model_dim)) | |
| self.decoder = _make_silu_mlp( | |
| input_dim=model_dim * 3, | |
| hidden_dim=model_dim * 4, | |
| output_dim=self.direction_dim, | |
| bias=bias, | |
| ) | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| nn.init.trunc_normal_(self.revolute_embedding, std=0.02) | |
| nn.init.trunc_normal_(self.prismatic_embedding, std=0.02) | |
| def _predict_directions( | |
| self, | |
| *, | |
| parent_latents: Tensor, | |
| child_latents: Tensor, | |
| active_mask: Tensor, | |
| motion_type: str, | |
| ) -> Tensor: | |
| decoded = self.decoder( | |
| _build_joint_motion_condition_inputs( | |
| parent_latents=parent_latents, | |
| child_latents=child_latents, | |
| motion_type=motion_type, | |
| revolute_embedding=self.revolute_embedding, | |
| prismatic_embedding=self.prismatic_embedding, | |
| ) | |
| ) | |
| return decoded.masked_fill(~active_mask.unsqueeze(-1), 0) | |
| def predict( | |
| self, | |
| *, | |
| link_latents: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| is_revolute: Tensor, | |
| is_prismatic: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| parent_latents, child_latents = _gather_joint_link_latents( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| ) | |
| return ( | |
| self._predict_directions( | |
| parent_latents=parent_latents, | |
| child_latents=child_latents, | |
| active_mask=joint_valid_flag & is_revolute, | |
| motion_type="revolute", | |
| ), | |
| self._predict_directions( | |
| parent_latents=parent_latents, | |
| child_latents=child_latents, | |
| active_mask=joint_valid_flag & is_prismatic, | |
| motion_type="prismatic", | |
| ), | |
| ) | |
| class Particulate2Encoder(nn.Module): | |
| """Encoder backbone for articulated point-cloud understanding.""" | |
| def __init__( | |
| self, | |
| model_dim: int = 768, | |
| num_heads: int = 12, | |
| head_dim: Optional[int] = None, | |
| num_shape_latents: int = 256, | |
| num_attention_blocks: int = 6, | |
| coordinate_num_frequencies: int = 32, | |
| normal_num_frequencies: int = 32, | |
| coordinate_max_frequency: float = 64.0, | |
| normal_max_frequency: float = 16.0, | |
| link_text_feature_dim: int = 768, | |
| embedding_hidden_dim: Optional[int] = None, | |
| attn_dropout: float = 0.0, | |
| proj_dropout: float = 0.0, | |
| ffn_multiplier: float = 4.0, | |
| ffn_dropout: float = 0.0, | |
| qk_norm: bool = True, | |
| norm_eps: float = 1e-6, | |
| clip_model_name: str = "openai/clip-vit-large-patch14", | |
| compute_link_text_embeddings_on_the_fly: bool = False, | |
| use_text_conditioning: bool = True, | |
| clip_text_batch_size: int = 64, | |
| dropout_all_normals_prob: float = 0.3, | |
| dropout_all_point_prompts_prob: float = 0.0, | |
| dropout_individual_point_prompt_prob: float = 0.0, | |
| dropout_individual_text_conditioning_prob: float = 0.0, | |
| use_pretrained_features_shape: bool = False, | |
| use_pretrained_features_query: bool = False, | |
| use_pretrained_features_point_prompt: bool = False, | |
| pretrained_semantic_point_feature_dim: int = PARTFIELD_FEATURE_DIM, | |
| ): | |
| super().__init__() | |
| self.model_dim = model_dim | |
| self.link_text_feature_dim = link_text_feature_dim | |
| self.clip_model_name = clip_model_name | |
| self.compute_link_text_embeddings_on_the_fly = bool(compute_link_text_embeddings_on_the_fly) | |
| self.use_text_conditioning = bool(use_text_conditioning) | |
| if int(clip_text_batch_size) <= 0: | |
| raise ValueError(f"clip_text_batch_size must be positive, got {clip_text_batch_size}") | |
| self.clip_text_batch_size = int(clip_text_batch_size) | |
| self.dropout_all_normals_prob = float(dropout_all_normals_prob) | |
| if not 0.0 <= self.dropout_all_normals_prob <= 1.0: | |
| raise ValueError( | |
| "dropout_all_normals_prob must be in [0, 1], " | |
| f"got {dropout_all_normals_prob}" | |
| ) | |
| self.dropout_all_point_prompts_prob = dropout_all_point_prompts_prob | |
| self.dropout_individual_point_prompt_prob = dropout_individual_point_prompt_prob | |
| self.dropout_individual_text_conditioning_prob = float( | |
| dropout_individual_text_conditioning_prob | |
| ) | |
| if not 0.0 <= self.dropout_individual_text_conditioning_prob <= 1.0: | |
| raise ValueError( | |
| "dropout_individual_text_conditioning_prob must be in [0, 1], " | |
| f"got {dropout_individual_text_conditioning_prob}" | |
| ) | |
| if not self.use_text_conditioning and ( | |
| self.dropout_all_point_prompts_prob != 0.0 | |
| or self.dropout_individual_point_prompt_prob != 0.0 | |
| or self.dropout_individual_text_conditioning_prob != 0.0 | |
| ): | |
| raise ValueError( | |
| "use_text_conditioning=False requires " | |
| "dropout_all_point_prompts_prob=0 and " | |
| "dropout_individual_point_prompt_prob=0 and " | |
| "dropout_individual_text_conditioning_prob=0 to avoid ambiguous " | |
| "text-free vs. point-prompt-dropped training examples" | |
| ) | |
| self.use_pretrained_features_shape = bool(use_pretrained_features_shape) | |
| self.use_pretrained_features_query = bool(use_pretrained_features_query) | |
| self.use_pretrained_features_point_prompt = bool(use_pretrained_features_point_prompt) | |
| self.pretrained_semantic_point_feature_dim = int(pretrained_semantic_point_feature_dim) | |
| self._partfield_feature_extractor: Optional[PartFieldFeatureExtractor] = None | |
| self.text_tokenizer: Any = None | |
| self.text_model: Optional[nn.Module] = None | |
| attn_block_kwargs = { | |
| "model_dim": model_dim, | |
| "num_heads": num_heads, | |
| "head_dim": head_dim, | |
| "attn_dropout": attn_dropout, | |
| "proj_dropout": proj_dropout, | |
| "ffn_multiplier": ffn_multiplier, | |
| "ffn_dropout": ffn_dropout, | |
| "qk_norm": qk_norm, | |
| "norm_eps": norm_eps, | |
| } | |
| self.coordinate_embedder = FrequencyMLPEmbedder( | |
| output_dim=model_dim, | |
| num_frequencies=coordinate_num_frequencies, | |
| input_dim=3, | |
| hidden_dim=embedding_hidden_dim, | |
| max_frequency=coordinate_max_frequency, | |
| ) | |
| self.normal_embedder = FrequencyMLPEmbedder( | |
| output_dim=model_dim, | |
| num_frequencies=normal_num_frequencies, | |
| input_dim=3, | |
| hidden_dim=embedding_hidden_dim, | |
| max_frequency=normal_max_frequency, | |
| ) | |
| self.shape_encoder = ShapeLatentEncoder( | |
| num_shape_latents=num_shape_latents, | |
| **attn_block_kwargs, | |
| ) | |
| text_hidden_dim = embedding_hidden_dim or model_dim | |
| self.link_text_input_norm = nn.RMSNorm(link_text_feature_dim, eps=norm_eps) | |
| self.link_text_projector = _make_silu_mlp( | |
| input_dim=link_text_feature_dim, | |
| hidden_dim=text_hidden_dim, | |
| output_dim=model_dim, | |
| ) | |
| if ( | |
| self.use_pretrained_features_shape | |
| or self.use_pretrained_features_query | |
| or self.use_pretrained_features_point_prompt | |
| ): | |
| self.pretrained_feature_input_norm = nn.RMSNorm( | |
| self.pretrained_semantic_point_feature_dim, | |
| eps=norm_eps, | |
| ) | |
| self.pretrained_feature_projector = _make_silu_mlp( | |
| input_dim=self.pretrained_semantic_point_feature_dim, | |
| hidden_dim=text_hidden_dim, | |
| output_dim=model_dim, | |
| ) | |
| else: | |
| self.pretrained_feature_input_norm = None | |
| self.pretrained_feature_projector = None | |
| self.link_to_shape_cross_attn = nn.ModuleList( | |
| [ | |
| CrossAttentionBlock(**attn_block_kwargs) | |
| for _ in range(num_attention_blocks) | |
| ] | |
| ) | |
| self.link_self_attn = nn.ModuleList( | |
| [ | |
| SelfAttentionBlock(**attn_block_kwargs) | |
| for _ in range(num_attention_blocks) | |
| ] | |
| ) | |
| self.query_to_shape_cross_attn = nn.ModuleList( | |
| [ | |
| CrossAttentionBlock(**attn_block_kwargs) | |
| for _ in range(num_attention_blocks) | |
| ] | |
| ) | |
| self.query_to_link_cross_attn = nn.ModuleList( | |
| [ | |
| CrossAttentionBlock(**attn_block_kwargs) | |
| for _ in range(num_attention_blocks) | |
| ] | |
| ) | |
| self.no_point_prompt_embedding = nn.Parameter(torch.zeros(model_dim)) | |
| self.no_text_conditioning_embedding = nn.Parameter(torch.zeros(model_dim)) | |
| self.link_output_norm = nn.RMSNorm(model_dim, eps=norm_eps) | |
| self.query_output_norm = nn.RMSNorm(model_dim, eps=norm_eps) | |
| def encode_shape( | |
| self, | |
| shape_points: Tensor, | |
| shape_point_normals: Tensor, | |
| pretrained_features: Tensor | None = None, | |
| drop_normal_mask: Tensor | None = None, | |
| ) -> Tensor: | |
| point_tokens = self._embed_point_tokens( | |
| shape_points, | |
| shape_point_normals, | |
| pretrained_features=pretrained_features, | |
| drop_normal_mask=drop_normal_mask, | |
| ) | |
| return self.shape_encoder(point_tokens) | |
| def encode_links( | |
| self, | |
| link_point_prompts: Tensor | None, | |
| link_point_prompt_normals: Tensor | None, | |
| link_valid_flag: Tensor, | |
| link_point_prompt_dropout_eligible: Tensor | None = None, | |
| forced_no_point_prompt_mask: Tensor | None = None, | |
| drop_normal_mask: Tensor | None = None, | |
| link_point_prompt_pretrained_features: Tensor | None = None, | |
| link_text_prompts: Optional[Sequence[Sequence[str]]] = None, | |
| link_text_embeddings: Tensor | None = None, | |
| ) -> Tensor: | |
| """Embeds valid link prompts and leaves padded link slots at zero. | |
| When prompt tensors are omitted, valid links use the learned | |
| `no_point_prompt_embedding` instead of point-derived features. When | |
| `forced_no_point_prompt_mask` is provided, those valid links also use | |
| the no-prompt embedding even in eval mode. | |
| """ | |
| batch_size, max_links = link_valid_flag.shape | |
| no_point_prompt_mask = self._resolve_no_point_prompt_mask( | |
| link_point_prompts=link_point_prompts, | |
| link_valid_flag=link_valid_flag, | |
| link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible, | |
| forced_no_point_prompt_mask=forced_no_point_prompt_mask, | |
| ) | |
| point_features = self._resolve_valid_link_point_features( | |
| link_point_prompts=link_point_prompts, | |
| link_point_prompt_normals=link_point_prompt_normals, | |
| link_valid_flag=link_valid_flag, | |
| no_point_prompt_mask=no_point_prompt_mask, | |
| drop_normal_mask=drop_normal_mask, | |
| link_point_prompt_pretrained_features=link_point_prompt_pretrained_features, | |
| ) | |
| if self.use_text_conditioning: | |
| text_features = self._resolve_valid_link_text_features( | |
| link_text_prompts=link_text_prompts, | |
| link_text_embeddings=link_text_embeddings, | |
| link_valid_flag=link_valid_flag, | |
| ) | |
| else: | |
| text_features = torch.zeros( | |
| (int(link_valid_flag.sum().item()), self.link_text_feature_dim), | |
| device=link_valid_flag.device, | |
| dtype=self.link_text_projector[0].weight.dtype, | |
| ) | |
| projected_text_features = self.link_text_projector( | |
| self.link_text_input_norm(text_features) | |
| ) | |
| if self.use_text_conditioning: | |
| text_conditioning_dropout_mask = self._sample_text_conditioning_dropout_mask( | |
| link_valid_flag, | |
| no_point_prompt_mask=no_point_prompt_mask, | |
| ) | |
| dropped_valid_text_links = text_conditioning_dropout_mask[link_valid_flag] | |
| projected_text_features = torch.where( | |
| dropped_valid_text_links.unsqueeze(-1), | |
| self.no_text_conditioning_embedding.to( | |
| device=projected_text_features.device, | |
| dtype=projected_text_features.dtype, | |
| ).unsqueeze(0).expand_as(projected_text_features), | |
| projected_text_features, | |
| ) | |
| if point_features.dtype != projected_text_features.dtype: | |
| point_features = point_features.to(dtype=projected_text_features.dtype) | |
| link_latents = projected_text_features.new_zeros((batch_size, max_links, self.model_dim)) | |
| valid_link_features = point_features + projected_text_features | |
| link_latents[link_valid_flag] = valid_link_features | |
| return link_latents | |
| def _run_attention_blocks( | |
| self, | |
| shape_latents: Tensor, | |
| link_latents: Tensor, | |
| query_latents: Tensor, | |
| link_valid_flag: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| for ( | |
| link_to_shape_cross_attn, | |
| link_self_attn, | |
| query_to_shape_cross_attn, | |
| query_to_link_cross_attn, | |
| ) in zip( | |
| self.link_to_shape_cross_attn, | |
| self.link_self_attn, | |
| self.query_to_shape_cross_attn, | |
| self.query_to_link_cross_attn, | |
| strict=True, | |
| ): | |
| link_latents = link_to_shape_cross_attn( | |
| link_latents, | |
| shape_latents, | |
| query_mask=link_valid_flag, | |
| ) | |
| link_latents = link_self_attn(link_latents, mask=link_valid_flag) | |
| query_latents = query_to_shape_cross_attn(query_latents, shape_latents) | |
| query_latents = query_to_link_cross_attn( | |
| query_latents, | |
| link_latents, | |
| context_mask=link_valid_flag, | |
| ) | |
| # Residual blocks can write into padded rows unless they are zeroed again. | |
| link_latents = self.link_output_norm(link_latents) | |
| link_latents = link_latents.masked_fill(~link_valid_flag.unsqueeze(-1), 0) | |
| query_latents = self.query_output_norm(query_latents) | |
| return link_latents, query_latents | |
| def forward( | |
| self, | |
| shape_points: Tensor | None = None, | |
| shape_point_normals: Tensor | None = None, | |
| query_points: Tensor | None = None, | |
| query_point_normals: Tensor | None = None, | |
| link_point_prompts: Tensor | None = None, | |
| link_point_prompt_normals: Tensor | None = None, | |
| link_valid_flag: Tensor | None = None, | |
| link_point_prompt_dropout_eligible: Tensor | None = None, | |
| link_text_prompts: Sequence[Sequence[str]] | None = None, | |
| link_text_embeddings: Tensor | None = None, | |
| ) -> Dict[str, Any]: | |
| """Encodes shape, link, and query inputs into latent tokens. | |
| Args: | |
| shape_points: Shape points with shape `(B, Ns, 3)`. | |
| shape_point_normals: Shape normals with shape `(B, Ns, 3)`. | |
| query_points: Query points with shape `(B, Nq, 3)`. | |
| query_point_normals: Query normals with shape `(B, Nq, 3)`. | |
| link_point_prompts: Optional link prompt points with shape | |
| `(B, L, 3)`. When omitted together with | |
| `link_point_prompt_normals`, valid link slots use the learned | |
| no-prompt embedding instead. | |
| link_point_prompt_normals: Optional link prompt normals aligned | |
| with `link_point_prompts`, with shape `(B, L, 3)`. | |
| link_valid_flag: Boolean tensor with shape `(B, L)` indicating | |
| which padded link slots are valid. | |
| link_point_prompt_dropout_eligible: Optional boolean tensor with | |
| shape `(B, L)` indicating which valid links may have their | |
| point prompt and prompt normal dropped during training. | |
| link_text_prompts: Per-link text prompts as a batch list of string | |
| lists, where sample `i` has length `sum(link_valid_flag[i])`. | |
| link_text_embeddings: Optional padded language features with shape | |
| `(B, L, D_lang)`. | |
| """ | |
| if (link_point_prompts is None) != (link_point_prompt_normals is None): | |
| raise ValueError( | |
| "link_point_prompts and link_point_prompt_normals must both be provided or both be None" | |
| ) | |
| if ( | |
| shape_points is None | |
| or shape_point_normals is None | |
| or query_points is None | |
| or query_point_normals is None | |
| or link_valid_flag is None | |
| ): | |
| raise ValueError( | |
| "forward requires shape/query tensors and link_valid_flag" | |
| ) | |
| if ( | |
| self.use_text_conditioning | |
| and link_text_prompts is None | |
| and link_text_embeddings is None | |
| ): | |
| raise ValueError( | |
| "forward requires either link_text_prompts or link_text_embeddings " | |
| "when use_text_conditioning=True" | |
| ) | |
| ( | |
| shape_pretrained_features, | |
| query_pretrained_features, | |
| link_point_prompt_pretrained_features, | |
| ) = ( | |
| self._compute_pretrained_point_features( | |
| shape_points=shape_points, | |
| query_points=query_points, | |
| link_point_prompts=link_point_prompts, | |
| ) | |
| ) | |
| drop_normal_mask = self._sample_all_normals_dropout_mask( | |
| batch_size=int(shape_points.shape[0]), | |
| device=shape_points.device, | |
| ) | |
| shape_latents = self.encode_shape( | |
| shape_points, | |
| shape_point_normals, | |
| pretrained_features=shape_pretrained_features, | |
| drop_normal_mask=drop_normal_mask, | |
| ) | |
| link_latents = self.encode_links( | |
| link_point_prompts=link_point_prompts, | |
| link_point_prompt_normals=link_point_prompt_normals, | |
| link_valid_flag=link_valid_flag, | |
| link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible, | |
| drop_normal_mask=drop_normal_mask, | |
| link_point_prompt_pretrained_features=link_point_prompt_pretrained_features, | |
| link_text_prompts=link_text_prompts, | |
| link_text_embeddings=link_text_embeddings, | |
| ) | |
| query_latents = self._embed_point_tokens( | |
| query_points, | |
| query_point_normals, | |
| pretrained_features=query_pretrained_features, | |
| drop_normal_mask=drop_normal_mask, | |
| ) | |
| link_latents, query_latents = self._run_attention_blocks( | |
| shape_latents=shape_latents, | |
| link_latents=link_latents, | |
| query_latents=query_latents, | |
| link_valid_flag=link_valid_flag, | |
| ) | |
| return { | |
| "shape_latents": shape_latents, | |
| "query_latents": query_latents, | |
| "link_latents": link_latents, | |
| } | |
| def _embed_point_tokens( | |
| self, | |
| points: Tensor, | |
| normals: Tensor, | |
| *, | |
| pretrained_features: Tensor | None = None, | |
| drop_normal_mask: Tensor | None = None, | |
| ) -> Tensor: | |
| point_tokens = self.coordinate_embedder(points) | |
| normal_tokens = self.normal_embedder(normals) | |
| if drop_normal_mask is not None: | |
| if drop_normal_mask.ndim != 1 or drop_normal_mask.shape[0] != points.shape[0]: | |
| raise ValueError( | |
| "drop_normal_mask must have shape (B,), " | |
| f"got {tuple(drop_normal_mask.shape)} for points {tuple(points.shape)}" | |
| ) | |
| drop_normal_mask = drop_normal_mask.to( | |
| device=normal_tokens.device, | |
| dtype=torch.bool, | |
| ) | |
| mask_shape = (drop_normal_mask.shape[0],) + (1,) * (normal_tokens.ndim - 1) | |
| normal_tokens = normal_tokens.masked_fill(drop_normal_mask.view(mask_shape), 0) | |
| if point_tokens.dtype != normal_tokens.dtype: | |
| point_tokens = point_tokens.to(dtype=normal_tokens.dtype) | |
| point_tokens = point_tokens + normal_tokens | |
| if pretrained_features is None: | |
| return point_tokens | |
| if ( | |
| self.pretrained_feature_input_norm is None | |
| or self.pretrained_feature_projector is None | |
| ): | |
| raise ValueError( | |
| "Received pretrained_features but pretrained feature projection is disabled" | |
| ) | |
| pretrained_features = pretrained_features.to( | |
| dtype=self.pretrained_feature_projector[0].weight.dtype | |
| ) | |
| projected_pretrained_features = self.pretrained_feature_projector( | |
| self.pretrained_feature_input_norm(pretrained_features) | |
| ) | |
| if point_tokens.dtype != projected_pretrained_features.dtype: | |
| point_tokens = point_tokens.to(dtype=projected_pretrained_features.dtype) | |
| return point_tokens + projected_pretrained_features | |
| def _get_partfield_feature_extractor(self) -> PartFieldFeatureExtractor: | |
| if self._partfield_feature_extractor is None: | |
| self._partfield_feature_extractor = PartFieldFeatureExtractor() | |
| return self._partfield_feature_extractor | |
| def _compute_pretrained_point_features( | |
| self, | |
| *, | |
| shape_points: Tensor, | |
| query_points: Tensor, | |
| link_point_prompts: Tensor | None = None, | |
| ) -> tuple[Tensor | None, Tensor | None, Tensor | None]: | |
| needs_prompt_features = ( | |
| self.use_pretrained_features_point_prompt | |
| and link_point_prompts is not None | |
| ) | |
| if not ( | |
| self.use_pretrained_features_shape | |
| or self.use_pretrained_features_query | |
| or needs_prompt_features | |
| ): | |
| return None, None, None | |
| decode_segments: list[tuple[str, Tensor]] = [] | |
| if self.use_pretrained_features_query: | |
| decode_segments.append(("query", query_points)) | |
| if needs_prompt_features: | |
| if link_point_prompts.shape[0] != shape_points.shape[0]: | |
| raise ValueError( | |
| "link_point_prompts batch dimension must match shape_points when " | |
| "use_pretrained_features_point_prompt=True, " | |
| f"got {tuple(link_point_prompts.shape)} and {tuple(shape_points.shape)}" | |
| ) | |
| decode_segments.append(("prompt", link_point_prompts)) | |
| decode_query_points = ( | |
| torch.cat([points for _, points in decode_segments], dim=1) | |
| if decode_segments | |
| else None | |
| ) | |
| extractor = self._get_partfield_feature_extractor() | |
| shape_features, combined_decode_features = extractor.extract( | |
| encode_points=shape_points, | |
| decode_shape_points=shape_points if self.use_pretrained_features_shape else None, | |
| decode_query_points=decode_query_points, | |
| ) | |
| query_features = None | |
| prompt_features = None | |
| if combined_decode_features is not None: | |
| if len(decode_segments) == 1: | |
| decoded_features = {decode_segments[0][0]: combined_decode_features} | |
| else: | |
| decoded_features: dict[str, Tensor] = {} | |
| offset = 0 | |
| for name, points in decode_segments: | |
| count = points.shape[1] | |
| decoded_features[name] = combined_decode_features[:, offset : offset + count] | |
| offset += count | |
| query_features = decoded_features.get("query") | |
| prompt_features = decoded_features.get("prompt") | |
| return shape_features, query_features, prompt_features | |
| def _sample_point_prompt_dropout_mask( | |
| self, | |
| link_valid_flag: Tensor, | |
| link_point_prompt_dropout_eligible: Tensor | None = None, | |
| ) -> Tensor: | |
| if not self.training or ( | |
| self.dropout_all_point_prompts_prob == 0.0 | |
| and self.dropout_individual_point_prompt_prob == 0.0 | |
| ): | |
| return torch.zeros_like(link_valid_flag) | |
| batch_size = link_valid_flag.shape[0] | |
| device = link_valid_flag.device | |
| if link_point_prompt_dropout_eligible is None: | |
| eligible_link_mask = link_valid_flag | |
| else: | |
| if link_point_prompt_dropout_eligible.shape != link_valid_flag.shape: | |
| raise ValueError( | |
| "link_point_prompt_dropout_eligible must match link_valid_flag shape, " | |
| f"got {tuple(link_point_prompt_dropout_eligible.shape)} and " | |
| f"{tuple(link_valid_flag.shape)}" | |
| ) | |
| eligible_link_mask = link_valid_flag & link_point_prompt_dropout_eligible | |
| drop_mask = torch.zeros_like(link_valid_flag) | |
| drop_all_samples = torch.zeros(batch_size, device=device, dtype=torch.bool) | |
| if self.dropout_all_point_prompts_prob > 0.0: | |
| drop_all_samples = ( | |
| torch.rand(batch_size, device=device) < self.dropout_all_point_prompts_prob | |
| ) | |
| drop_mask[drop_all_samples] = eligible_link_mask[drop_all_samples] | |
| if self.dropout_individual_point_prompt_prob > 0.0: | |
| keep_individual_dropout = ~drop_all_samples | |
| individual_drop_mask = ( | |
| torch.rand(link_valid_flag.shape, device=device) | |
| < self.dropout_individual_point_prompt_prob | |
| ) | |
| drop_mask[keep_individual_dropout] = ( | |
| individual_drop_mask[keep_individual_dropout] | |
| & eligible_link_mask[keep_individual_dropout] | |
| ) | |
| return drop_mask | |
| def _sample_all_normals_dropout_mask( | |
| self, | |
| *, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> Tensor: | |
| if not self.training or self.dropout_all_normals_prob == 0.0: | |
| return torch.zeros((batch_size,), device=device, dtype=torch.bool) | |
| return torch.rand(batch_size, device=device) < self.dropout_all_normals_prob | |
| def _resolve_no_point_prompt_mask( | |
| self, | |
| *, | |
| link_point_prompts: Tensor | None, | |
| link_valid_flag: Tensor, | |
| link_point_prompt_dropout_eligible: Tensor | None = None, | |
| forced_no_point_prompt_mask: Tensor | None = None, | |
| ) -> Tensor: | |
| if forced_no_point_prompt_mask is not None: | |
| if forced_no_point_prompt_mask.shape != link_valid_flag.shape: | |
| raise ValueError( | |
| "forced_no_point_prompt_mask must match link_valid_flag shape, " | |
| f"got {tuple(forced_no_point_prompt_mask.shape)} and " | |
| f"{tuple(link_valid_flag.shape)}" | |
| ) | |
| forced_no_point_prompt_mask = ( | |
| forced_no_point_prompt_mask.to( | |
| device=link_valid_flag.device, | |
| dtype=torch.bool, | |
| ) | |
| & link_valid_flag | |
| ) | |
| if link_point_prompts is None: | |
| no_point_prompt_mask = link_valid_flag.clone() | |
| else: | |
| no_point_prompt_mask = self._sample_point_prompt_dropout_mask( | |
| link_valid_flag, | |
| link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible, | |
| ) | |
| if forced_no_point_prompt_mask is not None: | |
| no_point_prompt_mask = no_point_prompt_mask | forced_no_point_prompt_mask | |
| return no_point_prompt_mask & link_valid_flag | |
| def _sample_text_conditioning_dropout_mask( | |
| self, | |
| link_valid_flag: Tensor, | |
| *, | |
| no_point_prompt_mask: Tensor | None = None, | |
| ) -> Tensor: | |
| if ( | |
| not self.training | |
| or self.dropout_individual_text_conditioning_prob == 0.0 | |
| ): | |
| return torch.zeros_like(link_valid_flag) | |
| if no_point_prompt_mask is None: | |
| text_dropout_eligible_mask = link_valid_flag | |
| else: | |
| if no_point_prompt_mask.shape != link_valid_flag.shape: | |
| raise ValueError( | |
| "no_point_prompt_mask must match link_valid_flag shape, " | |
| f"got {tuple(no_point_prompt_mask.shape)} and " | |
| f"{tuple(link_valid_flag.shape)}" | |
| ) | |
| text_dropout_eligible_mask = link_valid_flag & ~no_point_prompt_mask.to( | |
| device=link_valid_flag.device, | |
| dtype=torch.bool, | |
| ) | |
| return ( | |
| torch.rand(link_valid_flag.shape, device=link_valid_flag.device) | |
| < self.dropout_individual_text_conditioning_prob | |
| ) & text_dropout_eligible_mask | |
| def _resolve_valid_link_point_features( | |
| self, | |
| link_point_prompts: Tensor | None, | |
| link_point_prompt_normals: Tensor | None, | |
| link_valid_flag: Tensor, | |
| no_point_prompt_mask: Tensor | None = None, | |
| drop_normal_mask: Tensor | None = None, | |
| link_point_prompt_pretrained_features: Tensor | None = None, | |
| ) -> Tensor: | |
| if no_point_prompt_mask is not None: | |
| if no_point_prompt_mask.shape != link_valid_flag.shape: | |
| raise ValueError( | |
| "no_point_prompt_mask must match link_valid_flag shape, " | |
| f"got {tuple(no_point_prompt_mask.shape)} and " | |
| f"{tuple(link_valid_flag.shape)}" | |
| ) | |
| no_point_prompt_mask = ( | |
| no_point_prompt_mask.to( | |
| device=link_valid_flag.device, | |
| dtype=torch.bool, | |
| ) | |
| & link_valid_flag | |
| ) | |
| if link_point_prompts is None: | |
| if link_point_prompt_pretrained_features is not None: | |
| raise ValueError( | |
| "link_point_prompt_pretrained_features requires link_point_prompts" | |
| ) | |
| no_prompt_embedding = self.no_point_prompt_embedding.to( | |
| device=link_valid_flag.device, | |
| dtype=self.no_point_prompt_embedding.dtype, | |
| ) | |
| return no_prompt_embedding.view(1, 1, -1).expand(*link_valid_flag.shape, -1)[link_valid_flag] | |
| point_features = self._embed_point_tokens( | |
| link_point_prompts, | |
| link_point_prompt_normals, | |
| pretrained_features=link_point_prompt_pretrained_features, | |
| drop_normal_mask=drop_normal_mask, | |
| )[link_valid_flag] | |
| if no_point_prompt_mask is None: | |
| no_point_prompt_mask = torch.zeros_like(link_valid_flag) | |
| dropped_valid_links = no_point_prompt_mask[link_valid_flag] | |
| return torch.where( | |
| dropped_valid_links.unsqueeze(-1), | |
| self.no_point_prompt_embedding.to( | |
| device=point_features.device, | |
| dtype=point_features.dtype, | |
| ).unsqueeze(0).expand_as(point_features), | |
| point_features, | |
| ) | |
| def _resolve_valid_link_text_features( | |
| self, | |
| link_text_prompts: Optional[Sequence[Sequence[str]]], | |
| link_text_embeddings: Tensor | None, | |
| link_valid_flag: Tensor, | |
| ) -> Tensor: | |
| """Returns text features flattened in the same order as `link_valid_flag`.""" | |
| if link_text_embeddings is not None: | |
| return link_text_embeddings[link_valid_flag].to( | |
| dtype=self.link_text_projector[0].weight.dtype | |
| ) | |
| if link_text_prompts is None: | |
| flattened_prompts = [""] * int(link_valid_flag.sum().item()) | |
| else: | |
| flattened_prompts = list(chain.from_iterable(link_text_prompts)) | |
| return self._encode_link_text_prompts( | |
| flattened_prompts, | |
| dtype=self.link_text_projector[0].weight.dtype, | |
| ) | |
| def _ensure_text_model_loaded(self) -> None: | |
| if self.text_model is not None and self.text_tokenizer is not None: | |
| return | |
| cache_dir = os.environ.get("HF_HOME") | |
| self.text_tokenizer, self.text_model = load_clip_text_encoder( | |
| self.clip_model_name, | |
| device=self.link_text_projector[0].weight.device, | |
| expected_embedding_dim=self.link_text_feature_dim, | |
| cache_dir=cache_dir, | |
| ) | |
| def _encode_link_text_prompts( | |
| self, | |
| prompts: Sequence[str], | |
| dtype: torch.dtype, | |
| ) -> Tensor: | |
| if len(prompts) == 0: | |
| return torch.zeros( | |
| (0, self.link_text_feature_dim), | |
| device=self.link_text_projector[0].weight.device, | |
| dtype=dtype, | |
| ) | |
| if not self.compute_link_text_embeddings_on_the_fly: | |
| raise ValueError( | |
| "Missing link_text_embeddings but compute_link_text_embeddings_on_the_fly=False. " | |
| "Provide dataset-side text embeddings or enable on-the-fly CLIP text encoding." | |
| ) | |
| self._ensure_text_model_loaded() | |
| return encode_clip_text_prompts( | |
| prompts, | |
| tokenizer=self.text_tokenizer, | |
| text_model=self.text_model, | |
| batch_size=self.clip_text_batch_size, | |
| output_device=self.link_text_projector[0].weight.device, | |
| output_dtype=dtype, | |
| ) | |
| class Particulate2ArticulationModel(nn.Module): | |
| """Encoder, decoders, and training losses for articulation prediction.""" | |
| def __init__( | |
| self, | |
| encoder: Optional[Particulate2Encoder] = None, | |
| *, | |
| segmentation_decode_type: str = "mlp", | |
| segmentation_query_chunk_size: Optional[int] = None, | |
| joint_decode_type: str = "overparameterization", | |
| joint_fm_hidden_dim: Optional[int] = None, | |
| joint_fm_prediction_type: str = "v", | |
| joint_fm_time_embedding_dim: int = 256, | |
| joint_fm_inference_steps: int = 100, | |
| joint_fm_time_scale: float = 1000.0, | |
| joint_fm_sigma_min: float = 0.0, | |
| joint_fm_rescale_t: float = 1.0, | |
| joint_fm_cfg_scale: float = 1.0, | |
| revolute_joint_fm_state_mean: Optional[Sequence[float]] = None, | |
| revolute_joint_fm_state_std: Optional[Sequence[float]] = None, | |
| prismatic_joint_fm_state_mean: Optional[Sequence[float]] = None, | |
| prismatic_joint_fm_state_std: Optional[Sequence[float]] = None, | |
| joint_fm_training_time_mean: float = 0.0, | |
| joint_fm_training_time_std: float = 1.0, | |
| use_ancestor_context_for_segmentation: bool = False, | |
| ancestor_context_decay: float = 0.5, | |
| segmentation_bias: bool = True, | |
| joint_decoder_bias: bool = True, | |
| segmentation_cross_entropy_weight: float = 1.0, | |
| segmentation_dice_weight: float = 1.0, | |
| revolute_joint_axis_l1_weight: float = 1.0, | |
| prismatic_joint_axis_l1_weight: float = 1.0, | |
| revolute_joint_range_l1_weight: float = 1.0, | |
| prismatic_joint_range_l1_weight: float = 1.0, | |
| revolute_joint_fm_weight: float = 1.0, | |
| prismatic_joint_fm_weight: float = 1.0, | |
| revolute_overparam_axis_l1_weight: float = 1.0, | |
| revolute_overparam_point_l1_weight: float = 1.0, | |
| prismatic_overparam_axis_l1_weight: float = 1.0, | |
| prismatic_overparam_point_l1_weight: float = 1.0, | |
| revolute_overparam_direction_weight: float = 1.0, | |
| prismatic_overparam_direction_weight: float = 1.0, | |
| overparam_closest_axis_space: str = "world", | |
| dice_smoothing: float = 1e-6, | |
| **encoder_kwargs: Any, | |
| ): | |
| super().__init__() | |
| if encoder is not None and encoder_kwargs: | |
| raise ValueError("Pass either an encoder instance or encoder kwargs, not both") | |
| self.encoder = encoder if encoder is not None else Particulate2Encoder(**encoder_kwargs) | |
| joint_decode_type = _normalize_joint_decode_type(joint_decode_type) | |
| if joint_decode_type not in {*_PLAIN_JOINT_DECODE_TYPES, *_OVERPARAM_JOINT_DECODE_TYPES}: | |
| raise ValueError( | |
| "joint_decode_type must be 'plain', 'plain+fm', 'overparametrized', " | |
| "'overparameterized', 'overparameterization', 'overparam+dir', or " | |
| "'overparam+singledir', " | |
| f"got {joint_decode_type!r}" | |
| ) | |
| self.joint_decode_type = joint_decode_type | |
| self.plain_flow_matching_enabled = self.joint_decode_type == "plain+fm" | |
| self.overparam_predicts_query_axis_direction = self.joint_decode_type == "overparam+dir" | |
| self.overparam_predicts_single_axis_direction = ( | |
| self.joint_decode_type == "overparam+singledir" | |
| ) | |
| self.overparam_uses_axis_direction = ( | |
| self.overparam_predicts_query_axis_direction | |
| or self.overparam_predicts_single_axis_direction | |
| ) | |
| overparam_closest_axis_space = _normalize_overparam_closest_axis_space( | |
| overparam_closest_axis_space | |
| ) | |
| if overparam_closest_axis_space not in {"world", "part_aabb"}: | |
| raise ValueError( | |
| "overparam_closest_axis_space must be 'world', 'sample', 'part_aabb', or " | |
| f"'local_aabb', got {overparam_closest_axis_space!r}" | |
| ) | |
| self.overparam_closest_axis_space = overparam_closest_axis_space | |
| self.overparam_closest_axis_uses_part_aabb = ( | |
| self.overparam_closest_axis_space == "part_aabb" | |
| ) | |
| self.use_ancestor_context_for_segmentation = bool( | |
| use_ancestor_context_for_segmentation | |
| ) | |
| self.ancestor_context_decay = float(ancestor_context_decay) | |
| if not math.isfinite(self.ancestor_context_decay): | |
| raise ValueError( | |
| "ancestor_context_decay must be finite, " | |
| f"got {ancestor_context_decay!r}" | |
| ) | |
| if self.ancestor_context_decay < 0.0 or self.ancestor_context_decay > 1.0: | |
| raise ValueError( | |
| "ancestor_context_decay must be in [0, 1], " | |
| f"got {ancestor_context_decay!r}" | |
| ) | |
| self.segmentation_decoder = SegmentationDecoder( | |
| model_dim=self.encoder.model_dim, | |
| decode_type=segmentation_decode_type, | |
| bias=segmentation_bias, | |
| query_chunk_size=segmentation_query_chunk_size, | |
| ) | |
| self.joint_direction_decoder: JointDecoderSingleDirection | None = None | |
| if self.joint_decode_type == "plain": | |
| self.joint_decoder = JointDecoderPlain( | |
| model_dim=self.encoder.model_dim, | |
| bias=joint_decoder_bias, | |
| ) | |
| elif self.plain_flow_matching_enabled: | |
| self.joint_decoder = JointDecoderPlainFlowMatching( | |
| model_dim=self.encoder.model_dim, | |
| hidden_dim=joint_fm_hidden_dim, | |
| prediction_type=joint_fm_prediction_type, | |
| time_embedding_dim=joint_fm_time_embedding_dim, | |
| inference_steps=joint_fm_inference_steps, | |
| time_scale=joint_fm_time_scale, | |
| sigma_min=joint_fm_sigma_min, | |
| rescale_t=joint_fm_rescale_t, | |
| cfg_scale=joint_fm_cfg_scale, | |
| revolute_state_mean=revolute_joint_fm_state_mean, | |
| revolute_state_std=revolute_joint_fm_state_std, | |
| prismatic_state_mean=prismatic_joint_fm_state_mean, | |
| prismatic_state_std=prismatic_joint_fm_state_std, | |
| bias=joint_decoder_bias, | |
| ) | |
| else: | |
| self.joint_decoder = JointDecoderOverParametrized( | |
| model_dim=self.encoder.model_dim, | |
| output_dim=12 if self.overparam_predicts_query_axis_direction else 9, | |
| bias=joint_decoder_bias, | |
| ) | |
| if self.overparam_predicts_single_axis_direction: | |
| self.joint_direction_decoder = JointDecoderSingleDirection( | |
| model_dim=self.encoder.model_dim, | |
| bias=joint_decoder_bias, | |
| ) | |
| if self.use_ancestor_context_for_segmentation: | |
| self.ancestor_context_projector = nn.Linear( | |
| self.encoder.model_dim, | |
| self.encoder.model_dim, | |
| bias=False, | |
| ) | |
| self.ancestor_context_gate = _make_silu_mlp( | |
| input_dim=self.encoder.model_dim * 2, | |
| hidden_dim=self.encoder.model_dim, | |
| output_dim=1, | |
| ) | |
| else: | |
| self.ancestor_context_projector = None | |
| self.ancestor_context_gate = None | |
| self.segmentation_cross_entropy_weight = float(segmentation_cross_entropy_weight) | |
| self.segmentation_dice_weight = float(segmentation_dice_weight) | |
| self.revolute_joint_axis_l1_weight = float(revolute_joint_axis_l1_weight) | |
| self.prismatic_joint_axis_l1_weight = float(prismatic_joint_axis_l1_weight) | |
| self.revolute_joint_range_l1_weight = float(revolute_joint_range_l1_weight) | |
| self.prismatic_joint_range_l1_weight = float(prismatic_joint_range_l1_weight) | |
| self.revolute_joint_fm_weight = float(revolute_joint_fm_weight) | |
| self.prismatic_joint_fm_weight = float(prismatic_joint_fm_weight) | |
| self.revolute_overparam_axis_l1_weight = float(revolute_overparam_axis_l1_weight) | |
| self.revolute_overparam_point_l1_weight = float(revolute_overparam_point_l1_weight) | |
| self.prismatic_overparam_axis_l1_weight = float(prismatic_overparam_axis_l1_weight) | |
| self.prismatic_overparam_point_l1_weight = float(prismatic_overparam_point_l1_weight) | |
| self.revolute_overparam_direction_weight = float(revolute_overparam_direction_weight) | |
| self.prismatic_overparam_direction_weight = float(prismatic_overparam_direction_weight) | |
| self.joint_fm_training_time_mean = float(joint_fm_training_time_mean) | |
| self.joint_fm_training_time_std = float(joint_fm_training_time_std) | |
| self.dice_smoothing = float(dice_smoothing) | |
| def decode_segmentation( | |
| self, | |
| query_latents: Tensor, | |
| link_latents: Tensor, | |
| link_valid_flag: Tensor, | |
| ) -> Tensor: | |
| return self.segmentation_decoder( | |
| query_latents=query_latents, | |
| link_latents=link_latents, | |
| link_valid_flag=link_valid_flag, | |
| ) | |
| def decode_joint_parameters( | |
| self, | |
| *, | |
| link_latents: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| is_revolute: Tensor, | |
| is_prismatic: Tensor, | |
| query_latents: Tensor | None = None, | |
| query_points: Tensor | None = None, | |
| assigned_link_ids: Tensor | None = None, | |
| decoded_motion_points: Tuple[Tensor, Tensor] | None = None, | |
| decoded_axis_directions: Tuple[Tensor, Tensor] | None = None, | |
| decoded_motion_points_are_world: bool = False, | |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
| """Returns revolute/prismatic axis and range tensors for the configured decode mode. | |
| In plain mode the joint decoder predicts per-joint parameters directly from | |
| parent/child link latents. In plain+fm mode it samples parameters with | |
| a SAM3D-style flow-matching Euler sampler. In over-parameterized mode it first obtains | |
| query-wise motion targets and then fits a single joint axis and range for | |
| each child link from those targets. | |
| """ | |
| if self.joint_decode_type == "plain": | |
| return self.joint_decoder.predict( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| ) | |
| if self.plain_flow_matching_enabled: | |
| return self.joint_decoder.predict( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| ) | |
| if query_points is None or assigned_link_ids is None: | |
| raise ValueError( | |
| "over-parameterized joint decoding requires query_points and assigned_link_ids" | |
| ) | |
| if decoded_motion_points is None: | |
| if query_latents is None: | |
| raise ValueError( | |
| "over-parameterized joint decoding requires query_latents when decoded_motion_points are not provided" | |
| ) | |
| decoded_motion_points = self._decode_joint_motion_points( | |
| query_latents=query_latents, | |
| link_latents=link_latents, | |
| assigned_link_ids=assigned_link_ids, | |
| joint_connections=joint_connections, | |
| ) | |
| if self.overparam_predicts_single_axis_direction and decoded_axis_directions is None: | |
| decoded_axis_directions = self._decode_joint_axis_directions( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| ) | |
| if not decoded_motion_points_are_world: | |
| decoded_motion_points = ( | |
| self._convert_overparam_motion_points_to_world_coordinates( | |
| motion_points=decoded_motion_points[0], | |
| query_points=query_points, | |
| assigned_link_ids=assigned_link_ids, | |
| ), | |
| self._convert_overparam_motion_points_to_world_coordinates( | |
| motion_points=decoded_motion_points[1], | |
| query_points=query_points, | |
| assigned_link_ids=assigned_link_ids, | |
| ), | |
| ) | |
| try: | |
| autocast_context = torch.autocast(device_type=query_points.device.type, enabled=False) | |
| except (RuntimeError, TypeError, ValueError): | |
| autocast_context = nullcontext() | |
| with autocast_context: | |
| return self._recover_overparam_joint_parameters( | |
| query_points=query_points, | |
| assigned_link_ids=assigned_link_ids, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| revolute_motion_points=decoded_motion_points[0], | |
| prismatic_motion_points=decoded_motion_points[1], | |
| revolute_axis_directions=( | |
| None if decoded_axis_directions is None else decoded_axis_directions[0] | |
| ), | |
| prismatic_axis_directions=( | |
| None if decoded_axis_directions is None else decoded_axis_directions[1] | |
| ), | |
| ) | |
| def _decode_joint_motion_points( | |
| self, | |
| *, | |
| query_latents: Tensor, | |
| link_latents: Tensor, | |
| assigned_link_ids: Tensor, | |
| joint_connections: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| if self.joint_decode_type not in _OVERPARAM_JOINT_DECODE_TYPES: | |
| raise ValueError( | |
| "_decode_joint_motion_points is only available in over-parameterized joint decoding mode" | |
| ) | |
| return self.joint_decoder( | |
| query_latents=query_latents, | |
| link_latents=link_latents, | |
| assigned_link_ids=assigned_link_ids, | |
| joint_connections=joint_connections, | |
| ) | |
| def _decode_joint_axis_directions( | |
| self, | |
| *, | |
| link_latents: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| is_revolute: Tensor, | |
| is_prismatic: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| if not self.overparam_predicts_single_axis_direction or self.joint_direction_decoder is None: | |
| raise ValueError( | |
| "_decode_joint_axis_directions is only available for joint_decode_type='overparam+singledir'" | |
| ) | |
| return self.joint_direction_decoder.predict( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| ) | |
| def _compute_query_link_aabb_parameters( | |
| self, | |
| *, | |
| query_points: Tensor, | |
| link_ids: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """Returns per-query AABB centers and half-extents for the assigned link ID.""" | |
| if query_points.ndim != 3 or query_points.shape[-1] != 3: | |
| raise ValueError( | |
| f"query_points must have shape (B, Q, 3), got {tuple(query_points.shape)}" | |
| ) | |
| if link_ids.shape != query_points.shape[:2]: | |
| raise ValueError( | |
| "link_ids must match query_points batch/query dims, " | |
| f"got {tuple(link_ids.shape)} and {tuple(query_points.shape)}" | |
| ) | |
| query_points = query_points.float() | |
| centers = query_points.new_zeros(query_points.shape) | |
| half_extents = query_points.new_ones(query_points.shape) | |
| for batch_idx in range(query_points.shape[0]): | |
| batch_link_ids = link_ids[batch_idx] | |
| valid_mask = batch_link_ids >= 0 | |
| if not bool(valid_mask.any().item()): | |
| continue | |
| unique_link_ids = torch.unique(batch_link_ids[valid_mask]) | |
| for link_id in unique_link_ids.tolist(): | |
| query_mask = batch_link_ids == int(link_id) | |
| link_query_points = query_points[batch_idx][query_mask] | |
| min_corner = link_query_points.min(dim=0).values | |
| max_corner = link_query_points.max(dim=0).values | |
| centers[batch_idx][query_mask] = 0.5 * (min_corner + max_corner) | |
| half_extents[batch_idx][query_mask] = ( | |
| 0.5 * (max_corner - min_corner) | |
| ).clamp_min(_OVERPARAM_AXIS_AABB_HALF_EXTENT_MIN) | |
| return centers, half_extents | |
| def _denormalize_overparam_axis_points( | |
| self, | |
| *, | |
| axis_points: Tensor, | |
| query_points: Tensor, | |
| link_ids: Tensor, | |
| ) -> Tensor: | |
| """Converts normalized closest-axis points back into sample/world coordinates.""" | |
| if not self.overparam_closest_axis_uses_part_aabb: | |
| return axis_points | |
| centers, half_extents = self._compute_query_link_aabb_parameters( | |
| query_points=query_points, | |
| link_ids=link_ids, | |
| ) | |
| return axis_points.float() * half_extents + centers | |
| def _convert_overparam_motion_points_to_world_coordinates( | |
| self, | |
| *, | |
| motion_points: Tensor, | |
| query_points: Tensor, | |
| assigned_link_ids: Tensor, | |
| ) -> Tensor: | |
| """Returns motion-point predictions in sample/world coordinates.""" | |
| if not self.overparam_closest_axis_uses_part_aabb: | |
| return motion_points | |
| closest_axis_points = self._denormalize_overparam_axis_points( | |
| axis_points=motion_points[..., :3], | |
| query_points=query_points, | |
| link_ids=assigned_link_ids, | |
| ) | |
| return torch.cat((closest_axis_points, motion_points[..., 3:]), dim=-1) | |
| def _fit_revolute_joint_parameters( | |
| self, | |
| query_points: Tensor, | |
| revolute_motion_points: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """Fits one revolute axis and low/high angles from query-wise motion targets.""" | |
| # This solver is small and numerically sensitive; keep it in fp32 even | |
| # when the surrounding network runs under AMP. | |
| query_points = query_points.float() | |
| revolute_motion_points = revolute_motion_points.float() | |
| if query_points.numel() == 0: | |
| zero_axis = revolute_motion_points.new_zeros(6) | |
| zero_range = revolute_motion_points.new_zeros(2) | |
| return zero_axis, zero_range | |
| closest_axis_points = revolute_motion_points[..., :3] | |
| low_points = revolute_motion_points[..., 3:6] | |
| high_points = revolute_motion_points[..., 6:9] | |
| direction_hint = ( | |
| torch.linalg.cross( | |
| query_points - closest_axis_points, | |
| low_points - closest_axis_points, | |
| dim=-1, | |
| ) | |
| + torch.linalg.cross( | |
| query_points - closest_axis_points, | |
| high_points - closest_axis_points, | |
| dim=-1, | |
| ) | |
| + torch.linalg.cross( | |
| low_points - closest_axis_points, | |
| high_points - closest_axis_points, | |
| dim=-1, | |
| ) | |
| ).mean(dim=0) | |
| axis_direction, axis_point = fit_axis_to_closest_points_torch( | |
| query_points, | |
| closest_axis_points, | |
| direction_hint=direction_hint, | |
| ) | |
| revolute_axis = axis_point_to_plucker_torch(axis_direction, axis_point) | |
| low_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| high_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| return revolute_axis, torch.stack((low_limit, high_limit)) | |
| def _aggregate_flip_invariant_axis_direction( | |
| self, | |
| predicted_directions: Tensor, | |
| *, | |
| sign_hint: Tensor | None = None, | |
| ) -> Tensor: | |
| """Returns one averaged unit direction while treating per-query flips as equivalent.""" | |
| predicted_directions = predicted_directions.float() | |
| if predicted_directions.numel() == 0: | |
| return predicted_directions.new_zeros(3) | |
| direction_norms = torch.linalg.vector_norm(predicted_directions, dim=-1) | |
| valid_mask = direction_norms > 1e-8 | |
| if not bool(valid_mask.any().item()): | |
| return predicted_directions.new_zeros(3) | |
| unit_directions = predicted_directions[valid_mask] / direction_norms[valid_mask].unsqueeze(-1) | |
| direction_covariance = unit_directions.transpose(0, 1) @ unit_directions | |
| _, eigenvectors = torch.linalg.eigh(direction_covariance) | |
| anchor_direction = eigenvectors[:, -1] | |
| alignment = torch.sign(unit_directions @ anchor_direction) | |
| alignment = torch.where( | |
| alignment == 0, | |
| torch.ones_like(alignment), | |
| alignment, | |
| ) | |
| aligned_mean_direction = (unit_directions * alignment.unsqueeze(-1)).mean(dim=0) | |
| if float(torch.linalg.vector_norm(aligned_mean_direction).item()) <= 1e-8: | |
| axis_direction = F.normalize(anchor_direction, dim=0, eps=1e-8) | |
| else: | |
| axis_direction = F.normalize(aligned_mean_direction, dim=0, eps=1e-8) | |
| if sign_hint is not None: | |
| sign_hint = sign_hint.float() | |
| if float(torch.linalg.vector_norm(sign_hint).item()) > 1e-8: | |
| if float(torch.dot(axis_direction, sign_hint).item()) < 0.0: | |
| axis_direction = -axis_direction | |
| return axis_direction | |
| def _fit_revolute_joint_parameters_with_direction( | |
| self, | |
| query_points: Tensor, | |
| revolute_motion_points: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """Fits one revolute axis from predicted axis directions plus point targets.""" | |
| query_points = query_points.float() | |
| revolute_motion_points = revolute_motion_points.float() | |
| if query_points.numel() == 0: | |
| zero_axis = revolute_motion_points.new_zeros(6) | |
| zero_range = revolute_motion_points.new_zeros(2) | |
| return zero_axis, zero_range | |
| closest_axis_points = revolute_motion_points[..., :3] | |
| low_points = revolute_motion_points[..., 3:6] | |
| high_points = revolute_motion_points[..., 6:9] | |
| predicted_axis_directions = revolute_motion_points[..., 9:12] | |
| direction_sign_hint = ( | |
| torch.linalg.cross( | |
| query_points - closest_axis_points, | |
| low_points - closest_axis_points, | |
| dim=-1, | |
| ) | |
| + torch.linalg.cross( | |
| query_points - closest_axis_points, | |
| high_points - closest_axis_points, | |
| dim=-1, | |
| ) | |
| + torch.linalg.cross( | |
| low_points - closest_axis_points, | |
| high_points - closest_axis_points, | |
| dim=-1, | |
| ) | |
| ).mean(dim=0) | |
| axis_direction = self._aggregate_flip_invariant_axis_direction( | |
| predicted_axis_directions, | |
| sign_hint=direction_sign_hint, | |
| ) | |
| if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8: | |
| return self._fit_revolute_joint_parameters( | |
| query_points=query_points, | |
| revolute_motion_points=revolute_motion_points[..., :9], | |
| ) | |
| axis_point = torch.quantile(closest_axis_points, 0.5, dim=0) | |
| low_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| high_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| if float(low_limit.item()) > float(high_limit.item()): | |
| axis_direction = -axis_direction | |
| low_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| high_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| revolute_axis = axis_point_to_plucker_torch(axis_direction, axis_point) | |
| return revolute_axis, torch.stack((low_limit, high_limit)) | |
| def _fit_prismatic_joint_parameters( | |
| self, | |
| query_points: Tensor, | |
| prismatic_motion_points: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """Fits one prismatic axis and low/high displacements from query-wise targets.""" | |
| # This solver is small and numerically sensitive; keep it in fp32 even | |
| # when the surrounding network runs under AMP. | |
| query_points = query_points.float() | |
| prismatic_motion_points = prismatic_motion_points.float() | |
| if query_points.numel() == 0: | |
| zero_axis = prismatic_motion_points.new_zeros(6) | |
| zero_range = prismatic_motion_points.new_zeros(2) | |
| return zero_axis, zero_range | |
| closest_axis_points = prismatic_motion_points[..., :3] | |
| low_points = prismatic_motion_points[..., 3:6] | |
| high_points = prismatic_motion_points[..., 6:9] | |
| direction_hint = (high_points - low_points).mean(dim=0) | |
| axis_direction, axis_point = fit_axis_to_closest_points_torch( | |
| query_points, | |
| closest_axis_points, | |
| direction_hint=direction_hint, | |
| ) | |
| prismatic_axis = axis_point_to_plucker_torch(axis_direction, axis_point) | |
| low_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| ) | |
| high_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| ) | |
| return prismatic_axis, torch.stack((low_limit, high_limit)) | |
| def _fit_prismatic_joint_parameters_with_direction( | |
| self, | |
| query_points: Tensor, | |
| prismatic_motion_points: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """Fits one prismatic axis from predicted axis directions plus point targets.""" | |
| query_points = query_points.float() | |
| prismatic_motion_points = prismatic_motion_points.float() | |
| if query_points.numel() == 0: | |
| zero_axis = prismatic_motion_points.new_zeros(6) | |
| zero_range = prismatic_motion_points.new_zeros(2) | |
| return zero_axis, zero_range | |
| closest_axis_points = prismatic_motion_points[..., :3] | |
| low_points = prismatic_motion_points[..., 3:6] | |
| high_points = prismatic_motion_points[..., 6:9] | |
| predicted_axis_directions = prismatic_motion_points[..., 9:12] | |
| axis_direction = self._aggregate_flip_invariant_axis_direction( | |
| predicted_axis_directions, | |
| sign_hint=(high_points - low_points).mean(dim=0), | |
| ) | |
| if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8: | |
| return self._fit_prismatic_joint_parameters( | |
| query_points=query_points, | |
| prismatic_motion_points=prismatic_motion_points[..., :9], | |
| ) | |
| axis_point = torch.quantile(closest_axis_points, 0.5, dim=0) | |
| low_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| ) | |
| high_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| ) | |
| if float(low_limit.item()) > float(high_limit.item()): | |
| axis_direction = -axis_direction | |
| low_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| ) | |
| high_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| ) | |
| prismatic_axis = axis_point_to_plucker_torch(axis_direction, axis_point) | |
| return prismatic_axis, torch.stack((low_limit, high_limit)) | |
| def _fit_revolute_joint_parameters_with_single_direction( | |
| self, | |
| query_points: Tensor, | |
| revolute_motion_points: Tensor, | |
| axis_direction: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """Fits one revolute axis from a single predicted axis direction plus point targets.""" | |
| query_points = query_points.float() | |
| revolute_motion_points = revolute_motion_points.float() | |
| axis_direction = axis_direction.float() | |
| if query_points.numel() == 0: | |
| zero_axis = revolute_motion_points.new_zeros(6) | |
| zero_range = revolute_motion_points.new_zeros(2) | |
| return zero_axis, zero_range | |
| if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8: | |
| return self._fit_revolute_joint_parameters( | |
| query_points=query_points, | |
| revolute_motion_points=revolute_motion_points, | |
| ) | |
| axis_direction = F.normalize(axis_direction, dim=0, eps=1e-8) | |
| closest_axis_points = revolute_motion_points[..., :3] | |
| low_points = revolute_motion_points[..., 3:6] | |
| high_points = revolute_motion_points[..., 6:9] | |
| axis_point = torch.quantile(closest_axis_points, 0.5, dim=0) | |
| low_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| high_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| if float(low_limit.item()) > float(high_limit.item()): | |
| axis_direction = -axis_direction | |
| low_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| high_limit = estimate_revolute_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| axis_point, | |
| ) | |
| revolute_axis = axis_point_to_plucker_torch(axis_direction, axis_point) | |
| return revolute_axis, torch.stack((low_limit, high_limit)) | |
| def _fit_prismatic_joint_parameters_with_single_direction( | |
| self, | |
| query_points: Tensor, | |
| prismatic_motion_points: Tensor, | |
| axis_direction: Tensor, | |
| ) -> Tuple[Tensor, Tensor]: | |
| """Fits one prismatic axis from a single predicted axis direction plus point targets.""" | |
| query_points = query_points.float() | |
| prismatic_motion_points = prismatic_motion_points.float() | |
| axis_direction = axis_direction.float() | |
| if query_points.numel() == 0: | |
| zero_axis = prismatic_motion_points.new_zeros(6) | |
| zero_range = prismatic_motion_points.new_zeros(2) | |
| return zero_axis, zero_range | |
| if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8: | |
| return self._fit_prismatic_joint_parameters( | |
| query_points=query_points, | |
| prismatic_motion_points=prismatic_motion_points, | |
| ) | |
| axis_direction = F.normalize(axis_direction, dim=0, eps=1e-8) | |
| closest_axis_points = prismatic_motion_points[..., :3] | |
| low_points = prismatic_motion_points[..., 3:6] | |
| high_points = prismatic_motion_points[..., 6:9] | |
| axis_point = torch.quantile(closest_axis_points, 0.5, dim=0) | |
| low_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| ) | |
| high_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| ) | |
| if float(low_limit.item()) > float(high_limit.item()): | |
| axis_direction = -axis_direction | |
| low_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| low_points, | |
| axis_direction, | |
| ) | |
| high_limit = estimate_prismatic_limit_torch( | |
| query_points, | |
| high_points, | |
| axis_direction, | |
| ) | |
| prismatic_axis = axis_point_to_plucker_torch(axis_direction, axis_point) | |
| return prismatic_axis, torch.stack((low_limit, high_limit)) | |
| def _recover_overparam_joint_parameters( | |
| self, | |
| *, | |
| query_points: Tensor, | |
| assigned_link_ids: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| is_revolute: Tensor, | |
| is_prismatic: Tensor, | |
| revolute_motion_points: Tensor, | |
| prismatic_motion_points: Tensor, | |
| revolute_axis_directions: Tensor | None = None, | |
| prismatic_axis_directions: Tensor | None = None, | |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
| """Recovers per-joint parameters by fitting each child link's query-wise targets.""" | |
| query_points = query_points.float() | |
| revolute_motion_points = revolute_motion_points.float() | |
| prismatic_motion_points = prismatic_motion_points.float() | |
| batch_size, max_joints = joint_connections.shape[:2] | |
| revolute_axis = query_points.new_zeros((batch_size, max_joints, 6)) | |
| prismatic_axis = query_points.new_zeros((batch_size, max_joints, 6)) | |
| revolute_range = query_points.new_zeros((batch_size, max_joints, 2)) | |
| prismatic_range = query_points.new_zeros((batch_size, max_joints, 2)) | |
| child_link_ids = joint_connections[..., 1] | |
| for batch_idx in range(batch_size): | |
| for joint_idx in range(max_joints): | |
| query_mask = assigned_link_ids[batch_idx] == child_link_ids[batch_idx, joint_idx] | |
| joint_query_points = query_points[batch_idx][query_mask] | |
| if self.overparam_uses_axis_direction: | |
| if self.joint_decode_type == "overparam+dir": | |
| ( | |
| revolute_axis[batch_idx, joint_idx], | |
| revolute_range[batch_idx, joint_idx], | |
| ) = self._fit_revolute_joint_parameters_with_direction( | |
| joint_query_points, | |
| revolute_motion_points[batch_idx][query_mask], | |
| ) | |
| ( | |
| prismatic_axis[batch_idx, joint_idx], | |
| prismatic_range[batch_idx, joint_idx], | |
| ) = self._fit_prismatic_joint_parameters_with_direction( | |
| joint_query_points, | |
| prismatic_motion_points[batch_idx][query_mask], | |
| ) | |
| else: | |
| if revolute_axis_directions is None or prismatic_axis_directions is None: | |
| raise ValueError( | |
| "overparam+singledir recovery requires per-joint axis directions" | |
| ) | |
| ( | |
| revolute_axis[batch_idx, joint_idx], | |
| revolute_range[batch_idx, joint_idx], | |
| ) = self._fit_revolute_joint_parameters_with_single_direction( | |
| joint_query_points, | |
| revolute_motion_points[batch_idx][query_mask], | |
| revolute_axis_directions[batch_idx, joint_idx], | |
| ) | |
| ( | |
| prismatic_axis[batch_idx, joint_idx], | |
| prismatic_range[batch_idx, joint_idx], | |
| ) = self._fit_prismatic_joint_parameters_with_single_direction( | |
| joint_query_points, | |
| prismatic_motion_points[batch_idx][query_mask], | |
| prismatic_axis_directions[batch_idx, joint_idx], | |
| ) | |
| else: | |
| ( | |
| revolute_axis[batch_idx, joint_idx], | |
| revolute_range[batch_idx, joint_idx], | |
| ) = self._fit_revolute_joint_parameters( | |
| joint_query_points, | |
| revolute_motion_points[batch_idx][query_mask], | |
| ) | |
| ( | |
| prismatic_axis[batch_idx, joint_idx], | |
| prismatic_range[batch_idx, joint_idx], | |
| ) = self._fit_prismatic_joint_parameters( | |
| joint_query_points, | |
| prismatic_motion_points[batch_idx][query_mask], | |
| ) | |
| revolute_mask = (joint_valid_flag & is_revolute).unsqueeze(-1) | |
| prismatic_mask = (joint_valid_flag & is_prismatic).unsqueeze(-1) | |
| return ( | |
| revolute_axis.masked_fill(~revolute_mask, 0), | |
| prismatic_axis.masked_fill(~prismatic_mask, 0), | |
| revolute_range.masked_fill(~revolute_mask[..., :1], 0), | |
| prismatic_range.masked_fill(~prismatic_mask[..., :1], 0), | |
| ) | |
| def _resolve_joint_decoding_link_ids( | |
| self, | |
| *, | |
| segmentation_logits: Tensor, | |
| link_ids: Tensor | None, | |
| ) -> Tensor: | |
| if link_ids is not None: | |
| return link_ids | |
| if self.training: | |
| raise ValueError( | |
| "forward requires link_ids when joint_decode_type uses over-parameterized decoding during training" | |
| ) | |
| return segmentation_logits.argmax(dim=-1) | |
| def _build_parent_index( | |
| self, | |
| *, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| num_links: int, | |
| ) -> Tensor: | |
| """Returns one parent-link index per link, using `-1` for roots/padding.""" | |
| device = joint_connections.device | |
| parent_index = torch.full( | |
| (joint_connections.shape[0], num_links), | |
| fill_value=-1, | |
| dtype=joint_connections.dtype, | |
| device=device, | |
| ) | |
| if joint_connections.shape[1] == 0 or num_links == 0: | |
| return parent_index | |
| parent_indices = joint_connections[..., 0] | |
| child_indices = joint_connections[..., 1] | |
| valid_joint_mask = ( | |
| joint_valid_flag | |
| & (parent_indices >= 0) | |
| & (parent_indices < num_links) | |
| & (child_indices >= 0) | |
| & (child_indices < num_links) | |
| ) | |
| if not bool(valid_joint_mask.any().item()): | |
| return parent_index | |
| valid_child_mask = torch.zeros( | |
| (joint_connections.shape[0], num_links), | |
| dtype=torch.bool, | |
| device=device, | |
| ) | |
| valid_child_mask.scatter_( | |
| dim=1, | |
| index=child_indices.clamp_min(0), | |
| src=valid_joint_mask, | |
| ) | |
| parent_values = torch.where(valid_joint_mask, parent_indices, torch.zeros_like(parent_indices)) | |
| parent_index.scatter_( | |
| dim=1, | |
| index=child_indices.clamp_min(0), | |
| src=parent_values, | |
| ) | |
| return parent_index.masked_fill(~valid_child_mask, -1) | |
| def _build_depth_decayed_ancestor_context( | |
| self, | |
| *, | |
| link_latents: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| ) -> Tensor: | |
| """Returns a depth-decayed ancestor average for each link latent.""" | |
| if not self.use_ancestor_context_for_segmentation: | |
| return link_latents | |
| batch_size, num_links, _ = link_latents.shape | |
| if num_links == 0: | |
| return torch.zeros_like(link_latents) | |
| parent_index = self._build_parent_index( | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| num_links=num_links, | |
| ) | |
| valid_parent_mask = parent_index >= 0 | |
| if self.ancestor_context_decay == 0.0: | |
| parent_latents = link_latents.gather( | |
| dim=1, | |
| index=parent_index.clamp_min(0).unsqueeze(-1).expand_as(link_latents), | |
| ) | |
| return parent_latents * valid_parent_mask.unsqueeze(-1).to(dtype=link_latents.dtype) | |
| if not bool(valid_parent_mask.any().item()): | |
| return torch.zeros_like(link_latents) | |
| solve_dtype = ( | |
| torch.float64 | |
| if link_latents.device.type == "cpu" | |
| else torch.float32 | |
| ) | |
| parent_adjacency = torch.zeros( | |
| (batch_size, num_links, num_links), | |
| device=link_latents.device, | |
| dtype=solve_dtype, | |
| ) | |
| parent_adjacency.scatter_( | |
| dim=2, | |
| index=parent_index.clamp_min(0).unsqueeze(-1), | |
| src=valid_parent_mask.unsqueeze(-1).to(dtype=solve_dtype), | |
| ) | |
| system_matrix = ( | |
| torch.eye(num_links, device=link_latents.device, dtype=solve_dtype).unsqueeze(0) | |
| - self.ancestor_context_decay * parent_adjacency | |
| ) | |
| weighted_ancestor_matrix = torch.linalg.solve(system_matrix, parent_adjacency) | |
| normalization = weighted_ancestor_matrix.sum(dim=-1, keepdim=True) | |
| normalized_ancestor_matrix = weighted_ancestor_matrix / normalization.clamp_min(1.0) | |
| normalized_ancestor_matrix = normalized_ancestor_matrix.masked_fill( | |
| normalization <= 0.0, | |
| 0.0, | |
| ) | |
| return torch.matmul( | |
| normalized_ancestor_matrix.to(dtype=link_latents.dtype), | |
| link_latents, | |
| ) | |
| def build_segmentation_link_latents( | |
| self, | |
| *, | |
| link_latents: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| ) -> Tensor: | |
| """Applies optional ancestor-context refinement for segmentation only.""" | |
| if not self.use_ancestor_context_for_segmentation: | |
| return link_latents | |
| if self.ancestor_context_gate is None or self.ancestor_context_projector is None: | |
| raise RuntimeError( | |
| "ancestor-context projector and gate must exist when " | |
| "use_ancestor_context_for_segmentation=True" | |
| ) | |
| ancestor_context = self._build_depth_decayed_ancestor_context( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| ) | |
| projected_ancestor_context = self.ancestor_context_projector(ancestor_context) | |
| gate = torch.sigmoid( | |
| self.ancestor_context_gate( | |
| torch.cat((link_latents, projected_ancestor_context), dim=-1) | |
| ) | |
| ) | |
| return link_latents + gate * projected_ancestor_context | |
| def decode( | |
| self, | |
| *, | |
| query_latents: Tensor, | |
| query_points: Tensor, | |
| link_latents: Tensor, | |
| link_valid_flag: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| is_revolute: Tensor, | |
| is_prismatic: Tensor, | |
| link_ids: Tensor | None = None, | |
| ) -> Dict[str, Any]: | |
| """Decodes predictions while preserving the caller-provided joint layout.""" | |
| link_latents = link_latents.masked_fill(~link_valid_flag.unsqueeze(-1), 0) | |
| segmentation_link_latents = self.build_segmentation_link_latents( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| ) | |
| segmentation_logits = self.decode_segmentation( | |
| query_latents=query_latents, | |
| link_latents=segmentation_link_latents, | |
| link_valid_flag=link_valid_flag, | |
| ) | |
| revolute_axis = None | |
| prismatic_axis = None | |
| revolute_range = None | |
| prismatic_range = None | |
| revolute_closest_axis_points = None | |
| revolute_low_points = None | |
| revolute_high_points = None | |
| revolute_axis_directions = None | |
| revolute_closest_axis_points_decoder = None | |
| prismatic_closest_axis_points = None | |
| prismatic_low_points = None | |
| prismatic_high_points = None | |
| prismatic_axis_directions = None | |
| prismatic_closest_axis_points_decoder = None | |
| joint_decoding_link_ids = None | |
| decoded_motion_points = None | |
| if self.joint_decode_type in _PLAIN_JOINT_DECODE_TYPES: | |
| if not self.training: | |
| revolute_axis, prismatic_axis, revolute_range, prismatic_range = ( | |
| self.decode_joint_parameters( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| ) | |
| ) | |
| else: | |
| joint_decoding_link_ids = self._resolve_joint_decoding_link_ids( | |
| segmentation_logits=segmentation_logits, | |
| link_ids=link_ids, | |
| ) | |
| decoded_motion_points = self._decode_joint_motion_points( | |
| query_latents=query_latents, | |
| link_latents=link_latents, | |
| assigned_link_ids=joint_decoding_link_ids, | |
| joint_connections=joint_connections, | |
| ) | |
| revolute_motion_points_decoder_raw, prismatic_motion_points_decoder_raw = ( | |
| decoded_motion_points | |
| ) | |
| revolute_joint_axis_directions = None | |
| prismatic_joint_axis_directions = None | |
| if self.overparam_predicts_single_axis_direction: | |
| ( | |
| revolute_joint_axis_directions, | |
| prismatic_joint_axis_directions, | |
| ) = self._decode_joint_axis_directions( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| ) | |
| revolute_motion_points_decoder, prismatic_motion_points_decoder = decoded_motion_points | |
| revolute_motion_points = self._convert_overparam_motion_points_to_world_coordinates( | |
| motion_points=revolute_motion_points_decoder, | |
| query_points=query_points, | |
| assigned_link_ids=joint_decoding_link_ids, | |
| ) | |
| prismatic_motion_points = self._convert_overparam_motion_points_to_world_coordinates( | |
| motion_points=prismatic_motion_points_decoder, | |
| query_points=query_points, | |
| assigned_link_ids=joint_decoding_link_ids, | |
| ) | |
| revolute_axis, prismatic_axis, revolute_range, prismatic_range = ( | |
| self.decode_joint_parameters( | |
| link_latents=link_latents, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| query_points=query_points, | |
| assigned_link_ids=joint_decoding_link_ids, | |
| decoded_motion_points=(revolute_motion_points, prismatic_motion_points), | |
| decoded_axis_directions=( | |
| None | |
| if revolute_joint_axis_directions is None | |
| or prismatic_joint_axis_directions is None | |
| else ( | |
| revolute_joint_axis_directions, | |
| prismatic_joint_axis_directions, | |
| ) | |
| ), | |
| decoded_motion_points_are_world=True, | |
| ) | |
| ) | |
| revolute_closest_axis_points_decoder = revolute_motion_points_decoder_raw[..., :3] | |
| revolute_closest_axis_points = revolute_motion_points[..., :3] | |
| revolute_low_points = revolute_motion_points[..., 3:6] | |
| revolute_high_points = revolute_motion_points[..., 6:9] | |
| if self.overparam_predicts_query_axis_direction: | |
| revolute_axis_directions = revolute_motion_points[..., 9:12] | |
| elif self.overparam_predicts_single_axis_direction: | |
| revolute_axis_directions = revolute_joint_axis_directions | |
| prismatic_closest_axis_points_decoder = prismatic_motion_points_decoder_raw[..., :3] | |
| prismatic_closest_axis_points = prismatic_motion_points[..., :3] | |
| prismatic_low_points = prismatic_motion_points[..., 3:6] | |
| prismatic_high_points = prismatic_motion_points[..., 6:9] | |
| if self.overparam_predicts_query_axis_direction: | |
| prismatic_axis_directions = prismatic_motion_points[..., 9:12] | |
| elif self.overparam_predicts_single_axis_direction: | |
| prismatic_axis_directions = prismatic_joint_axis_directions | |
| return { | |
| "segmentation_logits": segmentation_logits, | |
| "revolute_axis": revolute_axis, | |
| "prismatic_axis": prismatic_axis, | |
| "revolute_range": revolute_range, | |
| "prismatic_range": prismatic_range, | |
| "revolute_closest_axis_points": revolute_closest_axis_points, | |
| "revolute_closest_axis_points_decoder": revolute_closest_axis_points_decoder, | |
| "revolute_low_points": revolute_low_points, | |
| "revolute_high_points": revolute_high_points, | |
| "revolute_axis_directions": revolute_axis_directions, | |
| "prismatic_closest_axis_points": prismatic_closest_axis_points, | |
| "prismatic_closest_axis_points_decoder": prismatic_closest_axis_points_decoder, | |
| "prismatic_low_points": prismatic_low_points, | |
| "prismatic_high_points": prismatic_high_points, | |
| "prismatic_axis_directions": prismatic_axis_directions, | |
| "joint_decoding_link_ids": joint_decoding_link_ids, | |
| "joint_connections": joint_connections, | |
| "joint_valid_flag": joint_valid_flag, | |
| "is_revolute": is_revolute, | |
| "is_prismatic": is_prismatic, | |
| "query_points": query_points, | |
| } | |
| def forward( | |
| self, | |
| shape_points: Tensor, | |
| shape_point_normals: Tensor, | |
| query_points: Tensor, | |
| query_point_normals: Tensor, | |
| link_point_prompts: Tensor | None, | |
| link_point_prompt_normals: Tensor | None, | |
| link_valid_flag: Tensor, | |
| joint_connections: Tensor, | |
| joint_valid_flag: Tensor, | |
| is_revolute: Tensor, | |
| is_prismatic: Tensor, | |
| link_point_prompt_dropout_eligible: Tensor | None = None, | |
| link_text_prompts: Sequence[Sequence[str]] | None = None, | |
| link_text_embeddings: Tensor | None = None, | |
| link_ids: Tensor | None = None, | |
| ) -> Dict[str, Any]: | |
| """Encodes inputs and decodes using padded joint tensors supplied by the caller.""" | |
| encoder_output = self.encoder( | |
| shape_points=shape_points, | |
| shape_point_normals=shape_point_normals, | |
| query_points=query_points, | |
| query_point_normals=query_point_normals, | |
| link_point_prompts=link_point_prompts, | |
| link_point_prompt_normals=link_point_prompt_normals, | |
| link_valid_flag=link_valid_flag, | |
| link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible, | |
| link_text_prompts=link_text_prompts, | |
| link_text_embeddings=link_text_embeddings, | |
| ) | |
| decoded_output = self.decode( | |
| query_latents=encoder_output["query_latents"], | |
| query_points=query_points, | |
| link_latents=encoder_output["link_latents"], | |
| link_valid_flag=link_valid_flag, | |
| joint_connections=joint_connections, | |
| joint_valid_flag=joint_valid_flag, | |
| is_revolute=is_revolute, | |
| is_prismatic=is_prismatic, | |
| link_ids=link_ids, | |
| ) | |
| return {**encoder_output, **decoded_output} | |
| __all__ = [ | |
| "Particulate2ArticulationModel", | |
| "Particulate2Encoder", | |
| "JointDecoderPlainFlowMatching", | |
| "JointDecoderOverParametrized", | |
| "JointDecoderSingleDirection", | |
| "JointDecoderPlain", | |
| "SegmentationDecoder", | |
| ] | |