lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Sequence
import torch
from torch import Tensor, nn
from models.action_decoder import (
ACTBimanualChunkDecoder,
ChunkDecoderConfig,
InteractionChunkDecoder,
SymmetricCoordinatedChunkDecoder,
infer_task_name_from_text,
)
from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
from models.observation_memory import (
DualObservationMemory,
InteractionObservationMemory,
ObservationMemory,
ObservationMemoryConfig,
)
from models.planner import CascadePlanner, InteractionPlanner, PlannerConfig, RevealPlanner
from models.reveal_head import (
ElasticOcclusionStateHead,
InteractionStateHead,
RevealHeadConfig,
RevealStateHead,
)
from models.world_model import ElasticOcclusionWorldModel, InteractionRolloutModel, RevealWM, RevealWMConfig
@dataclass
class PolicyConfig:
backbone: FrozenVLBackboneConfig = field(default_factory=FrozenVLBackboneConfig)
fusion: MultiViewFusionConfig = field(default_factory=MultiViewFusionConfig)
memory: ObservationMemoryConfig = field(default_factory=ObservationMemoryConfig)
decoder: ChunkDecoderConfig = field(default_factory=ChunkDecoderConfig)
reveal_head: RevealHeadConfig = field(default_factory=RevealHeadConfig)
world_model: RevealWMConfig = field(default_factory=RevealWMConfig)
planner: PlannerConfig = field(default_factory=PlannerConfig)
class BackboneOnlyPolicy(nn.Module):
def __init__(self, config: PolicyConfig) -> None:
super().__init__()
self.config = config
self.backbone = FrozenVLBackbone(config.backbone)
self.fusion = MultiViewFusion(config.fusion)
self.memory = ObservationMemory(config.memory)
self.decoder = ACTBimanualChunkDecoder(config.decoder)
def _encode_language(
self,
images: Tensor,
texts: Sequence[str] | None = None,
language_tokens: dict[str, Tensor] | None = None,
) -> Tensor:
if language_tokens is None:
if texts is None:
raise ValueError("Either texts or language_tokens must be provided.")
language_tokens = self.backbone.tokenize_text(texts, device=images.device)
return self.backbone.encode_text(
input_ids=language_tokens["input_ids"],
attention_mask=language_tokens["attention_mask"],
)
def _task_names(self, batch_size: int, texts: Sequence[str] | None = None) -> list[str]:
if texts is None:
return ["generic"] * batch_size
return [infer_task_name_from_text(text) for text in texts]
def encode_scene(
self,
images: Tensor,
proprio: Tensor,
texts: Sequence[str] | None = None,
language_tokens: dict[str, Tensor] | None = None,
) -> Tensor:
image_tokens = self.backbone.encode_images(images)
text_tokens = self._encode_language(images, texts=texts, language_tokens=language_tokens)
return self.fusion(image_tokens=image_tokens, proprio=proprio, language_tokens=text_tokens)
def _expand_language_tokens_for_history(
self,
language_tokens: dict[str, Tensor] | None,
history_steps: int,
) -> dict[str, Tensor] | None:
if language_tokens is None:
return None
return {
key: value.unsqueeze(1).expand(-1, history_steps, *value.shape[1:]).reshape(
value.shape[0] * history_steps, *value.shape[1:]
)
for key, value in language_tokens.items()
}
def encode_history(
self,
history_images: Tensor | None,
history_proprio: Tensor | None,
texts: Sequence[str] | None = None,
language_tokens: dict[str, Tensor] | None = None,
) -> Tensor | None:
if history_images is None or history_proprio is None or history_images.numel() == 0:
return None
batch_size, history_steps = history_images.shape[:2]
flat_images = history_images.reshape(batch_size * history_steps, *history_images.shape[2:])
flat_proprio = history_proprio.reshape(batch_size * history_steps, history_proprio.shape[-1])
if language_tokens is None:
if texts is None:
raise ValueError("Either texts or language_tokens must be provided.")
flat_texts = [text for text in texts for _ in range(history_steps)]
flat_language_tokens = None
else:
flat_texts = None
flat_language_tokens = self._expand_language_tokens_for_history(language_tokens, history_steps)
history_scene = self.encode_scene(
flat_images,
flat_proprio,
texts=flat_texts,
language_tokens=flat_language_tokens,
)
return history_scene.view(batch_size, history_steps, history_scene.shape[1], history_scene.shape[2])
def forward(
self,
images: Tensor,
proprio: Tensor,
texts: Sequence[str] | None = None,
language_tokens: dict[str, Tensor] | None = None,
history_images: Tensor | None = None,
history_proprio: Tensor | None = None,
history_actions: Tensor | None = None,
) -> dict[str, Tensor]:
scene_tokens = self.encode_scene(images, proprio, texts=texts, language_tokens=language_tokens)
history_scene_tokens = self.encode_history(
history_images,
history_proprio,
texts=texts,
language_tokens=language_tokens,
)
memory_output = self.memory(
scene_tokens,
history_scene_tokens=history_scene_tokens,
history_actions=history_actions,
)
decoded = self.decoder(scene_tokens, memory_token=memory_output["memory_token"])
decoded["scene_tokens"] = scene_tokens
decoded["history_scene_tokens"] = history_scene_tokens
decoded["memory_output"] = memory_output
return decoded
class RevealBimanualPolicy(BackboneOnlyPolicy):
def __init__(self, config: PolicyConfig) -> None:
super().__init__(config)
self.reveal_head = RevealStateHead(config.reveal_head)
self.world_model = RevealWM(config.world_model)
self.planner = RevealPlanner(config.planner)
def forward(
self,
images: Tensor,
proprio: Tensor,
texts: Sequence[str] | None = None,
language_tokens: dict[str, Tensor] | None = None,
history_images: Tensor | None = None,
history_proprio: Tensor | None = None,
history_actions: Tensor | None = None,
plan: bool = True,
support_mode_conditioning: bool = True,
candidate_chunks_override: Tensor | None = None,
) -> dict[str, Tensor]:
outputs = super().forward(
images,
proprio,
texts=texts,
language_tokens=language_tokens,
history_images=history_images,
history_proprio=history_proprio,
history_actions=history_actions,
)
reveal_state = self.reveal_head(
outputs["scene_tokens"],
memory_token=outputs["memory_output"]["memory_token"],
)
outputs["reveal_state"] = reveal_state
outputs["memory_uncertainty"] = outputs["memory_output"]["memory_uncertainty"]
decoded = self.decoder(
outputs["scene_tokens"],
reveal_tokens=reveal_state["field_tokens"],
memory_token=outputs["memory_output"]["memory_token"],
)
outputs.update(decoded)
if plan:
candidate_chunks = candidate_chunks_override
if candidate_chunks is None:
candidate_chunks = self.decoder.sample_candidates(
outputs["action_mean"],
outputs["action_log_std"],
num_candidates=self.config.decoder.num_candidates,
)
outputs["candidate_chunks"] = candidate_chunks
batch_size, num_candidates, chunk_size, action_dim = candidate_chunks.shape
flat_chunks = candidate_chunks.view(batch_size * num_candidates, chunk_size, action_dim)
tiled_scene = outputs["scene_tokens"].unsqueeze(1).expand(-1, num_candidates, -1, -1)
tiled_scene = tiled_scene.reshape(batch_size * num_candidates, outputs["scene_tokens"].shape[1], outputs["scene_tokens"].shape[2])
planning_reveal_state = reveal_state
if not support_mode_conditioning:
planning_reveal_state = dict(reveal_state)
planning_reveal_state["support_mode_logits"] = torch.zeros_like(reveal_state["support_mode_logits"])
tiled_reveal = {
key: value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape(batch_size * num_candidates, *value.shape[1:])
for key, value in planning_reveal_state.items()
}
rollout = self.world_model(tiled_scene, tiled_reveal, flat_chunks)
reshaped_rollout = {
key: value.view(batch_size, num_candidates, *value.shape[1:]) for key, value in rollout.items()
}
selected = self.planner.select_best(
candidate_chunks=candidate_chunks,
rollout_state=reshaped_rollout,
)
outputs["planned_rollout"] = reshaped_rollout
outputs["planned_chunk"] = selected["best_chunk"]
outputs["planner_success_logits"] = selected["success_logits"]
outputs["planner_risk_values"] = selected["risk_values"]
outputs["planner_scores"] = selected["utility_scores"]
outputs["best_candidate_indices"] = selected["best_indices"]
return outputs
class InteractionBimanualPolicy(BackboneOnlyPolicy):
def __init__(self, config: PolicyConfig) -> None:
super().__init__(config)
self.memory = InteractionObservationMemory(config.memory)
self.decoder = InteractionChunkDecoder(config.decoder)
self.interaction_head = InteractionStateHead(config.reveal_head)
self.world_model = InteractionRolloutModel(config.world_model)
self.planner = InteractionPlanner(config.planner)
def _tile_tensor(self, value: Tensor, num_candidates: int) -> Tensor:
return value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape(
value.shape[0] * num_candidates,
*value.shape[1:],
)
def _tile_state(self, state: dict[str, Tensor], num_candidates: int) -> dict[str, Tensor]:
return {key: self._tile_tensor(value, num_candidates) for key, value in state.items()}
def forward(
self,
images: Tensor,
proprio: Tensor,
texts: Sequence[str] | None = None,
language_tokens: dict[str, Tensor] | None = None,
history_images: Tensor | None = None,
history_proprio: Tensor | None = None,
history_actions: Tensor | None = None,
plan: bool = True,
support_mode_conditioning: bool = True,
candidate_chunks_override: Tensor | None = None,
use_interaction_head: bool = True,
use_role_tokens: bool = True,
history_steps_override: int | None = None,
) -> dict[str, Tensor]:
scene_tokens = self.encode_scene(images, proprio, texts=texts, language_tokens=language_tokens)
history_scene_tokens = self.encode_history(
history_images,
history_proprio,
texts=texts,
language_tokens=language_tokens,
)
if history_steps_override is not None and history_scene_tokens is not None and history_scene_tokens.numel() > 0:
history_scene_tokens = history_scene_tokens[:, -history_steps_override:]
memory_output = self.memory(
scene_tokens,
history_scene_tokens=history_scene_tokens,
history_actions=history_actions,
)
interaction_state = None
if use_interaction_head:
interaction_state = self.interaction_head(
scene_tokens,
memory_tokens=memory_output["memory_tokens"],
)
interaction_state["memory_tokens"] = memory_output["memory_tokens"]
interaction_state["memory_token"] = memory_output["memory_token"]
if interaction_state is not None and not use_role_tokens:
interaction_state = dict(interaction_state)
interaction_state["arm_role_logits"] = torch.zeros_like(interaction_state["arm_role_logits"])
decoded = self.decoder(
scene_tokens,
interaction_state=interaction_state,
memory_tokens=memory_output["memory_tokens"],
)
outputs = {
**decoded,
"scene_tokens": scene_tokens,
"history_scene_tokens": history_scene_tokens,
"memory_output": memory_output,
"memory_uncertainty": memory_output["memory_uncertainty"],
"interaction_state": interaction_state,
"reveal_state": interaction_state,
}
if plan:
candidate_chunks = candidate_chunks_override
proposal_logits = outputs.get("proposal_logits")
if candidate_chunks is None:
candidate_chunks = self.decoder.sample_candidates(
outputs["action_mean"],
outputs["action_log_std"],
num_candidates=self.config.decoder.num_candidates,
proposal_candidates=outputs.get("proposal_candidates"),
)
else:
proposal_logits = None
outputs["candidate_chunks"] = candidate_chunks
if interaction_state is None:
outputs["planned_chunk"] = outputs["action_mean"]
outputs["planner_success_logits"] = torch.zeros(
candidate_chunks.shape[:2],
device=candidate_chunks.device,
dtype=candidate_chunks.dtype,
)
outputs["planner_risk_values"] = torch.zeros_like(outputs["planner_success_logits"])
outputs["planner_scores"] = torch.zeros_like(outputs["planner_success_logits"])
outputs["best_candidate_indices"] = torch.zeros(
candidate_chunks.shape[0],
dtype=torch.long,
device=candidate_chunks.device,
)
outputs["planned_rollout"] = {}
return outputs
batch_size, num_candidates, chunk_size, action_dim = candidate_chunks.shape
flat_chunks = candidate_chunks.view(batch_size * num_candidates, chunk_size, action_dim)
tiled_scene = self._tile_tensor(scene_tokens, num_candidates)
planning_state = interaction_state
if not support_mode_conditioning:
planning_state = dict(interaction_state)
planning_state["support_mode_logits"] = torch.zeros_like(interaction_state["support_mode_logits"])
tiled_state = self._tile_state(planning_state, num_candidates)
tiled_memory_tokens = self._tile_tensor(memory_output["memory_tokens"], num_candidates)
rollout = self.world_model(
scene_tokens=tiled_scene,
interaction_state=tiled_state,
action_chunk=flat_chunks,
memory_tokens=tiled_memory_tokens,
)
reshaped_rollout = {
key: value.view(batch_size, num_candidates, *value.shape[1:]) for key, value in rollout.items()
}
selected = self.planner.select_best(
candidate_chunks=candidate_chunks,
rollout_state=reshaped_rollout,
proposal_logits=proposal_logits,
)
outputs["planned_rollout"] = reshaped_rollout
outputs["planned_chunk"] = selected["best_chunk"]
outputs["planner_success_logits"] = selected["success_logits"]
outputs["planner_risk_values"] = selected["risk_values"]
outputs["planner_scores"] = selected["utility_scores"]
outputs["best_candidate_indices"] = selected["best_indices"]
return outputs
class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
def __init__(self, config: PolicyConfig) -> None:
super().__init__(config)
self.memory = DualObservationMemory(config.memory)
self.decoder = SymmetricCoordinatedChunkDecoder(config.decoder)
self.elastic_state_head = ElasticOcclusionStateHead(config.reveal_head)
self.world_model = ElasticOcclusionWorldModel(config.world_model)
self.planner = CascadePlanner(config.planner)
def _encode_scene_with_optional_depth(
self,
images: Tensor,
proprio: Tensor,
texts: Sequence[str] | None = None,
language_tokens: dict[str, Tensor] | None = None,
depths: Tensor | None = None,
depth_valid: Tensor | None = None,
camera_intrinsics: Tensor | None = None,
camera_extrinsics: Tensor | None = None,
use_depth: bool = True,
use_geometry_tokens: bool | None = None,
use_camera_pose_tokens: bool | None = None,
) -> dict[str, Tensor]:
encoded = self.backbone.encode_images(
images,
depths=depths if use_depth else None,
depth_valid=depth_valid if use_depth else None,
camera_intrinsics=camera_intrinsics if use_depth else None,
camera_extrinsics=camera_extrinsics if use_depth else None,
return_aux=True,
use_depth_tokens=use_depth,
use_geometry_tokens=use_geometry_tokens,
use_camera_pose_tokens=use_camera_pose_tokens,
)
assert isinstance(encoded, dict)
text_tokens = self._encode_language(images, texts=texts, language_tokens=language_tokens)
fused = self.fusion(
image_tokens=encoded["rgb_tokens"],
proprio=proprio,
language_tokens=text_tokens,
depth_tokens=encoded.get("depth_tokens"),
geometry_tokens=encoded.get("geometry_tokens"),
camera_tokens=encoded.get("camera_tokens"),
return_aux=True,
)
assert isinstance(fused, dict)
return {
"scene_tokens": fused["scene_tokens"],
"view_summaries": fused["view_summaries"],
"geometry_summaries": fused["geometry_summaries"],
"depth_tokens": encoded.get("depth_tokens"),
"geometry_tokens": encoded.get("geometry_tokens"),
"camera_tokens": encoded.get("camera_tokens"),
}
def _expand_language_tokens_for_history(
self,
language_tokens: dict[str, Tensor] | None,
history_steps: int,
) -> dict[str, Tensor] | None:
if language_tokens is None:
return None
return {
key: value.unsqueeze(1).expand(-1, history_steps, *value.shape[1:]).reshape(
value.shape[0] * history_steps, *value.shape[1:]
)
for key, value in language_tokens.items()
}
def encode_history_with_optional_depth(
self,
history_images: Tensor | None,
history_proprio: Tensor | None,
texts: Sequence[str] | None = None,
language_tokens: dict[str, Tensor] | None = None,
history_depths: Tensor | None = None,
history_depth_valid: Tensor | None = None,
camera_intrinsics: Tensor | None = None,
camera_extrinsics: Tensor | None = None,
use_depth: bool = True,
use_geometry_tokens: bool | None = None,
use_camera_pose_tokens: bool | None = None,
) -> Tensor | None:
if history_images is None or history_proprio is None or history_images.numel() == 0:
return None
batch_size, history_steps = history_images.shape[:2]
flat_images = history_images.reshape(batch_size * history_steps, *history_images.shape[2:])
flat_proprio = history_proprio.reshape(batch_size * history_steps, history_proprio.shape[-1])
flat_depths = None
flat_depth_valid = None
if history_depths is not None and history_depths.numel() > 0:
flat_depths = history_depths.reshape(batch_size * history_steps, *history_depths.shape[2:])
if history_depth_valid is not None and history_depth_valid.numel() > 0:
flat_depth_valid = history_depth_valid.reshape(batch_size * history_steps, *history_depth_valid.shape[2:])
if language_tokens is None:
flat_texts = [text for text in texts for _ in range(history_steps)] if texts is not None else None
flat_language_tokens = None
else:
flat_texts = None
flat_language_tokens = self._expand_language_tokens_for_history(language_tokens, history_steps)
history_scene = self._encode_scene_with_optional_depth(
images=flat_images,
proprio=flat_proprio,
texts=flat_texts,
language_tokens=flat_language_tokens,
depths=flat_depths,
depth_valid=flat_depth_valid,
camera_intrinsics=None,
camera_extrinsics=None,
use_depth=use_depth,
use_geometry_tokens=use_geometry_tokens,
use_camera_pose_tokens=use_camera_pose_tokens,
)["scene_tokens"]
return history_scene.view(batch_size, history_steps, history_scene.shape[1], history_scene.shape[2])
def _tile_tensor(self, value: Tensor, num_candidates: int) -> Tensor:
return value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape(
value.shape[0] * num_candidates,
*value.shape[1:],
)
def _tile_state(self, state: dict[str, Tensor], num_candidates: int) -> dict[str, Tensor]:
tiled: dict[str, Tensor] = {}
for key, value in state.items():
if isinstance(value, Tensor):
tiled[key] = self._tile_tensor(value, num_candidates)
return tiled
def _detach_state(self, state: dict[str, Tensor]) -> dict[str, Tensor]:
detached: dict[str, Tensor] = {}
for key, value in state.items():
detached[key] = value.detach() if isinstance(value, Tensor) else value
return detached
def _repeat_rollout_tensor(self, value: Tensor, num_candidates: int, horizon: int) -> Tensor:
value = value.detach()
return value.unsqueeze(1).unsqueeze(2).expand(-1, num_candidates, horizon, *value.shape[1:])
def _zero_memory_output(self, scene_tokens: Tensor) -> dict[str, Tensor]:
batch_size, _, hidden_dim = scene_tokens.shape
scene_memory_tokens = scene_tokens.new_zeros((batch_size, self.config.memory.scene_bank_size, hidden_dim))
belief_memory_tokens = scene_tokens.new_zeros((batch_size, self.config.memory.belief_bank_size, hidden_dim))
memory_tokens = torch.cat([scene_memory_tokens, belief_memory_tokens], dim=1)
return {
"scene_memory_tokens": scene_memory_tokens,
"belief_memory_tokens": belief_memory_tokens,
"memory_tokens": memory_tokens,
"memory_token": memory_tokens.mean(dim=1, keepdim=True),
"memory_sequence": scene_tokens.new_zeros((batch_size, 0, hidden_dim)),
"memory_state": scene_tokens.new_zeros((batch_size, hidden_dim * 2)),
"memory_uncertainty": scene_tokens.new_zeros((batch_size,)),
"memory_write_rate": scene_tokens.new_zeros((batch_size,)),
"memory_saturation": scene_tokens.new_zeros((batch_size,)),
"scene_write_gate": scene_tokens.new_zeros((batch_size, self.config.memory.scene_bank_size)),
"belief_write_gate": scene_tokens.new_zeros((batch_size, self.config.memory.belief_bank_size)),
"memory_scene_state": scene_tokens.new_zeros((batch_size, hidden_dim)),
"memory_belief_state": scene_tokens.new_zeros((batch_size, hidden_dim)),
}
def _identity_rollout(
self,
interaction_state: dict[str, Tensor],
num_candidates: int,
) -> dict[str, Tensor]:
horizon = max(1, self.config.world_model.rollout_horizon)
rollout: dict[str, Tensor] = {}
for key, value in interaction_state.items():
if isinstance(value, Tensor):
rollout[key] = self._repeat_rollout_tensor(value, num_candidates, horizon)
return rollout
def forward(
self,
images: Tensor,
proprio: Tensor,
texts: Sequence[str] | None = None,
language_tokens: dict[str, Tensor] | None = None,
history_images: Tensor | None = None,
history_proprio: Tensor | None = None,
history_actions: Tensor | None = None,
plan: bool = True,
support_mode_conditioning: bool = True,
candidate_chunks_override: Tensor | None = None,
use_depth: bool = True,
use_world_model: bool = True,
use_planner: bool = True,
use_role_tokens: bool = True,
history_steps_override: int | None = None,
depths: Tensor | None = None,
depth_valid: Tensor | None = None,
camera_intrinsics: Tensor | None = None,
camera_extrinsics: Tensor | None = None,
history_depths: Tensor | None = None,
history_depth_valid: Tensor | None = None,
compute_equivariance_probe: bool = False,
use_geometry_tokens: bool | None = None,
use_camera_pose_tokens: bool | None = None,
use_memory: bool = True,
use_task_conditioning: bool = True,
rollout_mode_override: str | None = None,
use_proposal_candidates: bool = True,
) -> dict[str, Tensor]:
task_names = self._task_names(images.shape[0], texts=texts)
scene_output = self._encode_scene_with_optional_depth(
images=images,
proprio=proprio,
texts=texts,
language_tokens=language_tokens,
depths=depths,
depth_valid=depth_valid,
camera_intrinsics=camera_intrinsics,
camera_extrinsics=camera_extrinsics,
use_depth=use_depth,
use_geometry_tokens=use_geometry_tokens,
use_camera_pose_tokens=use_camera_pose_tokens,
)
scene_tokens = scene_output["scene_tokens"]
history_scene_tokens = self.encode_history_with_optional_depth(
history_images=history_images,
history_proprio=history_proprio,
texts=texts,
language_tokens=language_tokens,
history_depths=history_depths,
history_depth_valid=history_depth_valid,
camera_intrinsics=camera_intrinsics,
camera_extrinsics=camera_extrinsics,
use_depth=use_depth,
use_geometry_tokens=use_geometry_tokens,
use_camera_pose_tokens=use_camera_pose_tokens,
)
if history_steps_override is not None and history_scene_tokens is not None and history_scene_tokens.numel() > 0:
history_scene_tokens = history_scene_tokens[:, -history_steps_override:]
if history_actions is not None and history_actions.numel() > 0:
history_actions = history_actions[:, -history_steps_override:]
if use_memory:
memory_output = self.memory(
scene_tokens,
history_scene_tokens=history_scene_tokens,
history_actions=history_actions,
)
else:
memory_output = self._zero_memory_output(scene_tokens)
elastic_state = self.elastic_state_head(
scene_tokens,
memory_tokens=memory_output["memory_tokens"],
task_names=task_names,
use_task_conditioning=use_task_conditioning,
)
elastic_state["memory_tokens"] = memory_output["memory_tokens"]
elastic_state["memory_token"] = memory_output["memory_token"]
elastic_state["scene_memory_tokens"] = memory_output["scene_memory_tokens"]
elastic_state["belief_memory_tokens"] = memory_output["belief_memory_tokens"]
if not use_role_tokens:
elastic_state = dict(elastic_state)
elastic_state["arm_role_logits"] = torch.zeros_like(elastic_state["arm_role_logits"])
decoded = self.decoder(
scene_tokens,
interaction_state=elastic_state,
memory_tokens=memory_output["memory_tokens"],
compute_equivariance_probe=compute_equivariance_probe,
task_names=task_names,
)
outputs = {
**decoded,
"scene_tokens": scene_tokens,
"history_scene_tokens": history_scene_tokens,
"memory_output": memory_output,
"memory_uncertainty": memory_output["memory_uncertainty"],
"interaction_state": elastic_state,
"reveal_state": elastic_state,
"view_summaries": scene_output["view_summaries"],
"geometry_summaries": scene_output["geometry_summaries"],
"depth_tokens": scene_output["depth_tokens"],
"geometry_tokens": scene_output["geometry_tokens"],
"camera_tokens": scene_output["camera_tokens"],
"rollout_source": "none",
"task_names": task_names,
}
candidate_chunks = candidate_chunks_override
proposal_logits = outputs.get("proposal_logits")
if candidate_chunks is None:
candidate_chunks = self.decoder.sample_candidates(
outputs["action_mean"],
outputs["action_log_std"],
num_candidates=self.config.decoder.num_candidates,
proposal_candidates=outputs.get("proposal_candidates") if use_proposal_candidates else None,
)
if not use_proposal_candidates:
proposal_logits = None
else:
proposal_logits = None
outputs["candidate_chunks"] = candidate_chunks
if not plan or not use_planner:
outputs["planned_chunk"] = outputs["action_mean"]
outputs["planned_rollout"] = {}
outputs["planner_success_logits"] = torch.zeros(
candidate_chunks.shape[:2],
device=candidate_chunks.device,
dtype=candidate_chunks.dtype,
)
outputs["planner_risk_values"] = torch.zeros_like(outputs["planner_success_logits"])
outputs["planner_scores"] = torch.zeros_like(outputs["planner_success_logits"])
outputs["best_candidate_indices"] = torch.zeros(
candidate_chunks.shape[0],
dtype=torch.long,
device=candidate_chunks.device,
)
return outputs
shortlist_indices = self.planner.shortlist(
proposal_logits=proposal_logits,
candidate_chunks=candidate_chunks,
proposal_mode_assignments=outputs.get("proposal_mode_assignments") if use_proposal_candidates else None,
)
outputs["planner_topk_indices"] = shortlist_indices
batch_size = candidate_chunks.shape[0]
batch_indices = torch.arange(batch_size, device=candidate_chunks.device).unsqueeze(-1)
topk_candidates = candidate_chunks[batch_indices, shortlist_indices]
num_topk = topk_candidates.shape[1]
outputs["planner_topk_candidates"] = topk_candidates
proposal_mode_names = outputs.get("proposal_mode_names")
topk_proposal_mode_names = None
if proposal_mode_names is not None and use_proposal_candidates:
topk_proposal_mode_names = [
[proposal_mode_names[batch_idx][int(candidate_idx.item())] for candidate_idx in shortlist_indices[batch_idx]]
for batch_idx in range(batch_size)
]
outputs["planner_topk_mode_names"] = topk_proposal_mode_names
if proposal_logits is not None:
topk_proposal_logits = proposal_logits.gather(1, shortlist_indices)
else:
topk_proposal_logits = None
planning_state = elastic_state
if not support_mode_conditioning:
planning_state = dict(elastic_state)
planning_state["support_mode_logits"] = torch.zeros_like(elastic_state["support_mode_logits"])
if not use_world_model:
detached_state = self._detach_state(planning_state)
identity_rollout = self._identity_rollout(
interaction_state=detached_state,
num_candidates=num_topk,
)
selected = self.planner.select_best(
initial_state=detached_state,
candidate_chunks=topk_candidates,
rollout_state=identity_rollout,
proposal_logits=topk_proposal_logits,
candidate_indices=shortlist_indices,
proposal_mode_names=topk_proposal_mode_names,
)
outputs["planned_rollout"] = identity_rollout
outputs["planned_chunk"] = selected["best_chunk"]
outputs["planner_success_logits"] = selected["success_logits"]
outputs["planner_risk_values"] = selected["risk_values"]
outputs["planner_scores"] = selected["utility_total"]
outputs["best_candidate_indices"] = selected["best_indices"]
outputs["utility_structured"] = selected["utility_structured"]
outputs["utility_residual"] = selected["utility_residual"]
outputs["utility_total"] = selected["utility_total"]
outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
outputs["rollout_source"] = "identity"
return outputs
flat_chunks = topk_candidates.view(batch_size * num_topk, topk_candidates.shape[2], topk_candidates.shape[3])
tiled_scene = self._tile_tensor(scene_tokens, num_topk)
tiled_state = self._tile_state(planning_state, num_topk)
rollout = self.world_model(
scene_tokens=tiled_scene,
interaction_state=tiled_state,
action_chunk=flat_chunks,
memory_tokens=self._tile_tensor(memory_output["memory_tokens"], num_topk),
scene_memory_tokens=self._tile_tensor(memory_output["scene_memory_tokens"], num_topk),
belief_memory_tokens=self._tile_tensor(memory_output["belief_memory_tokens"], num_topk),
task_names=[name for name in task_names for _ in range(num_topk)],
rollout_mode_override=rollout_mode_override,
)
reshaped_rollout = {
key: value.view(batch_size, num_topk, *value.shape[1:]) for key, value in rollout.items()
}
selected = self.planner.select_best(
initial_state=elastic_state,
candidate_chunks=topk_candidates,
rollout_state=reshaped_rollout,
proposal_logits=topk_proposal_logits,
candidate_indices=shortlist_indices,
proposal_mode_names=topk_proposal_mode_names,
)
outputs["planned_rollout"] = reshaped_rollout
outputs["planned_chunk"] = selected["best_chunk"]
outputs["planner_success_logits"] = selected["success_logits"]
outputs["planner_risk_values"] = selected["risk_values"]
outputs["planner_scores"] = selected["utility_total"]
outputs["best_candidate_indices"] = selected["best_indices"]
outputs["utility_structured"] = selected["utility_structured"]
outputs["utility_residual"] = selected["utility_residual"]
outputs["utility_total"] = selected["utility_total"]
outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
outputs["rollout_source"] = "learned"
return outputs