Add files using upload-large-folder tool
Browse files- code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/eval/__pycache__/run_peract2_launch_smoke.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/eval/__pycache__/run_peract2_task_sweep.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/eval/__pycache__/run_rlbench_rollout_eval.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/eval/metrics.py +85 -0
- code/reveal_vla_bimanual/eval/run_peract2_launch_smoke.py +131 -0
- code/reveal_vla_bimanual/eval/run_proxy_diagnostics.py +148 -26
- code/reveal_vla_bimanual/eval/run_reveal_benchmark.py +48 -0
- code/reveal_vla_bimanual/eval/run_rlbench_rollout_eval.py +19 -1
- code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/models/action_decoder.py +304 -0
- code/reveal_vla_bimanual/models/backbones.py +249 -24
- code/reveal_vla_bimanual/models/multiview_fusion.py +74 -3
- code/reveal_vla_bimanual/models/observation_memory.py +192 -0
- code/reveal_vla_bimanual/models/planner.py +191 -0
- code/reveal_vla_bimanual/models/policy.py +319 -5
- code/reveal_vla_bimanual/models/reveal_head.py +242 -0
- code/reveal_vla_bimanual/models/world_model.py +185 -0
- code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc +0 -0
- code/reveal_vla_bimanual/sim_reveal/dataset.py +133 -14
- code/reveal_vla_bimanual/sim_reveal/procedural_envs.py +210 -5
- code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc +0 -0
- code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc +0 -0
code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc and b/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc and b/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/eval/__pycache__/run_peract2_launch_smoke.cpython-310.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|
code/reveal_vla_bimanual/eval/__pycache__/run_peract2_task_sweep.cpython-310.pyc
ADDED
|
Binary file (6.11 kB). View file
|
|
|
code/reveal_vla_bimanual/eval/__pycache__/run_rlbench_rollout_eval.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/eval/__pycache__/run_rlbench_rollout_eval.cpython-310.pyc and b/code/reveal_vla_bimanual/eval/__pycache__/run_rlbench_rollout_eval.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/eval/metrics.py
CHANGED
|
@@ -22,6 +22,13 @@ class PlannerDiagnostics:
|
|
| 22 |
regret: float
|
| 23 |
risk_calibration_mse: float
|
| 24 |
role_collapse_rate: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def mean_success(per_task_success: dict[str, float]) -> float:
|
|
@@ -87,6 +94,84 @@ def risk_calibration_mse(predicted_risk: np.ndarray, realized_risk: np.ndarray)
|
|
| 87 |
return float(np.mean((predicted_risk - realized_risk) ** 2))
|
| 88 |
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def role_collapse_rate(
|
| 91 |
action_chunks: np.ndarray,
|
| 92 |
arm_role_logits: np.ndarray | None = None,
|
|
|
|
| 22 |
regret: float
|
| 23 |
risk_calibration_mse: float
|
| 24 |
role_collapse_rate: float
|
| 25 |
+
proposal_diversity: float | None = None
|
| 26 |
+
planner_score_utility_spearman: float | None = None
|
| 27 |
+
left_right_equivariance_error: float | None = None
|
| 28 |
+
belief_calibration_brier: float | None = None
|
| 29 |
+
reocclusion_calibration_brier: float | None = None
|
| 30 |
+
support_stability_mae: float | None = None
|
| 31 |
+
clearance_auc: float | None = None
|
| 32 |
|
| 33 |
|
| 34 |
def mean_success(per_task_success: dict[str, float]) -> float:
|
|
|
|
| 94 |
return float(np.mean((predicted_risk - realized_risk) ** 2))
|
| 95 |
|
| 96 |
|
| 97 |
+
def proposal_diversity(proposal_chunks: np.ndarray) -> float:
|
| 98 |
+
proposal_chunks = np.asarray(proposal_chunks, dtype=np.float32)
|
| 99 |
+
if proposal_chunks.ndim != 4 or proposal_chunks.shape[1] <= 1:
|
| 100 |
+
return 0.0
|
| 101 |
+
flat = proposal_chunks.reshape(proposal_chunks.shape[0], proposal_chunks.shape[1], -1)
|
| 102 |
+
diffs = flat[:, :, None, :] - flat[:, None, :, :]
|
| 103 |
+
distances = np.abs(diffs).mean(axis=-1)
|
| 104 |
+
mask = ~np.eye(distances.shape[1], dtype=bool)
|
| 105 |
+
if not mask.any():
|
| 106 |
+
return 0.0
|
| 107 |
+
off_diagonal = distances[:, mask]
|
| 108 |
+
return float(off_diagonal.mean())
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def planner_score_utility_spearman(pred_scores: np.ndarray, oracle_utility: np.ndarray) -> float:
|
| 112 |
+
pred_scores = np.asarray(pred_scores, dtype=np.float32)
|
| 113 |
+
oracle_utility = np.asarray(oracle_utility, dtype=np.float32)
|
| 114 |
+
if pred_scores.size == 0:
|
| 115 |
+
return 0.0
|
| 116 |
+
pred_rank = pred_scores.argsort(axis=-1).argsort(axis=-1).astype(np.float32)
|
| 117 |
+
oracle_rank = oracle_utility.argsort(axis=-1).argsort(axis=-1).astype(np.float32)
|
| 118 |
+
pred_rank = pred_rank - pred_rank.mean(axis=-1, keepdims=True)
|
| 119 |
+
oracle_rank = oracle_rank - oracle_rank.mean(axis=-1, keepdims=True)
|
| 120 |
+
denom = np.sqrt((pred_rank**2).sum(axis=-1) * (oracle_rank**2).sum(axis=-1))
|
| 121 |
+
valid = denom > 1e-6
|
| 122 |
+
if not np.any(valid):
|
| 123 |
+
return 0.0
|
| 124 |
+
corr = np.zeros_like(denom)
|
| 125 |
+
corr[valid] = (pred_rank[valid] * oracle_rank[valid]).sum(axis=-1) / denom[valid]
|
| 126 |
+
return float(corr.mean())
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def left_right_equivariance_error(pred: np.ndarray, swapped_target: np.ndarray) -> float:
|
| 130 |
+
pred = np.asarray(pred, dtype=np.float32)
|
| 131 |
+
swapped_target = np.asarray(swapped_target, dtype=np.float32)
|
| 132 |
+
if pred.size == 0 or swapped_target.size == 0:
|
| 133 |
+
return 0.0
|
| 134 |
+
return float(np.abs(pred - swapped_target).mean())
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def belief_calibration_brier(predicted_belief: np.ndarray, target_belief: np.ndarray) -> float:
|
| 138 |
+
predicted_belief = np.asarray(predicted_belief, dtype=np.float32)
|
| 139 |
+
target_belief = np.asarray(target_belief, dtype=np.float32)
|
| 140 |
+
if predicted_belief.size == 0:
|
| 141 |
+
return 0.0
|
| 142 |
+
return float(np.mean((predicted_belief - target_belief) ** 2))
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def reocclusion_calibration_brier(predicted_reocclusion: np.ndarray, target_reocclusion: np.ndarray) -> float:
|
| 146 |
+
predicted_reocclusion = np.asarray(predicted_reocclusion, dtype=np.float32)
|
| 147 |
+
target_reocclusion = np.asarray(target_reocclusion, dtype=np.float32)
|
| 148 |
+
if predicted_reocclusion.size == 0:
|
| 149 |
+
return 0.0
|
| 150 |
+
return float(np.mean((predicted_reocclusion - target_reocclusion) ** 2))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def support_stability_mae(predicted: np.ndarray, target: np.ndarray) -> float:
|
| 154 |
+
predicted = np.asarray(predicted, dtype=np.float32)
|
| 155 |
+
target = np.asarray(target, dtype=np.float32)
|
| 156 |
+
if predicted.size == 0:
|
| 157 |
+
return 0.0
|
| 158 |
+
return float(np.abs(predicted - target).mean())
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def clearance_auc(predicted: np.ndarray, target: np.ndarray) -> float:
|
| 162 |
+
predicted = np.asarray(predicted, dtype=np.float32).reshape(-1)
|
| 163 |
+
target = np.asarray(target, dtype=np.float32).reshape(-1)
|
| 164 |
+
positives = target > 0.5
|
| 165 |
+
negatives = ~positives
|
| 166 |
+
if positives.sum() == 0 or negatives.sum() == 0:
|
| 167 |
+
return 0.0
|
| 168 |
+
order = np.argsort(predicted)
|
| 169 |
+
ranks = np.empty_like(order, dtype=np.float32)
|
| 170 |
+
ranks[order] = np.arange(order.shape[0], dtype=np.float32)
|
| 171 |
+
pos_ranks = ranks[positives]
|
| 172 |
+
return float((pos_ranks.sum() - positives.sum() * (positives.sum() - 1) / 2.0) / (positives.sum() * negatives.sum()))
|
| 173 |
+
|
| 174 |
+
|
| 175 |
def role_collapse_rate(
|
| 176 |
action_chunks: np.ndarray,
|
| 177 |
arm_role_logits: np.ndarray | None = None,
|
code/reveal_vla_bimanual/eval/run_peract2_launch_smoke.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from sim_rlbench.task_splits import PERACT2_BIMANUAL_TASKS
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _parse_json_payload(stdout: str) -> dict[str, Any]:
|
| 14 |
+
start = stdout.find("{")
|
| 15 |
+
end = stdout.rfind("}")
|
| 16 |
+
if start == -1 or end == -1 or end < start:
|
| 17 |
+
raise ValueError("No JSON object found in subprocess stdout.")
|
| 18 |
+
return json.loads(stdout[start : end + 1])
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _run_task(project_root: Path, output_dir: Path, task_name: str, *, resolution: int, headless: bool) -> dict[str, Any]:
|
| 22 |
+
task_dir = output_dir / task_name
|
| 23 |
+
task_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
command = [
|
| 25 |
+
sys.executable,
|
| 26 |
+
"-m",
|
| 27 |
+
"sim_rlbench.launch_smoke",
|
| 28 |
+
"--task",
|
| 29 |
+
task_name,
|
| 30 |
+
"--resolution",
|
| 31 |
+
str(resolution),
|
| 32 |
+
]
|
| 33 |
+
if headless:
|
| 34 |
+
command.append("--headless")
|
| 35 |
+
completed = subprocess.run(
|
| 36 |
+
command,
|
| 37 |
+
cwd=project_root,
|
| 38 |
+
text=True,
|
| 39 |
+
capture_output=True,
|
| 40 |
+
check=False,
|
| 41 |
+
)
|
| 42 |
+
(task_dir / "command.txt").write_text(" ".join(command) + "\n", encoding="utf-8")
|
| 43 |
+
(task_dir / "stdout.txt").write_text(completed.stdout, encoding="utf-8")
|
| 44 |
+
(task_dir / "stderr.txt").write_text(completed.stderr, encoding="utf-8")
|
| 45 |
+
|
| 46 |
+
payload: dict[str, Any] = {
|
| 47 |
+
"subprocess_returncode": int(completed.returncode),
|
| 48 |
+
"launch_ok": completed.returncode == 0,
|
| 49 |
+
}
|
| 50 |
+
try:
|
| 51 |
+
payload.update(_parse_json_payload(completed.stdout))
|
| 52 |
+
except Exception as exc:
|
| 53 |
+
payload["launch_ok"] = False
|
| 54 |
+
payload["error"] = f"json_parse_failed: {exc}"
|
| 55 |
+
if completed.returncode != 0 and "error" not in payload:
|
| 56 |
+
payload["error"] = f"subprocess_exit_{completed.returncode}"
|
| 57 |
+
return payload
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _write_markdown(path: Path, payload: dict[str, Any]) -> None:
|
| 61 |
+
lines = [
|
| 62 |
+
"# PerAct2 13-Task Launch Smoke",
|
| 63 |
+
"",
|
| 64 |
+
f"- Resolution: `{payload['resolution']}`",
|
| 65 |
+
f"- Headless: `{payload['headless']}`",
|
| 66 |
+
f"- Task count: `{payload['task_count']}`",
|
| 67 |
+
f"- Launch successes: `{payload['launch_successes']}`",
|
| 68 |
+
f"- Finite-action tasks: `{payload['finite_action_tasks']}`",
|
| 69 |
+
f"- Error tasks: `{payload['error_tasks']}`",
|
| 70 |
+
"",
|
| 71 |
+
"## Per-task",
|
| 72 |
+
"",
|
| 73 |
+
]
|
| 74 |
+
for task_name, task_payload in payload["tasks"].items():
|
| 75 |
+
if "error" in task_payload:
|
| 76 |
+
lines.append(
|
| 77 |
+
f"- `{task_name}`: launch_ok={task_payload.get('launch_ok')}, "
|
| 78 |
+
f"action_finite={task_payload.get('action_finite')}, "
|
| 79 |
+
f"error={task_payload['error']}, "
|
| 80 |
+
f"subprocess_returncode={task_payload['subprocess_returncode']}"
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
lines.append(
|
| 84 |
+
f"- `{task_name}`: launch_ok={task_payload.get('launch_ok')}, "
|
| 85 |
+
f"action_finite={task_payload.get('action_finite')}, "
|
| 86 |
+
f"task_class={task_payload.get('task')}, "
|
| 87 |
+
f"reward={task_payload.get('reward')}"
|
| 88 |
+
)
|
| 89 |
+
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def main() -> None:
|
| 93 |
+
parser = argparse.ArgumentParser()
|
| 94 |
+
parser.add_argument("--output-dir", required=True)
|
| 95 |
+
parser.add_argument("--tasks", nargs="*", default=list(PERACT2_BIMANUAL_TASKS))
|
| 96 |
+
parser.add_argument("--resolution", type=int, default=224)
|
| 97 |
+
parser.add_argument("--headless", action="store_true", default=True)
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
|
| 100 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 101 |
+
output_dir = Path(args.output_dir).resolve()
|
| 102 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
results: dict[str, Any] = {
|
| 105 |
+
"resolution": int(args.resolution),
|
| 106 |
+
"headless": bool(args.headless),
|
| 107 |
+
"tasks": {},
|
| 108 |
+
}
|
| 109 |
+
for task_name in tuple(args.tasks):
|
| 110 |
+
print(f"[peract2-launch-smoke] task={task_name}", flush=True)
|
| 111 |
+
results["tasks"][task_name] = _run_task(
|
| 112 |
+
project_root,
|
| 113 |
+
output_dir,
|
| 114 |
+
task_name,
|
| 115 |
+
resolution=args.resolution,
|
| 116 |
+
headless=args.headless,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
task_payloads = list(results["tasks"].values())
|
| 120 |
+
results["task_count"] = len(task_payloads)
|
| 121 |
+
results["launch_successes"] = int(sum(1 for payload in task_payloads if payload.get("launch_ok")))
|
| 122 |
+
results["finite_action_tasks"] = int(sum(1 for payload in task_payloads if payload.get("action_finite")))
|
| 123 |
+
results["error_tasks"] = sorted(task_name for task_name, payload in results["tasks"].items() if "error" in payload)
|
| 124 |
+
|
| 125 |
+
(output_dir / "launch_smoke_summary.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
|
| 126 |
+
_write_markdown(output_dir / "launch_smoke_summary.md", results)
|
| 127 |
+
print(json.dumps(results, indent=2))
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
main()
|
code/reveal_vla_bimanual/eval/run_proxy_diagnostics.py
CHANGED
|
@@ -8,9 +8,22 @@ from typing import Any
|
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
| 10 |
from torch import Tensor
|
|
|
|
| 11 |
from torch.utils.data import DataLoader
|
| 12 |
|
| 13 |
-
from eval.metrics import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from eval.run_reveal_benchmark import load_model
|
| 15 |
from sim_reveal.dataset import dataset_from_bundle, load_teacher_dataset
|
| 16 |
|
|
@@ -52,41 +65,121 @@ def main() -> None:
|
|
| 52 |
risk_batches: list[np.ndarray] = []
|
| 53 |
realized_risk_batches: list[np.ndarray] = []
|
| 54 |
collapse_batches: list[float] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
with torch.no_grad():
|
| 57 |
for batch in loader:
|
| 58 |
moved = _move_batch_to_device(batch, device)
|
| 59 |
-
|
| 60 |
-
images
|
| 61 |
-
proprio
|
| 62 |
-
texts
|
| 63 |
-
history_images
|
| 64 |
-
history_proprio
|
| 65 |
-
history_actions
|
| 66 |
-
plan
|
| 67 |
-
candidate_chunks_override
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
if "planner_scores" not in outputs:
|
| 70 |
raise RuntimeError("Planner outputs were not produced for proxy diagnostics.")
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
0.0,
|
| 79 |
-
1.0,
|
| 80 |
-
)
|
| 81 |
-
.detach()
|
| 82 |
-
.cpu()
|
| 83 |
-
.numpy()
|
| 84 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
selected_chunk = outputs["planned_chunk"].detach().cpu().numpy()[:, None]
|
|
|
|
| 86 |
role_logits = None
|
| 87 |
-
if
|
| 88 |
-
role_logits =
|
| 89 |
collapse_batches.append(role_collapse_rate(selected_chunk, role_logits))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
scores = np.concatenate(score_batches, axis=0) if score_batches else np.zeros((0, 0), dtype=np.float32)
|
| 92 |
utility = np.concatenate(utility_batches, axis=0) if utility_batches else np.zeros((0, 0), dtype=np.float32)
|
|
@@ -101,8 +194,37 @@ def main() -> None:
|
|
| 101 |
diagnostics = {
|
| 102 |
"planner_top1_accuracy": planner_top1_accuracy(scores, utility),
|
| 103 |
"planner_regret": planner_regret(selected_indices, utility),
|
|
|
|
| 104 |
"risk_calibration_mse": risk_calibration_mse(predicted_risk, realized_risk),
|
| 105 |
"role_collapse_rate": float(np.mean(collapse_batches)) if collapse_batches else 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
"num_samples": int(scores.shape[0]),
|
| 107 |
}
|
| 108 |
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
| 10 |
from torch import Tensor
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
from torch.utils.data import DataLoader
|
| 13 |
|
| 14 |
+
from eval.metrics import (
|
| 15 |
+
belief_calibration_brier,
|
| 16 |
+
clearance_auc,
|
| 17 |
+
left_right_equivariance_error,
|
| 18 |
+
planner_regret,
|
| 19 |
+
planner_score_utility_spearman,
|
| 20 |
+
planner_top1_accuracy,
|
| 21 |
+
proposal_diversity,
|
| 22 |
+
reocclusion_calibration_brier,
|
| 23 |
+
risk_calibration_mse,
|
| 24 |
+
role_collapse_rate,
|
| 25 |
+
support_stability_mae,
|
| 26 |
+
)
|
| 27 |
from eval.run_reveal_benchmark import load_model
|
| 28 |
from sim_reveal.dataset import dataset_from_bundle, load_teacher_dataset
|
| 29 |
|
|
|
|
| 65 |
risk_batches: list[np.ndarray] = []
|
| 66 |
realized_risk_batches: list[np.ndarray] = []
|
| 67 |
collapse_batches: list[float] = []
|
| 68 |
+
proposal_batches: list[np.ndarray] = []
|
| 69 |
+
equivariance_batches: list[float] = []
|
| 70 |
+
belief_pred_batches: list[np.ndarray] = []
|
| 71 |
+
belief_target_batches: list[np.ndarray] = []
|
| 72 |
+
reocclusion_pred_batches: list[np.ndarray] = []
|
| 73 |
+
reocclusion_target_batches: list[np.ndarray] = []
|
| 74 |
+
support_pred_batches: list[np.ndarray] = []
|
| 75 |
+
support_target_batches: list[np.ndarray] = []
|
| 76 |
+
clearance_pred_batches: list[np.ndarray] = []
|
| 77 |
+
clearance_target_batches: list[np.ndarray] = []
|
| 78 |
+
memory_write_batches: list[np.ndarray] = []
|
| 79 |
+
memory_saturation_batches: list[np.ndarray] = []
|
| 80 |
|
| 81 |
with torch.no_grad():
|
| 82 |
for batch in loader:
|
| 83 |
moved = _move_batch_to_device(batch, device)
|
| 84 |
+
forward_kwargs = {
|
| 85 |
+
"images": moved["images"],
|
| 86 |
+
"proprio": moved["proprio"],
|
| 87 |
+
"texts": moved["texts"],
|
| 88 |
+
"history_images": moved.get("history_images"),
|
| 89 |
+
"history_proprio": moved.get("history_proprio"),
|
| 90 |
+
"history_actions": moved.get("history_actions"),
|
| 91 |
+
"plan": True,
|
| 92 |
+
"candidate_chunks_override": moved["candidate_action_chunks"],
|
| 93 |
+
}
|
| 94 |
+
if hasattr(model, "elastic_state_head"):
|
| 95 |
+
forward_kwargs.update(
|
| 96 |
+
{
|
| 97 |
+
"depths": moved.get("depths"),
|
| 98 |
+
"depth_valid": moved.get("depth_valid"),
|
| 99 |
+
"camera_intrinsics": moved.get("camera_intrinsics"),
|
| 100 |
+
"camera_extrinsics": moved.get("camera_extrinsics"),
|
| 101 |
+
"history_depths": moved.get("history_depths"),
|
| 102 |
+
"history_depth_valid": moved.get("history_depth_valid"),
|
| 103 |
+
"use_depth": moved.get("depths") is not None,
|
| 104 |
+
"use_world_model": True,
|
| 105 |
+
"use_planner": True,
|
| 106 |
+
"use_role_tokens": True,
|
| 107 |
+
"compute_equivariance_probe": True,
|
| 108 |
+
}
|
| 109 |
+
)
|
| 110 |
+
outputs = model(**forward_kwargs)
|
| 111 |
if "planner_scores" not in outputs:
|
| 112 |
raise RuntimeError("Planner outputs were not produced for proxy diagnostics.")
|
| 113 |
+
planner_scores = outputs["planner_scores"]
|
| 114 |
+
candidate_utility = moved["candidate_utility"]
|
| 115 |
+
predicted_risk = outputs["planner_risk_values"]
|
| 116 |
+
realized_risk = torch.clamp(
|
| 117 |
+
moved["candidate_final_disturbance_cost"] + moved["candidate_reocclusion_rate"],
|
| 118 |
+
0.0,
|
| 119 |
+
1.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
)
|
| 121 |
+
shortlist_indices = outputs.get("planner_topk_indices")
|
| 122 |
+
if shortlist_indices is not None:
|
| 123 |
+
candidate_utility = candidate_utility.gather(1, shortlist_indices)
|
| 124 |
+
predicted_risk = predicted_risk
|
| 125 |
+
realized_risk = realized_risk.gather(1, shortlist_indices)
|
| 126 |
+
score_batches.append(planner_scores.detach().cpu().numpy())
|
| 127 |
+
utility_batches.append(candidate_utility.detach().cpu().numpy())
|
| 128 |
+
best_index_batches.append(outputs["best_candidate_indices"].detach().cpu().numpy())
|
| 129 |
+
risk_batches.append(predicted_risk.detach().cpu().numpy())
|
| 130 |
+
realized_risk_batches.append(realized_risk.detach().cpu().numpy())
|
| 131 |
selected_chunk = outputs["planned_chunk"].detach().cpu().numpy()[:, None]
|
| 132 |
+
state = outputs.get("interaction_state") or outputs.get("reveal_state")
|
| 133 |
role_logits = None
|
| 134 |
+
if state is not None:
|
| 135 |
+
role_logits = state["arm_role_logits"].detach().cpu().numpy()[:, None]
|
| 136 |
collapse_batches.append(role_collapse_rate(selected_chunk, role_logits))
|
| 137 |
+
if outputs.get("proposal_candidates") is not None:
|
| 138 |
+
proposal_batches.append(outputs["proposal_candidates"].detach().cpu().numpy())
|
| 139 |
+
if outputs.get("equivariance_probe_action_mean") is not None:
|
| 140 |
+
equivariance_batches.append(
|
| 141 |
+
left_right_equivariance_error(
|
| 142 |
+
outputs["equivariance_probe_action_mean"].detach().cpu().numpy(),
|
| 143 |
+
outputs["equivariance_target_action_mean"].detach().cpu().numpy(),
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
if state is not None:
|
| 147 |
+
if "belief_map" in state and "belief_map" in moved:
|
| 148 |
+
belief_pred_batches.append(torch.sigmoid(state["belief_map"]).detach().cpu().numpy())
|
| 149 |
+
belief_target_batches.append(moved["belief_map"].detach().cpu().numpy())
|
| 150 |
+
if "reocclusion_field" in state and "reocclusion_target" in moved:
|
| 151 |
+
reocclusion_pred_batches.append(torch.sigmoid(state["reocclusion_field"]).mean(dim=(-1, -2)).detach().cpu().numpy())
|
| 152 |
+
reocclusion_target_batches.append(moved["reocclusion_target"].detach().cpu().numpy())
|
| 153 |
+
if "support_stability_field" in state and "support_stability" in moved:
|
| 154 |
+
support_pred_batches.append(torch.sigmoid(state["support_stability_field"]).mean(dim=(-1, -2)).detach().cpu().numpy())
|
| 155 |
+
support_target_batches.append(moved["support_stability"].detach().cpu().numpy())
|
| 156 |
+
if "clearance_field" in state and "clearance_map" in moved:
|
| 157 |
+
clearance_pred = torch.sigmoid(state["clearance_field"])
|
| 158 |
+
clearance_target = moved["clearance_map"]
|
| 159 |
+
if clearance_pred.shape[-2:] != clearance_target.shape[-2:]:
|
| 160 |
+
clearance_pred = F.interpolate(
|
| 161 |
+
clearance_pred,
|
| 162 |
+
size=clearance_target.shape[-2:],
|
| 163 |
+
mode="bilinear",
|
| 164 |
+
align_corners=False,
|
| 165 |
+
)
|
| 166 |
+
if clearance_pred.shape[1] != clearance_target.shape[1]:
|
| 167 |
+
if clearance_pred.shape[1] == 1:
|
| 168 |
+
clearance_pred = clearance_pred.expand(-1, clearance_target.shape[1], -1, -1)
|
| 169 |
+
elif clearance_target.shape[1] == 1:
|
| 170 |
+
clearance_target = clearance_target.expand_as(clearance_pred)
|
| 171 |
+
else:
|
| 172 |
+
min_channels = min(clearance_pred.shape[1], clearance_target.shape[1])
|
| 173 |
+
clearance_pred = clearance_pred[:, :min_channels]
|
| 174 |
+
clearance_target = clearance_target[:, :min_channels]
|
| 175 |
+
clearance_pred_batches.append(clearance_pred.detach().cpu().numpy())
|
| 176 |
+
clearance_target_batches.append(clearance_target.detach().cpu().numpy())
|
| 177 |
+
if outputs.get("memory_output") is not None:
|
| 178 |
+
memory_output = outputs["memory_output"]
|
| 179 |
+
if "memory_write_rate" in memory_output:
|
| 180 |
+
memory_write_batches.append(memory_output["memory_write_rate"].detach().cpu().numpy())
|
| 181 |
+
if "memory_saturation" in memory_output:
|
| 182 |
+
memory_saturation_batches.append(memory_output["memory_saturation"].detach().cpu().numpy())
|
| 183 |
|
| 184 |
scores = np.concatenate(score_batches, axis=0) if score_batches else np.zeros((0, 0), dtype=np.float32)
|
| 185 |
utility = np.concatenate(utility_batches, axis=0) if utility_batches else np.zeros((0, 0), dtype=np.float32)
|
|
|
|
| 194 |
diagnostics = {
|
| 195 |
"planner_top1_accuracy": planner_top1_accuracy(scores, utility),
|
| 196 |
"planner_regret": planner_regret(selected_indices, utility),
|
| 197 |
+
"planner_score_utility_spearman": planner_score_utility_spearman(scores, utility),
|
| 198 |
"risk_calibration_mse": risk_calibration_mse(predicted_risk, realized_risk),
|
| 199 |
"role_collapse_rate": float(np.mean(collapse_batches)) if collapse_batches else 0.0,
|
| 200 |
+
"proposal_diversity": proposal_diversity(np.concatenate(proposal_batches, axis=0)) if proposal_batches else 0.0,
|
| 201 |
+
"left_right_equivariance_error": float(np.mean(equivariance_batches)) if equivariance_batches else 0.0,
|
| 202 |
+
"belief_calibration_brier": belief_calibration_brier(
|
| 203 |
+
np.concatenate(belief_pred_batches, axis=0),
|
| 204 |
+
np.concatenate(belief_target_batches, axis=0),
|
| 205 |
+
)
|
| 206 |
+
if belief_pred_batches
|
| 207 |
+
else 0.0,
|
| 208 |
+
"reocclusion_calibration_brier": reocclusion_calibration_brier(
|
| 209 |
+
np.concatenate(reocclusion_pred_batches, axis=0),
|
| 210 |
+
np.concatenate(reocclusion_target_batches, axis=0),
|
| 211 |
+
)
|
| 212 |
+
if reocclusion_pred_batches
|
| 213 |
+
else 0.0,
|
| 214 |
+
"support_stability_mae": support_stability_mae(
|
| 215 |
+
np.concatenate(support_pred_batches, axis=0),
|
| 216 |
+
np.concatenate(support_target_batches, axis=0),
|
| 217 |
+
)
|
| 218 |
+
if support_pred_batches
|
| 219 |
+
else 0.0,
|
| 220 |
+
"clearance_auc": clearance_auc(
|
| 221 |
+
np.concatenate(clearance_pred_batches, axis=0),
|
| 222 |
+
np.concatenate(clearance_target_batches, axis=0),
|
| 223 |
+
)
|
| 224 |
+
if clearance_pred_batches
|
| 225 |
+
else 0.0,
|
| 226 |
+
"memory_write_rate": float(np.mean(np.concatenate(memory_write_batches, axis=0))) if memory_write_batches else 0.0,
|
| 227 |
+
"memory_saturation": float(np.mean(np.concatenate(memory_saturation_batches, axis=0))) if memory_saturation_batches else 0.0,
|
| 228 |
"num_samples": int(scores.shape[0]),
|
| 229 |
}
|
| 230 |
|
code/reveal_vla_bimanual/eval/run_reveal_benchmark.py
CHANGED
|
@@ -73,12 +73,18 @@ def _prepare_batch(
|
|
| 73 |
observation: dict[str, Any],
|
| 74 |
device: torch.device,
|
| 75 |
history_images: list[np.ndarray] | None = None,
|
|
|
|
|
|
|
| 76 |
history_proprio: list[np.ndarray] | None = None,
|
| 77 |
history_actions: list[np.ndarray] | None = None,
|
| 78 |
) -> dict[str, Any]:
|
| 79 |
images = torch.from_numpy(observation["images"]).permute(0, 3, 1, 2).unsqueeze(0).float() / 255.0
|
|
|
|
|
|
|
| 80 |
proprio = torch.from_numpy(observation["proprio"]).unsqueeze(0).float()
|
| 81 |
history_images = history_images or []
|
|
|
|
|
|
|
| 82 |
history_proprio = history_proprio or []
|
| 83 |
history_actions = history_actions or []
|
| 84 |
if history_images:
|
|
@@ -90,6 +96,12 @@ def _prepare_batch(
|
|
| 90 |
(1, 0, images.shape[1], images.shape[2], images.shape[3], images.shape[4]),
|
| 91 |
dtype=torch.float32,
|
| 92 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
if history_proprio:
|
| 94 |
history_proprio_tensor = torch.from_numpy(np.stack(history_proprio, axis=0)).unsqueeze(0).float()
|
| 95 |
else:
|
|
@@ -100,7 +112,13 @@ def _prepare_batch(
|
|
| 100 |
history_actions_tensor = torch.zeros((1, 0, 14), dtype=torch.float32)
|
| 101 |
return {
|
| 102 |
"images": images.to(device),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
"history_images": history_images_tensor.to(device),
|
|
|
|
|
|
|
| 104 |
"history_proprio": history_proprio_tensor.to(device),
|
| 105 |
"history_actions": history_actions_tensor.to(device),
|
| 106 |
"proprio": proprio.to(device),
|
|
@@ -147,6 +165,27 @@ def select_chunk(
|
|
| 147 |
if "planned_chunk" in outputs and ablation not in {"no_world_model", "no_interaction_head"}:
|
| 148 |
return outputs["planned_chunk"], outputs
|
| 149 |
return outputs["action_mean"], outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
if hasattr(model, "reveal_head"):
|
| 151 |
if ablation == "no_world_model":
|
| 152 |
outputs = model(**forward_kwargs, plan=False)
|
|
@@ -195,6 +234,8 @@ def evaluate_model(
|
|
| 195 |
episode_corridor = [float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any())]
|
| 196 |
episode_disturbance = [float(privileged_state["disturbance_cost"])]
|
| 197 |
history_images: list[np.ndarray] = []
|
|
|
|
|
|
|
| 198 |
history_proprio: list[np.ndarray] = []
|
| 199 |
history_actions: list[np.ndarray] = []
|
| 200 |
done = False
|
|
@@ -203,6 +244,8 @@ def evaluate_model(
|
|
| 203 |
observation,
|
| 204 |
device=device,
|
| 205 |
history_images=history_images,
|
|
|
|
|
|
|
| 206 |
history_proprio=history_proprio,
|
| 207 |
history_actions=history_actions,
|
| 208 |
)
|
|
@@ -224,9 +267,14 @@ def evaluate_model(
|
|
| 224 |
if history_steps > 0:
|
| 225 |
if len(history_images) >= history_steps:
|
| 226 |
history_images = history_images[-history_steps + 1 :]
|
|
|
|
|
|
|
| 227 |
history_proprio = history_proprio[-history_steps + 1 :]
|
| 228 |
history_actions = history_actions[-history_steps + 1 :]
|
| 229 |
history_images.append(observation["images"])
|
|
|
|
|
|
|
|
|
|
| 230 |
history_proprio.append(observation["proprio"])
|
| 231 |
history_actions.append(action.astype(np.float32))
|
| 232 |
observation, _, terminated, truncated, privileged_state = env.step(action)
|
|
|
|
| 73 |
observation: dict[str, Any],
|
| 74 |
device: torch.device,
|
| 75 |
history_images: list[np.ndarray] | None = None,
|
| 76 |
+
history_depths: list[np.ndarray] | None = None,
|
| 77 |
+
history_depth_valid: list[np.ndarray] | None = None,
|
| 78 |
history_proprio: list[np.ndarray] | None = None,
|
| 79 |
history_actions: list[np.ndarray] | None = None,
|
| 80 |
) -> dict[str, Any]:
|
| 81 |
images = torch.from_numpy(observation["images"]).permute(0, 3, 1, 2).unsqueeze(0).float() / 255.0
|
| 82 |
+
depths = torch.from_numpy(observation.get("depths", np.zeros((3, 1, images.shape[-2], images.shape[-1]), dtype=np.float32))).unsqueeze(0).float()
|
| 83 |
+
depth_valid = torch.from_numpy(observation.get("depth_valid", np.zeros((3, 1, images.shape[-2], images.shape[-1]), dtype=np.float32))).unsqueeze(0).float()
|
| 84 |
proprio = torch.from_numpy(observation["proprio"]).unsqueeze(0).float()
|
| 85 |
history_images = history_images or []
|
| 86 |
+
history_depths = history_depths or []
|
| 87 |
+
history_depth_valid = history_depth_valid or []
|
| 88 |
history_proprio = history_proprio or []
|
| 89 |
history_actions = history_actions or []
|
| 90 |
if history_images:
|
|
|
|
| 96 |
(1, 0, images.shape[1], images.shape[2], images.shape[3], images.shape[4]),
|
| 97 |
dtype=torch.float32,
|
| 98 |
)
|
| 99 |
+
if history_depths:
|
| 100 |
+
history_depths_tensor = torch.from_numpy(np.stack(history_depths, axis=0)).unsqueeze(0).float()
|
| 101 |
+
history_depth_valid_tensor = torch.from_numpy(np.stack(history_depth_valid, axis=0)).unsqueeze(0).float()
|
| 102 |
+
else:
|
| 103 |
+
history_depths_tensor = torch.zeros((1, 0, depths.shape[1], depths.shape[2], depths.shape[3], depths.shape[4]), dtype=torch.float32)
|
| 104 |
+
history_depth_valid_tensor = torch.zeros_like(history_depths_tensor)
|
| 105 |
if history_proprio:
|
| 106 |
history_proprio_tensor = torch.from_numpy(np.stack(history_proprio, axis=0)).unsqueeze(0).float()
|
| 107 |
else:
|
|
|
|
| 112 |
history_actions_tensor = torch.zeros((1, 0, 14), dtype=torch.float32)
|
| 113 |
return {
|
| 114 |
"images": images.to(device),
|
| 115 |
+
"depths": depths.to(device),
|
| 116 |
+
"depth_valid": depth_valid.to(device),
|
| 117 |
+
"camera_intrinsics": torch.from_numpy(observation.get("camera_intrinsics", np.zeros((3, 3, 3), dtype=np.float32))).unsqueeze(0).to(device),
|
| 118 |
+
"camera_extrinsics": torch.from_numpy(observation.get("camera_extrinsics", np.zeros((3, 4, 4), dtype=np.float32))).unsqueeze(0).to(device),
|
| 119 |
"history_images": history_images_tensor.to(device),
|
| 120 |
+
"history_depths": history_depths_tensor.to(device),
|
| 121 |
+
"history_depth_valid": history_depth_valid_tensor.to(device),
|
| 122 |
"history_proprio": history_proprio_tensor.to(device),
|
| 123 |
"history_actions": history_actions_tensor.to(device),
|
| 124 |
"proprio": proprio.to(device),
|
|
|
|
| 165 |
if "planned_chunk" in outputs and ablation not in {"no_world_model", "no_interaction_head"}:
|
| 166 |
return outputs["planned_chunk"], outputs
|
| 167 |
return outputs["action_mean"], outputs
|
| 168 |
+
if hasattr(model, "elastic_state_head"):
|
| 169 |
+
outputs = model(
|
| 170 |
+
**forward_kwargs,
|
| 171 |
+
depths=batch.get("depths"),
|
| 172 |
+
depth_valid=batch.get("depth_valid"),
|
| 173 |
+
camera_intrinsics=batch.get("camera_intrinsics"),
|
| 174 |
+
camera_extrinsics=batch.get("camera_extrinsics"),
|
| 175 |
+
history_depths=batch.get("history_depths"),
|
| 176 |
+
history_depth_valid=batch.get("history_depth_valid"),
|
| 177 |
+
plan=True,
|
| 178 |
+
use_world_model=(ablation not in {"no_world_model", "no_planner"}),
|
| 179 |
+
use_planner=(ablation != "no_planner"),
|
| 180 |
+
use_depth=(ablation != "no_depth"),
|
| 181 |
+
use_role_tokens=(ablation not in {"no_role_tokens", "no_role_symmetry"}),
|
| 182 |
+
history_steps_override=(2 if ablation == "short_history" else None),
|
| 183 |
+
)
|
| 184 |
+
if "planned_chunk" in outputs and ablation != "no_planner":
|
| 185 |
+
return outputs["planned_chunk"], outputs
|
| 186 |
+
if "candidate_chunks" in outputs:
|
| 187 |
+
return outputs["candidate_chunks"][:, 0], outputs
|
| 188 |
+
return outputs["action_mean"], outputs
|
| 189 |
if hasattr(model, "reveal_head"):
|
| 190 |
if ablation == "no_world_model":
|
| 191 |
outputs = model(**forward_kwargs, plan=False)
|
|
|
|
| 234 |
episode_corridor = [float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any())]
|
| 235 |
episode_disturbance = [float(privileged_state["disturbance_cost"])]
|
| 236 |
history_images: list[np.ndarray] = []
|
| 237 |
+
history_depths: list[np.ndarray] = []
|
| 238 |
+
history_depth_valid: list[np.ndarray] = []
|
| 239 |
history_proprio: list[np.ndarray] = []
|
| 240 |
history_actions: list[np.ndarray] = []
|
| 241 |
done = False
|
|
|
|
| 244 |
observation,
|
| 245 |
device=device,
|
| 246 |
history_images=history_images,
|
| 247 |
+
history_depths=history_depths,
|
| 248 |
+
history_depth_valid=history_depth_valid,
|
| 249 |
history_proprio=history_proprio,
|
| 250 |
history_actions=history_actions,
|
| 251 |
)
|
|
|
|
| 267 |
if history_steps > 0:
|
| 268 |
if len(history_images) >= history_steps:
|
| 269 |
history_images = history_images[-history_steps + 1 :]
|
| 270 |
+
history_depths = history_depths[-history_steps + 1 :]
|
| 271 |
+
history_depth_valid = history_depth_valid[-history_steps + 1 :]
|
| 272 |
history_proprio = history_proprio[-history_steps + 1 :]
|
| 273 |
history_actions = history_actions[-history_steps + 1 :]
|
| 274 |
history_images.append(observation["images"])
|
| 275 |
+
if "depths" in observation:
|
| 276 |
+
history_depths.append(observation["depths"])
|
| 277 |
+
history_depth_valid.append(observation["depth_valid"])
|
| 278 |
history_proprio.append(observation["proprio"])
|
| 279 |
history_actions.append(action.astype(np.float32))
|
| 280 |
observation, _, terminated, truncated, privileged_state = env.step(action)
|
code/reveal_vla_bimanual/eval/run_rlbench_rollout_eval.py
CHANGED
|
@@ -52,6 +52,19 @@ def _episode_language_goal(descriptions: Sequence[str]) -> str:
|
|
| 52 |
return str(descriptions[0]) if descriptions else ""
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def _noop_bimanual_action(obs: Any) -> np.ndarray:
|
| 56 |
right_obs = getattr(obs, "right", None)
|
| 57 |
left_obs = getattr(obs, "left", None)
|
|
@@ -113,6 +126,7 @@ def main() -> None:
|
|
| 113 |
parser.add_argument("--disable-support-mode-conditioning", action="store_true")
|
| 114 |
parser.add_argument("--headless", action="store_true", default=True)
|
| 115 |
parser.add_argument("--chunk-commit-steps", type=int, default=0)
|
|
|
|
| 116 |
args = parser.parse_args()
|
| 117 |
|
| 118 |
checkpoint = torch.load(Path(args.checkpoint), map_location="cpu", weights_only=False)
|
|
@@ -155,6 +169,7 @@ def main() -> None:
|
|
| 155 |
"episodes_per_task": args.episodes_per_task,
|
| 156 |
"episode_length": args.episode_length,
|
| 157 |
"resolution": args.resolution,
|
|
|
|
| 158 |
"cameras": list(camera_spec.cameras),
|
| 159 |
"tasks": {},
|
| 160 |
}
|
|
@@ -180,8 +195,10 @@ def main() -> None:
|
|
| 180 |
)
|
| 181 |
env.launch()
|
| 182 |
task = env.get_task(task_class)
|
|
|
|
| 183 |
for _ in range(args.episodes_per_task):
|
| 184 |
-
descriptions, obs = task
|
|
|
|
| 185 |
language_goal = _episode_language_goal(descriptions)
|
| 186 |
total_reward = 0.0
|
| 187 |
success = 0.0
|
|
@@ -291,6 +308,7 @@ def main() -> None:
|
|
| 291 |
"returns": task_returns,
|
| 292 |
"path_recoveries": episode_recoveries if args.episodes_per_task == 1 else None,
|
| 293 |
"noop_fallbacks": episode_noop_fallbacks if args.episodes_per_task == 1 else None,
|
|
|
|
| 294 |
"mean_success": float(np.mean(task_successes)) if task_successes else 0.0,
|
| 295 |
"mean_return": float(np.mean(task_returns)) if task_returns else 0.0,
|
| 296 |
}
|
|
|
|
| 52 |
return str(descriptions[0]) if descriptions else ""
|
| 53 |
|
| 54 |
|
| 55 |
+
def _reset_task_with_retries(task: Any, max_attempts: int) -> tuple[Sequence[str], Any, int]:
|
| 56 |
+
last_error: Exception | None = None
|
| 57 |
+
for attempt in range(max_attempts):
|
| 58 |
+
try:
|
| 59 |
+
descriptions, obs = task.reset()
|
| 60 |
+
return descriptions, obs, attempt
|
| 61 |
+
except Exception as exc: # pragma: no cover - live RLBench failure path
|
| 62 |
+
last_error = exc
|
| 63 |
+
if last_error is not None:
|
| 64 |
+
raise last_error
|
| 65 |
+
raise RuntimeError("Task reset failed without raising a concrete exception.")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
def _noop_bimanual_action(obs: Any) -> np.ndarray:
|
| 69 |
right_obs = getattr(obs, "right", None)
|
| 70 |
left_obs = getattr(obs, "left", None)
|
|
|
|
| 126 |
parser.add_argument("--disable-support-mode-conditioning", action="store_true")
|
| 127 |
parser.add_argument("--headless", action="store_true", default=True)
|
| 128 |
parser.add_argument("--chunk-commit-steps", type=int, default=0)
|
| 129 |
+
parser.add_argument("--reset-retries", type=int, default=20)
|
| 130 |
args = parser.parse_args()
|
| 131 |
|
| 132 |
checkpoint = torch.load(Path(args.checkpoint), map_location="cpu", weights_only=False)
|
|
|
|
| 169 |
"episodes_per_task": args.episodes_per_task,
|
| 170 |
"episode_length": args.episode_length,
|
| 171 |
"resolution": args.resolution,
|
| 172 |
+
"reset_retries": args.reset_retries,
|
| 173 |
"cameras": list(camera_spec.cameras),
|
| 174 |
"tasks": {},
|
| 175 |
}
|
|
|
|
| 195 |
)
|
| 196 |
env.launch()
|
| 197 |
task = env.get_task(task_class)
|
| 198 |
+
task_reset_retries: list[int] = []
|
| 199 |
for _ in range(args.episodes_per_task):
|
| 200 |
+
descriptions, obs, reset_retries = _reset_task_with_retries(task, max_attempts=max(1, args.reset_retries))
|
| 201 |
+
task_reset_retries.append(int(reset_retries))
|
| 202 |
language_goal = _episode_language_goal(descriptions)
|
| 203 |
total_reward = 0.0
|
| 204 |
success = 0.0
|
|
|
|
| 308 |
"returns": task_returns,
|
| 309 |
"path_recoveries": episode_recoveries if args.episodes_per_task == 1 else None,
|
| 310 |
"noop_fallbacks": episode_noop_fallbacks if args.episodes_per_task == 1 else None,
|
| 311 |
+
"reset_retries": task_reset_retries,
|
| 312 |
"mean_success": float(np.mean(task_successes)) if task_successes else 0.0,
|
| 313 |
"mean_return": float(np.mean(task_returns)) if task_returns else 0.0,
|
| 314 |
}
|
code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-310.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-311.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc and b/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc and b/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/models/action_decoder.py
CHANGED
|
@@ -19,6 +19,8 @@ class ChunkDecoderConfig:
|
|
| 19 |
num_candidates: int = 8
|
| 20 |
num_phases: int = 5
|
| 21 |
num_arm_roles: int = 4
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class ACTBimanualChunkDecoder(nn.Module):
|
|
@@ -381,3 +383,305 @@ class InteractionChunkDecoder(nn.Module):
|
|
| 381 |
candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1)
|
| 382 |
candidates[:, 0] = action_mean
|
| 383 |
return candidates
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
num_candidates: int = 8
|
| 20 |
num_phases: int = 5
|
| 21 |
num_arm_roles: int = 4
|
| 22 |
+
num_proposal_modes: int = 6
|
| 23 |
+
planner_top_k: int = 4
|
| 24 |
|
| 25 |
|
| 26 |
class ACTBimanualChunkDecoder(nn.Module):
|
|
|
|
| 383 |
candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1)
|
| 384 |
candidates[:, 0] = action_mean
|
| 385 |
return candidates
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
DEFAULT_PROPOSAL_MODES = (
|
| 389 |
+
"widen_opening",
|
| 390 |
+
"maintain_opening",
|
| 391 |
+
"slide_occluder",
|
| 392 |
+
"lift_support_layer",
|
| 393 |
+
"stabilize_support",
|
| 394 |
+
"retrieve",
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def swap_arm_action_order(action_chunk: Tensor) -> Tensor:
|
| 399 |
+
midpoint = action_chunk.shape[-1] // 2
|
| 400 |
+
return torch.cat([action_chunk[..., midpoint:], action_chunk[..., :midpoint]], dim=-1)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
class SymmetricCoordinatedChunkDecoder(nn.Module):
|
| 404 |
+
def __init__(self, config: ChunkDecoderConfig) -> None:
|
| 405 |
+
super().__init__()
|
| 406 |
+
self.config = config
|
| 407 |
+
proposal_context_dim = config.action_dim + (config.hidden_dim * 2)
|
| 408 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
| 409 |
+
d_model=config.hidden_dim,
|
| 410 |
+
nhead=config.num_heads,
|
| 411 |
+
dim_feedforward=config.ff_dim,
|
| 412 |
+
dropout=config.dropout,
|
| 413 |
+
batch_first=True,
|
| 414 |
+
norm_first=True,
|
| 415 |
+
)
|
| 416 |
+
self.arm_decoder = nn.TransformerDecoder(decoder_layer, num_layers=config.num_layers)
|
| 417 |
+
self.query_embed = nn.Embedding(config.chunk_size, config.hidden_dim)
|
| 418 |
+
self.arm_identity = nn.Embedding(2, config.hidden_dim)
|
| 419 |
+
self.phase_adapter = nn.Linear(config.num_phases, config.hidden_dim)
|
| 420 |
+
self.role_adapter = nn.Linear(config.num_arm_roles, config.hidden_dim)
|
| 421 |
+
self.context_proj = nn.Sequential(
|
| 422 |
+
nn.LayerNorm(config.hidden_dim),
|
| 423 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 424 |
+
nn.GELU(),
|
| 425 |
+
)
|
| 426 |
+
self.coordination = nn.Sequential(
|
| 427 |
+
nn.LayerNorm(config.hidden_dim * 3),
|
| 428 |
+
nn.Linear(config.hidden_dim * 3, config.hidden_dim),
|
| 429 |
+
nn.GELU(),
|
| 430 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 431 |
+
)
|
| 432 |
+
self.arm_head = nn.Sequential(
|
| 433 |
+
nn.LayerNorm(config.hidden_dim),
|
| 434 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 435 |
+
nn.GELU(),
|
| 436 |
+
)
|
| 437 |
+
self.arm_mean = nn.Linear(config.hidden_dim, config.arm_action_dim)
|
| 438 |
+
self.arm_log_std = nn.Linear(config.hidden_dim, config.arm_action_dim)
|
| 439 |
+
self.proposal_mode_head = nn.Sequential(
|
| 440 |
+
nn.LayerNorm(proposal_context_dim),
|
| 441 |
+
nn.Linear(proposal_context_dim, config.hidden_dim),
|
| 442 |
+
nn.GELU(),
|
| 443 |
+
nn.Linear(config.hidden_dim, config.num_proposal_modes),
|
| 444 |
+
)
|
| 445 |
+
self.proposal_mode_embeddings = nn.Embedding(config.num_proposal_modes, config.hidden_dim)
|
| 446 |
+
self.proposal_slot_embeddings = nn.Embedding(config.num_candidates, config.hidden_dim)
|
| 447 |
+
self.mode_residual_heads = nn.ModuleList(
|
| 448 |
+
[
|
| 449 |
+
nn.Sequential(
|
| 450 |
+
nn.LayerNorm(proposal_context_dim),
|
| 451 |
+
nn.Linear(proposal_context_dim, config.hidden_dim),
|
| 452 |
+
nn.GELU(),
|
| 453 |
+
nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim),
|
| 454 |
+
)
|
| 455 |
+
for _ in range(config.num_proposal_modes)
|
| 456 |
+
]
|
| 457 |
+
)
|
| 458 |
+
self.slot_delta = nn.Sequential(
|
| 459 |
+
nn.LayerNorm(config.hidden_dim),
|
| 460 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 461 |
+
nn.GELU(),
|
| 462 |
+
nn.Linear(config.hidden_dim, config.chunk_size * config.action_dim),
|
| 463 |
+
)
|
| 464 |
+
self.proposal_score = nn.Sequential(
|
| 465 |
+
nn.LayerNorm(proposal_context_dim + config.hidden_dim),
|
| 466 |
+
nn.Linear(proposal_context_dim + config.hidden_dim, config.hidden_dim),
|
| 467 |
+
nn.GELU(),
|
| 468 |
+
nn.Linear(config.hidden_dim, 1),
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
def _conditioning(
|
| 472 |
+
self,
|
| 473 |
+
interaction_state: dict[str, Tensor] | None,
|
| 474 |
+
batch_size: int,
|
| 475 |
+
device: torch.device,
|
| 476 |
+
dtype: torch.dtype,
|
| 477 |
+
swap_roles: bool = False,
|
| 478 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 479 |
+
if interaction_state is None:
|
| 480 |
+
zero_phase = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype)
|
| 481 |
+
zero_roles = torch.zeros(batch_size, 2, self.config.hidden_dim, device=device, dtype=dtype)
|
| 482 |
+
zero_context = torch.zeros(batch_size, self.config.hidden_dim, device=device, dtype=dtype)
|
| 483 |
+
return zero_phase, zero_roles, zero_context
|
| 484 |
+
phase_probs = interaction_state["phase_logits"].softmax(dim=-1).to(dtype=dtype)
|
| 485 |
+
arm_role_probs = interaction_state["arm_role_logits"].softmax(dim=-1).to(dtype=dtype)
|
| 486 |
+
if swap_roles:
|
| 487 |
+
arm_role_probs = arm_role_probs.flip(1)
|
| 488 |
+
phase_context = self.phase_adapter(phase_probs)
|
| 489 |
+
role_context = self.role_adapter(arm_role_probs)
|
| 490 |
+
if interaction_state.get("interaction_tokens") is not None:
|
| 491 |
+
interaction_context = interaction_state["interaction_tokens"].mean(dim=1)
|
| 492 |
+
else:
|
| 493 |
+
interaction_context = interaction_state["field_tokens"].mean(dim=1)
|
| 494 |
+
return phase_context, role_context, self.context_proj(interaction_context)
|
| 495 |
+
|
| 496 |
+
def _decode_arm_tokens(
|
| 497 |
+
self,
|
| 498 |
+
queries: Tensor,
|
| 499 |
+
decoder_memory: Tensor,
|
| 500 |
+
phase_context: Tensor,
|
| 501 |
+
role_context: Tensor,
|
| 502 |
+
interaction_context: Tensor,
|
| 503 |
+
swap_roles: bool = False,
|
| 504 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 505 |
+
batch_size, chunk_size, _ = queries.shape
|
| 506 |
+
identity_order = torch.tensor([1, 0], device=queries.device) if swap_roles else torch.tensor([0, 1], device=queries.device)
|
| 507 |
+
arm_queries = queries.unsqueeze(1).expand(-1, 2, -1, -1)
|
| 508 |
+
arm_queries = arm_queries + phase_context.unsqueeze(1).unsqueeze(2)
|
| 509 |
+
arm_queries = arm_queries + role_context.unsqueeze(2)
|
| 510 |
+
arm_queries = arm_queries + self.arm_identity(identity_order).view(1, 2, 1, -1).to(dtype=queries.dtype)
|
| 511 |
+
flat_queries = arm_queries.reshape(batch_size * 2, chunk_size, self.config.hidden_dim)
|
| 512 |
+
flat_memory = decoder_memory.unsqueeze(1).expand(-1, 2, -1, -1).reshape(
|
| 513 |
+
batch_size * 2,
|
| 514 |
+
decoder_memory.shape[1],
|
| 515 |
+
decoder_memory.shape[2],
|
| 516 |
+
)
|
| 517 |
+
decoded = self.arm_decoder(flat_queries, flat_memory).reshape(batch_size, 2, chunk_size, self.config.hidden_dim)
|
| 518 |
+
coordination_input = torch.cat(
|
| 519 |
+
[
|
| 520 |
+
decoded[:, 0],
|
| 521 |
+
decoded[:, 1],
|
| 522 |
+
interaction_context.unsqueeze(1).expand(-1, chunk_size, -1),
|
| 523 |
+
],
|
| 524 |
+
dim=-1,
|
| 525 |
+
)
|
| 526 |
+
coordination = torch.tanh(self.coordination(coordination_input))
|
| 527 |
+
decoded[:, 0] = decoded[:, 0] + coordination
|
| 528 |
+
decoded[:, 1] = decoded[:, 1] + coordination
|
| 529 |
+
decoded = self.arm_head(decoded)
|
| 530 |
+
arm_mean = self.arm_mean(decoded)
|
| 531 |
+
arm_log_std = self.arm_log_std(decoded).clamp(min=-5.0, max=2.0)
|
| 532 |
+
return arm_mean, arm_log_std, coordination
|
| 533 |
+
|
| 534 |
+
def _proposal_outputs(
|
| 535 |
+
self,
|
| 536 |
+
base_action: Tensor,
|
| 537 |
+
pooled_context: Tensor,
|
| 538 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 539 |
+
batch_size = pooled_context.shape[0]
|
| 540 |
+
mode_logits = self.proposal_mode_head(pooled_context)
|
| 541 |
+
mode_residuals = []
|
| 542 |
+
for head in self.mode_residual_heads:
|
| 543 |
+
residual = head(pooled_context).view(batch_size, self.config.chunk_size, self.config.action_dim)
|
| 544 |
+
mode_residuals.append(residual)
|
| 545 |
+
mode_residuals = torch.stack(mode_residuals, dim=1)
|
| 546 |
+
|
| 547 |
+
mode_assignments = torch.arange(self.config.num_candidates, device=pooled_context.device) % self.config.num_proposal_modes
|
| 548 |
+
slot_embeddings = self.proposal_slot_embeddings.weight
|
| 549 |
+
slot_deltas = self.slot_delta(slot_embeddings).view(
|
| 550 |
+
self.config.num_candidates,
|
| 551 |
+
self.config.chunk_size,
|
| 552 |
+
self.config.action_dim,
|
| 553 |
+
)
|
| 554 |
+
proposal_candidates = []
|
| 555 |
+
proposal_logits = []
|
| 556 |
+
for slot_idx in range(self.config.num_candidates):
|
| 557 |
+
mode_idx = int(mode_assignments[slot_idx])
|
| 558 |
+
candidate = base_action + 0.35 * torch.tanh(mode_residuals[:, mode_idx]) + 0.05 * torch.tanh(slot_deltas[slot_idx]).unsqueeze(0)
|
| 559 |
+
proposal_candidates.append(candidate)
|
| 560 |
+
score_features = torch.cat(
|
| 561 |
+
[
|
| 562 |
+
pooled_context,
|
| 563 |
+
self.proposal_mode_embeddings.weight[mode_idx].unsqueeze(0).expand(batch_size, -1)
|
| 564 |
+
+ slot_embeddings[slot_idx].unsqueeze(0).expand(batch_size, -1),
|
| 565 |
+
],
|
| 566 |
+
dim=-1,
|
| 567 |
+
)
|
| 568 |
+
proposal_logits.append(
|
| 569 |
+
self.proposal_score(score_features).squeeze(-1) + mode_logits[:, mode_idx]
|
| 570 |
+
)
|
| 571 |
+
stacked_candidates = torch.stack(proposal_candidates, dim=1)
|
| 572 |
+
stacked_logits = torch.stack(proposal_logits, dim=1)
|
| 573 |
+
stacked_candidates[:, 0] = base_action
|
| 574 |
+
return stacked_candidates, stacked_logits, mode_logits
|
| 575 |
+
|
| 576 |
+
def forward(
|
| 577 |
+
self,
|
| 578 |
+
scene_tokens: Tensor,
|
| 579 |
+
interaction_state: dict[str, Tensor] | None = None,
|
| 580 |
+
memory_tokens: Tensor | None = None,
|
| 581 |
+
reveal_tokens: Tensor | None = None,
|
| 582 |
+
memory_token: Tensor | None = None,
|
| 583 |
+
compute_equivariance_probe: bool = False,
|
| 584 |
+
) -> dict[str, Tensor]:
|
| 585 |
+
if memory_tokens is None:
|
| 586 |
+
memory_tokens = memory_token
|
| 587 |
+
batch_size = scene_tokens.shape[0]
|
| 588 |
+
dtype = scene_tokens.dtype
|
| 589 |
+
phase_context, role_context, interaction_context = self._conditioning(
|
| 590 |
+
interaction_state=interaction_state,
|
| 591 |
+
batch_size=batch_size,
|
| 592 |
+
device=scene_tokens.device,
|
| 593 |
+
dtype=dtype,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
decoder_memory = scene_tokens
|
| 597 |
+
interaction_tokens = interaction_state.get("interaction_tokens") if interaction_state is not None else None
|
| 598 |
+
if interaction_tokens is not None:
|
| 599 |
+
decoder_memory = torch.cat([decoder_memory, interaction_tokens], dim=1)
|
| 600 |
+
elif reveal_tokens is not None:
|
| 601 |
+
decoder_memory = torch.cat([decoder_memory, reveal_tokens], dim=1)
|
| 602 |
+
if memory_tokens is not None:
|
| 603 |
+
decoder_memory = torch.cat([decoder_memory, memory_tokens], dim=1)
|
| 604 |
+
|
| 605 |
+
base_queries = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
| 606 |
+
arm_mean, arm_log_std, coordination = self._decode_arm_tokens(
|
| 607 |
+
queries=base_queries,
|
| 608 |
+
decoder_memory=decoder_memory,
|
| 609 |
+
phase_context=phase_context,
|
| 610 |
+
role_context=role_context,
|
| 611 |
+
interaction_context=interaction_context,
|
| 612 |
+
)
|
| 613 |
+
action_mean = torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1)
|
| 614 |
+
action_log_std = torch.cat([arm_log_std[:, 0], arm_log_std[:, 1]], dim=-1)
|
| 615 |
+
pooled_context = torch.cat(
|
| 616 |
+
[
|
| 617 |
+
arm_mean[:, 0].mean(dim=1),
|
| 618 |
+
arm_mean[:, 1].mean(dim=1),
|
| 619 |
+
coordination.mean(dim=1),
|
| 620 |
+
interaction_context,
|
| 621 |
+
],
|
| 622 |
+
dim=-1,
|
| 623 |
+
)
|
| 624 |
+
proposal_candidates, proposal_logits, proposal_mode_logits = self._proposal_outputs(action_mean, pooled_context)
|
| 625 |
+
|
| 626 |
+
outputs = {
|
| 627 |
+
"decoded_tokens": torch.cat([arm_mean[:, 0], arm_mean[:, 1]], dim=-1),
|
| 628 |
+
"right_tokens": arm_mean[:, 0],
|
| 629 |
+
"left_tokens": arm_mean[:, 1],
|
| 630 |
+
"coordination_tokens": coordination,
|
| 631 |
+
"action_mean": action_mean,
|
| 632 |
+
"action_log_std": action_log_std,
|
| 633 |
+
"proposal_candidates": proposal_candidates,
|
| 634 |
+
"proposal_logits": proposal_logits,
|
| 635 |
+
"proposal_mode_logits": proposal_mode_logits,
|
| 636 |
+
"proposal_mode_assignments": torch.arange(
|
| 637 |
+
self.config.num_candidates,
|
| 638 |
+
device=scene_tokens.device,
|
| 639 |
+
) % self.config.num_proposal_modes,
|
| 640 |
+
"proposal_mode_names": list(DEFAULT_PROPOSAL_MODES[: self.config.num_proposal_modes]),
|
| 641 |
+
}
|
| 642 |
+
if compute_equivariance_probe:
|
| 643 |
+
swapped_phase, swapped_roles, swapped_context = self._conditioning(
|
| 644 |
+
interaction_state=interaction_state,
|
| 645 |
+
batch_size=batch_size,
|
| 646 |
+
device=scene_tokens.device,
|
| 647 |
+
dtype=dtype,
|
| 648 |
+
swap_roles=True,
|
| 649 |
+
)
|
| 650 |
+
swapped_arm_mean, _, _ = self._decode_arm_tokens(
|
| 651 |
+
queries=base_queries,
|
| 652 |
+
decoder_memory=decoder_memory,
|
| 653 |
+
phase_context=swapped_phase,
|
| 654 |
+
role_context=swapped_roles,
|
| 655 |
+
interaction_context=swapped_context,
|
| 656 |
+
swap_roles=True,
|
| 657 |
+
)
|
| 658 |
+
outputs["equivariance_probe_action_mean"] = torch.cat(
|
| 659 |
+
[swapped_arm_mean[:, 0], swapped_arm_mean[:, 1]],
|
| 660 |
+
dim=-1,
|
| 661 |
+
)
|
| 662 |
+
outputs["equivariance_target_action_mean"] = swap_arm_action_order(action_mean)
|
| 663 |
+
return outputs
|
| 664 |
+
|
| 665 |
+
def sample_candidates(
|
| 666 |
+
self,
|
| 667 |
+
action_mean: Tensor,
|
| 668 |
+
action_log_std: Tensor,
|
| 669 |
+
num_candidates: int | None = None,
|
| 670 |
+
proposal_candidates: Tensor | None = None,
|
| 671 |
+
) -> Tensor:
|
| 672 |
+
if proposal_candidates is not None:
|
| 673 |
+
return proposal_candidates
|
| 674 |
+
num_candidates = num_candidates or self.config.num_candidates
|
| 675 |
+
if num_candidates <= 1:
|
| 676 |
+
return action_mean.unsqueeze(1)
|
| 677 |
+
noise = torch.randn(
|
| 678 |
+
action_mean.size(0),
|
| 679 |
+
num_candidates,
|
| 680 |
+
action_mean.size(1),
|
| 681 |
+
action_mean.size(2),
|
| 682 |
+
device=action_mean.device,
|
| 683 |
+
dtype=action_mean.dtype,
|
| 684 |
+
)
|
| 685 |
+
candidates = action_mean.unsqueeze(1) + noise * action_log_std.exp().unsqueeze(1)
|
| 686 |
+
candidates[:, 0] = action_mean
|
| 687 |
+
return candidates
|
code/reveal_vla_bimanual/models/backbones.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
| 4 |
import math
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Sequence
|
| 7 |
|
|
@@ -18,6 +19,157 @@ class FrozenVLBackboneConfig:
|
|
| 18 |
freeze_backbone: bool = True
|
| 19 |
gradient_checkpointing: bool = True
|
| 20 |
use_dummy_backbone: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class _DummyTextTokenizer:
|
|
@@ -42,6 +194,11 @@ class FrozenVLBackbone(nn.Module):
|
|
| 42 |
self.config = config
|
| 43 |
self.hidden_dim = config.hidden_dim
|
| 44 |
self.use_dummy_backbone = config.use_dummy_backbone
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
if config.use_dummy_backbone:
|
| 47 |
self.image_patch_size = 16
|
|
@@ -51,36 +208,62 @@ class FrozenVLBackbone(nn.Module):
|
|
| 51 |
|
| 52 |
local_model_source: str | None = None
|
| 53 |
if config.model_name == "openai/clip-vit-base-patch32":
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
clip_model = None
|
|
|
|
|
|
|
| 62 |
if local_model_source is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
try:
|
| 64 |
-
clip_model = CLIPModel.from_pretrained(
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
)
|
| 69 |
-
except OSError:
|
| 70 |
-
clip_model = None
|
| 71 |
if clip_model is None:
|
| 72 |
-
|
|
|
|
| 73 |
self.vision_model = clip_model.vision_model
|
| 74 |
self.text_model = clip_model.text_model
|
| 75 |
self.visual_projection = clip_model.visual_projection
|
| 76 |
self.text_projection = clip_model.text_projection
|
|
|
|
|
|
|
|
|
|
| 77 |
if local_model_source is not None:
|
|
|
|
|
|
|
|
|
|
| 78 |
try:
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
| 84 |
self.hidden_dim = clip_model.config.projection_dim
|
| 85 |
if config.gradient_checkpointing:
|
| 86 |
if hasattr(self.vision_model, "gradient_checkpointing_enable"):
|
|
@@ -88,9 +271,17 @@ class FrozenVLBackbone(nn.Module):
|
|
| 88 |
if hasattr(self.text_model, "gradient_checkpointing_enable"):
|
| 89 |
self.text_model.gradient_checkpointing_enable()
|
| 90 |
|
| 91 |
-
if config.freeze_backbone:
|
| 92 |
-
for
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
def tokenize_text(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]:
|
| 96 |
if self.use_dummy_backbone:
|
|
@@ -103,7 +294,7 @@ class FrozenVLBackbone(nn.Module):
|
|
| 103 |
return_tensors="pt",
|
| 104 |
).to(device)
|
| 105 |
|
| 106 |
-
def
|
| 107 |
batch_size, num_views, channels, height, width = images.shape
|
| 108 |
flat_images = images.reshape(batch_size * num_views, channels, height, width)
|
| 109 |
if self.use_dummy_backbone:
|
|
@@ -125,6 +316,40 @@ class FrozenVLBackbone(nn.Module):
|
|
| 125 |
num_tokens = tokens.shape[1]
|
| 126 |
return tokens.reshape(batch_size, num_views, num_tokens, -1)
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
def encode_text(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
| 129 |
if self.use_dummy_backbone:
|
| 130 |
vocab_scale = float(self.tokenizer.vocab_size - 1)
|
|
|
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
| 4 |
import math
|
| 5 |
+
import os
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import Sequence
|
| 8 |
|
|
|
|
| 19 |
freeze_backbone: bool = True
|
| 20 |
gradient_checkpointing: bool = True
|
| 21 |
use_dummy_backbone: bool = False
|
| 22 |
+
depth_patch_size: int = 16
|
| 23 |
+
geometry_feature_dim: int = 8
|
| 24 |
+
use_camera_geometry: bool = True
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DepthPatchAdapter(nn.Module):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
hidden_dim: int,
|
| 31 |
+
patch_size: int = 16,
|
| 32 |
+
geometry_feature_dim: int = 8,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.hidden_dim = hidden_dim
|
| 36 |
+
self.patch_size = patch_size
|
| 37 |
+
self.geometry_feature_dim = geometry_feature_dim
|
| 38 |
+
self.depth_proj = nn.Sequential(
|
| 39 |
+
nn.LayerNorm(2 + geometry_feature_dim),
|
| 40 |
+
nn.Linear(2 + geometry_feature_dim, hidden_dim),
|
| 41 |
+
nn.GELU(),
|
| 42 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 43 |
+
)
|
| 44 |
+
self.geometry_proj = nn.Sequential(
|
| 45 |
+
nn.LayerNorm(geometry_feature_dim),
|
| 46 |
+
nn.Linear(geometry_feature_dim, hidden_dim),
|
| 47 |
+
nn.GELU(),
|
| 48 |
+
)
|
| 49 |
+
self.camera_proj = nn.Sequential(
|
| 50 |
+
nn.LayerNorm(7),
|
| 51 |
+
nn.Linear(7, hidden_dim),
|
| 52 |
+
nn.GELU(),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def _patchify(self, tensor: Tensor) -> Tensor:
|
| 56 |
+
pooled = F.avg_pool2d(tensor, kernel_size=self.patch_size, stride=self.patch_size)
|
| 57 |
+
return pooled.flatten(2).transpose(1, 2)
|
| 58 |
+
|
| 59 |
+
def _geometry_features(
|
| 60 |
+
self,
|
| 61 |
+
depths: Tensor,
|
| 62 |
+
camera_intrinsics: Tensor | None = None,
|
| 63 |
+
camera_extrinsics: Tensor | None = None,
|
| 64 |
+
) -> tuple[Tensor, Tensor]:
|
| 65 |
+
batch_views, _, height, width = depths.shape
|
| 66 |
+
grid_h = max(1, height // self.patch_size)
|
| 67 |
+
grid_w = max(1, width // self.patch_size)
|
| 68 |
+
y_coords = torch.linspace(-1.0, 1.0, steps=grid_h, device=depths.device, dtype=depths.dtype)
|
| 69 |
+
x_coords = torch.linspace(-1.0, 1.0, steps=grid_w, device=depths.device, dtype=depths.dtype)
|
| 70 |
+
grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij")
|
| 71 |
+
coords = torch.stack([grid_x, grid_y], dim=-1).reshape(1, grid_h * grid_w, 2).expand(batch_views, -1, -1)
|
| 72 |
+
|
| 73 |
+
geometry_terms: list[Tensor] = [coords]
|
| 74 |
+
if camera_intrinsics is not None:
|
| 75 |
+
fx = camera_intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1)
|
| 76 |
+
fy = camera_intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1)
|
| 77 |
+
cx = camera_intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)
|
| 78 |
+
cy = camera_intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)
|
| 79 |
+
intrinsic_features = torch.cat(
|
| 80 |
+
[
|
| 81 |
+
fx.expand(-1, grid_h * grid_w, -1),
|
| 82 |
+
fy.expand(-1, grid_h * grid_w, -1),
|
| 83 |
+
cx.expand(-1, grid_h * grid_w, -1),
|
| 84 |
+
cy.expand(-1, grid_h * grid_w, -1),
|
| 85 |
+
],
|
| 86 |
+
dim=-1,
|
| 87 |
+
)
|
| 88 |
+
geometry_terms.append(intrinsic_features)
|
| 89 |
+
else:
|
| 90 |
+
geometry_terms.append(torch.zeros(batch_views, grid_h * grid_w, 4, device=depths.device, dtype=depths.dtype))
|
| 91 |
+
|
| 92 |
+
if camera_extrinsics is not None:
|
| 93 |
+
translation = camera_extrinsics[:, :3, 3]
|
| 94 |
+
translation = translation.unsqueeze(1).expand(-1, grid_h * grid_w, -1)
|
| 95 |
+
geometry_terms.append(translation)
|
| 96 |
+
else:
|
| 97 |
+
geometry_terms.append(torch.zeros(batch_views, grid_h * grid_w, 3, device=depths.device, dtype=depths.dtype))
|
| 98 |
+
|
| 99 |
+
geometry = torch.cat(geometry_terms, dim=-1)
|
| 100 |
+
if geometry.shape[-1] < self.geometry_feature_dim:
|
| 101 |
+
pad = self.geometry_feature_dim - geometry.shape[-1]
|
| 102 |
+
geometry = F.pad(geometry, (0, pad))
|
| 103 |
+
elif geometry.shape[-1] > self.geometry_feature_dim:
|
| 104 |
+
geometry = geometry[..., : self.geometry_feature_dim]
|
| 105 |
+
|
| 106 |
+
if camera_intrinsics is not None:
|
| 107 |
+
camera_summary = torch.cat(
|
| 108 |
+
[
|
| 109 |
+
camera_intrinsics[:, 0, 0:1],
|
| 110 |
+
camera_intrinsics[:, 1, 1:2],
|
| 111 |
+
camera_intrinsics[:, 0, 2:3],
|
| 112 |
+
camera_intrinsics[:, 1, 2:3],
|
| 113 |
+
],
|
| 114 |
+
dim=-1,
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
camera_summary = torch.zeros(batch_views, 4, device=depths.device, dtype=depths.dtype)
|
| 118 |
+
if camera_extrinsics is not None:
|
| 119 |
+
camera_summary = torch.cat([camera_summary, camera_extrinsics[:, :3, 3]], dim=-1)
|
| 120 |
+
else:
|
| 121 |
+
camera_summary = torch.cat(
|
| 122 |
+
[camera_summary, torch.zeros(batch_views, 3, device=depths.device, dtype=depths.dtype)],
|
| 123 |
+
dim=-1,
|
| 124 |
+
)
|
| 125 |
+
return geometry, camera_summary
|
| 126 |
+
|
| 127 |
+
def forward(
|
| 128 |
+
self,
|
| 129 |
+
depths: Tensor,
|
| 130 |
+
depth_valid: Tensor | None = None,
|
| 131 |
+
camera_intrinsics: Tensor | None = None,
|
| 132 |
+
camera_extrinsics: Tensor | None = None,
|
| 133 |
+
) -> dict[str, Tensor]:
|
| 134 |
+
if depths.ndim == 4:
|
| 135 |
+
depths = depths.unsqueeze(2)
|
| 136 |
+
if depth_valid is None:
|
| 137 |
+
depth_valid = torch.ones_like(depths)
|
| 138 |
+
if depth_valid.ndim == 4:
|
| 139 |
+
depth_valid = depth_valid.unsqueeze(2)
|
| 140 |
+
if depths.ndim != 5:
|
| 141 |
+
raise ValueError(f"Expected depths to have shape [B, V, H, W] or [B, V, 1, H, W], got {tuple(depths.shape)}")
|
| 142 |
+
if depths.shape[2] != 1:
|
| 143 |
+
depths = depths.mean(dim=2, keepdim=True)
|
| 144 |
+
if depth_valid.shape[2] != 1:
|
| 145 |
+
depth_valid = depth_valid.mean(dim=2, keepdim=True)
|
| 146 |
+
|
| 147 |
+
batch_size, num_views = depths.shape[:2]
|
| 148 |
+
flat_depths = depths.reshape(batch_size * num_views, 1, depths.shape[-2], depths.shape[-1]).float()
|
| 149 |
+
flat_valid = depth_valid.reshape(batch_size * num_views, 1, depth_valid.shape[-2], depth_valid.shape[-1]).float()
|
| 150 |
+
flat_intrinsics = None
|
| 151 |
+
flat_extrinsics = None
|
| 152 |
+
if camera_intrinsics is not None:
|
| 153 |
+
flat_intrinsics = camera_intrinsics.reshape(batch_size * num_views, *camera_intrinsics.shape[-2:]).float()
|
| 154 |
+
if camera_extrinsics is not None:
|
| 155 |
+
flat_extrinsics = camera_extrinsics.reshape(batch_size * num_views, *camera_extrinsics.shape[-2:]).float()
|
| 156 |
+
|
| 157 |
+
depth_patch = self._patchify(flat_depths)
|
| 158 |
+
valid_patch = self._patchify(flat_valid)
|
| 159 |
+
geometry_features, camera_summary = self._geometry_features(
|
| 160 |
+
flat_depths,
|
| 161 |
+
camera_intrinsics=flat_intrinsics,
|
| 162 |
+
camera_extrinsics=flat_extrinsics,
|
| 163 |
+
)
|
| 164 |
+
token_inputs = torch.cat([depth_patch, valid_patch, geometry_features], dim=-1)
|
| 165 |
+
depth_tokens = self.depth_proj(token_inputs)
|
| 166 |
+
geometry_tokens = self.geometry_proj(geometry_features)
|
| 167 |
+
camera_tokens = self.camera_proj(camera_summary).unsqueeze(1)
|
| 168 |
+
return {
|
| 169 |
+
"depth_tokens": depth_tokens.view(batch_size, num_views, depth_tokens.shape[1], depth_tokens.shape[2]),
|
| 170 |
+
"geometry_tokens": geometry_tokens.view(batch_size, num_views, geometry_tokens.shape[1], geometry_tokens.shape[2]),
|
| 171 |
+
"camera_tokens": camera_tokens.view(batch_size, num_views, 1, camera_tokens.shape[-1]),
|
| 172 |
+
}
|
| 173 |
|
| 174 |
|
| 175 |
class _DummyTextTokenizer:
|
|
|
|
| 194 |
self.config = config
|
| 195 |
self.hidden_dim = config.hidden_dim
|
| 196 |
self.use_dummy_backbone = config.use_dummy_backbone
|
| 197 |
+
self.depth_adapter = DepthPatchAdapter(
|
| 198 |
+
hidden_dim=config.hidden_dim,
|
| 199 |
+
patch_size=config.depth_patch_size,
|
| 200 |
+
geometry_feature_dim=config.geometry_feature_dim,
|
| 201 |
+
)
|
| 202 |
|
| 203 |
if config.use_dummy_backbone:
|
| 204 |
self.image_patch_size = 16
|
|
|
|
| 208 |
|
| 209 |
local_model_source: str | None = None
|
| 210 |
if config.model_name == "openai/clip-vit-base-patch32":
|
| 211 |
+
explicit_local_dir = Path("/workspace/models/openai_clip_vit_base_patch32")
|
| 212 |
+
if (explicit_local_dir / "config.json").exists():
|
| 213 |
+
local_model_source = str(explicit_local_dir)
|
| 214 |
+
cache_home = Path(os.environ.get("HF_HOME", "/workspace/.cache/huggingface"))
|
| 215 |
+
cache_root = cache_home / "hub" / "models--openai--clip-vit-base-patch32"
|
| 216 |
+
if local_model_source is None:
|
| 217 |
+
ref_path = cache_root / "refs" / "main"
|
| 218 |
+
if ref_path.exists():
|
| 219 |
+
snapshot_id = ref_path.read_text(encoding="utf-8").strip()
|
| 220 |
+
snapshot_dir = cache_root / "snapshots" / snapshot_id
|
| 221 |
+
if (snapshot_dir / "config.json").exists():
|
| 222 |
+
local_model_source = str(snapshot_dir)
|
| 223 |
+
if local_model_source is None:
|
| 224 |
+
snapshot_root = cache_root / "snapshots"
|
| 225 |
+
if snapshot_root.exists():
|
| 226 |
+
for snapshot_dir in sorted(snapshot_root.iterdir(), reverse=True):
|
| 227 |
+
if (snapshot_dir / "config.json").exists():
|
| 228 |
+
local_model_source = str(snapshot_dir)
|
| 229 |
+
break
|
| 230 |
clip_model = None
|
| 231 |
+
last_clip_error: Exception | None = None
|
| 232 |
+
model_sources: list[tuple[str, dict[str, object]]] = []
|
| 233 |
if local_model_source is not None:
|
| 234 |
+
model_sources.append((local_model_source, {"use_safetensors": True, "local_files_only": True}))
|
| 235 |
+
model_sources.append((local_model_source, {"local_files_only": True}))
|
| 236 |
+
model_sources.append((config.model_name, {"use_safetensors": True}))
|
| 237 |
+
model_sources.append((config.model_name, {}))
|
| 238 |
+
for source, kwargs in model_sources:
|
| 239 |
try:
|
| 240 |
+
clip_model = CLIPModel.from_pretrained(source, **kwargs)
|
| 241 |
+
break
|
| 242 |
+
except Exception as exc:
|
| 243 |
+
last_clip_error = exc
|
|
|
|
|
|
|
|
|
|
| 244 |
if clip_model is None:
|
| 245 |
+
assert last_clip_error is not None
|
| 246 |
+
raise last_clip_error
|
| 247 |
self.vision_model = clip_model.vision_model
|
| 248 |
self.text_model = clip_model.text_model
|
| 249 |
self.visual_projection = clip_model.visual_projection
|
| 250 |
self.text_projection = clip_model.text_projection
|
| 251 |
+
tokenizer = None
|
| 252 |
+
last_tokenizer_error: Exception | None = None
|
| 253 |
+
tokenizer_sources: list[tuple[str, dict[str, object]]] = []
|
| 254 |
if local_model_source is not None:
|
| 255 |
+
tokenizer_sources.append((local_model_source, {"local_files_only": True}))
|
| 256 |
+
tokenizer_sources.append((config.model_name, {}))
|
| 257 |
+
for source, kwargs in tokenizer_sources:
|
| 258 |
try:
|
| 259 |
+
tokenizer = AutoTokenizer.from_pretrained(source, **kwargs)
|
| 260 |
+
break
|
| 261 |
+
except Exception as exc:
|
| 262 |
+
last_tokenizer_error = exc
|
| 263 |
+
if tokenizer is None:
|
| 264 |
+
assert last_tokenizer_error is not None
|
| 265 |
+
raise last_tokenizer_error
|
| 266 |
+
self.tokenizer = tokenizer
|
| 267 |
self.hidden_dim = clip_model.config.projection_dim
|
| 268 |
if config.gradient_checkpointing:
|
| 269 |
if hasattr(self.vision_model, "gradient_checkpointing_enable"):
|
|
|
|
| 271 |
if hasattr(self.text_model, "gradient_checkpointing_enable"):
|
| 272 |
self.text_model.gradient_checkpointing_enable()
|
| 273 |
|
| 274 |
+
if config.freeze_backbone and not config.use_dummy_backbone:
|
| 275 |
+
for module in (
|
| 276 |
+
getattr(self, "vision_model", None),
|
| 277 |
+
getattr(self, "text_model", None),
|
| 278 |
+
getattr(self, "visual_projection", None),
|
| 279 |
+
getattr(self, "text_projection", None),
|
| 280 |
+
):
|
| 281 |
+
if module is None:
|
| 282 |
+
continue
|
| 283 |
+
for parameter in module.parameters():
|
| 284 |
+
parameter.requires_grad = False
|
| 285 |
|
| 286 |
def tokenize_text(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]:
|
| 287 |
if self.use_dummy_backbone:
|
|
|
|
| 294 |
return_tensors="pt",
|
| 295 |
).to(device)
|
| 296 |
|
| 297 |
+
def _encode_rgb_tokens(self, images: Tensor) -> Tensor:
|
| 298 |
batch_size, num_views, channels, height, width = images.shape
|
| 299 |
flat_images = images.reshape(batch_size * num_views, channels, height, width)
|
| 300 |
if self.use_dummy_backbone:
|
|
|
|
| 316 |
num_tokens = tokens.shape[1]
|
| 317 |
return tokens.reshape(batch_size, num_views, num_tokens, -1)
|
| 318 |
|
| 319 |
+
def encode_images(
|
| 320 |
+
self,
|
| 321 |
+
images: Tensor,
|
| 322 |
+
depths: Tensor | None = None,
|
| 323 |
+
depth_valid: Tensor | None = None,
|
| 324 |
+
camera_intrinsics: Tensor | None = None,
|
| 325 |
+
camera_extrinsics: Tensor | None = None,
|
| 326 |
+
return_aux: bool = False,
|
| 327 |
+
) -> Tensor | dict[str, Tensor | None]:
|
| 328 |
+
rgb_tokens = self._encode_rgb_tokens(images)
|
| 329 |
+
wants_aux = return_aux or depths is not None or depth_valid is not None or camera_intrinsics is not None or camera_extrinsics is not None
|
| 330 |
+
if not wants_aux:
|
| 331 |
+
return rgb_tokens
|
| 332 |
+
|
| 333 |
+
depth_outputs: dict[str, Tensor | None] = {
|
| 334 |
+
"depth_tokens": None,
|
| 335 |
+
"geometry_tokens": None,
|
| 336 |
+
"camera_tokens": None,
|
| 337 |
+
}
|
| 338 |
+
if depths is not None:
|
| 339 |
+
depth_outputs = self.depth_adapter(
|
| 340 |
+
depths=depths,
|
| 341 |
+
depth_valid=depth_valid,
|
| 342 |
+
camera_intrinsics=camera_intrinsics,
|
| 343 |
+
camera_extrinsics=camera_extrinsics,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
return {
|
| 347 |
+
"rgb_tokens": rgb_tokens,
|
| 348 |
+
"depth_tokens": depth_outputs["depth_tokens"],
|
| 349 |
+
"geometry_tokens": depth_outputs["geometry_tokens"],
|
| 350 |
+
"camera_tokens": depth_outputs["camera_tokens"],
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
def encode_text(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
| 354 |
if self.use_dummy_backbone:
|
| 355 |
vocab_scale = float(self.tokenizer.vocab_size - 1)
|
code/reveal_vla_bimanual/models/multiview_fusion.py
CHANGED
|
@@ -16,6 +16,37 @@ class MultiViewFusionConfig:
|
|
| 16 |
dropout: float = 0.1
|
| 17 |
proprio_dim: int = 32
|
| 18 |
proprio_tokens: int = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class MultiViewFusion(nn.Module):
|
|
@@ -35,13 +66,26 @@ class MultiViewFusion(nn.Module):
|
|
| 35 |
encoder_layer,
|
| 36 |
num_layers=config.num_layers,
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
self.proprio_adapter = nn.Sequential(
|
| 39 |
nn.LayerNorm(config.proprio_dim),
|
| 40 |
nn.Linear(config.proprio_dim, config.hidden_dim * config.proprio_tokens),
|
| 41 |
nn.GELU(),
|
| 42 |
)
|
| 43 |
|
| 44 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
batch_size, num_views, num_tokens, hidden_dim = image_tokens.shape
|
| 46 |
if num_views != self.config.num_cameras:
|
| 47 |
raise ValueError(f"Expected {self.config.num_cameras} views, received {num_views}")
|
|
@@ -49,9 +93,36 @@ class MultiViewFusion(nn.Module):
|
|
| 49 |
camera_ids = torch.arange(num_views, device=image_tokens.device)
|
| 50 |
camera_embed = self.camera_embedding(camera_ids).view(1, num_views, 1, hidden_dim)
|
| 51 |
image_tokens = image_tokens + camera_embed
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
proprio_tokens = self.proprio_adapter(proprio).view(
|
| 55 |
batch_size, self.config.proprio_tokens, hidden_dim
|
| 56 |
)
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
dropout: float = 0.1
|
| 17 |
proprio_dim: int = 32
|
| 18 |
proprio_tokens: int = 1
|
| 19 |
+
geometry_num_heads: int = 4
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class GatedCrossAttentionBlock(nn.Module):
|
| 23 |
+
def __init__(self, hidden_dim: int, num_heads: int, dropout: float) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.attn = nn.MultiheadAttention(
|
| 26 |
+
embed_dim=hidden_dim,
|
| 27 |
+
num_heads=num_heads,
|
| 28 |
+
dropout=dropout,
|
| 29 |
+
batch_first=True,
|
| 30 |
+
)
|
| 31 |
+
self.gate = nn.Sequential(
|
| 32 |
+
nn.LayerNorm(hidden_dim * 2),
|
| 33 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 34 |
+
nn.GELU(),
|
| 35 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 36 |
+
)
|
| 37 |
+
self.out = nn.Sequential(
|
| 38 |
+
nn.LayerNorm(hidden_dim),
|
| 39 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 40 |
+
nn.GELU(),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def forward(self, rgb_tokens: Tensor, geometry_tokens: Tensor) -> tuple[Tensor, Tensor]:
|
| 44 |
+
attended, _ = self.attn(rgb_tokens, geometry_tokens, geometry_tokens)
|
| 45 |
+
rgb_summary = rgb_tokens.mean(dim=1)
|
| 46 |
+
geometry_summary = geometry_tokens.mean(dim=1)
|
| 47 |
+
gate = torch.sigmoid(self.gate(torch.cat([rgb_summary, geometry_summary], dim=-1))).unsqueeze(1)
|
| 48 |
+
fused = rgb_tokens + gate * attended
|
| 49 |
+
return self.out(fused), geometry_summary
|
| 50 |
|
| 51 |
|
| 52 |
class MultiViewFusion(nn.Module):
|
|
|
|
| 66 |
encoder_layer,
|
| 67 |
num_layers=config.num_layers,
|
| 68 |
)
|
| 69 |
+
self.geometry_fusion = GatedCrossAttentionBlock(
|
| 70 |
+
hidden_dim=config.hidden_dim,
|
| 71 |
+
num_heads=max(1, min(config.num_heads, config.geometry_num_heads)),
|
| 72 |
+
dropout=config.dropout,
|
| 73 |
+
)
|
| 74 |
self.proprio_adapter = nn.Sequential(
|
| 75 |
nn.LayerNorm(config.proprio_dim),
|
| 76 |
nn.Linear(config.proprio_dim, config.hidden_dim * config.proprio_tokens),
|
| 77 |
nn.GELU(),
|
| 78 |
)
|
| 79 |
|
| 80 |
+
def forward(
|
| 81 |
+
self,
|
| 82 |
+
image_tokens: Tensor,
|
| 83 |
+
proprio: Tensor,
|
| 84 |
+
language_tokens: Tensor,
|
| 85 |
+
depth_tokens: Tensor | None = None,
|
| 86 |
+
camera_tokens: Tensor | None = None,
|
| 87 |
+
return_aux: bool = False,
|
| 88 |
+
) -> Tensor | dict[str, Tensor]:
|
| 89 |
batch_size, num_views, num_tokens, hidden_dim = image_tokens.shape
|
| 90 |
if num_views != self.config.num_cameras:
|
| 91 |
raise ValueError(f"Expected {self.config.num_cameras} views, received {num_views}")
|
|
|
|
| 93 |
camera_ids = torch.arange(num_views, device=image_tokens.device)
|
| 94 |
camera_embed = self.camera_embedding(camera_ids).view(1, num_views, 1, hidden_dim)
|
| 95 |
image_tokens = image_tokens + camera_embed
|
| 96 |
+
|
| 97 |
+
per_view_tokens = []
|
| 98 |
+
view_summaries = []
|
| 99 |
+
geometry_summaries = []
|
| 100 |
+
for view_idx in range(num_views):
|
| 101 |
+
rgb_tokens = image_tokens[:, view_idx]
|
| 102 |
+
geometry_sources = []
|
| 103 |
+
if depth_tokens is not None:
|
| 104 |
+
geometry_sources.append(depth_tokens[:, view_idx])
|
| 105 |
+
if camera_tokens is not None:
|
| 106 |
+
geometry_sources.append(camera_tokens[:, view_idx])
|
| 107 |
+
if geometry_sources:
|
| 108 |
+
geometry = torch.cat(geometry_sources, dim=1)
|
| 109 |
+
rgb_tokens, geometry_summary = self.geometry_fusion(rgb_tokens, geometry)
|
| 110 |
+
else:
|
| 111 |
+
geometry_summary = torch.zeros(batch_size, hidden_dim, device=image_tokens.device, dtype=image_tokens.dtype)
|
| 112 |
+
per_view_tokens.append(rgb_tokens)
|
| 113 |
+
view_summaries.append(rgb_tokens.mean(dim=1))
|
| 114 |
+
geometry_summaries.append(geometry_summary)
|
| 115 |
+
|
| 116 |
+
fused = self.cross_view_transformer(torch.cat(per_view_tokens, dim=1))
|
| 117 |
|
| 118 |
proprio_tokens = self.proprio_adapter(proprio).view(
|
| 119 |
batch_size, self.config.proprio_tokens, hidden_dim
|
| 120 |
)
|
| 121 |
+
scene_tokens = torch.cat([fused, proprio_tokens, language_tokens], dim=1)
|
| 122 |
+
if not (return_aux or depth_tokens is not None or camera_tokens is not None):
|
| 123 |
+
return scene_tokens
|
| 124 |
+
return {
|
| 125 |
+
"scene_tokens": scene_tokens,
|
| 126 |
+
"view_summaries": torch.stack(view_summaries, dim=1),
|
| 127 |
+
"geometry_summaries": torch.stack(geometry_summaries, dim=1),
|
| 128 |
+
}
|
code/reveal_vla_bimanual/models/observation_memory.py
CHANGED
|
@@ -16,6 +16,12 @@ class ObservationMemoryConfig:
|
|
| 16 |
memory_bank_size: int = 4
|
| 17 |
num_heads: int = 4
|
| 18 |
max_history_steps: int = 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class ObservationMemory(nn.Module):
|
|
@@ -173,3 +179,189 @@ class InteractionObservationMemory(nn.Module):
|
|
| 173 |
"memory_tokens": projected_bank,
|
| 174 |
"memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_bank)).squeeze(-1),
|
| 175 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
memory_bank_size: int = 4
|
| 17 |
num_heads: int = 4
|
| 18 |
max_history_steps: int = 8
|
| 19 |
+
scene_bank_size: int = 2
|
| 20 |
+
belief_bank_size: int = 2
|
| 21 |
+
scene_history_steps: int = 3
|
| 22 |
+
belief_history_steps: int = 8
|
| 23 |
+
memory_write_threshold: float = 0.45
|
| 24 |
+
memory_suppression_margin: float = 0.05
|
| 25 |
|
| 26 |
|
| 27 |
class ObservationMemory(nn.Module):
|
|
|
|
| 179 |
"memory_tokens": projected_bank,
|
| 180 |
"memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_bank)).squeeze(-1),
|
| 181 |
}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class _SelectiveMemoryBank(nn.Module):
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
hidden_dim: int,
|
| 188 |
+
action_dim: int,
|
| 189 |
+
num_heads: int,
|
| 190 |
+
dropout: float,
|
| 191 |
+
bank_size: int,
|
| 192 |
+
history_steps: int,
|
| 193 |
+
max_history_steps: int,
|
| 194 |
+
write_threshold: float,
|
| 195 |
+
suppression_margin: float,
|
| 196 |
+
) -> None:
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.hidden_dim = hidden_dim
|
| 199 |
+
self.history_steps = history_steps
|
| 200 |
+
self.max_history_steps = max_history_steps
|
| 201 |
+
self.write_threshold = write_threshold
|
| 202 |
+
self.suppression_margin = suppression_margin
|
| 203 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 204 |
+
d_model=hidden_dim,
|
| 205 |
+
nhead=num_heads,
|
| 206 |
+
dim_feedforward=hidden_dim * 4,
|
| 207 |
+
dropout=dropout,
|
| 208 |
+
batch_first=True,
|
| 209 |
+
norm_first=True,
|
| 210 |
+
)
|
| 211 |
+
self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
|
| 212 |
+
self.position_embedding = nn.Parameter(torch.randn(1, max_history_steps + 1, hidden_dim) * 0.02)
|
| 213 |
+
self.bank_queries = nn.Parameter(torch.randn(bank_size, hidden_dim) * 0.02)
|
| 214 |
+
self.bank_attention = nn.MultiheadAttention(
|
| 215 |
+
embed_dim=hidden_dim,
|
| 216 |
+
num_heads=num_heads,
|
| 217 |
+
dropout=dropout,
|
| 218 |
+
batch_first=True,
|
| 219 |
+
)
|
| 220 |
+
self.action_proj = nn.Sequential(
|
| 221 |
+
nn.LayerNorm(action_dim),
|
| 222 |
+
nn.Linear(action_dim, hidden_dim),
|
| 223 |
+
nn.GELU(),
|
| 224 |
+
)
|
| 225 |
+
self.write_gate = nn.Sequential(
|
| 226 |
+
nn.LayerNorm(hidden_dim * 3),
|
| 227 |
+
nn.Linear(hidden_dim * 3, hidden_dim),
|
| 228 |
+
nn.GELU(),
|
| 229 |
+
nn.Linear(hidden_dim, 1),
|
| 230 |
+
)
|
| 231 |
+
self.token_proj = nn.Sequential(
|
| 232 |
+
nn.LayerNorm(hidden_dim),
|
| 233 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 234 |
+
nn.GELU(),
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def _truncate(self, history: Tensor | None) -> Tensor | None:
|
| 238 |
+
if history is None or history.numel() == 0:
|
| 239 |
+
return history
|
| 240 |
+
if history.shape[1] <= self.history_steps:
|
| 241 |
+
return history
|
| 242 |
+
return history[:, -self.history_steps :]
|
| 243 |
+
|
| 244 |
+
def forward(
|
| 245 |
+
self,
|
| 246 |
+
pooled_current: Tensor,
|
| 247 |
+
history_scene_tokens: Tensor | None = None,
|
| 248 |
+
history_actions: Tensor | None = None,
|
| 249 |
+
) -> dict[str, Tensor]:
|
| 250 |
+
history_scene_tokens = self._truncate(history_scene_tokens)
|
| 251 |
+
pooled_current = pooled_current.unsqueeze(1)
|
| 252 |
+
if history_scene_tokens is not None and history_scene_tokens.numel() > 0:
|
| 253 |
+
history_pooled = history_scene_tokens.mean(dim=2)
|
| 254 |
+
if history_actions is not None and history_actions.numel() > 0:
|
| 255 |
+
history_actions = history_actions[:, -history_pooled.shape[1] :]
|
| 256 |
+
history_pooled = history_pooled + self.action_proj(history_actions)
|
| 257 |
+
sequence = torch.cat([history_pooled, pooled_current], dim=1)
|
| 258 |
+
else:
|
| 259 |
+
history_pooled = pooled_current[:, :0]
|
| 260 |
+
sequence = pooled_current
|
| 261 |
+
if sequence.shape[1] > self.position_embedding.shape[1]:
|
| 262 |
+
raise ValueError(
|
| 263 |
+
f"Sequence length {sequence.shape[1]} exceeds configured maximum {self.position_embedding.shape[1]}"
|
| 264 |
+
)
|
| 265 |
+
encoded = self.sequence_encoder(sequence + self.position_embedding[:, : sequence.shape[1]])
|
| 266 |
+
current_token = encoded[:, -1]
|
| 267 |
+
prior_token = encoded[:, :-1].mean(dim=1) if encoded.shape[1] > 1 else torch.zeros_like(current_token)
|
| 268 |
+
novelty = torch.abs(current_token - prior_token)
|
| 269 |
+
informative = novelty.mean(dim=-1, keepdim=True)
|
| 270 |
+
gate_logit = self.write_gate(torch.cat([current_token, prior_token, novelty], dim=-1))
|
| 271 |
+
gate = torch.sigmoid(gate_logit)
|
| 272 |
+
gate = gate * (informative > (self.write_threshold - self.suppression_margin)).to(gate.dtype)
|
| 273 |
+
recent_summary = encoded[:, -min(max(1, self.bank_queries.shape[0]), encoded.shape[1]) :].mean(dim=1, keepdim=True)
|
| 274 |
+
queries = self.bank_queries.unsqueeze(0).expand(encoded.shape[0], -1, -1) + recent_summary
|
| 275 |
+
bank_tokens, _ = self.bank_attention(queries, encoded, encoded)
|
| 276 |
+
bank_tokens = bank_tokens + recent_summary
|
| 277 |
+
bank_tokens = prior_token.unsqueeze(1) * (1.0 - gate.unsqueeze(1)) + bank_tokens * gate.unsqueeze(1)
|
| 278 |
+
bank_tokens = self.token_proj(bank_tokens)
|
| 279 |
+
return {
|
| 280 |
+
"memory_tokens": bank_tokens,
|
| 281 |
+
"memory_token": bank_tokens.mean(dim=1, keepdim=True),
|
| 282 |
+
"memory_sequence": encoded,
|
| 283 |
+
"memory_state": current_token,
|
| 284 |
+
"write_gate": gate.squeeze(-1),
|
| 285 |
+
"saturation": bank_tokens.abs().mean(dim=(1, 2)),
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class SceneMemory(_SelectiveMemoryBank):
|
| 290 |
+
def __init__(self, config: ObservationMemoryConfig) -> None:
|
| 291 |
+
super().__init__(
|
| 292 |
+
hidden_dim=config.hidden_dim,
|
| 293 |
+
action_dim=config.action_dim,
|
| 294 |
+
num_heads=config.num_heads,
|
| 295 |
+
dropout=config.dropout,
|
| 296 |
+
bank_size=max(1, config.scene_bank_size),
|
| 297 |
+
history_steps=max(1, config.scene_history_steps),
|
| 298 |
+
max_history_steps=config.max_history_steps,
|
| 299 |
+
write_threshold=config.memory_write_threshold,
|
| 300 |
+
suppression_margin=config.memory_suppression_margin,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class BeliefMemory(_SelectiveMemoryBank):
|
| 305 |
+
def __init__(self, config: ObservationMemoryConfig) -> None:
|
| 306 |
+
super().__init__(
|
| 307 |
+
hidden_dim=config.hidden_dim,
|
| 308 |
+
action_dim=config.action_dim,
|
| 309 |
+
num_heads=config.num_heads,
|
| 310 |
+
dropout=config.dropout,
|
| 311 |
+
bank_size=max(1, config.belief_bank_size),
|
| 312 |
+
history_steps=max(1, config.belief_history_steps),
|
| 313 |
+
max_history_steps=config.max_history_steps,
|
| 314 |
+
write_threshold=config.memory_write_threshold + 0.05,
|
| 315 |
+
suppression_margin=config.memory_suppression_margin,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class DualObservationMemory(nn.Module):
|
| 320 |
+
def __init__(self, config: ObservationMemoryConfig) -> None:
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.scene_memory = SceneMemory(config)
|
| 323 |
+
self.belief_memory = BeliefMemory(config)
|
| 324 |
+
self.uncertainty_head = nn.Sequential(
|
| 325 |
+
nn.LayerNorm(config.hidden_dim),
|
| 326 |
+
nn.Linear(config.hidden_dim, 1),
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def forward(
|
| 330 |
+
self,
|
| 331 |
+
scene_tokens: Tensor,
|
| 332 |
+
history_scene_tokens: Tensor | None = None,
|
| 333 |
+
history_actions: Tensor | None = None,
|
| 334 |
+
) -> dict[str, Tensor]:
|
| 335 |
+
pooled_current = scene_tokens.mean(dim=1)
|
| 336 |
+
scene_output = self.scene_memory(
|
| 337 |
+
pooled_current=pooled_current,
|
| 338 |
+
history_scene_tokens=history_scene_tokens,
|
| 339 |
+
history_actions=history_actions,
|
| 340 |
+
)
|
| 341 |
+
belief_output = self.belief_memory(
|
| 342 |
+
pooled_current=pooled_current,
|
| 343 |
+
history_scene_tokens=history_scene_tokens,
|
| 344 |
+
history_actions=history_actions,
|
| 345 |
+
)
|
| 346 |
+
memory_tokens = torch.cat([scene_output["memory_tokens"], belief_output["memory_tokens"]], dim=1)
|
| 347 |
+
memory_token = memory_tokens.mean(dim=1, keepdim=True)
|
| 348 |
+
memory_state = torch.cat([scene_output["memory_state"], belief_output["memory_state"]], dim=-1)
|
| 349 |
+
pooled_memory = memory_tokens.mean(dim=1)
|
| 350 |
+
return {
|
| 351 |
+
"scene_memory_tokens": scene_output["memory_tokens"],
|
| 352 |
+
"belief_memory_tokens": belief_output["memory_tokens"],
|
| 353 |
+
"memory_tokens": memory_tokens,
|
| 354 |
+
"memory_token": memory_token,
|
| 355 |
+
"memory_sequence": torch.cat(
|
| 356 |
+
[scene_output["memory_sequence"], belief_output["memory_sequence"]],
|
| 357 |
+
dim=1,
|
| 358 |
+
),
|
| 359 |
+
"memory_state": memory_state,
|
| 360 |
+
"memory_uncertainty": torch.nn.functional.softplus(self.uncertainty_head(pooled_memory)).squeeze(-1),
|
| 361 |
+
"memory_write_rate": 0.5 * (scene_output["write_gate"] + belief_output["write_gate"]),
|
| 362 |
+
"memory_saturation": 0.5 * (scene_output["saturation"] + belief_output["saturation"]),
|
| 363 |
+
"scene_write_gate": scene_output["write_gate"],
|
| 364 |
+
"belief_write_gate": belief_output["write_gate"],
|
| 365 |
+
"memory_scene_state": scene_output["memory_state"],
|
| 366 |
+
"memory_belief_state": belief_output["memory_state"],
|
| 367 |
+
}
|
code/reveal_vla_bimanual/models/planner.py
CHANGED
|
@@ -24,6 +24,14 @@ class PlannerConfig:
|
|
| 24 |
num_layers: int = 2
|
| 25 |
num_phases: int = 5
|
| 26 |
num_arm_roles: int = 4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class RevealPlanner(nn.Module):
|
|
@@ -202,3 +210,186 @@ class InteractionPlanner(nn.Module):
|
|
| 202 |
"best_indices": best_idx,
|
| 203 |
"best_chunk": candidate_chunks[batch_indices, best_idx],
|
| 204 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
num_layers: int = 2
|
| 25 |
num_phases: int = 5
|
| 26 |
num_arm_roles: int = 4
|
| 27 |
+
top_k: int = 4
|
| 28 |
+
belief_gain_weight: float = 1.0
|
| 29 |
+
visibility_gain_weight: float = 0.75
|
| 30 |
+
clearance_weight: float = 0.75
|
| 31 |
+
occluder_contact_weight: float = 0.5
|
| 32 |
+
grasp_affordance_weight: float = 0.75
|
| 33 |
+
support_stability_weight: float = 0.5
|
| 34 |
+
residual_weight: float = 0.5
|
| 35 |
|
| 36 |
|
| 37 |
class RevealPlanner(nn.Module):
|
|
|
|
| 210 |
"best_indices": best_idx,
|
| 211 |
"best_chunk": candidate_chunks[batch_indices, best_idx],
|
| 212 |
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class StructuredElasticUtility(nn.Module):
|
| 216 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.config = config
|
| 219 |
+
|
| 220 |
+
def _field_mean(self, tensor: Tensor) -> Tensor:
|
| 221 |
+
if tensor.ndim == 6:
|
| 222 |
+
return tensor.mean(dim=(-1, -2, -3))
|
| 223 |
+
if tensor.ndim == 5:
|
| 224 |
+
return tensor.mean(dim=(-1, -2))
|
| 225 |
+
if tensor.ndim == 4:
|
| 226 |
+
return tensor.mean(dim=(-1, -2))
|
| 227 |
+
return tensor
|
| 228 |
+
|
| 229 |
+
def _initial_scalar(self, state: dict[str, Tensor], key: str) -> Tensor:
|
| 230 |
+
value = state[key]
|
| 231 |
+
if value.ndim >= 4:
|
| 232 |
+
return value.mean(dim=tuple(range(1, value.ndim)))
|
| 233 |
+
if value.ndim == 3:
|
| 234 |
+
return value.mean(dim=(-1, -2))
|
| 235 |
+
if value.ndim == 2:
|
| 236 |
+
return value.mean(dim=-1)
|
| 237 |
+
return value
|
| 238 |
+
|
| 239 |
+
def forward(
|
| 240 |
+
self,
|
| 241 |
+
initial_state: dict[str, Tensor],
|
| 242 |
+
rollout_state: dict[str, Tensor],
|
| 243 |
+
candidate_chunks: Tensor,
|
| 244 |
+
) -> dict[str, Tensor]:
|
| 245 |
+
initial_belief = self._initial_scalar(initial_state, "target_belief_field").unsqueeze(1)
|
| 246 |
+
initial_visibility = self._initial_scalar(initial_state, "visibility_field").unsqueeze(1)
|
| 247 |
+
belief_future = self._field_mean(rollout_state["target_belief_field"]).mean(dim=-1)
|
| 248 |
+
visibility_future = self._field_mean(rollout_state["visibility_field"]).mean(dim=-1)
|
| 249 |
+
clearance = self._field_mean(rollout_state["clearance_field"]).mean(dim=-1)
|
| 250 |
+
occluder_contact = self._field_mean(rollout_state["occluder_contact_field"]).mean(dim=-1)
|
| 251 |
+
grasp_affordance = self._field_mean(rollout_state["grasp_affordance_field"]).mean(dim=-1)
|
| 252 |
+
support_stability = torch.sigmoid(self._field_mean(rollout_state["support_stability_field"])).mean(dim=-1)
|
| 253 |
+
persistence = self._field_mean(rollout_state["persistence_field"]).mean(dim=-1)
|
| 254 |
+
reocclusion = self._field_mean(rollout_state["reocclusion_field"]).mean(dim=-1)
|
| 255 |
+
disturbance = self._field_mean(rollout_state["disturbance_field"]).mean(dim=-1)
|
| 256 |
+
access_quality = torch.sigmoid(self._field_mean(rollout_state["access_field"])).mean(dim=-1)
|
| 257 |
+
retrieve_progress = torch.sigmoid(candidate_chunks[:, :, :, -1]).mean(dim=-1)
|
| 258 |
+
utility = (
|
| 259 |
+
self.config.belief_gain_weight * (belief_future - initial_belief)
|
| 260 |
+
+ self.config.visibility_gain_weight * (visibility_future - initial_visibility)
|
| 261 |
+
+ self.config.clearance_weight * clearance
|
| 262 |
+
+ self.config.occluder_contact_weight * occluder_contact
|
| 263 |
+
+ self.config.grasp_affordance_weight * grasp_affordance
|
| 264 |
+
+ self.config.persistence_weight * persistence
|
| 265 |
+
+ self.config.support_stability_weight * support_stability
|
| 266 |
+
+ self.config.corridor_weight * access_quality
|
| 267 |
+
+ self.config.task_progress_weight * retrieve_progress
|
| 268 |
+
- self.config.reocclusion_weight * reocclusion
|
| 269 |
+
- self.config.disturbance_weight * disturbance
|
| 270 |
+
- self.config.visibility_weight * (1.0 - visibility_future)
|
| 271 |
+
)
|
| 272 |
+
return {
|
| 273 |
+
"belief_gain": belief_future - initial_belief,
|
| 274 |
+
"visibility_gain": visibility_future - initial_visibility,
|
| 275 |
+
"clearance": clearance,
|
| 276 |
+
"occluder_contact_quality": occluder_contact,
|
| 277 |
+
"grasp_affordance": grasp_affordance,
|
| 278 |
+
"persistence": persistence,
|
| 279 |
+
"support_stability": support_stability,
|
| 280 |
+
"reocclusion_penalty": reocclusion,
|
| 281 |
+
"disturbance_penalty": disturbance,
|
| 282 |
+
"access_quality": access_quality,
|
| 283 |
+
"task_progress": retrieve_progress,
|
| 284 |
+
"utility_structured": utility,
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class ResidualPlannerScorer(nn.Module):
|
| 289 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 290 |
+
super().__init__()
|
| 291 |
+
feature_dim = (config.action_dim * 2) + 11
|
| 292 |
+
self.trunk = nn.Sequential(
|
| 293 |
+
nn.LayerNorm(feature_dim),
|
| 294 |
+
nn.Linear(feature_dim, config.hidden_dim),
|
| 295 |
+
nn.GELU(),
|
| 296 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 297 |
+
nn.GELU(),
|
| 298 |
+
)
|
| 299 |
+
self.success_head = nn.Linear(config.hidden_dim, 1)
|
| 300 |
+
self.risk_head = nn.Linear(config.hidden_dim, 1)
|
| 301 |
+
self.residual_head = nn.Linear(config.hidden_dim, 1)
|
| 302 |
+
|
| 303 |
+
def forward(
|
| 304 |
+
self,
|
| 305 |
+
candidate_chunks: Tensor,
|
| 306 |
+
structured: dict[str, Tensor],
|
| 307 |
+
proposal_logits: Tensor | None = None,
|
| 308 |
+
) -> dict[str, Tensor]:
|
| 309 |
+
candidate_mean = candidate_chunks.mean(dim=2)
|
| 310 |
+
candidate_terminal = candidate_chunks[:, :, -1]
|
| 311 |
+
components = torch.stack(
|
| 312 |
+
[
|
| 313 |
+
structured["belief_gain"],
|
| 314 |
+
structured["visibility_gain"],
|
| 315 |
+
structured["clearance"],
|
| 316 |
+
structured["occluder_contact_quality"],
|
| 317 |
+
structured["grasp_affordance"],
|
| 318 |
+
structured["persistence"],
|
| 319 |
+
structured["support_stability"],
|
| 320 |
+
structured["reocclusion_penalty"],
|
| 321 |
+
structured["disturbance_penalty"],
|
| 322 |
+
structured["access_quality"],
|
| 323 |
+
structured["task_progress"],
|
| 324 |
+
],
|
| 325 |
+
dim=-1,
|
| 326 |
+
)
|
| 327 |
+
features = torch.cat([candidate_mean, candidate_terminal, components], dim=-1)
|
| 328 |
+
hidden = self.trunk(features)
|
| 329 |
+
success_logits = self.success_head(hidden).squeeze(-1)
|
| 330 |
+
risk_values = torch.sigmoid(self.risk_head(hidden)).squeeze(-1)
|
| 331 |
+
residual = self.residual_head(hidden).squeeze(-1)
|
| 332 |
+
if proposal_logits is not None and proposal_logits.shape == residual.shape:
|
| 333 |
+
residual = residual + 0.25 * proposal_logits.sigmoid()
|
| 334 |
+
return {
|
| 335 |
+
"planner_hidden": hidden,
|
| 336 |
+
"success_logits": success_logits,
|
| 337 |
+
"risk_values": risk_values,
|
| 338 |
+
"utility_residual": residual,
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class CascadePlanner(nn.Module):
|
| 343 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.config = config
|
| 346 |
+
self.structured = StructuredElasticUtility(config)
|
| 347 |
+
self.residual = ResidualPlannerScorer(config)
|
| 348 |
+
|
| 349 |
+
def shortlist(self, proposal_logits: Tensor | None, candidate_chunks: Tensor) -> Tensor:
|
| 350 |
+
batch_size, num_candidates = candidate_chunks.shape[:2]
|
| 351 |
+
top_k = min(max(1, self.config.top_k), num_candidates)
|
| 352 |
+
if proposal_logits is None:
|
| 353 |
+
cheap_scores = -candidate_chunks.square().mean(dim=(-1, -2))
|
| 354 |
+
else:
|
| 355 |
+
cheap_scores = proposal_logits
|
| 356 |
+
return cheap_scores.topk(top_k, dim=-1).indices
|
| 357 |
+
|
| 358 |
+
def select_best(
|
| 359 |
+
self,
|
| 360 |
+
initial_state: dict[str, Tensor],
|
| 361 |
+
candidate_chunks: Tensor,
|
| 362 |
+
rollout_state: dict[str, Tensor],
|
| 363 |
+
proposal_logits: Tensor | None = None,
|
| 364 |
+
candidate_indices: Tensor | None = None,
|
| 365 |
+
) -> dict[str, Tensor]:
|
| 366 |
+
structured = self.structured(
|
| 367 |
+
initial_state=initial_state,
|
| 368 |
+
rollout_state=rollout_state,
|
| 369 |
+
candidate_chunks=candidate_chunks,
|
| 370 |
+
)
|
| 371 |
+
residual = self.residual(
|
| 372 |
+
candidate_chunks=candidate_chunks,
|
| 373 |
+
structured=structured,
|
| 374 |
+
proposal_logits=proposal_logits,
|
| 375 |
+
)
|
| 376 |
+
utility_total = structured["utility_structured"] + self.config.residual_weight * residual["utility_residual"]
|
| 377 |
+
utility_total = utility_total + residual["success_logits"].sigmoid() - residual["risk_values"]
|
| 378 |
+
best_local = utility_total.argmax(dim=-1)
|
| 379 |
+
batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
|
| 380 |
+
if candidate_indices is None:
|
| 381 |
+
best_indices = best_local
|
| 382 |
+
else:
|
| 383 |
+
best_indices = candidate_indices[batch_indices, best_local]
|
| 384 |
+
return {
|
| 385 |
+
**structured,
|
| 386 |
+
**residual,
|
| 387 |
+
"utility_total": utility_total,
|
| 388 |
+
"utility_scores": utility_total,
|
| 389 |
+
"best_indices": best_indices,
|
| 390 |
+
"best_chunk": candidate_chunks[batch_indices, best_local],
|
| 391 |
+
"ranking_diagnostics": {
|
| 392 |
+
"topk_indices": candidate_indices if candidate_indices is not None else best_local.unsqueeze(-1),
|
| 393 |
+
"best_local_indices": best_local,
|
| 394 |
+
},
|
| 395 |
+
}
|
code/reveal_vla_bimanual/models/policy.py
CHANGED
|
@@ -6,13 +6,28 @@ from typing import Sequence
|
|
| 6 |
import torch
|
| 7 |
from torch import Tensor, nn
|
| 8 |
|
| 9 |
-
from models.action_decoder import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
|
| 11 |
from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
|
| 12 |
-
from models.observation_memory import
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
@dataclass
|
|
@@ -351,3 +366,302 @@ class InteractionBimanualPolicy(BackboneOnlyPolicy):
|
|
| 351 |
outputs["planner_scores"] = selected["utility_scores"]
|
| 352 |
outputs["best_candidate_indices"] = selected["best_indices"]
|
| 353 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import torch
|
| 7 |
from torch import Tensor, nn
|
| 8 |
|
| 9 |
+
from models.action_decoder import (
|
| 10 |
+
ACTBimanualChunkDecoder,
|
| 11 |
+
ChunkDecoderConfig,
|
| 12 |
+
InteractionChunkDecoder,
|
| 13 |
+
SymmetricCoordinatedChunkDecoder,
|
| 14 |
+
)
|
| 15 |
from models.backbones import FrozenVLBackbone, FrozenVLBackboneConfig
|
| 16 |
from models.multiview_fusion import MultiViewFusion, MultiViewFusionConfig
|
| 17 |
+
from models.observation_memory import (
|
| 18 |
+
DualObservationMemory,
|
| 19 |
+
InteractionObservationMemory,
|
| 20 |
+
ObservationMemory,
|
| 21 |
+
ObservationMemoryConfig,
|
| 22 |
+
)
|
| 23 |
+
from models.planner import CascadePlanner, InteractionPlanner, PlannerConfig, RevealPlanner
|
| 24 |
+
from models.reveal_head import (
|
| 25 |
+
ElasticOcclusionStateHead,
|
| 26 |
+
InteractionStateHead,
|
| 27 |
+
RevealHeadConfig,
|
| 28 |
+
RevealStateHead,
|
| 29 |
+
)
|
| 30 |
+
from models.world_model import ElasticOcclusionWorldModel, InteractionRolloutModel, RevealWM, RevealWMConfig
|
| 31 |
|
| 32 |
|
| 33 |
@dataclass
|
|
|
|
| 366 |
outputs["planner_scores"] = selected["utility_scores"]
|
| 367 |
outputs["best_candidate_indices"] = selected["best_indices"]
|
| 368 |
return outputs
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class ElasticRevealBimanualPolicy(BackboneOnlyPolicy):
|
| 372 |
+
def __init__(self, config: PolicyConfig) -> None:
|
| 373 |
+
super().__init__(config)
|
| 374 |
+
self.memory = DualObservationMemory(config.memory)
|
| 375 |
+
self.decoder = SymmetricCoordinatedChunkDecoder(config.decoder)
|
| 376 |
+
self.elastic_state_head = ElasticOcclusionStateHead(config.reveal_head)
|
| 377 |
+
self.world_model = ElasticOcclusionWorldModel(config.world_model)
|
| 378 |
+
self.planner = CascadePlanner(config.planner)
|
| 379 |
+
|
| 380 |
+
def _encode_scene_with_optional_depth(
|
| 381 |
+
self,
|
| 382 |
+
images: Tensor,
|
| 383 |
+
proprio: Tensor,
|
| 384 |
+
texts: Sequence[str] | None = None,
|
| 385 |
+
language_tokens: dict[str, Tensor] | None = None,
|
| 386 |
+
depths: Tensor | None = None,
|
| 387 |
+
depth_valid: Tensor | None = None,
|
| 388 |
+
camera_intrinsics: Tensor | None = None,
|
| 389 |
+
camera_extrinsics: Tensor | None = None,
|
| 390 |
+
use_depth: bool = True,
|
| 391 |
+
) -> dict[str, Tensor]:
|
| 392 |
+
encoded = self.backbone.encode_images(
|
| 393 |
+
images,
|
| 394 |
+
depths=depths if use_depth else None,
|
| 395 |
+
depth_valid=depth_valid if use_depth else None,
|
| 396 |
+
camera_intrinsics=camera_intrinsics if use_depth else None,
|
| 397 |
+
camera_extrinsics=camera_extrinsics if use_depth else None,
|
| 398 |
+
return_aux=True,
|
| 399 |
+
)
|
| 400 |
+
assert isinstance(encoded, dict)
|
| 401 |
+
text_tokens = self._encode_language(images, texts=texts, language_tokens=language_tokens)
|
| 402 |
+
fused = self.fusion(
|
| 403 |
+
image_tokens=encoded["rgb_tokens"],
|
| 404 |
+
proprio=proprio,
|
| 405 |
+
language_tokens=text_tokens,
|
| 406 |
+
depth_tokens=encoded.get("depth_tokens"),
|
| 407 |
+
camera_tokens=encoded.get("camera_tokens"),
|
| 408 |
+
return_aux=True,
|
| 409 |
+
)
|
| 410 |
+
assert isinstance(fused, dict)
|
| 411 |
+
return {
|
| 412 |
+
"scene_tokens": fused["scene_tokens"],
|
| 413 |
+
"view_summaries": fused["view_summaries"],
|
| 414 |
+
"geometry_summaries": fused["geometry_summaries"],
|
| 415 |
+
"depth_tokens": encoded.get("depth_tokens"),
|
| 416 |
+
"camera_tokens": encoded.get("camera_tokens"),
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
def _expand_language_tokens_for_history(
|
| 420 |
+
self,
|
| 421 |
+
language_tokens: dict[str, Tensor] | None,
|
| 422 |
+
history_steps: int,
|
| 423 |
+
) -> dict[str, Tensor] | None:
|
| 424 |
+
if language_tokens is None:
|
| 425 |
+
return None
|
| 426 |
+
return {
|
| 427 |
+
key: value.unsqueeze(1).expand(-1, history_steps, *value.shape[1:]).reshape(
|
| 428 |
+
value.shape[0] * history_steps, *value.shape[1:]
|
| 429 |
+
)
|
| 430 |
+
for key, value in language_tokens.items()
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
def encode_history_with_optional_depth(
|
| 434 |
+
self,
|
| 435 |
+
history_images: Tensor | None,
|
| 436 |
+
history_proprio: Tensor | None,
|
| 437 |
+
texts: Sequence[str] | None = None,
|
| 438 |
+
language_tokens: dict[str, Tensor] | None = None,
|
| 439 |
+
history_depths: Tensor | None = None,
|
| 440 |
+
history_depth_valid: Tensor | None = None,
|
| 441 |
+
camera_intrinsics: Tensor | None = None,
|
| 442 |
+
camera_extrinsics: Tensor | None = None,
|
| 443 |
+
use_depth: bool = True,
|
| 444 |
+
) -> Tensor | None:
|
| 445 |
+
if history_images is None or history_proprio is None or history_images.numel() == 0:
|
| 446 |
+
return None
|
| 447 |
+
batch_size, history_steps = history_images.shape[:2]
|
| 448 |
+
flat_images = history_images.reshape(batch_size * history_steps, *history_images.shape[2:])
|
| 449 |
+
flat_proprio = history_proprio.reshape(batch_size * history_steps, history_proprio.shape[-1])
|
| 450 |
+
flat_depths = None
|
| 451 |
+
flat_depth_valid = None
|
| 452 |
+
if history_depths is not None and history_depths.numel() > 0:
|
| 453 |
+
flat_depths = history_depths.reshape(batch_size * history_steps, *history_depths.shape[2:])
|
| 454 |
+
if history_depth_valid is not None and history_depth_valid.numel() > 0:
|
| 455 |
+
flat_depth_valid = history_depth_valid.reshape(batch_size * history_steps, *history_depth_valid.shape[2:])
|
| 456 |
+
if language_tokens is None:
|
| 457 |
+
flat_texts = [text for text in texts for _ in range(history_steps)] if texts is not None else None
|
| 458 |
+
flat_language_tokens = None
|
| 459 |
+
else:
|
| 460 |
+
flat_texts = None
|
| 461 |
+
flat_language_tokens = self._expand_language_tokens_for_history(language_tokens, history_steps)
|
| 462 |
+
history_scene = self._encode_scene_with_optional_depth(
|
| 463 |
+
images=flat_images,
|
| 464 |
+
proprio=flat_proprio,
|
| 465 |
+
texts=flat_texts,
|
| 466 |
+
language_tokens=flat_language_tokens,
|
| 467 |
+
depths=flat_depths,
|
| 468 |
+
depth_valid=flat_depth_valid,
|
| 469 |
+
camera_intrinsics=None,
|
| 470 |
+
camera_extrinsics=None,
|
| 471 |
+
use_depth=use_depth,
|
| 472 |
+
)["scene_tokens"]
|
| 473 |
+
return history_scene.view(batch_size, history_steps, history_scene.shape[1], history_scene.shape[2])
|
| 474 |
+
|
| 475 |
+
def _tile_tensor(self, value: Tensor, num_candidates: int) -> Tensor:
|
| 476 |
+
return value.unsqueeze(1).expand(-1, num_candidates, *value.shape[1:]).reshape(
|
| 477 |
+
value.shape[0] * num_candidates,
|
| 478 |
+
*value.shape[1:],
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
def _tile_state(self, state: dict[str, Tensor], num_candidates: int) -> dict[str, Tensor]:
|
| 482 |
+
tiled: dict[str, Tensor] = {}
|
| 483 |
+
for key, value in state.items():
|
| 484 |
+
if isinstance(value, Tensor):
|
| 485 |
+
tiled[key] = self._tile_tensor(value, num_candidates)
|
| 486 |
+
return tiled
|
| 487 |
+
|
| 488 |
+
def forward(
|
| 489 |
+
self,
|
| 490 |
+
images: Tensor,
|
| 491 |
+
proprio: Tensor,
|
| 492 |
+
texts: Sequence[str] | None = None,
|
| 493 |
+
language_tokens: dict[str, Tensor] | None = None,
|
| 494 |
+
history_images: Tensor | None = None,
|
| 495 |
+
history_proprio: Tensor | None = None,
|
| 496 |
+
history_actions: Tensor | None = None,
|
| 497 |
+
plan: bool = True,
|
| 498 |
+
support_mode_conditioning: bool = True,
|
| 499 |
+
candidate_chunks_override: Tensor | None = None,
|
| 500 |
+
use_depth: bool = True,
|
| 501 |
+
use_world_model: bool = True,
|
| 502 |
+
use_planner: bool = True,
|
| 503 |
+
use_role_tokens: bool = True,
|
| 504 |
+
history_steps_override: int | None = None,
|
| 505 |
+
depths: Tensor | None = None,
|
| 506 |
+
depth_valid: Tensor | None = None,
|
| 507 |
+
camera_intrinsics: Tensor | None = None,
|
| 508 |
+
camera_extrinsics: Tensor | None = None,
|
| 509 |
+
history_depths: Tensor | None = None,
|
| 510 |
+
history_depth_valid: Tensor | None = None,
|
| 511 |
+
compute_equivariance_probe: bool = False,
|
| 512 |
+
) -> dict[str, Tensor]:
|
| 513 |
+
scene_output = self._encode_scene_with_optional_depth(
|
| 514 |
+
images=images,
|
| 515 |
+
proprio=proprio,
|
| 516 |
+
texts=texts,
|
| 517 |
+
language_tokens=language_tokens,
|
| 518 |
+
depths=depths,
|
| 519 |
+
depth_valid=depth_valid,
|
| 520 |
+
camera_intrinsics=camera_intrinsics,
|
| 521 |
+
camera_extrinsics=camera_extrinsics,
|
| 522 |
+
use_depth=use_depth,
|
| 523 |
+
)
|
| 524 |
+
scene_tokens = scene_output["scene_tokens"]
|
| 525 |
+
history_scene_tokens = self.encode_history_with_optional_depth(
|
| 526 |
+
history_images=history_images,
|
| 527 |
+
history_proprio=history_proprio,
|
| 528 |
+
texts=texts,
|
| 529 |
+
language_tokens=language_tokens,
|
| 530 |
+
history_depths=history_depths,
|
| 531 |
+
history_depth_valid=history_depth_valid,
|
| 532 |
+
camera_intrinsics=camera_intrinsics,
|
| 533 |
+
camera_extrinsics=camera_extrinsics,
|
| 534 |
+
use_depth=use_depth,
|
| 535 |
+
)
|
| 536 |
+
if history_steps_override is not None and history_scene_tokens is not None and history_scene_tokens.numel() > 0:
|
| 537 |
+
history_scene_tokens = history_scene_tokens[:, -history_steps_override:]
|
| 538 |
+
if history_actions is not None and history_actions.numel() > 0:
|
| 539 |
+
history_actions = history_actions[:, -history_steps_override:]
|
| 540 |
+
memory_output = self.memory(
|
| 541 |
+
scene_tokens,
|
| 542 |
+
history_scene_tokens=history_scene_tokens,
|
| 543 |
+
history_actions=history_actions,
|
| 544 |
+
)
|
| 545 |
+
elastic_state = self.elastic_state_head(
|
| 546 |
+
scene_tokens,
|
| 547 |
+
memory_tokens=memory_output["memory_tokens"],
|
| 548 |
+
)
|
| 549 |
+
elastic_state["memory_tokens"] = memory_output["memory_tokens"]
|
| 550 |
+
elastic_state["memory_token"] = memory_output["memory_token"]
|
| 551 |
+
elastic_state["scene_memory_tokens"] = memory_output["scene_memory_tokens"]
|
| 552 |
+
elastic_state["belief_memory_tokens"] = memory_output["belief_memory_tokens"]
|
| 553 |
+
if not use_role_tokens:
|
| 554 |
+
elastic_state = dict(elastic_state)
|
| 555 |
+
elastic_state["arm_role_logits"] = torch.zeros_like(elastic_state["arm_role_logits"])
|
| 556 |
+
|
| 557 |
+
decoded = self.decoder(
|
| 558 |
+
scene_tokens,
|
| 559 |
+
interaction_state=elastic_state,
|
| 560 |
+
memory_tokens=memory_output["memory_tokens"],
|
| 561 |
+
compute_equivariance_probe=compute_equivariance_probe,
|
| 562 |
+
)
|
| 563 |
+
outputs = {
|
| 564 |
+
**decoded,
|
| 565 |
+
"scene_tokens": scene_tokens,
|
| 566 |
+
"history_scene_tokens": history_scene_tokens,
|
| 567 |
+
"memory_output": memory_output,
|
| 568 |
+
"memory_uncertainty": memory_output["memory_uncertainty"],
|
| 569 |
+
"interaction_state": elastic_state,
|
| 570 |
+
"reveal_state": elastic_state,
|
| 571 |
+
"view_summaries": scene_output["view_summaries"],
|
| 572 |
+
"geometry_summaries": scene_output["geometry_summaries"],
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
candidate_chunks = candidate_chunks_override
|
| 576 |
+
proposal_logits = outputs.get("proposal_logits")
|
| 577 |
+
if candidate_chunks is None:
|
| 578 |
+
candidate_chunks = self.decoder.sample_candidates(
|
| 579 |
+
outputs["action_mean"],
|
| 580 |
+
outputs["action_log_std"],
|
| 581 |
+
num_candidates=self.config.decoder.num_candidates,
|
| 582 |
+
proposal_candidates=outputs.get("proposal_candidates"),
|
| 583 |
+
)
|
| 584 |
+
else:
|
| 585 |
+
proposal_logits = None
|
| 586 |
+
outputs["candidate_chunks"] = candidate_chunks
|
| 587 |
+
|
| 588 |
+
if not plan or not use_planner:
|
| 589 |
+
outputs["planned_chunk"] = outputs["action_mean"]
|
| 590 |
+
outputs["planned_rollout"] = {}
|
| 591 |
+
outputs["planner_success_logits"] = torch.zeros(
|
| 592 |
+
candidate_chunks.shape[:2],
|
| 593 |
+
device=candidate_chunks.device,
|
| 594 |
+
dtype=candidate_chunks.dtype,
|
| 595 |
+
)
|
| 596 |
+
outputs["planner_risk_values"] = torch.zeros_like(outputs["planner_success_logits"])
|
| 597 |
+
outputs["planner_scores"] = torch.zeros_like(outputs["planner_success_logits"])
|
| 598 |
+
outputs["best_candidate_indices"] = torch.zeros(
|
| 599 |
+
candidate_chunks.shape[0],
|
| 600 |
+
dtype=torch.long,
|
| 601 |
+
device=candidate_chunks.device,
|
| 602 |
+
)
|
| 603 |
+
return outputs
|
| 604 |
+
|
| 605 |
+
shortlist_indices = self.planner.shortlist(proposal_logits=proposal_logits, candidate_chunks=candidate_chunks)
|
| 606 |
+
outputs["planner_topk_indices"] = shortlist_indices
|
| 607 |
+
batch_size = candidate_chunks.shape[0]
|
| 608 |
+
batch_indices = torch.arange(batch_size, device=candidate_chunks.device).unsqueeze(-1)
|
| 609 |
+
topk_candidates = candidate_chunks[batch_indices, shortlist_indices]
|
| 610 |
+
outputs["planner_topk_candidates"] = topk_candidates
|
| 611 |
+
if proposal_logits is not None:
|
| 612 |
+
topk_proposal_logits = proposal_logits.gather(1, shortlist_indices)
|
| 613 |
+
else:
|
| 614 |
+
topk_proposal_logits = None
|
| 615 |
+
|
| 616 |
+
if not use_world_model:
|
| 617 |
+
score_source = topk_proposal_logits if topk_proposal_logits is not None else -topk_candidates.square().mean(dim=(-1, -2))
|
| 618 |
+
best_local = score_source.argmax(dim=-1)
|
| 619 |
+
best_indices = shortlist_indices[torch.arange(batch_size, device=best_local.device), best_local]
|
| 620 |
+
outputs["planned_chunk"] = candidate_chunks[torch.arange(batch_size, device=best_local.device), best_indices]
|
| 621 |
+
outputs["planned_rollout"] = {}
|
| 622 |
+
outputs["planner_success_logits"] = torch.zeros_like(score_source)
|
| 623 |
+
outputs["planner_risk_values"] = torch.zeros_like(score_source)
|
| 624 |
+
outputs["planner_scores"] = score_source
|
| 625 |
+
outputs["best_candidate_indices"] = best_indices
|
| 626 |
+
outputs["utility_structured"] = score_source
|
| 627 |
+
outputs["utility_residual"] = torch.zeros_like(score_source)
|
| 628 |
+
outputs["utility_total"] = score_source
|
| 629 |
+
return outputs
|
| 630 |
+
|
| 631 |
+
num_topk = topk_candidates.shape[1]
|
| 632 |
+
flat_chunks = topk_candidates.view(batch_size * num_topk, topk_candidates.shape[2], topk_candidates.shape[3])
|
| 633 |
+
tiled_scene = self._tile_tensor(scene_tokens, num_topk)
|
| 634 |
+
planning_state = elastic_state
|
| 635 |
+
if not support_mode_conditioning:
|
| 636 |
+
planning_state = dict(elastic_state)
|
| 637 |
+
planning_state["support_mode_logits"] = torch.zeros_like(elastic_state["support_mode_logits"])
|
| 638 |
+
tiled_state = self._tile_state(planning_state, num_topk)
|
| 639 |
+
rollout = self.world_model(
|
| 640 |
+
scene_tokens=tiled_scene,
|
| 641 |
+
interaction_state=tiled_state,
|
| 642 |
+
action_chunk=flat_chunks,
|
| 643 |
+
memory_tokens=self._tile_tensor(memory_output["memory_tokens"], num_topk),
|
| 644 |
+
scene_memory_tokens=self._tile_tensor(memory_output["scene_memory_tokens"], num_topk),
|
| 645 |
+
belief_memory_tokens=self._tile_tensor(memory_output["belief_memory_tokens"], num_topk),
|
| 646 |
+
)
|
| 647 |
+
reshaped_rollout = {
|
| 648 |
+
key: value.view(batch_size, num_topk, *value.shape[1:]) for key, value in rollout.items()
|
| 649 |
+
}
|
| 650 |
+
selected = self.planner.select_best(
|
| 651 |
+
initial_state=elastic_state,
|
| 652 |
+
candidate_chunks=topk_candidates,
|
| 653 |
+
rollout_state=reshaped_rollout,
|
| 654 |
+
proposal_logits=topk_proposal_logits,
|
| 655 |
+
candidate_indices=shortlist_indices,
|
| 656 |
+
)
|
| 657 |
+
outputs["planned_rollout"] = reshaped_rollout
|
| 658 |
+
outputs["planned_chunk"] = selected["best_chunk"]
|
| 659 |
+
outputs["planner_success_logits"] = selected["success_logits"]
|
| 660 |
+
outputs["planner_risk_values"] = selected["risk_values"]
|
| 661 |
+
outputs["planner_scores"] = selected["utility_total"]
|
| 662 |
+
outputs["best_candidate_indices"] = selected["best_indices"]
|
| 663 |
+
outputs["utility_structured"] = selected["utility_structured"]
|
| 664 |
+
outputs["utility_residual"] = selected["utility_residual"]
|
| 665 |
+
outputs["utility_total"] = selected["utility_total"]
|
| 666 |
+
outputs["ranking_diagnostics"] = selected["ranking_diagnostics"]
|
| 667 |
+
return outputs
|
code/reveal_vla_bimanual/models/reveal_head.py
CHANGED
|
@@ -317,3 +317,245 @@ class InteractionStateHead(nn.Module):
|
|
| 317 |
scene_tokens=scene_tokens,
|
| 318 |
memory_tokens=memory_tokens,
|
| 319 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
scene_tokens=scene_tokens,
|
| 318 |
memory_tokens=memory_tokens,
|
| 319 |
)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class ElasticOcclusionFieldDecoder(nn.Module):
|
| 323 |
+
def __init__(self, config: RevealHeadConfig) -> None:
|
| 324 |
+
super().__init__()
|
| 325 |
+
self.config = config
|
| 326 |
+
self.field_queries = nn.Parameter(
|
| 327 |
+
torch.randn(config.field_size * config.field_size, config.hidden_dim) * 0.02
|
| 328 |
+
)
|
| 329 |
+
self.field_attention = nn.MultiheadAttention(
|
| 330 |
+
embed_dim=config.hidden_dim,
|
| 331 |
+
num_heads=config.num_heads,
|
| 332 |
+
batch_first=True,
|
| 333 |
+
)
|
| 334 |
+
self.field_mlp = nn.Sequential(
|
| 335 |
+
nn.LayerNorm(config.hidden_dim),
|
| 336 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 337 |
+
nn.GELU(),
|
| 338 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 339 |
+
)
|
| 340 |
+
summary_dim = config.hidden_dim * 4
|
| 341 |
+
self.summary_proj = nn.Sequential(
|
| 342 |
+
nn.LayerNorm(summary_dim),
|
| 343 |
+
nn.Linear(summary_dim, config.hidden_dim),
|
| 344 |
+
nn.GELU(),
|
| 345 |
+
)
|
| 346 |
+
self.phase_head = nn.Sequential(
|
| 347 |
+
nn.LayerNorm(summary_dim),
|
| 348 |
+
nn.Linear(summary_dim, config.hidden_dim),
|
| 349 |
+
nn.GELU(),
|
| 350 |
+
nn.Linear(config.hidden_dim, config.num_phases),
|
| 351 |
+
)
|
| 352 |
+
self.arm_role_head = nn.Sequential(
|
| 353 |
+
nn.LayerNorm(config.hidden_dim * 2),
|
| 354 |
+
nn.Linear(config.hidden_dim * 2, config.hidden_dim),
|
| 355 |
+
nn.GELU(),
|
| 356 |
+
nn.Linear(config.hidden_dim, config.num_arm_roles),
|
| 357 |
+
)
|
| 358 |
+
self.arm_identity = nn.Embedding(2, config.hidden_dim)
|
| 359 |
+
self.support_mode = nn.Sequential(
|
| 360 |
+
nn.LayerNorm(summary_dim),
|
| 361 |
+
nn.Linear(summary_dim, config.hidden_dim),
|
| 362 |
+
nn.GELU(),
|
| 363 |
+
nn.Linear(config.hidden_dim, config.num_support_modes),
|
| 364 |
+
)
|
| 365 |
+
self.access_field = nn.Conv2d(config.hidden_dim, config.num_support_modes, kernel_size=1)
|
| 366 |
+
self.target_belief_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
|
| 367 |
+
self.visibility_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
|
| 368 |
+
self.clearance_field = nn.Conv2d(config.hidden_dim, 2, kernel_size=1)
|
| 369 |
+
self.occluder_contact_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
|
| 370 |
+
self.grasp_affordance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
|
| 371 |
+
self.support_stability_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
|
| 372 |
+
self.persistence_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
|
| 373 |
+
self.reocclusion_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
|
| 374 |
+
self.disturbance_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
|
| 375 |
+
self.uncertainty_field = nn.Conv2d(config.hidden_dim, 1, kernel_size=1)
|
| 376 |
+
self.reocclusion_head = nn.Sequential(
|
| 377 |
+
nn.LayerNorm(summary_dim),
|
| 378 |
+
nn.Linear(summary_dim, config.hidden_dim),
|
| 379 |
+
nn.GELU(),
|
| 380 |
+
nn.Linear(config.hidden_dim, config.num_support_modes),
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
def _pool_source(self, source_tokens: Tensor | None, fallback: Tensor) -> Tensor:
|
| 384 |
+
if source_tokens is None or source_tokens.numel() == 0:
|
| 385 |
+
return fallback.new_zeros(fallback.shape)
|
| 386 |
+
return source_tokens.mean(dim=1)
|
| 387 |
+
|
| 388 |
+
def _field_mean(self, field: Tensor) -> Tensor:
|
| 389 |
+
return field.mean(dim=(-1, -2))
|
| 390 |
+
|
| 391 |
+
def _upsampled_belief(self, target_belief_field: Tensor) -> Tensor:
|
| 392 |
+
if target_belief_field.shape[-1] == self.config.belief_map_size:
|
| 393 |
+
return target_belief_field
|
| 394 |
+
return F.interpolate(
|
| 395 |
+
target_belief_field,
|
| 396 |
+
size=(self.config.belief_map_size, self.config.belief_map_size),
|
| 397 |
+
mode="bilinear",
|
| 398 |
+
align_corners=False,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
def forward(
|
| 402 |
+
self,
|
| 403 |
+
interaction_tokens: Tensor,
|
| 404 |
+
scene_tokens: Tensor | None = None,
|
| 405 |
+
memory_tokens: Tensor | None = None,
|
| 406 |
+
) -> dict[str, Tensor]:
|
| 407 |
+
batch_size = interaction_tokens.shape[0]
|
| 408 |
+
pooled_interaction = interaction_tokens.mean(dim=1)
|
| 409 |
+
pooled_scene = self._pool_source(scene_tokens, pooled_interaction)
|
| 410 |
+
pooled_memory = self._pool_source(memory_tokens, pooled_interaction)
|
| 411 |
+
|
| 412 |
+
field_queries = self.field_queries.unsqueeze(0).expand(batch_size, -1, -1)
|
| 413 |
+
source_tokens = interaction_tokens
|
| 414 |
+
if scene_tokens is not None:
|
| 415 |
+
source_tokens = torch.cat([source_tokens, scene_tokens], dim=1)
|
| 416 |
+
if memory_tokens is not None:
|
| 417 |
+
source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
|
| 418 |
+
field_tokens, _ = self.field_attention(field_queries, source_tokens, source_tokens)
|
| 419 |
+
field_tokens = field_tokens + self.field_mlp(field_tokens)
|
| 420 |
+
|
| 421 |
+
side = self.config.field_size
|
| 422 |
+
grid = field_tokens.transpose(1, 2).reshape(batch_size, self.config.hidden_dim, side, side)
|
| 423 |
+
pooled_field = field_tokens.mean(dim=1)
|
| 424 |
+
summary_input = torch.cat([pooled_interaction, pooled_field, pooled_scene, pooled_memory], dim=-1)
|
| 425 |
+
latent_summary = self.summary_proj(summary_input)
|
| 426 |
+
|
| 427 |
+
access_field = self.access_field(grid)
|
| 428 |
+
target_belief_field = self.target_belief_field(grid)
|
| 429 |
+
visibility_field = self.visibility_field(grid)
|
| 430 |
+
clearance_field = self.clearance_field(grid)
|
| 431 |
+
occluder_contact_field = self.occluder_contact_field(grid)
|
| 432 |
+
grasp_affordance_field = self.grasp_affordance_field(grid)
|
| 433 |
+
support_stability_field = self.support_stability_field(grid)
|
| 434 |
+
persistence_field = torch.sigmoid(self.persistence_field(grid))
|
| 435 |
+
reocclusion_field = torch.sigmoid(self.reocclusion_field(grid))
|
| 436 |
+
disturbance_field = torch.sigmoid(self.disturbance_field(grid))
|
| 437 |
+
uncertainty_field = F.softplus(self.uncertainty_field(grid))
|
| 438 |
+
|
| 439 |
+
support_stability_prob = torch.sigmoid(support_stability_field)
|
| 440 |
+
risk_field = torch.sigmoid(
|
| 441 |
+
disturbance_field
|
| 442 |
+
+ 0.75 * reocclusion_field
|
| 443 |
+
+ 0.5 * (1.0 - support_stability_prob)
|
| 444 |
+
+ 0.25 * uncertainty_field
|
| 445 |
+
)
|
| 446 |
+
corridor_source = access_field.amax(dim=-2)
|
| 447 |
+
corridor_logits = F.interpolate(
|
| 448 |
+
corridor_source,
|
| 449 |
+
size=self.config.num_approach_templates,
|
| 450 |
+
mode="linear",
|
| 451 |
+
align_corners=False,
|
| 452 |
+
)
|
| 453 |
+
access_prob = torch.sigmoid(access_field)
|
| 454 |
+
weighted_persistence = (persistence_field.expand_as(access_prob) * access_prob).sum(dim=(-1, -2))
|
| 455 |
+
access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4)
|
| 456 |
+
persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass
|
| 457 |
+
disturbance_cost = disturbance_field.mean(dim=(-1, -2)).squeeze(1)
|
| 458 |
+
|
| 459 |
+
arm_identity = self.arm_identity.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
| 460 |
+
arm_tokens = pooled_interaction.unsqueeze(1).expand(-1, 2, -1) + arm_identity
|
| 461 |
+
arm_role_input = torch.cat(
|
| 462 |
+
[arm_tokens, latent_summary.unsqueeze(1).expand(-1, 2, -1)],
|
| 463 |
+
dim=-1,
|
| 464 |
+
)
|
| 465 |
+
arm_role_logits = self.arm_role_head(arm_role_input)
|
| 466 |
+
target_belief_map = self._upsampled_belief(target_belief_field)
|
| 467 |
+
compact_components = [
|
| 468 |
+
target_belief_field.mean(dim=(-1, -2)).squeeze(1),
|
| 469 |
+
visibility_field.mean(dim=(-1, -2)).squeeze(1),
|
| 470 |
+
clearance_field.mean(dim=(-1, -2)).mean(dim=1),
|
| 471 |
+
occluder_contact_field.mean(dim=(-1, -2)).squeeze(1),
|
| 472 |
+
grasp_affordance_field.mean(dim=(-1, -2)).squeeze(1),
|
| 473 |
+
support_stability_prob.mean(dim=(-1, -2)).squeeze(1),
|
| 474 |
+
persistence_field.mean(dim=(-1, -2)).squeeze(1),
|
| 475 |
+
reocclusion_field.mean(dim=(-1, -2)).squeeze(1),
|
| 476 |
+
disturbance_field.mean(dim=(-1, -2)).squeeze(1),
|
| 477 |
+
risk_field.mean(dim=(-1, -2)).squeeze(1),
|
| 478 |
+
uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
|
| 479 |
+
access_prob.mean(dim=(-1, -2)).transpose(0, 1).transpose(0, 1),
|
| 480 |
+
self.support_mode(summary_input),
|
| 481 |
+
self.phase_head(summary_input),
|
| 482 |
+
arm_role_logits.reshape(batch_size, -1),
|
| 483 |
+
]
|
| 484 |
+
compact_state = torch.cat(
|
| 485 |
+
[component if component.ndim > 1 else component.unsqueeze(-1) for component in compact_components],
|
| 486 |
+
dim=-1,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
output = {
|
| 490 |
+
"phase_logits": self.phase_head(summary_input),
|
| 491 |
+
"arm_role_logits": arm_role_logits,
|
| 492 |
+
"target_belief_field": target_belief_field,
|
| 493 |
+
"visibility_field": visibility_field,
|
| 494 |
+
"clearance_field": clearance_field,
|
| 495 |
+
"occluder_contact_field": occluder_contact_field,
|
| 496 |
+
"grasp_affordance_field": grasp_affordance_field,
|
| 497 |
+
"support_stability_field": support_stability_field,
|
| 498 |
+
"persistence_field": persistence_field,
|
| 499 |
+
"reocclusion_field": reocclusion_field,
|
| 500 |
+
"disturbance_field": disturbance_field,
|
| 501 |
+
"risk_field": risk_field,
|
| 502 |
+
"uncertainty_field": uncertainty_field,
|
| 503 |
+
"interaction_tokens": interaction_tokens,
|
| 504 |
+
"field_tokens": field_tokens,
|
| 505 |
+
"latent_summary": latent_summary,
|
| 506 |
+
"support_mode_logits": self.support_mode(summary_input),
|
| 507 |
+
"corridor_logits": corridor_logits,
|
| 508 |
+
"persistence_horizon": persistence_horizon,
|
| 509 |
+
"disturbance_cost": disturbance_cost,
|
| 510 |
+
"belief_map": target_belief_map,
|
| 511 |
+
"reocclusion_logit": self.reocclusion_head(summary_input),
|
| 512 |
+
"persistence_uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
|
| 513 |
+
"access_field": access_field,
|
| 514 |
+
"uncertainty": uncertainty_field.mean(dim=(-1, -2)).squeeze(1),
|
| 515 |
+
"compact_state": compact_state,
|
| 516 |
+
}
|
| 517 |
+
output["target_field"] = output["target_belief_field"]
|
| 518 |
+
output["actor_feasibility_field"] = output["clearance_field"]
|
| 519 |
+
return output
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class ElasticOcclusionStateHead(nn.Module):
|
| 523 |
+
def __init__(self, config: RevealHeadConfig) -> None:
|
| 524 |
+
super().__init__()
|
| 525 |
+
self.config = config
|
| 526 |
+
self.interaction_queries = nn.Parameter(
|
| 527 |
+
torch.randn(config.num_interaction_tokens, config.hidden_dim) * 0.02
|
| 528 |
+
)
|
| 529 |
+
self.interaction_attention = nn.MultiheadAttention(
|
| 530 |
+
embed_dim=config.hidden_dim,
|
| 531 |
+
num_heads=config.num_heads,
|
| 532 |
+
batch_first=True,
|
| 533 |
+
)
|
| 534 |
+
self.interaction_mlp = nn.Sequential(
|
| 535 |
+
nn.LayerNorm(config.hidden_dim),
|
| 536 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 537 |
+
nn.GELU(),
|
| 538 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 539 |
+
)
|
| 540 |
+
self.decoder = ElasticOcclusionFieldDecoder(config)
|
| 541 |
+
|
| 542 |
+
def forward(
|
| 543 |
+
self,
|
| 544 |
+
scene_tokens: Tensor,
|
| 545 |
+
memory_token: Tensor | None = None,
|
| 546 |
+
memory_tokens: Tensor | None = None,
|
| 547 |
+
) -> dict[str, Tensor]:
|
| 548 |
+
if memory_tokens is None:
|
| 549 |
+
memory_tokens = memory_token
|
| 550 |
+
source_tokens = scene_tokens
|
| 551 |
+
if memory_tokens is not None:
|
| 552 |
+
source_tokens = torch.cat([source_tokens, memory_tokens], dim=1)
|
| 553 |
+
batch_size = source_tokens.shape[0]
|
| 554 |
+
interaction_queries = self.interaction_queries.unsqueeze(0).expand(batch_size, -1, -1)
|
| 555 |
+
interaction_tokens, _ = self.interaction_attention(interaction_queries, source_tokens, source_tokens)
|
| 556 |
+
interaction_tokens = interaction_tokens + self.interaction_mlp(interaction_tokens)
|
| 557 |
+
return self.decoder(
|
| 558 |
+
interaction_tokens=interaction_tokens,
|
| 559 |
+
scene_tokens=scene_tokens,
|
| 560 |
+
memory_tokens=memory_tokens,
|
| 561 |
+
)
|
code/reveal_vla_bimanual/models/world_model.py
CHANGED
|
@@ -22,6 +22,8 @@ class RevealWMConfig:
|
|
| 22 |
num_interaction_tokens: int = 8
|
| 23 |
belief_map_size: int = 32
|
| 24 |
predict_belief_map: bool = True
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class RevealWM(nn.Module):
|
|
@@ -152,3 +154,186 @@ class InteractionRolloutModel(nn.Module):
|
|
| 152 |
for key, values in outputs.items():
|
| 153 |
stacked[key] = torch.stack(values, dim=1)
|
| 154 |
return stacked
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
num_interaction_tokens: int = 8
|
| 23 |
belief_map_size: int = 32
|
| 24 |
predict_belief_map: bool = True
|
| 25 |
+
scene_bank_size: int = 2
|
| 26 |
+
belief_bank_size: int = 2
|
| 27 |
|
| 28 |
|
| 29 |
class RevealWM(nn.Module):
|
|
|
|
| 154 |
for key, values in outputs.items():
|
| 155 |
stacked[key] = torch.stack(values, dim=1)
|
| 156 |
return stacked
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class ElasticOcclusionWorldModel(nn.Module):
|
| 160 |
+
def __init__(self, config: RevealWMConfig) -> None:
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.config = config
|
| 163 |
+
compact_state_dim = (
|
| 164 |
+
11
|
| 165 |
+
+ config.num_support_modes
|
| 166 |
+
+ config.num_support_modes
|
| 167 |
+
+ config.num_phases
|
| 168 |
+
+ (2 * config.num_arm_roles)
|
| 169 |
+
)
|
| 170 |
+
self.state_encoder = nn.Sequential(
|
| 171 |
+
nn.LayerNorm(compact_state_dim),
|
| 172 |
+
nn.Linear(compact_state_dim, config.hidden_dim),
|
| 173 |
+
nn.GELU(),
|
| 174 |
+
)
|
| 175 |
+
self.scene_memory_proj = nn.Sequential(
|
| 176 |
+
nn.LayerNorm(config.hidden_dim),
|
| 177 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 178 |
+
nn.GELU(),
|
| 179 |
+
)
|
| 180 |
+
self.belief_memory_proj = nn.Sequential(
|
| 181 |
+
nn.LayerNorm(config.hidden_dim),
|
| 182 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 183 |
+
nn.GELU(),
|
| 184 |
+
)
|
| 185 |
+
self.action_encoder = nn.Sequential(
|
| 186 |
+
nn.LayerNorm(config.action_dim),
|
| 187 |
+
nn.Linear(config.action_dim, config.hidden_dim),
|
| 188 |
+
nn.GELU(),
|
| 189 |
+
)
|
| 190 |
+
self.transition = nn.GRUCell(config.hidden_dim * 4, config.hidden_dim)
|
| 191 |
+
self.scene_memory_update = nn.Linear(config.hidden_dim, config.hidden_dim)
|
| 192 |
+
self.belief_memory_update = nn.Linear(config.hidden_dim, config.hidden_dim)
|
| 193 |
+
self.compact_decoder = nn.Linear(config.hidden_dim, compact_state_dim)
|
| 194 |
+
field_elements = config.field_size * config.field_size
|
| 195 |
+
self.target_belief_head = nn.Linear(config.hidden_dim, field_elements)
|
| 196 |
+
self.visibility_head = nn.Linear(config.hidden_dim, field_elements)
|
| 197 |
+
self.clearance_head = nn.Linear(config.hidden_dim, 2 * field_elements)
|
| 198 |
+
self.occluder_contact_head = nn.Linear(config.hidden_dim, field_elements)
|
| 199 |
+
self.grasp_affordance_head = nn.Linear(config.hidden_dim, field_elements)
|
| 200 |
+
self.support_stability_head = nn.Linear(config.hidden_dim, field_elements)
|
| 201 |
+
self.persistence_head = nn.Linear(config.hidden_dim, field_elements)
|
| 202 |
+
self.reocclusion_head = nn.Linear(config.hidden_dim, field_elements)
|
| 203 |
+
self.disturbance_head = nn.Linear(config.hidden_dim, field_elements)
|
| 204 |
+
self.uncertainty_head = nn.Linear(config.hidden_dim, field_elements)
|
| 205 |
+
self.access_head = nn.Linear(config.hidden_dim, config.num_support_modes * field_elements)
|
| 206 |
+
|
| 207 |
+
def _compact_from_state(self, interaction_state: dict[str, Tensor]) -> Tensor:
|
| 208 |
+
if "compact_state" in interaction_state:
|
| 209 |
+
return interaction_state["compact_state"]
|
| 210 |
+
components = [
|
| 211 |
+
interaction_state["target_belief_field"].mean(dim=(-1, -2)).squeeze(1),
|
| 212 |
+
interaction_state["visibility_field"].mean(dim=(-1, -2)).squeeze(1),
|
| 213 |
+
interaction_state["clearance_field"].mean(dim=(-1, -2)).mean(dim=1),
|
| 214 |
+
interaction_state["occluder_contact_field"].mean(dim=(-1, -2)).squeeze(1),
|
| 215 |
+
interaction_state["grasp_affordance_field"].mean(dim=(-1, -2)).squeeze(1),
|
| 216 |
+
torch.sigmoid(interaction_state["support_stability_field"]).mean(dim=(-1, -2)).squeeze(1),
|
| 217 |
+
interaction_state["persistence_field"].mean(dim=(-1, -2)).squeeze(1),
|
| 218 |
+
interaction_state["reocclusion_field"].mean(dim=(-1, -2)).squeeze(1),
|
| 219 |
+
interaction_state["disturbance_field"].mean(dim=(-1, -2)).squeeze(1),
|
| 220 |
+
interaction_state["risk_field"].mean(dim=(-1, -2)).squeeze(1),
|
| 221 |
+
interaction_state["uncertainty_field"].mean(dim=(-1, -2)).squeeze(1),
|
| 222 |
+
torch.sigmoid(interaction_state["access_field"]).mean(dim=(-1, -2)),
|
| 223 |
+
interaction_state["support_mode_logits"],
|
| 224 |
+
interaction_state["phase_logits"],
|
| 225 |
+
interaction_state["arm_role_logits"].reshape(interaction_state["arm_role_logits"].shape[0], -1),
|
| 226 |
+
]
|
| 227 |
+
return torch.cat([component if component.ndim > 1 else component.unsqueeze(-1) for component in components], dim=-1)
|
| 228 |
+
|
| 229 |
+
def _decode_fields(self, latent: Tensor) -> dict[str, Tensor]:
|
| 230 |
+
batch_size = latent.shape[0]
|
| 231 |
+
side = self.config.field_size
|
| 232 |
+
target_belief_field = self.target_belief_head(latent).view(batch_size, 1, side, side)
|
| 233 |
+
visibility_field = self.visibility_head(latent).view(batch_size, 1, side, side)
|
| 234 |
+
clearance_field = self.clearance_head(latent).view(batch_size, 2, side, side)
|
| 235 |
+
occluder_contact_field = self.occluder_contact_head(latent).view(batch_size, 1, side, side)
|
| 236 |
+
grasp_affordance_field = self.grasp_affordance_head(latent).view(batch_size, 1, side, side)
|
| 237 |
+
support_stability_field = self.support_stability_head(latent).view(batch_size, 1, side, side)
|
| 238 |
+
persistence_field = torch.sigmoid(self.persistence_head(latent).view(batch_size, 1, side, side))
|
| 239 |
+
reocclusion_field = torch.sigmoid(self.reocclusion_head(latent).view(batch_size, 1, side, side))
|
| 240 |
+
disturbance_field = torch.sigmoid(self.disturbance_head(latent).view(batch_size, 1, side, side))
|
| 241 |
+
uncertainty_field = torch.nn.functional.softplus(self.uncertainty_head(latent).view(batch_size, 1, side, side))
|
| 242 |
+
risk_field = torch.sigmoid(
|
| 243 |
+
disturbance_field
|
| 244 |
+
+ 0.75 * reocclusion_field
|
| 245 |
+
+ 0.5 * (1.0 - torch.sigmoid(support_stability_field))
|
| 246 |
+
+ 0.25 * uncertainty_field
|
| 247 |
+
)
|
| 248 |
+
access_field = self.access_head(latent).view(batch_size, self.config.num_support_modes, side, side)
|
| 249 |
+
corridor_source = access_field.amax(dim=-2)
|
| 250 |
+
corridor_logits = torch.nn.functional.interpolate(
|
| 251 |
+
corridor_source,
|
| 252 |
+
size=self.config.num_approach_templates,
|
| 253 |
+
mode="linear",
|
| 254 |
+
align_corners=False,
|
| 255 |
+
)
|
| 256 |
+
access_prob = torch.sigmoid(access_field)
|
| 257 |
+
weighted_persistence = (persistence_field.expand_as(access_prob) * access_prob).sum(dim=(-1, -2))
|
| 258 |
+
access_mass = access_prob.sum(dim=(-1, -2)).clamp_min(1e-4)
|
| 259 |
+
persistence_horizon = self.config.rollout_horizon * weighted_persistence / access_mass
|
| 260 |
+
return {
|
| 261 |
+
"target_belief_field": target_belief_field,
|
| 262 |
+
"visibility_field": visibility_field,
|
| 263 |
+
"clearance_field": clearance_field,
|
| 264 |
+
"occluder_contact_field": occluder_contact_field,
|
| 265 |
+
"grasp_affordance_field": grasp_affordance_field,
|
| 266 |
+
"support_stability_field": support_stability_field,
|
| 267 |
+
"persistence_field": persistence_field,
|
| 268 |
+
"reocclusion_field": reocclusion_field,
|
| 269 |
+
"disturbance_field": disturbance_field,
|
| 270 |
+
"risk_field": risk_field,
|
| 271 |
+
"uncertainty_field": uncertainty_field,
|
| 272 |
+
"access_field": access_field,
|
| 273 |
+
"corridor_logits": corridor_logits,
|
| 274 |
+
"persistence_horizon": persistence_horizon,
|
| 275 |
+
"disturbance_cost": disturbance_field.mean(dim=(-1, -2)).squeeze(1),
|
| 276 |
+
"belief_map": torch.nn.functional.interpolate(
|
| 277 |
+
target_belief_field,
|
| 278 |
+
size=(self.config.belief_map_size, self.config.belief_map_size),
|
| 279 |
+
mode="bilinear",
|
| 280 |
+
align_corners=False,
|
| 281 |
+
),
|
| 282 |
+
"target_field": target_belief_field,
|
| 283 |
+
"actor_feasibility_field": clearance_field,
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
def forward(
|
| 287 |
+
self,
|
| 288 |
+
scene_tokens: Tensor,
|
| 289 |
+
interaction_state: dict[str, Tensor],
|
| 290 |
+
action_chunk: Tensor,
|
| 291 |
+
memory_tokens: Tensor | None = None,
|
| 292 |
+
scene_memory_tokens: Tensor | None = None,
|
| 293 |
+
belief_memory_tokens: Tensor | None = None,
|
| 294 |
+
) -> dict[str, Tensor]:
|
| 295 |
+
if scene_memory_tokens is None:
|
| 296 |
+
scene_memory_tokens = interaction_state.get("scene_memory_tokens")
|
| 297 |
+
if belief_memory_tokens is None:
|
| 298 |
+
belief_memory_tokens = interaction_state.get("belief_memory_tokens")
|
| 299 |
+
if scene_memory_tokens is None and memory_tokens is not None:
|
| 300 |
+
scene_memory_tokens = memory_tokens
|
| 301 |
+
if belief_memory_tokens is None and memory_tokens is not None:
|
| 302 |
+
belief_memory_tokens = memory_tokens
|
| 303 |
+
if scene_memory_tokens is None:
|
| 304 |
+
scene_memory_tokens = scene_tokens[:, :1]
|
| 305 |
+
if belief_memory_tokens is None:
|
| 306 |
+
belief_memory_tokens = scene_tokens[:, :1]
|
| 307 |
+
|
| 308 |
+
latent = self.state_encoder(self._compact_from_state(interaction_state))
|
| 309 |
+
scene_memory = self.scene_memory_proj(scene_memory_tokens.mean(dim=1))
|
| 310 |
+
belief_memory = self.belief_memory_proj(belief_memory_tokens.mean(dim=1))
|
| 311 |
+
outputs: dict[str, list[Tensor]] = {}
|
| 312 |
+
scene_bias = scene_tokens.mean(dim=1)
|
| 313 |
+
|
| 314 |
+
for step in range(action_chunk.shape[1]):
|
| 315 |
+
action_latent = self.action_encoder(action_chunk[:, step])
|
| 316 |
+
transition_input = torch.cat([latent, action_latent, scene_memory, belief_memory], dim=-1)
|
| 317 |
+
latent = self.transition(transition_input, latent + 0.1 * scene_bias)
|
| 318 |
+
scene_memory = 0.75 * scene_memory + 0.25 * torch.tanh(self.scene_memory_update(latent))
|
| 319 |
+
belief_memory = 0.65 * belief_memory + 0.35 * torch.tanh(self.belief_memory_update(latent))
|
| 320 |
+
compact_state = self.compact_decoder(latent)
|
| 321 |
+
decoded = self._decode_fields(latent)
|
| 322 |
+
decoded["compact_state"] = compact_state
|
| 323 |
+
decoded["phase_logits"] = compact_state[:, -(self.config.num_phases + (2 * self.config.num_arm_roles)) : -(2 * self.config.num_arm_roles)]
|
| 324 |
+
role_slice = compact_state[:, -(2 * self.config.num_arm_roles) :]
|
| 325 |
+
decoded["arm_role_logits"] = role_slice.view(role_slice.shape[0], 2, self.config.num_arm_roles)
|
| 326 |
+
decoded["support_mode_logits"] = compact_state[
|
| 327 |
+
:,
|
| 328 |
+
-(self.config.num_phases + (2 * self.config.num_arm_roles) + self.config.num_support_modes) : -(self.config.num_phases + (2 * self.config.num_arm_roles)),
|
| 329 |
+
]
|
| 330 |
+
decoded["scene_memory_tokens"] = scene_memory.unsqueeze(1).expand(-1, self.config.scene_bank_size, -1)
|
| 331 |
+
decoded["belief_memory_tokens"] = belief_memory.unsqueeze(1).expand(-1, self.config.belief_bank_size, -1)
|
| 332 |
+
decoded["memory_tokens"] = torch.cat([decoded["scene_memory_tokens"], decoded["belief_memory_tokens"]], dim=1)
|
| 333 |
+
decoded["memory_token"] = decoded["memory_tokens"].mean(dim=1, keepdim=True)
|
| 334 |
+
decoded["uncertainty"] = decoded["uncertainty_field"].mean(dim=(-1, -2)).squeeze(1)
|
| 335 |
+
decoded["reocclusion_logit"] = decoded["reocclusion_field"].mean(dim=(-1, -2)).expand(-1, self.config.num_support_modes)
|
| 336 |
+
for key, value in decoded.items():
|
| 337 |
+
outputs.setdefault(key, []).append(value)
|
| 338 |
+
|
| 339 |
+
return {key: torch.stack(values, dim=1) for key, values in outputs.items()}
|
code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/__init__.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc and b/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc differ
|
|
|
code/reveal_vla_bimanual/sim_reveal/dataset.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
from typing import Any, Sequence
|
| 5 |
|
|
@@ -9,9 +10,10 @@ from torch.utils.data import Dataset
|
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
|
| 12 |
-
from sim_reveal.procedural_envs import available_proxy_names, make_proxy_env, render_views_from_state
|
| 13 |
|
| 14 |
NOLEAK_PROXY_DATASET_VERSION = "reveal_proxy_v5_noleak_actionhist"
|
|
|
|
| 15 |
LEGACY_PRIVILEGED_RENDER_KEYS = frozenset(
|
| 16 |
{
|
| 17 |
"target_template",
|
|
@@ -44,6 +46,7 @@ def collect_teacher_dataset(
|
|
| 44 |
rollout_horizon: int = 5,
|
| 45 |
history_steps: int = 2,
|
| 46 |
planner_candidates: int = 4,
|
|
|
|
| 47 |
) -> dict[str, Any]:
|
| 48 |
proxy_names = tuple(proxy_names or available_proxy_names())
|
| 49 |
samples: list[dict[str, Any]] = []
|
|
@@ -91,7 +94,7 @@ def collect_teacher_dataset(
|
|
| 91 |
padded_history_actions.append(item["action"])
|
| 92 |
samples.append(
|
| 93 |
{
|
| 94 |
-
"dataset_version":
|
| 95 |
"proxy_name": proxy_name,
|
| 96 |
"episode_id": episode_idx,
|
| 97 |
"render_state": env.render_state(privileged_state),
|
|
@@ -103,10 +106,25 @@ def collect_teacher_dataset(
|
|
| 103 |
"persistence_horizon": privileged_state["persistence_horizon"].astype("float32"),
|
| 104 |
"disturbance_cost": float(privileged_state["disturbance_cost"]),
|
| 105 |
"belief_map": privileged_state["belief_map"].astype("float32"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
"rollout_support_mode": rollout["rollout_support_mode"].astype("int64"),
|
| 107 |
"rollout_corridor_feasible": rollout["rollout_corridor_feasible"].astype("float32"),
|
| 108 |
"rollout_persistence_horizon": rollout["rollout_persistence_horizon"].astype("float32"),
|
| 109 |
"rollout_disturbance_cost": rollout["rollout_disturbance_cost"].astype("float32"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
"history_render_states": padded_history_render_states,
|
| 111 |
"history_proprio": np.stack(padded_history_proprio, axis=0).astype("float32")
|
| 112 |
if padded_history_proprio
|
|
@@ -138,7 +156,7 @@ def collect_teacher_dataset(
|
|
| 138 |
"teacher_success": proxy_success / float(max(1, episodes_per_proxy)),
|
| 139 |
}
|
| 140 |
return {
|
| 141 |
-
"dataset_version":
|
| 142 |
"resolution": resolution,
|
| 143 |
"chunk_horizon": chunk_horizon,
|
| 144 |
"rollout_horizon": rollout_horizon,
|
|
@@ -164,25 +182,46 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
|
|
| 164 |
def __init__(self, samples: Sequence[dict[str, Any]], resolution: int = 96) -> None:
|
| 165 |
self.samples = list(samples)
|
| 166 |
self.resolution = resolution
|
|
|
|
|
|
|
| 167 |
|
| 168 |
def __len__(self) -> int:
|
| 169 |
return len(self.samples)
|
| 170 |
|
| 171 |
-
def
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
proxy_name=sample["proxy_name"],
|
| 176 |
-
render_state=
|
| 177 |
resolution=self.resolution,
|
|
|
|
| 178 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
history_images = []
|
|
|
|
|
|
|
| 180 |
for history_state in sample.get("history_render_states", []):
|
| 181 |
-
rendered =
|
| 182 |
-
proxy_name=sample["proxy_name"],
|
| 183 |
-
render_state=history_state,
|
| 184 |
-
resolution=self.resolution,
|
| 185 |
-
)
|
| 186 |
history_images.append(
|
| 187 |
torch.stack(
|
| 188 |
[
|
|
@@ -193,6 +232,27 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
|
|
| 193 |
dim=0,
|
| 194 |
)
|
| 195 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
stacked = torch.from_numpy(
|
| 197 |
torch.stack(
|
| 198 |
[
|
|
@@ -207,9 +267,42 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
|
|
| 207 |
history_stacked = torch.stack(history_images, dim=0).permute(0, 1, 4, 2, 3).float() / 255.0
|
| 208 |
else:
|
| 209 |
history_stacked = torch.zeros((0, 3, 3, self.resolution, self.resolution), dtype=torch.float32)
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
"images": stacked,
|
|
|
|
|
|
|
| 212 |
"history_images": history_stacked,
|
|
|
|
|
|
|
| 213 |
"history_proprio": torch.as_tensor(sample.get("history_proprio", []), dtype=torch.float32),
|
| 214 |
"history_actions": torch.as_tensor(
|
| 215 |
sample.get(
|
|
@@ -218,6 +311,8 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
|
|
| 218 |
),
|
| 219 |
dtype=torch.float32,
|
| 220 |
),
|
|
|
|
|
|
|
| 221 |
"proprio": torch.as_tensor(sample["proprio"], dtype=torch.float32),
|
| 222 |
"texts": sample["language_goal"],
|
| 223 |
"action_chunk": torch.as_tensor(sample["action_chunk"], dtype=torch.float32),
|
|
@@ -226,15 +321,37 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
|
|
| 226 |
"persistence_horizon": torch.as_tensor(sample["persistence_horizon"], dtype=torch.float32),
|
| 227 |
"disturbance_cost": torch.as_tensor(sample["disturbance_cost"], dtype=torch.float32),
|
| 228 |
"belief_map": torch.as_tensor(sample["belief_map"], dtype=torch.float32).unsqueeze(0),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
"rollout_support_mode": torch.as_tensor(sample["rollout_support_mode"], dtype=torch.long),
|
| 230 |
"rollout_corridor_feasible": torch.as_tensor(sample["rollout_corridor_feasible"], dtype=torch.float32),
|
| 231 |
"rollout_persistence_horizon": torch.as_tensor(sample["rollout_persistence_horizon"], dtype=torch.float32),
|
| 232 |
"rollout_disturbance_cost": torch.as_tensor(sample["rollout_disturbance_cost"], dtype=torch.float32),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
"candidate_action_chunks": torch.as_tensor(sample["candidate_action_chunks"], dtype=torch.float32),
|
| 234 |
"candidate_rollout_support_mode": torch.as_tensor(sample["candidate_rollout_support_mode"], dtype=torch.long),
|
| 235 |
"candidate_rollout_corridor_feasible": torch.as_tensor(sample["candidate_rollout_corridor_feasible"], dtype=torch.float32),
|
| 236 |
"candidate_rollout_persistence_horizon": torch.as_tensor(sample["candidate_rollout_persistence_horizon"], dtype=torch.float32),
|
| 237 |
"candidate_rollout_disturbance_cost": torch.as_tensor(sample["candidate_rollout_disturbance_cost"], dtype=torch.float32),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
"candidate_retrieval_success": torch.as_tensor(sample["candidate_retrieval_success"], dtype=torch.float32),
|
| 239 |
"candidate_final_disturbance_cost": torch.as_tensor(sample["candidate_final_disturbance_cost"], dtype=torch.float32),
|
| 240 |
"candidate_reocclusion_rate": torch.as_tensor(sample["candidate_reocclusion_rate"], dtype=torch.float32),
|
|
@@ -244,6 +361,8 @@ class RevealOfflineDataset(Dataset[dict[str, Any]]):
|
|
| 244 |
"proxy_name": sample["proxy_name"],
|
| 245 |
"episode_id": sample["episode_id"],
|
| 246 |
}
|
|
|
|
|
|
|
| 247 |
|
| 248 |
|
| 249 |
def dataset_from_bundle(dataset_bundle: dict[str, Any], resolution: int | None = None) -> RevealOfflineDataset:
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import pickle
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Any, Sequence
|
| 6 |
|
|
|
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
|
| 13 |
+
from sim_reveal.procedural_envs import available_proxy_names, default_camera_matrices, make_proxy_env, render_views_from_state
|
| 14 |
|
| 15 |
NOLEAK_PROXY_DATASET_VERSION = "reveal_proxy_v5_noleak_actionhist"
|
| 16 |
+
RGBD_PROXY_DATASET_VERSION = "reveal_proxy_v6_rgbd_elastic_state"
|
| 17 |
LEGACY_PRIVILEGED_RENDER_KEYS = frozenset(
|
| 18 |
{
|
| 19 |
"target_template",
|
|
|
|
| 46 |
rollout_horizon: int = 5,
|
| 47 |
history_steps: int = 2,
|
| 48 |
planner_candidates: int = 4,
|
| 49 |
+
dataset_version: str = NOLEAK_PROXY_DATASET_VERSION,
|
| 50 |
) -> dict[str, Any]:
|
| 51 |
proxy_names = tuple(proxy_names or available_proxy_names())
|
| 52 |
samples: list[dict[str, Any]] = []
|
|
|
|
| 94 |
padded_history_actions.append(item["action"])
|
| 95 |
samples.append(
|
| 96 |
{
|
| 97 |
+
"dataset_version": dataset_version,
|
| 98 |
"proxy_name": proxy_name,
|
| 99 |
"episode_id": episode_idx,
|
| 100 |
"render_state": env.render_state(privileged_state),
|
|
|
|
| 106 |
"persistence_horizon": privileged_state["persistence_horizon"].astype("float32"),
|
| 107 |
"disturbance_cost": float(privileged_state["disturbance_cost"]),
|
| 108 |
"belief_map": privileged_state["belief_map"].astype("float32"),
|
| 109 |
+
"visibility_map": privileged_state["visibility_map"].astype("float32"),
|
| 110 |
+
"clearance_map": privileged_state["clearance_map"].astype("float32"),
|
| 111 |
+
"occluder_contact_map": privileged_state["occluder_contact_map"].astype("float32"),
|
| 112 |
+
"grasp_affordance_map": privileged_state["grasp_affordance_map"].astype("float32"),
|
| 113 |
+
"support_stability": float(privileged_state["support_stability"]),
|
| 114 |
+
"support_stability_map": privileged_state["support_stability_map"].astype("float32"),
|
| 115 |
+
"reocclusion_target": float(privileged_state["reocclusion_target"]),
|
| 116 |
+
"reocclusion_map": privileged_state["reocclusion_map"].astype("float32"),
|
| 117 |
"rollout_support_mode": rollout["rollout_support_mode"].astype("int64"),
|
| 118 |
"rollout_corridor_feasible": rollout["rollout_corridor_feasible"].astype("float32"),
|
| 119 |
"rollout_persistence_horizon": rollout["rollout_persistence_horizon"].astype("float32"),
|
| 120 |
"rollout_disturbance_cost": rollout["rollout_disturbance_cost"].astype("float32"),
|
| 121 |
+
"rollout_belief_map": rollout["rollout_belief_map"].astype("float32"),
|
| 122 |
+
"rollout_visibility_map": rollout["rollout_visibility_map"].astype("float32"),
|
| 123 |
+
"rollout_clearance_map": rollout["rollout_clearance_map"].astype("float32"),
|
| 124 |
+
"rollout_support_stability": rollout["rollout_support_stability"].astype("float32"),
|
| 125 |
+
"rollout_reocclusion_target": rollout["rollout_reocclusion_target"].astype("float32"),
|
| 126 |
+
"rollout_occluder_contact_map": rollout["rollout_occluder_contact_map"].astype("float32"),
|
| 127 |
+
"rollout_grasp_affordance_map": rollout["rollout_grasp_affordance_map"].astype("float32"),
|
| 128 |
"history_render_states": padded_history_render_states,
|
| 129 |
"history_proprio": np.stack(padded_history_proprio, axis=0).astype("float32")
|
| 130 |
if padded_history_proprio
|
|
|
|
| 156 |
"teacher_success": proxy_success / float(max(1, episodes_per_proxy)),
|
| 157 |
}
|
| 158 |
return {
|
| 159 |
+
"dataset_version": dataset_version,
|
| 160 |
"resolution": resolution,
|
| 161 |
"chunk_horizon": chunk_horizon,
|
| 162 |
"rollout_horizon": rollout_horizon,
|
|
|
|
| 182 |
def __init__(self, samples: Sequence[dict[str, Any]], resolution: int = 96) -> None:
|
| 183 |
self.samples = list(samples)
|
| 184 |
self.resolution = resolution
|
| 185 |
+
self._render_cache: dict[bytes, dict[str, np.ndarray]] = {}
|
| 186 |
+
self._item_cache: dict[int, dict[str, Any]] = {}
|
| 187 |
|
| 188 |
def __len__(self) -> int:
|
| 189 |
return len(self.samples)
|
| 190 |
|
| 191 |
+
def _render_cache_key(self, sample: dict[str, Any], render_state: dict[str, Any]) -> bytes:
|
| 192 |
+
include_depth = sample.get("dataset_version") == RGBD_PROXY_DATASET_VERSION
|
| 193 |
+
return pickle.dumps(
|
| 194 |
+
(sample["proxy_name"], self.resolution, include_depth, render_state),
|
| 195 |
+
protocol=4,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def _render_sample(self, sample: dict[str, Any], render_state: dict[str, Any]) -> dict[str, np.ndarray]:
|
| 199 |
+
cache_key = self._render_cache_key(sample, render_state)
|
| 200 |
+
cached = self._render_cache.get(cache_key)
|
| 201 |
+
if cached is not None:
|
| 202 |
+
return cached
|
| 203 |
+
include_depth = sample.get("dataset_version") == RGBD_PROXY_DATASET_VERSION
|
| 204 |
+
rendered = render_views_from_state(
|
| 205 |
proxy_name=sample["proxy_name"],
|
| 206 |
+
render_state=render_state,
|
| 207 |
resolution=self.resolution,
|
| 208 |
+
include_depth=include_depth,
|
| 209 |
)
|
| 210 |
+
self._render_cache[cache_key] = rendered
|
| 211 |
+
return rendered
|
| 212 |
+
|
| 213 |
+
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 214 |
+
cached_item = self._item_cache.get(index)
|
| 215 |
+
if cached_item is not None:
|
| 216 |
+
return cached_item
|
| 217 |
+
sample = self.samples[index]
|
| 218 |
+
_assert_noleak_sample(sample)
|
| 219 |
+
images = self._render_sample(sample, sample["render_state"])
|
| 220 |
history_images = []
|
| 221 |
+
history_depths = []
|
| 222 |
+
history_depth_valid = []
|
| 223 |
for history_state in sample.get("history_render_states", []):
|
| 224 |
+
rendered = self._render_sample(sample, history_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
history_images.append(
|
| 226 |
torch.stack(
|
| 227 |
[
|
|
|
|
| 232 |
dim=0,
|
| 233 |
)
|
| 234 |
)
|
| 235 |
+
if sample.get("dataset_version") == RGBD_PROXY_DATASET_VERSION:
|
| 236 |
+
history_depths.append(
|
| 237 |
+
torch.stack(
|
| 238 |
+
[
|
| 239 |
+
torch.from_numpy(rendered["front_depth"]),
|
| 240 |
+
torch.from_numpy(rendered["wrist_left_depth"]),
|
| 241 |
+
torch.from_numpy(rendered["wrist_right_depth"]),
|
| 242 |
+
],
|
| 243 |
+
dim=0,
|
| 244 |
+
)
|
| 245 |
+
)
|
| 246 |
+
history_depth_valid.append(
|
| 247 |
+
torch.stack(
|
| 248 |
+
[
|
| 249 |
+
torch.from_numpy(rendered["front_depth_valid"]),
|
| 250 |
+
torch.from_numpy(rendered["wrist_left_depth_valid"]),
|
| 251 |
+
torch.from_numpy(rendered["wrist_right_depth_valid"]),
|
| 252 |
+
],
|
| 253 |
+
dim=0,
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
stacked = torch.from_numpy(
|
| 257 |
torch.stack(
|
| 258 |
[
|
|
|
|
| 267 |
history_stacked = torch.stack(history_images, dim=0).permute(0, 1, 4, 2, 3).float() / 255.0
|
| 268 |
else:
|
| 269 |
history_stacked = torch.zeros((0, 3, 3, self.resolution, self.resolution), dtype=torch.float32)
|
| 270 |
+
if sample.get("dataset_version") == RGBD_PROXY_DATASET_VERSION:
|
| 271 |
+
depths = torch.stack(
|
| 272 |
+
[
|
| 273 |
+
torch.from_numpy(images["front_depth"]),
|
| 274 |
+
torch.from_numpy(images["wrist_left_depth"]),
|
| 275 |
+
torch.from_numpy(images["wrist_right_depth"]),
|
| 276 |
+
],
|
| 277 |
+
dim=0,
|
| 278 |
+
).unsqueeze(1).float()
|
| 279 |
+
depth_valid = torch.stack(
|
| 280 |
+
[
|
| 281 |
+
torch.from_numpy(images["front_depth_valid"]),
|
| 282 |
+
torch.from_numpy(images["wrist_left_depth_valid"]),
|
| 283 |
+
torch.from_numpy(images["wrist_right_depth_valid"]),
|
| 284 |
+
],
|
| 285 |
+
dim=0,
|
| 286 |
+
).unsqueeze(1).float()
|
| 287 |
+
if history_depths:
|
| 288 |
+
history_depths_tensor = torch.stack(history_depths, dim=0).unsqueeze(2).float()
|
| 289 |
+
history_depth_valid_tensor = torch.stack(history_depth_valid, dim=0).unsqueeze(2).float()
|
| 290 |
+
else:
|
| 291 |
+
history_depths_tensor = torch.zeros((0, 3, 1, self.resolution, self.resolution), dtype=torch.float32)
|
| 292 |
+
history_depth_valid_tensor = torch.zeros((0, 3, 1, self.resolution, self.resolution), dtype=torch.float32)
|
| 293 |
+
else:
|
| 294 |
+
depths = torch.zeros((3, 1, self.resolution, self.resolution), dtype=torch.float32)
|
| 295 |
+
depth_valid = torch.zeros_like(depths)
|
| 296 |
+
history_depths_tensor = torch.zeros((0, 3, 1, self.resolution, self.resolution), dtype=torch.float32)
|
| 297 |
+
history_depth_valid_tensor = torch.zeros_like(history_depths_tensor)
|
| 298 |
+
camera_intrinsics, camera_extrinsics = default_camera_matrices()
|
| 299 |
+
item = {
|
| 300 |
"images": stacked,
|
| 301 |
+
"depths": depths,
|
| 302 |
+
"depth_valid": depth_valid,
|
| 303 |
"history_images": history_stacked,
|
| 304 |
+
"history_depths": history_depths_tensor,
|
| 305 |
+
"history_depth_valid": history_depth_valid_tensor,
|
| 306 |
"history_proprio": torch.as_tensor(sample.get("history_proprio", []), dtype=torch.float32),
|
| 307 |
"history_actions": torch.as_tensor(
|
| 308 |
sample.get(
|
|
|
|
| 311 |
),
|
| 312 |
dtype=torch.float32,
|
| 313 |
),
|
| 314 |
+
"camera_intrinsics": torch.as_tensor(camera_intrinsics, dtype=torch.float32),
|
| 315 |
+
"camera_extrinsics": torch.as_tensor(camera_extrinsics, dtype=torch.float32),
|
| 316 |
"proprio": torch.as_tensor(sample["proprio"], dtype=torch.float32),
|
| 317 |
"texts": sample["language_goal"],
|
| 318 |
"action_chunk": torch.as_tensor(sample["action_chunk"], dtype=torch.float32),
|
|
|
|
| 321 |
"persistence_horizon": torch.as_tensor(sample["persistence_horizon"], dtype=torch.float32),
|
| 322 |
"disturbance_cost": torch.as_tensor(sample["disturbance_cost"], dtype=torch.float32),
|
| 323 |
"belief_map": torch.as_tensor(sample["belief_map"], dtype=torch.float32).unsqueeze(0),
|
| 324 |
+
"visibility_map": torch.as_tensor(sample.get("visibility_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
|
| 325 |
+
"clearance_map": torch.as_tensor(sample.get("clearance_map", np.zeros((2, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 326 |
+
"occluder_contact_map": torch.as_tensor(sample.get("occluder_contact_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
|
| 327 |
+
"grasp_affordance_map": torch.as_tensor(sample.get("grasp_affordance_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
|
| 328 |
+
"support_stability": torch.as_tensor(sample.get("support_stability", 0.0), dtype=torch.float32),
|
| 329 |
+
"support_stability_map": torch.as_tensor(sample.get("support_stability_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
|
| 330 |
+
"reocclusion_target": torch.as_tensor(sample.get("reocclusion_target", 0.0), dtype=torch.float32),
|
| 331 |
+
"reocclusion_map": torch.as_tensor(sample.get("reocclusion_map", np.zeros((32, 32), dtype=np.float32)), dtype=torch.float32).unsqueeze(0),
|
| 332 |
"rollout_support_mode": torch.as_tensor(sample["rollout_support_mode"], dtype=torch.long),
|
| 333 |
"rollout_corridor_feasible": torch.as_tensor(sample["rollout_corridor_feasible"], dtype=torch.float32),
|
| 334 |
"rollout_persistence_horizon": torch.as_tensor(sample["rollout_persistence_horizon"], dtype=torch.float32),
|
| 335 |
"rollout_disturbance_cost": torch.as_tensor(sample["rollout_disturbance_cost"], dtype=torch.float32),
|
| 336 |
+
"rollout_belief_map": torch.as_tensor(sample.get("rollout_belief_map", np.zeros((0, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 337 |
+
"rollout_visibility_map": torch.as_tensor(sample.get("rollout_visibility_map", np.zeros((0, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 338 |
+
"rollout_clearance_map": torch.as_tensor(sample.get("rollout_clearance_map", np.zeros((0, 2, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 339 |
+
"rollout_support_stability": torch.as_tensor(sample.get("rollout_support_stability", np.zeros((0,), dtype=np.float32)), dtype=torch.float32),
|
| 340 |
+
"rollout_reocclusion_target": torch.as_tensor(sample.get("rollout_reocclusion_target", np.zeros((0,), dtype=np.float32)), dtype=torch.float32),
|
| 341 |
+
"rollout_occluder_contact_map": torch.as_tensor(sample.get("rollout_occluder_contact_map", np.zeros((0, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 342 |
+
"rollout_grasp_affordance_map": torch.as_tensor(sample.get("rollout_grasp_affordance_map", np.zeros((0, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 343 |
"candidate_action_chunks": torch.as_tensor(sample["candidate_action_chunks"], dtype=torch.float32),
|
| 344 |
"candidate_rollout_support_mode": torch.as_tensor(sample["candidate_rollout_support_mode"], dtype=torch.long),
|
| 345 |
"candidate_rollout_corridor_feasible": torch.as_tensor(sample["candidate_rollout_corridor_feasible"], dtype=torch.float32),
|
| 346 |
"candidate_rollout_persistence_horizon": torch.as_tensor(sample["candidate_rollout_persistence_horizon"], dtype=torch.float32),
|
| 347 |
"candidate_rollout_disturbance_cost": torch.as_tensor(sample["candidate_rollout_disturbance_cost"], dtype=torch.float32),
|
| 348 |
+
"candidate_rollout_belief_map": torch.as_tensor(sample.get("candidate_rollout_belief_map", np.zeros((0, 0, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 349 |
+
"candidate_rollout_visibility_map": torch.as_tensor(sample.get("candidate_rollout_visibility_map", np.zeros((0, 0, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 350 |
+
"candidate_rollout_clearance_map": torch.as_tensor(sample.get("candidate_rollout_clearance_map", np.zeros((0, 0, 2, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 351 |
+
"candidate_rollout_support_stability": torch.as_tensor(sample.get("candidate_rollout_support_stability", np.zeros((0, 0), dtype=np.float32)), dtype=torch.float32),
|
| 352 |
+
"candidate_rollout_reocclusion_target": torch.as_tensor(sample.get("candidate_rollout_reocclusion_target", np.zeros((0, 0), dtype=np.float32)), dtype=torch.float32),
|
| 353 |
+
"candidate_rollout_occluder_contact_map": torch.as_tensor(sample.get("candidate_rollout_occluder_contact_map", np.zeros((0, 0, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 354 |
+
"candidate_rollout_grasp_affordance_map": torch.as_tensor(sample.get("candidate_rollout_grasp_affordance_map", np.zeros((0, 0, 32, 32), dtype=np.float32)), dtype=torch.float32),
|
| 355 |
"candidate_retrieval_success": torch.as_tensor(sample["candidate_retrieval_success"], dtype=torch.float32),
|
| 356 |
"candidate_final_disturbance_cost": torch.as_tensor(sample["candidate_final_disturbance_cost"], dtype=torch.float32),
|
| 357 |
"candidate_reocclusion_rate": torch.as_tensor(sample["candidate_reocclusion_rate"], dtype=torch.float32),
|
|
|
|
| 361 |
"proxy_name": sample["proxy_name"],
|
| 362 |
"episode_id": sample["episode_id"],
|
| 363 |
}
|
| 364 |
+
self._item_cache[index] = item
|
| 365 |
+
return item
|
| 366 |
|
| 367 |
|
| 368 |
def dataset_from_bundle(dataset_bundle: dict[str, Any], resolution: int | None = None) -> RevealOfflineDataset:
|
code/reveal_vla_bimanual/sim_reveal/procedural_envs.py
CHANGED
|
@@ -83,6 +83,26 @@ PROXY_GOALS = {
|
|
| 83 |
}
|
| 84 |
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
def available_proxy_names() -> tuple[str, ...]:
|
| 87 |
return tuple(PROXY_CONFIGS.keys())
|
| 88 |
|
|
@@ -285,6 +305,57 @@ class ProceduralRevealEnv:
|
|
| 285 |
belief *= visibility
|
| 286 |
return belief.astype(np.float32)
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
def get_privileged_state(self) -> dict[str, Any]:
|
| 289 |
support_mode = int(self._current_support_mode())
|
| 290 |
corridor = np.stack(
|
|
@@ -294,12 +365,29 @@ class ProceduralRevealEnv:
|
|
| 294 |
persistence = np.asarray([self._persistence_for_mode(mode) for mode in SupportMode], dtype=np.float32)
|
| 295 |
visibility = self._visibility()
|
| 296 |
disturbance_cost = float(np.clip(self.disturbance + 0.08 * max(0.0, self.opening - self.dynamics.desired_opening), 0.0, 1.0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
return {
|
| 298 |
"support_mode": support_mode,
|
| 299 |
"corridor_feasible": corridor,
|
| 300 |
"persistence_horizon": persistence,
|
| 301 |
"disturbance_cost": disturbance_cost,
|
| 302 |
-
"belief_map":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
"visibility": visibility,
|
| 304 |
"retrieval_success": bool(self.retrieved),
|
| 305 |
"target_template": self.target_template,
|
|
@@ -335,12 +423,18 @@ class ProceduralRevealEnv:
|
|
| 335 |
render_state=render_state,
|
| 336 |
resolution=self.resolution,
|
| 337 |
num_templates=self.num_templates,
|
|
|
|
| 338 |
)
|
|
|
|
| 339 |
return {
|
| 340 |
"images": np.stack([images[camera] for camera in self.camera_names], axis=0),
|
|
|
|
|
|
|
| 341 |
"proprio": self._proprio(privileged_state),
|
| 342 |
"text": PROXY_GOALS[self.proxy_name],
|
| 343 |
"camera_names": self.camera_names,
|
|
|
|
|
|
|
| 344 |
}
|
| 345 |
|
| 346 |
def teacher_action(self) -> np.ndarray:
|
|
@@ -385,6 +479,13 @@ class ProceduralRevealEnv:
|
|
| 385 |
rollout_corridor = []
|
| 386 |
rollout_persistence = []
|
| 387 |
rollout_disturbance = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
for step in range(chunk_horizon):
|
| 389 |
action = self.teacher_action()
|
| 390 |
action_chunk.append(action)
|
|
@@ -394,21 +495,43 @@ class ProceduralRevealEnv:
|
|
| 394 |
rollout_corridor.append(privileged_state["corridor_feasible"])
|
| 395 |
rollout_persistence.append(privileged_state["persistence_horizon"])
|
| 396 |
rollout_disturbance.append(privileged_state["disturbance_cost"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
if terminated or truncated:
|
| 398 |
break
|
| 399 |
while len(action_chunk) < chunk_horizon:
|
| 400 |
action_chunk.append(np.zeros((14,), dtype=np.float32))
|
| 401 |
while len(rollout_support_mode) < rollout_horizon:
|
|
|
|
| 402 |
rollout_support_mode.append(int(self._current_support_mode()))
|
| 403 |
-
rollout_corridor.append(
|
| 404 |
-
rollout_persistence.append(
|
| 405 |
-
rollout_disturbance.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
self.restore_state(snapshot)
|
| 407 |
return np.stack(action_chunk, axis=0).astype(np.float32), {
|
| 408 |
"rollout_support_mode": np.asarray(rollout_support_mode, dtype=np.int64),
|
| 409 |
"rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32),
|
| 410 |
"rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32),
|
| 411 |
"rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
}
|
| 413 |
|
| 414 |
def evaluate_action_chunk(
|
|
@@ -422,6 +545,13 @@ class ProceduralRevealEnv:
|
|
| 422 |
rollout_corridor: list[np.ndarray] = []
|
| 423 |
rollout_persistence: list[np.ndarray] = []
|
| 424 |
rollout_disturbance: list[float] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
corridor_open_trace = [float(self.get_privileged_state()["corridor_feasible"][self._current_support_mode()].any())]
|
| 426 |
visibility_trace = [float(self.get_privileged_state()["visibility"])]
|
| 427 |
terminated = False
|
|
@@ -434,6 +564,13 @@ class ProceduralRevealEnv:
|
|
| 434 |
rollout_corridor.append(privileged_state["corridor_feasible"].astype(np.float32))
|
| 435 |
rollout_persistence.append(privileged_state["persistence_horizon"].astype(np.float32))
|
| 436 |
rollout_disturbance.append(float(privileged_state["disturbance_cost"]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
corridor_open_trace.append(float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any()))
|
| 438 |
visibility_trace.append(float(privileged_state["visibility"]))
|
| 439 |
if terminated or truncated:
|
|
@@ -444,6 +581,13 @@ class ProceduralRevealEnv:
|
|
| 444 |
rollout_corridor.append(current["corridor_feasible"].astype(np.float32))
|
| 445 |
rollout_persistence.append(current["persistence_horizon"].astype(np.float32))
|
| 446 |
rollout_disturbance.append(float(current["disturbance_cost"]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
final_state = self.get_privileged_state()
|
| 448 |
reocclusion = float(
|
| 449 |
np.logical_and(
|
|
@@ -456,6 +600,13 @@ class ProceduralRevealEnv:
|
|
| 456 |
"rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32),
|
| 457 |
"rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32),
|
| 458 |
"rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
"retrieval_success": float(final_state["retrieval_success"]),
|
| 460 |
"final_disturbance_cost": float(final_state["disturbance_cost"]),
|
| 461 |
"reocclusion_rate": reocclusion,
|
|
@@ -493,6 +644,27 @@ class ProceduralRevealEnv:
|
|
| 493 |
"candidate_rollout_disturbance_cost": np.stack(
|
| 494 |
[item["rollout_disturbance_cost"] for item in outcomes], axis=0
|
| 495 |
).astype(np.float32),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
"candidate_retrieval_success": np.asarray([item["retrieval_success"] for item in outcomes], dtype=np.float32),
|
| 497 |
"candidate_final_disturbance_cost": np.asarray(
|
| 498 |
[item["final_disturbance_cost"] for item in outcomes], dtype=np.float32
|
|
@@ -587,6 +759,7 @@ def render_views_from_state(
|
|
| 587 |
render_state: dict[str, Any],
|
| 588 |
resolution: int,
|
| 589 |
num_templates: int = 32,
|
|
|
|
| 590 |
) -> dict[str, np.ndarray]:
|
| 591 |
dynamics = PROXY_DYNAMICS[proxy_name]
|
| 592 |
opening = float(render_state["opening"])
|
|
@@ -668,8 +841,40 @@ def render_views_from_state(
|
|
| 668 |
wrist_right[..., 2] = np.clip(wrist_right[..., 2] + 0.08 * step_fraction + 0.06 * right_band, 0.0, 1.0)
|
| 669 |
wrist_right = np.clip(wrist_right, 0.0, 1.0)
|
| 670 |
|
| 671 |
-
|
| 672 |
"front": (front * 255.0).astype(np.uint8),
|
| 673 |
"wrist_left": (wrist_left * 255.0).astype(np.uint8),
|
| 674 |
"wrist_right": (wrist_right * 255.0).astype(np.uint8),
|
| 675 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
}
|
| 84 |
|
| 85 |
|
| 86 |
+
def default_camera_matrices() -> tuple[np.ndarray, np.ndarray]:
|
| 87 |
+
intrinsics = np.asarray(
|
| 88 |
+
[
|
| 89 |
+
[[140.0, 0.0, 48.0], [0.0, 140.0, 48.0], [0.0, 0.0, 1.0]],
|
| 90 |
+
[[135.0, 0.0, 48.0], [0.0, 135.0, 48.0], [0.0, 0.0, 1.0]],
|
| 91 |
+
[[135.0, 0.0, 48.0], [0.0, 135.0, 48.0], [0.0, 0.0, 1.0]],
|
| 92 |
+
],
|
| 93 |
+
dtype=np.float32,
|
| 94 |
+
)
|
| 95 |
+
extrinsics = np.asarray(
|
| 96 |
+
[
|
| 97 |
+
np.eye(4, dtype=np.float32),
|
| 98 |
+
[[1.0, 0.0, 0.0, -0.18], [0.0, 1.0, 0.0, 0.04], [0.0, 0.0, 1.0, 0.10], [0.0, 0.0, 0.0, 1.0]],
|
| 99 |
+
[[1.0, 0.0, 0.0, 0.18], [0.0, 1.0, 0.0, 0.04], [0.0, 0.0, 1.0, 0.10], [0.0, 0.0, 0.0, 1.0]],
|
| 100 |
+
],
|
| 101 |
+
dtype=np.float32,
|
| 102 |
+
)
|
| 103 |
+
return intrinsics, extrinsics
|
| 104 |
+
|
| 105 |
+
|
| 106 |
def available_proxy_names() -> tuple[str, ...]:
|
| 107 |
return tuple(PROXY_CONFIGS.keys())
|
| 108 |
|
|
|
|
| 305 |
belief *= visibility
|
| 306 |
return belief.astype(np.float32)
|
| 307 |
|
| 308 |
+
def _visibility_map(self, visibility: float) -> np.ndarray:
|
| 309 |
+
belief = self._belief_map(visibility)
|
| 310 |
+
gradient = np.linspace(0.65, 1.0, belief.shape[0], dtype=np.float32).reshape(-1, 1)
|
| 311 |
+
return np.clip(belief * gradient, 0.0, 1.0).astype(np.float32)
|
| 312 |
+
|
| 313 |
+
def _clearance_map(self, visibility: float) -> np.ndarray:
|
| 314 |
+
side = 32
|
| 315 |
+
x = np.linspace(0.0, 1.0, side, dtype=np.float32)
|
| 316 |
+
y = np.linspace(0.0, 1.0, side, dtype=np.float32)
|
| 317 |
+
yy, xx = np.meshgrid(y, x, indexing="ij")
|
| 318 |
+
corridor_width = np.clip(0.05 + 0.18 * self.opening - 0.10 * self.disturbance, 0.01, 0.28)
|
| 319 |
+
corridor = np.exp(-(((xx - self.target_center) ** 2) / max(1e-5, corridor_width**2)))
|
| 320 |
+
vertical = np.exp(-(((yy - (0.72 - 0.25 * self.target_depth)) ** 2) / 0.03))
|
| 321 |
+
left = np.clip(corridor * vertical * visibility * (0.92 - 0.15 * self.disturbance), 0.0, 1.0)
|
| 322 |
+
right = np.clip(corridor * vertical * visibility * (0.88 - 0.10 * self.disturbance), 0.0, 1.0)
|
| 323 |
+
return np.stack([left, right], axis=0).astype(np.float32)
|
| 324 |
+
|
| 325 |
+
def _occluder_contact_map(self) -> np.ndarray:
|
| 326 |
+
side = 32
|
| 327 |
+
x = np.linspace(0.0, 1.0, side, dtype=np.float32)
|
| 328 |
+
y = np.linspace(0.0, 1.0, side, dtype=np.float32)
|
| 329 |
+
yy, xx = np.meshgrid(y, x, indexing="ij")
|
| 330 |
+
gap_width = np.clip(0.03 + 0.16 * self.opening, 0.03, 0.24)
|
| 331 |
+
left_band = np.exp(-(((xx - (self.target_center - gap_width)) ** 2) / 0.0025))
|
| 332 |
+
right_band = np.exp(-(((xx - (self.target_center + gap_width)) ** 2) / 0.0025))
|
| 333 |
+
support = np.exp(-(((yy - 0.55) ** 2) / 0.04))
|
| 334 |
+
return np.clip((left_band + right_band) * support, 0.0, 1.0).astype(np.float32)
|
| 335 |
+
|
| 336 |
+
def _support_stability(self) -> float:
|
| 337 |
+
base = 1.0 - 0.45 * self.disturbance - 0.10 * max(0.0, self.opening - self.dynamics.desired_opening)
|
| 338 |
+
if self._current_support_mode() == self.dynamics.preferred_mode:
|
| 339 |
+
base += 0.08
|
| 340 |
+
return float(np.clip(base, 0.0, 1.0))
|
| 341 |
+
|
| 342 |
+
def _support_stability_map(self) -> np.ndarray:
|
| 343 |
+
return np.full((32, 32), self._support_stability(), dtype=np.float32)
|
| 344 |
+
|
| 345 |
+
def _reocclusion_target(self, persistence: np.ndarray) -> float:
|
| 346 |
+
current_mode = int(self._current_support_mode())
|
| 347 |
+
horizon_ratio = persistence[current_mode] / float(max(1, self.rollout_horizon))
|
| 348 |
+
return float(np.clip(1.0 - horizon_ratio + 0.35 * self.disturbance, 0.0, 1.0))
|
| 349 |
+
|
| 350 |
+
def _grasp_affordance_map(
|
| 351 |
+
self,
|
| 352 |
+
belief_map: np.ndarray,
|
| 353 |
+
visibility_map: np.ndarray,
|
| 354 |
+
clearance_map: np.ndarray,
|
| 355 |
+
) -> np.ndarray:
|
| 356 |
+
combined = belief_map * visibility_map * clearance_map.mean(axis=0)
|
| 357 |
+
return np.clip(combined * (1.0 - 0.35 * self.disturbance), 0.0, 1.0).astype(np.float32)
|
| 358 |
+
|
| 359 |
def get_privileged_state(self) -> dict[str, Any]:
|
| 360 |
support_mode = int(self._current_support_mode())
|
| 361 |
corridor = np.stack(
|
|
|
|
| 365 |
persistence = np.asarray([self._persistence_for_mode(mode) for mode in SupportMode], dtype=np.float32)
|
| 366 |
visibility = self._visibility()
|
| 367 |
disturbance_cost = float(np.clip(self.disturbance + 0.08 * max(0.0, self.opening - self.dynamics.desired_opening), 0.0, 1.0))
|
| 368 |
+
belief_map = self._belief_map(visibility)
|
| 369 |
+
visibility_map = self._visibility_map(visibility)
|
| 370 |
+
clearance_map = self._clearance_map(visibility)
|
| 371 |
+
occluder_contact_map = self._occluder_contact_map()
|
| 372 |
+
support_stability = self._support_stability()
|
| 373 |
+
support_stability_map = self._support_stability_map()
|
| 374 |
+
reocclusion_target = self._reocclusion_target(persistence)
|
| 375 |
+
reocclusion_map = np.full((32, 32), reocclusion_target, dtype=np.float32)
|
| 376 |
+
grasp_affordance_map = self._grasp_affordance_map(belief_map, visibility_map, clearance_map)
|
| 377 |
return {
|
| 378 |
"support_mode": support_mode,
|
| 379 |
"corridor_feasible": corridor,
|
| 380 |
"persistence_horizon": persistence,
|
| 381 |
"disturbance_cost": disturbance_cost,
|
| 382 |
+
"belief_map": belief_map,
|
| 383 |
+
"visibility_map": visibility_map,
|
| 384 |
+
"clearance_map": clearance_map,
|
| 385 |
+
"occluder_contact_map": occluder_contact_map,
|
| 386 |
+
"grasp_affordance_map": grasp_affordance_map,
|
| 387 |
+
"support_stability": support_stability,
|
| 388 |
+
"support_stability_map": support_stability_map,
|
| 389 |
+
"reocclusion_target": reocclusion_target,
|
| 390 |
+
"reocclusion_map": reocclusion_map,
|
| 391 |
"visibility": visibility,
|
| 392 |
"retrieval_success": bool(self.retrieved),
|
| 393 |
"target_template": self.target_template,
|
|
|
|
| 423 |
render_state=render_state,
|
| 424 |
resolution=self.resolution,
|
| 425 |
num_templates=self.num_templates,
|
| 426 |
+
include_depth=True,
|
| 427 |
)
|
| 428 |
+
camera_intrinsics, camera_extrinsics = default_camera_matrices()
|
| 429 |
return {
|
| 430 |
"images": np.stack([images[camera] for camera in self.camera_names], axis=0),
|
| 431 |
+
"depths": np.stack([images[f"{camera}_depth"] for camera in self.camera_names], axis=0)[:, None, :, :],
|
| 432 |
+
"depth_valid": np.stack([images[f"{camera}_depth_valid"] for camera in self.camera_names], axis=0)[:, None, :, :],
|
| 433 |
"proprio": self._proprio(privileged_state),
|
| 434 |
"text": PROXY_GOALS[self.proxy_name],
|
| 435 |
"camera_names": self.camera_names,
|
| 436 |
+
"camera_intrinsics": camera_intrinsics,
|
| 437 |
+
"camera_extrinsics": camera_extrinsics,
|
| 438 |
}
|
| 439 |
|
| 440 |
def teacher_action(self) -> np.ndarray:
|
|
|
|
| 479 |
rollout_corridor = []
|
| 480 |
rollout_persistence = []
|
| 481 |
rollout_disturbance = []
|
| 482 |
+
rollout_belief = []
|
| 483 |
+
rollout_visibility = []
|
| 484 |
+
rollout_clearance = []
|
| 485 |
+
rollout_support_stability = []
|
| 486 |
+
rollout_reocclusion = []
|
| 487 |
+
rollout_occluder_contact = []
|
| 488 |
+
rollout_grasp_affordance = []
|
| 489 |
for step in range(chunk_horizon):
|
| 490 |
action = self.teacher_action()
|
| 491 |
action_chunk.append(action)
|
|
|
|
| 495 |
rollout_corridor.append(privileged_state["corridor_feasible"])
|
| 496 |
rollout_persistence.append(privileged_state["persistence_horizon"])
|
| 497 |
rollout_disturbance.append(privileged_state["disturbance_cost"])
|
| 498 |
+
rollout_belief.append(privileged_state["belief_map"])
|
| 499 |
+
rollout_visibility.append(privileged_state["visibility_map"])
|
| 500 |
+
rollout_clearance.append(privileged_state["clearance_map"])
|
| 501 |
+
rollout_support_stability.append(privileged_state["support_stability"])
|
| 502 |
+
rollout_reocclusion.append(privileged_state["reocclusion_target"])
|
| 503 |
+
rollout_occluder_contact.append(privileged_state["occluder_contact_map"])
|
| 504 |
+
rollout_grasp_affordance.append(privileged_state["grasp_affordance_map"])
|
| 505 |
if terminated or truncated:
|
| 506 |
break
|
| 507 |
while len(action_chunk) < chunk_horizon:
|
| 508 |
action_chunk.append(np.zeros((14,), dtype=np.float32))
|
| 509 |
while len(rollout_support_mode) < rollout_horizon:
|
| 510 |
+
current = self.get_privileged_state()
|
| 511 |
rollout_support_mode.append(int(self._current_support_mode()))
|
| 512 |
+
rollout_corridor.append(current["corridor_feasible"])
|
| 513 |
+
rollout_persistence.append(current["persistence_horizon"])
|
| 514 |
+
rollout_disturbance.append(current["disturbance_cost"])
|
| 515 |
+
rollout_belief.append(current["belief_map"])
|
| 516 |
+
rollout_visibility.append(current["visibility_map"])
|
| 517 |
+
rollout_clearance.append(current["clearance_map"])
|
| 518 |
+
rollout_support_stability.append(current["support_stability"])
|
| 519 |
+
rollout_reocclusion.append(current["reocclusion_target"])
|
| 520 |
+
rollout_occluder_contact.append(current["occluder_contact_map"])
|
| 521 |
+
rollout_grasp_affordance.append(current["grasp_affordance_map"])
|
| 522 |
self.restore_state(snapshot)
|
| 523 |
return np.stack(action_chunk, axis=0).astype(np.float32), {
|
| 524 |
"rollout_support_mode": np.asarray(rollout_support_mode, dtype=np.int64),
|
| 525 |
"rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32),
|
| 526 |
"rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32),
|
| 527 |
"rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32),
|
| 528 |
+
"rollout_belief_map": np.asarray(rollout_belief, dtype=np.float32),
|
| 529 |
+
"rollout_visibility_map": np.asarray(rollout_visibility, dtype=np.float32),
|
| 530 |
+
"rollout_clearance_map": np.asarray(rollout_clearance, dtype=np.float32),
|
| 531 |
+
"rollout_support_stability": np.asarray(rollout_support_stability, dtype=np.float32),
|
| 532 |
+
"rollout_reocclusion_target": np.asarray(rollout_reocclusion, dtype=np.float32),
|
| 533 |
+
"rollout_occluder_contact_map": np.asarray(rollout_occluder_contact, dtype=np.float32),
|
| 534 |
+
"rollout_grasp_affordance_map": np.asarray(rollout_grasp_affordance, dtype=np.float32),
|
| 535 |
}
|
| 536 |
|
| 537 |
def evaluate_action_chunk(
|
|
|
|
| 545 |
rollout_corridor: list[np.ndarray] = []
|
| 546 |
rollout_persistence: list[np.ndarray] = []
|
| 547 |
rollout_disturbance: list[float] = []
|
| 548 |
+
rollout_belief: list[np.ndarray] = []
|
| 549 |
+
rollout_visibility: list[np.ndarray] = []
|
| 550 |
+
rollout_clearance: list[np.ndarray] = []
|
| 551 |
+
rollout_support_stability: list[float] = []
|
| 552 |
+
rollout_reocclusion: list[float] = []
|
| 553 |
+
rollout_occluder_contact: list[np.ndarray] = []
|
| 554 |
+
rollout_grasp_affordance: list[np.ndarray] = []
|
| 555 |
corridor_open_trace = [float(self.get_privileged_state()["corridor_feasible"][self._current_support_mode()].any())]
|
| 556 |
visibility_trace = [float(self.get_privileged_state()["visibility"])]
|
| 557 |
terminated = False
|
|
|
|
| 564 |
rollout_corridor.append(privileged_state["corridor_feasible"].astype(np.float32))
|
| 565 |
rollout_persistence.append(privileged_state["persistence_horizon"].astype(np.float32))
|
| 566 |
rollout_disturbance.append(float(privileged_state["disturbance_cost"]))
|
| 567 |
+
rollout_belief.append(privileged_state["belief_map"].astype(np.float32))
|
| 568 |
+
rollout_visibility.append(privileged_state["visibility_map"].astype(np.float32))
|
| 569 |
+
rollout_clearance.append(privileged_state["clearance_map"].astype(np.float32))
|
| 570 |
+
rollout_support_stability.append(float(privileged_state["support_stability"]))
|
| 571 |
+
rollout_reocclusion.append(float(privileged_state["reocclusion_target"]))
|
| 572 |
+
rollout_occluder_contact.append(privileged_state["occluder_contact_map"].astype(np.float32))
|
| 573 |
+
rollout_grasp_affordance.append(privileged_state["grasp_affordance_map"].astype(np.float32))
|
| 574 |
corridor_open_trace.append(float(privileged_state["corridor_feasible"][privileged_state["support_mode"]].any()))
|
| 575 |
visibility_trace.append(float(privileged_state["visibility"]))
|
| 576 |
if terminated or truncated:
|
|
|
|
| 581 |
rollout_corridor.append(current["corridor_feasible"].astype(np.float32))
|
| 582 |
rollout_persistence.append(current["persistence_horizon"].astype(np.float32))
|
| 583 |
rollout_disturbance.append(float(current["disturbance_cost"]))
|
| 584 |
+
rollout_belief.append(current["belief_map"].astype(np.float32))
|
| 585 |
+
rollout_visibility.append(current["visibility_map"].astype(np.float32))
|
| 586 |
+
rollout_clearance.append(current["clearance_map"].astype(np.float32))
|
| 587 |
+
rollout_support_stability.append(float(current["support_stability"]))
|
| 588 |
+
rollout_reocclusion.append(float(current["reocclusion_target"]))
|
| 589 |
+
rollout_occluder_contact.append(current["occluder_contact_map"].astype(np.float32))
|
| 590 |
+
rollout_grasp_affordance.append(current["grasp_affordance_map"].astype(np.float32))
|
| 591 |
final_state = self.get_privileged_state()
|
| 592 |
reocclusion = float(
|
| 593 |
np.logical_and(
|
|
|
|
| 600 |
"rollout_corridor_feasible": np.asarray(rollout_corridor, dtype=np.float32),
|
| 601 |
"rollout_persistence_horizon": np.asarray(rollout_persistence, dtype=np.float32),
|
| 602 |
"rollout_disturbance_cost": np.asarray(rollout_disturbance, dtype=np.float32),
|
| 603 |
+
"rollout_belief_map": np.asarray(rollout_belief, dtype=np.float32),
|
| 604 |
+
"rollout_visibility_map": np.asarray(rollout_visibility, dtype=np.float32),
|
| 605 |
+
"rollout_clearance_map": np.asarray(rollout_clearance, dtype=np.float32),
|
| 606 |
+
"rollout_support_stability": np.asarray(rollout_support_stability, dtype=np.float32),
|
| 607 |
+
"rollout_reocclusion_target": np.asarray(rollout_reocclusion, dtype=np.float32),
|
| 608 |
+
"rollout_occluder_contact_map": np.asarray(rollout_occluder_contact, dtype=np.float32),
|
| 609 |
+
"rollout_grasp_affordance_map": np.asarray(rollout_grasp_affordance, dtype=np.float32),
|
| 610 |
"retrieval_success": float(final_state["retrieval_success"]),
|
| 611 |
"final_disturbance_cost": float(final_state["disturbance_cost"]),
|
| 612 |
"reocclusion_rate": reocclusion,
|
|
|
|
| 644 |
"candidate_rollout_disturbance_cost": np.stack(
|
| 645 |
[item["rollout_disturbance_cost"] for item in outcomes], axis=0
|
| 646 |
).astype(np.float32),
|
| 647 |
+
"candidate_rollout_belief_map": np.stack(
|
| 648 |
+
[item["rollout_belief_map"] for item in outcomes], axis=0
|
| 649 |
+
).astype(np.float32),
|
| 650 |
+
"candidate_rollout_visibility_map": np.stack(
|
| 651 |
+
[item["rollout_visibility_map"] for item in outcomes], axis=0
|
| 652 |
+
).astype(np.float32),
|
| 653 |
+
"candidate_rollout_clearance_map": np.stack(
|
| 654 |
+
[item["rollout_clearance_map"] for item in outcomes], axis=0
|
| 655 |
+
).astype(np.float32),
|
| 656 |
+
"candidate_rollout_support_stability": np.stack(
|
| 657 |
+
[item["rollout_support_stability"] for item in outcomes], axis=0
|
| 658 |
+
).astype(np.float32),
|
| 659 |
+
"candidate_rollout_reocclusion_target": np.stack(
|
| 660 |
+
[item["rollout_reocclusion_target"] for item in outcomes], axis=0
|
| 661 |
+
).astype(np.float32),
|
| 662 |
+
"candidate_rollout_occluder_contact_map": np.stack(
|
| 663 |
+
[item["rollout_occluder_contact_map"] for item in outcomes], axis=0
|
| 664 |
+
).astype(np.float32),
|
| 665 |
+
"candidate_rollout_grasp_affordance_map": np.stack(
|
| 666 |
+
[item["rollout_grasp_affordance_map"] for item in outcomes], axis=0
|
| 667 |
+
).astype(np.float32),
|
| 668 |
"candidate_retrieval_success": np.asarray([item["retrieval_success"] for item in outcomes], dtype=np.float32),
|
| 669 |
"candidate_final_disturbance_cost": np.asarray(
|
| 670 |
[item["final_disturbance_cost"] for item in outcomes], dtype=np.float32
|
|
|
|
| 759 |
render_state: dict[str, Any],
|
| 760 |
resolution: int,
|
| 761 |
num_templates: int = 32,
|
| 762 |
+
include_depth: bool = False,
|
| 763 |
) -> dict[str, np.ndarray]:
|
| 764 |
dynamics = PROXY_DYNAMICS[proxy_name]
|
| 765 |
opening = float(render_state["opening"])
|
|
|
|
| 841 |
wrist_right[..., 2] = np.clip(wrist_right[..., 2] + 0.08 * step_fraction + 0.06 * right_band, 0.0, 1.0)
|
| 842 |
wrist_right = np.clip(wrist_right, 0.0, 1.0)
|
| 843 |
|
| 844 |
+
outputs = {
|
| 845 |
"front": (front * 255.0).astype(np.uint8),
|
| 846 |
"wrist_left": (wrist_left * 255.0).astype(np.uint8),
|
| 847 |
"wrist_right": (wrist_right * 255.0).astype(np.uint8),
|
| 848 |
}
|
| 849 |
+
if not include_depth:
|
| 850 |
+
return outputs
|
| 851 |
+
|
| 852 |
+
front_depth = np.clip(0.25 + 0.40 * target_depth + 0.15 * disturbance + 0.10 * (1.0 - visibility), 0.0, 1.0)
|
| 853 |
+
target_depth_map = np.clip(0.10 + 0.55 * target_depth, 0.0, 1.0)
|
| 854 |
+
occluder_depth = np.clip(0.30 + 0.20 * disturbance + 0.10 * (1.0 - opening), 0.0, 1.0)
|
| 855 |
+
front_depth_map = np.full((height, width), front_depth, dtype=np.float32)
|
| 856 |
+
front_depth_map[gap_mask] = np.minimum(front_depth_map[gap_mask], occluder_depth)
|
| 857 |
+
front_depth_map[target_mask] = np.minimum(front_depth_map[target_mask], target_depth_map)
|
| 858 |
+
|
| 859 |
+
wrist_left_depth = np.clip(0.35 + 0.25 * target_depth + 0.10 * disturbance, 0.0, 1.0)
|
| 860 |
+
wrist_left_depth_map = np.full((height, width), wrist_left_depth, dtype=np.float32)
|
| 861 |
+
wrist_left_depth_map[left_open] = np.minimum(wrist_left_depth_map[left_open], 0.22 + 0.25 * target_depth)
|
| 862 |
+
wrist_left_depth_map[target_mask] = np.minimum(wrist_left_depth_map[target_mask], target_depth_map)
|
| 863 |
+
|
| 864 |
+
wrist_right_depth = np.clip(0.35 + 0.20 * target_depth + 0.12 * disturbance, 0.0, 1.0)
|
| 865 |
+
wrist_right_depth_map = np.full((height, width), wrist_right_depth, dtype=np.float32)
|
| 866 |
+
right_focus = (right_band * right_clear) > 0.15
|
| 867 |
+
wrist_right_depth_map[right_focus] = np.minimum(wrist_right_depth_map[right_focus], 0.20 + 0.25 * target_depth)
|
| 868 |
+
wrist_right_depth_map[target_mask] = np.minimum(wrist_right_depth_map[target_mask], target_depth_map)
|
| 869 |
+
|
| 870 |
+
outputs.update(
|
| 871 |
+
{
|
| 872 |
+
"front_depth": front_depth_map.astype(np.float32),
|
| 873 |
+
"wrist_left_depth": wrist_left_depth_map.astype(np.float32),
|
| 874 |
+
"wrist_right_depth": wrist_right_depth_map.astype(np.float32),
|
| 875 |
+
"front_depth_valid": np.ones((height, width), dtype=np.float32),
|
| 876 |
+
"wrist_left_depth_valid": np.ones((height, width), dtype=np.float32),
|
| 877 |
+
"wrist_right_depth_valid": np.ones((height, width), dtype=np.float32),
|
| 878 |
+
}
|
| 879 |
+
)
|
| 880 |
+
return outputs
|
code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/__init__.cpython-310.pyc differ
|
|
|
code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc
CHANGED
|
Binary files a/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc and b/code/reveal_vla_bimanual/sim_rlbench/__pycache__/camera_spec.cpython-310.pyc differ
|
|
|