| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Sequence |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from models.action_decoder import ( |
| ACTBimanualChunkDecoder, |
| ChunkDecoderConfig, |
| InteractionChunkDecoder, |
| SymmetricCoordinatedChunkDecoder, |
| infer_task_name_from_text, |
| ) |
| from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig |
| from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig |
| from models.observation_memory import ( |
| DualObservationMemory, |
| InteractionObservationMemory, |
| ObservationMemory, |
| ObservationMemoryConfig, |
| ) |
| from models.planner import CascadePlanner, InteractionPlanner, PlannerConfig, RevealPlanner |
| from models.reveal_head import ( |
| ElasticOcclusionStateHead, |
| InteractionStateHead, |
| RevealHeadConfig, |
| RevealStateHead, |
| ) |
| from models.world_model import ElasticOcclusionWorldModel, InteractionRolloutModel, RevealWM, RevealWMConfig |
|
|
|
|
| @dataclass |
| class PolicyConfig: |
| backbone: FrozenVLBackboneConfig = field(default_factory=FrozenVLBackboneConfig) |
| fusion: MultiViewFusionConfig = field(default_factory=MultiViewFusionConfig) |
| memory: ObservationMemoryConfig = field(default_factory=ObservationMemoryConfig) |
| decoder: ChunkDecoderConfig = field(default_factory=ChunkDecoderConfig) |
| reveal_head: RevealHeadConfig = field(default_factory=RevealHeadConfig) |
| world_model: RevealWMConfig = field(default_factory=RevealWMConfig) |
| planner: PlannerConfig = field(default_factory=PlannerConfig) |
|
|
|
|
| class BackboneOnlyPolicy(nn.Module): |
| def __init__(self, config: PolicyConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.backbone = FrozenVLBackbone(config.backbone) |
| self.fusion = MultiViewFusion(config.fusion) |
| self.memory = ObservationMemory(config.memory) |
| self.decoder = ACTBimanualChunkDecoder(config.decoder) |
|
|
| def _encode_language( |
| self, |
| images: Tensor, |
| texts: Sequence[str] | None = None, |
| language_tokens: dict[str, Tensor] | None = None, |
| ) -> Tensor: |
| if language_tokens is None: |
| if texts is None: |
| raise ValueError("Either texts or language_tokens must be provided.") |
| language_tokens = self.backbone.tokenize_text(texts, device=images.device) |
| return self.backbone.encode_text( |
| input_ids=language_tokens["input_ids"], |
| attention_mask=language_tokens["attention_mask"], |
| ) |
|
|
| def _task_names(self, batch_size: int, texts: Sequence[str] | None = None) -> list[str]: |
| if texts is None: |
| return ["generic"] * batch_size |
| return [infer_task_name_from_text(text) for text in texts] |
|
|
| def encode_scene( |
| self, |
| images: Tensor, |
| proprio: Tensor, |
| texts: Sequence[str] | None = None, |
| language_tokens: dict[str, Tensor] | None = None, |
| ) -> Tensor: |
| image_tokens = self.backbone.encode_images(images) |
| text_tokens = self._encode_language(images, texts=texts, language_tokens=language_tokens) |
| return self.fusion(image_tokens=image_tokens, proprio=proprio, language_tokens=text_tokens) |
|
|
| def _expand_language_tokens_for_history( |
| self, |
| language_tokens: dict[str, Tensor] | None, |
| history_steps: int, |
| ) -> dict[str, Tensor] | None: |
| if language_tokens is None: |
| return None |
| return { |
| key: value.unsqueeze(1).expand(-1, history_steps, *value.shape[1:]).reshape( |
| value.shape[0] * history_steps, *value.shape[1:] |
| ) |
| for key, value in language_tokens.items() |
| } |
|
|
| def encode_history( |
| self, |
| history_images: Tensor | None, |
| history_proprio: Tensor | None, |
| texts: Sequence[str] | None = None, |
| language_tokens: dict[str, Tensor] | None = None, |
| ) -> Tensor | None: |
| if history_images is None or history_proprio is None or history_images.numel() == 0: |
| return None |
| batch_size, history_steps = history_images.shape[:2] |
| flat_images = history_images.reshape(batch_size * history_steps, *history_images.shape[2:]) |
| flat_proprio = history_proprio.reshape(batch_size * history_steps, history_proprio.shape[-1]) |
| if language_tokens is None: |
| if texts is None: |
| raise ValueError("Either texts or language_tokens must be provided.") |
| flat_texts = [text for text in texts for _ in range(history_steps)] |
| flat_language_tokens = None |
| else: |
| flat_texts = None |
| flat_language_tokens = self._expand_language_tokens_for_history(language_tokens, history_steps) |
| history_scene = self.encode_scene( |
| flat_images, |
| flat_proprio, |
| texts=flat_texts, |
| language_tokens=flat_language_tokens, |
| ) |
| return history_scene.view(batch_size, history_steps, history_scene.shape[1], history_scene.shape[2]) |
|
|
| def forward( |
| self, |
| images: Tensor, |
| proprio: Tensor, |
| texts: Sequence[str] | None = None, |
| language_tokens: dict[str, Tensor] | None = None, |
| history_images: Tensor | None = None, |
| history_proprio: Tensor | None = None, |
| history_actions: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| scene_tokens = self.encode_scene(images, proprio, texts=texts, language_tokens=language_tokens) |
| history_scene_tokens = self.encode_history( |
| history_images, |
| history_proprio, |
| texts=texts, |
| language_tokens=language_tokens, |
| ) |
| memory_output = self.memory( |
| scene_tokens, |
| history_scene_tokens=history_scene_tokens, |
| history_actions=history_actions, |
| ) |
| decoded = self.decoder(scene_tokens, memory_token=memory_output["memory_token"]) |
| decoded["scene_tokens"] = scene_tokens |
| decoded["history_scene_tokens"] = history_scene_tokens |
| decoded["memory_output"] = memory_output |
| return decoded |
|
|
|
|
| class RevealBimanualPolicy(BackboneOnlyPolicy): |
| def __init__(self, config: PolicyConfig) -> None: |
| super().__init__(config) |
| self.reveal_head = RevealStateHead(config.reveal_head) |
| self.world_model = RevealWM(config.world_model) |
| self.planner = RevealPlanner(config.planner) |
|
|
| def forward( |
| self, |
| images: Tensor, |
| proprio: Tensor, |
| texts: Sequence[str] | None = None, |
| language_tokens: dict[str, Tensor] | None = None, |
| history_images: Tensor | None = None, |
| history_proprio: Tensor | None = None, |
| history_actions: Tensor | None = None, |
| plan: bool = True, |
| support_mode_conditioning: bool = True, |
| candidate_chunks_override: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| outputs = super().forward( |
| images, |
| proprio, |
| texts=texts, |
| language_tokens=language_tokens, |
| history_images=history_images, |
| history_proprio=history_proprio, |
| history_actions=history_actions, |
| ) |
| reveal_state = self.reveal_head( |
| outputs["scene_tokens"], |
| memory_token=outputs["memory_output"]["memory_token"], |
| ) |
| outputs["reveal_state"] = reveal_state |
| outputs["memory_uncertainty"] = outputs["memory_output"]["memory_uncertainty"] |
|
|
| decoded = self.decoder( |
| outputs["scene_tokens"], |
| reveal_tokens=reveal_state["field_tokens"], |
| memory_token=outputs["memory_output"]["memory_token"], |
| ) |
| outputs.update(decoded) |
|
|
| if plan: |
| candidate_chunks = candidate_chunks_override |
| if candidate_chunks is None: |
| candidate_chunks = self.decoder.sample_candidates( |
| outputs["action_mean"], |
| outputs["action_log_std"], |
| num_candidates=self.config.decoder.num_candidates, |
| ) |
| outputs["candidate_chunks"] = candidate_chunks |
| batch_size, num_candidates, chunk_size, action_dim = candidate_chunks.shape |
| flat_chunks = candidate_chunks.view(batch_size * num_candidates, chunk_size, action_dim) |
| tiled_scene = outputs["scene_tokens"].unsqueeze(1).expand(-1, num_candidates, -1, -1) |
| tiled_scene = tiled_scene.reshape(batch_size * num_candidates, outputs["scene_tokens"].shape[1], outputs["scene_tokens"].shape[2]) |
| planning_reveal_state = reveal_state |
| if not support_mode_conditioning: |
| planning_reveal_state = dict(reveal_state) |
| planning_reveal_state["support_mode_logits"] = torch.zeros_like(reveal_state["support_mode_logits"]) |
| tiled_reveal = { |
| key: value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape(batch_size * num_candidates, *value.shape[1:]) |
| for key, value in planning_reveal_state.items() |
| } |
| rollout = self.world_model(tiled_scene, tiled_reveal, flat_chunks) |
| reshaped_rollout = { |
| key: value.view(batch_size, num_candidates, *value.shape[1:]) for key, value in rollout.items() |
| } |
| selected = self.planner.select_best( |
| candidate_chunks=candidate_chunks, |
| rollout_state=reshaped_rollout, |
| ) |
| outputs["planned_rollout"] = reshaped_rollout |
| outputs["planned_chunk"] = selected["best_chunk"] |
| outputs["planner_success_logits"] = selected["success_logits"] |
| outputs["planner_risk_values"] = selected["risk_values"] |
| outputs["planner_scores"] = selected["utility_scores"] |
| outputs["best_candidate_indices"] = selected["best_indices"] |
| return outputs |
|
|
|
|
| class InteractionBimanualPolicy(BackboneOnlyPolicy): |
| def __init__(self, config: PolicyConfig) -> None: |
| super().__init__(config) |
| self.memory = InteractionObservationMemory(config.memory) |
| self.decoder = InteractionChunkDecoder(config.decoder) |
| self.interaction_head = InteractionStateHead(config.reveal_head) |
| self.world_model = InteractionRolloutModel(config.world_model) |
| self.planner = InteractionPlanner(config.planner) |
|
|
| def _tile_tensor(self, value: Tensor, num_candidates: int) -> Tensor: |
| return value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape( |
| value.shape[0] * num_candidates, |
| *value.shape[1:], |
| ) |
|
|
| def _tile_state(self, state: dict[str, Tensor], num_candidates: int) -> dict[str, Tensor]: |
| return {key: self._tile_tensor(value, num_candidates) for key, value in state.items()} |
|
|
| def forward( |
| self, |
| images: Tensor, |
| proprio: Tensor, |
| texts: Sequence[str] | None = None, |
| language_tokens: dict[str, Tensor] | None = None, |
| history_images: Tensor | None = None, |
| history_proprio: Tensor | None = None, |
| history_actions: Tensor | None = None, |
| plan: bool = True, |
| support_mode_conditioning: bool = True, |
| candidate_chunks_override: Tensor | None = None, |
| use_interaction_head: bool = True, |
| use_role_tokens: bool = True, |
| history_steps_override: int | None = None, |
| ) -> dict[str, Tensor]: |
| scene_tokens = self.encode_scene(images, proprio, texts=texts, language_tokens=language_tokens) |
| history_scene_tokens = self.encode_history( |
| history_images, |
| history_proprio, |
| texts=texts, |
| language_tokens=language_tokens, |
| ) |
| if history_steps_override is not None and history_scene_tokens is not None and history_scene_tokens.numel() > 0: |
| history_scene_tokens = history_scene_tokens[:, -history_steps_override:] |
| memory_output = self.memory( |
| scene_tokens, |
| history_scene_tokens=history_scene_tokens, |
| history_actions=history_actions, |
| ) |
|
|
| interaction_state = None |
| if use_interaction_head: |
| interaction_state = self.interaction_head( |
| scene_tokens, |
| memory_tokens=memory_output["memory_tokens"], |
| ) |
| interaction_state["memory_tokens"] = memory_output["memory_tokens"] |
| interaction_state["memory_token"] = memory_output["memory_token"] |
|
|
| if interaction_state is not None and not use_role_tokens: |
| interaction_state = dict(interaction_state) |
| interaction_state["arm_role_logits"] = torch.zeros_like(interaction_state["arm_role_logits"]) |
|
|
| decoded = self.decoder( |
| scene_tokens, |
| interaction_state=interaction_state, |
| memory_tokens=memory_output["memory_tokens"], |
| ) |
| outputs = { |
| **decoded, |
| "scene_tokens": scene_tokens, |
| "history_scene_tokens": history_scene_tokens, |
| "memory_output": memory_output, |
| "memory_uncertainty": memory_output["memory_uncertainty"], |
| "interaction_state": interaction_state, |
| "reveal_state": interaction_state, |
| } |
|
|
| if plan: |
| candidate_chunks = candidate_chunks_override |
| proposal_logits = outputs.get("proposal_logits") |
| if candidate_chunks is None: |
| candidate_chunks = self.decoder.sample_candidates( |
| outputs["action_mean"], |
| outputs["action_log_std"], |
| num_candidates=self.config.decoder.num_candidates, |
| proposal_candidates=outputs.get("proposal_candidates"), |
| ) |
| else: |
| proposal_logits = None |
| outputs["candidate_chunks"] = candidate_chunks |
|
|
| if interaction_state is None: |
| outputs["planned_chunk"] = outputs["action_mean"] |
| outputs["planner_success_logits"] = torch.zeros( |
| candidate_chunks.shape[:2], |
| device=candidate_chunks.device, |
| dtype=candidate_chunks.dtype, |
| ) |
| outputs["planner_risk_values"] = torch.zeros_like(outputs["planner_success_logits"]) |
| outputs["planner_scores"] = torch.zeros_like(outputs["planner_success_logits"]) |
| outputs["best_candidate_indices"] = torch.zeros( |
| candidate_chunks.shape[0], |
| dtype=torch.long, |
| device=candidate_chunks.device, |
| ) |
| outputs["planned_rollout"] = {} |
| return outputs |
|
|
| batch_size, num_candidates, chunk_size, action_dim = candidate_chunks.shape |
| flat_chunks = candidate_chunks.view(batch_size * num_candidates, chunk_size, action_dim) |
| tiled_scene = self._tile_tensor(scene_tokens, num_candidates) |
| planning_state = interaction_state |
| if not support_mode_conditioning: |
| planning_state = dict(interaction_state) |
| planning_state["support_mode_logits"] = torch.zeros_like(interaction_state["support_mode_logits"]) |
| tiled_state = self._tile_state(planning_state, num_candidates) |
| tiled_memory_tokens = self._tile_tensor(memory_output["memory_tokens"], num_candidates) |
| rollout = self.world_model( |
| scene_tokens=tiled_scene, |
| interaction_state=tiled_state, |
| action_chunk=flat_chunks, |
| memory_tokens=tiled_memory_tokens, |
| ) |
| reshaped_rollout = { |
| key: value.view(batch_size, num_candidates, *value.shape[1:]) for key, value in rollout.items() |
| } |
| selected = self.planner.select_best( |
| candidate_chunks=candidate_chunks, |
| rollout_state=reshaped_rollout, |
| proposal_logits=proposal_logits, |
| ) |
| outputs["planned_rollout"] = reshaped_rollout |
| outputs["planned_chunk"] = selected["best_chunk"] |
| outputs["planner_success_logits"] = selected["success_logits"] |
| outputs["planner_risk_values"] = selected["risk_values"] |
| outputs["planner_scores"] = selected["utility_scores"] |
| outputs["best_candidate_indices"] = selected["best_indices"] |
| return outputs |
|
|
|
|
| class ElasticRevealBimanualPolicy(BackboneOnlyPolicy): |
| def __init__(self, config: PolicyConfig) -> None: |
| super().__init__(config) |
| self.memory = DualObservationMemory(config.memory) |
| self.decoder = SymmetricCoordinatedChunkDecoder(config.decoder) |
| self.elastic_state_head = ElasticOcclusionStateHead(config.reveal_head) |
| self.world_model = ElasticOcclusionWorldModel(config.world_model) |
| self.planner = CascadePlanner(config.planner) |
|
|
| def _encode_scene_with_optional_depth( |
| self, |
| images: Tensor, |
| proprio: Tensor, |
| texts: Sequence[str] | None = None, |
| language_tokens: dict[str, Tensor] | None = None, |
| depths: Tensor | None = None, |
| depth_valid: Tensor | None = None, |
| camera_intrinsics: Tensor | None = None, |
| camera_extrinsics: Tensor | None = None, |
| use_depth: bool = True, |
| use_geometry_tokens: bool | None = None, |
| use_camera_pose_tokens: bool | None = None, |
| ) -> dict[str, Tensor]: |
| encoded = self.backbone.encode_images( |
| images, |
| depths=depths if use_depth else None, |
| depth_valid=depth_valid if use_depth else None, |
| camera_intrinsics=camera_intrinsics if use_depth else None, |
| camera_extrinsics=camera_extrinsics if use_depth else None, |
| return_aux=True, |
| use_depth_tokens=use_depth, |
| use_geometry_tokens=use_geometry_tokens, |
| use_camera_pose_tokens=use_camera_pose_tokens, |
| ) |
| assert isinstance(encoded, dict) |
| text_tokens = self._encode_language(images, texts=texts, language_tokens=language_tokens) |
| fused = self.fusion( |
| image_tokens=encoded["rgb_tokens"], |
| proprio=proprio, |
| language_tokens=text_tokens, |
| depth_tokens=encoded.get("depth_tokens"), |
| geometry_tokens=encoded.get("geometry_tokens"), |
| camera_tokens=encoded.get("camera_tokens"), |
| return_aux=True, |
| ) |
| assert isinstance(fused, dict) |
| return { |
| "scene_tokens": fused["scene_tokens"], |
| "view_summaries": fused["view_summaries"], |
| "geometry_summaries": fused["geometry_summaries"], |
| "depth_tokens": encoded.get("depth_tokens"), |
| "geometry_tokens": encoded.get("geometry_tokens"), |
| "camera_tokens": encoded.get("camera_tokens"), |
| } |
|
|
| def _expand_language_tokens_for_history( |
| self, |
| language_tokens: dict[str, Tensor] | None, |
| history_steps: int, |
| ) -> dict[str, Tensor] | None: |
| if language_tokens is None: |
| return None |
| return { |
| key: value.unsqueeze(1).expand(-1, history_steps, *value.shape[1:]).reshape( |
| value.shape[0] * history_steps, *value.shape[1:] |
| ) |
| for key, value in language_tokens.items() |
| } |
|
|
| def encode_history_with_optional_depth( |
| self, |
| history_images: Tensor | None, |
| history_proprio: Tensor | None, |
| texts: Sequence[str] | None = None, |
| language_tokens: dict[str, Tensor] | None = None, |
| history_depths: Tensor | None = None, |
| history_depth_valid: Tensor | None = None, |
| camera_intrinsics: Tensor | None = None, |
| camera_extrinsics: Tensor | None = None, |
| use_depth: bool = True, |
| use_geometry_tokens: bool | None = None, |
| use_camera_pose_tokens: bool | None = None, |
| ) -> Tensor | None: |
| if history_images is None or history_proprio is None or history_images.numel() == 0: |
| return None |
| batch_size, history_steps = history_images.shape[:2] |
| flat_images = history_images.reshape(batch_size * history_steps, *history_images.shape[2:]) |
| flat_proprio = history_proprio.reshape(batch_size * history_steps, history_proprio.shape[-1]) |
| flat_depths = None |
| flat_depth_valid = None |
| if history_depths is not None and history_depths.numel() > 0: |
| flat_depths = history_depths.reshape(batch_size * history_steps, *history_depths.shape[2:]) |
| if history_depth_valid is not None and history_depth_valid.numel() > 0: |
| flat_depth_valid = history_depth_valid.reshape(batch_size * history_steps, *history_depth_valid.shape[2:]) |
| if language_tokens is None: |
| flat_texts = [text for text in texts for _ in range(history_steps)] if texts is not None else None |
| flat_language_tokens = None |
| else: |
| flat_texts = None |
| flat_language_tokens = self._expand_language_tokens_for_history(language_tokens, history_steps) |
| history_scene = self._encode_scene_with_optional_depth( |
| images=flat_images, |
| proprio=flat_proprio, |
| texts=flat_texts, |
| language_tokens=flat_language_tokens, |
| depths=flat_depths, |
| depth_valid=flat_depth_valid, |
| camera_intrinsics=None, |
| camera_extrinsics=None, |
| use_depth=use_depth, |
| use_geometry_tokens=use_geometry_tokens, |
| use_camera_pose_tokens=use_camera_pose_tokens, |
| )["scene_tokens"] |
| return history_scene.view(batch_size, history_steps, history_scene.shape[1], history_scene.shape[2]) |
|
|
| def _tile_tensor(self, value: Tensor, num_candidates: int) -> Tensor: |
| return value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape( |
| value.shape[0] * num_candidates, |
| *value.shape[1:], |
| ) |
|
|
| def _tile_state(self, state: dict[str, Tensor], num_candidates: int) -> dict[str, Tensor]: |
| tiled: dict[str, Tensor] = {} |
| for key, value in state.items(): |
| if isinstance(value, Tensor): |
| tiled[key] = self._tile_tensor(value, num_candidates) |
| return tiled |
|
|
| def _detach_state(self, state: dict[str, Tensor]) -> dict[str, Tensor]: |
| detached: dict[str, Tensor] = {} |
| for key, value in state.items(): |
| detached[key] = value.detach() if isinstance(value, Tensor) else value |
| return detached |
|
|
| def _repeat_rollout_tensor(self, value: Tensor, num_candidates: int, horizon: int) -> Tensor: |
| value = value.detach() |
| return value.unsqueeze(1).unsqueeze(2).expand(-1, num_candidates, horizon, *value.shape[1:]) |
|
|
| def _zero_memory_output(self, scene_tokens: Tensor) -> dict[str, Tensor]: |
| batch_size, _, hidden_dim = scene_tokens.shape |
| scene_memory_tokens = scene_tokens.new_zeros((batch_size, self.config.memory.scene_bank_size, hidden_dim)) |
| belief_memory_tokens = scene_tokens.new_zeros((batch_size, self.config.memory.belief_bank_size, hidden_dim)) |
| memory_tokens = torch.cat([scene_memory_tokens, belief_memory_tokens], dim=1) |
| return { |
| "scene_memory_tokens": scene_memory_tokens, |
| "belief_memory_tokens": belief_memory_tokens, |
| "memory_tokens": memory_tokens, |
| "memory_token": memory_tokens.mean(dim=1, keepdim=True), |
| "memory_sequence": scene_tokens.new_zeros((batch_size, 0, hidden_dim)), |
| "memory_state": scene_tokens.new_zeros((batch_size, hidden_dim * 2)), |
| "memory_uncertainty": scene_tokens.new_zeros((batch_size,)), |
| "memory_write_rate": scene_tokens.new_zeros((batch_size,)), |
| "memory_saturation": scene_tokens.new_zeros((batch_size,)), |
| "scene_write_gate": scene_tokens.new_zeros((batch_size, self.config.memory.scene_bank_size)), |
| "belief_write_gate": scene_tokens.new_zeros((batch_size, self.config.memory.belief_bank_size)), |
| "memory_scene_state": scene_tokens.new_zeros((batch_size, hidden_dim)), |
| "memory_belief_state": scene_tokens.new_zeros((batch_size, hidden_dim)), |
| } |
|
|
| def _identity_rollout( |
| self, |
| interaction_state: dict[str, Tensor], |
| num_candidates: int, |
| ) -> dict[str, Tensor]: |
| horizon = max(1, self.config.world_model.rollout_horizon) |
| rollout: dict[str, Tensor] = {} |
| for key, value in interaction_state.items(): |
| if isinstance(value, Tensor): |
| rollout[key] = self._repeat_rollout_tensor(value, num_candidates, horizon) |
| return rollout |
|
|
| def forward( |
| self, |
| images: Tensor, |
| proprio: Tensor, |
| texts: Sequence[str] | None = None, |
| language_tokens: dict[str, Tensor] | None = None, |
| history_images: Tensor | None = None, |
| history_proprio: Tensor | None = None, |
| history_actions: Tensor | None = None, |
| plan: bool = True, |
| support_mode_conditioning: bool = True, |
| candidate_chunks_override: Tensor | None = None, |
| use_depth: bool = True, |
| use_world_model: bool = True, |
| use_planner: bool = True, |
| use_role_tokens: bool = True, |
| history_steps_override: int | None = None, |
| depths: Tensor | None = None, |
| depth_valid: Tensor | None = None, |
| camera_intrinsics: Tensor | None = None, |
| camera_extrinsics: Tensor | None = None, |
| history_depths: Tensor | None = None, |
| history_depth_valid: Tensor | None = None, |
| compute_equivariance_probe: bool = False, |
| use_geometry_tokens: bool | None = None, |
| use_camera_pose_tokens: bool | None = None, |
| use_memory: bool = True, |
| use_task_conditioning: bool = True, |
| rollout_mode_override: str | None = None, |
| use_proposal_candidates: bool = True, |
| ) -> dict[str, Tensor]: |
| task_names = self._task_names(images.shape[0], texts=texts) |
| scene_output = self._encode_scene_with_optional_depth( |
| images=images, |
| proprio=proprio, |
| texts=texts, |
| language_tokens=language_tokens, |
| depths=depths, |
| depth_valid=depth_valid, |
| camera_intrinsics=camera_intrinsics, |
| camera_extrinsics=camera_extrinsics, |
| use_depth=use_depth, |
| use_geometry_tokens=use_geometry_tokens, |
| use_camera_pose_tokens=use_camera_pose_tokens, |
| ) |
| scene_tokens = scene_output["scene_tokens"] |
| history_scene_tokens = self.encode_history_with_optional_depth( |
| history_images=history_images, |
| history_proprio=history_proprio, |
| texts=texts, |
| language_tokens=language_tokens, |
| history_depths=history_depths, |
| history_depth_valid=history_depth_valid, |
| camera_intrinsics=camera_intrinsics, |
| camera_extrinsics=camera_extrinsics, |
| use_depth=use_depth, |
| use_geometry_tokens=use_geometry_tokens, |
| use_camera_pose_tokens=use_camera_pose_tokens, |
| ) |
| if history_steps_override is not None and history_scene_tokens is not None and history_scene_tokens.numel() > 0: |
| history_scene_tokens = history_scene_tokens[:, -history_steps_override:] |
| if history_actions is not None and history_actions.numel() > 0: |
| history_actions = history_actions[:, -history_steps_override:] |
| if use_memory: |
| memory_output = self.memory( |
| scene_tokens, |
| history_scene_tokens=history_scene_tokens, |
| history_actions=history_actions, |
| ) |
| else: |
| memory_output = self._zero_memory_output(scene_tokens) |
| elastic_state = self.elastic_state_head( |
| scene_tokens, |
| memory_tokens=memory_output["memory_tokens"], |
| task_names=task_names, |
| use_task_conditioning=use_task_conditioning, |
| ) |
| elastic_state["memory_tokens"] = memory_output["memory_tokens"] |
| elastic_state["memory_token"] = memory_output["memory_token"] |
| elastic_state["scene_memory_tokens"] = memory_output["scene_memory_tokens"] |
| elastic_state["belief_memory_tokens"] = memory_output["belief_memory_tokens"] |
| if not use_role_tokens: |
| elastic_state = dict(elastic_state) |
| elastic_state["arm_role_logits"] = torch.zeros_like(elastic_state["arm_role_logits"]) |
|
|
| decoded = self.decoder( |
| scene_tokens, |
| interaction_state=elastic_state, |
| memory_tokens=memory_output["memory_tokens"], |
| compute_equivariance_probe=compute_equivariance_probe, |
| task_names=task_names, |
| ) |
| outputs = { |
| **decoded, |
| "scene_tokens": scene_tokens, |
| "history_scene_tokens": history_scene_tokens, |
| "memory_output": memory_output, |
| "memory_uncertainty": memory_output["memory_uncertainty"], |
| "interaction_state": elastic_state, |
| "reveal_state": elastic_state, |
| "view_summaries": scene_output["view_summaries"], |
| "geometry_summaries": scene_output["geometry_summaries"], |
| "depth_tokens": scene_output["depth_tokens"], |
| "geometry_tokens": scene_output["geometry_tokens"], |
| "camera_tokens": scene_output["camera_tokens"], |
| "rollout_source": "none", |
| "task_names": task_names, |
| } |
|
|
| candidate_chunks = candidate_chunks_override |
| proposal_logits = outputs.get("proposal_logits") |
| if candidate_chunks is None: |
| candidate_chunks = self.decoder.sample_candidates( |
| outputs["action_mean"], |
| outputs["action_log_std"], |
| num_candidates=self.config.decoder.num_candidates, |
| proposal_candidates=outputs.get("proposal_candidates") if use_proposal_candidates else None, |
| ) |
| if not use_proposal_candidates: |
| proposal_logits = None |
| else: |
| proposal_logits = None |
| outputs["candidate_chunks"] = candidate_chunks |
|
|
| if not plan or not use_planner: |
| outputs["planned_chunk"] = outputs["action_mean"] |
| outputs["planned_rollout"] = {} |
| outputs["planner_success_logits"] = torch.zeros( |
| candidate_chunks.shape[:2], |
| device=candidate_chunks.device, |
| dtype=candidate_chunks.dtype, |
| ) |
| outputs["planner_risk_values"] = torch.zeros_like(outputs["planner_success_logits"]) |
| outputs["planner_scores"] = torch.zeros_like(outputs["planner_success_logits"]) |
| outputs["best_candidate_indices"] = torch.zeros( |
| candidate_chunks.shape[0], |
| dtype=torch.long, |
| device=candidate_chunks.device, |
| ) |
| return outputs |
|
|
| shortlist_indices = self.planner.shortlist( |
| proposal_logits=proposal_logits, |
| candidate_chunks=candidate_chunks, |
| proposal_mode_assignments=outputs.get("proposal_mode_assignments") if use_proposal_candidates else None, |
| ) |
| outputs["planner_topk_indices"] = shortlist_indices |
| batch_size = candidate_chunks.shape[0] |
| batch_indices = torch.arange(batch_size, device=candidate_chunks.device).unsqueeze(-1) |
| topk_candidates = candidate_chunks[batch_indices, shortlist_indices] |
| num_topk = topk_candidates.shape[1] |
| outputs["planner_topk_candidates"] = topk_candidates |
| proposal_mode_names = outputs.get("proposal_mode_names") |
| topk_proposal_mode_names = None |
| if proposal_mode_names is not None and use_proposal_candidates: |
| topk_proposal_mode_names = [ |
| [proposal_mode_names[batch_idx][int(candidate_idx.item())] for candidate_idx in shortlist_indices[batch_idx]] |
| for batch_idx in range(batch_size) |
| ] |
| outputs["planner_topk_mode_names"] = topk_proposal_mode_names |
| if proposal_logits is not None: |
| topk_proposal_logits = proposal_logits.gather(1, shortlist_indices) |
| else: |
| topk_proposal_logits = None |
| planning_state = elastic_state |
| if not support_mode_conditioning: |
| planning_state = dict(elastic_state) |
| planning_state["support_mode_logits"] = torch.zeros_like(elastic_state["support_mode_logits"]) |
|
|
| if not use_world_model: |
| detached_state = self._detach_state(planning_state) |
| identity_rollout = self._identity_rollout( |
| interaction_state=detached_state, |
| num_candidates=num_topk, |
| ) |
| selected = self.planner.select_best( |
| initial_state=detached_state, |
| candidate_chunks=topk_candidates, |
| rollout_state=identity_rollout, |
| proposal_logits=topk_proposal_logits, |
| candidate_indices=shortlist_indices, |
| proposal_mode_names=topk_proposal_mode_names, |
| ) |
| outputs["planned_rollout"] = identity_rollout |
| outputs["planned_chunk"] = selected["best_chunk"] |
| outputs["planner_success_logits"] = selected["success_logits"] |
| outputs["planner_risk_values"] = selected["risk_values"] |
| outputs["planner_scores"] = selected["utility_total"] |
| outputs["best_candidate_indices"] = selected["best_indices"] |
| outputs["utility_structured"] = selected["utility_structured"] |
| outputs["utility_residual"] = selected["utility_residual"] |
| outputs["utility_total"] = selected["utility_total"] |
| outputs["ranking_diagnostics"] = selected["ranking_diagnostics"] |
| outputs["rollout_source"] = "identity" |
| return outputs |
|
|
| flat_chunks = topk_candidates.view(batch_size * num_topk, topk_candidates.shape[2], topk_candidates.shape[3]) |
| tiled_scene = self._tile_tensor(scene_tokens, num_topk) |
| tiled_state = self._tile_state(planning_state, num_topk) |
| rollout = self.world_model( |
| scene_tokens=tiled_scene, |
| interaction_state=tiled_state, |
| action_chunk=flat_chunks, |
| memory_tokens=self._tile_tensor(memory_output["memory_tokens"], num_topk), |
| scene_memory_tokens=self._tile_tensor(memory_output["scene_memory_tokens"], num_topk), |
| belief_memory_tokens=self._tile_tensor(memory_output["belief_memory_tokens"], num_topk), |
| task_names=[name for name in task_names for _ in range(num_topk)], |
| rollout_mode_override=rollout_mode_override, |
| ) |
| reshaped_rollout = { |
| key: value.view(batch_size, num_topk, *value.shape[1:]) for key, value in rollout.items() |
| } |
| selected = self.planner.select_best( |
| initial_state=elastic_state, |
| candidate_chunks=topk_candidates, |
| rollout_state=reshaped_rollout, |
| proposal_logits=topk_proposal_logits, |
| candidate_indices=shortlist_indices, |
| proposal_mode_names=topk_proposal_mode_names, |
| ) |
| outputs["planned_rollout"] = reshaped_rollout |
| outputs["planned_chunk"] = selected["best_chunk"] |
| outputs["planner_success_logits"] = selected["success_logits"] |
| outputs["planner_risk_values"] = selected["risk_values"] |
| outputs["planner_scores"] = selected["utility_total"] |
| outputs["best_candidate_indices"] = selected["best_indices"] |
| outputs["utility_structured"] = selected["utility_structured"] |
| outputs["utility_residual"] = selected["utility_residual"] |
| outputs["utility_total"] = selected["utility_total"] |
| outputs["ranking_diagnostics"] = selected["ranking_diagnostics"] |
| outputs["rollout_source"] = "learned" |
| return outputs |
|
|