lsnu commited on
Commit
b1ef16c
·
verified ·
1 Parent(s): 6fa1956

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. code/reveal_vla_bimanual/models/__init__.py +24 -0
  2. code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc +0 -0
  3. code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc +0 -0
  4. code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc +0 -0
  5. code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc +0 -0
  6. code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc +0 -0
  7. code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc +0 -0
  8. code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc +0 -0
  9. code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc +0 -0
  10. code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc +0 -0
  11. code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc +0 -0
  12. code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc +0 -0
  13. code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc +0 -0
  14. code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc +0 -0
  15. code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc +0 -0
  16. code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc +0 -0
  17. code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc +0 -0
  18. code/reveal_vla_bimanual/models/action_decoder.py +68 -0
  19. code/reveal_vla_bimanual/models/backbones.py +116 -0
  20. code/reveal_vla_bimanual/models/multiview_fusion.py +57 -0
  21. code/reveal_vla_bimanual/models/planner.py +61 -0
  22. code/reveal_vla_bimanual/models/policy.py +127 -0
  23. code/reveal_vla_bimanual/models/reveal_head.py +55 -0
  24. code/reveal_vla_bimanual/models/world_model.py +70 -0
  25. code/reveal_vla_bimanual/sim_reveal/__init__.py +15 -0
  26. code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-310.pyc +0 -0
  27. code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc +0 -0
  28. code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-310.pyc +0 -0
  29. code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc +0 -0
  30. code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-310.pyc +0 -0
  31. code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc +0 -0
  32. code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-310.pyc +0 -0
  33. code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-311.pyc +0 -0
  34. code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-310.pyc +0 -0
  35. code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-311.pyc +0 -0
  36. code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-310.pyc +0 -0
  37. code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-311.pyc +0 -0
  38. code/reveal_vla_bimanual/sim_reveal/__pycache__/labels.cpython-311.pyc +0 -0
  39. code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-310.pyc +0 -0
  40. code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc +0 -0
  41. code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-310.pyc +0 -0
  42. code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc +0 -0
  43. code/reveal_vla_bimanual/sim_reveal/__pycache__/teachers.cpython-311.pyc +0 -0
  44. code/reveal_vla_bimanual/sim_reveal/base.py +32 -0
  45. code/reveal_vla_bimanual/sim_reveal/dataset.py +137 -0
  46. code/reveal_vla_bimanual/sim_reveal/generate_dataset.py +40 -0
  47. code/reveal_vla_bimanual/sim_reveal/isaac_smoke.py +29 -0
  48. code/reveal_vla_bimanual/sim_reveal/isaac_wrapper.py +16 -0
  49. code/reveal_vla_bimanual/sim_reveal/labels.py +61 -0
  50. code/reveal_vla_bimanual/sim_reveal/procedural_envs.py +545 -0
code/reveal_vla_bimanual/models/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.action_decoder import ACTBimanualChunkDecoder, ChunkDecoderConfig
2
+ from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
3
+ from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
4
+ from models.planner import PlannerConfig, RevealPlanner
5
+ from models.policy import BackboneOnlyPolicy, RevealBimanualPolicy
6
+ from models.reveal_head import RevealHeadConfig, RevealStateHead
7
+ from models.world_model import RevealWM, RevealWMConfig
8
+
9
+ __all__ = [
10
+ "ACTBimanualChunkDecoder",
11
+ "BackboneOnlyPolicy",
12
+ "ChunkDecoderConfig",
13
+ "FrozenVLBackbone",
14
+ "FrozenVLBackboneConfig",
15
+ "MultiViewFusion",
16
+ "MultiViewFusionConfig",
17
+ "PlannerConfig",
18
+ "RevealBimanualPolicy",
19
+ "RevealHeadConfig",
20
+ "RevealPlanner",
21
+ "RevealStateHead",
22
+ "RevealWM",
23
+ "RevealWMConfig",
24
+ ]
code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (840 Bytes). View file
 
