lsnu commited on
Commit
504ec88
·
verified ·
1 Parent(s): 58418ff

Add files using upload-large-folder tool

Browse files
Files changed (45) hide show
  1. code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc +0 -0
  2. code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc +0 -0
  3. code/reveal_vla_bimanual/eval/__pycache__/run_peract2_launch_smoke.cpython-310.pyc +0 -0
  4. code/reveal_vla_bimanual/eval/__pycache__/run_peract2_task_sweep.cpython-310.pyc +0 -0
  5. code/reveal_vla_bimanual/eval/__pycache__/run_rlbench_rollout_eval.cpython-310.pyc +0 -0
  6. code/reveal_vla_bimanual/eval/metrics.py +85 -0
  7. code/reveal_vla_bimanual/eval/run_peract2_launch_smoke.py +131 -0
  8. code/reveal_vla_bimanual/eval/run_proxy_diagnostics.py +148 -26
  9. code/reveal_vla_bimanual/eval/run_reveal_benchmark.py +48 -0
  10. code/reveal_vla_bimanual/eval/run_rlbench_rollout_eval.py +19 -1
  11. code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc +0 -0
  12. code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc +0 -0
  13. code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc +0 -0
  14. code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc +0 -0
  15. code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc +0 -0
  16. code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc +0 -0
  17. code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc +0 -0
  18. code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc +0 -0
  19. code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-310.pyc +0 -0
  20. code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-311.pyc +0 -0
  21. code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc +0 -0
  22. code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc +0 -0
  23. code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc +0 -0
  24. code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc +0 -0
  25. code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc +0 -0
  26. code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc +0 -0
  27. code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc +0 -0
  28. code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc +0 -0
  29. code/reveal_vla_bimanual/models/action_decoder.py +304 -0
  30. code/reveal_vla_bimanual/models/backbones.py +249 -24
  31. code/reveal_vla_bimanual/models/multiview_fusion.py +74 -3
  32. code/reveal_vla_bimanual/models/observation_memory.py +192 -0
  33. code/reveal_vla_bimanual/models/planner.py +191 -0
  34. code/reveal_vla_bimanual/models/policy.py +319 -5
  35. code/reveal_vla_bimanual/models/reveal_head.py +242 -0
  36. code/reveal_vla_bimanual/models/world_model.py +185 -0
  37. code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc +0 -0
  38. code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc +0 -0
  39. code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc +0 -0
  40. code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc +0 -0
  41. code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc +0 -0
  42. code/reveal_vla_bimanual/sim_reveal/dataset.py +133 -14
  43. code/reveal_vla_bimanual/sim_reveal/procedural_envs.py +210 -5
  44. code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc +0 -0
  45. code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc +0 -0
code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc and b/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc differ
 
code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc and b/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc differ
 
code/reveal_vla_bimanual/eval/__pycache__/run_peract2_launch_smoke.cpython-310.pyc ADDED
Binary file (4.31 kB). View file
 
code/reveal_vla_bimanual/eval/__pycache__/run_peract2_task_sweep.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
code/reveal_vla_bimanual/eval/__pycache__/run_rlbench_rollout_eval.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/eval/__pycache__/run_rlbench_rollout_eval.cpython-310.pyc and b/code/reveal_vla_bimanual/eval/__pycache__/run_rlbench_rollout_eval.cpython-310.pyc differ
 
code/reveal_vla_bimanual/eval/metrics.py CHANGED
@@ -22,6 +22,13 @@ class PlannerDiagnostics:
22
  regret: float
23
  risk_calibration_mse: float
24
  role_collapse_rate: float
 
 
 
 
 
 
 
25
 
26
 
27
  def mean_success(per_task_success: dict[str, float]) -> float:
@@ -87,6 +94,84 @@ def risk_calibration_mse(predicted_risk: np.ndarray, realized_risk: np.ndarray)
87
  return float(np.mean((predicted_risk - realized_risk) ** 2))
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def role_collapse_rate(
91
  action_chunks: np.ndarray,
92
  arm_role_logits: np.ndarray | None = None,
 
22
  regret: float
23
  risk_calibration_mse: float
24
  role_collapse_rate: float
25
+ proposal_diversity: float | None = None
26
+ planner_score_utility_spearman: float | None = None
27
+ left_right_equivariance_error: float | None = None
28
+ belief_calibration_brier: float | None = None
29
+ reocclusion_calibration_brier: float | None = None
30
+ support_stability_mae: float | None = None
31
+ clearance_auc: float | None = None
32
 
33
 
34
  def mean_success(per_task_success: dict[str, float]) -> float:
 
94
  return float(np.mean((predicted_risk - realized_risk) ** 2))
95
 
96
 
97
+ def proposal_diversity(proposal_chunks: np.ndarray) -> float:
98
+ proposal_chunks = np.asarray(proposal_chunks, dtype=np.float32)
99
+ if proposal_chunks.ndim != 4 or proposal_chunks.shape[1] <= 1:
100
+ return 0.0
101
+ flat = proposal_chunks.reshape(proposal_chunks.shape[0], proposal_chunks.shape[1], -1)
102
+ diffs = flat[:, :, None, :] - flat[:, None, :, :]
103
+ distances = np.abs(diffs).mean(axis=-1)
104
+ mask = ~np.eye(distances.shape[1], dtype=bool)
105
+ if not mask.any():
106
+ return 0.0
107
+ off_diagonal = distances[:, mask]
108
+ return float(off_diagonal.mean())
109
+
110
+
111
+ def planner_score_utility_spearman(pred_scores: np.ndarray, oracle_utility: np.ndarray) -> float:
112
+ pred_scores = np.asarray(pred_scores, dtype=np.float32)
113
+ oracle_utility = np.asarray(oracle_utility, dtype=np.float32)
114
+ if pred_scores.size == 0:
115
+ return 0.0
116
+ pred_rank = pred_scores.argsort(axis=-1).argsort(axis=-1).astype(np.float32)
117
+ oracle_rank = oracle_utility.argsort(axis=-1).argsort(axis=-1).astype(np.float32)
118
+ pred_rank = pred_rank - pred_rank.mean(axis=-1, keepdims=True)
119
+ oracle_rank = oracle_rank - oracle_rank.mean(axis=-1, keepdims=True)
120
+ denom = np.sqrt((pred_rank**2).sum(axis=-1) * (oracle_rank**2).sum(axis=-1))
121
+ valid = denom > 1e-6
122
+ if not np.any(valid):
123
+ return 0.0
124
+ corr = np.zeros_like(denom)
125
+ corr[valid] = (pred_rank[valid] * oracle_rank[valid]).sum(axis=-1) / denom[valid]
126
+ return float(corr.mean())
127
+
128
+
129
+ def left_right_equivariance_error(pred: np.ndarray, swapped_target: np.ndarray) -> float:
130
+ pred = np.asarray(pred, dtype=np.float32)
131
+ swapped_target = np.asarray(swapped_target, dtype=np.float32)
132
+ if pred.size == 0 or swapped_target.size == 0:
133
+ return 0.0
134
+ return float(np.abs(pred - swapped_target).mean())
135
+
136
+
137
+ def belief_calibration_brier(predicted_belief: np.ndarray, target_belief: np.ndarray) -> float:
138
+ predicted_belief = np.asarray(predicted_belief, dtype=np.float32)
139
+ target_belief = np.asarray(target_belief, dtype=np.float32)
140
+ if predicted_belief.size == 0:
141
+ return 0.0
142
+ return float(np.mean((predicted_belief - target_belief) ** 2))
143
+
144
+
145
+ def reocclusion_calibration_brier(predicted_reocclusion: np.ndarray, target_reocclusion: np.ndarray) -> float:
146
+ predicted_reocclusion = np.asarray(predicted_reocclusion, dtype=np.float32)
147
+ target_reocclusion = np.asarray(target_reocclusion, dtype=np.float32)
148
+ if predicted_reocclusion.size == 0:
149
+ return 0.0
150
+ return float(np.mean((predicted_reocclusion - target_reocclusion) ** 2))
151
+
152
+
153
+ def support_stability_mae(predicted: np.ndarray, target: np.ndarray) -> float:
154
+ predicted = np.asarray(predicted, dtype=np.float32)
155
+ target = np.asarray(target, dtype=np.float32)
156
+ if predicted.size == 0:
157
+ return 0.0
158
+ return float(np.abs(predicted - target).mean())
159
+
160
+
161
+ def clearance_auc(predicted: np.ndarray, target: np.ndarray) -> float:
162
+ predicted = np.asarray(predicted, dtype=np.float32).reshape(-1)
163
+ target = np.asarray(target, dtype=np.float32).reshape(-1)
164
+ positives = target > 0.5
165
+ negatives = ~positives
166
+ if positives.sum() == 0 or negatives.sum() == 0:
167
+ return 0.0
168
+ order = np.argsort(predicted)
169
+ ranks = np.empty_like(order, dtype=np.float32)
170
+ ranks[order] = np.arange(order.shape[0], dtype=np.float32)
171
+ pos_ranks = ranks[positives]
172
+ return float((pos_ranks.sum() - positives.sum() * (positives.sum() - 1) / 2.0) / (positives.sum() * negatives.sum()))
173
+
174
+
175
  def role_collapse_rate(
176
  action_chunks: np.ndarray,
177
  arm_role_logits: np.ndarray | None = None,
code/reveal_vla_bimanual/eval/run_peract2_launch_smoke.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import subprocess
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from sim_rlbench.task_splits import PERACT2_BIMANUAL_TASKS
11
+
12
+
13
+ def _parse_json_payload(stdout: str) -> dict[str, Any]:
14
+ start = stdout.find("{")
15
+ end = stdout.rfind("}")
16
+ if start == -1 or end == -1 or end < start:
17
+ raise ValueError("No JSON object found in subprocess stdout.")
18
+ return json.loads(stdout[start : end + 1])
19
+
20
+
21
+ def _run_task(project_root: Path, output_dir: Path, task_name: str, *, resolution: int, headless: bool) -> dict[str, Any]:
22
+ task_dir = output_dir / task_name
23
+ task_dir.mkdir(parents=True, exist_ok=True)
24
+ command = [
25
+ sys.executable,
26
+ "-m",
27
+ "sim_rlbench.launch_smoke",
28
+ "--task",
29
+ task_name,
30
+ "--resolution",
31
+ str(resolution),
32
+ ]
33
+ if headless:
34
+ command.append("--headless")
35
+ completed = subprocess.run(
36
+ command,
37
+ cwd=project_root,
38
+ text=True,
39
+ capture_output=True,
40
+ check=False,
41
+ )
42
+ (task_dir / "command.txt").write_text(" ".join(command) + "\n", encoding="utf-8")
43
+ (task_dir / "stdout.txt").write_text(completed.stdout, encoding="utf-8")
44
+ (task_dir / "stderr.txt").write_text(completed.stderr, encoding="utf-8")
45
+
46
+ payload: dict[str, Any] = {
47
+ "subprocess_returncode": int(completed.returncode),
48
+ "launch_ok": completed.returncode == 0,
49
+ }
50
+ try:
51
+ payload.update(_parse_json_payload(completed.stdout))
52
+ except Exception as exc:
53
+ payload["launch_ok"] = False
54
+ payload["error"] = f"json_parse_failed: {exc}"
55
+ if completed.returncode != 0 and "error" not in payload:
56
+ payload["error"] = f"subprocess_exit_{completed.returncode}"
57
+ return payload
58
+
59
+
60
+ def _write_markdown(path: Path, payload: dict[str, Any]) -> None:
61
+ lines = [
62
+ "# PerAct2 13-Task Launch Smoke",
63
+ "",
64
+ f"- Resolution: `{payload['resolution']}`",
65
+ f"- Headless: `{payload['headless']}`",
66
+ f"- Task count: `{payload['task_count']}`",
67
+ f"- Launch successes: `{payload['launch_successes']}`",
68
+ f"- Finite-action tasks: `{payload['finite_action_tasks']}`",
69
+ f"- Error tasks: `{payload['error_tasks']}`",
70
+ "",
71
+ "## Per-task",
72
+ "",
73
+ ]
74
+ for task_name, task_payload in payload["tasks"].items():
75
+ if "error" in task_payload:
76
+ lines.append(
77
+ f"- `{task_name}`: launch_ok={task_payload.get('launch_ok')}, "
78
+ f"action_finite={task_payload.get('action_finite')}, "
79
+ f"error={task_payload['error']}, "
80
+ f"subprocess_returncode={task_payload['subprocess_returncode']}"
81
+ )
82
+ else:
83
+ lines.append(
84
+ f"- `{task_name}`: launch_ok={task_payload.get('launch_ok')}, "
85
+ f"action_finite={task_payload.get('action_finite')}, "
86
+ f"task_class={task_payload.get('task')}, "
87
+ f"reward={task_payload.get('reward')}"
88
+ )
89
+ path.write_text("\n".join(lines) + "\n", encoding="utf-8")
90
+
91
+
92
+ def main() -> None:
93
+ parser = argparse.ArgumentParser()
94
+ parser.add_argument("--output-dir", required=True)
95
+ parser.add_argument("--tasks", nargs="*", default=list(PERACT2_BIMANUAL_TASKS))
96
+ parser.add_argument("--resolution", type=int, default=224)
97
+ parser.add_argument("--headless", action="store_true", default=True)
98
+ args = parser.parse_args()
99
+
100
+ project_root = Path(__file__).resolve().parents[1]
101
+ output_dir = Path(args.output_dir).resolve()
102
+ output_dir.mkdir(parents=True, exist_ok=True)
103
+
104
+ results: dict[str, Any] = {
105
+ "resolution": int(args.resolution),
106
+ "headless": bool(args.headless),
107
+ "tasks": {},
108
+ }
109
+ for task_name in tuple(args.tasks):
110
+ print(f"[peract2-launch-smoke] task={task_name}", flush=True)
111
+ results["tasks"][task_name] = _run_task(
112
+ project_root,
113
+ output_dir,
114
+ task_name,
115
+ resolution=args.resolution,
116
+ headless=args.headless,
117
+ )
118
+
119
+ task_payloads = list(results["tasks"].values())
120
+ results["task_count"] = len(task_payloads)
121
+ results["launch_successes"] = int(sum(1 for payload in task_payloads if payload.get("launch_ok")))
122
+ results["finite_action_tasks"] = int(sum(1 for payload in task_payloads if payload.get("action_finite")))
123
+ results["error_tasks"] = sorted(task_name for task_name, payload in results["tasks"].items() if "error" in payload)
124
+
125
+ (output_dir / "launch_smoke_summary.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
126
+ _write_markdown(output_dir / "launch_smoke_summary.md", results)
127
+ print(json.dumps(results, indent=2))
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
code/reveal_vla_bimanual/eval/run_proxy_diagnostics.py CHANGED
@@ -8,9 +8,22 @@ from typing import Any
8
  import numpy as np
9
  import torch
10
  from torch import Tensor
 
11
  from torch.utils.data import DataLoader
12
 
13
- from eval.metrics import planner_regret, planner_top1_accuracy, risk_calibration_mse, role_collapse_rate
 
 
 
 
 
 
 
 
 
 
 
 
14
  from eval.run_reveal_benchmark import load_model
15
  from sim_reveal.dataset import dataset_from_bundle, load_teacher_dataset
16
 
@@ -52,41 +65,121 @@ def main() -> None:
52
  risk_batches: list[np.ndarray] = []
53
  realized_risk_batches: list[np.ndarray] = []
54
  collapse_batches: list[float] = []
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  with torch.no_grad():
57
  for batch in loader:
58
  moved = _move_batch_to_device(batch, device)
59
- outputs = model(
60
- images=moved["images"],
61
- proprio=moved["proprio"],
62
- texts=moved["texts"],
63
- history_images=moved.get("history_images"),
64
- history_proprio=moved.get("history_proprio"),
65
- history_actions=moved.get("history_actions"),
66
- plan=True,
67
- candidate_chunks_override=moved["candidate_action_chunks"],
68
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  if "planner_scores" not in outputs:
70
  raise RuntimeError("Planner outputs were not produced for proxy diagnostics.")
71
- score_batches.append(outputs["planner_scores"].detach().cpu().numpy())
72
- utility_batches.append(moved["candidate_utility"].detach().cpu().numpy())
73
- best_index_batches.append(outputs["best_candidate_indices"].detach().cpu().numpy())
74
- risk_batches.append(outputs["planner_risk_values"].detach().cpu().numpy())
75
- realized_risk_batches.append(
76
- torch.clamp(
77
- moved["candidate_final_disturbance_cost"] + moved["candidate_reocclusion_rate"],
78
- 0.0,
79
- 1.0,
80
- )
81
- .detach()
82
- .cpu()
83
- .numpy()
84
  )
 
 
 
 
 
 
 
 
 
 
85
  selected_chunk = outputs["planned_chunk"].detach().cpu().numpy()[:, None]
 
86
  role_logits = None
87
- if outputs.get("interaction_state") is not None:
88
- role_logits = outputs["interaction_state"]["arm_role_logits"].detach().cpu().numpy()[:, None]
89
  collapse_batches.append(role_collapse_rate(selected_chunk, role_logits))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  scores = np.concatenate(score_batches, axis=0) if score_batches else np.zeros((0, 0), dtype=np.float32)
92
  utility = np.concatenate(utility_batches, axis=0) if utility_batches else np.zeros((0, 0), dtype=np.float32)
@@ -101,8 +194,37 @@ def main() -> None:
101
  diagnostics = {
102
  "planner_top1_accuracy": planner_top1_accuracy(scores, utility),
103
  "planner_regret": planner_regret(selected_indices, utility),
 
104
  "risk_calibration_mse": risk_calibration_mse(predicted_risk, realized_risk),
105
  "role_collapse_rate": float(np.mean(collapse_batches)) if collapse_batches else 0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  "num_samples": int(scores.shape[0]),
107
  }
108
 
 
8
  import numpy as np
9
  import torch
10
  from torch import Tensor
11
+ import torch.nn.functional as F
12
  from torch.utils.data import DataLoader
13
 
14
+ from eval.metrics import (
15
+ belief_calibration_brier,
16
+ clearance_auc,
17
+ left_right_equivariance_error,
18
+ planner_regret,
19
+ planner_score_utility_spearman,
20
+ planner_top1_accuracy,
21
+ proposal_diversity,
22
+ reocclusion_calibration_brier,
23
+ risk_calibration_mse,
24
+ role_collapse_rate,
25
+ support_stability_mae,
26
+ )
27
  from eval.run_reveal_benchmark import load_model
28
  from sim_reveal.dataset import dataset_from_bundle, load_teacher_dataset
29
 
 
65
  risk_batches: list[np.ndarray] = []
66
  realized_risk_batches: list[np.ndarray] = []
67
  collapse_batches: list[float] = []
68
+ proposal_batches: list[np.ndarray] = []
69
+ equivariance_batches: list[float] = []
70
+ belief_pred_batches: list[np.ndarray] = []
71
+ belief_target_batches: list[np.ndarray] = []
72
+ reocclusion_pred_batches: list[np.ndarray] = []
73
+ reocclusion_target_batches: list[np.ndarray] = []
74
+ support_pred_batches: list[np.ndarray] = []
75
+ support_target_batches: list[np.ndarray] = []
76
+ clearance_pred_batches: list[np.ndarray] = []
77
+ clearance_target_batches: list[np.ndarray] = []
78
+ memory_write_batches: list[np.ndarray] = []
79
+ memory_saturation_batches: list[np.ndarray] = []
80
 
81
  with torch.no_grad():
82
  for batch in loader:
83
  moved = _move_batch_to_device(batch, device)
84
+ forward_kwargs = {
85
+ "images": moved["images"],
86
+ "proprio": moved["proprio"],
87
+ "texts": moved["texts"],
88
+ "history_images": moved.get("history_images"),
89
+ "history_proprio": moved.get("history_proprio"),
90
+ "history_actions": moved.get("history_actions"),
91
+ "plan": True,
92
+ "candidate_chunks_override": moved["candidate_action_chunks"],
93
+ }
94
+ if hasattr(model, "elastic_state_head"):
95
+ forward_kwargs.update(
96
+ {
97
+ "depths": moved.get("depths"),
98
+ "depth_valid": moved.get("depth_valid"),
99
+ "camera_intrinsics": moved.get("camera_intrinsics"),
100
+ "camera_extrinsics": moved.get("camera_extrinsics"),
101
+ "history_depths": moved.get("history_depths"),
102
+ "history_depth_valid": moved.get("history_depth_valid"),
103
+ "use_depth": moved.get("depths") is not None,
104
+ "use_world_model": True,
105
+ "use_planner": True,
106
+ "use_role_tokens": True,
107
+ "compute_equivariance_probe": True,
108
+ }
109
+ )
110
+ outputs = model(**forward_kwargs)
111
  if "planner_scores" not in outputs:
112
  raise RuntimeError("Planner outputs were not produced for proxy diagnostics.")
113
+ planner_scores = outputs["planner_scores"]
114
+ candidate_utility = moved["candidate_utility"]
115
+ predicted_risk = outputs["planner_risk_values"]
116
+ realized_risk = torch.clamp(
117
+ moved["candidate_final_disturbance_cost"] + moved["candidate_reocclusion_rate"],
118
+ 0.0,
119
+ 1.0,
 
 
 
 
 
 
120
  )
121
+ shortlist_indices = outputs.get("planner_topk_indices")
122
+ if shortlist_indices is not None:
123
+ candidate_utility = candidate_utility.gather(1, shortlist_indices)
124
+ predicted_risk = predicted_risk
125
+ realized_risk = realized_risk.gather(1, shortlist_indices)
126
+ score_batches.append(planner_scores.detach().cpu().numpy())
127
+ utility_batches.append(candidate_utility.detach().cpu().numpy())
128
+ best_index_batches.append(outputs["best_candidate_indices"].detach().cpu().numpy())
129
+ risk_batches.append(predicted_risk.detach().cpu().numpy())
130
+ realized_risk_batches.append(realized_risk.detach().cpu().numpy())
131
  selected_chunk = outputs["planned_chunk"].detach().cpu().numpy()[:, None]
132
+ state = outputs.get("interaction_state") or outputs.get("reveal_state")
133
  role_logits = None
134
+ if state is not None:
135
+ role_logits = state["arm_role_logits"].detach().cpu().numpy()[:, None]
136
  collapse_batches.append(role_collapse_rate(selected_chunk, role_logits))
137
+ if outputs.get("proposal_candidates") is not None:
138
+ proposal_batches.append(outputs["proposal_candidates"].detach().cpu().numpy())
139
+ if outputs.get("equivariance_probe_action_mean") is not None:
140
+ equivariance_batches.append(
141
+ left_right_equivariance_error(
142
+ outputs["equivariance_probe_action_mean"].detach().cpu().numpy(),
143
+ outputs["equivariance_target_action_mean"].detach().cpu().numpy(),
144
+ )
145
+ )
146
+ if state is not None:
147
+ if "belief_map" in state and "belief_map" in moved:
148
+ belief_pred_batches.append(torch.sigmoid(state["belief_map"]).detach().cpu().numpy())
149
+ belief_target_batches.append(moved["belief_map"].detach().cpu().numpy())
150
+ if "reocclusion_field" in state and "reocclusion_target" in moved:
151
+ reocclusion_pred_batches.append(torch.sigmoid(state["reocclusion_field"]).mean(dim=(-1, -2)).detach().cpu().numpy())
152
+ reocclusion_target_batches.append(moved["reocclusion_target"].detach().cpu().numpy())
153
+ if "support_stability_field" in state and "support_stability" in moved:
154
+ support_pred_batches.append(torch.sigmoid(state["support_stability_field"]).mean(dim=(-1, -2)).detach().cpu().numpy())
155
+ support_target_batches.append(moved["support_stability"].detach().cpu().numpy())
156
+ if "clearance_field" in state and "clearance_map" in moved:
157
+ clearance_pred = torch.sigmoid(state["clearance_field"])
158
+ clearance_target = moved["clearance_map"]
159
+ if clearance_pred.shape[-2:] != clearance_target.shape[-2:]:
160
+ clearance_pred = F.interpolate(
161
+ clearance_pred,
162
+ size=clearance_target.shape[-2:],
163
+ mode="bilinear",
164
+ align_corners=False,
165
+ )
166
+ if clearance_pred.shape[1] != clearance_target.shape[1]:
167
+ if clearance_pred.shape[1] == 1:
168
+ clearance_pred = clearance_pred.expand(-1, clearance_target.shape[1], -1, -1)
169
+ elif clearance_target.shape[1] == 1:
170
+ clearance_target = clearance_target.expand_as(clearance_pred)
171
+ else:
172
+ min_channels = min(clearance_pred.shape[1], clearance_target.shape[1])
173
+ clearance_pred = clearance_pred[:, :min_channels]
174
+ clearance_target = clearance_target[:, :min_channels]
175
+ clearance_pred_batches.append(clearance_pred.detach().cpu().numpy())
176
+ clearance_target_batches.append(clearance_target.detach().cpu().numpy())
177
+ if outputs.get("memory_output") is not None:
178
+ memory_output = outputs["memory_output"]
179
+ if "memory_write_rate" in memory_output:
180
+ memory_write_batches.append(memory_output["memory_write_rate"].detach().cpu().numpy())
181
+ if "memory_saturation" in memory_output:
182
+ memory_saturation_batches.append(memory_output["memory_saturation"].detach().cpu().numpy())
183
 
184
  scores = np.concatenate(score_batches, axis=0) if score_batches else np.zeros((0, 0), dtype=np.float32)
185
  utility = np.concatenate(utility_batches, axis=0) if utility_batches else np.zeros((0, 0), dtype=np.float32)
 
194
  diagnostics = {
195
  "planner_top1_accuracy": planner_top1_accuracy(scores, utility),
196
  "planner_regret": planner_regret(selected_indices, utility),
197
+ "planner_score_utility_spearman": planner_score_utility_spearman(scores, utility),
198
  "risk_calibration_mse": risk_calibration_mse(predicted_risk, realized_risk),
199
  "role_collapse_rate": float(np.mean(collapse_batches)) if collapse_batches else 0.0,
200
+ "proposal_diversity": proposal_diversity(np.concatenate(proposal_batches, axis=0)) if proposal_batches else 0.0,
201
+ "left_right_equivariance_error": float(np.mean(equivariance_batches)) if equivariance_batches else 0.0,
202
+ "belief_calibration_brier": belief_calibration_brier(
203
+ np.concatenate(belief_pred_batches, axis=0),
204
+ np.concatenate(belief_target_batches, axis=0),
205
+ )
206
+ if belief_pred_batches
207
+ else 0.0,
208
+ "reocclusion_calibration_brier": reocclusion_calibration_brier(
209
+ np.concatenate(reocclusion_pred_batches, axis=0),
210
+ np.concatenate(reocclusion_target_batches, axis=0),
211
+ )
212
+ if reocclusion_pred_batches
213
+ else 0.0,
214
+ "support_stability_mae": support_stability_mae(
215
+ np.concatenate(support_pred_batches, axis=0),
216
+ np.concatenate(support_target_batches, axis=0),
217
+ )
218
+ if support_pred_batches
219
+ else 0.0,
220
+ "clearance_auc": clearance_auc(
221
+ np.concatenate(clearance_pred_batches, axis=0),
222
+ np.concatenate(clearance_target_batches, axis=0),
223
+ )
224
+ if clearance_pred_batches
225
+ else 0.0,
226
+ "memory_write_rate": float(np.mean(np.concatenate(memory_write_batches, axis=0))) if memory_write_batches else 0.0,
227
+ "memory_saturation": float(np.mean(np.concatenate(memory_saturation_batches, axis=0))) if memory_saturation_batches else 0.0,
228
  "num_samples": int(scores.shape[0]),
229
  }
230
 
code/reveal_vla_bimanual/eval/run_reveal_benchmark.py CHANGED
@@ -73,12 +73,18 @@ def _prepare_batch(
73
  observation: dict[str, Any],
74
  device: torch.device,
75
  history_images: list[np.ndarray] | None = None,
 
 
76
  history_proprio: list[np.ndarray] | None = None,
77
  history_actions: list[np.ndarray] | None = None,
78
  ) -> dict[str, Any]:
79
  images = torch.from_numpy(observation["images"]).permute(0, 3, 1, 2).unsqueeze(0).float() / 255.0
 
 
80
  proprio = torch.from_numpy(observation["proprio"]).unsqueeze(0).float()
81
  history_images = history_images or []
 
 
82
  history_proprio = history_proprio or []
83
  history_actions = history_actions or []
84
  if history_images:
@@ -90,6 +96,12 @@ def _prepare_batch(
90
  (1, 0, images.shape[1], images.shape[2], images.shape[3], images.shape[4]),
91
  dtype=torch.float32,
92
  )
 
 
 
 
 
 
93
  if history_proprio:
94
  history_proprio_tensor = torch.from_numpy(np.stack(history_proprio, axis=0)).unsqueeze(0).float()
95
  else:
@@ -100,7 +112,13 @@ def _prepare_batch(
100
  history_actions_tensor = torch.zeros((1, 0, 14), dtype=torch.float32)
101
  return {
102
  "images": images.to(device),
 
 
 
 
103
  "history_images": history_images_tensor.to(device),
 
 
104
  "history_proprio": history_proprio_tensor.to(device),
105
  "history_actions": history_actions_tensor.to(device),
106
  "proprio": proprio.to(device),
@@ -147,6 +165,27 @@ def select_chunk(
147
  if "planned_chunk" in outputs and ablation not in {"no_world_model", "no_interaction_head"}:
148
  return outputs["planned_chunk"], outputs
149
  return outputs["action_mean"], outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  if hasattr(model, "reveal_head"):
151
  if ablation == "no_world_model":
152
  outputs = model(**forward_kwargs, plan=False)
@@ -195,6 +234,8 @@ def evaluate_model(
195
  episode_corridor = [float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any())]
196
  episode_disturbance = [float(privileged_state["disturbance_cost"])]
197
  history_images: list[np.ndarray] = []
 
 
198
  history_proprio: list[np.ndarray] = []
199
  history_actions: list[np.ndarray] = []
200
  done = False
@@ -203,6 +244,8 @@ def evaluate_model(
203
  observation,
204
  device=device,
205
  history_images=history_images,
 
 
206
  history_proprio=history_proprio,
207
  history_actions=history_actions,
208
  )
@@ -224,9 +267,14 @@ def evaluate_model(
224
  if history_steps > 0:
225
  if len(history_images) >= history_steps:
226
  history_images = history_images[-history_steps + 1 :]
 
 
227
  history_proprio = history_proprio[-history_steps + 1 :]
228
  history_actions = history_actions[-history_steps + 1 :]
229
  history_images.append(observation["images"])
 
 
 
230
  history_proprio.append(observation["proprio"])
231
  history_actions.append(action.astype(np.float32))
232
  observation, _, terminated, truncated, privileged_state = env.step(action)
 
73
  observation: dict[str, Any],
74
  device: torch.device,
75
  history_images: list[np.ndarray] | None = None,
76
+ history_depths: list[np.ndarray] | None = None,
77
+ history_depth_valid: list[np.ndarray] | None = None,
78
  history_proprio: list[np.ndarray] | None = None,
79
  history_actions: list[np.ndarray] | None = None,
80
  ) -> dict[str, Any]:
81
  images = torch.from_numpy(observation["images"]).permute(0, 3, 1, 2).unsqueeze(0).float() / 255.0
82
+ depths = torch.from_numpy(observation.get("depths", np.zeros((3, 1, images.shape[-2], images.shape[-1]), dtype=np.float32))).unsqueeze(0).float()
83
+ depth_valid = torch.from_numpy(observation.get("depth_valid", np.zeros((3, 1, images.shape[-2], images.shape[-1]), dtype=np.float32))).unsqueeze(0).float()
84
  proprio = torch.from_numpy(observation["proprio"]).unsqueeze(0).float()
85
  history_images = history_images or []
86
+ history_depths = history_depths or []
87
+ history_depth_valid = history_depth_valid or []
88
  history_proprio = history_proprio or []
89
  history_actions = history_actions or []
90
  if history_images:
 
96
  (1, 0, images.shape[1], images.shape[2], images.shape[3], images.shape[4]),
97
  dtype=torch.float32,
98
  )
99
+ if history_depths:
100
+ history_depths_tensor = torch.from_numpy(np.stack(history_depths, axis=0)).unsqueeze(0).float()
101
+ history_depth_valid_tensor = torch.from_numpy(np.stack(history_depth_valid, axis=0)).unsqueeze(0).float()
102
+ else:
103
+ history_depths_tensor = torch.zeros((1, 0, depths.shape[1], depths.shape[2], depths.shape[3], depths.shape[4]), dtype=torch.float32)
104
+ history_depth_valid_tensor = torch.zeros_like(history_depths_tensor)
105
  if history_proprio:
106
  history_proprio_tensor = torch.from_numpy(np.stack(history_proprio, axis=0)).unsqueeze(0).float()
107
  else:
 
112
  history_actions_tensor = torch.zeros((1, 0, 14), dtype=torch.float32)
113
  return {
114
  "images": images.to(device),
115
+ "depths": depths.to(device),
116
+ "depth_valid": depth_valid.to(device),
117
+ "camera_intrinsics": torch.from_numpy(observation.get("camera_intrinsics", np.zeros((3, 3, 3), dtype=np.float32))).unsqueeze(0).to(device),
118
+ "camera_extrinsics": torch.from_numpy(observation.get("camera_extrinsics", np.zeros((3, 4, 4), dtype=np.float32))).unsqueeze(0).to(device),
119
  "history_images": history_images_tensor.to(device),
120
+ "history_depths": history_depths_tensor.to(device),
121
+ "history_depth_valid": history_depth_valid_tensor.to(device),
122
  "history_proprio": history_proprio_tensor.to(device),
123
  "history_actions": history_actions_tensor.to(device),
124
  "proprio": proprio.to(device),
 
165
  if "planned_chunk" in outputs and ablation not in {"no_world_model", "no_interaction_head"}:
166
  return outputs["planned_chunk"], outputs
167
  return outputs["action_mean"], outputs
168
+ if hasattr(model, "elastic_state_head"):
169
+ outputs = model(
170
+ **forward_kwargs,
171
+ depths=batch.get("depths"),
172
+ depth_valid=batch.get("depth_valid"),
173
+ camera_intrinsics=batch.get("camera_intrinsics"),
174
+ camera_extrinsics=batch.get("camera_extrinsics"),
175
+ history_depths=batch.get("history_depths"),
176
+ history_depth_valid=batch.get("history_depth_valid"),
177
+ plan=True,
178
+ use_world_model=(ablation not in {"no_world_model", "no_planner"}),
179
+ use_planner=(ablation != "no_planner"),
180
+ use_depth=(ablation != "no_depth"),
181
+ use_role_tokens=(ablation not in {"no_role_tokens", "no_role_symmetry"}),
182
+ history_steps_override=(2 if ablation == "short_history" else None),
183
+ )
184
+ if "planned_chunk" in outputs and ablation != "no_planner":
185
+ return outputs["planned_chunk"], outputs
186
+ if "candidate_chunks" in outputs:
187
+ return outputs["candidate_chunks"][:, 0], outputs
188
+ return outputs["action_mean"], outputs
189
  if hasattr(model, "reveal_head"):
190
  if ablation == "no_world_model":
191
  outputs = model(**forward_kwargs, plan=False)
 
234
  episode_corridor = [float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any())]
235
  episode_disturbance = [float(privileged_state["disturbance_cost"])]
236
  history_images: list[np.ndarray] = []
237
+ history_depths: list[np.ndarray] = []
238
+ history_depth_valid: list[np.ndarray] = []
239
  history_proprio: list[np.ndarray] = []
240
  history_actions: list[np.ndarray] = []
241
  done = False
 
244
  observation,
245
  device=device,
246
  history_images=history_images,
247
+ history_depths=history_depths,
248
+ history_depth_valid=history_depth_valid,
249
  history_proprio=history_proprio,
250
  history_actions=history_actions,
251
  )
 
267
  if history_steps > 0:
268
  if len(history_images) >= history_steps:
269
  history_images = history_images[-history_steps + 1 :]
270
+ history_depths = history_depths[-history_steps + 1 :]
271
+ history_depth_valid = history_depth_valid[-history_steps + 1 :]
272
  history_proprio = history_proprio[-history_steps + 1 :]
273
  history_actions = history_actions[-history_steps + 1 :]
274
  history_images.append(observation["images"])
275
+ if "depths" in observation:
276
+ history_depths.append(observation["depths"])
277
+ history_depth_valid.append(observation["depth_valid"])
278
  history_proprio.append(observation["proprio"])
279
  history_actions.append(action.astype(np.float32))
280
  observation, _, terminated, truncated, privileged_state = env.step(action)
code/reveal_vla_bimanual/eval/run_rlbench_rollout_eval.py CHANGED
@@ -52,6 +52,19 @@ def _episode_language_goal(descriptions: Sequence[str]) -> str:
52
  return str(descriptions[0]) if descriptions else ""
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def _noop_bimanual_action(obs: Any) -> np.ndarray:
56
  right_obs = getattr(obs, "right", None)
57
  left_obs = getattr(obs, "left", None)
@@ -113,6 +126,7 @@ def main() -> None:
113
  parser.add_argument("--disable-support-mode-conditioning", action="store_true")
114
  parser.add_argument("--headless", action="store_true", default=True)
115
  parser.add_argument("--chunk-commit-steps", type=int, default=0)
 
116
  args = parser.parse_args()
117
 
118
  checkpoint = torch.load(Path(args.checkpoint), map_location="cpu", weights_only=False)
@@ -155,6 +169,7 @@ def main() -> None:
155
  "episodes_per_task": args.episodes_per_task,
156
  "episode_length": args.episode_length,
157
  "resolution": args.resolution,
 
158
  "cameras": list(camera_spec.cameras),
159
  "tasks": {},
160
  }
@@ -180,8 +195,10 @@ def main() -> None:
180
  )
181
  env.launch()
182
  task = env.get_task(task_class)
 
183
  for _ in range(args.episodes_per_task):
184
- descriptions, obs = task.reset()
 
185
  language_goal = _episode_language_goal(descriptions)
186
  total_reward = 0.0
187
  success = 0.0
@@ -291,6 +308,7 @@ def main() -> None:
291
  "returns": task_returns,
292
  "path_recoveries": episode_recoveries if args.episodes_per_task == 1 else None,
293
  "noop_fallbacks": episode_noop_fallbacks if args.episodes_per_task == 1 else None,
 
294
  "mean_success": float(np.mean(task_successes)) if task_successes else 0.0,
295
  "mean_return": float(np.mean(task_returns)) if task_returns else 0.0,
296
  }
 
52
  return str(descriptions[0]) if descriptions else ""
53
 
54
 
55
+ def _reset_task_with_retries(task: Any, max_attempts: int) -> tuple[Sequence[str], Any, int]:
56
+ last_error: Exception | None = None
57
+ for attempt in range(max_attempts):
58
+ try:
59
+ descriptions, obs = task.reset()
60
+ return descriptions, obs, attempt
61
+ except Exception as exc: # pragma: no cover - live RLBench failure path
62
+ last_error = exc
63
+ if last_error is not None:
64
+ raise last_error
65
+ raise RuntimeError("Task reset failed without raising a concrete exception.")
66
+
67
+
68
  def _noop_bimanual_action(obs: Any) -> np.ndarray:
69
  right_obs = getattr(obs, "right", None)
70
  left_obs = getattr(obs, "left", None)
 
126
  parser.add_argument("--disable-support-mode-conditioning", action="store_true")
127
  parser.add_argument("--headless", action="store_true", default=True)
128
  parser.add_argument("--chunk-commit-steps", type=int, default=0)
129
+ parser.add_argument("--reset-retries", type=int, default=20)
130
  args = parser.parse_args()
131
 
132
  checkpoint = torch.load(Path(args.checkpoint), map_location="cpu", weights_only=False)
 
169
  "episodes_per_task": args.episodes_per_task,
170
  "episode_length": args.episode_length,
171
  "resolution": args.resolution,
172
+ "reset_retries": args.reset_retries,
173
  "cameras": list(camera_spec.cameras),
174
  "tasks": {},
175
  }
 
195
  )
