lsnu commited on
Commit
b3a756d
·
verified ·
1 Parent(s): c725033

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/README.md +114 -0
  2. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/docs/adapter_stack.md +87 -0
  3. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/docs/public_benchmark_package.md +73 -0
  4. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc +0 -0
  5. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-311.pyc +0 -0
  6. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc +0 -0
  7. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-311.pyc +0 -0
  8. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/protocols.cpython-310.pyc +0 -0
  9. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/public_benchmark_package.cpython-310.pyc +0 -0
  10. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/public_benchmark_package.cpython-311.pyc +0 -0
  11. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/report.cpython-310.pyc +0 -0
  12. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/report.cpython-311.pyc +0 -0
  13. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_public_benchmark_package.cpython-310.pyc +0 -0
  14. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_public_benchmark_package.cpython-311.pyc +0 -0
  15. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_reveal_benchmark.cpython-310.pyc +0 -0
  16. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_reveal_benchmark.cpython-311.pyc +0 -0
  17. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/public_benchmark_package.py +266 -0
  18. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_maniskill_bridge_retrieval_smoke.py +2037 -0
  19. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_maniskill_pickclutter_smoke.py +2005 -0
  20. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_public_benchmark_package.py +369 -0
  21. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc +0 -0
  22. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc +0 -0
  23. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc +0 -0
  24. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc +0 -0
  25. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc +0 -0
  26. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc +0 -0
  27. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc +0 -0
  28. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc +0 -0
  29. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-310.pyc +0 -0
  30. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-311.pyc +0 -0
  31. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc +0 -0
  32. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc +0 -0
  33. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc +0 -0
  34. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc +0 -0
  35. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc +0 -0
  36. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc +0 -0
  37. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/rvt_backbone.cpython-310.pyc +0 -0
  38. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/rvt_backbone.cpython-311.pyc +0 -0
  39. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc +0 -0
  40. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc +0 -0
  41. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/planner.py +887 -0
  42. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/__init__.cpython-310.pyc +0 -0
  43. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/__init__.cpython-311.pyc +0 -0
  44. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/transforms.cpython-310.pyc +0 -0
  45. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/transforms.cpython-311.pyc +0 -0
  46. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc +0 -0
  47. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc +0 -0
  48. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc +0 -0
  49. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc +0 -0
  50. code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-310.pyc +0 -0
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/README.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reveal_vla_bimanual
2
+
3
+ Simulation-first prototype for a language-conditioned bimanual reveal-and-retrieve policy under elastic occlusion.
4
+
5
+ This repo is not a generalist VLA backbone in the RT-2 / OpenVLA / Octo sense. The current contribution is the reveal-state machinery layered on top of a frozen vision-language encoder.
6
+
7
+ This repo is structured around five top-level modules:
8
+
9
+ - `sim_rlbench/`: RLBench2 / PerAct2 wrappers, dataset hooks, camera setup, and benchmark evaluation helpers.
10
+ - `sim_reveal/`: reveal-proxy environments, scripted teachers, and privileged label extraction.
11
+ - `models/`: shared backbone wrappers, multi-view fusion, bimanual decoder, reveal-state head, world model, and planner.
12
+ - `train/`: trainers, losses, checkpointing, and Hydra/YAML configs.
13
+ - `eval/`: benchmark scripts, ablations, metrics, plots, and report generation.
14
+
15
+ Current bootstrap priorities:
16
+
17
+ 1. Reproduce the RLBench2 / PerAct2 stack with a fixed 3-camera interface.
18
+ 2. Stand up a backbone-only 3-camera policy in the same training/eval harness.
19
+ 3. Add reveal-state supervision and short-horizon planning for synthetic reveal proxies.
20
+
21
+ ## Public benchmark package
22
+
23
+ The repo now includes a concrete public-benchmark package definition for the next-stage fair comparison:
24
+
25
+ - `eval/public_benchmark_package.py`
26
+ - track registry for bag, dense occluded retrieval, cloth retrieval, and the generic anchor
27
+ - same-protocol signatures across `trunk_only`, `adapter_noop`, and `adapter_active`
28
+ - same-data / same-init fairness signatures for `trunk_only_ft` vs `adapter_active_ft`
29
+
30
+ - `eval/run_public_benchmark_package.py`
31
+ - validates normalized result files from multiple public suites
32
+ - checks protocol identity and training fairness
33
+ - aggregates per-track gains, sign-of-life diagnostics, and anchor regressions
34
+
35
+ Write the default manifest to `~/workspace` with:
36
+
37
+ ```bash
38
+ python -m eval.run_public_benchmark_package \
39
+ --write-default-manifest ~/workspace/public_benchmark_package_v1.json
40
+ ```
41
+
42
+ Summarize normalized result files with:
43
+
44
+ ```bash
45
+ python -m eval.run_public_benchmark_package \
46
+ --result /abs/path/result_a.json \
47
+ --result /abs/path/result_b.json \
48
+ --output-dir ~/workspace/reports/public_benchmark_package_v1
49
+ ```
50
+
51
+ Upstream dependencies are kept in `/workspace/third_party` and pinned in `docs/upstream_pins.md`.
52
+
53
+ ## RLBench env A
54
+
55
+ The RLBench / PerAct2 stack is pinned to Python 3.10 and lives in `/workspace/envs/rlbench`.
56
+
57
+ Bring it up with:
58
+
59
+ ```bash
60
+ /workspace/reveal_vla_bimanual/scripts/setup_env_a_rlbench.sh
61
+ /workspace/reveal_vla_bimanual/scripts/setup_rlbench_headless_x.sh
62
+ /workspace/reveal_vla_bimanual/scripts/start_rlbench_x.sh
63
+ ```
64
+
65
+ Verify GPU GL on the headless display:
66
+
67
+ ```bash
68
+ DISPLAY=:99 glxinfo -B
69
+ ```
70
+
71
+ Run the RLBench launch/reset/step smoke test:
72
+
73
+ ```bash
74
+ env \
75
+ DISPLAY=:99 \
76
+ XDG_RUNTIME_DIR=/tmp/runtime-root \
77
+ COPPELIASIM_ROOT=/workspace/assets/coppeliasim_v4_1_0 \
78
+ LD_LIBRARY_PATH=/workspace/system_shims/nvidia$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n1 | cut -d. -f1)/usr/lib/x86_64-linux-gnu:/workspace/system_shims/nvidia$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n1 | cut -d. -f1)/usr/lib/x86_64-linux-gnu/nvidia:/workspace/assets/coppeliasim_v4_1_0 \
79
+ QT_QPA_PLATFORM_PLUGIN_PATH=/workspace/assets/coppeliasim_v4_1_0 \
80
+ /workspace/.tools/micromamba/bin/micromamba run \
81
+ -r /workspace/.micromamba \
82
+ -p /workspace/envs/rlbench \
83
+ python -m sim_rlbench.launch_smoke --headless
84
+ ```
85
+
86
+ The working benchmark interface is fixed to three cameras only:
87
+
88
+ - `front`
89
+ - `wrist_left`
90
+ - `wrist_right`
91
+
92
+ The smoke test covers launch, bimanual task reset, canonical observation extraction, and one bimanual action step in `headless=True`, which is the same mode used by the upstream PerAct2-style training stack.
93
+
94
+ Generate the PerAct2-compatible train command for the fixed 3-camera interface with:
95
+
96
+ ```bash
97
+ micromamba run -r /workspace/.micromamba -p /workspace/envs/rlbench \
98
+ python -m sim_rlbench.smoke_test --print-train-command
99
+ ```
100
+
101
+ Download the published PerAct2 demos into `/workspace/data/rlbench2` with checksum verification:
102
+
103
+ ```bash
104
+ micromamba run -r /workspace/.micromamba -p /workspace/envs/rlbench \
105
+ python -m sim_rlbench.dataset_download --resolution 256 --splits train
106
+ ```
107
+
108
+ If you want the archives unpacked directly into the demo root expected by RLBench, add `--extract`:
109
+
110
+ ```bash
111
+ apt-get install -y squashfs-tools
112
+ micromamba run -r /workspace/.micromamba -p /workspace/envs/rlbench \
113
+ python -m sim_rlbench.dataset_download --resolution 256 --splits train --extract
114
+ ```
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/docs/adapter_stack.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapter Stack
2
+
3
+ This repo now contains a no-op-safe `trunk + adapter` path alongside the legacy monolithic policies.
4
+
5
+ ## Main classes
6
+
7
+ - `models/policy.py`
8
+ - `FoundationTrunkPolicy`
9
+ - `ElasticOcclusionAdapter`
10
+ - `AdapterWrappedPolicy`
11
+
12
+ - `models/backbones.py`
13
+ - `NoOpAdapterCompatibleTrunkOutput`
14
+ - `TrunkInterface`
15
+
16
+ - `models/action_decoder.py`
17
+ - `TaskRoutedProposalPrior`
18
+
19
+ - `models/planner.py`
20
+ - `ElasticFeasibilityGate`
21
+ - `ResidualActionReranker`
22
+ - `AdapterPlanner`
23
+
24
+ - `models/world_model.py`
25
+ - `LightweightRevealStateTransitionModel`
26
+
27
+ - `models/observation_memory.py`
28
+ - `RevealStateCache`
29
+
30
+ ## Trainer modes
31
+
32
+ `train/trainer.py` now supports:
33
+
34
+ - `policy_type: adapter_wrapped`
35
+ - `policy_type: foundation_trunk`
36
+
37
+ Relevant trainer fields:
38
+
39
+ - `training_regime`
40
+ - `eval_mode`
41
+ - `adapter_mode`
42
+ - `adapter_use_transition_model`
43
+ - `adapter_use_task_conditioning`
44
+
45
+ ## Guardrail tests
46
+
47
+ New tests:
48
+
49
+ - `tests/test_trunk_noop_equivalence.py`
50
+ - `tests/test_adapter_gate_blocks_unsafe_retrieve.py`
51
+ - `tests/test_task_specific_loss_masking.py`
52
+ - `tests/test_cloth_specific_metrics_affect_selection.py`
53
+ - `tests/test_general_eval_protocol_is_identical.py`
54
+
55
+ ## Config templates
56
+
57
+ - `train/configs/proxy_adapter_wrapped_clip_base.yaml`
58
+ - `train/configs/proxy_adapter_wrapped_clip_rank_only.yaml`
59
+ - `train/configs/proxy_adapter_wrapped_clip_noop_eval.yaml`
60
+
61
+ ## Benchmark wrappers
62
+
63
+ - `scripts/run_anchor_adapter_ablations.sh`
64
+ - `scripts/run_proxy_adapter_ablations.sh`
65
+ - `scripts/run_target_like_adapter_subset.sh`
66
+ - `eval/public_benchmark_package.py`
67
+ - `eval/run_public_benchmark_package.py`
68
+
69
+ All new configs and scripts default to `~/workspace` outputs and reports.
70
+
71
+ ## Public benchmark package
72
+
73
+ The public benchmark package is the current fair-comparison contract for real benchmarks:
74
+
75
+ - target tracks:
76
+ - `bag_track` -> `BEHAVIOR-1K/unpacking_childs_bag-0`
77
+ - `occlusion_track` -> `ManiSkill/PickClutterYCB-v1`
78
+ - `cloth_track` -> `GarmentLab/grasp_protocol_stacked_garment`
79
+ - anchor track:
80
+ - `anchor_track` -> `AnyBimanual/dual_push_buttons`
81
+
82
+ The package code enforces:
83
+
84
+ - mode-invariant eval protocols per track
85
+ - same-data / same-init fairness for `trunk_only_ft` vs `adapter_active_ft`
86
+ - sign-of-life thresholds on intervention and non-base proposal selection
87
+ - no-regression tolerance on the trusted generic anchor
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/docs/public_benchmark_package.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Public Benchmark Package
2
+
3
+ This repo now contains a concrete public-benchmark package for the real-sim phase.
4
+
5
+ ## Tracks
6
+
7
+ - `bag_track`
8
+ - suite: `BEHAVIOR-1K`
9
+ - task: `unpacking_childs_bag-0`
10
+ - `occlusion_track`
11
+ - suite: `ManiSkill 3`
12
+ - task: `PickClutterYCB-v1`
13
+ - `cloth_track`
14
+ - suite: `GarmentLab`
15
+ - task slice: `grasp_protocol_stacked_garment`
16
+ - `anchor_track`
17
+ - suite: `AnyBimanual`
18
+ - task: `dual_push_buttons`
19
+
20
+ ## Enforced fairness
21
+
22
+ - `trunk_only_ft` and `adapter_active_ft` must share:
23
+ - train demos
24
+ - val demos
25
+ - init checkpoint group
26
+ - optimizer
27
+ - LR schedule
28
+ - batch size
29
+ - augmentations
30
+ - early stopping metric
31
+ - max gradient steps
32
+ - unfrozen trunk scope
33
+ - dataset split id
34
+ - all modes on a track must share the same eval protocol signature
35
+ - anchor regressions are bounded by an absolute tolerance of `0.02`
36
+
37
+ ## Normalized result schema
38
+
39
+ Each external benchmark run should be converted to one JSON object with:
40
+
41
+ - `track_id`
42
+ - `adapter_mode`
43
+ - `successes` or `success_rate`
44
+ - `episodes`
45
+ - `seed`
46
+ - `eval_protocol`
47
+ - for target tracks: `train_spec`
48
+ - optional diagnostics:
49
+ - `intervention_rate`
50
+ - `non_base_selection_rate`
51
+ - `steps_to_first_reveal_or_access`
52
+ - `steps_to_retrieve`
53
+ - `disturbance_proxy`
54
+
55
+ ## Commands
56
+
57
+ Write the default manifest:
58
+
59
+ ```bash
60
+ python -m eval.run_public_benchmark_package \
61
+ --write-default-manifest ~/workspace/public_benchmark_package_v1.json
62
+ ```
63
+
64
+ Summarize results:
65
+
66
+ ```bash
67
+ python -m eval.run_public_benchmark_package \
68
+ --result /abs/path/behavior_bag_adapter_active_seed17.json \
69
+ --result /abs/path/behavior_bag_trunk_seed17.json \
70
+ --result /abs/path/maniskill_occlusion_adapter_active_seed17.json \
71
+ --result /abs/path/maniskill_occlusion_trunk_seed17.json \
72
+ --output-dir ~/workspace/reports/public_benchmark_package_v1
73
+ ```
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (288 Bytes). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (322 Bytes). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (22.3 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/protocols.cpython-310.pyc ADDED
Binary file (1.52 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/public_benchmark_package.cpython-310.pyc ADDED
Binary file (8.38 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/public_benchmark_package.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/report.cpython-310.pyc ADDED
Binary file (1.79 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/report.cpython-311.pyc ADDED
Binary file (3.36 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_public_benchmark_package.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_public_benchmark_package.cpython-311.pyc ADDED
Binary file (24 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_reveal_benchmark.cpython-310.pyc ADDED
Binary file (29.2 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_reveal_benchmark.cpython-311.pyc ADDED
Binary file (63.1 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/public_benchmark_package.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import asdict, dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Sequence
7
+
8
+
9
+ TARGET_ROLE = "target"
10
+ ANCHOR_ROLE = "anchor"
11
+
12
+ TARGET_TRACK_EVAL_MODES: tuple[str, ...] = (
13
+ "trunk_only_ft",
14
+ "adapter_noop",
15
+ "adapter_active_ft",
16
+ )
17
+ ANCHOR_TRACK_EVAL_MODES: tuple[str, ...] = (
18
+ "trunk_only",
19
+ "adapter_noop",
20
+ "adapter_active",
21
+ )
22
+
23
+ DEFAULT_TARGET_TRAIN_DEMOS = 64
24
+ DEFAULT_TARGET_VAL_DEMOS = 16
25
+ DEFAULT_TARGET_TEST_EPISODES = 100
26
+ DEFAULT_ANCHOR_EPISODES = 25
27
+ DEFAULT_RESOLUTION = 256
28
+ DEFAULT_ANCHOR_TOLERANCE = 0.02
29
+ DEFAULT_SIGN_OF_LIFE_INTERVENTION = 0.15
30
+ DEFAULT_SIGN_OF_LIFE_NON_BASE = 0.15
31
+ DEFAULT_SIGN_OF_LIFE_GAIN = 0.05
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class PublicBenchmarkTrack:
36
+ track_id: str
37
+ suite: str
38
+ benchmark_task: str
39
+ role: str
40
+ task_family: str
41
+ target_behavior: str
42
+ public_source: str
43
+ notes: str = ""
44
+ success_metric: str = "success_rate"
45
+
46
+
47
+ PUBLIC_BENCHMARK_TRACKS: tuple[PublicBenchmarkTrack, ...] = (
48
+ PublicBenchmarkTrack(
49
+ track_id="bag_track",
50
+ suite="behavior1k",
51
+ benchmark_task="unpacking_childs_bag-0",
52
+ role=TARGET_ROLE,
53
+ task_family="bag_retrieval",
54
+ target_behavior="retrieve target objects from an opened compliant bag or backpack",
55
+ public_source="https://behavior.stanford.edu/knowledgebase/tasks/",
56
+ notes=(
57
+ "Closest public bag retrieval task. Treat as the benchmark-standard bag opening / "
58
+ "retrieval slice."
59
+ ),
60
+ ),
61
+ PublicBenchmarkTrack(
62
+ track_id="occlusion_track",
63
+ suite="maniskill3",
64
+ benchmark_task="PickClutterYCB-v1",
65
+ role=TARGET_ROLE,
66
+ task_family="dense_occluded_retrieval",
67
+ target_behavior="retrieve a target object from dense occluding clutter",
68
+ public_source="https://maniskill.readthedocs.io/en/latest/tasks/table_top_gripper/index.html",
69
+ notes=(
70
+ "Closest maintained public occluded retrieval task. Treat as the canopy / dense "
71
+ "occlusion proxy."
72
+ ),
73
+ ),
74
+ PublicBenchmarkTrack(
75
+ track_id="cloth_track",
76
+ suite="garmentlab",
77
+ benchmark_task="grasp_protocol_stacked_garment",
78
+ role=TARGET_ROLE,
79
+ task_family="cloth_retrieval",
80
+ target_behavior="retrieve a hidden or partially covered object from stacked or cluttered garments",
81
+ public_source="https://garmentlab.readthedocs.io/en/latest/tutorial/realworldbenchmark/index.html",
82
+ notes=(
83
+ "Use the GarmentLab grasp protocol in stacked/clutter layouts as the closest public "
84
+ "cloth retrieval benchmark slice."
85
+ ),
86
+ ),
87
+ PublicBenchmarkTrack(
88
+ track_id="anchor_track",
89
+ suite="anybimanual",
90
+ benchmark_task="dual_push_buttons",
91
+ role=ANCHOR_ROLE,
92
+ task_family="generic_anchor",
93
+ target_behavior="generic bimanual control regression anchor",
94
+ public_source="https://arxiv.org/abs/2412.06779",
95
+ notes="Trusted public anchor on this setup. Keep as a no-regression track only.",
96
+ ),
97
+ )
98
+
99
+
100
+ def public_benchmark_tracks(role: str | None = None) -> list[PublicBenchmarkTrack]:
101
+ if role is None:
102
+ return list(PUBLIC_BENCHMARK_TRACKS)
103
+ return [track for track in PUBLIC_BENCHMARK_TRACKS if track.role == role]
104
+
105
+
106
+ def public_track_by_id(track_id: str) -> PublicBenchmarkTrack:
107
+ normalized = str(track_id).strip()
108
+ for track in PUBLIC_BENCHMARK_TRACKS:
109
+ if track.track_id == normalized:
110
+ return track
111
+ raise KeyError(f"Unknown public benchmark track: {track_id!r}")
112
+
113
+
114
+ def expected_eval_modes(track_id: str) -> tuple[str, ...]:
115
+ track = public_track_by_id(track_id)
116
+ if track.role == TARGET_ROLE:
117
+ return TARGET_TRACK_EVAL_MODES
118
+ return ANCHOR_TRACK_EVAL_MODES
119
+
120
+
121
+ def build_public_eval_protocol(
122
+ *,
123
+ track_id: str,
124
+ eval_mode: str,
125
+ seed: int = 17,
126
+ episodes: int | None = None,
127
+ resolution: int = DEFAULT_RESOLUTION,
128
+ cameras: Sequence[str] = ("front", "left_wrist", "right_wrist"),
129
+ ) -> dict[str, Any]:
130
+ track = public_track_by_id(track_id)
131
+ expected = expected_eval_modes(track.track_id)
132
+ mode = str(eval_mode)
133
+ if mode not in expected:
134
+ raise ValueError(f"Unexpected eval mode {mode!r} for track {track.track_id!r}. Expected one of {expected}.")
135
+ if episodes is None:
136
+ episodes = DEFAULT_TARGET_TEST_EPISODES if track.role == TARGET_ROLE else DEFAULT_ANCHOR_EPISODES
137
+ return {
138
+ "track_id": track.track_id,
139
+ "suite": track.suite,
140
+ "benchmark_task": track.benchmark_task,
141
+ "role": track.role,
142
+ "eval_mode": mode,
143
+ "seed": int(seed),
144
+ "episodes": int(episodes),
145
+ "resolution": int(resolution),
146
+ "cameras": tuple(str(camera) for camera in cameras),
147
+ "observation_stack": "rgbd_3cam",
148
+ "action_horizon": 8,
149
+ "action_space": "bimanual_delta_pose",
150
+ "same_test_episodes": True,
151
+ }
152
+
153
+
154
+ def public_protocol_identity_signature(protocol: dict[str, Any]) -> tuple[object, ...]:
155
+ return (
156
+ protocol["track_id"],
157
+ protocol["suite"],
158
+ protocol["benchmark_task"],
159
+ protocol["role"],
160
+ protocol["seed"],
161
+ protocol["episodes"],
162
+ protocol["resolution"],
163
+ tuple(protocol["cameras"]),
164
+ protocol["observation_stack"],
165
+ protocol["action_horizon"],
166
+ protocol["action_space"],
167
+ protocol["same_test_episodes"],
168
+ )
169
+
170
+
171
+ def build_target_training_spec(
172
+ *,
173
+ track_id: str,
174
+ model_variant: str,
175
+ seed: int,
176
+ train_demos: int = DEFAULT_TARGET_TRAIN_DEMOS,
177
+ val_demos: int = DEFAULT_TARGET_VAL_DEMOS,
178
+ init_checkpoint_group: str = "shared_public_trunk",
179
+ optimizer: str = "adamw",
180
+ learning_rate: float = 3e-4,
181
+ lr_schedule: str = "cosine",
182
+ batch_size: int = 32,
183
+ augmentations: str = "matched_rgbd_aug_v1",
184
+ early_stopping_metric: str = "val_success",
185
+ max_gradient_steps: int = 20_000,
186
+ unfreeze_scope: str = "matched_trunk_scope",
187
+ dataset_split_id: str | None = None,
188
+ ) -> dict[str, Any]:
189
+ track = public_track_by_id(track_id)
190
+ if track.role != TARGET_ROLE:
191
+ raise ValueError(f"Target training spec is only valid for target tracks, got {track_id!r}.")
192
+ return {
193
+ "track_id": track.track_id,
194
+ "suite": track.suite,
195
+ "benchmark_task": track.benchmark_task,
196
+ "model_variant": str(model_variant),
197
+ "seed": int(seed),
198
+ "train_demos": int(train_demos),
199
+ "val_demos": int(val_demos),
200
+ "init_checkpoint_group": str(init_checkpoint_group),
201
+ "optimizer": str(optimizer),
202
+ "learning_rate": float(learning_rate),
203
+ "lr_schedule": str(lr_schedule),
204
+ "batch_size": int(batch_size),
205
+ "augmentations": str(augmentations),
206
+ "early_stopping_metric": str(early_stopping_metric),
207
+ "max_gradient_steps": int(max_gradient_steps),
208
+ "unfreeze_scope": str(unfreeze_scope),
209
+ "dataset_split_id": dataset_split_id or f"{track.track_id}_shared_split_seed{int(seed)}",
210
+ "same_data_policy": True,
211
+ "same_init_policy": True,
212
+ }
213
+
214
+
215
+ def training_fairness_signature(spec: dict[str, Any]) -> tuple[object, ...]:
216
+ return (
217
+ spec["track_id"],
218
+ spec["suite"],
219
+ spec["benchmark_task"],
220
+ spec["seed"],
221
+ spec["train_demos"],
222
+ spec["val_demos"],
223
+ spec["init_checkpoint_group"],
224
+ spec["optimizer"],
225
+ spec["learning_rate"],
226
+ spec["lr_schedule"],
227
+ spec["batch_size"],
228
+ spec["augmentations"],
229
+ spec["early_stopping_metric"],
230
+ spec["max_gradient_steps"],
231
+ spec["unfreeze_scope"],
232
+ spec["dataset_split_id"],
233
+ spec["same_data_policy"],
234
+ spec["same_init_policy"],
235
+ )
236
+
237
+
238
+ def default_public_benchmark_manifest() -> dict[str, Any]:
239
+ return {
240
+ "package_name": "public_reveal_retrieve_package_v1",
241
+ "tracks": [asdict(track) for track in PUBLIC_BENCHMARK_TRACKS],
242
+ "target_track_ids": [track.track_id for track in public_benchmark_tracks(TARGET_ROLE)],
243
+ "anchor_track_ids": [track.track_id for track in public_benchmark_tracks(ANCHOR_ROLE)],
244
+ "target_eval_modes": list(TARGET_TRACK_EVAL_MODES),
245
+ "anchor_eval_modes": list(ANCHOR_TRACK_EVAL_MODES),
246
+ "defaults": {
247
+ "target_train_demos": DEFAULT_TARGET_TRAIN_DEMOS,
248
+ "target_val_demos": DEFAULT_TARGET_VAL_DEMOS,
249
+ "target_test_episodes": DEFAULT_TARGET_TEST_EPISODES,
250
+ "anchor_episodes": DEFAULT_ANCHOR_EPISODES,
251
+ "resolution": DEFAULT_RESOLUTION,
252
+ },
253
+ "thresholds": {
254
+ "anchor_tolerance": DEFAULT_ANCHOR_TOLERANCE,
255
+ "sign_of_life_intervention_rate": DEFAULT_SIGN_OF_LIFE_INTERVENTION,
256
+ "sign_of_life_non_base_selection_rate": DEFAULT_SIGN_OF_LIFE_NON_BASE,
257
+ "sign_of_life_success_gain": DEFAULT_SIGN_OF_LIFE_GAIN,
258
+ },
259
+ }
260
+
261
+
262
+ def write_default_public_benchmark_manifest(output_path: str | Path) -> Path:
263
+ path = Path(output_path)
264
+ path.parent.mkdir(parents=True, exist_ok=True)
265
+ path.write_text(json.dumps(default_public_benchmark_manifest(), indent=2, sort_keys=True) + "\n", encoding="utf-8")
266
+ return path
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_maniskill_bridge_retrieval_smoke.py ADDED
@@ -0,0 +1,2037 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import math
6
+ import os
7
+ import sys
8
+ from collections import deque
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Any, Sequence
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ CODE_ROOT = Path(__file__).resolve().parents[1]
19
+ if str(CODE_ROOT) not in sys.path:
20
+ sys.path.insert(0, str(CODE_ROOT))
21
+
22
+
23
+ def _configure_runtime_env() -> None:
24
+ os.environ.setdefault("VK_ICD_FILENAMES", "/workspace/runtime/vulkan/icd.d/nvidia_icd_egl.json")
25
+ os.environ.setdefault("VK_LAYER_PATH", "/workspace/runtime/vulkan/implicit_layer.d")
26
+ os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp/runtime-root")
27
+ os.environ["MS_ASSET_DIR"] = "/workspace/.maniskill"
28
+
29
+
30
+ _configure_runtime_env()
31
+
32
+ from eval.run_maniskill_pickclutter_smoke import (
33
+ DEFAULT_INIT_CHECKPOINT,
34
+ HISTORY_STEPS,
35
+ MAX_MACRO_STEPS,
36
+ NUM_APPROACH_TEMPLATES,
37
+ PROPRIO_DIM,
38
+ ROLL_OUT_HORIZON,
39
+ SMOKE_ADAPTER_CONFIDENCE_THRESHOLD,
40
+ SMOKE_RETRIEVE_ACCESS_THRESHOLD,
41
+ SMOKE_RETRIEVE_PERSISTENCE_THRESHOLD,
42
+ SMOKE_RETRIEVE_REOCCLUSION_THRESHOLD,
43
+ SMOKE_RETRIEVE_SUPPORT_THRESHOLD,
44
+ STATE_METRIC_MASK,
45
+ STATE_SUPERVISION_METRICS,
46
+ SUPPORT_MODE_HOLD,
47
+ SUPPORT_MODE_PASSIVE,
48
+ SUPPORT_MODE_TRANSFER,
49
+ _aggregate_epoch,
50
+ _apply_smoke_planner_overrides,
51
+ _history_stack,
52
+ _init_history_entry,
53
+ _load_checkpoint,
54
+ _load_init_bundle,
55
+ _make_loader,
56
+ _save_training_checkpoint,
57
+ )
58
+ from eval.run_public_benchmark_package import summarize_public_benchmark_package
59
+ from models.action_decoder import ChunkDecoderConfig, TASK_INDEX, semantic_macro_chunk
60
+ from train.losses import LossWeights, compute_total_loss
61
+ from train.run_experiment import _load_init_checkpoint, _move_batch_to_device
62
+ from train.trainer import BimanualTrainer, TrainerConfig, apply_trainable_parameter_prefixes, build_policy
63
+
64
+ import gymnasium as gym # noqa: E402
65
+ import mani_skill.envs # noqa: E402
66
+ from mani_skill.utils.structs.pose import Pose # noqa: E402
67
+
68
+
69
+ WORKSPACE_ROOT = Path("/workspace/workspace")
70
+ SMOKE_VERSION = "bridge_smoke_v1"
71
+ CAMERA_NAMES = ("front", "left", "right")
72
+ IMAGE_RESOLUTION = 224
73
+ DEFAULT_SEED = 17
74
+ VIEW_VISIBILITY_SCALE = 0.0125
75
+ CLOTH_HIDDEN_SETTLE_STEPS = 25
76
+ CLOTH_SUCCESS_MIN_Y_DELTA = 0.10
77
+ CLOTH_SUCCESS_MIN_PLANAR_DELTA = 0.10
78
+ CLOTH_SUCCESS_MIN_VISIBILITY = 0.45
79
+ CLOTH_FIXED_SOURCE_X = -0.235
80
+ CLOTH_FIXED_SOURCE_Y = -0.094
81
+ CLOTH_FIXED_SOURCE_Z = 0.8748
82
+ CLOTH_FIXED_COVER_X = -0.235
83
+ CLOTH_FIXED_COVER_Y = -0.075
84
+ CLOTH_FIXED_COVER_Z = 0.885
85
+ EXPECTED_PROPOSAL_CANDIDATES = ChunkDecoderConfig().num_candidates
86
+
87
+
88
+ @dataclass(frozen=True)
89
+ class SmokePaths:
90
+ data_dir: Path
91
+ output_dir: Path
92
+ report_dir: Path
93
+
94
+
95
+ @dataclass(frozen=True)
96
+ class SmokeSpec:
97
+ resolution: int = IMAGE_RESOLUTION
98
+ train_episodes: int = 32
99
+ val_episodes: int = 8
100
+ eval_episodes: int = 50
101
+ dataset_seed: int = DEFAULT_SEED
102
+ train_seed: int = DEFAULT_SEED
103
+ history_steps: int = HISTORY_STEPS
104
+ max_macro_steps: int = MAX_MACRO_STEPS
105
+ batch_size: int = 4
106
+ epochs: int = 6
107
+ num_workers: int = 16
108
+ learning_rate: float = 1e-4
109
+ weight_decay: float = 1e-4
110
+
111
+ @property
112
+ def seed(self) -> int:
113
+ return self.train_seed
114
+
115
+
116
+ @dataclass(frozen=True)
117
+ class BridgeTaskSpec:
118
+ key: str
119
+ env_id: str
120
+ track_id: str
121
+ suite: str
122
+ benchmark_task: str
123
+ task_name: str
124
+ text_prompt: str
125
+ mode_order: tuple[str, ...]
126
+ reveal_modes: tuple[str, ...]
127
+ transfer_modes: tuple[str, ...]
128
+ retrieve_modes: tuple[str, ...]
129
+ notes: str
130
+
131
+
132
+ TASK_SPECS: dict[str, BridgeTaskSpec] = {
133
+ "bag": BridgeTaskSpec(
134
+ key="bag",
135
+ env_id="PutEggplantInBasketScene-v1",
136
+ track_id="bag_track",
137
+ suite="maniskill3",
138
+ benchmark_task="PutEggplantInBasketRetrievalProxy-v1",
139
+ task_name="bag",
140
+ text_prompt="retrieve the target object from inside the basket and stage it outside the basket",
141
+ mode_order=(
142
+ "base_action",
143
+ "pin_left_rim",
144
+ "pin_right_rim",
145
+ "widen_mouth",
146
+ "maintain_mouth",
147
+ "probe_inside",
148
+ "insert_actor",
149
+ "retrieve",
150
+ ),
151
+ reveal_modes=("pin_left_rim", "pin_right_rim", "widen_mouth", "maintain_mouth", "probe_inside"),
152
+ transfer_modes=("insert_actor",),
153
+ retrieve_modes=("retrieve",),
154
+ notes=(
155
+ "Public ManiSkill bridge scene with custom retrieval initialization. The eggplant is placed inside the "
156
+ "basket region and must be pulled back out to a staging zone."
157
+ ),
158
+ ),
159
+ "cloth": BridgeTaskSpec(
160
+ key="cloth",
161
+ env_id="PutSpoonOnTableClothInScene-v1",
162
+ track_id="cloth_track",
163
+ suite="maniskill3",
164
+ benchmark_task="PutSpoonUnderClothRetrievalProxy-v1",
165
+ task_name="cloth",
166
+ text_prompt="reveal the spoon from under the cloth and retrieve it to the open area",
167
+ mode_order=(
168
+ "base_action",
169
+ "lift_edge",
170
+ "separate_layer",
171
+ "stabilize_fold",
172
+ "maintain_lift",
173
+ "insert_actor",
174
+ "retrieve",
175
+ ),
176
+ reveal_modes=("lift_edge", "separate_layer", "stabilize_fold", "maintain_lift"),
177
+ transfer_modes=("insert_actor",),
178
+ retrieve_modes=("retrieve",),
179
+ notes=(
180
+ "Public ManiSkill bridge scene with custom retrieval initialization. The spoon is placed under the "
181
+ "cloth region and must be revealed and extracted to the open side of the table."
182
+ ),
183
+ ),
184
+ }
185
+
186
+
187
+ def _task_spec(task: str) -> BridgeTaskSpec:
188
+ normalized = str(task).strip().lower()
189
+ if normalized not in TASK_SPECS:
190
+ raise KeyError(f"Unsupported task {task!r}. Expected one of {sorted(TASK_SPECS)}.")
191
+ return TASK_SPECS[normalized]
192
+
193
+
194
+ def _default_paths(task_spec: BridgeTaskSpec) -> SmokePaths:
195
+ return SmokePaths(
196
+ data_dir=WORKSPACE_ROOT / "data" / "maniskill_bridge_retrieval" / f"{task_spec.key}_{SMOKE_VERSION}",
197
+ output_dir=WORKSPACE_ROOT / "outputs" / f"maniskill_{task_spec.key}_{SMOKE_VERSION}",
198
+ report_dir=WORKSPACE_ROOT / "reports" / f"maniskill_{task_spec.key}_{SMOKE_VERSION}",
199
+ )
200
+
201
+
202
+ def _dataset_artifact_path(data_dir: Path, basename: str, *, dataset_seed: int) -> Path:
203
+ if int(dataset_seed) == DEFAULT_SEED:
204
+ return data_dir / basename
205
+ artifact = Path(basename)
206
+ return data_dir / f"{artifact.stem}_seed{int(dataset_seed)}{artifact.suffix}"
207
+
208
+
209
+ def _np(value: Any, *, dtype: np.dtype | None = None) -> np.ndarray:
210
+ if isinstance(value, np.ndarray):
211
+ array = value
212
+ elif isinstance(value, torch.Tensor):
213
+ array = value.detach().cpu().numpy()
214
+ else:
215
+ array = np.asarray(value)
216
+ if dtype is not None:
217
+ array = array.astype(dtype, copy=False)
218
+ return array
219
+
220
+
221
+ def _vec3(value: Any) -> np.ndarray:
222
+ return _np(value, dtype=np.float32).reshape(-1)[:3]
223
+
224
+
225
+ def _resize_rgb(rgb: np.ndarray, size: int) -> np.ndarray:
226
+ tensor = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0).float()
227
+ resized = F.interpolate(tensor, size=(size, size), mode="bilinear", align_corners=False)
228
+ return resized[0].permute(1, 2, 0).round().clamp(0, 255).to(dtype=torch.uint8).cpu().numpy()
229
+
230
+
231
+ def _resize_single_channel(image: np.ndarray, size: int, *, dtype: np.dtype) -> np.ndarray:
232
+ tensor = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).float()
233
+ resized = F.interpolate(tensor, size=(size, size), mode="nearest")
234
+ return resized[0, 0].to(dtype=torch.float32).cpu().numpy().astype(dtype, copy=False)
235
+
236
+
237
+ def _camera_intrinsic_from_param(param: dict[str, Any]) -> np.ndarray:
238
+ for key in ("intrinsic_cv", "intrinsic", "cam_intrinsic"):
239
+ if key in param:
240
+ matrix = _np(param[key], dtype=np.float32)
241
+ return matrix[0] if matrix.ndim == 3 else matrix
242
+ return np.eye(3, dtype=np.float32)
243
+
244
+
245
+ def _camera_extrinsic_from_param(param: dict[str, Any]) -> np.ndarray:
246
+ for key in ("cam2world_gl", "cam2world", "extrinsic_cv", "extrinsic"):
247
+ if key in param:
248
+ matrix = _np(param[key], dtype=np.float32)
249
+ return matrix[0] if matrix.ndim == 3 else matrix
250
+ return np.eye(4, dtype=np.float32)
251
+
252
+
253
+ def _extract_sensor_bundle(obs: dict[str, Any], *, resolution: int) -> dict[str, np.ndarray]:
254
+ camera_name = next(iter(obs["sensor_data"].keys()))
255
+ view = obs["sensor_data"][camera_name]
256
+ param = obs["sensor_param"][camera_name]
257
+ rgb = _np(view["rgb"], dtype=np.uint8)
258
+ segmentation = _np(view["segmentation"], dtype=np.int16)
259
+ rgb = rgb[0] if rgb.ndim == 4 else rgb
260
+ segmentation = segmentation[0] if segmentation.ndim == 4 else segmentation
261
+ segmentation = segmentation[..., 0] if segmentation.ndim == 3 else segmentation
262
+ rgb_resized = _resize_rgb(rgb, resolution)
263
+ seg_resized = _resize_single_channel(segmentation, resolution, dtype=np.int16)
264
+ intrinsic = _camera_intrinsic_from_param(param)
265
+ extrinsic = _camera_extrinsic_from_param(param)
266
+ images = np.stack([rgb_resized.copy() for _ in CAMERA_NAMES], axis=0)
267
+ segmentations = np.stack([seg_resized.copy() for _ in CAMERA_NAMES], axis=0)
268
+ depths = np.zeros((len(CAMERA_NAMES), 1, resolution, resolution), dtype=np.float32)
269
+ depth_valid = np.zeros_like(depths, dtype=np.float32)
270
+ intrinsics = np.stack([intrinsic.copy() for _ in CAMERA_NAMES], axis=0)
271
+ extrinsics = np.stack([extrinsic.copy() for _ in CAMERA_NAMES], axis=0)
272
+ return {
273
+ "images": images,
274
+ "segmentations": segmentations,
275
+ "depths": depths,
276
+ "depth_valid": depth_valid,
277
+ "camera_intrinsics": intrinsics,
278
+ "camera_extrinsics": extrinsics,
279
+ }
280
+
281
+
282
+ def _build_proprio(env: gym.Env[Any, Any]) -> np.ndarray:
283
+ base = env.unwrapped
284
+ qpos = _np(base.agent.robot.get_qpos(), dtype=np.float32).reshape(-1)
285
+ qvel = _np(base.agent.robot.get_qvel(), dtype=np.float32).reshape(-1)
286
+ ee_pose = base.agent.robot.links_map["ee_gripper_link"].pose
287
+ tcp_pose = np.concatenate([_vec3(ee_pose.p), _np(ee_pose.q, dtype=np.float32).reshape(-1)[:4]], axis=0)
288
+ gripper_width = qpos[-2:].sum(keepdims=True).astype(np.float32)
289
+ flat = np.concatenate([qpos, qvel, tcp_pose, gripper_width], axis=0)
290
+ if flat.shape[0] >= PROPRIO_DIM:
291
+ return flat[:PROPRIO_DIM]
292
+ padded = np.zeros((PROPRIO_DIM,), dtype=np.float32)
293
+ padded[: flat.shape[0]] = flat
294
+ return padded
295
+
296
+
297
+ def _source_actor(env: gym.Env[Any, Any]) -> Any:
298
+ base = env.unwrapped
299
+ return base.objs[base.source_obj_name]
300
+
301
+
302
+ def _target_actor(env: gym.Env[Any, Any]) -> Any:
303
+ base = env.unwrapped
304
+ return base.objs[base.target_obj_name]
305
+
306
+
307
+ def _source_position(env: gym.Env[Any, Any]) -> np.ndarray:
308
+ return _vec3(_source_actor(env).pose.p)
309
+
310
+
311
+ def _target_position(env: gym.Env[Any, Any]) -> np.ndarray:
312
+ return _vec3(_target_actor(env).pose.p)
313
+
314
+
315
+ def _ee_position(env: gym.Env[Any, Any]) -> np.ndarray:
316
+ return _vec3(env.unwrapped.agent.robot.links_map["ee_gripper_link"].pose.p)
317
+
318
+
319
+ def _act_from_world_delta(delta_xyz: Sequence[float]) -> np.ndarray:
320
+ delta = np.asarray(delta_xyz, dtype=np.float32).reshape(3)
321
+ return np.asarray([-delta[0], -delta[1], delta[2]], dtype=np.float32)
322
+
323
+
324
+ def _step_action(env: gym.Env[Any, Any], delta_xyz: Sequence[float], *, grip: float) -> None:
325
+ action = np.zeros((1, 7), dtype=np.float32)
326
+ action[0, :3] = np.clip(_act_from_world_delta(delta_xyz), -0.02, 0.02)
327
+ action[0, 6] = float(np.clip(grip, -1.0, 1.0))
328
+ env.step(action)
329
+
330
+
331
+ def _hold(env: gym.Env[Any, Any], *, steps: int, grip: float) -> None:
332
+ for _ in range(int(steps)):
333
+ action = np.zeros((1, 7), dtype=np.float32)
334
+ action[0, 6] = float(np.clip(grip, -1.0, 1.0))
335
+ env.step(action)
336
+
337
+
338
+ def _move_ee(env: gym.Env[Any, Any], goal_xyz: Sequence[float], *, grip: float, max_steps: int = 60, tol: float = 0.003) -> dict[str, Any]:
339
+ goal = np.asarray(goal_xyz, dtype=np.float32).reshape(3)
340
+ for _ in range(int(max_steps)):
341
+ ee = _ee_position(env)
342
+ delta = goal - ee
343
+ if float(np.linalg.norm(delta)) <= float(tol):
344
+ break
345
+ _step_action(env, delta, grip=grip)
346
+ return {"ee_position": _ee_position(env)}
347
+
348
+
349
+ def _repeat_world_delta(env: gym.Env[Any, Any], delta_xyz: Sequence[float], *, grip: float, steps: int) -> None:
350
+ delta = np.asarray(delta_xyz, dtype=np.float32).reshape(3)
351
+ for _ in range(int(steps)):
352
+ _step_action(env, delta, grip=grip)
353
+
354
+
355
+ def _snapshot_env(env: gym.Env[Any, Any]) -> dict[str, Any]:
356
+ return {"state_dict": env.unwrapped.get_state_dict()}
357
+
358
+
359
+ def _restore_env(env: gym.Env[Any, Any], snapshot: dict[str, Any]) -> None:
360
+ env.unwrapped.set_state_dict(snapshot["state_dict"])
361
+
362
+
363
+ def _sync_env_state(src_env: gym.Env[Any, Any], dst_env: gym.Env[Any, Any]) -> None:
364
+ _restore_env(dst_env, _snapshot_env(src_env))
365
+
366
+
367
+ def _canonical_chunks(task_spec: BridgeTaskSpec) -> dict[str, np.ndarray]:
368
+ base = torch.zeros((1, 8, 14), dtype=torch.float32)
369
+ chunks = {"base_action": base.squeeze(0).numpy().astype(np.float32)}
370
+ for mode_name in task_spec.mode_order[1:]:
371
+ chunk = semantic_macro_chunk(base, task_name=task_spec.task_name, mode_name=mode_name).squeeze(0).cpu().numpy()
372
+ chunks[mode_name] = chunk.astype(np.float32)
373
+ return chunks
374
+
375
+
376
+ def _classify_mode_from_chunk(chunk: np.ndarray, canonical_chunks: dict[str, np.ndarray]) -> str:
377
+ candidate = np.asarray(chunk, dtype=np.float32)
378
+ distances = {
379
+ mode_name: float(np.mean(np.abs(candidate - prototype)))
380
+ for mode_name, prototype in canonical_chunks.items()
381
+ }
382
+ return min(distances, key=distances.get)
383
+
384
+
385
+ def _rng_for_seed(seed: int) -> np.random.Generator:
386
+ return np.random.default_rng(int(seed) + 31)
387
+
388
+
389
+ def _initialize_proxy_state(env: gym.Env[Any, Any], task_spec: BridgeTaskSpec, *, episode_seed: int) -> None:
390
+ base = env.unwrapped
391
+ rng = _rng_for_seed(episode_seed)
392
+ source = _source_actor(env)
393
+ source_pose = source.pose
394
+ source_q = _np(source_pose.q, dtype=np.float32).reshape(-1)[:4]
395
+ if task_spec.key == "bag":
396
+ center = _target_position(env)
397
+ start = center + np.asarray(
398
+ [
399
+ rng.uniform(-0.006, 0.006),
400
+ rng.uniform(-0.010, 0.004),
401
+ 0.010 + rng.uniform(-0.002, 0.002),
402
+ ],
403
+ dtype=np.float32,
404
+ )
405
+ else:
406
+ cover = _target_actor(env)
407
+ start = np.asarray([CLOTH_FIXED_SOURCE_X, CLOTH_FIXED_SOURCE_Y, CLOTH_FIXED_SOURCE_Z], dtype=np.float32)
408
+ source.set_pose(Pose.create_from_pq(p=start, q=source_q))
409
+ if task_spec.key == "cloth":
410
+ _hold(env, steps=8, grip=1.0)
411
+ cover_pose = cover.pose
412
+ cover_q = _np(cover_pose.q, dtype=np.float32).reshape(-1)[:4]
413
+ cover_start = np.asarray([CLOTH_FIXED_COVER_X, CLOTH_FIXED_COVER_Y, CLOTH_FIXED_COVER_Z], dtype=np.float32)
414
+ cover.set_pose(Pose.create_from_pq(p=cover_start, q=cover_q))
415
+ _hold(env, steps=CLOTH_HIDDEN_SETTLE_STEPS, grip=1.0)
416
+ return
417
+ _hold(env, steps=5, grip=1.0)
418
+
419
+
420
+ def _source_visibility(obs_bundle: dict[str, np.ndarray], actor_id: int) -> float:
421
+ seg = obs_bundle["segmentations"]
422
+ match = float(np.mean(seg == int(actor_id)))
423
+ return float(np.clip(match / VIEW_VISIBILITY_SCALE, 0.0, 1.0))
424
+
425
+
426
+ def _all_positions(env: gym.Env[Any, Any], task_spec: BridgeTaskSpec) -> dict[str, np.ndarray]:
427
+ positions = {"source": _source_position(env), "target": _target_position(env)}
428
+ if task_spec.key == "cloth":
429
+ positions["cover"] = _target_position(env)
430
+ return positions
431
+
432
+
433
+ def _bag_progress(env: gym.Env[Any, Any]) -> float:
434
+ source = _source_position(env)
435
+ center = _target_position(env)
436
+ x_shift = max(source[0] - center[0], 0.0)
437
+ y_pull = max(center[1] - source[1], 0.0)
438
+ z_lift = max(source[2] - center[2], 0.0)
439
+ planar = float(np.linalg.norm((source - center)[:2]))
440
+ return float(
441
+ np.clip(
442
+ 0.35 * (x_shift / 0.05)
443
+ + 0.30 * (y_pull / 0.18)
444
+ + 0.20 * (z_lift / 0.12)
445
+ + 0.15 * (planar / 0.12),
446
+ 0.0,
447
+ 1.0,
448
+ )
449
+ )
450
+
451
+
452
+ def _bag_success(env: gym.Env[Any, Any]) -> bool:
453
+ source = _source_position(env)
454
+ center = _target_position(env)
455
+ planar = float(np.linalg.norm((source - center)[:2]))
456
+ return bool(
457
+ planar >= 0.035
458
+ and (
459
+ source[0] >= center[0] + 0.020
460
+ or source[1] <= center[1] - 0.050
461
+ or source[2] >= center[2] + 0.050
462
+ )
463
+ )
464
+
465
+
466
+ def _cloth_progress(
467
+ env: gym.Env[Any, Any],
468
+ *,
469
+ start_positions: dict[str, np.ndarray],
470
+ current_visibility: float,
471
+ ) -> float:
472
+ source = _source_position(env)
473
+ source_start = start_positions["source"]
474
+ y_shift = max(source[1] - source_start[1], 0.0)
475
+ planar = float(np.linalg.norm((source - source_start)[:2]))
476
+ return float(np.clip(0.45 * (y_shift / 0.16) + 0.35 * (planar / 0.16) + 0.20 * current_visibility, 0.0, 1.0))
477
+
478
+
479
+ def _cloth_success(
480
+ env: gym.Env[Any, Any],
481
+ *,
482
+ start_positions: dict[str, np.ndarray],
483
+ current_visibility: float,
484
+ ) -> bool:
485
+ source = _source_position(env)
486
+ source_start = start_positions["source"]
487
+ planar = float(np.linalg.norm((source - source_start)[:2]))
488
+ return bool(
489
+ source[1] >= source_start[1] + CLOTH_SUCCESS_MIN_Y_DELTA
490
+ and planar >= CLOTH_SUCCESS_MIN_PLANAR_DELTA
491
+ and current_visibility >= CLOTH_SUCCESS_MIN_VISIBILITY
492
+ )
493
+
494
+
495
+ def _candidate_metrics(
496
+ env: gym.Env[Any, Any],
497
+ *,
498
+ task_spec: BridgeTaskSpec,
499
+ start_positions: dict[str, np.ndarray],
500
+ current_obs_bundle: dict[str, np.ndarray] | None = None,
501
+ ) -> dict[str, float]:
502
+ source_actor = _source_actor(env)
503
+ actor_id = int(getattr(source_actor, "per_scene_id", -1))
504
+ visibility = 0.0
505
+ if current_obs_bundle is not None:
506
+ visibility = _source_visibility(current_obs_bundle, actor_id)
507
+ if task_spec.key == "bag":
508
+ progress = _bag_progress(env)
509
+ success = float(_bag_success(env))
510
+ disturbance = 0.0
511
+ access = float(np.clip(0.65 * progress + 0.35 * visibility, 0.0, 1.0))
512
+ else:
513
+ progress = _cloth_progress(env, start_positions=start_positions, current_visibility=visibility)
514
+ success = float(_cloth_success(env, start_positions=start_positions, current_visibility=visibility))
515
+ cloth_start = start_positions["cover"]
516
+ cloth_now = _target_position(env)
517
+ cloth_displacement = float(np.linalg.norm((cloth_now - cloth_start)[:2]))
518
+ disturbance = float(np.clip(max(cloth_displacement - 0.24, 0.0) / 0.14, 0.0, 1.0))
519
+ access = float(np.clip(0.55 * progress + 0.45 * visibility, 0.0, 1.0))
520
+ return {
521
+ "retrieval_success": success,
522
+ "disturbance": disturbance,
523
+ "visibility": visibility,
524
+ "clearance": access,
525
+ "progress": progress,
526
+ }
527
+
528
+
529
+ def _execute_bag_mode(env: gym.Env[Any, Any], mode_name: str) -> None:
530
+ center = _target_position(env)
531
+ source = _source_position(env)
532
+ if mode_name == "retrieve":
533
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.08], dtype=np.float32), grip=1.0)
534
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.010], dtype=np.float32), grip=1.0, tol=0.002)
535
+ _hold(env, steps=10, grip=-1.0)
536
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.12], dtype=np.float32), grip=-1.0)
537
+ _move_ee(env, source + np.asarray([0.0, -0.18, 0.12], dtype=np.float32), grip=-1.0)
538
+ return
539
+ if mode_name == "insert_actor":
540
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.06], dtype=np.float32), grip=1.0)
541
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.018], dtype=np.float32), grip=1.0, tol=0.002)
542
+ _hold(env, steps=4, grip=1.0)
543
+ return
544
+ if mode_name == "probe_inside":
545
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.05], dtype=np.float32), grip=1.0)
546
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.026], dtype=np.float32), grip=1.0, tol=0.002)
547
+ _repeat_world_delta(env, np.asarray([0.010, -0.004, 0.0], dtype=np.float32), grip=1.0, steps=8)
548
+ return
549
+ if mode_name == "widen_mouth":
550
+ _move_ee(env, center + np.asarray([-0.04, 0.01, 0.06], dtype=np.float32), grip=1.0)
551
+ _move_ee(env, center + np.asarray([-0.03, 0.01, 0.028], dtype=np.float32), grip=1.0, tol=0.003)
552
+ _repeat_world_delta(env, np.asarray([0.012, -0.004, 0.0], dtype=np.float32), grip=1.0, steps=12)
553
+ return
554
+ if mode_name == "pin_left_rim":
555
+ _move_ee(env, center + np.asarray([-0.03, 0.01, 0.06], dtype=np.float32), grip=1.0)
556
+ _move_ee(env, center + np.asarray([-0.03, 0.01, 0.028], dtype=np.float32), grip=1.0, tol=0.003)
557
+ _repeat_world_delta(env, np.asarray([0.006, -0.003, 0.0], dtype=np.float32), grip=1.0, steps=8)
558
+ return
559
+ if mode_name == "pin_right_rim":
560
+ _move_ee(env, center + np.asarray([0.03, 0.01, 0.06], dtype=np.float32), grip=1.0)
561
+ _move_ee(env, center + np.asarray([0.03, 0.01, 0.028], dtype=np.float32), grip=1.0, tol=0.003)
562
+ _repeat_world_delta(env, np.asarray([-0.006, -0.003, 0.0], dtype=np.float32), grip=1.0, steps=8)
563
+ return
564
+ if mode_name in {"maintain_mouth", "base_action"}:
565
+ _move_ee(env, center + np.asarray([0.0, 0.0, 0.09], dtype=np.float32), grip=1.0, max_steps=30, tol=0.006)
566
+ _hold(env, steps=3, grip=1.0)
567
+ return
568
+ raise KeyError(f"Unsupported bag mode {mode_name!r}.")
569
+
570
+
571
+ def _execute_cloth_mode(env: gym.Env[Any, Any], mode_name: str) -> None:
572
+ cloth = _target_position(env)
573
+ source = _source_position(env)
574
+ if mode_name == "retrieve":
575
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.05], dtype=np.float32), grip=1.0)
576
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.010], dtype=np.float32), grip=1.0, tol=0.002)
577
+ _hold(env, steps=10, grip=-1.0)
578
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.10], dtype=np.float32), grip=-1.0)
579
+ _move_ee(env, source + np.asarray([0.0, 0.16, 0.10], dtype=np.float32), grip=-1.0)
580
+ return
581
+ if mode_name == "insert_actor":
582
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.05], dtype=np.float32), grip=1.0)
583
+ _move_ee(env, source + np.asarray([0.0, 0.0, 0.018], dtype=np.float32), grip=1.0, tol=0.002)
584
+ _hold(env, steps=4, grip=1.0)
585
+ return
586
+ if mode_name == "lift_edge":
587
+ _move_ee(env, cloth + np.asarray([0.0, -0.03, 0.05], dtype=np.float32), grip=1.0)
588
+ _move_ee(env, cloth + np.asarray([0.0, -0.03, 0.015], dtype=np.float32), grip=1.0, tol=0.003)
589
+ _repeat_world_delta(env, np.asarray([0.0, 0.006, 0.0], dtype=np.float32), grip=1.0, steps=8)
590
+ return
591
+ if mode_name == "separate_layer":
592
+ _move_ee(env, cloth + np.asarray([-0.04, 0.0, 0.05], dtype=np.float32), grip=1.0)
593
+ _move_ee(env, cloth + np.asarray([-0.04, 0.0, 0.015], dtype=np.float32), grip=1.0, tol=0.003)
594
+ _repeat_world_delta(env, np.asarray([0.008, 0.002, 0.0], dtype=np.float32), grip=1.0, steps=10)
595
+ return
596
+ if mode_name == "stabilize_fold":
597
+ _move_ee(env, cloth + np.asarray([0.0, 0.03, 0.05], dtype=np.float32), grip=1.0)
598
+ _move_ee(env, cloth + np.asarray([0.0, 0.03, 0.015], dtype=np.float32), grip=1.0, tol=0.003)
599
+ _repeat_world_delta(env, np.asarray([0.0, -0.006, 0.0], dtype=np.float32), grip=1.0, steps=8)
600
+ return
601
+ if mode_name in {"maintain_lift", "base_action"}:
602
+ _move_ee(env, cloth + np.asarray([0.0, 0.06, 0.07], dtype=np.float32), grip=1.0, max_steps=30, tol=0.006)
603
+ _hold(env, steps=3, grip=1.0)
604
+ return
605
+ raise KeyError(f"Unsupported cloth mode {mode_name!r}.")
606
+
607
+
608
+ def _execute_mode(env: gym.Env[Any, Any], task_spec: BridgeTaskSpec, mode_name: str) -> None:
609
+ if task_spec.key == "bag":
610
+ _execute_bag_mode(env, mode_name)
611
+ return
612
+ _execute_cloth_mode(env, mode_name)
613
+
614
+
615
+ def _mode_support_mode(task_spec: BridgeTaskSpec, mode_name: str, current_support_mode: int) -> int:
616
+ if mode_name in task_spec.reveal_modes:
617
+ return SUPPORT_MODE_HOLD
618
+ if mode_name in task_spec.transfer_modes:
619
+ return SUPPORT_MODE_TRANSFER
620
+ if mode_name in task_spec.retrieve_modes:
621
+ return SUPPORT_MODE_PASSIVE
622
+ return int(current_support_mode)
623
+
624
+
625
+ def _mode_progress_schedule(task_spec: BridgeTaskSpec, mode_name: str) -> np.ndarray:
626
+ if mode_name in task_spec.reveal_modes:
627
+ return np.asarray([0.18, 0.38, 0.62, 0.84, 1.0], dtype=np.float32)
628
+ if mode_name in task_spec.transfer_modes:
629
+ return np.asarray([0.22, 0.44, 0.66, 0.86, 1.0], dtype=np.float32)
630
+ if mode_name in task_spec.retrieve_modes:
631
+ return np.asarray([0.34, 0.56, 0.76, 0.92, 1.0], dtype=np.float32)
632
+ return np.asarray([0.10, 0.22, 0.34, 0.44, 0.54], dtype=np.float32)
633
+
634
+
635
+ def _scalar_rollout(start: float, end: float, schedule: np.ndarray) -> np.ndarray:
636
+ return np.clip((1.0 - schedule) * float(start) + schedule * float(end), 0.0, 1.0).astype(np.float32)
637
+
638
+
639
+ def _current_state_targets(
640
+ task_spec: BridgeTaskSpec,
641
+ *,
642
+ obs_bundle: dict[str, np.ndarray],
643
+ candidate_metrics: Sequence[dict[str, float]],
644
+ episode_start_positions: dict[str, np.ndarray],
645
+ selected_mode: str,
646
+ env: gym.Env[Any, Any],
647
+ ) -> dict[str, Any]:
648
+ metrics_by_name = {mode_name: payload for mode_name, payload in zip(task_spec.mode_order, candidate_metrics)}
649
+ current_metrics = _candidate_metrics(
650
+ env,
651
+ task_spec=task_spec,
652
+ start_positions=episode_start_positions,
653
+ current_obs_bundle=obs_bundle,
654
+ )
655
+ current_disturbance = float(np.clip(current_metrics["disturbance"], 0.0, 1.0))
656
+ current_visibility = float(np.clip(current_metrics["visibility"], 0.0, 1.0))
657
+ current_clearance = float(np.clip(current_metrics["clearance"], 0.0, 1.0))
658
+ current_progress = float(np.clip(current_metrics["progress"], 0.0, 1.0))
659
+ base_gap = float(np.clip(max(current_clearance, current_progress), 0.0, 1.0))
660
+ support_stability = float(np.clip(1.0 - 0.5 * current_disturbance, 0.0, 1.0))
661
+ hold_quality = float(np.clip(0.5 * (support_stability + max(current_clearance, current_progress)), 0.0, 1.0))
662
+ opening_quality = float(np.clip(0.55 * current_progress + 0.25 * current_clearance + 0.20 * current_visibility, 0.0, 1.0))
663
+ actor_feasibility = float(np.clip(0.6 * current_clearance + 0.4 * max(current_visibility, current_progress), 0.0, 1.0))
664
+ reocclusion_rate = float(np.clip(1.0 - max(current_clearance, current_visibility), 0.0, 1.0))
665
+ insertable_actor_corridor = float(np.clip(0.6 * actor_feasibility + 0.4 * base_gap, 0.0, 1.0))
666
+ insertion_corridor = float(np.clip(0.5 * actor_feasibility + 0.5 * base_gap, 0.0, 1.0))
667
+ layer_separation = float(np.clip(0.7 * base_gap + 0.3 * actor_feasibility, 0.0, 1.0))
668
+ fold_preservation = float(np.clip(1.0 - current_disturbance, 0.0, 1.0))
669
+ lift_too_much_risk = float(np.clip(current_disturbance + 0.5 * max(base_gap - 0.5, 0.0), 0.0, 1.0))
670
+ task_metrics = {
671
+ "opening_quality": opening_quality,
672
+ "actor_feasibility_score": actor_feasibility,
673
+ "gap_width": float(0.03 + 0.21 * base_gap),
674
+ "damage_proxy": current_disturbance,
675
+ "release_collapse_rate": reocclusion_rate,
676
+ "target_visibility_confidence": current_visibility,
677
+ "insertable_actor_corridor": insertable_actor_corridor,
678
+ "insertion_corridor": insertion_corridor,
679
+ "hold_quality": hold_quality,
680
+ "layer_separation_quality": layer_separation,
681
+ "fold_preservation": fold_preservation,
682
+ "top_layer_stability": support_stability,
683
+ "lift_too_much_risk": lift_too_much_risk,
684
+ }
685
+
686
+ base_metrics = metrics_by_name["base_action"]
687
+ insert_metrics = metrics_by_name["insert_actor"]
688
+ retrieve_metrics = metrics_by_name["retrieve"]
689
+ reveal_candidates = [metrics_by_name[mode_name] for mode_name in task_spec.reveal_modes]
690
+ reveal_access = max(candidate["candidate_actor_feasibility_auc"] for candidate in reveal_candidates)
691
+ reveal_reveal = max(candidate["candidate_reveal_achieved"] for candidate in reveal_candidates)
692
+ reveal_hold = max(candidate["candidate_hold_persistence"] for candidate in reveal_candidates)
693
+ reveal_visibility = max(candidate["candidate_visibility_integral"] for candidate in reveal_candidates)
694
+
695
+ reveal_corridor = float(
696
+ np.clip(
697
+ 0.45 * opening_quality
698
+ + 0.30 * reveal_access
699
+ + 0.15 * reveal_reveal
700
+ + 0.10 * reveal_visibility
701
+ - 0.10 * current_disturbance,
702
+ 0.0,
703
+ 1.0,
704
+ )
705
+ )
706
+ transfer_corridor = float(
707
+ np.clip(
708
+ 0.45 * insertable_actor_corridor
709
+ + 0.30 * insert_metrics["candidate_actor_feasibility_auc"]
710
+ + 0.15 * insert_metrics["candidate_reveal_achieved"]
711
+ + 0.10 * insert_metrics["candidate_visibility_integral"]
712
+ - 0.15 * current_disturbance,
713
+ 0.0,
714
+ 1.0,
715
+ )
716
+ )
717
+ passive_corridor = float(
718
+ np.clip(
719
+ 0.55 * retrieve_metrics["candidate_retrieval_success"]
720
+ + 0.20 * retrieve_metrics["candidate_actor_feasibility_auc"]
721
+ + 0.15 * current_progress
722
+ + 0.10 * current_clearance
723
+ - 0.10 * current_disturbance,
724
+ 0.0,
725
+ 1.0,
726
+ )
727
+ )
728
+ corridor_feasible = np.stack(
729
+ [
730
+ np.full((NUM_APPROACH_TEMPLATES,), reveal_corridor, dtype=np.float32),
731
+ np.full((NUM_APPROACH_TEMPLATES,), transfer_corridor, dtype=np.float32),
732
+ np.full((NUM_APPROACH_TEMPLATES,), passive_corridor, dtype=np.float32),
733
+ ],
734
+ axis=0,
735
+ )
736
+ persistence_horizon = np.asarray(
737
+ [
738
+ ROLL_OUT_HORIZON * float(np.clip(0.35 * hold_quality + 0.35 * reveal_hold + 0.30 * reveal_corridor, 0.0, 1.0)),
739
+ ROLL_OUT_HORIZON
740
+ * float(
741
+ np.clip(
742
+ 0.30 * hold_quality + 0.35 * insert_metrics["candidate_hold_persistence"] + 0.35 * transfer_corridor,
743
+ 0.0,
744
+ 1.0,
745
+ )
746
+ ),
747
+ ROLL_OUT_HORIZON
748
+ * float(
749
+ np.clip(
750
+ 0.25 * hold_quality + 0.35 * retrieve_metrics["candidate_hold_persistence"] + 0.40 * passive_corridor,
751
+ 0.0,
752
+ 1.0,
753
+ )
754
+ ),
755
+ ],
756
+ dtype=np.float32,
757
+ )
758
+ retrieve_margin = float(retrieve_metrics["candidate_utility"] - base_metrics["candidate_utility"])
759
+ insert_margin = float(insert_metrics["candidate_utility"] - base_metrics["candidate_utility"])
760
+ if selected_mode == "retrieve" or (retrieve_metrics["candidate_retrieval_success"] >= 0.5 and retrieve_margin >= 0.12):
761
+ support_mode = SUPPORT_MODE_PASSIVE
762
+ elif selected_mode == "insert_actor" or (insert_margin >= 0.12 and transfer_corridor >= 0.35):
763
+ support_mode = SUPPORT_MODE_TRANSFER
764
+ elif selected_mode in task_spec.reveal_modes:
765
+ support_mode = SUPPORT_MODE_HOLD
766
+ elif selected_mode == "base_action":
767
+ support_mode = SUPPORT_MODE_PASSIVE if passive_corridor >= 0.55 and retrieve_margin >= 0.03 else SUPPORT_MODE_HOLD
768
+ else:
769
+ support_mode = SUPPORT_MODE_HOLD
770
+ best_non_base_utility = max(float(payload["candidate_utility"]) for payload in candidate_metrics[1:])
771
+ intervention_warranted = selected_mode != "base_action" and best_non_base_utility >= float(base_metrics["candidate_utility"]) + 0.12
772
+ return {
773
+ "support_mode": int(support_mode),
774
+ "corridor_feasible": corridor_feasible,
775
+ "persistence_horizon": persistence_horizon,
776
+ "disturbance_cost": np.float32(current_disturbance),
777
+ "state_confidence_target": np.float32(1.0 if intervention_warranted else 0.0),
778
+ "task_metric_mask": STATE_METRIC_MASK.copy(),
779
+ **{metric_name: np.float32(metric_value) for metric_name, metric_value in task_metrics.items()},
780
+ }
781
+
782
+
783
+ def _candidate_rollout_targets(
784
+ task_spec: BridgeTaskSpec,
785
+ *,
786
+ mode_name: str,
787
+ state_targets: dict[str, Any],
788
+ candidate_payload: dict[str, float],
789
+ ) -> dict[str, np.ndarray]:
790
+ schedule = _mode_progress_schedule(task_spec, mode_name)
791
+ start_visibility = float(state_targets["target_visibility_confidence"])
792
+ start_access = float(state_targets["actor_feasibility_score"])
793
+ start_persistence = float(np.clip(state_targets["hold_quality"], 0.0, 1.0))
794
+ start_support = float(np.clip(state_targets["top_layer_stability"], 0.0, 1.0))
795
+ start_reocclusion = float(np.clip(state_targets["release_collapse_rate"], 0.0, 1.0))
796
+ start_disturbance = float(np.clip(state_targets["disturbance_cost"], 0.0, 1.0))
797
+ start_clearance = float(np.clip(state_targets["actor_feasibility_score"], 0.0, 1.0))
798
+ start_grasp = float(np.clip(max(start_visibility, start_access), 0.0, 1.0))
799
+
800
+ end_visibility = float(np.clip(candidate_payload["candidate_immediate_visibility"], 0.0, 1.0))
801
+ end_access = float(np.clip(candidate_payload["candidate_immediate_access"], 0.0, 1.0))
802
+ end_progress = float(np.clip(candidate_payload["candidate_immediate_progress"], 0.0, 1.0))
803
+ end_disturbance = float(np.clip(candidate_payload["candidate_immediate_disturbance"], 0.0, 1.0))
804
+ end_support = float(np.clip(candidate_payload["candidate_immediate_support_stability"], 0.0, 1.0))
805
+ end_persistence = float(np.clip(candidate_payload["candidate_immediate_hold_persistence"], 0.0, 1.0))
806
+ end_reocclusion = float(np.clip(candidate_payload["candidate_immediate_reocclusion"], 0.0, 1.0))
807
+ end_clearance = float(np.clip(max(end_access, end_progress), 0.0, 1.0))
808
+ end_grasp = float(np.clip(max(end_visibility, 0.5 * end_access + 0.5 * end_progress), 0.0, 1.0))
809
+
810
+ if mode_name in task_spec.transfer_modes:
811
+ start_visibility = max(start_visibility, 0.35 * end_visibility)
812
+ start_access = max(start_access, 0.40 * end_access)
813
+ start_persistence = max(start_persistence, 0.45 * end_persistence)
814
+ start_support = max(start_support, 0.50 * end_support)
815
+ elif mode_name in task_spec.retrieve_modes:
816
+ start_visibility = max(start_visibility, 0.55 * end_visibility)
817
+ start_access = max(start_access, 0.70 * end_access)
818
+ start_persistence = max(start_persistence, 0.65 * end_persistence)
819
+ start_support = max(start_support, 0.65 * end_support)
820
+ start_reocclusion = min(start_reocclusion, max(0.4 * end_reocclusion, 0.0))
821
+
822
+ visibility = _scalar_rollout(start_visibility, end_visibility, schedule)
823
+ access = _scalar_rollout(start_access, end_access, schedule)
824
+ persistence = _scalar_rollout(start_persistence, end_persistence, schedule)
825
+ support = _scalar_rollout(start_support, end_support, schedule)
826
+ reocclusion = _scalar_rollout(start_reocclusion, end_reocclusion, schedule)
827
+ disturbance = _scalar_rollout(start_disturbance, end_disturbance, schedule)
828
+ clearance = _scalar_rollout(start_clearance, end_clearance, schedule)
829
+ grasp = _scalar_rollout(start_grasp, end_grasp, schedule)
830
+ reveal_corridor = np.clip(0.38 * visibility + 0.34 * access + 0.22 * support - 0.12 * disturbance, 0.0, 1.0)
831
+ transfer_corridor = np.clip(
832
+ 0.30 * visibility + 0.38 * access + 0.18 * persistence + 0.14 * support - 0.12 * disturbance,
833
+ 0.0,
834
+ 1.0,
835
+ )
836
+ passive_corridor = np.clip(
837
+ 0.22 * visibility + 0.42 * access + 0.20 * persistence + 0.16 * grasp - 0.14 * disturbance - 0.10 * reocclusion,
838
+ 0.0,
839
+ 1.0,
840
+ )
841
+ if mode_name in task_spec.reveal_modes:
842
+ reveal_corridor = np.clip(reveal_corridor + 0.14, 0.0, 1.0)
843
+ passive_corridor = np.clip(0.75 * passive_corridor, 0.0, 1.0)
844
+ elif mode_name in task_spec.transfer_modes:
845
+ transfer_corridor = np.clip(transfer_corridor + 0.16, 0.0, 1.0)
846
+ elif mode_name in task_spec.retrieve_modes:
847
+ passive_corridor = np.clip(passive_corridor + 0.20, 0.0, 1.0)
848
+ reveal_corridor = np.clip(0.60 * reveal_corridor, 0.0, 1.0)
849
+ else:
850
+ reveal_corridor = np.clip(0.85 * reveal_corridor, 0.0, 1.0)
851
+ transfer_corridor = np.clip(0.75 * transfer_corridor, 0.0, 1.0)
852
+ passive_corridor = np.clip(0.80 * passive_corridor, 0.0, 1.0)
853
+ corridor_feasible = np.stack(
854
+ [
855
+ np.repeat(reveal_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
856
+ np.repeat(transfer_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
857
+ np.repeat(passive_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
858
+ ],
859
+ axis=1,
860
+ ).astype(np.float32)
861
+ persistence_horizon = np.stack(
862
+ [
863
+ np.clip(ROLL_OUT_HORIZON * (0.55 * reveal_corridor + 0.45 * support), 0.0, float(ROLL_OUT_HORIZON)),
864
+ np.clip(ROLL_OUT_HORIZON * (0.50 * transfer_corridor + 0.50 * persistence), 0.0, float(ROLL_OUT_HORIZON)),
865
+ np.clip(ROLL_OUT_HORIZON * (0.55 * passive_corridor + 0.45 * persistence), 0.0, float(ROLL_OUT_HORIZON)),
866
+ ],
867
+ axis=1,
868
+ ).astype(np.float32)
869
+ support_mode = np.full((ROLL_OUT_HORIZON,), _mode_support_mode(task_spec, mode_name, int(state_targets["support_mode"])), dtype=np.int64)
870
+ if mode_name == "base_action":
871
+ support_mode[:] = int(state_targets["support_mode"])
872
+ return {
873
+ "candidate_rollout_support_mode": support_mode,
874
+ "candidate_rollout_corridor_feasible": corridor_feasible,
875
+ "candidate_rollout_persistence_horizon": persistence_horizon,
876
+ "candidate_rollout_disturbance_cost": disturbance.astype(np.float32),
877
+ "candidate_rollout_belief_map": visibility[:, None, None].astype(np.float32),
878
+ "candidate_rollout_visibility_map": visibility[:, None, None].astype(np.float32),
879
+ "candidate_rollout_clearance_map": np.repeat(clearance[:, None, None, None], 2, axis=1).astype(np.float32),
880
+ "candidate_rollout_support_stability": support[:, None, None, None].astype(np.float32),
881
+ "candidate_rollout_reocclusion_target": reocclusion[:, None, None].astype(np.float32),
882
+ "candidate_rollout_occluder_contact_map": np.clip(access * support, 0.0, 1.0)[:, None, None].astype(np.float32),
883
+ "candidate_rollout_grasp_affordance_map": grasp[:, None, None].astype(np.float32),
884
+ }
885
+
886
+
887
+ def _select_expert_mode(
888
+ task_spec: BridgeTaskSpec,
889
+ *,
890
+ decision_step: int,
891
+ candidate_metrics: Sequence[dict[str, float]],
892
+ ) -> str:
893
+ metrics_by_name = {mode_name: payload for mode_name, payload in zip(task_spec.mode_order, candidate_metrics)}
894
+ base_utility = float(metrics_by_name["base_action"]["candidate_utility"])
895
+ reveal_best = max(task_spec.reveal_modes, key=lambda name: float(metrics_by_name[name]["candidate_utility"]))
896
+ transfer_best = max(task_spec.transfer_modes, key=lambda name: float(metrics_by_name[name]["candidate_utility"]))
897
+ retrieve_utility = float(metrics_by_name["retrieve"]["candidate_utility"])
898
+ reveal_best_utility = float(metrics_by_name[reveal_best]["candidate_utility"])
899
+ transfer_best_utility = float(metrics_by_name[transfer_best]["candidate_utility"])
900
+ retrieve_success = float(metrics_by_name["retrieve"]["candidate_retrieval_success"])
901
+
902
+ if int(decision_step) > 0 and retrieve_success >= 0.5:
903
+ return "retrieve"
904
+ if int(decision_step) == 0 and reveal_best_utility >= base_utility - 0.02:
905
+ return reveal_best
906
+ if transfer_best_utility >= reveal_best_utility + 0.05 and transfer_best_utility >= base_utility + 0.02:
907
+ return transfer_best
908
+ if reveal_best_utility >= base_utility - 0.02:
909
+ return reveal_best
910
+ if retrieve_success >= 0.5 and retrieve_utility >= base_utility + 0.02:
911
+ return "retrieve"
912
+ if transfer_best_utility >= base_utility + 0.02:
913
+ return transfer_best
914
+ utilities = np.asarray([payload["candidate_utility"] for payload in candidate_metrics], dtype=np.float32)
915
+ return task_spec.mode_order[int(utilities.argmax())]
916
+
917
+
918
+ def _evaluate_candidate(
919
+ task_spec: BridgeTaskSpec,
920
+ sim_env: gym.Env[Any, Any],
921
+ obs_env: gym.Env[Any, Any],
922
+ snapshot: dict[str, Any],
923
+ mode_name: str,
924
+ ) -> dict[str, float]:
925
+ _restore_env(sim_env, snapshot)
926
+ start_positions = _all_positions(sim_env, task_spec)
927
+ _execute_mode(sim_env, task_spec, mode_name)
928
+ _sync_env_state(sim_env, obs_env)
929
+ after_bundle = _extract_sensor_bundle(obs_env.get_obs(obs_env.get_info()), resolution=IMAGE_RESOLUTION)
930
+ immediate = _candidate_metrics(sim_env, task_spec=task_spec, start_positions=start_positions, current_obs_bundle=after_bundle)
931
+ if not immediate["retrieval_success"] and mode_name not in {"retrieve", "base_action"}:
932
+ _execute_mode(sim_env, task_spec, "retrieve")
933
+ _sync_env_state(sim_env, obs_env)
934
+ follow_bundle = _extract_sensor_bundle(obs_env.get_obs(obs_env.get_info()), resolution=IMAGE_RESOLUTION)
935
+ final_metrics = _candidate_metrics(sim_env, task_spec=task_spec, start_positions=start_positions, current_obs_bundle=follow_bundle)
936
+ else:
937
+ final_metrics = immediate
938
+ _restore_env(obs_env, snapshot)
939
+ utility = (
940
+ 2.5 * final_metrics["retrieval_success"]
941
+ + 1.0 * final_metrics["progress"]
942
+ + 0.5 * final_metrics["clearance"]
943
+ + 0.25 * final_metrics["visibility"]
944
+ - 0.5 * final_metrics["disturbance"]
945
+ )
946
+ return {
947
+ "candidate_retrieval_success": final_metrics["retrieval_success"],
948
+ "candidate_risk": float(np.clip(final_metrics["disturbance"], 0.0, 1.0)),
949
+ "candidate_utility": float(utility),
950
+ "candidate_final_disturbance_cost": final_metrics["disturbance"],
951
+ "candidate_reocclusion_rate": float(np.clip(1.0 - final_metrics["clearance"], 0.0, 1.0)),
952
+ "candidate_visibility_integral": final_metrics["visibility"],
953
+ "candidate_actor_feasibility_auc": final_metrics["clearance"],
954
+ "candidate_reveal_achieved": float(final_metrics["progress"] > 0.15 or final_metrics["clearance"] > 0.35),
955
+ "candidate_hold_persistence": float(1.0 - final_metrics["disturbance"]),
956
+ "candidate_support_stability_auc": float(1.0 - 0.5 * final_metrics["disturbance"]),
957
+ "candidate_disturbance_auc": final_metrics["disturbance"],
958
+ "candidate_immediate_retrieval_success": immediate["retrieval_success"],
959
+ "candidate_immediate_visibility": immediate["visibility"],
960
+ "candidate_immediate_access": immediate["clearance"],
961
+ "candidate_immediate_progress": immediate["progress"],
962
+ "candidate_immediate_reocclusion": float(np.clip(1.0 - immediate["clearance"], 0.0, 1.0)),
963
+ "candidate_immediate_hold_persistence": float(1.0 - immediate["disturbance"]),
964
+ "candidate_immediate_support_stability": float(1.0 - 0.5 * immediate["disturbance"]),
965
+ "candidate_immediate_disturbance": immediate["disturbance"],
966
+ }
967
+
968
+
969
+ def _cloth_seed_is_valid(env: gym.Env[Any, Any], task_spec: BridgeTaskSpec, *, episode_seed: int) -> bool:
970
+ env.reset(seed=int(episode_seed))
971
+ _initialize_proxy_state(env, task_spec, episode_seed=int(episode_seed))
972
+ start_positions = _all_positions(env, task_spec)
973
+ obs = env.get_obs(env.get_info())
974
+ obs_bundle = _extract_sensor_bundle(obs, resolution=IMAGE_RESOLUTION)
975
+ actor_id = int(getattr(_source_actor(env), "per_scene_id", -1))
976
+ start_visibility = _source_visibility(obs_bundle, actor_id)
977
+ if start_visibility > CLOTH_SUCCESS_MIN_VISIBILITY:
978
+ return False
979
+ snapshot = _snapshot_env(env)
980
+ for reveal_mode in ("lift_edge", "separate_layer"):
981
+ _restore_env(env, snapshot)
982
+ _execute_mode(env, task_spec, reveal_mode)
983
+ _execute_mode(env, task_spec, "retrieve")
984
+ obs = env.get_obs(env.get_info())
985
+ obs_bundle = _extract_sensor_bundle(obs, resolution=IMAGE_RESOLUTION)
986
+ visibility = _source_visibility(obs_bundle, actor_id)
987
+ if _cloth_success(env, start_positions=start_positions, current_visibility=visibility):
988
+ return True
989
+ return False
990
+
991
+
992
+ def _build_episode_splits(task_spec: BridgeTaskSpec, spec: SmokeSpec) -> dict[str, list[int]]:
993
+ if task_spec.key != "cloth":
994
+ return {
995
+ "train": [spec.dataset_seed * 10_000 + index for index in range(spec.train_episodes)],
996
+ "val": [spec.dataset_seed * 10_000 + 1_000 + index for index in range(spec.val_episodes)],
997
+ "eval": [spec.dataset_seed * 10_000 + 2_000 + index for index in range(spec.eval_episodes)],
998
+ }
999
+ target_total = int(spec.train_episodes + spec.val_episodes + spec.eval_episodes)
1000
+ valid_seeds: list[int] = []
1001
+ candidate_index = 0
1002
+ env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
1003
+ try:
1004
+ while len(valid_seeds) < target_total:
1005
+ episode_seed = spec.dataset_seed * 10_000 + candidate_index
1006
+ candidate_index += 1
1007
+ if _cloth_seed_is_valid(env, task_spec, episode_seed=episode_seed):
1008
+ valid_seeds.append(int(episode_seed))
1009
+ print(
1010
+ json.dumps(
1011
+ {
1012
+ "phase": "cloth_seed_selected",
1013
+ "episode_seed": int(episode_seed),
1014
+ "selected": len(valid_seeds),
1015
+ "target_total": target_total,
1016
+ }
1017
+ ),
1018
+ flush=True,
1019
+ )
1020
+ if candidate_index > target_total * 30:
1021
+ raise RuntimeError("Unable to find enough physics-valid cloth proxy seeds for the smoke protocol.")
1022
+ finally:
1023
+ env.close()
1024
+ return {
1025
+ "train": valid_seeds[: spec.train_episodes],
1026
+ "val": valid_seeds[spec.train_episodes : spec.train_episodes + spec.val_episodes],
1027
+ "eval": valid_seeds[spec.train_episodes + spec.val_episodes : target_total],
1028
+ }
1029
+
1030
+
1031
+ def _save_episode_splits(output_path: Path, payload: dict[str, list[int]]) -> None:
1032
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1033
+ output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
1034
+
1035
+
1036
+ def _normalize_depth_array(array: np.ndarray) -> np.ndarray:
1037
+ normalized = _np(array, dtype=np.float32)
1038
+ if normalized.ndim >= 4 and normalized.shape[-1] == 1:
1039
+ return np.moveaxis(normalized, -1, normalized.ndim - 3)
1040
+ return normalized
1041
+
1042
+
1043
+ def _candidate_pad_indices(source_candidates: int, expected_candidates: int) -> list[int]:
1044
+ if source_candidates <= 0 or source_candidates >= expected_candidates:
1045
+ return []
1046
+ if source_candidates == 1:
1047
+ return [0] * (expected_candidates - source_candidates)
1048
+ cycle = list(range(1, source_candidates))
1049
+ indices: list[int] = []
1050
+ while len(indices) < (expected_candidates - source_candidates):
1051
+ indices.extend(cycle)
1052
+ return indices[: expected_candidates - source_candidates]
1053
+
1054
+
1055
+ def _pad_candidate_axis(
1056
+ value: Any,
1057
+ *,
1058
+ source_candidates: int,
1059
+ expected_candidates: int,
1060
+ pad_indices: Sequence[int],
1061
+ ) -> Any:
1062
+ if source_candidates <= 0 or source_candidates >= expected_candidates:
1063
+ return value
1064
+ if isinstance(value, np.ndarray):
1065
+ if value.ndim == 0 or value.shape[0] != source_candidates:
1066
+ return value
1067
+ if not pad_indices:
1068
+ return value
1069
+ padding = np.take(value, indices=list(pad_indices), axis=0)
1070
+ return np.concatenate([value, padding], axis=0)
1071
+ if isinstance(value, torch.Tensor):
1072
+ if value.ndim == 0 or value.shape[0] != source_candidates:
1073
+ return value
1074
+ if not pad_indices:
1075
+ return value
1076
+ pad_index = torch.as_tensor(list(pad_indices), device=value.device, dtype=torch.long)
1077
+ padding = value.index_select(0, pad_index)
1078
+ return torch.cat([value, padding], dim=0)
1079
+ if isinstance(value, list) and len(value) == source_candidates:
1080
+ padded = list(value)
1081
+ padded.extend(value[index] for index in pad_indices)
1082
+ return padded
1083
+ if isinstance(value, tuple) and len(value) == source_candidates:
1084
+ padded = list(value)
1085
+ padded.extend(value[index] for index in pad_indices)
1086
+ return tuple(padded)
1087
+ return value
1088
+
1089
+
1090
+ def _normalize_candidate_targets(sample: dict[str, Any]) -> dict[str, Any]:
1091
+ candidate_chunks = sample.get("candidate_action_chunks")
1092
+ if candidate_chunks is None:
1093
+ return sample
1094
+ candidate_array = _np(candidate_chunks)
1095
+ if candidate_array.ndim == 0:
1096
+ return sample
1097
+ source_candidates = int(candidate_array.shape[0])
1098
+ if source_candidates >= EXPECTED_PROPOSAL_CANDIDATES:
1099
+ return sample
1100
+ pad_indices = _candidate_pad_indices(source_candidates, EXPECTED_PROPOSAL_CANDIDATES)
1101
+ if not pad_indices:
1102
+ return sample
1103
+ padded = dict(sample)
1104
+ for key, value in sample.items():
1105
+ if not (key.startswith("candidate_") or key.startswith("proposal_target_")):
1106
+ continue
1107
+ padded[key] = _pad_candidate_axis(
1108
+ value,
1109
+ source_candidates=source_candidates,
1110
+ expected_candidates=EXPECTED_PROPOSAL_CANDIDATES,
1111
+ pad_indices=pad_indices,
1112
+ )
1113
+ return padded
1114
+
1115
+
1116
+ def _normalize_cached_samples(samples: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
1117
+ normalized_samples: list[dict[str, Any]] = []
1118
+ for sample in samples:
1119
+ patched = dict(sample)
1120
+ for key in ("depths", "depth_valid", "history_depths", "history_depth_valid"):
1121
+ if key in patched:
1122
+ patched[key] = _normalize_depth_array(patched[key])
1123
+ patched = _normalize_candidate_targets(patched)
1124
+ normalized_samples.append(patched)
1125
+ return normalized_samples
1126
+
1127
+
1128
+ def _collect_split(
1129
+ *,
1130
+ task_spec: BridgeTaskSpec,
1131
+ canonical_chunks: dict[str, np.ndarray],
1132
+ split_name: str,
1133
+ seeds: Sequence[int],
1134
+ spec: SmokeSpec,
1135
+ output_path: Path,
1136
+ ) -> dict[str, Any]:
1137
+ obs_env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
1138
+ sim_env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
1139
+ samples: list[dict[str, Any]] = []
1140
+ episode_records: list[dict[str, Any]] = []
1141
+ try:
1142
+ for episode_seed in seeds:
1143
+ obs, _ = obs_env.reset(seed=int(episode_seed))
1144
+ sim_env.reset(seed=int(episode_seed))
1145
+ _initialize_proxy_state(obs_env, task_spec, episode_seed=int(episode_seed))
1146
+ _sync_env_state(obs_env, sim_env)
1147
+ obs = obs_env.get_obs(obs_env.get_info())
1148
+ episode_start_positions = _all_positions(obs_env, task_spec)
1149
+ history: deque[dict[str, Any]] = deque(maxlen=spec.history_steps)
1150
+ episode_success = False
1151
+ for decision_step in range(spec.max_macro_steps):
1152
+ obs_bundle = _extract_sensor_bundle(obs, resolution=spec.resolution)
1153
+ proprio = _build_proprio(obs_env)
1154
+ snapshot = _snapshot_env(obs_env)
1155
+ candidate_metrics = [
1156
+ _evaluate_candidate(task_spec, sim_env, obs_env, snapshot, mode_name) for mode_name in task_spec.mode_order
1157
+ ]
1158
+ candidate_chunks = np.stack([canonical_chunks[mode_name] for mode_name in task_spec.mode_order], axis=0).astype(np.float32)
1159
+ utilities = np.asarray([payload["candidate_utility"] for payload in candidate_metrics], dtype=np.float32)
1160
+ selected_mode = _select_expert_mode(task_spec, decision_step=decision_step, candidate_metrics=candidate_metrics)
1161
+ state_targets = _current_state_targets(
1162
+ task_spec,
1163
+ env=obs_env,
1164
+ obs_bundle=obs_bundle,
1165
+ candidate_metrics=candidate_metrics,
1166
+ episode_start_positions=episode_start_positions,
1167
+ selected_mode=selected_mode,
1168
+ )
1169
+ rollout_targets_by_mode = [
1170
+ _candidate_rollout_targets(task_spec, mode_name=mode_name, state_targets=state_targets, candidate_payload=payload)
1171
+ for mode_name, payload in zip(task_spec.mode_order, candidate_metrics)
1172
+ ]
1173
+ sample = {
1174
+ "images": obs_bundle["images"].copy(),
1175
+ "depths": obs_bundle["depths"].copy(),
1176
+ "depth_valid": obs_bundle["depth_valid"].copy(),
1177
+ "camera_intrinsics": obs_bundle["camera_intrinsics"].copy(),
1178
+ "camera_extrinsics": obs_bundle["camera_extrinsics"].copy(),
1179
+ "history_images": _history_stack(
1180
+ history,
1181
+ "images",
1182
+ pad_shape=obs_bundle["images"].shape,
1183
+ dtype=np.uint8,
1184
+ history_steps=spec.history_steps,
1185
+ ),
1186
+ "history_depths": _history_stack(
1187
+ history,
1188
+ "depths",
1189
+ pad_shape=obs_bundle["depths"].shape,
1190
+ dtype=np.float32,
1191
+ history_steps=spec.history_steps,
1192
+ ),
1193
+ "history_depth_valid": _history_stack(
1194
+ history,
1195
+ "depth_valid",
1196
+ pad_shape=obs_bundle["depth_valid"].shape,
1197
+ dtype=np.float32,
1198
+ history_steps=spec.history_steps,
1199
+ ),
1200
+ "history_camera_intrinsics": _history_stack(
1201
+ history,
1202
+ "camera_intrinsics",
1203
+ pad_shape=obs_bundle["camera_intrinsics"].shape,
1204
+ dtype=np.float32,
1205
+ history_steps=spec.history_steps,
1206
+ ),
1207
+ "history_camera_extrinsics": _history_stack(
1208
+ history,
1209
+ "camera_extrinsics",
1210
+ pad_shape=obs_bundle["camera_extrinsics"].shape,
1211
+ dtype=np.float32,
1212
+ history_steps=spec.history_steps,
1213
+ ),
1214
+ "history_proprio": _history_stack(
1215
+ history,
1216
+ "proprio",
1217
+ pad_shape=(PROPRIO_DIM,),
1218
+ dtype=np.float32,
1219
+ history_steps=spec.history_steps,
1220
+ ),
1221
+ "history_actions": _history_stack(
1222
+ history,
1223
+ "action",
1224
+ pad_shape=(14,),
1225
+ dtype=np.float32,
1226
+ history_steps=spec.history_steps,
1227
+ ),
1228
+ "proprio": proprio.astype(np.float32),
1229
+ "language_goal": task_spec.text_prompt,
1230
+ "task_name": task_spec.task_name,
1231
+ "task_id": TASK_INDEX[task_spec.task_name],
1232
+ "action_chunk": canonical_chunks[selected_mode].copy(),
1233
+ "candidate_action_chunks": candidate_chunks,
1234
+ "candidate_retrieval_success": np.asarray([payload["candidate_retrieval_success"] for payload in candidate_metrics], dtype=np.float32),
1235
+ "candidate_final_disturbance_cost": np.asarray(
1236
+ [payload["candidate_final_disturbance_cost"] for payload in candidate_metrics],
1237
+ dtype=np.float32,
1238
+ ),
1239
+ "candidate_reocclusion_rate": np.asarray([payload["candidate_reocclusion_rate"] for payload in candidate_metrics], dtype=np.float32),
1240
+ "candidate_visibility_integral": np.asarray(
1241
+ [payload["candidate_visibility_integral"] for payload in candidate_metrics],
1242
+ dtype=np.float32,
1243
+ ),
1244
+ "candidate_actor_feasibility_auc": np.asarray(
1245
+ [payload["candidate_actor_feasibility_auc"] for payload in candidate_metrics],
1246
+ dtype=np.float32,
1247
+ ),
1248
+ "candidate_reveal_achieved": np.asarray([payload["candidate_reveal_achieved"] for payload in candidate_metrics], dtype=np.float32),
1249
+ "candidate_hold_persistence": np.asarray([payload["candidate_hold_persistence"] for payload in candidate_metrics], dtype=np.float32),
1250
+ "candidate_support_stability_auc": np.asarray(
1251
+ [payload["candidate_support_stability_auc"] for payload in candidate_metrics],
1252
+ dtype=np.float32,
1253
+ ),
1254
+ "candidate_disturbance_auc": np.asarray([payload["candidate_disturbance_auc"] for payload in candidate_metrics], dtype=np.float32),
1255
+ "candidate_risk": np.asarray([payload["candidate_risk"] for payload in candidate_metrics], dtype=np.float32),
1256
+ "candidate_utility": utilities,
1257
+ "candidate_rollout_support_mode": np.stack(
1258
+ [payload["candidate_rollout_support_mode"] for payload in rollout_targets_by_mode],
1259
+ axis=0,
1260
+ ).astype(np.int64),
1261
+ "candidate_rollout_corridor_feasible": np.stack(
1262
+ [payload["candidate_rollout_corridor_feasible"] for payload in rollout_targets_by_mode],
1263
+ axis=0,
1264
+ ).astype(np.float32),
1265
+ "candidate_rollout_persistence_horizon": np.stack(
1266
+ [payload["candidate_rollout_persistence_horizon"] for payload in rollout_targets_by_mode],
1267
+ axis=0,
1268
+ ).astype(np.float32),
1269
+ "candidate_rollout_disturbance_cost": np.stack(
1270
+ [payload["candidate_rollout_disturbance_cost"] for payload in rollout_targets_by_mode],
1271
+ axis=0,
1272
+ ).astype(np.float32),
1273
+ "candidate_rollout_belief_map": np.stack(
1274
+ [payload["candidate_rollout_belief_map"] for payload in rollout_targets_by_mode],
1275
+ axis=0,
1276
+ ).astype(np.float32),
1277
+ "candidate_rollout_visibility_map": np.stack(
1278
+ [payload["candidate_rollout_visibility_map"] for payload in rollout_targets_by_mode],
1279
+ axis=0,
1280
+ ).astype(np.float32),
1281
+ "candidate_rollout_clearance_map": np.stack(
1282
+ [payload["candidate_rollout_clearance_map"] for payload in rollout_targets_by_mode],
1283
+ axis=0,
1284
+ ).astype(np.float32),
1285
+ "candidate_rollout_support_stability": np.stack(
1286
+ [payload["candidate_rollout_support_stability"] for payload in rollout_targets_by_mode],
1287
+ axis=0,
1288
+ ).astype(np.float32),
1289
+ "candidate_rollout_reocclusion_target": np.stack(
1290
+ [payload["candidate_rollout_reocclusion_target"] for payload in rollout_targets_by_mode],
1291
+ axis=0,
1292
+ ).astype(np.float32),
1293
+ "candidate_rollout_occluder_contact_map": np.stack(
1294
+ [payload["candidate_rollout_occluder_contact_map"] for payload in rollout_targets_by_mode],
1295
+ axis=0,
1296
+ ).astype(np.float32),
1297
+ "candidate_rollout_grasp_affordance_map": np.stack(
1298
+ [payload["candidate_rollout_grasp_affordance_map"] for payload in rollout_targets_by_mode],
1299
+ axis=0,
1300
+ ).astype(np.float32),
1301
+ "proposal_target_rollout_support_mode": np.stack(
1302
+ [payload["candidate_rollout_support_mode"] for payload in rollout_targets_by_mode],
1303
+ axis=0,
1304
+ ).astype(np.int64),
1305
+ "proposal_target_rollout_corridor_feasible": np.stack(
1306
+ [payload["candidate_rollout_corridor_feasible"] for payload in rollout_targets_by_mode],
1307
+ axis=0,
1308
+ ).astype(np.float32),
1309
+ "proposal_target_rollout_persistence_horizon": np.stack(
1310
+ [payload["candidate_rollout_persistence_horizon"] for payload in rollout_targets_by_mode],
1311
+ axis=0,
1312
+ ).astype(np.float32),
1313
+ "proposal_target_rollout_disturbance_cost": np.stack(
1314
+ [payload["candidate_rollout_disturbance_cost"] for payload in rollout_targets_by_mode],
1315
+ axis=0,
1316
+ ).astype(np.float32),
1317
+ "proposal_target_rollout_belief_map": np.stack(
1318
+ [payload["candidate_rollout_belief_map"] for payload in rollout_targets_by_mode],
1319
+ axis=0,
1320
+ ).astype(np.float32),
1321
+ "proposal_target_rollout_visibility_map": np.stack(
1322
+ [payload["candidate_rollout_visibility_map"] for payload in rollout_targets_by_mode],
1323
+ axis=0,
1324
+ ).astype(np.float32),
1325
+ "proposal_target_rollout_clearance_map": np.stack(
1326
+ [payload["candidate_rollout_clearance_map"] for payload in rollout_targets_by_mode],
1327
+ axis=0,
1328
+ ).astype(np.float32),
1329
+ "proposal_target_rollout_support_stability": np.stack(
1330
+ [payload["candidate_rollout_support_stability"] for payload in rollout_targets_by_mode],
1331
+ axis=0,
1332
+ ).astype(np.float32),
1333
+ "proposal_target_rollout_reocclusion_target": np.stack(
1334
+ [payload["candidate_rollout_reocclusion_target"] for payload in rollout_targets_by_mode],
1335
+ axis=0,
1336
+ ).astype(np.float32),
1337
+ "proposal_target_rollout_occluder_contact_map": np.stack(
1338
+ [payload["candidate_rollout_occluder_contact_map"] for payload in rollout_targets_by_mode],
1339
+ axis=0,
1340
+ ).astype(np.float32),
1341
+ "proposal_target_rollout_grasp_affordance_map": np.stack(
1342
+ [payload["candidate_rollout_grasp_affordance_map"] for payload in rollout_targets_by_mode],
1343
+ axis=0,
1344
+ ).astype(np.float32),
1345
+ "episode_seed": int(episode_seed),
1346
+ "decision_step": int(decision_step),
1347
+ "selected_mode": selected_mode,
1348
+ **state_targets,
1349
+ }
1350
+ samples.append(sample)
1351
+ _execute_mode(obs_env, task_spec, selected_mode)
1352
+ obs = obs_env.get_obs(obs_env.get_info())
1353
+ post_bundle = _extract_sensor_bundle(obs, resolution=spec.resolution)
1354
+ history.append(_init_history_entry(obs_bundle, proprio, canonical_chunks[selected_mode]))
1355
+ if (
1356
+ _bag_success(obs_env)
1357
+ if task_spec.key == "bag"
1358
+ else _cloth_success(
1359
+ obs_env,
1360
+ start_positions=episode_start_positions,
1361
+ current_visibility=_source_visibility(post_bundle, int(getattr(_source_actor(obs_env), "per_scene_id", -1))),
1362
+ )
1363
+ ):
1364
+ episode_success = True
1365
+ break
1366
+ episode_records.append({"episode_seed": int(episode_seed), "success": episode_success, "steps": len(history)})
1367
+ print(
1368
+ json.dumps(
1369
+ {
1370
+ "phase": "collect_episode_complete",
1371
+ "task": task_spec.key,
1372
+ "split": split_name,
1373
+ "episode_seed": int(episode_seed),
1374
+ "success": episode_success,
1375
+ "steps": len(history),
1376
+ "samples_collected": len(samples),
1377
+ }
1378
+ ),
1379
+ flush=True,
1380
+ )
1381
+ finally:
1382
+ obs_env.close()
1383
+ sim_env.close()
1384
+ payload = {
1385
+ "split_name": split_name,
1386
+ "resolution": spec.resolution,
1387
+ "history_steps": spec.history_steps,
1388
+ "samples": samples,
1389
+ "episode_records": episode_records,
1390
+ }
1391
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1392
+ torch.save(payload, output_path)
1393
+ return payload
1394
+
1395
+
1396
+ def _manual_train_spec(task_spec: BridgeTaskSpec, variant: str, spec: SmokeSpec) -> dict[str, Any]:
1397
+ return {
1398
+ "track_id": task_spec.track_id,
1399
+ "suite": task_spec.suite,
1400
+ "benchmark_task": task_spec.benchmark_task,
1401
+ "model_variant": str(variant),
1402
+ "seed": int(spec.train_seed),
1403
+ "train_demos": int(spec.train_episodes),
1404
+ "val_demos": int(spec.val_episodes),
1405
+ "init_checkpoint_group": str(DEFAULT_INIT_CHECKPOINT),
1406
+ "optimizer": "adamw",
1407
+ "learning_rate": float(spec.learning_rate),
1408
+ "lr_schedule": "constant",
1409
+ "batch_size": int(spec.batch_size),
1410
+ "augmentations": "none",
1411
+ "early_stopping_metric": "val_total",
1412
+ "max_gradient_steps": int(spec.epochs * math.ceil(max(1, spec.train_episodes) / max(1, spec.batch_size))),
1413
+ "unfreeze_scope": "fusion_memory_decoder",
1414
+ "dataset_split_id": (
1415
+ f"{task_spec.key}_{SMOKE_VERSION}_seed{spec.dataset_seed}"
1416
+ if int(spec.dataset_seed) == DEFAULT_SEED
1417
+ else f"{task_spec.key}_{SMOKE_VERSION}_dataset_seed{spec.dataset_seed}"
1418
+ ),
1419
+ "same_data_policy": True,
1420
+ "same_init_policy": True,
1421
+ }
1422
+
1423
+
1424
+ def _trainer_config_for_variant(variant: str) -> TrainerConfig:
1425
+ if variant == "trunk_only_ft":
1426
+ return TrainerConfig(
1427
+ policy_type="trunk_only",
1428
+ trainable_parameter_prefixes=("fusion", "memory", "decoder"),
1429
+ eval_mode="trunk_only",
1430
+ )
1431
+ if variant == "adapter_active_ft":
1432
+ return TrainerConfig(
1433
+ policy_type="adapter_wrapped",
1434
+ trainable_parameter_prefixes=(
1435
+ "trunk.fusion",
1436
+ "trunk.memory",
1437
+ "trunk.decoder",
1438
+ "adapter.state_head",
1439
+ "adapter.transition_model",
1440
+ "adapter.proposal_prior",
1441
+ "adapter.planner",
1442
+ ),
1443
+ adapter_mode="adapter_active",
1444
+ eval_mode="adapter_active",
1445
+ adapter_use_transition_model=True,
1446
+ adapter_use_task_conditioning=True,
1447
+ adapter_action_supervision_source="trunk",
1448
+ )
1449
+ raise KeyError(f"Unsupported variant {variant!r}.")
1450
+
1451
+
1452
+ def _loss_weights_for_smoke(task_spec: BridgeTaskSpec) -> LossWeights:
1453
+ return LossWeights(
1454
+ action=1.0,
1455
+ support_mode=0.15,
1456
+ corridor=0.15,
1457
+ persistence=0.08,
1458
+ disturbance=0.08,
1459
+ planner_success=0.20,
1460
+ planner_risk=0.08,
1461
+ planner_ranking=0.20,
1462
+ proposal_reconstruction=0.10,
1463
+ proposal_success=0.12,
1464
+ proposal_ranking=0.15,
1465
+ proposal_mode=0.10,
1466
+ proposal_diversity=0.02,
1467
+ task_metrics=0.15,
1468
+ transition=0.25,
1469
+ gate=0.25,
1470
+ calibration=0.10,
1471
+ proposal_mode_task_filter=[task_spec.task_name],
1472
+ )
1473
+
1474
+
1475
+ def _train_variant(
1476
+ *,
1477
+ task_spec: BridgeTaskSpec,
1478
+ variant: str,
1479
+ train_samples: Sequence[dict[str, Any]],
1480
+ val_samples: Sequence[dict[str, Any]],
1481
+ spec: SmokeSpec,
1482
+ output_dir: Path,
1483
+ ) -> tuple[Path, dict[str, Any]]:
1484
+ policy_config, _init_trainer_cfg, _init_loss_weights = _load_init_bundle()
1485
+ policy_config = _apply_smoke_planner_overrides(policy_config)
1486
+ trainer_config = _trainer_config_for_variant(variant)
1487
+ loss_weights = _loss_weights_for_smoke(task_spec)
1488
+ model = build_policy(policy_config, trainer_config)
1489
+ init_info = _load_init_checkpoint(model, str(DEFAULT_INIT_CHECKPOINT), False)
1490
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1491
+ model = model.to(device)
1492
+ torch.manual_seed(spec.train_seed)
1493
+ if torch.cuda.is_available():
1494
+ torch.cuda.manual_seed_all(spec.train_seed)
1495
+ torch.backends.cuda.matmul.allow_tf32 = True
1496
+ matched = apply_trainable_parameter_prefixes(model, trainer_config)
1497
+ optimizer = torch.optim.AdamW(
1498
+ [parameter for parameter in model.parameters() if parameter.requires_grad],
1499
+ lr=spec.learning_rate,
1500
+ weight_decay=spec.weight_decay,
1501
+ )
1502
+ trainer = BimanualTrainer(model=model, optimizer=optimizer, config=trainer_config)
1503
+ train_loader = _make_loader(train_samples, batch_size=spec.batch_size, num_workers=spec.num_workers, shuffle=True)
1504
+ val_loader = _make_loader(val_samples, batch_size=spec.batch_size, num_workers=spec.num_workers, shuffle=False)
1505
+ best_val = math.inf
1506
+ history: list[dict[str, Any]] = []
1507
+ train_spec = _manual_train_spec(task_spec, variant, spec)
1508
+ train_spec["max_gradient_steps"] = len(train_loader) * spec.epochs
1509
+ for epoch in range(spec.epochs):
1510
+ model.train()
1511
+ train_losses: list[dict[str, float]] = []
1512
+ for batch in train_loader:
1513
+ moved = _move_batch_to_device(batch, device)
1514
+ loss_dict = trainer.training_step(moved, loss_weights=loss_weights)
1515
+ train_losses.append({key: float(value.detach().cpu()) for key, value in loss_dict.items()})
1516
+ model.eval()
1517
+ val_losses: list[dict[str, float]] = []
1518
+ with torch.no_grad():
1519
+ for batch in val_loader:
1520
+ moved = _move_batch_to_device(batch, device)
1521
+ forward_kwargs = {
1522
+ "images": moved["images"],
1523
+ "proprio": moved["proprio"],
1524
+ "texts": moved["texts"],
1525
+ "task_names": moved.get("task_name"),
1526
+ "task_ids": moved.get("task_id"),
1527
+ "history_images": moved.get("history_images"),
1528
+ "history_proprio": moved.get("history_proprio"),
1529
+ "history_actions": moved.get("history_actions"),
1530
+ "depths": moved.get("depths"),
1531
+ "depth_valid": moved.get("depth_valid"),
1532
+ "camera_intrinsics": moved.get("camera_intrinsics"),
1533
+ "camera_extrinsics": moved.get("camera_extrinsics"),
1534
+ "history_depths": moved.get("history_depths"),
1535
+ "history_depth_valid": moved.get("history_depth_valid"),
1536
+ "history_camera_intrinsics": moved.get("history_camera_intrinsics"),
1537
+ "history_camera_extrinsics": moved.get("history_camera_extrinsics"),
1538
+ }
1539
+ if variant == "adapter_active_ft":
1540
+ forward_kwargs["adapter_mode"] = "adapter_active"
1541
+ forward_kwargs["use_transition_model"] = True
1542
+ forward_kwargs["use_task_conditioning"] = True
1543
+ outputs = model(**forward_kwargs)
1544
+ losses = compute_total_loss(outputs, moved, weights=loss_weights)
1545
+ val_losses.append({key: float(value.detach().cpu()) for key, value in losses.items()})
1546
+ train_summary = _aggregate_epoch(train_losses)
1547
+ val_summary = _aggregate_epoch(val_losses)
1548
+ history.append({"epoch": epoch, "train": train_summary, "val": val_summary})
1549
+ print(
1550
+ json.dumps(
1551
+ {
1552
+ "phase": "epoch_complete",
1553
+ "task": task_spec.key,
1554
+ "variant": variant,
1555
+ "epoch": epoch,
1556
+ "train_total": train_summary.get("total", 0.0),
1557
+ "val_total": val_summary.get("total", 0.0),
1558
+ }
1559
+ ),
1560
+ flush=True,
1561
+ )
1562
+ if val_summary.get("total", math.inf) <= best_val:
1563
+ best_val = val_summary["total"]
1564
+ checkpoint_path = _save_training_checkpoint(
1565
+ output_dir=output_dir,
1566
+ experiment_name=f"{task_spec.key}_{variant}_seed{spec.train_seed}",
1567
+ model=model,
1568
+ policy_config=policy_config,
1569
+ trainer_config=trainer_config,
1570
+ loss_weights=loss_weights,
1571
+ history=history,
1572
+ best_val=best_val,
1573
+ train_spec=train_spec,
1574
+ )
1575
+ (output_dir / "summary.json").write_text(
1576
+ json.dumps(
1577
+ {
1578
+ "task": task_spec.key,
1579
+ "variant": variant,
1580
+ "checkpoint_path": str(checkpoint_path),
1581
+ "init_info": init_info,
1582
+ "trainable_parameter_names": matched,
1583
+ "best_val_total": best_val,
1584
+ "history": history,
1585
+ "train_spec": train_spec,
1586
+ },
1587
+ indent=2,
1588
+ )
1589
+ + "\n",
1590
+ encoding="utf-8",
1591
+ )
1592
+ return output_dir / "checkpoint_best.pt", train_spec
1593
+
1594
+
1595
+ def _eval_mode_name(model_output: dict[str, Any], result_mode_name: str, canonical_chunks: dict[str, np.ndarray]) -> tuple[str, bool, bool]:
1596
+ if result_mode_name == "adapter_active_ft" and "proposal_mode_names" in model_output and "best_candidate_indices" in model_output:
1597
+ active_mask = bool(_np(model_output.get("adapter_active_mask", np.asarray([False]))).reshape(-1)[0])
1598
+ if not active_mask:
1599
+ return _classify_mode_from_chunk(_np(model_output["action_mean"])[0], canonical_chunks), False, False
1600
+ best_index = int(_np(model_output["best_candidate_indices"])[0])
1601
+ proposal_mode_names = model_output["proposal_mode_names"][0]
1602
+ if best_index < len(proposal_mode_names):
1603
+ mode_name = str(proposal_mode_names[best_index])
1604
+ else:
1605
+ mode_name = _classify_mode_from_chunk(_np(model_output["action_mean"])[0], canonical_chunks)
1606
+ return mode_name, active_mask, bool(best_index > 0)
1607
+ return _classify_mode_from_chunk(_np(model_output["action_mean"])[0], canonical_chunks), False, False
1608
+
1609
+
1610
+ def _manual_eval_protocol(task_spec: BridgeTaskSpec, *, eval_mode: str, spec: SmokeSpec, episodes: int) -> dict[str, Any]:
1611
+ return {
1612
+ "track_id": task_spec.track_id,
1613
+ "suite": task_spec.suite,
1614
+ "benchmark_task": task_spec.benchmark_task,
1615
+ "role": "target",
1616
+ "eval_mode": eval_mode,
1617
+ "seed": int(spec.dataset_seed),
1618
+ "episodes": int(episodes),
1619
+ "resolution": int(spec.resolution),
1620
+ "cameras": tuple(CAMERA_NAMES),
1621
+ "observation_stack": "rgb_triplicate_zero_depth",
1622
+ "action_horizon": 8,
1623
+ "action_space": "widowx_delta_pose",
1624
+ "same_test_episodes": True,
1625
+ }
1626
+
1627
+
1628
+ def _batch_from_obs(
1629
+ task_spec: BridgeTaskSpec,
1630
+ obs_bundle: dict[str, np.ndarray],
1631
+ proprio: np.ndarray,
1632
+ history: Sequence[dict[str, Any]],
1633
+ device: torch.device,
1634
+ ) -> dict[str, Any]:
1635
+ return {
1636
+ "images": torch.from_numpy(obs_bundle["images"]).permute(0, 3, 1, 2).unsqueeze(0).float().div(255.0).to(device),
1637
+ "depths": torch.from_numpy(obs_bundle["depths"]).unsqueeze(0).float().to(device),
1638
+ "depth_valid": torch.from_numpy(obs_bundle["depth_valid"]).unsqueeze(0).float().to(device),
1639
+ "camera_intrinsics": torch.from_numpy(obs_bundle["camera_intrinsics"]).unsqueeze(0).float().to(device),
1640
+ "camera_extrinsics": torch.from_numpy(obs_bundle["camera_extrinsics"]).unsqueeze(0).float().to(device),
1641
+ "history_images": torch.from_numpy(
1642
+ _history_stack(history, "images", pad_shape=obs_bundle["images"].shape, dtype=np.uint8, history_steps=HISTORY_STEPS)
1643
+ ).permute(0, 1, 4, 2, 3).unsqueeze(0).float().div(255.0).to(device),
1644
+ "history_depths": torch.from_numpy(
1645
+ _history_stack(history, "depths", pad_shape=obs_bundle["depths"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
1646
+ ).unsqueeze(0).float().to(device),
1647
+ "history_depth_valid": torch.from_numpy(
1648
+ _history_stack(
1649
+ history,
1650
+ "depth_valid",
1651
+ pad_shape=obs_bundle["depth_valid"].shape,
1652
+ dtype=np.float32,
1653
+ history_steps=HISTORY_STEPS,
1654
+ )
1655
+ ).unsqueeze(0).float().to(device),
1656
+ "history_camera_intrinsics": torch.from_numpy(
1657
+ _history_stack(
1658
+ history,
1659
+ "camera_intrinsics",
1660
+ pad_shape=obs_bundle["camera_intrinsics"].shape,
1661
+ dtype=np.float32,
1662
+ history_steps=HISTORY_STEPS,
1663
+ )
1664
+ ).unsqueeze(0).float().to(device),
1665
+ "history_camera_extrinsics": torch.from_numpy(
1666
+ _history_stack(
1667
+ history,
1668
+ "camera_extrinsics",
1669
+ pad_shape=obs_bundle["camera_extrinsics"].shape,
1670
+ dtype=np.float32,
1671
+ history_steps=HISTORY_STEPS,
1672
+ )
1673
+ ).unsqueeze(0).float().to(device),
1674
+ "history_proprio": torch.from_numpy(
1675
+ _history_stack(history, "proprio", pad_shape=(PROPRIO_DIM,), dtype=np.float32, history_steps=HISTORY_STEPS)
1676
+ ).unsqueeze(0).float().to(device),
1677
+ "history_actions": torch.from_numpy(
1678
+ _history_stack(history, "action", pad_shape=(14,), dtype=np.float32, history_steps=HISTORY_STEPS)
1679
+ ).unsqueeze(0).float().to(device),
1680
+ "proprio": torch.from_numpy(proprio).unsqueeze(0).float().to(device),
1681
+ "texts": [task_spec.text_prompt],
1682
+ "task_names": [task_spec.task_name],
1683
+ "task_ids": torch.as_tensor([TASK_INDEX[task_spec.task_name]], dtype=torch.long, device=device),
1684
+ }
1685
+
1686
+
1687
+ def _evaluate_checkpoint(
1688
+ *,
1689
+ task_spec: BridgeTaskSpec,
1690
+ canonical_chunks: dict[str, np.ndarray],
1691
+ checkpoint_path: Path,
1692
+ adapter_mode: str,
1693
+ result_mode_name: str,
1694
+ seeds: Sequence[int],
1695
+ report_path: Path,
1696
+ train_spec: dict[str, Any] | None,
1697
+ planner_overrides: dict[str, float] | None,
1698
+ ) -> dict[str, Any]:
1699
+ model, checkpoint = _load_checkpoint(
1700
+ checkpoint_path,
1701
+ adapter_mode=adapter_mode if adapter_mode != "trunk_only" else None,
1702
+ planner_overrides=planner_overrides,
1703
+ )
1704
+ device = next(model.parameters()).device
1705
+ obs_env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
1706
+ sim_env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
1707
+ successes: list[int] = []
1708
+ episode_records: list[dict[str, Any]] = []
1709
+ reveal_steps: list[int] = []
1710
+ retrieve_steps: list[int] = []
1711
+ disturbance_values: list[float] = []
1712
+ intervention_events = 0
1713
+ non_base_events = 0
1714
+ total_decisions = 0
1715
+ try:
1716
+ for episode_seed in seeds:
1717
+ obs, _ = obs_env.reset(seed=int(episode_seed))
1718
+ sim_env.reset(seed=int(episode_seed))
1719
+ _initialize_proxy_state(obs_env, task_spec, episode_seed=int(episode_seed))
1720
+ _sync_env_state(obs_env, sim_env)
1721
+ obs = obs_env.get_obs(obs_env.get_info())
1722
+ history: deque[dict[str, Any]] = deque(maxlen=HISTORY_STEPS)
1723
+ episode_start_positions = _all_positions(obs_env, task_spec)
1724
+ success = False
1725
+ first_reveal_step: int | None = None
1726
+ first_retrieve_step: int | None = None
1727
+ episode_disturbance: list[float] = []
1728
+ for decision_step in range(MAX_MACRO_STEPS):
1729
+ obs_bundle = _extract_sensor_bundle(obs, resolution=IMAGE_RESOLUTION)
1730
+ proprio = _build_proprio(obs_env)
1731
+ batch = _batch_from_obs(task_spec, obs_bundle, proprio, list(history), device)
1732
+ with torch.no_grad():
1733
+ if adapter_mode == "trunk_only":
1734
+ outputs = model(**batch)
1735
+ else:
1736
+ outputs = model(**batch, adapter_mode=adapter_mode, use_transition_model=True, use_task_conditioning=True)
1737
+ selected_mode, active_mask, non_base = _eval_mode_name(outputs, result_mode_name, canonical_chunks)
1738
+ start_positions = _all_positions(obs_env, task_spec)
1739
+ _sync_env_state(obs_env, sim_env)
1740
+ _execute_mode(sim_env, task_spec, selected_mode)
1741
+ _sync_env_state(sim_env, obs_env)
1742
+ obs = obs_env.get_obs(obs_env.get_info())
1743
+ post_bundle = _extract_sensor_bundle(obs, resolution=IMAGE_RESOLUTION)
1744
+ end_metrics = _candidate_metrics(
1745
+ obs_env,
1746
+ task_spec=task_spec,
1747
+ start_positions=start_positions,
1748
+ current_obs_bundle=post_bundle,
1749
+ )
1750
+ history.append(_init_history_entry(obs_bundle, proprio, canonical_chunks.get(selected_mode, canonical_chunks["base_action"])))
1751
+ total_decisions += 1
1752
+ intervention_events += int(active_mask)
1753
+ non_base_events += int(non_base)
1754
+ episode_disturbance.append(end_metrics["disturbance"])
1755
+ if selected_mode != "retrieve" and selected_mode != "base_action" and first_reveal_step is None:
1756
+ first_reveal_step = decision_step + 1
1757
+ if selected_mode == "retrieve" and first_retrieve_step is None:
1758
+ first_retrieve_step = decision_step + 1
1759
+ if (
1760
+ _bag_success(obs_env)
1761
+ if task_spec.key == "bag"
1762
+ else _cloth_success(
1763
+ obs_env,
1764
+ start_positions=episode_start_positions,
1765
+ current_visibility=end_metrics["visibility"],
1766
+ )
1767
+ ):
1768
+ success = True
1769
+ break
1770
+ successes.append(int(success))
1771
+ if first_reveal_step is not None:
1772
+ reveal_steps.append(first_reveal_step)
1773
+ if first_retrieve_step is not None:
1774
+ retrieve_steps.append(first_retrieve_step)
1775
+ disturbance_values.append(float(np.mean(episode_disturbance)) if episode_disturbance else 0.0)
1776
+ episode_records.append(
1777
+ {
1778
+ "episode_seed": int(episode_seed),
1779
+ "success": success,
1780
+ "steps": len(history),
1781
+ "first_reveal_step": first_reveal_step,
1782
+ "first_retrieve_step": first_retrieve_step,
1783
+ "episode_disturbance": float(np.mean(episode_disturbance)) if episode_disturbance else 0.0,
1784
+ }
1785
+ )
1786
+ print(
1787
+ json.dumps(
1788
+ {
1789
+ "phase": "eval_episode_complete",
1790
+ "task": task_spec.key,
1791
+ "adapter_mode": result_mode_name,
1792
+ "episode_seed": int(episode_seed),
1793
+ "success": success,
1794
+ "steps": len(history),
1795
+ }
1796
+ ),
1797
+ flush=True,
1798
+ )
1799
+ finally:
1800
+ obs_env.close()
1801
+ sim_env.close()
1802
+ payload = {
1803
+ "track_id": task_spec.track_id,
1804
+ "suite": task_spec.suite,
1805
+ "benchmark_task": task_spec.benchmark_task,
1806
+ "role": "target",
1807
+ "adapter_mode": result_mode_name,
1808
+ "episodes": len(seeds),
1809
+ "successes": successes,
1810
+ "success_rate": float(np.mean(successes)) if successes else 0.0,
1811
+ "intervention_rate": float(intervention_events / max(1, total_decisions)),
1812
+ "non_base_selection_rate": float(non_base_events / max(1, total_decisions)),
1813
+ "steps_to_first_reveal_or_access": float(np.mean(reveal_steps)) if reveal_steps else float(MAX_MACRO_STEPS),
1814
+ "steps_to_retrieve": float(np.mean(retrieve_steps)) if retrieve_steps else float(MAX_MACRO_STEPS),
1815
+ "disturbance_proxy": float(np.mean(disturbance_values)) if disturbance_values else 0.0,
1816
+ "episode_records": episode_records,
1817
+ "eval_protocol": _manual_eval_protocol(task_spec, eval_mode=result_mode_name, spec=SmokeSpec(), episodes=len(seeds)),
1818
+ "proxy_notes": task_spec.notes,
1819
+ }
1820
+ if train_spec is not None:
1821
+ payload["train_spec"] = train_spec
1822
+ report_path.parent.mkdir(parents=True, exist_ok=True)
1823
+ report_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
1824
+ return payload
1825
+
1826
+
1827
+ def _patch_summary_metadata(summary: dict[str, Any], task_spec: BridgeTaskSpec) -> dict[str, Any]:
1828
+ patched = json.loads(json.dumps(summary))
1829
+ track_payload = patched.get("tracks", {}).get(task_spec.track_id)
1830
+ if track_payload is not None:
1831
+ track_payload["suite"] = task_spec.suite
1832
+ track_payload["benchmark_task"] = task_spec.benchmark_task
1833
+ track_payload["notes"] = task_spec.notes
1834
+ track_payload["public_source"] = f"ManiSkill public scene proxy: {task_spec.env_id}"
1835
+ track_payload["task_family"] = f"{task_spec.key}_retrieval_proxy"
1836
+ track_payload["target_behavior"] = task_spec.text_prompt
1837
+ return patched
1838
+
1839
+
1840
+ def _summarize_task(task_spec: BridgeTaskSpec, results: Sequence[dict[str, Any]], output_dir: Path) -> dict[str, Any]:
1841
+ summary = summarize_public_benchmark_package(list(results), allow_partial=True)
1842
+ summary = _patch_summary_metadata(summary, task_spec)
1843
+ output_dir.mkdir(parents=True, exist_ok=True)
1844
+ json_path = output_dir / "public_benchmark_package_summary.json"
1845
+ md_path = output_dir / "public_benchmark_package_summary.md"
1846
+ json_path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8")
1847
+ track_payload = summary["tracks"][task_spec.track_id]
1848
+ lines = [
1849
+ f"# ManiSkill {task_spec.key.capitalize()} Retrieval Smoke Summary",
1850
+ "",
1851
+ f"- benchmark_task: {task_spec.benchmark_task}",
1852
+ f"- target_macro_average_delta: {summary['target_macro_average_delta']:.3f}",
1853
+ f"- headline_pass: {summary['headline_pass']}",
1854
+ f"- sign_of_life_pass: {summary['sign_of_life_pass']}",
1855
+ "",
1856
+ f"## {task_spec.track_id}",
1857
+ f"- delta_active_vs_trunk: {track_payload.get('delta_active_vs_trunk', 0.0):.3f}",
1858
+ f"- delta_noop_vs_trunk: {track_payload.get('delta_noop_vs_trunk', 0.0):.3f}",
1859
+ f"- signs_of_life: {track_payload.get('signs_of_life', False)}",
1860
+ ]
1861
+ if "delta_active_vs_trunk_ci95" in track_payload:
1862
+ low, high = track_payload["delta_active_vs_trunk_ci95"]
1863
+ lines.append(f"- delta_active_vs_trunk_ci95: [{low:.3f}, {high:.3f}]")
1864
+ for mode, mode_payload in track_payload["modes"].items():
1865
+ lines.append(f"- {mode}: mean_success={mode_payload['mean_success']:.3f}")
1866
+ lines.append("")
1867
+ md_path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
1868
+ return summary
1869
+
1870
+
1871
+ def _parse_args() -> argparse.Namespace:
1872
+ parser = argparse.ArgumentParser(description="Run a fair bridge-scene retrieval smoke for bag or cloth proxy tasks.")
1873
+ parser.add_argument("--task", choices=sorted(TASK_SPECS), required=True)
1874
+ parser.add_argument("--dataset-seed", type=int, default=DEFAULT_SEED)
1875
+ parser.add_argument("--train-seed", type=int, default=DEFAULT_SEED)
1876
+ parser.add_argument("--eval-split", choices=("val", "eval"), default="eval")
1877
+ parser.add_argument("--report-dir", type=Path, default=None)
1878
+ parser.add_argument("--skip-collection", action="store_true")
1879
+ parser.add_argument("--skip-train", action="store_true")
1880
+ parser.add_argument("--skip-eval", action="store_true")
1881
+ parser.add_argument("--reuse-dataset", action="store_true")
1882
+ parser.add_argument("--reuse-checkpoints", action="store_true")
1883
+ parser.add_argument("--adapter-confidence-threshold", type=float, default=None)
1884
+ parser.add_argument("--retrieve-access-threshold", type=float, default=None)
1885
+ parser.add_argument("--retrieve-persistence-threshold", type=float, default=None)
1886
+ parser.add_argument("--retrieve-support-threshold", type=float, default=None)
1887
+ parser.add_argument("--retrieve-reocclusion-threshold", type=float, default=None)
1888
+ parser.add_argument("--planner-mode-preference-bonus", type=float, default=None)
1889
+ parser.add_argument("--planner-premature-retrieve-penalty", type=float, default=None)
1890
+ parser.add_argument("--planner-premature-insert-penalty", type=float, default=None)
1891
+ parser.add_argument("--planner-premature-occlusion-sweep-penalty", type=float, default=None)
1892
+ parser.add_argument("--planner-premature-maintain-penalty", type=float, default=None)
1893
+ parser.add_argument("--planner-retrieve-stage-access-threshold", type=float, default=None)
1894
+ parser.add_argument("--planner-retrieve-stage-reveal-threshold", type=float, default=None)
1895
+ parser.add_argument("--planner-retrieve-stage-persistence-threshold", type=float, default=None)
1896
+ parser.add_argument("--planner-retrieve-stage-support-threshold", type=float, default=None)
1897
+ parser.add_argument("--planner-insert-stage-access-threshold", type=float, default=None)
1898
+ parser.add_argument("--planner-insert-stage-visibility-threshold", type=float, default=None)
1899
+ parser.add_argument("--planner-insert-stage-support-threshold", type=float, default=None)
1900
+ parser.add_argument("--planner-occlusion-maintain-gap-min-access", type=float, default=None)
1901
+ parser.add_argument("--planner-occlusion-maintain-gap-min-visibility", type=float, default=None)
1902
+ return parser.parse_args()
1903
+
1904
+
1905
+ def _planner_overrides_from_args(args: argparse.Namespace) -> dict[str, float]:
1906
+ overrides = {
1907
+ "adapter_confidence_threshold": SMOKE_ADAPTER_CONFIDENCE_THRESHOLD,
1908
+ "retrieve_access_threshold": SMOKE_RETRIEVE_ACCESS_THRESHOLD,
1909
+ "retrieve_persistence_threshold": SMOKE_RETRIEVE_PERSISTENCE_THRESHOLD,
1910
+ "retrieve_support_threshold": SMOKE_RETRIEVE_SUPPORT_THRESHOLD,
1911
+ "retrieve_reocclusion_threshold": SMOKE_RETRIEVE_REOCCLUSION_THRESHOLD,
1912
+ }
1913
+ optional_pairs = (
1914
+ ("adapter_confidence_threshold", args.adapter_confidence_threshold),
1915
+ ("retrieve_access_threshold", args.retrieve_access_threshold),
1916
+ ("retrieve_persistence_threshold", args.retrieve_persistence_threshold),
1917
+ ("retrieve_support_threshold", args.retrieve_support_threshold),
1918
+ ("retrieve_reocclusion_threshold", args.retrieve_reocclusion_threshold),
1919
+ ("mode_preference_bonus", args.planner_mode_preference_bonus),
1920
+ ("premature_retrieve_penalty", args.planner_premature_retrieve_penalty),
1921
+ ("premature_insert_penalty", args.planner_premature_insert_penalty),
1922
+ ("premature_occlusion_sweep_penalty", args.planner_premature_occlusion_sweep_penalty),
1923
+ ("premature_maintain_penalty", args.planner_premature_maintain_penalty),
1924
+ ("retrieve_stage_access_threshold", args.planner_retrieve_stage_access_threshold),
1925
+ ("retrieve_stage_reveal_threshold", args.planner_retrieve_stage_reveal_threshold),
1926
+ ("retrieve_stage_persistence_threshold", args.planner_retrieve_stage_persistence_threshold),
1927
+ ("retrieve_stage_support_threshold", args.planner_retrieve_stage_support_threshold),
1928
+ ("insert_stage_access_threshold", args.planner_insert_stage_access_threshold),
1929
+ ("insert_stage_visibility_threshold", args.planner_insert_stage_visibility_threshold),
1930
+ ("insert_stage_support_threshold", args.planner_insert_stage_support_threshold),
1931
+ ("occlusion_maintain_gap_min_access", args.planner_occlusion_maintain_gap_min_access),
1932
+ ("occlusion_maintain_gap_min_visibility", args.planner_occlusion_maintain_gap_min_visibility),
1933
+ )
1934
+ for key, value in optional_pairs:
1935
+ if value is not None:
1936
+ overrides[key] = value
1937
+ return overrides
1938
+
1939
+
1940
+ def main() -> None:
1941
+ args = _parse_args()
1942
+ task_spec = _task_spec(args.task)
1943
+ spec = SmokeSpec(dataset_seed=int(args.dataset_seed), train_seed=int(args.train_seed))
1944
+ canonical_chunks = _canonical_chunks(task_spec)
1945
+ paths = _default_paths(task_spec)
1946
+ report_dir = args.report_dir or paths.report_dir
1947
+ planner_overrides = _planner_overrides_from_args(args)
1948
+ split_path = _dataset_artifact_path(paths.data_dir, "episode_splits.json", dataset_seed=spec.dataset_seed)
1949
+ if split_path.exists():
1950
+ episode_splits = json.loads(split_path.read_text(encoding="utf-8"))
1951
+ else:
1952
+ episode_splits = _build_episode_splits(task_spec, spec)
1953
+ _save_episode_splits(split_path, episode_splits)
1954
+
1955
+ train_path = _dataset_artifact_path(paths.data_dir, "train.pt", dataset_seed=spec.dataset_seed)
1956
+ val_path = _dataset_artifact_path(paths.data_dir, "val.pt", dataset_seed=spec.dataset_seed)
1957
+ if args.skip_collection and (not train_path.exists() or not val_path.exists()):
1958
+ raise FileNotFoundError("Requested --skip-collection but cached dataset files are missing.")
1959
+ if not args.skip_collection and (not args.reuse_dataset or not train_path.exists() or not val_path.exists()):
1960
+ train_payload = _collect_split(
1961
+ task_spec=task_spec,
1962
+ canonical_chunks=canonical_chunks,
1963
+ split_name="train",
1964
+ seeds=episode_splits["train"],
1965
+ spec=spec,
1966
+ output_path=train_path,
1967
+ )
1968
+ val_payload = _collect_split(
1969
+ task_spec=task_spec,
1970
+ canonical_chunks=canonical_chunks,
1971
+ split_name="val",
1972
+ seeds=episode_splits["val"],
1973
+ spec=spec,
1974
+ output_path=val_path,
1975
+ )
1976
+ else:
1977
+ train_payload = torch.load(train_path, map_location="cpu", weights_only=False)
1978
+ val_payload = torch.load(val_path, map_location="cpu", weights_only=False)
1979
+
1980
+ train_samples = _normalize_cached_samples(train_payload["samples"])
1981
+ val_samples = _normalize_cached_samples(val_payload["samples"])
1982
+ checkpoints: dict[str, Path] = {}
1983
+ train_specs: dict[str, dict[str, Any]] = {}
1984
+ for variant in ("trunk_only_ft", "adapter_active_ft"):
1985
+ variant_output_dir = paths.output_dir / f"{variant}_seed{spec.train_seed}"
1986
+ checkpoint_path = variant_output_dir / "checkpoint_best.pt"
1987
+ if args.skip_train and not checkpoint_path.exists():
1988
+ raise FileNotFoundError(f"Requested --skip-train but checkpoint is missing: {checkpoint_path}")
1989
+ if not args.skip_train and (not args.reuse_checkpoints or not checkpoint_path.exists()):
1990
+ checkpoint_path, train_spec = _train_variant(
1991
+ task_spec=task_spec,
1992
+ variant=variant,
1993
+ train_samples=train_samples,
1994
+ val_samples=val_samples,
1995
+ spec=spec,
1996
+ output_dir=variant_output_dir,
1997
+ )
1998
+ else:
1999
+ summary_path = variant_output_dir / "summary.json"
2000
+ if not summary_path.exists():
2001
+ raise FileNotFoundError(f"Missing cached summary file for {variant}: {summary_path}")
2002
+ summary_payload = json.loads(summary_path.read_text(encoding="utf-8"))
2003
+ train_spec = summary_payload["train_spec"]
2004
+ checkpoints[variant] = checkpoint_path
2005
+ train_specs[variant] = train_spec
2006
+
2007
+ results: list[dict[str, Any]] = []
2008
+ if not args.skip_eval:
2009
+ eval_plan = (
2010
+ ("trunk_only_ft", checkpoints["trunk_only_ft"], "trunk_only", None),
2011
+ ("adapter_noop", checkpoints["adapter_active_ft"], "adapter_noop", None),
2012
+ ("adapter_active_ft", checkpoints["adapter_active_ft"], "adapter_active", train_specs["adapter_active_ft"]),
2013
+ )
2014
+ for result_mode_name, checkpoint_path, adapter_mode, train_spec in eval_plan:
2015
+ result = _evaluate_checkpoint(
2016
+ task_spec=task_spec,
2017
+ canonical_chunks=canonical_chunks,
2018
+ checkpoint_path=checkpoint_path,
2019
+ adapter_mode=adapter_mode,
2020
+ result_mode_name=result_mode_name,
2021
+ seeds=episode_splits[args.eval_split],
2022
+ report_path=report_dir / f"{result_mode_name}_seed{spec.train_seed}.json",
2023
+ train_spec=train_spec if result_mode_name != "adapter_noop" else None,
2024
+ planner_overrides=planner_overrides,
2025
+ )
2026
+ if result_mode_name == "trunk_only_ft":
2027
+ result["train_spec"] = train_specs["trunk_only_ft"]
2028
+ (report_dir / f"{result_mode_name}_seed{spec.train_seed}.json").write_text(
2029
+ json.dumps(result, indent=2) + "\n",
2030
+ encoding="utf-8",
2031
+ )
2032
+ results.append(result)
2033
+ _summarize_task(task_spec, results, report_dir)
2034
+
2035
+
2036
+ if __name__ == "__main__":
2037
+ main()
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_maniskill_pickclutter_smoke.py ADDED
@@ -0,0 +1,2005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import collections
5
+ import json
6
+ import math
7
+ import os
8
+ import time
9
+ from collections import deque
10
+ from dataclasses import asdict, dataclass
11
+ from pathlib import Path
12
+ from typing import Any, Iterable, Sequence
13
+
14
+ import numpy as np
15
+ import torch
16
+ from omegaconf import OmegaConf
17
+ from torch import Tensor
18
+ from torch.utils.data import DataLoader, Dataset
19
+
20
+ from eval.public_benchmark_package import build_public_eval_protocol, build_target_training_spec
21
+ from models.action_decoder import ChunkDecoderConfig, semantic_macro_chunk
22
+ from models.backbones import FrozenVLBackboneConfig
23
+ from models.multiview_fusion import MultiViewFusionConfig
24
+ from models.observation_memory import ObservationMemoryConfig
25
+ from models.planner import PlannerConfig
26
+ from models.policy import PolicyConfig
27
+ from models.reveal_head import RevealHeadConfig, TASK_METRIC_NAMES
28
+ from models.world_model import RevealWMConfig
29
+ from train.checkpoint_compat import filter_compatible_state_dict
30
+ from train.losses import LossWeights
31
+ from train.trainer import BimanualTrainer, TrainerConfig, apply_trainable_parameter_prefixes, build_policy
32
+
33
+
34
+ def _configure_runtime_env() -> None:
35
+ os.environ.setdefault("VK_ICD_FILENAMES", "/workspace/runtime/vulkan/icd.d/nvidia_icd_egl.json")
36
+ os.environ.setdefault("VK_LAYER_PATH", "/workspace/runtime/vulkan/implicit_layer.d")
37
+ os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp/runtime-root")
38
+ os.environ.setdefault("MS_ASSET_DIR", "/workspace/data/maniskill")
39
+
40
+
41
+ _configure_runtime_env()
42
+
43
+ import mani_skill.envs # noqa: E402
44
+ import sapien # noqa: E402
45
+ from mani_skill.envs.tasks.tabletop.pick_clutter_ycb import PickClutterYCBEnv # noqa: E402
46
+ from mani_skill.sensors.camera import CameraConfig # noqa: E402
47
+ from mani_skill.utils import sapien_utils # noqa: E402
48
+ from mani_skill.utils.structs import Pose # noqa: E402
49
+
50
+ from eval.run_public_benchmark_package import summarize_public_benchmark_package # noqa: E402
51
+ from models.action_decoder import TASK_INDEX # noqa: E402
52
+ from train.run_experiment import _load_init_checkpoint, _move_batch_to_device # noqa: E402
53
+
54
+
55
+ REPO_ROOT = Path(__file__).resolve().parents[3]
56
+ WORKSPACE_ROOT = Path("/workspace/workspace")
57
+ SMOKE_VERSION = "smoke_v5"
58
+ DEFAULT_DATA_DIR = WORKSPACE_ROOT / "data" / "maniskill_pickclutter" / SMOKE_VERSION
59
+ DEFAULT_OUTPUT_DIR = WORKSPACE_ROOT / "outputs" / f"maniskill_pickclutter_{SMOKE_VERSION}"
60
+ DEFAULT_REPORT_DIR = WORKSPACE_ROOT / "reports" / f"maniskill_pickclutter_{SMOKE_VERSION}"
61
+ DEFAULT_INIT_CHECKPOINT = Path(
62
+ "/workspace/workspace/VLAarchtests2/VLAarchtests/artifacts/outputs/"
63
+ "r3d_handoff_phase/proxy_interaction_r3d_stage3_clip_rgbd_handoff_compact_phase_seed17/checkpoint_best.pt"
64
+ )
65
+
66
+ TEXT_PROMPT = "retrieve the target object from dense clutter and stage it at the front edge"
67
+ TASK_NAME = "foliage"
68
+ TASK_ID = TASK_INDEX[TASK_NAME]
69
+ CAMERA_NAMES = ("front", "left", "right")
70
+ MODE_ORDER = (
71
+ "base_action",
72
+ "sweep_left",
73
+ "sweep_right",
74
+ "pin_canopy",
75
+ "widen_gap",
76
+ "maintain_gap",
77
+ "insert_actor",
78
+ "retrieve",
79
+ )
80
+ ROLL_OUT_HORIZON = 5
81
+ NUM_SUPPORT_MODES = 3
82
+ NUM_APPROACH_TEMPLATES = 32
83
+ SUPPORT_MODE_HOLD = 0
84
+ SUPPORT_MODE_TRANSFER = 1
85
+ SUPPORT_MODE_PASSIVE = 2
86
+ REVEAL_MODES = ("sweep_left", "sweep_right", "pin_canopy", "widen_gap", "maintain_gap")
87
+ TRANSFER_MODES = ("insert_actor",)
88
+ RETRIEVE_MODES = ("retrieve",)
89
+ STATE_SUPERVISION_METRICS = (
90
+ "opening_quality",
91
+ "actor_feasibility_score",
92
+ "gap_width",
93
+ "damage_proxy",
94
+ "release_collapse_rate",
95
+ "target_visibility_confidence",
96
+ "insertable_actor_corridor",
97
+ "insertion_corridor",
98
+ "hold_quality",
99
+ "layer_separation_quality",
100
+ "fold_preservation",
101
+ "top_layer_stability",
102
+ "lift_too_much_risk",
103
+ )
104
+ MAX_MACRO_STEPS = 4
105
+ HISTORY_STEPS = 6
106
+ PROPRIO_DIM = 32
107
+ EXTRACTION_LINE_Y = -0.22
108
+ MIN_CLEARANCE_FOR_SUCCESS = 0.05
109
+ DEFAULT_SEED = 17
110
+ SMOKE_ADAPTER_CONFIDENCE_THRESHOLD = 0.50
111
+ SMOKE_RETRIEVE_ACCESS_THRESHOLD = 0.08
112
+ SMOKE_RETRIEVE_PERSISTENCE_THRESHOLD = 0.12
113
+ SMOKE_RETRIEVE_SUPPORT_THRESHOLD = 0.08
114
+ SMOKE_RETRIEVE_REOCCLUSION_THRESHOLD = 0.92
115
+
116
+
117
+ @dataclass(frozen=True)
118
+ class SmokePaths:
119
+ data_dir: Path = DEFAULT_DATA_DIR
120
+ output_dir: Path = DEFAULT_OUTPUT_DIR
121
+ report_dir: Path = DEFAULT_REPORT_DIR
122
+
123
+
124
+ @dataclass(frozen=True)
125
+ class SmokeSpec:
126
+ resolution: int = 224
127
+ train_episodes: int = 32
128
+ val_episodes: int = 8
129
+ eval_episodes: int = 50
130
+ dataset_seed: int = DEFAULT_SEED
131
+ train_seed: int = DEFAULT_SEED
132
+ history_steps: int = HISTORY_STEPS
133
+ max_macro_steps: int = MAX_MACRO_STEPS
134
+ batch_size: int = 4
135
+ epochs: int = 6
136
+ num_workers: int = 16
137
+ learning_rate: float = 1e-4
138
+ weight_decay: float = 1e-4
139
+
140
+ @property
141
+ def seed(self) -> int:
142
+ return self.train_seed
143
+
144
+
145
+ def _apply_smoke_planner_overrides(
146
+ policy_config: PolicyConfig,
147
+ planner_overrides: dict[str, float] | None = None,
148
+ ) -> PolicyConfig:
149
+ policy_config.planner.adapter_confidence_threshold = SMOKE_ADAPTER_CONFIDENCE_THRESHOLD
150
+ policy_config.planner.retrieve_access_threshold = SMOKE_RETRIEVE_ACCESS_THRESHOLD
151
+ policy_config.planner.retrieve_persistence_threshold = SMOKE_RETRIEVE_PERSISTENCE_THRESHOLD
152
+ policy_config.planner.retrieve_support_threshold = SMOKE_RETRIEVE_SUPPORT_THRESHOLD
153
+ policy_config.planner.retrieve_reocclusion_threshold = SMOKE_RETRIEVE_REOCCLUSION_THRESHOLD
154
+ if planner_overrides:
155
+ for key, value in planner_overrides.items():
156
+ if value is None:
157
+ continue
158
+ setattr(policy_config.planner, key, value)
159
+ return policy_config
160
+
161
+
162
+ class PickClutterRevealEnv(PickClutterYCBEnv):
163
+ @property
164
+ def _default_sensor_configs(self):
165
+ resolution = 224
166
+ return [
167
+ CameraConfig(
168
+ "front",
169
+ pose=sapien_utils.look_at(eye=[0.30, 0.00, 0.62], target=[-0.06, 0.00, 0.04]),
170
+ width=resolution,
171
+ height=resolution,
172
+ fov=np.pi / 2,
173
+ near=0.01,
174
+ far=100.0,
175
+ ),
176
+ CameraConfig(
177
+ "left",
178
+ pose=sapien_utils.look_at(eye=[0.22, 0.34, 0.34], target=[-0.02, 0.02, 0.03]),
179
+ width=resolution,
180
+ height=resolution,
181
+ fov=np.pi / 2,
182
+ near=0.01,
183
+ far=100.0,
184
+ ),
185
+ CameraConfig(
186
+ "right",
187
+ pose=sapien_utils.look_at(eye=[0.22, -0.34, 0.34], target=[-0.02, -0.02, 0.03]),
188
+ width=resolution,
189
+ height=resolution,
190
+ fov=np.pi / 2,
191
+ near=0.01,
192
+ far=100.0,
193
+ ),
194
+ ]
195
+
196
+
197
+ def _np(value: Any, *, dtype: np.dtype | None = None) -> np.ndarray:
198
+ if isinstance(value, np.ndarray):
199
+ array = value
200
+ elif isinstance(value, Tensor):
201
+ array = value.detach().cpu().numpy()
202
+ else:
203
+ array = np.asarray(value)
204
+ if dtype is not None:
205
+ array = array.astype(dtype, copy=False)
206
+ return array
207
+
208
+
209
+ def _vec3(value: Any) -> np.ndarray:
210
+ return _np(value, dtype=np.float32).reshape(-1)[:3]
211
+
212
+
213
+ def _camera_intrinsic_from_param(param: dict[str, Any]) -> np.ndarray:
214
+ for key in ("intrinsic_cv", "intrinsic", "cam_intrinsic"):
215
+ if key in param:
216
+ matrix = _np(param[key], dtype=np.float32)
217
+ return matrix[0] if matrix.ndim == 3 else matrix
218
+ return np.eye(3, dtype=np.float32)
219
+
220
+
221
+ def _camera_extrinsic_from_param(param: dict[str, Any]) -> np.ndarray:
222
+ for key in ("cam2world_gl", "cam2world", "extrinsic_cv", "extrinsic"):
223
+ if key in param:
224
+ matrix = _np(param[key], dtype=np.float32)
225
+ return matrix[0] if matrix.ndim == 3 else matrix
226
+ return np.eye(4, dtype=np.float32)
227
+
228
+
229
+ def _convert_depth(depth: np.ndarray) -> np.ndarray:
230
+ depth = depth.astype(np.float32, copy=False)
231
+ if np.issubdtype(depth.dtype, np.integer):
232
+ depth = depth / 1000.0
233
+ return depth
234
+
235
+
236
+ def _build_proprio(env: PickClutterRevealEnv) -> np.ndarray:
237
+ base = env.unwrapped
238
+ qpos = _np(base.agent.robot.get_qpos(), dtype=np.float32).reshape(-1)
239
+ qvel = _np(base.agent.robot.get_qvel(), dtype=np.float32).reshape(-1)
240
+ tcp_pose = _np(base.agent.tcp.pose.raw_pose, dtype=np.float32).reshape(-1)
241
+ gripper_width = qpos[-2:].sum(keepdims=True).astype(np.float32)
242
+ pieces = [qpos, qvel, tcp_pose, gripper_width]
243
+ flat = np.concatenate(pieces, axis=0)
244
+ if flat.shape[0] >= PROPRIO_DIM:
245
+ return flat[:PROPRIO_DIM]
246
+ padded = np.zeros((PROPRIO_DIM,), dtype=np.float32)
247
+ padded[: flat.shape[0]] = flat
248
+ return padded
249
+
250
+
251
+ def _extract_sensor_bundle(obs: dict[str, Any]) -> dict[str, np.ndarray]:
252
+ sensor_data = obs["sensor_data"]
253
+ sensor_param = obs["sensor_param"]
254
+ rgb_views: list[np.ndarray] = []
255
+ depth_views: list[np.ndarray] = []
256
+ seg_views: list[np.ndarray] = []
257
+ intrinsics: list[np.ndarray] = []
258
+ extrinsics: list[np.ndarray] = []
259
+ for camera_name in CAMERA_NAMES:
260
+ view = sensor_data[camera_name]
261
+ param = sensor_param[camera_name]
262
+ rgb = _np(view["rgb"], dtype=np.uint8)
263
+ depth = _np(view["depth"])
264
+ segmentation = _np(view["segmentation"])
265
+ rgb = rgb[0] if rgb.ndim == 4 else rgb
266
+ depth = depth[0] if depth.ndim == 4 else depth
267
+ segmentation = segmentation[0] if segmentation.ndim == 4 else segmentation
268
+ if depth.ndim == 3 and depth.shape[-1] == 1:
269
+ depth = depth[..., 0]
270
+ if segmentation.ndim == 3 and segmentation.shape[-1] == 1:
271
+ segmentation = segmentation[..., 0]
272
+ rgb_views.append(rgb.astype(np.uint8, copy=False))
273
+ depth_views.append(_convert_depth(depth))
274
+ seg_views.append(segmentation.astype(np.int32, copy=False))
275
+ intrinsics.append(_camera_intrinsic_from_param(param))
276
+ extrinsics.append(_camera_extrinsic_from_param(param))
277
+ depth_stack = np.stack(depth_views, axis=0).astype(np.float32)
278
+ depth_valid = (depth_stack > 1e-5).astype(np.float32)
279
+ return {
280
+ "images": np.stack(rgb_views, axis=0),
281
+ "depths": depth_stack[:, None, :, :],
282
+ "depth_valid": depth_valid[:, None, :, :],
283
+ "segmentations": np.stack(seg_views, axis=0),
284
+ "camera_intrinsics": np.stack(intrinsics, axis=0).astype(np.float32),
285
+ "camera_extrinsics": np.stack(extrinsics, axis=0).astype(np.float32),
286
+ }
287
+
288
+
289
+ def _target_actor(env: PickClutterRevealEnv) -> Any:
290
+ return env.unwrapped.target_object._objs[0]
291
+
292
+
293
+ def _all_scene_actors(env: PickClutterRevealEnv) -> list[Any]:
294
+ return list(env.unwrapped.all_objects._objs)
295
+
296
+
297
+ def _target_position(env: PickClutterRevealEnv) -> np.ndarray:
298
+ return _vec3(_target_actor(env).pose.p)
299
+
300
+
301
+ def _all_positions(env: PickClutterRevealEnv) -> dict[str, np.ndarray]:
302
+ return {actor.name: _vec3(actor.pose.p) for actor in _all_scene_actors(env)}
303
+
304
+
305
+ def _nearest_non_target_distance(env: PickClutterRevealEnv) -> float:
306
+ target = _target_actor(env)
307
+ target_xy = _vec3(target.pose.p)[:2]
308
+ distances = []
309
+ for actor in _all_scene_actors(env):
310
+ if actor.name == target.name:
311
+ continue
312
+ distances.append(float(np.linalg.norm(_vec3(actor.pose.p)[:2] - target_xy)))
313
+ if not distances:
314
+ return 1.0
315
+ return float(min(distances))
316
+
317
+
318
+ def _success_from_state(env: PickClutterRevealEnv) -> bool:
319
+ target = _target_position(env)
320
+ return bool(target[1] <= EXTRACTION_LINE_Y and _nearest_non_target_distance(env) >= MIN_CLEARANCE_FOR_SUCCESS)
321
+
322
+
323
+ def _clearance_score(env: PickClutterRevealEnv) -> float:
324
+ return float(np.clip((_nearest_non_target_distance(env) - 0.03) / 0.09, 0.0, 1.0))
325
+
326
+
327
+ def _extraction_progress(env: PickClutterRevealEnv) -> float:
328
+ y_value = _target_position(env)[1]
329
+ return float(np.clip(((-0.05) - y_value) / ((-0.05) - EXTRACTION_LINE_Y), 0.0, 1.0))
330
+
331
+
332
+ def _target_visibility(obs_bundle: dict[str, np.ndarray], target_seg_id: int) -> float:
333
+ segmentation = obs_bundle["segmentations"]
334
+ fractions = [(view == int(target_seg_id)).mean() for view in segmentation]
335
+ return float(np.clip(np.mean(fractions) * 80.0, 0.0, 1.0))
336
+
337
+
338
+ def _snapshot_env(env: PickClutterRevealEnv) -> dict[str, Any]:
339
+ base = env.unwrapped
340
+ return {
341
+ "state_dict": base.get_state_dict(),
342
+ "goal_pos": _np(base.goal_pos, dtype=np.float32).copy(),
343
+ }
344
+
345
+
346
+ def _restore_env(env: PickClutterRevealEnv, snapshot: dict[str, Any]) -> None:
347
+ base = env.unwrapped
348
+ state_dict = snapshot["state_dict"]
349
+ goal_pos = torch.as_tensor(snapshot["goal_pos"], dtype=torch.float32, device=base.device)
350
+ base.set_state_dict(state_dict)
351
+ base.goal_pos = goal_pos.view_as(base.goal_pos)
352
+ base.goal_site.set_pose(Pose.create_from_pq(base.goal_pos))
353
+
354
+
355
+ def _sync_env_state(src_env: PickClutterRevealEnv, dst_env: PickClutterRevealEnv) -> None:
356
+ _restore_env(dst_env, _snapshot_env(src_env))
357
+
358
+
359
+ def _canonical_chunks() -> dict[str, np.ndarray]:
360
+ base = torch.zeros((1, 8, 14), dtype=torch.float32)
361
+ chunks: dict[str, np.ndarray] = {"base_action": base.squeeze(0).numpy().astype(np.float32)}
362
+ for mode_name in MODE_ORDER[1:]:
363
+ chunk = semantic_macro_chunk(base, task_name=TASK_NAME, mode_name=mode_name).squeeze(0).cpu().numpy()
364
+ chunks[mode_name] = chunk.astype(np.float32)
365
+ return chunks
366
+
367
+
368
+ CANONICAL_CHUNKS = _canonical_chunks()
369
+ STATE_METRIC_MASK = np.asarray(
370
+ [metric_name in STATE_SUPERVISION_METRICS for metric_name in TASK_METRIC_NAMES],
371
+ dtype=np.bool_,
372
+ )
373
+
374
+
375
+ def _classify_mode_from_chunk(chunk: np.ndarray) -> str:
376
+ candidate = np.asarray(chunk, dtype=np.float32)
377
+ distances = {
378
+ mode_name: float(np.mean(np.abs(candidate - prototype)))
379
+ for mode_name, prototype in CANONICAL_CHUNKS.items()
380
+ }
381
+ return min(distances, key=distances.get)
382
+
383
+
384
+ def _gripper_action(open_gripper: bool) -> float:
385
+ return 1.0 if open_gripper else -1.0
386
+
387
+
388
+ def _repeat_delta(env: PickClutterRevealEnv, delta_xyz: Sequence[float], *, open_gripper: bool, steps: int) -> dict[str, Any]:
389
+ last_obs: dict[str, Any] | None = None
390
+ action = np.zeros((1, 4), dtype=np.float32)
391
+ action[0, :3] = np.asarray(delta_xyz, dtype=np.float32)
392
+ action[0, 3] = _gripper_action(open_gripper)
393
+ for _ in range(int(steps)):
394
+ obs, _, terminated, truncated, info = env.step(action)
395
+ last_obs = obs
396
+ if bool(np.asarray(terminated).reshape(-1)[0]) or bool(np.asarray(truncated).reshape(-1)[0]):
397
+ break
398
+ return {
399
+ "obs": last_obs if last_obs is not None else env.get_obs(env.get_info()),
400
+ "terminated": False,
401
+ "truncated": False,
402
+ "info": info if last_obs is not None else env.get_info(),
403
+ }
404
+
405
+
406
+ def _move_tcp_to(
407
+ env: PickClutterRevealEnv,
408
+ target_xyz: Sequence[float],
409
+ *,
410
+ open_gripper: bool,
411
+ max_steps: int = 120,
412
+ tolerance: float = 0.008,
413
+ ) -> dict[str, Any]:
414
+ last_obs: dict[str, Any] | None = None
415
+ target = np.asarray(target_xyz, dtype=np.float32)
416
+ info = env.get_info()
417
+ for _ in range(int(max_steps)):
418
+ tcp = _vec3(env.unwrapped.agent.tcp.pose.p)
419
+ delta = target - tcp
420
+ if float(np.linalg.norm(delta)) <= float(tolerance):
421
+ break
422
+ action = np.zeros((1, 4), dtype=np.float32)
423
+ action[0, :3] = np.clip(delta / 0.04, -1.0, 1.0)
424
+ action[0, 3] = _gripper_action(open_gripper)
425
+ obs, _, terminated, truncated, info = env.step(action)
426
+ last_obs = obs
427
+ if bool(np.asarray(terminated).reshape(-1)[0]) or bool(np.asarray(truncated).reshape(-1)[0]):
428
+ break
429
+ return {
430
+ "obs": last_obs if last_obs is not None else env.get_obs(info),
431
+ "info": info,
432
+ }
433
+
434
+
435
+ def _find_path_blocker(env: PickClutterRevealEnv) -> np.ndarray | None:
436
+ target = _target_position(env)
437
+ target_name = _target_actor(env).name
438
+ blockers: list[tuple[float, np.ndarray]] = []
439
+ for actor in _all_scene_actors(env):
440
+ if actor.name == target_name:
441
+ continue
442
+ position = _vec3(actor.pose.p)
443
+ if position[1] <= target[1] + 0.06 and abs(position[0] - target[0]) <= 0.10:
444
+ blockers.append((float(np.linalg.norm(position[:2] - target[:2])), position))
445
+ if blockers:
446
+ blockers.sort(key=lambda item: item[0])
447
+ return blockers[0][1]
448
+ nearest: tuple[float, np.ndarray] | None = None
449
+ for actor in _all_scene_actors(env):
450
+ if actor.name == target_name:
451
+ continue
452
+ position = _vec3(actor.pose.p)
453
+ distance = float(np.linalg.norm(position[:2] - target[:2]))
454
+ if nearest is None or distance < nearest[0]:
455
+ nearest = (distance, position)
456
+ return None if nearest is None else nearest[1]
457
+
458
+
459
+ def _execute_push(
460
+ env: PickClutterRevealEnv,
461
+ *,
462
+ anchor_xyz: np.ndarray,
463
+ pre_offset: np.ndarray,
464
+ push_delta: np.ndarray,
465
+ push_steps: int,
466
+ ) -> dict[str, Any]:
467
+ _move_tcp_to(env, anchor_xyz + np.array([0.0, 0.0, 0.08], dtype=np.float32) + pre_offset, open_gripper=True)
468
+ _move_tcp_to(env, anchor_xyz + pre_offset, open_gripper=True, max_steps=100, tolerance=0.010)
469
+ result = _repeat_delta(env, push_delta, open_gripper=True, steps=push_steps)
470
+ _move_tcp_to(
471
+ env,
472
+ np.array([_vec3(env.unwrapped.agent.tcp.pose.p)[0], _vec3(env.unwrapped.agent.tcp.pose.p)[1], 0.10], dtype=np.float32),
473
+ open_gripper=True,
474
+ max_steps=80,
475
+ tolerance=0.012,
476
+ )
477
+ return result
478
+
479
+
480
+ def _execute_mode(env: PickClutterRevealEnv, mode_name: str) -> dict[str, Any]:
481
+ target = _target_position(env)
482
+ blocker = _find_path_blocker(env)
483
+ if mode_name == "retrieve":
484
+ return _execute_push(
485
+ env,
486
+ anchor_xyz=target,
487
+ pre_offset=np.array([0.0, 0.035, 0.026], dtype=np.float32),
488
+ push_delta=np.array([0.0, -0.7, 0.0], dtype=np.float32),
489
+ push_steps=18,
490
+ )
491
+ if mode_name == "insert_actor":
492
+ return _execute_push(
493
+ env,
494
+ anchor_xyz=target,
495
+ pre_offset=np.array([0.0, 0.045, 0.028], dtype=np.float32),
496
+ push_delta=np.array([0.0, -0.4, 0.0], dtype=np.float32),
497
+ push_steps=10,
498
+ )
499
+ if mode_name == "widen_gap":
500
+ anchor = blocker if blocker is not None else target
501
+ direction = -1.0 if anchor[0] >= target[0] else 1.0
502
+ return _execute_push(
503
+ env,
504
+ anchor_xyz=anchor,
505
+ pre_offset=np.array([0.0, 0.025, 0.028], dtype=np.float32),
506
+ push_delta=np.array([0.75 * direction, -0.12, 0.0], dtype=np.float32),
507
+ push_steps=18,
508
+ )
509
+ if mode_name == "sweep_left":
510
+ anchor = blocker if blocker is not None else target
511
+ return _execute_push(
512
+ env,
513
+ anchor_xyz=anchor,
514
+ pre_offset=np.array([0.015, 0.025, 0.028], dtype=np.float32),
515
+ push_delta=np.array([-0.70, -0.10, 0.0], dtype=np.float32),
516
+ push_steps=14,
517
+ )
518
+ if mode_name == "sweep_right":
519
+ anchor = blocker if blocker is not None else target
520
+ return _execute_push(
521
+ env,
522
+ anchor_xyz=anchor,
523
+ pre_offset=np.array([-0.015, 0.025, 0.028], dtype=np.float32),
524
+ push_delta=np.array([0.70, -0.10, 0.0], dtype=np.float32),
525
+ push_steps=14,
526
+ )
527
+ if mode_name == "pin_canopy":
528
+ anchor = blocker if blocker is not None else target
529
+ return _execute_push(
530
+ env,
531
+ anchor_xyz=anchor,
532
+ pre_offset=np.array([0.0, -0.015, 0.028], dtype=np.float32),
533
+ push_delta=np.array([0.0, 0.35, 0.0], dtype=np.float32),
534
+ push_steps=10,
535
+ )
536
+ if mode_name in {"maintain_gap", "base_action"}:
537
+ _move_tcp_to(
538
+ env,
539
+ np.array([target[0], target[1] + 0.02, 0.10], dtype=np.float32),
540
+ open_gripper=True,
541
+ max_steps=60,
542
+ tolerance=0.015,
543
+ )
544
+ return _repeat_delta(env, np.array([0.0, -0.10, 0.0], dtype=np.float32), open_gripper=True, steps=4)
545
+ raise KeyError(f"Unsupported mode: {mode_name}")
546
+
547
+
548
+ def _candidate_metrics(
549
+ env: PickClutterRevealEnv,
550
+ *,
551
+ start_positions: dict[str, np.ndarray],
552
+ current_obs_bundle: dict[str, np.ndarray] | None = None,
553
+ ) -> dict[str, float]:
554
+ positions = _all_positions(env)
555
+ target_name = _target_actor(env).name
556
+ non_target_displacements = []
557
+ for name, start_position in start_positions.items():
558
+ if name == target_name or name not in positions:
559
+ continue
560
+ non_target_displacements.append(float(np.linalg.norm((positions[name] - start_position)[:2])))
561
+ disturbance = float(np.clip(np.mean(non_target_displacements) / 0.10, 0.0, 1.0)) if non_target_displacements else 0.0
562
+ visibility = 0.0
563
+ if current_obs_bundle is not None:
564
+ visibility = _target_visibility(current_obs_bundle, getattr(_target_actor(env), "per_scene_id", -1))
565
+ return {
566
+ "retrieval_success": float(_success_from_state(env)),
567
+ "disturbance": disturbance,
568
+ "visibility": visibility,
569
+ "clearance": _clearance_score(env),
570
+ "progress": _extraction_progress(env),
571
+ }
572
+
573
+
574
+ def _mean_non_target_displacement(
575
+ start_positions: dict[str, np.ndarray],
576
+ current_positions: dict[str, np.ndarray],
577
+ *,
578
+ target_name: str,
579
+ ) -> float:
580
+ displacements = []
581
+ for actor_name, start_position in start_positions.items():
582
+ if actor_name == target_name or actor_name not in current_positions:
583
+ continue
584
+ displacements.append(float(np.linalg.norm((current_positions[actor_name] - start_position)[:2])))
585
+ if not displacements:
586
+ return 0.0
587
+ return float(np.mean(displacements))
588
+
589
+
590
+ def _current_state_targets(
591
+ env: PickClutterRevealEnv,
592
+ *,
593
+ obs_bundle: dict[str, np.ndarray],
594
+ candidate_metrics: Sequence[dict[str, float]],
595
+ episode_start_positions: dict[str, np.ndarray],
596
+ selected_mode: str,
597
+ ) -> dict[str, Any]:
598
+ metrics_by_name = {mode_name: payload for mode_name, payload in zip(MODE_ORDER, candidate_metrics)}
599
+ current_positions = _all_positions(env)
600
+ target_name = _target_actor(env).name
601
+ current_disturbance = float(
602
+ np.clip(
603
+ _mean_non_target_displacement(
604
+ episode_start_positions,
605
+ current_positions,
606
+ target_name=target_name,
607
+ )
608
+ / 0.10,
609
+ 0.0,
610
+ 1.0,
611
+ )
612
+ )
613
+ current_visibility = _target_visibility(obs_bundle, getattr(_target_actor(env), "per_scene_id", -1))
614
+ current_clearance = _clearance_score(env)
615
+ current_progress = _extraction_progress(env)
616
+ base_gap = float(np.clip(max(current_clearance, current_progress), 0.0, 1.0))
617
+ support_stability = float(np.clip(1.0 - 0.5 * current_disturbance, 0.0, 1.0))
618
+ hold_quality = float(np.clip(0.5 * (support_stability + max(current_clearance, current_progress)), 0.0, 1.0))
619
+ opening_quality = float(
620
+ np.clip(0.55 * current_progress + 0.25 * current_clearance + 0.20 * current_visibility, 0.0, 1.0)
621
+ )
622
+ actor_feasibility = float(np.clip(0.6 * current_clearance + 0.4 * max(current_visibility, current_progress), 0.0, 1.0))
623
+ reocclusion_rate = float(np.clip(1.0 - max(current_clearance, current_visibility), 0.0, 1.0))
624
+ insertable_actor_corridor = float(np.clip(0.6 * actor_feasibility + 0.4 * base_gap, 0.0, 1.0))
625
+ insertion_corridor = float(np.clip(0.5 * actor_feasibility + 0.5 * base_gap, 0.0, 1.0))
626
+ layer_separation = float(np.clip(0.7 * base_gap + 0.3 * actor_feasibility, 0.0, 1.0))
627
+ fold_preservation = float(np.clip(1.0 - current_disturbance, 0.0, 1.0))
628
+ lift_too_much_risk = float(np.clip(current_disturbance + 0.5 * max(base_gap - 0.5, 0.0), 0.0, 1.0))
629
+ task_metrics = {
630
+ "opening_quality": opening_quality,
631
+ "actor_feasibility_score": actor_feasibility,
632
+ "gap_width": float(0.03 + 0.21 * base_gap),
633
+ "damage_proxy": current_disturbance,
634
+ "release_collapse_rate": reocclusion_rate,
635
+ "target_visibility_confidence": current_visibility,
636
+ "insertable_actor_corridor": insertable_actor_corridor,
637
+ "insertion_corridor": insertion_corridor,
638
+ "hold_quality": hold_quality,
639
+ "layer_separation_quality": layer_separation,
640
+ "fold_preservation": fold_preservation,
641
+ "top_layer_stability": support_stability,
642
+ "lift_too_much_risk": lift_too_much_risk,
643
+ }
644
+
645
+ base_metrics = metrics_by_name["base_action"]
646
+ insert_metrics = metrics_by_name["insert_actor"]
647
+ retrieve_metrics = metrics_by_name["retrieve"]
648
+ reveal_candidates = [metrics_by_name[mode_name] for mode_name in REVEAL_MODES]
649
+ reveal_access = max(candidate["candidate_actor_feasibility_auc"] for candidate in reveal_candidates)
650
+ reveal_reveal = max(candidate["candidate_reveal_achieved"] for candidate in reveal_candidates)
651
+ reveal_hold = max(candidate["candidate_hold_persistence"] for candidate in reveal_candidates)
652
+ reveal_visibility = max(candidate["candidate_visibility_integral"] for candidate in reveal_candidates)
653
+
654
+ reveal_corridor = float(
655
+ np.clip(
656
+ 0.45 * opening_quality
657
+ + 0.30 * reveal_access
658
+ + 0.15 * reveal_reveal
659
+ + 0.10 * reveal_visibility
660
+ - 0.10 * current_disturbance,
661
+ 0.0,
662
+ 1.0,
663
+ )
664
+ )
665
+ transfer_corridor = float(
666
+ np.clip(
667
+ 0.45 * insertable_actor_corridor
668
+ + 0.30 * insert_metrics["candidate_actor_feasibility_auc"]
669
+ + 0.15 * insert_metrics["candidate_reveal_achieved"]
670
+ + 0.10 * insert_metrics["candidate_visibility_integral"]
671
+ - 0.15 * current_disturbance,
672
+ 0.0,
673
+ 1.0,
674
+ )
675
+ )
676
+ passive_corridor = float(
677
+ np.clip(
678
+ 0.55 * retrieve_metrics["candidate_retrieval_success"]
679
+ + 0.20 * retrieve_metrics["candidate_actor_feasibility_auc"]
680
+ + 0.15 * current_progress
681
+ + 0.10 * current_clearance
682
+ - 0.10 * current_disturbance,
683
+ 0.0,
684
+ 1.0,
685
+ )
686
+ )
687
+ corridor_feasible = np.stack(
688
+ [
689
+ np.full((NUM_APPROACH_TEMPLATES,), reveal_corridor, dtype=np.float32),
690
+ np.full((NUM_APPROACH_TEMPLATES,), transfer_corridor, dtype=np.float32),
691
+ np.full((NUM_APPROACH_TEMPLATES,), passive_corridor, dtype=np.float32),
692
+ ],
693
+ axis=0,
694
+ )
695
+ persistence_horizon = np.asarray(
696
+ [
697
+ ROLL_OUT_HORIZON
698
+ * float(np.clip(0.35 * hold_quality + 0.35 * reveal_hold + 0.30 * reveal_corridor, 0.0, 1.0)),
699
+ ROLL_OUT_HORIZON
700
+ * float(
701
+ np.clip(
702
+ 0.30 * hold_quality
703
+ + 0.35 * insert_metrics["candidate_hold_persistence"]
704
+ + 0.35 * transfer_corridor,
705
+ 0.0,
706
+ 1.0,
707
+ )
708
+ ),
709
+ ROLL_OUT_HORIZON
710
+ * float(
711
+ np.clip(
712
+ 0.25 * hold_quality
713
+ + 0.35 * retrieve_metrics["candidate_hold_persistence"]
714
+ + 0.40 * passive_corridor,
715
+ 0.0,
716
+ 1.0,
717
+ )
718
+ ),
719
+ ],
720
+ dtype=np.float32,
721
+ )
722
+
723
+ retrieve_margin = float(retrieve_metrics["candidate_utility"] - base_metrics["candidate_utility"])
724
+ insert_margin = float(insert_metrics["candidate_utility"] - base_metrics["candidate_utility"])
725
+ if selected_mode == "retrieve" or (retrieve_metrics["candidate_retrieval_success"] >= 0.5 and retrieve_margin >= 0.15):
726
+ support_mode = SUPPORT_MODE_PASSIVE
727
+ elif selected_mode == "insert_actor" or (insert_margin >= 0.15 and transfer_corridor >= 0.40):
728
+ support_mode = SUPPORT_MODE_TRANSFER
729
+ elif selected_mode in REVEAL_MODES or selected_mode == "maintain_gap":
730
+ support_mode = SUPPORT_MODE_HOLD
731
+ elif selected_mode == "base_action":
732
+ support_mode = SUPPORT_MODE_PASSIVE if passive_corridor >= 0.65 and retrieve_margin >= 0.05 else SUPPORT_MODE_HOLD
733
+ else:
734
+ support_mode = SUPPORT_MODE_HOLD
735
+
736
+ best_non_base_utility = max(float(payload["candidate_utility"]) for payload in candidate_metrics[1:])
737
+ intervention_warranted = selected_mode != "base_action" and best_non_base_utility >= float(base_metrics["candidate_utility"]) + 0.15
738
+
739
+ return {
740
+ "support_mode": int(support_mode),
741
+ "corridor_feasible": corridor_feasible,
742
+ "persistence_horizon": persistence_horizon,
743
+ "disturbance_cost": np.float32(current_disturbance),
744
+ "state_confidence_target": np.float32(1.0 if intervention_warranted else 0.0),
745
+ "task_metric_mask": STATE_METRIC_MASK.copy(),
746
+ **{metric_name: np.float32(metric_value) for metric_name, metric_value in task_metrics.items()},
747
+ }
748
+
749
+
750
+ def _mode_support_mode(mode_name: str, current_support_mode: int) -> int:
751
+ if mode_name in REVEAL_MODES or mode_name == "maintain_gap":
752
+ return SUPPORT_MODE_HOLD
753
+ if mode_name in TRANSFER_MODES:
754
+ return SUPPORT_MODE_TRANSFER
755
+ if mode_name in RETRIEVE_MODES:
756
+ return SUPPORT_MODE_PASSIVE
757
+ return int(current_support_mode)
758
+
759
+
760
+ def _mode_progress_schedule(mode_name: str) -> np.ndarray:
761
+ if mode_name in REVEAL_MODES:
762
+ return np.asarray([0.18, 0.38, 0.62, 0.84, 1.0], dtype=np.float32)
763
+ if mode_name in TRANSFER_MODES:
764
+ return np.asarray([0.22, 0.44, 0.66, 0.86, 1.0], dtype=np.float32)
765
+ if mode_name in RETRIEVE_MODES:
766
+ return np.asarray([0.34, 0.56, 0.76, 0.92, 1.0], dtype=np.float32)
767
+ return np.asarray([0.10, 0.22, 0.34, 0.44, 0.54], dtype=np.float32)
768
+
769
+
770
+ def _scalar_rollout(start: float, end: float, schedule: np.ndarray) -> np.ndarray:
771
+ return np.clip((1.0 - schedule) * float(start) + schedule * float(end), 0.0, 1.0).astype(np.float32)
772
+
773
+
774
+ def _candidate_rollout_targets(
775
+ *,
776
+ mode_name: str,
777
+ state_targets: dict[str, Any],
778
+ candidate_payload: dict[str, float],
779
+ ) -> dict[str, np.ndarray]:
780
+ schedule = _mode_progress_schedule(mode_name)
781
+ start_visibility = float(state_targets["target_visibility_confidence"])
782
+ start_access = float(state_targets["actor_feasibility_score"])
783
+ start_persistence = float(np.clip(state_targets["hold_quality"], 0.0, 1.0))
784
+ start_support = float(np.clip(state_targets["top_layer_stability"], 0.0, 1.0))
785
+ start_reocclusion = float(np.clip(state_targets["release_collapse_rate"], 0.0, 1.0))
786
+ start_disturbance = float(np.clip(state_targets["disturbance_cost"], 0.0, 1.0))
787
+ start_clearance = float(np.clip(state_targets["actor_feasibility_score"], 0.0, 1.0))
788
+ start_grasp = float(np.clip(max(start_visibility, start_access), 0.0, 1.0))
789
+
790
+ end_visibility = float(np.clip(candidate_payload["candidate_immediate_visibility"], 0.0, 1.0))
791
+ end_access = float(np.clip(candidate_payload["candidate_immediate_access"], 0.0, 1.0))
792
+ end_progress = float(np.clip(candidate_payload["candidate_immediate_progress"], 0.0, 1.0))
793
+ end_disturbance = float(np.clip(candidate_payload["candidate_immediate_disturbance"], 0.0, 1.0))
794
+ end_support = float(np.clip(candidate_payload["candidate_immediate_support_stability"], 0.0, 1.0))
795
+ end_persistence = float(np.clip(candidate_payload["candidate_immediate_hold_persistence"], 0.0, 1.0))
796
+ end_reocclusion = float(np.clip(candidate_payload["candidate_immediate_reocclusion"], 0.0, 1.0))
797
+ end_clearance = float(np.clip(max(end_access, end_progress), 0.0, 1.0))
798
+ end_grasp = float(np.clip(max(end_visibility, 0.5 * end_access + 0.5 * end_progress), 0.0, 1.0))
799
+
800
+ if mode_name in TRANSFER_MODES:
801
+ start_visibility = max(start_visibility, 0.35 * end_visibility)
802
+ start_access = max(start_access, 0.40 * end_access)
803
+ start_persistence = max(start_persistence, 0.45 * end_persistence)
804
+ start_support = max(start_support, 0.50 * end_support)
805
+ elif mode_name in RETRIEVE_MODES:
806
+ start_visibility = max(start_visibility, 0.55 * end_visibility)
807
+ start_access = max(start_access, 0.70 * end_access)
808
+ start_persistence = max(start_persistence, 0.65 * end_persistence)
809
+ start_support = max(start_support, 0.65 * end_support)
810
+ start_reocclusion = min(start_reocclusion, max(0.4 * end_reocclusion, 0.0))
811
+
812
+ visibility = _scalar_rollout(start_visibility, end_visibility, schedule)
813
+ access = _scalar_rollout(start_access, end_access, schedule)
814
+ persistence = _scalar_rollout(start_persistence, end_persistence, schedule)
815
+ support = _scalar_rollout(start_support, end_support, schedule)
816
+ reocclusion = _scalar_rollout(start_reocclusion, end_reocclusion, schedule)
817
+ disturbance = _scalar_rollout(start_disturbance, end_disturbance, schedule)
818
+ clearance = _scalar_rollout(start_clearance, end_clearance, schedule)
819
+ grasp = _scalar_rollout(start_grasp, end_grasp, schedule)
820
+
821
+ reveal_corridor = np.clip(
822
+ 0.38 * visibility + 0.34 * access + 0.22 * support - 0.12 * disturbance,
823
+ 0.0,
824
+ 1.0,
825
+ )
826
+ transfer_corridor = np.clip(
827
+ 0.30 * visibility + 0.38 * access + 0.18 * persistence + 0.14 * support - 0.12 * disturbance,
828
+ 0.0,
829
+ 1.0,
830
+ )
831
+ passive_corridor = np.clip(
832
+ 0.22 * visibility + 0.42 * access + 0.20 * persistence + 0.16 * grasp - 0.14 * disturbance - 0.10 * reocclusion,
833
+ 0.0,
834
+ 1.0,
835
+ )
836
+ if mode_name in REVEAL_MODES:
837
+ reveal_corridor = np.clip(reveal_corridor + 0.14, 0.0, 1.0)
838
+ passive_corridor = np.clip(0.75 * passive_corridor, 0.0, 1.0)
839
+ elif mode_name in TRANSFER_MODES:
840
+ transfer_corridor = np.clip(transfer_corridor + 0.16, 0.0, 1.0)
841
+ elif mode_name in RETRIEVE_MODES:
842
+ passive_corridor = np.clip(passive_corridor + 0.20, 0.0, 1.0)
843
+ reveal_corridor = np.clip(0.60 * reveal_corridor, 0.0, 1.0)
844
+ else:
845
+ reveal_corridor = np.clip(0.85 * reveal_corridor, 0.0, 1.0)
846
+ transfer_corridor = np.clip(0.75 * transfer_corridor, 0.0, 1.0)
847
+ passive_corridor = np.clip(0.80 * passive_corridor, 0.0, 1.0)
848
+
849
+ corridor_feasible = np.stack(
850
+ [
851
+ np.repeat(reveal_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
852
+ np.repeat(transfer_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
853
+ np.repeat(passive_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
854
+ ],
855
+ axis=1,
856
+ ).astype(np.float32)
857
+ persistence_horizon = np.stack(
858
+ [
859
+ np.clip(ROLL_OUT_HORIZON * (0.55 * reveal_corridor + 0.45 * support), 0.0, float(ROLL_OUT_HORIZON)),
860
+ np.clip(ROLL_OUT_HORIZON * (0.50 * transfer_corridor + 0.50 * persistence), 0.0, float(ROLL_OUT_HORIZON)),
861
+ np.clip(ROLL_OUT_HORIZON * (0.55 * passive_corridor + 0.45 * persistence), 0.0, float(ROLL_OUT_HORIZON)),
862
+ ],
863
+ axis=1,
864
+ ).astype(np.float32)
865
+ support_mode = np.full(
866
+ (ROLL_OUT_HORIZON,),
867
+ _mode_support_mode(mode_name, int(state_targets["support_mode"])),
868
+ dtype=np.int64,
869
+ )
870
+ if mode_name == "base_action":
871
+ support_mode[:] = int(state_targets["support_mode"])
872
+
873
+ return {
874
+ "candidate_rollout_support_mode": support_mode,
875
+ "candidate_rollout_corridor_feasible": corridor_feasible,
876
+ "candidate_rollout_persistence_horizon": persistence_horizon,
877
+ "candidate_rollout_disturbance_cost": disturbance.astype(np.float32),
878
+ "candidate_rollout_belief_map": visibility[:, None, None].astype(np.float32),
879
+ "candidate_rollout_visibility_map": visibility[:, None, None].astype(np.float32),
880
+ "candidate_rollout_clearance_map": np.repeat(clearance[:, None, None, None], 2, axis=1).astype(np.float32),
881
+ "candidate_rollout_support_stability": support[:, None, None, None].astype(np.float32),
882
+ "candidate_rollout_reocclusion_target": reocclusion[:, None, None].astype(np.float32),
883
+ "candidate_rollout_occluder_contact_map": np.clip(access * support, 0.0, 1.0)[:, None, None].astype(np.float32),
884
+ "candidate_rollout_grasp_affordance_map": grasp[:, None, None].astype(np.float32),
885
+ }
886
+
887
+
888
+ def _evaluate_candidate(
889
+ sim_env: PickClutterRevealEnv,
890
+ obs_env: PickClutterRevealEnv,
891
+ snapshot: dict[str, Any],
892
+ mode_name: str,
893
+ ) -> dict[str, float]:
894
+ _restore_env(sim_env, snapshot)
895
+ start_positions = _all_positions(sim_env)
896
+ _execute_mode(sim_env, mode_name)
897
+ _sync_env_state(sim_env, obs_env)
898
+ after_bundle = _extract_sensor_bundle(obs_env.get_obs(obs_env.get_info()))
899
+ immediate = _candidate_metrics(sim_env, start_positions=start_positions, current_obs_bundle=after_bundle)
900
+ if not immediate["retrieval_success"] and mode_name != "retrieve":
901
+ _execute_mode(sim_env, "retrieve")
902
+ _sync_env_state(sim_env, obs_env)
903
+ follow_bundle = _extract_sensor_bundle(obs_env.get_obs(obs_env.get_info()))
904
+ final_metrics = _candidate_metrics(sim_env, start_positions=start_positions, current_obs_bundle=follow_bundle)
905
+ else:
906
+ final_metrics = immediate
907
+ _restore_env(obs_env, snapshot)
908
+ utility = (
909
+ 2.5 * final_metrics["retrieval_success"]
910
+ + 1.0 * final_metrics["progress"]
911
+ + 0.5 * final_metrics["clearance"]
912
+ + 0.25 * final_metrics["visibility"]
913
+ - 0.5 * final_metrics["disturbance"]
914
+ )
915
+ return {
916
+ "candidate_retrieval_success": final_metrics["retrieval_success"],
917
+ "candidate_risk": float(np.clip(final_metrics["disturbance"], 0.0, 1.0)),
918
+ "candidate_utility": float(utility),
919
+ "candidate_final_disturbance_cost": final_metrics["disturbance"],
920
+ "candidate_reocclusion_rate": float(np.clip(1.0 - final_metrics["clearance"], 0.0, 1.0)),
921
+ "candidate_visibility_integral": final_metrics["visibility"],
922
+ "candidate_actor_feasibility_auc": final_metrics["clearance"],
923
+ "candidate_reveal_achieved": float(final_metrics["progress"] > 0.15 or final_metrics["clearance"] > 0.35),
924
+ "candidate_hold_persistence": float(1.0 - final_metrics["disturbance"]),
925
+ "candidate_support_stability_auc": float(1.0 - 0.5 * final_metrics["disturbance"]),
926
+ "candidate_disturbance_auc": final_metrics["disturbance"],
927
+ "candidate_immediate_retrieval_success": immediate["retrieval_success"],
928
+ "candidate_immediate_visibility": immediate["visibility"],
929
+ "candidate_immediate_access": immediate["clearance"],
930
+ "candidate_immediate_progress": immediate["progress"],
931
+ "candidate_immediate_reocclusion": float(np.clip(1.0 - immediate["clearance"], 0.0, 1.0)),
932
+ "candidate_immediate_hold_persistence": float(1.0 - immediate["disturbance"]),
933
+ "candidate_immediate_support_stability": float(1.0 - 0.5 * immediate["disturbance"]),
934
+ "candidate_immediate_disturbance": immediate["disturbance"],
935
+ }
936
+
937
+
938
+ def _build_episode_splits(spec: SmokeSpec) -> dict[str, list[int]]:
939
+ return {
940
+ "train": [spec.dataset_seed * 10_000 + index for index in range(spec.train_episodes)],
941
+ "val": [spec.dataset_seed * 10_000 + 1_000 + index for index in range(spec.val_episodes)],
942
+ "eval": [spec.dataset_seed * 10_000 + 2_000 + index for index in range(spec.eval_episodes)],
943
+ }
944
+
945
+
946
+ def _save_episode_splits(output_path: Path, payload: dict[str, list[int]]) -> None:
947
+ output_path.parent.mkdir(parents=True, exist_ok=True)
948
+ output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
949
+
950
+
951
+ def _init_history_entry(obs_bundle: dict[str, np.ndarray], proprio: np.ndarray, action_chunk: np.ndarray) -> dict[str, Any]:
952
+ return {
953
+ "images": obs_bundle["images"].copy(),
954
+ "depths": obs_bundle["depths"].copy(),
955
+ "depth_valid": obs_bundle["depth_valid"].copy(),
956
+ "camera_intrinsics": obs_bundle["camera_intrinsics"].copy(),
957
+ "camera_extrinsics": obs_bundle["camera_extrinsics"].copy(),
958
+ "proprio": proprio.copy(),
959
+ "action": action_chunk.mean(axis=0).astype(np.float32, copy=False),
960
+ }
961
+
962
+
963
+ def _history_stack(
964
+ history: Sequence[dict[str, Any]],
965
+ key: str,
966
+ *,
967
+ pad_shape: tuple[int, ...],
968
+ dtype: np.dtype,
969
+ history_steps: int,
970
+ ) -> np.ndarray:
971
+ history = list(history)[-history_steps:]
972
+ pad_count = history_steps - len(history)
973
+ chunks = [np.zeros(pad_shape, dtype=dtype) for _ in range(pad_count)]
974
+ chunks.extend(np.asarray(item[key], dtype=dtype) for item in history)
975
+ return np.stack(chunks, axis=0).astype(dtype, copy=False)
976
+
977
+
978
+ class ManiSkillPickClutterDataset(Dataset[dict[str, Any]]):
979
+ def __init__(self, samples: Sequence[dict[str, Any]]) -> None:
980
+ self.samples = list(samples)
981
+
982
+ def __len__(self) -> int:
983
+ return len(self.samples)
984
+
985
+ def __getitem__(self, index: int) -> dict[str, Any]:
986
+ sample = self.samples[index]
987
+ item = {
988
+ "images": torch.from_numpy(sample["images"]).permute(0, 3, 1, 2).float() / 255.0,
989
+ "depths": torch.from_numpy(sample["depths"]).float(),
990
+ "depth_valid": torch.from_numpy(sample["depth_valid"]).float(),
991
+ "camera_intrinsics": torch.from_numpy(sample["camera_intrinsics"]).float(),
992
+ "camera_extrinsics": torch.from_numpy(sample["camera_extrinsics"]).float(),
993
+ "history_images": torch.from_numpy(sample["history_images"]).permute(0, 1, 4, 2, 3).float() / 255.0,
994
+ "history_depths": torch.from_numpy(sample["history_depths"]).float(),
995
+ "history_depth_valid": torch.from_numpy(sample["history_depth_valid"]).float(),
996
+ "history_camera_intrinsics": torch.from_numpy(sample["history_camera_intrinsics"]).float(),
997
+ "history_camera_extrinsics": torch.from_numpy(sample["history_camera_extrinsics"]).float(),
998
+ "history_proprio": torch.from_numpy(sample["history_proprio"]).float(),
999
+ "history_actions": torch.from_numpy(sample["history_actions"]).float(),
1000
+ "proprio": torch.from_numpy(sample["proprio"]).float(),
1001
+ "texts": sample["language_goal"],
1002
+ "task_name": sample["task_name"],
1003
+ "task_id": torch.as_tensor(sample["task_id"], dtype=torch.long),
1004
+ "action_chunk": torch.from_numpy(sample["action_chunk"]).float(),
1005
+ "candidate_action_chunks": torch.from_numpy(sample["candidate_action_chunks"]).float(),
1006
+ "candidate_retrieval_success": torch.from_numpy(sample["candidate_retrieval_success"]).float(),
1007
+ "candidate_final_disturbance_cost": torch.from_numpy(sample["candidate_final_disturbance_cost"]).float(),
1008
+ "candidate_reocclusion_rate": torch.from_numpy(sample["candidate_reocclusion_rate"]).float(),
1009
+ "candidate_visibility_integral": torch.from_numpy(sample["candidate_visibility_integral"]).float(),
1010
+ "candidate_actor_feasibility_auc": torch.from_numpy(sample["candidate_actor_feasibility_auc"]).float(),
1011
+ "candidate_reveal_achieved": torch.from_numpy(sample["candidate_reveal_achieved"]).float(),
1012
+ "candidate_hold_persistence": torch.from_numpy(sample["candidate_hold_persistence"]).float(),
1013
+ "candidate_support_stability_auc": torch.from_numpy(sample["candidate_support_stability_auc"]).float(),
1014
+ "candidate_disturbance_auc": torch.from_numpy(sample["candidate_disturbance_auc"]).float(),
1015
+ "candidate_risk": torch.from_numpy(sample["candidate_risk"]).float(),
1016
+ "candidate_utility": torch.from_numpy(sample["candidate_utility"]).float(),
1017
+ "proposal_target_action_chunks": torch.from_numpy(sample["candidate_action_chunks"]).float(),
1018
+ "proposal_target_retrieval_success": torch.from_numpy(sample["candidate_retrieval_success"]).float(),
1019
+ "proposal_target_risk": torch.from_numpy(sample["candidate_risk"]).float(),
1020
+ "proposal_target_utility": torch.from_numpy(sample["candidate_utility"]).float(),
1021
+ "episode_seed": sample["episode_seed"],
1022
+ "decision_step": sample["decision_step"],
1023
+ "selected_mode": sample["selected_mode"],
1024
+ }
1025
+ if "support_mode" in sample:
1026
+ item["support_mode"] = torch.as_tensor(sample["support_mode"], dtype=torch.long)
1027
+ if "corridor_feasible" in sample:
1028
+ item["corridor_feasible"] = torch.from_numpy(sample["corridor_feasible"]).float()
1029
+ if "persistence_horizon" in sample:
1030
+ item["persistence_horizon"] = torch.from_numpy(sample["persistence_horizon"]).float()
1031
+ if "disturbance_cost" in sample:
1032
+ item["disturbance_cost"] = torch.as_tensor(sample["disturbance_cost"], dtype=torch.float32)
1033
+ if "state_confidence_target" in sample:
1034
+ item["state_confidence_target"] = torch.as_tensor(sample["state_confidence_target"], dtype=torch.float32)
1035
+ if "task_metric_mask" in sample:
1036
+ item["task_metric_mask"] = torch.from_numpy(sample["task_metric_mask"]).to(dtype=torch.bool)
1037
+ for metric_name in STATE_SUPERVISION_METRICS:
1038
+ if metric_name in sample:
1039
+ item[metric_name] = torch.as_tensor(sample[metric_name], dtype=torch.float32)
1040
+ for key in (
1041
+ "candidate_rollout_support_mode",
1042
+ "proposal_target_rollout_support_mode",
1043
+ ):
1044
+ if key in sample:
1045
+ item[key] = torch.from_numpy(sample[key]).long()
1046
+ for key in (
1047
+ "candidate_rollout_corridor_feasible",
1048
+ "candidate_rollout_persistence_horizon",
1049
+ "candidate_rollout_disturbance_cost",
1050
+ "candidate_rollout_belief_map",
1051
+ "candidate_rollout_visibility_map",
1052
+ "candidate_rollout_clearance_map",
1053
+ "candidate_rollout_support_stability",
1054
+ "candidate_rollout_reocclusion_target",
1055
+ "candidate_rollout_occluder_contact_map",
1056
+ "candidate_rollout_grasp_affordance_map",
1057
+ "proposal_target_rollout_corridor_feasible",
1058
+ "proposal_target_rollout_persistence_horizon",
1059
+ "proposal_target_rollout_disturbance_cost",
1060
+ "proposal_target_rollout_belief_map",
1061
+ "proposal_target_rollout_visibility_map",
1062
+ "proposal_target_rollout_clearance_map",
1063
+ "proposal_target_rollout_support_stability",
1064
+ "proposal_target_rollout_reocclusion_target",
1065
+ "proposal_target_rollout_occluder_contact_map",
1066
+ "proposal_target_rollout_grasp_affordance_map",
1067
+ ):
1068
+ if key in sample:
1069
+ item[key] = torch.from_numpy(sample[key]).float()
1070
+ return item
1071
+
1072
+
1073
+ def _make_loader(samples: Sequence[dict[str, Any]], *, batch_size: int, num_workers: int, shuffle: bool) -> DataLoader:
1074
+ return DataLoader(
1075
+ ManiSkillPickClutterDataset(samples),
1076
+ batch_size=batch_size,
1077
+ shuffle=shuffle,
1078
+ num_workers=num_workers,
1079
+ pin_memory=torch.cuda.is_available(),
1080
+ )
1081
+
1082
+
1083
+ def _load_init_bundle() -> tuple[PolicyConfig, dict[str, Any], dict[str, Any]]:
1084
+ checkpoint = torch.load(DEFAULT_INIT_CHECKPOINT, map_location="cpu", weights_only=False)
1085
+ policy_config = PolicyConfig(
1086
+ backbone=FrozenVLBackboneConfig(**checkpoint["policy_config"]["backbone"]),
1087
+ fusion=MultiViewFusionConfig(**checkpoint["policy_config"]["fusion"]),
1088
+ memory=ObservationMemoryConfig(**checkpoint["policy_config"]["memory"]),
1089
+ decoder=ChunkDecoderConfig(**checkpoint["policy_config"]["decoder"]),
1090
+ reveal_head=RevealHeadConfig(**checkpoint["policy_config"]["reveal_head"]),
1091
+ world_model=RevealWMConfig(**checkpoint["policy_config"]["world_model"]),
1092
+ planner=PlannerConfig(**checkpoint["policy_config"]["planner"]),
1093
+ )
1094
+ return _apply_smoke_planner_overrides(policy_config), checkpoint["trainer_config"], checkpoint["loss_weights"]
1095
+
1096
+
1097
+ def _trainer_config_for_variant(variant: str) -> TrainerConfig:
1098
+ if variant == "trunk_only_ft":
1099
+ return TrainerConfig(
1100
+ policy_type="foundation_trunk",
1101
+ use_bf16=True,
1102
+ grad_clip_norm=1.0,
1103
+ freeze_backbone=True,
1104
+ gradient_checkpointing=False,
1105
+ trainable_parameter_prefixes=("fusion", "memory", "decoder"),
1106
+ )
1107
+ if variant == "adapter_active_ft":
1108
+ return TrainerConfig(
1109
+ policy_type="adapter_wrapped",
1110
+ use_bf16=True,
1111
+ grad_clip_norm=1.0,
1112
+ freeze_backbone=True,
1113
+ gradient_checkpointing=False,
1114
+ eval_mode="adapter_active",
1115
+ trainable_parameter_prefixes=(
1116
+ "trunk.fusion",
1117
+ "trunk.memory",
1118
+ "trunk.decoder",
1119
+ "adapter.state_head",
1120
+ "adapter.transition_model",
1121
+ "adapter.proposal_prior",
1122
+ "adapter.planner",
1123
+ ),
1124
+ adapter_mode="adapter_active",
1125
+ adapter_use_transition_model=True,
1126
+ adapter_use_task_conditioning=True,
1127
+ adapter_action_supervision_source="trunk",
1128
+ )
1129
+ raise KeyError(f"Unsupported variant: {variant}")
1130
+
1131
+
1132
+ def _loss_weights_for_smoke() -> LossWeights:
1133
+ return LossWeights(
1134
+ action=1.0,
1135
+ support_mode=0.15,
1136
+ corridor=0.15,
1137
+ persistence=0.08,
1138
+ disturbance=0.08,
1139
+ planner_success=0.20,
1140
+ planner_risk=0.08,
1141
+ planner_ranking=0.20,
1142
+ proposal_reconstruction=0.10,
1143
+ proposal_success=0.12,
1144
+ proposal_ranking=0.15,
1145
+ proposal_mode=0.10,
1146
+ proposal_diversity=0.02,
1147
+ task_metrics=0.15,
1148
+ transition=0.25,
1149
+ gate=0.25,
1150
+ calibration=0.10,
1151
+ proposal_mode_task_filter=[TASK_NAME],
1152
+ )
1153
+
1154
+
1155
+ def _save_training_checkpoint(
1156
+ *,
1157
+ output_dir: Path,
1158
+ experiment_name: str,
1159
+ model: torch.nn.Module,
1160
+ policy_config: PolicyConfig,
1161
+ trainer_config: TrainerConfig,
1162
+ loss_weights: LossWeights,
1163
+ history: list[dict[str, Any]],
1164
+ best_val: float,
1165
+ train_spec: dict[str, Any],
1166
+ ) -> Path:
1167
+ output_dir.mkdir(parents=True, exist_ok=True)
1168
+ checkpoint_path = output_dir / "checkpoint_best.pt"
1169
+ torch.save(
1170
+ {
1171
+ "experiment_name": experiment_name,
1172
+ "policy_config": asdict(policy_config),
1173
+ "trainer_config": asdict(trainer_config),
1174
+ "loss_weights": asdict(loss_weights),
1175
+ "state_dict": model.state_dict(),
1176
+ "history": history,
1177
+ "best_val_total": best_val,
1178
+ "train_spec": train_spec,
1179
+ },
1180
+ checkpoint_path,
1181
+ )
1182
+ return checkpoint_path
1183
+
1184
+
1185
+ def _aggregate_epoch(loss_records: Sequence[dict[str, float]]) -> dict[str, float]:
1186
+ keys = sorted({key for record in loss_records for key in record})
1187
+ return {
1188
+ key: float(np.mean([record.get(key, 0.0) for record in loss_records])) if loss_records else 0.0
1189
+ for key in keys
1190
+ }
1191
+
1192
+
1193
+ def _train_variant(
1194
+ *,
1195
+ variant: str,
1196
+ train_samples: Sequence[dict[str, Any]],
1197
+ val_samples: Sequence[dict[str, Any]],
1198
+ spec: SmokeSpec,
1199
+ output_dir: Path,
1200
+ ) -> tuple[Path, dict[str, Any]]:
1201
+ policy_config, _init_trainer_cfg, _init_loss_weights = _load_init_bundle()
1202
+ trainer_config = _trainer_config_for_variant(variant)
1203
+ loss_weights = _loss_weights_for_smoke()
1204
+ model = build_policy(policy_config, trainer_config)
1205
+ init_info = _load_init_checkpoint(model, str(DEFAULT_INIT_CHECKPOINT), False)
1206
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1207
+ model = model.to(device)
1208
+ torch.manual_seed(spec.train_seed)
1209
+ if torch.cuda.is_available():
1210
+ torch.cuda.manual_seed_all(spec.train_seed)
1211
+ torch.backends.cuda.matmul.allow_tf32 = True
1212
+ matched = apply_trainable_parameter_prefixes(model, trainer_config)
1213
+ optimizer = torch.optim.AdamW(
1214
+ [parameter for parameter in model.parameters() if parameter.requires_grad],
1215
+ lr=spec.learning_rate,
1216
+ weight_decay=spec.weight_decay,
1217
+ )
1218
+ trainer = BimanualTrainer(model=model, optimizer=optimizer, config=trainer_config)
1219
+ train_loader = _make_loader(train_samples, batch_size=spec.batch_size, num_workers=spec.num_workers, shuffle=True)
1220
+ val_loader = _make_loader(val_samples, batch_size=spec.batch_size, num_workers=spec.num_workers, shuffle=False)
1221
+
1222
+ best_val = math.inf
1223
+ history: list[dict[str, Any]] = []
1224
+ train_spec = build_target_training_spec(
1225
+ track_id="occlusion_track",
1226
+ model_variant=variant,
1227
+ seed=spec.train_seed,
1228
+ train_demos=spec.train_episodes,
1229
+ val_demos=spec.val_episodes,
1230
+ init_checkpoint_group=str(DEFAULT_INIT_CHECKPOINT),
1231
+ optimizer="adamw",
1232
+ learning_rate=spec.learning_rate,
1233
+ lr_schedule="constant",
1234
+ batch_size=spec.batch_size,
1235
+ augmentations="none",
1236
+ early_stopping_metric="val_total",
1237
+ max_gradient_steps=len(train_loader) * spec.epochs,
1238
+ unfreeze_scope="fusion_memory_decoder",
1239
+ dataset_split_id=(
1240
+ f"pickclutter_{SMOKE_VERSION}_seed{spec.dataset_seed}"
1241
+ if int(spec.dataset_seed) == DEFAULT_SEED
1242
+ else f"pickclutter_{SMOKE_VERSION}_dataset_seed{spec.dataset_seed}"
1243
+ ),
1244
+ )
1245
+
1246
+ for epoch in range(spec.epochs):
1247
+ model.train()
1248
+ train_losses: list[dict[str, float]] = []
1249
+ for batch in train_loader:
1250
+ moved = _move_batch_to_device(batch, device)
1251
+ loss_dict = trainer.training_step(moved, loss_weights=loss_weights)
1252
+ train_losses.append({key: float(value.detach().cpu()) for key, value in loss_dict.items()})
1253
+
1254
+ model.eval()
1255
+ val_losses: list[dict[str, float]] = []
1256
+ with torch.no_grad():
1257
+ for batch in val_loader:
1258
+ moved = _move_batch_to_device(batch, device)
1259
+ forward_kwargs = {
1260
+ "images": moved["images"],
1261
+ "proprio": moved["proprio"],
1262
+ "texts": moved["texts"],
1263
+ "task_names": moved.get("task_name"),
1264
+ "task_ids": moved.get("task_id"),
1265
+ "history_images": moved.get("history_images"),
1266
+ "history_proprio": moved.get("history_proprio"),
1267
+ "history_actions": moved.get("history_actions"),
1268
+ "depths": moved.get("depths"),
1269
+ "depth_valid": moved.get("depth_valid"),
1270
+ "camera_intrinsics": moved.get("camera_intrinsics"),
1271
+ "camera_extrinsics": moved.get("camera_extrinsics"),
1272
+ "history_depths": moved.get("history_depths"),
1273
+ "history_depth_valid": moved.get("history_depth_valid"),
1274
+ "history_camera_intrinsics": moved.get("history_camera_intrinsics"),
1275
+ "history_camera_extrinsics": moved.get("history_camera_extrinsics"),
1276
+ }
1277
+ if variant == "adapter_active_ft":
1278
+ forward_kwargs["adapter_mode"] = "adapter_active"
1279
+ forward_kwargs["use_transition_model"] = True
1280
+ forward_kwargs["use_task_conditioning"] = True
1281
+ outputs = model(**forward_kwargs)
1282
+ from train.losses import compute_total_loss
1283
+
1284
+ losses = compute_total_loss(outputs, moved, weights=loss_weights)
1285
+ val_losses.append({key: float(value.detach().cpu()) for key, value in losses.items()})
1286
+
1287
+ train_summary = _aggregate_epoch(train_losses)
1288
+ val_summary = _aggregate_epoch(val_losses)
1289
+ history.append({"epoch": epoch, "train": train_summary, "val": val_summary})
1290
+ print(
1291
+ json.dumps(
1292
+ {
1293
+ "phase": "epoch_complete",
1294
+ "variant": variant,
1295
+ "epoch": epoch,
1296
+ "train_total": train_summary.get("total", 0.0),
1297
+ "val_total": val_summary.get("total", 0.0),
1298
+ }
1299
+ ),
1300
+ flush=True,
1301
+ )
1302
+ if val_summary.get("total", math.inf) <= best_val:
1303
+ best_val = val_summary["total"]
1304
+ checkpoint_path = _save_training_checkpoint(
1305
+ output_dir=output_dir,
1306
+ experiment_name=f"{variant}_seed{spec.train_seed}",
1307
+ model=model,
1308
+ policy_config=policy_config,
1309
+ trainer_config=trainer_config,
1310
+ loss_weights=loss_weights,
1311
+ history=history,
1312
+ best_val=best_val,
1313
+ train_spec=train_spec,
1314
+ )
1315
+ (output_dir / "summary.json").write_text(
1316
+ json.dumps(
1317
+ {
1318
+ "variant": variant,
1319
+ "checkpoint_path": str(checkpoint_path),
1320
+ "init_info": init_info,
1321
+ "trainable_parameter_names": matched,
1322
+ "best_val_total": best_val,
1323
+ "history": history,
1324
+ "train_spec": train_spec,
1325
+ },
1326
+ indent=2,
1327
+ )
1328
+ + "\n",
1329
+ encoding="utf-8",
1330
+ )
1331
+ return output_dir / "checkpoint_best.pt", train_spec
1332
+
1333
+
1334
+ def _load_checkpoint(
1335
+ checkpoint_path: Path,
1336
+ *,
1337
+ adapter_mode: str | None = None,
1338
+ planner_overrides: dict[str, float] | None = None,
1339
+ ) -> tuple[torch.nn.Module, dict[str, Any]]:
1340
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
1341
+ policy_config = PolicyConfig(
1342
+ backbone=FrozenVLBackboneConfig(**checkpoint["policy_config"]["backbone"]),
1343
+ fusion=MultiViewFusionConfig(**checkpoint["policy_config"]["fusion"]),
1344
+ memory=ObservationMemoryConfig(**checkpoint["policy_config"]["memory"]),
1345
+ decoder=ChunkDecoderConfig(**checkpoint["policy_config"]["decoder"]),
1346
+ reveal_head=RevealHeadConfig(**checkpoint["policy_config"]["reveal_head"]),
1347
+ world_model=RevealWMConfig(**checkpoint["policy_config"]["world_model"]),
1348
+ planner=PlannerConfig(**checkpoint["policy_config"]["planner"]),
1349
+ )
1350
+ policy_config = _apply_smoke_planner_overrides(policy_config, planner_overrides=planner_overrides)
1351
+ trainer_config = TrainerConfig(**checkpoint["trainer_config"])
1352
+ if adapter_mode is not None and trainer_config.policy_type == "adapter_wrapped":
1353
+ trainer_config.adapter_mode = adapter_mode
1354
+ trainer_config.eval_mode = adapter_mode
1355
+ model = build_policy(policy_config, trainer_config)
1356
+ filtered_state_dict, skipped, _remapped = filter_compatible_state_dict(model.state_dict(), checkpoint["state_dict"])
1357
+ incompatible = model.load_state_dict(filtered_state_dict, strict=False)
1358
+ if incompatible.unexpected_keys:
1359
+ raise RuntimeError(f"Unexpected checkpoint keys for {checkpoint_path}: {list(incompatible.unexpected_keys)}")
1360
+ if skipped:
1361
+ checkpoint["_shape_skipped_keys"] = skipped
1362
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1363
+ model = model.to(device)
1364
+ model.eval()
1365
+ return model, checkpoint
1366
+
1367
+
1368
+ def _collect_split(
1369
+ *,
1370
+ split_name: str,
1371
+ seeds: Sequence[int],
1372
+ spec: SmokeSpec,
1373
+ output_path: Path,
1374
+ ) -> dict[str, Any]:
1375
+ obs_env = PickClutterRevealEnv(
1376
+ obs_mode="rgb+depth+segmentation",
1377
+ control_mode="pd_ee_delta_pos",
1378
+ render_mode="rgb_array",
1379
+ )
1380
+ sim_env = PickClutterRevealEnv(
1381
+ obs_mode="none",
1382
+ control_mode="pd_ee_delta_pos",
1383
+ render_mode="rgb_array",
1384
+ )
1385
+ samples: list[dict[str, Any]] = []
1386
+ episode_records: list[dict[str, Any]] = []
1387
+ try:
1388
+ for episode_seed in seeds:
1389
+ obs, _ = obs_env.reset(seed=int(episode_seed))
1390
+ sim_env.reset(seed=int(episode_seed))
1391
+ _sync_env_state(obs_env, sim_env)
1392
+ episode_start_positions = _all_positions(obs_env)
1393
+ history: deque[dict[str, Any]] = deque(maxlen=spec.history_steps)
1394
+ episode_success = False
1395
+ for decision_step in range(spec.max_macro_steps):
1396
+ obs_bundle = _extract_sensor_bundle(obs)
1397
+ proprio = _build_proprio(obs_env)
1398
+ snapshot = _snapshot_env(obs_env)
1399
+ candidate_metrics = [_evaluate_candidate(sim_env, obs_env, snapshot, mode_name) for mode_name in MODE_ORDER]
1400
+ candidate_chunks = np.stack([CANONICAL_CHUNKS[mode_name] for mode_name in MODE_ORDER], axis=0).astype(np.float32)
1401
+ utilities = np.asarray([payload["candidate_utility"] for payload in candidate_metrics], dtype=np.float32)
1402
+ best_index = int(utilities.argmax())
1403
+ selected_mode = MODE_ORDER[best_index]
1404
+ state_targets = _current_state_targets(
1405
+ obs_env,
1406
+ obs_bundle=obs_bundle,
1407
+ candidate_metrics=candidate_metrics,
1408
+ episode_start_positions=episode_start_positions,
1409
+ selected_mode=selected_mode,
1410
+ )
1411
+ rollout_targets_by_mode = [
1412
+ _candidate_rollout_targets(
1413
+ mode_name=mode_name,
1414
+ state_targets=state_targets,
1415
+ candidate_payload=payload,
1416
+ )
1417
+ for mode_name, payload in zip(MODE_ORDER, candidate_metrics)
1418
+ ]
1419
+ sample = {
1420
+ "images": obs_bundle["images"].copy(),
1421
+ "depths": obs_bundle["depths"].copy(),
1422
+ "depth_valid": obs_bundle["depth_valid"].copy(),
1423
+ "camera_intrinsics": obs_bundle["camera_intrinsics"].copy(),
1424
+ "camera_extrinsics": obs_bundle["camera_extrinsics"].copy(),
1425
+ "history_images": _history_stack(history, "images", pad_shape=obs_bundle["images"].shape, dtype=np.uint8, history_steps=spec.history_steps),
1426
+ "history_depths": _history_stack(history, "depths", pad_shape=obs_bundle["depths"].shape, dtype=np.float32, history_steps=spec.history_steps),
1427
+ "history_depth_valid": _history_stack(history, "depth_valid", pad_shape=obs_bundle["depth_valid"].shape, dtype=np.float32, history_steps=spec.history_steps),
1428
+ "history_camera_intrinsics": _history_stack(history, "camera_intrinsics", pad_shape=obs_bundle["camera_intrinsics"].shape, dtype=np.float32, history_steps=spec.history_steps),
1429
+ "history_camera_extrinsics": _history_stack(history, "camera_extrinsics", pad_shape=obs_bundle["camera_extrinsics"].shape, dtype=np.float32, history_steps=spec.history_steps),
1430
+ "history_proprio": _history_stack(history, "proprio", pad_shape=(PROPRIO_DIM,), dtype=np.float32, history_steps=spec.history_steps),
1431
+ "history_actions": _history_stack(history, "action", pad_shape=(14,), dtype=np.float32, history_steps=spec.history_steps),
1432
+ "proprio": proprio.astype(np.float32),
1433
+ "language_goal": TEXT_PROMPT,
1434
+ "task_name": TASK_NAME,
1435
+ "task_id": TASK_ID,
1436
+ "action_chunk": CANONICAL_CHUNKS[selected_mode].copy(),
1437
+ "candidate_action_chunks": candidate_chunks,
1438
+ "candidate_retrieval_success": np.asarray([payload["candidate_retrieval_success"] for payload in candidate_metrics], dtype=np.float32),
1439
+ "candidate_final_disturbance_cost": np.asarray([payload["candidate_final_disturbance_cost"] for payload in candidate_metrics], dtype=np.float32),
1440
+ "candidate_reocclusion_rate": np.asarray([payload["candidate_reocclusion_rate"] for payload in candidate_metrics], dtype=np.float32),
1441
+ "candidate_visibility_integral": np.asarray([payload["candidate_visibility_integral"] for payload in candidate_metrics], dtype=np.float32),
1442
+ "candidate_actor_feasibility_auc": np.asarray([payload["candidate_actor_feasibility_auc"] for payload in candidate_metrics], dtype=np.float32),
1443
+ "candidate_reveal_achieved": np.asarray([payload["candidate_reveal_achieved"] for payload in candidate_metrics], dtype=np.float32),
1444
+ "candidate_hold_persistence": np.asarray([payload["candidate_hold_persistence"] for payload in candidate_metrics], dtype=np.float32),
1445
+ "candidate_support_stability_auc": np.asarray([payload["candidate_support_stability_auc"] for payload in candidate_metrics], dtype=np.float32),
1446
+ "candidate_disturbance_auc": np.asarray([payload["candidate_disturbance_auc"] for payload in candidate_metrics], dtype=np.float32),
1447
+ "candidate_risk": np.asarray([payload["candidate_risk"] for payload in candidate_metrics], dtype=np.float32),
1448
+ "candidate_utility": utilities,
1449
+ "candidate_rollout_support_mode": np.stack(
1450
+ [payload["candidate_rollout_support_mode"] for payload in rollout_targets_by_mode],
1451
+ axis=0,
1452
+ ).astype(np.int64),
1453
+ "candidate_rollout_corridor_feasible": np.stack(
1454
+ [payload["candidate_rollout_corridor_feasible"] for payload in rollout_targets_by_mode],
1455
+ axis=0,
1456
+ ).astype(np.float32),
1457
+ "candidate_rollout_persistence_horizon": np.stack(
1458
+ [payload["candidate_rollout_persistence_horizon"] for payload in rollout_targets_by_mode],
1459
+ axis=0,
1460
+ ).astype(np.float32),
1461
+ "candidate_rollout_disturbance_cost": np.stack(
1462
+ [payload["candidate_rollout_disturbance_cost"] for payload in rollout_targets_by_mode],
1463
+ axis=0,
1464
+ ).astype(np.float32),
1465
+ "candidate_rollout_belief_map": np.stack(
1466
+ [payload["candidate_rollout_belief_map"] for payload in rollout_targets_by_mode],
1467
+ axis=0,
1468
+ ).astype(np.float32),
1469
+ "candidate_rollout_visibility_map": np.stack(
1470
+ [payload["candidate_rollout_visibility_map"] for payload in rollout_targets_by_mode],
1471
+ axis=0,
1472
+ ).astype(np.float32),
1473
+ "candidate_rollout_clearance_map": np.stack(
1474
+ [payload["candidate_rollout_clearance_map"] for payload in rollout_targets_by_mode],
1475
+ axis=0,
1476
+ ).astype(np.float32),
1477
+ "candidate_rollout_support_stability": np.stack(
1478
+ [payload["candidate_rollout_support_stability"] for payload in rollout_targets_by_mode],
1479
+ axis=0,
1480
+ ).astype(np.float32),
1481
+ "candidate_rollout_reocclusion_target": np.stack(
1482
+ [payload["candidate_rollout_reocclusion_target"] for payload in rollout_targets_by_mode],
1483
+ axis=0,
1484
+ ).astype(np.float32),
1485
+ "candidate_rollout_occluder_contact_map": np.stack(
1486
+ [payload["candidate_rollout_occluder_contact_map"] for payload in rollout_targets_by_mode],
1487
+ axis=0,
1488
+ ).astype(np.float32),
1489
+ "candidate_rollout_grasp_affordance_map": np.stack(
1490
+ [payload["candidate_rollout_grasp_affordance_map"] for payload in rollout_targets_by_mode],
1491
+ axis=0,
1492
+ ).astype(np.float32),
1493
+ "proposal_target_rollout_support_mode": np.stack(
1494
+ [payload["candidate_rollout_support_mode"] for payload in rollout_targets_by_mode],
1495
+ axis=0,
1496
+ ).astype(np.int64),
1497
+ "proposal_target_rollout_corridor_feasible": np.stack(
1498
+ [payload["candidate_rollout_corridor_feasible"] for payload in rollout_targets_by_mode],
1499
+ axis=0,
1500
+ ).astype(np.float32),
1501
+ "proposal_target_rollout_persistence_horizon": np.stack(
1502
+ [payload["candidate_rollout_persistence_horizon"] for payload in rollout_targets_by_mode],
1503
+ axis=0,
1504
+ ).astype(np.float32),
1505
+ "proposal_target_rollout_disturbance_cost": np.stack(
1506
+ [payload["candidate_rollout_disturbance_cost"] for payload in rollout_targets_by_mode],
1507
+ axis=0,
1508
+ ).astype(np.float32),
1509
+ "proposal_target_rollout_belief_map": np.stack(
1510
+ [payload["candidate_rollout_belief_map"] for payload in rollout_targets_by_mode],
1511
+ axis=0,
1512
+ ).astype(np.float32),
1513
+ "proposal_target_rollout_visibility_map": np.stack(
1514
+ [payload["candidate_rollout_visibility_map"] for payload in rollout_targets_by_mode],
1515
+ axis=0,
1516
+ ).astype(np.float32),
1517
+ "proposal_target_rollout_clearance_map": np.stack(
1518
+ [payload["candidate_rollout_clearance_map"] for payload in rollout_targets_by_mode],
1519
+ axis=0,
1520
+ ).astype(np.float32),
1521
+ "proposal_target_rollout_support_stability": np.stack(
1522
+ [payload["candidate_rollout_support_stability"] for payload in rollout_targets_by_mode],
1523
+ axis=0,
1524
+ ).astype(np.float32),
1525
+ "proposal_target_rollout_reocclusion_target": np.stack(
1526
+ [payload["candidate_rollout_reocclusion_target"] for payload in rollout_targets_by_mode],
1527
+ axis=0,
1528
+ ).astype(np.float32),
1529
+ "proposal_target_rollout_occluder_contact_map": np.stack(
1530
+ [payload["candidate_rollout_occluder_contact_map"] for payload in rollout_targets_by_mode],
1531
+ axis=0,
1532
+ ).astype(np.float32),
1533
+ "proposal_target_rollout_grasp_affordance_map": np.stack(
1534
+ [payload["candidate_rollout_grasp_affordance_map"] for payload in rollout_targets_by_mode],
1535
+ axis=0,
1536
+ ).astype(np.float32),
1537
+ "episode_seed": int(episode_seed),
1538
+ "decision_step": int(decision_step),
1539
+ "selected_mode": selected_mode,
1540
+ **state_targets,
1541
+ }
1542
+ samples.append(sample)
1543
+ _execute_mode(obs_env, selected_mode)
1544
+ obs = obs_env.get_obs(obs_env.get_info())
1545
+ history.append(_init_history_entry(obs_bundle, proprio, CANONICAL_CHUNKS[selected_mode]))
1546
+ if _success_from_state(obs_env):
1547
+ episode_success = True
1548
+ break
1549
+ episode_records.append(
1550
+ {
1551
+ "episode_seed": int(episode_seed),
1552
+ "success": episode_success,
1553
+ "steps": len(history),
1554
+ }
1555
+ )
1556
+ print(
1557
+ json.dumps(
1558
+ {
1559
+ "phase": "collect_episode_complete",
1560
+ "split": split_name,
1561
+ "episode_seed": int(episode_seed),
1562
+ "success": episode_success,
1563
+ "steps": len(history),
1564
+ "samples_collected": len(samples),
1565
+ }
1566
+ ),
1567
+ flush=True,
1568
+ )
1569
+ finally:
1570
+ obs_env.close()
1571
+ sim_env.close()
1572
+
1573
+ payload = {
1574
+ "split_name": split_name,
1575
+ "resolution": spec.resolution,
1576
+ "history_steps": spec.history_steps,
1577
+ "samples": samples,
1578
+ "episode_records": episode_records,
1579
+ }
1580
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1581
+ torch.save(payload, output_path)
1582
+ return payload
1583
+
1584
+
1585
+ def _load_split(path: Path) -> dict[str, Any]:
1586
+ return torch.load(path, map_location="cpu", weights_only=False)
1587
+
1588
+
1589
+ def _print_split_supervision_summary(split_name: str, samples: Sequence[dict[str, Any]]) -> None:
1590
+ mode_counter = collections.Counter(str(sample.get("selected_mode", "unknown")) for sample in samples)
1591
+ support_counter = collections.Counter(int(sample.get("support_mode", -1)) for sample in samples if "support_mode" in sample)
1592
+ confidence_values = [float(sample.get("state_confidence_target", 0.0)) for sample in samples if "state_confidence_target" in sample]
1593
+ payload = {
1594
+ "phase": "split_supervision_summary",
1595
+ "split": split_name,
1596
+ "samples": len(samples),
1597
+ "selected_modes": dict(mode_counter),
1598
+ "support_modes": dict(support_counter),
1599
+ "mean_state_confidence_target": float(np.mean(confidence_values)) if confidence_values else None,
1600
+ }
1601
+ print(json.dumps(payload, sort_keys=True), flush=True)
1602
+
1603
+
1604
+ def _batch_from_obs(obs_bundle: dict[str, np.ndarray], proprio: np.ndarray, history: Sequence[dict[str, Any]], device: torch.device) -> dict[str, Any]:
1605
+ return {
1606
+ "images": torch.from_numpy(obs_bundle["images"]).permute(0, 3, 1, 2).unsqueeze(0).float().div(255.0).to(device),
1607
+ "depths": torch.from_numpy(obs_bundle["depths"]).unsqueeze(0).float().to(device),
1608
+ "depth_valid": torch.from_numpy(obs_bundle["depth_valid"]).unsqueeze(0).float().to(device),
1609
+ "camera_intrinsics": torch.from_numpy(obs_bundle["camera_intrinsics"]).unsqueeze(0).float().to(device),
1610
+ "camera_extrinsics": torch.from_numpy(obs_bundle["camera_extrinsics"]).unsqueeze(0).float().to(device),
1611
+ "history_images": torch.from_numpy(
1612
+ _history_stack(history, "images", pad_shape=obs_bundle["images"].shape, dtype=np.uint8, history_steps=HISTORY_STEPS)
1613
+ ).permute(0, 1, 4, 2, 3).unsqueeze(0).float().div(255.0).to(device),
1614
+ "history_depths": torch.from_numpy(
1615
+ _history_stack(history, "depths", pad_shape=obs_bundle["depths"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
1616
+ ).unsqueeze(0).float().to(device),
1617
+ "history_depth_valid": torch.from_numpy(
1618
+ _history_stack(history, "depth_valid", pad_shape=obs_bundle["depth_valid"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
1619
+ ).unsqueeze(0).float().to(device),
1620
+ "history_camera_intrinsics": torch.from_numpy(
1621
+ _history_stack(history, "camera_intrinsics", pad_shape=obs_bundle["camera_intrinsics"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
1622
+ ).unsqueeze(0).float().to(device),
1623
+ "history_camera_extrinsics": torch.from_numpy(
1624
+ _history_stack(history, "camera_extrinsics", pad_shape=obs_bundle["camera_extrinsics"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
1625
+ ).unsqueeze(0).float().to(device),
1626
+ "history_proprio": torch.from_numpy(
1627
+ _history_stack(history, "proprio", pad_shape=(PROPRIO_DIM,), dtype=np.float32, history_steps=HISTORY_STEPS)
1628
+ ).unsqueeze(0).float().to(device),
1629
+ "history_actions": torch.from_numpy(
1630
+ _history_stack(history, "action", pad_shape=(14,), dtype=np.float32, history_steps=HISTORY_STEPS)
1631
+ ).unsqueeze(0).float().to(device),
1632
+ "proprio": torch.from_numpy(proprio).unsqueeze(0).float().to(device),
1633
+ "texts": [TEXT_PROMPT],
1634
+ "task_names": [TASK_NAME],
1635
+ "task_ids": torch.as_tensor([TASK_ID], dtype=torch.long, device=device),
1636
+ }
1637
+
1638
+
1639
+ def _eval_mode_name(
1640
+ model_output: dict[str, Any],
1641
+ checkpoint_mode: str,
1642
+ ) -> tuple[str, bool, bool]:
1643
+ if checkpoint_mode == "adapter_active_ft" and "proposal_mode_names" in model_output and "best_candidate_indices" in model_output:
1644
+ active_mask = bool(_np(model_output.get("adapter_active_mask", np.asarray([False]))).reshape(-1)[0])
1645
+ if not active_mask:
1646
+ mode_name = _classify_mode_from_chunk(_np(model_output["action_mean"])[0])
1647
+ return mode_name, False, False
1648
+ best_index = int(_np(model_output["best_candidate_indices"])[0])
1649
+ proposal_mode_names = model_output["proposal_mode_names"][0]
1650
+ mode_name = str(proposal_mode_names[best_index]) if best_index < len(proposal_mode_names) else _classify_mode_from_chunk(
1651
+ _np(model_output["action_mean"])[0]
1652
+ )
1653
+ non_base = bool(best_index > 0)
1654
+ return mode_name, active_mask, non_base
1655
+ mode_name = _classify_mode_from_chunk(_np(model_output["action_mean"])[0])
1656
+ return mode_name, False, False
1657
+
1658
+
1659
+ def _evaluate_checkpoint(
1660
+ *,
1661
+ checkpoint_path: Path,
1662
+ adapter_mode: str,
1663
+ result_mode_name: str,
1664
+ seeds: Sequence[int],
1665
+ report_path: Path,
1666
+ train_spec: dict[str, Any] | None,
1667
+ dataset_seed: int,
1668
+ planner_overrides: dict[str, float] | None = None,
1669
+ ) -> dict[str, Any]:
1670
+ model, checkpoint = _load_checkpoint(
1671
+ checkpoint_path,
1672
+ adapter_mode=adapter_mode if adapter_mode != "trunk_only" else None,
1673
+ planner_overrides=planner_overrides,
1674
+ )
1675
+ device = next(model.parameters()).device
1676
+ obs_env = PickClutterRevealEnv(
1677
+ obs_mode="rgb+depth+segmentation",
1678
+ control_mode="pd_ee_delta_pos",
1679
+ render_mode="rgb_array",
1680
+ )
1681
+ sim_env = PickClutterRevealEnv(
1682
+ obs_mode="none",
1683
+ control_mode="pd_ee_delta_pos",
1684
+ render_mode="rgb_array",
1685
+ )
1686
+ successes: list[int] = []
1687
+ episode_records: list[dict[str, Any]] = []
1688
+ reveal_steps: list[int] = []
1689
+ retrieve_steps: list[int] = []
1690
+ disturbance_values: list[float] = []
1691
+ intervention_events = 0
1692
+ non_base_events = 0
1693
+ total_decisions = 0
1694
+ try:
1695
+ for episode_seed in seeds:
1696
+ obs, _ = obs_env.reset(seed=int(episode_seed))
1697
+ sim_env.reset(seed=int(episode_seed))
1698
+ _sync_env_state(obs_env, sim_env)
1699
+ history: deque[dict[str, Any]] = deque(maxlen=HISTORY_STEPS)
1700
+ success = False
1701
+ first_reveal_step: int | None = None
1702
+ first_retrieve_step: int | None = None
1703
+ episode_disturbance: list[float] = []
1704
+ for decision_step in range(MAX_MACRO_STEPS):
1705
+ obs_bundle = _extract_sensor_bundle(obs)
1706
+ proprio = _build_proprio(obs_env)
1707
+ batch = _batch_from_obs(obs_bundle, proprio, list(history), device)
1708
+ with torch.no_grad():
1709
+ if adapter_mode == "trunk_only":
1710
+ outputs = model(**batch)
1711
+ else:
1712
+ outputs = model(
1713
+ **batch,
1714
+ adapter_mode=adapter_mode,
1715
+ use_transition_model=True,
1716
+ use_task_conditioning=True,
1717
+ )
1718
+ selected_mode, active_mask, non_base = _eval_mode_name(outputs, result_mode_name)
1719
+ start_positions = _all_positions(obs_env)
1720
+ _sync_env_state(obs_env, sim_env)
1721
+ _execute_mode(sim_env, selected_mode)
1722
+ end_metrics = _candidate_metrics(sim_env, start_positions=start_positions, current_obs_bundle=None)
1723
+ _sync_env_state(sim_env, obs_env)
1724
+ obs = obs_env.get_obs(obs_env.get_info())
1725
+ history.append(_init_history_entry(obs_bundle, proprio, CANONICAL_CHUNKS.get(selected_mode, CANONICAL_CHUNKS["base_action"])))
1726
+ total_decisions += 1
1727
+ intervention_events += int(active_mask)
1728
+ non_base_events += int(non_base)
1729
+ episode_disturbance.append(end_metrics["disturbance"])
1730
+ if selected_mode != "retrieve" and selected_mode not in {"base_action", "maintain_gap"} and first_reveal_step is None:
1731
+ first_reveal_step = decision_step + 1
1732
+ if selected_mode == "retrieve" and first_retrieve_step is None:
1733
+ first_retrieve_step = decision_step + 1
1734
+ if _success_from_state(obs_env):
1735
+ success = True
1736
+ break
1737
+ successes.append(int(success))
1738
+ if first_reveal_step is not None:
1739
+ reveal_steps.append(first_reveal_step)
1740
+ if first_retrieve_step is not None:
1741
+ retrieve_steps.append(first_retrieve_step)
1742
+ disturbance_values.append(float(np.mean(episode_disturbance)) if episode_disturbance else 0.0)
1743
+ episode_records.append(
1744
+ {
1745
+ "episode_seed": int(episode_seed),
1746
+ "success": success,
1747
+ "steps": len(history),
1748
+ "first_reveal_step": first_reveal_step,
1749
+ "first_retrieve_step": first_retrieve_step,
1750
+ "episode_disturbance": float(np.mean(episode_disturbance)) if episode_disturbance else 0.0,
1751
+ }
1752
+ )
1753
+ print(
1754
+ json.dumps(
1755
+ {
1756
+ "phase": "eval_episode_complete",
1757
+ "adapter_mode": result_mode_name,
1758
+ "episode_seed": int(episode_seed),
1759
+ "success": success,
1760
+ "steps": len(history),
1761
+ }
1762
+ ),
1763
+ flush=True,
1764
+ )
1765
+ finally:
1766
+ obs_env.close()
1767
+ sim_env.close()
1768
+
1769
+ eval_protocol = build_public_eval_protocol(
1770
+ track_id="occlusion_track",
1771
+ eval_mode=result_mode_name,
1772
+ seed=int(dataset_seed),
1773
+ episodes=len(seeds),
1774
+ resolution=224,
1775
+ cameras=CAMERA_NAMES,
1776
+ )
1777
+ payload = {
1778
+ "track_id": "occlusion_track",
1779
+ "suite": "maniskill3",
1780
+ "benchmark_task": "PickClutterYCB-v1",
1781
+ "role": "target",
1782
+ "adapter_mode": result_mode_name,
1783
+ "episodes": len(seeds),
1784
+ "successes": successes,
1785
+ "success_rate": float(np.mean(successes)) if successes else 0.0,
1786
+ "intervention_rate": float(intervention_events / max(1, total_decisions)),
1787
+ "non_base_selection_rate": float(non_base_events / max(1, total_decisions)),
1788
+ "steps_to_first_reveal_or_access": float(np.mean(reveal_steps)) if reveal_steps else float(MAX_MACRO_STEPS),
1789
+ "steps_to_retrieve": float(np.mean(retrieve_steps)) if retrieve_steps else float(MAX_MACRO_STEPS),
1790
+ "disturbance_proxy": float(np.mean(disturbance_values)) if disturbance_values else 0.0,
1791
+ "episode_records": episode_records,
1792
+ "eval_protocol": eval_protocol,
1793
+ }
1794
+ if train_spec is not None:
1795
+ payload["train_spec"] = train_spec
1796
+ report_path.parent.mkdir(parents=True, exist_ok=True)
1797
+ report_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
1798
+ return payload
1799
+
1800
+
1801
+ def _summarize_smoke(results: Sequence[dict[str, Any]], output_dir: Path) -> dict[str, Any]:
1802
+ summary = summarize_public_benchmark_package(list(results), allow_partial=True)
1803
+ output_dir.mkdir(parents=True, exist_ok=True)
1804
+ json_path = output_dir / "public_benchmark_package_summary.json"
1805
+ md_path = output_dir / "public_benchmark_package_summary.md"
1806
+ json_path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8")
1807
+ lines = [
1808
+ "# ManiSkill PickClutter Smoke Summary",
1809
+ "",
1810
+ f"- available_tracks: {summary['available_tracks']}",
1811
+ f"- target_macro_average_delta: {summary['target_macro_average_delta']:.3f}",
1812
+ f"- headline_pass: {summary['headline_pass']}",
1813
+ f"- sign_of_life_pass: {summary['sign_of_life_pass']}",
1814
+ "",
1815
+ ]
1816
+ for track_id, payload in summary["tracks"].items():
1817
+ lines.append(f"## {track_id}")
1818
+ lines.append(f"- delta_active_vs_trunk: {payload.get('delta_active_vs_trunk', 0.0):.3f}")
1819
+ lines.append(f"- delta_noop_vs_trunk: {payload.get('delta_noop_vs_trunk', 0.0):.3f}")
1820
+ lines.append(f"- signs_of_life: {payload.get('signs_of_life', False)}")
1821
+ for mode, mode_payload in payload["modes"].items():
1822
+ lines.append(f"- {mode}: mean_success={mode_payload['mean_success']:.3f}")
1823
+ lines.append("")
1824
+ md_path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
1825
+ return summary
1826
+
1827
+
1828
+ def _default_paths() -> SmokePaths:
1829
+ return SmokePaths()
1830
+
1831
+
1832
+ def _dataset_artifact_path(data_dir: Path, basename: str, *, dataset_seed: int) -> Path:
1833
+ if int(dataset_seed) == DEFAULT_SEED:
1834
+ return data_dir / basename
1835
+ artifact = Path(basename)
1836
+ return data_dir / f"{artifact.stem}_seed{int(dataset_seed)}{artifact.suffix}"
1837
+
1838
+
1839
+ def _parse_args() -> argparse.Namespace:
1840
+ parser = argparse.ArgumentParser(description="Minimum-sign-of-life ManiSkill PickClutter smoke run.")
1841
+ parser.add_argument("--stage", choices=("collect", "train", "eval", "all"), default="all")
1842
+ parser.add_argument("--data-dir", type=Path, default=_default_paths().data_dir)
1843
+ parser.add_argument("--output-dir", type=Path, default=_default_paths().output_dir)
1844
+ parser.add_argument("--report-dir", type=Path, default=_default_paths().report_dir)
1845
+ parser.add_argument("--seed", type=int, default=None, help="Deprecated alias for train/eval dataset seed.")
1846
+ parser.add_argument("--train-seed", type=int, default=None)
1847
+ parser.add_argument("--dataset-seed", type=int, default=None)
1848
+ parser.add_argument("--eval-split", choices=("val", "eval"), default="eval")
1849
+ parser.add_argument("--adapter-confidence-threshold", type=float, default=None)
1850
+ parser.add_argument("--retrieve-access-threshold", type=float, default=None)
1851
+ parser.add_argument("--retrieve-persistence-threshold", type=float, default=None)
1852
+ parser.add_argument("--retrieve-support-threshold", type=float, default=None)
1853
+ parser.add_argument("--retrieve-reocclusion-threshold", type=float, default=None)
1854
+ parser.add_argument("--planner-mode-preference-bonus", type=float, default=None)
1855
+ parser.add_argument("--planner-premature-retrieve-penalty", type=float, default=None)
1856
+ parser.add_argument("--planner-premature-insert-penalty", type=float, default=None)
1857
+ parser.add_argument("--planner-premature-occlusion-sweep-penalty", type=float, default=None)
1858
+ parser.add_argument("--planner-premature-maintain-penalty", type=float, default=None)
1859
+ parser.add_argument("--planner-retrieve-stage-access-threshold", type=float, default=None)
1860
+ parser.add_argument("--planner-retrieve-stage-reveal-threshold", type=float, default=None)
1861
+ parser.add_argument("--planner-retrieve-stage-persistence-threshold", type=float, default=None)
1862
+ parser.add_argument("--planner-retrieve-stage-support-threshold", type=float, default=None)
1863
+ parser.add_argument("--planner-insert-stage-access-threshold", type=float, default=None)
1864
+ parser.add_argument("--planner-insert-stage-visibility-threshold", type=float, default=None)
1865
+ parser.add_argument("--planner-insert-stage-support-threshold", type=float, default=None)
1866
+ parser.add_argument("--planner-occlusion-maintain-gap-min-access", type=float, default=None)
1867
+ parser.add_argument("--planner-occlusion-maintain-gap-min-visibility", type=float, default=None)
1868
+ return parser.parse_args()
1869
+
1870
+
1871
+ def _planner_overrides_from_args(args: argparse.Namespace) -> dict[str, float]:
1872
+ override_pairs = (
1873
+ ("adapter_confidence_threshold", args.adapter_confidence_threshold),
1874
+ ("retrieve_access_threshold", args.retrieve_access_threshold),
1875
+ ("retrieve_persistence_threshold", args.retrieve_persistence_threshold),
1876
+ ("retrieve_support_threshold", args.retrieve_support_threshold),
1877
+ ("retrieve_reocclusion_threshold", args.retrieve_reocclusion_threshold),
1878
+ ("mode_preference_bonus", args.planner_mode_preference_bonus),
1879
+ ("premature_retrieve_penalty", args.planner_premature_retrieve_penalty),
1880
+ ("premature_insert_penalty", args.planner_premature_insert_penalty),
1881
+ ("premature_occlusion_sweep_penalty", args.planner_premature_occlusion_sweep_penalty),
1882
+ ("premature_maintain_penalty", args.planner_premature_maintain_penalty),
1883
+ ("retrieve_stage_access_threshold", args.planner_retrieve_stage_access_threshold),
1884
+ ("retrieve_stage_reveal_threshold", args.planner_retrieve_stage_reveal_threshold),
1885
+ ("retrieve_stage_persistence_threshold", args.planner_retrieve_stage_persistence_threshold),
1886
+ ("retrieve_stage_support_threshold", args.planner_retrieve_stage_support_threshold),
1887
+ ("insert_stage_access_threshold", args.planner_insert_stage_access_threshold),
1888
+ ("insert_stage_visibility_threshold", args.planner_insert_stage_visibility_threshold),
1889
+ ("insert_stage_support_threshold", args.planner_insert_stage_support_threshold),
1890
+ ("occlusion_maintain_gap_min_access", args.planner_occlusion_maintain_gap_min_access),
1891
+ ("occlusion_maintain_gap_min_visibility", args.planner_occlusion_maintain_gap_min_visibility),
1892
+ )
1893
+ return {key: value for key, value in override_pairs if value is not None}
1894
+
1895
+
1896
+ def main() -> None:
1897
+ args = _parse_args()
1898
+ planner_overrides = _planner_overrides_from_args(args)
1899
+ base_seed = DEFAULT_SEED if args.seed is None else int(args.seed)
1900
+ train_seed = int(args.train_seed) if args.train_seed is not None else base_seed
1901
+ dataset_seed = int(args.dataset_seed) if args.dataset_seed is not None else base_seed
1902
+ spec = SmokeSpec(dataset_seed=dataset_seed, train_seed=train_seed)
1903
+ splits = _build_episode_splits(spec)
1904
+ split_path = _dataset_artifact_path(args.data_dir, "episode_splits.json", dataset_seed=spec.dataset_seed)
1905
+ train_path = _dataset_artifact_path(args.data_dir, "train.pt", dataset_seed=spec.dataset_seed)
1906
+ val_path = _dataset_artifact_path(args.data_dir, "val.pt", dataset_seed=spec.dataset_seed)
1907
+
1908
+ if args.stage in {"collect", "all"}:
1909
+ _save_episode_splits(split_path, splits)
1910
+ if not train_path.exists():
1911
+ print(json.dumps({"phase": "collect_train_start", "episodes": len(splits["train"])}), flush=True)
1912
+ _collect_split(split_name="train", seeds=splits["train"], spec=spec, output_path=train_path)
1913
+ if not val_path.exists():
1914
+ print(json.dumps({"phase": "collect_val_start", "episodes": len(splits["val"])}), flush=True)
1915
+ _collect_split(split_name="val", seeds=splits["val"], spec=spec, output_path=val_path)
1916
+
1917
+ if args.stage == "collect":
1918
+ return
1919
+
1920
+ train_bundle = _load_split(train_path)
1921
+ val_bundle = _load_split(val_path)
1922
+ train_samples = train_bundle["samples"]
1923
+ val_samples = val_bundle["samples"]
1924
+ _print_split_supervision_summary("train", train_samples)
1925
+ _print_split_supervision_summary("val", val_samples)
1926
+
1927
+ trunk_checkpoint = args.output_dir / f"trunk_only_ft_seed{spec.train_seed}" / "checkpoint_best.pt"
1928
+ adapter_checkpoint = args.output_dir / f"adapter_active_ft_seed{spec.train_seed}" / "checkpoint_best.pt"
1929
+ trunk_train_spec: dict[str, Any] | None = None
1930
+ adapter_train_spec: dict[str, Any] | None = None
1931
+
1932
+ if args.stage in {"train", "all"}:
1933
+ if not trunk_checkpoint.exists():
1934
+ print(json.dumps({"phase": "train_variant_start", "variant": "trunk_only_ft"}), flush=True)
1935
+ trunk_checkpoint, trunk_train_spec = _train_variant(
1936
+ variant="trunk_only_ft",
1937
+ train_samples=train_samples,
1938
+ val_samples=val_samples,
1939
+ spec=spec,
1940
+ output_dir=args.output_dir / f"trunk_only_ft_seed{spec.train_seed}",
1941
+ )
1942
+ else:
1943
+ trunk_payload = torch.load(trunk_checkpoint, map_location="cpu", weights_only=False)
1944
+ trunk_train_spec = trunk_payload.get("train_spec")
1945
+ if not adapter_checkpoint.exists():
1946
+ print(json.dumps({"phase": "train_variant_start", "variant": "adapter_active_ft"}), flush=True)
1947
+ adapter_checkpoint, adapter_train_spec = _train_variant(
1948
+ variant="adapter_active_ft",
1949
+ train_samples=train_samples,
1950
+ val_samples=val_samples,
1951
+ spec=spec,
1952
+ output_dir=args.output_dir / f"adapter_active_ft_seed{spec.train_seed}",
1953
+ )
1954
+ else:
1955
+ adapter_payload = torch.load(adapter_checkpoint, map_location="cpu", weights_only=False)
1956
+ adapter_train_spec = adapter_payload.get("train_spec")
1957
+
1958
+ if args.stage == "train":
1959
+ return
1960
+
1961
+ if trunk_train_spec is None and trunk_checkpoint.exists():
1962
+ trunk_payload = torch.load(trunk_checkpoint, map_location="cpu", weights_only=False)
1963
+ trunk_train_spec = trunk_payload.get("train_spec")
1964
+ if adapter_train_spec is None and adapter_checkpoint.exists():
1965
+ adapter_payload = torch.load(adapter_checkpoint, map_location="cpu", weights_only=False)
1966
+ adapter_train_spec = adapter_payload.get("train_spec")
1967
+
1968
+ eval_seeds = splits[args.eval_split]
1969
+ print(json.dumps({"phase": "eval_start", "episodes": len(eval_seeds)}), flush=True)
1970
+ trunk_result = _evaluate_checkpoint(
1971
+ checkpoint_path=trunk_checkpoint,
1972
+ adapter_mode="trunk_only",
1973
+ result_mode_name="trunk_only_ft",
1974
+ seeds=eval_seeds,
1975
+ report_path=args.report_dir / f"trunk_only_ft_seed{spec.train_seed}.json",
1976
+ train_spec=trunk_train_spec,
1977
+ dataset_seed=spec.dataset_seed,
1978
+ planner_overrides=planner_overrides,
1979
+ )
1980
+ noop_result = _evaluate_checkpoint(
1981
+ checkpoint_path=adapter_checkpoint,
1982
+ adapter_mode="adapter_noop",
1983
+ result_mode_name="adapter_noop",
1984
+ seeds=eval_seeds,
1985
+ report_path=args.report_dir / f"adapter_noop_seed{spec.train_seed}.json",
1986
+ train_spec=adapter_train_spec,
1987
+ dataset_seed=spec.dataset_seed,
1988
+ planner_overrides=planner_overrides,
1989
+ )
1990
+ active_result = _evaluate_checkpoint(
1991
+ checkpoint_path=adapter_checkpoint,
1992
+ adapter_mode="adapter_active",
1993
+ result_mode_name="adapter_active_ft",
1994
+ seeds=eval_seeds,
1995
+ report_path=args.report_dir / f"adapter_active_ft_seed{spec.train_seed}.json",
1996
+ train_spec=adapter_train_spec,
1997
+ dataset_seed=spec.dataset_seed,
1998
+ planner_overrides=planner_overrides,
1999
+ )
2000
+ summary = _summarize_smoke([trunk_result, noop_result, active_result], args.report_dir)
2001
+ print(json.dumps({"phase": "complete", "summary": summary}, indent=2), flush=True)
2002
+
2003
+
2004
+ if __name__ == "__main__":
2005
+ main()
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_public_benchmark_package.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+
10
+ from eval.public_benchmark_package import (
11
+ ANCHOR_ROLE,
12
+ DEFAULT_ANCHOR_TOLERANCE,
13
+ DEFAULT_SIGN_OF_LIFE_GAIN,
14
+ DEFAULT_SIGN_OF_LIFE_INTERVENTION,
15
+ DEFAULT_SIGN_OF_LIFE_NON_BASE,
16
+ TARGET_ROLE,
17
+ build_public_eval_protocol,
18
+ build_target_training_spec,
19
+ default_public_benchmark_manifest,
20
+ expected_eval_modes,
21
+ public_benchmark_tracks,
22
+ public_protocol_identity_signature,
23
+ public_track_by_id,
24
+ training_fairness_signature,
25
+ write_default_public_benchmark_manifest,
26
+ )
27
+
28
+
29
+ def _load_json(path: str | Path) -> dict[str, Any]:
30
+ with Path(path).open("r", encoding="utf-8") as handle:
31
+ payload = json.load(handle)
32
+ if not isinstance(payload, dict):
33
+ raise TypeError(f"Expected a JSON object in {path!s}, got {type(payload)!r}.")
34
+ return payload
35
+
36
+
37
+ def _normalize_success_samples(payload: dict[str, Any]) -> np.ndarray:
38
+ if "successes" in payload:
39
+ raw = np.asarray(payload["successes"], dtype=np.float32).reshape(-1)
40
+ return raw
41
+ if "success_rate" in payload:
42
+ return np.asarray([float(payload["success_rate"])], dtype=np.float32)
43
+ raise KeyError("Each result payload must include either `successes` or `success_rate`.")
44
+
45
+
46
+ def _mean_optional(records: list[dict[str, Any]], key: str) -> float | None:
47
+ values = [float(record[key]) for record in records if key in record]
48
+ if not values:
49
+ return None
50
+ return float(np.mean(values))
51
+
52
+
53
+ def _bootstrap_delta_ci(
54
+ lhs: np.ndarray,
55
+ rhs: np.ndarray,
56
+ *,
57
+ bootstrap_samples: int,
58
+ bootstrap_seed: int,
59
+ ) -> tuple[float, float]:
60
+ if lhs.size == 0 or rhs.size == 0:
61
+ return 0.0, 0.0
62
+ rng = np.random.default_rng(int(bootstrap_seed))
63
+ deltas = np.empty(int(bootstrap_samples), dtype=np.float32)
64
+ for index in range(int(bootstrap_samples)):
65
+ lhs_sample = lhs[rng.integers(0, lhs.shape[0], size=lhs.shape[0])]
66
+ rhs_sample = rhs[rng.integers(0, rhs.shape[0], size=rhs.shape[0])]
67
+ deltas[index] = float(lhs_sample.mean() - rhs_sample.mean())
68
+ low, high = np.percentile(deltas, [2.5, 97.5])
69
+ return float(low), float(high)
70
+
71
+
72
+ def _normalize_record(payload: dict[str, Any]) -> dict[str, Any]:
73
+ if "track_id" not in payload:
74
+ raise KeyError("Missing required field `track_id`.")
75
+ if "adapter_mode" not in payload:
76
+ raise KeyError("Missing required field `adapter_mode`.")
77
+ track = public_track_by_id(str(payload["track_id"]))
78
+ success_samples = _normalize_success_samples(payload)
79
+ success_rate = float(payload.get("success_rate", float(success_samples.mean())))
80
+ episodes = int(payload.get("episodes", success_samples.shape[0]))
81
+ record = dict(payload)
82
+ record["track_id"] = track.track_id
83
+ record["suite"] = payload.get("suite", track.suite)
84
+ record["benchmark_task"] = payload.get("benchmark_task", track.benchmark_task)
85
+ record["role"] = payload.get("role", track.role)
86
+ record["adapter_mode"] = str(payload["adapter_mode"])
87
+ record["successes"] = success_samples.tolist()
88
+ record["success_rate"] = success_rate
89
+ record["episodes"] = episodes
90
+ return record
91
+
92
+
93
+ def _validate_protocols(records: list[dict[str, Any]]) -> None:
94
+ by_track: dict[str, list[dict[str, Any]]] = {}
95
+ for record in records:
96
+ by_track.setdefault(record["track_id"], []).append(record)
97
+ for track_id, grouped in by_track.items():
98
+ signatures = []
99
+ for record in grouped:
100
+ protocol = record.get("eval_protocol")
101
+ if protocol is None:
102
+ raise ValueError(
103
+ f"Missing eval_protocol for track {track_id!r}, mode {record['adapter_mode']!r}."
104
+ )
105
+ signatures.append(public_protocol_identity_signature(protocol))
106
+ if any(signature != signatures[0] for signature in signatures[1:]):
107
+ raise ValueError(f"Protocol identity mismatch detected for track {track_id!r}.")
108
+
109
+
110
+ def _validate_training_fairness(records: list[dict[str, Any]]) -> None:
111
+ grouped: dict[tuple[str, str], list[dict[str, Any]]] = {}
112
+ for record in records:
113
+ grouped.setdefault((record["track_id"], record["adapter_mode"]), []).append(record)
114
+ for track in public_benchmark_tracks(TARGET_ROLE):
115
+ trunk_records = grouped.get((track.track_id, "trunk_only_ft"), [])
116
+ active_records = grouped.get((track.track_id, "adapter_active_ft"), [])
117
+ if not trunk_records or not active_records:
118
+ continue
119
+ if len(trunk_records) != len(active_records):
120
+ raise ValueError(
121
+ f"Training fairness mismatch for {track.track_id!r}: different run counts "
122
+ f"between trunk_only_ft ({len(trunk_records)}) and adapter_active_ft ({len(active_records)})."
123
+ )
124
+ if any(record.get("train_spec") is None for record in trunk_records + active_records):
125
+ raise ValueError(
126
+ f"Training fairness mismatch for {track.track_id!r}: missing train_spec on a target-track result."
127
+ )
128
+ trunk_by_seed = {
129
+ int(record["train_spec"]["seed"]): training_fairness_signature(record["train_spec"])
130
+ for record in trunk_records
131
+ }
132
+ active_by_seed = {
133
+ int(record["train_spec"]["seed"]): training_fairness_signature(record["train_spec"])
134
+ for record in active_records
135
+ }
136
+ if set(trunk_by_seed) != set(active_by_seed):
137
+ raise ValueError(f"Training fairness mismatch for {track.track_id!r}: seed sets differ.")
138
+ for seed, trunk_signature in trunk_by_seed.items():
139
+ if trunk_signature != active_by_seed[seed]:
140
+ raise ValueError(
141
+ f"Training fairness mismatch for {track.track_id!r} at seed {seed}: "
142
+ "trunk_only_ft and adapter_active_ft do not share the same data/init signature."
143
+ )
144
+
145
+
146
+ def _aggregate_mode(records: list[dict[str, Any]]) -> dict[str, Any]:
147
+ success_rates = np.asarray([float(record["success_rate"]) for record in records], dtype=np.float32)
148
+ success_samples = np.concatenate(
149
+ [np.asarray(record["successes"], dtype=np.float32).reshape(-1) for record in records],
150
+ axis=0,
151
+ )
152
+ payload: dict[str, Any] = {
153
+ "num_runs": len(records),
154
+ "mean_success": float(success_rates.mean()) if success_rates.size else 0.0,
155
+ "success_samples": success_samples.tolist(),
156
+ }
157
+ for key in (
158
+ "intervention_rate",
159
+ "non_base_selection_rate",
160
+ "steps_to_first_reveal_or_access",
161
+ "steps_to_retrieve",
162
+ "disturbance_proxy",
163
+ ):
164
+ mean_value = _mean_optional(records, key)
165
+ if mean_value is not None:
166
+ payload[key] = mean_value
167
+ return payload
168
+
169
+
170
+ def summarize_public_benchmark_package(
171
+ result_payloads: list[dict[str, Any]],
172
+ *,
173
+ bootstrap_samples: int = 2000,
174
+ bootstrap_seed: int = 0,
175
+ allow_partial: bool = False,
176
+ ) -> dict[str, Any]:
177
+ records = [_normalize_record(payload) for payload in result_payloads]
178
+ _validate_protocols(records)
179
+ _validate_training_fairness(records)
180
+
181
+ grouped: dict[tuple[str, str], list[dict[str, Any]]] = {}
182
+ for record in records:
183
+ grouped.setdefault((record["track_id"], record["adapter_mode"]), []).append(record)
184
+
185
+ track_summaries: dict[str, Any] = {}
186
+ target_deltas: list[float] = []
187
+ anchor_pass = True
188
+ sign_of_life_tracks: list[str] = []
189
+ ci_above_zero_tracks: list[str] = []
190
+ available_tracks: list[str] = []
191
+
192
+ for track in public_benchmark_tracks():
193
+ track_modes = expected_eval_modes(track.track_id)
194
+ mode_payloads: dict[str, Any] = {}
195
+ missing_modes: list[str] = []
196
+ for mode in track_modes:
197
+ mode_records = grouped.get((track.track_id, mode), [])
198
+ if not mode_records:
199
+ missing_modes.append(mode)
200
+ continue
201
+ mode_payloads[mode] = _aggregate_mode(mode_records)
202
+ if missing_modes:
203
+ if allow_partial:
204
+ continue
205
+ raise ValueError(f"Missing results for track {track.track_id!r}, mode(s) {missing_modes!r}.")
206
+ available_tracks.append(track.track_id)
207
+
208
+ track_summary: dict[str, Any] = {
209
+ "suite": track.suite,
210
+ "benchmark_task": track.benchmark_task,
211
+ "role": track.role,
212
+ "task_family": track.task_family,
213
+ "target_behavior": track.target_behavior,
214
+ "public_source": track.public_source,
215
+ "notes": track.notes,
216
+ "modes": {
217
+ mode: {
218
+ key: value
219
+ for key, value in payload.items()
220
+ if key != "success_samples"
221
+ }
222
+ for mode, payload in mode_payloads.items()
223
+ },
224
+ }
225
+
226
+ if track.role == TARGET_ROLE:
227
+ trunk = mode_payloads["trunk_only_ft"]
228
+ active = mode_payloads["adapter_active_ft"]
229
+ noop = mode_payloads["adapter_noop"]
230
+ delta_active = float(active["mean_success"] - trunk["mean_success"])
231
+ delta_noop = float(noop["mean_success"] - trunk["mean_success"])
232
+ target_deltas.append(delta_active)
233
+ ci_low, ci_high = _bootstrap_delta_ci(
234
+ np.asarray(active["success_samples"], dtype=np.float32),
235
+ np.asarray(trunk["success_samples"], dtype=np.float32),
236
+ bootstrap_samples=bootstrap_samples,
237
+ bootstrap_seed=bootstrap_seed + len(target_deltas),
238
+ )
239
+ sign_of_life = bool(
240
+ float(active.get("intervention_rate", 0.0)) >= DEFAULT_SIGN_OF_LIFE_INTERVENTION
241
+ and float(active.get("non_base_selection_rate", 0.0)) >= DEFAULT_SIGN_OF_LIFE_NON_BASE
242
+ and delta_active >= DEFAULT_SIGN_OF_LIFE_GAIN
243
+ )
244
+ if sign_of_life:
245
+ sign_of_life_tracks.append(track.track_id)
246
+ if ci_low > 0.0:
247
+ ci_above_zero_tracks.append(track.track_id)
248
+ track_summary.update(
249
+ {
250
+ "delta_active_vs_trunk": delta_active,
251
+ "delta_noop_vs_trunk": delta_noop,
252
+ "delta_active_vs_trunk_ci95": [ci_low, ci_high],
253
+ "signs_of_life": sign_of_life,
254
+ }
255
+ )
256
+ else:
257
+ trunk = mode_payloads["trunk_only"]
258
+ active = mode_payloads["adapter_active"]
259
+ noop = mode_payloads["adapter_noop"]
260
+ active_delta = float(active["mean_success"] - trunk["mean_success"])
261
+ noop_delta = float(noop["mean_success"] - trunk["mean_success"])
262
+ within_tolerance = bool(
263
+ abs(active_delta) <= DEFAULT_ANCHOR_TOLERANCE
264
+ and abs(noop_delta) <= DEFAULT_ANCHOR_TOLERANCE
265
+ )
266
+ anchor_pass = anchor_pass and within_tolerance
267
+ track_summary.update(
268
+ {
269
+ "delta_active_vs_trunk": active_delta,
270
+ "delta_noop_vs_trunk": noop_delta,
271
+ "anchor_within_tolerance": within_tolerance,
272
+ }
273
+ )
274
+
275
+ track_summaries[track.track_id] = track_summary
276
+
277
+ headline_pass = bool(
278
+ target_deltas
279
+ and all(delta > 0.0 for delta in target_deltas)
280
+ and len(ci_above_zero_tracks) >= 1
281
+ )
282
+ sign_of_life_pass = len(sign_of_life_tracks) >= 2
283
+
284
+ return {
285
+ "package_name": default_public_benchmark_manifest()["package_name"],
286
+ "tracks": track_summaries,
287
+ "available_tracks": available_tracks,
288
+ "target_macro_average_delta": float(np.mean(target_deltas)) if target_deltas else 0.0,
289
+ "headline_pass": headline_pass,
290
+ "sign_of_life_pass": sign_of_life_pass,
291
+ "sign_of_life_track_count": len(sign_of_life_tracks),
292
+ "sign_of_life_tracks": sign_of_life_tracks,
293
+ "ci_above_zero_tracks": ci_above_zero_tracks,
294
+ "anchor_pass": anchor_pass,
295
+ }
296
+
297
+
298
+ def _write_markdown(output_path: Path, summary: dict[str, Any]) -> None:
299
+ lines = [
300
+ "# Public Benchmark Package Summary",
301
+ "",
302
+ f"- package_name: {summary['package_name']}",
303
+ f"- headline_pass: {summary['headline_pass']}",
304
+ f"- sign_of_life_pass: {summary['sign_of_life_pass']}",
305
+ f"- sign_of_life_track_count: {summary['sign_of_life_track_count']}",
306
+ f"- anchor_pass: {summary['anchor_pass']}",
307
+ f"- target_macro_average_delta: {summary['target_macro_average_delta']:.3f}",
308
+ "",
309
+ ]
310
+ for track_id, payload in summary["tracks"].items():
311
+ lines.append(f"## {track_id}")
312
+ lines.append(f"- suite: {payload['suite']}")
313
+ lines.append(f"- benchmark_task: {payload['benchmark_task']}")
314
+ lines.append(f"- role: {payload['role']}")
315
+ for mode, mode_payload in payload["modes"].items():
316
+ lines.append(f"- {mode}: mean_success={mode_payload['mean_success']:.3f}, num_runs={mode_payload['num_runs']}")
317
+ if "delta_active_vs_trunk" in payload:
318
+ lines.append(f"- delta_active_vs_trunk: {payload['delta_active_vs_trunk']:.3f}")
319
+ if "delta_noop_vs_trunk" in payload:
320
+ lines.append(f"- delta_noop_vs_trunk: {payload['delta_noop_vs_trunk']:.3f}")
321
+ if "delta_active_vs_trunk_ci95" in payload:
322
+ low, high = payload["delta_active_vs_trunk_ci95"]
323
+ lines.append(f"- delta_active_vs_trunk_ci95: [{low:.3f}, {high:.3f}]")
324
+ if "signs_of_life" in payload:
325
+ lines.append(f"- signs_of_life: {payload['signs_of_life']}")
326
+ if "anchor_within_tolerance" in payload:
327
+ lines.append(f"- anchor_within_tolerance: {payload['anchor_within_tolerance']}")
328
+ lines.append("")
329
+ output_path.parent.mkdir(parents=True, exist_ok=True)
330
+ output_path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
331
+
332
+
333
+ def _parse_args() -> argparse.Namespace:
334
+ parser = argparse.ArgumentParser(description="Validate and summarize the public benchmark package results.")
335
+ parser.add_argument("--result", action="append", default=[], help="Path to a normalized benchmark result JSON.")
336
+ parser.add_argument("--output-dir", type=Path, default=Path.home() / "workspace" / "reports" / "public_benchmark_package_v1")
337
+ parser.add_argument("--bootstrap-samples", type=int, default=2000)
338
+ parser.add_argument("--bootstrap-seed", type=int, default=0)
339
+ parser.add_argument("--write-default-manifest", type=Path, default=None)
340
+ return parser.parse_args()
341
+
342
+
343
+ def main() -> None:
344
+ args = _parse_args()
345
+ if args.write_default_manifest is not None:
346
+ path = write_default_public_benchmark_manifest(args.write_default_manifest)
347
+ print(json.dumps({"wrote_manifest": str(path)}, indent=2))
348
+ if not args.result:
349
+ return
350
+ if not args.result:
351
+ raise SystemExit("No results provided. Pass one or more --result files or use --write-default-manifest.")
352
+
353
+ payloads = [_load_json(path) for path in args.result]
354
+ summary = summarize_public_benchmark_package(
355
+ payloads,
356
+ bootstrap_samples=args.bootstrap_samples,
357
+ bootstrap_seed=args.bootstrap_seed,
358
+ )
359
+
360
+ args.output_dir.mkdir(parents=True, exist_ok=True)
361
+ json_path = args.output_dir / "public_benchmark_package_summary.json"
362
+ md_path = args.output_dir / "public_benchmark_package_summary.md"
363
+ json_path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8")
364
+ _write_markdown(md_path, summary)
365
+ print(json.dumps({"summary_json": str(json_path), "summary_md": str(md_path)}, indent=2))
366
+
367
+
368
+ if __name__ == "__main__":
369
+ main()
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.26 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.5 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc ADDED
Binary file (26.3 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc ADDED
Binary file (58.2 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc ADDED
Binary file (19 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc ADDED
Binary file (37.3 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc ADDED
Binary file (4.25 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc ADDED
Binary file (8.01 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-311.pyc ADDED
Binary file (27.7 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc ADDED
Binary file (25.3 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc ADDED
Binary file (55.6 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc ADDED
Binary file (32.1 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc ADDED
Binary file (58.9 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc ADDED
Binary file (46.7 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/rvt_backbone.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/rvt_backbone.cpython-311.pyc ADDED
Binary file (29.6 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc ADDED
Binary file (20 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc ADDED
Binary file (49.6 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/planner.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ @dataclass
10
+ class PlannerConfig:
11
+ hidden_dim: int = 512
12
+ num_candidates: int = 8
13
+ action_dim: int = 14
14
+ num_support_modes: int = 3
15
+ utility_margin: float = 0.1
16
+ corridor_weight: float = 1.0
17
+ persistence_weight: float = 0.5
18
+ proposal_weight: float = 0.5
19
+ task_progress_weight: float = 0.75
20
+ disturbance_weight: float = 0.75
21
+ reocclusion_weight: float = 0.5
22
+ visibility_weight: float = 0.25
23
+ num_heads: int = 4
24
+ num_layers: int = 2
25
+ num_phases: int = 5
26
+ num_arm_roles: int = 4
27
+ top_k: int = 4
28
+ belief_gain_weight: float = 1.0
29
+ visibility_gain_weight: float = 0.75
30
+ clearance_weight: float = 0.75
31
+ occluder_contact_weight: float = 0.5
32
+ grasp_affordance_weight: float = 0.75
33
+ support_stability_weight: float = 0.5
34
+ residual_weight: float = 0.5
35
+ retrieve_access_threshold: float = 0.15
36
+ retrieve_persistence_threshold: float = 0.15
37
+ retrieve_support_threshold: float = 0.25
38
+ retrieve_reocclusion_threshold: float = 0.6
39
+ adapter_confidence_threshold: float = 0.55
40
+ mode_preference_bonus: float = 3.0
41
+ premature_retrieve_penalty: float = 1.5
42
+ premature_insert_penalty: float = 0.75
43
+ premature_occlusion_sweep_penalty: float = 0.75
44
+ premature_maintain_penalty: float = 0.0
45
+ retrieve_stage_access_threshold: float = 0.45
46
+ retrieve_stage_reveal_threshold: float = 0.40
47
+ retrieve_stage_persistence_threshold: float = 0.20
48
+ retrieve_stage_support_threshold: float = 0.25
49
+ insert_stage_access_threshold: float = 0.40
50
+ insert_stage_visibility_threshold: float = 0.30
51
+ insert_stage_support_threshold: float = 0.25
52
+ occlusion_maintain_gap_min_access: float = 0.0
53
+ occlusion_maintain_gap_min_visibility: float = 0.0
54
+
55
+
56
+ class RevealPlanner(nn.Module):
57
+ def __init__(self, config: PlannerConfig) -> None:
58
+ super().__init__()
59
+ self.config = config
60
+ summary_dim = (
61
+ config.action_dim * 2
62
+ + 3
63
+ + 3
64
+ + 1
65
+ + 3
66
+ + 1
67
+ )
68
+ self.trunk = nn.Sequential(
69
+ nn.LayerNorm(summary_dim),
70
+ nn.Linear(summary_dim, config.hidden_dim),
71
+ nn.GELU(),
72
+ nn.Linear(config.hidden_dim, config.hidden_dim),
73
+ nn.GELU(),
74
+ )
75
+ self.success_head = nn.Linear(config.hidden_dim, 1)
76
+ self.risk_head = nn.Linear(config.hidden_dim, 1)
77
+
78
+ def summarize_candidates(self, candidate_chunks: Tensor, rollout_state: dict[str, Tensor]) -> Tensor:
79
+ candidate_mean = candidate_chunks.mean(dim=2)
80
+ candidate_terminal = candidate_chunks[:, :, -1]
81
+ corridor_prob = rollout_state["corridor_logits"].sigmoid().amax(dim=-1).mean(dim=-2)
82
+ persistence = rollout_state["persistence_horizon"].mean(dim=-2)
83
+ disturbance = rollout_state["disturbance_cost"].mean(dim=-1, keepdim=True)
84
+ reocclusion = rollout_state["reocclusion_logit"].sigmoid().mean(dim=-2)
85
+ uncertainty = rollout_state["uncertainty"].mean(dim=-1, keepdim=True)
86
+ return torch.cat(
87
+ [
88
+ candidate_mean,
89
+ candidate_terminal,
90
+ corridor_prob,
91
+ persistence,
92
+ disturbance,
93
+ reocclusion,
94
+ uncertainty,
95
+ ],
96
+ dim=-1,
97
+ )
98
+
99
+ def score_rollouts(self, rollout_state: dict[str, Tensor], candidate_chunks: Tensor) -> dict[str, Tensor]:
100
+ features = self.summarize_candidates(candidate_chunks, rollout_state)
101
+ hidden = self.trunk(features)
102
+ success_logits = self.success_head(hidden).squeeze(-1)
103
+ risk_values = torch.sigmoid(self.risk_head(hidden)).squeeze(-1)
104
+ utility_scores = success_logits.sigmoid() - risk_values
105
+ return {
106
+ "planner_features": features,
107
+ "planner_hidden": hidden,
108
+ "success_logits": success_logits,
109
+ "risk_values": risk_values,
110
+ "utility_scores": utility_scores,
111
+ }
112
+
113
+ def select_best(self, candidate_chunks: Tensor, rollout_state: dict[str, Tensor]) -> dict[str, Tensor]:
114
+ outputs = self.score_rollouts(rollout_state=rollout_state, candidate_chunks=candidate_chunks)
115
+ best_idx = outputs["utility_scores"].argmax(dim=-1)
116
+ batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
117
+ return {
118
+ **outputs,
119
+ "best_indices": best_idx,
120
+ "best_chunk": candidate_chunks[batch_indices, best_idx],
121
+ }
122
+
123
+
124
+ class InteractionPlanner(nn.Module):
125
+ def __init__(self, config: PlannerConfig) -> None:
126
+ super().__init__()
127
+ self.config = config
128
+ step_dim = (
129
+ config.action_dim
130
+ + config.num_phases
131
+ + (2 * config.num_arm_roles)
132
+ + config.num_support_modes
133
+ + 7
134
+ )
135
+ self.step_proj = nn.Sequential(
136
+ nn.LayerNorm(step_dim),
137
+ nn.Linear(step_dim, config.hidden_dim),
138
+ nn.GELU(),
139
+ )
140
+ encoder_layer = nn.TransformerEncoderLayer(
141
+ d_model=config.hidden_dim,
142
+ nhead=config.num_heads,
143
+ dim_feedforward=config.hidden_dim * 4,
144
+ batch_first=True,
145
+ norm_first=True,
146
+ )
147
+ self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
148
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_dim) * 0.02)
149
+ self.success_head = nn.Linear(config.hidden_dim, 1)
150
+ self.risk_head = nn.Linear(config.hidden_dim, 1)
151
+ self.score_head = nn.Linear(config.hidden_dim, 1)
152
+
153
+ def _mean_field(self, tensor: Tensor) -> Tensor:
154
+ return tensor.mean(dim=(-1, -2))
155
+
156
+ def summarize_trajectory(self, candidate_chunks: Tensor, rollout_state: dict[str, Tensor]) -> Tensor:
157
+ horizon = min(candidate_chunks.shape[2], rollout_state["phase_logits"].shape[2])
158
+ candidate_steps = candidate_chunks[:, :, :horizon]
159
+ phase_probs = rollout_state["phase_logits"][:, :, :horizon].softmax(dim=-1)
160
+ support_probs = rollout_state["support_mode_logits"][:, :, :horizon].softmax(dim=-1)
161
+ arm_role_probs = rollout_state["arm_role_logits"][:, :, :horizon].softmax(dim=-1).flatten(start_dim=-2)
162
+ target_mean = self._mean_field(rollout_state["target_field"][:, :, :horizon].sigmoid())
163
+ feasibility_mean = self._mean_field(rollout_state["actor_feasibility_field"][:, :, :horizon].sigmoid())
164
+ persistence_mean = self._mean_field(rollout_state["persistence_field"][:, :, :horizon])
165
+ risk_mean = self._mean_field(rollout_state["risk_field"][:, :, :horizon])
166
+ uncertainty_mean = self._mean_field(rollout_state["uncertainty_field"][:, :, :horizon])
167
+ role_gap = (
168
+ rollout_state["arm_role_logits"][:, :, :horizon, 0].softmax(dim=-1)
169
+ - rollout_state["arm_role_logits"][:, :, :horizon, 1].softmax(dim=-1)
170
+ ).abs().mean(dim=-1, keepdim=True)
171
+ return torch.cat(
172
+ [
173
+ candidate_steps,
174
+ phase_probs,
175
+ arm_role_probs,
176
+ support_probs,
177
+ target_mean,
178
+ feasibility_mean,
179
+ persistence_mean,
180
+ risk_mean,
181
+ uncertainty_mean,
182
+ role_gap,
183
+ ],
184
+ dim=-1,
185
+ )
186
+
187
+ def score_rollouts(
188
+ self,
189
+ rollout_state: dict[str, Tensor],
190
+ candidate_chunks: Tensor,
191
+ proposal_logits: Tensor | None = None,
192
+ ) -> dict[str, Tensor]:
193
+ features = self.summarize_trajectory(candidate_chunks, rollout_state)
194
+ batch_size, num_candidates, horizon, _ = features.shape
195
+ flat_features = features.view(batch_size * num_candidates, horizon, -1)
196
+ hidden_steps = self.step_proj(flat_features)
197
+ cls = self.cls_token.expand(batch_size * num_candidates, -1, -1)
198
+ encoded = self.sequence_encoder(torch.cat([cls, hidden_steps], dim=1))
199
+ pooled = encoded[:, 0]
200
+ success_logits = self.success_head(pooled).view(batch_size, num_candidates).squeeze(-1)
201
+ risk_values = torch.sigmoid(self.risk_head(pooled)).view(batch_size, num_candidates).squeeze(-1)
202
+ utility_scores = self.score_head(pooled).view(batch_size, num_candidates).squeeze(-1)
203
+ utility_scores = utility_scores + success_logits.sigmoid() - risk_values
204
+ if proposal_logits is not None and proposal_logits.shape == utility_scores.shape:
205
+ utility_scores = utility_scores + self.config.proposal_weight * proposal_logits.sigmoid()
206
+ return {
207
+ "planner_features": features.mean(dim=2),
208
+ "planner_hidden": pooled.view(batch_size, num_candidates, -1),
209
+ "success_logits": success_logits,
210
+ "risk_values": risk_values,
211
+ "utility_scores": utility_scores,
212
+ }
213
+
214
+ def select_best(
215
+ self,
216
+ candidate_chunks: Tensor,
217
+ rollout_state: dict[str, Tensor],
218
+ proposal_logits: Tensor | None = None,
219
+ ) -> dict[str, Tensor]:
220
+ outputs = self.score_rollouts(
221
+ rollout_state=rollout_state,
222
+ candidate_chunks=candidate_chunks,
223
+ proposal_logits=proposal_logits,
224
+ )
225
+ best_idx = outputs["utility_scores"].argmax(dim=-1)
226
+ batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
227
+ return {
228
+ **outputs,
229
+ "best_indices": best_idx,
230
+ "best_chunk": candidate_chunks[batch_indices, best_idx],
231
+ }
232
+
233
+
234
+ class StructuredElasticUtility(nn.Module):
235
+ def __init__(self, config: PlannerConfig) -> None:
236
+ super().__init__()
237
+ self.config = config
238
+
239
+ def _field_mean(self, tensor: Tensor) -> Tensor:
240
+ if tensor.ndim == 6:
241
+ return tensor.mean(dim=(-1, -2, -3))
242
+ if tensor.ndim == 5:
243
+ return tensor.mean(dim=(-1, -2))
244
+ if tensor.ndim == 4:
245
+ return tensor.mean(dim=(-1, -2))
246
+ return tensor
247
+
248
+ def _initial_scalar(self, state: dict[str, Tensor], key: str) -> Tensor:
249
+ value = state[key]
250
+ if value.ndim >= 4:
251
+ return value.mean(dim=tuple(range(1, value.ndim)))
252
+ if value.ndim == 3:
253
+ return value.mean(dim=(-1, -2))
254
+ if value.ndim == 2:
255
+ return value.mean(dim=-1)
256
+ return value
257
+
258
+ def forward(
259
+ self,
260
+ initial_state: dict[str, Tensor],
261
+ rollout_state: dict[str, Tensor],
262
+ candidate_chunks: Tensor,
263
+ ) -> dict[str, Tensor]:
264
+ initial_belief = self._initial_scalar(initial_state, "target_belief_field").unsqueeze(1)
265
+ initial_visibility = self._initial_scalar(initial_state, "visibility_field").unsqueeze(1)
266
+ belief_future = self._field_mean(rollout_state["target_belief_field"]).mean(dim=-1)
267
+ visibility_future = self._field_mean(rollout_state["visibility_field"]).mean(dim=-1)
268
+ clearance = self._field_mean(rollout_state["clearance_field"]).mean(dim=-1)
269
+ occluder_contact = self._field_mean(rollout_state["occluder_contact_field"]).mean(dim=-1)
270
+ grasp_affordance = self._field_mean(rollout_state["grasp_affordance_field"]).mean(dim=-1)
271
+ support_stability = torch.sigmoid(self._field_mean(rollout_state["support_stability_field"])).mean(dim=-1)
272
+ persistence_traj = self._field_mean(rollout_state["persistence_field"])
273
+ reocclusion_traj = self._field_mean(rollout_state["reocclusion_field"])
274
+ disturbance_traj = self._field_mean(rollout_state["disturbance_field"])
275
+ access_traj = torch.sigmoid(self._field_mean(rollout_state["access_field"]))
276
+ persistence = persistence_traj.mean(dim=-1)
277
+ reocclusion = reocclusion_traj.mean(dim=-1)
278
+ disturbance = disturbance_traj.mean(dim=-1)
279
+ access_quality = access_traj.mean(dim=-1)
280
+ access_floor = access_traj.amin(dim=-1)
281
+ persistence_floor = persistence_traj.amin(dim=-1)
282
+ support_floor = torch.sigmoid(self._field_mean(rollout_state["support_stability_field"])).amin(dim=-1)
283
+ reocclusion_worst = reocclusion_traj.amax(dim=-1)
284
+ retrieve_progress = torch.sigmoid(candidate_chunks[:, :, :, -1]).mean(dim=-1)
285
+ utility = (
286
+ self.config.belief_gain_weight * (belief_future - initial_belief)
287
+ + self.config.visibility_gain_weight * (visibility_future - initial_visibility)
288
+ + self.config.clearance_weight * clearance
289
+ + self.config.occluder_contact_weight * occluder_contact
290
+ + self.config.grasp_affordance_weight * grasp_affordance
291
+ + self.config.persistence_weight * persistence
292
+ + self.config.support_stability_weight * support_stability
293
+ + self.config.corridor_weight * access_quality
294
+ + self.config.task_progress_weight * retrieve_progress
295
+ - self.config.reocclusion_weight * reocclusion
296
+ - self.config.disturbance_weight * disturbance
297
+ - self.config.visibility_weight * (1.0 - visibility_future)
298
+ )
299
+ return {
300
+ "belief_gain": belief_future - initial_belief,
301
+ "visibility_gain": visibility_future - initial_visibility,
302
+ "clearance": clearance,
303
+ "occluder_contact_quality": occluder_contact,
304
+ "grasp_affordance": grasp_affordance,
305
+ "persistence": persistence,
306
+ "support_stability": support_stability,
307
+ "reocclusion_penalty": reocclusion,
308
+ "reocclusion_worst": reocclusion_worst,
309
+ "disturbance_penalty": disturbance,
310
+ "access_quality": access_quality,
311
+ "access_floor": access_floor,
312
+ "persistence_floor": persistence_floor,
313
+ "support_floor": support_floor,
314
+ "task_progress": retrieve_progress,
315
+ "utility_structured": utility,
316
+ }
317
+
318
+
319
+ class ResidualPlannerScorer(nn.Module):
320
+ def __init__(self, config: PlannerConfig) -> None:
321
+ super().__init__()
322
+ feature_dim = (config.action_dim * 2) + 11
323
+ self.trunk = nn.Sequential(
324
+ nn.LayerNorm(feature_dim),
325
+ nn.Linear(feature_dim, config.hidden_dim),
326
+ nn.GELU(),
327
+ nn.Linear(config.hidden_dim, config.hidden_dim),
328
+ nn.GELU(),
329
+ )
330
+ self.success_head = nn.Linear(config.hidden_dim, 1)
331
+ self.risk_head = nn.Linear(config.hidden_dim, 1)
332
+ self.residual_head = nn.Linear(config.hidden_dim, 1)
333
+
334
+ def forward(
335
+ self,
336
+ candidate_chunks: Tensor,
337
+ structured: dict[str, Tensor],
338
+ proposal_logits: Tensor | None = None,
339
+ ) -> dict[str, Tensor]:
340
+ candidate_mean = candidate_chunks.mean(dim=2)
341
+ candidate_terminal = candidate_chunks[:, :, -1]
342
+ components = torch.stack(
343
+ [
344
+ structured["belief_gain"],
345
+ structured["visibility_gain"],
346
+ structured["clearance"],
347
+ structured["occluder_contact_quality"],
348
+ structured["grasp_affordance"],
349
+ structured["persistence"],
350
+ structured["support_stability"],
351
+ structured["reocclusion_penalty"],
352
+ structured["disturbance_penalty"],
353
+ structured["access_quality"],
354
+ structured["task_progress"],
355
+ ],
356
+ dim=-1,
357
+ )
358
+ features = torch.cat([candidate_mean, candidate_terminal, components], dim=-1)
359
+ hidden = self.trunk(features)
360
+ success_logits = self.success_head(hidden).squeeze(-1)
361
+ risk_values = torch.sigmoid(self.risk_head(hidden)).squeeze(-1)
362
+ residual = self.residual_head(hidden).squeeze(-1)
363
+ if proposal_logits is not None and proposal_logits.shape == residual.shape:
364
+ residual = residual + 0.25 * proposal_logits.sigmoid()
365
+ return {
366
+ "planner_hidden": hidden,
367
+ "success_logits": success_logits,
368
+ "risk_values": risk_values,
369
+ "utility_residual": residual,
370
+ }
371
+
372
+
373
+ class CascadePlanner(nn.Module):
374
+ def __init__(self, config: PlannerConfig) -> None:
375
+ super().__init__()
376
+ self.config = config
377
+ self.structured = StructuredElasticUtility(config)
378
+ self.residual = ResidualPlannerScorer(config)
379
+
380
+ def shortlist(
381
+ self,
382
+ proposal_logits: Tensor | None,
383
+ candidate_chunks: Tensor,
384
+ proposal_mode_assignments: Tensor | None = None,
385
+ ) -> Tensor:
386
+ batch_size, num_candidates = candidate_chunks.shape[:2]
387
+ top_k = min(max(1, self.config.top_k), num_candidates)
388
+ if proposal_logits is None:
389
+ cheap_scores = -candidate_chunks.square().mean(dim=(-1, -2))
390
+ else:
391
+ cheap_scores = proposal_logits
392
+ if proposal_mode_assignments is None:
393
+ return cheap_scores.topk(top_k, dim=-1).indices
394
+ if proposal_mode_assignments.ndim == 1:
395
+ proposal_mode_assignments = proposal_mode_assignments.unsqueeze(0).expand(batch_size, -1)
396
+
397
+ shortlisted = []
398
+ for batch_idx in range(batch_size):
399
+ scores = cheap_scores[batch_idx]
400
+ mode_ids = proposal_mode_assignments[batch_idx]
401
+ mode_best: list[tuple[float, int]] = []
402
+ for mode_id in torch.unique(mode_ids):
403
+ mode_indices = torch.nonzero(mode_ids == mode_id, as_tuple=False).squeeze(-1)
404
+ best_local = mode_indices[scores[mode_indices].argmax()]
405
+ mode_best.append((float(scores[best_local].detach()), int(best_local)))
406
+ mode_best.sort(key=lambda item: item[0], reverse=True)
407
+ chosen = [index for _, index in mode_best[:top_k]]
408
+ if len(chosen) < top_k:
409
+ for candidate_idx in scores.argsort(descending=True).tolist():
410
+ if candidate_idx not in chosen:
411
+ chosen.append(candidate_idx)
412
+ if len(chosen) >= top_k:
413
+ break
414
+ shortlisted.append(torch.as_tensor(chosen[:top_k], device=candidate_chunks.device, dtype=torch.long))
415
+ return torch.stack(shortlisted, dim=0)
416
+
417
+ def select_best(
418
+ self,
419
+ initial_state: dict[str, Tensor],
420
+ candidate_chunks: Tensor,
421
+ rollout_state: dict[str, Tensor],
422
+ proposal_logits: Tensor | None = None,
423
+ candidate_indices: Tensor | None = None,
424
+ proposal_mode_names: list[list[str]] | None = None,
425
+ ) -> dict[str, Tensor]:
426
+ structured = self.structured(
427
+ initial_state=initial_state,
428
+ rollout_state=rollout_state,
429
+ candidate_chunks=candidate_chunks,
430
+ )
431
+ residual = self.residual(
432
+ candidate_chunks=candidate_chunks,
433
+ structured=structured,
434
+ proposal_logits=proposal_logits,
435
+ )
436
+ utility_total = structured["utility_structured"] + self.config.residual_weight * residual["utility_residual"]
437
+ utility_total = utility_total + residual["success_logits"].sigmoid() - residual["risk_values"]
438
+ feasibility_penalty = torch.zeros_like(utility_total)
439
+ if proposal_mode_names is not None:
440
+ retrieve_like = torch.zeros_like(utility_total, dtype=torch.bool)
441
+ for batch_idx, names in enumerate(proposal_mode_names):
442
+ for candidate_idx, name in enumerate(names[: utility_total.shape[1]]):
443
+ retrieve_like[batch_idx, candidate_idx] = any(
444
+ token in name for token in ("retrieve", "insert_actor", "probe_inside")
445
+ )
446
+ blocked = (
447
+ (structured["access_floor"] < 0.15)
448
+ | (structured["persistence_floor"] < 0.15)
449
+ | (structured["support_floor"] < 0.25)
450
+ | (structured["reocclusion_worst"] > 0.6)
451
+ )
452
+ feasibility_penalty = retrieve_like.to(dtype=utility_total.dtype) * blocked.to(dtype=utility_total.dtype) * 2.0
453
+ utility_total = utility_total - feasibility_penalty
454
+ best_local = utility_total.argmax(dim=-1)
455
+ batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
456
+ if candidate_indices is None:
457
+ best_indices = best_local
458
+ else:
459
+ best_indices = candidate_indices[batch_indices, best_local]
460
+ return {
461
+ **structured,
462
+ **residual,
463
+ "utility_total": utility_total,
464
+ "utility_scores": utility_total,
465
+ "feasibility_penalty": feasibility_penalty,
466
+ "best_indices": best_indices,
467
+ "best_chunk": candidate_chunks[batch_indices, best_local],
468
+ "ranking_diagnostics": {
469
+ "topk_indices": candidate_indices if candidate_indices is not None else best_local.unsqueeze(-1),
470
+ "best_local_indices": best_local,
471
+ },
472
+ }
473
+
474
+
475
+ def _summary_scalar(state: dict[str, Tensor], key: str, fallback_keys: tuple[str, ...] = ()) -> Tensor:
476
+ for candidate in (key, *fallback_keys):
477
+ value = state.get(candidate)
478
+ if value is None:
479
+ continue
480
+ if value.ndim >= 5:
481
+ return value.mean(dim=tuple(range(value.ndim - 2, value.ndim))).mean(dim=-1)
482
+ if value.ndim == 4:
483
+ return value.mean(dim=(-1, -2))
484
+ if value.ndim == 3:
485
+ return value
486
+ if value.ndim == 2:
487
+ return value
488
+ return value.unsqueeze(-1)
489
+ raise KeyError(f"Missing summary key {key} and fallbacks {fallback_keys}.")
490
+
491
+
492
+ def _optional_summary_scalar(
493
+ state: dict[str, Tensor],
494
+ key: str,
495
+ *,
496
+ reference: Tensor,
497
+ fallback_keys: tuple[str, ...] = (),
498
+ ) -> Tensor:
499
+ try:
500
+ return _summary_scalar(state, key, fallback_keys)
501
+ except KeyError:
502
+ return torch.zeros_like(reference)
503
+
504
+
505
+ class ElasticFeasibilityGate(nn.Module):
506
+ def __init__(self, config: PlannerConfig) -> None:
507
+ super().__init__()
508
+ self.config = config
509
+
510
+ def forward(
511
+ self,
512
+ *,
513
+ rollout_state: dict[str, Tensor],
514
+ proposal_mode_names: list[list[str]],
515
+ ) -> dict[str, Tensor | list[list[dict[str, float | bool | str]]]]:
516
+ access = _summary_scalar(rollout_state, "access_summary", ("access_quality",))
517
+ persistence = _summary_scalar(rollout_state, "persistence_summary", ("persistence", "persistence_horizon"))
518
+ support = _summary_scalar(rollout_state, "support_summary", ("support_stability",))
519
+ reocclusion = _summary_scalar(rollout_state, "reocclusion_summary", ("reocclusion_penalty",))
520
+ disturbance = _summary_scalar(rollout_state, "disturbance_summary", ("disturbance_penalty",))
521
+ access_floor = access.amin(dim=-1)
522
+ persistence_floor = persistence.amin(dim=-1)
523
+ support_floor = support.amin(dim=-1)
524
+ reocclusion_worst = reocclusion.amax(dim=-1)
525
+ disturbance_worst = disturbance.amax(dim=-1)
526
+
527
+ blocked = (
528
+ (access_floor < self.config.retrieve_access_threshold)
529
+ | (persistence_floor < self.config.retrieve_persistence_threshold)
530
+ | (support_floor < self.config.retrieve_support_threshold)
531
+ | (reocclusion_worst > self.config.retrieve_reocclusion_threshold)
532
+ )
533
+ penalties = blocked.to(dtype=access.dtype) * 2.0
534
+ allowed_mask = torch.ones_like(access_floor, dtype=torch.bool)
535
+ reject_diagnostics: list[list[dict[str, float | bool | str]]] = []
536
+ for batch_idx, names in enumerate(proposal_mode_names):
537
+ sample_records: list[dict[str, float | bool | str]] = []
538
+ for candidate_idx, name in enumerate(names[: access_floor.shape[1]]):
539
+ retrieve_like = any(token in name for token in ("retrieve", "insert_actor", "probe_inside"))
540
+ candidate_blocked = bool(retrieve_like and blocked[batch_idx, candidate_idx])
541
+ if candidate_blocked:
542
+ allowed_mask[batch_idx, candidate_idx] = False
543
+ sample_records.append(
544
+ {
545
+ "mode_name": name,
546
+ "retrieve_like": retrieve_like,
547
+ "blocked": candidate_blocked,
548
+ "access_floor": float(access_floor[batch_idx, candidate_idx].detach()),
549
+ "persistence_floor": float(persistence_floor[batch_idx, candidate_idx].detach()),
550
+ "support_floor": float(support_floor[batch_idx, candidate_idx].detach()),
551
+ "reocclusion_worst": float(reocclusion_worst[batch_idx, candidate_idx].detach()),
552
+ "disturbance_worst": float(disturbance_worst[batch_idx, candidate_idx].detach()),
553
+ }
554
+ )
555
+ reject_diagnostics.append(sample_records)
556
+
557
+ confidence = torch.sigmoid(
558
+ 2.0 * access.mean(dim=-1)
559
+ + 1.5 * persistence.mean(dim=-1)
560
+ + 1.5 * support.mean(dim=-1)
561
+ - 1.5 * reocclusion.mean(dim=-1)
562
+ - disturbance.mean(dim=-1)
563
+ )
564
+ return {
565
+ "allowed_mask": allowed_mask,
566
+ "penalties": penalties,
567
+ "blocked_mask": blocked,
568
+ "adapter_confidence": confidence,
569
+ "gate_access_floor": access_floor,
570
+ "gate_persistence_floor": persistence_floor,
571
+ "gate_support_floor": support_floor,
572
+ "gate_reocclusion_worst": reocclusion_worst,
573
+ "reject_diagnostics": reject_diagnostics,
574
+ }
575
+
576
+
577
+ class ResidualActionReranker(nn.Module):
578
+ def __init__(self, config: PlannerConfig) -> None:
579
+ super().__init__()
580
+ feature_dim = (config.action_dim * 2) + 8
581
+ self.network = nn.Sequential(
582
+ nn.LayerNorm(feature_dim),
583
+ nn.Linear(feature_dim, config.hidden_dim),
584
+ nn.GELU(),
585
+ nn.Linear(config.hidden_dim, config.hidden_dim),
586
+ nn.GELU(),
587
+ )
588
+ self.score_head = nn.Linear(config.hidden_dim, 1)
589
+ self.success_head = nn.Linear(config.hidden_dim, 1)
590
+ self.risk_head = nn.Linear(config.hidden_dim, 1)
591
+
592
+ def forward(
593
+ self,
594
+ *,
595
+ candidate_chunks: Tensor,
596
+ rollout_state: dict[str, Tensor],
597
+ proposal_logits: Tensor | None,
598
+ ) -> dict[str, Tensor]:
599
+ candidate_mean = candidate_chunks.mean(dim=2)
600
+ candidate_terminal = candidate_chunks[:, :, -1]
601
+ visibility = _summary_scalar(rollout_state, "visibility_summary", ("visibility_gain",))
602
+ access = _summary_scalar(rollout_state, "access_summary", ("access_quality",))
603
+ persistence = _summary_scalar(rollout_state, "persistence_summary", ("persistence", "persistence_horizon"))
604
+ support = _summary_scalar(rollout_state, "support_summary", ("support_stability",))
605
+ reocclusion = _summary_scalar(rollout_state, "reocclusion_summary", ("reocclusion_penalty",))
606
+ disturbance = _summary_scalar(rollout_state, "disturbance_summary", ("disturbance_penalty",))
607
+ fold_preservation = _optional_summary_scalar(
608
+ rollout_state,
609
+ "fold_preservation_summary",
610
+ reference=visibility,
611
+ fallback_keys=("fold_preservation",),
612
+ )
613
+ lift_risk = _optional_summary_scalar(
614
+ rollout_state,
615
+ "lift_too_much_risk_summary",
616
+ reference=visibility,
617
+ fallback_keys=("lift_too_much_risk",),
618
+ )
619
+ features = torch.cat(
620
+ [
621
+ candidate_mean,
622
+ candidate_terminal,
623
+ visibility.mean(dim=-1, keepdim=True),
624
+ access.mean(dim=-1, keepdim=True),
625
+ persistence.mean(dim=-1, keepdim=True),
626
+ support.mean(dim=-1, keepdim=True),
627
+ reocclusion.mean(dim=-1, keepdim=True),
628
+ disturbance.mean(dim=-1, keepdim=True),
629
+ fold_preservation.mean(dim=-1, keepdim=True),
630
+ lift_risk.mean(dim=-1, keepdim=True),
631
+ ],
632
+ dim=-1,
633
+ )
634
+ hidden = self.network(features)
635
+ residual = self.score_head(hidden).squeeze(-1)
636
+ success = self.success_head(hidden).squeeze(-1)
637
+ risk = torch.sigmoid(self.risk_head(hidden).squeeze(-1))
638
+ if proposal_logits is not None and proposal_logits.shape == residual.shape:
639
+ residual = residual + 0.25 * proposal_logits.sigmoid()
640
+ return {
641
+ "residual_scores": residual,
642
+ "planner_success_logits": success,
643
+ "planner_risk_values": risk,
644
+ }
645
+
646
+
647
+ class AdapterPlanner(nn.Module):
648
+ def __init__(self, config: PlannerConfig) -> None:
649
+ super().__init__()
650
+ self.config = config
651
+ self.gate = ElasticFeasibilityGate(config)
652
+ self.reranker = ResidualActionReranker(config)
653
+
654
+ def select_best(
655
+ self,
656
+ *,
657
+ candidate_chunks: Tensor,
658
+ rollout_state: dict[str, Tensor],
659
+ proposal_mode_names: list[list[str]],
660
+ proposal_logits: Tensor | None = None,
661
+ planning_mode: str = "adapter_active",
662
+ ) -> dict[str, Tensor | list[list[dict[str, float | bool | str]]]]:
663
+ batch_size = candidate_chunks.shape[0]
664
+ batch_indices = torch.arange(batch_size, device=candidate_chunks.device)
665
+ if planning_mode in {"identity", "trunk_only", "adapter_noop"}:
666
+ zero_scores = candidate_chunks.new_zeros((batch_size, candidate_chunks.shape[1]))
667
+ return {
668
+ "best_indices": torch.zeros(batch_size, dtype=torch.long, device=candidate_chunks.device),
669
+ "best_chunk": candidate_chunks[:, 0],
670
+ "utility_scores": zero_scores,
671
+ "utility_total": zero_scores,
672
+ "planner_success_logits": zero_scores,
673
+ "planner_risk_values": zero_scores,
674
+ "adapter_confidence": candidate_chunks.new_ones((batch_size, candidate_chunks.shape[1])),
675
+ "reject_diagnostics": [[] for _ in range(batch_size)],
676
+ "planning_mode": planning_mode,
677
+ }
678
+
679
+ gate_outputs = self.gate(rollout_state=rollout_state, proposal_mode_names=proposal_mode_names)
680
+ reranker = self.reranker(
681
+ candidate_chunks=candidate_chunks,
682
+ rollout_state=rollout_state,
683
+ proposal_logits=proposal_logits,
684
+ )
685
+ utility = reranker["residual_scores"] + reranker["planner_success_logits"].sigmoid() - reranker["planner_risk_values"]
686
+ visibility = _summary_scalar(rollout_state, "visibility_summary", ("visibility_gain",)).mean(dim=-1)
687
+ access = _summary_scalar(rollout_state, "access_summary", ("access_quality",)).mean(dim=-1)
688
+ persistence = _summary_scalar(rollout_state, "persistence_summary", ("persistence", "persistence_horizon")).mean(dim=-1)
689
+ support = _summary_scalar(rollout_state, "support_summary", ("support_stability",)).mean(dim=-1)
690
+ reocclusion = _summary_scalar(rollout_state, "reocclusion_summary", ("reocclusion_penalty",)).mean(dim=-1)
691
+ disturbance = _summary_scalar(rollout_state, "disturbance_summary", ("disturbance_penalty",)).mean(dim=-1)
692
+ fold_preservation = _optional_summary_scalar(
693
+ rollout_state,
694
+ "fold_preservation_summary",
695
+ reference=_summary_scalar(rollout_state, "access_summary", ("access_quality",)),
696
+ fallback_keys=("fold_preservation",),
697
+ ).mean(dim=-1)
698
+ mouth_aperture = _optional_summary_scalar(
699
+ rollout_state,
700
+ "mouth_aperture_summary",
701
+ reference=_summary_scalar(rollout_state, "access_summary", ("access_quality",)),
702
+ fallback_keys=("mouth_aperture",),
703
+ ).mean(dim=-1)
704
+ layer_separation = _optional_summary_scalar(
705
+ rollout_state,
706
+ "layer_separation_summary",
707
+ reference=_summary_scalar(rollout_state, "access_summary", ("access_quality",)),
708
+ fallback_keys=("layer_separation_quality",),
709
+ ).mean(dim=-1)
710
+ lift_risk = _optional_summary_scalar(
711
+ rollout_state,
712
+ "lift_too_much_risk_summary",
713
+ reference=_summary_scalar(rollout_state, "access_summary", ("access_quality",)),
714
+ fallback_keys=("lift_too_much_risk",),
715
+ ).mean(dim=-1)
716
+ mode_bias = utility.new_zeros(utility.shape)
717
+ stage_penalty = utility.new_zeros(utility.shape)
718
+ unresolved_reveal = (1.0 - visibility) + (1.0 - access)
719
+ stabilized_reveal = 0.5 * (access + persistence + support)
720
+ # Use optimistic scene readiness summaries for stage switching.
721
+ # Candidate-level safety is still enforced by the retrieve gate below, so
722
+ # we should not let one poor candidate keep the entire scene stuck in
723
+ # "reveal forever" mode when another candidate already makes retrieve feasible.
724
+ batch_visibility = visibility.amax(dim=1)
725
+ batch_access = access.amax(dim=1)
726
+ batch_persistence = persistence.amax(dim=1)
727
+ batch_support = support.amax(dim=1)
728
+ batch_reocclusion = reocclusion.amin(dim=1)
729
+ batch_disturbance = disturbance.amin(dim=1)
730
+ batch_fold = fold_preservation.amax(dim=1)
731
+ batch_mouth = mouth_aperture.amax(dim=1)
732
+ batch_layer = layer_separation.amax(dim=1)
733
+ batch_lift = lift_risk.amin(dim=1)
734
+ batch_reveal_readiness = torch.maximum(batch_visibility, batch_access)
735
+ for batch_idx, names in enumerate(proposal_mode_names):
736
+ is_bag = any(any(token in name for token in ("mouth", "rim", "probe_inside")) for name in names)
737
+ is_cloth = any(any(token in name for token in ("fold", "lift", "layer")) for name in names)
738
+ can_retrieve = (
739
+ batch_access[batch_idx] >= self.config.retrieve_stage_access_threshold
740
+ and batch_reveal_readiness[batch_idx] >= self.config.retrieve_stage_reveal_threshold
741
+ and batch_persistence[batch_idx] >= self.config.retrieve_stage_persistence_threshold
742
+ and batch_support[batch_idx] >= self.config.retrieve_stage_support_threshold
743
+ and batch_reocclusion[batch_idx] <= self.config.retrieve_reocclusion_threshold
744
+ )
745
+ if is_bag:
746
+ can_retrieve = bool(
747
+ can_retrieve
748
+ and batch_mouth[batch_idx] >= 0.30
749
+ and batch_persistence[batch_idx] >= 0.55
750
+ )
751
+ elif is_cloth:
752
+ can_retrieve = bool(
753
+ can_retrieve
754
+ and batch_layer[batch_idx] >= 0.18
755
+ and batch_fold[batch_idx] >= 0.60
756
+ and batch_lift[batch_idx] <= 0.30
757
+ and batch_support[batch_idx] >= 0.70
758
+ )
759
+ can_insert = (
760
+ batch_access[batch_idx] >= self.config.insert_stage_access_threshold
761
+ and batch_visibility[batch_idx] >= self.config.insert_stage_visibility_threshold
762
+ and batch_support[batch_idx] >= self.config.insert_stage_support_threshold
763
+ and batch_reocclusion[batch_idx] <= 0.65
764
+ )
765
+ maintain_ready = (
766
+ batch_access[batch_idx] >= self.config.occlusion_maintain_gap_min_access
767
+ and batch_visibility[batch_idx] >= self.config.occlusion_maintain_gap_min_visibility
768
+ )
769
+ if can_retrieve:
770
+ preferred_tokens = ("retrieve",)
771
+ elif can_insert:
772
+ preferred_tokens = ("probe_inside", "insert_actor") if is_bag else ("insert_actor",)
773
+ elif is_bag:
774
+ if batch_access[batch_idx] < 0.15 or batch_visibility[batch_idx] < 0.20:
775
+ preferred_tokens = ("widen_mouth", "maintain_mouth")
776
+ else:
777
+ preferred_tokens = ("maintain_mouth", "widen_mouth")
778
+ elif is_cloth:
779
+ if batch_access[batch_idx] < 0.15 or batch_visibility[batch_idx] < 0.20:
780
+ preferred_tokens = ("lift_edge", "separate_layer")
781
+ elif batch_lift[batch_idx] > 0.15 or batch_disturbance[batch_idx] > 0.25:
782
+ preferred_tokens = ("stabilize_fold", "maintain_lift")
783
+ else:
784
+ preferred_tokens = ("maintain_lift", "stabilize_fold")
785
+ else:
786
+ if not maintain_ready:
787
+ preferred_tokens = ("widen_gap", "pin_canopy", "sweep_left", "sweep_right")
788
+ elif batch_visibility[batch_idx] < 0.20 or batch_access[batch_idx] < 0.25:
789
+ preferred_tokens = ("widen_gap", "pin_canopy")
790
+ elif batch_disturbance[batch_idx] > 0.25 or batch_reocclusion[batch_idx] > 0.40:
791
+ preferred_tokens = ("maintain_gap", "pin_canopy")
792
+ else:
793
+ preferred_tokens = ("pin_canopy", "widen_gap")
794
+ for candidate_idx, name in enumerate(names[: utility.shape[1]]):
795
+ if name == "base_action":
796
+ continue
797
+ if any(token in name for token in ("retrieve",)):
798
+ bonus = (
799
+ 0.85 * visibility[batch_idx, candidate_idx]
800
+ + 0.85 * access[batch_idx, candidate_idx]
801
+ + 0.65 * persistence[batch_idx, candidate_idx]
802
+ + 0.50 * support[batch_idx, candidate_idx]
803
+ - 0.60 * reocclusion[batch_idx, candidate_idx]
804
+ - 0.25 * disturbance[batch_idx, candidate_idx]
805
+ )
806
+ elif any(token in name for token in ("insert_actor", "probe_inside")):
807
+ bonus = (
808
+ 0.70 * visibility[batch_idx, candidate_idx]
809
+ + 0.70 * access[batch_idx, candidate_idx]
810
+ + 0.35 * persistence[batch_idx, candidate_idx]
811
+ - 0.35 * reocclusion[batch_idx, candidate_idx]
812
+ - 0.15 * disturbance[batch_idx, candidate_idx]
813
+ )
814
+ elif any(token in name for token in ("maintain", "stabilize", "pin_canopy")):
815
+ bonus = (
816
+ 0.85 * stabilized_reveal[batch_idx, candidate_idx]
817
+ + 0.25 * visibility[batch_idx, candidate_idx]
818
+ - 0.20 * reocclusion[batch_idx, candidate_idx]
819
+ - 0.10 * disturbance[batch_idx, candidate_idx]
820
+ )
821
+ else:
822
+ bonus = (
823
+ 0.95 * unresolved_reveal[batch_idx, candidate_idx]
824
+ + 0.20 * (1.0 - persistence[batch_idx, candidate_idx])
825
+ - 0.10 * disturbance[batch_idx, candidate_idx]
826
+ )
827
+ if any(token in name for token in ("fold", "lift", "layer")):
828
+ bonus = bonus + 0.35 * fold_preservation[batch_idx, candidate_idx] - 0.35 * lift_risk[batch_idx, candidate_idx]
829
+ if any(token in name for token in preferred_tokens):
830
+ bonus = bonus + self.config.mode_preference_bonus
831
+ elif "retrieve" in name and not can_retrieve:
832
+ bonus = bonus - self.config.premature_retrieve_penalty
833
+ stage_penalty[batch_idx, candidate_idx] = (
834
+ stage_penalty[batch_idx, candidate_idx] + self.config.premature_retrieve_penalty
835
+ )
836
+ elif is_cloth and any(token in name for token in ("stabilize", "maintain")) and any(
837
+ token in preferred_tokens for token in ("lift_edge", "separate_layer")
838
+ ):
839
+ bonus = bonus - 1.0
840
+ stage_penalty[batch_idx, candidate_idx] = stage_penalty[batch_idx, candidate_idx] + 1.0
841
+ elif (not is_bag and not is_cloth) and any(token in name for token in ("sweep_left", "sweep_right")) and any(
842
+ token in preferred_tokens for token in ("pin_canopy", "widen_gap", "maintain_gap")
843
+ ):
844
+ bonus = bonus - self.config.premature_occlusion_sweep_penalty
845
+ elif any(token in name for token in ("probe_inside", "insert_actor", "retrieve")) and not can_insert:
846
+ bonus = bonus - self.config.premature_insert_penalty
847
+ stage_penalty[batch_idx, candidate_idx] = (
848
+ stage_penalty[batch_idx, candidate_idx] + self.config.premature_insert_penalty
849
+ )
850
+ if (
851
+ (not is_bag and not is_cloth)
852
+ and "maintain_gap" in name
853
+ and not maintain_ready
854
+ and self.config.premature_maintain_penalty > 0.0
855
+ ):
856
+ bonus = bonus - self.config.premature_maintain_penalty
857
+ stage_penalty[batch_idx, candidate_idx] = (
858
+ stage_penalty[batch_idx, candidate_idx] + self.config.premature_maintain_penalty
859
+ )
860
+ if is_bag and (batch_mouth[batch_idx] < 0.18 or batch_access[batch_idx] < 0.15) and "widen_mouth" in name:
861
+ stage_penalty[batch_idx, candidate_idx] = stage_penalty[batch_idx, candidate_idx] + 1.5
862
+ if is_cloth and (batch_layer[batch_idx] < 0.12 or batch_visibility[batch_idx] < 0.05) and any(
863
+ token in name for token in ("lift_edge", "separate_layer")
864
+ ):
865
+ stage_penalty[batch_idx, candidate_idx] = stage_penalty[batch_idx, candidate_idx] + 1.5
866
+ mode_bias[batch_idx, candidate_idx] = bonus
867
+ utility = utility + mode_bias
868
+ utility = utility + 0.5 * fold_preservation - 0.5 * lift_risk
869
+ utility = utility - stage_penalty
870
+ utility = utility - gate_outputs["penalties"]
871
+ allowed_mask = gate_outputs["allowed_mask"]
872
+ assert isinstance(allowed_mask, Tensor)
873
+ utility = utility.masked_fill(~allowed_mask, -1e6)
874
+ best_indices = utility.argmax(dim=-1)
875
+ best_chunk = candidate_chunks[batch_indices, best_indices]
876
+ return {
877
+ "best_indices": best_indices,
878
+ "best_chunk": best_chunk,
879
+ "utility_scores": utility,
880
+ "utility_total": utility,
881
+ "planner_success_logits": reranker["planner_success_logits"],
882
+ "planner_risk_values": reranker["planner_risk_values"],
883
+ "adapter_confidence": gate_outputs["adapter_confidence"],
884
+ "allowed_mask": gate_outputs["allowed_mask"],
885
+ "reject_diagnostics": gate_outputs["reject_diagnostics"],
886
+ "planning_mode": planning_mode,
887
+ }
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (284 Bytes). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (317 Bytes). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (3.63 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/transforms.cpython-311.pyc ADDED
Binary file (7.25 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc ADDED
Binary file (1.95 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (54.1 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc ADDED
Binary file (98 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc ADDED
Binary file (6.32 kB). View file
 
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (291 Bytes). View file