from __future__ import annotations from dataclasses import dataclass import torch from torch import Tensor, nn @dataclass class ChunkDecoderConfig: hidden_dim: int = 512 num_heads: int = 8 num_layers: int = 4 ff_dim: int = 2048 dropout: float = 0.1 chunk_size: int = 8 action_dim: int = 14 arm_action_dim: int = 7 num_candidates: int = 8 num_phases: int = 5 num_arm_roles: int = 4 num_proposal_modes: int = 6 planner_top_k: int = 4 class ACTBimanualChunkDecoder(nn.Module): def __init__(self, config: ChunkDecoderConfig) -> None: super().__init__() self.config = config decoder_layer = nn.TransformerDecoderLayer( d_model=config.hidden_dim, nhead=config.num_heads, dim_feedforward=config.ff_dim, dropout=config.dropout, batch_first=True, norm_first=True, ) self.revealer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers) actor_layer = nn.TransformerDecoderLayer( d_model=config.hidden_dim, nhead=config.num_heads, dim_feedforward=config.ff_dim, dropout=config.dropout, batch_first=True, norm_first=True, ) self.actor_decoder = nn.TransformerDecoder(actor_layer, num_layers=config.num_layers) self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim) self.actor_role_bias = nn.Parameter(torch.zeros(1, config.chunk_size, config.hidden_dim)) self.revealer_mean = nn.Linear(config.hidden_dim, config.arm_action_dim) self.revealer_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim) self.actor_mean = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim) self.actor_log_std = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim) self.coordination = nn.Sequential( nn.LayerNorm(config.hidden_dim * 3), nn.Linear(config.hidden_dim * 3, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.hidden_dim), ) self.proposal_score = nn.Sequential( nn.LayerNorm(config.hidden_dim * 3), nn.Linear(config.hidden_dim * 3, 1), ) def _deterministic_candidate_noise( self, action_mean: Tensor, num_candidates: int, ) -> Tensor: batch_size, chunk_size, action_dim = action_mean.shape noise = torch.zeros( batch_size, num_candidates, chunk_size, action_dim, device=action_mean.device, dtype=action_mean.dtype, ) if num_candidates <= 1: return noise candidate_index = torch.arange(1, num_candidates, device=action_mean.device, dtype=action_mean.dtype).view( num_candidates - 1, 1, 1 ) step_index = torch.arange(chunk_size, device=action_mean.device, dtype=action_mean.dtype).view(1, chunk_size, 1) dim_index = torch.arange(action_dim, device=action_mean.device, dtype=action_mean.dtype).view(1, 1, action_dim) base = torch.sin(candidate_index * 0.73 + step_index * 0.37 + dim_index * 0.19) base = base + torch.cos(candidate_index * 1.11 + step_index * 0.17 + dim_index * 0.41) base = base / base.square().mean(dim=(1, 2), keepdim=True).sqrt().clamp_min(1e-6) noise[:, 1:] = base.unsqueeze(0).expand(batch_size, -1, -1, -1) return noise def forward( self, scene_tokens: Tensor, reveal_tokens: Tensor | None = None, memory_token: Tensor | None = None, ) -> dict[str, Tensor]: batch_size = scene_tokens.shape[0] query = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) decoder_memory = scene_tokens if reveal_tokens is not None: decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1) if memory_token is not None: decoder_memory = torch.cat([decoder_memory, memory_token], dim=1) revealer_tokens = self.revealer_decoder(query, decoder_memory) actor_query = query + self.actor_role_bias actor_tokens = self.actor_decoder(actor_query, torch.cat([decoder_memory, revealer_tokens], dim=1)) if reveal_tokens is not None: reveal_context = reveal_tokens.mean(dim=1, keepdim=True).expand(-1, self.config.chunk_size, -1) else: reveal_context = scene_tokens.mean(dim=1, keepdim=True).expand(-1, self.config.chunk_size, -1) coordination_input = torch.cat([revealer_tokens, actor_tokens, reveal_context], dim=-1) coordination = torch.tanh(self.coordination(coordination_input)) revealer_tokens = revealer_tokens + coordination actor_tokens = actor_tokens + coordination action_mean = torch.cat([self.revealer_mean(revealer_tokens), self.actor_mean(actor_tokens)], dim=-1) action_log_std = torch.cat( [ self.revealer_log_std(revealer_tokens), self.actor_log_std(actor_tokens), ], dim=-1, ).clamp(min=-5.0, max=2.0) proposal_features = torch.cat( [ revealer_tokens.mean(dim=1), actor_tokens.mean(dim=1), coordination.mean(dim=1), ], dim=-1, ) return { "decoded_tokens": torch.cat([revealer_tokens, actor_tokens], dim=-1), "revealer_tokens": revealer_tokens, "actor_tokens": actor_tokens, "coordination_tokens": coordination, "action_mean": action_mean, "action_log_std": action_log_std, "proposal_score": self.proposal_score(proposal_features).squeeze(-1), } def sample_candidates(self, action_mean: Tensor, action_log_std: Tensor, num_candidates: int | None = None) -> Tensor: num_candidates = num_candidates or self.config.num_candidates if num_candidates <= 1: return action_mean.unsqueeze(1) std = action_log_std.exp() if self.training: noise = torch.randn( action_mean.size(0), num_candidates, action_mean.size(1), action_mean.size(2), device=action_mean.device, dtype=action_mean.dtype, ) else: noise = self._deterministic_candidate_noise(action_mean, num_candidates) candidates = action_mean.unsqueeze(1) + noise * std.unsqueeze(1) candidates[:, 0] = action_mean return candidates class InteractionChunkDecoder(nn.Module): def __init__(self, config: ChunkDecoderConfig) -> None: super().__init__() self.config = config decoder_layer = nn.TransformerDecoderLayer( d_model=config.hidden_dim, nhead=config.num_heads, dim_feedforward=config.ff_dim, dropout=config.dropout, batch_first=True, norm_first=True, ) self.right_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers) left_layer = nn.TransformerDecoderLayer( d_model=config.hidden_dim, nhead=config.num_heads, dim_feedforward=config.ff_dim, dropout=config.dropout, batch_first=True, norm_first=True, ) self.left_decoder = nn.TransformerDecoder(left_layer, num_layers=config.num_layers) self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim) self.proposal_queries = nn.Embedding(config.num_candidates, config.hidden_dim) self.arm_identity = nn.Embedding(2, config.hidden_dim) self.phase_adapter = nn.Linear(config.num_phases, config.hidden_dim) self.role_adapter = nn.Linear(config.num_arm_roles, config.hidden_dim) self.context_proj = nn.Sequential( nn.LayerNorm(config.hidden_dim), nn.Linear(config.hidden_dim, config.hidden_dim), nn.GELU(), ) self.coordination = nn.Sequential( nn.LayerNorm(config.hidden_dim * 3), nn.Linear(config.hidden_dim * 3, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.hidden_dim), ) self.right_mean = nn.Linear(config.hidden_dim, config.arm_action_dim) self.right_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim) self.left_mean = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim) self.left_log_std = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim) self.proposal_score = nn.Sequential( nn.LayerNorm(config.hidden_dim * 3), nn.Linear(config.hidden_dim * 3, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, 1), ) def _conditioning( self, interaction_state: dict[str, Tensor] | None, batch_size: int, device: torch.device, dtype: torch.dtype, ) -> tuple[Tensor, Tensor, Tensor | None]: if interaction_state is None: zero_phase = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype) zero_roles = torch.zeros(batch_size, 2, self.config.hidden_dim, device=device, dtype=dtype) return zero_phase, zero_roles, None phase_probs = interaction_state["phase_logits"].softmax(dim=-1).to(dtype=dtype) arm_role_probs = interaction_state["arm_role_logits"].softmax(dim=-1).to(dtype=dtype) phase_context = self.phase_adapter(phase_probs) role_context = self.role_adapter(arm_role_probs) return phase_context, role_context, interaction_state.get("interaction_tokens") def _decode_from_queries( self, queries: Tensor, decoder_memory: Tensor, phase_context: Tensor, role_context: Tensor, interaction_context: Tensor, ) -> dict[str, Tensor]: phase_bias = phase_context.unsqueeze(1) right_queries = ( queries + phase_bias + role_context[:, 0].unsqueeze(1) + self.arm_identity.weight[0].view(1, 1, -1).to(dtype=queries.dtype) ) left_queries = ( queries + phase_bias + role_context[:, 1].unsqueeze(1) + self.arm_identity.weight[1].view(1, 1, -1).to(dtype=queries.dtype) ) right_tokens = self.right_decoder(right_queries, decoder_memory) left_tokens = self.left_decoder(left_queries, torch.cat([decoder_memory, right_tokens], dim=1)) context = interaction_context.unsqueeze(1).expand(-1, queries.shape[1], -1) coordination_input = torch.cat([right_tokens, left_tokens, context], dim=-1) coordination = torch.tanh(self.coordination(coordination_input)) right_tokens = right_tokens + coordination left_tokens = left_tokens + coordination action_mean = torch.cat([self.right_mean(right_tokens), self.left_mean(left_tokens)], dim=-1) action_log_std = torch.cat( [self.right_log_std(right_tokens), self.left_log_std(left_tokens)], dim=-1, ).clamp(min=-5.0, max=2.0) pooled_features = torch.cat( [right_tokens.mean(dim=1), left_tokens.mean(dim=1), coordination.mean(dim=1)], dim=-1, ) return { "right_tokens": right_tokens, "left_tokens": left_tokens, "coordination_tokens": coordination, "decoded_tokens": torch.cat([right_tokens, left_tokens], dim=-1), "action_mean": action_mean, "action_log_std": action_log_std, "proposal_score": self.proposal_score(pooled_features).squeeze(-1), } def forward( self, scene_tokens: Tensor, interaction_state: dict[str, Tensor] | None = None, memory_tokens: Tensor | None = None, reveal_tokens: Tensor | None = None, memory_token: Tensor | None = None, ) -> dict[str, Tensor]: if memory_tokens is None: memory_tokens = memory_token batch_size = scene_tokens.shape[0] dtype = scene_tokens.dtype phase_context, role_context, interaction_tokens = self._conditioning( interaction_state=interaction_state, batch_size=batch_size, device=scene_tokens.device, dtype=dtype, ) decoder_memory = scene_tokens if interaction_tokens is not None: decoder_memory = torch.cat([decoder_memory, interaction_tokens], dim=1) elif reveal_tokens is not None: decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1) if memory_tokens is not None: decoder_memory = torch.cat([decoder_memory, memory_tokens], dim=1) if interaction_tokens is not None and interaction_tokens.numel() > 0: interaction_context = interaction_tokens.mean(dim=1) elif reveal_tokens is not None and reveal_tokens.numel() > 0: interaction_context = reveal_tokens.mean(dim=1) else: interaction_context = scene_tokens.mean(dim=1) interaction_context = self.context_proj(interaction_context) base_queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) decoded = self._decode_from_queries( queries=base_queries, decoder_memory=decoder_memory, phase_context=phase_context, role_context=role_context, interaction_context=interaction_context, ) num_candidates = self.config.num_candidates proposal_bias = self.proposal_queries.weight.view(1, num_candidates, 1, -1).expand( batch_size, -1, self.config.chunk_size, -1 ) candidate_queries = base_queries.unsqueeze(1) + proposal_bias flat_queries = candidate_queries.reshape(batch_size * num_candidates, self.config.chunk_size, self.config.hidden_dim) flat_memory = decoder_memory.unsqueeze(1).expand(-1, num_candidates, -1, -1).reshape( batch_size * num_candidates, decoder_memory.shape[1], decoder_memory.shape[2] ) flat_phase = phase_context.unsqueeze(1).expand(-1, num_candidates, -1).reshape( batch_size * num_candidates, self.config.hidden_dim ) flat_roles = role_context.unsqueeze(1).expand(-1, num_candidates, -1, -1).reshape( batch_size * num_candidates, 2, self.config.hidden_dim ) flat_context = interaction_context.unsqueeze(1).expand(-1, num_candidates, -1).reshape( batch_size * num_candidates, self.config.hidden_dim ) candidate_decoded = self._decode_from_queries( queries=flat_queries, decoder_memory=flat_memory, phase_context=flat_phase, role_context=flat_roles, interaction_context=flat_context, ) proposal_deltas = candidate_decoded["action_mean"].view( batch_size, num_candidates, self.config.chunk_size, self.config.action_dim, ) proposal_logits = candidate_decoded["proposal_score"].view(batch_size, num_candidates) proposal_candidates = decoded["action_mean"].unsqueeze(1) + 0.35 * torch.tanh(proposal_deltas) proposal_candidates[:, 0] = decoded["action_mean"] proposal_logits[:, 0] = decoded["proposal_score"] decoded["proposal_candidates"] = proposal_candidates decoded["proposal_logits"] = proposal_logits return decoded def sample_candidates( self, action_mean: Tensor, action_log_std: Tensor, num_candidates: int | None = None, proposal_candidates: Tensor | None = None, ) -> Tensor: if proposal_candidates is not None: return proposal_candidates num_candidates = num_candidates or self.config.num_candidates if num_candidates <= 1: return action_mean.unsqueeze(1) noise = torch.randn( action_mean.size(0), num_candidates, action_mean.size(1), action_mean.size(2), device=action_mean.device, dtype=action_mean.dtype, ) candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1) candidates[:, 0] = action_mean return candidates DEFAULT_PROPOSAL_MODES = ( "widen_opening", "maintain_opening", "slide_occluder", "lift_support_layer", "stabilize_support", "retrieve", ) TASK_PROPOSAL_MODES = { "foliage": ( "sweep_left", "sweep_right", "pin_canopy", "widen_gap", "maintain_gap", "insert_actor", "retrieve", ), "bag": ( "pin_left_rim", "pin_right_rim", "widen_mouth", "maintain_mouth", "probe_inside", "insert_actor", "retrieve", ), "cloth": ( "lift_edge", "separate_layer", "stabilize_fold", "maintain_lift", "insert_actor", "retrieve", ), } TASK_INDEX = {"foliage": 0, "bag": 1, "cloth": 2} def infer_task_name_from_text(text: str | None) -> str: if not text: return "generic" lowered = text.lower() if any(token in lowered for token in ("foliage", "canopy", "leaf", "leaves", "snail")): return "foliage" if any(token in lowered for token in ("bag", "mouth", "rim", "aperture")): return "bag" if any(token in lowered for token in ("cloth", "fold", "layer", "suitcase", "garment")): return "cloth" return "generic" def proposal_mode_vocab(task_name: str, num_modes: int) -> tuple[str, ...]: if task_name == "generic": base_vocab = tuple(DEFAULT_PROPOSAL_MODES) else: vocab = TASK_PROPOSAL_MODES[task_name] if len(vocab) > num_modes: if num_modes >= 6: base_vocab = ( vocab[0], vocab[1], vocab[2], vocab[3], vocab[-2], vocab[-1], )[:num_modes] else: base_vocab = vocab[:num_modes] else: base_vocab = vocab if len(base_vocab) >= num_modes: return tuple(base_vocab[:num_modes]) if not base_vocab: return tuple("retrieve" for _ in range(num_modes)) padded = list(base_vocab) while len(padded) < num_modes: padded.append(base_vocab[-1]) return tuple(padded) def swap_arm_action_order(action_chunk: Tensor) -> Tensor: midpoint = action_chunk.shape[-1] // 2 return torch.cat([action_chunk[..., midpoint:], action_chunk[..., :midpoint]], dim=-1) class SymmetricCoordinatedChunkDecoder(nn.Module): def __init__(self, config: ChunkDecoderConfig) -> None: super().__init__() self.config = config proposal_context_dim = config.action_dim + (config.hidden_dim * 2) decoder_layer = nn.TransformerDecoderLayer( d_model=config.hidden_dim, nhead=config.num_heads, dim_feedforward=config.ff_dim, dropout=config.dropout, batch_first=True, norm_first=True, ) self.arm_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers) self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim) self.arm_identity = nn.Embedding(2, config.hidden_dim) self.task_embedding = nn.Embedding(len(TASK_INDEX), config.hidden_dim) self.phase_adapter = nn.Linear(config.num_phases, config.hidden_dim) self.role_adapter = nn.Linear(config.num_arm_roles, config.hidden_dim) self.context_proj = nn.Sequential( nn.LayerNorm(config.hidden_dim), nn.Linear(config.hidden_dim, config.hidden_dim), nn.GELU(), ) self.coordination = nn.Sequential( nn.LayerNorm(config.hidden_dim * 3), nn.Linear(config.hidden_dim * 3, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.hidden_dim), ) self.arm_head = nn.Sequential( nn.LayerNorm(config.hidden_dim), nn.Linear(config.hidden_dim, config.hidden_dim), nn.GELU(), ) self.arm_mean = nn.Linear(config.hidden_dim, config.arm_action_dim) self.arm_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim) self.proposal_mode_head = nn.Sequential( nn.LayerNorm(proposal_context_dim), nn.Linear(proposal_context_dim, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.num_proposal_modes), ) self.proposal_mode_embeddings = nn.Embedding(config.num_proposal_modes, config.hidden_dim) self.proposal_slot_embeddings = nn.Embedding(config.num_candidates, config.hidden_dim) self.mode_residual_heads = nn.ModuleList( [ nn.Sequential( nn.LayerNorm(proposal_context_dim), nn.Linear(proposal_context_dim, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim), ) for _ in range(config.num_proposal_modes) ] ) self.slot_delta = nn.Sequential( nn.LayerNorm(config.hidden_dim), nn.Linear(config.hidden_dim, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim), ) self.proposal_score = nn.Sequential( nn.LayerNorm(proposal_context_dim + config.hidden_dim), nn.Linear(proposal_context_dim + config.hidden_dim, config.hidden_dim), nn.GELU(), nn.Linear(config.hidden_dim, 1), ) def _conditioning( self, interaction_state: dict[str, Tensor] | None, batch_size: int, device: torch.device, dtype: torch.dtype, swap_roles: bool = False, ) -> tuple[Tensor, Tensor, Tensor]: if interaction_state is None: zero_phase = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype) zero_roles = torch.zeros(batch_size, 2, self.config.hidden_dim, device=device, dtype=dtype) zero_context = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype) return zero_phase, zero_roles, zero_context phase_probs = interaction_state["phase_logits"].softmax(dim=-1).to(dtype=dtype) arm_role_probs = interaction_state["arm_role_logits"].softmax(dim=-1).to(dtype=dtype) if swap_roles: arm_role_probs = arm_role_probs.flip(1) phase_context = self.phase_adapter(phase_probs) role_context = self.role_adapter(arm_role_probs) if interaction_state.get("interaction_tokens") is not None: interaction_context = interaction_state["interaction_tokens"].mean(dim=1) else: interaction_context = interaction_state["field_tokens"].mean(dim=1) return phase_context, role_context, self.context_proj(interaction_context) def _decode_arm_tokens( self, queries: Tensor, decoder_memory: Tensor, phase_context: Tensor, role_context: Tensor, interaction_context: Tensor, swap_roles: bool = False, ) -> tuple[Tensor, Tensor, Tensor]: batch_size, chunk_size, _ = queries.shape identity_order = torch.tensor([1, 0], device=queries.device) if swap_roles else torch.tensor([0, 1], device=queries.device) arm_queries = queries.unsqueeze(1).expand(-1, 2, -1, -1) arm_queries = arm_queries + phase_context.unsqueeze(1).unsqueeze(2) arm_queries = arm_queries + role_context.unsqueeze(2) arm_queries = arm_queries + self.arm_identity(identity_order).view(1, 2, 1, -1).to(dtype=queries.dtype) flat_queries = arm_queries.reshape(batch_size * 2, chunk_size, self.config.hidden_dim) flat_memory = decoder_memory.unsqueeze(1).expand(-1, 2, -1, -1).reshape( batch_size * 2, decoder_memory.shape[1], decoder_memory.shape[2], ) decoded = self.arm_decoder(flat_queries, flat_memory).reshape(batch_size, 2, chunk_size, self.config.hidden_dim) coordination_input = torch.cat( [ decoded[:, 0], decoded[:, 1], interaction_context.unsqueeze(1).expand(-1, chunk_size, -1), ], dim=-1, ) coordination = torch.tanh(self.coordination(coordination_input)) decoded[:, 0] = decoded[:, 0] + coordination decoded[:, 1] = decoded[:, 1] + coordination decoded = self.arm_head(decoded) arm_mean = self.arm_mean(decoded) arm_log_std = self.arm_log_std(decoded).clamp(min=-5.0, max=2.0) return arm_mean, arm_log_std, coordination def _proposal_outputs( self, base_action: Tensor, pooled_context: Tensor, task_names: list[str], ) -> tuple[Tensor, Tensor, Tensor, list[list[str]]]: batch_size = pooled_context.shape[0] mode_logits = self.proposal_mode_head(pooled_context) mode_residuals = [] for head in self.mode_residual_heads: residual = head(pooled_context).view(batch_size, self.config.chunk_size, self.config.action_dim) mode_residuals.append(residual) mode_residuals = torch.stack(mode_residuals, dim=1) mode_assignments = torch.arange(self.config.num_candidates, device=pooled_context.device) % self.config.num_proposal_modes slot_embeddings = self.proposal_slot_embeddings.weight slot_deltas = self.slot_delta(slot_embeddings).view( self.config.num_candidates, self.config.chunk_size, self.config.action_dim, ) proposal_candidates = [] proposal_logits = [] proposal_mode_names = [ [ proposal_mode_vocab(task_name, self.config.num_proposal_modes)[int(mode_assignments[slot_idx])] for slot_idx in range(self.config.num_candidates) ] for task_name in task_names ] for slot_idx in range(self.config.num_candidates): mode_idx = int(mode_assignments[slot_idx]) candidate = base_action + 0.35 * torch.tanh(mode_residuals[:, mode_idx]) + 0.05 * torch.tanh(slot_deltas[slot_idx]).unsqueeze(0) proposal_candidates.append(candidate) score_features = torch.cat( [ pooled_context, self.proposal_mode_embeddings.weight[mode_idx].unsqueeze(0).expand(batch_size, -1) + slot_embeddings[slot_idx].unsqueeze(0).expand(batch_size, -1), ], dim=-1, ) proposal_logits.append( self.proposal_score(score_features).squeeze(-1) + mode_logits[:, mode_idx] ) stacked_candidates = torch.stack(proposal_candidates, dim=1) stacked_logits = torch.stack(proposal_logits, dim=1) stacked_candidates[:, 0] = base_action return stacked_candidates, stacked_logits, mode_logits, proposal_mode_names def forward( self, scene_tokens: Tensor, interaction_state: dict[str, Tensor] | None = None, memory_tokens: Tensor | None = None, reveal_tokens: Tensor | None = None, memory_token: Tensor | None = None, compute_equivariance_probe: bool = False, task_names: list[str] | None = None, ) -> dict[str, Tensor]: if memory_tokens is None: memory_tokens = memory_token batch_size = scene_tokens.shape[0] dtype = scene_tokens.dtype phase_context, role_context, interaction_context = self._conditioning( interaction_state=interaction_state, batch_size=batch_size, device=scene_tokens.device, dtype=dtype, ) decoder_memory = scene_tokens interaction_tokens = interaction_state.get("interaction_tokens") if interaction_state is not None else None if interaction_tokens is not None: decoder_memory = torch.cat([decoder_memory, interaction_tokens], dim=1) elif reveal_tokens is not None: decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1) if memory_tokens is not None: decoder_memory = torch.cat([decoder_memory, memory_tokens], dim=1) canonical_task_names = [infer_task_name_from_text(name) for name in (task_names or ["generic"] * batch_size)] task_ids = torch.as_tensor( [TASK_INDEX[name] for name in canonical_task_names if name in TASK_INDEX], device=scene_tokens.device, dtype=torch.long, ) if task_ids.numel() != batch_size: task_ids = torch.as_tensor( [TASK_INDEX.get(name, 0) for name in canonical_task_names], device=scene_tokens.device, dtype=torch.long, ) interaction_context = interaction_context + self.task_embedding(task_ids) base_queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) arm_mean, arm_log_std, coordination = self._decode_arm_tokens( queries=base_queries, decoder_memory=decoder_memory, phase_context=phase_context, role_context=role_context, interaction_context=interaction_context, ) action_mean = torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1) action_log_std = torch.cat([arm_log_std[:, 0], arm_log_std[:, 1]], dim=-1) pooled_context = torch.cat( [ arm_mean[:, 0].mean(dim=1), arm_mean[:, 1].mean(dim=1), coordination.mean(dim=1), interaction_context, ], dim=-1, ) proposal_candidates, proposal_logits, proposal_mode_logits, proposal_mode_names = self._proposal_outputs( action_mean, pooled_context, canonical_task_names, ) outputs = { "decoded_tokens": torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1), "right_tokens": arm_mean[:, 0], "left_tokens": arm_mean[:, 1], "coordination_tokens": coordination, "action_mean": action_mean, "action_log_std": action_log_std, "proposal_candidates": proposal_candidates, "proposal_logits": proposal_logits, "proposal_mode_logits": proposal_mode_logits, "proposal_mode_assignments": torch.arange( self.config.num_candidates, device=scene_tokens.device, ) % self.config.num_proposal_modes, "proposal_mode_names": proposal_mode_names, "proposal_task_names": canonical_task_names, } if compute_equivariance_probe: swapped_phase, swapped_roles, swapped_context = self._conditioning( interaction_state=interaction_state, batch_size=batch_size, device=scene_tokens.device, dtype=dtype, swap_roles=True, ) swapped_arm_mean, _, _ = self._decode_arm_tokens( queries=base_queries, decoder_memory=decoder_memory, phase_context=swapped_phase, role_context=swapped_roles, interaction_context=swapped_context, swap_roles=True, ) outputs["equivariance_probe_action_mean"] = torch.cat( [swapped_arm_mean[:, 0], swapped_arm_mean[:, 1]], dim=-1, ) outputs["equivariance_target_action_mean"] = swap_arm_action_order(action_mean) return outputs def sample_candidates( self, action_mean: Tensor, action_log_std: Tensor, num_candidates: int | None = None, proposal_candidates: Tensor | None = None, ) -> Tensor: if proposal_candidates is not None: return proposal_candidates num_candidates = num_candidates or self.config.num_candidates if num_candidates <= 1: return action_mean.unsqueeze(1) noise = torch.randn( action_mean.size(0), num_candidates, action_mean.size(1), action_mean.size(2), device=action_mean.device, dtype=action_mean.dtype, ) candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1) candidates[:, 0] = action_mean return candidates