196
  env.launch()
197
  task = env.get_task(task_class)
198
+ task_reset_retries: list[int] = []
199
  for _ in range(args.episodes_per_task):
200
+ descriptions, obs, reset_retries = _reset_task_with_retries(task, max_attempts=max(1, args.reset_retries))
201
+ task_reset_retries.append(int(reset_retries))
202
  language_goal = _episode_language_goal(descriptions)
203
  total_reward = 0.0
204
  success = 0.0
 
308
  "returns": task_returns,
309
  "path_recoveries": episode_recoveries if args.episodes_per_task == 1 else None,
310
  "noop_fallbacks": episode_noop_fallbacks if args.episodes_per_task == 1 else None,
311
+ "reset_retries": task_reset_retries,
312
  "mean_success": float(np.mean(task_successes)) if task_successes else 0.0,
313
  "mean_return": float(np.mean(task_returns)) if task_returns else 0.0,
314
  }
code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-311.pyc ADDED
Binary file (22.6 kB). View file
 
code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc differ
 
code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc differ
 
code/reveal_vla_bimanual/models/action_decoder.py CHANGED
@@ -19,6 +19,8 @@ class ChunkDecoderConfig:
19
  num_candidates: int = 8
20
  num_phases: int = 5
21
  num_arm_roles: int = 4
 
 
22
 
23
 
24
  class ACTBimanualChunkDecoder(nn.Module):
@@ -381,3 +383,305 @@ class InteractionChunkDecoder(nn.Module):
381
  candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1)
382
  candidates[:, 0] = action_mean
383
  return candidates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  num_candidates: int = 8
20
  num_phases: int = 5
21
  num_arm_roles: int = 4
22
+ num_proposal_modes: int = 6
23
+ planner_top_k: int = 4
24
 
25
 
26
  class ACTBimanualChunkDecoder(nn.Module):
 
383
  candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1)
384
  candidates[:, 0] = action_mean
385
  return candidates
