lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import Tensor, nn
@dataclass
class ChunkDecoderConfig:
hidden_dim: int = 512
num_heads: int = 8
num_layers: int = 4
ff_dim: int = 2048
dropout: float = 0.1
chunk_size: int = 8
action_dim: int = 14
arm_action_dim: int = 7
num_candidates: int = 8
num_phases: int = 5
num_arm_roles: int = 4
num_proposal_modes: int = 6
planner_top_k: int = 4
class ACTBimanualChunkDecoder(nn.Module):
def __init__(self, config: ChunkDecoderConfig) -> None:
super().__init__()
self.config = config
decoder_layer = nn.TransformerDecoderLayer(
d_model=config.hidden_dim,
nhead=config.num_heads,
dim_feedforward=config.ff_dim,
dropout=config.dropout,
batch_first=True,
norm_first=True,
)
self.revealer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers)
actor_layer = nn.TransformerDecoderLayer(
d_model=config.hidden_dim,
nhead=config.num_heads,
dim_feedforward=config.ff_dim,
dropout=config.dropout,
batch_first=True,
norm_first=True,
)
self.actor_decoder = nn.TransformerDecoder(actor_layer, num_layers=config.num_layers)
self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim)
self.actor_role_bias = nn.Parameter(torch.zeros(1, config.chunk_size, config.hidden_dim))
self.revealer_mean = nn.Linear(config.hidden_dim, config.arm_action_dim)
self.revealer_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim)
self.actor_mean = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim)
self.actor_log_std = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim)
self.coordination = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 3),
nn.Linear(config.hidden_dim * 3, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
self.proposal_score = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 3),
nn.Linear(config.hidden_dim * 3, 1),
)
def _deterministic_candidate_noise(
self,
action_mean: Tensor,
num_candidates: int,
) -> Tensor:
batch_size, chunk_size, action_dim = action_mean.shape
noise = torch.zeros(
batch_size,
num_candidates,
chunk_size,
action_dim,
device=action_mean.device,
dtype=action_mean.dtype,
)
if num_candidates <= 1:
return noise
candidate_index = torch.arange(1, num_candidates, device=action_mean.device, dtype=action_mean.dtype).view(
num_candidates - 1, 1, 1
)
step_index = torch.arange(chunk_size, device=action_mean.device, dtype=action_mean.dtype).view(1, chunk_size, 1)
dim_index = torch.arange(action_dim, device=action_mean.device, dtype=action_mean.dtype).view(1, 1, action_dim)
base = torch.sin(candidate_index * 0.73 + step_index * 0.37 + dim_index * 0.19)
base = base + torch.cos(candidate_index * 1.11 + step_index * 0.17 + dim_index * 0.41)
base = base / base.square().mean(dim=(1, 2), keepdim=True).sqrt().clamp_min(1e-6)
noise[:, 1:] = base.unsqueeze(0).expand(batch_size, -1, -1, -1)
return noise
def forward(
self,
scene_tokens: Tensor,
reveal_tokens: Tensor | None = None,
memory_token: Tensor | None = None,
) -> dict[str, Tensor]:
batch_size = scene_tokens.shape[0]
query = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
decoder_memory = scene_tokens
if reveal_tokens is not None:
decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1)
if memory_token is not None:
decoder_memory = torch.cat([decoder_memory, memory_token], dim=1)
revealer_tokens = self.revealer_decoder(query, decoder_memory)
actor_query = query + self.actor_role_bias
actor_tokens = self.actor_decoder(actor_query, torch.cat([decoder_memory, revealer_tokens], dim=1))
if reveal_tokens is not None:
reveal_context = reveal_tokens.mean(dim=1, keepdim=True).expand(-1, self.config.chunk_size, -1)
else:
reveal_context = scene_tokens.mean(dim=1, keepdim=True).expand(-1, self.config.chunk_size, -1)
coordination_input = torch.cat([revealer_tokens, actor_tokens, reveal_context], dim=-1)
coordination = torch.tanh(self.coordination(coordination_input))
revealer_tokens = revealer_tokens + coordination
actor_tokens = actor_tokens + coordination
action_mean = torch.cat([self.revealer_mean(revealer_tokens), self.actor_mean(actor_tokens)], dim=-1)
action_log_std = torch.cat(
[
self.revealer_log_std(revealer_tokens),
self.actor_log_std(actor_tokens),
],
dim=-1,
).clamp(min=-5.0, max=2.0)
proposal_features = torch.cat(
[
revealer_tokens.mean(dim=1),
actor_tokens.mean(dim=1),
coordination.mean(dim=1),
],
dim=-1,
)
return {
"decoded_tokens": torch.cat([revealer_tokens, actor_tokens], dim=-1),
"revealer_tokens": revealer_tokens,
"actor_tokens": actor_tokens,
"coordination_tokens": coordination,
"action_mean": action_mean,
"action_log_std": action_log_std,
"proposal_score": self.proposal_score(proposal_features).squeeze(-1),
}
def sample_candidates(self, action_mean: Tensor, action_log_std: Tensor, num_candidates: int | None = None) -> Tensor:
num_candidates = num_candidates or self.config.num_candidates
if num_candidates <= 1:
return action_mean.unsqueeze(1)
std = action_log_std.exp()
if self.training:
noise = torch.randn(
action_mean.size(0),
num_candidates,
action_mean.size(1),
action_mean.size(2),
device=action_mean.device,
dtype=action_mean.dtype,
)
else:
noise = self._deterministic_candidate_noise(action_mean, num_candidates)
candidates = action_mean.unsqueeze(1) + noise * std.unsqueeze(1)
candidates[:, 0] = action_mean
return candidates
class InteractionChunkDecoder(nn.Module):
def __init__(self, config: ChunkDecoderConfig) -> None:
super().__init__()
self.config = config
decoder_layer = nn.TransformerDecoderLayer(
d_model=config.hidden_dim,
nhead=config.num_heads,
dim_feedforward=config.ff_dim,
dropout=config.dropout,
batch_first=True,
norm_first=True,
)
self.right_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers)
left_layer = nn.TransformerDecoderLayer(
d_model=config.hidden_dim,
nhead=config.num_heads,
dim_feedforward=config.ff_dim,
dropout=config.dropout,
batch_first=True,
norm_first=True,
)
self.left_decoder = nn.TransformerDecoder(left_layer, num_layers=config.num_layers)
self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim)
self.proposal_queries = nn.Embedding(config.num_candidates, config.hidden_dim)
self.arm_identity = nn.Embedding(2, config.hidden_dim)
self.phase_adapter = nn.Linear(config.num_phases, config.hidden_dim)
self.role_adapter = nn.Linear(config.num_arm_roles, config.hidden_dim)
self.context_proj = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
)
self.coordination = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 3),
nn.Linear(config.hidden_dim * 3, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
self.right_mean = nn.Linear(config.hidden_dim, config.arm_action_dim)
self.right_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim)
self.left_mean = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim)
self.left_log_std = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim)
self.proposal_score = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 3),
nn.Linear(config.hidden_dim * 3, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, 1),
)
def _conditioning(
self,
interaction_state: dict[str, Tensor] | None,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
) -> tuple[Tensor, Tensor, Tensor | None]:
if interaction_state is None:
zero_phase = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype)
zero_roles = torch.zeros(batch_size, 2, self.config.hidden_dim, device=device, dtype=dtype)
return zero_phase, zero_roles, None
phase_probs = interaction_state["phase_logits"].softmax(dim=-1).to(dtype=dtype)
arm_role_probs = interaction_state["arm_role_logits"].softmax(dim=-1).to(dtype=dtype)
phase_context = self.phase_adapter(phase_probs)
role_context = self.role_adapter(arm_role_probs)
return phase_context, role_context, interaction_state.get("interaction_tokens")
def _decode_from_queries(
self,
queries: Tensor,
decoder_memory: Tensor,
phase_context: Tensor,
role_context: Tensor,
interaction_context: Tensor,
) -> dict[str, Tensor]:
phase_bias = phase_context.unsqueeze(1)
right_queries = (
queries
+ phase_bias
+ role_context[:, 0].unsqueeze(1)
+ self.arm_identity.weight[0].view(1, 1, -1).to(dtype=queries.dtype)
)
left_queries = (
queries
+ phase_bias
+ role_context[:, 1].unsqueeze(1)
+ self.arm_identity.weight[1].view(1, 1, -1).to(dtype=queries.dtype)
)
right_tokens = self.right_decoder(right_queries, decoder_memory)
left_tokens = self.left_decoder(left_queries, torch.cat([decoder_memory, right_tokens], dim=1))
context = interaction_context.unsqueeze(1).expand(-1, queries.shape[1], -1)
coordination_input = torch.cat([right_tokens, left_tokens, context], dim=-1)
coordination = torch.tanh(self.coordination(coordination_input))
right_tokens = right_tokens + coordination
left_tokens = left_tokens + coordination
action_mean = torch.cat([self.right_mean(right_tokens), self.left_mean(left_tokens)], dim=-1)
action_log_std = torch.cat(
[self.right_log_std(right_tokens), self.left_log_std(left_tokens)],
dim=-1,
).clamp(min=-5.0, max=2.0)
pooled_features = torch.cat(
[right_tokens.mean(dim=1), left_tokens.mean(dim=1), coordination.mean(dim=1)],
dim=-1,
)
return {
"right_tokens": right_tokens,
"left_tokens": left_tokens,
"coordination_tokens": coordination,
"decoded_tokens": torch.cat([right_tokens, left_tokens], dim=-1),
"action_mean": action_mean,
"action_log_std": action_log_std,
"proposal_score": self.proposal_score(pooled_features).squeeze(-1),
}
def forward(
self,
scene_tokens: Tensor,
interaction_state: dict[str, Tensor] | None = None,
memory_tokens: Tensor | None = None,
reveal_tokens: Tensor | None = None,
memory_token: Tensor | None = None,
) -> dict[str, Tensor]:
if memory_tokens is None:
memory_tokens = memory_token
batch_size = scene_tokens.shape[0]
dtype = scene_tokens.dtype
phase_context, role_context, interaction_tokens = self._conditioning(
interaction_state=interaction_state,
batch_size=batch_size,
device=scene_tokens.device,
dtype=dtype,
)
decoder_memory = scene_tokens
if interaction_tokens is not None:
decoder_memory = torch.cat([decoder_memory, interaction_tokens], dim=1)
elif reveal_tokens is not None:
decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1)
if memory_tokens is not None:
decoder_memory = torch.cat([decoder_memory, memory_tokens], dim=1)
if interaction_tokens is not None and interaction_tokens.numel() > 0:
interaction_context = interaction_tokens.mean(dim=1)
elif reveal_tokens is not None and reveal_tokens.numel() > 0:
interaction_context = reveal_tokens.mean(dim=1)
else:
interaction_context = scene_tokens.mean(dim=1)
interaction_context = self.context_proj(interaction_context)
base_queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
decoded = self._decode_from_queries(
queries=base_queries,
decoder_memory=decoder_memory,
phase_context=phase_context,
role_context=role_context,
interaction_context=interaction_context,
)
num_candidates = self.config.num_candidates
proposal_bias = self.proposal_queries.weight.view(1, num_candidates, 1, -1).expand(
batch_size, -1, self.config.chunk_size, -1
)
candidate_queries = base_queries.unsqueeze(1) + proposal_bias
flat_queries = candidate_queries.reshape(batch_size * num_candidates, self.config.chunk_size, self.config.hidden_dim)
flat_memory = decoder_memory.unsqueeze(1).expand(-1, num_candidates, -1, -1).reshape(
batch_size * num_candidates, decoder_memory.shape[1], decoder_memory.shape[2]
)
flat_phase = phase_context.unsqueeze(1).expand(-1, num_candidates, -1).reshape(
batch_size * num_candidates, self.config.hidden_dim
)
flat_roles = role_context.unsqueeze(1).expand(-1, num_candidates, -1, -1).reshape(
batch_size * num_candidates, 2, self.config.hidden_dim
)
flat_context = interaction_context.unsqueeze(1).expand(-1, num_candidates, -1).reshape(
batch_size * num_candidates, self.config.hidden_dim
)
candidate_decoded = self._decode_from_queries(
queries=flat_queries,
decoder_memory=flat_memory,
phase_context=flat_phase,
role_context=flat_roles,
interaction_context=flat_context,
)
proposal_deltas = candidate_decoded["action_mean"].view(
batch_size,
num_candidates,
self.config.chunk_size,
self.config.action_dim,
)
proposal_logits = candidate_decoded["proposal_score"].view(batch_size, num_candidates)
proposal_candidates = decoded["action_mean"].unsqueeze(1) + 0.35 * torch.tanh(proposal_deltas)
proposal_candidates[:, 0] = decoded["action_mean"]
proposal_logits[:, 0] = decoded["proposal_score"]
decoded["proposal_candidates"] = proposal_candidates
decoded["proposal_logits"] = proposal_logits
return decoded
def sample_candidates(
self,
action_mean: Tensor,
action_log_std: Tensor,
num_candidates: int | None = None,
proposal_candidates: Tensor | None = None,
) -> Tensor:
if proposal_candidates is not None:
return proposal_candidates
num_candidates = num_candidates or self.config.num_candidates
if num_candidates <= 1:
return action_mean.unsqueeze(1)
noise = torch.randn(
action_mean.size(0),
num_candidates,
action_mean.size(1),
action_mean.size(2),
device=action_mean.device,
dtype=action_mean.dtype,
)
candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1)
candidates[:, 0] = action_mean
return candidates
DEFAULT_PROPOSAL_MODES = (
"widen_opening",
"maintain_opening",
"slide_occluder",
"lift_support_layer",
"stabilize_support",
"retrieve",
)
TASK_PROPOSAL_MODES = {
"foliage": (
"sweep_left",
"sweep_right",
"pin_canopy",
"widen_gap",
"maintain_gap",
"insert_actor",
"retrieve",
),
"bag": (
"pin_left_rim",
"pin_right_rim",
"widen_mouth",
"maintain_mouth",
"probe_inside",
"insert_actor",
"retrieve",
),
"cloth": (
"lift_edge",
"separate_layer",
"stabilize_fold",
"maintain_lift",
"insert_actor",
"retrieve",
),
}
TASK_INDEX = {"foliage": 0, "bag": 1, "cloth": 2}
def infer_task_name_from_text(text: str | None) -> str:
if not text:
return "generic"
lowered = text.lower()
if any(token in lowered for token in ("foliage", "canopy", "leaf", "leaves", "snail")):
return "foliage"
if any(token in lowered for token in ("bag", "mouth", "rim", "aperture")):
return "bag"
if any(token in lowered for token in ("cloth", "fold", "layer", "suitcase", "garment")):
return "cloth"
return "generic"
def proposal_mode_vocab(task_name: str, num_modes: int) -> tuple[str, ...]:
if task_name == "generic":
base_vocab = tuple(DEFAULT_PROPOSAL_MODES)
else:
vocab = TASK_PROPOSAL_MODES[task_name]
if len(vocab) > num_modes:
if num_modes >= 6:
base_vocab = (
vocab[0],
vocab[1],
vocab[2],
vocab[3],
vocab[-2],
vocab[-1],
)[:num_modes]
else:
base_vocab = vocab[:num_modes]
else:
base_vocab = vocab
if len(base_vocab) >= num_modes:
return tuple(base_vocab[:num_modes])
if not base_vocab:
return tuple("retrieve" for _ in range(num_modes))
padded = list(base_vocab)
while len(padded) < num_modes:
padded.append(base_vocab[-1])
return tuple(padded)
def swap_arm_action_order(action_chunk: Tensor) -> Tensor:
midpoint = action_chunk.shape[-1] // 2
return torch.cat([action_chunk[..., midpoint:], action_chunk[..., :midpoint]], dim=-1)
class SymmetricCoordinatedChunkDecoder(nn.Module):
def __init__(self, config: ChunkDecoderConfig) -> None:
super().__init__()
self.config = config
proposal_context_dim = config.action_dim + (config.hidden_dim * 2)
decoder_layer = nn.TransformerDecoderLayer(
d_model=config.hidden_dim,
nhead=config.num_heads,
dim_feedforward=config.ff_dim,
dropout=config.dropout,
batch_first=True,
norm_first=True,
)
self.arm_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers)
self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim)
self.arm_identity = nn.Embedding(2, config.hidden_dim)
self.task_embedding = nn.Embedding(len(TASK_INDEX), config.hidden_dim)
self.phase_adapter = nn.Linear(config.num_phases, config.hidden_dim)
self.role_adapter = nn.Linear(config.num_arm_roles, config.hidden_dim)
self.context_proj = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
)
self.coordination = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 3),
nn.Linear(config.hidden_dim * 3, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
self.arm_head = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
)
self.arm_mean = nn.Linear(config.hidden_dim, config.arm_action_dim)
self.arm_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim)
self.proposal_mode_head = nn.Sequential(
nn.LayerNorm(proposal_context_dim),
nn.Linear(proposal_context_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_proposal_modes),
)
self.proposal_mode_embeddings = nn.Embedding(config.num_proposal_modes, config.hidden_dim)
self.proposal_slot_embeddings = nn.Embedding(config.num_candidates, config.hidden_dim)
self.mode_residual_heads = nn.ModuleList(
[
nn.Sequential(
nn.LayerNorm(proposal_context_dim),
nn.Linear(proposal_context_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim),
)
for _ in range(config.num_proposal_modes)
]
)
self.slot_delta = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim),
)
self.proposal_score = nn.Sequential(
nn.LayerNorm(proposal_context_dim + config.hidden_dim),
nn.Linear(proposal_context_dim + config.hidden_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, 1),
)
def _conditioning(
self,
interaction_state: dict[str, Tensor] | None,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
swap_roles: bool = False,
) -> tuple[Tensor, Tensor, Tensor]:
if interaction_state is None:
zero_phase = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype)
zero_roles = torch.zeros(batch_size, 2, self.config.hidden_dim, device=device, dtype=dtype)
zero_context = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype)
return zero_phase, zero_roles, zero_context
phase_probs = interaction_state["phase_logits"].softmax(dim=-1).to(dtype=dtype)
arm_role_probs = interaction_state["arm_role_logits"].softmax(dim=-1).to(dtype=dtype)
if swap_roles:
arm_role_probs = arm_role_probs.flip(1)
phase_context = self.phase_adapter(phase_probs)
role_context = self.role_adapter(arm_role_probs)
if interaction_state.get("interaction_tokens") is not None:
interaction_context = interaction_state["interaction_tokens"].mean(dim=1)
else:
interaction_context = interaction_state["field_tokens"].mean(dim=1)
return phase_context, role_context, self.context_proj(interaction_context)
def _decode_arm_tokens(
self,
queries: Tensor,
decoder_memory: Tensor,
phase_context: Tensor,
role_context: Tensor,
interaction_context: Tensor,
swap_roles: bool = False,
) -> tuple[Tensor, Tensor, Tensor]:
batch_size, chunk_size, _ = queries.shape
identity_order = torch.tensor([1, 0], device=queries.device) if swap_roles else torch.tensor([0, 1], device=queries.device)
arm_queries = queries.unsqueeze(1).expand(-1, 2, -1, -1)
arm_queries = arm_queries + phase_context.unsqueeze(1).unsqueeze(2)
arm_queries = arm_queries + role_context.unsqueeze(2)
arm_queries = arm_queries + self.arm_identity(identity_order).view(1, 2, 1, -1).to(dtype=queries.dtype)
flat_queries = arm_queries.reshape(batch_size * 2, chunk_size, self.config.hidden_dim)
flat_memory = decoder_memory.unsqueeze(1).expand(-1, 2, -1, -1).reshape(
batch_size * 2,
decoder_memory.shape[1],
decoder_memory.shape[2],
)
decoded = self.arm_decoder(flat_queries, flat_memory).reshape(batch_size, 2, chunk_size, self.config.hidden_dim)
coordination_input = torch.cat(
[
decoded[:, 0],
decoded[:, 1],
interaction_context.unsqueeze(1).expand(-1, chunk_size, -1),
],
dim=-1,
)
coordination = torch.tanh(self.coordination(coordination_input))
decoded[:, 0] = decoded[:, 0] + coordination
decoded[:, 1] = decoded[:, 1] + coordination
decoded = self.arm_head(decoded)
arm_mean = self.arm_mean(decoded)
arm_log_std = self.arm_log_std(decoded).clamp(min=-5.0, max=2.0)
return arm_mean, arm_log_std, coordination
def _proposal_outputs(
self,
base_action: Tensor,
pooled_context: Tensor,
task_names: list[str],
) -> tuple[Tensor, Tensor, Tensor, list[list[str]]]:
batch_size = pooled_context.shape[0]
mode_logits = self.proposal_mode_head(pooled_context)
mode_residuals = []
for head in self.mode_residual_heads:
residual = head(pooled_context).view(batch_size, self.config.chunk_size, self.config.action_dim)
mode_residuals.append(residual)
mode_residuals = torch.stack(mode_residuals, dim=1)
mode_assignments = torch.arange(self.config.num_candidates, device=pooled_context.device) % self.config.num_proposal_modes
slot_embeddings = self.proposal_slot_embeddings.weight
slot_deltas = self.slot_delta(slot_embeddings).view(
self.config.num_candidates,
self.config.chunk_size,
self.config.action_dim,
)
proposal_candidates = []
proposal_logits = []
proposal_mode_names = [
[
proposal_mode_vocab(task_name, self.config.num_proposal_modes)[int(mode_assignments[slot_idx])]
for slot_idx in range(self.config.num_candidates)
]
for task_name in task_names
]
for slot_idx in range(self.config.num_candidates):
mode_idx = int(mode_assignments[slot_idx])
candidate = base_action + 0.35 * torch.tanh(mode_residuals[:, mode_idx]) + 0.05 * torch.tanh(slot_deltas[slot_idx]).unsqueeze(0)
proposal_candidates.append(candidate)
score_features = torch.cat(
[
pooled_context,
self.proposal_mode_embeddings.weight[mode_idx].unsqueeze(0).expand(batch_size, -1)
+ slot_embeddings[slot_idx].unsqueeze(0).expand(batch_size, -1),
],
dim=-1,
)
proposal_logits.append(
self.proposal_score(score_features).squeeze(-1) + mode_logits[:, mode_idx]
)
stacked_candidates = torch.stack(proposal_candidates, dim=1)
stacked_logits = torch.stack(proposal_logits, dim=1)
stacked_candidates[:, 0] = base_action
return stacked_candidates, stacked_logits, mode_logits, proposal_mode_names
def forward(
self,
scene_tokens: Tensor,
interaction_state: dict[str, Tensor] | None = None,
memory_tokens: Tensor | None = None,
reveal_tokens: Tensor | None = None,
memory_token: Tensor | None = None,
compute_equivariance_probe: bool = False,
task_names: list[str] | None = None,
) -> dict[str, Tensor]:
if memory_tokens is None:
memory_tokens = memory_token
batch_size = scene_tokens.shape[0]
dtype = scene_tokens.dtype
phase_context, role_context, interaction_context = self._conditioning(
interaction_state=interaction_state,
batch_size=batch_size,
device=scene_tokens.device,
dtype=dtype,
)
decoder_memory = scene_tokens
interaction_tokens = interaction_state.get("interaction_tokens") if interaction_state is not None else None
if interaction_tokens is not None:
decoder_memory = torch.cat([decoder_memory, interaction_tokens], dim=1)
elif reveal_tokens is not None:
decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1)
if memory_tokens is not None:
decoder_memory = torch.cat([decoder_memory, memory_tokens], dim=1)
canonical_task_names = [infer_task_name_from_text(name) for name in (task_names or ["generic"] * batch_size)]
task_ids = torch.as_tensor(
[TASK_INDEX[name] for name in canonical_task_names if name in TASK_INDEX],
device=scene_tokens.device,
dtype=torch.long,
)
if task_ids.numel() != batch_size:
task_ids = torch.as_tensor(
[TASK_INDEX.get(name, 0) for name in canonical_task_names],
device=scene_tokens.device,
dtype=torch.long,
)
interaction_context = interaction_context + self.task_embedding(task_ids)
base_queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
arm_mean, arm_log_std, coordination = self._decode_arm_tokens(
queries=base_queries,
decoder_memory=decoder_memory,
phase_context=phase_context,
role_context=role_context,
interaction_context=interaction_context,
)
action_mean = torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1)
action_log_std = torch.cat([arm_log_std[:, 0], arm_log_std[:, 1]], dim=-1)
pooled_context = torch.cat(
[
arm_mean[:, 0].mean(dim=1),
arm_mean[:, 1].mean(dim=1),
coordination.mean(dim=1),
interaction_context,
],
dim=-1,
)
proposal_candidates, proposal_logits, proposal_mode_logits, proposal_mode_names = self._proposal_outputs(
action_mean,
pooled_context,
canonical_task_names,
)
outputs = {
"decoded_tokens": torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1),
"right_tokens": arm_mean[:, 0],
"left_tokens": arm_mean[:, 1],
"coordination_tokens": coordination,
"action_mean": action_mean,
"action_log_std": action_log_std,
"proposal_candidates": proposal_candidates,
"proposal_logits": proposal_logits,
"proposal_mode_logits": proposal_mode_logits,
"proposal_mode_assignments": torch.arange(
self.config.num_candidates,
device=scene_tokens.device,
) % self.config.num_proposal_modes,
"proposal_mode_names": proposal_mode_names,
"proposal_task_names": canonical_task_names,
}
if compute_equivariance_probe:
swapped_phase, swapped_roles, swapped_context = self._conditioning(
interaction_state=interaction_state,
batch_size=batch_size,
device=scene_tokens.device,
dtype=dtype,
swap_roles=True,
)
swapped_arm_mean, _, _ = self._decode_arm_tokens(
queries=base_queries,
decoder_memory=decoder_memory,
phase_context=swapped_phase,
role_context=swapped_roles,
interaction_context=swapped_context,
swap_roles=True,
)
outputs["equivariance_probe_action_mean"] = torch.cat(
[swapped_arm_mean[:, 0], swapped_arm_mean[:, 1]],
dim=-1,
)
outputs["equivariance_target_action_mean"] = swap_arm_action_order(action_mean)
return outputs
def sample_candidates(
self,
action_mean: Tensor,
action_log_std: Tensor,
num_candidates: int | None = None,
proposal_candidates: Tensor | None = None,
) -> Tensor:
if proposal_candidates is not None:
return proposal_candidates
num_candidates = num_candidates or self.config.num_candidates
if num_candidates <= 1:
return action_mean.unsqueeze(1)
noise = torch.randn(
action_mean.size(0),
num_candidates,
action_mean.size(1),
action_mean.size(2),
device=action_mean.device,
dtype=action_mean.dtype,
)
candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1)
candidates[:, 0] = action_mean
return candidates