lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
import torch
import torch.nn.functional as F
from torch import Tensor, nn
HEAD_TASKS = ("generic", "foliage", "bag", "cloth")
TASK_METRIC_NAMES = (
"opening_quality",
"actor_feasibility_score",
"gap_width",
"damage_proxy",
"release_collapse_rate",
"target_visibility_confidence",
"mouth_aperture",
"hold_quality",
"rim_slip_risk",
"insertable_actor_corridor",
"layer_separation_quality",
"fold_preservation",
"insertion_corridor",
"top_layer_stability",
"lift_too_much_risk",
)
TASK_INDEX = {name: idx for idx, name in enumerate(HEAD_TASKS)}
def task_ids_from_names(task_names: Sequence[str] | None, device: torch.device, batch_size: int) -> Tensor:
if task_names is None:
return torch.zeros(batch_size, device=device, dtype=torch.long)
return torch.as_tensor(
[TASK_INDEX.get(str(name), 0) for name in task_names],
device=device,
dtype=torch.long,
)
def _mean_map(value: Tensor) -> Tensor:
return value.mean(dim=(-1, -2)).squeeze(1)
def compute_task_metrics_from_fields(
*,
access_field: Tensor,
persistence_field: Tensor,
disturbance_field: Tensor,
reocclusion_field: Tensor,
visibility_field: Tensor,
clearance_field: Tensor,
support_stability_field: Tensor,
uncertainty_field: Tensor,
) -> dict[str, Tensor]:
access_prob = torch.sigmoid(access_field)
opening_mask = access_prob.amax(dim=1, keepdim=True)
support_stability = torch.sigmoid(support_stability_field)
visibility_prob = torch.sigmoid(visibility_field)
clearance_prob = torch.sigmoid(clearance_field).mean(dim=1, keepdim=True)
normalized_uncertainty = uncertainty_field / (1.0 + uncertainty_field)
opening_quality_field = opening_mask * persistence_field * support_stability
newly_revealed_field = torch.relu(visibility_prob - reocclusion_field)
still_visible_field = visibility_prob * persistence_field
reoccluded_field = reocclusion_field
opening_quality = _mean_map(opening_quality_field)
actor_feasibility_score = 0.6 * _mean_map(clearance_prob) + 0.4 * _mean_map(opening_mask)
base_gap = _mean_map(opening_mask)
disturbance_cost = _mean_map(disturbance_field)
support_quality = _mean_map(support_stability)
visibility_confidence = _mean_map(visibility_prob * (1.0 - normalized_uncertainty))
reocclusion_rate = _mean_map(reocclusion_field)
persistence_score = _mean_map(persistence_field)
return {
"newly_revealed_field": newly_revealed_field,
"still_visible_field": still_visible_field,
"reoccluded_field": reoccluded_field,
"opening_quality_field": opening_quality_field,
"opening_quality": torch.clamp(opening_quality, 0.0, 1.0),
"actor_feasibility_score": torch.clamp(actor_feasibility_score, 0.0, 1.0),
"gap_width": 0.03 + 0.21 * torch.clamp(base_gap, 0.0, 1.0),
"damage_proxy": torch.clamp(disturbance_cost + 0.5 * (1.0 - support_quality), 0.0, 1.0),
"release_collapse_rate": torch.clamp(reocclusion_rate, 0.0, 1.0),
"target_visibility_confidence": torch.clamp(visibility_confidence, 0.0, 1.0),
"mouth_aperture": torch.clamp(base_gap, 0.0, 1.0),
"hold_quality": torch.clamp(0.5 * (persistence_score + support_quality), 0.0, 1.0),
"rim_slip_risk": torch.clamp(reocclusion_rate + 0.5 * (1.0 - support_quality), 0.0, 1.0),
"insertable_actor_corridor": torch.clamp(0.6 * actor_feasibility_score + 0.4 * base_gap, 0.0, 1.0),
"layer_separation_quality": torch.clamp(0.7 * base_gap + 0.3 * actor_feasibility_score, 0.0, 1.0),
"fold_preservation": torch.clamp(1.0 - disturbance_cost, 0.0, 1.0),
"insertion_corridor": torch.clamp(0.5 * actor_feasibility_score + 0.5 * base_gap, 0.0, 1.0),
"top_layer_stability": torch.clamp(support_quality, 0.0, 1.0),
"lift_too_much_risk": torch.clamp(disturbance_cost + 0.5 * torch.relu(base_gap - 0.5), 0.0, 1.0),
}
@dataclass
class RevealHeadConfig:
hidden_dim: int = 512
num_support_modes: int = 3
num_approach_templates: int = 32
rollout_horizon: int = 5
belief_map_size: int = 32
field_size: int = 16
num_heads: int = 4
predict_belief_map: bool = False
num_phases: int = 5
num_arm_roles: int = 4
num_interaction_tokens: int = 8
num_tasks: int = len(HEAD_TASKS)
class RevealStateHead(nn.Module):
def __init__(self, config: RevealHeadConfig) -> None:
super().__init__()
self.config = config
self.field_queries = nn.Parameter(
torch.randn(config.field_size * config.field_size, config.hidden_dim) * 0.02
)
self.field_attention = nn.MultiheadAttention(
embed_dim=config.hidden_dim,
num_heads=config.num_heads,
batch_first=True,
)
self.field_mlp = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
self.support_mode = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 3),
nn.Linear(config.hidden_dim * 3, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_support_modes),
)
self.access_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1)
self.persistence_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1)
self.disturbance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.uncertainty_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.reocclusion_head = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 2),
nn.Linear(config.hidden_dim * 2, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_support_modes),
)
self.latent_summary = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 2),
nn.Linear(config.hidden_dim * 2, config.hidden_dim),
nn.GELU(),
)
def forward(self, scene_tokens: Tensor, memory_token: Tensor | None = None) -> dict[str, Tensor]:
source_tokens = scene_tokens if memory_token is None else torch.cat([scene_tokens, memory_token], dim=1)
batch_size = source_tokens.shape[0]
field_queries = self.field_queries.unsqueeze(0).expand(batch_size, -1, -1)
field_tokens, _ = self.field_attention(field_queries, source_tokens, source_tokens)
field_tokens = field_tokens + self.field_mlp(field_tokens)
side = self.config.field_size
grid = field_tokens.transpose(1, 2).reshape(batch_size, self.config.hidden_dim, side, side)
pooled_scene = scene_tokens.mean(dim=1)
pooled_field = field_tokens.mean(dim=1)
if memory_token is not None:
pooled_memory = memory_token.squeeze(1)
else:
pooled_memory = pooled_scene.new_zeros(pooled_scene.shape)
support_input = torch.cat([pooled_scene, pooled_field, pooled_memory], dim=-1)
access_field = self.access_field(grid)
persistence_field = torch.sigmoid(self.persistence_field(grid))
disturbance_field = torch.sigmoid(self.disturbance_field(grid))
uncertainty_field = F.softplus(self.uncertainty_field(grid))
corridor_source = access_field.amax(dim=-2)
corridor_logits = F.interpolate(
corridor_source,
size=self.config.num_approach_templates,
mode="linear",
align_corners=False,
)
access_prob = torch.sigmoid(access_field)
weighted_persistence = (persistence_field * access_prob).sum(dim=(-1, -2))
access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4)
persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass
disturbance_cost = disturbance_field.mean(dim=(-1, -2)).squeeze(1)
belief_map = access_field.max(dim=1, keepdim=True).values
if belief_map.shape[-1] != self.config.belief_map_size:
belief_map = F.interpolate(
belief_map,
size=(self.config.belief_map_size, self.config.belief_map_size),
mode="bilinear",
align_corners=False,
)
latent_summary = self.latent_summary(torch.cat([pooled_scene, pooled_field], dim=-1))
output = {
"support_mode_logits": self.support_mode(support_input),
"corridor_logits": corridor_logits,
"persistence_horizon": persistence_horizon,
"disturbance_cost": disturbance_cost,
"access_field": access_field,
"persistence_field": persistence_field,
"disturbance_field": disturbance_field,
"uncertainty_field": uncertainty_field,
"field_tokens": field_tokens,
"latent_summary": latent_summary,
"reocclusion_logit": self.reocclusion_head(torch.cat([pooled_field, pooled_memory], dim=-1)),
"persistence_uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
}
if self.config.predict_belief_map:
output["belief_map"] = belief_map
return output
class InteractionFieldDecoder(nn.Module):
def __init__(self, config: RevealHeadConfig) -> None:
super().__init__()
self.config = config
self.field_queries = nn.Parameter(
torch.randn(config.field_size * config.field_size, config.hidden_dim) * 0.02
)
self.field_attention = nn.MultiheadAttention(
embed_dim=config.hidden_dim,
num_heads=config.num_heads,
batch_first=True,
)
self.field_mlp = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
summary_dim = config.hidden_dim * 4
self.summary_proj = nn.Sequential(
nn.LayerNorm(summary_dim),
nn.Linear(summary_dim, config.hidden_dim),
nn.GELU(),
)
self.phase_head = nn.Sequential(
nn.LayerNorm(summary_dim),
nn.Linear(summary_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_phases),
)
self.arm_role_head = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 2),
nn.Linear(config.hidden_dim * 2, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_arm_roles),
)
self.arm_identity = nn.Embedding(2, config.hidden_dim)
self.support_mode = nn.Sequential(
nn.LayerNorm(summary_dim),
nn.Linear(summary_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_support_modes),
)
self.target_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.actor_feasibility_field = nn.Conv2d(config.hidden_dim, 2, kernel_size=1)
self.persistence_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.risk_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.uncertainty_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.compat_access_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1)
self.compat_persistence = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1)
self.reocclusion_head = nn.Sequential(
nn.LayerNorm(summary_dim),
nn.Linear(summary_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_support_modes),
)
def _pool_source(self, source_tokens: Tensor | None, fallback: Tensor) -> Tensor:
if source_tokens is None or source_tokens.numel() == 0:
return fallback.new_zeros(fallback.shape)
return source_tokens.mean(dim=1)
def forward(
self,
interaction_tokens: Tensor,
scene_tokens: Tensor | None = None,
memory_tokens: Tensor | None = None,
) -> dict[str, Tensor]:
batch_size = interaction_tokens.shape[0]
pooled_interaction = interaction_tokens.mean(dim=1)
pooled_scene = self._pool_source(scene_tokens, pooled_interaction)
pooled_memory = self._pool_source(memory_tokens, pooled_interaction)
field_queries = self.field_queries.unsqueeze(0).expand(batch_size, -1, -1)
source_tokens = interaction_tokens
if scene_tokens is not None:
source_tokens = torch.cat([source_tokens, scene_tokens], dim=1)
if memory_tokens is not None:
source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
field_tokens, _ = self.field_attention(field_queries, source_tokens, source_tokens)
field_tokens = field_tokens + self.field_mlp(field_tokens)
side = self.config.field_size
grid = field_tokens.transpose(1, 2).reshape(batch_size, self.config.hidden_dim, side, side)
pooled_field = field_tokens.mean(dim=1)
summary_input = torch.cat([pooled_interaction, pooled_field, pooled_scene, pooled_memory], dim=-1)
summary = self.summary_proj(summary_input)
target_field = self.target_field(grid)
actor_feasibility_field = self.actor_feasibility_field(grid)
persistence_field = torch.sigmoid(self.persistence_field(grid))
risk_field = torch.sigmoid(self.risk_field(grid))
uncertainty_field = F.softplus(self.uncertainty_field(grid))
access_field = self.compat_access_field(grid)
corridor_source = access_field.amax(dim=-2)
corridor_logits = F.interpolate(
corridor_source,
size=self.config.num_approach_templates,
mode="linear",
align_corners=False,
)
compatibility_persistence = torch.sigmoid(self.compat_persistence(grid))
access_prob = torch.sigmoid(access_field)
weighted_persistence = (compatibility_persistence * access_prob).sum(dim=(-1, -2))
access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4)
persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass
disturbance_cost = risk_field.mean(dim=(-1, -2)).squeeze(1)
belief_map = target_field
if belief_map.shape[-1] != self.config.belief_map_size:
belief_map = F.interpolate(
belief_map,
size=(self.config.belief_map_size, self.config.belief_map_size),
mode="bilinear",
align_corners=False,
)
arm_identity = self.arm_identity.weight.unsqueeze(0).expand(batch_size, -1, -1)
if interaction_tokens.shape[1] >= 2:
arm_tokens = interaction_tokens[:, :2] + arm_identity
else:
arm_tokens = pooled_interaction.unsqueeze(1).expand(-1, 2, -1) + arm_identity
arm_role_input = torch.cat(
[arm_tokens, summary.unsqueeze(1).expand(-1, arm_tokens.shape[1], -1)],
dim=-1,
)
arm_role_logits = self.arm_role_head(arm_role_input)
reocclusion_logit = self.reocclusion_head(summary_input)
output = {
"phase_logits": self.phase_head(summary_input),
"arm_role_logits": arm_role_logits,
"target_field": target_field,
"actor_feasibility_field": actor_feasibility_field,
"persistence_field": persistence_field,
"risk_field": risk_field,
"uncertainty_field": uncertainty_field,
"interaction_tokens": interaction_tokens,
"field_tokens": field_tokens,
"latent_summary": summary,
"support_mode_logits": self.support_mode(summary_input),
"corridor_logits": corridor_logits,
"persistence_horizon": persistence_horizon,
"disturbance_cost": disturbance_cost,
"belief_map": belief_map,
"reocclusion_logit": reocclusion_logit,
"persistence_uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
"access_field": access_field,
"disturbance_field": risk_field,
"uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
}
if not self.config.predict_belief_map:
output.pop("belief_map")
return output
class InteractionStateHead(nn.Module):
def __init__(self, config: RevealHeadConfig) -> None:
super().__init__()
self.config = config
self.interaction_queries = nn.Parameter(
torch.randn(config.num_interaction_tokens, config.hidden_dim) * 0.02
)
self.interaction_attention = nn.MultiheadAttention(
embed_dim=config.hidden_dim,
num_heads=config.num_heads,
batch_first=True,
)
self.interaction_mlp = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
self.decoder = InteractionFieldDecoder(config)
def forward(
self,
scene_tokens: Tensor,
memory_token: Tensor | None = None,
memory_tokens: Tensor | None = None,
) -> dict[str, Tensor]:
if memory_tokens is None:
memory_tokens = memory_token
source_tokens = scene_tokens
if memory_tokens is not None:
source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
batch_size = source_tokens.shape[0]
interaction_queries = self.interaction_queries.unsqueeze(0).expand(batch_size, -1, -1)
interaction_tokens, _ = self.interaction_attention(interaction_queries, source_tokens, source_tokens)
interaction_tokens = interaction_tokens + self.interaction_mlp(interaction_tokens)
return self.decoder(
interaction_tokens=interaction_tokens,
scene_tokens=scene_tokens,
memory_tokens=memory_tokens,
)
class ElasticOcclusionFieldDecoder(nn.Module):
def __init__(self, config: RevealHeadConfig) -> None:
super().__init__()
self.config = config
self.field_queries = nn.Parameter(
torch.randn(config.field_size * config.field_size, config.hidden_dim) * 0.02
)
self.field_attention = nn.MultiheadAttention(
embed_dim=config.hidden_dim,
num_heads=config.num_heads,
batch_first=True,
)
self.field_mlp = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
summary_dim = config.hidden_dim * 4
self.summary_proj = nn.Sequential(
nn.LayerNorm(summary_dim),
nn.Linear(summary_dim, config.hidden_dim),
nn.GELU(),
)
self.phase_head = nn.Sequential(
nn.LayerNorm(summary_dim),
nn.Linear(summary_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_phases),
)
self.arm_role_head = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 2),
nn.Linear(config.hidden_dim * 2, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_arm_roles),
)
self.arm_identity = nn.Embedding(2, config.hidden_dim)
self.support_mode = nn.Sequential(
nn.LayerNorm(summary_dim),
nn.Linear(summary_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_support_modes),
)
self.access_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1)
self.target_belief_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.visibility_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.clearance_field = nn.Conv2d(config.hidden_dim, 2, kernel_size=1)
self.occluder_contact_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.grasp_affordance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.support_stability_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.persistence_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.reocclusion_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.disturbance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.uncertainty_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
self.reocclusion_head = nn.Sequential(
nn.LayerNorm(summary_dim),
nn.Linear(summary_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.num_support_modes),
)
self.task_embedding = nn.Embedding(config.num_tasks, config.hidden_dim)
self.task_field_affine = nn.Linear(config.hidden_dim, config.hidden_dim * 2)
self.task_summary_adapter = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 2),
nn.Linear(config.hidden_dim * 2, config.hidden_dim),
nn.GELU(),
)
self.task_phase_head = nn.Linear(config.hidden_dim, config.num_phases)
self.task_support_head = nn.Linear(config.hidden_dim, config.num_support_modes)
self.task_reocclusion_head = nn.Linear(config.hidden_dim, config.num_support_modes)
self.task_metric_head = nn.Sequential(
nn.LayerNorm(config.hidden_dim * 2),
nn.Linear(config.hidden_dim * 2, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, len(TASK_METRIC_NAMES)),
)
def _pool_source(self, source_tokens: Tensor | None, fallback: Tensor) -> Tensor:
if source_tokens is None or source_tokens.numel() == 0:
return fallback.new_zeros(fallback.shape)
return source_tokens.mean(dim=1)
def _field_mean(self, field: Tensor) -> Tensor:
return field.mean(dim=(-1, -2))
def _upsampled_belief(self, target_belief_field: Tensor) -> Tensor:
if target_belief_field.shape[-1] == self.config.belief_map_size:
return target_belief_field
return F.interpolate(
target_belief_field,
size=(self.config.belief_map_size, self.config.belief_map_size),
mode="bilinear",
align_corners=False,
)
def forward(
self,
interaction_tokens: Tensor,
scene_tokens: Tensor | None = None,
memory_tokens: Tensor | None = None,
task_names: Sequence[str] | None = None,
use_task_conditioning: bool = True,
) -> dict[str, Tensor]:
batch_size = interaction_tokens.shape[0]
pooled_interaction = interaction_tokens.mean(dim=1)
pooled_scene = self._pool_source(scene_tokens, pooled_interaction)
pooled_memory = self._pool_source(memory_tokens, pooled_interaction)
field_queries = self.field_queries.unsqueeze(0).expand(batch_size, -1, -1)
source_tokens = interaction_tokens
if scene_tokens is not None:
source_tokens = torch.cat([source_tokens, scene_tokens], dim=1)
if memory_tokens is not None:
source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
field_tokens, _ = self.field_attention(field_queries, source_tokens, source_tokens)
field_tokens = field_tokens + self.field_mlp(field_tokens)
side = self.config.field_size
grid = field_tokens.transpose(1, 2).reshape(batch_size, self.config.hidden_dim, side, side)
pooled_field = field_tokens.mean(dim=1)
summary_input = torch.cat([pooled_interaction, pooled_field, pooled_scene, pooled_memory], dim=-1)
latent_summary = self.summary_proj(summary_input)
task_ids = task_ids_from_names(task_names, interaction_tokens.device, batch_size)
task_embed = self.task_embedding(task_ids)
if use_task_conditioning:
scale, bias = self.task_field_affine(task_embed).chunk(2, dim=-1)
grid = grid * (1.0 + 0.1 * scale.view(batch_size, self.config.hidden_dim, 1, 1))
grid = grid + 0.1 * bias.view(batch_size, self.config.hidden_dim, 1, 1)
task_summary = latent_summary + 0.1 * self.task_summary_adapter(torch.cat([latent_summary, task_embed], dim=-1))
else:
task_summary = latent_summary
access_field = self.access_field(grid)
target_belief_field = self.target_belief_field(grid)
visibility_field = self.visibility_field(grid)
clearance_field = self.clearance_field(grid)
occluder_contact_field = self.occluder_contact_field(grid)
grasp_affordance_field = self.grasp_affordance_field(grid)
support_stability_field = self.support_stability_field(grid)
persistence_field = torch.sigmoid(self.persistence_field(grid))
reocclusion_field = torch.sigmoid(self.reocclusion_field(grid))
disturbance_field = torch.sigmoid(self.disturbance_field(grid))
uncertainty_field = F.softplus(self.uncertainty_field(grid))
task_metrics = compute_task_metrics_from_fields(
access_field=access_field,
persistence_field=persistence_field,
disturbance_field=disturbance_field,
reocclusion_field=reocclusion_field,
visibility_field=visibility_field,
clearance_field=clearance_field,
support_stability_field=support_stability_field,
uncertainty_field=uncertainty_field,
)
metric_residuals = 0.05 * torch.tanh(
self.task_metric_head(torch.cat([task_summary, task_embed], dim=-1))
)
metric_residual_map = {
name: metric_residuals[:, idx]
for idx, name in enumerate(TASK_METRIC_NAMES)
}
support_stability_prob = torch.sigmoid(support_stability_field)
risk_field = torch.sigmoid(
disturbance_field
+ 0.75 * reocclusion_field
+ 0.5 * (1.0 - support_stability_prob)
+ 0.25 * uncertainty_field
)
corridor_source = access_field.amax(dim=-2)
corridor_logits = F.interpolate(
corridor_source,
size=self.config.num_approach_templates,
mode="linear",
align_corners=False,
)
access_prob = torch.sigmoid(access_field)
weighted_persistence = (persistence_field.expand_as(access_prob) * access_prob).sum(dim=(-1, -2))
access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4)
persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass
disturbance_cost = disturbance_field.mean(dim=(-1, -2)).squeeze(1)
arm_identity = self.arm_identity.weight.unsqueeze(0).expand(batch_size, -1, -1)
arm_tokens = pooled_interaction.unsqueeze(1).expand(-1, 2, -1) + arm_identity
arm_role_input = torch.cat(
[arm_tokens, task_summary.unsqueeze(1).expand(-1, 2, -1)],
dim=-1,
)
arm_role_logits = self.arm_role_head(arm_role_input)
target_belief_map = self._upsampled_belief(target_belief_field)
compact_components = [
target_belief_field.mean(dim=(-1, -2)).squeeze(1),
visibility_field.mean(dim=(-1, -2)).squeeze(1),
clearance_field.mean(dim=(-1, -2)).mean(dim=1),
occluder_contact_field.mean(dim=(-1, -2)).squeeze(1),
grasp_affordance_field.mean(dim=(-1, -2)).squeeze(1),
support_stability_prob.mean(dim=(-1, -2)).squeeze(1),
persistence_field.mean(dim=(-1, -2)).squeeze(1),
reocclusion_field.mean(dim=(-1, -2)).squeeze(1),
disturbance_field.mean(dim=(-1, -2)).squeeze(1),
risk_field.mean(dim=(-1, -2)).squeeze(1),
uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
access_prob.mean(dim=(-1, -2)).transpose(0, 1).transpose(0, 1),
self.support_mode(summary_input) + (self.task_support_head(task_summary) if use_task_conditioning else 0.0),
self.phase_head(summary_input) + (self.task_phase_head(task_summary) if use_task_conditioning else 0.0),
arm_role_logits.reshape(batch_size, -1),
]
compact_state = torch.cat(
[component if component.ndim > 1 else component.unsqueeze(-1) for component in compact_components],
dim=-1,
)
output = {
"phase_logits": self.phase_head(summary_input) + (self.task_phase_head(task_summary) if use_task_conditioning else 0.0),
"arm_role_logits": arm_role_logits,
"target_belief_field": target_belief_field,
"visibility_field": visibility_field,
"clearance_field": clearance_field,
"occluder_contact_field": occluder_contact_field,
"grasp_affordance_field": grasp_affordance_field,
"support_stability_field": support_stability_field,
"persistence_field": persistence_field,
"reocclusion_field": reocclusion_field,
"disturbance_field": disturbance_field,
"risk_field": risk_field,
"uncertainty_field": uncertainty_field,
"interaction_tokens": interaction_tokens,
"field_tokens": field_tokens,
"latent_summary": task_summary,
"support_mode_logits": self.support_mode(summary_input) + (self.task_support_head(task_summary) if use_task_conditioning else 0.0),
"corridor_logits": corridor_logits,
"persistence_horizon": persistence_horizon,
"disturbance_cost": disturbance_cost,
"belief_map": target_belief_map,
"reocclusion_logit": self.reocclusion_head(summary_input) + (self.task_reocclusion_head(task_summary) if use_task_conditioning else 0.0),
"persistence_uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
"access_field": access_field,
"uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
"compact_state": compact_state,
"task_ids": task_ids,
}
output["target_field"] = output["target_belief_field"]
output["actor_feasibility_field"] = output["clearance_field"]
output.update(
{
"newly_revealed_field": task_metrics["newly_revealed_field"],
"still_visible_field": task_metrics["still_visible_field"],
"reoccluded_field": task_metrics["reoccluded_field"],
"opening_quality_field": task_metrics["opening_quality_field"],
}
)
for name in TASK_METRIC_NAMES:
if name == "gap_width":
output[name] = torch.clamp(task_metrics[name] + 0.01 * metric_residual_map[name], 0.0, 1.0)
else:
output[name] = torch.clamp(task_metrics[name] + metric_residual_map[name], 0.0, 1.0)
return output
class ElasticOcclusionStateHead(nn.Module):
def __init__(self, config: RevealHeadConfig) -> None:
super().__init__()
self.config = config
self.interaction_queries = nn.Parameter(
torch.randn(config.num_interaction_tokens, config.hidden_dim) * 0.02
)
self.interaction_attention = nn.MultiheadAttention(
embed_dim=config.hidden_dim,
num_heads=config.num_heads,
batch_first=True,
)
self.interaction_mlp = nn.Sequential(
nn.LayerNorm(config.hidden_dim),
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim),
)
self.decoder = ElasticOcclusionFieldDecoder(config)
def forward(
self,
scene_tokens: Tensor,
memory_token: Tensor | None = None,
memory_tokens: Tensor | None = None,
task_names: Sequence[str] | None = None,
use_task_conditioning: bool = True,
) -> dict[str, Tensor]:
if memory_tokens is None:
memory_tokens = memory_token
source_tokens = scene_tokens
if memory_tokens is not None:
source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
batch_size = source_tokens.shape[0]
interaction_queries = self.interaction_queries.unsqueeze(0).expand(batch_size, -1, -1)
interaction_tokens, _ = self.interaction_attention(interaction_queries, source_tokens, source_tokens)
interaction_tokens = interaction_tokens + self.interaction_mlp(interaction_tokens)
return self.decoder(
interaction_tokens=interaction_tokens,
scene_tokens=scene_tokens,
memory_tokens=memory_tokens,
task_names=task_names,
use_task_conditioning=use_task_conditioning,
)