386
+
387
+
388
+ DEFAULT_PROPOSAL_MODES = (
389
+ "widen_opening",
390
+ "maintain_opening",
391
+ "slide_occluder",
392
+ "lift_support_layer",
393
+ "stabilize_support",
394
+ "retrieve",
395
+ )
396
+
397
+
398
+ def swap_arm_action_order(action_chunk: Tensor) -> Tensor:
399
+ midpoint = action_chunk.shape[-1] // 2
400
+ return torch.cat([action_chunk[..., midpoint:], action_chunk[..., :midpoint]], dim=-1)
401
+
402
+
403
+ class SymmetricCoordinatedChunkDecoder(nn.Module):
404
+ def __init__(self, config: ChunkDecoderConfig) -> None:
405
+ super().__init__()
406
+ self.config = config
407
+ proposal_context_dim = config.action_dim + (config.hidden_dim * 2)
408
+ decoder_layer = nn.TransformerDecoderLayer(
409
+ d_model=config.hidden_dim,
410
+ nhead=config.num_heads,
411
+ dim_feedforward=config.ff_dim,
412
+ dropout=config.dropout,
413
+ batch_first=True,
414
+ norm_first=True,
415
+ )
416
+ self.arm_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers)
417
+ self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim)
418
+ self.arm_identity = nn.Embedding(2, config.hidden_dim)
419
+ self.phase_adapter = nn.Linear(config.num_phases, config.hidden_dim)
420
+ self.role_adapter = nn.Linear(config.num_arm_roles, config.hidden_dim)
421
+ self.context_proj = nn.Sequential(
422
+ nn.LayerNorm(config.hidden_dim),
423
+ nn.Linear(config.hidden_dim, config.hidden_dim),
424
+ nn.GELU(),
425
+ )
426
+ self.coordination = nn.Sequential(
427
+ nn.LayerNorm(config.hidden_dim * 3),
428
+ nn.Linear(config.hidden_dim * 3, config.hidden_dim),
429
+ nn.GELU(),
430
+ nn.Linear(config.hidden_dim, config.hidden_dim),
431
+ )
432
+ self.arm_head = nn.Sequential(
433
+ nn.LayerNorm(config.hidden_dim),
434
+ nn.Linear(config.hidden_dim, config.hidden_dim),
435
+ nn.GELU(),
436
+ )
437
+ self.arm_mean = nn.Linear(config.hidden_dim, config.arm_action_dim)
438
+ self.arm_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim)
439
+ self.proposal_mode_head = nn.Sequential(
440
+ nn.LayerNorm(proposal_context_dim),
441
+ nn.Linear(proposal_context_dim, config.hidden_dim),
442
+ nn.GELU(),
443
+ nn.Linear(config.hidden_dim, config.num_proposal_modes),
444
+ )
445
+ self.proposal_mode_embeddings = nn.Embedding(config.num_proposal_modes, config.hidden_dim)
446
+ self.proposal_slot_embeddings = nn.Embedding(config.num_candidates, config.hidden_dim)
447
+ self.mode_residual_heads = nn.ModuleList(
448
+ [
449
+ nn.Sequential(
450
+ nn.LayerNorm(proposal_context_dim),
451
+ nn.Linear(proposal_context_dim, config.hidden_dim),
452
+ nn.GELU(),
453
+ nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim),
454
+ )
455
+ for _ in range(config.num_proposal_modes)
456
+ ]
457
+ )
458
+ self.slot_delta = nn.Sequential(
459
+ nn.LayerNorm(config.hidden_dim),
460
+ nn.Linear(config.hidden_dim, config.hidden_dim),
461
+ nn.GELU(),
462
+ nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim),
463
+ )
464
+ self.proposal_score = nn.Sequential(
465
+ nn.LayerNorm(proposal_context_dim + config.hidden_dim),
466
+ nn.Linear(proposal_context_dim + config.hidden_dim, config.hidden_dim),
467
+ nn.GELU(),
468
+ nn.Linear(config.hidden_dim, 1),
469
+ )
470
+
471
+ def _conditioning(
472
+ self,
473
+ interaction_state: dict[str, Tensor] | None,
474
+ batch_size: int,
475
+ device: torch.device,
476
+ dtype: torch.dtype,
477
+ swap_roles: bool = False,
478
+ ) -> tuple[Tensor, Tensor, Tensor]:
479
+ if interaction_state is None:
480
+ zero_phase = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype)
481
+ zero_roles = torch.zeros(batch_size, 2, self.config.hidden_dim, device=device, dtype=dtype)
482
+ zero_context = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype)
483
+ return zero_phase, zero_roles, zero_context
484
+ phase_probs = interaction_state["phase_logits"].softmax(dim=-1).to(dtype=dtype)
485
+ arm_role_probs = interaction_state["arm_role_logits"].softmax(dim=-1).to(dtype=dtype)
486
+ if swap_roles:
487
+ arm_role_probs = arm_role_probs.flip(1)
488
+ phase_context = self.phase_adapter(phase_probs)
489
+ role_context = self.role_adapter(arm_role_probs)
490
+ if interaction_state.get("interaction_tokens") is not None:
491
+ interaction_context = interaction_state["interaction_tokens"].mean(dim=1)
492
+ else:
493
+ interaction_context = interaction_state["field_tokens"].mean(dim=1)
494
+ return phase_context, role_context, self.context_proj(interaction_context)
495
+
496
+ def _decode_arm_tokens(
497
+ self,
498
+ queries: Tensor,
499
+ decoder_memory: Tensor,
500
+ phase_context: Tensor,
501
+ role_context: Tensor,
502
+ interaction_context: Tensor,
503
+ swap_roles: bool = False,
504
+ ) -> tuple[Tensor, Tensor, Tensor]:
505
+ batch_size, chunk_size, _ = queries.shape
506
+ identity_order = torch.tensor([1, 0], device=queries.device) if swap_roles else torch.tensor([0, 1], device=queries.device)
507
+ arm_queries = queries.unsqueeze(1).expand(-1, 2, -1, -1)
508
+ arm_queries = arm_queries + phase_context.unsqueeze(1).unsqueeze(2)
509
+ arm_queries = arm_queries + role_context.unsqueeze(2)
510
+ arm_queries = arm_queries + self.arm_identity(identity_order).view(1, 2, 1, -1).to(dtype=queries.dtype)
511
+ flat_queries = arm_queries.reshape(batch_size * 2, chunk_size, self.config.hidden_dim)
512
+ flat_memory = decoder_memory.unsqueeze(1).expand(-1, 2, -1, -1).reshape(
513
+ batch_size * 2,
514
+ decoder_memory.shape[1],
515
+ decoder_memory.shape[2],
516
+ )
517
+ decoded = self.arm_decoder(flat_queries, flat_memory).reshape(batch_size, 2, chunk_size, self.config.hidden_dim)
518
+ coordination_input = torch.cat(
519
+ [
520
+ decoded[:, 0],
521
+ decoded[:, 1],
522
+ interaction_context.unsqueeze(1).expand(-1, chunk_size, -1),
523
+ ],
524
+ dim=-1,
525
+ )
526
+ coordination = torch.tanh(self.coordination(coordination_input))
527
+ decoded[:, 0] = decoded[:, 0] + coordination
528
+ decoded[:, 1] = decoded[:, 1] + coordination
529
+ decoded = self.arm_head(decoded)
530
+ arm_mean = self.arm_mean(decoded)
531
+ arm_log_std = self.arm_log_std(decoded).clamp(min=-5.0, max=2.0)
532
+ return arm_mean, arm_log_std, coordination
533
+
534
+ def _proposal_outputs(
535
+ self,
536
+ base_action: Tensor,
537
+ pooled_context: Tensor,
538
+ ) -> tuple[Tensor, Tensor, Tensor]:
539
+ batch_size = pooled_context.shape[0]
540
+ mode_logits = self.proposal_mode_head(pooled_context)
541
+ mode_residuals = []
542
+ for head in self.mode_residual_heads:
543
+ residual = head(pooled_context).view(batch_size, self.config.chunk_size, self.config.action_dim)
544
+ mode_residuals.append(residual)
545
+ mode_residuals = torch.stack(mode_residuals, dim=1)
546
+
547
+ mode_assignments = torch.arange(self.config.num_candidates, device=pooled_context.device) % self.config.num_proposal_modes
548
+ slot_embeddings = self.proposal_slot_embeddings.weight
549
+ slot_deltas = self.slot_delta(slot_embeddings).view(
550
+ self.config.num_candidates,
551
+ self.config.chunk_size,
552
+ self.config.action_dim,
553
+ )
554
+ proposal_candidates = []
555
+ proposal_logits = []
556
+ for slot_idx in range(self.config.num_candidates):
557
+ mode_idx = int(mode_assignments[slot_idx])
558
+ candidate = base_action + 0.35 * torch.tanh(mode_residuals[:, mode_idx]) + 0.05 * torch.tanh(slot_deltas[slot_idx]).unsqueeze(0)
559
+ proposal_candidates.append(candidate)
560
+ score_features = torch.cat(
561
+ [
562
+ pooled_context,
563
+ self.proposal_mode_embeddings.weight[mode_idx].unsqueeze(0).expand(batch_size, -1)
564
+ + slot_embeddings[slot_idx].unsqueeze(0).expand(batch_size, -1),
565
+ ],
566
+ dim=-1,
567
+ )
568
+ proposal_logits.append(
569
+ self.proposal_score(score_features).squeeze(-1) + mode_logits[:, mode_idx]
570
+ )
571
+ stacked_candidates = torch.stack(proposal_candidates, dim=1)
572
+ stacked_logits = torch.stack(proposal_logits, dim=1)
573
+ stacked_candidates[:, 0] = base_action
574
+ return stacked_candidates, stacked_logits, mode_logits
575
+
576
+ def forward(
577
+ self,
578
+ scene_tokens: Tensor,
579
+ interaction_state: dict[str, Tensor] | None = None,
580
+ memory_tokens: Tensor | None = None,
581
+ reveal_tokens: Tensor | None = None,
582
+ memory_token: Tensor | None = None,
583
+ compute_equivariance_probe: bool = False,
584
+ ) -> dict[str, Tensor]:
585
+ if memory_tokens is None:
586
+ memory_tokens = memory_token
587
+ batch_size = scene_tokens.shape[0]
588
+ dtype = scene_tokens.dtype
589
+ phase_context, role_context, interaction_context = self._conditioning(
590
+ interaction_state=interaction_state,
591
+ batch_size=batch_size,
592
+ device=scene_tokens.device,
593
+ dtype=dtype,
594
+ )
595
+
596
+ decoder_memory = scene_tokens
597
+ interaction_tokens = interaction_state.get("interaction_tokens") if interaction_state is not None else None
598
+ if interaction_tokens is not None:
599
+ decoder_memory = torch.cat([decoder_memory, interaction_tokens], dim=1)
600
+ elif reveal_tokens is not None:
601
+ decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1)
602
+ if memory_tokens is not None:
603
+ decoder_memory = torch.cat([decoder_memory, memory_tokens], dim=1)
604
+
605
+ base_queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
606
+ arm_mean, arm_log_std, coordination = self._decode_arm_tokens(
607
+ queries=base_queries,
608
+ decoder_memory=decoder_memory,
609
+ phase_context=phase_context,
610
+ role_context=role_context,
611
+ interaction_context=interaction_context,
612
+ )
613
+ action_mean = torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1)
614
+ action_log_std = torch.cat([arm_log_std[:, 0], arm_log_std[:, 1]], dim=-1)
615
+ pooled_context = torch.cat(
616
+ [
617
+ arm_mean[:, 0].mean(dim=1),
618
+ arm_mean[:, 1].mean(dim=1),
619
+ coordination.mean(dim=1),
620
+ interaction_context,
621
+ ],
622
+ dim=-1,
623
+ )
624
+ proposal_candidates, proposal_logits, proposal_mode_logits = self._proposal_outputs(action_mean, pooled_context)
625
+
626
+ outputs = {
627
+ "decoded_tokens": torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1),
628
+ "right_tokens": arm_mean[:, 0],
629
+ "left_tokens": arm_mean[:, 1],
630
+ "coordination_tokens": coordination,
631
+ "action_mean": action_mean,
632
+ "action_log_std": action_log_std,
633
+ "proposal_candidates": proposal_candidates,
634
+ "proposal_logits": proposal_logits,
635
+ "proposal_mode_logits": proposal_mode_logits,
636
+ "proposal_mode_assignments": torch.arange(
637
+ self.config.num_candidates,
638
+ device=scene_tokens.device,
639
+ ) % self.config.num_proposal_modes,
640
+ "proposal_mode_names": list(DEFAULT_PROPOSAL_MODES[: self.config.num_proposal_modes]),
641
+ }
642
+ if compute_equivariance_probe:
643
+ swapped_phase, swapped_roles, swapped_context = self._conditioning(
644
+ interaction_state=interaction_state,
645
+ batch_size=batch_size,
646
+ device=scene_tokens.device,
647
+ dtype=dtype,
648
+ swap_roles=True,
649
+ )
650
+ swapped_arm_mean, _, _ = self._decode_arm_tokens(
651
+ queries=base_queries,
652
+ decoder_memory=decoder_memory,
653
+ phase_context=swapped_phase,
654
+ role_context=swapped_roles,
655
+ interaction_context=swapped_context,
656
+ swap_roles=True,
657
+ )
658
+ outputs["equivariance_probe_action_mean"] = torch.cat(
659
+ [swapped_arm_mean[:, 0], swapped_arm_mean[:, 1]],
660
+ dim=-1,
661
+ )
662
+ outputs["equivariance_target_action_mean"] = swap_arm_action_order(action_mean)
663
+ return outputs
664
+
665
+ def sample_candidates(
666
+ self,
667
+ action_mean: Tensor,
668
+ action_log_std: Tensor,
669
+ num_candidates: int | None = None,
670
+ proposal_candidates: Tensor | None = None,
671
+ ) -> Tensor:
672
+ if proposal_candidates is not None:
673
+ return proposal_candidates
674
+ num_candidates = num_candidates or self.config.num_candidates
675
+ if num_candidates <= 1:
676
+ return action_mean.unsqueeze(1)
677
+ noise = torch.randn(
678
+ action_mean.size(0),
679
+ num_candidates,
680
+ action_mean.size(1),
681
+ action_mean.size(2),
682
+ device=action_mean.device,
683
+ dtype=action_mean.dtype,
684
+ )
685
+ candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1)
686
+ candidates[:, 0] = action_mean
687
+ return candidates
code/reveal_vla_bimanual/models/backbones.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
  import math
 
5
  from pathlib import Path
6
  from typing import Sequence
7
 
@@ -18,6 +19,157 @@ class FrozenVLBackboneConfig:
18
  freeze_backbone: bool = True
19
  gradient_checkpointing: bool = True
20
  use_dummy_backbone: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  class _DummyTextTokenizer:
@@ -42,6 +194,11 @@ class FrozenVLBackbone(nn.Module):
42
  self.config = config
43
  self.hidden_dim = config.hidden_dim
44
  self.use_dummy_backbone = config.use_dummy_backbone
 
 
 
 
 
45
 
46
  if config.use_dummy_backbone:
47
  self.image_patch_size = 16
@@ -51,36 +208,62 @@ class FrozenVLBackbone(nn.Module):
51
 
52
  local_model_source: str | None = None
53
  if config.model_name == "openai/clip-vit-base-patch32":
54
- cache_root = Path("/workspace/.cache/huggingface/hub/models--openai--clip-vit-base-patch32")
55
- ref_path = cache_root / "refs" / "main"
56
- if ref_path.exists():
57
- snapshot_id = ref_path.read_text(encoding="utf-8").strip()
58
- snapshot_dir = cache_root / "snapshots" / snapshot_id
59
- if (snapshot_dir / "config.json").exists():
60
- local_model_source = str(snapshot_dir)
 
 
 
 
 
 
 
 
 
 
 
 
61
  clip_model = None
 
 
62
  if local_model_source is not None:
 
 
 
 
 
63
  try:
64
- clip_model = CLIPModel.from_pretrained(
65
- local_model_source,
66
- use_safetensors=True,
67
- local_files_only=True,
68
- )
69
- except OSError:
70
- clip_model = None
71
  if clip_model is None:
72
- clip_model = CLIPModel.from_pretrained(config.model_name, use_safetensors=True)
 
73
  self.vision_model = clip_model.vision_model
74
  self.text_model = clip_model.text_model
75
  self.visual_projection = clip_model.visual_projection
76
  self.text_projection = clip_model.text_projection
 
 
 
77
  if local_model_source is not None:
 
 
 
78
  try:
79
- self.tokenizer = AutoTokenizer.from_pretrained(local_model_source, local_files_only=True)
80
- except OSError:
81
- self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
82
- else:
83
- self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
 
 
 
84
  self.hidden_dim = clip_model.config.projection_dim
85
  if config.gradient_checkpointing:
86
  if hasattr(self.vision_model, "gradient_checkpointing_enable"):
@@ -88,9 +271,17 @@ class FrozenVLBackbone(nn.Module):
88
  if hasattr(self.text_model, "gradient_checkpointing_enable"):
89
  self.text_model.gradient_checkpointing_enable()
90
 
91
- if config.freeze_backbone:
92
- for parameter in self.parameters():
93
- parameter.requires_grad = False
 
 
 
 
 
 
 
 
94
 
95
  def tokenize_text(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]:
96
  if self.use_dummy_backbone:
@@ -103,7 +294,7 @@ class FrozenVLBackbone(nn.Module):
103
  return_tensors="pt",
104
  ).to(device)
105
 
106
- def encode_images(self, images: Tensor) -> Tensor:
107
  batch_size, num_views, channels, height, width = images.shape
108
  flat_images = images.reshape(batch_size * num_views, channels, height, width)
109
  if self.use_dummy_backbone:
@@ -125,6 +316,40 @@ class FrozenVLBackbone(nn.Module):
125
  num_tokens = tokens.shape[1]
126
  return tokens.reshape(batch_size, num_views, num_tokens, -1)
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def encode_text(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
129
  if self.use_dummy_backbone:
130
  vocab_scale = float(self.tokenizer.vocab_size - 1)
 
2
 
3
  from dataclasses import dataclass
4
  import math
5
+ import os
6
  from pathlib import Path
7
  from typing import Sequence
8
 
 
19
  freeze_backbone: bool = True
20
  gradient_checkpointing: bool = True
21
  use_dummy_backbone: bool = False
22
+ depth_patch_size: int = 16
23
+ geometry_feature_dim: int = 8
24
+ use_camera_geometry: bool = True
25
+
26
+
27
+ class DepthPatchAdapter(nn.Module):
28
+ def __init__(
29
+ self,
30
+ hidden_dim: int,
31
+ patch_size: int = 16,
32
+ geometry_feature_dim: int = 8,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.hidden_dim = hidden_dim
36
+ self.patch_size = patch_size
37
+ self.geometry_feature_dim = geometry_feature_dim
38
+ self.depth_proj = nn.Sequential(
39
+ nn.LayerNorm(2 + geometry_feature_dim),
40
+ nn.Linear(2 + geometry_feature_dim, hidden_dim),
41
+ nn.GELU(),
42
+ nn.Linear(hidden_dim, hidden_dim),
43
+ )
44
+ self.geometry_proj = nn.Sequential(
45
+ nn.LayerNorm(geometry_feature_dim),
46
+ nn.Linear(geometry_feature_dim, hidden_dim),
47
+ nn.GELU(),
48
+ )
49
+ self.camera_proj = nn.Sequential(
50
+ nn.LayerNorm(7),
51
+ nn.Linear(7, hidden_dim),
52
+ nn.GELU(),
53
+ )
54
+
55
+ def _patchify(self, tensor: Tensor) -> Tensor:
56
+ pooled = F.avg_pool2d(tensor, kernel_size=self.patch_size, stride=self.patch_size)
57
+ return pooled.flatten(2).transpose(1, 2)
58
+
59
+ def _geometry_features(
60
+ self,
61
+ depths: Tensor,
62
+ camera_intrinsics: Tensor | None = None,
63
+ camera_extrinsics: Tensor | None = None,
64
+ ) -> tuple[Tensor, Tensor]:
65
+ batch_views, _, height, width = depths.shape
66
+ grid_h = max(1, height // self.patch_size)
67
+ grid_w = max(1, width // self.patch_size)
68
+ y_coords = torch.linspace(-1.0, 1.0, steps=grid_h, device=depths.device, dtype=depths.dtype)
69
+ x_coords = torch.linspace(-1.0, 1.0, steps=grid_w, device=depths.device, dtype=depths.dtype)
70
+ grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij")
71
+ coords = torch.stack([grid_x, grid_y], dim=-1).reshape(1, grid_h * grid_w, 2).expand(batch_views, -1, -1)
72
+
73
+ geometry_terms: list[Tensor] = [coords]
74
+ if camera_intrinsics is not None:
75
+ fx = camera_intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1)
76
+ fy = camera_intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1)
77
+ cx = camera_intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)
78
+ cy = camera_intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)
79
+ intrinsic_features = torch.cat(
80
+ [
81
+ fx.expand(-1, grid_h * grid_w, -1),
82
+ fy.expand(-1, grid_h * grid_w, -1),
83
+ cx.expand(-1, grid_h * grid_w, -1),
84
+ cy.expand(-1, grid_h * grid_w, -1),
85
+ ],
86
+ dim=-1,
87
+ )
88
+ geometry_terms.append(intrinsic_features)
89
+ else:
90
+ geometry_terms.append(torch.zeros(batch_views, grid_h * grid_w, 4, device=depths.device, dtype=depths.dtype))
91
+
92
+ if camera_extrinsics is not None:
93
+ translation = camera_extrinsics[:, :3, 3]
94
+ translation = translation.unsqueeze(1).expand(-1, grid_h * grid_w, -1)
95
+ geometry_terms.append(translation)
96
+ else:
97
+ geometry_terms.append(torch.zeros(batch_views, grid_h * grid_w, 3, device=depths.device, dtype=depths.dtype))
98
+
99
+ geometry = torch.cat(geometry_terms, dim=-1)
100
+ if geometry.shape[-1] < self.geometry_feature_dim:
101
+ pad = self.geometry_feature_dim - geometry.shape[-1]
102
+ geometry = F.pad(geometry, (0, pad))
103
+ elif geometry.shape[-1] > self.geometry_feature_dim:
104
+ geometry = geometry[..., : self.geometry_feature_dim]
105
+
106
+ if camera_intrinsics is not None:
107
+ camera_summary = torch.cat(
108
+ [
109
+ camera_intrinsics[:, 0, 0:1],
110
+ camera_intrinsics[:, 1, 1:2],
111
+ camera_intrinsics[:, 0, 2:3],
112
+ camera_intrinsics[:, 1, 2:3],
113
+ ],
114
+ dim=-1,
115
+ )
116
+ else:
117
+ camera_summary = torch.zeros(batch_views, 4, device=depths.device, dtype=depths.dtype)
118
+ if camera_extrinsics is not None:
119
+ camera_summary = torch.cat([camera_summary, camera_extrinsics[:, :3, 3]], dim=-1)
120
+ else:
121
+ camera_summary = torch.cat(
122
+ [camera_summary, torch.zeros(batch_views, 3, device=depths.device, dtype=depths.dtype)],
123
+ dim=-1,
124
+ )
125
+ return geometry, camera_summary
126
+
127
+ def forward(
128
+ self,
129
+ depths: Tensor,
130
+ depth_valid: Tensor | None = None,
131
+ camera_intrinsics: Tensor | None = None,
132
+ camera_extrinsics: Tensor | None = None,
133
+ ) -> dict[str, Tensor]:
134
+ if depths.ndim == 4:
135
+ depths = depths.unsqueeze(2)
136
+ if depth_valid is None:
137
+ depth_valid = torch.ones_like(depths)
138
+ if depth_valid.ndim == 4:
139
+ depth_valid = depth_valid.unsqueeze(2)
140
+ if depths.ndim != 5:
141
+ raise ValueError(f"Expected depths to have shape [B, V, H, W] or [B, V, 1, H, W], got {tuple(depths.shape)}")
142
+ if depths.shape[2] != 1:
143
+ depths = depths.mean(dim=2, keepdim=True)
144
+ if depth_valid.shape[2] != 1:
145
+ depth_valid = depth_valid.mean(dim=2, keepdim=True)
146
+
147
+ batch_size, num_views = depths.shape[:2]
148
+ flat_depths = depths.reshape(batch_size * num_views, 1, depths.shape[-2], depths.shape[-1]).float()
149
+ flat_valid = depth_valid.reshape(batch_size * num_views, 1, depth_valid.shape[-2], depth_valid.shape[-1]).float()
150
+ flat_intrinsics = None
151
+ flat_extrinsics = None
152
+ if camera_intrinsics is not None:
153
+ flat_intrinsics = camera_intrinsics.reshape(batch_size * num_views, *camera_intrinsics.shape[-2:]).float()
154
+ if camera_extrinsics is not None:
155
+ flat_extrinsics = camera_extrinsics.reshape(batch_size * num_views, *camera_extrinsics.shape[-2:]).float()
156
+
157
+ depth_patch = self._patchify(flat_depths)
158
+ valid_patch = self._patchify(flat_valid)
159
+ geometry_features, camera_summary = self._geometry_features(
160
+ flat_depths,
161
+ camera_intrinsics=flat_intrinsics,
162
+ camera_extrinsics=flat_extrinsics,
163
+ )
164
+ token_inputs = torch.cat([depth_patch, valid_patch, geometry_features], dim=-1)
165
+ depth_tokens = self.depth_proj(token_inputs)
166
+ geometry_tokens = self.geometry_proj(geometry_features)
167
+ camera_tokens = self.camera_proj(camera_summary).unsqueeze(1)
168
+ return {
169
+ "depth_tokens": depth_tokens.view(batch_size, num_views, depth_tokens.shape[1], depth_tokens.shape[2]),
170
+ "geometry_tokens": geometry_tokens.view(batch_size, num_views, geometry_tokens.shape[1], geometry_tokens.shape[2]),
171
+ "camera_tokens": camera_tokens.view(batch_size, num_views, 1, camera_tokens.shape[-1]),
172
+ }
173
 
174
 
175
  class _DummyTextTokenizer:
 
194
  self.config = config
195
  self.hidden_dim = config.hidden_dim
196
  self.use_dummy_backbone = config.use_dummy_backbone
197
+ self.depth_adapter = DepthPatchAdapter(
198
+ hidden_dim=config.hidden_dim,
199
+ patch_size=config.depth_patch_size,
200
+ geometry_feature_dim=config.geometry_feature_dim,
201
+ )
202
 
203
  if config.use_dummy_backbone:
204
  self.image_patch_size = 16
 
208
 
209
  local_model_source: str | None = None
210
  if config.model_name == "openai/clip-vit-base-patch32":
211
+ explicit_local_dir = Path("/workspace/models/openai_clip_vit_base_patch32")
212
+ if (explicit_local_dir / "config.json").exists():
213
+ local_model_source = str(explicit_local_dir)
214
+ cache_home = Path(os.environ.get("HF_HOME", "/workspace/.cache/huggingface"))
215
+ cache_root = cache_home / "hub" / "models--openai--clip-vit-base-patch32"
216
+ if local_model_source is None:
217
+ ref_path = cache_root / "refs" / "main"
218
+ if ref_path.exists():
219
+ snapshot_id = ref_path.read_text(encoding="utf-8").strip()
220
+ snapshot_dir = cache_root / "snapshots" / snapshot_id
221
+ if (snapshot_dir / "config.json").exists():
222
+ local_model_source = str(snapshot_dir)
223
+ if local_model_source is None:
224
+ snapshot_root = cache_root / "snapshots"
225
+ if snapshot_root.exists():
226
+ for snapshot_dir in sorted(snapshot_root.iterdir(), reverse=True):
227
+ if (snapshot_dir / "config.json").exists():
228
+ local_model_source = str(snapshot_dir)
229
+ break
230
  clip_model = None
231
+ last_clip_error: Exception | None = None
232
+ model_sources: list[tuple[str, dict[str, object]]] = []
233
  if local_model_source is not None:
234
+ model_sources.append((local_model_source, {"use_safetensors": True, "local_files_only": True}))
235
+ model_sources.append((local_model_source, {"local_files_only": True}))
236
+ model_sources.append((config.model_name, {"use_safetensors": True}))
237
+ model_sources.append((config.model_name, {}))
238
+ for source, kwargs in model_sources:
239
  try:
240
+ clip_model = CLIPModel.from_pretrained(source, **kwargs)
241
+ break
242
+ except Exception as exc:
243
+ last_clip_error = exc
 
 
 
244
  if clip_model is None:
245
+ assert last_clip_error is not None
246
+ raise last_clip_error
247
  self.vision_model = clip_model.vision_model
248
  self.text_model = clip_model.text_model
249
  self.visual_projection = clip_model.visual_projection
250
  self.text_projection = clip_model.text_projection
251
+ tokenizer = None
252
+ last_tokenizer_error: Exception | None = None
253
+ tokenizer_sources: list[tuple[str, dict[str, object]]] = []
254
  if local_model_source is not None:
255
+ tokenizer_sources.append((local_model_source, {"local_files_only": True}))
256
+ tokenizer_sources.append((config.model_name, {}))
257
+ for source, kwargs in tokenizer_sources:
258
  try:
259
+ tokenizer = AutoTokenizer.from_pretrained(source, **kwargs)
260
+ break
261
+ except Exception as exc:
262
+ last_tokenizer_error = exc
263
+ if tokenizer is None:
264
+ assert last_tokenizer_error is not None
265
+ raise last_tokenizer_error
266
+ self.tokenizer = tokenizer
267
  self.hidden_dim = clip_model.config.projection_dim
