| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
|
|
| @dataclass |
| class ChunkDecoderConfig: |
| hidden_dim: int = 512 |
| num_heads: int = 8 |
| num_layers: int = 4 |
| ff_dim: int = 2048 |
| dropout: float = 0.1 |
| chunk_size: int = 8 |
| action_dim: int = 14 |
| arm_action_dim: int = 7 |
| num_candidates: int = 8 |
| num_phases: int = 5 |
| num_arm_roles: int = 4 |
| num_proposal_modes: int = 6 |
| planner_top_k: int = 4 |
|
|
|
|
| class ACTBimanualChunkDecoder(nn.Module): |
| def __init__(self, config: ChunkDecoderConfig) -> None: |
| super().__init__() |
| self.config = config |
| decoder_layer = nn.TransformerDecoderLayer( |
| d_model=config.hidden_dim, |
| nhead=config.num_heads, |
| dim_feedforward=config.ff_dim, |
| dropout=config.dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.revealer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers) |
| actor_layer = nn.TransformerDecoderLayer( |
| d_model=config.hidden_dim, |
| nhead=config.num_heads, |
| dim_feedforward=config.ff_dim, |
| dropout=config.dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.actor_decoder = nn.TransformerDecoder(actor_layer, num_layers=config.num_layers) |
| self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim) |
| self.actor_role_bias = nn.Parameter(torch.zeros(1, config.chunk_size, config.hidden_dim)) |
| self.revealer_mean = nn.Linear(config.hidden_dim, config.arm_action_dim) |
| self.revealer_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim) |
| self.actor_mean = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim) |
| self.actor_log_std = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim) |
| self.coordination = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 3), |
| nn.Linear(config.hidden_dim * 3, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| ) |
| self.proposal_score = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 3), |
| nn.Linear(config.hidden_dim * 3, 1), |
| ) |
|
|
| def _deterministic_candidate_noise( |
| self, |
| action_mean: Tensor, |
| num_candidates: int, |
| ) -> Tensor: |
| batch_size, chunk_size, action_dim = action_mean.shape |
| noise = torch.zeros( |
| batch_size, |
| num_candidates, |
| chunk_size, |
| action_dim, |
| device=action_mean.device, |
| dtype=action_mean.dtype, |
| ) |
| if num_candidates <= 1: |
| return noise |
|
|
| candidate_index = torch.arange(1, num_candidates, device=action_mean.device, dtype=action_mean.dtype).view( |
| num_candidates - 1, 1, 1 |
| ) |
| step_index = torch.arange(chunk_size, device=action_mean.device, dtype=action_mean.dtype).view(1, chunk_size, 1) |
| dim_index = torch.arange(action_dim, device=action_mean.device, dtype=action_mean.dtype).view(1, 1, action_dim) |
|
|
| base = torch.sin(candidate_index * 0.73 + step_index * 0.37 + dim_index * 0.19) |
| base = base + torch.cos(candidate_index * 1.11 + step_index * 0.17 + dim_index * 0.41) |
| base = base / base.square().mean(dim=(1, 2), keepdim=True).sqrt().clamp_min(1e-6) |
| noise[:, 1:] = base.unsqueeze(0).expand(batch_size, -1, -1, -1) |
| return noise |
|
|
| def forward( |
| self, |
| scene_tokens: Tensor, |
| reveal_tokens: Tensor | None = None, |
| memory_token: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| batch_size = scene_tokens.shape[0] |
| query = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) |
| decoder_memory = scene_tokens |
| if reveal_tokens is not None: |
| decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1) |
| if memory_token is not None: |
| decoder_memory = torch.cat([decoder_memory, memory_token], dim=1) |
|
|
| revealer_tokens = self.revealer_decoder(query, decoder_memory) |
| actor_query = query + self.actor_role_bias |
| actor_tokens = self.actor_decoder(actor_query, torch.cat([decoder_memory, revealer_tokens], dim=1)) |
| if reveal_tokens is not None: |
| reveal_context = reveal_tokens.mean(dim=1, keepdim=True).expand(-1, self.config.chunk_size, -1) |
| else: |
| reveal_context = scene_tokens.mean(dim=1, keepdim=True).expand(-1, self.config.chunk_size, -1) |
| coordination_input = torch.cat([revealer_tokens, actor_tokens, reveal_context], dim=-1) |
| coordination = torch.tanh(self.coordination(coordination_input)) |
| revealer_tokens = revealer_tokens + coordination |
| actor_tokens = actor_tokens + coordination |
| action_mean = torch.cat([self.revealer_mean(revealer_tokens), self.actor_mean(actor_tokens)], dim=-1) |
| action_log_std = torch.cat( |
| [ |
| self.revealer_log_std(revealer_tokens), |
| self.actor_log_std(actor_tokens), |
| ], |
| dim=-1, |
| ).clamp(min=-5.0, max=2.0) |
| proposal_features = torch.cat( |
| [ |
| revealer_tokens.mean(dim=1), |
| actor_tokens.mean(dim=1), |
| coordination.mean(dim=1), |
| ], |
| dim=-1, |
| ) |
| return { |
| "decoded_tokens": torch.cat([revealer_tokens, actor_tokens], dim=-1), |
| "revealer_tokens": revealer_tokens, |
| "actor_tokens": actor_tokens, |
| "coordination_tokens": coordination, |
| "action_mean": action_mean, |
| "action_log_std": action_log_std, |
| "proposal_score": self.proposal_score(proposal_features).squeeze(-1), |
| } |
|
|
| def sample_candidates(self, action_mean: Tensor, action_log_std: Tensor, num_candidates: int | None = None) -> Tensor: |
| num_candidates = num_candidates or self.config.num_candidates |
| if num_candidates <= 1: |
| return action_mean.unsqueeze(1) |
| std = action_log_std.exp() |
| if self.training: |
| noise = torch.randn( |
| action_mean.size(0), |
| num_candidates, |
| action_mean.size(1), |
| action_mean.size(2), |
| device=action_mean.device, |
| dtype=action_mean.dtype, |
| ) |
| else: |
| noise = self._deterministic_candidate_noise(action_mean, num_candidates) |
| candidates = action_mean.unsqueeze(1) + noise * std.unsqueeze(1) |
| candidates[:, 0] = action_mean |
| return candidates |
|
|
|
|
| class InteractionChunkDecoder(nn.Module): |
| def __init__(self, config: ChunkDecoderConfig) -> None: |
| super().__init__() |
| self.config = config |
| decoder_layer = nn.TransformerDecoderLayer( |
| d_model=config.hidden_dim, |
| nhead=config.num_heads, |
| dim_feedforward=config.ff_dim, |
| dropout=config.dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.right_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers) |
| left_layer = nn.TransformerDecoderLayer( |
| d_model=config.hidden_dim, |
| nhead=config.num_heads, |
| dim_feedforward=config.ff_dim, |
| dropout=config.dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.left_decoder = nn.TransformerDecoder(left_layer, num_layers=config.num_layers) |
| self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim) |
| self.proposal_queries = nn.Embedding(config.num_candidates, config.hidden_dim) |
| self.arm_identity = nn.Embedding(2, config.hidden_dim) |
| self.phase_adapter = nn.Linear(config.num_phases, config.hidden_dim) |
| self.role_adapter = nn.Linear(config.num_arm_roles, config.hidden_dim) |
| self.context_proj = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.coordination = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 3), |
| nn.Linear(config.hidden_dim * 3, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| ) |
| self.right_mean = nn.Linear(config.hidden_dim, config.arm_action_dim) |
| self.right_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim) |
| self.left_mean = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim) |
| self.left_log_std = nn.Linear(config.hidden_dim, config.action_dim - config.arm_action_dim) |
| self.proposal_score = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 3), |
| nn.Linear(config.hidden_dim * 3, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, 1), |
| ) |
|
|
| def _conditioning( |
| self, |
| interaction_state: dict[str, Tensor] | None, |
| batch_size: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> tuple[Tensor, Tensor, Tensor | None]: |
| if interaction_state is None: |
| zero_phase = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype) |
| zero_roles = torch.zeros(batch_size, 2, self.config.hidden_dim, device=device, dtype=dtype) |
| return zero_phase, zero_roles, None |
| phase_probs = interaction_state["phase_logits"].softmax(dim=-1).to(dtype=dtype) |
| arm_role_probs = interaction_state["arm_role_logits"].softmax(dim=-1).to(dtype=dtype) |
| phase_context = self.phase_adapter(phase_probs) |
| role_context = self.role_adapter(arm_role_probs) |
| return phase_context, role_context, interaction_state.get("interaction_tokens") |
|
|
| def _decode_from_queries( |
| self, |
| queries: Tensor, |
| decoder_memory: Tensor, |
| phase_context: Tensor, |
| role_context: Tensor, |
| interaction_context: Tensor, |
| ) -> dict[str, Tensor]: |
| phase_bias = phase_context.unsqueeze(1) |
| right_queries = ( |
| queries |
| + phase_bias |
| + role_context[:, 0].unsqueeze(1) |
| + self.arm_identity.weight[0].view(1, 1, -1).to(dtype=queries.dtype) |
| ) |
| left_queries = ( |
| queries |
| + phase_bias |
| + role_context[:, 1].unsqueeze(1) |
| + self.arm_identity.weight[1].view(1, 1, -1).to(dtype=queries.dtype) |
| ) |
| right_tokens = self.right_decoder(right_queries, decoder_memory) |
| left_tokens = self.left_decoder(left_queries, torch.cat([decoder_memory, right_tokens], dim=1)) |
| context = interaction_context.unsqueeze(1).expand(-1, queries.shape[1], -1) |
| coordination_input = torch.cat([right_tokens, left_tokens, context], dim=-1) |
| coordination = torch.tanh(self.coordination(coordination_input)) |
| right_tokens = right_tokens + coordination |
| left_tokens = left_tokens + coordination |
| action_mean = torch.cat([self.right_mean(right_tokens), self.left_mean(left_tokens)], dim=-1) |
| action_log_std = torch.cat( |
| [self.right_log_std(right_tokens), self.left_log_std(left_tokens)], |
| dim=-1, |
| ).clamp(min=-5.0, max=2.0) |
| pooled_features = torch.cat( |
| [right_tokens.mean(dim=1), left_tokens.mean(dim=1), coordination.mean(dim=1)], |
| dim=-1, |
| ) |
| return { |
| "right_tokens": right_tokens, |
| "left_tokens": left_tokens, |
| "coordination_tokens": coordination, |
| "decoded_tokens": torch.cat([right_tokens, left_tokens], dim=-1), |
| "action_mean": action_mean, |
| "action_log_std": action_log_std, |
| "proposal_score": self.proposal_score(pooled_features).squeeze(-1), |
| } |
|
|
| def forward( |
| self, |
| scene_tokens: Tensor, |
| interaction_state: dict[str, Tensor] | None = None, |
| memory_tokens: Tensor | None = None, |
| reveal_tokens: Tensor | None = None, |
| memory_token: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| if memory_tokens is None: |
| memory_tokens = memory_token |
| batch_size = scene_tokens.shape[0] |
| dtype = scene_tokens.dtype |
| phase_context, role_context, interaction_tokens = self._conditioning( |
| interaction_state=interaction_state, |
| batch_size=batch_size, |
| device=scene_tokens.device, |
| dtype=dtype, |
| ) |
|
|
| decoder_memory = scene_tokens |
| if interaction_tokens is not None: |
| decoder_memory = torch.cat([decoder_memory, interaction_tokens], dim=1) |
| elif reveal_tokens is not None: |
| decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1) |
| if memory_tokens is not None: |
| decoder_memory = torch.cat([decoder_memory, memory_tokens], dim=1) |
|
|
| if interaction_tokens is not None and interaction_tokens.numel() > 0: |
| interaction_context = interaction_tokens.mean(dim=1) |
| elif reveal_tokens is not None and reveal_tokens.numel() > 0: |
| interaction_context = reveal_tokens.mean(dim=1) |
| else: |
| interaction_context = scene_tokens.mean(dim=1) |
| interaction_context = self.context_proj(interaction_context) |
|
|
| base_queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) |
| decoded = self._decode_from_queries( |
| queries=base_queries, |
| decoder_memory=decoder_memory, |
| phase_context=phase_context, |
| role_context=role_context, |
| interaction_context=interaction_context, |
| ) |
|
|
| num_candidates = self.config.num_candidates |
| proposal_bias = self.proposal_queries.weight.view(1, num_candidates, 1, -1).expand( |
| batch_size, -1, self.config.chunk_size, -1 |
| ) |
| candidate_queries = base_queries.unsqueeze(1) + proposal_bias |
| flat_queries = candidate_queries.reshape(batch_size * num_candidates, self.config.chunk_size, self.config.hidden_dim) |
| flat_memory = decoder_memory.unsqueeze(1).expand(-1, num_candidates, -1, -1).reshape( |
| batch_size * num_candidates, decoder_memory.shape[1], decoder_memory.shape[2] |
| ) |
| flat_phase = phase_context.unsqueeze(1).expand(-1, num_candidates, -1).reshape( |
| batch_size * num_candidates, self.config.hidden_dim |
| ) |
| flat_roles = role_context.unsqueeze(1).expand(-1, num_candidates, -1, -1).reshape( |
| batch_size * num_candidates, 2, self.config.hidden_dim |
| ) |
| flat_context = interaction_context.unsqueeze(1).expand(-1, num_candidates, -1).reshape( |
| batch_size * num_candidates, self.config.hidden_dim |
| ) |
| candidate_decoded = self._decode_from_queries( |
| queries=flat_queries, |
| decoder_memory=flat_memory, |
| phase_context=flat_phase, |
| role_context=flat_roles, |
| interaction_context=flat_context, |
| ) |
|
|
| proposal_deltas = candidate_decoded["action_mean"].view( |
| batch_size, |
| num_candidates, |
| self.config.chunk_size, |
| self.config.action_dim, |
| ) |
| proposal_logits = candidate_decoded["proposal_score"].view(batch_size, num_candidates) |
| proposal_candidates = decoded["action_mean"].unsqueeze(1) + 0.35 * torch.tanh(proposal_deltas) |
| proposal_candidates[:, 0] = decoded["action_mean"] |
| proposal_logits[:, 0] = decoded["proposal_score"] |
| decoded["proposal_candidates"] = proposal_candidates |
| decoded["proposal_logits"] = proposal_logits |
| return decoded |
|
|
| def sample_candidates( |
| self, |
| action_mean: Tensor, |
| action_log_std: Tensor, |
| num_candidates: int | None = None, |
| proposal_candidates: Tensor | None = None, |
| ) -> Tensor: |
| if proposal_candidates is not None: |
| return proposal_candidates |
| num_candidates = num_candidates or self.config.num_candidates |
| if num_candidates <= 1: |
| return action_mean.unsqueeze(1) |
| noise = torch.randn( |
| action_mean.size(0), |
| num_candidates, |
| action_mean.size(1), |
| action_mean.size(2), |
| device=action_mean.device, |
| dtype=action_mean.dtype, |
| ) |
| candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1) |
| candidates[:, 0] = action_mean |
| return candidates |
|
|
|
|
| DEFAULT_PROPOSAL_MODES = ( |
| "widen_opening", |
| "maintain_opening", |
| "slide_occluder", |
| "lift_support_layer", |
| "stabilize_support", |
| "retrieve", |
| ) |
|
|
| TASK_PROPOSAL_MODES = { |
| "foliage": ( |
| "sweep_left", |
| "sweep_right", |
| "pin_canopy", |
| "widen_gap", |
| "maintain_gap", |
| "insert_actor", |
| "retrieve", |
| ), |
| "bag": ( |
| "pin_left_rim", |
| "pin_right_rim", |
| "widen_mouth", |
| "maintain_mouth", |
| "probe_inside", |
| "insert_actor", |
| "retrieve", |
| ), |
| "cloth": ( |
| "lift_edge", |
| "separate_layer", |
| "stabilize_fold", |
| "maintain_lift", |
| "insert_actor", |
| "retrieve", |
| ), |
| } |
|
|
| TASK_INDEX = {"foliage": 0, "bag": 1, "cloth": 2} |
|
|
|
|
| def infer_task_name_from_text(text: str | None) -> str: |
| if not text: |
| return "generic" |
| lowered = text.lower() |
| if any(token in lowered for token in ("foliage", "canopy", "leaf", "leaves", "snail")): |
| return "foliage" |
| if any(token in lowered for token in ("bag", "mouth", "rim", "aperture")): |
| return "bag" |
| if any(token in lowered for token in ("cloth", "fold", "layer", "suitcase", "garment")): |
| return "cloth" |
| return "generic" |
|
|
|
|
| def proposal_mode_vocab(task_name: str, num_modes: int) -> tuple[str, ...]: |
| if task_name == "generic": |
| base_vocab = tuple(DEFAULT_PROPOSAL_MODES) |
| else: |
| vocab = TASK_PROPOSAL_MODES[task_name] |
| if len(vocab) > num_modes: |
| if num_modes >= 6: |
| base_vocab = ( |
| vocab[0], |
| vocab[1], |
| vocab[2], |
| vocab[3], |
| vocab[-2], |
| vocab[-1], |
| )[:num_modes] |
| else: |
| base_vocab = vocab[:num_modes] |
| else: |
| base_vocab = vocab |
| if len(base_vocab) >= num_modes: |
| return tuple(base_vocab[:num_modes]) |
| if not base_vocab: |
| return tuple("retrieve" for _ in range(num_modes)) |
| padded = list(base_vocab) |
| while len(padded) < num_modes: |
| padded.append(base_vocab[-1]) |
| return tuple(padded) |
|
|
|
|
| def swap_arm_action_order(action_chunk: Tensor) -> Tensor: |
| midpoint = action_chunk.shape[-1] // 2 |
| return torch.cat([action_chunk[..., midpoint:], action_chunk[..., :midpoint]], dim=-1) |
|
|
|
|
| class SymmetricCoordinatedChunkDecoder(nn.Module): |
| def __init__(self, config: ChunkDecoderConfig) -> None: |
| super().__init__() |
| self.config = config |
| proposal_context_dim = config.action_dim + (config.hidden_dim * 2) |
| decoder_layer = nn.TransformerDecoderLayer( |
| d_model=config.hidden_dim, |
| nhead=config.num_heads, |
| dim_feedforward=config.ff_dim, |
| dropout=config.dropout, |
| batch_first=True, |
| norm_first=True, |
| ) |
| self.arm_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers) |
| self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim) |
| self.arm_identity = nn.Embedding(2, config.hidden_dim) |
| self.task_embedding = nn.Embedding(len(TASK_INDEX), config.hidden_dim) |
| self.phase_adapter = nn.Linear(config.num_phases, config.hidden_dim) |
| self.role_adapter = nn.Linear(config.num_arm_roles, config.hidden_dim) |
| self.context_proj = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.coordination = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 3), |
| nn.Linear(config.hidden_dim * 3, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| ) |
| self.arm_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.arm_mean = nn.Linear(config.hidden_dim, config.arm_action_dim) |
| self.arm_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim) |
| self.proposal_mode_head = nn.Sequential( |
| nn.LayerNorm(proposal_context_dim), |
| nn.Linear(proposal_context_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_proposal_modes), |
| ) |
| self.proposal_mode_embeddings = nn.Embedding(config.num_proposal_modes, config.hidden_dim) |
| self.proposal_slot_embeddings = nn.Embedding(config.num_candidates, config.hidden_dim) |
| self.mode_residual_heads = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.LayerNorm(proposal_context_dim), |
| nn.Linear(proposal_context_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim), |
| ) |
| for _ in range(config.num_proposal_modes) |
| ] |
| ) |
| self.slot_delta = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim), |
| ) |
| self.proposal_score = nn.Sequential( |
| nn.LayerNorm(proposal_context_dim + config.hidden_dim), |
| nn.Linear(proposal_context_dim + config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, 1), |
| ) |
|
|
| def _conditioning( |
| self, |
| interaction_state: dict[str, Tensor] | None, |
| batch_size: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| swap_roles: bool = False, |
| ) -> tuple[Tensor, Tensor, Tensor]: |
| if interaction_state is None: |
| zero_phase = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype) |
| zero_roles = torch.zeros(batch_size, 2, self.config.hidden_dim, device=device, dtype=dtype) |
| zero_context = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype) |
| return zero_phase, zero_roles, zero_context |
| phase_probs = interaction_state["phase_logits"].softmax(dim=-1).to(dtype=dtype) |
| arm_role_probs = interaction_state["arm_role_logits"].softmax(dim=-1).to(dtype=dtype) |
| if swap_roles: |
| arm_role_probs = arm_role_probs.flip(1) |
| phase_context = self.phase_adapter(phase_probs) |
| role_context = self.role_adapter(arm_role_probs) |
| if interaction_state.get("interaction_tokens") is not None: |
| interaction_context = interaction_state["interaction_tokens"].mean(dim=1) |
| else: |
| interaction_context = interaction_state["field_tokens"].mean(dim=1) |
| return phase_context, role_context, self.context_proj(interaction_context) |
|
|
| def _decode_arm_tokens( |
| self, |
| queries: Tensor, |
| decoder_memory: Tensor, |
| phase_context: Tensor, |
| role_context: Tensor, |
| interaction_context: Tensor, |
| swap_roles: bool = False, |
| ) -> tuple[Tensor, Tensor, Tensor]: |
| batch_size, chunk_size, _ = queries.shape |
| identity_order = torch.tensor([1, 0], device=queries.device) if swap_roles else torch.tensor([0, 1], device=queries.device) |
| arm_queries = queries.unsqueeze(1).expand(-1, 2, -1, -1) |
| arm_queries = arm_queries + phase_context.unsqueeze(1).unsqueeze(2) |
| arm_queries = arm_queries + role_context.unsqueeze(2) |
| arm_queries = arm_queries + self.arm_identity(identity_order).view(1, 2, 1, -1).to(dtype=queries.dtype) |
| flat_queries = arm_queries.reshape(batch_size * 2, chunk_size, self.config.hidden_dim) |
| flat_memory = decoder_memory.unsqueeze(1).expand(-1, 2, -1, -1).reshape( |
| batch_size * 2, |
| decoder_memory.shape[1], |
| decoder_memory.shape[2], |
| ) |
| decoded = self.arm_decoder(flat_queries, flat_memory).reshape(batch_size, 2, chunk_size, self.config.hidden_dim) |
| coordination_input = torch.cat( |
| [ |
| decoded[:, 0], |
| decoded[:, 1], |
| interaction_context.unsqueeze(1).expand(-1, chunk_size, -1), |
| ], |
| dim=-1, |
| ) |
| coordination = torch.tanh(self.coordination(coordination_input)) |
| decoded[:, 0] = decoded[:, 0] + coordination |
| decoded[:, 1] = decoded[:, 1] + coordination |
| decoded = self.arm_head(decoded) |
| arm_mean = self.arm_mean(decoded) |
| arm_log_std = self.arm_log_std(decoded).clamp(min=-5.0, max=2.0) |
| return arm_mean, arm_log_std, coordination |
|
|
| def _proposal_outputs( |
| self, |
| base_action: Tensor, |
| pooled_context: Tensor, |
| task_names: list[str], |
| ) -> tuple[Tensor, Tensor, Tensor, list[list[str]]]: |
| batch_size = pooled_context.shape[0] |
| mode_logits = self.proposal_mode_head(pooled_context) |
| mode_residuals = [] |
| for head in self.mode_residual_heads: |
| residual = head(pooled_context).view(batch_size, self.config.chunk_size, self.config.action_dim) |
| mode_residuals.append(residual) |
| mode_residuals = torch.stack(mode_residuals, dim=1) |
|
|
| mode_assignments = torch.arange(self.config.num_candidates, device=pooled_context.device) % self.config.num_proposal_modes |
| slot_embeddings = self.proposal_slot_embeddings.weight |
| slot_deltas = self.slot_delta(slot_embeddings).view( |
| self.config.num_candidates, |
| self.config.chunk_size, |
| self.config.action_dim, |
| ) |
| proposal_candidates = [] |
| proposal_logits = [] |
| proposal_mode_names = [ |
| [ |
| proposal_mode_vocab(task_name, self.config.num_proposal_modes)[int(mode_assignments[slot_idx])] |
| for slot_idx in range(self.config.num_candidates) |
| ] |
| for task_name in task_names |
| ] |
| for slot_idx in range(self.config.num_candidates): |
| mode_idx = int(mode_assignments[slot_idx]) |
| candidate = base_action + 0.35 * torch.tanh(mode_residuals[:, mode_idx]) + 0.05 * torch.tanh(slot_deltas[slot_idx]).unsqueeze(0) |
| proposal_candidates.append(candidate) |
| score_features = torch.cat( |
| [ |
| pooled_context, |
| self.proposal_mode_embeddings.weight[mode_idx].unsqueeze(0).expand(batch_size, -1) |
| + slot_embeddings[slot_idx].unsqueeze(0).expand(batch_size, -1), |
| ], |
| dim=-1, |
| ) |
| proposal_logits.append( |
| self.proposal_score(score_features).squeeze(-1) + mode_logits[:, mode_idx] |
| ) |
| stacked_candidates = torch.stack(proposal_candidates, dim=1) |
| stacked_logits = torch.stack(proposal_logits, dim=1) |
| stacked_candidates[:, 0] = base_action |
| return stacked_candidates, stacked_logits, mode_logits, proposal_mode_names |
|
|
| def forward( |
| self, |
| scene_tokens: Tensor, |
| interaction_state: dict[str, Tensor] | None = None, |
| memory_tokens: Tensor | None = None, |
| reveal_tokens: Tensor | None = None, |
| memory_token: Tensor | None = None, |
| compute_equivariance_probe: bool = False, |
| task_names: list[str] | None = None, |
| ) -> dict[str, Tensor]: |
| if memory_tokens is None: |
| memory_tokens = memory_token |
| batch_size = scene_tokens.shape[0] |
| dtype = scene_tokens.dtype |
| phase_context, role_context, interaction_context = self._conditioning( |
| interaction_state=interaction_state, |
| batch_size=batch_size, |
| device=scene_tokens.device, |
| dtype=dtype, |
| ) |
|
|
| decoder_memory = scene_tokens |
| interaction_tokens = interaction_state.get("interaction_tokens") if interaction_state is not None else None |
| if interaction_tokens is not None: |
| decoder_memory = torch.cat([decoder_memory, interaction_tokens], dim=1) |
| elif reveal_tokens is not None: |
| decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1) |
| if memory_tokens is not None: |
| decoder_memory = torch.cat([decoder_memory, memory_tokens], dim=1) |
|
|
| canonical_task_names = [infer_task_name_from_text(name) for name in (task_names or ["generic"] * batch_size)] |
| task_ids = torch.as_tensor( |
| [TASK_INDEX[name] for name in canonical_task_names if name in TASK_INDEX], |
| device=scene_tokens.device, |
| dtype=torch.long, |
| ) |
| if task_ids.numel() != batch_size: |
| task_ids = torch.as_tensor( |
| [TASK_INDEX.get(name, 0) for name in canonical_task_names], |
| device=scene_tokens.device, |
| dtype=torch.long, |
| ) |
| interaction_context = interaction_context + self.task_embedding(task_ids) |
|
|
| base_queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1) |
| arm_mean, arm_log_std, coordination = self._decode_arm_tokens( |
| queries=base_queries, |
| decoder_memory=decoder_memory, |
| phase_context=phase_context, |
| role_context=role_context, |
| interaction_context=interaction_context, |
| ) |
| action_mean = torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1) |
| action_log_std = torch.cat([arm_log_std[:, 0], arm_log_std[:, 1]], dim=-1) |
| pooled_context = torch.cat( |
| [ |
| arm_mean[:, 0].mean(dim=1), |
| arm_mean[:, 1].mean(dim=1), |
| coordination.mean(dim=1), |
| interaction_context, |
| ], |
| dim=-1, |
| ) |
| proposal_candidates, proposal_logits, proposal_mode_logits, proposal_mode_names = self._proposal_outputs( |
| action_mean, |
| pooled_context, |
| canonical_task_names, |
| ) |
|
|
| outputs = { |
| "decoded_tokens": torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1), |
| "right_tokens": arm_mean[:, 0], |
| "left_tokens": arm_mean[:, 1], |
| "coordination_tokens": coordination, |
| "action_mean": action_mean, |
| "action_log_std": action_log_std, |
| "proposal_candidates": proposal_candidates, |
| "proposal_logits": proposal_logits, |
| "proposal_mode_logits": proposal_mode_logits, |
| "proposal_mode_assignments": torch.arange( |
| self.config.num_candidates, |
| device=scene_tokens.device, |
| ) % self.config.num_proposal_modes, |
| "proposal_mode_names": proposal_mode_names, |
| "proposal_task_names": canonical_task_names, |
| } |
| if compute_equivariance_probe: |
| swapped_phase, swapped_roles, swapped_context = self._conditioning( |
| interaction_state=interaction_state, |
| batch_size=batch_size, |
| device=scene_tokens.device, |
| dtype=dtype, |
| swap_roles=True, |
| ) |
| swapped_arm_mean, _, _ = self._decode_arm_tokens( |
| queries=base_queries, |
| decoder_memory=decoder_memory, |
| phase_context=swapped_phase, |
| role_context=swapped_roles, |
| interaction_context=swapped_context, |
| swap_roles=True, |
| ) |
| outputs["equivariance_probe_action_mean"] = torch.cat( |
| [swapped_arm_mean[:, 0], swapped_arm_mean[:, 1]], |
| dim=-1, |
| ) |
| outputs["equivariance_target_action_mean"] = swap_arm_action_order(action_mean) |
| return outputs |
|
|
| def sample_candidates( |
| self, |
| action_mean: Tensor, |
| action_log_std: Tensor, |
| num_candidates: int | None = None, |
| proposal_candidates: Tensor | None = None, |
| ) -> Tensor: |
| if proposal_candidates is not None: |
| return proposal_candidates |
| num_candidates = num_candidates or self.config.num_candidates |
| if num_candidates <= 1: |
| return action_mean.unsqueeze(1) |
| noise = torch.randn( |
| action_mean.size(0), |
| num_candidates, |
| action_mean.size(1), |
| action_mean.size(2), |
| device=action_mean.device, |
| dtype=action_mean.dtype, |
| ) |
| candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1) |
| candidates[:, 0] = action_mean |
| return candidates |
|
|