rayli's picture
Clean unused demo logic
13116e0 verified
Raw
History Blame Contribute Delete
129 kB
from __future__ import annotations
import math
import os
from contextlib import nullcontext
from itertools import chain
from typing import Any, Dict, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from instruct_particulate.utils.inference_utils import (
axis_point_to_plucker_torch,
estimate_prismatic_limit_torch,
estimate_revolute_limit_torch,
fit_axis_to_closest_points_torch,
)
from instruct_particulate.utils.partfield_feature_utils import (
PARTFIELD_FEATURE_DIM,
PartFieldFeatureExtractor,
)
from instruct_particulate.utils.text_embedding_utils import (
encode_clip_text_prompts,
load_clip_text_encoder,
)
def _make_silu_mlp(
input_dim: int,
hidden_dim: int,
output_dim: int,
*,
bias: bool = True,
) -> nn.Sequential:
return nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=bias),
nn.SiLU(),
nn.Linear(hidden_dim, output_dim, bias=bias),
)
_OVERPARAM_AXIS_AABB_HALF_EXTENT_MIN = 1e-4
_PLAIN_JOINT_DECODE_TYPES = frozenset({"plain", "plain+fm"})
_OVERPARAM_JOINT_DECODE_TYPES = frozenset(
{"overparametrized", "overparam+dir", "overparam+singledir"}
)
def _normalize_joint_decode_type(joint_decode_type: str) -> str:
normalized_joint_decode_type = str(joint_decode_type).lower()
if normalized_joint_decode_type in {
"plain+flowmatching",
"plain+flow-matching",
"plain+fm",
}:
return "plain+fm"
if normalized_joint_decode_type in {
"overparameterization",
"overparameterized",
"overparam",
}:
return "overparametrized"
if normalized_joint_decode_type in {
"overparameterization+dir",
"overparameterized+dir",
"overparametrized+dir",
"overparam+dir",
}:
return "overparam+dir"
if normalized_joint_decode_type in {
"overparameterization+singledir",
"overparameterized+singledir",
"overparametrized+singledir",
"overparam+singledir",
"overparameterization+single-dir",
"overparameterized+single-dir",
"overparametrized+single-dir",
"overparam+single-dir",
}:
return "overparam+singledir"
return normalized_joint_decode_type
def _normalize_joint_fm_prediction_type(prediction_type: str) -> str:
normalized_prediction_type = str(prediction_type).lower()
if normalized_prediction_type in {"x", "xpred", "x-pred", "x_pred"}:
return "x"
if normalized_prediction_type in {"v", "vpred", "v-pred", "v_pred"}:
return "v"
return normalized_prediction_type
def _normalize_overparam_closest_axis_space(closest_axis_space: str) -> str:
normalized_closest_axis_space = str(closest_axis_space).lower()
if normalized_closest_axis_space in {
"world",
"sample",
"global",
"global-space",
"world-space",
}:
return "world"
if normalized_closest_axis_space in {
"part_aabb",
"part-aabb",
"aabb",
"local_aabb",
"local-aabb",
}:
return "part_aabb"
return normalized_closest_axis_space
def _coerce_joint_fm_state_stat(
values: Optional[Sequence[float]],
*,
default_value: float,
name: str,
) -> Tensor:
if values is None:
values = [default_value] * 8
if len(values) != 8:
raise ValueError(f"{name} must have length 8, got {len(values)}")
tensor = torch.tensor([float(value) for value in values], dtype=torch.float32)
if not torch.isfinite(tensor).all():
raise ValueError(f"{name} must contain only finite values, got {values!r}")
if name.endswith("_std") and torch.any(tensor <= 0.0):
raise ValueError(f"{name} must contain only positive values, got {values!r}")
return tensor
def _gather_joint_link_latents(
*,
link_latents: Tensor,
joint_connections: Tensor,
) -> Tuple[Tensor, Tensor]:
parent_indices = joint_connections[..., 0].clamp_min(0)
child_indices = joint_connections[..., 1].clamp_min(0)
gather_index = parent_indices.unsqueeze(-1).expand(-1, -1, link_latents.shape[-1])
parent_latents = link_latents.gather(dim=1, index=gather_index)
gather_index = child_indices.unsqueeze(-1).expand(-1, -1, link_latents.shape[-1])
child_latents = link_latents.gather(dim=1, index=gather_index)
return parent_latents, child_latents
def _build_joint_motion_condition_inputs(
*,
parent_latents: Tensor,
child_latents: Tensor,
motion_type: str,
revolute_embedding: Tensor,
prismatic_embedding: Tensor,
) -> Tensor:
if parent_latents.shape != child_latents.shape:
raise ValueError(
"parent_latents and child_latents must share the same shape, "
f"got {tuple(parent_latents.shape)} and {tuple(child_latents.shape)}"
)
if motion_type == "revolute":
type_embedding = revolute_embedding
elif motion_type == "prismatic":
type_embedding = prismatic_embedding
else:
raise ValueError(
"motion_type must be 'revolute' or 'prismatic', "
f"got {motion_type!r}"
)
type_embeddings = type_embedding.to(
device=parent_latents.device,
dtype=parent_latents.dtype,
).view(1, 1, -1).expand_as(parent_latents)
return torch.cat((type_embeddings, parent_latents, child_latents), dim=-1)
class SwiGLUFeedForward(nn.Module):
"""Modern gated MLP used after attention blocks."""
def __init__(
self,
dim: int,
hidden_dim: Optional[int] = None,
multiplier: float = 4.0,
dropout: float = 0.0,
):
super().__init__()
hidden_dim = hidden_dim or int(dim * multiplier)
self.in_proj = nn.Linear(dim, 2 * hidden_dim)
self.out_proj = nn.Linear(hidden_dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor) -> Tensor:
value, gate = self.in_proj(x).chunk(2, dim=-1)
x = value * F.silu(gate)
x = self.dropout(x)
return self.out_proj(x)
class FrequencyMLPEmbedder(nn.Module):
"""Embeds coordinates or normals with Fourier features and a small MLP."""
def __init__(
self,
output_dim: int,
num_frequencies: int,
input_dim: int = 3,
hidden_dim: Optional[int] = None,
max_frequency: float = 32.0,
include_raw: bool = True,
):
super().__init__()
if num_frequencies <= 0:
raise ValueError(f"num_frequencies must be positive, got {num_frequencies}")
if input_dim <= 0:
raise ValueError(f"input_dim must be positive, got {input_dim}")
if max_frequency <= 0.0:
raise ValueError(f"max_frequency must be positive, got {max_frequency}")
hidden_dim = hidden_dim or output_dim
self.input_dim = input_dim
self.include_raw = include_raw
if num_frequencies == 1:
frequencies = torch.tensor([math.pi], dtype=torch.float32)
else:
frequencies = torch.exp(
torch.linspace(0.0, math.log(max_frequency), steps=num_frequencies, dtype=torch.float32)
) * math.pi
self.register_buffer("frequencies", frequencies, persistent=False)
encoded_dim = input_dim * (2 * num_frequencies + int(include_raw))
self.input_norm = nn.RMSNorm(encoded_dim)
self.mlp = _make_silu_mlp(
input_dim=encoded_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
)
def forward(self, x: Tensor) -> Tensor:
if x.shape[-1] != self.input_dim:
raise ValueError(
f"expected inputs with last dimension {self.input_dim}, got {tuple(x.shape)}"
)
x_float = x.float()
angles = x_float.unsqueeze(-1) * self.frequencies
encoded = torch.cat((angles.sin(), angles.cos()), dim=-1).flatten(start_dim=-2)
if self.include_raw:
encoded = torch.cat((encoded, x_float), dim=-1)
encoded = encoded.to(dtype=self.mlp[0].weight.dtype)
return self.mlp(self.input_norm(encoded))
class TimestepEmbedder(nn.Module):
"""Embeds scalar diffusion/flow timesteps with sinusoidal features."""
def __init__(
self,
hidden_dim: int,
*,
frequency_embedding_dim: int = 256,
max_period: float = 10_000.0,
):
super().__init__()
if hidden_dim <= 0:
raise ValueError(f"hidden_dim must be positive, got {hidden_dim}")
if frequency_embedding_dim <= 0:
raise ValueError(
f"frequency_embedding_dim must be positive, got {frequency_embedding_dim}"
)
if max_period <= 0.0:
raise ValueError(f"max_period must be positive, got {max_period}")
self.frequency_embedding_dim = int(frequency_embedding_dim)
self.max_period = float(max_period)
self.mlp = _make_silu_mlp(
input_dim=self.frequency_embedding_dim,
hidden_dim=hidden_dim,
output_dim=hidden_dim,
)
def _frequency_embedding(self, t: Tensor) -> Tensor:
half_dim = self.frequency_embedding_dim // 2
if half_dim == 0:
return t.unsqueeze(-1)
frequency_exponents = torch.arange(
half_dim,
device=t.device,
dtype=torch.float32,
)
frequency_exponents = frequency_exponents / max(half_dim, 1)
frequencies = torch.exp(-math.log(self.max_period) * frequency_exponents)
angles = t.float().unsqueeze(-1) * frequencies.unsqueeze(0)
embedding = torch.cat((angles.cos(), angles.sin()), dim=-1)
if self.frequency_embedding_dim % 2 != 0:
embedding = torch.cat((embedding, torch.zeros_like(embedding[:, :1])), dim=-1)
return embedding
def forward(self, t: Tensor) -> Tensor:
if t.ndim != 1:
raise ValueError(f"expected a 1D timestep tensor, got shape {tuple(t.shape)}")
frequency_embedding = self._frequency_embedding(t).to(
dtype=self.mlp[0].weight.dtype
)
return self.mlp(frequency_embedding)
class SDPASelfAttention(nn.Module):
"""Self-attention implemented with PyTorch SDPA."""
def __init__(
self,
model_dim: int,
num_heads: int,
head_dim: Optional[int] = None,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
qk_norm: bool = True,
qkv_bias: bool = True,
):
super().__init__()
if head_dim is None:
if model_dim % num_heads != 0:
raise ValueError(
f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads}) "
"when head_dim is not specified"
)
head_dim = model_dim // num_heads
self.model_dim = model_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.inner_dim = num_heads * head_dim
self.attn_dropout = attn_dropout
self.qkv_proj = nn.Linear(model_dim, 3 * self.inner_dim, bias=qkv_bias)
self.out_proj = nn.Linear(self.inner_dim, model_dim)
self.out_dropout = nn.Dropout(proj_dropout)
self.q_norm = nn.RMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = nn.RMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()
def _reshape_heads(self, x: Tensor) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
return x.transpose(1, 2)
def forward(self, x: Tensor, *, mask: Tensor | None = None) -> Tensor:
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
q = self.q_norm(self._reshape_heads(q))
k = self.k_norm(self._reshape_heads(k))
v = self._reshape_heads(v)
attn_mask = None
if mask is not None:
attn_mask = mask[:, None, :, None] & mask[:, None, None, :]
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.attn_dropout if self.training else 0.0,
)
attn_output = attn_output.transpose(1, 2).contiguous().view(
x.shape[0],
x.shape[1],
self.inner_dim,
)
return self.out_dropout(self.out_proj(attn_output))
class SDPACrossAttention(nn.Module):
"""Cross-attention implemented with PyTorch SDPA."""
def __init__(
self,
model_dim: int,
num_heads: int,
head_dim: Optional[int] = None,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
qk_norm: bool = True,
q_bias: bool = True,
kv_bias: bool = True,
):
super().__init__()
if head_dim is None:
if model_dim % num_heads != 0:
raise ValueError(
f"model_dim ({model_dim}) must be divisible by num_heads ({num_heads}) "
"when head_dim is not specified"
)
head_dim = model_dim // num_heads
self.model_dim = model_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.inner_dim = num_heads * head_dim
self.attn_dropout = attn_dropout
self.q_proj = nn.Linear(model_dim, self.inner_dim, bias=q_bias)
self.kv_proj = nn.Linear(model_dim, 2 * self.inner_dim, bias=kv_bias)
self.out_proj = nn.Linear(self.inner_dim, model_dim)
self.out_dropout = nn.Dropout(proj_dropout)
self.q_norm = nn.RMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = nn.RMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()
def _reshape_heads(self, x: Tensor) -> Tensor:
batch_size, seq_len, _ = x.shape
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
return x.transpose(1, 2)
def forward(
self,
query: Tensor,
context: Tensor,
*,
query_mask: Tensor | None = None,
context_mask: Tensor | None = None,
) -> Tensor:
q = self.q_norm(self._reshape_heads(self.q_proj(query)))
k, v = self.kv_proj(context).chunk(2, dim=-1)
k = self.k_norm(self._reshape_heads(k))
v = self._reshape_heads(v)
if query_mask is None and context_mask is None:
attn_mask = None
else:
if query_mask is None:
query_mask = torch.ones(query.shape[:2], device=query.device, dtype=torch.bool)
if context_mask is None:
context_mask = torch.ones(
context.shape[:2],
device=context.device,
dtype=torch.bool,
)
attn_mask = query_mask[:, None, :, None] & context_mask[:, None, None, :]
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.attn_dropout if self.training else 0.0,
)
attn_output = attn_output.transpose(1, 2).contiguous().view(
query.shape[0],
query.shape[1],
self.inner_dim,
)
return self.out_dropout(self.out_proj(attn_output))
class CrossAttentionBlock(nn.Module):
"""Pre-norm cross-attention block with a gated MLP."""
def __init__(
self,
model_dim: int,
num_heads: int,
head_dim: Optional[int] = None,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
ffn_multiplier: float = 4.0,
ffn_dropout: float = 0.0,
qk_norm: bool = True,
norm_eps: float = 1e-6,
):
super().__init__()
self.query_norm = nn.RMSNorm(model_dim, eps=norm_eps)
self.context_norm = nn.RMSNorm(model_dim, eps=norm_eps)
self.attn = SDPACrossAttention(
model_dim=model_dim,
num_heads=num_heads,
head_dim=head_dim,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
qk_norm=qk_norm,
)
self.ffn_norm = nn.RMSNorm(model_dim, eps=norm_eps)
self.ffn = SwiGLUFeedForward(
dim=model_dim,
multiplier=ffn_multiplier,
dropout=ffn_dropout,
)
def forward(
self,
query: Tensor,
context: Tensor,
*,
query_mask: Tensor | None = None,
context_mask: Tensor | None = None,
) -> Tensor:
query = query + self.attn(
self.query_norm(query),
self.context_norm(context),
query_mask=query_mask,
context_mask=context_mask,
)
if query_mask is not None:
query = query.masked_fill(~query_mask.unsqueeze(-1), 0)
query = query + self.ffn(self.ffn_norm(query))
if query_mask is not None:
query = query.masked_fill(~query_mask.unsqueeze(-1), 0)
return query
class SelfAttentionBlock(nn.Module):
"""Pre-norm self-attention block with a gated MLP."""
def __init__(
self,
model_dim: int,
num_heads: int,
head_dim: Optional[int] = None,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
ffn_multiplier: float = 4.0,
ffn_dropout: float = 0.0,
qk_norm: bool = True,
norm_eps: float = 1e-6,
):
super().__init__()
self.attn_norm = nn.RMSNorm(model_dim, eps=norm_eps)
self.attn = SDPASelfAttention(
model_dim=model_dim,
num_heads=num_heads,
head_dim=head_dim,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
qk_norm=qk_norm,
)
self.ffn_norm = nn.RMSNorm(model_dim, eps=norm_eps)
self.ffn = SwiGLUFeedForward(
dim=model_dim,
multiplier=ffn_multiplier,
dropout=ffn_dropout,
)
def forward(self, x: Tensor, *, mask: Tensor | None = None) -> Tensor:
x = x + self.attn(self.attn_norm(x), mask=mask)
if mask is not None:
x = x.masked_fill(~mask.unsqueeze(-1), 0)
x = x + self.ffn(self.ffn_norm(x))
if mask is not None:
x = x.masked_fill(~mask.unsqueeze(-1), 0)
return x
class ShapeLatentEncoder(nn.Module):
"""Encodes a point cloud into a fixed set of shape latents."""
def __init__(
self,
model_dim: int,
num_shape_latents: int,
num_heads: int,
head_dim: Optional[int] = None,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
ffn_multiplier: float = 4.0,
ffn_dropout: float = 0.0,
qk_norm: bool = True,
norm_eps: float = 1e-6,
):
super().__init__()
self.shape_latents = nn.Parameter(torch.empty(num_shape_latents, model_dim))
self.block = CrossAttentionBlock(
model_dim=model_dim,
num_heads=num_heads,
head_dim=head_dim,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
ffn_multiplier=ffn_multiplier,
ffn_dropout=ffn_dropout,
qk_norm=qk_norm,
norm_eps=norm_eps,
)
self.output_norm = nn.RMSNorm(model_dim, eps=norm_eps)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.shape_latents, std=0.02)
def forward(self, point_tokens: Tensor) -> Tensor:
batch_size = point_tokens.shape[0]
latents = self.shape_latents.to(dtype=point_tokens.dtype).unsqueeze(0).expand(
batch_size,
-1,
-1,
)
latents = self.block(latents, point_tokens)
return self.output_norm(latents)
class SegmentationDecoder(nn.Module):
"""Pairwise query-to-link scoring over padded link tokens."""
def __init__(
self,
model_dim: int,
decode_type: str = "dot",
bias: bool = True,
query_chunk_size: Optional[int] = None,
):
super().__init__()
self.model_dim = model_dim
self.decode_type = str(decode_type)
if self.decode_type not in {"dot", "mlp"}:
raise ValueError(
f"decode_type must be 'dot' or 'mlp', got {decode_type!r}"
)
if query_chunk_size is not None and int(query_chunk_size) <= 0:
raise ValueError(f"query_chunk_size must be positive when set, got {query_chunk_size}")
self.query_chunk_size = None if query_chunk_size is None else int(query_chunk_size)
if self.decode_type == "mlp":
self.decoder = _make_silu_mlp(
input_dim=model_dim * 2,
hidden_dim=model_dim * 4,
output_dim=1,
bias=bias,
)
def _decode_mlp_logits(
self,
query_latents: Tensor,
link_latents: Tensor,
) -> Tensor:
expanded_link_latents = link_latents.unsqueeze(1).expand(
-1,
query_latents.shape[1],
-1,
-1,
)
expanded_query_latents = query_latents.unsqueeze(2).expand(
-1,
-1,
link_latents.shape[1],
-1,
)
pair_latents = torch.cat(
(expanded_link_latents, expanded_query_latents),
dim=-1,
)
return self.decoder(pair_latents).squeeze(-1)
def forward(
self,
query_latents: Tensor,
link_latents: Tensor,
link_valid_flag: Tensor,
) -> Tensor:
if self.decode_type == "dot":
logits = torch.matmul(query_latents, link_latents.transpose(-1, -2))
else:
if (
self.query_chunk_size is None
or query_latents.shape[1] <= self.query_chunk_size
):
logits = self._decode_mlp_logits(query_latents, link_latents)
else:
logits = torch.cat(
[
self._decode_mlp_logits(
query_latents[:, start : start + self.query_chunk_size],
link_latents,
)
for start in range(0, query_latents.shape[1], self.query_chunk_size)
],
dim=1,
)
return logits.masked_fill(~link_valid_flag.unsqueeze(1), float("-inf"))
class JointDecoderPlain(nn.Module):
"""Decodes per-joint axes and ranges from parent/child link latents."""
motion_state_dim = 8
def __init__(
self,
model_dim: int,
bias: bool = True,
):
super().__init__()
self.model_dim = model_dim
self.revolute_embedding = nn.Parameter(torch.empty(model_dim))
self.prismatic_embedding = nn.Parameter(torch.empty(model_dim))
self.decoder = _make_silu_mlp(
input_dim=model_dim * 3,
hidden_dim=model_dim * 4,
output_dim=self.motion_state_dim,
bias=bias,
)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.revolute_embedding, std=0.02)
nn.init.trunc_normal_(self.prismatic_embedding, std=0.02)
def _predict_motion_parameters(
self,
*,
parent_latents: Tensor,
child_latents: Tensor,
active_mask: Tensor,
motion_type: str,
) -> Tuple[Tensor, Tensor]:
decoded = self.decoder(
_build_joint_motion_condition_inputs(
parent_latents=parent_latents,
child_latents=child_latents,
motion_type=motion_type,
revolute_embedding=self.revolute_embedding,
prismatic_embedding=self.prismatic_embedding,
)
)
decoded = decoded.masked_fill(~active_mask.unsqueeze(-1), 0)
return decoded[..., :6], decoded[..., 6:8]
def predict(
self,
link_latents: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
is_revolute: Tensor,
is_prismatic: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Decodes padded joint tensors and zeros invalid rows in the outputs."""
parent_latents, child_latents = _gather_joint_link_latents(
link_latents=link_latents,
joint_connections=joint_connections,
)
motion_flags = dict(revolute=is_revolute, prismatic=is_prismatic)
motion_outputs: Dict[str, Tensor] = {}
for motion_type in ("revolute", "prismatic"):
motion_outputs[f"{motion_type}_axis"], motion_outputs[f"{motion_type}_range"] = self._predict_motion_parameters(
parent_latents=parent_latents,
child_latents=child_latents,
active_mask=joint_valid_flag & motion_flags[motion_type],
motion_type=motion_type)
return (
motion_outputs["revolute_axis"],
motion_outputs["prismatic_axis"],
motion_outputs["revolute_range"],
motion_outputs["prismatic_range"],
)
class JointDecoderPlainFlowMatching(nn.Module):
"""Predicts plain joint FM states with configurable x/v parameterization."""
motion_state_dim = 8
def __init__(
self,
model_dim: int,
*,
hidden_dim: Optional[int] = None,
prediction_type: str = "v",
time_embedding_dim: int = 256,
inference_steps: int = 100,
time_scale: float = 1000.0,
sigma_min: float = 0.0,
rescale_t: float = 1.0,
cfg_scale: float = 1.0,
revolute_state_mean: Optional[Sequence[float]] = None,
revolute_state_std: Optional[Sequence[float]] = None,
prismatic_state_mean: Optional[Sequence[float]] = None,
prismatic_state_std: Optional[Sequence[float]] = None,
bias: bool = True,
):
super().__init__()
self.model_dim = model_dim
self.hidden_dim = int(hidden_dim or model_dim)
self.inference_steps = int(inference_steps)
if self.inference_steps <= 0:
raise ValueError(f"inference_steps must be positive, got {inference_steps}")
prediction_type = _normalize_joint_fm_prediction_type(prediction_type)
if prediction_type not in {"x", "v"}:
raise ValueError(
"prediction_type must be 'x' or 'v', "
f"got {prediction_type!r}"
)
self.prediction_type = prediction_type
self.time_scale = float(time_scale)
self.sigma_min = float(sigma_min)
self.rescale_t = float(rescale_t)
self.cfg_scale = float(cfg_scale)
self.register_buffer(
"revolute_state_mean",
_coerce_joint_fm_state_stat(
revolute_state_mean,
default_value=0.0,
name="revolute_state_mean",
),
persistent=True,
)
self.register_buffer(
"revolute_state_std",
_coerce_joint_fm_state_stat(
revolute_state_std,
default_value=1.0,
name="revolute_state_std",
),
persistent=True,
)
self.register_buffer(
"prismatic_state_mean",
_coerce_joint_fm_state_stat(
prismatic_state_mean,
default_value=0.0,
name="prismatic_state_mean",
),
persistent=True,
)
self.register_buffer(
"prismatic_state_std",
_coerce_joint_fm_state_stat(
prismatic_state_std,
default_value=1.0,
name="prismatic_state_std",
),
persistent=True,
)
self.revolute_embedding = nn.Parameter(torch.empty(model_dim))
self.prismatic_embedding = nn.Parameter(torch.empty(model_dim))
self.condition_projector = _make_silu_mlp(
input_dim=model_dim * 3,
hidden_dim=self.hidden_dim * 2,
output_dim=self.hidden_dim,
bias=bias,
)
self.state_projector = nn.Linear(
self.motion_state_dim,
self.hidden_dim,
bias=bias,
)
self.time_embedder = TimestepEmbedder(
self.hidden_dim,
frequency_embedding_dim=time_embedding_dim,
)
self.input_norm = nn.RMSNorm(self.hidden_dim)
self.residual_block_1 = _make_silu_mlp(
input_dim=self.hidden_dim,
hidden_dim=self.hidden_dim * 4,
output_dim=self.hidden_dim,
bias=bias,
)
self.residual_block_2 = _make_silu_mlp(
input_dim=self.hidden_dim,
hidden_dim=self.hidden_dim * 4,
output_dim=self.hidden_dim,
bias=bias,
)
self.output_norm = nn.RMSNorm(self.hidden_dim)
self.output_projector = nn.Linear(
self.hidden_dim,
self.motion_state_dim,
bias=bias,
)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.revolute_embedding, std=0.02)
nn.init.trunc_normal_(self.prismatic_embedding, std=0.02)
def _prepare_t_sequence(
self,
*,
device: torch.device,
steps: int | None = None,
dtype: torch.dtype = torch.float32,
) -> Tensor:
num_steps = self.inference_steps if steps is None else int(steps)
if num_steps <= 0:
raise ValueError(f"steps must be positive, got {num_steps}")
t_sequence = torch.linspace(
0.0,
1.0,
num_steps + 1,
device=device,
dtype=torch.float32,
)
if self.rescale_t:
t_sequence = t_sequence / (
1.0 + (self.rescale_t - 1.0) * (1.0 - t_sequence)
)
return t_sequence.to(dtype=dtype)
def _motion_state_mean_and_std(
self,
*,
motion_type: str,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[Tensor, Tensor]:
if motion_type == "revolute":
mean = self.revolute_state_mean
std = self.revolute_state_std
elif motion_type == "prismatic":
mean = self.prismatic_state_mean
std = self.prismatic_state_std
else:
raise ValueError(
"motion_type must be 'revolute' or 'prismatic', "
f"got {motion_type!r}"
)
safe_std = std.to(device=device, dtype=dtype).clamp_min(1.0e-6)
return mean.to(device=device, dtype=dtype), safe_std
def _unnormalize_motion_state(
self,
motion_state: Tensor,
*,
motion_type: str,
) -> Tensor:
mean, std = self._motion_state_mean_and_std(
motion_type=motion_type,
device=motion_state.device,
dtype=motion_state.dtype,
)
return motion_state * std + mean
def _sample_motion_state(
self,
*,
condition_embeddings: Tensor,
steps: int | None = None,
cfg_scale: float | None = None,
) -> Tensor:
num_samples = condition_embeddings.shape[0]
if num_samples == 0:
return condition_embeddings.new_zeros((0, self.motion_state_dim))
guidance_scale = self.cfg_scale if cfg_scale is None else float(cfg_scale)
motion_state = torch.randn(
(num_samples, self.motion_state_dim),
device=condition_embeddings.device,
dtype=condition_embeddings.dtype,
)
t_sequence = self._prepare_t_sequence(
device=condition_embeddings.device,
steps=steps,
dtype=condition_embeddings.dtype,
)
for t0, t1 in zip(t_sequence[:-1], t_sequence[1:]):
time_vector = motion_state.new_full((num_samples,), t0)
if guidance_scale == 1.0:
velocity = self.predict_velocity(
x_t=motion_state,
t=time_vector,
condition_embeddings=condition_embeddings,
)
else:
velocity = self._predict_guided_velocity(
x_t=motion_state,
t=time_vector,
condition_embeddings=condition_embeddings,
cfg_scale=guidance_scale,
)
motion_state = motion_state + (t1 - t0).to(dtype=motion_state.dtype) * velocity
return motion_state
def _bridge_noise_scale(self, t: Tensor) -> Tensor:
return 1.0 - (1.0 - self.sigma_min) * t
def _predict_model_output(
self,
*,
x_t: Tensor,
t: Tensor,
condition_embeddings: Tensor,
) -> Tensor:
if x_t.ndim != 2:
raise ValueError(f"x_t must be rank-2, got shape {tuple(x_t.shape)}")
if x_t.shape[-1] != self.motion_state_dim:
raise ValueError(
f"x_t last dim must be {self.motion_state_dim}, got {tuple(x_t.shape)}"
)
if condition_embeddings.ndim != 2:
raise ValueError(
"condition_embeddings must be rank-2, "
f"got shape {tuple(condition_embeddings.shape)}"
)
if condition_embeddings.shape[0] != x_t.shape[0]:
raise ValueError(
"condition_embeddings batch dim must match x_t, "
f"got {tuple(condition_embeddings.shape)} and {tuple(x_t.shape)}"
)
if t.ndim != 1 or t.shape[0] != x_t.shape[0]:
raise ValueError(
"t must be a vector matching x_t batch size, "
f"got {tuple(t.shape)} and {tuple(x_t.shape)}"
)
x_hidden = self.state_projector(x_t.to(dtype=self.state_projector.weight.dtype))
time_hidden = self.time_embedder(t * self.time_scale)
condition_embeddings = condition_embeddings.to(dtype=x_hidden.dtype)
hidden = self.input_norm(x_hidden + time_hidden + condition_embeddings)
hidden = hidden + self.residual_block_1(hidden)
hidden = hidden + self.residual_block_2(hidden)
return self.output_projector(self.output_norm(hidden))
def _predicted_clean_state_to_velocity(
self,
*,
predicted_clean_state: Tensor,
x_t: Tensor,
t: Tensor,
) -> Tensor:
bridge_noise_scale = self._bridge_noise_scale(t).unsqueeze(-1).to(
dtype=predicted_clean_state.dtype,
)
bridge_noise_scale = bridge_noise_scale.clamp_min(
torch.finfo(predicted_clean_state.dtype).eps
)
x_t = x_t.to(dtype=predicted_clean_state.dtype)
return (
predicted_clean_state - (1.0 - self.sigma_min) * x_t
) / bridge_noise_scale
def predict_velocity(
self,
*,
x_t: Tensor,
t: Tensor,
condition_embeddings: Tensor,
) -> Tensor:
model_output = self._predict_model_output(
x_t=x_t,
t=t,
condition_embeddings=condition_embeddings,
)
if self.prediction_type == "v":
return model_output
return self._predicted_clean_state_to_velocity(
predicted_clean_state=model_output,
x_t=x_t,
t=t,
)
def _predict_guided_velocity(
self,
*,
x_t: Tensor,
t: Tensor,
condition_embeddings: Tensor,
cfg_scale: float | None = None,
) -> Tensor:
guidance_scale = self.cfg_scale if cfg_scale is None else float(cfg_scale)
if not math.isfinite(guidance_scale) or guidance_scale < 0.0:
raise ValueError(
"cfg_scale must be a finite non-negative number, "
f"got {guidance_scale!r}"
)
conditional_velocity = self.predict_velocity(
x_t=x_t,
t=t,
condition_embeddings=condition_embeddings,
)
unconditional_velocity = self.predict_velocity(
x_t=x_t,
t=t,
condition_embeddings=torch.zeros_like(condition_embeddings),
)
return unconditional_velocity + guidance_scale * (
conditional_velocity - unconditional_velocity
)
def predict(
self,
link_latents: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
is_revolute: Tensor,
is_prismatic: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
parent_latents, child_latents = _gather_joint_link_latents(
link_latents=link_latents,
joint_connections=joint_connections,
)
batch_size, max_joints = joint_valid_flag.shape
motion_flags = dict(revolute=is_revolute, prismatic=is_prismatic)
motion_outputs: Dict[str, Tensor] = {}
for motion_type in ("revolute", "prismatic"):
motion_outputs[f"{motion_type}_axis"] = link_latents.new_zeros(
(batch_size, max_joints, 6)
)
motion_outputs[f"{motion_type}_range"] = link_latents.new_zeros(
(batch_size, max_joints, 2)
)
active_mask = joint_valid_flag & motion_flags[motion_type]
if not torch.any(active_mask):
continue
condition_embeddings = self.condition_projector(
_build_joint_motion_condition_inputs(
parent_latents=parent_latents,
child_latents=child_latents,
motion_type=motion_type,
revolute_embedding=self.revolute_embedding,
prismatic_embedding=self.prismatic_embedding,
)
)[active_mask]
motion_state = self._sample_motion_state(
condition_embeddings=condition_embeddings,
)
motion_state = self._unnormalize_motion_state(
motion_state,
motion_type=motion_type,
)
motion_outputs[f"{motion_type}_axis"][active_mask] = motion_state[..., :6]
motion_outputs[f"{motion_type}_range"][active_mask] = motion_state[..., 6:8]
return (
motion_outputs["revolute_axis"],
motion_outputs["prismatic_axis"],
motion_outputs["revolute_range"],
motion_outputs["prismatic_range"],
)
class JointDecoderOverParametrized(nn.Module):
"""Decodes per-query over-parameterized joint supervision targets."""
def __init__(
self,
model_dim: int,
output_dim: int = 9,
bias: bool = True,
):
super().__init__()
self.model_dim = model_dim
self.output_dim = int(output_dim)
if self.output_dim not in {9, 12}:
raise ValueError(
f"JointDecoderOverParametrized output_dim must be 9 or 12, got {output_dim}"
)
self.revolute_embedding = nn.Parameter(torch.empty(model_dim))
self.prismatic_embedding = nn.Parameter(torch.empty(model_dim))
self.joint_decoder = _make_silu_mlp(
input_dim=model_dim * 4,
hidden_dim=model_dim * 4,
output_dim=self.output_dim,
bias=bias,
)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.revolute_embedding, std=0.02)
nn.init.trunc_normal_(self.prismatic_embedding, std=0.02)
def forward(
self,
query_latents: Tensor,
link_latents: Tensor,
assigned_link_ids: Tensor,
joint_connections: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Decodes `(closest_axis_point, low_pose_point, high_pose_point)` per query."""
batch_size, num_links = link_latents.shape[:2]
parent_link_ids = torch.arange(
num_links,
device=assigned_link_ids.device,
dtype=assigned_link_ids.dtype,
).view(1, -1).expand(batch_size, -1).clone() # [batch_size, num_links]
valid_joint_mask = joint_connections[..., 1] >= 0
child_link_ids = joint_connections[..., 1].clamp_min(0)
batch_indices = torch.arange(
batch_size,
device=joint_connections.device,
).unsqueeze(1).expand_as(child_link_ids)
parent_link_ids[
batch_indices[valid_joint_mask],
child_link_ids[valid_joint_mask],
] = joint_connections[..., 0][valid_joint_mask]
assigned_link_latents = link_latents.gather(
dim=1,
index=assigned_link_ids.unsqueeze(-1).expand(-1, -1, link_latents.shape[-1]),
)
assigned_parent_link_ids = parent_link_ids.gather(dim=1, index=assigned_link_ids)
assigned_parent_link_latents = link_latents.gather(
dim=1,
index=assigned_parent_link_ids.unsqueeze(-1).expand(-1, -1, link_latents.shape[-1]),
)
revolute_type_embeddings = self.revolute_embedding.to(
device=assigned_link_latents.device,
dtype=assigned_link_latents.dtype,
).view(1, 1, -1).expand_as(assigned_link_latents)
prismatic_type_embeddings = self.prismatic_embedding.to(
device=assigned_link_latents.device,
dtype=assigned_link_latents.dtype,
).view(1, 1, -1).expand_as(assigned_link_latents)
return (
self.joint_decoder(
torch.cat(
(
revolute_type_embeddings,
query_latents,
assigned_parent_link_latents,
assigned_link_latents,
),
dim=-1,
)
),
self.joint_decoder(
torch.cat(
(
prismatic_type_embeddings,
query_latents,
assigned_parent_link_latents,
assigned_link_latents,
),
dim=-1,
)
),
)
class JointDecoderSingleDirection(nn.Module):
"""Decodes one axis direction per joint from parent/child link latents."""
direction_dim = 3
def __init__(
self,
model_dim: int,
bias: bool = True,
):
super().__init__()
self.model_dim = model_dim
self.revolute_embedding = nn.Parameter(torch.empty(model_dim))
self.prismatic_embedding = nn.Parameter(torch.empty(model_dim))
self.decoder = _make_silu_mlp(
input_dim=model_dim * 3,
hidden_dim=model_dim * 4,
output_dim=self.direction_dim,
bias=bias,
)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.revolute_embedding, std=0.02)
nn.init.trunc_normal_(self.prismatic_embedding, std=0.02)
def _predict_directions(
self,
*,
parent_latents: Tensor,
child_latents: Tensor,
active_mask: Tensor,
motion_type: str,
) -> Tensor:
decoded = self.decoder(
_build_joint_motion_condition_inputs(
parent_latents=parent_latents,
child_latents=child_latents,
motion_type=motion_type,
revolute_embedding=self.revolute_embedding,
prismatic_embedding=self.prismatic_embedding,
)
)
return decoded.masked_fill(~active_mask.unsqueeze(-1), 0)
def predict(
self,
*,
link_latents: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
is_revolute: Tensor,
is_prismatic: Tensor,
) -> Tuple[Tensor, Tensor]:
parent_latents, child_latents = _gather_joint_link_latents(
link_latents=link_latents,
joint_connections=joint_connections,
)
return (
self._predict_directions(
parent_latents=parent_latents,
child_latents=child_latents,
active_mask=joint_valid_flag & is_revolute,
motion_type="revolute",
),
self._predict_directions(
parent_latents=parent_latents,
child_latents=child_latents,
active_mask=joint_valid_flag & is_prismatic,
motion_type="prismatic",
),
)
class Particulate2Encoder(nn.Module):
"""Encoder backbone for articulated point-cloud understanding."""
def __init__(
self,
model_dim: int = 768,
num_heads: int = 12,
head_dim: Optional[int] = None,
num_shape_latents: int = 256,
num_attention_blocks: int = 6,
coordinate_num_frequencies: int = 32,
normal_num_frequencies: int = 32,
coordinate_max_frequency: float = 64.0,
normal_max_frequency: float = 16.0,
link_text_feature_dim: int = 768,
embedding_hidden_dim: Optional[int] = None,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
ffn_multiplier: float = 4.0,
ffn_dropout: float = 0.0,
qk_norm: bool = True,
norm_eps: float = 1e-6,
clip_model_name: str = "openai/clip-vit-large-patch14",
compute_link_text_embeddings_on_the_fly: bool = False,
use_text_conditioning: bool = True,
clip_text_batch_size: int = 64,
dropout_all_normals_prob: float = 0.3,
dropout_all_point_prompts_prob: float = 0.0,
dropout_individual_point_prompt_prob: float = 0.0,
dropout_individual_text_conditioning_prob: float = 0.0,
use_pretrained_features_shape: bool = False,
use_pretrained_features_query: bool = False,
use_pretrained_features_point_prompt: bool = False,
pretrained_semantic_point_feature_dim: int = PARTFIELD_FEATURE_DIM,
):
super().__init__()
self.model_dim = model_dim
self.link_text_feature_dim = link_text_feature_dim
self.clip_model_name = clip_model_name
self.compute_link_text_embeddings_on_the_fly = bool(compute_link_text_embeddings_on_the_fly)
self.use_text_conditioning = bool(use_text_conditioning)
if int(clip_text_batch_size) <= 0:
raise ValueError(f"clip_text_batch_size must be positive, got {clip_text_batch_size}")
self.clip_text_batch_size = int(clip_text_batch_size)
self.dropout_all_normals_prob = float(dropout_all_normals_prob)
if not 0.0 <= self.dropout_all_normals_prob <= 1.0:
raise ValueError(
"dropout_all_normals_prob must be in [0, 1], "
f"got {dropout_all_normals_prob}"
)
self.dropout_all_point_prompts_prob = dropout_all_point_prompts_prob
self.dropout_individual_point_prompt_prob = dropout_individual_point_prompt_prob
self.dropout_individual_text_conditioning_prob = float(
dropout_individual_text_conditioning_prob
)
if not 0.0 <= self.dropout_individual_text_conditioning_prob <= 1.0:
raise ValueError(
"dropout_individual_text_conditioning_prob must be in [0, 1], "
f"got {dropout_individual_text_conditioning_prob}"
)
if not self.use_text_conditioning and (
self.dropout_all_point_prompts_prob != 0.0
or self.dropout_individual_point_prompt_prob != 0.0
or self.dropout_individual_text_conditioning_prob != 0.0
):
raise ValueError(
"use_text_conditioning=False requires "
"dropout_all_point_prompts_prob=0 and "
"dropout_individual_point_prompt_prob=0 and "
"dropout_individual_text_conditioning_prob=0 to avoid ambiguous "
"text-free vs. point-prompt-dropped training examples"
)
self.use_pretrained_features_shape = bool(use_pretrained_features_shape)
self.use_pretrained_features_query = bool(use_pretrained_features_query)
self.use_pretrained_features_point_prompt = bool(use_pretrained_features_point_prompt)
self.pretrained_semantic_point_feature_dim = int(pretrained_semantic_point_feature_dim)
self._partfield_feature_extractor: Optional[PartFieldFeatureExtractor] = None
self.text_tokenizer: Any = None
self.text_model: Optional[nn.Module] = None
attn_block_kwargs = {
"model_dim": model_dim,
"num_heads": num_heads,
"head_dim": head_dim,
"attn_dropout": attn_dropout,
"proj_dropout": proj_dropout,
"ffn_multiplier": ffn_multiplier,
"ffn_dropout": ffn_dropout,
"qk_norm": qk_norm,
"norm_eps": norm_eps,
}
self.coordinate_embedder = FrequencyMLPEmbedder(
output_dim=model_dim,
num_frequencies=coordinate_num_frequencies,
input_dim=3,
hidden_dim=embedding_hidden_dim,
max_frequency=coordinate_max_frequency,
)
self.normal_embedder = FrequencyMLPEmbedder(
output_dim=model_dim,
num_frequencies=normal_num_frequencies,
input_dim=3,
hidden_dim=embedding_hidden_dim,
max_frequency=normal_max_frequency,
)
self.shape_encoder = ShapeLatentEncoder(
num_shape_latents=num_shape_latents,
**attn_block_kwargs,
)
text_hidden_dim = embedding_hidden_dim or model_dim
self.link_text_input_norm = nn.RMSNorm(link_text_feature_dim, eps=norm_eps)
self.link_text_projector = _make_silu_mlp(
input_dim=link_text_feature_dim,
hidden_dim=text_hidden_dim,
output_dim=model_dim,
)
if (
self.use_pretrained_features_shape
or self.use_pretrained_features_query
or self.use_pretrained_features_point_prompt
):
self.pretrained_feature_input_norm = nn.RMSNorm(
self.pretrained_semantic_point_feature_dim,
eps=norm_eps,
)
self.pretrained_feature_projector = _make_silu_mlp(
input_dim=self.pretrained_semantic_point_feature_dim,
hidden_dim=text_hidden_dim,
output_dim=model_dim,
)
else:
self.pretrained_feature_input_norm = None
self.pretrained_feature_projector = None
self.link_to_shape_cross_attn = nn.ModuleList(
[
CrossAttentionBlock(**attn_block_kwargs)
for _ in range(num_attention_blocks)
]
)
self.link_self_attn = nn.ModuleList(
[
SelfAttentionBlock(**attn_block_kwargs)
for _ in range(num_attention_blocks)
]
)
self.query_to_shape_cross_attn = nn.ModuleList(
[
CrossAttentionBlock(**attn_block_kwargs)
for _ in range(num_attention_blocks)
]
)
self.query_to_link_cross_attn = nn.ModuleList(
[
CrossAttentionBlock(**attn_block_kwargs)
for _ in range(num_attention_blocks)
]
)
self.no_point_prompt_embedding = nn.Parameter(torch.zeros(model_dim))
self.no_text_conditioning_embedding = nn.Parameter(torch.zeros(model_dim))
self.link_output_norm = nn.RMSNorm(model_dim, eps=norm_eps)
self.query_output_norm = nn.RMSNorm(model_dim, eps=norm_eps)
def encode_shape(
self,
shape_points: Tensor,
shape_point_normals: Tensor,
pretrained_features: Tensor | None = None,
drop_normal_mask: Tensor | None = None,
) -> Tensor:
point_tokens = self._embed_point_tokens(
shape_points,
shape_point_normals,
pretrained_features=pretrained_features,
drop_normal_mask=drop_normal_mask,
)
return self.shape_encoder(point_tokens)
def encode_links(
self,
link_point_prompts: Tensor | None,
link_point_prompt_normals: Tensor | None,
link_valid_flag: Tensor,
link_point_prompt_dropout_eligible: Tensor | None = None,
forced_no_point_prompt_mask: Tensor | None = None,
drop_normal_mask: Tensor | None = None,
link_point_prompt_pretrained_features: Tensor | None = None,
link_text_prompts: Optional[Sequence[Sequence[str]]] = None,
link_text_embeddings: Tensor | None = None,
) -> Tensor:
"""Embeds valid link prompts and leaves padded link slots at zero.
When prompt tensors are omitted, valid links use the learned
`no_point_prompt_embedding` instead of point-derived features. When
`forced_no_point_prompt_mask` is provided, those valid links also use
the no-prompt embedding even in eval mode.
"""
batch_size, max_links = link_valid_flag.shape
no_point_prompt_mask = self._resolve_no_point_prompt_mask(
link_point_prompts=link_point_prompts,
link_valid_flag=link_valid_flag,
link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible,
forced_no_point_prompt_mask=forced_no_point_prompt_mask,
)
point_features = self._resolve_valid_link_point_features(
link_point_prompts=link_point_prompts,
link_point_prompt_normals=link_point_prompt_normals,
link_valid_flag=link_valid_flag,
no_point_prompt_mask=no_point_prompt_mask,
drop_normal_mask=drop_normal_mask,
link_point_prompt_pretrained_features=link_point_prompt_pretrained_features,
)
if self.use_text_conditioning:
text_features = self._resolve_valid_link_text_features(
link_text_prompts=link_text_prompts,
link_text_embeddings=link_text_embeddings,
link_valid_flag=link_valid_flag,
)
else:
text_features = torch.zeros(
(int(link_valid_flag.sum().item()), self.link_text_feature_dim),
device=link_valid_flag.device,
dtype=self.link_text_projector[0].weight.dtype,
)
projected_text_features = self.link_text_projector(
self.link_text_input_norm(text_features)
)
if self.use_text_conditioning:
text_conditioning_dropout_mask = self._sample_text_conditioning_dropout_mask(
link_valid_flag,
no_point_prompt_mask=no_point_prompt_mask,
)
dropped_valid_text_links = text_conditioning_dropout_mask[link_valid_flag]
projected_text_features = torch.where(
dropped_valid_text_links.unsqueeze(-1),
self.no_text_conditioning_embedding.to(
device=projected_text_features.device,
dtype=projected_text_features.dtype,
).unsqueeze(0).expand_as(projected_text_features),
projected_text_features,
)
if point_features.dtype != projected_text_features.dtype:
point_features = point_features.to(dtype=projected_text_features.dtype)
link_latents = projected_text_features.new_zeros((batch_size, max_links, self.model_dim))
valid_link_features = point_features + projected_text_features
link_latents[link_valid_flag] = valid_link_features
return link_latents
def _run_attention_blocks(
self,
shape_latents: Tensor,
link_latents: Tensor,
query_latents: Tensor,
link_valid_flag: Tensor,
) -> Tuple[Tensor, Tensor]:
for (
link_to_shape_cross_attn,
link_self_attn,
query_to_shape_cross_attn,
query_to_link_cross_attn,
) in zip(
self.link_to_shape_cross_attn,
self.link_self_attn,
self.query_to_shape_cross_attn,
self.query_to_link_cross_attn,
strict=True,
):
link_latents = link_to_shape_cross_attn(
link_latents,
shape_latents,
query_mask=link_valid_flag,
)
link_latents = link_self_attn(link_latents, mask=link_valid_flag)
query_latents = query_to_shape_cross_attn(query_latents, shape_latents)
query_latents = query_to_link_cross_attn(
query_latents,
link_latents,
context_mask=link_valid_flag,
)
# Residual blocks can write into padded rows unless they are zeroed again.
link_latents = self.link_output_norm(link_latents)
link_latents = link_latents.masked_fill(~link_valid_flag.unsqueeze(-1), 0)
query_latents = self.query_output_norm(query_latents)
return link_latents, query_latents
def forward(
self,
shape_points: Tensor | None = None,
shape_point_normals: Tensor | None = None,
query_points: Tensor | None = None,
query_point_normals: Tensor | None = None,
link_point_prompts: Tensor | None = None,
link_point_prompt_normals: Tensor | None = None,
link_valid_flag: Tensor | None = None,
link_point_prompt_dropout_eligible: Tensor | None = None,
link_text_prompts: Sequence[Sequence[str]] | None = None,
link_text_embeddings: Tensor | None = None,
) -> Dict[str, Any]:
"""Encodes shape, link, and query inputs into latent tokens.
Args:
shape_points: Shape points with shape `(B, Ns, 3)`.
shape_point_normals: Shape normals with shape `(B, Ns, 3)`.
query_points: Query points with shape `(B, Nq, 3)`.
query_point_normals: Query normals with shape `(B, Nq, 3)`.
link_point_prompts: Optional link prompt points with shape
`(B, L, 3)`. When omitted together with
`link_point_prompt_normals`, valid link slots use the learned
no-prompt embedding instead.
link_point_prompt_normals: Optional link prompt normals aligned
with `link_point_prompts`, with shape `(B, L, 3)`.
link_valid_flag: Boolean tensor with shape `(B, L)` indicating
which padded link slots are valid.
link_point_prompt_dropout_eligible: Optional boolean tensor with
shape `(B, L)` indicating which valid links may have their
point prompt and prompt normal dropped during training.
link_text_prompts: Per-link text prompts as a batch list of string
lists, where sample `i` has length `sum(link_valid_flag[i])`.
link_text_embeddings: Optional padded language features with shape
`(B, L, D_lang)`.
"""
if (link_point_prompts is None) != (link_point_prompt_normals is None):
raise ValueError(
"link_point_prompts and link_point_prompt_normals must both be provided or both be None"
)
if (
shape_points is None
or shape_point_normals is None
or query_points is None
or query_point_normals is None
or link_valid_flag is None
):
raise ValueError(
"forward requires shape/query tensors and link_valid_flag"
)
if (
self.use_text_conditioning
and link_text_prompts is None
and link_text_embeddings is None
):
raise ValueError(
"forward requires either link_text_prompts or link_text_embeddings "
"when use_text_conditioning=True"
)
(
shape_pretrained_features,
query_pretrained_features,
link_point_prompt_pretrained_features,
) = (
self._compute_pretrained_point_features(
shape_points=shape_points,
query_points=query_points,
link_point_prompts=link_point_prompts,
)
)
drop_normal_mask = self._sample_all_normals_dropout_mask(
batch_size=int(shape_points.shape[0]),
device=shape_points.device,
)
shape_latents = self.encode_shape(
shape_points,
shape_point_normals,
pretrained_features=shape_pretrained_features,
drop_normal_mask=drop_normal_mask,
)
link_latents = self.encode_links(
link_point_prompts=link_point_prompts,
link_point_prompt_normals=link_point_prompt_normals,
link_valid_flag=link_valid_flag,
link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible,
drop_normal_mask=drop_normal_mask,
link_point_prompt_pretrained_features=link_point_prompt_pretrained_features,
link_text_prompts=link_text_prompts,
link_text_embeddings=link_text_embeddings,
)
query_latents = self._embed_point_tokens(
query_points,
query_point_normals,
pretrained_features=query_pretrained_features,
drop_normal_mask=drop_normal_mask,
)
link_latents, query_latents = self._run_attention_blocks(
shape_latents=shape_latents,
link_latents=link_latents,
query_latents=query_latents,
link_valid_flag=link_valid_flag,
)
return {
"shape_latents": shape_latents,
"query_latents": query_latents,
"link_latents": link_latents,
}
def _embed_point_tokens(
self,
points: Tensor,
normals: Tensor,
*,
pretrained_features: Tensor | None = None,
drop_normal_mask: Tensor | None = None,
) -> Tensor:
point_tokens = self.coordinate_embedder(points)
normal_tokens = self.normal_embedder(normals)
if drop_normal_mask is not None:
if drop_normal_mask.ndim != 1 or drop_normal_mask.shape[0] != points.shape[0]:
raise ValueError(
"drop_normal_mask must have shape (B,), "
f"got {tuple(drop_normal_mask.shape)} for points {tuple(points.shape)}"
)
drop_normal_mask = drop_normal_mask.to(
device=normal_tokens.device,
dtype=torch.bool,
)
mask_shape = (drop_normal_mask.shape[0],) + (1,) * (normal_tokens.ndim - 1)
normal_tokens = normal_tokens.masked_fill(drop_normal_mask.view(mask_shape), 0)
if point_tokens.dtype != normal_tokens.dtype:
point_tokens = point_tokens.to(dtype=normal_tokens.dtype)
point_tokens = point_tokens + normal_tokens
if pretrained_features is None:
return point_tokens
if (
self.pretrained_feature_input_norm is None
or self.pretrained_feature_projector is None
):
raise ValueError(
"Received pretrained_features but pretrained feature projection is disabled"
)
pretrained_features = pretrained_features.to(
dtype=self.pretrained_feature_projector[0].weight.dtype
)
projected_pretrained_features = self.pretrained_feature_projector(
self.pretrained_feature_input_norm(pretrained_features)
)
if point_tokens.dtype != projected_pretrained_features.dtype:
point_tokens = point_tokens.to(dtype=projected_pretrained_features.dtype)
return point_tokens + projected_pretrained_features
def _get_partfield_feature_extractor(self) -> PartFieldFeatureExtractor:
if self._partfield_feature_extractor is None:
self._partfield_feature_extractor = PartFieldFeatureExtractor()
return self._partfield_feature_extractor
def _compute_pretrained_point_features(
self,
*,
shape_points: Tensor,
query_points: Tensor,
link_point_prompts: Tensor | None = None,
) -> tuple[Tensor | None, Tensor | None, Tensor | None]:
needs_prompt_features = (
self.use_pretrained_features_point_prompt
and link_point_prompts is not None
)
if not (
self.use_pretrained_features_shape
or self.use_pretrained_features_query
or needs_prompt_features
):
return None, None, None
decode_segments: list[tuple[str, Tensor]] = []
if self.use_pretrained_features_query:
decode_segments.append(("query", query_points))
if needs_prompt_features:
if link_point_prompts.shape[0] != shape_points.shape[0]:
raise ValueError(
"link_point_prompts batch dimension must match shape_points when "
"use_pretrained_features_point_prompt=True, "
f"got {tuple(link_point_prompts.shape)} and {tuple(shape_points.shape)}"
)
decode_segments.append(("prompt", link_point_prompts))
decode_query_points = (
torch.cat([points for _, points in decode_segments], dim=1)
if decode_segments
else None
)
extractor = self._get_partfield_feature_extractor()
shape_features, combined_decode_features = extractor.extract(
encode_points=shape_points,
decode_shape_points=shape_points if self.use_pretrained_features_shape else None,
decode_query_points=decode_query_points,
)
query_features = None
prompt_features = None
if combined_decode_features is not None:
if len(decode_segments) == 1:
decoded_features = {decode_segments[0][0]: combined_decode_features}
else:
decoded_features: dict[str, Tensor] = {}
offset = 0
for name, points in decode_segments:
count = points.shape[1]
decoded_features[name] = combined_decode_features[:, offset : offset + count]
offset += count
query_features = decoded_features.get("query")
prompt_features = decoded_features.get("prompt")
return shape_features, query_features, prompt_features
def _sample_point_prompt_dropout_mask(
self,
link_valid_flag: Tensor,
link_point_prompt_dropout_eligible: Tensor | None = None,
) -> Tensor:
if not self.training or (
self.dropout_all_point_prompts_prob == 0.0
and self.dropout_individual_point_prompt_prob == 0.0
):
return torch.zeros_like(link_valid_flag)
batch_size = link_valid_flag.shape[0]
device = link_valid_flag.device
if link_point_prompt_dropout_eligible is None:
eligible_link_mask = link_valid_flag
else:
if link_point_prompt_dropout_eligible.shape != link_valid_flag.shape:
raise ValueError(
"link_point_prompt_dropout_eligible must match link_valid_flag shape, "
f"got {tuple(link_point_prompt_dropout_eligible.shape)} and "
f"{tuple(link_valid_flag.shape)}"
)
eligible_link_mask = link_valid_flag & link_point_prompt_dropout_eligible
drop_mask = torch.zeros_like(link_valid_flag)
drop_all_samples = torch.zeros(batch_size, device=device, dtype=torch.bool)
if self.dropout_all_point_prompts_prob > 0.0:
drop_all_samples = (
torch.rand(batch_size, device=device) < self.dropout_all_point_prompts_prob
)
drop_mask[drop_all_samples] = eligible_link_mask[drop_all_samples]
if self.dropout_individual_point_prompt_prob > 0.0:
keep_individual_dropout = ~drop_all_samples
individual_drop_mask = (
torch.rand(link_valid_flag.shape, device=device)
< self.dropout_individual_point_prompt_prob
)
drop_mask[keep_individual_dropout] = (
individual_drop_mask[keep_individual_dropout]
& eligible_link_mask[keep_individual_dropout]
)
return drop_mask
def _sample_all_normals_dropout_mask(
self,
*,
batch_size: int,
device: torch.device,
) -> Tensor:
if not self.training or self.dropout_all_normals_prob == 0.0:
return torch.zeros((batch_size,), device=device, dtype=torch.bool)
return torch.rand(batch_size, device=device) < self.dropout_all_normals_prob
def _resolve_no_point_prompt_mask(
self,
*,
link_point_prompts: Tensor | None,
link_valid_flag: Tensor,
link_point_prompt_dropout_eligible: Tensor | None = None,
forced_no_point_prompt_mask: Tensor | None = None,
) -> Tensor:
if forced_no_point_prompt_mask is not None:
if forced_no_point_prompt_mask.shape != link_valid_flag.shape:
raise ValueError(
"forced_no_point_prompt_mask must match link_valid_flag shape, "
f"got {tuple(forced_no_point_prompt_mask.shape)} and "
f"{tuple(link_valid_flag.shape)}"
)
forced_no_point_prompt_mask = (
forced_no_point_prompt_mask.to(
device=link_valid_flag.device,
dtype=torch.bool,
)
& link_valid_flag
)
if link_point_prompts is None:
no_point_prompt_mask = link_valid_flag.clone()
else:
no_point_prompt_mask = self._sample_point_prompt_dropout_mask(
link_valid_flag,
link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible,
)
if forced_no_point_prompt_mask is not None:
no_point_prompt_mask = no_point_prompt_mask | forced_no_point_prompt_mask
return no_point_prompt_mask & link_valid_flag
def _sample_text_conditioning_dropout_mask(
self,
link_valid_flag: Tensor,
*,
no_point_prompt_mask: Tensor | None = None,
) -> Tensor:
if (
not self.training
or self.dropout_individual_text_conditioning_prob == 0.0
):
return torch.zeros_like(link_valid_flag)
if no_point_prompt_mask is None:
text_dropout_eligible_mask = link_valid_flag
else:
if no_point_prompt_mask.shape != link_valid_flag.shape:
raise ValueError(
"no_point_prompt_mask must match link_valid_flag shape, "
f"got {tuple(no_point_prompt_mask.shape)} and "
f"{tuple(link_valid_flag.shape)}"
)
text_dropout_eligible_mask = link_valid_flag & ~no_point_prompt_mask.to(
device=link_valid_flag.device,
dtype=torch.bool,
)
return (
torch.rand(link_valid_flag.shape, device=link_valid_flag.device)
< self.dropout_individual_text_conditioning_prob
) & text_dropout_eligible_mask
def _resolve_valid_link_point_features(
self,
link_point_prompts: Tensor | None,
link_point_prompt_normals: Tensor | None,
link_valid_flag: Tensor,
no_point_prompt_mask: Tensor | None = None,
drop_normal_mask: Tensor | None = None,
link_point_prompt_pretrained_features: Tensor | None = None,
) -> Tensor:
if no_point_prompt_mask is not None:
if no_point_prompt_mask.shape != link_valid_flag.shape:
raise ValueError(
"no_point_prompt_mask must match link_valid_flag shape, "
f"got {tuple(no_point_prompt_mask.shape)} and "
f"{tuple(link_valid_flag.shape)}"
)
no_point_prompt_mask = (
no_point_prompt_mask.to(
device=link_valid_flag.device,
dtype=torch.bool,
)
& link_valid_flag
)
if link_point_prompts is None:
if link_point_prompt_pretrained_features is not None:
raise ValueError(
"link_point_prompt_pretrained_features requires link_point_prompts"
)
no_prompt_embedding = self.no_point_prompt_embedding.to(
device=link_valid_flag.device,
dtype=self.no_point_prompt_embedding.dtype,
)
return no_prompt_embedding.view(1, 1, -1).expand(*link_valid_flag.shape, -1)[link_valid_flag]
point_features = self._embed_point_tokens(
link_point_prompts,
link_point_prompt_normals,
pretrained_features=link_point_prompt_pretrained_features,
drop_normal_mask=drop_normal_mask,
)[link_valid_flag]
if no_point_prompt_mask is None:
no_point_prompt_mask = torch.zeros_like(link_valid_flag)
dropped_valid_links = no_point_prompt_mask[link_valid_flag]
return torch.where(
dropped_valid_links.unsqueeze(-1),
self.no_point_prompt_embedding.to(
device=point_features.device,
dtype=point_features.dtype,
).unsqueeze(0).expand_as(point_features),
point_features,
)
def _resolve_valid_link_text_features(
self,
link_text_prompts: Optional[Sequence[Sequence[str]]],
link_text_embeddings: Tensor | None,
link_valid_flag: Tensor,
) -> Tensor:
"""Returns text features flattened in the same order as `link_valid_flag`."""
if link_text_embeddings is not None:
return link_text_embeddings[link_valid_flag].to(
dtype=self.link_text_projector[0].weight.dtype
)
if link_text_prompts is None:
flattened_prompts = [""] * int(link_valid_flag.sum().item())
else:
flattened_prompts = list(chain.from_iterable(link_text_prompts))
return self._encode_link_text_prompts(
flattened_prompts,
dtype=self.link_text_projector[0].weight.dtype,
)
def _ensure_text_model_loaded(self) -> None:
if self.text_model is not None and self.text_tokenizer is not None:
return
cache_dir = os.environ.get("HF_HOME")
self.text_tokenizer, self.text_model = load_clip_text_encoder(
self.clip_model_name,
device=self.link_text_projector[0].weight.device,
expected_embedding_dim=self.link_text_feature_dim,
cache_dir=cache_dir,
)
@torch.no_grad()
def _encode_link_text_prompts(
self,
prompts: Sequence[str],
dtype: torch.dtype,
) -> Tensor:
if len(prompts) == 0:
return torch.zeros(
(0, self.link_text_feature_dim),
device=self.link_text_projector[0].weight.device,
dtype=dtype,
)
if not self.compute_link_text_embeddings_on_the_fly:
raise ValueError(
"Missing link_text_embeddings but compute_link_text_embeddings_on_the_fly=False. "
"Provide dataset-side text embeddings or enable on-the-fly CLIP text encoding."
)
self._ensure_text_model_loaded()
return encode_clip_text_prompts(
prompts,
tokenizer=self.text_tokenizer,
text_model=self.text_model,
batch_size=self.clip_text_batch_size,
output_device=self.link_text_projector[0].weight.device,
output_dtype=dtype,
)
class Particulate2ArticulationModel(nn.Module):
"""Encoder, decoders, and training losses for articulation prediction."""
def __init__(
self,
encoder: Optional[Particulate2Encoder] = None,
*,
segmentation_decode_type: str = "mlp",
segmentation_query_chunk_size: Optional[int] = None,
joint_decode_type: str = "overparameterization",
joint_fm_hidden_dim: Optional[int] = None,
joint_fm_prediction_type: str = "v",
joint_fm_time_embedding_dim: int = 256,
joint_fm_inference_steps: int = 100,
joint_fm_time_scale: float = 1000.0,
joint_fm_sigma_min: float = 0.0,
joint_fm_rescale_t: float = 1.0,
joint_fm_cfg_scale: float = 1.0,
revolute_joint_fm_state_mean: Optional[Sequence[float]] = None,
revolute_joint_fm_state_std: Optional[Sequence[float]] = None,
prismatic_joint_fm_state_mean: Optional[Sequence[float]] = None,
prismatic_joint_fm_state_std: Optional[Sequence[float]] = None,
joint_fm_training_time_mean: float = 0.0,
joint_fm_training_time_std: float = 1.0,
use_ancestor_context_for_segmentation: bool = False,
ancestor_context_decay: float = 0.5,
segmentation_bias: bool = True,
joint_decoder_bias: bool = True,
segmentation_cross_entropy_weight: float = 1.0,
segmentation_dice_weight: float = 1.0,
revolute_joint_axis_l1_weight: float = 1.0,
prismatic_joint_axis_l1_weight: float = 1.0,
revolute_joint_range_l1_weight: float = 1.0,
prismatic_joint_range_l1_weight: float = 1.0,
revolute_joint_fm_weight: float = 1.0,
prismatic_joint_fm_weight: float = 1.0,
revolute_overparam_axis_l1_weight: float = 1.0,
revolute_overparam_point_l1_weight: float = 1.0,
prismatic_overparam_axis_l1_weight: float = 1.0,
prismatic_overparam_point_l1_weight: float = 1.0,
revolute_overparam_direction_weight: float = 1.0,
prismatic_overparam_direction_weight: float = 1.0,
overparam_closest_axis_space: str = "world",
dice_smoothing: float = 1e-6,
**encoder_kwargs: Any,
):
super().__init__()
if encoder is not None and encoder_kwargs:
raise ValueError("Pass either an encoder instance or encoder kwargs, not both")
self.encoder = encoder if encoder is not None else Particulate2Encoder(**encoder_kwargs)
joint_decode_type = _normalize_joint_decode_type(joint_decode_type)
if joint_decode_type not in {*_PLAIN_JOINT_DECODE_TYPES, *_OVERPARAM_JOINT_DECODE_TYPES}:
raise ValueError(
"joint_decode_type must be 'plain', 'plain+fm', 'overparametrized', "
"'overparameterized', 'overparameterization', 'overparam+dir', or "
"'overparam+singledir', "
f"got {joint_decode_type!r}"
)
self.joint_decode_type = joint_decode_type
self.plain_flow_matching_enabled = self.joint_decode_type == "plain+fm"
self.overparam_predicts_query_axis_direction = self.joint_decode_type == "overparam+dir"
self.overparam_predicts_single_axis_direction = (
self.joint_decode_type == "overparam+singledir"
)
self.overparam_uses_axis_direction = (
self.overparam_predicts_query_axis_direction
or self.overparam_predicts_single_axis_direction
)
overparam_closest_axis_space = _normalize_overparam_closest_axis_space(
overparam_closest_axis_space
)
if overparam_closest_axis_space not in {"world", "part_aabb"}:
raise ValueError(
"overparam_closest_axis_space must be 'world', 'sample', 'part_aabb', or "
f"'local_aabb', got {overparam_closest_axis_space!r}"
)
self.overparam_closest_axis_space = overparam_closest_axis_space
self.overparam_closest_axis_uses_part_aabb = (
self.overparam_closest_axis_space == "part_aabb"
)
self.use_ancestor_context_for_segmentation = bool(
use_ancestor_context_for_segmentation
)
self.ancestor_context_decay = float(ancestor_context_decay)
if not math.isfinite(self.ancestor_context_decay):
raise ValueError(
"ancestor_context_decay must be finite, "
f"got {ancestor_context_decay!r}"
)
if self.ancestor_context_decay < 0.0 or self.ancestor_context_decay > 1.0:
raise ValueError(
"ancestor_context_decay must be in [0, 1], "
f"got {ancestor_context_decay!r}"
)
self.segmentation_decoder = SegmentationDecoder(
model_dim=self.encoder.model_dim,
decode_type=segmentation_decode_type,
bias=segmentation_bias,
query_chunk_size=segmentation_query_chunk_size,
)
self.joint_direction_decoder: JointDecoderSingleDirection | None = None
if self.joint_decode_type == "plain":
self.joint_decoder = JointDecoderPlain(
model_dim=self.encoder.model_dim,
bias=joint_decoder_bias,
)
elif self.plain_flow_matching_enabled:
self.joint_decoder = JointDecoderPlainFlowMatching(
model_dim=self.encoder.model_dim,
hidden_dim=joint_fm_hidden_dim,
prediction_type=joint_fm_prediction_type,
time_embedding_dim=joint_fm_time_embedding_dim,
inference_steps=joint_fm_inference_steps,
time_scale=joint_fm_time_scale,
sigma_min=joint_fm_sigma_min,
rescale_t=joint_fm_rescale_t,
cfg_scale=joint_fm_cfg_scale,
revolute_state_mean=revolute_joint_fm_state_mean,
revolute_state_std=revolute_joint_fm_state_std,
prismatic_state_mean=prismatic_joint_fm_state_mean,
prismatic_state_std=prismatic_joint_fm_state_std,
bias=joint_decoder_bias,
)
else:
self.joint_decoder = JointDecoderOverParametrized(
model_dim=self.encoder.model_dim,
output_dim=12 if self.overparam_predicts_query_axis_direction else 9,
bias=joint_decoder_bias,
)
if self.overparam_predicts_single_axis_direction:
self.joint_direction_decoder = JointDecoderSingleDirection(
model_dim=self.encoder.model_dim,
bias=joint_decoder_bias,
)
if self.use_ancestor_context_for_segmentation:
self.ancestor_context_projector = nn.Linear(
self.encoder.model_dim,
self.encoder.model_dim,
bias=False,
)
self.ancestor_context_gate = _make_silu_mlp(
input_dim=self.encoder.model_dim * 2,
hidden_dim=self.encoder.model_dim,
output_dim=1,
)
else:
self.ancestor_context_projector = None
self.ancestor_context_gate = None
self.segmentation_cross_entropy_weight = float(segmentation_cross_entropy_weight)
self.segmentation_dice_weight = float(segmentation_dice_weight)
self.revolute_joint_axis_l1_weight = float(revolute_joint_axis_l1_weight)
self.prismatic_joint_axis_l1_weight = float(prismatic_joint_axis_l1_weight)
self.revolute_joint_range_l1_weight = float(revolute_joint_range_l1_weight)
self.prismatic_joint_range_l1_weight = float(prismatic_joint_range_l1_weight)
self.revolute_joint_fm_weight = float(revolute_joint_fm_weight)
self.prismatic_joint_fm_weight = float(prismatic_joint_fm_weight)
self.revolute_overparam_axis_l1_weight = float(revolute_overparam_axis_l1_weight)
self.revolute_overparam_point_l1_weight = float(revolute_overparam_point_l1_weight)
self.prismatic_overparam_axis_l1_weight = float(prismatic_overparam_axis_l1_weight)
self.prismatic_overparam_point_l1_weight = float(prismatic_overparam_point_l1_weight)
self.revolute_overparam_direction_weight = float(revolute_overparam_direction_weight)
self.prismatic_overparam_direction_weight = float(prismatic_overparam_direction_weight)
self.joint_fm_training_time_mean = float(joint_fm_training_time_mean)
self.joint_fm_training_time_std = float(joint_fm_training_time_std)
self.dice_smoothing = float(dice_smoothing)
def decode_segmentation(
self,
query_latents: Tensor,
link_latents: Tensor,
link_valid_flag: Tensor,
) -> Tensor:
return self.segmentation_decoder(
query_latents=query_latents,
link_latents=link_latents,
link_valid_flag=link_valid_flag,
)
def decode_joint_parameters(
self,
*,
link_latents: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
is_revolute: Tensor,
is_prismatic: Tensor,
query_latents: Tensor | None = None,
query_points: Tensor | None = None,
assigned_link_ids: Tensor | None = None,
decoded_motion_points: Tuple[Tensor, Tensor] | None = None,
decoded_axis_directions: Tuple[Tensor, Tensor] | None = None,
decoded_motion_points_are_world: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Returns revolute/prismatic axis and range tensors for the configured decode mode.
In plain mode the joint decoder predicts per-joint parameters directly from
parent/child link latents. In plain+fm mode it samples parameters with
a SAM3D-style flow-matching Euler sampler. In over-parameterized mode it first obtains
query-wise motion targets and then fits a single joint axis and range for
each child link from those targets.
"""
if self.joint_decode_type == "plain":
return self.joint_decoder.predict(
link_latents=link_latents,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
is_revolute=is_revolute,
is_prismatic=is_prismatic,
)
if self.plain_flow_matching_enabled:
return self.joint_decoder.predict(
link_latents=link_latents,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
is_revolute=is_revolute,
is_prismatic=is_prismatic,
)
if query_points is None or assigned_link_ids is None:
raise ValueError(
"over-parameterized joint decoding requires query_points and assigned_link_ids"
)
if decoded_motion_points is None:
if query_latents is None:
raise ValueError(
"over-parameterized joint decoding requires query_latents when decoded_motion_points are not provided"
)
decoded_motion_points = self._decode_joint_motion_points(
query_latents=query_latents,
link_latents=link_latents,
assigned_link_ids=assigned_link_ids,
joint_connections=joint_connections,
)
if self.overparam_predicts_single_axis_direction and decoded_axis_directions is None:
decoded_axis_directions = self._decode_joint_axis_directions(
link_latents=link_latents,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
is_revolute=is_revolute,
is_prismatic=is_prismatic,
)
if not decoded_motion_points_are_world:
decoded_motion_points = (
self._convert_overparam_motion_points_to_world_coordinates(
motion_points=decoded_motion_points[0],
query_points=query_points,
assigned_link_ids=assigned_link_ids,
),
self._convert_overparam_motion_points_to_world_coordinates(
motion_points=decoded_motion_points[1],
query_points=query_points,
assigned_link_ids=assigned_link_ids,
),
)
try:
autocast_context = torch.autocast(device_type=query_points.device.type, enabled=False)
except (RuntimeError, TypeError, ValueError):
autocast_context = nullcontext()
with autocast_context:
return self._recover_overparam_joint_parameters(
query_points=query_points,
assigned_link_ids=assigned_link_ids,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
is_revolute=is_revolute,
is_prismatic=is_prismatic,
revolute_motion_points=decoded_motion_points[0],
prismatic_motion_points=decoded_motion_points[1],
revolute_axis_directions=(
None if decoded_axis_directions is None else decoded_axis_directions[0]
),
prismatic_axis_directions=(
None if decoded_axis_directions is None else decoded_axis_directions[1]
),
)
def _decode_joint_motion_points(
self,
*,
query_latents: Tensor,
link_latents: Tensor,
assigned_link_ids: Tensor,
joint_connections: Tensor,
) -> Tuple[Tensor, Tensor]:
if self.joint_decode_type not in _OVERPARAM_JOINT_DECODE_TYPES:
raise ValueError(
"_decode_joint_motion_points is only available in over-parameterized joint decoding mode"
)
return self.joint_decoder(
query_latents=query_latents,
link_latents=link_latents,
assigned_link_ids=assigned_link_ids,
joint_connections=joint_connections,
)
def _decode_joint_axis_directions(
self,
*,
link_latents: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
is_revolute: Tensor,
is_prismatic: Tensor,
) -> Tuple[Tensor, Tensor]:
if not self.overparam_predicts_single_axis_direction or self.joint_direction_decoder is None:
raise ValueError(
"_decode_joint_axis_directions is only available for joint_decode_type='overparam+singledir'"
)
return self.joint_direction_decoder.predict(
link_latents=link_latents,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
is_revolute=is_revolute,
is_prismatic=is_prismatic,
)
def _compute_query_link_aabb_parameters(
self,
*,
query_points: Tensor,
link_ids: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Returns per-query AABB centers and half-extents for the assigned link ID."""
if query_points.ndim != 3 or query_points.shape[-1] != 3:
raise ValueError(
f"query_points must have shape (B, Q, 3), got {tuple(query_points.shape)}"
)
if link_ids.shape != query_points.shape[:2]:
raise ValueError(
"link_ids must match query_points batch/query dims, "
f"got {tuple(link_ids.shape)} and {tuple(query_points.shape)}"
)
query_points = query_points.float()
centers = query_points.new_zeros(query_points.shape)
half_extents = query_points.new_ones(query_points.shape)
for batch_idx in range(query_points.shape[0]):
batch_link_ids = link_ids[batch_idx]
valid_mask = batch_link_ids >= 0
if not bool(valid_mask.any().item()):
continue
unique_link_ids = torch.unique(batch_link_ids[valid_mask])
for link_id in unique_link_ids.tolist():
query_mask = batch_link_ids == int(link_id)
link_query_points = query_points[batch_idx][query_mask]
min_corner = link_query_points.min(dim=0).values
max_corner = link_query_points.max(dim=0).values
centers[batch_idx][query_mask] = 0.5 * (min_corner + max_corner)
half_extents[batch_idx][query_mask] = (
0.5 * (max_corner - min_corner)
).clamp_min(_OVERPARAM_AXIS_AABB_HALF_EXTENT_MIN)
return centers, half_extents
def _denormalize_overparam_axis_points(
self,
*,
axis_points: Tensor,
query_points: Tensor,
link_ids: Tensor,
) -> Tensor:
"""Converts normalized closest-axis points back into sample/world coordinates."""
if not self.overparam_closest_axis_uses_part_aabb:
return axis_points
centers, half_extents = self._compute_query_link_aabb_parameters(
query_points=query_points,
link_ids=link_ids,
)
return axis_points.float() * half_extents + centers
def _convert_overparam_motion_points_to_world_coordinates(
self,
*,
motion_points: Tensor,
query_points: Tensor,
assigned_link_ids: Tensor,
) -> Tensor:
"""Returns motion-point predictions in sample/world coordinates."""
if not self.overparam_closest_axis_uses_part_aabb:
return motion_points
closest_axis_points = self._denormalize_overparam_axis_points(
axis_points=motion_points[..., :3],
query_points=query_points,
link_ids=assigned_link_ids,
)
return torch.cat((closest_axis_points, motion_points[..., 3:]), dim=-1)
def _fit_revolute_joint_parameters(
self,
query_points: Tensor,
revolute_motion_points: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Fits one revolute axis and low/high angles from query-wise motion targets."""
# This solver is small and numerically sensitive; keep it in fp32 even
# when the surrounding network runs under AMP.
query_points = query_points.float()
revolute_motion_points = revolute_motion_points.float()
if query_points.numel() == 0:
zero_axis = revolute_motion_points.new_zeros(6)
zero_range = revolute_motion_points.new_zeros(2)
return zero_axis, zero_range
closest_axis_points = revolute_motion_points[..., :3]
low_points = revolute_motion_points[..., 3:6]
high_points = revolute_motion_points[..., 6:9]
direction_hint = (
torch.linalg.cross(
query_points - closest_axis_points,
low_points - closest_axis_points,
dim=-1,
)
+ torch.linalg.cross(
query_points - closest_axis_points,
high_points - closest_axis_points,
dim=-1,
)
+ torch.linalg.cross(
low_points - closest_axis_points,
high_points - closest_axis_points,
dim=-1,
)
).mean(dim=0)
axis_direction, axis_point = fit_axis_to_closest_points_torch(
query_points,
closest_axis_points,
direction_hint=direction_hint,
)
revolute_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
low_limit = estimate_revolute_limit_torch(
query_points,
low_points,
axis_direction,
axis_point,
)
high_limit = estimate_revolute_limit_torch(
query_points,
high_points,
axis_direction,
axis_point,
)
return revolute_axis, torch.stack((low_limit, high_limit))
def _aggregate_flip_invariant_axis_direction(
self,
predicted_directions: Tensor,
*,
sign_hint: Tensor | None = None,
) -> Tensor:
"""Returns one averaged unit direction while treating per-query flips as equivalent."""
predicted_directions = predicted_directions.float()
if predicted_directions.numel() == 0:
return predicted_directions.new_zeros(3)
direction_norms = torch.linalg.vector_norm(predicted_directions, dim=-1)
valid_mask = direction_norms > 1e-8
if not bool(valid_mask.any().item()):
return predicted_directions.new_zeros(3)
unit_directions = predicted_directions[valid_mask] / direction_norms[valid_mask].unsqueeze(-1)
direction_covariance = unit_directions.transpose(0, 1) @ unit_directions
_, eigenvectors = torch.linalg.eigh(direction_covariance)
anchor_direction = eigenvectors[:, -1]
alignment = torch.sign(unit_directions @ anchor_direction)
alignment = torch.where(
alignment == 0,
torch.ones_like(alignment),
alignment,
)
aligned_mean_direction = (unit_directions * alignment.unsqueeze(-1)).mean(dim=0)
if float(torch.linalg.vector_norm(aligned_mean_direction).item()) <= 1e-8:
axis_direction = F.normalize(anchor_direction, dim=0, eps=1e-8)
else:
axis_direction = F.normalize(aligned_mean_direction, dim=0, eps=1e-8)
if sign_hint is not None:
sign_hint = sign_hint.float()
if float(torch.linalg.vector_norm(sign_hint).item()) > 1e-8:
if float(torch.dot(axis_direction, sign_hint).item()) < 0.0:
axis_direction = -axis_direction
return axis_direction
def _fit_revolute_joint_parameters_with_direction(
self,
query_points: Tensor,
revolute_motion_points: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Fits one revolute axis from predicted axis directions plus point targets."""
query_points = query_points.float()
revolute_motion_points = revolute_motion_points.float()
if query_points.numel() == 0:
zero_axis = revolute_motion_points.new_zeros(6)
zero_range = revolute_motion_points.new_zeros(2)
return zero_axis, zero_range
closest_axis_points = revolute_motion_points[..., :3]
low_points = revolute_motion_points[..., 3:6]
high_points = revolute_motion_points[..., 6:9]
predicted_axis_directions = revolute_motion_points[..., 9:12]
direction_sign_hint = (
torch.linalg.cross(
query_points - closest_axis_points,
low_points - closest_axis_points,
dim=-1,
)
+ torch.linalg.cross(
query_points - closest_axis_points,
high_points - closest_axis_points,
dim=-1,
)
+ torch.linalg.cross(
low_points - closest_axis_points,
high_points - closest_axis_points,
dim=-1,
)
).mean(dim=0)
axis_direction = self._aggregate_flip_invariant_axis_direction(
predicted_axis_directions,
sign_hint=direction_sign_hint,
)
if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8:
return self._fit_revolute_joint_parameters(
query_points=query_points,
revolute_motion_points=revolute_motion_points[..., :9],
)
axis_point = torch.quantile(closest_axis_points, 0.5, dim=0)
low_limit = estimate_revolute_limit_torch(
query_points,
low_points,
axis_direction,
axis_point,
)
high_limit = estimate_revolute_limit_torch(
query_points,
high_points,
axis_direction,
axis_point,
)
if float(low_limit.item()) > float(high_limit.item()):
axis_direction = -axis_direction
low_limit = estimate_revolute_limit_torch(
query_points,
low_points,
axis_direction,
axis_point,
)
high_limit = estimate_revolute_limit_torch(
query_points,
high_points,
axis_direction,
axis_point,
)
revolute_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
return revolute_axis, torch.stack((low_limit, high_limit))
def _fit_prismatic_joint_parameters(
self,
query_points: Tensor,
prismatic_motion_points: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Fits one prismatic axis and low/high displacements from query-wise targets."""
# This solver is small and numerically sensitive; keep it in fp32 even
# when the surrounding network runs under AMP.
query_points = query_points.float()
prismatic_motion_points = prismatic_motion_points.float()
if query_points.numel() == 0:
zero_axis = prismatic_motion_points.new_zeros(6)
zero_range = prismatic_motion_points.new_zeros(2)
return zero_axis, zero_range
closest_axis_points = prismatic_motion_points[..., :3]
low_points = prismatic_motion_points[..., 3:6]
high_points = prismatic_motion_points[..., 6:9]
direction_hint = (high_points - low_points).mean(dim=0)
axis_direction, axis_point = fit_axis_to_closest_points_torch(
query_points,
closest_axis_points,
direction_hint=direction_hint,
)
prismatic_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
low_limit = estimate_prismatic_limit_torch(
query_points,
low_points,
axis_direction,
)
high_limit = estimate_prismatic_limit_torch(
query_points,
high_points,
axis_direction,
)
return prismatic_axis, torch.stack((low_limit, high_limit))
def _fit_prismatic_joint_parameters_with_direction(
self,
query_points: Tensor,
prismatic_motion_points: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Fits one prismatic axis from predicted axis directions plus point targets."""
query_points = query_points.float()
prismatic_motion_points = prismatic_motion_points.float()
if query_points.numel() == 0:
zero_axis = prismatic_motion_points.new_zeros(6)
zero_range = prismatic_motion_points.new_zeros(2)
return zero_axis, zero_range
closest_axis_points = prismatic_motion_points[..., :3]
low_points = prismatic_motion_points[..., 3:6]
high_points = prismatic_motion_points[..., 6:9]
predicted_axis_directions = prismatic_motion_points[..., 9:12]
axis_direction = self._aggregate_flip_invariant_axis_direction(
predicted_axis_directions,
sign_hint=(high_points - low_points).mean(dim=0),
)
if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8:
return self._fit_prismatic_joint_parameters(
query_points=query_points,
prismatic_motion_points=prismatic_motion_points[..., :9],
)
axis_point = torch.quantile(closest_axis_points, 0.5, dim=0)
low_limit = estimate_prismatic_limit_torch(
query_points,
low_points,
axis_direction,
)
high_limit = estimate_prismatic_limit_torch(
query_points,
high_points,
axis_direction,
)
if float(low_limit.item()) > float(high_limit.item()):
axis_direction = -axis_direction
low_limit = estimate_prismatic_limit_torch(
query_points,
low_points,
axis_direction,
)
high_limit = estimate_prismatic_limit_torch(
query_points,
high_points,
axis_direction,
)
prismatic_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
return prismatic_axis, torch.stack((low_limit, high_limit))
def _fit_revolute_joint_parameters_with_single_direction(
self,
query_points: Tensor,
revolute_motion_points: Tensor,
axis_direction: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Fits one revolute axis from a single predicted axis direction plus point targets."""
query_points = query_points.float()
revolute_motion_points = revolute_motion_points.float()
axis_direction = axis_direction.float()
if query_points.numel() == 0:
zero_axis = revolute_motion_points.new_zeros(6)
zero_range = revolute_motion_points.new_zeros(2)
return zero_axis, zero_range
if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8:
return self._fit_revolute_joint_parameters(
query_points=query_points,
revolute_motion_points=revolute_motion_points,
)
axis_direction = F.normalize(axis_direction, dim=0, eps=1e-8)
closest_axis_points = revolute_motion_points[..., :3]
low_points = revolute_motion_points[..., 3:6]
high_points = revolute_motion_points[..., 6:9]
axis_point = torch.quantile(closest_axis_points, 0.5, dim=0)
low_limit = estimate_revolute_limit_torch(
query_points,
low_points,
axis_direction,
axis_point,
)
high_limit = estimate_revolute_limit_torch(
query_points,
high_points,
axis_direction,
axis_point,
)
if float(low_limit.item()) > float(high_limit.item()):
axis_direction = -axis_direction
low_limit = estimate_revolute_limit_torch(
query_points,
low_points,
axis_direction,
axis_point,
)
high_limit = estimate_revolute_limit_torch(
query_points,
high_points,
axis_direction,
axis_point,
)
revolute_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
return revolute_axis, torch.stack((low_limit, high_limit))
def _fit_prismatic_joint_parameters_with_single_direction(
self,
query_points: Tensor,
prismatic_motion_points: Tensor,
axis_direction: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Fits one prismatic axis from a single predicted axis direction plus point targets."""
query_points = query_points.float()
prismatic_motion_points = prismatic_motion_points.float()
axis_direction = axis_direction.float()
if query_points.numel() == 0:
zero_axis = prismatic_motion_points.new_zeros(6)
zero_range = prismatic_motion_points.new_zeros(2)
return zero_axis, zero_range
if float(torch.linalg.vector_norm(axis_direction).item()) <= 1e-8:
return self._fit_prismatic_joint_parameters(
query_points=query_points,
prismatic_motion_points=prismatic_motion_points,
)
axis_direction = F.normalize(axis_direction, dim=0, eps=1e-8)
closest_axis_points = prismatic_motion_points[..., :3]
low_points = prismatic_motion_points[..., 3:6]
high_points = prismatic_motion_points[..., 6:9]
axis_point = torch.quantile(closest_axis_points, 0.5, dim=0)
low_limit = estimate_prismatic_limit_torch(
query_points,
low_points,
axis_direction,
)
high_limit = estimate_prismatic_limit_torch(
query_points,
high_points,
axis_direction,
)
if float(low_limit.item()) > float(high_limit.item()):
axis_direction = -axis_direction
low_limit = estimate_prismatic_limit_torch(
query_points,
low_points,
axis_direction,
)
high_limit = estimate_prismatic_limit_torch(
query_points,
high_points,
axis_direction,
)
prismatic_axis = axis_point_to_plucker_torch(axis_direction, axis_point)
return prismatic_axis, torch.stack((low_limit, high_limit))
def _recover_overparam_joint_parameters(
self,
*,
query_points: Tensor,
assigned_link_ids: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
is_revolute: Tensor,
is_prismatic: Tensor,
revolute_motion_points: Tensor,
prismatic_motion_points: Tensor,
revolute_axis_directions: Tensor | None = None,
prismatic_axis_directions: Tensor | None = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Recovers per-joint parameters by fitting each child link's query-wise targets."""
query_points = query_points.float()
revolute_motion_points = revolute_motion_points.float()
prismatic_motion_points = prismatic_motion_points.float()
batch_size, max_joints = joint_connections.shape[:2]
revolute_axis = query_points.new_zeros((batch_size, max_joints, 6))
prismatic_axis = query_points.new_zeros((batch_size, max_joints, 6))
revolute_range = query_points.new_zeros((batch_size, max_joints, 2))
prismatic_range = query_points.new_zeros((batch_size, max_joints, 2))
child_link_ids = joint_connections[..., 1]
for batch_idx in range(batch_size):
for joint_idx in range(max_joints):
query_mask = assigned_link_ids[batch_idx] == child_link_ids[batch_idx, joint_idx]
joint_query_points = query_points[batch_idx][query_mask]
if self.overparam_uses_axis_direction:
if self.joint_decode_type == "overparam+dir":
(
revolute_axis[batch_idx, joint_idx],
revolute_range[batch_idx, joint_idx],
) = self._fit_revolute_joint_parameters_with_direction(
joint_query_points,
revolute_motion_points[batch_idx][query_mask],
)
(
prismatic_axis[batch_idx, joint_idx],
prismatic_range[batch_idx, joint_idx],
) = self._fit_prismatic_joint_parameters_with_direction(
joint_query_points,
prismatic_motion_points[batch_idx][query_mask],
)
else:
if revolute_axis_directions is None or prismatic_axis_directions is None:
raise ValueError(
"overparam+singledir recovery requires per-joint axis directions"
)
(
revolute_axis[batch_idx, joint_idx],
revolute_range[batch_idx, joint_idx],
) = self._fit_revolute_joint_parameters_with_single_direction(
joint_query_points,
revolute_motion_points[batch_idx][query_mask],
revolute_axis_directions[batch_idx, joint_idx],
)
(
prismatic_axis[batch_idx, joint_idx],
prismatic_range[batch_idx, joint_idx],
) = self._fit_prismatic_joint_parameters_with_single_direction(
joint_query_points,
prismatic_motion_points[batch_idx][query_mask],
prismatic_axis_directions[batch_idx, joint_idx],
)
else:
(
revolute_axis[batch_idx, joint_idx],
revolute_range[batch_idx, joint_idx],
) = self._fit_revolute_joint_parameters(
joint_query_points,
revolute_motion_points[batch_idx][query_mask],
)
(
prismatic_axis[batch_idx, joint_idx],
prismatic_range[batch_idx, joint_idx],
) = self._fit_prismatic_joint_parameters(
joint_query_points,
prismatic_motion_points[batch_idx][query_mask],
)
revolute_mask = (joint_valid_flag & is_revolute).unsqueeze(-1)
prismatic_mask = (joint_valid_flag & is_prismatic).unsqueeze(-1)
return (
revolute_axis.masked_fill(~revolute_mask, 0),
prismatic_axis.masked_fill(~prismatic_mask, 0),
revolute_range.masked_fill(~revolute_mask[..., :1], 0),
prismatic_range.masked_fill(~prismatic_mask[..., :1], 0),
)
def _resolve_joint_decoding_link_ids(
self,
*,
segmentation_logits: Tensor,
link_ids: Tensor | None,
) -> Tensor:
if link_ids is not None:
return link_ids
if self.training:
raise ValueError(
"forward requires link_ids when joint_decode_type uses over-parameterized decoding during training"
)
return segmentation_logits.argmax(dim=-1)
def _build_parent_index(
self,
*,
joint_connections: Tensor,
joint_valid_flag: Tensor,
num_links: int,
) -> Tensor:
"""Returns one parent-link index per link, using `-1` for roots/padding."""
device = joint_connections.device
parent_index = torch.full(
(joint_connections.shape[0], num_links),
fill_value=-1,
dtype=joint_connections.dtype,
device=device,
)
if joint_connections.shape[1] == 0 or num_links == 0:
return parent_index
parent_indices = joint_connections[..., 0]
child_indices = joint_connections[..., 1]
valid_joint_mask = (
joint_valid_flag
& (parent_indices >= 0)
& (parent_indices < num_links)
& (child_indices >= 0)
& (child_indices < num_links)
)
if not bool(valid_joint_mask.any().item()):
return parent_index
valid_child_mask = torch.zeros(
(joint_connections.shape[0], num_links),
dtype=torch.bool,
device=device,
)
valid_child_mask.scatter_(
dim=1,
index=child_indices.clamp_min(0),
src=valid_joint_mask,
)
parent_values = torch.where(valid_joint_mask, parent_indices, torch.zeros_like(parent_indices))
parent_index.scatter_(
dim=1,
index=child_indices.clamp_min(0),
src=parent_values,
)
return parent_index.masked_fill(~valid_child_mask, -1)
def _build_depth_decayed_ancestor_context(
self,
*,
link_latents: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
) -> Tensor:
"""Returns a depth-decayed ancestor average for each link latent."""
if not self.use_ancestor_context_for_segmentation:
return link_latents
batch_size, num_links, _ = link_latents.shape
if num_links == 0:
return torch.zeros_like(link_latents)
parent_index = self._build_parent_index(
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
num_links=num_links,
)
valid_parent_mask = parent_index >= 0
if self.ancestor_context_decay == 0.0:
parent_latents = link_latents.gather(
dim=1,
index=parent_index.clamp_min(0).unsqueeze(-1).expand_as(link_latents),
)
return parent_latents * valid_parent_mask.unsqueeze(-1).to(dtype=link_latents.dtype)
if not bool(valid_parent_mask.any().item()):
return torch.zeros_like(link_latents)
solve_dtype = (
torch.float64
if link_latents.device.type == "cpu"
else torch.float32
)
parent_adjacency = torch.zeros(
(batch_size, num_links, num_links),
device=link_latents.device,
dtype=solve_dtype,
)
parent_adjacency.scatter_(
dim=2,
index=parent_index.clamp_min(0).unsqueeze(-1),
src=valid_parent_mask.unsqueeze(-1).to(dtype=solve_dtype),
)
system_matrix = (
torch.eye(num_links, device=link_latents.device, dtype=solve_dtype).unsqueeze(0)
- self.ancestor_context_decay * parent_adjacency
)
weighted_ancestor_matrix = torch.linalg.solve(system_matrix, parent_adjacency)
normalization = weighted_ancestor_matrix.sum(dim=-1, keepdim=True)
normalized_ancestor_matrix = weighted_ancestor_matrix / normalization.clamp_min(1.0)
normalized_ancestor_matrix = normalized_ancestor_matrix.masked_fill(
normalization <= 0.0,
0.0,
)
return torch.matmul(
normalized_ancestor_matrix.to(dtype=link_latents.dtype),
link_latents,
)
def build_segmentation_link_latents(
self,
*,
link_latents: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
) -> Tensor:
"""Applies optional ancestor-context refinement for segmentation only."""
if not self.use_ancestor_context_for_segmentation:
return link_latents
if self.ancestor_context_gate is None or self.ancestor_context_projector is None:
raise RuntimeError(
"ancestor-context projector and gate must exist when "
"use_ancestor_context_for_segmentation=True"
)
ancestor_context = self._build_depth_decayed_ancestor_context(
link_latents=link_latents,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
)
projected_ancestor_context = self.ancestor_context_projector(ancestor_context)
gate = torch.sigmoid(
self.ancestor_context_gate(
torch.cat((link_latents, projected_ancestor_context), dim=-1)
)
)
return link_latents + gate * projected_ancestor_context
def decode(
self,
*,
query_latents: Tensor,
query_points: Tensor,
link_latents: Tensor,
link_valid_flag: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
is_revolute: Tensor,
is_prismatic: Tensor,
link_ids: Tensor | None = None,
) -> Dict[str, Any]:
"""Decodes predictions while preserving the caller-provided joint layout."""
link_latents = link_latents.masked_fill(~link_valid_flag.unsqueeze(-1), 0)
segmentation_link_latents = self.build_segmentation_link_latents(
link_latents=link_latents,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
)
segmentation_logits = self.decode_segmentation(
query_latents=query_latents,
link_latents=segmentation_link_latents,
link_valid_flag=link_valid_flag,
)
revolute_axis = None
prismatic_axis = None
revolute_range = None
prismatic_range = None
revolute_closest_axis_points = None
revolute_low_points = None
revolute_high_points = None
revolute_axis_directions = None
revolute_closest_axis_points_decoder = None
prismatic_closest_axis_points = None
prismatic_low_points = None
prismatic_high_points = None
prismatic_axis_directions = None
prismatic_closest_axis_points_decoder = None
joint_decoding_link_ids = None
decoded_motion_points = None
if self.joint_decode_type in _PLAIN_JOINT_DECODE_TYPES:
if not self.training:
revolute_axis, prismatic_axis, revolute_range, prismatic_range = (
self.decode_joint_parameters(
link_latents=link_latents,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
is_revolute=is_revolute,
is_prismatic=is_prismatic,
)
)
else:
joint_decoding_link_ids = self._resolve_joint_decoding_link_ids(
segmentation_logits=segmentation_logits,
link_ids=link_ids,
)
decoded_motion_points = self._decode_joint_motion_points(
query_latents=query_latents,
link_latents=link_latents,
assigned_link_ids=joint_decoding_link_ids,
joint_connections=joint_connections,
)
revolute_motion_points_decoder_raw, prismatic_motion_points_decoder_raw = (
decoded_motion_points
)
revolute_joint_axis_directions = None
prismatic_joint_axis_directions = None
if self.overparam_predicts_single_axis_direction:
(
revolute_joint_axis_directions,
prismatic_joint_axis_directions,
) = self._decode_joint_axis_directions(
link_latents=link_latents,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
is_revolute=is_revolute,
is_prismatic=is_prismatic,
)
revolute_motion_points_decoder, prismatic_motion_points_decoder = decoded_motion_points
revolute_motion_points = self._convert_overparam_motion_points_to_world_coordinates(
motion_points=revolute_motion_points_decoder,
query_points=query_points,
assigned_link_ids=joint_decoding_link_ids,
)
prismatic_motion_points = self._convert_overparam_motion_points_to_world_coordinates(
motion_points=prismatic_motion_points_decoder,
query_points=query_points,
assigned_link_ids=joint_decoding_link_ids,
)
revolute_axis, prismatic_axis, revolute_range, prismatic_range = (
self.decode_joint_parameters(
link_latents=link_latents,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
is_revolute=is_revolute,
is_prismatic=is_prismatic,
query_points=query_points,
assigned_link_ids=joint_decoding_link_ids,
decoded_motion_points=(revolute_motion_points, prismatic_motion_points),
decoded_axis_directions=(
None
if revolute_joint_axis_directions is None
or prismatic_joint_axis_directions is None
else (
revolute_joint_axis_directions,
prismatic_joint_axis_directions,
)
),
decoded_motion_points_are_world=True,
)
)
revolute_closest_axis_points_decoder = revolute_motion_points_decoder_raw[..., :3]
revolute_closest_axis_points = revolute_motion_points[..., :3]
revolute_low_points = revolute_motion_points[..., 3:6]
revolute_high_points = revolute_motion_points[..., 6:9]
if self.overparam_predicts_query_axis_direction:
revolute_axis_directions = revolute_motion_points[..., 9:12]
elif self.overparam_predicts_single_axis_direction:
revolute_axis_directions = revolute_joint_axis_directions
prismatic_closest_axis_points_decoder = prismatic_motion_points_decoder_raw[..., :3]
prismatic_closest_axis_points = prismatic_motion_points[..., :3]
prismatic_low_points = prismatic_motion_points[..., 3:6]
prismatic_high_points = prismatic_motion_points[..., 6:9]
if self.overparam_predicts_query_axis_direction:
prismatic_axis_directions = prismatic_motion_points[..., 9:12]
elif self.overparam_predicts_single_axis_direction:
prismatic_axis_directions = prismatic_joint_axis_directions
return {
"segmentation_logits": segmentation_logits,
"revolute_axis": revolute_axis,
"prismatic_axis": prismatic_axis,
"revolute_range": revolute_range,
"prismatic_range": prismatic_range,
"revolute_closest_axis_points": revolute_closest_axis_points,
"revolute_closest_axis_points_decoder": revolute_closest_axis_points_decoder,
"revolute_low_points": revolute_low_points,
"revolute_high_points": revolute_high_points,
"revolute_axis_directions": revolute_axis_directions,
"prismatic_closest_axis_points": prismatic_closest_axis_points,
"prismatic_closest_axis_points_decoder": prismatic_closest_axis_points_decoder,
"prismatic_low_points": prismatic_low_points,
"prismatic_high_points": prismatic_high_points,
"prismatic_axis_directions": prismatic_axis_directions,
"joint_decoding_link_ids": joint_decoding_link_ids,
"joint_connections": joint_connections,
"joint_valid_flag": joint_valid_flag,
"is_revolute": is_revolute,
"is_prismatic": is_prismatic,
"query_points": query_points,
}
def forward(
self,
shape_points: Tensor,
shape_point_normals: Tensor,
query_points: Tensor,
query_point_normals: Tensor,
link_point_prompts: Tensor | None,
link_point_prompt_normals: Tensor | None,
link_valid_flag: Tensor,
joint_connections: Tensor,
joint_valid_flag: Tensor,
is_revolute: Tensor,
is_prismatic: Tensor,
link_point_prompt_dropout_eligible: Tensor | None = None,
link_text_prompts: Sequence[Sequence[str]] | None = None,
link_text_embeddings: Tensor | None = None,
link_ids: Tensor | None = None,
) -> Dict[str, Any]:
"""Encodes inputs and decodes using padded joint tensors supplied by the caller."""
encoder_output = self.encoder(
shape_points=shape_points,
shape_point_normals=shape_point_normals,
query_points=query_points,
query_point_normals=query_point_normals,
link_point_prompts=link_point_prompts,
link_point_prompt_normals=link_point_prompt_normals,
link_valid_flag=link_valid_flag,
link_point_prompt_dropout_eligible=link_point_prompt_dropout_eligible,
link_text_prompts=link_text_prompts,
link_text_embeddings=link_text_embeddings,
)
decoded_output = self.decode(
query_latents=encoder_output["query_latents"],
query_points=query_points,
link_latents=encoder_output["link_latents"],
link_valid_flag=link_valid_flag,
joint_connections=joint_connections,
joint_valid_flag=joint_valid_flag,
is_revolute=is_revolute,
is_prismatic=is_prismatic,
link_ids=link_ids,
)
return {**encoder_output, **decoded_output}
__all__ = [
"Particulate2ArticulationModel",
"Particulate2Encoder",
"JointDecoderPlainFlowMatching",
"JointDecoderOverParametrized",
"JointDecoderSingleDirection",
"JointDecoderPlain",
"SegmentationDecoder",
]