268
  if config.gradient_checkpointing:
269
  if hasattr(self.vision_model, "gradient_checkpointing_enable"):
 
271
  if hasattr(self.text_model, "gradient_checkpointing_enable"):
272
  self.text_model.gradient_checkpointing_enable()
273
 
274
+ if config.freeze_backbone and not config.use_dummy_backbone:
275
+ for module in (
276
+ getattr(self, "vision_model", None),
277
+ getattr(self, "text_model", None),
278
+ getattr(self, "visual_projection", None),
279
+ getattr(self, "text_projection", None),
280
+ ):
281
+ if module is None:
282
+ continue
283
+ for parameter in module.parameters():
284
+ parameter.requires_grad = False
285
 
286
  def tokenize_text(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]:
287
  if self.use_dummy_backbone:
 
294
  return_tensors="pt",
295
  ).to(device)
296
 
297
+ def _encode_rgb_tokens(self, images: Tensor) -> Tensor:
298
  batch_size, num_views, channels, height, width = images.shape
299
  flat_images = images.reshape(batch_size * num_views, channels, height, width)
300
  if self.use_dummy_backbone:
 
316
  num_tokens = tokens.shape[1]
317
  return tokens.reshape(batch_size, num_views, num_tokens, -1)
318
 
319
+ def encode_images(
320
+ self,
321
+ images: Tensor,
322
+ depths: Tensor | None = None,
323
+ depth_valid: Tensor | None = None,
324
+ camera_intrinsics: Tensor | None = None,
325
+ camera_extrinsics: Tensor | None = None,
326
+ return_aux: bool = False,
327
+ ) -> Tensor | dict[str, Tensor | None]:
328
+ rgb_tokens = self._encode_rgb_tokens(images)
329
+ wants_aux = return_aux or depths is not None or depth_valid is not None or camera_intrinsics is not None or camera_extrinsics is not None
330
+ if not wants_aux:
331
+ return rgb_tokens
332
+
333
+ depth_outputs: dict[str, Tensor | None] = {
334
+ "depth_tokens": None,
335
+ "geometry_tokens": None,
336
+ "camera_tokens": None,
337
+ }
338
+ if depths is not None:
339
+ depth_outputs = self.depth_adapter(
340
+ depths=depths,
341
+ depth_valid=depth_valid,
342
+ camera_intrinsics=camera_intrinsics,
343
+ camera_extrinsics=camera_extrinsics,
344
+ )
345
+
346
+ return {
347
+ "rgb_tokens": rgb_tokens,
348
+ "depth_tokens": depth_outputs["depth_tokens"],
349
+ "geometry_tokens": depth_outputs["geometry_tokens"],
350
+ "camera_tokens": depth_outputs["camera_tokens"],
351
+ }
352
+
353
  def encode_text(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
354
  if self.use_dummy_backbone:
355
  vocab_scale = float(self.tokenizer.vocab_size - 1)
code/reveal_vla_bimanual/models/multiview_fusion.py CHANGED
@@ -16,6 +16,37 @@ class MultiViewFusionConfig:
16
  dropout: float = 0.1
17
  proprio_dim: int = 32
18
  proprio_tokens: int = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  class MultiViewFusion(nn.Module):
@@ -35,13 +66,26 @@ class MultiViewFusion(nn.Module):
35
  encoder_layer,
36
  num_layers=config.num_layers,
37
  )
 
 
 
 
 
38
  self.proprio_adapter = nn.Sequential(
39
  nn.LayerNorm(config.proprio_dim),
40
  nn.Linear(config.proprio_dim, config.hidden_dim * config.proprio_tokens),
41
  nn.GELU(),
42
  )
43
 
44
- def forward(self, image_tokens: Tensor, proprio: Tensor, language_tokens: Tensor) -> Tensor:
 
 
 
 
 
 
 
 
45
  batch_size, num_views, num_tokens, hidden_dim = image_tokens.shape
46
  if num_views != self.config.num_cameras:
47
  raise ValueError(f"Expected {self.config.num_cameras} views, received {num_views}")
@@ -49,9 +93,36 @@ class MultiViewFusion(nn.Module):
49
  camera_ids = torch.arange(num_views, device=image_tokens.device)
50
  camera_embed = self.camera_embedding(camera_ids).view(1, num_views, 1, hidden_dim)
51
  image_tokens = image_tokens + camera_embed
52
- fused = self.cross_view_transformer(image_tokens.reshape(batch_size, num_views * num_tokens, hidden_dim))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  proprio_tokens = self.proprio_adapter(proprio).view(
55
  batch_size, self.config.proprio_tokens, hidden_dim
56
  )
57
- return torch.cat([fused, proprio_tokens, language_tokens], dim=1)
 
 
 
 
 
 
 
 
16
  dropout: float = 0.1
17
  proprio_dim: int = 32
18
  proprio_tokens: int = 1
19
+ geometry_num_heads: int = 4
20
+
21
+
22
+ class GatedCrossAttentionBlock(nn.Module):
23
+ def __init__(self, hidden_dim: int, num_heads: int, dropout: float) -> None:
24
+ super().__init__()
25
+ self.attn = nn.MultiheadAttention(
26
+ embed_dim=hidden_dim,
27
+ num_heads=num_heads,
28
+ dropout=dropout,
29
+ batch_first=True,
30
+ )
31
+ self.gate = nn.Sequential(
32
+ nn.LayerNorm(hidden_dim * 2),
33
+ nn.Linear(hidden_dim * 2, hidden_dim),
34
+ nn.GELU(),
35
+ nn.Linear(hidden_dim, hidden_dim),
36
+ )
37
+ self.out = nn.Sequential(
38
+ nn.LayerNorm(hidden_dim),
39
+ nn.Linear(hidden_dim, hidden_dim),
40
+ nn.GELU(),
41
+ )
42
+
43
+ def forward(self, rgb_tokens: Tensor, geometry_tokens: Tensor) -> tuple[Tensor, Tensor]:
44
+ attended, _ = self.attn(rgb_tokens, geometry_tokens, geometry_tokens)
45
+ rgb_summary = rgb_tokens.mean(dim=1)
46
+ geometry_summary = geometry_tokens.mean(dim=1)
47
+ gate = torch.sigmoid(self.gate(torch.cat([rgb_summary, geometry_summary], dim=-1))).unsqueeze(1)
48
+ fused = rgb_tokens + gate * attended
49
+ return self.out(fused), geometry_summary
50
 
51
 
52
  class MultiViewFusion(nn.Module):
 
66
  encoder_layer,
67
  num_layers=config.num_layers,
68
  )
69
+ self.geometry_fusion = GatedCrossAttentionBlock(
70
+ hidden_dim=config.hidden_dim,
71
+ num_heads=max(1, min(config.num_heads, config.geometry_num_heads)),
72
+ dropout=config.dropout,
73
+ )
74
  self.proprio_adapter = nn.Sequential(
75
  nn.LayerNorm(config.proprio_dim),
76
  nn.Linear(config.proprio_dim, config.hidden_dim * config.proprio_tokens),
77
  nn.GELU(),
78
  )
79
 
80
+ def forward(
81
+ self,
82
+ image_tokens: Tensor,
83
+ proprio: Tensor,
84
+ language_tokens: Tensor,
85
+ depth_tokens: Tensor | None = None,
86
+ camera_tokens: Tensor | None = None,
87
+ return_aux: bool = False,
88
+ ) -> Tensor | dict[str, Tensor]:
89
  batch_size, num_views, num_tokens, hidden_dim = image_tokens.shape
90
  if num_views != self.config.num_cameras:
91
  raise ValueError(f"Expected {self.config.num_cameras} views, received {num_views}")
 
93
  camera_ids = torch.arange(num_views, device=image_tokens.device)
94
  camera_embed = self.camera_embedding(camera_ids).view(1, num_views, 1, hidden_dim)
95
  image_tokens = image_tokens + camera_embed
96
+
97
+ per_view_tokens = []
98
+ view_summaries = []
99
+ geometry_summaries = []
100
+ for view_idx in range(num_views):
101
+ rgb_tokens = image_tokens[:, view_idx]
102
+ geometry_sources = []
103
+ if depth_tokens is not None:
104
+ geometry_sources.append(depth_tokens[:, view_idx])
105
+ if camera_tokens is not None:
106
+ geometry_sources.append(camera_tokens[:, view_idx])
107
+ if geometry_sources:
108
+ geometry = torch.cat(geometry_sources, dim=1)
109
+ rgb_tokens, geometry_summary = self.geometry_fusion(rgb_tokens, geometry)
110
+ else:
111
+ geometry_summary = torch.zeros(batch_size, hidden_dim, device=image_tokens.device, dtype=image_tokens.dtype)
112
+ per_view_tokens.append(rgb_tokens)
113
+ view_summaries.append(rgb_tokens.mean(dim=1))
114
+ geometry_summaries.append(geometry_summary)
115
+
116
+ fused = self.cross_view_transformer(torch.cat(per_view_tokens, dim=1))
117
 
118
  proprio_tokens = self.proprio_adapter(proprio).view(
119
  batch_size, self.config.proprio_tokens, hidden_dim
120
  )
121
+ scene_tokens = torch.cat([fused, proprio_tokens, language_tokens], dim=1)
122
+ if not (return_aux or depth_tokens is not None or camera_tokens is not None):
123
+ return scene_tokens
124
+ return {
125
+ "scene_tokens": scene_tokens,
126
+ "view_summaries": torch.stack(view_summaries, dim=1),
127
+ "geometry_summaries": torch.stack(geometry_summaries, dim=1),
128
+ }
code/reveal_vla_bimanual/models/observation_memory.py CHANGED
@@ -16,6 +16,12 @@ class ObservationMemoryConfig:
16
  memory_bank_size: int = 4
17
  num_heads: int = 4
18
  max_history_steps: int = 8
 
 
 
 
 
 
19
 
20
 
21
  class ObservationMemory(nn.Module):
@@ -173,3 +179,189 @@ class InteractionObservationMemory(nn.Module):
173
  "memory_tokens": projected_bank,
174
  "memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_bank)).squeeze(-1),
175
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  memory_bank_size: int = 4
17
  num_heads: int = 4
18
  max_history_steps: int = 8
19
+ scene_bank_size: int = 2
20
+ belief_bank_size: int = 2
21
+ scene_history_steps: int = 3
22
+ belief_history_steps: int = 8
23
+ memory_write_threshold: float = 0.45
24
+ memory_suppression_margin: float = 0.05
25
 
26
 
27
  class ObservationMemory(nn.Module):
 
179
  "memory_tokens": projected_bank,
180
  "memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_bank)).squeeze(-1),
181
  }
182
+
183
+
184
+ class _SelectiveMemoryBank(nn.Module):
185
+ def __init__(
186
+ self,
187
+ hidden_dim: int,
188
+ action_dim: int,
189
+ num_heads: int,
190
+ dropout: float,
191
+ bank_size: int,
192
+ history_steps: int,
193
+ max_history_steps: int,
194
+ write_threshold: float,
195
+ suppression_margin: float,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.hidden_dim = hidden_dim
199
+ self.history_steps = history_steps
200
+ self.max_history_steps = max_history_steps
201
+ self.write_threshold = write_threshold
202
+ self.suppression_margin = suppression_margin
203
+ encoder_layer = nn.TransformerEncoderLayer(
204
+ d_model=hidden_dim,
205
+ nhead=num_heads,
206
+ dim_feedforward=hidden_dim * 4,
207
+ dropout=dropout,
208
+ batch_first=True,
209
+ norm_first=True,
210
+ )
211
+ self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
212
+ self.position_embedding = nn.Parameter(torch.randn(1, max_history_steps + 1, hidden_dim) * 0.02)
213
+ self.bank_queries = nn.Parameter(torch.randn(bank_size, hidden_dim) * 0.02)
214
+ self.bank_attention = nn.MultiheadAttention(
215
+ embed_dim=hidden_dim,
216
+ num_heads=num_heads,
217
+ dropout=dropout,
218
+ batch_first=True,
219
+ )
220
+ self.action_proj = nn.Sequential(
221
+ nn.LayerNorm(action_dim),
222
+ nn.Linear(action_dim, hidden_dim),
223
+ nn.GELU(),
224
+ )
225
+ self.write_gate = nn.Sequential(
226
+ nn.LayerNorm(hidden_dim * 3),
227
+ nn.Linear(hidden_dim * 3, hidden_dim),
228
+ nn.GELU(),
229
+ nn.Linear(hidden_dim, 1),
230
+ )
231
+ self.token_proj = nn.Sequential(
232
+ nn.LayerNorm(hidden_dim),
233
+ nn.Linear(hidden_dim, hidden_dim),
234
+ nn.GELU(),
235
+ )
236
+
237
+ def _truncate(self, history: Tensor | None) -> Tensor | None:
238
+ if history is None or history.numel() == 0:
239
+ return history
240
+ if history.shape[1] <= self.history_steps:
241
+ return history
242
+ return history[:, -self.history_steps :]
243
+
244
+ def forward(
245
+ self,
246
+ pooled_current: Tensor,
247
+ history_scene_tokens: Tensor | None = None,
248
+ history_actions: Tensor | None = None,
249
+ ) -> dict[str, Tensor]:
250
+ history_scene_tokens = self._truncate(history_scene_tokens)
251
+ pooled_current = pooled_current.unsqueeze(1)
252
+ if history_scene_tokens is not None and history_scene_tokens.numel() > 0:
253
+ history_pooled = history_scene_tokens.mean(dim=2)
254
+ if history_actions is not None and history_actions.numel() > 0:
255
+ history_actions = history_actions[:, -history_pooled.shape[1] :]
256
+ history_pooled = history_pooled + self.action_proj(history_actions)
257
+ sequence = torch.cat([history_pooled, pooled_current], dim=1)
258
+ else:
259
+ history_pooled = pooled_current[:, :0]
260
+ sequence = pooled_current
261
+ if sequence.shape[1] > self.position_embedding.shape[1]:
262
+ raise ValueError(
263
+ f"Sequence length {sequence.shape[1]} exceeds configured maximum {self.position_embedding.shape[1]}"
264
+ )
265
+ encoded = self.sequence_encoder(sequence + self.position_embedding[:, : sequence.shape[1]])
266
+ current_token = encoded[:, -1]
267
+ prior_token = encoded[:, :-1].mean(dim=1) if encoded.shape[1] > 1 else torch.zeros_like(current_token)
268
+ novelty = torch.abs(current_token - prior_token)
269
+ informative = novelty.mean(dim=-1, keepdim=True)
270
+ gate_logit = self.write_gate(torch.cat([current_token, prior_token, novelty], dim=-1))
271
+ gate = torch.sigmoid(gate_logit)
272
+ gate = gate * (informative > (self.write_threshold - self.suppression_margin)).to(gate.dtype)
273
+ recent_summary = encoded[:, -min(max(1, self.bank_queries.shape[0]), encoded.shape[1]) :].mean(dim=1, keepdim=True)
274
+ queries = self.bank_queries.unsqueeze(0).expand(encoded.shape[0], -1, -1) + recent_summary
275
+ bank_tokens, _ = self.bank_attention(queries, encoded, encoded)
276
+ bank_tokens = bank_tokens + recent_summary
277
+ bank_tokens = prior_token.unsqueeze(1) * (1.0 - gate.unsqueeze(1)) + bank_tokens * gate.unsqueeze(1)
278
+ bank_tokens = self.token_proj(bank_tokens)
279
+ return {
280
+ "memory_tokens": bank_tokens,
281
+ "memory_token": bank_tokens.mean(dim=1, keepdim=True),
282
+ "memory_sequence": encoded,
283
+ "memory_state": current_token,
284
+ "write_gate": gate.squeeze(-1),
285
+ "saturation": bank_tokens.abs().mean(dim=(1, 2)),
286
+ }
287
+
288
+
289
+ class SceneMemory(_SelectiveMemoryBank):
290
+ def __init__(self, config: ObservationMemoryConfig) -> None:
291
+ super().__init__(
292
+ hidden_dim=config.hidden_dim,
293
+ action_dim=config.action_dim,
294
+ num_heads=config.num_heads,
295
+ dropout=config.dropout,
296
+ bank_size=max(1, config.scene_bank_size),
297
+ history_steps=max(1, config.scene_history_steps),
298
+ max_history_steps=config.max_history_steps,
299
+ write_threshold=config.memory_write_threshold,
300
+ suppression_margin=config.memory_suppression_margin,
301
+ )
302
+
303
+
304
+ class BeliefMemory(_SelectiveMemoryBank):
305
+ def __init__(self, config: ObservationMemoryConfig) -> None:
306
+ super().__init__(
307
+ hidden_dim=config.hidden_dim,
308
+ action_dim=config.action_dim,
309
+ num_heads=config.num_heads,
310
+ dropout=config.dropout,
311
+ bank_size=max(1, config.belief_bank_size),
312
+ history_steps=max(1, config.belief_history_steps),
313
+ max_history_steps=config.max_history_steps,
314
+ write_threshold=config.memory_write_threshold + 0.05,
315
+ suppression_margin=config.memory_suppression_margin,
316
+ )
317
+
318
+
319
+ class DualObservationMemory(nn.Module):
320
+ def __init__(self, config: ObservationMemoryConfig) -> None:
321
+ super().__init__()
322
+ self.scene_memory = SceneMemory(config)
323
+ self.belief_memory = BeliefMemory(config)
324
+ self.uncertainty_head = nn.Sequential(
325
+ nn.LayerNorm(config.hidden_dim),
326
+ nn.Linear(config.hidden_dim, 1),
327
+ )
328
+
329
+ def forward(
330
+ self,
331
+ scene_tokens: Tensor,
332
+ history_scene_tokens: Tensor | None = None,
333
+ history_actions: Tensor | None = None,
334
+ ) -> dict[str, Tensor]:
335
+ pooled_current = scene_tokens.mean(dim=1)
336
+ scene_output = self.scene_memory(
337
+ pooled_current=pooled_current,
338
+ history_scene_tokens=history_scene_tokens,
339
+ history_actions=history_actions,
340
+ )
341
+ belief_output = self.belief_memory(
342
+ pooled_current=pooled_current,
343
+ history_scene_tokens=history_scene_tokens,
344
+ history_actions=history_actions,
345
+ )
346
+ memory_tokens = torch.cat([scene_output["memory_tokens"], belief_output["memory_tokens"]], dim=1)
347
+ memory_token = memory_tokens.mean(dim=1, keepdim=True)
348
+ memory_state = torch.cat([scene_output["memory_state"], belief_output["memory_state"]], dim=-1)
349
+ pooled_memory = memory_tokens.mean(dim=1)
350
+ return {
351
+ "scene_memory_tokens": scene_output["memory_tokens"],
352
+ "belief_memory_tokens": belief_output["memory_tokens"],
353
+ "memory_tokens": memory_tokens,
354
+ "memory_token": memory_token,
355
+ "memory_sequence": torch.cat(
356
+ [scene_output["memory_sequence"], belief_output["memory_sequence"]],
357
+ dim=1,
358
+ ),
359
+ "memory_state": memory_state,
360
+ "memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_memory)).squeeze(-1),
361
+ "memory_write_rate": 0.5 * (scene_output["write_gate"] + belief_output["write_gate"]),
362
+ "memory_saturation": 0.5 * (scene_output["saturation"] + belief_output["saturation"]),
363
+ "scene_write_gate": scene_output["write_gate"],
364
+ "belief_write_gate": belief_output["write_gate"],
365
+ "memory_scene_state": scene_output["memory_state"],
366
+ "memory_belief_state": belief_output["memory_state"],
367
+ }
code/reveal_vla_bimanual/models/planner.py CHANGED
@@ -24,6 +24,14 @@ class PlannerConfig:
24
  num_layers: int = 2
25
  num_phases: int = 5
26
  num_arm_roles: int = 4
 
 
 
 
 
 
 
 
27
 
28
 
29
  class RevealPlanner(nn.Module):
@@ -202,3 +210,186 @@ class InteractionPlanner(nn.Module):
202
  "best_indices": best_idx,
203
  "best_chunk": candidate_chunks[batch_indices, best_idx],
204
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  num_layers: int = 2
25
  num_phases: int = 5
26
  num_arm_roles: int = 4
27
+ top_k: int = 4
28
+ belief_gain_weight: float = 1.0
29
+ visibility_gain_weight: float = 0.75
30
+ clearance_weight: float = 0.75
31
+ occluder_contact_weight: float = 0.5
32
+ grasp_affordance_weight: float = 0.75
33
+ support_stability_weight: float = 0.5
34
+ residual_weight: float = 0.5
35
 
36
 
37
  class RevealPlanner(nn.Module):
 
210
  "best_indices": best_idx,
211
  "best_chunk": candidate_chunks[batch_indices, best_idx],
212
  }
