| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Sequence |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
|
|
|
|
| HEAD_TASKS = ("generic", "foliage", "bag", "cloth") |
| TASK_METRIC_NAMES = ( |
| "opening_quality", |
| "actor_feasibility_score", |
| "gap_width", |
| "damage_proxy", |
| "release_collapse_rate", |
| "target_visibility_confidence", |
| "mouth_aperture", |
| "hold_quality", |
| "rim_slip_risk", |
| "insertable_actor_corridor", |
| "layer_separation_quality", |
| "fold_preservation", |
| "insertion_corridor", |
| "top_layer_stability", |
| "lift_too_much_risk", |
| ) |
| TASK_INDEX = {name: idx for idx, name in enumerate(HEAD_TASKS)} |
|
|
|
|
| def task_ids_from_names(task_names: Sequence[str] | None, device: torch.device, batch_size: int) -> Tensor: |
| if task_names is None: |
| return torch.zeros(batch_size, device=device, dtype=torch.long) |
| return torch.as_tensor( |
| [TASK_INDEX.get(str(name), 0) for name in task_names], |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
|
|
| def _mean_map(value: Tensor) -> Tensor: |
| return value.mean(dim=(-1, -2)).squeeze(1) |
|
|
|
|
| def compute_task_metrics_from_fields( |
| *, |
| access_field: Tensor, |
| persistence_field: Tensor, |
| disturbance_field: Tensor, |
| reocclusion_field: Tensor, |
| visibility_field: Tensor, |
| clearance_field: Tensor, |
| support_stability_field: Tensor, |
| uncertainty_field: Tensor, |
| ) -> dict[str, Tensor]: |
| access_prob = torch.sigmoid(access_field) |
| opening_mask = access_prob.amax(dim=1, keepdim=True) |
| support_stability = torch.sigmoid(support_stability_field) |
| visibility_prob = torch.sigmoid(visibility_field) |
| clearance_prob = torch.sigmoid(clearance_field).mean(dim=1, keepdim=True) |
| normalized_uncertainty = uncertainty_field / (1.0 + uncertainty_field) |
|
|
| opening_quality_field = opening_mask * persistence_field * support_stability |
| newly_revealed_field = torch.relu(visibility_prob - reocclusion_field) |
| still_visible_field = visibility_prob * persistence_field |
| reoccluded_field = reocclusion_field |
|
|
| opening_quality = _mean_map(opening_quality_field) |
| actor_feasibility_score = 0.6 * _mean_map(clearance_prob) + 0.4 * _mean_map(opening_mask) |
| base_gap = _mean_map(opening_mask) |
| disturbance_cost = _mean_map(disturbance_field) |
| support_quality = _mean_map(support_stability) |
| visibility_confidence = _mean_map(visibility_prob * (1.0 - normalized_uncertainty)) |
| reocclusion_rate = _mean_map(reocclusion_field) |
| persistence_score = _mean_map(persistence_field) |
|
|
| return { |
| "newly_revealed_field": newly_revealed_field, |
| "still_visible_field": still_visible_field, |
| "reoccluded_field": reoccluded_field, |
| "opening_quality_field": opening_quality_field, |
| "opening_quality": torch.clamp(opening_quality, 0.0, 1.0), |
| "actor_feasibility_score": torch.clamp(actor_feasibility_score, 0.0, 1.0), |
| "gap_width": 0.03 + 0.21 * torch.clamp(base_gap, 0.0, 1.0), |
| "damage_proxy": torch.clamp(disturbance_cost + 0.5 * (1.0 - support_quality), 0.0, 1.0), |
| "release_collapse_rate": torch.clamp(reocclusion_rate, 0.0, 1.0), |
| "target_visibility_confidence": torch.clamp(visibility_confidence, 0.0, 1.0), |
| "mouth_aperture": torch.clamp(base_gap, 0.0, 1.0), |
| "hold_quality": torch.clamp(0.5 * (persistence_score + support_quality), 0.0, 1.0), |
| "rim_slip_risk": torch.clamp(reocclusion_rate + 0.5 * (1.0 - support_quality), 0.0, 1.0), |
| "insertable_actor_corridor": torch.clamp(0.6 * actor_feasibility_score + 0.4 * base_gap, 0.0, 1.0), |
| "layer_separation_quality": torch.clamp(0.7 * base_gap + 0.3 * actor_feasibility_score, 0.0, 1.0), |
| "fold_preservation": torch.clamp(1.0 - disturbance_cost, 0.0, 1.0), |
| "insertion_corridor": torch.clamp(0.5 * actor_feasibility_score + 0.5 * base_gap, 0.0, 1.0), |
| "top_layer_stability": torch.clamp(support_quality, 0.0, 1.0), |
| "lift_too_much_risk": torch.clamp(disturbance_cost + 0.5 * torch.relu(base_gap - 0.5), 0.0, 1.0), |
| } |
|
|
|
|
| @dataclass |
| class RevealHeadConfig: |
| hidden_dim: int = 512 |
| num_support_modes: int = 3 |
| num_approach_templates: int = 32 |
| rollout_horizon: int = 5 |
| belief_map_size: int = 32 |
| field_size: int = 16 |
| num_heads: int = 4 |
| predict_belief_map: bool = False |
| num_phases: int = 5 |
| num_arm_roles: int = 4 |
| num_interaction_tokens: int = 8 |
| num_tasks: int = len(HEAD_TASKS) |
|
|
|
|
| class RevealStateHead(nn.Module): |
| def __init__(self, config: RevealHeadConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.field_queries = nn.Parameter( |
| torch.randn(config.field_size * config.field_size, config.hidden_dim) * 0.02 |
| ) |
| self.field_attention = nn.MultiheadAttention( |
| embed_dim=config.hidden_dim, |
| num_heads=config.num_heads, |
| batch_first=True, |
| ) |
| self.field_mlp = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| ) |
| self.support_mode = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 3), |
| nn.Linear(config.hidden_dim * 3, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_support_modes), |
| ) |
| self.access_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1) |
| self.persistence_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1) |
| self.disturbance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.uncertainty_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.reocclusion_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 2), |
| nn.Linear(config.hidden_dim * 2, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_support_modes), |
| ) |
| self.latent_summary = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 2), |
| nn.Linear(config.hidden_dim * 2, config.hidden_dim), |
| nn.GELU(), |
| ) |
|
|
| def forward(self, scene_tokens: Tensor, memory_token: Tensor | None = None) -> dict[str, Tensor]: |
| source_tokens = scene_tokens if memory_token is None else torch.cat([scene_tokens, memory_token], dim=1) |
| batch_size = source_tokens.shape[0] |
| field_queries = self.field_queries.unsqueeze(0).expand(batch_size, -1, -1) |
| field_tokens, _ = self.field_attention(field_queries, source_tokens, source_tokens) |
| field_tokens = field_tokens + self.field_mlp(field_tokens) |
| side = self.config.field_size |
| grid = field_tokens.transpose(1, 2).reshape(batch_size, self.config.hidden_dim, side, side) |
| pooled_scene = scene_tokens.mean(dim=1) |
| pooled_field = field_tokens.mean(dim=1) |
| if memory_token is not None: |
| pooled_memory = memory_token.squeeze(1) |
| else: |
| pooled_memory = pooled_scene.new_zeros(pooled_scene.shape) |
| support_input = torch.cat([pooled_scene, pooled_field, pooled_memory], dim=-1) |
| access_field = self.access_field(grid) |
| persistence_field = torch.sigmoid(self.persistence_field(grid)) |
| disturbance_field = torch.sigmoid(self.disturbance_field(grid)) |
| uncertainty_field = F.softplus(self.uncertainty_field(grid)) |
| corridor_source = access_field.amax(dim=-2) |
| corridor_logits = F.interpolate( |
| corridor_source, |
| size=self.config.num_approach_templates, |
| mode="linear", |
| align_corners=False, |
| ) |
| access_prob = torch.sigmoid(access_field) |
| weighted_persistence = (persistence_field * access_prob).sum(dim=(-1, -2)) |
| access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4) |
| persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass |
| disturbance_cost = disturbance_field.mean(dim=(-1, -2)).squeeze(1) |
| belief_map = access_field.max(dim=1, keepdim=True).values |
| if belief_map.shape[-1] != self.config.belief_map_size: |
| belief_map = F.interpolate( |
| belief_map, |
| size=(self.config.belief_map_size, self.config.belief_map_size), |
| mode="bilinear", |
| align_corners=False, |
| ) |
| latent_summary = self.latent_summary(torch.cat([pooled_scene, pooled_field], dim=-1)) |
| output = { |
| "support_mode_logits": self.support_mode(support_input), |
| "corridor_logits": corridor_logits, |
| "persistence_horizon": persistence_horizon, |
| "disturbance_cost": disturbance_cost, |
| "access_field": access_field, |
| "persistence_field": persistence_field, |
| "disturbance_field": disturbance_field, |
| "uncertainty_field": uncertainty_field, |
| "field_tokens": field_tokens, |
| "latent_summary": latent_summary, |
| "reocclusion_logit": self.reocclusion_head(torch.cat([pooled_field, pooled_memory], dim=-1)), |
| "persistence_uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1), |
| } |
| if self.config.predict_belief_map: |
| output["belief_map"] = belief_map |
| return output |
|
|
|
|
| class InteractionFieldDecoder(nn.Module): |
| def __init__(self, config: RevealHeadConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.field_queries = nn.Parameter( |
| torch.randn(config.field_size * config.field_size, config.hidden_dim) * 0.02 |
| ) |
| self.field_attention = nn.MultiheadAttention( |
| embed_dim=config.hidden_dim, |
| num_heads=config.num_heads, |
| batch_first=True, |
| ) |
| self.field_mlp = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| ) |
| summary_dim = config.hidden_dim * 4 |
| self.summary_proj = nn.Sequential( |
| nn.LayerNorm(summary_dim), |
| nn.Linear(summary_dim, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.phase_head = nn.Sequential( |
| nn.LayerNorm(summary_dim), |
| nn.Linear(summary_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_phases), |
| ) |
| self.arm_role_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 2), |
| nn.Linear(config.hidden_dim * 2, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_arm_roles), |
| ) |
| self.arm_identity = nn.Embedding(2, config.hidden_dim) |
| self.support_mode = nn.Sequential( |
| nn.LayerNorm(summary_dim), |
| nn.Linear(summary_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_support_modes), |
| ) |
| self.target_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.actor_feasibility_field = nn.Conv2d(config.hidden_dim, 2, kernel_size=1) |
| self.persistence_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.risk_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.uncertainty_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.compat_access_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1) |
| self.compat_persistence = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1) |
| self.reocclusion_head = nn.Sequential( |
| nn.LayerNorm(summary_dim), |
| nn.Linear(summary_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_support_modes), |
| ) |
|
|
| def _pool_source(self, source_tokens: Tensor | None, fallback: Tensor) -> Tensor: |
| if source_tokens is None or source_tokens.numel() == 0: |
| return fallback.new_zeros(fallback.shape) |
| return source_tokens.mean(dim=1) |
|
|
| def forward( |
| self, |
| interaction_tokens: Tensor, |
| scene_tokens: Tensor | None = None, |
| memory_tokens: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| batch_size = interaction_tokens.shape[0] |
| pooled_interaction = interaction_tokens.mean(dim=1) |
| pooled_scene = self._pool_source(scene_tokens, pooled_interaction) |
| pooled_memory = self._pool_source(memory_tokens, pooled_interaction) |
|
|
| field_queries = self.field_queries.unsqueeze(0).expand(batch_size, -1, -1) |
| source_tokens = interaction_tokens |
| if scene_tokens is not None: |
| source_tokens = torch.cat([source_tokens, scene_tokens], dim=1) |
| if memory_tokens is not None: |
| source_tokens = torch.cat([source_tokens, memory_tokens], dim=1) |
| field_tokens, _ = self.field_attention(field_queries, source_tokens, source_tokens) |
| field_tokens = field_tokens + self.field_mlp(field_tokens) |
|
|
| side = self.config.field_size |
| grid = field_tokens.transpose(1, 2).reshape(batch_size, self.config.hidden_dim, side, side) |
| pooled_field = field_tokens.mean(dim=1) |
| summary_input = torch.cat([pooled_interaction, pooled_field, pooled_scene, pooled_memory], dim=-1) |
| summary = self.summary_proj(summary_input) |
|
|
| target_field = self.target_field(grid) |
| actor_feasibility_field = self.actor_feasibility_field(grid) |
| persistence_field = torch.sigmoid(self.persistence_field(grid)) |
| risk_field = torch.sigmoid(self.risk_field(grid)) |
| uncertainty_field = F.softplus(self.uncertainty_field(grid)) |
|
|
| access_field = self.compat_access_field(grid) |
| corridor_source = access_field.amax(dim=-2) |
| corridor_logits = F.interpolate( |
| corridor_source, |
| size=self.config.num_approach_templates, |
| mode="linear", |
| align_corners=False, |
| ) |
| compatibility_persistence = torch.sigmoid(self.compat_persistence(grid)) |
| access_prob = torch.sigmoid(access_field) |
| weighted_persistence = (compatibility_persistence * access_prob).sum(dim=(-1, -2)) |
| access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4) |
| persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass |
| disturbance_cost = risk_field.mean(dim=(-1, -2)).squeeze(1) |
| belief_map = target_field |
| if belief_map.shape[-1] != self.config.belief_map_size: |
| belief_map = F.interpolate( |
| belief_map, |
| size=(self.config.belief_map_size, self.config.belief_map_size), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| arm_identity = self.arm_identity.weight.unsqueeze(0).expand(batch_size, -1, -1) |
| if interaction_tokens.shape[1] >= 2: |
| arm_tokens = interaction_tokens[:, :2] + arm_identity |
| else: |
| arm_tokens = pooled_interaction.unsqueeze(1).expand(-1, 2, -1) + arm_identity |
| arm_role_input = torch.cat( |
| [arm_tokens, summary.unsqueeze(1).expand(-1, arm_tokens.shape[1], -1)], |
| dim=-1, |
| ) |
| arm_role_logits = self.arm_role_head(arm_role_input) |
| reocclusion_logit = self.reocclusion_head(summary_input) |
|
|
| output = { |
| "phase_logits": self.phase_head(summary_input), |
| "arm_role_logits": arm_role_logits, |
| "target_field": target_field, |
| "actor_feasibility_field": actor_feasibility_field, |
| "persistence_field": persistence_field, |
| "risk_field": risk_field, |
| "uncertainty_field": uncertainty_field, |
| "interaction_tokens": interaction_tokens, |
| "field_tokens": field_tokens, |
| "latent_summary": summary, |
| "support_mode_logits": self.support_mode(summary_input), |
| "corridor_logits": corridor_logits, |
| "persistence_horizon": persistence_horizon, |
| "disturbance_cost": disturbance_cost, |
| "belief_map": belief_map, |
| "reocclusion_logit": reocclusion_logit, |
| "persistence_uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1), |
| "access_field": access_field, |
| "disturbance_field": risk_field, |
| "uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1), |
| } |
| if not self.config.predict_belief_map: |
| output.pop("belief_map") |
| return output |
|
|
|
|
| class InteractionStateHead(nn.Module): |
| def __init__(self, config: RevealHeadConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.interaction_queries = nn.Parameter( |
| torch.randn(config.num_interaction_tokens, config.hidden_dim) * 0.02 |
| ) |
| self.interaction_attention = nn.MultiheadAttention( |
| embed_dim=config.hidden_dim, |
| num_heads=config.num_heads, |
| batch_first=True, |
| ) |
| self.interaction_mlp = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| ) |
| self.decoder = InteractionFieldDecoder(config) |
|
|
| def forward( |
| self, |
| scene_tokens: Tensor, |
| memory_token: Tensor | None = None, |
| memory_tokens: Tensor | None = None, |
| ) -> dict[str, Tensor]: |
| if memory_tokens is None: |
| memory_tokens = memory_token |
| source_tokens = scene_tokens |
| if memory_tokens is not None: |
| source_tokens = torch.cat([source_tokens, memory_tokens], dim=1) |
| batch_size = source_tokens.shape[0] |
| interaction_queries = self.interaction_queries.unsqueeze(0).expand(batch_size, -1, -1) |
| interaction_tokens, _ = self.interaction_attention(interaction_queries, source_tokens, source_tokens) |
| interaction_tokens = interaction_tokens + self.interaction_mlp(interaction_tokens) |
| return self.decoder( |
| interaction_tokens=interaction_tokens, |
| scene_tokens=scene_tokens, |
| memory_tokens=memory_tokens, |
| ) |
|
|
|
|
| class ElasticOcclusionFieldDecoder(nn.Module): |
| def __init__(self, config: RevealHeadConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.field_queries = nn.Parameter( |
| torch.randn(config.field_size * config.field_size, config.hidden_dim) * 0.02 |
| ) |
| self.field_attention = nn.MultiheadAttention( |
| embed_dim=config.hidden_dim, |
| num_heads=config.num_heads, |
| batch_first=True, |
| ) |
| self.field_mlp = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| ) |
| summary_dim = config.hidden_dim * 4 |
| self.summary_proj = nn.Sequential( |
| nn.LayerNorm(summary_dim), |
| nn.Linear(summary_dim, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.phase_head = nn.Sequential( |
| nn.LayerNorm(summary_dim), |
| nn.Linear(summary_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_phases), |
| ) |
| self.arm_role_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 2), |
| nn.Linear(config.hidden_dim * 2, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_arm_roles), |
| ) |
| self.arm_identity = nn.Embedding(2, config.hidden_dim) |
| self.support_mode = nn.Sequential( |
| nn.LayerNorm(summary_dim), |
| nn.Linear(summary_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_support_modes), |
| ) |
| self.access_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1) |
| self.target_belief_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.visibility_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.clearance_field = nn.Conv2d(config.hidden_dim, 2, kernel_size=1) |
| self.occluder_contact_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.grasp_affordance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.support_stability_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.persistence_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.reocclusion_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.disturbance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.uncertainty_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1) |
| self.reocclusion_head = nn.Sequential( |
| nn.LayerNorm(summary_dim), |
| nn.Linear(summary_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.num_support_modes), |
| ) |
| self.task_embedding = nn.Embedding(config.num_tasks, config.hidden_dim) |
| self.task_field_affine = nn.Linear(config.hidden_dim, config.hidden_dim * 2) |
| self.task_summary_adapter = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 2), |
| nn.Linear(config.hidden_dim * 2, config.hidden_dim), |
| nn.GELU(), |
| ) |
| self.task_phase_head = nn.Linear(config.hidden_dim, config.num_phases) |
| self.task_support_head = nn.Linear(config.hidden_dim, config.num_support_modes) |
| self.task_reocclusion_head = nn.Linear(config.hidden_dim, config.num_support_modes) |
| self.task_metric_head = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim * 2), |
| nn.Linear(config.hidden_dim * 2, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, len(TASK_METRIC_NAMES)), |
| ) |
|
|
| def _pool_source(self, source_tokens: Tensor | None, fallback: Tensor) -> Tensor: |
| if source_tokens is None or source_tokens.numel() == 0: |
| return fallback.new_zeros(fallback.shape) |
| return source_tokens.mean(dim=1) |
|
|
| def _field_mean(self, field: Tensor) -> Tensor: |
| return field.mean(dim=(-1, -2)) |
|
|
| def _upsampled_belief(self, target_belief_field: Tensor) -> Tensor: |
| if target_belief_field.shape[-1] == self.config.belief_map_size: |
| return target_belief_field |
| return F.interpolate( |
| target_belief_field, |
| size=(self.config.belief_map_size, self.config.belief_map_size), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| def forward( |
| self, |
| interaction_tokens: Tensor, |
| scene_tokens: Tensor | None = None, |
| memory_tokens: Tensor | None = None, |
| task_names: Sequence[str] | None = None, |
| use_task_conditioning: bool = True, |
| ) -> dict[str, Tensor]: |
| batch_size = interaction_tokens.shape[0] |
| pooled_interaction = interaction_tokens.mean(dim=1) |
| pooled_scene = self._pool_source(scene_tokens, pooled_interaction) |
| pooled_memory = self._pool_source(memory_tokens, pooled_interaction) |
|
|
| field_queries = self.field_queries.unsqueeze(0).expand(batch_size, -1, -1) |
| source_tokens = interaction_tokens |
| if scene_tokens is not None: |
| source_tokens = torch.cat([source_tokens, scene_tokens], dim=1) |
| if memory_tokens is not None: |
| source_tokens = torch.cat([source_tokens, memory_tokens], dim=1) |
| field_tokens, _ = self.field_attention(field_queries, source_tokens, source_tokens) |
| field_tokens = field_tokens + self.field_mlp(field_tokens) |
|
|
| side = self.config.field_size |
| grid = field_tokens.transpose(1, 2).reshape(batch_size, self.config.hidden_dim, side, side) |
| pooled_field = field_tokens.mean(dim=1) |
| summary_input = torch.cat([pooled_interaction, pooled_field, pooled_scene, pooled_memory], dim=-1) |
| latent_summary = self.summary_proj(summary_input) |
| task_ids = task_ids_from_names(task_names, interaction_tokens.device, batch_size) |
| task_embed = self.task_embedding(task_ids) |
| if use_task_conditioning: |
| scale, bias = self.task_field_affine(task_embed).chunk(2, dim=-1) |
| grid = grid * (1.0 + 0.1 * scale.view(batch_size, self.config.hidden_dim, 1, 1)) |
| grid = grid + 0.1 * bias.view(batch_size, self.config.hidden_dim, 1, 1) |
| task_summary = latent_summary + 0.1 * self.task_summary_adapter(torch.cat([latent_summary, task_embed], dim=-1)) |
| else: |
| task_summary = latent_summary |
|
|
| access_field = self.access_field(grid) |
| target_belief_field = self.target_belief_field(grid) |
| visibility_field = self.visibility_field(grid) |
| clearance_field = self.clearance_field(grid) |
| occluder_contact_field = self.occluder_contact_field(grid) |
| grasp_affordance_field = self.grasp_affordance_field(grid) |
| support_stability_field = self.support_stability_field(grid) |
| persistence_field = torch.sigmoid(self.persistence_field(grid)) |
| reocclusion_field = torch.sigmoid(self.reocclusion_field(grid)) |
| disturbance_field = torch.sigmoid(self.disturbance_field(grid)) |
| uncertainty_field = F.softplus(self.uncertainty_field(grid)) |
| task_metrics = compute_task_metrics_from_fields( |
| access_field=access_field, |
| persistence_field=persistence_field, |
| disturbance_field=disturbance_field, |
| reocclusion_field=reocclusion_field, |
| visibility_field=visibility_field, |
| clearance_field=clearance_field, |
| support_stability_field=support_stability_field, |
| uncertainty_field=uncertainty_field, |
| ) |
| metric_residuals = 0.05 * torch.tanh( |
| self.task_metric_head(torch.cat([task_summary, task_embed], dim=-1)) |
| ) |
| metric_residual_map = { |
| name: metric_residuals[:, idx] |
| for idx, name in enumerate(TASK_METRIC_NAMES) |
| } |
|
|
| support_stability_prob = torch.sigmoid(support_stability_field) |
| risk_field = torch.sigmoid( |
| disturbance_field |
| + 0.75 * reocclusion_field |
| + 0.5 * (1.0 - support_stability_prob) |
| + 0.25 * uncertainty_field |
| ) |
| corridor_source = access_field.amax(dim=-2) |
| corridor_logits = F.interpolate( |
| corridor_source, |
| size=self.config.num_approach_templates, |
| mode="linear", |
| align_corners=False, |
| ) |
| access_prob = torch.sigmoid(access_field) |
| weighted_persistence = (persistence_field.expand_as(access_prob) * access_prob).sum(dim=(-1, -2)) |
| access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4) |
| persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass |
| disturbance_cost = disturbance_field.mean(dim=(-1, -2)).squeeze(1) |
|
|
| arm_identity = self.arm_identity.weight.unsqueeze(0).expand(batch_size, -1, -1) |
| arm_tokens = pooled_interaction.unsqueeze(1).expand(-1, 2, -1) + arm_identity |
| arm_role_input = torch.cat( |
| [arm_tokens, task_summary.unsqueeze(1).expand(-1, 2, -1)], |
| dim=-1, |
| ) |
| arm_role_logits = self.arm_role_head(arm_role_input) |
| target_belief_map = self._upsampled_belief(target_belief_field) |
| compact_components = [ |
| target_belief_field.mean(dim=(-1, -2)).squeeze(1), |
| visibility_field.mean(dim=(-1, -2)).squeeze(1), |
| clearance_field.mean(dim=(-1, -2)).mean(dim=1), |
| occluder_contact_field.mean(dim=(-1, -2)).squeeze(1), |
| grasp_affordance_field.mean(dim=(-1, -2)).squeeze(1), |
| support_stability_prob.mean(dim=(-1, -2)).squeeze(1), |
| persistence_field.mean(dim=(-1, -2)).squeeze(1), |
| reocclusion_field.mean(dim=(-1, -2)).squeeze(1), |
| disturbance_field.mean(dim=(-1, -2)).squeeze(1), |
| risk_field.mean(dim=(-1, -2)).squeeze(1), |
| uncertainty_field.mean(dim=(-1, -2)).squeeze(1), |
| access_prob.mean(dim=(-1, -2)).transpose(0, 1).transpose(0, 1), |
| self.support_mode(summary_input) + (self.task_support_head(task_summary) if use_task_conditioning else 0.0), |
| self.phase_head(summary_input) + (self.task_phase_head(task_summary) if use_task_conditioning else 0.0), |
| arm_role_logits.reshape(batch_size, -1), |
| ] |
| compact_state = torch.cat( |
| [component if component.ndim > 1 else component.unsqueeze(-1) for component in compact_components], |
| dim=-1, |
| ) |
|
|
| output = { |
| "phase_logits": self.phase_head(summary_input) + (self.task_phase_head(task_summary) if use_task_conditioning else 0.0), |
| "arm_role_logits": arm_role_logits, |
| "target_belief_field": target_belief_field, |
| "visibility_field": visibility_field, |
| "clearance_field": clearance_field, |
| "occluder_contact_field": occluder_contact_field, |
| "grasp_affordance_field": grasp_affordance_field, |
| "support_stability_field": support_stability_field, |
| "persistence_field": persistence_field, |
| "reocclusion_field": reocclusion_field, |
| "disturbance_field": disturbance_field, |
| "risk_field": risk_field, |
| "uncertainty_field": uncertainty_field, |
| "interaction_tokens": interaction_tokens, |
| "field_tokens": field_tokens, |
| "latent_summary": task_summary, |
| "support_mode_logits": self.support_mode(summary_input) + (self.task_support_head(task_summary) if use_task_conditioning else 0.0), |
| "corridor_logits": corridor_logits, |
| "persistence_horizon": persistence_horizon, |
| "disturbance_cost": disturbance_cost, |
| "belief_map": target_belief_map, |
| "reocclusion_logit": self.reocclusion_head(summary_input) + (self.task_reocclusion_head(task_summary) if use_task_conditioning else 0.0), |
| "persistence_uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1), |
| "access_field": access_field, |
| "uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1), |
| "compact_state": compact_state, |
| "task_ids": task_ids, |
| } |
| output["target_field"] = output["target_belief_field"] |
| output["actor_feasibility_field"] = output["clearance_field"] |
| output.update( |
| { |
| "newly_revealed_field": task_metrics["newly_revealed_field"], |
| "still_visible_field": task_metrics["still_visible_field"], |
| "reoccluded_field": task_metrics["reoccluded_field"], |
| "opening_quality_field": task_metrics["opening_quality_field"], |
| } |
| ) |
| for name in TASK_METRIC_NAMES: |
| if name == "gap_width": |
| output[name] = torch.clamp(task_metrics[name] + 0.01 * metric_residual_map[name], 0.0, 1.0) |
| else: |
| output[name] = torch.clamp(task_metrics[name] + metric_residual_map[name], 0.0, 1.0) |
| return output |
|
|
|
|
| class ElasticOcclusionStateHead(nn.Module): |
| def __init__(self, config: RevealHeadConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.interaction_queries = nn.Parameter( |
| torch.randn(config.num_interaction_tokens, config.hidden_dim) * 0.02 |
| ) |
| self.interaction_attention = nn.MultiheadAttention( |
| embed_dim=config.hidden_dim, |
| num_heads=config.num_heads, |
| batch_first=True, |
| ) |
| self.interaction_mlp = nn.Sequential( |
| nn.LayerNorm(config.hidden_dim), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| nn.GELU(), |
| nn.Linear(config.hidden_dim, config.hidden_dim), |
| ) |
| self.decoder = ElasticOcclusionFieldDecoder(config) |
|
|
| def forward( |
| self, |
| scene_tokens: Tensor, |
| memory_token: Tensor | None = None, |
| memory_tokens: Tensor | None = None, |
| task_names: Sequence[str] | None = None, |
| use_task_conditioning: bool = True, |
| ) -> dict[str, Tensor]: |
| if memory_tokens is None: |
| memory_tokens = memory_token |
| source_tokens = scene_tokens |
| if memory_tokens is not None: |
| source_tokens = torch.cat([source_tokens, memory_tokens], dim=1) |
| batch_size = source_tokens.shape[0] |
| interaction_queries = self.interaction_queries.unsqueeze(0).expand(batch_size, -1, -1) |
| interaction_tokens, _ = self.interaction_attention(interaction_queries, source_tokens, source_tokens) |
| interaction_tokens = interaction_tokens + self.interaction_mlp(interaction_tokens) |
| return self.decoder( |
| interaction_tokens=interaction_tokens, |
| scene_tokens=scene_tokens, |
| memory_tokens=memory_tokens, |
| task_names=task_names, |
| use_task_conditioning=use_task_conditioning, |
| ) |
|
|