code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.02 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc ADDED
Binary file (2.64 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc ADDED
Binary file (4.77 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc ADDED
Binary file (5.04 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc ADDED
Binary file (9.38 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc ADDED
Binary file (3.9 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc ADDED
Binary file (2.52 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc ADDED
Binary file (3.62 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc ADDED
Binary file (5.17 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc ADDED
Binary file (8.91 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc ADDED
Binary file (2.04 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc ADDED
Binary file (3.84 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc ADDED
Binary file (2.48 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc ADDED
Binary file (4.71 kB). View file
 
code/reveal_vla_bimanual/models/action_decoder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ @dataclass
10
+ class ChunkDecoderConfig:
11
+ hidden_dim: int = 512
12
+ num_heads: int = 8
13
+ num_layers: int = 4
14
+ ff_dim: int = 2048
15
+ dropout: float = 0.1
16
+ chunk_size: int = 8
17
+ action_dim: int = 14
18
+ num_candidates: int = 8
19
+
20
+
21
+ class ACTBimanualChunkDecoder(nn.Module):
22
+ def __init__(self, config: ChunkDecoderConfig) -> None:
23
+ super().__init__()
24
+ self.config = config
25
+ decoder_layer = nn.TransformerDecoderLayer(
26
+ d_model=config.hidden_dim,
27
+ nhead=config.num_heads,
28
+ dim_feedforward=config.ff_dim,
29
+ dropout=config.dropout,
30
+ batch_first=True,
31
+ norm_first=True,
32
+ )
33
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers)
34
+ self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim)
35
+ self.action_mean = nn.Linear(config.hidden_dim, config.action_dim)
36
+ self.action_log_std = nn.Linear(config.hidden_dim, config.action_dim)
37
+ self.proposal_score = nn.Sequential(
38
+ nn.LayerNorm(config.hidden_dim),
39
+ nn.Linear(config.hidden_dim, 1),
40
+ )
41
+
42
+ def forward(self, scene_tokens: Tensor) -> dict[str, Tensor]:
43
+ batch_size = scene_tokens.shape[0]
44
+ query = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
45
+ decoded = self.decoder(query, scene_tokens)
46
+ return {
47
+ "decoded_tokens": decoded,
48
+ "action_mean": self.action_mean(decoded),
49
+ "action_log_std": self.action_log_std(decoded).clamp(min=-5.0, max=2.0),
50
+ "proposal_score": self.proposal_score(decoded.mean(dim=1)).squeeze(-1),
51
+ }
52
+
53
+ def sample_candidates(self, action_mean: Tensor, action_log_std: Tensor, num_candidates: int | None = None) -> Tensor:
54
+ num_candidates = num_candidates or self.config.num_candidates
55
+ if num_candidates <= 1:
56
+ return action_mean.unsqueeze(1)
57
+ std = action_log_std.exp()
58
+ noise = torch.randn(
59
+ action_mean.size(0),
60
+ num_candidates,
61
+ action_mean.size(1),
62
+ action_mean.size(2),
63
+ device=action_mean.device,
64
+ dtype=action_mean.dtype,
65
+ )
66
+ candidates = action_mean.unsqueeze(1) + noise * std.unsqueeze(1)
67
+ candidates[:, 0] = action_mean
68
+ return candidates
code/reveal_vla_bimanual/models/backbones.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import math
5
+ from typing import Sequence
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, nn
10
+
11
+
12
+ @dataclass
13
+ class FrozenVLBackboneConfig:
14
+ model_name: str = "openai/clip-vit-base-patch32"
15
+ hidden_dim: int = 512
16
+ max_text_tokens: int = 32
17
+ freeze_backbone: bool = True
18
+ gradient_checkpointing: bool = True
19
+ use_dummy_backbone: bool = False
20
+
21
+
22
+ class _DummyTextTokenizer:
23
+ def __init__(self, vocab_size: int = 8192, max_length: int = 32) -> None:
24
+ self.vocab_size = vocab_size
25
+ self.max_length = max_length
26
+
27
+ def __call__(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]:
28
+ token_ids = torch.zeros((len(texts), self.max_length), dtype=torch.long, device=device)
29
+ attention_mask = torch.zeros_like(token_ids)
30
+ for row, text in enumerate(texts):
31
+ encoded = [min(ord(char), self.vocab_size - 1) for char in text[: self.max_length]]
32
+ if encoded:
33
+ token_ids[row, : len(encoded)] = torch.tensor(encoded, dtype=torch.long, device=device)
34
+ attention_mask[row, : len(encoded)] = 1
35
+ return {"input_ids": token_ids, "attention_mask": attention_mask}
36
+
37
+
38
+ class FrozenVLBackbone(nn.Module):
39
+ def __init__(self, config: FrozenVLBackboneConfig) -> None:
40
+ super().__init__()
41
+ self.config = config
42
+ self.hidden_dim = config.hidden_dim
43
+ self.use_dummy_backbone = config.use_dummy_backbone
44
+
45
+ if config.use_dummy_backbone:
46
+ self.image_patch_size = 16
47
+ self.tokenizer = _DummyTextTokenizer(max_length=config.max_text_tokens)
48
+ else:
49
+ from transformers import AutoTokenizer, CLIPModel
50
+
51
+ clip_model = CLIPModel.from_pretrained(config.model_name)
52
+ self.vision_model = clip_model.vision_model
53
+ self.text_model = clip_model.text_model
54
+ self.visual_projection = clip_model.visual_projection
55
+ self.text_projection = clip_model.text_projection
56
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
57
+ self.hidden_dim = clip_model.config.projection_dim
58
+ if config.gradient_checkpointing:
59
+ if hasattr(self.vision_model, "gradient_checkpointing_enable"):
60
+ self.vision_model.gradient_checkpointing_enable()
61
+ if hasattr(self.text_model, "gradient_checkpointing_enable"):
62
+ self.text_model.gradient_checkpointing_enable()
63
+
64
+ if config.freeze_backbone:
65
+ for parameter in self.parameters():
66
+ parameter.requires_grad = False
67
+
68
+ def tokenize_text(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]:
69
+ if self.use_dummy_backbone:
70
+ return self.tokenizer(texts, device=device)
71
+ return self.tokenizer(
72
+ list(texts),
73
+ padding=True,
74
+ truncation=True,
75
+ max_length=self.config.max_text_tokens,
76
+ return_tensors="pt",
77
+ ).to(device)
78
+
79
+ def encode_images(self, images: Tensor) -> Tensor:
80
+ batch_size, num_views, channels, height, width = images.shape
81
+ flat_images = images.reshape(batch_size * num_views, channels, height, width)
82
+ if self.use_dummy_backbone:
83
+ pooled = F.avg_pool2d(flat_images.float(), kernel_size=self.image_patch_size, stride=self.image_patch_size)
84
+ patch_tokens = pooled.flatten(2).transpose(1, 2)
85
+ grid_h, grid_w = pooled.shape[-2], pooled.shape[-1]
86
+ y_coords = torch.linspace(-1.0, 1.0, steps=grid_h, device=images.device)
87
+ x_coords = torch.linspace(-1.0, 1.0, steps=grid_w, device=images.device)
88
+ grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij")
89
+ coords = torch.stack([grid_x, grid_y], dim=-1).reshape(1, grid_h * grid_w, 2)
90
+ coords = coords.expand(patch_tokens.shape[0], -1, -1)
91
+ intensity = patch_tokens.mean(dim=-1, keepdim=True)
92
+ base = torch.cat([patch_tokens, intensity, coords], dim=-1)
93
+ repeat_factor = math.ceil(self.hidden_dim / base.shape[-1])
94
+ tokens = base.repeat(1, 1, repeat_factor)[..., : self.hidden_dim]
95
+ else:
96
+ outputs = self.vision_model(pixel_values=flat_images)
97
+ tokens = self.visual_projection(outputs.last_hidden_state)
98
+ num_tokens = tokens.shape[1]
99
+ return tokens.reshape(batch_size, num_views, num_tokens, -1)
100
+
101
+ def encode_text(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
102
+ if self.use_dummy_backbone:
103
+ vocab_scale = float(self.tokenizer.vocab_size - 1)
104
+ token_values = input_ids.float() / vocab_scale
105
+ frequencies = torch.linspace(
106
+ 1.0,
107
+ 4.0,
108
+ steps=max(1, self.hidden_dim // 2),
109
+ device=input_ids.device,
110
+ dtype=token_values.dtype,
111
+ )
112
+ phases = token_values.unsqueeze(-1) * frequencies.view(1, 1, -1) * (2.0 * math.pi)
113
+ embeddings = torch.cat([torch.sin(phases), torch.cos(phases)], dim=-1)[..., : self.hidden_dim]
114
+ return embeddings * attention_mask.unsqueeze(-1).float()
115
+ outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
116
+ return self.text_projection(outputs.last_hidden_state)
code/reveal_vla_bimanual/models/multiview_fusion.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ @dataclass
10
+ class MultiViewFusionConfig:
11
+ hidden_dim: int = 512
12
+ num_cameras: int = 3
13
+ num_layers: int = 4
14
+ num_heads: int = 8
15
+ ff_dim: int = 2048
16
+ dropout: float = 0.1
17
+ proprio_dim: int = 32
18
+ proprio_tokens: int = 1
19
+
20
+
21
+ class MultiViewFusion(nn.Module):
22
+ def __init__(self, config: MultiViewFusionConfig) -> None:
23
+ super().__init__()
24
+ self.config = config
25
+ self.camera_embedding = nn.Embedding(config.num_cameras, config.hidden_dim)
26
+ encoder_layer = nn.TransformerEncoderLayer(
27
+ d_model=config.hidden_dim,
28
+ nhead=config.num_heads,
29
+ dim_feedforward=config.ff_dim,
30
+ dropout=config.dropout,
31
+ batch_first=True,
32
+ norm_first=True,
33
+ )
34
+ self.cross_view_transformer = nn.TransformerEncoder(
35
+ encoder_layer,
36
+ num_layers=config.num_layers,
37
+ )
38
+ self.proprio_adapter = nn.Sequential(
39
+ nn.LayerNorm(config.proprio_dim),
40
+ nn.Linear(config.proprio_dim, config.hidden_dim * config.proprio_tokens),
41
+ nn.GELU(),
42
+ )
43
+
44
+ def forward(self, image_tokens: Tensor, proprio: Tensor, language_tokens: Tensor) -> Tensor:
45
+ batch_size, num_views, num_tokens, hidden_dim = image_tokens.shape
46
+ if num_views != self.config.num_cameras:
47
+ raise ValueError(f"Expected {self.config.num_cameras} views, received {num_views}")
48
+
49
+ camera_ids = torch.arange(num_views, device=image_tokens.device)
50
+ camera_embed = self.camera_embedding(camera_ids).view(1, num_views, 1, hidden_dim)
51
+ image_tokens = image_tokens + camera_embed
52
+ fused = self.cross_view_transformer(image_tokens.reshape(batch_size, num_views * num_tokens, hidden_dim))
53
+
54
+ proprio_tokens = self.proprio_adapter(proprio).view(
55
+ batch_size, self.config.proprio_tokens, hidden_dim
56
+ )
57
+ return torch.cat([fused, proprio_tokens, language_tokens], dim=1)
code/reveal_vla_bimanual/models/planner.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+
9
+ @dataclass
10
+ class PlannerConfig:
11
+ num_candidates: int = 8
12
+ corridor_weight: float = 1.0
13
+ persistence_weight: float = 0.5
14
+ proposal_weight: float = 0.5
15
+ task_progress_weight: float = 0.75
16
+ disturbance_weight: float = 0.75
17
+ reocclusion_weight: float = 0.5
18
+ visibility_weight: float = 0.25
19
+
20
+
21
+ class RevealPlanner:
22
+ def __init__(self, config: PlannerConfig) -> None:
23
+ self.config = config
24
+
25
+ def score_rollouts(
26
+ self,
27
+ rollout_state: dict[str, Tensor],
28
+ proposal_scores: Tensor,
29
+ candidate_chunks: Tensor | None = None,
30
+ belief_gain: Tensor | None = None,
31
+ ) -> Tensor:
32
+ corridor_prob = rollout_state["corridor_logits"].sigmoid().amax(dim=-1).mean(dim=(-1, -2))
33
+ persistence = rollout_state["persistence_horizon"].mean(dim=(-1, -2))
34
+ disturbance = rollout_state["disturbance_cost"].mean(dim=-1)
35
+ reocclusion_penalty = torch.relu(1.0 - rollout_state["corridor_logits"].sigmoid().amax(dim=-1)).mean(dim=(-1, -2))
36
+ task_progress = proposal_scores.new_zeros(proposal_scores.shape)
37
+ if candidate_chunks is not None:
38
+ actor_reach = torch.tanh(candidate_chunks[..., 8]).mean(dim=-1)
39
+ actor_retrieve = torch.tanh(candidate_chunks[..., 13]).amax(dim=-1)
40
+ task_progress = 0.5 * (actor_reach + 1.0) * 0.5 + 0.5 * (actor_retrieve + 1.0) * 0.5
41
+ score = (
42
+ self.config.corridor_weight * corridor_prob
43
+ + self.config.persistence_weight * persistence
44
+ + self.config.proposal_weight * proposal_scores
45
+ + self.config.task_progress_weight * task_progress
46
+ - self.config.disturbance_weight * disturbance
47
+ - self.config.reocclusion_weight * reocclusion_penalty
48
+ )
49
+ if belief_gain is not None:
50
+ score = score + self.config.visibility_weight * belief_gain
51
+ return score
52
+
53
+ def select_best(self, candidate_chunks: Tensor, rollout_state: dict[str, Tensor], proposal_scores: Tensor) -> dict[str, Tensor]:
54
+ scores = self.score_rollouts(rollout_state, proposal_scores, candidate_chunks=candidate_chunks)
55
+ best_idx = scores.argmax(dim=-1)
56
+ batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
57
+ return {
58
+ "scores": scores,
59
+ "best_indices": best_idx,
60
+ "best_chunk": candidate_chunks[batch_indices, best_idx],
61
+ }
code/reveal_vla_bimanual/models/policy.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Sequence
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+
9
+ from models.action_decoder import ACTBimanualChunkDecoder, ChunkDecoderConfig
10
+ from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
11
+ from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
12
+ from models.planner import PlannerConfig, RevealPlanner
13
+ from models.reveal_head import RevealHeadConfig, RevealStateHead
14
+ from models.world_model import RevealWM, RevealWMConfig
15
+
16
+
17
+ @dataclass
18
+ class PolicyConfig:
19
+ backbone: FrozenVLBackboneConfig = field(default_factory=FrozenVLBackboneConfig)
20
+ fusion: MultiViewFusionConfig = field(default_factory=MultiViewFusionConfig)
21
+ decoder: ChunkDecoderConfig = field(default_factory=ChunkDecoderConfig)
22
+ reveal_head: RevealHeadConfig = field(default_factory=RevealHeadConfig)
23
+ world_model: RevealWMConfig = field(default_factory=RevealWMConfig)
24
+ planner: PlannerConfig = field(default_factory=PlannerConfig)
25
+
26
+
27
+ class BackboneOnlyPolicy(nn.Module):
28
+ def __init__(self, config: PolicyConfig) -> None:
29
+ super().__init__()
30
+ self.config = config
31
+ self.backbone = FrozenVLBackbone(config.backbone)
32
+ self.fusion = MultiViewFusion(config.fusion)
33
+ self.decoder = ACTBimanualChunkDecoder(config.decoder)
34
+
35
+ def _encode_language(
36
+ self,
37
+ images: Tensor,
38
+ texts: Sequence[str] | None = None,
39
+ language_tokens: dict[str, Tensor] | None = None,
40
+ ) -> Tensor:
41
+ if language_tokens is None:
42
+ if texts is None:
43
+ raise ValueError("Either texts or language_tokens must be provided.")
44
+ language_tokens = self.backbone.tokenize_text(texts, device=images.device)
45
+ return self.backbone.encode_text(
46
+ input_ids=language_tokens["input_ids"],
47
+ attention_mask=language_tokens["attention_mask"],
48
+ )
49
+
50
+ def encode_scene(
51
+ self,
52
+ images: Tensor,
53
+ proprio: Tensor,
54
+ texts: Sequence[str] | None = None,
55
+ language_tokens: dict[str, Tensor] | None = None,
56
+ ) -> Tensor:
57
+ image_tokens = self.backbone.encode_images(images)
58
+ text_tokens = self._encode_language(images, texts=texts, language_tokens=language_tokens)
59
+ return self.fusion(image_tokens=image_tokens, proprio=proprio, language_tokens=text_tokens)
60
+
61
+ def forward(
62
+ self,
63
+ images: Tensor,
64
+ proprio: Tensor,
65
+ texts: Sequence[str] | None = None,
66
+ language_tokens: dict[str, Tensor] | None = None,
67
+ ) -> dict[str, Tensor]:
68
+ scene_tokens = self.encode_scene(images, proprio, texts=texts, language_tokens=language_tokens)
69
+ decoded = self.decoder(scene_tokens)
70
+ decoded["scene_tokens"] = scene_tokens
71
+ return decoded
72
+
73
+
74
+ class RevealBimanualPolicy(BackboneOnlyPolicy):
75
+ def __init__(self, config: PolicyConfig) -> None:
76
+ super().__init__(config)
77
+ self.reveal_head = RevealStateHead(config.reveal_head)
78
+ self.world_model = RevealWM(config.world_model)
79
+ self.planner = RevealPlanner(config.planner)
80
+
81
+ def forward(
82
+ self,
83
+ images: Tensor,
84
+ proprio: Tensor,
85
+ texts: Sequence[str] | None = None,
86
+ language_tokens: dict[str, Tensor] | None = None,
87
+ plan: bool = True,
88
+ support_mode_conditioning: bool = True,
89
+ ) -> dict[str, Tensor]:
90
+ outputs = super().forward(images, proprio, texts=texts, language_tokens=language_tokens)
91
+ reveal_state = self.reveal_head(outputs["scene_tokens"])
92
+ outputs["reveal_state"] = reveal_state
93
+
94
+ candidate_chunks = self.decoder.sample_candidates(
95
+ outputs["action_mean"],
96
+ outputs["action_log_std"],
97
+ num_candidates=self.config.decoder.num_candidates,
98
+ )
99
+ outputs["candidate_chunks"] = candidate_chunks
100
+
101
+ if plan:
102
+ batch_size, num_candidates, chunk_size, action_dim = candidate_chunks.shape
103
+ flat_chunks = candidate_chunks.view(batch_size * num_candidates, chunk_size, action_dim)
104
+ tiled_scene = outputs["scene_tokens"].unsqueeze(1).expand(-1, num_candidates, -1, -1)
105
+ tiled_scene = tiled_scene.reshape(batch_size * num_candidates, outputs["scene_tokens"].shape[1], outputs["scene_tokens"].shape[2])
106
+ planning_reveal_state = reveal_state
107
+ if not support_mode_conditioning:
108
+ planning_reveal_state = dict(reveal_state)
109
+ planning_reveal_state["support_mode_logits"] = torch.zeros_like(reveal_state["support_mode_logits"])
110
+ tiled_reveal = {
111
+ key: value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape(batch_size * num_candidates, *value.shape[1:])
112
+ for key, value in planning_reveal_state.items()
113
+ }
114
+ rollout = self.world_model(tiled_scene, tiled_reveal, flat_chunks)
115
+ reshaped_rollout = {
116
+ key: value.view(batch_size, num_candidates, *value.shape[1:]) for key, value in rollout.items()
117
+ }
118
+ selected = self.planner.select_best(
119
+ candidate_chunks=candidate_chunks,
120
+ rollout_state=reshaped_rollout,
121
+ proposal_scores=outputs["proposal_score"].unsqueeze(-1).expand(-1, num_candidates),
122
+ )
123
+ outputs["planned_rollout"] = reshaped_rollout
124
+ outputs["planned_chunk"] = selected["best_chunk"]
125
+ outputs["planner_scores"] = selected["scores"]
126
+ outputs["best_candidate_indices"] = selected["best_indices"]
127
+ return outputs
code/reveal_vla_bimanual/models/reveal_head.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class RevealHeadConfig:
10
+ hidden_dim: int = 512
11
+ num_support_modes: int = 3
12
+ num_approach_templates: int = 32
13
+ rollout_horizon: int = 5
14
+ belief_map_size: int = 32
15
+ predict_belief_map: bool = False
16
+
17
+
18
+ class RevealStateHead(nn.Module):
19
+ def __init__(self, config: RevealHeadConfig) -> None:
20
+ super().__init__()
21
+ self.config = config
22
+ self.trunk = nn.Sequential(
23
+ nn.LayerNorm(config.hidden_dim),
24
+ nn.Linear(config.hidden_dim, config.hidden_dim),
25
+ nn.GELU(),
26
+ )
27
+ self.support_mode = nn.Linear(config.hidden_dim, config.num_support_modes)
28
+ self.corridor = nn.Linear(
29
+ config.hidden_dim,
30
+ config.num_support_modes * config.num_approach_templates,
31
+ )
32
+ self.persistence = nn.Linear(config.hidden_dim, config.num_support_modes)
33
+ self.disturbance = nn.Linear(config.hidden_dim, 1)
34
+ self.belief_map = None
35
+ if config.predict_belief_map:
36
+ map_side = config.belief_map_size
37
+ self.belief_map = nn.Linear(config.hidden_dim, map_side * map_side)
38
+
39
+ def forward(self, scene_tokens: Tensor) -> dict[str, Tensor]:
40
+ pooled = scene_tokens.mean(dim=1)
41
+ hidden = self.trunk(pooled)
42
+ output = {
43
+ "support_mode_logits": self.support_mode(hidden),
44
+ "corridor_logits": self.corridor(hidden).view(
45
+ hidden.shape[0],
46
+ self.config.num_support_modes,
47
+ self.config.num_approach_templates,
48
+ ),
49
+ "persistence_horizon": self.persistence(hidden),
50
+ "disturbance_cost": self.disturbance(hidden).squeeze(-1),
51
+ }
52
+ if self.belief_map is not None:
53
+ side = self.config.belief_map_size
54
+ output["belief_map"] = self.belief_map(hidden).view(hidden.shape[0], 1, side, side)
55
+ return output
code/reveal_vla_bimanual/models/world_model.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ @dataclass
10
+ class RevealWMConfig:
11
+ hidden_dim: int = 512
12
+ action_dim: int = 14
13
+ num_support_modes: int = 3
14
+ num_approach_templates: int = 32
15
+ rollout_horizon: int = 5
16
+
17
+
18
+ class RevealWM(nn.Module):
19
+ def __init__(self, config: RevealWMConfig) -> None:
20
+ super().__init__()
21
+ self.config = config
22
+ reveal_dim = (
23
+ config.num_support_modes
24
+ + config.num_support_modes * config.num_approach_templates
25
+ + config.num_support_modes
26
+ + 1
27
+ )
28
+ self.initial = nn.Sequential(
29
+ nn.LayerNorm(config.hidden_dim + reveal_dim),
30
+ nn.Linear(config.hidden_dim + reveal_dim, config.hidden_dim),
31
+ nn.GELU(),
32
+ )
33
+ self.action_encoder = nn.Linear(config.action_dim, config.hidden_dim)
34
+ self.gru = nn.GRU(config.hidden_dim, config.hidden_dim, batch_first=True)
35
+ self.support_mode = nn.Linear(config.hidden_dim, config.num_support_modes)
36
+ self.corridor = nn.Linear(
37
+ config.hidden_dim,
38
+ config.num_support_modes * config.num_approach_templates,
39
+ )
40
+ self.persistence = nn.Linear(config.hidden_dim, config.num_support_modes)
41
+ self.disturbance = nn.Linear(config.hidden_dim, 1)
42
+
43
+ def _flatten_reveal(self, reveal_state: dict[str, Tensor]) -> Tensor:
44
+ return torch.cat(
45
+ [
46
+ reveal_state["support_mode_logits"],
47
+ reveal_state["corridor_logits"].flatten(start_dim=1),
48
+ reveal_state["persistence_horizon"],
49
+ reveal_state["disturbance_cost"].unsqueeze(-1),
50
+ ],
51
+ dim=-1,
52
+ )
53
+
54
+ def forward(self, scene_tokens: Tensor, reveal_state: dict[str, Tensor], action_chunk: Tensor) -> dict[str, Tensor]:
55
+ pooled = scene_tokens.mean(dim=1)
56
+ initial_hidden = self.initial(torch.cat([pooled, self._flatten_reveal(reveal_state)], dim=-1))
57
+ encoded_actions = self.action_encoder(action_chunk)
58
+ rollout, _ = self.gru(encoded_actions, initial_hidden.unsqueeze(0))
59
+ batch_size, horizon, _ = rollout.shape
60
+ return {
61
+ "support_mode_logits": self.support_mode(rollout),
62
+ "corridor_logits": self.corridor(rollout).view(
63
+ batch_size,
64
+ horizon,
65
+ self.config.num_support_modes,
66
+ self.config.num_approach_templates,
67
+ ),
68
+ "persistence_horizon": self.persistence(rollout),
69
+ "disturbance_cost": self.disturbance(rollout).squeeze(-1),
70
+ }
code/reveal_vla_bimanual/sim_reveal/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sim_reveal.base import RevealProxyConfig, RevealState, SupportMode
2
+ from sim_reveal.procedural_envs import ProceduralRevealEnv, available_proxy_names, make_proxy_env
3
+ from sim_reveal.proxy_specs import BAG_PROXY, CLOTH_PROXY, FOLIAGE_PROXY
4
+
5
+ __all__ = [
6
+ "BAG_PROXY",
7
+ "CLOTH_PROXY",
8
+ "FOLIAGE_PROXY",
9
+ "ProceduralRevealEnv",
10
+ "RevealProxyConfig",
11
+ "RevealState",
12
+ "SupportMode",
13
+ "available_proxy_names",
14
+ "make_proxy_env",
15
+ ]
code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (542 Bytes). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (653 Bytes). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc ADDED
Binary file (1.88 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (4.5 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (8.42 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-310.pyc ADDED
Binary file (1.37 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/generate_dataset.cpython-311.pyc ADDED
Binary file (2.44 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-310.pyc ADDED
Binary file (868 Bytes). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_smoke.cpython-311.pyc ADDED
Binary file (1.45 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-310.pyc ADDED
Binary file (874 Bytes). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/isaac_wrapper.cpython-311.pyc ADDED
Binary file (1.22 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/labels.cpython-311.pyc ADDED
Binary file (3.57 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-310.pyc ADDED
Binary file (16.7 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc ADDED
Binary file (33.2 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-310.pyc ADDED
Binary file (922 Bytes). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc ADDED
Binary file (1.09 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/teachers.cpython-311.pyc ADDED
Binary file (3.68 kB). View file
 
code/reveal_vla_bimanual/sim_reveal/base.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from enum import IntEnum
5
+
6
+ import numpy as np
7
+
8
+
9
+ class SupportMode(IntEnum):
10
+ HOLD = 0
11
+ TRANSFER = 1
12
+ PASSIVE = 2
13
+
14
+
15
+ @dataclass
16
+ class RevealState:
17
+ support_mode_logits: np.ndarray
18
+ corridor_logits: np.ndarray
19
+ persistence_horizon: np.ndarray
20
+ disturbance_cost: np.ndarray
21
+ belief_map: np.ndarray | None = None
22
+
23
+
24
+ @dataclass
25
+ class RevealProxyConfig:
26
+ name: str
27
+ num_templates: int = 32
28
+ rollout_horizon: int = 5
29
+ max_steps: int = 80
30
+ disturbance_key: str = "disturbance_cost"
31
+ success_key: str = "retrieval_success"
32
+ metadata: dict[str, str] = field(default_factory=dict)
code/reveal_vla_bimanual/sim_reveal/dataset.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Sequence
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.utils.data import Dataset
9
+
10
+ from sim_reveal.procedural_envs import available_proxy_names, make_proxy_env, render_views_from_state
11
+
12
+
13
+ def collect_teacher_dataset(
14
+ proxy_names: Sequence[str] | None = None,
15
+ episodes_per_proxy: int = 32,
16
+ resolution: int = 96,
17
+ seed: int = 0,
18
+ chunk_horizon: int = 8,
19
+ rollout_horizon: int = 5,
20
+ ) -> dict[str, Any]:
21
+ proxy_names = tuple(proxy_names or available_proxy_names())
22
+ samples: list[dict[str, Any]] = []
23
+ summary: dict[str, dict[str, float]] = {}
24
+
25
+ for proxy_offset, proxy_name in enumerate(proxy_names):
26
+ proxy_samples = 0
27
+ proxy_success = 0
28
+ for episode_idx in range(episodes_per_proxy):
29
+ env = make_proxy_env(
30
+ proxy_name=proxy_name,
31
+ resolution=resolution,
32
+ seed=seed + proxy_offset * 10_000 + episode_idx,
33
+ rollout_horizon=rollout_horizon,
34
+ )
35
+ _, privileged_state = env.reset(seed=seed + proxy_offset * 10_000 + episode_idx)
36
+ while True:
37
+ action_chunk, rollout = env.teacher_chunk_and_rollout(
38
+ chunk_horizon=chunk_horizon,
39
+ rollout_horizon=rollout_horizon,
40
+ )
41
+ samples.append(
42
+ {
43
+ "proxy_name": proxy_name,
44
+ "episode_id": episode_idx,
45
+ "render_state": env.render_state(privileged_state),
46
+ "proprio": env.get_observation(privileged_state)["proprio"].astype("float32"),
47
+ "language_goal": env.get_observation(privileged_state)["text"],
48
+ "action_chunk": action_chunk.astype("float32"),
49
+ "support_mode": int(privileged_state["support_mode"]),
50
+ "corridor_feasible": privileged_state["corridor_feasible"].astype("float32"),
51
+ "persistence_horizon": privileged_state["persistence_horizon"].astype("float32"),
52
+ "disturbance_cost": float(privileged_state["disturbance_cost"]),
53
+ "belief_map": privileged_state["belief_map"].astype("float32"),
54
+ "rollout_support_mode": rollout["rollout_support_mode"].astype("int64"),
55
+ "rollout_corridor_feasible": rollout["rollout_corridor_feasible"].astype("float32"),
56
+ "rollout_persistence_horizon": rollout["rollout_persistence_horizon"].astype("float32"),
57
+ "rollout_disturbance_cost": rollout["rollout_disturbance_cost"].astype("float32"),
58
+ }
59
+ )
60
+ proxy_samples += 1
61
+ _, _, terminated, truncated, privileged_state = env.step(env.teacher_action())
62
+ if terminated:
63
+ proxy_success += 1
64
+ if terminated or truncated:
65
+ break
66
+ summary[proxy_name] = {
67
+ "episodes": float(episodes_per_proxy),
68
+ "samples": float(proxy_samples),
69
+ "teacher_success": proxy_success / float(max(1, episodes_per_proxy)),
70
+ }
71
+ return {
72
+ "resolution": resolution,
73
+ "chunk_horizon": chunk_horizon,
74
+ "rollout_horizon": rollout_horizon,
75
+ "samples": samples,
76
+ "summary": summary,
77
+ }
78
+
79
+
80
+ def save_teacher_dataset(output_path: str | Path, dataset_bundle: dict[str, Any]) -> Path:
81
+ output_path = Path(output_path)
82
+ output_path.parent.mkdir(parents=True, exist_ok=True)
83
+ torch.save(dataset_bundle, output_path)
84
+ return output_path
85
+
86
+
87
+ def load_teacher_dataset(dataset_path: str | Path) -> dict[str, Any]:
88
+ return torch.load(Path(dataset_path), map_location="cpu", weights_only=False)
89
+
90
+
91
+ class RevealOfflineDataset(Dataset[dict[str, Any]]):
92
+ def __init__(self, samples: Sequence[dict[str, Any]], resolution: int = 96) -> None:
93
+ self.samples = list(samples)
94
+ self.resolution = resolution
95
+
96
+ def __len__(self) -> int:
97
+ return len(self.samples)
98
+
99
+ def __getitem__(self, index: int) -> dict[str, Any]:
100
+ sample = self.samples[index]
101
+ images = render_views_from_state(
102
+ proxy_name=sample["proxy_name"],
103
+ render_state=sample["render_state"],
104
+ resolution=self.resolution,
105
+ )
106
+ stacked = torch.from_numpy(
107
+ torch.stack(
108
+ [
109
+ torch.from_numpy(images["front"]),
110
+ torch.from_numpy(images["wrist_left"]),
111
+ torch.from_numpy(images["wrist_right"]),
112
+ ],
113
+ dim=0,
114
+ ).numpy()
115
+ ).permute(0, 3, 1, 2).float() / 255.0
116
+ return {
117
+ "images": stacked,
118
+ "proprio": torch.as_tensor(sample["proprio"], dtype=torch.float32),
119
+ "texts": sample["language_goal"],
120
+ "action_chunk": torch.as_tensor(sample["action_chunk"], dtype=torch.float32),
121
+ "support_mode": torch.as_tensor(sample["support_mode"], dtype=torch.long),
122
+ "corridor_feasible": torch.as_tensor(sample["corridor_feasible"], dtype=torch.float32),
123
+ "persistence_horizon": torch.as_tensor(sample["persistence_horizon"], dtype=torch.float32),
124
+ "disturbance_cost": torch.as_tensor(sample["disturbance_cost"], dtype=torch.float32),
125
+ "belief_map": torch.as_tensor(sample["belief_map"], dtype=torch.float32).unsqueeze(0),
126
+ "rollout_support_mode": torch.as_tensor(sample["rollout_support_mode"], dtype=torch.long),
127
+ "rollout_corridor_feasible": torch.as_tensor(sample["rollout_corridor_feasible"], dtype=torch.float32),
128
+ "rollout_persistence_horizon": torch.as_tensor(sample["rollout_persistence_horizon"], dtype=torch.float32),
129
+ "rollout_disturbance_cost": torch.as_tensor(sample["rollout_disturbance_cost"], dtype=torch.float32),
130
+ "proxy_name": sample["proxy_name"],
131
+ "episode_id": sample["episode_id"],
132
+ }
133
+
134
+
135
+ def dataset_from_bundle(dataset_bundle: dict[str, Any], resolution: int | None = None) -> RevealOfflineDataset:
136
+ resolution = resolution or int(dataset_bundle["resolution"])
137
+ return RevealOfflineDataset(dataset_bundle["samples"], resolution=resolution)
code/reveal_vla_bimanual/sim_reveal/generate_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+
7
+ from sim_reveal.dataset import collect_teacher_dataset, save_teacher_dataset
8
+
9
+
10
+ def main() -> None:
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--proxies", nargs="*", default=None)
13
+ parser.add_argument("--episodes-per-proxy", type=int, default=32)
14
+ parser.add_argument("--resolution", type=int, default=96)
15
+ parser.add_argument("--seed", type=int, default=0)
16
+ parser.add_argument("--chunk-horizon", type=int, default=8)
17
+ parser.add_argument("--rollout-horizon", type=int, default=5)
18
+ parser.add_argument("--output-path", default="/workspace/data/reveal_proxy/reveal_proxy_teacher.pt")
19
+ args = parser.parse_args()
20
+
21
+ dataset_bundle = collect_teacher_dataset(
22
+ proxy_names=args.proxies,
23
+ episodes_per_proxy=args.episodes_per_proxy,
24
+ resolution=args.resolution,
25
+ seed=args.seed,
26
+ chunk_horizon=args.chunk_horizon,
27
+ rollout_horizon=args.rollout_horizon,
28
+ )
29
+ output_path = save_teacher_dataset(Path(args.output_path), dataset_bundle)
30
+ payload = {
31
+ "output_path": str(output_path),
32
+ "resolution": dataset_bundle["resolution"],
33
+ "num_samples": len(dataset_bundle["samples"]),
34
+ "summary": dataset_bundle["summary"],
35
+ }
36
+ print(json.dumps(payload, indent=2))
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
code/reveal_vla_bimanual/sim_reveal/isaac_smoke.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+
6
+ from sim_reveal.isaac_wrapper import IsaacRevealRuntime
7
+
8
+
9
+ def main() -> None:
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--visible", action="store_true")
12
+ args = parser.parse_args()
13
+
14
+ runtime = IsaacRevealRuntime(headless=not args.visible)
15
+ try:
16
+ import isaacsim
17
+
18
+ payload = {
19
+ "headless": not args.visible,
20
+ "isaacsim_version": getattr(isaacsim, "__version__", "unknown"),
21
+ "status": "ok",
22
+ }
23
+ print(json.dumps(payload, indent=2))
24
+ finally:
25
+ runtime.close()
26
+
27
+
28
+ if __name__ == "__main__":
29
+ main()
code/reveal_vla_bimanual/sim_reveal/isaac_wrapper.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class IsaacRevealRuntime:
8
+ headless: bool = True
9
+
10
+ def __post_init__(self) -> None:
11
+ from isaacsim import SimulationApp
12
+
13
+ self._simulation_app = SimulationApp({"headless": self.headless})
14
+
15
+ def close(self) -> None:
16
+ self._simulation_app.close()
code/reveal_vla_bimanual/sim_reveal/labels.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+
7
+ from sim_reveal.base import RevealState, SupportMode
8
+
9
+
10
+ def privileged_state_to_reveal_labels(
11
+ state: dict[str, Any],
12
+ num_modes: int = 3,
13
+ num_templates: int = 32,
14
+ rollout_horizon: int = 5,
15
+ ) -> RevealState:
16
+ support_mode = int(state["support_mode"])
17
+ support_logits = np.full((num_modes,), -4.0, dtype=np.float32)
18
+ support_logits[support_mode] = 4.0
19
+
20
+ corridor = np.asarray(state["corridor_feasible"], dtype=np.float32)
21
+ if corridor.shape != (num_modes, num_templates):
22
+ raise ValueError(
23
+ f"Expected corridor_feasible shape {(num_modes, num_templates)}, got {corridor.shape}"
24
+ )
25
+ corridor_logits = np.where(corridor > 0.5, 4.0, -4.0).astype(np.float32)
26
+
27
+ persistence = np.asarray(state["persistence_horizon"], dtype=np.float32)
28
+ if persistence.shape != (num_modes,):
29
+ raise ValueError(f"Expected persistence_horizon shape {(num_modes,)}, got {persistence.shape}")
30
+ persistence = np.clip(persistence, 0.0, float(rollout_horizon))
31
+
32
+ disturbance = np.asarray([state["disturbance_cost"]], dtype=np.float32)
33
+ belief_map = state.get("belief_map")
34
+ if belief_map is not None:
35
+ belief_map = np.asarray(belief_map, dtype=np.float32)
36
+
37
+ return RevealState(
38
+ support_mode_logits=support_logits,
39
+ corridor_logits=corridor_logits,
40
+ persistence_horizon=persistence,
41
+ disturbance_cost=disturbance,
42
+ belief_map=belief_map,
43
+ )
44
+
45
+
46
+ def reocclusion_rate(corridor_open_history: np.ndarray) -> float:
47
+ corridor_open_history = np.asarray(corridor_open_history, dtype=np.float32)
48
+ if corridor_open_history.ndim != 1:
49
+ raise ValueError("corridor_open_history must be 1D.")
50
+ if corridor_open_history.size < 2:
51
+ return 0.0
52
+ open_then_closed = np.logical_and(corridor_open_history[:-1] > 0.5, corridor_open_history[1:] <= 0.5)
53
+ return float(open_then_closed.mean())
54
+
55
+
56
+ def infer_support_mode_from_flags(holding: bool, transferred: bool) -> SupportMode:
57
+ if holding:
58
+ return SupportMode.HOLD
59
+ if transferred:
60
+ return SupportMode.TRANSFER
61
+ return SupportMode.PASSIVE
code/reveal_vla_bimanual/sim_reveal/procedural_envs.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+
8
+ from sim_reveal.base import RevealProxyConfig, SupportMode
9
+ from sim_reveal.proxy_specs import BAG_PROXY, CLOTH_PROXY, FOLIAGE_PROXY
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class ProxyDynamics:
14
+ hold_decay: float
15
+ transfer_decay: float
16
+ passive_decay: float
17
+ disturbance_gain: float
18
+ settle_rate: float
19
+ desired_opening: float
20
+ preferred_mode: SupportMode
21
+ transfer_support_factor: float
22
+ passive_support_factor: float
23
+ visibility_bias: float
24
+ retrieve_visibility_threshold: float
25
+ palette: tuple[float, float, float]
26
+
27
+
28
+ PROXY_CONFIGS: dict[str, RevealProxyConfig] = {
29
+ FOLIAGE_PROXY.name: FOLIAGE_PROXY,
30
+ BAG_PROXY.name: BAG_PROXY,
31
+ CLOTH_PROXY.name: CLOTH_PROXY,
32
+ }
33
+
34
+ PROXY_DYNAMICS: dict[str, ProxyDynamics] = {
35
+ FOLIAGE_PROXY.name: ProxyDynamics(
36
+ hold_decay=0.02,
37
+ transfer_decay=0.07,
38
+ passive_decay=0.15,
39
+ disturbance_gain=0.06,
40
+ settle_rate=0.03,
41
+ desired_opening=0.60,
42
+ preferred_mode=SupportMode.HOLD,
43
+ transfer_support_factor=0.76,
44
+ passive_support_factor=0.42,
45
+ visibility_bias=0.03,
46
+ retrieve_visibility_threshold=0.42,
47
+ palette=(0.16, 0.30, 0.12),
48
+ ),
49
+ BAG_PROXY.name: ProxyDynamics(
50
+ hold_decay=0.04,
51
+ transfer_decay=0.03,
52
+ passive_decay=0.12,
53
+ disturbance_gain=0.05,
54
+ settle_rate=0.02,
55
+ desired_opening=0.68,
56
+ preferred_mode=SupportMode.TRANSFER,
57
+ transfer_support_factor=0.96,
58
+ passive_support_factor=0.55,
59
+ visibility_bias=0.06,
60
+ retrieve_visibility_threshold=0.48,
61
+ palette=(0.26, 0.17, 0.10),
62
+ ),
63
+ CLOTH_PROXY.name: ProxyDynamics(
64
+ hold_decay=0.03,
65
+ transfer_decay=0.05,
66
+ passive_decay=0.04,
67
+ disturbance_gain=0.04,
68
+ settle_rate=0.04,
69
+ desired_opening=0.50,
70
+ preferred_mode=SupportMode.PASSIVE,
71
+ transfer_support_factor=0.82,
72
+ passive_support_factor=0.90,
73
+ visibility_bias=0.08,
74
+ retrieve_visibility_threshold=0.38,
75
+ palette=(0.24, 0.24, 0.29),
76
+ ),
77
+ }
78
+
79
+ PROXY_GOALS = {
80
+ FOLIAGE_PROXY.name: "create a gap in the foliage and retrieve the target",
81
+ BAG_PROXY.name: "open the bag mouth and retrieve the target object",
82
+ CLOTH_PROXY.name: "lift the top layer enough to retrieve the hidden object",
83
+ }
84
+
85
+
86
+ def available_proxy_names() -> tuple[str, ...]:
87
+ return tuple(PROXY_CONFIGS.keys())
88
+
89
+
90
+ def make_proxy_env(
91
+ proxy_name: str,
92
+ resolution: int = 96,
93
+ seed: int = 0,
94
+ num_templates: int = 32,
95
+ rollout_horizon: int = 5,
96
+ max_steps: int | None = None,
97
+ ) -> "ProceduralRevealEnv":
98
+ return ProceduralRevealEnv(
99
+ proxy_name=proxy_name,
100
+ resolution=resolution,
101
+ seed=seed,
102
+ num_templates=num_templates,
103
+ rollout_horizon=rollout_horizon,
104
+ max_steps=max_steps,
105
+ )
106
+
107
+
108
+ class ProceduralRevealEnv:
109
+ camera_names = ("front", "wrist_left", "wrist_right")
110
+
111
+ def __init__(
112
+ self,
113
+ proxy_name: str,
114
+ resolution: int = 96,
115
+ seed: int = 0,
116
+ num_templates: int = 32,
117
+ rollout_horizon: int = 5,
118
+ max_steps: int | None = None,
119
+ ) -> None:
120
+ if proxy_name not in PROXY_CONFIGS:
121
+ raise KeyError(f"Unknown proxy: {proxy_name}")
122
+ self.proxy = PROXY_CONFIGS[proxy_name]
123
+ self.dynamics = PROXY_DYNAMICS[proxy_name]
124
+ self.proxy_name = proxy_name
125
+ self.resolution = resolution
126
+ self.num_templates = num_templates
127
+ self.rollout_horizon = rollout_horizon
128
+ self.max_steps = max_steps or self.proxy.max_steps
129
+ self.rng = np.random.default_rng(seed)
130
+ self.reset(seed=seed)
131
+
132
+ def clone_state(self) -> dict[str, Any]:
133
+ return {
134
+ "step_count": self.step_count,
135
+ "opening": self.opening,
136
+ "disturbance": self.disturbance,
137
+ "target_template": self.target_template,
138
+ "target_depth": self.target_depth,
139
+ "holding": self.holding,
140
+ "transferred": self.transferred,
141
+ "retrieved": self.retrieved,
142
+ "actor_progress": self.actor_progress,
143
+ "last_actor_template": self.last_actor_template,
144
+ "visibility_trace": list(self.visibility_trace),
145
+ "corridor_trace": list(self.corridor_trace),
146
+ }
147
+
148
+ def restore_state(self, state: dict[str, Any]) -> None:
149
+ self.step_count = int(state["step_count"])
150
+ self.opening = float(state["opening"])
151
+ self.disturbance = float(state["disturbance"])
152
+ self.target_template = int(state["target_template"])
153
+ self.target_depth = float(state["target_depth"])
154
+ self.holding = bool(state["holding"])
155
+ self.transferred = bool(state["transferred"])
156
+ self.retrieved = bool(state["retrieved"])
157
+ self.actor_progress = float(state["actor_progress"])
158
+ self.last_actor_template = int(state["last_actor_template"])
159
+ self.visibility_trace = list(state["visibility_trace"])
160
+ self.corridor_trace = list(state["corridor_trace"])
161
+
162
+ def reset(self, seed: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]:
163
+ if seed is not None:
164
+ self.rng = np.random.default_rng(seed)
165
+ self.step_count = 0
166
+ self.opening = float(self.rng.uniform(0.08, 0.22))
167
+ self.disturbance = float(self.rng.uniform(0.02, 0.12))
168
+ self.target_template = int(self.rng.integers(4, self.num_templates - 4))
169
+ self.target_depth = float(self.rng.uniform(0.15, 0.45))
170
+ self.holding = False
171
+ self.transferred = False
172
+ self.retrieved = False
173
+ self.actor_progress = 0.0
174
+ self.last_actor_template = self.target_template
175
+ privileged_state = self.get_privileged_state()
176
+ self.visibility_trace = [float(privileged_state["visibility"])]
177
+ self.corridor_trace = [float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any())]
178
+ return self.get_observation(privileged_state), privileged_state
179
+
180
+ def _normalized_template(self, template_index: int) -> float:
181
+ return (template_index / float(self.num_templates - 1)) * 2.0 - 1.0
182
+
183
+ def _current_support_mode(self) -> SupportMode:
184
+ if self.holding:
185
+ return SupportMode.HOLD
186
+ if self.transferred:
187
+ return SupportMode.TRANSFER
188
+ return SupportMode.PASSIVE
189
+
190
+ def _mode_from_action(self, action: np.ndarray) -> SupportMode:
191
+ hold_score = (np.tanh(float(action[6])) + 1.0) * 0.5
192
+ transfer_score = (np.tanh(float(action[1])) + 1.0) * 0.5
193
+ passive_score = (np.tanh(float(action[2])) + 1.0) * 0.5
194
+ if hold_score >= max(transfer_score, passive_score):
195
+ return SupportMode.HOLD
196
+ if transfer_score >= passive_score and self.opening >= 0.32:
197
+ return SupportMode.TRANSFER
198
+ return SupportMode.PASSIVE
199
+
200
+ def _visibility(self, opening: float | None = None, disturbance: float | None = None) -> float:
201
+ opening = self.opening if opening is None else float(opening)
202
+ disturbance = self.disturbance if disturbance is None else float(disturbance)
203
+ visibility = (
204
+ 1.35 * opening
205
+ - 0.58 * disturbance
206
+ - 0.25 * self.target_depth
207
+ + self.dynamics.visibility_bias
208
+ )
209
+ return float(np.clip(visibility, 0.0, 1.0))
210
+
211
+ def _mode_factor(self, mode: SupportMode) -> float:
212
+ if mode == SupportMode.HOLD:
213
+ return 1.0
214
+ if mode == SupportMode.TRANSFER:
215
+ return self.dynamics.transfer_support_factor
216
+ return self.dynamics.passive_support_factor
217
+
218
+ def _mode_decay(self, mode: SupportMode) -> float:
219
+ if mode == SupportMode.HOLD:
220
+ return self.dynamics.hold_decay
221
+ if mode == SupportMode.TRANSFER:
222
+ return self.dynamics.transfer_decay
223
+ return self.dynamics.passive_decay
224
+
225
+ def _corridor_for_mode(
226
+ self,
227
+ mode: SupportMode,
228
+ opening: float | None = None,
229
+ disturbance: float | None = None,
230
+ ) -> np.ndarray:
231
+ opening = self.opening if opening is None else float(opening)
232
+ disturbance = self.disturbance if disturbance is None else float(disturbance)
233
+ visibility = self._visibility(opening, disturbance)
234
+ effective = opening * self._mode_factor(mode) - 0.35 * disturbance - 0.18 * self.target_depth
235
+ width = int(np.floor(max(0.0, effective) * 8.0))
236
+ corridor = np.zeros((self.num_templates,), dtype=np.float32)
237
+ if visibility < self.dynamics.retrieve_visibility_threshold * 0.7 or width <= 0:
238
+ return corridor
239
+ low = max(0, self.target_template - width)
240
+ high = min(self.num_templates, self.target_template + width + 1)
241
+ corridor[low:high] = 1.0
242
+ return corridor
243
+
244
+ def _persistence_for_mode(self, mode: SupportMode) -> float:
245
+ opening = self.opening
246
+ disturbance = self.disturbance
247
+ persisted = 0.0
248
+ for _ in range(self.rollout_horizon):
249
+ if self._corridor_for_mode(mode, opening, disturbance).any():
250
+ persisted += 1.0
251
+ else:
252
+ break
253
+ opening = float(np.clip(opening - self._mode_decay(mode) + (0.035 if mode == SupportMode.HOLD else 0.0), 0.0, 1.0))
254
+ disturbance = float(np.clip(disturbance * (1.0 - self.dynamics.settle_rate), 0.0, 1.0))
255
+ return persisted
256
+
257
+ def _belief_map(self, visibility: float) -> np.ndarray:
258
+ side = 32
259
+ x = np.linspace(0.0, 1.0, side, dtype=np.float32)
260
+ y = np.linspace(0.0, 1.0, side, dtype=np.float32)
261
+ yy, xx = np.meshgrid(y, x, indexing="ij")
262
+ center_x = self.target_template / float(self.num_templates - 1)
263
+ center_y = 0.72 - 0.25 * self.target_depth
264
+ sigma = 0.08 + 0.05 * (1.0 - visibility)
265
+ belief = np.exp(-(((xx - center_x) ** 2) + ((yy - center_y) ** 2)) / (2.0 * sigma**2))
266
+ belief *= visibility
267
+ return belief.astype(np.float32)
268
+
269
+ def get_privileged_state(self) -> dict[str, Any]:
270
+ support_mode = int(self._current_support_mode())
271
+ corridor = np.stack(
272
+ [self._corridor_for_mode(mode) for mode in SupportMode],
273
+ axis=0,
274
+ )
275
+ persistence = np.asarray([self._persistence_for_mode(mode) for mode in SupportMode], dtype=np.float32)
276
+ visibility = self._visibility()
277
+ disturbance_cost = float(np.clip(self.disturbance + 0.08 * max(0.0, self.opening - self.dynamics.desired_opening), 0.0, 1.0))
278
+ return {
279
+ "support_mode": support_mode,
280
+ "corridor_feasible": corridor,
281
+ "persistence_horizon": persistence,
282
+ "disturbance_cost": disturbance_cost,
283
+ "belief_map": self._belief_map(visibility),
284
+ "visibility": visibility,
285
+ "retrieval_success": bool(self.retrieved),
286
+ "target_template": self.target_template,
287
+ }
288
+
289
+ def render_state(self, privileged_state: dict[str, Any] | None = None) -> dict[str, Any]:
290
+ privileged_state = privileged_state or self.get_privileged_state()
291
+ current_mode = int(privileged_state["support_mode"])
292
+ return {
293
+ "opening": float(self.opening),
294
+ "disturbance": float(self.disturbance),
295
+ "target_template": int(self.target_template),
296
+ "support_mode": current_mode,
297
+ "visibility": float(privileged_state["visibility"]),
298
+ "actor_template": int(self.last_actor_template),
299
+ "actor_progress": float(self.actor_progress),
300
+ "corridor_current": privileged_state["corridor_feasible"][current_mode].astype(np.float32),
301
+ "step_fraction": float(self.step_count / max(1, self.max_steps)),
302
+ }
303
+
304
+ def _proprio(self, privileged_state: dict[str, Any]) -> np.ndarray:
305
+ mode = privileged_state["support_mode"]
306
+ features = np.zeros((32,), dtype=np.float32)
307
+ features[0] = self.opening
308
+ features[1] = self.disturbance
309
+ features[2] = privileged_state["visibility"]
310
+ features[3 + mode] = 1.0
311
+ features[6] = self.target_template / float(self.num_templates - 1)
312
+ features[7] = self.last_actor_template / float(self.num_templates - 1)
313
+ features[8] = self.step_count / float(max(1, self.max_steps))
314
+ features[9:12] = privileged_state["persistence_horizon"] / float(self.rollout_horizon)
315
+ features[12] = float(privileged_state["corridor_feasible"][mode].any())
316
+ features[13] = float(self.retrieved)
317
+ features[14] = self.actor_progress
318
+ return features
319
+
320
+ def get_observation(self, privileged_state: dict[str, Any] | None = None) -> dict[str, Any]:
321
+ privileged_state = privileged_state or self.get_privileged_state()
322
+ render_state = self.render_state(privileged_state)
323
+ images = render_views_from_state(
324
+ proxy_name=self.proxy_name,
325
+ render_state=render_state,
326
+ resolution=self.resolution,
327
+ num_templates=self.num_templates,
328
+ )
329
+ return {
330
+ "images": np.stack([images[camera] for camera in self.camera_names], axis=0),
331
+ "proprio": self._proprio(privileged_state),
332
+ "text": PROXY_GOALS[self.proxy_name],
333
+ "camera_names": self.camera_names,
334
+ "render_state": render_state,
335
+ }
336
+
337
+ def teacher_action(self) -> np.ndarray:
338
+ privileged_state = self.get_privileged_state()
339
+ preferred_mode = self.dynamics.preferred_mode
340
+ if self.opening < self.dynamics.desired_opening:
341
+ chosen_mode = SupportMode.HOLD
342
+ open_cmd = 0.95
343
+ elif privileged_state["persistence_horizon"][preferred_mode] >= 2.0:
344
+ chosen_mode = preferred_mode
345
+ open_cmd = 0.12
346
+ else:
347
+ chosen_mode = SupportMode.HOLD
348
+ open_cmd = 0.30
349
+
350
+ corridor = privileged_state["corridor_feasible"][int(chosen_mode)]
351
+ actor_ready = bool(corridor[self.target_template] > 0.5)
352
+ retrieve = (
353
+ actor_ready
354
+ and privileged_state["visibility"] >= self.dynamics.retrieve_visibility_threshold
355
+ and self.actor_progress >= 0.55
356
+ )
357
+ action = np.zeros((14,), dtype=np.float32)
358
+ action[0] = np.float32(open_cmd)
359
+ action[1] = np.float32(1.0 if chosen_mode == SupportMode.TRANSFER else -1.0)
360
+ action[2] = np.float32(1.0 if chosen_mode == SupportMode.PASSIVE else -1.0)
361
+ action[6] = np.float32(1.0 if chosen_mode == SupportMode.HOLD else -1.0)
362
+ action[7] = np.float32(self._normalized_template(self.target_template))
363
+ action[8] = np.float32(1.0 if actor_ready else 0.2)
364
+ action[13] = np.float32(1.0 if retrieve else -1.0)
365
+ return action
366
+
367
+ def teacher_chunk_and_rollout(
368
+ self,
369
+ chunk_horizon: int = 8,
370
+ rollout_horizon: int | None = None,
371
+ ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
372
+ rollout_horizon = rollout_horizon or self.rollout_horizon
373
+ snapshot = self.clone_state()
374
+ action_chunk: list[np.ndarray] = []
375
+ rollout_support_mode = []
376
+ rollout_corridor = []
377
+ rollout_persistence = []
378
+ rollout_disturbance = []
379
+ for step in range(chunk_horizon):
380
+ action = self.teacher_action()
381
+ action_chunk.append(action)
382
+ _, _, terminated, truncated, privileged_state = self.step(action)
383
+ if step < rollout_horizon:
384
+ rollout_support_mode.append(privileged_state["support_mode"])
385
+ rollout_corridor.append(privileged_state["corridor_feasible"])
386
+ rollout_persistence.append(privileged_state["persistence_horizon"])
387
+ rollout_disturbance.append(privileged_state["disturbance_cost"])
388
+ if terminated or truncated:
389
+ break
390
+ while len(action_chunk) < chunk_horizon:
391
+ action_chunk.append(np.zeros((14,), dtype=np.float32))
392
+ while len(rollout_support_mode) < rollout_horizon:
393
+ rollout_support_mode.append(int(self._current_support_mode()))
394
+ rollout_corridor.append(self.get_privileged_state()["corridor_feasible"])
395
+ rollout_persistence.append(self.get_privileged_state()["persistence_horizon"])
396
+ rollout_disturbance.append(self.get_privileged_state()["disturbance_cost"])
397
+ self.restore_state(snapshot)
398
+ return np.stack(action_chunk, axis=0).astype(np.float32), {
399
+ "rollout_support_mode": np.asarray(rollout_support_mode, dtype=np.int64),
400
+ "rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32),
401
+ "rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32),
402
+ "rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32),
403
+ }
404
+
405
+ def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
406
+ action = np.asarray(action, dtype=np.float32)
407
+ mode = self._mode_from_action(action)
408
+ self.holding = mode == SupportMode.HOLD
409
+ self.transferred = mode == SupportMode.TRANSFER
410
+ open_cmd = float(np.clip(action[0], -1.0, 1.0))
411
+ actor_reach = float((np.tanh(float(action[8])) + 1.0) * 0.5)
412
+ retrieve_cmd = float((np.tanh(float(action[13])) + 1.0) * 0.5)
413
+ self.last_actor_template = int(
414
+ np.clip(
415
+ round(((float(np.clip(action[7], -1.0, 1.0)) + 1.0) * 0.5) * (self.num_templates - 1)),
416
+ 0,
417
+ self.num_templates - 1,
418
+ )
419
+ )
420
+
421
+ support_bonus = {SupportMode.HOLD: 0.08, SupportMode.TRANSFER: 0.04, SupportMode.PASSIVE: 0.0}[mode]
422
+ closure = self._mode_decay(mode)
423
+ self.opening = float(
424
+ np.clip(
425
+ self.opening + 0.16 * open_cmd + support_bonus - closure - 0.05 * self.disturbance,
426
+ 0.0,
427
+ 1.0,
428
+ )
429
+ )
430
+ self.disturbance = float(
431
+ np.clip(
432
+ self.disturbance
433
+ + self.dynamics.disturbance_gain * abs(open_cmd)
434
+ + 0.025 * actor_reach
435
+ + 0.05 * max(0.0, self.opening - self.dynamics.desired_opening)
436
+ - self.dynamics.settle_rate,
437
+ 0.0,
438
+ 1.0,
439
+ )
440
+ )
441
+
442
+ self.step_count += 1
443
+ privileged_state = self.get_privileged_state()
444
+ corridor = privileged_state["corridor_feasible"][privileged_state["support_mode"]]
445
+ if corridor[self.last_actor_template] > 0.5 and actor_reach >= 0.55:
446
+ persistence_ratio = privileged_state["persistence_horizon"][privileged_state["support_mode"]] / float(
447
+ max(1, self.rollout_horizon)
448
+ )
449
+ self.actor_progress = float(np.clip(self.actor_progress + 0.55 * persistence_ratio, 0.0, 1.0))
450
+ shock = 0.16 * max(0.0, 0.8 - persistence_ratio)
451
+ if shock > 0.0:
452
+ self.opening = float(np.clip(self.opening - shock, 0.0, 1.0))
453
+ privileged_state = self.get_privileged_state()
454
+ corridor = privileged_state["corridor_feasible"][privileged_state["support_mode"]]
455
+ else:
456
+ self.actor_progress = float(np.clip(self.actor_progress - 0.20, 0.0, 1.0))
457
+ success = bool(
458
+ retrieve_cmd >= 0.55
459
+ and self.actor_progress >= 0.80
460
+ and corridor[self.last_actor_template] > 0.5
461
+ and privileged_state["visibility"] >= self.dynamics.retrieve_visibility_threshold
462
+ and self.disturbance < 0.9
463
+ )
464
+ if success:
465
+ self.retrieved = True
466
+ privileged_state["retrieval_success"] = True
467
+
468
+ self.visibility_trace.append(float(privileged_state["visibility"]))
469
+ self.corridor_trace.append(float(corridor.any()))
470
+
471
+ reward = 1.0 if success else (0.08 * privileged_state["visibility"] - 0.03 * privileged_state["disturbance_cost"])
472
+ terminated = bool(self.retrieved)
473
+ truncated = bool(self.step_count >= self.max_steps)
474
+ return self.get_observation(privileged_state), float(reward), terminated, truncated, privileged_state
475
+
476
+
477
+ def render_views_from_state(
478
+ proxy_name: str,
479
+ render_state: dict[str, Any],
480
+ resolution: int,
481
+ num_templates: int = 32,
482
+ ) -> dict[str, np.ndarray]:
483
+ dynamics = PROXY_DYNAMICS[proxy_name]
484
+ opening = float(render_state["opening"])
485
+ disturbance = float(render_state["disturbance"])
486
+ target_template = int(render_state["target_template"])
487
+ support_mode = int(render_state["support_mode"])
488
+ visibility = float(render_state["visibility"])
489
+ actor_template = int(render_state["actor_template"])
490
+ actor_progress = float(render_state["actor_progress"])
491
+ corridor_current = np.asarray(render_state["corridor_current"], dtype=np.float32)
492
+ step_fraction = float(render_state["step_fraction"])
493
+
494
+ height = width = resolution
495
+ base = np.ones((height, width, 3), dtype=np.float32)
496
+ base *= np.asarray(dynamics.palette, dtype=np.float32)
497
+
498
+ x = np.linspace(0.0, 1.0, width, dtype=np.float32)
499
+ y = np.linspace(0.0, 1.0, height, dtype=np.float32)
500
+ yy, xx = np.meshgrid(y, x, indexing="ij")
501
+ center_x = target_template / float(max(1, num_templates - 1))
502
+ gap_width = 0.04 + 0.18 * opening
503
+ gap_mask = np.abs(xx - center_x) <= gap_width
504
+ stripe_mask = (np.sin(xx * np.pi * 18.0) > 0.2).astype(np.float32)
505
+
506
+ front = base.copy()
507
+ front[..., 1] += 0.22 * stripe_mask
508
+ front[..., 0] += 0.07 * stripe_mask
509
+ front[gap_mask, :] = np.clip(front[gap_mask, :] + np.asarray([0.18, 0.18, 0.18], dtype=np.float32), 0.0, 1.0)
510
+ target_mask = ((xx - center_x) ** 2 + (yy - 0.76) ** 2) <= (0.03 + 0.015 * visibility) ** 2
511
+ front[target_mask, 0] = np.clip(front[target_mask, 0] + 0.55 * visibility, 0.0, 1.0)
512
+ front[target_mask, 1] *= 0.55
513
+ front[..., 2] = np.clip(front[..., 2] + 0.18 * disturbance + 0.05 * step_fraction, 0.0, 1.0)
514
+
515
+ wrist_left = np.full((height, width, 3), 0.12, dtype=np.float32)
516
+ open_rows = int(opening * height)
517
+ wrist_left[height - open_rows :, : width // 3, 1] = 0.75
518
+ wrist_left[height - int(disturbance * height) :, width // 3 : (2 * width) // 3, 0] = 0.85
519
+ mode_colors = {
520
+ SupportMode.HOLD: np.asarray([0.92, 0.82, 0.16], dtype=np.float32),
521
+ SupportMode.TRANSFER: np.asarray([0.16, 0.78, 0.92], dtype=np.float32),
522
+ SupportMode.PASSIVE: np.asarray([0.86, 0.86, 0.86], dtype=np.float32),
523
+ }
524
+ wrist_left[:, (2 * width) // 3 :, :] = mode_colors[SupportMode(support_mode)]
525
+
526
+ wrist_right = np.full((height, width, 3), 0.08, dtype=np.float32)
527
+ template_edges = np.linspace(0, width, num_templates + 1, dtype=np.int32)
528
+ for template_idx in range(num_templates):
529
+ col_start = template_edges[template_idx]
530
+ col_end = template_edges[template_idx + 1]
531
+ if corridor_current[template_idx] > 0.5:
532
+ wrist_right[:, col_start:col_end, 1] = 0.70
533
+ if template_idx == target_template:
534
+ wrist_right[:, col_start:col_end, 0] = 0.78
535
+ if template_idx == actor_template:
536
+ wrist_right[:, col_start:col_end, 2] = 0.90
537
+ wrist_right[: max(1, int(visibility * height)), :, :] += 0.10
538
+ wrist_right[height - max(1, int(actor_progress * height)) :, :, 2] += 0.12
539
+ wrist_right = np.clip(wrist_right, 0.0, 1.0)
540
+
541
+ return {
542
+ "front": (front * 255.0).astype(np.uint8),
543
+ "wrist_left": (wrist_left * 255.0).astype(np.uint8),
544
+ "wrist_right": (wrist_right * 255.0).astype(np.uint8),
545
+ }