213
+
214
+
215
+ class StructuredElasticUtility(nn.Module):
216
+ def __init__(self, config: PlannerConfig) -> None:
217
+ super().__init__()
218
+ self.config = config
219
+
220
+ def _field_mean(self, tensor: Tensor) -> Tensor:
221
+ if tensor.ndim == 6:
222
+ return tensor.mean(dim=(-1, -2, -3))
223
+ if tensor.ndim == 5:
224
+ return tensor.mean(dim=(-1, -2))
225
+ if tensor.ndim == 4:
226
+ return tensor.mean(dim=(-1, -2))
227
+ return tensor
228
+
229
+ def _initial_scalar(self, state: dict[str, Tensor], key: str) -> Tensor:
230
+ value = state[key]
231
+ if value.ndim >= 4:
232
+ return value.mean(dim=tuple(range(1, value.ndim)))
233
+ if value.ndim == 3:
234
+ return value.mean(dim=(-1, -2))
235
+ if value.ndim == 2:
236
+ return value.mean(dim=-1)
237
+ return value
238
+
239
+ def forward(
240
+ self,
241
+ initial_state: dict[str, Tensor],
242
+ rollout_state: dict[str, Tensor],
243
+ candidate_chunks: Tensor,
244
+ ) -> dict[str, Tensor]:
245
+ initial_belief = self._initial_scalar(initial_state, "target_belief_field").unsqueeze(1)
246
+ initial_visibility = self._initial_scalar(initial_state, "visibility_field").unsqueeze(1)
247
+ belief_future = self._field_mean(rollout_state["target_belief_field"]).mean(dim=-1)
248
+ visibility_future = self._field_mean(rollout_state["visibility_field"]).mean(dim=-1)
249
+ clearance = self._field_mean(rollout_state["clearance_field"]).mean(dim=-1)
250
+ occluder_contact = self._field_mean(rollout_state["occluder_contact_field"]).mean(dim=-1)
251
+ grasp_affordance = self._field_mean(rollout_state["grasp_affordance_field"]).mean(dim=-1)
252
+ support_stability = torch.sigmoid(self._field_mean(rollout_state["support_stability_field"])).mean(dim=-1)
253
+ persistence = self._field_mean(rollout_state["persistence_field"]).mean(dim=-1)
254
+ reocclusion = self._field_mean(rollout_state["reocclusion_field"]).mean(dim=-1)
255
+ disturbance = self._field_mean(rollout_state["disturbance_field"]).mean(dim=-1)
256
+ access_quality = torch.sigmoid(self._field_mean(rollout_state["access_field"])).mean(dim=-1)
257
+ retrieve_progress = torch.sigmoid(candidate_chunks[:, :, :, -1]).mean(dim=-1)
258
+ utility = (
259
+ self.config.belief_gain_weight * (belief_future - initial_belief)
260
+ + self.config.visibility_gain_weight * (visibility_future - initial_visibility)
261
+ + self.config.clearance_weight * clearance
262
+ + self.config.occluder_contact_weight * occluder_contact
263
+ + self.config.grasp_affordance_weight * grasp_affordance
264
+ + self.config.persistence_weight * persistence
265
+ + self.config.support_stability_weight * support_stability
266
+ + self.config.corridor_weight * access_quality
267
+ + self.config.task_progress_weight * retrieve_progress
268
+ - self.config.reocclusion_weight * reocclusion
269
+ - self.config.disturbance_weight * disturbance
270
+ - self.config.visibility_weight * (1.0 - visibility_future)
271
+ )
272
+ return {
273
+ "belief_gain": belief_future - initial_belief,
274
+ "visibility_gain": visibility_future - initial_visibility,
275
+ "clearance": clearance,
276
+ "occluder_contact_quality": occluder_contact,
277
+ "grasp_affordance": grasp_affordance,
278
+ "persistence": persistence,
279
+ "support_stability": support_stability,
280
+ "reocclusion_penalty": reocclusion,
281
+ "disturbance_penalty": disturbance,
282
+ "access_quality": access_quality,
283
+ "task_progress": retrieve_progress,
284
+ "utility_structured": utility,
285
+ }
286
+
287
+
288
+ class ResidualPlannerScorer(nn.Module):
289
+ def __init__(self, config: PlannerConfig) -> None:
290
+ super().__init__()
291
+ feature_dim = (config.action_dim * 2) + 11
292
+ self.trunk = nn.Sequential(
293
+ nn.LayerNorm(feature_dim),
294
+ nn.Linear(feature_dim, config.hidden_dim),
295
+ nn.GELU(),
296
+ nn.Linear(config.hidden_dim, config.hidden_dim),
297
+ nn.GELU(),
298
+ )
299
+ self.success_head = nn.Linear(config.hidden_dim, 1)
300
+ self.risk_head = nn.Linear(config.hidden_dim, 1)
301
+ self.residual_head = nn.Linear(config.hidden_dim, 1)
302
+
303
+ def forward(
304
+ self,
305
+ candidate_chunks: Tensor,
306
+ structured: dict[str, Tensor],
307
+ proposal_logits: Tensor | None = None,
308
+ ) -> dict[str, Tensor]:
309
+ candidate_mean = candidate_chunks.mean(dim=2)
310
+ candidate_terminal = candidate_chunks[:, :, -1]
311
+ components = torch.stack(
312
+ [
313
+ structured["belief_gain"],
314
+ structured["visibility_gain"],
315
+ structured["clearance"],
316
+ structured["occluder_contact_quality"],
317
+ structured["grasp_affordance"],
318
+ structured["persistence"],
319
+ structured["support_stability"],
320
+ structured["reocclusion_penalty"],
321
+ structured["disturbance_penalty"],
322
+ structured["access_quality"],
323
+ structured["task_progress"],
324
+ ],
325
+ dim=-1,
326
+ )
327
+ features = torch.cat([candidate_mean, candidate_terminal, components], dim=-1)
328
+ hidden = self.trunk(features)
329
+ success_logits = self.success_head(hidden).squeeze(-1)
330
+ risk_values = torch.sigmoid(self.risk_head(hidden)).squeeze(-1)
331
+ residual = self.residual_head(hidden).squeeze(-1)
332
+ if proposal_logits is not None and proposal_logits.shape == residual.shape:
333
+ residual = residual + 0.25 * proposal_logits.sigmoid()
334
+ return {
335
+ "planner_hidden": hidden,
336
+ "success_logits": success_logits,
337
+ "risk_values": risk_values,
338
+ "utility_residual": residual,
339
+ }
340
+
341
+
342
+ class CascadePlanner(nn.Module):
343
+ def __init__(self, config: PlannerConfig) -> None:
344
+ super().__init__()
345
+ self.config = config
346
+ self.structured = StructuredElasticUtility(config)
347
+ self.residual = ResidualPlannerScorer(config)
348
+
349
+ def shortlist(self, proposal_logits: Tensor | None, candidate_chunks: Tensor) -> Tensor:
350
+ batch_size, num_candidates = candidate_chunks.shape[:2]
351
+ top_k = min(max(1, self.config.top_k), num_candidates)
352
+ if proposal_logits is None:
353
+ cheap_scores = -candidate_chunks.square().mean(dim=(-1, -2))
354
+ else:
355
+ cheap_scores = proposal_logits
356
+ return cheap_scores.topk(top_k, dim=-1).indices
357
+
358
+ def select_best(
359
+ self,
360
+ initial_state: dict[str, Tensor],
361
+ candidate_chunks: Tensor,
362
+ rollout_state: dict[str, Tensor],
363
+ proposal_logits: Tensor | None = None,
364
+ candidate_indices: Tensor | None = None,
365
+ ) -> dict[str, Tensor]:
366
+ structured = self.structured(
367
+ initial_state=initial_state,
368
+ rollout_state=rollout_state,
369
+ candidate_chunks=candidate_chunks,
370
+ )
371
+ residual = self.residual(
372
+ candidate_chunks=candidate_chunks,
373
+ structured=structured,
374
+ proposal_logits=proposal_logits,
375
+ )
376
+ utility_total = structured["utility_structured"] + self.config.residual_weight * residual["utility_residual"]
377
+ utility_total = utility_total + residual["success_logits"].sigmoid() - residual["risk_values"]
378
+ best_local = utility_total.argmax(dim=-1)
379
+ batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
380
+ if candidate_indices is None:
381
+ best_indices = best_local
382
+ else:
383
+ best_indices = candidate_indices[batch_indices, best_local]
384
+ return {
385
+ **structured,
386
+ **residual,
387
+ "utility_total": utility_total,
388
+ "utility_scores": utility_total,
389
+ "best_indices": best_indices,
390
+ "best_chunk": candidate_chunks[batch_indices, best_local],
391
+ "ranking_diagnostics": {
392
+ "topk_indices": candidate_indices if candidate_indices is not None else best_local.unsqueeze(-1),
393
+ "best_local_indices": best_local,
394
+ },
395
+ }
code/reveal_vla_bimanual/models/policy.py CHANGED
@@ -6,13 +6,28 @@ from typing import Sequence
6
  import torch
7
  from torch import Tensor, nn
8
 
9
- from models.action_decoder import ACTBimanualChunkDecoder, ChunkDecoderConfig, InteractionChunkDecoder
 
 
 
 
 
10
  from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
11
  from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
12
- from models.observation_memory import InteractionObservationMemory, ObservationMemory, ObservationMemoryConfig
13
- from models.planner import InteractionPlanner, PlannerConfig, RevealPlanner
14
- from models.reveal_head import InteractionStateHead, RevealHeadConfig, RevealStateHead
15
- from models.world_model import InteractionRolloutModel, RevealWM, RevealWMConfig
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  @dataclass
@@ -351,3 +366,302 @@ class InteractionBimanualPolicy(BackboneOnlyPolicy):
351
  outputs["planner_scores"] = selected["utility_scores"]
352
  outputs["best_candidate_indices"] = selected["best_indices"]
353
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import torch
7
  from torch import Tensor, nn
8
 
9
+ from models.action_decoder import (
10
+ ACTBimanualChunkDecoder,
11
+ ChunkDecoderConfig,
12
+ InteractionChunkDecoder,
13
+ SymmetricCoordinatedChunkDecoder,
14
+ )
15
  from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
16
  from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
17
+ from models.observation_memory import (
18
+ DualObservationMemory,
19
+ InteractionObservationMemory,
20
+ ObservationMemory,
21
+ ObservationMemoryConfig,
22
+ )
23
+ from models.planner import CascadePlanner, InteractionPlanner, PlannerConfig, RevealPlanner
24
+ from models.reveal_head import (
25
+ ElasticOcclusionStateHead,
26
+ InteractionStateHead,
27
+ RevealHeadConfig,
28
+ RevealStateHead,
29
+ )
30
+ from models.world_model import ElasticOcclusionWorldModel, InteractionRolloutModel, RevealWM, RevealWMConfig
31
 
32
 
33
  @dataclass
 
366
  outputs["planner_scores"] = selected["utility_scores"]
367
  outputs["best_candidate_indices"] = selected["best_indices"]
368
  return outputs
369
+
370
+
371
+ class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
372
+ def __init__(self, config: PolicyConfig) -> None:
373
+ super().__init__(config)
374
+ self.memory = DualObservationMemory(config.memory)
375
+ self.decoder = SymmetricCoordinatedChunkDecoder(config.decoder)
376
+ self.elastic_state_head = ElasticOcclusionStateHead(config.reveal_head)
377
+ self.world_model = ElasticOcclusionWorldModel(config.world_model)
378
+ self.planner = CascadePlanner(config.planner)
379
+
380
+ def _encode_scene_with_optional_depth(
381
+ self,
382
+ images: Tensor,
383
+ proprio: Tensor,
384
+ texts: Sequence[str] | None = None,
385
+ language_tokens: dict[str, Tensor] | None = None,
386
+ depths: Tensor | None = None,
387
+ depth_valid: Tensor | None = None,
388
+ camera_intrinsics: Tensor | None = None,
389
+ camera_extrinsics: Tensor | None = None,
390
+ use_depth: bool = True,
391
+ ) -> dict[str, Tensor]:
392
+ encoded = self.backbone.encode_images(
393
+ images,
394
+ depths=depths if use_depth else None,
395
+ depth_valid=depth_valid if use_depth else None,
396
+ camera_intrinsics=camera_intrinsics if use_depth else None,
397
+ camera_extrinsics=camera_extrinsics if use_depth else None,
398
+ return_aux=True,
399
+ )
400
+ assert isinstance(encoded, dict)
401
+ text_tokens = self._encode_language(images, texts=texts, language_tokens=language_tokens)
402
+ fused = self.fusion(
403
+ image_tokens=encoded["rgb_tokens"],
404
+ proprio=proprio,
405
+ language_tokens=text_tokens,
406
+ depth_tokens=encoded.get("depth_tokens"),
407
+ camera_tokens=encoded.get("camera_tokens"),
408
+ return_aux=True,
409
+ )
410
+ assert isinstance(fused, dict)
411
+ return {
412
+ "scene_tokens": fused["scene_tokens"],
413
+ "view_summaries": fused["view_summaries"],
414
+ "geometry_summaries": fused["geometry_summaries"],
415
+ "depth_tokens": encoded.get("depth_tokens"),
416
+ "camera_tokens": encoded.get("camera_tokens"),
417
+ }
418
+
419
+ def _expand_language_tokens_for_history(
420
+ self,
421
+ language_tokens: dict[str, Tensor] | None,
422
+ history_steps: int,
423
+ ) -> dict[str, Tensor] | None:
424
+ if language_tokens is None:
425
+ return None
426
+ return {
427
+ key: value.unsqueeze(1).expand(-1, history_steps, *value.shape[1:]).reshape(
428
+ value.shape[0] * history_steps, *value.shape[1:]
429
+ )
430
+ for key, value in language_tokens.items()
431
+ }
432
+
433
+ def encode_history_with_optional_depth(
434
+ self,
435
+ history_images: Tensor | None,
436
+ history_proprio: Tensor | None,
437
+ texts: Sequence[str] | None = None,
438
+ language_tokens: dict[str, Tensor] | None = None,
439
+ history_depths: Tensor | None = None,
440
+ history_depth_valid: Tensor | None = None,
441
+ camera_intrinsics: Tensor | None = None,
442
+ camera_extrinsics: Tensor | None = None,
443
+ use_depth: bool = True,
444
+ ) -> Tensor | None:
445
+ if history_images is None or history_proprio is None or history_images.numel() == 0:
446
+ return None
447
+ batch_size, history_steps = history_images.shape[:2]
448
+ flat_images = history_images.reshape(batch_size * history_steps, *history_images.shape[2:])
449
+ flat_proprio = history_proprio.reshape(batch_size * history_steps, history_proprio.shape[-1])
450
+ flat_depths = None
451
+ flat_depth_valid = None
452
+ if history_depths is not None and history_depths.numel() > 0:
453
+ flat_depths = history_depths.reshape(batch_size * history_steps, *history_depths.shape[2:])
454
+ if history_depth_valid is not None and history_depth_valid.numel() > 0:
455
+ flat_depth_valid = history_depth_valid.reshape(batch_size * history_steps, *history_depth_valid.shape[2:])
456
+ if language_tokens is None:
457
+ flat_texts = [text for text in texts for _ in range(history_steps)] if texts is not None else None
458
+ flat_language_tokens = None
459
+ else:
460
+ flat_texts = None
461
+ flat_language_tokens = self._expand_language_tokens_for_history(language_tokens, history_steps)
462
+ history_scene = self._encode_scene_with_optional_depth(
463
+ images=flat_images,
464
+ proprio=flat_proprio,
465
+ texts=flat_texts,
466
+ language_tokens=flat_language_tokens,
467
+ depths=flat_depths,
468
+ depth_valid=flat_depth_valid,
469
+ camera_intrinsics=None,
470
+ camera_extrinsics=None,
471
+ use_depth=use_depth,
472
+ )["scene_tokens"]
473
+ return history_scene.view(batch_size, history_steps, history_scene.shape[1], history_scene.shape[2])
474
+
475
+ def _tile_tensor(self, value: Tensor, num_candidates: int) -> Tensor:
476
+ return value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape(
477
+ value.shape[0] * num_candidates,
478
+ *value.shape[1:],
479
+ )
480
+
481
+ def _tile_state(self, state: dict[str, Tensor], num_candidates: int) -> dict[str, Tensor]:
482
+ tiled: dict[str, Tensor] = {}
483
+ for key, value in state.items():
484
+ if isinstance(value, Tensor):
485
+ tiled[key] = self._tile_tensor(value, num_candidates)
486
+ return tiled
487
+
488
+ def forward(
489
+ self,
490
+ images: Tensor,
491
+ proprio: Tensor,
492
+ texts: Sequence[str] | None = None,
493
+ language_tokens: dict[str, Tensor] | None = None,
494
+ history_images: Tensor | None = None,
495
+ history_proprio: Tensor | None = None,
496
+ history_actions: Tensor | None = None,
497
+ plan: bool = True,
498
+ support_mode_conditioning: bool = True,
499
+ candidate_chunks_override: Tensor | None = None,
500
+ use_depth: bool = True,
501
+ use_world_model: bool = True,
502
+ use_planner: bool = True,
503
+ use_role_tokens: bool = True,
504
+ history_steps_override: int | None = None,
505
+ depths: Tensor | None = None,
506
+ depth_valid: Tensor | None = None,
507
+ camera_intrinsics: Tensor | None = None,
508
+ camera_extrinsics: Tensor | None = None,
509
+ history_depths: Tensor | None = None,
510
+ history_depth_valid: Tensor | None = None,
511
+ compute_equivariance_probe: bool = False,
512
+ ) -> dict[str, Tensor]:
513
+ scene_output = self._encode_scene_with_optional_depth(
514
+ images=images,
515
+ proprio=proprio,
516
+ texts=texts,
517
+ language_tokens=language_tokens,
518
+ depths=depths,
519
+ depth_valid=depth_valid,
520
+ camera_intrinsics=camera_intrinsics,
521
+ camera_extrinsics=camera_extrinsics,
522
+ use_depth=use_depth,
523
+ )
524
+ scene_tokens = scene_output["scene_tokens"]
525
+ history_scene_tokens = self.encode_history_with_optional_depth(
526
+ history_images=history_images,
527
+ history_proprio=history_proprio,
528
+ texts=texts,
529
+ language_tokens=language_tokens,
530
+ history_depths=history_depths,
531
+ history_depth_valid=history_depth_valid,
532
+ camera_intrinsics=camera_intrinsics,
533
+ camera_extrinsics=camera_extrinsics,
534
+ use_depth=use_depth,
535
+ )
536
+ if history_steps_override is not None and history_scene_tokens is not None and history_scene_tokens.numel() > 0:
537
+ history_scene_tokens = history_scene_tokens[:, -history_steps_override:]
538
+ if history_actions is not None and history_actions.numel() > 0:
539
+ history_actions = history_actions[:, -history_steps_override:]
540
+ memory_output = self.memory(
541
+ scene_tokens,
542
+ history_scene_tokens=history_scene_tokens,
543
+ history_actions=history_actions,
544
+ )
545
+ elastic_state = self.elastic_state_head(
546
+ scene_tokens,
547
+ memory_tokens=memory_output["memory_tokens"],
548
+ )
549
+ elastic_state["memory_tokens"] = memory_output["memory_tokens"]
550
+ elastic_state["memory_token"] = memory_output["memory_token"]
551
+ elastic_state["scene_memory_tokens"] = memory_output["scene_memory_tokens"]
552
+ elastic_state["belief_memory_tokens"] = memory_output["belief_memory_tokens"]
553
+ if not use_role_tokens:
554
+ elastic_state = dict(elastic_state)
555
+ elastic_state["arm_role_logits"] = torch.zeros_like(elastic_state["arm_role_logits"])
556
+
557
+ decoded = self.decoder(
558
+ scene_tokens,
559
+ interaction_state=elastic_state,
560
+ memory_tokens=memory_output["memory_tokens"],
561
+ compute_equivariance_probe=compute_equivariance_probe,
562
+ )
563
+ outputs = {
564
+ **decoded,
565
+ "scene_tokens": scene_tokens,
566
+ "history_scene_tokens": history_scene_tokens,
567
+ "memory_output": memory_output,
568
+ "memory_uncertainty": memory_output["memory_uncertainty"],
569
+ "interaction_state": elastic_state,
570
+ "reveal_state": elastic_state,
571
+ "view_summaries": scene_output["view_summaries"],
572
+ "geometry_summaries": scene_output["geometry_summaries"],
573
+ }
574
+
575
+ candidate_chunks = candidate_chunks_override
576
+ proposal_logits = outputs.get("proposal_logits")
577
+ if candidate_chunks is None:
578
+ candidate_chunks = self.decoder.sample_candidates(
579
+ outputs["action_mean"],
580
+ outputs["action_log_std"],
581
+ num_candidates=self.config.decoder.num_candidates,
582
+ proposal_candidates=outputs.get("proposal_candidates"),
583
+ )
584
+ else:
585
+ proposal_logits = None
586
+ outputs["candidate_chunks"] = candidate_chunks
587
+
588
+ if not plan or not use_planner:
589
+ outputs["planned_chunk"] = outputs["action_mean"]
590
+ outputs["planned_rollout"] = {}
591
+ outputs["planner_success_logits"] = torch.zeros(
592
+ candidate_chunks.shape[:2],
593
+ device=candidate_chunks.device,
594
+ dtype=candidate_chunks.dtype,
595
+ )
596
+ outputs["planner_risk_values"] = torch.zeros_like(outputs["planner_success_logits"])
597
+ outputs["planner_scores"] = torch.zeros_like(outputs["planner_success_logits"])
598
+ outputs["best_candidate_indices"] = torch.zeros(
599
+ candidate_chunks.shape[0],
600
+ dtype=torch.long,
601
+ device=candidate_chunks.device,
602
+ )
603
+ return outputs
604
+
605
+ shortlist_indices = self.planner.shortlist(proposal_logits=proposal_logits, candidate_chunks=candidate_chunks)
606
+ outputs["planner_topk_indices"] = shortlist_indices
607
+ batch_size = candidate_chunks.shape[0]
608
+ batch_indices = torch.arange(batch_size, device=candidate_chunks.device).unsqueeze(-1)
609
+ topk_candidates = candidate_chunks[batch_indices, shortlist_indices]
610
+ outputs["planner_topk_candidates"] = topk_candidates
611
+ if proposal_logits is not None:
612
+ topk_proposal_logits = proposal_logits.gather(1, shortlist_indices)
613
+ else:
614
+ topk_proposal_logits = None
615
+
616
+ if not use_world_model:
617
+ score_source = topk_proposal_logits if topk_proposal_logits is not None else -topk_candidates.square().mean(dim=(-1, -2))
618
+ best_local = score_source.argmax(dim=-1)
619
+ best_indices = shortlist_indices[torch.arange(batch_size, device=best_local.device), best_local]
620
+ outputs["planned_chunk"] = candidate_chunks[torch.arange(batch_size, device=best_local.device), best_indices]
621
+ outputs["planned_rollout"] = {}
622
+ outputs["planner_success_logits"] = torch.zeros_like(score_source)
623
+ outputs["planner_risk_values"] = torch.zeros_like(score_source)
624
+ outputs["planner_scores"] = score_source
625
+ outputs["best_candidate_indices"] = best_indices
626
+ outputs["utility_structured"] = score_source
627
+ outputs["utility_residual"] = torch.zeros_like(score_source)
628
+ outputs["utility_total"] = score_source
629
+ return outputs
630
+
631
+ num_topk = topk_candidates.shape[1]
632
+ flat_chunks = topk_candidates.view(batch_size * num_topk, topk_candidates.shape[2], topk_candidates.shape[3])
633
+ tiled_scene = self._tile_tensor(scene_tokens, num_topk)
634
+ planning_state = elastic_state
635
+ if not support_mode_conditioning:
636
+ planning_state = dict(elastic_state)
637
+ planning_state["support_mode_logits"] = torch.zeros_like(elastic_state["support_mode_logits"])
638
+ tiled_state = self._tile_state(planning_state, num_topk)
639
+ rollout = self.world_model(
640
+ scene_tokens=tiled_scene,
641
+ interaction_state=tiled_state,
642
+ action_chunk=flat_chunks,
643
+ memory_tokens=self._tile_tensor(memory_output["memory_tokens"], num_topk),
644
+ scene_memory_tokens=self._tile_tensor(memory_output["scene_memory_tokens"], num_topk),
645
+ belief_memory_tokens=self._tile_tensor(memory_output["belief_memory_tokens"], num_topk),
646
+ )
647
+ reshaped_rollout = {
648
+ key: value.view(batch_size, num_topk, *value.shape[1:]) for key, value in rollout.items()
649
+ }
650
+ selected = self.planner.select_best(
651
+ initial_state=elastic_state,
652
+ candidate_chunks=topk_candidates,
653
+ rollout_state=reshaped_rollout,
654
+ proposal_logits=topk_proposal_logits,
655
+ candidate_indices=shortlist_indices,
656
+ )
657
+ outputs["planned_rollout"] = reshaped_rollout
658
+ outputs["planned_chunk"] = selected["best_chunk"]
659
+ outputs["planner_success_logits"] = selected["success_logits"]
660
+ outputs["planner_risk_values"] = selected["risk_values"]
661
+ outputs["planner_scores"] = selected["utility_total"]
662
+ outputs["best_candidate_indices"] = selected["best_indices"]
663
+ outputs["utility_structured"] = selected["utility_structured"]
664
+ outputs["utility_residual"] = selected["utility_residual"]
665
+ outputs["utility_total"] = selected["utility_total"]
666
+ outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
667
+ return outputs
code/reveal_vla_bimanual/models/reveal_head.py CHANGED
@@ -317,3 +317,245 @@ class InteractionStateHead(nn.Module):
317
  scene_tokens=scene_tokens,
318
  memory_tokens=memory_tokens,
319
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  scene_tokens=scene_tokens,
318
  memory_tokens=memory_tokens,
319
  )
320
+
321
+
322
+ class ElasticOcclusionFieldDecoder(nn.Module):
323
+ def __init__(self, config: RevealHeadConfig) -> None:
324
+ super().__init__()
325
+ self.config = config
326
+ self.field_queries = nn.Parameter(
327
+ torch.randn(config.field_size * config.field_size, config.hidden_dim) * 0.02
328
+ )
329
+ self.field_attention = nn.MultiheadAttention(
330
+ embed_dim=config.hidden_dim,
331
+ num_heads=config.num_heads,
332
+ batch_first=True,
333
+ )
334
+ self.field_mlp = nn.Sequential(
335
+ nn.LayerNorm(config.hidden_dim),
336
+ nn.Linear(config.hidden_dim, config.hidden_dim),
337
+ nn.GELU(),
338
+ nn.Linear(config.hidden_dim, config.hidden_dim),
339
+ )
340
+ summary_dim = config.hidden_dim * 4
341
+ self.summary_proj = nn.Sequential(
342
+ nn.LayerNorm(summary_dim),
343
+ nn.Linear(summary_dim, config.hidden_dim),
344
+ nn.GELU(),
345
+ )
346
+ self.phase_head = nn.Sequential(
347
+ nn.LayerNorm(summary_dim),
348
+ nn.Linear(summary_dim, config.hidden_dim),
349
+ nn.GELU(),
350
+ nn.Linear(config.hidden_dim, config.num_phases),
351
+ )
352
+ self.arm_role_head = nn.Sequential(
353
+ nn.LayerNorm(config.hidden_dim * 2),
354
+ nn.Linear(config.hidden_dim * 2, config.hidden_dim),
355
+ nn.GELU(),
356
+ nn.Linear(config.hidden_dim, config.num_arm_roles),
357
+ )
358
+ self.arm_identity = nn.Embedding(2, config.hidden_dim)
359
+ self.support_mode = nn.Sequential(
360
+ nn.LayerNorm(summary_dim),
361
+ nn.Linear(summary_dim, config.hidden_dim),
362
+ nn.GELU(),
363
+ nn.Linear(config.hidden_dim, config.num_support_modes),
364
+ )
365
+ self.access_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1)
366
+ self.target_belief_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
367
+ self.visibility_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
368
+ self.clearance_field = nn.Conv2d(config.hidden_dim, 2, kernel_size=1)
369
+ self.occluder_contact_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
370
+ self.grasp_affordance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
371
+ self.support_stability_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
372
+ self.persistence_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
373
+ self.reocclusion_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
374
+ self.disturbance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
375
+ self.uncertainty_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
376
+ self.reocclusion_head = nn.Sequential(
377
+ nn.LayerNorm(summary_dim),
378
+ nn.Linear(summary_dim, config.hidden_dim),
379
+ nn.GELU(),
380
+ nn.Linear(config.hidden_dim, config.num_support_modes),
381
+ )
382
+
383
+ def _pool_source(self, source_tokens: Tensor | None, fallback: Tensor) -> Tensor:
384
+ if source_tokens is None or source_tokens.numel() == 0:
385
+ return fallback.new_zeros(fallback.shape)
386
+ return source_tokens.mean(dim=1)
387
+
388
+ def _field_mean(self, field: Tensor) -> Tensor:
389
+ return field.mean(dim=(-1, -2))
390
+
391
+ def _upsampled_belief(self, target_belief_field: Tensor) -> Tensor:
392
+ if target_belief_field.shape[-1] == self.config.belief_map_size:
393
+ return target_belief_field
394
+ return F.interpolate(
395
+ target_belief_field,
396
+ size=(self.config.belief_map_size, self.config.belief_map_size),
397
+ mode="bilinear",
398
+ align_corners=False,
399
+ )
400
+
401
+ def forward(
402
+ self,
403
+ interaction_tokens: Tensor,
404
+ scene_tokens: Tensor | None = None,
405
+ memory_tokens: Tensor | None = None,
406
+ ) -> dict[str, Tensor]:
407
+ batch_size = interaction_tokens.shape[0]
408
+ pooled_interaction = interaction_tokens.mean(dim=1)
409
+ pooled_scene = self._pool_source(scene_tokens, pooled_interaction)
410
+ pooled_memory = self._pool_source(memory_tokens, pooled_interaction)
411
+
412
+ field_queries = self.field_queries.unsqueeze(0).expand(batch_size, -1, -1)
413
+ source_tokens = interaction_tokens
414
+ if scene_tokens is not None:
415
+ source_tokens = torch.cat([source_tokens, scene_tokens], dim=1)
416
+ if memory_tokens is not None:
417
+ source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
418
+ field_tokens, _ = self.field_attention(field_queries, source_tokens, source_tokens)
419
+ field_tokens = field_tokens + self.field_mlp(field_tokens)
420
+
421
+ side = self.config.field_size
422
+ grid = field_tokens.transpose(1, 2).reshape(batch_size, self.config.hidden_dim, side, side)
423
+ pooled_field = field_tokens.mean(dim=1)
424
+ summary_input = torch.cat([pooled_interaction, pooled_field, pooled_scene, pooled_memory], dim=-1)
425
+ latent_summary = self.summary_proj(summary_input)
426
+
427
+ access_field = self.access_field(grid)
428
+ target_belief_field = self.target_belief_field(grid)
429
+ visibility_field = self.visibility_field(grid)
430
+ clearance_field = self.clearance_field(grid)
431
+ occluder_contact_field = self.occluder_contact_field(grid)
432
+ grasp_affordance_field = self.grasp_affordance_field(grid)
433
+ support_stability_field = self.support_stability_field(grid)
434
+ persistence_field = torch.sigmoid(self.persistence_field(grid))
435
+ reocclusion_field = torch.sigmoid(self.reocclusion_field(grid))
436
+ disturbance_field = torch.sigmoid(self.disturbance_field(grid))
437
+ uncertainty_field = F.softplus(self.uncertainty_field(grid))
438
+
439
+ support_stability_prob = torch.sigmoid(support_stability_field)
440
+ risk_field = torch.sigmoid(
441
+ disturbance_field
442
+ + 0.75 * reocclusion_field
443
+ + 0.5 * (1.0 - support_stability_prob)
444
+ + 0.25 * uncertainty_field
445
+ )
446
+ corridor_source = access_field.amax(dim=-2)
447
+ corridor_logits = F.interpolate(
448
+ corridor_source,
449
+ size=self.config.num_approach_templates,
450
+ mode="linear",
451
+ align_corners=False,
452
+ )
453
+ access_prob = torch.sigmoid(access_field)
454
+ weighted_persistence = (persistence_field.expand_as(access_prob) * access_prob).sum(dim=(-1, -2))
455
+ access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4)
456
+ persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass
457
+ disturbance_cost = disturbance_field.mean(dim=(-1, -2)).squeeze(1)
458
+
459
+ arm_identity = self.arm_identity.weight.unsqueeze(0).expand(batch_size, -1, -1)
460
+ arm_tokens = pooled_interaction.unsqueeze(1).expand(-1, 2, -1) + arm_identity
461
+ arm_role_input = torch.cat(
462
+ [arm_tokens, latent_summary.unsqueeze(1).expand(-1, 2, -1)],
463
+ dim=-1,
464
+ )
465
+ arm_role_logits = self.arm_role_head(arm_role_input)
466
+ target_belief_map = self._upsampled_belief(target_belief_field)
467
+ compact_components = [
468
+ target_belief_field.mean(dim=(-1, -2)).squeeze(1),
469
+ visibility_field.mean(dim=(-1, -2)).squeeze(1),
470
+ clearance_field.mean(dim=(-1, -2)).mean(dim=1),
471
+ occluder_contact_field.mean(dim=(-1, -2)).squeeze(1),
472
+ grasp_affordance_field.mean(dim=(-1, -2)).squeeze(1),
473
+ support_stability_prob.mean(dim=(-1, -2)).squeeze(1),
474
+ persistence_field.mean(dim=(-1, -2)).squeeze(1),
475
+ reocclusion_field.mean(dim=(-1, -2)).squeeze(1),
476
+ disturbance_field.mean(dim=(-1, -2)).squeeze(1),
477
+ risk_field.mean(dim=(-1, -2)).squeeze(1),
478
+ uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
479
+ access_prob.mean(dim=(-1, -2)).transpose(0, 1).transpose(0, 1),
480
+ self.support_mode(summary_input),
481
+ self.phase_head(summary_input),
482
+ arm_role_logits.reshape(batch_size, -1),
483
+ ]
484
+ compact_state = torch.cat(
485
+ [component if component.ndim > 1 else component.unsqueeze(-1) for component in compact_components],
486
+ dim=-1,
487
+ )
488
+
489
+ output = {
490
+ "phase_logits": self.phase_head(summary_input),
491
+ "arm_role_logits": arm_role_logits,
492
+ "target_belief_field": target_belief_field,
493
+ "visibility_field": visibility_field,
494
+ "clearance_field": clearance_field,
495
+ "occluder_contact_field": occluder_contact_field,
496
+ "grasp_affordance_field": grasp_affordance_field,
497
+ "support_stability_field": support_stability_field,
498
+ "persistence_field": persistence_field,
499
+ "reocclusion_field": reocclusion_field,
500
+ "disturbance_field": disturbance_field,
501
+ "risk_field": risk_field,
502
+ "uncertainty_field": uncertainty_field,
503
+ "interaction_tokens": interaction_tokens,
504
+ "field_tokens": field_tokens,
505
+ "latent_summary": latent_summary,
506
+ "support_mode_logits": self.support_mode(summary_input),
507
+ "corridor_logits": corridor_logits,
508
+ "persistence_horizon": persistence_horizon,
509
+ "disturbance_cost": disturbance_cost,
510
+ "belief_map": target_belief_map,
511
+ "reocclusion_logit": self.reocclusion_head(summary_input),
512
+ "persistence_uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
513
+ "access_field": access_field,
514
+ "uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
515
+ "compact_state": compact_state,
516
+ }
517
+ output["target_field"] = output["target_belief_field"]
518
+ output["actor_feasibility_field"] = output["clearance_field"]
519
+ return output
520
+
521
+
522
+ class ElasticOcclusionStateHead(nn.Module):
523
+ def __init__(self, config: RevealHeadConfig) -> None:
524
+ super().__init__()
525
+ self.config = config
526
+ self.interaction_queries = nn.Parameter(
527
+ torch.randn(config.num_interaction_tokens, config.hidden_dim) * 0.02
528
+ )
529
+ self.interaction_attention = nn.MultiheadAttention(
530
+ embed_dim=config.hidden_dim,
531
+ num_heads=config.num_heads,
532
+ batch_first=True,
533
+ )
534
+ self.interaction_mlp = nn.Sequential(
535
+ nn.LayerNorm(config.hidden_dim),
536
+ nn.Linear(config.hidden_dim, config.hidden_dim),
537
+ nn.GELU(),
538
+ nn.Linear(config.hidden_dim, config.hidden_dim),
539
+ )
540
+ self.decoder = ElasticOcclusionFieldDecoder(config)
541
+
542
+ def forward(
543
+ self,
544
+ scene_tokens: Tensor,
545
+ memory_token: Tensor | None = None,
546
+ memory_tokens: Tensor | None = None,
547
+ ) -> dict[str, Tensor]:
548
+ if memory_tokens is None:
549
+ memory_tokens = memory_token
550
+ source_tokens = scene_tokens
551
+ if memory_tokens is not None:
552
+ source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
553
+ batch_size = source_tokens.shape[0]
554
+ interaction_queries = self.interaction_queries.unsqueeze(0).expand(batch_size, -1, -1)
555
+ interaction_tokens, _ = self.interaction_attention(interaction_queries, source_tokens, source_tokens)
556
+ interaction_tokens = interaction_tokens + self.interaction_mlp(interaction_tokens)
557
+ return self.decoder(
558
+ interaction_tokens=interaction_tokens,
559
+ scene_tokens=scene_tokens,
560
+ memory_tokens=memory_tokens,
561
+ )
code/reveal_vla_bimanual/models/world_model.py CHANGED
@@ -22,6 +22,8 @@ class RevealWMConfig:
22
  num_interaction_tokens: int = 8
23
  belief_map_size: int = 32
24
  predict_belief_map: bool = True
 
 
25
 
26
 
27
  class RevealWM(nn.Module):
@@ -152,3 +154,186 @@ class InteractionRolloutModel(nn.Module):
152
  for key, values in outputs.items():
153
  stacked[key] = torch.stack(values, dim=1)
154
  return stacked
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  num_interaction_tokens: int = 8
23
  belief_map_size: int = 32
24
  predict_belief_map: bool = True
25
+ scene_bank_size: int = 2
26
+ belief_bank_size: int = 2
27
 
28
 
29
  class RevealWM(nn.Module):
 
154
  for key, values in outputs.items():
155
  stacked[key] = torch.stack(values, dim=1)
156
  return stacked
157
+
158
+
159
+ class ElasticOcclusionWorldModel(nn.Module):
160
+ def __init__(self, config: RevealWMConfig) -> None:
161
+ super().__init__()
162
+ self.config = config
163
+ compact_state_dim = (
164
+ 11
165
+ + config.num_support_modes
166
+ + config.num_support_modes
167
+ + config.num_phases
168
+ + (2 * config.num_arm_roles)
169
+ )
170
+ self.state_encoder = nn.Sequential(
171
+ nn.LayerNorm(compact_state_dim),
172
+ nn.Linear(compact_state_dim, config.hidden_dim),
173
+ nn.GELU(),
174
+ )
175
+ self.scene_memory_proj = nn.Sequential(
176
+ nn.LayerNorm(config.hidden_dim),
177
+ nn.Linear(config.hidden_dim, config.hidden_dim),
178
+ nn.GELU(),
179
+ )
180
+ self.belief_memory_proj = nn.Sequential(
181
+ nn.LayerNorm(config.hidden_dim),
182
+ nn.Linear(config.hidden_dim, config.hidden_dim),
183
+ nn.GELU(),
184
+ )
185
+ self.action_encoder = nn.Sequential(
186
+ nn.LayerNorm(config.action_dim),
187
+ nn.Linear(config.action_dim, config.hidden_dim),
188
+ nn.GELU(),
189
+ )
190
+ self.transition = nn.GRUCell(config.hidden_dim * 4, config.hidden_dim)
191
+ self.scene_memory_update = nn.Linear(config.hidden_dim, config.hidden_dim)
192
+ self.belief_memory_update = nn.Linear(config.hidden_dim, config.hidden_dim)
193
+ self.compact_decoder = nn.Linear(config.hidden_dim, compact_state_dim)
194
+ field_elements = config.field_size * config.field_size
195
+ self.target_belief_head = nn.Linear(config.hidden_dim, field_elements)
196
+ self.visibility_head = nn.Linear(config.hidden_dim, field_elements)
197
+ self.clearance_head = nn.Linear(config.hidden_dim, 2 * field_elements)
198
+ self.occluder_contact_head = nn.Linear(config.hidden_dim, field_elements)
199
+ self.grasp_affordance_head = nn.Linear(config.hidden_dim, field_elements)
200
+ self.support_stability_head = nn.Linear(config.hidden_dim, field_elements)
201
+ self.persistence_head = nn.Linear(config.hidden_dim, field_elements)
202
+ self.reocclusion_head = nn.Linear(config.hidden_dim, field_elements)
203
+ self.disturbance_head = nn.Linear(config.hidden_dim, field_elements)
204
+ self.uncertainty_head = nn.Linear(config.hidden_dim, field_elements)
205
+ self.access_head = nn.Linear(config.hidden_dim, config.num_support_modes * field_elements)
206
+
207
+ def _compact_from_state(self, interaction_state: dict[str, Tensor]) -> Tensor:
208
+ if "compact_state" in interaction_state:
209
+ return interaction_state["compact_state"]
210
+ components = [
211
+ interaction_state["target_belief_field"].mean(dim=(-1, -2)).squeeze(1),
212
+ interaction_state["visibility_field"].mean(dim=(-1, -2)).squeeze(1),
213
+ interaction_state["clearance_field"].mean(dim=(-1, -2)).mean(dim=1),
214
+ interaction_state["occluder_contact_field"].mean(dim=(-1, -2)).squeeze(1),
215
+ interaction_state["grasp_affordance_field"].mean(dim=(-1, -2)).squeeze(1),
216
+ torch.sigmoid(interaction_state["support_stability_field"]).mean(dim=(-1, -2)).squeeze(1),
217
+ interaction_state["persistence_field"].mean(dim=(-1, -2)).squeeze(1),
218
+ interaction_state["reocclusion_field"].mean(dim=(-1, -2)).squeeze(1),
219
+ interaction_state["disturbance_field"].mean(dim=(-1, -2)).squeeze(1),
220
+ interaction_state["risk_field"].mean(dim=(-1, -2)).squeeze(1),
221
+ interaction_state["uncertainty_field"].mean(dim=(-1, -2)).squeeze(1),
222
+ torch.sigmoid(interaction_state["access_field"]).mean(dim=(-1, -2)),
223
+ interaction_state["support_mode_logits"],
224
+ interaction_state["phase_logits"],
225
+ interaction_state["arm_role_logits"].reshape(interaction_state["arm_role_logits"].shape[0], -1),
226
+ ]
227
+ return torch.cat([component if component.ndim > 1 else component.unsqueeze(-1) for component in components], dim=-1)
228
+
229
+ def _decode_fields(self, latent: Tensor) -> dict[str, Tensor]:
230
+ batch_size = latent.shape[0]
231
+ side = self.config.field_size
232
+ target_belief_field = self.target_belief_head(latent).view(batch_size, 1, side, side)
233
+ visibility_field = self.visibility_head(latent).view(batch_size, 1, side, side)
234
+ clearance_field = self.clearance_head(latent).view(batch_size, 2, side, side)
235
+ occluder_contact_field = self.occluder_contact_head(latent).view(batch_size, 1, side, side)
236
+ grasp_affordance_field = self.grasp_affordance_head(latent).view(batch_size, 1, side, side)
237
+ support_stability_field = self.support_stability_head(latent).view(batch_size, 1, side, side)
238
+ persistence_field = torch.sigmoid(self.persistence_head(latent).view(batch_size, 1, side, side))
239
+ reocclusion_field = torch.sigmoid(self.reocclusion_head(latent).view(batch_size, 1, side, side))
240
+ disturbance_field = torch.sigmoid(self.disturbance_head(latent).view(batch_size, 1, side, side))
241
+ uncertainty_field = torch.nn.functional.softplus(self.uncertainty_head(latent).view(batch_size, 1, side, side))
242
+ risk_field = torch.sigmoid(
243
+ disturbance_field
244
+ + 0.75 * reocclusion_field
245
+ + 0.5 * (1.0 - torch.sigmoid(support_stability_field))
246
+ + 0.25 * uncertainty_field
247
+ )
248
+ access_field = self.access_head(latent).view(batch_size, self.config.num_support_modes, side, side)
249
+ corridor_source = access_field.amax(dim=-2)
250
+ corridor_logits = torch.nn.functional.interpolate(
251
+ corridor_source,
252
+ size=self.config.num_approach_templates,
253
+ mode="linear",
254
+ align_corners=False,
255
+ )
256
+ access_prob = torch.sigmoid(access_field)
257
+ weighted_persistence = (persistence_field.expand_as(access_prob) * access_prob).sum(dim=(-1, -2))
258
+ access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4)
259
+ persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass
260
+ return {
261
+ "target_belief_field": target_belief_field,
262
+ "visibility_field": visibility_field,
263
+ "clearance_field": clearance_field,
264
+ "occluder_contact_field": occluder_contact_field,
265
+ "grasp_affordance_field": grasp_affordance_field,
266
+ "support_stability_field": support_stability_field,
267
+ "persistence_field": persistence_field,
268
+ "reocclusion_field": reocclusion_field,
269
+ "disturbance_field": disturbance_field,
270
+ "risk_field": risk_field,
271
+ "uncertainty_field": uncertainty_field,
272
+ "access_field": access_field,
273
+ "corridor_logits": corridor_logits,
274
+ "persistence_horizon": persistence_horizon,
275
+ "disturbance_cost": disturbance_field.mean(dim=(-1, -2)).squeeze(1),
276
+ "belief_map": torch.nn.functional.interpolate(
277
+ target_belief_field,
278
+ size=(self.config.belief_map_size, self.config.belief_map_size),
279
+ mode="bilinear",
280
+ align_corners=False,
281
+ ),
282
+ "target_field": target_belief_field,
283
+ "actor_feasibility_field": clearance_field,
284
+ }
285
+
286
+ def forward(
287
+ self,
288
+ scene_tokens: Tensor,
289
+ interaction_state: dict[str, Tensor],
290
+ action_chunk: Tensor,
291
+ memory_tokens: Tensor | None = None,
292
+ scene_memory_tokens: Tensor | None = None,
293
+ belief_memory_tokens: Tensor | None = None,
294
+ ) -> dict[str, Tensor]:
295
+ if scene_memory_tokens is None:
296
+ scene_memory_tokens = interaction_state.get("scene_memory_tokens")
297
+ if belief_memory_tokens is None:
298
+ belief_memory_tokens = interaction_state.get("belief_memory_tokens")
299
+ if scene_memory_tokens is None and memory_tokens is not None:
300
+ scene_memory_tokens = memory_tokens
301
+ if belief_memory_tokens is None and memory_tokens is not None:
302
+ belief_memory_tokens = memory_tokens
303
+ if scene_memory_tokens is None:
304
+ scene_memory_tokens = scene_tokens[:, :1]
305
+ if belief_memory_tokens is None:
306
+ belief_memory_tokens = scene_tokens[:, :1]
307
+
308
+ latent = self.state_encoder(self._compact_from_state(interaction_state))
309
+ scene_memory = self.scene_memory_proj(scene_memory_tokens.mean(dim=1))
310
+ belief_memory = self.belief_memory_proj(belief_memory_tokens.mean(dim=1))
311
+ outputs: dict[str, list[Tensor]] = {}
312
+ scene_bias = scene_tokens.mean(dim=1)
313
+
314
+ for step in range(action_chunk.shape[1]):
315
+ action_latent = self.action_encoder(action_chunk[:, step])
316
+ transition_input = torch.cat([latent, action_latent, scene_memory, belief_memory], dim=-1)
317
+ latent = self.transition(transition_input, latent + 0.1 * scene_bias)
318
+ scene_memory = 0.75 * scene_memory + 0.25 * torch.tanh(self.scene_memory_update(latent))
319
+ belief_memory = 0.65 * belief_memory + 0.35 * torch.tanh(self.belief_memory_update(latent))
320
+ compact_state = self.compact_decoder(latent)
321
+ decoded = self._decode_fields(latent)
322
+ decoded["compact_state"] = compact_state
323
+ decoded["phase_logits"] = compact_state[:, -(self.config.num_phases + (2 * self.config.num_arm_roles)) : -(2 * self.config.num_arm_roles)]
324
+ role_slice = compact_state[:, -(2 * self.config.num_arm_roles) :]
325
+ decoded["arm_role_logits"] = role_slice.view(role_slice.shape[0], 2, self.config.num_arm_roles)
326
+ decoded["support_mode_logits"] = compact_state[
327
+ :,
328
+ -(self.config.num_phases + (2 * self.config.num_arm_roles) + self.config.num_support_modes) : -(self.config.num_phases + (2 * self.config.num_arm_roles)),
329
+ ]
330
+ decoded["scene_memory_tokens"] = scene_memory.unsqueeze(1).expand(-1, self.config.scene_bank_size, -1)
331
+ decoded["belief_memory_tokens"] = belief_memory.unsqueeze(1).expand(-1, self.config.belief_bank_size, -1)
332
+ decoded["memory_tokens"] = torch.cat([decoded["scene_memory_tokens"], decoded["belief_memory_tokens"]], dim=1)
333
+ decoded["memory_token"] = decoded["memory_tokens"].mean(dim=1, keepdim=True)
334
+ decoded["uncertainty"] = decoded["uncertainty_field"].mean(dim=(-1, -2)).squeeze(1)
335
+ decoded["reocclusion_logit"] = decoded["reocclusion_field"].mean(dim=(-1, -2)).expand(-1, self.config.num_support_modes)
336
+ for key, value in decoded.items():
337
+ outputs.setdefault(key, []).append(value)
338
+
339
+ return {key: torch.stack(values, dim=1) for key, values in outputs.items()}
code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc differ
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc differ
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc differ
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc differ
 
code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc differ
 
code/reveal_vla_bimanual/sim_reveal/dataset.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  from pathlib import Path
4
  from typing import Any, Sequence
5
 
@@ -9,9 +10,10 @@ from torch.utils.data import Dataset
9
 
10
  import numpy as np
11
 
12
- from sim_reveal.procedural_envs import available_proxy_names, make_proxy_env, render_views_from_state
13
 
14
  NOLEAK_PROXY_DATASET_VERSION = "reveal_proxy_v5_noleak_actionhist"
 
15
  LEGACY_PRIVILEGED_RENDER_KEYS = frozenset(
16
  {
17
  "target_template",
@@ -44,6 +46,7 @@ def collect_teacher_dataset(
44
  rollout_horizon: int = 5,
45
  history_steps: int = 2,
46
  planner_candidates: int = 4,
 
47
  ) -> dict[str, Any]:
48
  proxy_names = tuple(proxy_names or available_proxy_names())
49
  samples: list[dict[str, Any]] = []
@@ -91,7 +94,7 @@ def collect_teacher_dataset(
91
  padded_history_actions.append(item["action"])
92
  samples.append(
93
  {
94
- "dataset_version": NOLEAK_PROXY_DATASET_VERSION,
95
  "proxy_name": proxy_name,
96
  "episode_id": episode_idx,
97
  "render_state": env.render_state(privileged_state),
@@ -103,10 +106,25 @@ def collect_teacher_dataset(
103
  "persistence_horizon": privileged_state["persistence_horizon"].astype("float32"),
104
  "disturbance_cost": float(privileged_state["disturbance_cost"]),
105
  "belief_map": privileged_state["belief_map"].astype("float32"),
 
 
 
 
 
 
 
 
106
  "rollout_support_mode": rollout["rollout_support_mode"].astype("int64"),
107
  "rollout_corridor_feasible": rollout["rollout_corridor_feasible"].astype("float32"),
108
  "rollout_persistence_horizon": rollout["rollout_persistence_horizon"].astype("float32"),
109
  "rollout_disturbance_cost": rollout["rollout_disturbance_cost"].astype("float32"),
 
 
 
 
 
 
 
110
  "history_render_states": padded_history_render_states,
111
  "history_proprio": np.stack(padded_history_proprio, axis=0).astype("float32")
112
  if padded_history_proprio
@@ -138,7 +156,7 @@ def collect_teacher_dataset(
138
  "teacher_success": proxy_success / float(max(1, episodes_per_proxy)),
139
  }
140
  return {
141
- "dataset_version": NOLEAK_PROXY_DATASET_VERSION,
142
  "resolution": resolution,
143
  "chunk_horizon": chunk_horizon,
144
  "rollout_horizon": rollout_horizon,
@@ -164,25 +182,46 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
164
  def __init__(self, samples: Sequence[dict[str, Any]], resolution: int = 96) -> None:
165
  self.samples = list(samples)
166
  self.resolution = resolution
 
 
167
 
168
  def __len__(self) -> int:
169
  return len(self.samples)
170
 
171
- def __getitem__(self, index: int) -> dict[str, Any]:
172
- sample = self.samples[index]
173
- _assert_noleak_sample(sample)
174
- images = render_views_from_state(
 
 
 
 
 
 
 
 
 
 
175
  proxy_name=sample["proxy_name"],
176
- render_state=sample["render_state"],
177
  resolution=self.resolution,
 
178
  )
 
 
 
 
 
 
 
 
 
 
179
  history_images = []
 
 
180
  for history_state in sample.get("history_render_states", []):
181
- rendered = render_views_from_state(
182
- proxy_name=sample["proxy_name"],
183
- render_state=history_state,
184
- resolution=self.resolution,
185
- )
186
  history_images.append(
187
  torch.stack(
188
  [
@@ -193,6 +232,27 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
193
  dim=0,
194
  )
195
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  stacked = torch.from_numpy(
197
  torch.stack(
198
  [
@@ -207,9 +267,42 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
207
  history_stacked = torch.stack(history_images, dim=0).permute(0, 1, 4, 2, 3).float() / 255.0
208
  else:
209
  history_stacked = torch.zeros((0, 3, 3, self.resolution, self.resolution), dtype=torch.float32)
210
- return {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  "images": stacked,
 
 
212
  "history_images": history_stacked,
 
 
213
  "history_proprio": torch.as_tensor(sample.get("history_proprio", []), dtype=torch.float32),
214
  "history_actions": torch.as_tensor(
215
  sample.get(
@@ -218,6 +311,8 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
218
  ),
219
  dtype=torch.float32,
220
  ),
 
 
221
  "proprio": torch.as_tensor(sample["proprio"], dtype=torch.float32),
222
  "texts": sample["language_goal"],
223
  "action_chunk": torch.as_tensor(sample["action_chunk"], dtype=torch.float32),
@@ -226,15 +321,37 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
226
  "persistence_horizon": torch.as_tensor(sample["persistence_horizon"], dtype=torch.float32),
227
  "disturbance_cost": torch.as_tensor(sample["disturbance_cost"], dtype=torch.float32),
228
  "belief_map": torch.as_tensor(sample["belief_map"], dtype=torch.float32).unsqueeze(0),
 
 
 
 
 
 
 
 
229
  "rollout_support_mode": torch.as_tensor(sample["rollout_support_mode"], dtype=torch.long),
230
  "rollout_corridor_feasible": torch.as_tensor(sample["rollout_corridor_feasible"], dtype=torch.float32),
231
  "rollout_persistence_horizon": torch.as_tensor(sample["rollout_persistence_horizon"], dtype=torch.float32),
232
  "rollout_disturbance_cost": torch.as_tensor(sample["rollout_disturbance_cost"], dtype=torch.float32),
 
 
 
 
 
 
 
233
  "candidate_action_chunks": torch.as_tensor(sample["candidate_action_chunks"], dtype=torch.float32),
234
  "candidate_rollout_support_mode": torch.as_tensor(sample["candidate_rollout_support_mode"], dtype=torch.long),
235
  "candidate_rollout_corridor_feasible": torch.as_tensor(sample["candidate_rollout_corridor_feasible"], dtype=torch.float32),
236
  "candidate_rollout_persistence_horizon": torch.as_tensor(sample["candidate_rollout_persistence_horizon"], dtype=torch.float32),
237
  "candidate_rollout_disturbance_cost": torch.as_tensor(sample["candidate_rollout_disturbance_cost"], dtype=torch.float32),
 
 
 
 
 
 
 
238
  "candidate_retrieval_success": torch.as_tensor(sample["candidate_retrieval_success"], dtype=torch.float32),
239
  "candidate_final_disturbance_cost": torch.as_tensor(sample["candidate_final_disturbance_cost"], dtype=torch.float32),
240
  "candidate_reocclusion_rate": torch.as_tensor(sample["candidate_reocclusion_rate"], dtype=torch.float32),
@@ -244,6 +361,8 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
244
  "proxy_name": sample["proxy_name"],
245
  "episode_id": sample["episode_id"],
246
  }
 
 
247
 
248
 
249
  def dataset_from_bundle(dataset_bundle: dict[str, Any], resolution: int | None = None) -> RevealOfflineDataset:
 
1
  from __future__ import annotations
2
 
3
+ import pickle
4
  from pathlib import Path
5
  from typing import Any, Sequence
6
 
 
10
 
11
  import numpy as np
12
 
13
+ from sim_reveal.procedural_envs import available_proxy_names, default_camera_matrices, make_proxy_env, render_views_from_state
14
 
15
  NOLEAK_PROXY_DATASET_VERSION = "reveal_proxy_v5_noleak_actionhist"
16
+ RGBD_PROXY_DATASET_VERSION = "reveal_proxy_v6_rgbd_elastic_state"
17
  LEGACY_PRIVILEGED_RENDER_KEYS = frozenset(
18
  {
19
  "target_template",
 
46
  rollout_horizon: int = 5,
47
  history_steps: int = 2,
48
  planner_candidates: int = 4,
49
+ dataset_version: str = NOLEAK_PROXY_DATASET_VERSION,
50
  ) -> dict[str, Any]:
51
  proxy_names = tuple(proxy_names or available_proxy_names())
52
  samples: list[dict[str, Any]] = []
 
94
  padded_history_actions.append(item["action"])
95
  samples.append(
96
  {
97
+ "dataset_version": dataset_version,
98
  "proxy_name": proxy_name,
99
  "episode_id": episode_idx,
100
  "render_state": env.render_state(privileged_state),
 
106
  "persistence_horizon": privileged_state["persistence_horizon"].astype("float32"),
107
  "disturbance_cost": float(privileged_state["disturbance_cost"]),
108
  "belief_map": privileged_state["belief_map"].astype("float32"),
109
+ "visibility_map": privileged_state["visibility_map"].astype("float32"),
110
+ "clearance_map": privileged_state["clearance_map"].astype("float32"),
111
+ "occluder_contact_map": privileged_state["occluder_contact_map"].astype("float32"),
112
+ "grasp_affordance_map": privileged_state["grasp_affordance_map"].astype("float32"),
113
+ "support_stability": float(privileged_state["support_stability"]),
114
+ "support_stability_map": privileged_state["support_stability_map"].astype("float32"),
115
+ "reocclusion_target": float(privileged_state["reocclusion_target"]),
116
+ "reocclusion_map": privileged_state["reocclusion_map"].astype("float32"),
117
  "rollout_support_mode": rollout["rollout_support_mode"].astype("int64"),
118
  "rollout_corridor_feasible": rollout["rollout_corridor_feasible"].astype("float32"),
119
  "rollout_persistence_horizon": rollout["rollout_persistence_horizon"].astype("float32"),
120
  "rollout_disturbance_cost": rollout["rollout_disturbance_cost"].astype("float32"),
121
+ "rollout_belief_map": rollout["rollout_belief_map"].astype("float32"),
122
+ "rollout_visibility_map": rollout["rollout_visibility_map"].astype("float32"),
123
+ "rollout_clearance_map": rollout["rollout_clearance_map"].astype("float32"),
124
+ "rollout_support_stability": rollout["rollout_support_stability"].astype("float32"),
125
+ "rollout_reocclusion_target": rollout["rollout_reocclusion_target"].astype("float32"),
126
+ "rollout_occluder_contact_map": rollout["rollout_occluder_contact_map"].astype("float32"),
127
+ "rollout_grasp_affordance_map": rollout["rollout_grasp_affordance_map"].astype("float32"),
128
  "history_render_states": padded_history_render_states,
129
  "history_proprio": np.stack(padded_history_proprio, axis=0).astype("float32")
130
  if padded_history_proprio
 
156
  "teacher_success": proxy_success / float(max(1, episodes_per_proxy)),
157
  }
158
  return {
159
+ "dataset_version": dataset_version,
160
  "resolution": resolution,
161
  "chunk_horizon": chunk_horizon,
162
  "rollout_horizon": rollout_horizon,
 
182
  def __init__(self, samples: Sequence[dict[str, Any]], resolution: int = 96) -> None:
183
  self.samples = list(samples)
184
  self.resolution = resolution
185
+ self._render_cache: dict[bytes, dict[str, np.ndarray]] = {}
186
+ self._item_cache: dict[int, dict[str, Any]] = {}
187
 
188
  def __len__(self) -> int:
189
  return len(self.samples)
190
 
191
+ def _render_cache_key(self, sample: dict[str, Any], render_state: dict[str, Any]) -> bytes:
192
+ include_depth = sample.get("dataset_version") == RGBD_PROXY_DATASET_VERSION
193
+ return pickle.dumps(
194
+ (sample["proxy_name"], self.resolution, include_depth, render_state),
195
+ protocol=4,
196
+ )
197
+
198
+ def _render_sample(self, sample: dict[str, Any], render_state: dict[str, Any]) -> dict[str, np.ndarray]:
199
+ cache_key = self._render_cache_key(sample, render_state)
200
+ cached = self._render_cache.get(cache_key)
201
+ if cached is not None:
202
+ return cached
203
+ include_depth = sample.get("dataset_version") == RGBD_PROXY_DATASET_VERSION
204
+ rendered = render_views_from_state(
205
  proxy_name=sample["proxy_name"],
206
+ render_state=render_state,
207
  resolution=self.resolution,
208
+ include_depth=include_depth,
209
  )
210
+ self._render_cache[cache_key] = rendered
211
+ return rendered
212
+
213
+ def __getitem__(self, index: int) -> dict[str, Any]:
214
+ cached_item = self._item_cache.get(index)
215
+ if cached_item is not None:
216
+ return cached_item
217
+ sample = self.samples[index]
218
+ _assert_noleak_sample(sample)
219
+ images = self._render_sample(sample, sample["render_state"])
220
  history_images = []
221
+ history_depths = []
222
+ history_depth_valid = []
223
  for history_state in sample.get("history_render_states", []):
224
+ rendered = self._render_sample(sample, history_state)
 
 
 
 
225
  history_images.append(
226
  torch.stack(
227
  [
 
232
  dim=0,
233
  )
234
  )
235
+ if sample.get("dataset_version") == RGBD_PROXY_DATASET_VERSION:
236
+ history_depths.append(
237
+ torch.stack(
238
+ [
239
+ torch.from_numpy(rendered["front_depth"]),
240
+ torch.from_numpy(rendered["wrist_left_depth"]),
241
+ torch.from_numpy(rendered["wrist_right_depth"]),
242
+ ],
243
+ dim=0,
244
+ )
245
+ )
246
+ history_depth_valid.append(
247
+ torch.stack(
248
+ [
249
+ torch.from_numpy(rendered["front_depth_valid"]),
250
+ torch.from_numpy(rendered["wrist_left_depth_valid"]),
251
+ torch.from_numpy(rendered["wrist_right_depth_valid"]),
252
+ ],
253
+ dim=0,
254
+ )
255
+ )
256
  stacked = torch.from_numpy(
257
  torch.stack(
258
  [
 
267
  history_stacked = torch.stack(history_images, dim=0).permute(0, 1, 4, 2, 3).float() / 255.0
268
  else:
269
  history_stacked = torch.zeros((0, 3, 3, self.resolution, self.resolution), dtype=torch.float32)
270
+ if sample.get("dataset_version") == RGBD_PROXY_DATASET_VERSION:
271
+ depths = torch.stack(
272
+ [
273
+ torch.from_numpy(images["front_depth"]),
274
+ torch.from_numpy(images["wrist_left_depth"]),
275
+ torch.from_numpy(images["wrist_right_depth"]),
276
+ ],
277
+ dim=0,
278
+ ).unsqueeze(1).float()
279
+ depth_valid = torch.stack(
280
+ [
281
+ torch.from_numpy(images["front_depth_valid"]),
282
+ torch.from_numpy(images["wrist_left_depth_valid"]),
283
+ torch.from_numpy(images["wrist_right_depth_valid"]),
284
+ ],
285
+ dim=0,
286
+ ).unsqueeze(1).float()
287
+ if history_depths:
288
+ history_depths_tensor = torch.stack(history_depths, dim=0).unsqueeze(2).float()
289
+ history_depth_valid_tensor = torch.stack(history_depth_valid, dim=0).unsqueeze(2).float()
290
+ else:
291
+ history_depths_tensor = torch.zeros((0, 3, 1, self.resolution, self.resolution), dtype=torch.float32)
292
+ history_depth_valid_tensor = torch.zeros((0, 3, 1, self.resolution, self.resolution), dtype=torch.float32)
293
+ else:
294
+ depths = torch.zeros((3, 1, self.resolution, self.resolution), dtype=torch.float32)
295
+ depth_valid = torch.zeros_like(depths)
296
+ history_depths_tensor = torch.zeros((0, 3, 1, self.resolution, self.resolution), dtype=torch.float32)
297
+ history_depth_valid_tensor = torch.zeros_like(history_depths_tensor)
298
+ camera_intrinsics, camera_extrinsics = default_camera_matrices()
299
+ item = {
300
  "images": stacked,
301
+ "depths": depths,
302
+ "depth_valid": depth_valid,
303
  "history_images": history_stacked,
304
+ "history_depths": history_depths_tensor,
305
+ "history_depth_valid": history_depth_valid_tensor,
306
  "history_proprio": torch.as_tensor(sample.get("history_proprio", []), dtype=torch.float32),
307
  "history_actions": torch.as_tensor(
308
  sample.get(
 
311
  ),
312
  dtype=torch.float32,
313
  ),
314
+ "camera_intrinsics": torch.as_tensor(camera_intrinsics, dtype=torch.float32),
315
+ "camera_extrinsics": torch.as_tensor(camera_extrinsics, dtype=torch.float32),
316
  "proprio": torch.as_tensor(sample["proprio"], dtype=torch.float32),
317
  "texts": sample["language_goal"],
318
  "action_chunk": torch.as_tensor(sample["action_chunk"], dtype=torch.float32),
 
321
  "persistence_horizon": torch.as_tensor(sample["persistence_horizon"], dtype=torch.float32),
322
  "disturbance_cost": torch.as_tensor(sample["disturbance_cost"], dtype=torch.float32),
323
  "belief_map": torch.as_tensor(sample["belief_map"], dtype=torch.float32).unsqueeze(0),
324
+ "visibility_map": torch.as_tensor(sample.get("visibility_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
325
+ "clearance_map": torch.as_tensor(sample.get("clearance_map", np.zeros((2, 32, 32), dtype=np.float32)), dtype=torch.float32),
326
+ "occluder_contact_map": torch.as_tensor(sample.get("occluder_contact_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
327
+ "grasp_affordance_map": torch.as_tensor(sample.get("grasp_affordance_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
328
+ "support_stability": torch.as_tensor(sample.get("support_stability", 0.0), dtype=torch.float32),
329
+ "support_stability_map": torch.as_tensor(sample.get("support_stability_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
330
+ "reocclusion_target": torch.as_tensor(sample.get("reocclusion_target", 0.0), dtype=torch.float32),
331
+ "reocclusion_map": torch.as_tensor(sample.get("reocclusion_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
332
  "rollout_support_mode": torch.as_tensor(sample["rollout_support_mode"], dtype=torch.long),
333
  "rollout_corridor_feasible": torch.as_tensor(sample["rollout_corridor_feasible"], dtype=torch.float32),
334
  "rollout_persistence_horizon": torch.as_tensor(sample["rollout_persistence_horizon"], dtype=torch.float32),
335
  "rollout_disturbance_cost": torch.as_tensor(sample["rollout_disturbance_cost"], dtype=torch.float32),
336
+ "rollout_belief_map": torch.as_tensor(sample.get("rollout_belief_map", np.zeros((0, 32, 32), dtype=np.float32)), dtype=torch.float32),
337
+ "rollout_visibility_map": torch.as_tensor(sample.get("rollout_visibility_map", np.zeros((0, 32, 32), dtype=np.float32)), dtype=torch.float32),
338
+ "rollout_clearance_map": torch.as_tensor(sample.get("rollout_clearance_map", np.zeros((0, 2, 32, 32), dtype=np.float32)), dtype=torch.float32),
339
+ "rollout_support_stability": torch.as_tensor(sample.get("rollout_support_stability", np.zeros((0,), dtype=np.float32)), dtype=torch.float32),
340
+ "rollout_reocclusion_target": torch.as_tensor(sample.get("rollout_reocclusion_target", np.zeros((0,), dtype=np.float32)), dtype=torch.float32),
341
+ "rollout_occluder_contact_map": torch.as_tensor(sample.get("rollout_occluder_contact_map", np.zeros((0, 32, 32), dtype=np.float32)), dtype=torch.float32),
342
+ "rollout_grasp_affordance_map": torch.as_tensor(sample.get("rollout_grasp_affordance_map", np.zeros((0, 32, 32), dtype=np.float32)), dtype=torch.float32),
343
  "candidate_action_chunks": torch.as_tensor(sample["candidate_action_chunks"], dtype=torch.float32),
344
  "candidate_rollout_support_mode": torch.as_tensor(sample["candidate_rollout_support_mode"], dtype=torch.long),
345
  "candidate_rollout_corridor_feasible": torch.as_tensor(sample["candidate_rollout_corridor_feasible"], dtype=torch.float32),
346
  "candidate_rollout_persistence_horizon": torch.as_tensor(sample["candidate_rollout_persistence_horizon"], dtype=torch.float32),
347
  "candidate_rollout_disturbance_cost": torch.as_tensor(sample["candidate_rollout_disturbance_cost"], dtype=torch.float32),
348
+ "candidate_rollout_belief_map": torch.as_tensor(sample.get("candidate_rollout_belief_map", np.zeros((0, 0, 32, 32), dtype=np.float32)), dtype=torch.float32),
349
+ "candidate_rollout_visibility_map": torch.as_tensor(sample.get("candidate_rollout_visibility_map", np.zeros((0, 0, 32, 32), dtype=np.float32)), dtype=torch.float32),
350
+ "candidate_rollout_clearance_map": torch.as_tensor(sample.get("candidate_rollout_clearance_map", np.zeros((0, 0, 2, 32, 32), dtype=np.float32)), dtype=torch.float32),
351
+ "candidate_rollout_support_stability": torch.as_tensor(sample.get("candidate_rollout_support_stability", np.zeros((0, 0), dtype=np.float32)), dtype=torch.float32),
352
+ "candidate_rollout_reocclusion_target": torch.as_tensor(sample.get("candidate_rollout_reocclusion_target", np.zeros((0, 0), dtype=np.float32)), dtype=torch.float32),
353
+ "candidate_rollout_occluder_contact_map": torch.as_tensor(sample.get("candidate_rollout_occluder_contact_map", np.zeros((0, 0, 32, 32), dtype=np.float32)), dtype=torch.float32),
354
+ "candidate_rollout_grasp_affordance_map": torch.as_tensor(sample.get("candidate_rollout_grasp_affordance_map", np.zeros((0, 0, 32, 32), dtype=np.float32)), dtype=torch.float32),
355
  "candidate_retrieval_success": torch.as_tensor(sample["candidate_retrieval_success"], dtype=torch.float32),
356
  "candidate_final_disturbance_cost": torch.as_tensor(sample["candidate_final_disturbance_cost"], dtype=torch.float32),
357
  "candidate_reocclusion_rate": torch.as_tensor(sample["candidate_reocclusion_rate"], dtype=torch.float32),
 
361
  "proxy_name": sample["proxy_name"],
362
  "episode_id": sample["episode_id"],
363
  }
364
+ self._item_cache[index] = item
365
+ return item
366
 
367
 
368
  def dataset_from_bundle(dataset_bundle: dict[str, Any], resolution: int | None = None) -> RevealOfflineDataset:
code/reveal_vla_bimanual/sim_reveal/procedural_envs.py CHANGED
@@ -83,6 +83,26 @@ PROXY_GOALS = {
83
  }
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def available_proxy_names() -> tuple[str, ...]:
87
  return tuple(PROXY_CONFIGS.keys())
88
 
@@ -285,6 +305,57 @@ class ProceduralRevealEnv:
285
  belief *= visibility
286
  return belief.astype(np.float32)
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  def get_privileged_state(self) -> dict[str, Any]:
289
  support_mode = int(self._current_support_mode())
290
  corridor = np.stack(
@@ -294,12 +365,29 @@ class ProceduralRevealEnv:
294
  persistence = np.asarray([self._persistence_for_mode(mode) for mode in SupportMode], dtype=np.float32)
295
  visibility = self._visibility()
296
  disturbance_cost = float(np.clip(self.disturbance + 0.08 * max(0.0, self.opening - self.dynamics.desired_opening), 0.0, 1.0))
 
 
 
 
 
 
 
 
 
297
  return {
298
  "support_mode": support_mode,
299
  "corridor_feasible": corridor,
300
  "persistence_horizon": persistence,
301
  "disturbance_cost": disturbance_cost,
302
- "belief_map": self._belief_map(visibility),
 
 
 
 
 
 
 
 
303
  "visibility": visibility,
304
  "retrieval_success": bool(self.retrieved),
305
  "target_template": self.target_template,
@@ -335,12 +423,18 @@ class ProceduralRevealEnv:
335
  render_state=render_state,
336
  resolution=self.resolution,
337
  num_templates=self.num_templates,
 
338
  )
 
339
  return {
340
  "images": np.stack([images[camera] for camera in self.camera_names], axis=0),
 
 
341
  "proprio": self._proprio(privileged_state),
342
  "text": PROXY_GOALS[self.proxy_name],
343
  "camera_names": self.camera_names,
 
 
344
  }
345
 
346
  def teacher_action(self) -> np.ndarray:
@@ -385,6 +479,13 @@ class ProceduralRevealEnv:
385
  rollout_corridor = []
386
  rollout_persistence = []
387
  rollout_disturbance = []
 
 
 
 
 
 
 
388
  for step in range(chunk_horizon):
389
  action = self.teacher_action()
390
  action_chunk.append(action)
@@ -394,21 +495,43 @@ class ProceduralRevealEnv:
394
  rollout_corridor.append(privileged_state["corridor_feasible"])
395
  rollout_persistence.append(privileged_state["persistence_horizon"])
396
  rollout_disturbance.append(privileged_state["disturbance_cost"])
 
 
 
 
 
 
 
397
  if terminated or truncated:
398
  break
399
  while len(action_chunk) < chunk_horizon:
400
  action_chunk.append(np.zeros((14,), dtype=np.float32))
401
  while len(rollout_support_mode) < rollout_horizon:
 
402
  rollout_support_mode.append(int(self._current_support_mode()))
403
- rollout_corridor.append(self.get_privileged_state()["corridor_feasible"])
404
- rollout_persistence.append(self.get_privileged_state()["persistence_horizon"])
405
- rollout_disturbance.append(self.get_privileged_state()["disturbance_cost"])
 
 
 
 
 
 
 
406
  self.restore_state(snapshot)
407
  return np.stack(action_chunk, axis=0).astype(np.float32), {
408
  "rollout_support_mode": np.asarray(rollout_support_mode, dtype=np.int64),
409
  "rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32),
410
  "rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32),
411
  "rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32),
 
 
 
 
 
 
 
412
  }
413
 
414
  def evaluate_action_chunk(
@@ -422,6 +545,13 @@ class ProceduralRevealEnv:
422
  rollout_corridor: list[np.ndarray] = []
423
  rollout_persistence: list[np.ndarray] = []
424
  rollout_disturbance: list[float] = []
 
 
 
 
 
 
 
425
  corridor_open_trace = [float(self.get_privileged_state()["corridor_feasible"][self._current_support_mode()].any())]
426
  visibility_trace = [float(self.get_privileged_state()["visibility"])]
427
  terminated = False
@@ -434,6 +564,13 @@ class ProceduralRevealEnv:
434
  rollout_corridor.append(privileged_state["corridor_feasible"].astype(np.float32))
435
  rollout_persistence.append(privileged_state["persistence_horizon"].astype(np.float32))
436
  rollout_disturbance.append(float(privileged_state["disturbance_cost"]))
 
 
 
 
 
 
 
437
  corridor_open_trace.append(float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any()))
438
  visibility_trace.append(float(privileged_state["visibility"]))
439
  if terminated or truncated:
@@ -444,6 +581,13 @@ class ProceduralRevealEnv:
444
  rollout_corridor.append(current["corridor_feasible"].astype(np.float32))
445
  rollout_persistence.append(current["persistence_horizon"].astype(np.float32))
446
  rollout_disturbance.append(float(current["disturbance_cost"]))
 
 
 
 
 
 
 
447
  final_state = self.get_privileged_state()
448
  reocclusion = float(
449
  np.logical_and(
@@ -456,6 +600,13 @@ class ProceduralRevealEnv:
456
  "rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32),
457
  "rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32),
458
  "rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32),
 
 
 
 
 
 
 
459
  "retrieval_success": float(final_state["retrieval_success"]),
460
  "final_disturbance_cost": float(final_state["disturbance_cost"]),
461
  "reocclusion_rate": reocclusion,
@@ -493,6 +644,27 @@ class ProceduralRevealEnv:
493
  "candidate_rollout_disturbance_cost": np.stack(
494
  [item["rollout_disturbance_cost"] for item in outcomes], axis=0
495
  ).astype(np.float32),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  "candidate_retrieval_success": np.asarray([item["retrieval_success"] for item in outcomes], dtype=np.float32),
497
  "candidate_final_disturbance_cost": np.asarray(
498
  [item["final_disturbance_cost"] for item in outcomes], dtype=np.float32
@@ -587,6 +759,7 @@ def render_views_from_state(
587
  render_state: dict[str, Any],
588
  resolution: int,
589
  num_templates: int = 32,
 
590
  ) -> dict[str, np.ndarray]:
591
  dynamics = PROXY_DYNAMICS[proxy_name]
592
  opening = float(render_state["opening"])
@@ -668,8 +841,40 @@ def render_views_from_state(
668
  wrist_right[..., 2] = np.clip(wrist_right[..., 2] + 0.08 * step_fraction + 0.06 * right_band, 0.0, 1.0)
669
  wrist_right = np.clip(wrist_right, 0.0, 1.0)
670
 
671
- return {
672
  "front": (front * 255.0).astype(np.uint8),
673
  "wrist_left": (wrist_left * 255.0).astype(np.uint8),
674
  "wrist_right": (wrist_right * 255.0).astype(np.uint8),
675
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  }
84
 
85
 
86
+ def default_camera_matrices() -> tuple[np.ndarray, np.ndarray]:
87
+ intrinsics = np.asarray(
88
+ [
89
+ [[140.0, 0.0, 48.0], [0.0, 140.0, 48.0], [0.0, 0.0, 1.0]],
90
+ [[135.0, 0.0, 48.0], [0.0, 135.0, 48.0], [0.0, 0.0, 1.0]],
91
+ [[135.0, 0.0, 48.0], [0.0, 135.0, 48.0], [0.0, 0.0, 1.0]],
92
+ ],
93
+ dtype=np.float32,
94
+ )
95
+ extrinsics = np.asarray(
96
+ [
97
+ np.eye(4, dtype=np.float32),
98
+ [[1.0, 0.0, 0.0, -0.18], [0.0, 1.0, 0.0, 0.04], [0.0, 0.0, 1.0, 0.10], [0.0, 0.0, 0.0, 1.0]],
99
+ [[1.0, 0.0, 0.0, 0.18], [0.0, 1.0, 0.0, 0.04], [0.0, 0.0, 1.0, 0.10], [0.0, 0.0, 0.0, 1.0]],
100
+ ],
101
+ dtype=np.float32,
102
+ )
103
+ return intrinsics, extrinsics
104
+
105
+
106
  def available_proxy_names() -> tuple[str, ...]:
107
  return tuple(PROXY_CONFIGS.keys())
108
 
 
305
  belief *= visibility
306
  return belief.astype(np.float32)
307
 
308
+ def _visibility_map(self, visibility: float) -> np.ndarray:
309
+ belief = self._belief_map(visibility)
310
+ gradient = np.linspace(0.65, 1.0, belief.shape[0], dtype=np.float32).reshape(-1, 1)
311
+ return np.clip(belief * gradient, 0.0, 1.0).astype(np.float32)
312
+
313
+ def _clearance_map(self, visibility: float) -> np.ndarray:
314
+ side = 32
315
+ x = np.linspace(0.0, 1.0, side, dtype=np.float32)
316
+ y = np.linspace(0.0, 1.0, side, dtype=np.float32)
317
+ yy, xx = np.meshgrid(y, x, indexing="ij")
318
+ corridor_width = np.clip(0.05 + 0.18 * self.opening - 0.10 * self.disturbance, 0.01, 0.28)
319
+ corridor = np.exp(-(((xx - self.target_center) ** 2) / max(1e-5, corridor_width**2)))
320
+ vertical = np.exp(-(((yy - (0.72 - 0.25 * self.target_depth)) ** 2) / 0.03))
321
+ left = np.clip(corridor * vertical * visibility * (0.92 - 0.15 * self.disturbance), 0.0, 1.0)
322
+ right = np.clip(corridor * vertical * visibility * (0.88 - 0.10 * self.disturbance), 0.0, 1.0)
323
+ return np.stack([left, right], axis=0).astype(np.float32)
324
+
325
+ def _occluder_contact_map(self) -> np.ndarray:
326
+ side = 32
327
+ x = np.linspace(0.0, 1.0, side, dtype=np.float32)
328
+ y = np.linspace(0.0, 1.0, side, dtype=np.float32)
329
+ yy, xx = np.meshgrid(y, x, indexing="ij")
330
+ gap_width = np.clip(0.03 + 0.16 * self.opening, 0.03, 0.24)
331
+ left_band = np.exp(-(((xx - (self.target_center - gap_width)) ** 2) / 0.0025))
332
+ right_band = np.exp(-(((xx - (self.target_center + gap_width)) ** 2) / 0.0025))
333
+ support = np.exp(-(((yy - 0.55) ** 2) / 0.04))
334
+ return np.clip((left_band + right_band) * support, 0.0, 1.0).astype(np.float32)
335
+
336
+ def _support_stability(self) -> float:
337
+ base = 1.0 - 0.45 * self.disturbance - 0.10 * max(0.0, self.opening - self.dynamics.desired_opening)
338
+ if self._current_support_mode() == self.dynamics.preferred_mode:
339
+ base += 0.08
340
+ return float(np.clip(base, 0.0, 1.0))
341
+
342
+ def _support_stability_map(self) -> np.ndarray:
343
+ return np.full((32, 32), self._support_stability(), dtype=np.float32)
344
+
345
+ def _reocclusion_target(self, persistence: np.ndarray) -> float:
346
+ current_mode = int(self._current_support_mode())
347
+ horizon_ratio = persistence[current_mode] / float(max(1, self.rollout_horizon))
348
+ return float(np.clip(1.0 - horizon_ratio + 0.35 * self.disturbance, 0.0, 1.0))
349
+
350
+ def _grasp_affordance_map(
351
+ self,
352
+ belief_map: np.ndarray,
353
+ visibility_map: np.ndarray,
354
+ clearance_map: np.ndarray,
355
+ ) -> np.ndarray:
356
+ combined = belief_map * visibility_map * clearance_map.mean(axis=0)
357
+ return np.clip(combined * (1.0 - 0.35 * self.disturbance), 0.0, 1.0).astype(np.float32)
358
+
359
  def get_privileged_state(self) -> dict[str, Any]:
360
  support_mode = int(self._current_support_mode())
361
  corridor = np.stack(
 
365
  persistence = np.asarray([self._persistence_for_mode(mode) for mode in SupportMode], dtype=np.float32)
366
  visibility = self._visibility()
367
  disturbance_cost = float(np.clip(self.disturbance + 0.08 * max(0.0, self.opening - self.dynamics.desired_opening), 0.0, 1.0))
368
+ belief_map = self._belief_map(visibility)
369
+ visibility_map = self._visibility_map(visibility)
370
+ clearance_map = self._clearance_map(visibility)
371
+ occluder_contact_map = self._occluder_contact_map()
372
+ support_stability = self._support_stability()
373
+ support_stability_map = self._support_stability_map()
374
+ reocclusion_target = self._reocclusion_target(persistence)
375
+ reocclusion_map = np.full((32, 32), reocclusion_target, dtype=np.float32)
376
+ grasp_affordance_map = self._grasp_affordance_map(belief_map, visibility_map, clearance_map)
377
  return {
378
  "support_mode": support_mode,
379
  "corridor_feasible": corridor,
380
  "persistence_horizon": persistence,
381
  "disturbance_cost": disturbance_cost,
382
+ "belief_map": belief_map,
383
+ "visibility_map": visibility_map,
384
+ "clearance_map": clearance_map,
385
+ "occluder_contact_map": occluder_contact_map,
386
+ "grasp_affordance_map": grasp_affordance_map,
387
+ "support_stability": support_stability,
388
+ "support_stability_map": support_stability_map,
389
+ "reocclusion_target": reocclusion_target,
390
+ "reocclusion_map": reocclusion_map,
391
  "visibility": visibility,
392
  "retrieval_success": bool(self.retrieved),
393
  "target_template": self.target_template,
 
423
  render_state=render_state,
424
  resolution=self.resolution,
425
  num_templates=self.num_templates,
426
+ include_depth=True,
427
  )
428
+ camera_intrinsics, camera_extrinsics = default_camera_matrices()
429
  return {
430
  "images": np.stack([images[camera] for camera in self.camera_names], axis=0),
431
+ "depths": np.stack([images[f"{camera}_depth"] for camera in self.camera_names], axis=0)[:, None, :, :],
432
+ "depth_valid": np.stack([images[f"{camera}_depth_valid"] for camera in self.camera_names], axis=0)[:, None, :, :],
433
  "proprio": self._proprio(privileged_state),
434
  "text": PROXY_GOALS[self.proxy_name],
435
  "camera_names": self.camera_names,
436
+ "camera_intrinsics": camera_intrinsics,
437
+ "camera_extrinsics": camera_extrinsics,
438
  }
439
 
440
  def teacher_action(self) -> np.ndarray:
 
479
  rollout_corridor = []
480
  rollout_persistence = []
481
  rollout_disturbance = []
482
+ rollout_belief = []
483
+ rollout_visibility = []
484
+ rollout_clearance = []
485
+ rollout_support_stability = []
486
+ rollout_reocclusion = []
487
+ rollout_occluder_contact = []
488
+ rollout_grasp_affordance = []
489
  for step in range(chunk_horizon):
490
  action = self.teacher_action()
491
  action_chunk.append(action)
 
495
  rollout_corridor.append(privileged_state["corridor_feasible"])
496
  rollout_persistence.append(privileged_state["persistence_horizon"])
497
  rollout_disturbance.append(privileged_state["disturbance_cost"])
498
+ rollout_belief.append(privileged_state["belief_map"])
499
+ rollout_visibility.append(privileged_state["visibility_map"])
500
+ rollout_clearance.append(privileged_state["clearance_map"])
501
+ rollout_support_stability.append(privileged_state["support_stability"])
502
+ rollout_reocclusion.append(privileged_state["reocclusion_target"])
503
+ rollout_occluder_contact.append(privileged_state["occluder_contact_map"])
504
+ rollout_grasp_affordance.append(privileged_state["grasp_affordance_map"])
505
  if terminated or truncated:
506
  break
507
  while len(action_chunk) < chunk_horizon:
508
  action_chunk.append(np.zeros((14,), dtype=np.float32))
509
  while len(rollout_support_mode) < rollout_horizon:
510
+ current = self.get_privileged_state()
511
  rollout_support_mode.append(int(self._current_support_mode()))
512
+ rollout_corridor.append(current["corridor_feasible"])
513
+ rollout_persistence.append(current["persistence_horizon"])
514
+ rollout_disturbance.append(current["disturbance_cost"])
515
+ rollout_belief.append(current["belief_map"])
516
+ rollout_visibility.append(current["visibility_map"])
517
+ rollout_clearance.append(current["clearance_map"])
518
+ rollout_support_stability.append(current["support_stability"])
519
+ rollout_reocclusion.append(current["reocclusion_target"])
520
+ rollout_occluder_contact.append(current["occluder_contact_map"])
521
+ rollout_grasp_affordance.append(current["grasp_affordance_map"])
522
  self.restore_state(snapshot)
523
  return np.stack(action_chunk, axis=0).astype(np.float32), {
524
  "rollout_support_mode": np.asarray(rollout_support_mode, dtype=np.int64),
525
  "rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32),
526
  "rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32),
527
  "rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32),
528
+ "rollout_belief_map": np.asarray(rollout_belief, dtype=np.float32),
529
+ "rollout_visibility_map": np.asarray(rollout_visibility, dtype=np.float32),
530
+ "rollout_clearance_map": np.asarray(rollout_clearance, dtype=np.float32),
531
+ "rollout_support_stability": np.asarray(rollout_support_stability, dtype=np.float32),
532
+ "rollout_reocclusion_target": np.asarray(rollout_reocclusion, dtype=np.float32),
533
+ "rollout_occluder_contact_map": np.asarray(rollout_occluder_contact, dtype=np.float32),
534
+ "rollout_grasp_affordance_map": np.asarray(rollout_grasp_affordance, dtype=np.float32),
535
  }
536
 
537
  def evaluate_action_chunk(
 
545
  rollout_corridor: list[np.ndarray] = []
546
  rollout_persistence: list[np.ndarray] = []
547
  rollout_disturbance: list[float] = []
548
+ rollout_belief: list[np.ndarray] = []
549
+ rollout_visibility: list[np.ndarray] = []
550
+ rollout_clearance: list[np.ndarray] = []
551
+ rollout_support_stability: list[float] = []
552
+ rollout_reocclusion: list[float] = []
553
+ rollout_occluder_contact: list[np.ndarray] = []
554
+ rollout_grasp_affordance: list[np.ndarray] = []
555
  corridor_open_trace = [float(self.get_privileged_state()["corridor_feasible"][self._current_support_mode()].any())]
556
  visibility_trace = [float(self.get_privileged_state()["visibility"])]
557
  terminated = False
 
564
  rollout_corridor.append(privileged_state["corridor_feasible"].astype(np.float32))
565
  rollout_persistence.append(privileged_state["persistence_horizon"].astype(np.float32))
566
  rollout_disturbance.append(float(privileged_state["disturbance_cost"]))
567
+ rollout_belief.append(privileged_state["belief_map"].astype(np.float32))
568
+ rollout_visibility.append(privileged_state["visibility_map"].astype(np.float32))
569
+ rollout_clearance.append(privileged_state["clearance_map"].astype(np.float32))
570
+ rollout_support_stability.append(float(privileged_state["support_stability"]))
571
+ rollout_reocclusion.append(float(privileged_state["reocclusion_target"]))
572
+ rollout_occluder_contact.append(privileged_state["occluder_contact_map"].astype(np.float32))
573
+ rollout_grasp_affordance.append(privileged_state["grasp_affordance_map"].astype(np.float32))
574
  corridor_open_trace.append(float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any()))
575
  visibility_trace.append(float(privileged_state["visibility"]))
576
  if terminated or truncated:
 
581
  rollout_corridor.append(current["corridor_feasible"].astype(np.float32))
582
  rollout_persistence.append(current["persistence_horizon"].astype(np.float32))
583
  rollout_disturbance.append(float(current["disturbance_cost"]))
584
+ rollout_belief.append(current["belief_map"].astype(np.float32))
585
+ rollout_visibility.append(current["visibility_map"].astype(np.float32))
586
+ rollout_clearance.append(current["clearance_map"].astype(np.float32))
587
+ rollout_support_stability.append(float(current["support_stability"]))
588
+ rollout_reocclusion.append(float(current["reocclusion_target"]))
589
+ rollout_occluder_contact.append(current["occluder_contact_map"].astype(np.float32))
590
+ rollout_grasp_affordance.append(current["grasp_affordance_map"].astype(np.float32))
591
  final_state = self.get_privileged_state()
592
  reocclusion = float(
593
  np.logical_and(
 
600
  "rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32),
601
  "rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32),
602
  "rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32),
603
+ "rollout_belief_map": np.asarray(rollout_belief, dtype=np.float32),
604
+ "rollout_visibility_map": np.asarray(rollout_visibility, dtype=np.float32),
605
+ "rollout_clearance_map": np.asarray(rollout_clearance, dtype=np.float32),
606
+ "rollout_support_stability": np.asarray(rollout_support_stability, dtype=np.float32),
607
+ "rollout_reocclusion_target": np.asarray(rollout_reocclusion, dtype=np.float32),
608
+ "rollout_occluder_contact_map": np.asarray(rollout_occluder_contact, dtype=np.float32),
609
+ "rollout_grasp_affordance_map": np.asarray(rollout_grasp_affordance, dtype=np.float32),
610
  "retrieval_success": float(final_state["retrieval_success"]),
611
  "final_disturbance_cost": float(final_state["disturbance_cost"]),
612
  "reocclusion_rate": reocclusion,
 
644
  "candidate_rollout_disturbance_cost": np.stack(
645
  [item["rollout_disturbance_cost"] for item in outcomes], axis=0
646
  ).astype(np.float32),
647
+ "candidate_rollout_belief_map": np.stack(
648
+ [item["rollout_belief_map"] for item in outcomes], axis=0
649
+ ).astype(np.float32),
650
+ "candidate_rollout_visibility_map": np.stack(
651
+ [item["rollout_visibility_map"] for item in outcomes], axis=0
652
+ ).astype(np.float32),
653
+ "candidate_rollout_clearance_map": np.stack(
654
+ [item["rollout_clearance_map"] for item in outcomes], axis=0
655
+ ).astype(np.float32),
656
+ "candidate_rollout_support_stability": np.stack(
657
+ [item["rollout_support_stability"] for item in outcomes], axis=0
658
+ ).astype(np.float32),
659
+ "candidate_rollout_reocclusion_target": np.stack(
660
+ [item["rollout_reocclusion_target"] for item in outcomes], axis=0
661
+ ).astype(np.float32),
662
+ "candidate_rollout_occluder_contact_map": np.stack(
663
+ [item["rollout_occluder_contact_map"] for item in outcomes], axis=0
664
+ ).astype(np.float32),
665
+ "candidate_rollout_grasp_affordance_map": np.stack(
666
+ [item["rollout_grasp_affordance_map"] for item in outcomes], axis=0
667
+ ).astype(np.float32),
668
  "candidate_retrieval_success": np.asarray([item["retrieval_success"] for item in outcomes], dtype=np.float32),
669
  "candidate_final_disturbance_cost": np.asarray(
670
  [item["final_disturbance_cost"] for item in outcomes], dtype=np.float32
 
759
  render_state: dict[str, Any],
760
  resolution: int,
761
  num_templates: int = 32,
762
+ include_depth: bool = False,
763
  ) -> dict[str, np.ndarray]:
764
  dynamics = PROXY_DYNAMICS[proxy_name]
765
  opening = float(render_state["opening"])
 
841
  wrist_right[..., 2] = np.clip(wrist_right[..., 2] + 0.08 * step_fraction + 0.06 * right_band, 0.0, 1.0)
842
  wrist_right = np.clip(wrist_right, 0.0, 1.0)
843
 
844
+ outputs = {
845
  "front": (front * 255.0).astype(np.uint8),
846
  "wrist_left": (wrist_left * 255.0).astype(np.uint8),
847
  "wrist_right": (wrist_right * 255.0).astype(np.uint8),
848
  }
849
+ if not include_depth:
850
+ return outputs
851
+
852
+ front_depth = np.clip(0.25 + 0.40 * target_depth + 0.15 * disturbance + 0.10 * (1.0 - visibility), 0.0, 1.0)
853
+ target_depth_map = np.clip(0.10 + 0.55 * target_depth, 0.0, 1.0)
854
+ occluder_depth = np.clip(0.30 + 0.20 * disturbance + 0.10 * (1.0 - opening), 0.0, 1.0)
855
+ front_depth_map = np.full((height, width), front_depth, dtype=np.float32)
856
+ front_depth_map[gap_mask] = np.minimum(front_depth_map[gap_mask], occluder_depth)
857
+ front_depth_map[target_mask] = np.minimum(front_depth_map[target_mask], target_depth_map)
858
+
859
+ wrist_left_depth = np.clip(0.35 + 0.25 * target_depth + 0.10 * disturbance, 0.0, 1.0)
860
+ wrist_left_depth_map = np.full((height, width), wrist_left_depth, dtype=np.float32)
861
+ wrist_left_depth_map[left_open] = np.minimum(wrist_left_depth_map[left_open], 0.22 + 0.25 * target_depth)
862
+ wrist_left_depth_map[target_mask] = np.minimum(wrist_left_depth_map[target_mask], target_depth_map)
863
+
864
+ wrist_right_depth = np.clip(0.35 + 0.20 * target_depth + 0.12 * disturbance, 0.0, 1.0)
865
+ wrist_right_depth_map = np.full((height, width), wrist_right_depth, dtype=np.float32)
866
+ right_focus = (right_band * right_clear) > 0.15
867
+ wrist_right_depth_map[right_focus] = np.minimum(wrist_right_depth_map[right_focus], 0.20 + 0.25 * target_depth)
868
+ wrist_right_depth_map[target_mask] = np.minimum(wrist_right_depth_map[target_mask], target_depth_map)
869
+
870
+ outputs.update(
871
+ {
872
+ "front_depth": front_depth_map.astype(np.float32),
873
+ "wrist_left_depth": wrist_left_depth_map.astype(np.float32),
874
+ "wrist_right_depth": wrist_right_depth_map.astype(np.float32),
875
+ "front_depth_valid": np.ones((height, width), dtype=np.float32),
876
+ "wrist_left_depth_valid": np.ones((height, width), dtype=np.float32),
877
+ "wrist_right_depth_valid": np.ones((height, width), dtype=np.float32),
878
+ }
879
+ )
880
+ return outputs
code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc differ
 
code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc CHANGED
Binary files a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc differ