Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/README.md +114 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/docs/adapter_stack.md +87 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/docs/public_benchmark_package.md +73 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/protocols.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/public_benchmark_package.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/public_benchmark_package.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/report.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/report.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_public_benchmark_package.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_public_benchmark_package.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_reveal_benchmark.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_reveal_benchmark.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/public_benchmark_package.py +266 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_maniskill_bridge_retrieval_smoke.py +2037 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_maniskill_pickclutter_smoke.py +2005 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_public_benchmark_package.py +369 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/rvt_backbone.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/rvt_backbone.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/planner.py +887 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/__init__.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/__init__.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/transforms.cpython-310.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/transforms.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc +0 -0
- code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-310.pyc +0 -0
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/README.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# reveal_vla_bimanual
|
| 2 |
+
|
| 3 |
+
Simulation-first prototype for a language-conditioned bimanual reveal-and-retrieve policy under elastic occlusion.
|
| 4 |
+
|
| 5 |
+
This repo is not a generalist VLA backbone in the RT-2 / OpenVLA / Octo sense. The current contribution is the reveal-state machinery layered on top of a frozen vision-language encoder.
|
| 6 |
+
|
| 7 |
+
This repo is structured around five top-level modules:
|
| 8 |
+
|
| 9 |
+
- `sim_rlbench/`: RLBench2 / PerAct2 wrappers, dataset hooks, camera setup, and benchmark evaluation helpers.
|
| 10 |
+
- `sim_reveal/`: reveal-proxy environments, scripted teachers, and privileged label extraction.
|
| 11 |
+
- `models/`: shared backbone wrappers, multi-view fusion, bimanual decoder, reveal-state head, world model, and planner.
|
| 12 |
+
- `train/`: trainers, losses, checkpointing, and Hydra/YAML configs.
|
| 13 |
+
- `eval/`: benchmark scripts, ablations, metrics, plots, and report generation.
|
| 14 |
+
|
| 15 |
+
Current bootstrap priorities:
|
| 16 |
+
|
| 17 |
+
1. Reproduce the RLBench2 / PerAct2 stack with a fixed 3-camera interface.
|
| 18 |
+
2. Stand up a backbone-only 3-camera policy in the same training/eval harness.
|
| 19 |
+
3. Add reveal-state supervision and short-horizon planning for synthetic reveal proxies.
|
| 20 |
+
|
| 21 |
+
## Public benchmark package
|
| 22 |
+
|
| 23 |
+
The repo now includes a concrete public-benchmark package definition for the next-stage fair comparison:
|
| 24 |
+
|
| 25 |
+
- `eval/public_benchmark_package.py`
|
| 26 |
+
- track registry for bag, dense occluded retrieval, cloth retrieval, and the generic anchor
|
| 27 |
+
- same-protocol signatures across `trunk_only`, `adapter_noop`, and `adapter_active`
|
| 28 |
+
- same-data / same-init fairness signatures for `trunk_only_ft` vs `adapter_active_ft`
|
| 29 |
+
|
| 30 |
+
- `eval/run_public_benchmark_package.py`
|
| 31 |
+
- validates normalized result files from multiple public suites
|
| 32 |
+
- checks protocol identity and training fairness
|
| 33 |
+
- aggregates per-track gains, sign-of-life diagnostics, and anchor regressions
|
| 34 |
+
|
| 35 |
+
Write the default manifest to `~/workspace` with:
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
python -m eval.run_public_benchmark_package \
|
| 39 |
+
--write-default-manifest ~/workspace/public_benchmark_package_v1.json
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Summarize normalized result files with:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
python -m eval.run_public_benchmark_package \
|
| 46 |
+
--result /abs/path/result_a.json \
|
| 47 |
+
--result /abs/path/result_b.json \
|
| 48 |
+
--output-dir ~/workspace/reports/public_benchmark_package_v1
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Upstream dependencies are kept in `/workspace/third_party` and pinned in `docs/upstream_pins.md`.
|
| 52 |
+
|
| 53 |
+
## RLBench env A
|
| 54 |
+
|
| 55 |
+
The RLBench / PerAct2 stack is pinned to Python 3.10 and lives in `/workspace/envs/rlbench`.
|
| 56 |
+
|
| 57 |
+
Bring it up with:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
/workspace/reveal_vla_bimanual/scripts/setup_env_a_rlbench.sh
|
| 61 |
+
/workspace/reveal_vla_bimanual/scripts/setup_rlbench_headless_x.sh
|
| 62 |
+
/workspace/reveal_vla_bimanual/scripts/start_rlbench_x.sh
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Verify GPU GL on the headless display:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
DISPLAY=:99 glxinfo -B
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Run the RLBench launch/reset/step smoke test:
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
env \
|
| 75 |
+
DISPLAY=:99 \
|
| 76 |
+
XDG_RUNTIME_DIR=/tmp/runtime-root \
|
| 77 |
+
COPPELIASIM_ROOT=/workspace/assets/coppeliasim_v4_1_0 \
|
| 78 |
+
LD_LIBRARY_PATH=/workspace/system_shims/nvidia$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n1 | cut -d. -f1)/usr/lib/x86_64-linux-gnu:/workspace/system_shims/nvidia$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n1 | cut -d. -f1)/usr/lib/x86_64-linux-gnu/nvidia:/workspace/assets/coppeliasim_v4_1_0 \
|
| 79 |
+
QT_QPA_PLATFORM_PLUGIN_PATH=/workspace/assets/coppeliasim_v4_1_0 \
|
| 80 |
+
/workspace/.tools/micromamba/bin/micromamba run \
|
| 81 |
+
-r /workspace/.micromamba \
|
| 82 |
+
-p /workspace/envs/rlbench \
|
| 83 |
+
python -m sim_rlbench.launch_smoke --headless
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
The working benchmark interface is fixed to three cameras only:
|
| 87 |
+
|
| 88 |
+
- `front`
|
| 89 |
+
- `wrist_left`
|
| 90 |
+
- `wrist_right`
|
| 91 |
+
|
| 92 |
+
The smoke test covers launch, bimanual task reset, canonical observation extraction, and one bimanual action step in `headless=True`, which is the same mode used by the upstream PerAct2-style training stack.
|
| 93 |
+
|
| 94 |
+
Generate the PerAct2-compatible train command for the fixed 3-camera interface with:
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
micromamba run -r /workspace/.micromamba -p /workspace/envs/rlbench \
|
| 98 |
+
python -m sim_rlbench.smoke_test --print-train-command
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
Download the published PerAct2 demos into `/workspace/data/rlbench2` with checksum verification:
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
micromamba run -r /workspace/.micromamba -p /workspace/envs/rlbench \
|
| 105 |
+
python -m sim_rlbench.dataset_download --resolution 256 --splits train
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
If you want the archives unpacked directly into the demo root expected by RLBench, add `--extract`:
|
| 109 |
+
|
| 110 |
+
```bash
|
| 111 |
+
apt-get install -y squashfs-tools
|
| 112 |
+
micromamba run -r /workspace/.micromamba -p /workspace/envs/rlbench \
|
| 113 |
+
python -m sim_rlbench.dataset_download --resolution 256 --splits train --extract
|
| 114 |
+
```
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/docs/adapter_stack.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapter Stack
|
| 2 |
+
|
| 3 |
+
This repo now contains a no-op-safe `trunk + adapter` path alongside the legacy monolithic policies.
|
| 4 |
+
|
| 5 |
+
## Main classes
|
| 6 |
+
|
| 7 |
+
- `models/policy.py`
|
| 8 |
+
- `FoundationTrunkPolicy`
|
| 9 |
+
- `ElasticOcclusionAdapter`
|
| 10 |
+
- `AdapterWrappedPolicy`
|
| 11 |
+
|
| 12 |
+
- `models/backbones.py`
|
| 13 |
+
- `NoOpAdapterCompatibleTrunkOutput`
|
| 14 |
+
- `TrunkInterface`
|
| 15 |
+
|
| 16 |
+
- `models/action_decoder.py`
|
| 17 |
+
- `TaskRoutedProposalPrior`
|
| 18 |
+
|
| 19 |
+
- `models/planner.py`
|
| 20 |
+
- `ElasticFeasibilityGate`
|
| 21 |
+
- `ResidualActionReranker`
|
| 22 |
+
- `AdapterPlanner`
|
| 23 |
+
|
| 24 |
+
- `models/world_model.py`
|
| 25 |
+
- `LightweightRevealStateTransitionModel`
|
| 26 |
+
|
| 27 |
+
- `models/observation_memory.py`
|
| 28 |
+
- `RevealStateCache`
|
| 29 |
+
|
| 30 |
+
## Trainer modes
|
| 31 |
+
|
| 32 |
+
`train/trainer.py` now supports:
|
| 33 |
+
|
| 34 |
+
- `policy_type: adapter_wrapped`
|
| 35 |
+
- `policy_type: foundation_trunk`
|
| 36 |
+
|
| 37 |
+
Relevant trainer fields:
|
| 38 |
+
|
| 39 |
+
- `training_regime`
|
| 40 |
+
- `eval_mode`
|
| 41 |
+
- `adapter_mode`
|
| 42 |
+
- `adapter_use_transition_model`
|
| 43 |
+
- `adapter_use_task_conditioning`
|
| 44 |
+
|
| 45 |
+
## Guardrail tests
|
| 46 |
+
|
| 47 |
+
New tests:
|
| 48 |
+
|
| 49 |
+
- `tests/test_trunk_noop_equivalence.py`
|
| 50 |
+
- `tests/test_adapter_gate_blocks_unsafe_retrieve.py`
|
| 51 |
+
- `tests/test_task_specific_loss_masking.py`
|
| 52 |
+
- `tests/test_cloth_specific_metrics_affect_selection.py`
|
| 53 |
+
- `tests/test_general_eval_protocol_is_identical.py`
|
| 54 |
+
|
| 55 |
+
## Config templates
|
| 56 |
+
|
| 57 |
+
- `train/configs/proxy_adapter_wrapped_clip_base.yaml`
|
| 58 |
+
- `train/configs/proxy_adapter_wrapped_clip_rank_only.yaml`
|
| 59 |
+
- `train/configs/proxy_adapter_wrapped_clip_noop_eval.yaml`
|
| 60 |
+
|
| 61 |
+
## Benchmark wrappers
|
| 62 |
+
|
| 63 |
+
- `scripts/run_anchor_adapter_ablations.sh`
|
| 64 |
+
- `scripts/run_proxy_adapter_ablations.sh`
|
| 65 |
+
- `scripts/run_target_like_adapter_subset.sh`
|
| 66 |
+
- `eval/public_benchmark_package.py`
|
| 67 |
+
- `eval/run_public_benchmark_package.py`
|
| 68 |
+
|
| 69 |
+
All new configs and scripts default to `~/workspace` outputs and reports.
|
| 70 |
+
|
| 71 |
+
## Public benchmark package
|
| 72 |
+
|
| 73 |
+
The public benchmark package is the current fair-comparison contract for real benchmarks:
|
| 74 |
+
|
| 75 |
+
- target tracks:
|
| 76 |
+
- `bag_track` -> `BEHAVIOR-1K/unpacking_childs_bag-0`
|
| 77 |
+
- `occlusion_track` -> `ManiSkill/PickClutterYCB-v1`
|
| 78 |
+
- `cloth_track` -> `GarmentLab/grasp_protocol_stacked_garment`
|
| 79 |
+
- anchor track:
|
| 80 |
+
- `anchor_track` -> `AnyBimanual/dual_push_buttons`
|
| 81 |
+
|
| 82 |
+
The package code enforces:
|
| 83 |
+
|
| 84 |
+
- mode-invariant eval protocols per track
|
| 85 |
+
- same-data / same-init fairness for `trunk_only_ft` vs `adapter_active_ft`
|
| 86 |
+
- sign-of-life thresholds on intervention and non-base proposal selection
|
| 87 |
+
- no-regression tolerance on the trusted generic anchor
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/docs/public_benchmark_package.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Public Benchmark Package
|
| 2 |
+
|
| 3 |
+
This repo now contains a concrete public-benchmark package for the real-sim phase.
|
| 4 |
+
|
| 5 |
+
## Tracks
|
| 6 |
+
|
| 7 |
+
- `bag_track`
|
| 8 |
+
- suite: `BEHAVIOR-1K`
|
| 9 |
+
- task: `unpacking_childs_bag-0`
|
| 10 |
+
- `occlusion_track`
|
| 11 |
+
- suite: `ManiSkill 3`
|
| 12 |
+
- task: `PickClutterYCB-v1`
|
| 13 |
+
- `cloth_track`
|
| 14 |
+
- suite: `GarmentLab`
|
| 15 |
+
- task slice: `grasp_protocol_stacked_garment`
|
| 16 |
+
- `anchor_track`
|
| 17 |
+
- suite: `AnyBimanual`
|
| 18 |
+
- task: `dual_push_buttons`
|
| 19 |
+
|
| 20 |
+
## Enforced fairness
|
| 21 |
+
|
| 22 |
+
- `trunk_only_ft` and `adapter_active_ft` must share:
|
| 23 |
+
- train demos
|
| 24 |
+
- val demos
|
| 25 |
+
- init checkpoint group
|
| 26 |
+
- optimizer
|
| 27 |
+
- LR schedule
|
| 28 |
+
- batch size
|
| 29 |
+
- augmentations
|
| 30 |
+
- early stopping metric
|
| 31 |
+
- max gradient steps
|
| 32 |
+
- unfrozen trunk scope
|
| 33 |
+
- dataset split id
|
| 34 |
+
- all modes on a track must share the same eval protocol signature
|
| 35 |
+
- anchor regressions are bounded by an absolute tolerance of `0.02`
|
| 36 |
+
|
| 37 |
+
## Normalized result schema
|
| 38 |
+
|
| 39 |
+
Each external benchmark run should be converted to one JSON object with:
|
| 40 |
+
|
| 41 |
+
- `track_id`
|
| 42 |
+
- `adapter_mode`
|
| 43 |
+
- `successes` or `success_rate`
|
| 44 |
+
- `episodes`
|
| 45 |
+
- `seed`
|
| 46 |
+
- `eval_protocol`
|
| 47 |
+
- for target tracks: `train_spec`
|
| 48 |
+
- optional diagnostics:
|
| 49 |
+
- `intervention_rate`
|
| 50 |
+
- `non_base_selection_rate`
|
| 51 |
+
- `steps_to_first_reveal_or_access`
|
| 52 |
+
- `steps_to_retrieve`
|
| 53 |
+
- `disturbance_proxy`
|
| 54 |
+
|
| 55 |
+
## Commands
|
| 56 |
+
|
| 57 |
+
Write the default manifest:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
python -m eval.run_public_benchmark_package \
|
| 61 |
+
--write-default-manifest ~/workspace/public_benchmark_package_v1.json
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
Summarize results:
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
python -m eval.run_public_benchmark_package \
|
| 68 |
+
--result /abs/path/behavior_bag_adapter_active_seed17.json \
|
| 69 |
+
--result /abs/path/behavior_bag_trunk_seed17.json \
|
| 70 |
+
--result /abs/path/maniskill_occlusion_adapter_active_seed17.json \
|
| 71 |
+
--result /abs/path/maniskill_occlusion_trunk_seed17.json \
|
| 72 |
+
--output-dir ~/workspace/reports/public_benchmark_package_v1
|
| 73 |
+
```
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (288 Bytes). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (322 Bytes). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/protocols.cpython-310.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/public_benchmark_package.cpython-310.pyc
ADDED
|
Binary file (8.38 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/public_benchmark_package.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/report.cpython-310.pyc
ADDED
|
Binary file (1.79 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/report.cpython-311.pyc
ADDED
|
Binary file (3.36 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_public_benchmark_package.cpython-310.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_public_benchmark_package.cpython-311.pyc
ADDED
|
Binary file (24 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_reveal_benchmark.cpython-310.pyc
ADDED
|
Binary file (29.2 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/__pycache__/run_reveal_benchmark.cpython-311.pyc
ADDED
|
Binary file (63.1 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/public_benchmark_package.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import asdict, dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Sequence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
TARGET_ROLE = "target"
|
| 10 |
+
ANCHOR_ROLE = "anchor"
|
| 11 |
+
|
| 12 |
+
TARGET_TRACK_EVAL_MODES: tuple[str, ...] = (
|
| 13 |
+
"trunk_only_ft",
|
| 14 |
+
"adapter_noop",
|
| 15 |
+
"adapter_active_ft",
|
| 16 |
+
)
|
| 17 |
+
ANCHOR_TRACK_EVAL_MODES: tuple[str, ...] = (
|
| 18 |
+
"trunk_only",
|
| 19 |
+
"adapter_noop",
|
| 20 |
+
"adapter_active",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
DEFAULT_TARGET_TRAIN_DEMOS = 64
|
| 24 |
+
DEFAULT_TARGET_VAL_DEMOS = 16
|
| 25 |
+
DEFAULT_TARGET_TEST_EPISODES = 100
|
| 26 |
+
DEFAULT_ANCHOR_EPISODES = 25
|
| 27 |
+
DEFAULT_RESOLUTION = 256
|
| 28 |
+
DEFAULT_ANCHOR_TOLERANCE = 0.02
|
| 29 |
+
DEFAULT_SIGN_OF_LIFE_INTERVENTION = 0.15
|
| 30 |
+
DEFAULT_SIGN_OF_LIFE_NON_BASE = 0.15
|
| 31 |
+
DEFAULT_SIGN_OF_LIFE_GAIN = 0.05
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass(frozen=True)
|
| 35 |
+
class PublicBenchmarkTrack:
|
| 36 |
+
track_id: str
|
| 37 |
+
suite: str
|
| 38 |
+
benchmark_task: str
|
| 39 |
+
role: str
|
| 40 |
+
task_family: str
|
| 41 |
+
target_behavior: str
|
| 42 |
+
public_source: str
|
| 43 |
+
notes: str = ""
|
| 44 |
+
success_metric: str = "success_rate"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
PUBLIC_BENCHMARK_TRACKS: tuple[PublicBenchmarkTrack, ...] = (
|
| 48 |
+
PublicBenchmarkTrack(
|
| 49 |
+
track_id="bag_track",
|
| 50 |
+
suite="behavior1k",
|
| 51 |
+
benchmark_task="unpacking_childs_bag-0",
|
| 52 |
+
role=TARGET_ROLE,
|
| 53 |
+
task_family="bag_retrieval",
|
| 54 |
+
target_behavior="retrieve target objects from an opened compliant bag or backpack",
|
| 55 |
+
public_source="https://behavior.stanford.edu/knowledgebase/tasks/",
|
| 56 |
+
notes=(
|
| 57 |
+
"Closest public bag retrieval task. Treat as the benchmark-standard bag opening / "
|
| 58 |
+
"retrieval slice."
|
| 59 |
+
),
|
| 60 |
+
),
|
| 61 |
+
PublicBenchmarkTrack(
|
| 62 |
+
track_id="occlusion_track",
|
| 63 |
+
suite="maniskill3",
|
| 64 |
+
benchmark_task="PickClutterYCB-v1",
|
| 65 |
+
role=TARGET_ROLE,
|
| 66 |
+
task_family="dense_occluded_retrieval",
|
| 67 |
+
target_behavior="retrieve a target object from dense occluding clutter",
|
| 68 |
+
public_source="https://maniskill.readthedocs.io/en/latest/tasks/table_top_gripper/index.html",
|
| 69 |
+
notes=(
|
| 70 |
+
"Closest maintained public occluded retrieval task. Treat as the canopy / dense "
|
| 71 |
+
"occlusion proxy."
|
| 72 |
+
),
|
| 73 |
+
),
|
| 74 |
+
PublicBenchmarkTrack(
|
| 75 |
+
track_id="cloth_track",
|
| 76 |
+
suite="garmentlab",
|
| 77 |
+
benchmark_task="grasp_protocol_stacked_garment",
|
| 78 |
+
role=TARGET_ROLE,
|
| 79 |
+
task_family="cloth_retrieval",
|
| 80 |
+
target_behavior="retrieve a hidden or partially covered object from stacked or cluttered garments",
|
| 81 |
+
public_source="https://garmentlab.readthedocs.io/en/latest/tutorial/realworldbenchmark/index.html",
|
| 82 |
+
notes=(
|
| 83 |
+
"Use the GarmentLab grasp protocol in stacked/clutter layouts as the closest public "
|
| 84 |
+
"cloth retrieval benchmark slice."
|
| 85 |
+
),
|
| 86 |
+
),
|
| 87 |
+
PublicBenchmarkTrack(
|
| 88 |
+
track_id="anchor_track",
|
| 89 |
+
suite="anybimanual",
|
| 90 |
+
benchmark_task="dual_push_buttons",
|
| 91 |
+
role=ANCHOR_ROLE,
|
| 92 |
+
task_family="generic_anchor",
|
| 93 |
+
target_behavior="generic bimanual control regression anchor",
|
| 94 |
+
public_source="https://arxiv.org/abs/2412.06779",
|
| 95 |
+
notes="Trusted public anchor on this setup. Keep as a no-regression track only.",
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def public_benchmark_tracks(role: str | None = None) -> list[PublicBenchmarkTrack]:
|
| 101 |
+
if role is None:
|
| 102 |
+
return list(PUBLIC_BENCHMARK_TRACKS)
|
| 103 |
+
return [track for track in PUBLIC_BENCHMARK_TRACKS if track.role == role]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def public_track_by_id(track_id: str) -> PublicBenchmarkTrack:
|
| 107 |
+
normalized = str(track_id).strip()
|
| 108 |
+
for track in PUBLIC_BENCHMARK_TRACKS:
|
| 109 |
+
if track.track_id == normalized:
|
| 110 |
+
return track
|
| 111 |
+
raise KeyError(f"Unknown public benchmark track: {track_id!r}")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def expected_eval_modes(track_id: str) -> tuple[str, ...]:
|
| 115 |
+
track = public_track_by_id(track_id)
|
| 116 |
+
if track.role == TARGET_ROLE:
|
| 117 |
+
return TARGET_TRACK_EVAL_MODES
|
| 118 |
+
return ANCHOR_TRACK_EVAL_MODES
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def build_public_eval_protocol(
|
| 122 |
+
*,
|
| 123 |
+
track_id: str,
|
| 124 |
+
eval_mode: str,
|
| 125 |
+
seed: int = 17,
|
| 126 |
+
episodes: int | None = None,
|
| 127 |
+
resolution: int = DEFAULT_RESOLUTION,
|
| 128 |
+
cameras: Sequence[str] = ("front", "left_wrist", "right_wrist"),
|
| 129 |
+
) -> dict[str, Any]:
|
| 130 |
+
track = public_track_by_id(track_id)
|
| 131 |
+
expected = expected_eval_modes(track.track_id)
|
| 132 |
+
mode = str(eval_mode)
|
| 133 |
+
if mode not in expected:
|
| 134 |
+
raise ValueError(f"Unexpected eval mode {mode!r} for track {track.track_id!r}. Expected one of {expected}.")
|
| 135 |
+
if episodes is None:
|
| 136 |
+
episodes = DEFAULT_TARGET_TEST_EPISODES if track.role == TARGET_ROLE else DEFAULT_ANCHOR_EPISODES
|
| 137 |
+
return {
|
| 138 |
+
"track_id": track.track_id,
|
| 139 |
+
"suite": track.suite,
|
| 140 |
+
"benchmark_task": track.benchmark_task,
|
| 141 |
+
"role": track.role,
|
| 142 |
+
"eval_mode": mode,
|
| 143 |
+
"seed": int(seed),
|
| 144 |
+
"episodes": int(episodes),
|
| 145 |
+
"resolution": int(resolution),
|
| 146 |
+
"cameras": tuple(str(camera) for camera in cameras),
|
| 147 |
+
"observation_stack": "rgbd_3cam",
|
| 148 |
+
"action_horizon": 8,
|
| 149 |
+
"action_space": "bimanual_delta_pose",
|
| 150 |
+
"same_test_episodes": True,
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def public_protocol_identity_signature(protocol: dict[str, Any]) -> tuple[object, ...]:
|
| 155 |
+
return (
|
| 156 |
+
protocol["track_id"],
|
| 157 |
+
protocol["suite"],
|
| 158 |
+
protocol["benchmark_task"],
|
| 159 |
+
protocol["role"],
|
| 160 |
+
protocol["seed"],
|
| 161 |
+
protocol["episodes"],
|
| 162 |
+
protocol["resolution"],
|
| 163 |
+
tuple(protocol["cameras"]),
|
| 164 |
+
protocol["observation_stack"],
|
| 165 |
+
protocol["action_horizon"],
|
| 166 |
+
protocol["action_space"],
|
| 167 |
+
protocol["same_test_episodes"],
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def build_target_training_spec(
|
| 172 |
+
*,
|
| 173 |
+
track_id: str,
|
| 174 |
+
model_variant: str,
|
| 175 |
+
seed: int,
|
| 176 |
+
train_demos: int = DEFAULT_TARGET_TRAIN_DEMOS,
|
| 177 |
+
val_demos: int = DEFAULT_TARGET_VAL_DEMOS,
|
| 178 |
+
init_checkpoint_group: str = "shared_public_trunk",
|
| 179 |
+
optimizer: str = "adamw",
|
| 180 |
+
learning_rate: float = 3e-4,
|
| 181 |
+
lr_schedule: str = "cosine",
|
| 182 |
+
batch_size: int = 32,
|
| 183 |
+
augmentations: str = "matched_rgbd_aug_v1",
|
| 184 |
+
early_stopping_metric: str = "val_success",
|
| 185 |
+
max_gradient_steps: int = 20_000,
|
| 186 |
+
unfreeze_scope: str = "matched_trunk_scope",
|
| 187 |
+
dataset_split_id: str | None = None,
|
| 188 |
+
) -> dict[str, Any]:
|
| 189 |
+
track = public_track_by_id(track_id)
|
| 190 |
+
if track.role != TARGET_ROLE:
|
| 191 |
+
raise ValueError(f"Target training spec is only valid for target tracks, got {track_id!r}.")
|
| 192 |
+
return {
|
| 193 |
+
"track_id": track.track_id,
|
| 194 |
+
"suite": track.suite,
|
| 195 |
+
"benchmark_task": track.benchmark_task,
|
| 196 |
+
"model_variant": str(model_variant),
|
| 197 |
+
"seed": int(seed),
|
| 198 |
+
"train_demos": int(train_demos),
|
| 199 |
+
"val_demos": int(val_demos),
|
| 200 |
+
"init_checkpoint_group": str(init_checkpoint_group),
|
| 201 |
+
"optimizer": str(optimizer),
|
| 202 |
+
"learning_rate": float(learning_rate),
|
| 203 |
+
"lr_schedule": str(lr_schedule),
|
| 204 |
+
"batch_size": int(batch_size),
|
| 205 |
+
"augmentations": str(augmentations),
|
| 206 |
+
"early_stopping_metric": str(early_stopping_metric),
|
| 207 |
+
"max_gradient_steps": int(max_gradient_steps),
|
| 208 |
+
"unfreeze_scope": str(unfreeze_scope),
|
| 209 |
+
"dataset_split_id": dataset_split_id or f"{track.track_id}_shared_split_seed{int(seed)}",
|
| 210 |
+
"same_data_policy": True,
|
| 211 |
+
"same_init_policy": True,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def training_fairness_signature(spec: dict[str, Any]) -> tuple[object, ...]:
|
| 216 |
+
return (
|
| 217 |
+
spec["track_id"],
|
| 218 |
+
spec["suite"],
|
| 219 |
+
spec["benchmark_task"],
|
| 220 |
+
spec["seed"],
|
| 221 |
+
spec["train_demos"],
|
| 222 |
+
spec["val_demos"],
|
| 223 |
+
spec["init_checkpoint_group"],
|
| 224 |
+
spec["optimizer"],
|
| 225 |
+
spec["learning_rate"],
|
| 226 |
+
spec["lr_schedule"],
|
| 227 |
+
spec["batch_size"],
|
| 228 |
+
spec["augmentations"],
|
| 229 |
+
spec["early_stopping_metric"],
|
| 230 |
+
spec["max_gradient_steps"],
|
| 231 |
+
spec["unfreeze_scope"],
|
| 232 |
+
spec["dataset_split_id"],
|
| 233 |
+
spec["same_data_policy"],
|
| 234 |
+
spec["same_init_policy"],
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def default_public_benchmark_manifest() -> dict[str, Any]:
|
| 239 |
+
return {
|
| 240 |
+
"package_name": "public_reveal_retrieve_package_v1",
|
| 241 |
+
"tracks": [asdict(track) for track in PUBLIC_BENCHMARK_TRACKS],
|
| 242 |
+
"target_track_ids": [track.track_id for track in public_benchmark_tracks(TARGET_ROLE)],
|
| 243 |
+
"anchor_track_ids": [track.track_id for track in public_benchmark_tracks(ANCHOR_ROLE)],
|
| 244 |
+
"target_eval_modes": list(TARGET_TRACK_EVAL_MODES),
|
| 245 |
+
"anchor_eval_modes": list(ANCHOR_TRACK_EVAL_MODES),
|
| 246 |
+
"defaults": {
|
| 247 |
+
"target_train_demos": DEFAULT_TARGET_TRAIN_DEMOS,
|
| 248 |
+
"target_val_demos": DEFAULT_TARGET_VAL_DEMOS,
|
| 249 |
+
"target_test_episodes": DEFAULT_TARGET_TEST_EPISODES,
|
| 250 |
+
"anchor_episodes": DEFAULT_ANCHOR_EPISODES,
|
| 251 |
+
"resolution": DEFAULT_RESOLUTION,
|
| 252 |
+
},
|
| 253 |
+
"thresholds": {
|
| 254 |
+
"anchor_tolerance": DEFAULT_ANCHOR_TOLERANCE,
|
| 255 |
+
"sign_of_life_intervention_rate": DEFAULT_SIGN_OF_LIFE_INTERVENTION,
|
| 256 |
+
"sign_of_life_non_base_selection_rate": DEFAULT_SIGN_OF_LIFE_NON_BASE,
|
| 257 |
+
"sign_of_life_success_gain": DEFAULT_SIGN_OF_LIFE_GAIN,
|
| 258 |
+
},
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def write_default_public_benchmark_manifest(output_path: str | Path) -> Path:
|
| 263 |
+
path = Path(output_path)
|
| 264 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 265 |
+
path.write_text(json.dumps(default_public_benchmark_manifest(), indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
| 266 |
+
return path
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_maniskill_bridge_retrieval_smoke.py
ADDED
|
@@ -0,0 +1,2037 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from collections import deque
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Sequence
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
CODE_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
+
if str(CODE_ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(CODE_ROOT))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _configure_runtime_env() -> None:
|
| 24 |
+
os.environ.setdefault("VK_ICD_FILENAMES", "/workspace/runtime/vulkan/icd.d/nvidia_icd_egl.json")
|
| 25 |
+
os.environ.setdefault("VK_LAYER_PATH", "/workspace/runtime/vulkan/implicit_layer.d")
|
| 26 |
+
os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp/runtime-root")
|
| 27 |
+
os.environ["MS_ASSET_DIR"] = "/workspace/.maniskill"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
_configure_runtime_env()
|
| 31 |
+
|
| 32 |
+
from eval.run_maniskill_pickclutter_smoke import (
|
| 33 |
+
DEFAULT_INIT_CHECKPOINT,
|
| 34 |
+
HISTORY_STEPS,
|
| 35 |
+
MAX_MACRO_STEPS,
|
| 36 |
+
NUM_APPROACH_TEMPLATES,
|
| 37 |
+
PROPRIO_DIM,
|
| 38 |
+
ROLL_OUT_HORIZON,
|
| 39 |
+
SMOKE_ADAPTER_CONFIDENCE_THRESHOLD,
|
| 40 |
+
SMOKE_RETRIEVE_ACCESS_THRESHOLD,
|
| 41 |
+
SMOKE_RETRIEVE_PERSISTENCE_THRESHOLD,
|
| 42 |
+
SMOKE_RETRIEVE_REOCCLUSION_THRESHOLD,
|
| 43 |
+
SMOKE_RETRIEVE_SUPPORT_THRESHOLD,
|
| 44 |
+
STATE_METRIC_MASK,
|
| 45 |
+
STATE_SUPERVISION_METRICS,
|
| 46 |
+
SUPPORT_MODE_HOLD,
|
| 47 |
+
SUPPORT_MODE_PASSIVE,
|
| 48 |
+
SUPPORT_MODE_TRANSFER,
|
| 49 |
+
_aggregate_epoch,
|
| 50 |
+
_apply_smoke_planner_overrides,
|
| 51 |
+
_history_stack,
|
| 52 |
+
_init_history_entry,
|
| 53 |
+
_load_checkpoint,
|
| 54 |
+
_load_init_bundle,
|
| 55 |
+
_make_loader,
|
| 56 |
+
_save_training_checkpoint,
|
| 57 |
+
)
|
| 58 |
+
from eval.run_public_benchmark_package import summarize_public_benchmark_package
|
| 59 |
+
from models.action_decoder import ChunkDecoderConfig, TASK_INDEX, semantic_macro_chunk
|
| 60 |
+
from train.losses import LossWeights, compute_total_loss
|
| 61 |
+
from train.run_experiment import _load_init_checkpoint, _move_batch_to_device
|
| 62 |
+
from train.trainer import BimanualTrainer, TrainerConfig, apply_trainable_parameter_prefixes, build_policy
|
| 63 |
+
|
| 64 |
+
import gymnasium as gym # noqa: E402
|
| 65 |
+
import mani_skill.envs # noqa: E402
|
| 66 |
+
from mani_skill.utils.structs.pose import Pose # noqa: E402
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
WORKSPACE_ROOT = Path("/workspace/workspace")
|
| 70 |
+
SMOKE_VERSION = "bridge_smoke_v1"
|
| 71 |
+
CAMERA_NAMES = ("front", "left", "right")
|
| 72 |
+
IMAGE_RESOLUTION = 224
|
| 73 |
+
DEFAULT_SEED = 17
|
| 74 |
+
VIEW_VISIBILITY_SCALE = 0.0125
|
| 75 |
+
CLOTH_HIDDEN_SETTLE_STEPS = 25
|
| 76 |
+
CLOTH_SUCCESS_MIN_Y_DELTA = 0.10
|
| 77 |
+
CLOTH_SUCCESS_MIN_PLANAR_DELTA = 0.10
|
| 78 |
+
CLOTH_SUCCESS_MIN_VISIBILITY = 0.45
|
| 79 |
+
CLOTH_FIXED_SOURCE_X = -0.235
|
| 80 |
+
CLOTH_FIXED_SOURCE_Y = -0.094
|
| 81 |
+
CLOTH_FIXED_SOURCE_Z = 0.8748
|
| 82 |
+
CLOTH_FIXED_COVER_X = -0.235
|
| 83 |
+
CLOTH_FIXED_COVER_Y = -0.075
|
| 84 |
+
CLOTH_FIXED_COVER_Z = 0.885
|
| 85 |
+
EXPECTED_PROPOSAL_CANDIDATES = ChunkDecoderConfig().num_candidates
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass(frozen=True)
|
| 89 |
+
class SmokePaths:
|
| 90 |
+
data_dir: Path
|
| 91 |
+
output_dir: Path
|
| 92 |
+
report_dir: Path
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass(frozen=True)
|
| 96 |
+
class SmokeSpec:
|
| 97 |
+
resolution: int = IMAGE_RESOLUTION
|
| 98 |
+
train_episodes: int = 32
|
| 99 |
+
val_episodes: int = 8
|
| 100 |
+
eval_episodes: int = 50
|
| 101 |
+
dataset_seed: int = DEFAULT_SEED
|
| 102 |
+
train_seed: int = DEFAULT_SEED
|
| 103 |
+
history_steps: int = HISTORY_STEPS
|
| 104 |
+
max_macro_steps: int = MAX_MACRO_STEPS
|
| 105 |
+
batch_size: int = 4
|
| 106 |
+
epochs: int = 6
|
| 107 |
+
num_workers: int = 16
|
| 108 |
+
learning_rate: float = 1e-4
|
| 109 |
+
weight_decay: float = 1e-4
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def seed(self) -> int:
|
| 113 |
+
return self.train_seed
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@dataclass(frozen=True)
|
| 117 |
+
class BridgeTaskSpec:
|
| 118 |
+
key: str
|
| 119 |
+
env_id: str
|
| 120 |
+
track_id: str
|
| 121 |
+
suite: str
|
| 122 |
+
benchmark_task: str
|
| 123 |
+
task_name: str
|
| 124 |
+
text_prompt: str
|
| 125 |
+
mode_order: tuple[str, ...]
|
| 126 |
+
reveal_modes: tuple[str, ...]
|
| 127 |
+
transfer_modes: tuple[str, ...]
|
| 128 |
+
retrieve_modes: tuple[str, ...]
|
| 129 |
+
notes: str
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
TASK_SPECS: dict[str, BridgeTaskSpec] = {
|
| 133 |
+
"bag": BridgeTaskSpec(
|
| 134 |
+
key="bag",
|
| 135 |
+
env_id="PutEggplantInBasketScene-v1",
|
| 136 |
+
track_id="bag_track",
|
| 137 |
+
suite="maniskill3",
|
| 138 |
+
benchmark_task="PutEggplantInBasketRetrievalProxy-v1",
|
| 139 |
+
task_name="bag",
|
| 140 |
+
text_prompt="retrieve the target object from inside the basket and stage it outside the basket",
|
| 141 |
+
mode_order=(
|
| 142 |
+
"base_action",
|
| 143 |
+
"pin_left_rim",
|
| 144 |
+
"pin_right_rim",
|
| 145 |
+
"widen_mouth",
|
| 146 |
+
"maintain_mouth",
|
| 147 |
+
"probe_inside",
|
| 148 |
+
"insert_actor",
|
| 149 |
+
"retrieve",
|
| 150 |
+
),
|
| 151 |
+
reveal_modes=("pin_left_rim", "pin_right_rim", "widen_mouth", "maintain_mouth", "probe_inside"),
|
| 152 |
+
transfer_modes=("insert_actor",),
|
| 153 |
+
retrieve_modes=("retrieve",),
|
| 154 |
+
notes=(
|
| 155 |
+
"Public ManiSkill bridge scene with custom retrieval initialization. The eggplant is placed inside the "
|
| 156 |
+
"basket region and must be pulled back out to a staging zone."
|
| 157 |
+
),
|
| 158 |
+
),
|
| 159 |
+
"cloth": BridgeTaskSpec(
|
| 160 |
+
key="cloth",
|
| 161 |
+
env_id="PutSpoonOnTableClothInScene-v1",
|
| 162 |
+
track_id="cloth_track",
|
| 163 |
+
suite="maniskill3",
|
| 164 |
+
benchmark_task="PutSpoonUnderClothRetrievalProxy-v1",
|
| 165 |
+
task_name="cloth",
|
| 166 |
+
text_prompt="reveal the spoon from under the cloth and retrieve it to the open area",
|
| 167 |
+
mode_order=(
|
| 168 |
+
"base_action",
|
| 169 |
+
"lift_edge",
|
| 170 |
+
"separate_layer",
|
| 171 |
+
"stabilize_fold",
|
| 172 |
+
"maintain_lift",
|
| 173 |
+
"insert_actor",
|
| 174 |
+
"retrieve",
|
| 175 |
+
),
|
| 176 |
+
reveal_modes=("lift_edge", "separate_layer", "stabilize_fold", "maintain_lift"),
|
| 177 |
+
transfer_modes=("insert_actor",),
|
| 178 |
+
retrieve_modes=("retrieve",),
|
| 179 |
+
notes=(
|
| 180 |
+
"Public ManiSkill bridge scene with custom retrieval initialization. The spoon is placed under the "
|
| 181 |
+
"cloth region and must be revealed and extracted to the open side of the table."
|
| 182 |
+
),
|
| 183 |
+
),
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _task_spec(task: str) -> BridgeTaskSpec:
|
| 188 |
+
normalized = str(task).strip().lower()
|
| 189 |
+
if normalized not in TASK_SPECS:
|
| 190 |
+
raise KeyError(f"Unsupported task {task!r}. Expected one of {sorted(TASK_SPECS)}.")
|
| 191 |
+
return TASK_SPECS[normalized]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _default_paths(task_spec: BridgeTaskSpec) -> SmokePaths:
|
| 195 |
+
return SmokePaths(
|
| 196 |
+
data_dir=WORKSPACE_ROOT / "data" / "maniskill_bridge_retrieval" / f"{task_spec.key}_{SMOKE_VERSION}",
|
| 197 |
+
output_dir=WORKSPACE_ROOT / "outputs" / f"maniskill_{task_spec.key}_{SMOKE_VERSION}",
|
| 198 |
+
report_dir=WORKSPACE_ROOT / "reports" / f"maniskill_{task_spec.key}_{SMOKE_VERSION}",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _dataset_artifact_path(data_dir: Path, basename: str, *, dataset_seed: int) -> Path:
|
| 203 |
+
if int(dataset_seed) == DEFAULT_SEED:
|
| 204 |
+
return data_dir / basename
|
| 205 |
+
artifact = Path(basename)
|
| 206 |
+
return data_dir / f"{artifact.stem}_seed{int(dataset_seed)}{artifact.suffix}"
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _np(value: Any, *, dtype: np.dtype | None = None) -> np.ndarray:
|
| 210 |
+
if isinstance(value, np.ndarray):
|
| 211 |
+
array = value
|
| 212 |
+
elif isinstance(value, torch.Tensor):
|
| 213 |
+
array = value.detach().cpu().numpy()
|
| 214 |
+
else:
|
| 215 |
+
array = np.asarray(value)
|
| 216 |
+
if dtype is not None:
|
| 217 |
+
array = array.astype(dtype, copy=False)
|
| 218 |
+
return array
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _vec3(value: Any) -> np.ndarray:
|
| 222 |
+
return _np(value, dtype=np.float32).reshape(-1)[:3]
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def _resize_rgb(rgb: np.ndarray, size: int) -> np.ndarray:
|
| 226 |
+
tensor = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0).float()
|
| 227 |
+
resized = F.interpolate(tensor, size=(size, size), mode="bilinear", align_corners=False)
|
| 228 |
+
return resized[0].permute(1, 2, 0).round().clamp(0, 255).to(dtype=torch.uint8).cpu().numpy()
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _resize_single_channel(image: np.ndarray, size: int, *, dtype: np.dtype) -> np.ndarray:
|
| 232 |
+
tensor = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).float()
|
| 233 |
+
resized = F.interpolate(tensor, size=(size, size), mode="nearest")
|
| 234 |
+
return resized[0, 0].to(dtype=torch.float32).cpu().numpy().astype(dtype, copy=False)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _camera_intrinsic_from_param(param: dict[str, Any]) -> np.ndarray:
|
| 238 |
+
for key in ("intrinsic_cv", "intrinsic", "cam_intrinsic"):
|
| 239 |
+
if key in param:
|
| 240 |
+
matrix = _np(param[key], dtype=np.float32)
|
| 241 |
+
return matrix[0] if matrix.ndim == 3 else matrix
|
| 242 |
+
return np.eye(3, dtype=np.float32)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _camera_extrinsic_from_param(param: dict[str, Any]) -> np.ndarray:
|
| 246 |
+
for key in ("cam2world_gl", "cam2world", "extrinsic_cv", "extrinsic"):
|
| 247 |
+
if key in param:
|
| 248 |
+
matrix = _np(param[key], dtype=np.float32)
|
| 249 |
+
return matrix[0] if matrix.ndim == 3 else matrix
|
| 250 |
+
return np.eye(4, dtype=np.float32)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _extract_sensor_bundle(obs: dict[str, Any], *, resolution: int) -> dict[str, np.ndarray]:
|
| 254 |
+
camera_name = next(iter(obs["sensor_data"].keys()))
|
| 255 |
+
view = obs["sensor_data"][camera_name]
|
| 256 |
+
param = obs["sensor_param"][camera_name]
|
| 257 |
+
rgb = _np(view["rgb"], dtype=np.uint8)
|
| 258 |
+
segmentation = _np(view["segmentation"], dtype=np.int16)
|
| 259 |
+
rgb = rgb[0] if rgb.ndim == 4 else rgb
|
| 260 |
+
segmentation = segmentation[0] if segmentation.ndim == 4 else segmentation
|
| 261 |
+
segmentation = segmentation[..., 0] if segmentation.ndim == 3 else segmentation
|
| 262 |
+
rgb_resized = _resize_rgb(rgb, resolution)
|
| 263 |
+
seg_resized = _resize_single_channel(segmentation, resolution, dtype=np.int16)
|
| 264 |
+
intrinsic = _camera_intrinsic_from_param(param)
|
| 265 |
+
extrinsic = _camera_extrinsic_from_param(param)
|
| 266 |
+
images = np.stack([rgb_resized.copy() for _ in CAMERA_NAMES], axis=0)
|
| 267 |
+
segmentations = np.stack([seg_resized.copy() for _ in CAMERA_NAMES], axis=0)
|
| 268 |
+
depths = np.zeros((len(CAMERA_NAMES), 1, resolution, resolution), dtype=np.float32)
|
| 269 |
+
depth_valid = np.zeros_like(depths, dtype=np.float32)
|
| 270 |
+
intrinsics = np.stack([intrinsic.copy() for _ in CAMERA_NAMES], axis=0)
|
| 271 |
+
extrinsics = np.stack([extrinsic.copy() for _ in CAMERA_NAMES], axis=0)
|
| 272 |
+
return {
|
| 273 |
+
"images": images,
|
| 274 |
+
"segmentations": segmentations,
|
| 275 |
+
"depths": depths,
|
| 276 |
+
"depth_valid": depth_valid,
|
| 277 |
+
"camera_intrinsics": intrinsics,
|
| 278 |
+
"camera_extrinsics": extrinsics,
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def _build_proprio(env: gym.Env[Any, Any]) -> np.ndarray:
|
| 283 |
+
base = env.unwrapped
|
| 284 |
+
qpos = _np(base.agent.robot.get_qpos(), dtype=np.float32).reshape(-1)
|
| 285 |
+
qvel = _np(base.agent.robot.get_qvel(), dtype=np.float32).reshape(-1)
|
| 286 |
+
ee_pose = base.agent.robot.links_map["ee_gripper_link"].pose
|
| 287 |
+
tcp_pose = np.concatenate([_vec3(ee_pose.p), _np(ee_pose.q, dtype=np.float32).reshape(-1)[:4]], axis=0)
|
| 288 |
+
gripper_width = qpos[-2:].sum(keepdims=True).astype(np.float32)
|
| 289 |
+
flat = np.concatenate([qpos, qvel, tcp_pose, gripper_width], axis=0)
|
| 290 |
+
if flat.shape[0] >= PROPRIO_DIM:
|
| 291 |
+
return flat[:PROPRIO_DIM]
|
| 292 |
+
padded = np.zeros((PROPRIO_DIM,), dtype=np.float32)
|
| 293 |
+
padded[: flat.shape[0]] = flat
|
| 294 |
+
return padded
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def _source_actor(env: gym.Env[Any, Any]) -> Any:
|
| 298 |
+
base = env.unwrapped
|
| 299 |
+
return base.objs[base.source_obj_name]
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _target_actor(env: gym.Env[Any, Any]) -> Any:
|
| 303 |
+
base = env.unwrapped
|
| 304 |
+
return base.objs[base.target_obj_name]
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _source_position(env: gym.Env[Any, Any]) -> np.ndarray:
|
| 308 |
+
return _vec3(_source_actor(env).pose.p)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _target_position(env: gym.Env[Any, Any]) -> np.ndarray:
|
| 312 |
+
return _vec3(_target_actor(env).pose.p)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def _ee_position(env: gym.Env[Any, Any]) -> np.ndarray:
|
| 316 |
+
return _vec3(env.unwrapped.agent.robot.links_map["ee_gripper_link"].pose.p)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _act_from_world_delta(delta_xyz: Sequence[float]) -> np.ndarray:
|
| 320 |
+
delta = np.asarray(delta_xyz, dtype=np.float32).reshape(3)
|
| 321 |
+
return np.asarray([-delta[0], -delta[1], delta[2]], dtype=np.float32)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _step_action(env: gym.Env[Any, Any], delta_xyz: Sequence[float], *, grip: float) -> None:
|
| 325 |
+
action = np.zeros((1, 7), dtype=np.float32)
|
| 326 |
+
action[0, :3] = np.clip(_act_from_world_delta(delta_xyz), -0.02, 0.02)
|
| 327 |
+
action[0, 6] = float(np.clip(grip, -1.0, 1.0))
|
| 328 |
+
env.step(action)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _hold(env: gym.Env[Any, Any], *, steps: int, grip: float) -> None:
|
| 332 |
+
for _ in range(int(steps)):
|
| 333 |
+
action = np.zeros((1, 7), dtype=np.float32)
|
| 334 |
+
action[0, 6] = float(np.clip(grip, -1.0, 1.0))
|
| 335 |
+
env.step(action)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _move_ee(env: gym.Env[Any, Any], goal_xyz: Sequence[float], *, grip: float, max_steps: int = 60, tol: float = 0.003) -> dict[str, Any]:
|
| 339 |
+
goal = np.asarray(goal_xyz, dtype=np.float32).reshape(3)
|
| 340 |
+
for _ in range(int(max_steps)):
|
| 341 |
+
ee = _ee_position(env)
|
| 342 |
+
delta = goal - ee
|
| 343 |
+
if float(np.linalg.norm(delta)) <= float(tol):
|
| 344 |
+
break
|
| 345 |
+
_step_action(env, delta, grip=grip)
|
| 346 |
+
return {"ee_position": _ee_position(env)}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _repeat_world_delta(env: gym.Env[Any, Any], delta_xyz: Sequence[float], *, grip: float, steps: int) -> None:
|
| 350 |
+
delta = np.asarray(delta_xyz, dtype=np.float32).reshape(3)
|
| 351 |
+
for _ in range(int(steps)):
|
| 352 |
+
_step_action(env, delta, grip=grip)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _snapshot_env(env: gym.Env[Any, Any]) -> dict[str, Any]:
|
| 356 |
+
return {"state_dict": env.unwrapped.get_state_dict()}
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _restore_env(env: gym.Env[Any, Any], snapshot: dict[str, Any]) -> None:
|
| 360 |
+
env.unwrapped.set_state_dict(snapshot["state_dict"])
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def _sync_env_state(src_env: gym.Env[Any, Any], dst_env: gym.Env[Any, Any]) -> None:
|
| 364 |
+
_restore_env(dst_env, _snapshot_env(src_env))
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _canonical_chunks(task_spec: BridgeTaskSpec) -> dict[str, np.ndarray]:
|
| 368 |
+
base = torch.zeros((1, 8, 14), dtype=torch.float32)
|
| 369 |
+
chunks = {"base_action": base.squeeze(0).numpy().astype(np.float32)}
|
| 370 |
+
for mode_name in task_spec.mode_order[1:]:
|
| 371 |
+
chunk = semantic_macro_chunk(base, task_name=task_spec.task_name, mode_name=mode_name).squeeze(0).cpu().numpy()
|
| 372 |
+
chunks[mode_name] = chunk.astype(np.float32)
|
| 373 |
+
return chunks
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _classify_mode_from_chunk(chunk: np.ndarray, canonical_chunks: dict[str, np.ndarray]) -> str:
|
| 377 |
+
candidate = np.asarray(chunk, dtype=np.float32)
|
| 378 |
+
distances = {
|
| 379 |
+
mode_name: float(np.mean(np.abs(candidate - prototype)))
|
| 380 |
+
for mode_name, prototype in canonical_chunks.items()
|
| 381 |
+
}
|
| 382 |
+
return min(distances, key=distances.get)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def _rng_for_seed(seed: int) -> np.random.Generator:
|
| 386 |
+
return np.random.default_rng(int(seed) + 31)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def _initialize_proxy_state(env: gym.Env[Any, Any], task_spec: BridgeTaskSpec, *, episode_seed: int) -> None:
|
| 390 |
+
base = env.unwrapped
|
| 391 |
+
rng = _rng_for_seed(episode_seed)
|
| 392 |
+
source = _source_actor(env)
|
| 393 |
+
source_pose = source.pose
|
| 394 |
+
source_q = _np(source_pose.q, dtype=np.float32).reshape(-1)[:4]
|
| 395 |
+
if task_spec.key == "bag":
|
| 396 |
+
center = _target_position(env)
|
| 397 |
+
start = center + np.asarray(
|
| 398 |
+
[
|
| 399 |
+
rng.uniform(-0.006, 0.006),
|
| 400 |
+
rng.uniform(-0.010, 0.004),
|
| 401 |
+
0.010 + rng.uniform(-0.002, 0.002),
|
| 402 |
+
],
|
| 403 |
+
dtype=np.float32,
|
| 404 |
+
)
|
| 405 |
+
else:
|
| 406 |
+
cover = _target_actor(env)
|
| 407 |
+
start = np.asarray([CLOTH_FIXED_SOURCE_X, CLOTH_FIXED_SOURCE_Y, CLOTH_FIXED_SOURCE_Z], dtype=np.float32)
|
| 408 |
+
source.set_pose(Pose.create_from_pq(p=start, q=source_q))
|
| 409 |
+
if task_spec.key == "cloth":
|
| 410 |
+
_hold(env, steps=8, grip=1.0)
|
| 411 |
+
cover_pose = cover.pose
|
| 412 |
+
cover_q = _np(cover_pose.q, dtype=np.float32).reshape(-1)[:4]
|
| 413 |
+
cover_start = np.asarray([CLOTH_FIXED_COVER_X, CLOTH_FIXED_COVER_Y, CLOTH_FIXED_COVER_Z], dtype=np.float32)
|
| 414 |
+
cover.set_pose(Pose.create_from_pq(p=cover_start, q=cover_q))
|
| 415 |
+
_hold(env, steps=CLOTH_HIDDEN_SETTLE_STEPS, grip=1.0)
|
| 416 |
+
return
|
| 417 |
+
_hold(env, steps=5, grip=1.0)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def _source_visibility(obs_bundle: dict[str, np.ndarray], actor_id: int) -> float:
|
| 421 |
+
seg = obs_bundle["segmentations"]
|
| 422 |
+
match = float(np.mean(seg == int(actor_id)))
|
| 423 |
+
return float(np.clip(match / VIEW_VISIBILITY_SCALE, 0.0, 1.0))
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def _all_positions(env: gym.Env[Any, Any], task_spec: BridgeTaskSpec) -> dict[str, np.ndarray]:
|
| 427 |
+
positions = {"source": _source_position(env), "target": _target_position(env)}
|
| 428 |
+
if task_spec.key == "cloth":
|
| 429 |
+
positions["cover"] = _target_position(env)
|
| 430 |
+
return positions
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def _bag_progress(env: gym.Env[Any, Any]) -> float:
|
| 434 |
+
source = _source_position(env)
|
| 435 |
+
center = _target_position(env)
|
| 436 |
+
x_shift = max(source[0] - center[0], 0.0)
|
| 437 |
+
y_pull = max(center[1] - source[1], 0.0)
|
| 438 |
+
z_lift = max(source[2] - center[2], 0.0)
|
| 439 |
+
planar = float(np.linalg.norm((source - center)[:2]))
|
| 440 |
+
return float(
|
| 441 |
+
np.clip(
|
| 442 |
+
0.35 * (x_shift / 0.05)
|
| 443 |
+
+ 0.30 * (y_pull / 0.18)
|
| 444 |
+
+ 0.20 * (z_lift / 0.12)
|
| 445 |
+
+ 0.15 * (planar / 0.12),
|
| 446 |
+
0.0,
|
| 447 |
+
1.0,
|
| 448 |
+
)
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def _bag_success(env: gym.Env[Any, Any]) -> bool:
|
| 453 |
+
source = _source_position(env)
|
| 454 |
+
center = _target_position(env)
|
| 455 |
+
planar = float(np.linalg.norm((source - center)[:2]))
|
| 456 |
+
return bool(
|
| 457 |
+
planar >= 0.035
|
| 458 |
+
and (
|
| 459 |
+
source[0] >= center[0] + 0.020
|
| 460 |
+
or source[1] <= center[1] - 0.050
|
| 461 |
+
or source[2] >= center[2] + 0.050
|
| 462 |
+
)
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def _cloth_progress(
|
| 467 |
+
env: gym.Env[Any, Any],
|
| 468 |
+
*,
|
| 469 |
+
start_positions: dict[str, np.ndarray],
|
| 470 |
+
current_visibility: float,
|
| 471 |
+
) -> float:
|
| 472 |
+
source = _source_position(env)
|
| 473 |
+
source_start = start_positions["source"]
|
| 474 |
+
y_shift = max(source[1] - source_start[1], 0.0)
|
| 475 |
+
planar = float(np.linalg.norm((source - source_start)[:2]))
|
| 476 |
+
return float(np.clip(0.45 * (y_shift / 0.16) + 0.35 * (planar / 0.16) + 0.20 * current_visibility, 0.0, 1.0))
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def _cloth_success(
|
| 480 |
+
env: gym.Env[Any, Any],
|
| 481 |
+
*,
|
| 482 |
+
start_positions: dict[str, np.ndarray],
|
| 483 |
+
current_visibility: float,
|
| 484 |
+
) -> bool:
|
| 485 |
+
source = _source_position(env)
|
| 486 |
+
source_start = start_positions["source"]
|
| 487 |
+
planar = float(np.linalg.norm((source - source_start)[:2]))
|
| 488 |
+
return bool(
|
| 489 |
+
source[1] >= source_start[1] + CLOTH_SUCCESS_MIN_Y_DELTA
|
| 490 |
+
and planar >= CLOTH_SUCCESS_MIN_PLANAR_DELTA
|
| 491 |
+
and current_visibility >= CLOTH_SUCCESS_MIN_VISIBILITY
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def _candidate_metrics(
|
| 496 |
+
env: gym.Env[Any, Any],
|
| 497 |
+
*,
|
| 498 |
+
task_spec: BridgeTaskSpec,
|
| 499 |
+
start_positions: dict[str, np.ndarray],
|
| 500 |
+
current_obs_bundle: dict[str, np.ndarray] | None = None,
|
| 501 |
+
) -> dict[str, float]:
|
| 502 |
+
source_actor = _source_actor(env)
|
| 503 |
+
actor_id = int(getattr(source_actor, "per_scene_id", -1))
|
| 504 |
+
visibility = 0.0
|
| 505 |
+
if current_obs_bundle is not None:
|
| 506 |
+
visibility = _source_visibility(current_obs_bundle, actor_id)
|
| 507 |
+
if task_spec.key == "bag":
|
| 508 |
+
progress = _bag_progress(env)
|
| 509 |
+
success = float(_bag_success(env))
|
| 510 |
+
disturbance = 0.0
|
| 511 |
+
access = float(np.clip(0.65 * progress + 0.35 * visibility, 0.0, 1.0))
|
| 512 |
+
else:
|
| 513 |
+
progress = _cloth_progress(env, start_positions=start_positions, current_visibility=visibility)
|
| 514 |
+
success = float(_cloth_success(env, start_positions=start_positions, current_visibility=visibility))
|
| 515 |
+
cloth_start = start_positions["cover"]
|
| 516 |
+
cloth_now = _target_position(env)
|
| 517 |
+
cloth_displacement = float(np.linalg.norm((cloth_now - cloth_start)[:2]))
|
| 518 |
+
disturbance = float(np.clip(max(cloth_displacement - 0.24, 0.0) / 0.14, 0.0, 1.0))
|
| 519 |
+
access = float(np.clip(0.55 * progress + 0.45 * visibility, 0.0, 1.0))
|
| 520 |
+
return {
|
| 521 |
+
"retrieval_success": success,
|
| 522 |
+
"disturbance": disturbance,
|
| 523 |
+
"visibility": visibility,
|
| 524 |
+
"clearance": access,
|
| 525 |
+
"progress": progress,
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def _execute_bag_mode(env: gym.Env[Any, Any], mode_name: str) -> None:
|
| 530 |
+
center = _target_position(env)
|
| 531 |
+
source = _source_position(env)
|
| 532 |
+
if mode_name == "retrieve":
|
| 533 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.08], dtype=np.float32), grip=1.0)
|
| 534 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.010], dtype=np.float32), grip=1.0, tol=0.002)
|
| 535 |
+
_hold(env, steps=10, grip=-1.0)
|
| 536 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.12], dtype=np.float32), grip=-1.0)
|
| 537 |
+
_move_ee(env, source + np.asarray([0.0, -0.18, 0.12], dtype=np.float32), grip=-1.0)
|
| 538 |
+
return
|
| 539 |
+
if mode_name == "insert_actor":
|
| 540 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.06], dtype=np.float32), grip=1.0)
|
| 541 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.018], dtype=np.float32), grip=1.0, tol=0.002)
|
| 542 |
+
_hold(env, steps=4, grip=1.0)
|
| 543 |
+
return
|
| 544 |
+
if mode_name == "probe_inside":
|
| 545 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.05], dtype=np.float32), grip=1.0)
|
| 546 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.026], dtype=np.float32), grip=1.0, tol=0.002)
|
| 547 |
+
_repeat_world_delta(env, np.asarray([0.010, -0.004, 0.0], dtype=np.float32), grip=1.0, steps=8)
|
| 548 |
+
return
|
| 549 |
+
if mode_name == "widen_mouth":
|
| 550 |
+
_move_ee(env, center + np.asarray([-0.04, 0.01, 0.06], dtype=np.float32), grip=1.0)
|
| 551 |
+
_move_ee(env, center + np.asarray([-0.03, 0.01, 0.028], dtype=np.float32), grip=1.0, tol=0.003)
|
| 552 |
+
_repeat_world_delta(env, np.asarray([0.012, -0.004, 0.0], dtype=np.float32), grip=1.0, steps=12)
|
| 553 |
+
return
|
| 554 |
+
if mode_name == "pin_left_rim":
|
| 555 |
+
_move_ee(env, center + np.asarray([-0.03, 0.01, 0.06], dtype=np.float32), grip=1.0)
|
| 556 |
+
_move_ee(env, center + np.asarray([-0.03, 0.01, 0.028], dtype=np.float32), grip=1.0, tol=0.003)
|
| 557 |
+
_repeat_world_delta(env, np.asarray([0.006, -0.003, 0.0], dtype=np.float32), grip=1.0, steps=8)
|
| 558 |
+
return
|
| 559 |
+
if mode_name == "pin_right_rim":
|
| 560 |
+
_move_ee(env, center + np.asarray([0.03, 0.01, 0.06], dtype=np.float32), grip=1.0)
|
| 561 |
+
_move_ee(env, center + np.asarray([0.03, 0.01, 0.028], dtype=np.float32), grip=1.0, tol=0.003)
|
| 562 |
+
_repeat_world_delta(env, np.asarray([-0.006, -0.003, 0.0], dtype=np.float32), grip=1.0, steps=8)
|
| 563 |
+
return
|
| 564 |
+
if mode_name in {"maintain_mouth", "base_action"}:
|
| 565 |
+
_move_ee(env, center + np.asarray([0.0, 0.0, 0.09], dtype=np.float32), grip=1.0, max_steps=30, tol=0.006)
|
| 566 |
+
_hold(env, steps=3, grip=1.0)
|
| 567 |
+
return
|
| 568 |
+
raise KeyError(f"Unsupported bag mode {mode_name!r}.")
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def _execute_cloth_mode(env: gym.Env[Any, Any], mode_name: str) -> None:
|
| 572 |
+
cloth = _target_position(env)
|
| 573 |
+
source = _source_position(env)
|
| 574 |
+
if mode_name == "retrieve":
|
| 575 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.05], dtype=np.float32), grip=1.0)
|
| 576 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.010], dtype=np.float32), grip=1.0, tol=0.002)
|
| 577 |
+
_hold(env, steps=10, grip=-1.0)
|
| 578 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.10], dtype=np.float32), grip=-1.0)
|
| 579 |
+
_move_ee(env, source + np.asarray([0.0, 0.16, 0.10], dtype=np.float32), grip=-1.0)
|
| 580 |
+
return
|
| 581 |
+
if mode_name == "insert_actor":
|
| 582 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.05], dtype=np.float32), grip=1.0)
|
| 583 |
+
_move_ee(env, source + np.asarray([0.0, 0.0, 0.018], dtype=np.float32), grip=1.0, tol=0.002)
|
| 584 |
+
_hold(env, steps=4, grip=1.0)
|
| 585 |
+
return
|
| 586 |
+
if mode_name == "lift_edge":
|
| 587 |
+
_move_ee(env, cloth + np.asarray([0.0, -0.03, 0.05], dtype=np.float32), grip=1.0)
|
| 588 |
+
_move_ee(env, cloth + np.asarray([0.0, -0.03, 0.015], dtype=np.float32), grip=1.0, tol=0.003)
|
| 589 |
+
_repeat_world_delta(env, np.asarray([0.0, 0.006, 0.0], dtype=np.float32), grip=1.0, steps=8)
|
| 590 |
+
return
|
| 591 |
+
if mode_name == "separate_layer":
|
| 592 |
+
_move_ee(env, cloth + np.asarray([-0.04, 0.0, 0.05], dtype=np.float32), grip=1.0)
|
| 593 |
+
_move_ee(env, cloth + np.asarray([-0.04, 0.0, 0.015], dtype=np.float32), grip=1.0, tol=0.003)
|
| 594 |
+
_repeat_world_delta(env, np.asarray([0.008, 0.002, 0.0], dtype=np.float32), grip=1.0, steps=10)
|
| 595 |
+
return
|
| 596 |
+
if mode_name == "stabilize_fold":
|
| 597 |
+
_move_ee(env, cloth + np.asarray([0.0, 0.03, 0.05], dtype=np.float32), grip=1.0)
|
| 598 |
+
_move_ee(env, cloth + np.asarray([0.0, 0.03, 0.015], dtype=np.float32), grip=1.0, tol=0.003)
|
| 599 |
+
_repeat_world_delta(env, np.asarray([0.0, -0.006, 0.0], dtype=np.float32), grip=1.0, steps=8)
|
| 600 |
+
return
|
| 601 |
+
if mode_name in {"maintain_lift", "base_action"}:
|
| 602 |
+
_move_ee(env, cloth + np.asarray([0.0, 0.06, 0.07], dtype=np.float32), grip=1.0, max_steps=30, tol=0.006)
|
| 603 |
+
_hold(env, steps=3, grip=1.0)
|
| 604 |
+
return
|
| 605 |
+
raise KeyError(f"Unsupported cloth mode {mode_name!r}.")
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def _execute_mode(env: gym.Env[Any, Any], task_spec: BridgeTaskSpec, mode_name: str) -> None:
|
| 609 |
+
if task_spec.key == "bag":
|
| 610 |
+
_execute_bag_mode(env, mode_name)
|
| 611 |
+
return
|
| 612 |
+
_execute_cloth_mode(env, mode_name)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def _mode_support_mode(task_spec: BridgeTaskSpec, mode_name: str, current_support_mode: int) -> int:
|
| 616 |
+
if mode_name in task_spec.reveal_modes:
|
| 617 |
+
return SUPPORT_MODE_HOLD
|
| 618 |
+
if mode_name in task_spec.transfer_modes:
|
| 619 |
+
return SUPPORT_MODE_TRANSFER
|
| 620 |
+
if mode_name in task_spec.retrieve_modes:
|
| 621 |
+
return SUPPORT_MODE_PASSIVE
|
| 622 |
+
return int(current_support_mode)
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def _mode_progress_schedule(task_spec: BridgeTaskSpec, mode_name: str) -> np.ndarray:
|
| 626 |
+
if mode_name in task_spec.reveal_modes:
|
| 627 |
+
return np.asarray([0.18, 0.38, 0.62, 0.84, 1.0], dtype=np.float32)
|
| 628 |
+
if mode_name in task_spec.transfer_modes:
|
| 629 |
+
return np.asarray([0.22, 0.44, 0.66, 0.86, 1.0], dtype=np.float32)
|
| 630 |
+
if mode_name in task_spec.retrieve_modes:
|
| 631 |
+
return np.asarray([0.34, 0.56, 0.76, 0.92, 1.0], dtype=np.float32)
|
| 632 |
+
return np.asarray([0.10, 0.22, 0.34, 0.44, 0.54], dtype=np.float32)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def _scalar_rollout(start: float, end: float, schedule: np.ndarray) -> np.ndarray:
|
| 636 |
+
return np.clip((1.0 - schedule) * float(start) + schedule * float(end), 0.0, 1.0).astype(np.float32)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def _current_state_targets(
|
| 640 |
+
task_spec: BridgeTaskSpec,
|
| 641 |
+
*,
|
| 642 |
+
obs_bundle: dict[str, np.ndarray],
|
| 643 |
+
candidate_metrics: Sequence[dict[str, float]],
|
| 644 |
+
episode_start_positions: dict[str, np.ndarray],
|
| 645 |
+
selected_mode: str,
|
| 646 |
+
env: gym.Env[Any, Any],
|
| 647 |
+
) -> dict[str, Any]:
|
| 648 |
+
metrics_by_name = {mode_name: payload for mode_name, payload in zip(task_spec.mode_order, candidate_metrics)}
|
| 649 |
+
current_metrics = _candidate_metrics(
|
| 650 |
+
env,
|
| 651 |
+
task_spec=task_spec,
|
| 652 |
+
start_positions=episode_start_positions,
|
| 653 |
+
current_obs_bundle=obs_bundle,
|
| 654 |
+
)
|
| 655 |
+
current_disturbance = float(np.clip(current_metrics["disturbance"], 0.0, 1.0))
|
| 656 |
+
current_visibility = float(np.clip(current_metrics["visibility"], 0.0, 1.0))
|
| 657 |
+
current_clearance = float(np.clip(current_metrics["clearance"], 0.0, 1.0))
|
| 658 |
+
current_progress = float(np.clip(current_metrics["progress"], 0.0, 1.0))
|
| 659 |
+
base_gap = float(np.clip(max(current_clearance, current_progress), 0.0, 1.0))
|
| 660 |
+
support_stability = float(np.clip(1.0 - 0.5 * current_disturbance, 0.0, 1.0))
|
| 661 |
+
hold_quality = float(np.clip(0.5 * (support_stability + max(current_clearance, current_progress)), 0.0, 1.0))
|
| 662 |
+
opening_quality = float(np.clip(0.55 * current_progress + 0.25 * current_clearance + 0.20 * current_visibility, 0.0, 1.0))
|
| 663 |
+
actor_feasibility = float(np.clip(0.6 * current_clearance + 0.4 * max(current_visibility, current_progress), 0.0, 1.0))
|
| 664 |
+
reocclusion_rate = float(np.clip(1.0 - max(current_clearance, current_visibility), 0.0, 1.0))
|
| 665 |
+
insertable_actor_corridor = float(np.clip(0.6 * actor_feasibility + 0.4 * base_gap, 0.0, 1.0))
|
| 666 |
+
insertion_corridor = float(np.clip(0.5 * actor_feasibility + 0.5 * base_gap, 0.0, 1.0))
|
| 667 |
+
layer_separation = float(np.clip(0.7 * base_gap + 0.3 * actor_feasibility, 0.0, 1.0))
|
| 668 |
+
fold_preservation = float(np.clip(1.0 - current_disturbance, 0.0, 1.0))
|
| 669 |
+
lift_too_much_risk = float(np.clip(current_disturbance + 0.5 * max(base_gap - 0.5, 0.0), 0.0, 1.0))
|
| 670 |
+
task_metrics = {
|
| 671 |
+
"opening_quality": opening_quality,
|
| 672 |
+
"actor_feasibility_score": actor_feasibility,
|
| 673 |
+
"gap_width": float(0.03 + 0.21 * base_gap),
|
| 674 |
+
"damage_proxy": current_disturbance,
|
| 675 |
+
"release_collapse_rate": reocclusion_rate,
|
| 676 |
+
"target_visibility_confidence": current_visibility,
|
| 677 |
+
"insertable_actor_corridor": insertable_actor_corridor,
|
| 678 |
+
"insertion_corridor": insertion_corridor,
|
| 679 |
+
"hold_quality": hold_quality,
|
| 680 |
+
"layer_separation_quality": layer_separation,
|
| 681 |
+
"fold_preservation": fold_preservation,
|
| 682 |
+
"top_layer_stability": support_stability,
|
| 683 |
+
"lift_too_much_risk": lift_too_much_risk,
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
base_metrics = metrics_by_name["base_action"]
|
| 687 |
+
insert_metrics = metrics_by_name["insert_actor"]
|
| 688 |
+
retrieve_metrics = metrics_by_name["retrieve"]
|
| 689 |
+
reveal_candidates = [metrics_by_name[mode_name] for mode_name in task_spec.reveal_modes]
|
| 690 |
+
reveal_access = max(candidate["candidate_actor_feasibility_auc"] for candidate in reveal_candidates)
|
| 691 |
+
reveal_reveal = max(candidate["candidate_reveal_achieved"] for candidate in reveal_candidates)
|
| 692 |
+
reveal_hold = max(candidate["candidate_hold_persistence"] for candidate in reveal_candidates)
|
| 693 |
+
reveal_visibility = max(candidate["candidate_visibility_integral"] for candidate in reveal_candidates)
|
| 694 |
+
|
| 695 |
+
reveal_corridor = float(
|
| 696 |
+
np.clip(
|
| 697 |
+
0.45 * opening_quality
|
| 698 |
+
+ 0.30 * reveal_access
|
| 699 |
+
+ 0.15 * reveal_reveal
|
| 700 |
+
+ 0.10 * reveal_visibility
|
| 701 |
+
- 0.10 * current_disturbance,
|
| 702 |
+
0.0,
|
| 703 |
+
1.0,
|
| 704 |
+
)
|
| 705 |
+
)
|
| 706 |
+
transfer_corridor = float(
|
| 707 |
+
np.clip(
|
| 708 |
+
0.45 * insertable_actor_corridor
|
| 709 |
+
+ 0.30 * insert_metrics["candidate_actor_feasibility_auc"]
|
| 710 |
+
+ 0.15 * insert_metrics["candidate_reveal_achieved"]
|
| 711 |
+
+ 0.10 * insert_metrics["candidate_visibility_integral"]
|
| 712 |
+
- 0.15 * current_disturbance,
|
| 713 |
+
0.0,
|
| 714 |
+
1.0,
|
| 715 |
+
)
|
| 716 |
+
)
|
| 717 |
+
passive_corridor = float(
|
| 718 |
+
np.clip(
|
| 719 |
+
0.55 * retrieve_metrics["candidate_retrieval_success"]
|
| 720 |
+
+ 0.20 * retrieve_metrics["candidate_actor_feasibility_auc"]
|
| 721 |
+
+ 0.15 * current_progress
|
| 722 |
+
+ 0.10 * current_clearance
|
| 723 |
+
- 0.10 * current_disturbance,
|
| 724 |
+
0.0,
|
| 725 |
+
1.0,
|
| 726 |
+
)
|
| 727 |
+
)
|
| 728 |
+
corridor_feasible = np.stack(
|
| 729 |
+
[
|
| 730 |
+
np.full((NUM_APPROACH_TEMPLATES,), reveal_corridor, dtype=np.float32),
|
| 731 |
+
np.full((NUM_APPROACH_TEMPLATES,), transfer_corridor, dtype=np.float32),
|
| 732 |
+
np.full((NUM_APPROACH_TEMPLATES,), passive_corridor, dtype=np.float32),
|
| 733 |
+
],
|
| 734 |
+
axis=0,
|
| 735 |
+
)
|
| 736 |
+
persistence_horizon = np.asarray(
|
| 737 |
+
[
|
| 738 |
+
ROLL_OUT_HORIZON * float(np.clip(0.35 * hold_quality + 0.35 * reveal_hold + 0.30 * reveal_corridor, 0.0, 1.0)),
|
| 739 |
+
ROLL_OUT_HORIZON
|
| 740 |
+
* float(
|
| 741 |
+
np.clip(
|
| 742 |
+
0.30 * hold_quality + 0.35 * insert_metrics["candidate_hold_persistence"] + 0.35 * transfer_corridor,
|
| 743 |
+
0.0,
|
| 744 |
+
1.0,
|
| 745 |
+
)
|
| 746 |
+
),
|
| 747 |
+
ROLL_OUT_HORIZON
|
| 748 |
+
* float(
|
| 749 |
+
np.clip(
|
| 750 |
+
0.25 * hold_quality + 0.35 * retrieve_metrics["candidate_hold_persistence"] + 0.40 * passive_corridor,
|
| 751 |
+
0.0,
|
| 752 |
+
1.0,
|
| 753 |
+
)
|
| 754 |
+
),
|
| 755 |
+
],
|
| 756 |
+
dtype=np.float32,
|
| 757 |
+
)
|
| 758 |
+
retrieve_margin = float(retrieve_metrics["candidate_utility"] - base_metrics["candidate_utility"])
|
| 759 |
+
insert_margin = float(insert_metrics["candidate_utility"] - base_metrics["candidate_utility"])
|
| 760 |
+
if selected_mode == "retrieve" or (retrieve_metrics["candidate_retrieval_success"] >= 0.5 and retrieve_margin >= 0.12):
|
| 761 |
+
support_mode = SUPPORT_MODE_PASSIVE
|
| 762 |
+
elif selected_mode == "insert_actor" or (insert_margin >= 0.12 and transfer_corridor >= 0.35):
|
| 763 |
+
support_mode = SUPPORT_MODE_TRANSFER
|
| 764 |
+
elif selected_mode in task_spec.reveal_modes:
|
| 765 |
+
support_mode = SUPPORT_MODE_HOLD
|
| 766 |
+
elif selected_mode == "base_action":
|
| 767 |
+
support_mode = SUPPORT_MODE_PASSIVE if passive_corridor >= 0.55 and retrieve_margin >= 0.03 else SUPPORT_MODE_HOLD
|
| 768 |
+
else:
|
| 769 |
+
support_mode = SUPPORT_MODE_HOLD
|
| 770 |
+
best_non_base_utility = max(float(payload["candidate_utility"]) for payload in candidate_metrics[1:])
|
| 771 |
+
intervention_warranted = selected_mode != "base_action" and best_non_base_utility >= float(base_metrics["candidate_utility"]) + 0.12
|
| 772 |
+
return {
|
| 773 |
+
"support_mode": int(support_mode),
|
| 774 |
+
"corridor_feasible": corridor_feasible,
|
| 775 |
+
"persistence_horizon": persistence_horizon,
|
| 776 |
+
"disturbance_cost": np.float32(current_disturbance),
|
| 777 |
+
"state_confidence_target": np.float32(1.0 if intervention_warranted else 0.0),
|
| 778 |
+
"task_metric_mask": STATE_METRIC_MASK.copy(),
|
| 779 |
+
**{metric_name: np.float32(metric_value) for metric_name, metric_value in task_metrics.items()},
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def _candidate_rollout_targets(
|
| 784 |
+
task_spec: BridgeTaskSpec,
|
| 785 |
+
*,
|
| 786 |
+
mode_name: str,
|
| 787 |
+
state_targets: dict[str, Any],
|
| 788 |
+
candidate_payload: dict[str, float],
|
| 789 |
+
) -> dict[str, np.ndarray]:
|
| 790 |
+
schedule = _mode_progress_schedule(task_spec, mode_name)
|
| 791 |
+
start_visibility = float(state_targets["target_visibility_confidence"])
|
| 792 |
+
start_access = float(state_targets["actor_feasibility_score"])
|
| 793 |
+
start_persistence = float(np.clip(state_targets["hold_quality"], 0.0, 1.0))
|
| 794 |
+
start_support = float(np.clip(state_targets["top_layer_stability"], 0.0, 1.0))
|
| 795 |
+
start_reocclusion = float(np.clip(state_targets["release_collapse_rate"], 0.0, 1.0))
|
| 796 |
+
start_disturbance = float(np.clip(state_targets["disturbance_cost"], 0.0, 1.0))
|
| 797 |
+
start_clearance = float(np.clip(state_targets["actor_feasibility_score"], 0.0, 1.0))
|
| 798 |
+
start_grasp = float(np.clip(max(start_visibility, start_access), 0.0, 1.0))
|
| 799 |
+
|
| 800 |
+
end_visibility = float(np.clip(candidate_payload["candidate_immediate_visibility"], 0.0, 1.0))
|
| 801 |
+
end_access = float(np.clip(candidate_payload["candidate_immediate_access"], 0.0, 1.0))
|
| 802 |
+
end_progress = float(np.clip(candidate_payload["candidate_immediate_progress"], 0.0, 1.0))
|
| 803 |
+
end_disturbance = float(np.clip(candidate_payload["candidate_immediate_disturbance"], 0.0, 1.0))
|
| 804 |
+
end_support = float(np.clip(candidate_payload["candidate_immediate_support_stability"], 0.0, 1.0))
|
| 805 |
+
end_persistence = float(np.clip(candidate_payload["candidate_immediate_hold_persistence"], 0.0, 1.0))
|
| 806 |
+
end_reocclusion = float(np.clip(candidate_payload["candidate_immediate_reocclusion"], 0.0, 1.0))
|
| 807 |
+
end_clearance = float(np.clip(max(end_access, end_progress), 0.0, 1.0))
|
| 808 |
+
end_grasp = float(np.clip(max(end_visibility, 0.5 * end_access + 0.5 * end_progress), 0.0, 1.0))
|
| 809 |
+
|
| 810 |
+
if mode_name in task_spec.transfer_modes:
|
| 811 |
+
start_visibility = max(start_visibility, 0.35 * end_visibility)
|
| 812 |
+
start_access = max(start_access, 0.40 * end_access)
|
| 813 |
+
start_persistence = max(start_persistence, 0.45 * end_persistence)
|
| 814 |
+
start_support = max(start_support, 0.50 * end_support)
|
| 815 |
+
elif mode_name in task_spec.retrieve_modes:
|
| 816 |
+
start_visibility = max(start_visibility, 0.55 * end_visibility)
|
| 817 |
+
start_access = max(start_access, 0.70 * end_access)
|
| 818 |
+
start_persistence = max(start_persistence, 0.65 * end_persistence)
|
| 819 |
+
start_support = max(start_support, 0.65 * end_support)
|
| 820 |
+
start_reocclusion = min(start_reocclusion, max(0.4 * end_reocclusion, 0.0))
|
| 821 |
+
|
| 822 |
+
visibility = _scalar_rollout(start_visibility, end_visibility, schedule)
|
| 823 |
+
access = _scalar_rollout(start_access, end_access, schedule)
|
| 824 |
+
persistence = _scalar_rollout(start_persistence, end_persistence, schedule)
|
| 825 |
+
support = _scalar_rollout(start_support, end_support, schedule)
|
| 826 |
+
reocclusion = _scalar_rollout(start_reocclusion, end_reocclusion, schedule)
|
| 827 |
+
disturbance = _scalar_rollout(start_disturbance, end_disturbance, schedule)
|
| 828 |
+
clearance = _scalar_rollout(start_clearance, end_clearance, schedule)
|
| 829 |
+
grasp = _scalar_rollout(start_grasp, end_grasp, schedule)
|
| 830 |
+
reveal_corridor = np.clip(0.38 * visibility + 0.34 * access + 0.22 * support - 0.12 * disturbance, 0.0, 1.0)
|
| 831 |
+
transfer_corridor = np.clip(
|
| 832 |
+
0.30 * visibility + 0.38 * access + 0.18 * persistence + 0.14 * support - 0.12 * disturbance,
|
| 833 |
+
0.0,
|
| 834 |
+
1.0,
|
| 835 |
+
)
|
| 836 |
+
passive_corridor = np.clip(
|
| 837 |
+
0.22 * visibility + 0.42 * access + 0.20 * persistence + 0.16 * grasp - 0.14 * disturbance - 0.10 * reocclusion,
|
| 838 |
+
0.0,
|
| 839 |
+
1.0,
|
| 840 |
+
)
|
| 841 |
+
if mode_name in task_spec.reveal_modes:
|
| 842 |
+
reveal_corridor = np.clip(reveal_corridor + 0.14, 0.0, 1.0)
|
| 843 |
+
passive_corridor = np.clip(0.75 * passive_corridor, 0.0, 1.0)
|
| 844 |
+
elif mode_name in task_spec.transfer_modes:
|
| 845 |
+
transfer_corridor = np.clip(transfer_corridor + 0.16, 0.0, 1.0)
|
| 846 |
+
elif mode_name in task_spec.retrieve_modes:
|
| 847 |
+
passive_corridor = np.clip(passive_corridor + 0.20, 0.0, 1.0)
|
| 848 |
+
reveal_corridor = np.clip(0.60 * reveal_corridor, 0.0, 1.0)
|
| 849 |
+
else:
|
| 850 |
+
reveal_corridor = np.clip(0.85 * reveal_corridor, 0.0, 1.0)
|
| 851 |
+
transfer_corridor = np.clip(0.75 * transfer_corridor, 0.0, 1.0)
|
| 852 |
+
passive_corridor = np.clip(0.80 * passive_corridor, 0.0, 1.0)
|
| 853 |
+
corridor_feasible = np.stack(
|
| 854 |
+
[
|
| 855 |
+
np.repeat(reveal_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
|
| 856 |
+
np.repeat(transfer_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
|
| 857 |
+
np.repeat(passive_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
|
| 858 |
+
],
|
| 859 |
+
axis=1,
|
| 860 |
+
).astype(np.float32)
|
| 861 |
+
persistence_horizon = np.stack(
|
| 862 |
+
[
|
| 863 |
+
np.clip(ROLL_OUT_HORIZON * (0.55 * reveal_corridor + 0.45 * support), 0.0, float(ROLL_OUT_HORIZON)),
|
| 864 |
+
np.clip(ROLL_OUT_HORIZON * (0.50 * transfer_corridor + 0.50 * persistence), 0.0, float(ROLL_OUT_HORIZON)),
|
| 865 |
+
np.clip(ROLL_OUT_HORIZON * (0.55 * passive_corridor + 0.45 * persistence), 0.0, float(ROLL_OUT_HORIZON)),
|
| 866 |
+
],
|
| 867 |
+
axis=1,
|
| 868 |
+
).astype(np.float32)
|
| 869 |
+
support_mode = np.full((ROLL_OUT_HORIZON,), _mode_support_mode(task_spec, mode_name, int(state_targets["support_mode"])), dtype=np.int64)
|
| 870 |
+
if mode_name == "base_action":
|
| 871 |
+
support_mode[:] = int(state_targets["support_mode"])
|
| 872 |
+
return {
|
| 873 |
+
"candidate_rollout_support_mode": support_mode,
|
| 874 |
+
"candidate_rollout_corridor_feasible": corridor_feasible,
|
| 875 |
+
"candidate_rollout_persistence_horizon": persistence_horizon,
|
| 876 |
+
"candidate_rollout_disturbance_cost": disturbance.astype(np.float32),
|
| 877 |
+
"candidate_rollout_belief_map": visibility[:, None, None].astype(np.float32),
|
| 878 |
+
"candidate_rollout_visibility_map": visibility[:, None, None].astype(np.float32),
|
| 879 |
+
"candidate_rollout_clearance_map": np.repeat(clearance[:, None, None, None], 2, axis=1).astype(np.float32),
|
| 880 |
+
"candidate_rollout_support_stability": support[:, None, None, None].astype(np.float32),
|
| 881 |
+
"candidate_rollout_reocclusion_target": reocclusion[:, None, None].astype(np.float32),
|
| 882 |
+
"candidate_rollout_occluder_contact_map": np.clip(access * support, 0.0, 1.0)[:, None, None].astype(np.float32),
|
| 883 |
+
"candidate_rollout_grasp_affordance_map": grasp[:, None, None].astype(np.float32),
|
| 884 |
+
}
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def _select_expert_mode(
|
| 888 |
+
task_spec: BridgeTaskSpec,
|
| 889 |
+
*,
|
| 890 |
+
decision_step: int,
|
| 891 |
+
candidate_metrics: Sequence[dict[str, float]],
|
| 892 |
+
) -> str:
|
| 893 |
+
metrics_by_name = {mode_name: payload for mode_name, payload in zip(task_spec.mode_order, candidate_metrics)}
|
| 894 |
+
base_utility = float(metrics_by_name["base_action"]["candidate_utility"])
|
| 895 |
+
reveal_best = max(task_spec.reveal_modes, key=lambda name: float(metrics_by_name[name]["candidate_utility"]))
|
| 896 |
+
transfer_best = max(task_spec.transfer_modes, key=lambda name: float(metrics_by_name[name]["candidate_utility"]))
|
| 897 |
+
retrieve_utility = float(metrics_by_name["retrieve"]["candidate_utility"])
|
| 898 |
+
reveal_best_utility = float(metrics_by_name[reveal_best]["candidate_utility"])
|
| 899 |
+
transfer_best_utility = float(metrics_by_name[transfer_best]["candidate_utility"])
|
| 900 |
+
retrieve_success = float(metrics_by_name["retrieve"]["candidate_retrieval_success"])
|
| 901 |
+
|
| 902 |
+
if int(decision_step) > 0 and retrieve_success >= 0.5:
|
| 903 |
+
return "retrieve"
|
| 904 |
+
if int(decision_step) == 0 and reveal_best_utility >= base_utility - 0.02:
|
| 905 |
+
return reveal_best
|
| 906 |
+
if transfer_best_utility >= reveal_best_utility + 0.05 and transfer_best_utility >= base_utility + 0.02:
|
| 907 |
+
return transfer_best
|
| 908 |
+
if reveal_best_utility >= base_utility - 0.02:
|
| 909 |
+
return reveal_best
|
| 910 |
+
if retrieve_success >= 0.5 and retrieve_utility >= base_utility + 0.02:
|
| 911 |
+
return "retrieve"
|
| 912 |
+
if transfer_best_utility >= base_utility + 0.02:
|
| 913 |
+
return transfer_best
|
| 914 |
+
utilities = np.asarray([payload["candidate_utility"] for payload in candidate_metrics], dtype=np.float32)
|
| 915 |
+
return task_spec.mode_order[int(utilities.argmax())]
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
def _evaluate_candidate(
|
| 919 |
+
task_spec: BridgeTaskSpec,
|
| 920 |
+
sim_env: gym.Env[Any, Any],
|
| 921 |
+
obs_env: gym.Env[Any, Any],
|
| 922 |
+
snapshot: dict[str, Any],
|
| 923 |
+
mode_name: str,
|
| 924 |
+
) -> dict[str, float]:
|
| 925 |
+
_restore_env(sim_env, snapshot)
|
| 926 |
+
start_positions = _all_positions(sim_env, task_spec)
|
| 927 |
+
_execute_mode(sim_env, task_spec, mode_name)
|
| 928 |
+
_sync_env_state(sim_env, obs_env)
|
| 929 |
+
after_bundle = _extract_sensor_bundle(obs_env.get_obs(obs_env.get_info()), resolution=IMAGE_RESOLUTION)
|
| 930 |
+
immediate = _candidate_metrics(sim_env, task_spec=task_spec, start_positions=start_positions, current_obs_bundle=after_bundle)
|
| 931 |
+
if not immediate["retrieval_success"] and mode_name not in {"retrieve", "base_action"}:
|
| 932 |
+
_execute_mode(sim_env, task_spec, "retrieve")
|
| 933 |
+
_sync_env_state(sim_env, obs_env)
|
| 934 |
+
follow_bundle = _extract_sensor_bundle(obs_env.get_obs(obs_env.get_info()), resolution=IMAGE_RESOLUTION)
|
| 935 |
+
final_metrics = _candidate_metrics(sim_env, task_spec=task_spec, start_positions=start_positions, current_obs_bundle=follow_bundle)
|
| 936 |
+
else:
|
| 937 |
+
final_metrics = immediate
|
| 938 |
+
_restore_env(obs_env, snapshot)
|
| 939 |
+
utility = (
|
| 940 |
+
2.5 * final_metrics["retrieval_success"]
|
| 941 |
+
+ 1.0 * final_metrics["progress"]
|
| 942 |
+
+ 0.5 * final_metrics["clearance"]
|
| 943 |
+
+ 0.25 * final_metrics["visibility"]
|
| 944 |
+
- 0.5 * final_metrics["disturbance"]
|
| 945 |
+
)
|
| 946 |
+
return {
|
| 947 |
+
"candidate_retrieval_success": final_metrics["retrieval_success"],
|
| 948 |
+
"candidate_risk": float(np.clip(final_metrics["disturbance"], 0.0, 1.0)),
|
| 949 |
+
"candidate_utility": float(utility),
|
| 950 |
+
"candidate_final_disturbance_cost": final_metrics["disturbance"],
|
| 951 |
+
"candidate_reocclusion_rate": float(np.clip(1.0 - final_metrics["clearance"], 0.0, 1.0)),
|
| 952 |
+
"candidate_visibility_integral": final_metrics["visibility"],
|
| 953 |
+
"candidate_actor_feasibility_auc": final_metrics["clearance"],
|
| 954 |
+
"candidate_reveal_achieved": float(final_metrics["progress"] > 0.15 or final_metrics["clearance"] > 0.35),
|
| 955 |
+
"candidate_hold_persistence": float(1.0 - final_metrics["disturbance"]),
|
| 956 |
+
"candidate_support_stability_auc": float(1.0 - 0.5 * final_metrics["disturbance"]),
|
| 957 |
+
"candidate_disturbance_auc": final_metrics["disturbance"],
|
| 958 |
+
"candidate_immediate_retrieval_success": immediate["retrieval_success"],
|
| 959 |
+
"candidate_immediate_visibility": immediate["visibility"],
|
| 960 |
+
"candidate_immediate_access": immediate["clearance"],
|
| 961 |
+
"candidate_immediate_progress": immediate["progress"],
|
| 962 |
+
"candidate_immediate_reocclusion": float(np.clip(1.0 - immediate["clearance"], 0.0, 1.0)),
|
| 963 |
+
"candidate_immediate_hold_persistence": float(1.0 - immediate["disturbance"]),
|
| 964 |
+
"candidate_immediate_support_stability": float(1.0 - 0.5 * immediate["disturbance"]),
|
| 965 |
+
"candidate_immediate_disturbance": immediate["disturbance"],
|
| 966 |
+
}
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
def _cloth_seed_is_valid(env: gym.Env[Any, Any], task_spec: BridgeTaskSpec, *, episode_seed: int) -> bool:
|
| 970 |
+
env.reset(seed=int(episode_seed))
|
| 971 |
+
_initialize_proxy_state(env, task_spec, episode_seed=int(episode_seed))
|
| 972 |
+
start_positions = _all_positions(env, task_spec)
|
| 973 |
+
obs = env.get_obs(env.get_info())
|
| 974 |
+
obs_bundle = _extract_sensor_bundle(obs, resolution=IMAGE_RESOLUTION)
|
| 975 |
+
actor_id = int(getattr(_source_actor(env), "per_scene_id", -1))
|
| 976 |
+
start_visibility = _source_visibility(obs_bundle, actor_id)
|
| 977 |
+
if start_visibility > CLOTH_SUCCESS_MIN_VISIBILITY:
|
| 978 |
+
return False
|
| 979 |
+
snapshot = _snapshot_env(env)
|
| 980 |
+
for reveal_mode in ("lift_edge", "separate_layer"):
|
| 981 |
+
_restore_env(env, snapshot)
|
| 982 |
+
_execute_mode(env, task_spec, reveal_mode)
|
| 983 |
+
_execute_mode(env, task_spec, "retrieve")
|
| 984 |
+
obs = env.get_obs(env.get_info())
|
| 985 |
+
obs_bundle = _extract_sensor_bundle(obs, resolution=IMAGE_RESOLUTION)
|
| 986 |
+
visibility = _source_visibility(obs_bundle, actor_id)
|
| 987 |
+
if _cloth_success(env, start_positions=start_positions, current_visibility=visibility):
|
| 988 |
+
return True
|
| 989 |
+
return False
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
def _build_episode_splits(task_spec: BridgeTaskSpec, spec: SmokeSpec) -> dict[str, list[int]]:
|
| 993 |
+
if task_spec.key != "cloth":
|
| 994 |
+
return {
|
| 995 |
+
"train": [spec.dataset_seed * 10_000 + index for index in range(spec.train_episodes)],
|
| 996 |
+
"val": [spec.dataset_seed * 10_000 + 1_000 + index for index in range(spec.val_episodes)],
|
| 997 |
+
"eval": [spec.dataset_seed * 10_000 + 2_000 + index for index in range(spec.eval_episodes)],
|
| 998 |
+
}
|
| 999 |
+
target_total = int(spec.train_episodes + spec.val_episodes + spec.eval_episodes)
|
| 1000 |
+
valid_seeds: list[int] = []
|
| 1001 |
+
candidate_index = 0
|
| 1002 |
+
env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
|
| 1003 |
+
try:
|
| 1004 |
+
while len(valid_seeds) < target_total:
|
| 1005 |
+
episode_seed = spec.dataset_seed * 10_000 + candidate_index
|
| 1006 |
+
candidate_index += 1
|
| 1007 |
+
if _cloth_seed_is_valid(env, task_spec, episode_seed=episode_seed):
|
| 1008 |
+
valid_seeds.append(int(episode_seed))
|
| 1009 |
+
print(
|
| 1010 |
+
json.dumps(
|
| 1011 |
+
{
|
| 1012 |
+
"phase": "cloth_seed_selected",
|
| 1013 |
+
"episode_seed": int(episode_seed),
|
| 1014 |
+
"selected": len(valid_seeds),
|
| 1015 |
+
"target_total": target_total,
|
| 1016 |
+
}
|
| 1017 |
+
),
|
| 1018 |
+
flush=True,
|
| 1019 |
+
)
|
| 1020 |
+
if candidate_index > target_total * 30:
|
| 1021 |
+
raise RuntimeError("Unable to find enough physics-valid cloth proxy seeds for the smoke protocol.")
|
| 1022 |
+
finally:
|
| 1023 |
+
env.close()
|
| 1024 |
+
return {
|
| 1025 |
+
"train": valid_seeds[: spec.train_episodes],
|
| 1026 |
+
"val": valid_seeds[spec.train_episodes : spec.train_episodes + spec.val_episodes],
|
| 1027 |
+
"eval": valid_seeds[spec.train_episodes + spec.val_episodes : target_total],
|
| 1028 |
+
}
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
def _save_episode_splits(output_path: Path, payload: dict[str, list[int]]) -> None:
|
| 1032 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1033 |
+
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
def _normalize_depth_array(array: np.ndarray) -> np.ndarray:
|
| 1037 |
+
normalized = _np(array, dtype=np.float32)
|
| 1038 |
+
if normalized.ndim >= 4 and normalized.shape[-1] == 1:
|
| 1039 |
+
return np.moveaxis(normalized, -1, normalized.ndim - 3)
|
| 1040 |
+
return normalized
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
def _candidate_pad_indices(source_candidates: int, expected_candidates: int) -> list[int]:
|
| 1044 |
+
if source_candidates <= 0 or source_candidates >= expected_candidates:
|
| 1045 |
+
return []
|
| 1046 |
+
if source_candidates == 1:
|
| 1047 |
+
return [0] * (expected_candidates - source_candidates)
|
| 1048 |
+
cycle = list(range(1, source_candidates))
|
| 1049 |
+
indices: list[int] = []
|
| 1050 |
+
while len(indices) < (expected_candidates - source_candidates):
|
| 1051 |
+
indices.extend(cycle)
|
| 1052 |
+
return indices[: expected_candidates - source_candidates]
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
def _pad_candidate_axis(
|
| 1056 |
+
value: Any,
|
| 1057 |
+
*,
|
| 1058 |
+
source_candidates: int,
|
| 1059 |
+
expected_candidates: int,
|
| 1060 |
+
pad_indices: Sequence[int],
|
| 1061 |
+
) -> Any:
|
| 1062 |
+
if source_candidates <= 0 or source_candidates >= expected_candidates:
|
| 1063 |
+
return value
|
| 1064 |
+
if isinstance(value, np.ndarray):
|
| 1065 |
+
if value.ndim == 0 or value.shape[0] != source_candidates:
|
| 1066 |
+
return value
|
| 1067 |
+
if not pad_indices:
|
| 1068 |
+
return value
|
| 1069 |
+
padding = np.take(value, indices=list(pad_indices), axis=0)
|
| 1070 |
+
return np.concatenate([value, padding], axis=0)
|
| 1071 |
+
if isinstance(value, torch.Tensor):
|
| 1072 |
+
if value.ndim == 0 or value.shape[0] != source_candidates:
|
| 1073 |
+
return value
|
| 1074 |
+
if not pad_indices:
|
| 1075 |
+
return value
|
| 1076 |
+
pad_index = torch.as_tensor(list(pad_indices), device=value.device, dtype=torch.long)
|
| 1077 |
+
padding = value.index_select(0, pad_index)
|
| 1078 |
+
return torch.cat([value, padding], dim=0)
|
| 1079 |
+
if isinstance(value, list) and len(value) == source_candidates:
|
| 1080 |
+
padded = list(value)
|
| 1081 |
+
padded.extend(value[index] for index in pad_indices)
|
| 1082 |
+
return padded
|
| 1083 |
+
if isinstance(value, tuple) and len(value) == source_candidates:
|
| 1084 |
+
padded = list(value)
|
| 1085 |
+
padded.extend(value[index] for index in pad_indices)
|
| 1086 |
+
return tuple(padded)
|
| 1087 |
+
return value
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
def _normalize_candidate_targets(sample: dict[str, Any]) -> dict[str, Any]:
|
| 1091 |
+
candidate_chunks = sample.get("candidate_action_chunks")
|
| 1092 |
+
if candidate_chunks is None:
|
| 1093 |
+
return sample
|
| 1094 |
+
candidate_array = _np(candidate_chunks)
|
| 1095 |
+
if candidate_array.ndim == 0:
|
| 1096 |
+
return sample
|
| 1097 |
+
source_candidates = int(candidate_array.shape[0])
|
| 1098 |
+
if source_candidates >= EXPECTED_PROPOSAL_CANDIDATES:
|
| 1099 |
+
return sample
|
| 1100 |
+
pad_indices = _candidate_pad_indices(source_candidates, EXPECTED_PROPOSAL_CANDIDATES)
|
| 1101 |
+
if not pad_indices:
|
| 1102 |
+
return sample
|
| 1103 |
+
padded = dict(sample)
|
| 1104 |
+
for key, value in sample.items():
|
| 1105 |
+
if not (key.startswith("candidate_") or key.startswith("proposal_target_")):
|
| 1106 |
+
continue
|
| 1107 |
+
padded[key] = _pad_candidate_axis(
|
| 1108 |
+
value,
|
| 1109 |
+
source_candidates=source_candidates,
|
| 1110 |
+
expected_candidates=EXPECTED_PROPOSAL_CANDIDATES,
|
| 1111 |
+
pad_indices=pad_indices,
|
| 1112 |
+
)
|
| 1113 |
+
return padded
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
def _normalize_cached_samples(samples: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 1117 |
+
normalized_samples: list[dict[str, Any]] = []
|
| 1118 |
+
for sample in samples:
|
| 1119 |
+
patched = dict(sample)
|
| 1120 |
+
for key in ("depths", "depth_valid", "history_depths", "history_depth_valid"):
|
| 1121 |
+
if key in patched:
|
| 1122 |
+
patched[key] = _normalize_depth_array(patched[key])
|
| 1123 |
+
patched = _normalize_candidate_targets(patched)
|
| 1124 |
+
normalized_samples.append(patched)
|
| 1125 |
+
return normalized_samples
|
| 1126 |
+
|
| 1127 |
+
|
| 1128 |
+
def _collect_split(
|
| 1129 |
+
*,
|
| 1130 |
+
task_spec: BridgeTaskSpec,
|
| 1131 |
+
canonical_chunks: dict[str, np.ndarray],
|
| 1132 |
+
split_name: str,
|
| 1133 |
+
seeds: Sequence[int],
|
| 1134 |
+
spec: SmokeSpec,
|
| 1135 |
+
output_path: Path,
|
| 1136 |
+
) -> dict[str, Any]:
|
| 1137 |
+
obs_env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
|
| 1138 |
+
sim_env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
|
| 1139 |
+
samples: list[dict[str, Any]] = []
|
| 1140 |
+
episode_records: list[dict[str, Any]] = []
|
| 1141 |
+
try:
|
| 1142 |
+
for episode_seed in seeds:
|
| 1143 |
+
obs, _ = obs_env.reset(seed=int(episode_seed))
|
| 1144 |
+
sim_env.reset(seed=int(episode_seed))
|
| 1145 |
+
_initialize_proxy_state(obs_env, task_spec, episode_seed=int(episode_seed))
|
| 1146 |
+
_sync_env_state(obs_env, sim_env)
|
| 1147 |
+
obs = obs_env.get_obs(obs_env.get_info())
|
| 1148 |
+
episode_start_positions = _all_positions(obs_env, task_spec)
|
| 1149 |
+
history: deque[dict[str, Any]] = deque(maxlen=spec.history_steps)
|
| 1150 |
+
episode_success = False
|
| 1151 |
+
for decision_step in range(spec.max_macro_steps):
|
| 1152 |
+
obs_bundle = _extract_sensor_bundle(obs, resolution=spec.resolution)
|
| 1153 |
+
proprio = _build_proprio(obs_env)
|
| 1154 |
+
snapshot = _snapshot_env(obs_env)
|
| 1155 |
+
candidate_metrics = [
|
| 1156 |
+
_evaluate_candidate(task_spec, sim_env, obs_env, snapshot, mode_name) for mode_name in task_spec.mode_order
|
| 1157 |
+
]
|
| 1158 |
+
candidate_chunks = np.stack([canonical_chunks[mode_name] for mode_name in task_spec.mode_order], axis=0).astype(np.float32)
|
| 1159 |
+
utilities = np.asarray([payload["candidate_utility"] for payload in candidate_metrics], dtype=np.float32)
|
| 1160 |
+
selected_mode = _select_expert_mode(task_spec, decision_step=decision_step, candidate_metrics=candidate_metrics)
|
| 1161 |
+
state_targets = _current_state_targets(
|
| 1162 |
+
task_spec,
|
| 1163 |
+
env=obs_env,
|
| 1164 |
+
obs_bundle=obs_bundle,
|
| 1165 |
+
candidate_metrics=candidate_metrics,
|
| 1166 |
+
episode_start_positions=episode_start_positions,
|
| 1167 |
+
selected_mode=selected_mode,
|
| 1168 |
+
)
|
| 1169 |
+
rollout_targets_by_mode = [
|
| 1170 |
+
_candidate_rollout_targets(task_spec, mode_name=mode_name, state_targets=state_targets, candidate_payload=payload)
|
| 1171 |
+
for mode_name, payload in zip(task_spec.mode_order, candidate_metrics)
|
| 1172 |
+
]
|
| 1173 |
+
sample = {
|
| 1174 |
+
"images": obs_bundle["images"].copy(),
|
| 1175 |
+
"depths": obs_bundle["depths"].copy(),
|
| 1176 |
+
"depth_valid": obs_bundle["depth_valid"].copy(),
|
| 1177 |
+
"camera_intrinsics": obs_bundle["camera_intrinsics"].copy(),
|
| 1178 |
+
"camera_extrinsics": obs_bundle["camera_extrinsics"].copy(),
|
| 1179 |
+
"history_images": _history_stack(
|
| 1180 |
+
history,
|
| 1181 |
+
"images",
|
| 1182 |
+
pad_shape=obs_bundle["images"].shape,
|
| 1183 |
+
dtype=np.uint8,
|
| 1184 |
+
history_steps=spec.history_steps,
|
| 1185 |
+
),
|
| 1186 |
+
"history_depths": _history_stack(
|
| 1187 |
+
history,
|
| 1188 |
+
"depths",
|
| 1189 |
+
pad_shape=obs_bundle["depths"].shape,
|
| 1190 |
+
dtype=np.float32,
|
| 1191 |
+
history_steps=spec.history_steps,
|
| 1192 |
+
),
|
| 1193 |
+
"history_depth_valid": _history_stack(
|
| 1194 |
+
history,
|
| 1195 |
+
"depth_valid",
|
| 1196 |
+
pad_shape=obs_bundle["depth_valid"].shape,
|
| 1197 |
+
dtype=np.float32,
|
| 1198 |
+
history_steps=spec.history_steps,
|
| 1199 |
+
),
|
| 1200 |
+
"history_camera_intrinsics": _history_stack(
|
| 1201 |
+
history,
|
| 1202 |
+
"camera_intrinsics",
|
| 1203 |
+
pad_shape=obs_bundle["camera_intrinsics"].shape,
|
| 1204 |
+
dtype=np.float32,
|
| 1205 |
+
history_steps=spec.history_steps,
|
| 1206 |
+
),
|
| 1207 |
+
"history_camera_extrinsics": _history_stack(
|
| 1208 |
+
history,
|
| 1209 |
+
"camera_extrinsics",
|
| 1210 |
+
pad_shape=obs_bundle["camera_extrinsics"].shape,
|
| 1211 |
+
dtype=np.float32,
|
| 1212 |
+
history_steps=spec.history_steps,
|
| 1213 |
+
),
|
| 1214 |
+
"history_proprio": _history_stack(
|
| 1215 |
+
history,
|
| 1216 |
+
"proprio",
|
| 1217 |
+
pad_shape=(PROPRIO_DIM,),
|
| 1218 |
+
dtype=np.float32,
|
| 1219 |
+
history_steps=spec.history_steps,
|
| 1220 |
+
),
|
| 1221 |
+
"history_actions": _history_stack(
|
| 1222 |
+
history,
|
| 1223 |
+
"action",
|
| 1224 |
+
pad_shape=(14,),
|
| 1225 |
+
dtype=np.float32,
|
| 1226 |
+
history_steps=spec.history_steps,
|
| 1227 |
+
),
|
| 1228 |
+
"proprio": proprio.astype(np.float32),
|
| 1229 |
+
"language_goal": task_spec.text_prompt,
|
| 1230 |
+
"task_name": task_spec.task_name,
|
| 1231 |
+
"task_id": TASK_INDEX[task_spec.task_name],
|
| 1232 |
+
"action_chunk": canonical_chunks[selected_mode].copy(),
|
| 1233 |
+
"candidate_action_chunks": candidate_chunks,
|
| 1234 |
+
"candidate_retrieval_success": np.asarray([payload["candidate_retrieval_success"] for payload in candidate_metrics], dtype=np.float32),
|
| 1235 |
+
"candidate_final_disturbance_cost": np.asarray(
|
| 1236 |
+
[payload["candidate_final_disturbance_cost"] for payload in candidate_metrics],
|
| 1237 |
+
dtype=np.float32,
|
| 1238 |
+
),
|
| 1239 |
+
"candidate_reocclusion_rate": np.asarray([payload["candidate_reocclusion_rate"] for payload in candidate_metrics], dtype=np.float32),
|
| 1240 |
+
"candidate_visibility_integral": np.asarray(
|
| 1241 |
+
[payload["candidate_visibility_integral"] for payload in candidate_metrics],
|
| 1242 |
+
dtype=np.float32,
|
| 1243 |
+
),
|
| 1244 |
+
"candidate_actor_feasibility_auc": np.asarray(
|
| 1245 |
+
[payload["candidate_actor_feasibility_auc"] for payload in candidate_metrics],
|
| 1246 |
+
dtype=np.float32,
|
| 1247 |
+
),
|
| 1248 |
+
"candidate_reveal_achieved": np.asarray([payload["candidate_reveal_achieved"] for payload in candidate_metrics], dtype=np.float32),
|
| 1249 |
+
"candidate_hold_persistence": np.asarray([payload["candidate_hold_persistence"] for payload in candidate_metrics], dtype=np.float32),
|
| 1250 |
+
"candidate_support_stability_auc": np.asarray(
|
| 1251 |
+
[payload["candidate_support_stability_auc"] for payload in candidate_metrics],
|
| 1252 |
+
dtype=np.float32,
|
| 1253 |
+
),
|
| 1254 |
+
"candidate_disturbance_auc": np.asarray([payload["candidate_disturbance_auc"] for payload in candidate_metrics], dtype=np.float32),
|
| 1255 |
+
"candidate_risk": np.asarray([payload["candidate_risk"] for payload in candidate_metrics], dtype=np.float32),
|
| 1256 |
+
"candidate_utility": utilities,
|
| 1257 |
+
"candidate_rollout_support_mode": np.stack(
|
| 1258 |
+
[payload["candidate_rollout_support_mode"] for payload in rollout_targets_by_mode],
|
| 1259 |
+
axis=0,
|
| 1260 |
+
).astype(np.int64),
|
| 1261 |
+
"candidate_rollout_corridor_feasible": np.stack(
|
| 1262 |
+
[payload["candidate_rollout_corridor_feasible"] for payload in rollout_targets_by_mode],
|
| 1263 |
+
axis=0,
|
| 1264 |
+
).astype(np.float32),
|
| 1265 |
+
"candidate_rollout_persistence_horizon": np.stack(
|
| 1266 |
+
[payload["candidate_rollout_persistence_horizon"] for payload in rollout_targets_by_mode],
|
| 1267 |
+
axis=0,
|
| 1268 |
+
).astype(np.float32),
|
| 1269 |
+
"candidate_rollout_disturbance_cost": np.stack(
|
| 1270 |
+
[payload["candidate_rollout_disturbance_cost"] for payload in rollout_targets_by_mode],
|
| 1271 |
+
axis=0,
|
| 1272 |
+
).astype(np.float32),
|
| 1273 |
+
"candidate_rollout_belief_map": np.stack(
|
| 1274 |
+
[payload["candidate_rollout_belief_map"] for payload in rollout_targets_by_mode],
|
| 1275 |
+
axis=0,
|
| 1276 |
+
).astype(np.float32),
|
| 1277 |
+
"candidate_rollout_visibility_map": np.stack(
|
| 1278 |
+
[payload["candidate_rollout_visibility_map"] for payload in rollout_targets_by_mode],
|
| 1279 |
+
axis=0,
|
| 1280 |
+
).astype(np.float32),
|
| 1281 |
+
"candidate_rollout_clearance_map": np.stack(
|
| 1282 |
+
[payload["candidate_rollout_clearance_map"] for payload in rollout_targets_by_mode],
|
| 1283 |
+
axis=0,
|
| 1284 |
+
).astype(np.float32),
|
| 1285 |
+
"candidate_rollout_support_stability": np.stack(
|
| 1286 |
+
[payload["candidate_rollout_support_stability"] for payload in rollout_targets_by_mode],
|
| 1287 |
+
axis=0,
|
| 1288 |
+
).astype(np.float32),
|
| 1289 |
+
"candidate_rollout_reocclusion_target": np.stack(
|
| 1290 |
+
[payload["candidate_rollout_reocclusion_target"] for payload in rollout_targets_by_mode],
|
| 1291 |
+
axis=0,
|
| 1292 |
+
).astype(np.float32),
|
| 1293 |
+
"candidate_rollout_occluder_contact_map": np.stack(
|
| 1294 |
+
[payload["candidate_rollout_occluder_contact_map"] for payload in rollout_targets_by_mode],
|
| 1295 |
+
axis=0,
|
| 1296 |
+
).astype(np.float32),
|
| 1297 |
+
"candidate_rollout_grasp_affordance_map": np.stack(
|
| 1298 |
+
[payload["candidate_rollout_grasp_affordance_map"] for payload in rollout_targets_by_mode],
|
| 1299 |
+
axis=0,
|
| 1300 |
+
).astype(np.float32),
|
| 1301 |
+
"proposal_target_rollout_support_mode": np.stack(
|
| 1302 |
+
[payload["candidate_rollout_support_mode"] for payload in rollout_targets_by_mode],
|
| 1303 |
+
axis=0,
|
| 1304 |
+
).astype(np.int64),
|
| 1305 |
+
"proposal_target_rollout_corridor_feasible": np.stack(
|
| 1306 |
+
[payload["candidate_rollout_corridor_feasible"] for payload in rollout_targets_by_mode],
|
| 1307 |
+
axis=0,
|
| 1308 |
+
).astype(np.float32),
|
| 1309 |
+
"proposal_target_rollout_persistence_horizon": np.stack(
|
| 1310 |
+
[payload["candidate_rollout_persistence_horizon"] for payload in rollout_targets_by_mode],
|
| 1311 |
+
axis=0,
|
| 1312 |
+
).astype(np.float32),
|
| 1313 |
+
"proposal_target_rollout_disturbance_cost": np.stack(
|
| 1314 |
+
[payload["candidate_rollout_disturbance_cost"] for payload in rollout_targets_by_mode],
|
| 1315 |
+
axis=0,
|
| 1316 |
+
).astype(np.float32),
|
| 1317 |
+
"proposal_target_rollout_belief_map": np.stack(
|
| 1318 |
+
[payload["candidate_rollout_belief_map"] for payload in rollout_targets_by_mode],
|
| 1319 |
+
axis=0,
|
| 1320 |
+
).astype(np.float32),
|
| 1321 |
+
"proposal_target_rollout_visibility_map": np.stack(
|
| 1322 |
+
[payload["candidate_rollout_visibility_map"] for payload in rollout_targets_by_mode],
|
| 1323 |
+
axis=0,
|
| 1324 |
+
).astype(np.float32),
|
| 1325 |
+
"proposal_target_rollout_clearance_map": np.stack(
|
| 1326 |
+
[payload["candidate_rollout_clearance_map"] for payload in rollout_targets_by_mode],
|
| 1327 |
+
axis=0,
|
| 1328 |
+
).astype(np.float32),
|
| 1329 |
+
"proposal_target_rollout_support_stability": np.stack(
|
| 1330 |
+
[payload["candidate_rollout_support_stability"] for payload in rollout_targets_by_mode],
|
| 1331 |
+
axis=0,
|
| 1332 |
+
).astype(np.float32),
|
| 1333 |
+
"proposal_target_rollout_reocclusion_target": np.stack(
|
| 1334 |
+
[payload["candidate_rollout_reocclusion_target"] for payload in rollout_targets_by_mode],
|
| 1335 |
+
axis=0,
|
| 1336 |
+
).astype(np.float32),
|
| 1337 |
+
"proposal_target_rollout_occluder_contact_map": np.stack(
|
| 1338 |
+
[payload["candidate_rollout_occluder_contact_map"] for payload in rollout_targets_by_mode],
|
| 1339 |
+
axis=0,
|
| 1340 |
+
).astype(np.float32),
|
| 1341 |
+
"proposal_target_rollout_grasp_affordance_map": np.stack(
|
| 1342 |
+
[payload["candidate_rollout_grasp_affordance_map"] for payload in rollout_targets_by_mode],
|
| 1343 |
+
axis=0,
|
| 1344 |
+
).astype(np.float32),
|
| 1345 |
+
"episode_seed": int(episode_seed),
|
| 1346 |
+
"decision_step": int(decision_step),
|
| 1347 |
+
"selected_mode": selected_mode,
|
| 1348 |
+
**state_targets,
|
| 1349 |
+
}
|
| 1350 |
+
samples.append(sample)
|
| 1351 |
+
_execute_mode(obs_env, task_spec, selected_mode)
|
| 1352 |
+
obs = obs_env.get_obs(obs_env.get_info())
|
| 1353 |
+
post_bundle = _extract_sensor_bundle(obs, resolution=spec.resolution)
|
| 1354 |
+
history.append(_init_history_entry(obs_bundle, proprio, canonical_chunks[selected_mode]))
|
| 1355 |
+
if (
|
| 1356 |
+
_bag_success(obs_env)
|
| 1357 |
+
if task_spec.key == "bag"
|
| 1358 |
+
else _cloth_success(
|
| 1359 |
+
obs_env,
|
| 1360 |
+
start_positions=episode_start_positions,
|
| 1361 |
+
current_visibility=_source_visibility(post_bundle, int(getattr(_source_actor(obs_env), "per_scene_id", -1))),
|
| 1362 |
+
)
|
| 1363 |
+
):
|
| 1364 |
+
episode_success = True
|
| 1365 |
+
break
|
| 1366 |
+
episode_records.append({"episode_seed": int(episode_seed), "success": episode_success, "steps": len(history)})
|
| 1367 |
+
print(
|
| 1368 |
+
json.dumps(
|
| 1369 |
+
{
|
| 1370 |
+
"phase": "collect_episode_complete",
|
| 1371 |
+
"task": task_spec.key,
|
| 1372 |
+
"split": split_name,
|
| 1373 |
+
"episode_seed": int(episode_seed),
|
| 1374 |
+
"success": episode_success,
|
| 1375 |
+
"steps": len(history),
|
| 1376 |
+
"samples_collected": len(samples),
|
| 1377 |
+
}
|
| 1378 |
+
),
|
| 1379 |
+
flush=True,
|
| 1380 |
+
)
|
| 1381 |
+
finally:
|
| 1382 |
+
obs_env.close()
|
| 1383 |
+
sim_env.close()
|
| 1384 |
+
payload = {
|
| 1385 |
+
"split_name": split_name,
|
| 1386 |
+
"resolution": spec.resolution,
|
| 1387 |
+
"history_steps": spec.history_steps,
|
| 1388 |
+
"samples": samples,
|
| 1389 |
+
"episode_records": episode_records,
|
| 1390 |
+
}
|
| 1391 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1392 |
+
torch.save(payload, output_path)
|
| 1393 |
+
return payload
|
| 1394 |
+
|
| 1395 |
+
|
| 1396 |
+
def _manual_train_spec(task_spec: BridgeTaskSpec, variant: str, spec: SmokeSpec) -> dict[str, Any]:
|
| 1397 |
+
return {
|
| 1398 |
+
"track_id": task_spec.track_id,
|
| 1399 |
+
"suite": task_spec.suite,
|
| 1400 |
+
"benchmark_task": task_spec.benchmark_task,
|
| 1401 |
+
"model_variant": str(variant),
|
| 1402 |
+
"seed": int(spec.train_seed),
|
| 1403 |
+
"train_demos": int(spec.train_episodes),
|
| 1404 |
+
"val_demos": int(spec.val_episodes),
|
| 1405 |
+
"init_checkpoint_group": str(DEFAULT_INIT_CHECKPOINT),
|
| 1406 |
+
"optimizer": "adamw",
|
| 1407 |
+
"learning_rate": float(spec.learning_rate),
|
| 1408 |
+
"lr_schedule": "constant",
|
| 1409 |
+
"batch_size": int(spec.batch_size),
|
| 1410 |
+
"augmentations": "none",
|
| 1411 |
+
"early_stopping_metric": "val_total",
|
| 1412 |
+
"max_gradient_steps": int(spec.epochs * math.ceil(max(1, spec.train_episodes) / max(1, spec.batch_size))),
|
| 1413 |
+
"unfreeze_scope": "fusion_memory_decoder",
|
| 1414 |
+
"dataset_split_id": (
|
| 1415 |
+
f"{task_spec.key}_{SMOKE_VERSION}_seed{spec.dataset_seed}"
|
| 1416 |
+
if int(spec.dataset_seed) == DEFAULT_SEED
|
| 1417 |
+
else f"{task_spec.key}_{SMOKE_VERSION}_dataset_seed{spec.dataset_seed}"
|
| 1418 |
+
),
|
| 1419 |
+
"same_data_policy": True,
|
| 1420 |
+
"same_init_policy": True,
|
| 1421 |
+
}
|
| 1422 |
+
|
| 1423 |
+
|
| 1424 |
+
def _trainer_config_for_variant(variant: str) -> TrainerConfig:
|
| 1425 |
+
if variant == "trunk_only_ft":
|
| 1426 |
+
return TrainerConfig(
|
| 1427 |
+
policy_type="trunk_only",
|
| 1428 |
+
trainable_parameter_prefixes=("fusion", "memory", "decoder"),
|
| 1429 |
+
eval_mode="trunk_only",
|
| 1430 |
+
)
|
| 1431 |
+
if variant == "adapter_active_ft":
|
| 1432 |
+
return TrainerConfig(
|
| 1433 |
+
policy_type="adapter_wrapped",
|
| 1434 |
+
trainable_parameter_prefixes=(
|
| 1435 |
+
"trunk.fusion",
|
| 1436 |
+
"trunk.memory",
|
| 1437 |
+
"trunk.decoder",
|
| 1438 |
+
"adapter.state_head",
|
| 1439 |
+
"adapter.transition_model",
|
| 1440 |
+
"adapter.proposal_prior",
|
| 1441 |
+
"adapter.planner",
|
| 1442 |
+
),
|
| 1443 |
+
adapter_mode="adapter_active",
|
| 1444 |
+
eval_mode="adapter_active",
|
| 1445 |
+
adapter_use_transition_model=True,
|
| 1446 |
+
adapter_use_task_conditioning=True,
|
| 1447 |
+
adapter_action_supervision_source="trunk",
|
| 1448 |
+
)
|
| 1449 |
+
raise KeyError(f"Unsupported variant {variant!r}.")
|
| 1450 |
+
|
| 1451 |
+
|
| 1452 |
+
def _loss_weights_for_smoke(task_spec: BridgeTaskSpec) -> LossWeights:
|
| 1453 |
+
return LossWeights(
|
| 1454 |
+
action=1.0,
|
| 1455 |
+
support_mode=0.15,
|
| 1456 |
+
corridor=0.15,
|
| 1457 |
+
persistence=0.08,
|
| 1458 |
+
disturbance=0.08,
|
| 1459 |
+
planner_success=0.20,
|
| 1460 |
+
planner_risk=0.08,
|
| 1461 |
+
planner_ranking=0.20,
|
| 1462 |
+
proposal_reconstruction=0.10,
|
| 1463 |
+
proposal_success=0.12,
|
| 1464 |
+
proposal_ranking=0.15,
|
| 1465 |
+
proposal_mode=0.10,
|
| 1466 |
+
proposal_diversity=0.02,
|
| 1467 |
+
task_metrics=0.15,
|
| 1468 |
+
transition=0.25,
|
| 1469 |
+
gate=0.25,
|
| 1470 |
+
calibration=0.10,
|
| 1471 |
+
proposal_mode_task_filter=[task_spec.task_name],
|
| 1472 |
+
)
|
| 1473 |
+
|
| 1474 |
+
|
| 1475 |
+
def _train_variant(
|
| 1476 |
+
*,
|
| 1477 |
+
task_spec: BridgeTaskSpec,
|
| 1478 |
+
variant: str,
|
| 1479 |
+
train_samples: Sequence[dict[str, Any]],
|
| 1480 |
+
val_samples: Sequence[dict[str, Any]],
|
| 1481 |
+
spec: SmokeSpec,
|
| 1482 |
+
output_dir: Path,
|
| 1483 |
+
) -> tuple[Path, dict[str, Any]]:
|
| 1484 |
+
policy_config, _init_trainer_cfg, _init_loss_weights = _load_init_bundle()
|
| 1485 |
+
policy_config = _apply_smoke_planner_overrides(policy_config)
|
| 1486 |
+
trainer_config = _trainer_config_for_variant(variant)
|
| 1487 |
+
loss_weights = _loss_weights_for_smoke(task_spec)
|
| 1488 |
+
model = build_policy(policy_config, trainer_config)
|
| 1489 |
+
init_info = _load_init_checkpoint(model, str(DEFAULT_INIT_CHECKPOINT), False)
|
| 1490 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1491 |
+
model = model.to(device)
|
| 1492 |
+
torch.manual_seed(spec.train_seed)
|
| 1493 |
+
if torch.cuda.is_available():
|
| 1494 |
+
torch.cuda.manual_seed_all(spec.train_seed)
|
| 1495 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 1496 |
+
matched = apply_trainable_parameter_prefixes(model, trainer_config)
|
| 1497 |
+
optimizer = torch.optim.AdamW(
|
| 1498 |
+
[parameter for parameter in model.parameters() if parameter.requires_grad],
|
| 1499 |
+
lr=spec.learning_rate,
|
| 1500 |
+
weight_decay=spec.weight_decay,
|
| 1501 |
+
)
|
| 1502 |
+
trainer = BimanualTrainer(model=model, optimizer=optimizer, config=trainer_config)
|
| 1503 |
+
train_loader = _make_loader(train_samples, batch_size=spec.batch_size, num_workers=spec.num_workers, shuffle=True)
|
| 1504 |
+
val_loader = _make_loader(val_samples, batch_size=spec.batch_size, num_workers=spec.num_workers, shuffle=False)
|
| 1505 |
+
best_val = math.inf
|
| 1506 |
+
history: list[dict[str, Any]] = []
|
| 1507 |
+
train_spec = _manual_train_spec(task_spec, variant, spec)
|
| 1508 |
+
train_spec["max_gradient_steps"] = len(train_loader) * spec.epochs
|
| 1509 |
+
for epoch in range(spec.epochs):
|
| 1510 |
+
model.train()
|
| 1511 |
+
train_losses: list[dict[str, float]] = []
|
| 1512 |
+
for batch in train_loader:
|
| 1513 |
+
moved = _move_batch_to_device(batch, device)
|
| 1514 |
+
loss_dict = trainer.training_step(moved, loss_weights=loss_weights)
|
| 1515 |
+
train_losses.append({key: float(value.detach().cpu()) for key, value in loss_dict.items()})
|
| 1516 |
+
model.eval()
|
| 1517 |
+
val_losses: list[dict[str, float]] = []
|
| 1518 |
+
with torch.no_grad():
|
| 1519 |
+
for batch in val_loader:
|
| 1520 |
+
moved = _move_batch_to_device(batch, device)
|
| 1521 |
+
forward_kwargs = {
|
| 1522 |
+
"images": moved["images"],
|
| 1523 |
+
"proprio": moved["proprio"],
|
| 1524 |
+
"texts": moved["texts"],
|
| 1525 |
+
"task_names": moved.get("task_name"),
|
| 1526 |
+
"task_ids": moved.get("task_id"),
|
| 1527 |
+
"history_images": moved.get("history_images"),
|
| 1528 |
+
"history_proprio": moved.get("history_proprio"),
|
| 1529 |
+
"history_actions": moved.get("history_actions"),
|
| 1530 |
+
"depths": moved.get("depths"),
|
| 1531 |
+
"depth_valid": moved.get("depth_valid"),
|
| 1532 |
+
"camera_intrinsics": moved.get("camera_intrinsics"),
|
| 1533 |
+
"camera_extrinsics": moved.get("camera_extrinsics"),
|
| 1534 |
+
"history_depths": moved.get("history_depths"),
|
| 1535 |
+
"history_depth_valid": moved.get("history_depth_valid"),
|
| 1536 |
+
"history_camera_intrinsics": moved.get("history_camera_intrinsics"),
|
| 1537 |
+
"history_camera_extrinsics": moved.get("history_camera_extrinsics"),
|
| 1538 |
+
}
|
| 1539 |
+
if variant == "adapter_active_ft":
|
| 1540 |
+
forward_kwargs["adapter_mode"] = "adapter_active"
|
| 1541 |
+
forward_kwargs["use_transition_model"] = True
|
| 1542 |
+
forward_kwargs["use_task_conditioning"] = True
|
| 1543 |
+
outputs = model(**forward_kwargs)
|
| 1544 |
+
losses = compute_total_loss(outputs, moved, weights=loss_weights)
|
| 1545 |
+
val_losses.append({key: float(value.detach().cpu()) for key, value in losses.items()})
|
| 1546 |
+
train_summary = _aggregate_epoch(train_losses)
|
| 1547 |
+
val_summary = _aggregate_epoch(val_losses)
|
| 1548 |
+
history.append({"epoch": epoch, "train": train_summary, "val": val_summary})
|
| 1549 |
+
print(
|
| 1550 |
+
json.dumps(
|
| 1551 |
+
{
|
| 1552 |
+
"phase": "epoch_complete",
|
| 1553 |
+
"task": task_spec.key,
|
| 1554 |
+
"variant": variant,
|
| 1555 |
+
"epoch": epoch,
|
| 1556 |
+
"train_total": train_summary.get("total", 0.0),
|
| 1557 |
+
"val_total": val_summary.get("total", 0.0),
|
| 1558 |
+
}
|
| 1559 |
+
),
|
| 1560 |
+
flush=True,
|
| 1561 |
+
)
|
| 1562 |
+
if val_summary.get("total", math.inf) <= best_val:
|
| 1563 |
+
best_val = val_summary["total"]
|
| 1564 |
+
checkpoint_path = _save_training_checkpoint(
|
| 1565 |
+
output_dir=output_dir,
|
| 1566 |
+
experiment_name=f"{task_spec.key}_{variant}_seed{spec.train_seed}",
|
| 1567 |
+
model=model,
|
| 1568 |
+
policy_config=policy_config,
|
| 1569 |
+
trainer_config=trainer_config,
|
| 1570 |
+
loss_weights=loss_weights,
|
| 1571 |
+
history=history,
|
| 1572 |
+
best_val=best_val,
|
| 1573 |
+
train_spec=train_spec,
|
| 1574 |
+
)
|
| 1575 |
+
(output_dir / "summary.json").write_text(
|
| 1576 |
+
json.dumps(
|
| 1577 |
+
{
|
| 1578 |
+
"task": task_spec.key,
|
| 1579 |
+
"variant": variant,
|
| 1580 |
+
"checkpoint_path": str(checkpoint_path),
|
| 1581 |
+
"init_info": init_info,
|
| 1582 |
+
"trainable_parameter_names": matched,
|
| 1583 |
+
"best_val_total": best_val,
|
| 1584 |
+
"history": history,
|
| 1585 |
+
"train_spec": train_spec,
|
| 1586 |
+
},
|
| 1587 |
+
indent=2,
|
| 1588 |
+
)
|
| 1589 |
+
+ "\n",
|
| 1590 |
+
encoding="utf-8",
|
| 1591 |
+
)
|
| 1592 |
+
return output_dir / "checkpoint_best.pt", train_spec
|
| 1593 |
+
|
| 1594 |
+
|
| 1595 |
+
def _eval_mode_name(model_output: dict[str, Any], result_mode_name: str, canonical_chunks: dict[str, np.ndarray]) -> tuple[str, bool, bool]:
|
| 1596 |
+
if result_mode_name == "adapter_active_ft" and "proposal_mode_names" in model_output and "best_candidate_indices" in model_output:
|
| 1597 |
+
active_mask = bool(_np(model_output.get("adapter_active_mask", np.asarray([False]))).reshape(-1)[0])
|
| 1598 |
+
if not active_mask:
|
| 1599 |
+
return _classify_mode_from_chunk(_np(model_output["action_mean"])[0], canonical_chunks), False, False
|
| 1600 |
+
best_index = int(_np(model_output["best_candidate_indices"])[0])
|
| 1601 |
+
proposal_mode_names = model_output["proposal_mode_names"][0]
|
| 1602 |
+
if best_index < len(proposal_mode_names):
|
| 1603 |
+
mode_name = str(proposal_mode_names[best_index])
|
| 1604 |
+
else:
|
| 1605 |
+
mode_name = _classify_mode_from_chunk(_np(model_output["action_mean"])[0], canonical_chunks)
|
| 1606 |
+
return mode_name, active_mask, bool(best_index > 0)
|
| 1607 |
+
return _classify_mode_from_chunk(_np(model_output["action_mean"])[0], canonical_chunks), False, False
|
| 1608 |
+
|
| 1609 |
+
|
| 1610 |
+
def _manual_eval_protocol(task_spec: BridgeTaskSpec, *, eval_mode: str, spec: SmokeSpec, episodes: int) -> dict[str, Any]:
|
| 1611 |
+
return {
|
| 1612 |
+
"track_id": task_spec.track_id,
|
| 1613 |
+
"suite": task_spec.suite,
|
| 1614 |
+
"benchmark_task": task_spec.benchmark_task,
|
| 1615 |
+
"role": "target",
|
| 1616 |
+
"eval_mode": eval_mode,
|
| 1617 |
+
"seed": int(spec.dataset_seed),
|
| 1618 |
+
"episodes": int(episodes),
|
| 1619 |
+
"resolution": int(spec.resolution),
|
| 1620 |
+
"cameras": tuple(CAMERA_NAMES),
|
| 1621 |
+
"observation_stack": "rgb_triplicate_zero_depth",
|
| 1622 |
+
"action_horizon": 8,
|
| 1623 |
+
"action_space": "widowx_delta_pose",
|
| 1624 |
+
"same_test_episodes": True,
|
| 1625 |
+
}
|
| 1626 |
+
|
| 1627 |
+
|
| 1628 |
+
def _batch_from_obs(
|
| 1629 |
+
task_spec: BridgeTaskSpec,
|
| 1630 |
+
obs_bundle: dict[str, np.ndarray],
|
| 1631 |
+
proprio: np.ndarray,
|
| 1632 |
+
history: Sequence[dict[str, Any]],
|
| 1633 |
+
device: torch.device,
|
| 1634 |
+
) -> dict[str, Any]:
|
| 1635 |
+
return {
|
| 1636 |
+
"images": torch.from_numpy(obs_bundle["images"]).permute(0, 3, 1, 2).unsqueeze(0).float().div(255.0).to(device),
|
| 1637 |
+
"depths": torch.from_numpy(obs_bundle["depths"]).unsqueeze(0).float().to(device),
|
| 1638 |
+
"depth_valid": torch.from_numpy(obs_bundle["depth_valid"]).unsqueeze(0).float().to(device),
|
| 1639 |
+
"camera_intrinsics": torch.from_numpy(obs_bundle["camera_intrinsics"]).unsqueeze(0).float().to(device),
|
| 1640 |
+
"camera_extrinsics": torch.from_numpy(obs_bundle["camera_extrinsics"]).unsqueeze(0).float().to(device),
|
| 1641 |
+
"history_images": torch.from_numpy(
|
| 1642 |
+
_history_stack(history, "images", pad_shape=obs_bundle["images"].shape, dtype=np.uint8, history_steps=HISTORY_STEPS)
|
| 1643 |
+
).permute(0, 1, 4, 2, 3).unsqueeze(0).float().div(255.0).to(device),
|
| 1644 |
+
"history_depths": torch.from_numpy(
|
| 1645 |
+
_history_stack(history, "depths", pad_shape=obs_bundle["depths"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
|
| 1646 |
+
).unsqueeze(0).float().to(device),
|
| 1647 |
+
"history_depth_valid": torch.from_numpy(
|
| 1648 |
+
_history_stack(
|
| 1649 |
+
history,
|
| 1650 |
+
"depth_valid",
|
| 1651 |
+
pad_shape=obs_bundle["depth_valid"].shape,
|
| 1652 |
+
dtype=np.float32,
|
| 1653 |
+
history_steps=HISTORY_STEPS,
|
| 1654 |
+
)
|
| 1655 |
+
).unsqueeze(0).float().to(device),
|
| 1656 |
+
"history_camera_intrinsics": torch.from_numpy(
|
| 1657 |
+
_history_stack(
|
| 1658 |
+
history,
|
| 1659 |
+
"camera_intrinsics",
|
| 1660 |
+
pad_shape=obs_bundle["camera_intrinsics"].shape,
|
| 1661 |
+
dtype=np.float32,
|
| 1662 |
+
history_steps=HISTORY_STEPS,
|
| 1663 |
+
)
|
| 1664 |
+
).unsqueeze(0).float().to(device),
|
| 1665 |
+
"history_camera_extrinsics": torch.from_numpy(
|
| 1666 |
+
_history_stack(
|
| 1667 |
+
history,
|
| 1668 |
+
"camera_extrinsics",
|
| 1669 |
+
pad_shape=obs_bundle["camera_extrinsics"].shape,
|
| 1670 |
+
dtype=np.float32,
|
| 1671 |
+
history_steps=HISTORY_STEPS,
|
| 1672 |
+
)
|
| 1673 |
+
).unsqueeze(0).float().to(device),
|
| 1674 |
+
"history_proprio": torch.from_numpy(
|
| 1675 |
+
_history_stack(history, "proprio", pad_shape=(PROPRIO_DIM,), dtype=np.float32, history_steps=HISTORY_STEPS)
|
| 1676 |
+
).unsqueeze(0).float().to(device),
|
| 1677 |
+
"history_actions": torch.from_numpy(
|
| 1678 |
+
_history_stack(history, "action", pad_shape=(14,), dtype=np.float32, history_steps=HISTORY_STEPS)
|
| 1679 |
+
).unsqueeze(0).float().to(device),
|
| 1680 |
+
"proprio": torch.from_numpy(proprio).unsqueeze(0).float().to(device),
|
| 1681 |
+
"texts": [task_spec.text_prompt],
|
| 1682 |
+
"task_names": [task_spec.task_name],
|
| 1683 |
+
"task_ids": torch.as_tensor([TASK_INDEX[task_spec.task_name]], dtype=torch.long, device=device),
|
| 1684 |
+
}
|
| 1685 |
+
|
| 1686 |
+
|
| 1687 |
+
def _evaluate_checkpoint(
|
| 1688 |
+
*,
|
| 1689 |
+
task_spec: BridgeTaskSpec,
|
| 1690 |
+
canonical_chunks: dict[str, np.ndarray],
|
| 1691 |
+
checkpoint_path: Path,
|
| 1692 |
+
adapter_mode: str,
|
| 1693 |
+
result_mode_name: str,
|
| 1694 |
+
seeds: Sequence[int],
|
| 1695 |
+
report_path: Path,
|
| 1696 |
+
train_spec: dict[str, Any] | None,
|
| 1697 |
+
planner_overrides: dict[str, float] | None,
|
| 1698 |
+
) -> dict[str, Any]:
|
| 1699 |
+
model, checkpoint = _load_checkpoint(
|
| 1700 |
+
checkpoint_path,
|
| 1701 |
+
adapter_mode=adapter_mode if adapter_mode != "trunk_only" else None,
|
| 1702 |
+
planner_overrides=planner_overrides,
|
| 1703 |
+
)
|
| 1704 |
+
device = next(model.parameters()).device
|
| 1705 |
+
obs_env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
|
| 1706 |
+
sim_env = gym.make(task_spec.env_id, obs_mode="rgb+segmentation", render_mode="rgb_array")
|
| 1707 |
+
successes: list[int] = []
|
| 1708 |
+
episode_records: list[dict[str, Any]] = []
|
| 1709 |
+
reveal_steps: list[int] = []
|
| 1710 |
+
retrieve_steps: list[int] = []
|
| 1711 |
+
disturbance_values: list[float] = []
|
| 1712 |
+
intervention_events = 0
|
| 1713 |
+
non_base_events = 0
|
| 1714 |
+
total_decisions = 0
|
| 1715 |
+
try:
|
| 1716 |
+
for episode_seed in seeds:
|
| 1717 |
+
obs, _ = obs_env.reset(seed=int(episode_seed))
|
| 1718 |
+
sim_env.reset(seed=int(episode_seed))
|
| 1719 |
+
_initialize_proxy_state(obs_env, task_spec, episode_seed=int(episode_seed))
|
| 1720 |
+
_sync_env_state(obs_env, sim_env)
|
| 1721 |
+
obs = obs_env.get_obs(obs_env.get_info())
|
| 1722 |
+
history: deque[dict[str, Any]] = deque(maxlen=HISTORY_STEPS)
|
| 1723 |
+
episode_start_positions = _all_positions(obs_env, task_spec)
|
| 1724 |
+
success = False
|
| 1725 |
+
first_reveal_step: int | None = None
|
| 1726 |
+
first_retrieve_step: int | None = None
|
| 1727 |
+
episode_disturbance: list[float] = []
|
| 1728 |
+
for decision_step in range(MAX_MACRO_STEPS):
|
| 1729 |
+
obs_bundle = _extract_sensor_bundle(obs, resolution=IMAGE_RESOLUTION)
|
| 1730 |
+
proprio = _build_proprio(obs_env)
|
| 1731 |
+
batch = _batch_from_obs(task_spec, obs_bundle, proprio, list(history), device)
|
| 1732 |
+
with torch.no_grad():
|
| 1733 |
+
if adapter_mode == "trunk_only":
|
| 1734 |
+
outputs = model(**batch)
|
| 1735 |
+
else:
|
| 1736 |
+
outputs = model(**batch, adapter_mode=adapter_mode, use_transition_model=True, use_task_conditioning=True)
|
| 1737 |
+
selected_mode, active_mask, non_base = _eval_mode_name(outputs, result_mode_name, canonical_chunks)
|
| 1738 |
+
start_positions = _all_positions(obs_env, task_spec)
|
| 1739 |
+
_sync_env_state(obs_env, sim_env)
|
| 1740 |
+
_execute_mode(sim_env, task_spec, selected_mode)
|
| 1741 |
+
_sync_env_state(sim_env, obs_env)
|
| 1742 |
+
obs = obs_env.get_obs(obs_env.get_info())
|
| 1743 |
+
post_bundle = _extract_sensor_bundle(obs, resolution=IMAGE_RESOLUTION)
|
| 1744 |
+
end_metrics = _candidate_metrics(
|
| 1745 |
+
obs_env,
|
| 1746 |
+
task_spec=task_spec,
|
| 1747 |
+
start_positions=start_positions,
|
| 1748 |
+
current_obs_bundle=post_bundle,
|
| 1749 |
+
)
|
| 1750 |
+
history.append(_init_history_entry(obs_bundle, proprio, canonical_chunks.get(selected_mode, canonical_chunks["base_action"])))
|
| 1751 |
+
total_decisions += 1
|
| 1752 |
+
intervention_events += int(active_mask)
|
| 1753 |
+
non_base_events += int(non_base)
|
| 1754 |
+
episode_disturbance.append(end_metrics["disturbance"])
|
| 1755 |
+
if selected_mode != "retrieve" and selected_mode != "base_action" and first_reveal_step is None:
|
| 1756 |
+
first_reveal_step = decision_step + 1
|
| 1757 |
+
if selected_mode == "retrieve" and first_retrieve_step is None:
|
| 1758 |
+
first_retrieve_step = decision_step + 1
|
| 1759 |
+
if (
|
| 1760 |
+
_bag_success(obs_env)
|
| 1761 |
+
if task_spec.key == "bag"
|
| 1762 |
+
else _cloth_success(
|
| 1763 |
+
obs_env,
|
| 1764 |
+
start_positions=episode_start_positions,
|
| 1765 |
+
current_visibility=end_metrics["visibility"],
|
| 1766 |
+
)
|
| 1767 |
+
):
|
| 1768 |
+
success = True
|
| 1769 |
+
break
|
| 1770 |
+
successes.append(int(success))
|
| 1771 |
+
if first_reveal_step is not None:
|
| 1772 |
+
reveal_steps.append(first_reveal_step)
|
| 1773 |
+
if first_retrieve_step is not None:
|
| 1774 |
+
retrieve_steps.append(first_retrieve_step)
|
| 1775 |
+
disturbance_values.append(float(np.mean(episode_disturbance)) if episode_disturbance else 0.0)
|
| 1776 |
+
episode_records.append(
|
| 1777 |
+
{
|
| 1778 |
+
"episode_seed": int(episode_seed),
|
| 1779 |
+
"success": success,
|
| 1780 |
+
"steps": len(history),
|
| 1781 |
+
"first_reveal_step": first_reveal_step,
|
| 1782 |
+
"first_retrieve_step": first_retrieve_step,
|
| 1783 |
+
"episode_disturbance": float(np.mean(episode_disturbance)) if episode_disturbance else 0.0,
|
| 1784 |
+
}
|
| 1785 |
+
)
|
| 1786 |
+
print(
|
| 1787 |
+
json.dumps(
|
| 1788 |
+
{
|
| 1789 |
+
"phase": "eval_episode_complete",
|
| 1790 |
+
"task": task_spec.key,
|
| 1791 |
+
"adapter_mode": result_mode_name,
|
| 1792 |
+
"episode_seed": int(episode_seed),
|
| 1793 |
+
"success": success,
|
| 1794 |
+
"steps": len(history),
|
| 1795 |
+
}
|
| 1796 |
+
),
|
| 1797 |
+
flush=True,
|
| 1798 |
+
)
|
| 1799 |
+
finally:
|
| 1800 |
+
obs_env.close()
|
| 1801 |
+
sim_env.close()
|
| 1802 |
+
payload = {
|
| 1803 |
+
"track_id": task_spec.track_id,
|
| 1804 |
+
"suite": task_spec.suite,
|
| 1805 |
+
"benchmark_task": task_spec.benchmark_task,
|
| 1806 |
+
"role": "target",
|
| 1807 |
+
"adapter_mode": result_mode_name,
|
| 1808 |
+
"episodes": len(seeds),
|
| 1809 |
+
"successes": successes,
|
| 1810 |
+
"success_rate": float(np.mean(successes)) if successes else 0.0,
|
| 1811 |
+
"intervention_rate": float(intervention_events / max(1, total_decisions)),
|
| 1812 |
+
"non_base_selection_rate": float(non_base_events / max(1, total_decisions)),
|
| 1813 |
+
"steps_to_first_reveal_or_access": float(np.mean(reveal_steps)) if reveal_steps else float(MAX_MACRO_STEPS),
|
| 1814 |
+
"steps_to_retrieve": float(np.mean(retrieve_steps)) if retrieve_steps else float(MAX_MACRO_STEPS),
|
| 1815 |
+
"disturbance_proxy": float(np.mean(disturbance_values)) if disturbance_values else 0.0,
|
| 1816 |
+
"episode_records": episode_records,
|
| 1817 |
+
"eval_protocol": _manual_eval_protocol(task_spec, eval_mode=result_mode_name, spec=SmokeSpec(), episodes=len(seeds)),
|
| 1818 |
+
"proxy_notes": task_spec.notes,
|
| 1819 |
+
}
|
| 1820 |
+
if train_spec is not None:
|
| 1821 |
+
payload["train_spec"] = train_spec
|
| 1822 |
+
report_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1823 |
+
report_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
|
| 1824 |
+
return payload
|
| 1825 |
+
|
| 1826 |
+
|
| 1827 |
+
def _patch_summary_metadata(summary: dict[str, Any], task_spec: BridgeTaskSpec) -> dict[str, Any]:
|
| 1828 |
+
patched = json.loads(json.dumps(summary))
|
| 1829 |
+
track_payload = patched.get("tracks", {}).get(task_spec.track_id)
|
| 1830 |
+
if track_payload is not None:
|
| 1831 |
+
track_payload["suite"] = task_spec.suite
|
| 1832 |
+
track_payload["benchmark_task"] = task_spec.benchmark_task
|
| 1833 |
+
track_payload["notes"] = task_spec.notes
|
| 1834 |
+
track_payload["public_source"] = f"ManiSkill public scene proxy: {task_spec.env_id}"
|
| 1835 |
+
track_payload["task_family"] = f"{task_spec.key}_retrieval_proxy"
|
| 1836 |
+
track_payload["target_behavior"] = task_spec.text_prompt
|
| 1837 |
+
return patched
|
| 1838 |
+
|
| 1839 |
+
|
| 1840 |
+
def _summarize_task(task_spec: BridgeTaskSpec, results: Sequence[dict[str, Any]], output_dir: Path) -> dict[str, Any]:
|
| 1841 |
+
summary = summarize_public_benchmark_package(list(results), allow_partial=True)
|
| 1842 |
+
summary = _patch_summary_metadata(summary, task_spec)
|
| 1843 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 1844 |
+
json_path = output_dir / "public_benchmark_package_summary.json"
|
| 1845 |
+
md_path = output_dir / "public_benchmark_package_summary.md"
|
| 1846 |
+
json_path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
| 1847 |
+
track_payload = summary["tracks"][task_spec.track_id]
|
| 1848 |
+
lines = [
|
| 1849 |
+
f"# ManiSkill {task_spec.key.capitalize()} Retrieval Smoke Summary",
|
| 1850 |
+
"",
|
| 1851 |
+
f"- benchmark_task: {task_spec.benchmark_task}",
|
| 1852 |
+
f"- target_macro_average_delta: {summary['target_macro_average_delta']:.3f}",
|
| 1853 |
+
f"- headline_pass: {summary['headline_pass']}",
|
| 1854 |
+
f"- sign_of_life_pass: {summary['sign_of_life_pass']}",
|
| 1855 |
+
"",
|
| 1856 |
+
f"## {task_spec.track_id}",
|
| 1857 |
+
f"- delta_active_vs_trunk: {track_payload.get('delta_active_vs_trunk', 0.0):.3f}",
|
| 1858 |
+
f"- delta_noop_vs_trunk: {track_payload.get('delta_noop_vs_trunk', 0.0):.3f}",
|
| 1859 |
+
f"- signs_of_life: {track_payload.get('signs_of_life', False)}",
|
| 1860 |
+
]
|
| 1861 |
+
if "delta_active_vs_trunk_ci95" in track_payload:
|
| 1862 |
+
low, high = track_payload["delta_active_vs_trunk_ci95"]
|
| 1863 |
+
lines.append(f"- delta_active_vs_trunk_ci95: [{low:.3f}, {high:.3f}]")
|
| 1864 |
+
for mode, mode_payload in track_payload["modes"].items():
|
| 1865 |
+
lines.append(f"- {mode}: mean_success={mode_payload['mean_success']:.3f}")
|
| 1866 |
+
lines.append("")
|
| 1867 |
+
md_path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
|
| 1868 |
+
return summary
|
| 1869 |
+
|
| 1870 |
+
|
| 1871 |
+
def _parse_args() -> argparse.Namespace:
|
| 1872 |
+
parser = argparse.ArgumentParser(description="Run a fair bridge-scene retrieval smoke for bag or cloth proxy tasks.")
|
| 1873 |
+
parser.add_argument("--task", choices=sorted(TASK_SPECS), required=True)
|
| 1874 |
+
parser.add_argument("--dataset-seed", type=int, default=DEFAULT_SEED)
|
| 1875 |
+
parser.add_argument("--train-seed", type=int, default=DEFAULT_SEED)
|
| 1876 |
+
parser.add_argument("--eval-split", choices=("val", "eval"), default="eval")
|
| 1877 |
+
parser.add_argument("--report-dir", type=Path, default=None)
|
| 1878 |
+
parser.add_argument("--skip-collection", action="store_true")
|
| 1879 |
+
parser.add_argument("--skip-train", action="store_true")
|
| 1880 |
+
parser.add_argument("--skip-eval", action="store_true")
|
| 1881 |
+
parser.add_argument("--reuse-dataset", action="store_true")
|
| 1882 |
+
parser.add_argument("--reuse-checkpoints", action="store_true")
|
| 1883 |
+
parser.add_argument("--adapter-confidence-threshold", type=float, default=None)
|
| 1884 |
+
parser.add_argument("--retrieve-access-threshold", type=float, default=None)
|
| 1885 |
+
parser.add_argument("--retrieve-persistence-threshold", type=float, default=None)
|
| 1886 |
+
parser.add_argument("--retrieve-support-threshold", type=float, default=None)
|
| 1887 |
+
parser.add_argument("--retrieve-reocclusion-threshold", type=float, default=None)
|
| 1888 |
+
parser.add_argument("--planner-mode-preference-bonus", type=float, default=None)
|
| 1889 |
+
parser.add_argument("--planner-premature-retrieve-penalty", type=float, default=None)
|
| 1890 |
+
parser.add_argument("--planner-premature-insert-penalty", type=float, default=None)
|
| 1891 |
+
parser.add_argument("--planner-premature-occlusion-sweep-penalty", type=float, default=None)
|
| 1892 |
+
parser.add_argument("--planner-premature-maintain-penalty", type=float, default=None)
|
| 1893 |
+
parser.add_argument("--planner-retrieve-stage-access-threshold", type=float, default=None)
|
| 1894 |
+
parser.add_argument("--planner-retrieve-stage-reveal-threshold", type=float, default=None)
|
| 1895 |
+
parser.add_argument("--planner-retrieve-stage-persistence-threshold", type=float, default=None)
|
| 1896 |
+
parser.add_argument("--planner-retrieve-stage-support-threshold", type=float, default=None)
|
| 1897 |
+
parser.add_argument("--planner-insert-stage-access-threshold", type=float, default=None)
|
| 1898 |
+
parser.add_argument("--planner-insert-stage-visibility-threshold", type=float, default=None)
|
| 1899 |
+
parser.add_argument("--planner-insert-stage-support-threshold", type=float, default=None)
|
| 1900 |
+
parser.add_argument("--planner-occlusion-maintain-gap-min-access", type=float, default=None)
|
| 1901 |
+
parser.add_argument("--planner-occlusion-maintain-gap-min-visibility", type=float, default=None)
|
| 1902 |
+
return parser.parse_args()
|
| 1903 |
+
|
| 1904 |
+
|
| 1905 |
+
def _planner_overrides_from_args(args: argparse.Namespace) -> dict[str, float]:
|
| 1906 |
+
overrides = {
|
| 1907 |
+
"adapter_confidence_threshold": SMOKE_ADAPTER_CONFIDENCE_THRESHOLD,
|
| 1908 |
+
"retrieve_access_threshold": SMOKE_RETRIEVE_ACCESS_THRESHOLD,
|
| 1909 |
+
"retrieve_persistence_threshold": SMOKE_RETRIEVE_PERSISTENCE_THRESHOLD,
|
| 1910 |
+
"retrieve_support_threshold": SMOKE_RETRIEVE_SUPPORT_THRESHOLD,
|
| 1911 |
+
"retrieve_reocclusion_threshold": SMOKE_RETRIEVE_REOCCLUSION_THRESHOLD,
|
| 1912 |
+
}
|
| 1913 |
+
optional_pairs = (
|
| 1914 |
+
("adapter_confidence_threshold", args.adapter_confidence_threshold),
|
| 1915 |
+
("retrieve_access_threshold", args.retrieve_access_threshold),
|
| 1916 |
+
("retrieve_persistence_threshold", args.retrieve_persistence_threshold),
|
| 1917 |
+
("retrieve_support_threshold", args.retrieve_support_threshold),
|
| 1918 |
+
("retrieve_reocclusion_threshold", args.retrieve_reocclusion_threshold),
|
| 1919 |
+
("mode_preference_bonus", args.planner_mode_preference_bonus),
|
| 1920 |
+
("premature_retrieve_penalty", args.planner_premature_retrieve_penalty),
|
| 1921 |
+
("premature_insert_penalty", args.planner_premature_insert_penalty),
|
| 1922 |
+
("premature_occlusion_sweep_penalty", args.planner_premature_occlusion_sweep_penalty),
|
| 1923 |
+
("premature_maintain_penalty", args.planner_premature_maintain_penalty),
|
| 1924 |
+
("retrieve_stage_access_threshold", args.planner_retrieve_stage_access_threshold),
|
| 1925 |
+
("retrieve_stage_reveal_threshold", args.planner_retrieve_stage_reveal_threshold),
|
| 1926 |
+
("retrieve_stage_persistence_threshold", args.planner_retrieve_stage_persistence_threshold),
|
| 1927 |
+
("retrieve_stage_support_threshold", args.planner_retrieve_stage_support_threshold),
|
| 1928 |
+
("insert_stage_access_threshold", args.planner_insert_stage_access_threshold),
|
| 1929 |
+
("insert_stage_visibility_threshold", args.planner_insert_stage_visibility_threshold),
|
| 1930 |
+
("insert_stage_support_threshold", args.planner_insert_stage_support_threshold),
|
| 1931 |
+
("occlusion_maintain_gap_min_access", args.planner_occlusion_maintain_gap_min_access),
|
| 1932 |
+
("occlusion_maintain_gap_min_visibility", args.planner_occlusion_maintain_gap_min_visibility),
|
| 1933 |
+
)
|
| 1934 |
+
for key, value in optional_pairs:
|
| 1935 |
+
if value is not None:
|
| 1936 |
+
overrides[key] = value
|
| 1937 |
+
return overrides
|
| 1938 |
+
|
| 1939 |
+
|
| 1940 |
+
def main() -> None:
|
| 1941 |
+
args = _parse_args()
|
| 1942 |
+
task_spec = _task_spec(args.task)
|
| 1943 |
+
spec = SmokeSpec(dataset_seed=int(args.dataset_seed), train_seed=int(args.train_seed))
|
| 1944 |
+
canonical_chunks = _canonical_chunks(task_spec)
|
| 1945 |
+
paths = _default_paths(task_spec)
|
| 1946 |
+
report_dir = args.report_dir or paths.report_dir
|
| 1947 |
+
planner_overrides = _planner_overrides_from_args(args)
|
| 1948 |
+
split_path = _dataset_artifact_path(paths.data_dir, "episode_splits.json", dataset_seed=spec.dataset_seed)
|
| 1949 |
+
if split_path.exists():
|
| 1950 |
+
episode_splits = json.loads(split_path.read_text(encoding="utf-8"))
|
| 1951 |
+
else:
|
| 1952 |
+
episode_splits = _build_episode_splits(task_spec, spec)
|
| 1953 |
+
_save_episode_splits(split_path, episode_splits)
|
| 1954 |
+
|
| 1955 |
+
train_path = _dataset_artifact_path(paths.data_dir, "train.pt", dataset_seed=spec.dataset_seed)
|
| 1956 |
+
val_path = _dataset_artifact_path(paths.data_dir, "val.pt", dataset_seed=spec.dataset_seed)
|
| 1957 |
+
if args.skip_collection and (not train_path.exists() or not val_path.exists()):
|
| 1958 |
+
raise FileNotFoundError("Requested --skip-collection but cached dataset files are missing.")
|
| 1959 |
+
if not args.skip_collection and (not args.reuse_dataset or not train_path.exists() or not val_path.exists()):
|
| 1960 |
+
train_payload = _collect_split(
|
| 1961 |
+
task_spec=task_spec,
|
| 1962 |
+
canonical_chunks=canonical_chunks,
|
| 1963 |
+
split_name="train",
|
| 1964 |
+
seeds=episode_splits["train"],
|
| 1965 |
+
spec=spec,
|
| 1966 |
+
output_path=train_path,
|
| 1967 |
+
)
|
| 1968 |
+
val_payload = _collect_split(
|
| 1969 |
+
task_spec=task_spec,
|
| 1970 |
+
canonical_chunks=canonical_chunks,
|
| 1971 |
+
split_name="val",
|
| 1972 |
+
seeds=episode_splits["val"],
|
| 1973 |
+
spec=spec,
|
| 1974 |
+
output_path=val_path,
|
| 1975 |
+
)
|
| 1976 |
+
else:
|
| 1977 |
+
train_payload = torch.load(train_path, map_location="cpu", weights_only=False)
|
| 1978 |
+
val_payload = torch.load(val_path, map_location="cpu", weights_only=False)
|
| 1979 |
+
|
| 1980 |
+
train_samples = _normalize_cached_samples(train_payload["samples"])
|
| 1981 |
+
val_samples = _normalize_cached_samples(val_payload["samples"])
|
| 1982 |
+
checkpoints: dict[str, Path] = {}
|
| 1983 |
+
train_specs: dict[str, dict[str, Any]] = {}
|
| 1984 |
+
for variant in ("trunk_only_ft", "adapter_active_ft"):
|
| 1985 |
+
variant_output_dir = paths.output_dir / f"{variant}_seed{spec.train_seed}"
|
| 1986 |
+
checkpoint_path = variant_output_dir / "checkpoint_best.pt"
|
| 1987 |
+
if args.skip_train and not checkpoint_path.exists():
|
| 1988 |
+
raise FileNotFoundError(f"Requested --skip-train but checkpoint is missing: {checkpoint_path}")
|
| 1989 |
+
if not args.skip_train and (not args.reuse_checkpoints or not checkpoint_path.exists()):
|
| 1990 |
+
checkpoint_path, train_spec = _train_variant(
|
| 1991 |
+
task_spec=task_spec,
|
| 1992 |
+
variant=variant,
|
| 1993 |
+
train_samples=train_samples,
|
| 1994 |
+
val_samples=val_samples,
|
| 1995 |
+
spec=spec,
|
| 1996 |
+
output_dir=variant_output_dir,
|
| 1997 |
+
)
|
| 1998 |
+
else:
|
| 1999 |
+
summary_path = variant_output_dir / "summary.json"
|
| 2000 |
+
if not summary_path.exists():
|
| 2001 |
+
raise FileNotFoundError(f"Missing cached summary file for {variant}: {summary_path}")
|
| 2002 |
+
summary_payload = json.loads(summary_path.read_text(encoding="utf-8"))
|
| 2003 |
+
train_spec = summary_payload["train_spec"]
|
| 2004 |
+
checkpoints[variant] = checkpoint_path
|
| 2005 |
+
train_specs[variant] = train_spec
|
| 2006 |
+
|
| 2007 |
+
results: list[dict[str, Any]] = []
|
| 2008 |
+
if not args.skip_eval:
|
| 2009 |
+
eval_plan = (
|
| 2010 |
+
("trunk_only_ft", checkpoints["trunk_only_ft"], "trunk_only", None),
|
| 2011 |
+
("adapter_noop", checkpoints["adapter_active_ft"], "adapter_noop", None),
|
| 2012 |
+
("adapter_active_ft", checkpoints["adapter_active_ft"], "adapter_active", train_specs["adapter_active_ft"]),
|
| 2013 |
+
)
|
| 2014 |
+
for result_mode_name, checkpoint_path, adapter_mode, train_spec in eval_plan:
|
| 2015 |
+
result = _evaluate_checkpoint(
|
| 2016 |
+
task_spec=task_spec,
|
| 2017 |
+
canonical_chunks=canonical_chunks,
|
| 2018 |
+
checkpoint_path=checkpoint_path,
|
| 2019 |
+
adapter_mode=adapter_mode,
|
| 2020 |
+
result_mode_name=result_mode_name,
|
| 2021 |
+
seeds=episode_splits[args.eval_split],
|
| 2022 |
+
report_path=report_dir / f"{result_mode_name}_seed{spec.train_seed}.json",
|
| 2023 |
+
train_spec=train_spec if result_mode_name != "adapter_noop" else None,
|
| 2024 |
+
planner_overrides=planner_overrides,
|
| 2025 |
+
)
|
| 2026 |
+
if result_mode_name == "trunk_only_ft":
|
| 2027 |
+
result["train_spec"] = train_specs["trunk_only_ft"]
|
| 2028 |
+
(report_dir / f"{result_mode_name}_seed{spec.train_seed}.json").write_text(
|
| 2029 |
+
json.dumps(result, indent=2) + "\n",
|
| 2030 |
+
encoding="utf-8",
|
| 2031 |
+
)
|
| 2032 |
+
results.append(result)
|
| 2033 |
+
_summarize_task(task_spec, results, report_dir)
|
| 2034 |
+
|
| 2035 |
+
|
| 2036 |
+
if __name__ == "__main__":
|
| 2037 |
+
main()
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_maniskill_pickclutter_smoke.py
ADDED
|
@@ -0,0 +1,2005 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import collections
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
from collections import deque
|
| 10 |
+
from dataclasses import asdict, dataclass
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Iterable, Sequence
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from torch.utils.data import DataLoader, Dataset
|
| 19 |
+
|
| 20 |
+
from eval.public_benchmark_package import build_public_eval_protocol, build_target_training_spec
|
| 21 |
+
from models.action_decoder import ChunkDecoderConfig, semantic_macro_chunk
|
| 22 |
+
from models.backbones import FrozenVLBackboneConfig
|
| 23 |
+
from models.multiview_fusion import MultiViewFusionConfig
|
| 24 |
+
from models.observation_memory import ObservationMemoryConfig
|
| 25 |
+
from models.planner import PlannerConfig
|
| 26 |
+
from models.policy import PolicyConfig
|
| 27 |
+
from models.reveal_head import RevealHeadConfig, TASK_METRIC_NAMES
|
| 28 |
+
from models.world_model import RevealWMConfig
|
| 29 |
+
from train.checkpoint_compat import filter_compatible_state_dict
|
| 30 |
+
from train.losses import LossWeights
|
| 31 |
+
from train.trainer import BimanualTrainer, TrainerConfig, apply_trainable_parameter_prefixes, build_policy
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _configure_runtime_env() -> None:
|
| 35 |
+
os.environ.setdefault("VK_ICD_FILENAMES", "/workspace/runtime/vulkan/icd.d/nvidia_icd_egl.json")
|
| 36 |
+
os.environ.setdefault("VK_LAYER_PATH", "/workspace/runtime/vulkan/implicit_layer.d")
|
| 37 |
+
os.environ.setdefault("XDG_RUNTIME_DIR", "/tmp/runtime-root")
|
| 38 |
+
os.environ.setdefault("MS_ASSET_DIR", "/workspace/data/maniskill")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
_configure_runtime_env()
|
| 42 |
+
|
| 43 |
+
import mani_skill.envs # noqa: E402
|
| 44 |
+
import sapien # noqa: E402
|
| 45 |
+
from mani_skill.envs.tasks.tabletop.pick_clutter_ycb import PickClutterYCBEnv # noqa: E402
|
| 46 |
+
from mani_skill.sensors.camera import CameraConfig # noqa: E402
|
| 47 |
+
from mani_skill.utils import sapien_utils # noqa: E402
|
| 48 |
+
from mani_skill.utils.structs import Pose # noqa: E402
|
| 49 |
+
|
| 50 |
+
from eval.run_public_benchmark_package import summarize_public_benchmark_package # noqa: E402
|
| 51 |
+
from models.action_decoder import TASK_INDEX # noqa: E402
|
| 52 |
+
from train.run_experiment import _load_init_checkpoint, _move_batch_to_device # noqa: E402
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
| 56 |
+
WORKSPACE_ROOT = Path("/workspace/workspace")
|
| 57 |
+
SMOKE_VERSION = "smoke_v5"
|
| 58 |
+
DEFAULT_DATA_DIR = WORKSPACE_ROOT / "data" / "maniskill_pickclutter" / SMOKE_VERSION
|
| 59 |
+
DEFAULT_OUTPUT_DIR = WORKSPACE_ROOT / "outputs" / f"maniskill_pickclutter_{SMOKE_VERSION}"
|
| 60 |
+
DEFAULT_REPORT_DIR = WORKSPACE_ROOT / "reports" / f"maniskill_pickclutter_{SMOKE_VERSION}"
|
| 61 |
+
DEFAULT_INIT_CHECKPOINT = Path(
|
| 62 |
+
"/workspace/workspace/VLAarchtests2/VLAarchtests/artifacts/outputs/"
|
| 63 |
+
"r3d_handoff_phase/proxy_interaction_r3d_stage3_clip_rgbd_handoff_compact_phase_seed17/checkpoint_best.pt"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
TEXT_PROMPT = "retrieve the target object from dense clutter and stage it at the front edge"
|
| 67 |
+
TASK_NAME = "foliage"
|
| 68 |
+
TASK_ID = TASK_INDEX[TASK_NAME]
|
| 69 |
+
CAMERA_NAMES = ("front", "left", "right")
|
| 70 |
+
MODE_ORDER = (
|
| 71 |
+
"base_action",
|
| 72 |
+
"sweep_left",
|
| 73 |
+
"sweep_right",
|
| 74 |
+
"pin_canopy",
|
| 75 |
+
"widen_gap",
|
| 76 |
+
"maintain_gap",
|
| 77 |
+
"insert_actor",
|
| 78 |
+
"retrieve",
|
| 79 |
+
)
|
| 80 |
+
ROLL_OUT_HORIZON = 5
|
| 81 |
+
NUM_SUPPORT_MODES = 3
|
| 82 |
+
NUM_APPROACH_TEMPLATES = 32
|
| 83 |
+
SUPPORT_MODE_HOLD = 0
|
| 84 |
+
SUPPORT_MODE_TRANSFER = 1
|
| 85 |
+
SUPPORT_MODE_PASSIVE = 2
|
| 86 |
+
REVEAL_MODES = ("sweep_left", "sweep_right", "pin_canopy", "widen_gap", "maintain_gap")
|
| 87 |
+
TRANSFER_MODES = ("insert_actor",)
|
| 88 |
+
RETRIEVE_MODES = ("retrieve",)
|
| 89 |
+
STATE_SUPERVISION_METRICS = (
|
| 90 |
+
"opening_quality",
|
| 91 |
+
"actor_feasibility_score",
|
| 92 |
+
"gap_width",
|
| 93 |
+
"damage_proxy",
|
| 94 |
+
"release_collapse_rate",
|
| 95 |
+
"target_visibility_confidence",
|
| 96 |
+
"insertable_actor_corridor",
|
| 97 |
+
"insertion_corridor",
|
| 98 |
+
"hold_quality",
|
| 99 |
+
"layer_separation_quality",
|
| 100 |
+
"fold_preservation",
|
| 101 |
+
"top_layer_stability",
|
| 102 |
+
"lift_too_much_risk",
|
| 103 |
+
)
|
| 104 |
+
MAX_MACRO_STEPS = 4
|
| 105 |
+
HISTORY_STEPS = 6
|
| 106 |
+
PROPRIO_DIM = 32
|
| 107 |
+
EXTRACTION_LINE_Y = -0.22
|
| 108 |
+
MIN_CLEARANCE_FOR_SUCCESS = 0.05
|
| 109 |
+
DEFAULT_SEED = 17
|
| 110 |
+
SMOKE_ADAPTER_CONFIDENCE_THRESHOLD = 0.50
|
| 111 |
+
SMOKE_RETRIEVE_ACCESS_THRESHOLD = 0.08
|
| 112 |
+
SMOKE_RETRIEVE_PERSISTENCE_THRESHOLD = 0.12
|
| 113 |
+
SMOKE_RETRIEVE_SUPPORT_THRESHOLD = 0.08
|
| 114 |
+
SMOKE_RETRIEVE_REOCCLUSION_THRESHOLD = 0.92
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass(frozen=True)
|
| 118 |
+
class SmokePaths:
|
| 119 |
+
data_dir: Path = DEFAULT_DATA_DIR
|
| 120 |
+
output_dir: Path = DEFAULT_OUTPUT_DIR
|
| 121 |
+
report_dir: Path = DEFAULT_REPORT_DIR
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass(frozen=True)
|
| 125 |
+
class SmokeSpec:
|
| 126 |
+
resolution: int = 224
|
| 127 |
+
train_episodes: int = 32
|
| 128 |
+
val_episodes: int = 8
|
| 129 |
+
eval_episodes: int = 50
|
| 130 |
+
dataset_seed: int = DEFAULT_SEED
|
| 131 |
+
train_seed: int = DEFAULT_SEED
|
| 132 |
+
history_steps: int = HISTORY_STEPS
|
| 133 |
+
max_macro_steps: int = MAX_MACRO_STEPS
|
| 134 |
+
batch_size: int = 4
|
| 135 |
+
epochs: int = 6
|
| 136 |
+
num_workers: int = 16
|
| 137 |
+
learning_rate: float = 1e-4
|
| 138 |
+
weight_decay: float = 1e-4
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def seed(self) -> int:
|
| 142 |
+
return self.train_seed
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _apply_smoke_planner_overrides(
|
| 146 |
+
policy_config: PolicyConfig,
|
| 147 |
+
planner_overrides: dict[str, float] | None = None,
|
| 148 |
+
) -> PolicyConfig:
|
| 149 |
+
policy_config.planner.adapter_confidence_threshold = SMOKE_ADAPTER_CONFIDENCE_THRESHOLD
|
| 150 |
+
policy_config.planner.retrieve_access_threshold = SMOKE_RETRIEVE_ACCESS_THRESHOLD
|
| 151 |
+
policy_config.planner.retrieve_persistence_threshold = SMOKE_RETRIEVE_PERSISTENCE_THRESHOLD
|
| 152 |
+
policy_config.planner.retrieve_support_threshold = SMOKE_RETRIEVE_SUPPORT_THRESHOLD
|
| 153 |
+
policy_config.planner.retrieve_reocclusion_threshold = SMOKE_RETRIEVE_REOCCLUSION_THRESHOLD
|
| 154 |
+
if planner_overrides:
|
| 155 |
+
for key, value in planner_overrides.items():
|
| 156 |
+
if value is None:
|
| 157 |
+
continue
|
| 158 |
+
setattr(policy_config.planner, key, value)
|
| 159 |
+
return policy_config
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class PickClutterRevealEnv(PickClutterYCBEnv):
|
| 163 |
+
@property
|
| 164 |
+
def _default_sensor_configs(self):
|
| 165 |
+
resolution = 224
|
| 166 |
+
return [
|
| 167 |
+
CameraConfig(
|
| 168 |
+
"front",
|
| 169 |
+
pose=sapien_utils.look_at(eye=[0.30, 0.00, 0.62], target=[-0.06, 0.00, 0.04]),
|
| 170 |
+
width=resolution,
|
| 171 |
+
height=resolution,
|
| 172 |
+
fov=np.pi / 2,
|
| 173 |
+
near=0.01,
|
| 174 |
+
far=100.0,
|
| 175 |
+
),
|
| 176 |
+
CameraConfig(
|
| 177 |
+
"left",
|
| 178 |
+
pose=sapien_utils.look_at(eye=[0.22, 0.34, 0.34], target=[-0.02, 0.02, 0.03]),
|
| 179 |
+
width=resolution,
|
| 180 |
+
height=resolution,
|
| 181 |
+
fov=np.pi / 2,
|
| 182 |
+
near=0.01,
|
| 183 |
+
far=100.0,
|
| 184 |
+
),
|
| 185 |
+
CameraConfig(
|
| 186 |
+
"right",
|
| 187 |
+
pose=sapien_utils.look_at(eye=[0.22, -0.34, 0.34], target=[-0.02, -0.02, 0.03]),
|
| 188 |
+
width=resolution,
|
| 189 |
+
height=resolution,
|
| 190 |
+
fov=np.pi / 2,
|
| 191 |
+
near=0.01,
|
| 192 |
+
far=100.0,
|
| 193 |
+
),
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _np(value: Any, *, dtype: np.dtype | None = None) -> np.ndarray:
|
| 198 |
+
if isinstance(value, np.ndarray):
|
| 199 |
+
array = value
|
| 200 |
+
elif isinstance(value, Tensor):
|
| 201 |
+
array = value.detach().cpu().numpy()
|
| 202 |
+
else:
|
| 203 |
+
array = np.asarray(value)
|
| 204 |
+
if dtype is not None:
|
| 205 |
+
array = array.astype(dtype, copy=False)
|
| 206 |
+
return array
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _vec3(value: Any) -> np.ndarray:
|
| 210 |
+
return _np(value, dtype=np.float32).reshape(-1)[:3]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _camera_intrinsic_from_param(param: dict[str, Any]) -> np.ndarray:
|
| 214 |
+
for key in ("intrinsic_cv", "intrinsic", "cam_intrinsic"):
|
| 215 |
+
if key in param:
|
| 216 |
+
matrix = _np(param[key], dtype=np.float32)
|
| 217 |
+
return matrix[0] if matrix.ndim == 3 else matrix
|
| 218 |
+
return np.eye(3, dtype=np.float32)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _camera_extrinsic_from_param(param: dict[str, Any]) -> np.ndarray:
|
| 222 |
+
for key in ("cam2world_gl", "cam2world", "extrinsic_cv", "extrinsic"):
|
| 223 |
+
if key in param:
|
| 224 |
+
matrix = _np(param[key], dtype=np.float32)
|
| 225 |
+
return matrix[0] if matrix.ndim == 3 else matrix
|
| 226 |
+
return np.eye(4, dtype=np.float32)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _convert_depth(depth: np.ndarray) -> np.ndarray:
|
| 230 |
+
depth = depth.astype(np.float32, copy=False)
|
| 231 |
+
if np.issubdtype(depth.dtype, np.integer):
|
| 232 |
+
depth = depth / 1000.0
|
| 233 |
+
return depth
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _build_proprio(env: PickClutterRevealEnv) -> np.ndarray:
|
| 237 |
+
base = env.unwrapped
|
| 238 |
+
qpos = _np(base.agent.robot.get_qpos(), dtype=np.float32).reshape(-1)
|
| 239 |
+
qvel = _np(base.agent.robot.get_qvel(), dtype=np.float32).reshape(-1)
|
| 240 |
+
tcp_pose = _np(base.agent.tcp.pose.raw_pose, dtype=np.float32).reshape(-1)
|
| 241 |
+
gripper_width = qpos[-2:].sum(keepdims=True).astype(np.float32)
|
| 242 |
+
pieces = [qpos, qvel, tcp_pose, gripper_width]
|
| 243 |
+
flat = np.concatenate(pieces, axis=0)
|
| 244 |
+
if flat.shape[0] >= PROPRIO_DIM:
|
| 245 |
+
return flat[:PROPRIO_DIM]
|
| 246 |
+
padded = np.zeros((PROPRIO_DIM,), dtype=np.float32)
|
| 247 |
+
padded[: flat.shape[0]] = flat
|
| 248 |
+
return padded
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _extract_sensor_bundle(obs: dict[str, Any]) -> dict[str, np.ndarray]:
|
| 252 |
+
sensor_data = obs["sensor_data"]
|
| 253 |
+
sensor_param = obs["sensor_param"]
|
| 254 |
+
rgb_views: list[np.ndarray] = []
|
| 255 |
+
depth_views: list[np.ndarray] = []
|
| 256 |
+
seg_views: list[np.ndarray] = []
|
| 257 |
+
intrinsics: list[np.ndarray] = []
|
| 258 |
+
extrinsics: list[np.ndarray] = []
|
| 259 |
+
for camera_name in CAMERA_NAMES:
|
| 260 |
+
view = sensor_data[camera_name]
|
| 261 |
+
param = sensor_param[camera_name]
|
| 262 |
+
rgb = _np(view["rgb"], dtype=np.uint8)
|
| 263 |
+
depth = _np(view["depth"])
|
| 264 |
+
segmentation = _np(view["segmentation"])
|
| 265 |
+
rgb = rgb[0] if rgb.ndim == 4 else rgb
|
| 266 |
+
depth = depth[0] if depth.ndim == 4 else depth
|
| 267 |
+
segmentation = segmentation[0] if segmentation.ndim == 4 else segmentation
|
| 268 |
+
if depth.ndim == 3 and depth.shape[-1] == 1:
|
| 269 |
+
depth = depth[..., 0]
|
| 270 |
+
if segmentation.ndim == 3 and segmentation.shape[-1] == 1:
|
| 271 |
+
segmentation = segmentation[..., 0]
|
| 272 |
+
rgb_views.append(rgb.astype(np.uint8, copy=False))
|
| 273 |
+
depth_views.append(_convert_depth(depth))
|
| 274 |
+
seg_views.append(segmentation.astype(np.int32, copy=False))
|
| 275 |
+
intrinsics.append(_camera_intrinsic_from_param(param))
|
| 276 |
+
extrinsics.append(_camera_extrinsic_from_param(param))
|
| 277 |
+
depth_stack = np.stack(depth_views, axis=0).astype(np.float32)
|
| 278 |
+
depth_valid = (depth_stack > 1e-5).astype(np.float32)
|
| 279 |
+
return {
|
| 280 |
+
"images": np.stack(rgb_views, axis=0),
|
| 281 |
+
"depths": depth_stack[:, None, :, :],
|
| 282 |
+
"depth_valid": depth_valid[:, None, :, :],
|
| 283 |
+
"segmentations": np.stack(seg_views, axis=0),
|
| 284 |
+
"camera_intrinsics": np.stack(intrinsics, axis=0).astype(np.float32),
|
| 285 |
+
"camera_extrinsics": np.stack(extrinsics, axis=0).astype(np.float32),
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _target_actor(env: PickClutterRevealEnv) -> Any:
|
| 290 |
+
return env.unwrapped.target_object._objs[0]
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _all_scene_actors(env: PickClutterRevealEnv) -> list[Any]:
|
| 294 |
+
return list(env.unwrapped.all_objects._objs)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def _target_position(env: PickClutterRevealEnv) -> np.ndarray:
|
| 298 |
+
return _vec3(_target_actor(env).pose.p)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _all_positions(env: PickClutterRevealEnv) -> dict[str, np.ndarray]:
|
| 302 |
+
return {actor.name: _vec3(actor.pose.p) for actor in _all_scene_actors(env)}
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _nearest_non_target_distance(env: PickClutterRevealEnv) -> float:
|
| 306 |
+
target = _target_actor(env)
|
| 307 |
+
target_xy = _vec3(target.pose.p)[:2]
|
| 308 |
+
distances = []
|
| 309 |
+
for actor in _all_scene_actors(env):
|
| 310 |
+
if actor.name == target.name:
|
| 311 |
+
continue
|
| 312 |
+
distances.append(float(np.linalg.norm(_vec3(actor.pose.p)[:2] - target_xy)))
|
| 313 |
+
if not distances:
|
| 314 |
+
return 1.0
|
| 315 |
+
return float(min(distances))
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _success_from_state(env: PickClutterRevealEnv) -> bool:
|
| 319 |
+
target = _target_position(env)
|
| 320 |
+
return bool(target[1] <= EXTRACTION_LINE_Y and _nearest_non_target_distance(env) >= MIN_CLEARANCE_FOR_SUCCESS)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def _clearance_score(env: PickClutterRevealEnv) -> float:
|
| 324 |
+
return float(np.clip((_nearest_non_target_distance(env) - 0.03) / 0.09, 0.0, 1.0))
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _extraction_progress(env: PickClutterRevealEnv) -> float:
|
| 328 |
+
y_value = _target_position(env)[1]
|
| 329 |
+
return float(np.clip(((-0.05) - y_value) / ((-0.05) - EXTRACTION_LINE_Y), 0.0, 1.0))
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _target_visibility(obs_bundle: dict[str, np.ndarray], target_seg_id: int) -> float:
|
| 333 |
+
segmentation = obs_bundle["segmentations"]
|
| 334 |
+
fractions = [(view == int(target_seg_id)).mean() for view in segmentation]
|
| 335 |
+
return float(np.clip(np.mean(fractions) * 80.0, 0.0, 1.0))
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _snapshot_env(env: PickClutterRevealEnv) -> dict[str, Any]:
|
| 339 |
+
base = env.unwrapped
|
| 340 |
+
return {
|
| 341 |
+
"state_dict": base.get_state_dict(),
|
| 342 |
+
"goal_pos": _np(base.goal_pos, dtype=np.float32).copy(),
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _restore_env(env: PickClutterRevealEnv, snapshot: dict[str, Any]) -> None:
|
| 347 |
+
base = env.unwrapped
|
| 348 |
+
state_dict = snapshot["state_dict"]
|
| 349 |
+
goal_pos = torch.as_tensor(snapshot["goal_pos"], dtype=torch.float32, device=base.device)
|
| 350 |
+
base.set_state_dict(state_dict)
|
| 351 |
+
base.goal_pos = goal_pos.view_as(base.goal_pos)
|
| 352 |
+
base.goal_site.set_pose(Pose.create_from_pq(base.goal_pos))
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _sync_env_state(src_env: PickClutterRevealEnv, dst_env: PickClutterRevealEnv) -> None:
|
| 356 |
+
_restore_env(dst_env, _snapshot_env(src_env))
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _canonical_chunks() -> dict[str, np.ndarray]:
|
| 360 |
+
base = torch.zeros((1, 8, 14), dtype=torch.float32)
|
| 361 |
+
chunks: dict[str, np.ndarray] = {"base_action": base.squeeze(0).numpy().astype(np.float32)}
|
| 362 |
+
for mode_name in MODE_ORDER[1:]:
|
| 363 |
+
chunk = semantic_macro_chunk(base, task_name=TASK_NAME, mode_name=mode_name).squeeze(0).cpu().numpy()
|
| 364 |
+
chunks[mode_name] = chunk.astype(np.float32)
|
| 365 |
+
return chunks
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
CANONICAL_CHUNKS = _canonical_chunks()
|
| 369 |
+
STATE_METRIC_MASK = np.asarray(
|
| 370 |
+
[metric_name in STATE_SUPERVISION_METRICS for metric_name in TASK_METRIC_NAMES],
|
| 371 |
+
dtype=np.bool_,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _classify_mode_from_chunk(chunk: np.ndarray) -> str:
|
| 376 |
+
candidate = np.asarray(chunk, dtype=np.float32)
|
| 377 |
+
distances = {
|
| 378 |
+
mode_name: float(np.mean(np.abs(candidate - prototype)))
|
| 379 |
+
for mode_name, prototype in CANONICAL_CHUNKS.items()
|
| 380 |
+
}
|
| 381 |
+
return min(distances, key=distances.get)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _gripper_action(open_gripper: bool) -> float:
|
| 385 |
+
return 1.0 if open_gripper else -1.0
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _repeat_delta(env: PickClutterRevealEnv, delta_xyz: Sequence[float], *, open_gripper: bool, steps: int) -> dict[str, Any]:
|
| 389 |
+
last_obs: dict[str, Any] | None = None
|
| 390 |
+
action = np.zeros((1, 4), dtype=np.float32)
|
| 391 |
+
action[0, :3] = np.asarray(delta_xyz, dtype=np.float32)
|
| 392 |
+
action[0, 3] = _gripper_action(open_gripper)
|
| 393 |
+
for _ in range(int(steps)):
|
| 394 |
+
obs, _, terminated, truncated, info = env.step(action)
|
| 395 |
+
last_obs = obs
|
| 396 |
+
if bool(np.asarray(terminated).reshape(-1)[0]) or bool(np.asarray(truncated).reshape(-1)[0]):
|
| 397 |
+
break
|
| 398 |
+
return {
|
| 399 |
+
"obs": last_obs if last_obs is not None else env.get_obs(env.get_info()),
|
| 400 |
+
"terminated": False,
|
| 401 |
+
"truncated": False,
|
| 402 |
+
"info": info if last_obs is not None else env.get_info(),
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _move_tcp_to(
|
| 407 |
+
env: PickClutterRevealEnv,
|
| 408 |
+
target_xyz: Sequence[float],
|
| 409 |
+
*,
|
| 410 |
+
open_gripper: bool,
|
| 411 |
+
max_steps: int = 120,
|
| 412 |
+
tolerance: float = 0.008,
|
| 413 |
+
) -> dict[str, Any]:
|
| 414 |
+
last_obs: dict[str, Any] | None = None
|
| 415 |
+
target = np.asarray(target_xyz, dtype=np.float32)
|
| 416 |
+
info = env.get_info()
|
| 417 |
+
for _ in range(int(max_steps)):
|
| 418 |
+
tcp = _vec3(env.unwrapped.agent.tcp.pose.p)
|
| 419 |
+
delta = target - tcp
|
| 420 |
+
if float(np.linalg.norm(delta)) <= float(tolerance):
|
| 421 |
+
break
|
| 422 |
+
action = np.zeros((1, 4), dtype=np.float32)
|
| 423 |
+
action[0, :3] = np.clip(delta / 0.04, -1.0, 1.0)
|
| 424 |
+
action[0, 3] = _gripper_action(open_gripper)
|
| 425 |
+
obs, _, terminated, truncated, info = env.step(action)
|
| 426 |
+
last_obs = obs
|
| 427 |
+
if bool(np.asarray(terminated).reshape(-1)[0]) or bool(np.asarray(truncated).reshape(-1)[0]):
|
| 428 |
+
break
|
| 429 |
+
return {
|
| 430 |
+
"obs": last_obs if last_obs is not None else env.get_obs(info),
|
| 431 |
+
"info": info,
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def _find_path_blocker(env: PickClutterRevealEnv) -> np.ndarray | None:
|
| 436 |
+
target = _target_position(env)
|
| 437 |
+
target_name = _target_actor(env).name
|
| 438 |
+
blockers: list[tuple[float, np.ndarray]] = []
|
| 439 |
+
for actor in _all_scene_actors(env):
|
| 440 |
+
if actor.name == target_name:
|
| 441 |
+
continue
|
| 442 |
+
position = _vec3(actor.pose.p)
|
| 443 |
+
if position[1] <= target[1] + 0.06 and abs(position[0] - target[0]) <= 0.10:
|
| 444 |
+
blockers.append((float(np.linalg.norm(position[:2] - target[:2])), position))
|
| 445 |
+
if blockers:
|
| 446 |
+
blockers.sort(key=lambda item: item[0])
|
| 447 |
+
return blockers[0][1]
|
| 448 |
+
nearest: tuple[float, np.ndarray] | None = None
|
| 449 |
+
for actor in _all_scene_actors(env):
|
| 450 |
+
if actor.name == target_name:
|
| 451 |
+
continue
|
| 452 |
+
position = _vec3(actor.pose.p)
|
| 453 |
+
distance = float(np.linalg.norm(position[:2] - target[:2]))
|
| 454 |
+
if nearest is None or distance < nearest[0]:
|
| 455 |
+
nearest = (distance, position)
|
| 456 |
+
return None if nearest is None else nearest[1]
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def _execute_push(
|
| 460 |
+
env: PickClutterRevealEnv,
|
| 461 |
+
*,
|
| 462 |
+
anchor_xyz: np.ndarray,
|
| 463 |
+
pre_offset: np.ndarray,
|
| 464 |
+
push_delta: np.ndarray,
|
| 465 |
+
push_steps: int,
|
| 466 |
+
) -> dict[str, Any]:
|
| 467 |
+
_move_tcp_to(env, anchor_xyz + np.array([0.0, 0.0, 0.08], dtype=np.float32) + pre_offset, open_gripper=True)
|
| 468 |
+
_move_tcp_to(env, anchor_xyz + pre_offset, open_gripper=True, max_steps=100, tolerance=0.010)
|
| 469 |
+
result = _repeat_delta(env, push_delta, open_gripper=True, steps=push_steps)
|
| 470 |
+
_move_tcp_to(
|
| 471 |
+
env,
|
| 472 |
+
np.array([_vec3(env.unwrapped.agent.tcp.pose.p)[0], _vec3(env.unwrapped.agent.tcp.pose.p)[1], 0.10], dtype=np.float32),
|
| 473 |
+
open_gripper=True,
|
| 474 |
+
max_steps=80,
|
| 475 |
+
tolerance=0.012,
|
| 476 |
+
)
|
| 477 |
+
return result
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _execute_mode(env: PickClutterRevealEnv, mode_name: str) -> dict[str, Any]:
|
| 481 |
+
target = _target_position(env)
|
| 482 |
+
blocker = _find_path_blocker(env)
|
| 483 |
+
if mode_name == "retrieve":
|
| 484 |
+
return _execute_push(
|
| 485 |
+
env,
|
| 486 |
+
anchor_xyz=target,
|
| 487 |
+
pre_offset=np.array([0.0, 0.035, 0.026], dtype=np.float32),
|
| 488 |
+
push_delta=np.array([0.0, -0.7, 0.0], dtype=np.float32),
|
| 489 |
+
push_steps=18,
|
| 490 |
+
)
|
| 491 |
+
if mode_name == "insert_actor":
|
| 492 |
+
return _execute_push(
|
| 493 |
+
env,
|
| 494 |
+
anchor_xyz=target,
|
| 495 |
+
pre_offset=np.array([0.0, 0.045, 0.028], dtype=np.float32),
|
| 496 |
+
push_delta=np.array([0.0, -0.4, 0.0], dtype=np.float32),
|
| 497 |
+
push_steps=10,
|
| 498 |
+
)
|
| 499 |
+
if mode_name == "widen_gap":
|
| 500 |
+
anchor = blocker if blocker is not None else target
|
| 501 |
+
direction = -1.0 if anchor[0] >= target[0] else 1.0
|
| 502 |
+
return _execute_push(
|
| 503 |
+
env,
|
| 504 |
+
anchor_xyz=anchor,
|
| 505 |
+
pre_offset=np.array([0.0, 0.025, 0.028], dtype=np.float32),
|
| 506 |
+
push_delta=np.array([0.75 * direction, -0.12, 0.0], dtype=np.float32),
|
| 507 |
+
push_steps=18,
|
| 508 |
+
)
|
| 509 |
+
if mode_name == "sweep_left":
|
| 510 |
+
anchor = blocker if blocker is not None else target
|
| 511 |
+
return _execute_push(
|
| 512 |
+
env,
|
| 513 |
+
anchor_xyz=anchor,
|
| 514 |
+
pre_offset=np.array([0.015, 0.025, 0.028], dtype=np.float32),
|
| 515 |
+
push_delta=np.array([-0.70, -0.10, 0.0], dtype=np.float32),
|
| 516 |
+
push_steps=14,
|
| 517 |
+
)
|
| 518 |
+
if mode_name == "sweep_right":
|
| 519 |
+
anchor = blocker if blocker is not None else target
|
| 520 |
+
return _execute_push(
|
| 521 |
+
env,
|
| 522 |
+
anchor_xyz=anchor,
|
| 523 |
+
pre_offset=np.array([-0.015, 0.025, 0.028], dtype=np.float32),
|
| 524 |
+
push_delta=np.array([0.70, -0.10, 0.0], dtype=np.float32),
|
| 525 |
+
push_steps=14,
|
| 526 |
+
)
|
| 527 |
+
if mode_name == "pin_canopy":
|
| 528 |
+
anchor = blocker if blocker is not None else target
|
| 529 |
+
return _execute_push(
|
| 530 |
+
env,
|
| 531 |
+
anchor_xyz=anchor,
|
| 532 |
+
pre_offset=np.array([0.0, -0.015, 0.028], dtype=np.float32),
|
| 533 |
+
push_delta=np.array([0.0, 0.35, 0.0], dtype=np.float32),
|
| 534 |
+
push_steps=10,
|
| 535 |
+
)
|
| 536 |
+
if mode_name in {"maintain_gap", "base_action"}:
|
| 537 |
+
_move_tcp_to(
|
| 538 |
+
env,
|
| 539 |
+
np.array([target[0], target[1] + 0.02, 0.10], dtype=np.float32),
|
| 540 |
+
open_gripper=True,
|
| 541 |
+
max_steps=60,
|
| 542 |
+
tolerance=0.015,
|
| 543 |
+
)
|
| 544 |
+
return _repeat_delta(env, np.array([0.0, -0.10, 0.0], dtype=np.float32), open_gripper=True, steps=4)
|
| 545 |
+
raise KeyError(f"Unsupported mode: {mode_name}")
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def _candidate_metrics(
|
| 549 |
+
env: PickClutterRevealEnv,
|
| 550 |
+
*,
|
| 551 |
+
start_positions: dict[str, np.ndarray],
|
| 552 |
+
current_obs_bundle: dict[str, np.ndarray] | None = None,
|
| 553 |
+
) -> dict[str, float]:
|
| 554 |
+
positions = _all_positions(env)
|
| 555 |
+
target_name = _target_actor(env).name
|
| 556 |
+
non_target_displacements = []
|
| 557 |
+
for name, start_position in start_positions.items():
|
| 558 |
+
if name == target_name or name not in positions:
|
| 559 |
+
continue
|
| 560 |
+
non_target_displacements.append(float(np.linalg.norm((positions[name] - start_position)[:2])))
|
| 561 |
+
disturbance = float(np.clip(np.mean(non_target_displacements) / 0.10, 0.0, 1.0)) if non_target_displacements else 0.0
|
| 562 |
+
visibility = 0.0
|
| 563 |
+
if current_obs_bundle is not None:
|
| 564 |
+
visibility = _target_visibility(current_obs_bundle, getattr(_target_actor(env), "per_scene_id", -1))
|
| 565 |
+
return {
|
| 566 |
+
"retrieval_success": float(_success_from_state(env)),
|
| 567 |
+
"disturbance": disturbance,
|
| 568 |
+
"visibility": visibility,
|
| 569 |
+
"clearance": _clearance_score(env),
|
| 570 |
+
"progress": _extraction_progress(env),
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def _mean_non_target_displacement(
|
| 575 |
+
start_positions: dict[str, np.ndarray],
|
| 576 |
+
current_positions: dict[str, np.ndarray],
|
| 577 |
+
*,
|
| 578 |
+
target_name: str,
|
| 579 |
+
) -> float:
|
| 580 |
+
displacements = []
|
| 581 |
+
for actor_name, start_position in start_positions.items():
|
| 582 |
+
if actor_name == target_name or actor_name not in current_positions:
|
| 583 |
+
continue
|
| 584 |
+
displacements.append(float(np.linalg.norm((current_positions[actor_name] - start_position)[:2])))
|
| 585 |
+
if not displacements:
|
| 586 |
+
return 0.0
|
| 587 |
+
return float(np.mean(displacements))
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def _current_state_targets(
|
| 591 |
+
env: PickClutterRevealEnv,
|
| 592 |
+
*,
|
| 593 |
+
obs_bundle: dict[str, np.ndarray],
|
| 594 |
+
candidate_metrics: Sequence[dict[str, float]],
|
| 595 |
+
episode_start_positions: dict[str, np.ndarray],
|
| 596 |
+
selected_mode: str,
|
| 597 |
+
) -> dict[str, Any]:
|
| 598 |
+
metrics_by_name = {mode_name: payload for mode_name, payload in zip(MODE_ORDER, candidate_metrics)}
|
| 599 |
+
current_positions = _all_positions(env)
|
| 600 |
+
target_name = _target_actor(env).name
|
| 601 |
+
current_disturbance = float(
|
| 602 |
+
np.clip(
|
| 603 |
+
_mean_non_target_displacement(
|
| 604 |
+
episode_start_positions,
|
| 605 |
+
current_positions,
|
| 606 |
+
target_name=target_name,
|
| 607 |
+
)
|
| 608 |
+
/ 0.10,
|
| 609 |
+
0.0,
|
| 610 |
+
1.0,
|
| 611 |
+
)
|
| 612 |
+
)
|
| 613 |
+
current_visibility = _target_visibility(obs_bundle, getattr(_target_actor(env), "per_scene_id", -1))
|
| 614 |
+
current_clearance = _clearance_score(env)
|
| 615 |
+
current_progress = _extraction_progress(env)
|
| 616 |
+
base_gap = float(np.clip(max(current_clearance, current_progress), 0.0, 1.0))
|
| 617 |
+
support_stability = float(np.clip(1.0 - 0.5 * current_disturbance, 0.0, 1.0))
|
| 618 |
+
hold_quality = float(np.clip(0.5 * (support_stability + max(current_clearance, current_progress)), 0.0, 1.0))
|
| 619 |
+
opening_quality = float(
|
| 620 |
+
np.clip(0.55 * current_progress + 0.25 * current_clearance + 0.20 * current_visibility, 0.0, 1.0)
|
| 621 |
+
)
|
| 622 |
+
actor_feasibility = float(np.clip(0.6 * current_clearance + 0.4 * max(current_visibility, current_progress), 0.0, 1.0))
|
| 623 |
+
reocclusion_rate = float(np.clip(1.0 - max(current_clearance, current_visibility), 0.0, 1.0))
|
| 624 |
+
insertable_actor_corridor = float(np.clip(0.6 * actor_feasibility + 0.4 * base_gap, 0.0, 1.0))
|
| 625 |
+
insertion_corridor = float(np.clip(0.5 * actor_feasibility + 0.5 * base_gap, 0.0, 1.0))
|
| 626 |
+
layer_separation = float(np.clip(0.7 * base_gap + 0.3 * actor_feasibility, 0.0, 1.0))
|
| 627 |
+
fold_preservation = float(np.clip(1.0 - current_disturbance, 0.0, 1.0))
|
| 628 |
+
lift_too_much_risk = float(np.clip(current_disturbance + 0.5 * max(base_gap - 0.5, 0.0), 0.0, 1.0))
|
| 629 |
+
task_metrics = {
|
| 630 |
+
"opening_quality": opening_quality,
|
| 631 |
+
"actor_feasibility_score": actor_feasibility,
|
| 632 |
+
"gap_width": float(0.03 + 0.21 * base_gap),
|
| 633 |
+
"damage_proxy": current_disturbance,
|
| 634 |
+
"release_collapse_rate": reocclusion_rate,
|
| 635 |
+
"target_visibility_confidence": current_visibility,
|
| 636 |
+
"insertable_actor_corridor": insertable_actor_corridor,
|
| 637 |
+
"insertion_corridor": insertion_corridor,
|
| 638 |
+
"hold_quality": hold_quality,
|
| 639 |
+
"layer_separation_quality": layer_separation,
|
| 640 |
+
"fold_preservation": fold_preservation,
|
| 641 |
+
"top_layer_stability": support_stability,
|
| 642 |
+
"lift_too_much_risk": lift_too_much_risk,
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
base_metrics = metrics_by_name["base_action"]
|
| 646 |
+
insert_metrics = metrics_by_name["insert_actor"]
|
| 647 |
+
retrieve_metrics = metrics_by_name["retrieve"]
|
| 648 |
+
reveal_candidates = [metrics_by_name[mode_name] for mode_name in REVEAL_MODES]
|
| 649 |
+
reveal_access = max(candidate["candidate_actor_feasibility_auc"] for candidate in reveal_candidates)
|
| 650 |
+
reveal_reveal = max(candidate["candidate_reveal_achieved"] for candidate in reveal_candidates)
|
| 651 |
+
reveal_hold = max(candidate["candidate_hold_persistence"] for candidate in reveal_candidates)
|
| 652 |
+
reveal_visibility = max(candidate["candidate_visibility_integral"] for candidate in reveal_candidates)
|
| 653 |
+
|
| 654 |
+
reveal_corridor = float(
|
| 655 |
+
np.clip(
|
| 656 |
+
0.45 * opening_quality
|
| 657 |
+
+ 0.30 * reveal_access
|
| 658 |
+
+ 0.15 * reveal_reveal
|
| 659 |
+
+ 0.10 * reveal_visibility
|
| 660 |
+
- 0.10 * current_disturbance,
|
| 661 |
+
0.0,
|
| 662 |
+
1.0,
|
| 663 |
+
)
|
| 664 |
+
)
|
| 665 |
+
transfer_corridor = float(
|
| 666 |
+
np.clip(
|
| 667 |
+
0.45 * insertable_actor_corridor
|
| 668 |
+
+ 0.30 * insert_metrics["candidate_actor_feasibility_auc"]
|
| 669 |
+
+ 0.15 * insert_metrics["candidate_reveal_achieved"]
|
| 670 |
+
+ 0.10 * insert_metrics["candidate_visibility_integral"]
|
| 671 |
+
- 0.15 * current_disturbance,
|
| 672 |
+
0.0,
|
| 673 |
+
1.0,
|
| 674 |
+
)
|
| 675 |
+
)
|
| 676 |
+
passive_corridor = float(
|
| 677 |
+
np.clip(
|
| 678 |
+
0.55 * retrieve_metrics["candidate_retrieval_success"]
|
| 679 |
+
+ 0.20 * retrieve_metrics["candidate_actor_feasibility_auc"]
|
| 680 |
+
+ 0.15 * current_progress
|
| 681 |
+
+ 0.10 * current_clearance
|
| 682 |
+
- 0.10 * current_disturbance,
|
| 683 |
+
0.0,
|
| 684 |
+
1.0,
|
| 685 |
+
)
|
| 686 |
+
)
|
| 687 |
+
corridor_feasible = np.stack(
|
| 688 |
+
[
|
| 689 |
+
np.full((NUM_APPROACH_TEMPLATES,), reveal_corridor, dtype=np.float32),
|
| 690 |
+
np.full((NUM_APPROACH_TEMPLATES,), transfer_corridor, dtype=np.float32),
|
| 691 |
+
np.full((NUM_APPROACH_TEMPLATES,), passive_corridor, dtype=np.float32),
|
| 692 |
+
],
|
| 693 |
+
axis=0,
|
| 694 |
+
)
|
| 695 |
+
persistence_horizon = np.asarray(
|
| 696 |
+
[
|
| 697 |
+
ROLL_OUT_HORIZON
|
| 698 |
+
* float(np.clip(0.35 * hold_quality + 0.35 * reveal_hold + 0.30 * reveal_corridor, 0.0, 1.0)),
|
| 699 |
+
ROLL_OUT_HORIZON
|
| 700 |
+
* float(
|
| 701 |
+
np.clip(
|
| 702 |
+
0.30 * hold_quality
|
| 703 |
+
+ 0.35 * insert_metrics["candidate_hold_persistence"]
|
| 704 |
+
+ 0.35 * transfer_corridor,
|
| 705 |
+
0.0,
|
| 706 |
+
1.0,
|
| 707 |
+
)
|
| 708 |
+
),
|
| 709 |
+
ROLL_OUT_HORIZON
|
| 710 |
+
* float(
|
| 711 |
+
np.clip(
|
| 712 |
+
0.25 * hold_quality
|
| 713 |
+
+ 0.35 * retrieve_metrics["candidate_hold_persistence"]
|
| 714 |
+
+ 0.40 * passive_corridor,
|
| 715 |
+
0.0,
|
| 716 |
+
1.0,
|
| 717 |
+
)
|
| 718 |
+
),
|
| 719 |
+
],
|
| 720 |
+
dtype=np.float32,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
retrieve_margin = float(retrieve_metrics["candidate_utility"] - base_metrics["candidate_utility"])
|
| 724 |
+
insert_margin = float(insert_metrics["candidate_utility"] - base_metrics["candidate_utility"])
|
| 725 |
+
if selected_mode == "retrieve" or (retrieve_metrics["candidate_retrieval_success"] >= 0.5 and retrieve_margin >= 0.15):
|
| 726 |
+
support_mode = SUPPORT_MODE_PASSIVE
|
| 727 |
+
elif selected_mode == "insert_actor" or (insert_margin >= 0.15 and transfer_corridor >= 0.40):
|
| 728 |
+
support_mode = SUPPORT_MODE_TRANSFER
|
| 729 |
+
elif selected_mode in REVEAL_MODES or selected_mode == "maintain_gap":
|
| 730 |
+
support_mode = SUPPORT_MODE_HOLD
|
| 731 |
+
elif selected_mode == "base_action":
|
| 732 |
+
support_mode = SUPPORT_MODE_PASSIVE if passive_corridor >= 0.65 and retrieve_margin >= 0.05 else SUPPORT_MODE_HOLD
|
| 733 |
+
else:
|
| 734 |
+
support_mode = SUPPORT_MODE_HOLD
|
| 735 |
+
|
| 736 |
+
best_non_base_utility = max(float(payload["candidate_utility"]) for payload in candidate_metrics[1:])
|
| 737 |
+
intervention_warranted = selected_mode != "base_action" and best_non_base_utility >= float(base_metrics["candidate_utility"]) + 0.15
|
| 738 |
+
|
| 739 |
+
return {
|
| 740 |
+
"support_mode": int(support_mode),
|
| 741 |
+
"corridor_feasible": corridor_feasible,
|
| 742 |
+
"persistence_horizon": persistence_horizon,
|
| 743 |
+
"disturbance_cost": np.float32(current_disturbance),
|
| 744 |
+
"state_confidence_target": np.float32(1.0 if intervention_warranted else 0.0),
|
| 745 |
+
"task_metric_mask": STATE_METRIC_MASK.copy(),
|
| 746 |
+
**{metric_name: np.float32(metric_value) for metric_name, metric_value in task_metrics.items()},
|
| 747 |
+
}
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def _mode_support_mode(mode_name: str, current_support_mode: int) -> int:
|
| 751 |
+
if mode_name in REVEAL_MODES or mode_name == "maintain_gap":
|
| 752 |
+
return SUPPORT_MODE_HOLD
|
| 753 |
+
if mode_name in TRANSFER_MODES:
|
| 754 |
+
return SUPPORT_MODE_TRANSFER
|
| 755 |
+
if mode_name in RETRIEVE_MODES:
|
| 756 |
+
return SUPPORT_MODE_PASSIVE
|
| 757 |
+
return int(current_support_mode)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def _mode_progress_schedule(mode_name: str) -> np.ndarray:
|
| 761 |
+
if mode_name in REVEAL_MODES:
|
| 762 |
+
return np.asarray([0.18, 0.38, 0.62, 0.84, 1.0], dtype=np.float32)
|
| 763 |
+
if mode_name in TRANSFER_MODES:
|
| 764 |
+
return np.asarray([0.22, 0.44, 0.66, 0.86, 1.0], dtype=np.float32)
|
| 765 |
+
if mode_name in RETRIEVE_MODES:
|
| 766 |
+
return np.asarray([0.34, 0.56, 0.76, 0.92, 1.0], dtype=np.float32)
|
| 767 |
+
return np.asarray([0.10, 0.22, 0.34, 0.44, 0.54], dtype=np.float32)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def _scalar_rollout(start: float, end: float, schedule: np.ndarray) -> np.ndarray:
|
| 771 |
+
return np.clip((1.0 - schedule) * float(start) + schedule * float(end), 0.0, 1.0).astype(np.float32)
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def _candidate_rollout_targets(
|
| 775 |
+
*,
|
| 776 |
+
mode_name: str,
|
| 777 |
+
state_targets: dict[str, Any],
|
| 778 |
+
candidate_payload: dict[str, float],
|
| 779 |
+
) -> dict[str, np.ndarray]:
|
| 780 |
+
schedule = _mode_progress_schedule(mode_name)
|
| 781 |
+
start_visibility = float(state_targets["target_visibility_confidence"])
|
| 782 |
+
start_access = float(state_targets["actor_feasibility_score"])
|
| 783 |
+
start_persistence = float(np.clip(state_targets["hold_quality"], 0.0, 1.0))
|
| 784 |
+
start_support = float(np.clip(state_targets["top_layer_stability"], 0.0, 1.0))
|
| 785 |
+
start_reocclusion = float(np.clip(state_targets["release_collapse_rate"], 0.0, 1.0))
|
| 786 |
+
start_disturbance = float(np.clip(state_targets["disturbance_cost"], 0.0, 1.0))
|
| 787 |
+
start_clearance = float(np.clip(state_targets["actor_feasibility_score"], 0.0, 1.0))
|
| 788 |
+
start_grasp = float(np.clip(max(start_visibility, start_access), 0.0, 1.0))
|
| 789 |
+
|
| 790 |
+
end_visibility = float(np.clip(candidate_payload["candidate_immediate_visibility"], 0.0, 1.0))
|
| 791 |
+
end_access = float(np.clip(candidate_payload["candidate_immediate_access"], 0.0, 1.0))
|
| 792 |
+
end_progress = float(np.clip(candidate_payload["candidate_immediate_progress"], 0.0, 1.0))
|
| 793 |
+
end_disturbance = float(np.clip(candidate_payload["candidate_immediate_disturbance"], 0.0, 1.0))
|
| 794 |
+
end_support = float(np.clip(candidate_payload["candidate_immediate_support_stability"], 0.0, 1.0))
|
| 795 |
+
end_persistence = float(np.clip(candidate_payload["candidate_immediate_hold_persistence"], 0.0, 1.0))
|
| 796 |
+
end_reocclusion = float(np.clip(candidate_payload["candidate_immediate_reocclusion"], 0.0, 1.0))
|
| 797 |
+
end_clearance = float(np.clip(max(end_access, end_progress), 0.0, 1.0))
|
| 798 |
+
end_grasp = float(np.clip(max(end_visibility, 0.5 * end_access + 0.5 * end_progress), 0.0, 1.0))
|
| 799 |
+
|
| 800 |
+
if mode_name in TRANSFER_MODES:
|
| 801 |
+
start_visibility = max(start_visibility, 0.35 * end_visibility)
|
| 802 |
+
start_access = max(start_access, 0.40 * end_access)
|
| 803 |
+
start_persistence = max(start_persistence, 0.45 * end_persistence)
|
| 804 |
+
start_support = max(start_support, 0.50 * end_support)
|
| 805 |
+
elif mode_name in RETRIEVE_MODES:
|
| 806 |
+
start_visibility = max(start_visibility, 0.55 * end_visibility)
|
| 807 |
+
start_access = max(start_access, 0.70 * end_access)
|
| 808 |
+
start_persistence = max(start_persistence, 0.65 * end_persistence)
|
| 809 |
+
start_support = max(start_support, 0.65 * end_support)
|
| 810 |
+
start_reocclusion = min(start_reocclusion, max(0.4 * end_reocclusion, 0.0))
|
| 811 |
+
|
| 812 |
+
visibility = _scalar_rollout(start_visibility, end_visibility, schedule)
|
| 813 |
+
access = _scalar_rollout(start_access, end_access, schedule)
|
| 814 |
+
persistence = _scalar_rollout(start_persistence, end_persistence, schedule)
|
| 815 |
+
support = _scalar_rollout(start_support, end_support, schedule)
|
| 816 |
+
reocclusion = _scalar_rollout(start_reocclusion, end_reocclusion, schedule)
|
| 817 |
+
disturbance = _scalar_rollout(start_disturbance, end_disturbance, schedule)
|
| 818 |
+
clearance = _scalar_rollout(start_clearance, end_clearance, schedule)
|
| 819 |
+
grasp = _scalar_rollout(start_grasp, end_grasp, schedule)
|
| 820 |
+
|
| 821 |
+
reveal_corridor = np.clip(
|
| 822 |
+
0.38 * visibility + 0.34 * access + 0.22 * support - 0.12 * disturbance,
|
| 823 |
+
0.0,
|
| 824 |
+
1.0,
|
| 825 |
+
)
|
| 826 |
+
transfer_corridor = np.clip(
|
| 827 |
+
0.30 * visibility + 0.38 * access + 0.18 * persistence + 0.14 * support - 0.12 * disturbance,
|
| 828 |
+
0.0,
|
| 829 |
+
1.0,
|
| 830 |
+
)
|
| 831 |
+
passive_corridor = np.clip(
|
| 832 |
+
0.22 * visibility + 0.42 * access + 0.20 * persistence + 0.16 * grasp - 0.14 * disturbance - 0.10 * reocclusion,
|
| 833 |
+
0.0,
|
| 834 |
+
1.0,
|
| 835 |
+
)
|
| 836 |
+
if mode_name in REVEAL_MODES:
|
| 837 |
+
reveal_corridor = np.clip(reveal_corridor + 0.14, 0.0, 1.0)
|
| 838 |
+
passive_corridor = np.clip(0.75 * passive_corridor, 0.0, 1.0)
|
| 839 |
+
elif mode_name in TRANSFER_MODES:
|
| 840 |
+
transfer_corridor = np.clip(transfer_corridor + 0.16, 0.0, 1.0)
|
| 841 |
+
elif mode_name in RETRIEVE_MODES:
|
| 842 |
+
passive_corridor = np.clip(passive_corridor + 0.20, 0.0, 1.0)
|
| 843 |
+
reveal_corridor = np.clip(0.60 * reveal_corridor, 0.0, 1.0)
|
| 844 |
+
else:
|
| 845 |
+
reveal_corridor = np.clip(0.85 * reveal_corridor, 0.0, 1.0)
|
| 846 |
+
transfer_corridor = np.clip(0.75 * transfer_corridor, 0.0, 1.0)
|
| 847 |
+
passive_corridor = np.clip(0.80 * passive_corridor, 0.0, 1.0)
|
| 848 |
+
|
| 849 |
+
corridor_feasible = np.stack(
|
| 850 |
+
[
|
| 851 |
+
np.repeat(reveal_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
|
| 852 |
+
np.repeat(transfer_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
|
| 853 |
+
np.repeat(passive_corridor[:, None], NUM_APPROACH_TEMPLATES, axis=1),
|
| 854 |
+
],
|
| 855 |
+
axis=1,
|
| 856 |
+
).astype(np.float32)
|
| 857 |
+
persistence_horizon = np.stack(
|
| 858 |
+
[
|
| 859 |
+
np.clip(ROLL_OUT_HORIZON * (0.55 * reveal_corridor + 0.45 * support), 0.0, float(ROLL_OUT_HORIZON)),
|
| 860 |
+
np.clip(ROLL_OUT_HORIZON * (0.50 * transfer_corridor + 0.50 * persistence), 0.0, float(ROLL_OUT_HORIZON)),
|
| 861 |
+
np.clip(ROLL_OUT_HORIZON * (0.55 * passive_corridor + 0.45 * persistence), 0.0, float(ROLL_OUT_HORIZON)),
|
| 862 |
+
],
|
| 863 |
+
axis=1,
|
| 864 |
+
).astype(np.float32)
|
| 865 |
+
support_mode = np.full(
|
| 866 |
+
(ROLL_OUT_HORIZON,),
|
| 867 |
+
_mode_support_mode(mode_name, int(state_targets["support_mode"])),
|
| 868 |
+
dtype=np.int64,
|
| 869 |
+
)
|
| 870 |
+
if mode_name == "base_action":
|
| 871 |
+
support_mode[:] = int(state_targets["support_mode"])
|
| 872 |
+
|
| 873 |
+
return {
|
| 874 |
+
"candidate_rollout_support_mode": support_mode,
|
| 875 |
+
"candidate_rollout_corridor_feasible": corridor_feasible,
|
| 876 |
+
"candidate_rollout_persistence_horizon": persistence_horizon,
|
| 877 |
+
"candidate_rollout_disturbance_cost": disturbance.astype(np.float32),
|
| 878 |
+
"candidate_rollout_belief_map": visibility[:, None, None].astype(np.float32),
|
| 879 |
+
"candidate_rollout_visibility_map": visibility[:, None, None].astype(np.float32),
|
| 880 |
+
"candidate_rollout_clearance_map": np.repeat(clearance[:, None, None, None], 2, axis=1).astype(np.float32),
|
| 881 |
+
"candidate_rollout_support_stability": support[:, None, None, None].astype(np.float32),
|
| 882 |
+
"candidate_rollout_reocclusion_target": reocclusion[:, None, None].astype(np.float32),
|
| 883 |
+
"candidate_rollout_occluder_contact_map": np.clip(access * support, 0.0, 1.0)[:, None, None].astype(np.float32),
|
| 884 |
+
"candidate_rollout_grasp_affordance_map": grasp[:, None, None].astype(np.float32),
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def _evaluate_candidate(
|
| 889 |
+
sim_env: PickClutterRevealEnv,
|
| 890 |
+
obs_env: PickClutterRevealEnv,
|
| 891 |
+
snapshot: dict[str, Any],
|
| 892 |
+
mode_name: str,
|
| 893 |
+
) -> dict[str, float]:
|
| 894 |
+
_restore_env(sim_env, snapshot)
|
| 895 |
+
start_positions = _all_positions(sim_env)
|
| 896 |
+
_execute_mode(sim_env, mode_name)
|
| 897 |
+
_sync_env_state(sim_env, obs_env)
|
| 898 |
+
after_bundle = _extract_sensor_bundle(obs_env.get_obs(obs_env.get_info()))
|
| 899 |
+
immediate = _candidate_metrics(sim_env, start_positions=start_positions, current_obs_bundle=after_bundle)
|
| 900 |
+
if not immediate["retrieval_success"] and mode_name != "retrieve":
|
| 901 |
+
_execute_mode(sim_env, "retrieve")
|
| 902 |
+
_sync_env_state(sim_env, obs_env)
|
| 903 |
+
follow_bundle = _extract_sensor_bundle(obs_env.get_obs(obs_env.get_info()))
|
| 904 |
+
final_metrics = _candidate_metrics(sim_env, start_positions=start_positions, current_obs_bundle=follow_bundle)
|
| 905 |
+
else:
|
| 906 |
+
final_metrics = immediate
|
| 907 |
+
_restore_env(obs_env, snapshot)
|
| 908 |
+
utility = (
|
| 909 |
+
2.5 * final_metrics["retrieval_success"]
|
| 910 |
+
+ 1.0 * final_metrics["progress"]
|
| 911 |
+
+ 0.5 * final_metrics["clearance"]
|
| 912 |
+
+ 0.25 * final_metrics["visibility"]
|
| 913 |
+
- 0.5 * final_metrics["disturbance"]
|
| 914 |
+
)
|
| 915 |
+
return {
|
| 916 |
+
"candidate_retrieval_success": final_metrics["retrieval_success"],
|
| 917 |
+
"candidate_risk": float(np.clip(final_metrics["disturbance"], 0.0, 1.0)),
|
| 918 |
+
"candidate_utility": float(utility),
|
| 919 |
+
"candidate_final_disturbance_cost": final_metrics["disturbance"],
|
| 920 |
+
"candidate_reocclusion_rate": float(np.clip(1.0 - final_metrics["clearance"], 0.0, 1.0)),
|
| 921 |
+
"candidate_visibility_integral": final_metrics["visibility"],
|
| 922 |
+
"candidate_actor_feasibility_auc": final_metrics["clearance"],
|
| 923 |
+
"candidate_reveal_achieved": float(final_metrics["progress"] > 0.15 or final_metrics["clearance"] > 0.35),
|
| 924 |
+
"candidate_hold_persistence": float(1.0 - final_metrics["disturbance"]),
|
| 925 |
+
"candidate_support_stability_auc": float(1.0 - 0.5 * final_metrics["disturbance"]),
|
| 926 |
+
"candidate_disturbance_auc": final_metrics["disturbance"],
|
| 927 |
+
"candidate_immediate_retrieval_success": immediate["retrieval_success"],
|
| 928 |
+
"candidate_immediate_visibility": immediate["visibility"],
|
| 929 |
+
"candidate_immediate_access": immediate["clearance"],
|
| 930 |
+
"candidate_immediate_progress": immediate["progress"],
|
| 931 |
+
"candidate_immediate_reocclusion": float(np.clip(1.0 - immediate["clearance"], 0.0, 1.0)),
|
| 932 |
+
"candidate_immediate_hold_persistence": float(1.0 - immediate["disturbance"]),
|
| 933 |
+
"candidate_immediate_support_stability": float(1.0 - 0.5 * immediate["disturbance"]),
|
| 934 |
+
"candidate_immediate_disturbance": immediate["disturbance"],
|
| 935 |
+
}
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
def _build_episode_splits(spec: SmokeSpec) -> dict[str, list[int]]:
|
| 939 |
+
return {
|
| 940 |
+
"train": [spec.dataset_seed * 10_000 + index for index in range(spec.train_episodes)],
|
| 941 |
+
"val": [spec.dataset_seed * 10_000 + 1_000 + index for index in range(spec.val_episodes)],
|
| 942 |
+
"eval": [spec.dataset_seed * 10_000 + 2_000 + index for index in range(spec.eval_episodes)],
|
| 943 |
+
}
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
def _save_episode_splits(output_path: Path, payload: dict[str, list[int]]) -> None:
|
| 947 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 948 |
+
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
def _init_history_entry(obs_bundle: dict[str, np.ndarray], proprio: np.ndarray, action_chunk: np.ndarray) -> dict[str, Any]:
|
| 952 |
+
return {
|
| 953 |
+
"images": obs_bundle["images"].copy(),
|
| 954 |
+
"depths": obs_bundle["depths"].copy(),
|
| 955 |
+
"depth_valid": obs_bundle["depth_valid"].copy(),
|
| 956 |
+
"camera_intrinsics": obs_bundle["camera_intrinsics"].copy(),
|
| 957 |
+
"camera_extrinsics": obs_bundle["camera_extrinsics"].copy(),
|
| 958 |
+
"proprio": proprio.copy(),
|
| 959 |
+
"action": action_chunk.mean(axis=0).astype(np.float32, copy=False),
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
|
| 963 |
+
def _history_stack(
|
| 964 |
+
history: Sequence[dict[str, Any]],
|
| 965 |
+
key: str,
|
| 966 |
+
*,
|
| 967 |
+
pad_shape: tuple[int, ...],
|
| 968 |
+
dtype: np.dtype,
|
| 969 |
+
history_steps: int,
|
| 970 |
+
) -> np.ndarray:
|
| 971 |
+
history = list(history)[-history_steps:]
|
| 972 |
+
pad_count = history_steps - len(history)
|
| 973 |
+
chunks = [np.zeros(pad_shape, dtype=dtype) for _ in range(pad_count)]
|
| 974 |
+
chunks.extend(np.asarray(item[key], dtype=dtype) for item in history)
|
| 975 |
+
return np.stack(chunks, axis=0).astype(dtype, copy=False)
|
| 976 |
+
|
| 977 |
+
|
| 978 |
+
class ManiSkillPickClutterDataset(Dataset[dict[str, Any]]):
|
| 979 |
+
def __init__(self, samples: Sequence[dict[str, Any]]) -> None:
|
| 980 |
+
self.samples = list(samples)
|
| 981 |
+
|
| 982 |
+
def __len__(self) -> int:
|
| 983 |
+
return len(self.samples)
|
| 984 |
+
|
| 985 |
+
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 986 |
+
sample = self.samples[index]
|
| 987 |
+
item = {
|
| 988 |
+
"images": torch.from_numpy(sample["images"]).permute(0, 3, 1, 2).float() / 255.0,
|
| 989 |
+
"depths": torch.from_numpy(sample["depths"]).float(),
|
| 990 |
+
"depth_valid": torch.from_numpy(sample["depth_valid"]).float(),
|
| 991 |
+
"camera_intrinsics": torch.from_numpy(sample["camera_intrinsics"]).float(),
|
| 992 |
+
"camera_extrinsics": torch.from_numpy(sample["camera_extrinsics"]).float(),
|
| 993 |
+
"history_images": torch.from_numpy(sample["history_images"]).permute(0, 1, 4, 2, 3).float() / 255.0,
|
| 994 |
+
"history_depths": torch.from_numpy(sample["history_depths"]).float(),
|
| 995 |
+
"history_depth_valid": torch.from_numpy(sample["history_depth_valid"]).float(),
|
| 996 |
+
"history_camera_intrinsics": torch.from_numpy(sample["history_camera_intrinsics"]).float(),
|
| 997 |
+
"history_camera_extrinsics": torch.from_numpy(sample["history_camera_extrinsics"]).float(),
|
| 998 |
+
"history_proprio": torch.from_numpy(sample["history_proprio"]).float(),
|
| 999 |
+
"history_actions": torch.from_numpy(sample["history_actions"]).float(),
|
| 1000 |
+
"proprio": torch.from_numpy(sample["proprio"]).float(),
|
| 1001 |
+
"texts": sample["language_goal"],
|
| 1002 |
+
"task_name": sample["task_name"],
|
| 1003 |
+
"task_id": torch.as_tensor(sample["task_id"], dtype=torch.long),
|
| 1004 |
+
"action_chunk": torch.from_numpy(sample["action_chunk"]).float(),
|
| 1005 |
+
"candidate_action_chunks": torch.from_numpy(sample["candidate_action_chunks"]).float(),
|
| 1006 |
+
"candidate_retrieval_success": torch.from_numpy(sample["candidate_retrieval_success"]).float(),
|
| 1007 |
+
"candidate_final_disturbance_cost": torch.from_numpy(sample["candidate_final_disturbance_cost"]).float(),
|
| 1008 |
+
"candidate_reocclusion_rate": torch.from_numpy(sample["candidate_reocclusion_rate"]).float(),
|
| 1009 |
+
"candidate_visibility_integral": torch.from_numpy(sample["candidate_visibility_integral"]).float(),
|
| 1010 |
+
"candidate_actor_feasibility_auc": torch.from_numpy(sample["candidate_actor_feasibility_auc"]).float(),
|
| 1011 |
+
"candidate_reveal_achieved": torch.from_numpy(sample["candidate_reveal_achieved"]).float(),
|
| 1012 |
+
"candidate_hold_persistence": torch.from_numpy(sample["candidate_hold_persistence"]).float(),
|
| 1013 |
+
"candidate_support_stability_auc": torch.from_numpy(sample["candidate_support_stability_auc"]).float(),
|
| 1014 |
+
"candidate_disturbance_auc": torch.from_numpy(sample["candidate_disturbance_auc"]).float(),
|
| 1015 |
+
"candidate_risk": torch.from_numpy(sample["candidate_risk"]).float(),
|
| 1016 |
+
"candidate_utility": torch.from_numpy(sample["candidate_utility"]).float(),
|
| 1017 |
+
"proposal_target_action_chunks": torch.from_numpy(sample["candidate_action_chunks"]).float(),
|
| 1018 |
+
"proposal_target_retrieval_success": torch.from_numpy(sample["candidate_retrieval_success"]).float(),
|
| 1019 |
+
"proposal_target_risk": torch.from_numpy(sample["candidate_risk"]).float(),
|
| 1020 |
+
"proposal_target_utility": torch.from_numpy(sample["candidate_utility"]).float(),
|
| 1021 |
+
"episode_seed": sample["episode_seed"],
|
| 1022 |
+
"decision_step": sample["decision_step"],
|
| 1023 |
+
"selected_mode": sample["selected_mode"],
|
| 1024 |
+
}
|
| 1025 |
+
if "support_mode" in sample:
|
| 1026 |
+
item["support_mode"] = torch.as_tensor(sample["support_mode"], dtype=torch.long)
|
| 1027 |
+
if "corridor_feasible" in sample:
|
| 1028 |
+
item["corridor_feasible"] = torch.from_numpy(sample["corridor_feasible"]).float()
|
| 1029 |
+
if "persistence_horizon" in sample:
|
| 1030 |
+
item["persistence_horizon"] = torch.from_numpy(sample["persistence_horizon"]).float()
|
| 1031 |
+
if "disturbance_cost" in sample:
|
| 1032 |
+
item["disturbance_cost"] = torch.as_tensor(sample["disturbance_cost"], dtype=torch.float32)
|
| 1033 |
+
if "state_confidence_target" in sample:
|
| 1034 |
+
item["state_confidence_target"] = torch.as_tensor(sample["state_confidence_target"], dtype=torch.float32)
|
| 1035 |
+
if "task_metric_mask" in sample:
|
| 1036 |
+
item["task_metric_mask"] = torch.from_numpy(sample["task_metric_mask"]).to(dtype=torch.bool)
|
| 1037 |
+
for metric_name in STATE_SUPERVISION_METRICS:
|
| 1038 |
+
if metric_name in sample:
|
| 1039 |
+
item[metric_name] = torch.as_tensor(sample[metric_name], dtype=torch.float32)
|
| 1040 |
+
for key in (
|
| 1041 |
+
"candidate_rollout_support_mode",
|
| 1042 |
+
"proposal_target_rollout_support_mode",
|
| 1043 |
+
):
|
| 1044 |
+
if key in sample:
|
| 1045 |
+
item[key] = torch.from_numpy(sample[key]).long()
|
| 1046 |
+
for key in (
|
| 1047 |
+
"candidate_rollout_corridor_feasible",
|
| 1048 |
+
"candidate_rollout_persistence_horizon",
|
| 1049 |
+
"candidate_rollout_disturbance_cost",
|
| 1050 |
+
"candidate_rollout_belief_map",
|
| 1051 |
+
"candidate_rollout_visibility_map",
|
| 1052 |
+
"candidate_rollout_clearance_map",
|
| 1053 |
+
"candidate_rollout_support_stability",
|
| 1054 |
+
"candidate_rollout_reocclusion_target",
|
| 1055 |
+
"candidate_rollout_occluder_contact_map",
|
| 1056 |
+
"candidate_rollout_grasp_affordance_map",
|
| 1057 |
+
"proposal_target_rollout_corridor_feasible",
|
| 1058 |
+
"proposal_target_rollout_persistence_horizon",
|
| 1059 |
+
"proposal_target_rollout_disturbance_cost",
|
| 1060 |
+
"proposal_target_rollout_belief_map",
|
| 1061 |
+
"proposal_target_rollout_visibility_map",
|
| 1062 |
+
"proposal_target_rollout_clearance_map",
|
| 1063 |
+
"proposal_target_rollout_support_stability",
|
| 1064 |
+
"proposal_target_rollout_reocclusion_target",
|
| 1065 |
+
"proposal_target_rollout_occluder_contact_map",
|
| 1066 |
+
"proposal_target_rollout_grasp_affordance_map",
|
| 1067 |
+
):
|
| 1068 |
+
if key in sample:
|
| 1069 |
+
item[key] = torch.from_numpy(sample[key]).float()
|
| 1070 |
+
return item
|
| 1071 |
+
|
| 1072 |
+
|
| 1073 |
+
def _make_loader(samples: Sequence[dict[str, Any]], *, batch_size: int, num_workers: int, shuffle: bool) -> DataLoader:
|
| 1074 |
+
return DataLoader(
|
| 1075 |
+
ManiSkillPickClutterDataset(samples),
|
| 1076 |
+
batch_size=batch_size,
|
| 1077 |
+
shuffle=shuffle,
|
| 1078 |
+
num_workers=num_workers,
|
| 1079 |
+
pin_memory=torch.cuda.is_available(),
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
def _load_init_bundle() -> tuple[PolicyConfig, dict[str, Any], dict[str, Any]]:
|
| 1084 |
+
checkpoint = torch.load(DEFAULT_INIT_CHECKPOINT, map_location="cpu", weights_only=False)
|
| 1085 |
+
policy_config = PolicyConfig(
|
| 1086 |
+
backbone=FrozenVLBackboneConfig(**checkpoint["policy_config"]["backbone"]),
|
| 1087 |
+
fusion=MultiViewFusionConfig(**checkpoint["policy_config"]["fusion"]),
|
| 1088 |
+
memory=ObservationMemoryConfig(**checkpoint["policy_config"]["memory"]),
|
| 1089 |
+
decoder=ChunkDecoderConfig(**checkpoint["policy_config"]["decoder"]),
|
| 1090 |
+
reveal_head=RevealHeadConfig(**checkpoint["policy_config"]["reveal_head"]),
|
| 1091 |
+
world_model=RevealWMConfig(**checkpoint["policy_config"]["world_model"]),
|
| 1092 |
+
planner=PlannerConfig(**checkpoint["policy_config"]["planner"]),
|
| 1093 |
+
)
|
| 1094 |
+
return _apply_smoke_planner_overrides(policy_config), checkpoint["trainer_config"], checkpoint["loss_weights"]
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
def _trainer_config_for_variant(variant: str) -> TrainerConfig:
|
| 1098 |
+
if variant == "trunk_only_ft":
|
| 1099 |
+
return TrainerConfig(
|
| 1100 |
+
policy_type="foundation_trunk",
|
| 1101 |
+
use_bf16=True,
|
| 1102 |
+
grad_clip_norm=1.0,
|
| 1103 |
+
freeze_backbone=True,
|
| 1104 |
+
gradient_checkpointing=False,
|
| 1105 |
+
trainable_parameter_prefixes=("fusion", "memory", "decoder"),
|
| 1106 |
+
)
|
| 1107 |
+
if variant == "adapter_active_ft":
|
| 1108 |
+
return TrainerConfig(
|
| 1109 |
+
policy_type="adapter_wrapped",
|
| 1110 |
+
use_bf16=True,
|
| 1111 |
+
grad_clip_norm=1.0,
|
| 1112 |
+
freeze_backbone=True,
|
| 1113 |
+
gradient_checkpointing=False,
|
| 1114 |
+
eval_mode="adapter_active",
|
| 1115 |
+
trainable_parameter_prefixes=(
|
| 1116 |
+
"trunk.fusion",
|
| 1117 |
+
"trunk.memory",
|
| 1118 |
+
"trunk.decoder",
|
| 1119 |
+
"adapter.state_head",
|
| 1120 |
+
"adapter.transition_model",
|
| 1121 |
+
"adapter.proposal_prior",
|
| 1122 |
+
"adapter.planner",
|
| 1123 |
+
),
|
| 1124 |
+
adapter_mode="adapter_active",
|
| 1125 |
+
adapter_use_transition_model=True,
|
| 1126 |
+
adapter_use_task_conditioning=True,
|
| 1127 |
+
adapter_action_supervision_source="trunk",
|
| 1128 |
+
)
|
| 1129 |
+
raise KeyError(f"Unsupported variant: {variant}")
|
| 1130 |
+
|
| 1131 |
+
|
| 1132 |
+
def _loss_weights_for_smoke() -> LossWeights:
|
| 1133 |
+
return LossWeights(
|
| 1134 |
+
action=1.0,
|
| 1135 |
+
support_mode=0.15,
|
| 1136 |
+
corridor=0.15,
|
| 1137 |
+
persistence=0.08,
|
| 1138 |
+
disturbance=0.08,
|
| 1139 |
+
planner_success=0.20,
|
| 1140 |
+
planner_risk=0.08,
|
| 1141 |
+
planner_ranking=0.20,
|
| 1142 |
+
proposal_reconstruction=0.10,
|
| 1143 |
+
proposal_success=0.12,
|
| 1144 |
+
proposal_ranking=0.15,
|
| 1145 |
+
proposal_mode=0.10,
|
| 1146 |
+
proposal_diversity=0.02,
|
| 1147 |
+
task_metrics=0.15,
|
| 1148 |
+
transition=0.25,
|
| 1149 |
+
gate=0.25,
|
| 1150 |
+
calibration=0.10,
|
| 1151 |
+
proposal_mode_task_filter=[TASK_NAME],
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
def _save_training_checkpoint(
|
| 1156 |
+
*,
|
| 1157 |
+
output_dir: Path,
|
| 1158 |
+
experiment_name: str,
|
| 1159 |
+
model: torch.nn.Module,
|
| 1160 |
+
policy_config: PolicyConfig,
|
| 1161 |
+
trainer_config: TrainerConfig,
|
| 1162 |
+
loss_weights: LossWeights,
|
| 1163 |
+
history: list[dict[str, Any]],
|
| 1164 |
+
best_val: float,
|
| 1165 |
+
train_spec: dict[str, Any],
|
| 1166 |
+
) -> Path:
|
| 1167 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 1168 |
+
checkpoint_path = output_dir / "checkpoint_best.pt"
|
| 1169 |
+
torch.save(
|
| 1170 |
+
{
|
| 1171 |
+
"experiment_name": experiment_name,
|
| 1172 |
+
"policy_config": asdict(policy_config),
|
| 1173 |
+
"trainer_config": asdict(trainer_config),
|
| 1174 |
+
"loss_weights": asdict(loss_weights),
|
| 1175 |
+
"state_dict": model.state_dict(),
|
| 1176 |
+
"history": history,
|
| 1177 |
+
"best_val_total": best_val,
|
| 1178 |
+
"train_spec": train_spec,
|
| 1179 |
+
},
|
| 1180 |
+
checkpoint_path,
|
| 1181 |
+
)
|
| 1182 |
+
return checkpoint_path
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
def _aggregate_epoch(loss_records: Sequence[dict[str, float]]) -> dict[str, float]:
|
| 1186 |
+
keys = sorted({key for record in loss_records for key in record})
|
| 1187 |
+
return {
|
| 1188 |
+
key: float(np.mean([record.get(key, 0.0) for record in loss_records])) if loss_records else 0.0
|
| 1189 |
+
for key in keys
|
| 1190 |
+
}
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
def _train_variant(
|
| 1194 |
+
*,
|
| 1195 |
+
variant: str,
|
| 1196 |
+
train_samples: Sequence[dict[str, Any]],
|
| 1197 |
+
val_samples: Sequence[dict[str, Any]],
|
| 1198 |
+
spec: SmokeSpec,
|
| 1199 |
+
output_dir: Path,
|
| 1200 |
+
) -> tuple[Path, dict[str, Any]]:
|
| 1201 |
+
policy_config, _init_trainer_cfg, _init_loss_weights = _load_init_bundle()
|
| 1202 |
+
trainer_config = _trainer_config_for_variant(variant)
|
| 1203 |
+
loss_weights = _loss_weights_for_smoke()
|
| 1204 |
+
model = build_policy(policy_config, trainer_config)
|
| 1205 |
+
init_info = _load_init_checkpoint(model, str(DEFAULT_INIT_CHECKPOINT), False)
|
| 1206 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1207 |
+
model = model.to(device)
|
| 1208 |
+
torch.manual_seed(spec.train_seed)
|
| 1209 |
+
if torch.cuda.is_available():
|
| 1210 |
+
torch.cuda.manual_seed_all(spec.train_seed)
|
| 1211 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 1212 |
+
matched = apply_trainable_parameter_prefixes(model, trainer_config)
|
| 1213 |
+
optimizer = torch.optim.AdamW(
|
| 1214 |
+
[parameter for parameter in model.parameters() if parameter.requires_grad],
|
| 1215 |
+
lr=spec.learning_rate,
|
| 1216 |
+
weight_decay=spec.weight_decay,
|
| 1217 |
+
)
|
| 1218 |
+
trainer = BimanualTrainer(model=model, optimizer=optimizer, config=trainer_config)
|
| 1219 |
+
train_loader = _make_loader(train_samples, batch_size=spec.batch_size, num_workers=spec.num_workers, shuffle=True)
|
| 1220 |
+
val_loader = _make_loader(val_samples, batch_size=spec.batch_size, num_workers=spec.num_workers, shuffle=False)
|
| 1221 |
+
|
| 1222 |
+
best_val = math.inf
|
| 1223 |
+
history: list[dict[str, Any]] = []
|
| 1224 |
+
train_spec = build_target_training_spec(
|
| 1225 |
+
track_id="occlusion_track",
|
| 1226 |
+
model_variant=variant,
|
| 1227 |
+
seed=spec.train_seed,
|
| 1228 |
+
train_demos=spec.train_episodes,
|
| 1229 |
+
val_demos=spec.val_episodes,
|
| 1230 |
+
init_checkpoint_group=str(DEFAULT_INIT_CHECKPOINT),
|
| 1231 |
+
optimizer="adamw",
|
| 1232 |
+
learning_rate=spec.learning_rate,
|
| 1233 |
+
lr_schedule="constant",
|
| 1234 |
+
batch_size=spec.batch_size,
|
| 1235 |
+
augmentations="none",
|
| 1236 |
+
early_stopping_metric="val_total",
|
| 1237 |
+
max_gradient_steps=len(train_loader) * spec.epochs,
|
| 1238 |
+
unfreeze_scope="fusion_memory_decoder",
|
| 1239 |
+
dataset_split_id=(
|
| 1240 |
+
f"pickclutter_{SMOKE_VERSION}_seed{spec.dataset_seed}"
|
| 1241 |
+
if int(spec.dataset_seed) == DEFAULT_SEED
|
| 1242 |
+
else f"pickclutter_{SMOKE_VERSION}_dataset_seed{spec.dataset_seed}"
|
| 1243 |
+
),
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
for epoch in range(spec.epochs):
|
| 1247 |
+
model.train()
|
| 1248 |
+
train_losses: list[dict[str, float]] = []
|
| 1249 |
+
for batch in train_loader:
|
| 1250 |
+
moved = _move_batch_to_device(batch, device)
|
| 1251 |
+
loss_dict = trainer.training_step(moved, loss_weights=loss_weights)
|
| 1252 |
+
train_losses.append({key: float(value.detach().cpu()) for key, value in loss_dict.items()})
|
| 1253 |
+
|
| 1254 |
+
model.eval()
|
| 1255 |
+
val_losses: list[dict[str, float]] = []
|
| 1256 |
+
with torch.no_grad():
|
| 1257 |
+
for batch in val_loader:
|
| 1258 |
+
moved = _move_batch_to_device(batch, device)
|
| 1259 |
+
forward_kwargs = {
|
| 1260 |
+
"images": moved["images"],
|
| 1261 |
+
"proprio": moved["proprio"],
|
| 1262 |
+
"texts": moved["texts"],
|
| 1263 |
+
"task_names": moved.get("task_name"),
|
| 1264 |
+
"task_ids": moved.get("task_id"),
|
| 1265 |
+
"history_images": moved.get("history_images"),
|
| 1266 |
+
"history_proprio": moved.get("history_proprio"),
|
| 1267 |
+
"history_actions": moved.get("history_actions"),
|
| 1268 |
+
"depths": moved.get("depths"),
|
| 1269 |
+
"depth_valid": moved.get("depth_valid"),
|
| 1270 |
+
"camera_intrinsics": moved.get("camera_intrinsics"),
|
| 1271 |
+
"camera_extrinsics": moved.get("camera_extrinsics"),
|
| 1272 |
+
"history_depths": moved.get("history_depths"),
|
| 1273 |
+
"history_depth_valid": moved.get("history_depth_valid"),
|
| 1274 |
+
"history_camera_intrinsics": moved.get("history_camera_intrinsics"),
|
| 1275 |
+
"history_camera_extrinsics": moved.get("history_camera_extrinsics"),
|
| 1276 |
+
}
|
| 1277 |
+
if variant == "adapter_active_ft":
|
| 1278 |
+
forward_kwargs["adapter_mode"] = "adapter_active"
|
| 1279 |
+
forward_kwargs["use_transition_model"] = True
|
| 1280 |
+
forward_kwargs["use_task_conditioning"] = True
|
| 1281 |
+
outputs = model(**forward_kwargs)
|
| 1282 |
+
from train.losses import compute_total_loss
|
| 1283 |
+
|
| 1284 |
+
losses = compute_total_loss(outputs, moved, weights=loss_weights)
|
| 1285 |
+
val_losses.append({key: float(value.detach().cpu()) for key, value in losses.items()})
|
| 1286 |
+
|
| 1287 |
+
train_summary = _aggregate_epoch(train_losses)
|
| 1288 |
+
val_summary = _aggregate_epoch(val_losses)
|
| 1289 |
+
history.append({"epoch": epoch, "train": train_summary, "val": val_summary})
|
| 1290 |
+
print(
|
| 1291 |
+
json.dumps(
|
| 1292 |
+
{
|
| 1293 |
+
"phase": "epoch_complete",
|
| 1294 |
+
"variant": variant,
|
| 1295 |
+
"epoch": epoch,
|
| 1296 |
+
"train_total": train_summary.get("total", 0.0),
|
| 1297 |
+
"val_total": val_summary.get("total", 0.0),
|
| 1298 |
+
}
|
| 1299 |
+
),
|
| 1300 |
+
flush=True,
|
| 1301 |
+
)
|
| 1302 |
+
if val_summary.get("total", math.inf) <= best_val:
|
| 1303 |
+
best_val = val_summary["total"]
|
| 1304 |
+
checkpoint_path = _save_training_checkpoint(
|
| 1305 |
+
output_dir=output_dir,
|
| 1306 |
+
experiment_name=f"{variant}_seed{spec.train_seed}",
|
| 1307 |
+
model=model,
|
| 1308 |
+
policy_config=policy_config,
|
| 1309 |
+
trainer_config=trainer_config,
|
| 1310 |
+
loss_weights=loss_weights,
|
| 1311 |
+
history=history,
|
| 1312 |
+
best_val=best_val,
|
| 1313 |
+
train_spec=train_spec,
|
| 1314 |
+
)
|
| 1315 |
+
(output_dir / "summary.json").write_text(
|
| 1316 |
+
json.dumps(
|
| 1317 |
+
{
|
| 1318 |
+
"variant": variant,
|
| 1319 |
+
"checkpoint_path": str(checkpoint_path),
|
| 1320 |
+
"init_info": init_info,
|
| 1321 |
+
"trainable_parameter_names": matched,
|
| 1322 |
+
"best_val_total": best_val,
|
| 1323 |
+
"history": history,
|
| 1324 |
+
"train_spec": train_spec,
|
| 1325 |
+
},
|
| 1326 |
+
indent=2,
|
| 1327 |
+
)
|
| 1328 |
+
+ "\n",
|
| 1329 |
+
encoding="utf-8",
|
| 1330 |
+
)
|
| 1331 |
+
return output_dir / "checkpoint_best.pt", train_spec
|
| 1332 |
+
|
| 1333 |
+
|
| 1334 |
+
def _load_checkpoint(
|
| 1335 |
+
checkpoint_path: Path,
|
| 1336 |
+
*,
|
| 1337 |
+
adapter_mode: str | None = None,
|
| 1338 |
+
planner_overrides: dict[str, float] | None = None,
|
| 1339 |
+
) -> tuple[torch.nn.Module, dict[str, Any]]:
|
| 1340 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 1341 |
+
policy_config = PolicyConfig(
|
| 1342 |
+
backbone=FrozenVLBackboneConfig(**checkpoint["policy_config"]["backbone"]),
|
| 1343 |
+
fusion=MultiViewFusionConfig(**checkpoint["policy_config"]["fusion"]),
|
| 1344 |
+
memory=ObservationMemoryConfig(**checkpoint["policy_config"]["memory"]),
|
| 1345 |
+
decoder=ChunkDecoderConfig(**checkpoint["policy_config"]["decoder"]),
|
| 1346 |
+
reveal_head=RevealHeadConfig(**checkpoint["policy_config"]["reveal_head"]),
|
| 1347 |
+
world_model=RevealWMConfig(**checkpoint["policy_config"]["world_model"]),
|
| 1348 |
+
planner=PlannerConfig(**checkpoint["policy_config"]["planner"]),
|
| 1349 |
+
)
|
| 1350 |
+
policy_config = _apply_smoke_planner_overrides(policy_config, planner_overrides=planner_overrides)
|
| 1351 |
+
trainer_config = TrainerConfig(**checkpoint["trainer_config"])
|
| 1352 |
+
if adapter_mode is not None and trainer_config.policy_type == "adapter_wrapped":
|
| 1353 |
+
trainer_config.adapter_mode = adapter_mode
|
| 1354 |
+
trainer_config.eval_mode = adapter_mode
|
| 1355 |
+
model = build_policy(policy_config, trainer_config)
|
| 1356 |
+
filtered_state_dict, skipped, _remapped = filter_compatible_state_dict(model.state_dict(), checkpoint["state_dict"])
|
| 1357 |
+
incompatible = model.load_state_dict(filtered_state_dict, strict=False)
|
| 1358 |
+
if incompatible.unexpected_keys:
|
| 1359 |
+
raise RuntimeError(f"Unexpected checkpoint keys for {checkpoint_path}: {list(incompatible.unexpected_keys)}")
|
| 1360 |
+
if skipped:
|
| 1361 |
+
checkpoint["_shape_skipped_keys"] = skipped
|
| 1362 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1363 |
+
model = model.to(device)
|
| 1364 |
+
model.eval()
|
| 1365 |
+
return model, checkpoint
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
def _collect_split(
|
| 1369 |
+
*,
|
| 1370 |
+
split_name: str,
|
| 1371 |
+
seeds: Sequence[int],
|
| 1372 |
+
spec: SmokeSpec,
|
| 1373 |
+
output_path: Path,
|
| 1374 |
+
) -> dict[str, Any]:
|
| 1375 |
+
obs_env = PickClutterRevealEnv(
|
| 1376 |
+
obs_mode="rgb+depth+segmentation",
|
| 1377 |
+
control_mode="pd_ee_delta_pos",
|
| 1378 |
+
render_mode="rgb_array",
|
| 1379 |
+
)
|
| 1380 |
+
sim_env = PickClutterRevealEnv(
|
| 1381 |
+
obs_mode="none",
|
| 1382 |
+
control_mode="pd_ee_delta_pos",
|
| 1383 |
+
render_mode="rgb_array",
|
| 1384 |
+
)
|
| 1385 |
+
samples: list[dict[str, Any]] = []
|
| 1386 |
+
episode_records: list[dict[str, Any]] = []
|
| 1387 |
+
try:
|
| 1388 |
+
for episode_seed in seeds:
|
| 1389 |
+
obs, _ = obs_env.reset(seed=int(episode_seed))
|
| 1390 |
+
sim_env.reset(seed=int(episode_seed))
|
| 1391 |
+
_sync_env_state(obs_env, sim_env)
|
| 1392 |
+
episode_start_positions = _all_positions(obs_env)
|
| 1393 |
+
history: deque[dict[str, Any]] = deque(maxlen=spec.history_steps)
|
| 1394 |
+
episode_success = False
|
| 1395 |
+
for decision_step in range(spec.max_macro_steps):
|
| 1396 |
+
obs_bundle = _extract_sensor_bundle(obs)
|
| 1397 |
+
proprio = _build_proprio(obs_env)
|
| 1398 |
+
snapshot = _snapshot_env(obs_env)
|
| 1399 |
+
candidate_metrics = [_evaluate_candidate(sim_env, obs_env, snapshot, mode_name) for mode_name in MODE_ORDER]
|
| 1400 |
+
candidate_chunks = np.stack([CANONICAL_CHUNKS[mode_name] for mode_name in MODE_ORDER], axis=0).astype(np.float32)
|
| 1401 |
+
utilities = np.asarray([payload["candidate_utility"] for payload in candidate_metrics], dtype=np.float32)
|
| 1402 |
+
best_index = int(utilities.argmax())
|
| 1403 |
+
selected_mode = MODE_ORDER[best_index]
|
| 1404 |
+
state_targets = _current_state_targets(
|
| 1405 |
+
obs_env,
|
| 1406 |
+
obs_bundle=obs_bundle,
|
| 1407 |
+
candidate_metrics=candidate_metrics,
|
| 1408 |
+
episode_start_positions=episode_start_positions,
|
| 1409 |
+
selected_mode=selected_mode,
|
| 1410 |
+
)
|
| 1411 |
+
rollout_targets_by_mode = [
|
| 1412 |
+
_candidate_rollout_targets(
|
| 1413 |
+
mode_name=mode_name,
|
| 1414 |
+
state_targets=state_targets,
|
| 1415 |
+
candidate_payload=payload,
|
| 1416 |
+
)
|
| 1417 |
+
for mode_name, payload in zip(MODE_ORDER, candidate_metrics)
|
| 1418 |
+
]
|
| 1419 |
+
sample = {
|
| 1420 |
+
"images": obs_bundle["images"].copy(),
|
| 1421 |
+
"depths": obs_bundle["depths"].copy(),
|
| 1422 |
+
"depth_valid": obs_bundle["depth_valid"].copy(),
|
| 1423 |
+
"camera_intrinsics": obs_bundle["camera_intrinsics"].copy(),
|
| 1424 |
+
"camera_extrinsics": obs_bundle["camera_extrinsics"].copy(),
|
| 1425 |
+
"history_images": _history_stack(history, "images", pad_shape=obs_bundle["images"].shape, dtype=np.uint8, history_steps=spec.history_steps),
|
| 1426 |
+
"history_depths": _history_stack(history, "depths", pad_shape=obs_bundle["depths"].shape, dtype=np.float32, history_steps=spec.history_steps),
|
| 1427 |
+
"history_depth_valid": _history_stack(history, "depth_valid", pad_shape=obs_bundle["depth_valid"].shape, dtype=np.float32, history_steps=spec.history_steps),
|
| 1428 |
+
"history_camera_intrinsics": _history_stack(history, "camera_intrinsics", pad_shape=obs_bundle["camera_intrinsics"].shape, dtype=np.float32, history_steps=spec.history_steps),
|
| 1429 |
+
"history_camera_extrinsics": _history_stack(history, "camera_extrinsics", pad_shape=obs_bundle["camera_extrinsics"].shape, dtype=np.float32, history_steps=spec.history_steps),
|
| 1430 |
+
"history_proprio": _history_stack(history, "proprio", pad_shape=(PROPRIO_DIM,), dtype=np.float32, history_steps=spec.history_steps),
|
| 1431 |
+
"history_actions": _history_stack(history, "action", pad_shape=(14,), dtype=np.float32, history_steps=spec.history_steps),
|
| 1432 |
+
"proprio": proprio.astype(np.float32),
|
| 1433 |
+
"language_goal": TEXT_PROMPT,
|
| 1434 |
+
"task_name": TASK_NAME,
|
| 1435 |
+
"task_id": TASK_ID,
|
| 1436 |
+
"action_chunk": CANONICAL_CHUNKS[selected_mode].copy(),
|
| 1437 |
+
"candidate_action_chunks": candidate_chunks,
|
| 1438 |
+
"candidate_retrieval_success": np.asarray([payload["candidate_retrieval_success"] for payload in candidate_metrics], dtype=np.float32),
|
| 1439 |
+
"candidate_final_disturbance_cost": np.asarray([payload["candidate_final_disturbance_cost"] for payload in candidate_metrics], dtype=np.float32),
|
| 1440 |
+
"candidate_reocclusion_rate": np.asarray([payload["candidate_reocclusion_rate"] for payload in candidate_metrics], dtype=np.float32),
|
| 1441 |
+
"candidate_visibility_integral": np.asarray([payload["candidate_visibility_integral"] for payload in candidate_metrics], dtype=np.float32),
|
| 1442 |
+
"candidate_actor_feasibility_auc": np.asarray([payload["candidate_actor_feasibility_auc"] for payload in candidate_metrics], dtype=np.float32),
|
| 1443 |
+
"candidate_reveal_achieved": np.asarray([payload["candidate_reveal_achieved"] for payload in candidate_metrics], dtype=np.float32),
|
| 1444 |
+
"candidate_hold_persistence": np.asarray([payload["candidate_hold_persistence"] for payload in candidate_metrics], dtype=np.float32),
|
| 1445 |
+
"candidate_support_stability_auc": np.asarray([payload["candidate_support_stability_auc"] for payload in candidate_metrics], dtype=np.float32),
|
| 1446 |
+
"candidate_disturbance_auc": np.asarray([payload["candidate_disturbance_auc"] for payload in candidate_metrics], dtype=np.float32),
|
| 1447 |
+
"candidate_risk": np.asarray([payload["candidate_risk"] for payload in candidate_metrics], dtype=np.float32),
|
| 1448 |
+
"candidate_utility": utilities,
|
| 1449 |
+
"candidate_rollout_support_mode": np.stack(
|
| 1450 |
+
[payload["candidate_rollout_support_mode"] for payload in rollout_targets_by_mode],
|
| 1451 |
+
axis=0,
|
| 1452 |
+
).astype(np.int64),
|
| 1453 |
+
"candidate_rollout_corridor_feasible": np.stack(
|
| 1454 |
+
[payload["candidate_rollout_corridor_feasible"] for payload in rollout_targets_by_mode],
|
| 1455 |
+
axis=0,
|
| 1456 |
+
).astype(np.float32),
|
| 1457 |
+
"candidate_rollout_persistence_horizon": np.stack(
|
| 1458 |
+
[payload["candidate_rollout_persistence_horizon"] for payload in rollout_targets_by_mode],
|
| 1459 |
+
axis=0,
|
| 1460 |
+
).astype(np.float32),
|
| 1461 |
+
"candidate_rollout_disturbance_cost": np.stack(
|
| 1462 |
+
[payload["candidate_rollout_disturbance_cost"] for payload in rollout_targets_by_mode],
|
| 1463 |
+
axis=0,
|
| 1464 |
+
).astype(np.float32),
|
| 1465 |
+
"candidate_rollout_belief_map": np.stack(
|
| 1466 |
+
[payload["candidate_rollout_belief_map"] for payload in rollout_targets_by_mode],
|
| 1467 |
+
axis=0,
|
| 1468 |
+
).astype(np.float32),
|
| 1469 |
+
"candidate_rollout_visibility_map": np.stack(
|
| 1470 |
+
[payload["candidate_rollout_visibility_map"] for payload in rollout_targets_by_mode],
|
| 1471 |
+
axis=0,
|
| 1472 |
+
).astype(np.float32),
|
| 1473 |
+
"candidate_rollout_clearance_map": np.stack(
|
| 1474 |
+
[payload["candidate_rollout_clearance_map"] for payload in rollout_targets_by_mode],
|
| 1475 |
+
axis=0,
|
| 1476 |
+
).astype(np.float32),
|
| 1477 |
+
"candidate_rollout_support_stability": np.stack(
|
| 1478 |
+
[payload["candidate_rollout_support_stability"] for payload in rollout_targets_by_mode],
|
| 1479 |
+
axis=0,
|
| 1480 |
+
).astype(np.float32),
|
| 1481 |
+
"candidate_rollout_reocclusion_target": np.stack(
|
| 1482 |
+
[payload["candidate_rollout_reocclusion_target"] for payload in rollout_targets_by_mode],
|
| 1483 |
+
axis=0,
|
| 1484 |
+
).astype(np.float32),
|
| 1485 |
+
"candidate_rollout_occluder_contact_map": np.stack(
|
| 1486 |
+
[payload["candidate_rollout_occluder_contact_map"] for payload in rollout_targets_by_mode],
|
| 1487 |
+
axis=0,
|
| 1488 |
+
).astype(np.float32),
|
| 1489 |
+
"candidate_rollout_grasp_affordance_map": np.stack(
|
| 1490 |
+
[payload["candidate_rollout_grasp_affordance_map"] for payload in rollout_targets_by_mode],
|
| 1491 |
+
axis=0,
|
| 1492 |
+
).astype(np.float32),
|
| 1493 |
+
"proposal_target_rollout_support_mode": np.stack(
|
| 1494 |
+
[payload["candidate_rollout_support_mode"] for payload in rollout_targets_by_mode],
|
| 1495 |
+
axis=0,
|
| 1496 |
+
).astype(np.int64),
|
| 1497 |
+
"proposal_target_rollout_corridor_feasible": np.stack(
|
| 1498 |
+
[payload["candidate_rollout_corridor_feasible"] for payload in rollout_targets_by_mode],
|
| 1499 |
+
axis=0,
|
| 1500 |
+
).astype(np.float32),
|
| 1501 |
+
"proposal_target_rollout_persistence_horizon": np.stack(
|
| 1502 |
+
[payload["candidate_rollout_persistence_horizon"] for payload in rollout_targets_by_mode],
|
| 1503 |
+
axis=0,
|
| 1504 |
+
).astype(np.float32),
|
| 1505 |
+
"proposal_target_rollout_disturbance_cost": np.stack(
|
| 1506 |
+
[payload["candidate_rollout_disturbance_cost"] for payload in rollout_targets_by_mode],
|
| 1507 |
+
axis=0,
|
| 1508 |
+
).astype(np.float32),
|
| 1509 |
+
"proposal_target_rollout_belief_map": np.stack(
|
| 1510 |
+
[payload["candidate_rollout_belief_map"] for payload in rollout_targets_by_mode],
|
| 1511 |
+
axis=0,
|
| 1512 |
+
).astype(np.float32),
|
| 1513 |
+
"proposal_target_rollout_visibility_map": np.stack(
|
| 1514 |
+
[payload["candidate_rollout_visibility_map"] for payload in rollout_targets_by_mode],
|
| 1515 |
+
axis=0,
|
| 1516 |
+
).astype(np.float32),
|
| 1517 |
+
"proposal_target_rollout_clearance_map": np.stack(
|
| 1518 |
+
[payload["candidate_rollout_clearance_map"] for payload in rollout_targets_by_mode],
|
| 1519 |
+
axis=0,
|
| 1520 |
+
).astype(np.float32),
|
| 1521 |
+
"proposal_target_rollout_support_stability": np.stack(
|
| 1522 |
+
[payload["candidate_rollout_support_stability"] for payload in rollout_targets_by_mode],
|
| 1523 |
+
axis=0,
|
| 1524 |
+
).astype(np.float32),
|
| 1525 |
+
"proposal_target_rollout_reocclusion_target": np.stack(
|
| 1526 |
+
[payload["candidate_rollout_reocclusion_target"] for payload in rollout_targets_by_mode],
|
| 1527 |
+
axis=0,
|
| 1528 |
+
).astype(np.float32),
|
| 1529 |
+
"proposal_target_rollout_occluder_contact_map": np.stack(
|
| 1530 |
+
[payload["candidate_rollout_occluder_contact_map"] for payload in rollout_targets_by_mode],
|
| 1531 |
+
axis=0,
|
| 1532 |
+
).astype(np.float32),
|
| 1533 |
+
"proposal_target_rollout_grasp_affordance_map": np.stack(
|
| 1534 |
+
[payload["candidate_rollout_grasp_affordance_map"] for payload in rollout_targets_by_mode],
|
| 1535 |
+
axis=0,
|
| 1536 |
+
).astype(np.float32),
|
| 1537 |
+
"episode_seed": int(episode_seed),
|
| 1538 |
+
"decision_step": int(decision_step),
|
| 1539 |
+
"selected_mode": selected_mode,
|
| 1540 |
+
**state_targets,
|
| 1541 |
+
}
|
| 1542 |
+
samples.append(sample)
|
| 1543 |
+
_execute_mode(obs_env, selected_mode)
|
| 1544 |
+
obs = obs_env.get_obs(obs_env.get_info())
|
| 1545 |
+
history.append(_init_history_entry(obs_bundle, proprio, CANONICAL_CHUNKS[selected_mode]))
|
| 1546 |
+
if _success_from_state(obs_env):
|
| 1547 |
+
episode_success = True
|
| 1548 |
+
break
|
| 1549 |
+
episode_records.append(
|
| 1550 |
+
{
|
| 1551 |
+
"episode_seed": int(episode_seed),
|
| 1552 |
+
"success": episode_success,
|
| 1553 |
+
"steps": len(history),
|
| 1554 |
+
}
|
| 1555 |
+
)
|
| 1556 |
+
print(
|
| 1557 |
+
json.dumps(
|
| 1558 |
+
{
|
| 1559 |
+
"phase": "collect_episode_complete",
|
| 1560 |
+
"split": split_name,
|
| 1561 |
+
"episode_seed": int(episode_seed),
|
| 1562 |
+
"success": episode_success,
|
| 1563 |
+
"steps": len(history),
|
| 1564 |
+
"samples_collected": len(samples),
|
| 1565 |
+
}
|
| 1566 |
+
),
|
| 1567 |
+
flush=True,
|
| 1568 |
+
)
|
| 1569 |
+
finally:
|
| 1570 |
+
obs_env.close()
|
| 1571 |
+
sim_env.close()
|
| 1572 |
+
|
| 1573 |
+
payload = {
|
| 1574 |
+
"split_name": split_name,
|
| 1575 |
+
"resolution": spec.resolution,
|
| 1576 |
+
"history_steps": spec.history_steps,
|
| 1577 |
+
"samples": samples,
|
| 1578 |
+
"episode_records": episode_records,
|
| 1579 |
+
}
|
| 1580 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1581 |
+
torch.save(payload, output_path)
|
| 1582 |
+
return payload
|
| 1583 |
+
|
| 1584 |
+
|
| 1585 |
+
def _load_split(path: Path) -> dict[str, Any]:
|
| 1586 |
+
return torch.load(path, map_location="cpu", weights_only=False)
|
| 1587 |
+
|
| 1588 |
+
|
| 1589 |
+
def _print_split_supervision_summary(split_name: str, samples: Sequence[dict[str, Any]]) -> None:
|
| 1590 |
+
mode_counter = collections.Counter(str(sample.get("selected_mode", "unknown")) for sample in samples)
|
| 1591 |
+
support_counter = collections.Counter(int(sample.get("support_mode", -1)) for sample in samples if "support_mode" in sample)
|
| 1592 |
+
confidence_values = [float(sample.get("state_confidence_target", 0.0)) for sample in samples if "state_confidence_target" in sample]
|
| 1593 |
+
payload = {
|
| 1594 |
+
"phase": "split_supervision_summary",
|
| 1595 |
+
"split": split_name,
|
| 1596 |
+
"samples": len(samples),
|
| 1597 |
+
"selected_modes": dict(mode_counter),
|
| 1598 |
+
"support_modes": dict(support_counter),
|
| 1599 |
+
"mean_state_confidence_target": float(np.mean(confidence_values)) if confidence_values else None,
|
| 1600 |
+
}
|
| 1601 |
+
print(json.dumps(payload, sort_keys=True), flush=True)
|
| 1602 |
+
|
| 1603 |
+
|
| 1604 |
+
def _batch_from_obs(obs_bundle: dict[str, np.ndarray], proprio: np.ndarray, history: Sequence[dict[str, Any]], device: torch.device) -> dict[str, Any]:
|
| 1605 |
+
return {
|
| 1606 |
+
"images": torch.from_numpy(obs_bundle["images"]).permute(0, 3, 1, 2).unsqueeze(0).float().div(255.0).to(device),
|
| 1607 |
+
"depths": torch.from_numpy(obs_bundle["depths"]).unsqueeze(0).float().to(device),
|
| 1608 |
+
"depth_valid": torch.from_numpy(obs_bundle["depth_valid"]).unsqueeze(0).float().to(device),
|
| 1609 |
+
"camera_intrinsics": torch.from_numpy(obs_bundle["camera_intrinsics"]).unsqueeze(0).float().to(device),
|
| 1610 |
+
"camera_extrinsics": torch.from_numpy(obs_bundle["camera_extrinsics"]).unsqueeze(0).float().to(device),
|
| 1611 |
+
"history_images": torch.from_numpy(
|
| 1612 |
+
_history_stack(history, "images", pad_shape=obs_bundle["images"].shape, dtype=np.uint8, history_steps=HISTORY_STEPS)
|
| 1613 |
+
).permute(0, 1, 4, 2, 3).unsqueeze(0).float().div(255.0).to(device),
|
| 1614 |
+
"history_depths": torch.from_numpy(
|
| 1615 |
+
_history_stack(history, "depths", pad_shape=obs_bundle["depths"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
|
| 1616 |
+
).unsqueeze(0).float().to(device),
|
| 1617 |
+
"history_depth_valid": torch.from_numpy(
|
| 1618 |
+
_history_stack(history, "depth_valid", pad_shape=obs_bundle["depth_valid"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
|
| 1619 |
+
).unsqueeze(0).float().to(device),
|
| 1620 |
+
"history_camera_intrinsics": torch.from_numpy(
|
| 1621 |
+
_history_stack(history, "camera_intrinsics", pad_shape=obs_bundle["camera_intrinsics"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
|
| 1622 |
+
).unsqueeze(0).float().to(device),
|
| 1623 |
+
"history_camera_extrinsics": torch.from_numpy(
|
| 1624 |
+
_history_stack(history, "camera_extrinsics", pad_shape=obs_bundle["camera_extrinsics"].shape, dtype=np.float32, history_steps=HISTORY_STEPS)
|
| 1625 |
+
).unsqueeze(0).float().to(device),
|
| 1626 |
+
"history_proprio": torch.from_numpy(
|
| 1627 |
+
_history_stack(history, "proprio", pad_shape=(PROPRIO_DIM,), dtype=np.float32, history_steps=HISTORY_STEPS)
|
| 1628 |
+
).unsqueeze(0).float().to(device),
|
| 1629 |
+
"history_actions": torch.from_numpy(
|
| 1630 |
+
_history_stack(history, "action", pad_shape=(14,), dtype=np.float32, history_steps=HISTORY_STEPS)
|
| 1631 |
+
).unsqueeze(0).float().to(device),
|
| 1632 |
+
"proprio": torch.from_numpy(proprio).unsqueeze(0).float().to(device),
|
| 1633 |
+
"texts": [TEXT_PROMPT],
|
| 1634 |
+
"task_names": [TASK_NAME],
|
| 1635 |
+
"task_ids": torch.as_tensor([TASK_ID], dtype=torch.long, device=device),
|
| 1636 |
+
}
|
| 1637 |
+
|
| 1638 |
+
|
| 1639 |
+
def _eval_mode_name(
|
| 1640 |
+
model_output: dict[str, Any],
|
| 1641 |
+
checkpoint_mode: str,
|
| 1642 |
+
) -> tuple[str, bool, bool]:
|
| 1643 |
+
if checkpoint_mode == "adapter_active_ft" and "proposal_mode_names" in model_output and "best_candidate_indices" in model_output:
|
| 1644 |
+
active_mask = bool(_np(model_output.get("adapter_active_mask", np.asarray([False]))).reshape(-1)[0])
|
| 1645 |
+
if not active_mask:
|
| 1646 |
+
mode_name = _classify_mode_from_chunk(_np(model_output["action_mean"])[0])
|
| 1647 |
+
return mode_name, False, False
|
| 1648 |
+
best_index = int(_np(model_output["best_candidate_indices"])[0])
|
| 1649 |
+
proposal_mode_names = model_output["proposal_mode_names"][0]
|
| 1650 |
+
mode_name = str(proposal_mode_names[best_index]) if best_index < len(proposal_mode_names) else _classify_mode_from_chunk(
|
| 1651 |
+
_np(model_output["action_mean"])[0]
|
| 1652 |
+
)
|
| 1653 |
+
non_base = bool(best_index > 0)
|
| 1654 |
+
return mode_name, active_mask, non_base
|
| 1655 |
+
mode_name = _classify_mode_from_chunk(_np(model_output["action_mean"])[0])
|
| 1656 |
+
return mode_name, False, False
|
| 1657 |
+
|
| 1658 |
+
|
| 1659 |
+
def _evaluate_checkpoint(
|
| 1660 |
+
*,
|
| 1661 |
+
checkpoint_path: Path,
|
| 1662 |
+
adapter_mode: str,
|
| 1663 |
+
result_mode_name: str,
|
| 1664 |
+
seeds: Sequence[int],
|
| 1665 |
+
report_path: Path,
|
| 1666 |
+
train_spec: dict[str, Any] | None,
|
| 1667 |
+
dataset_seed: int,
|
| 1668 |
+
planner_overrides: dict[str, float] | None = None,
|
| 1669 |
+
) -> dict[str, Any]:
|
| 1670 |
+
model, checkpoint = _load_checkpoint(
|
| 1671 |
+
checkpoint_path,
|
| 1672 |
+
adapter_mode=adapter_mode if adapter_mode != "trunk_only" else None,
|
| 1673 |
+
planner_overrides=planner_overrides,
|
| 1674 |
+
)
|
| 1675 |
+
device = next(model.parameters()).device
|
| 1676 |
+
obs_env = PickClutterRevealEnv(
|
| 1677 |
+
obs_mode="rgb+depth+segmentation",
|
| 1678 |
+
control_mode="pd_ee_delta_pos",
|
| 1679 |
+
render_mode="rgb_array",
|
| 1680 |
+
)
|
| 1681 |
+
sim_env = PickClutterRevealEnv(
|
| 1682 |
+
obs_mode="none",
|
| 1683 |
+
control_mode="pd_ee_delta_pos",
|
| 1684 |
+
render_mode="rgb_array",
|
| 1685 |
+
)
|
| 1686 |
+
successes: list[int] = []
|
| 1687 |
+
episode_records: list[dict[str, Any]] = []
|
| 1688 |
+
reveal_steps: list[int] = []
|
| 1689 |
+
retrieve_steps: list[int] = []
|
| 1690 |
+
disturbance_values: list[float] = []
|
| 1691 |
+
intervention_events = 0
|
| 1692 |
+
non_base_events = 0
|
| 1693 |
+
total_decisions = 0
|
| 1694 |
+
try:
|
| 1695 |
+
for episode_seed in seeds:
|
| 1696 |
+
obs, _ = obs_env.reset(seed=int(episode_seed))
|
| 1697 |
+
sim_env.reset(seed=int(episode_seed))
|
| 1698 |
+
_sync_env_state(obs_env, sim_env)
|
| 1699 |
+
history: deque[dict[str, Any]] = deque(maxlen=HISTORY_STEPS)
|
| 1700 |
+
success = False
|
| 1701 |
+
first_reveal_step: int | None = None
|
| 1702 |
+
first_retrieve_step: int | None = None
|
| 1703 |
+
episode_disturbance: list[float] = []
|
| 1704 |
+
for decision_step in range(MAX_MACRO_STEPS):
|
| 1705 |
+
obs_bundle = _extract_sensor_bundle(obs)
|
| 1706 |
+
proprio = _build_proprio(obs_env)
|
| 1707 |
+
batch = _batch_from_obs(obs_bundle, proprio, list(history), device)
|
| 1708 |
+
with torch.no_grad():
|
| 1709 |
+
if adapter_mode == "trunk_only":
|
| 1710 |
+
outputs = model(**batch)
|
| 1711 |
+
else:
|
| 1712 |
+
outputs = model(
|
| 1713 |
+
**batch,
|
| 1714 |
+
adapter_mode=adapter_mode,
|
| 1715 |
+
use_transition_model=True,
|
| 1716 |
+
use_task_conditioning=True,
|
| 1717 |
+
)
|
| 1718 |
+
selected_mode, active_mask, non_base = _eval_mode_name(outputs, result_mode_name)
|
| 1719 |
+
start_positions = _all_positions(obs_env)
|
| 1720 |
+
_sync_env_state(obs_env, sim_env)
|
| 1721 |
+
_execute_mode(sim_env, selected_mode)
|
| 1722 |
+
end_metrics = _candidate_metrics(sim_env, start_positions=start_positions, current_obs_bundle=None)
|
| 1723 |
+
_sync_env_state(sim_env, obs_env)
|
| 1724 |
+
obs = obs_env.get_obs(obs_env.get_info())
|
| 1725 |
+
history.append(_init_history_entry(obs_bundle, proprio, CANONICAL_CHUNKS.get(selected_mode, CANONICAL_CHUNKS["base_action"])))
|
| 1726 |
+
total_decisions += 1
|
| 1727 |
+
intervention_events += int(active_mask)
|
| 1728 |
+
non_base_events += int(non_base)
|
| 1729 |
+
episode_disturbance.append(end_metrics["disturbance"])
|
| 1730 |
+
if selected_mode != "retrieve" and selected_mode not in {"base_action", "maintain_gap"} and first_reveal_step is None:
|
| 1731 |
+
first_reveal_step = decision_step + 1
|
| 1732 |
+
if selected_mode == "retrieve" and first_retrieve_step is None:
|
| 1733 |
+
first_retrieve_step = decision_step + 1
|
| 1734 |
+
if _success_from_state(obs_env):
|
| 1735 |
+
success = True
|
| 1736 |
+
break
|
| 1737 |
+
successes.append(int(success))
|
| 1738 |
+
if first_reveal_step is not None:
|
| 1739 |
+
reveal_steps.append(first_reveal_step)
|
| 1740 |
+
if first_retrieve_step is not None:
|
| 1741 |
+
retrieve_steps.append(first_retrieve_step)
|
| 1742 |
+
disturbance_values.append(float(np.mean(episode_disturbance)) if episode_disturbance else 0.0)
|
| 1743 |
+
episode_records.append(
|
| 1744 |
+
{
|
| 1745 |
+
"episode_seed": int(episode_seed),
|
| 1746 |
+
"success": success,
|
| 1747 |
+
"steps": len(history),
|
| 1748 |
+
"first_reveal_step": first_reveal_step,
|
| 1749 |
+
"first_retrieve_step": first_retrieve_step,
|
| 1750 |
+
"episode_disturbance": float(np.mean(episode_disturbance)) if episode_disturbance else 0.0,
|
| 1751 |
+
}
|
| 1752 |
+
)
|
| 1753 |
+
print(
|
| 1754 |
+
json.dumps(
|
| 1755 |
+
{
|
| 1756 |
+
"phase": "eval_episode_complete",
|
| 1757 |
+
"adapter_mode": result_mode_name,
|
| 1758 |
+
"episode_seed": int(episode_seed),
|
| 1759 |
+
"success": success,
|
| 1760 |
+
"steps": len(history),
|
| 1761 |
+
}
|
| 1762 |
+
),
|
| 1763 |
+
flush=True,
|
| 1764 |
+
)
|
| 1765 |
+
finally:
|
| 1766 |
+
obs_env.close()
|
| 1767 |
+
sim_env.close()
|
| 1768 |
+
|
| 1769 |
+
eval_protocol = build_public_eval_protocol(
|
| 1770 |
+
track_id="occlusion_track",
|
| 1771 |
+
eval_mode=result_mode_name,
|
| 1772 |
+
seed=int(dataset_seed),
|
| 1773 |
+
episodes=len(seeds),
|
| 1774 |
+
resolution=224,
|
| 1775 |
+
cameras=CAMERA_NAMES,
|
| 1776 |
+
)
|
| 1777 |
+
payload = {
|
| 1778 |
+
"track_id": "occlusion_track",
|
| 1779 |
+
"suite": "maniskill3",
|
| 1780 |
+
"benchmark_task": "PickClutterYCB-v1",
|
| 1781 |
+
"role": "target",
|
| 1782 |
+
"adapter_mode": result_mode_name,
|
| 1783 |
+
"episodes": len(seeds),
|
| 1784 |
+
"successes": successes,
|
| 1785 |
+
"success_rate": float(np.mean(successes)) if successes else 0.0,
|
| 1786 |
+
"intervention_rate": float(intervention_events / max(1, total_decisions)),
|
| 1787 |
+
"non_base_selection_rate": float(non_base_events / max(1, total_decisions)),
|
| 1788 |
+
"steps_to_first_reveal_or_access": float(np.mean(reveal_steps)) if reveal_steps else float(MAX_MACRO_STEPS),
|
| 1789 |
+
"steps_to_retrieve": float(np.mean(retrieve_steps)) if retrieve_steps else float(MAX_MACRO_STEPS),
|
| 1790 |
+
"disturbance_proxy": float(np.mean(disturbance_values)) if disturbance_values else 0.0,
|
| 1791 |
+
"episode_records": episode_records,
|
| 1792 |
+
"eval_protocol": eval_protocol,
|
| 1793 |
+
}
|
| 1794 |
+
if train_spec is not None:
|
| 1795 |
+
payload["train_spec"] = train_spec
|
| 1796 |
+
report_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1797 |
+
report_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
|
| 1798 |
+
return payload
|
| 1799 |
+
|
| 1800 |
+
|
| 1801 |
+
def _summarize_smoke(results: Sequence[dict[str, Any]], output_dir: Path) -> dict[str, Any]:
|
| 1802 |
+
summary = summarize_public_benchmark_package(list(results), allow_partial=True)
|
| 1803 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 1804 |
+
json_path = output_dir / "public_benchmark_package_summary.json"
|
| 1805 |
+
md_path = output_dir / "public_benchmark_package_summary.md"
|
| 1806 |
+
json_path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
| 1807 |
+
lines = [
|
| 1808 |
+
"# ManiSkill PickClutter Smoke Summary",
|
| 1809 |
+
"",
|
| 1810 |
+
f"- available_tracks: {summary['available_tracks']}",
|
| 1811 |
+
f"- target_macro_average_delta: {summary['target_macro_average_delta']:.3f}",
|
| 1812 |
+
f"- headline_pass: {summary['headline_pass']}",
|
| 1813 |
+
f"- sign_of_life_pass: {summary['sign_of_life_pass']}",
|
| 1814 |
+
"",
|
| 1815 |
+
]
|
| 1816 |
+
for track_id, payload in summary["tracks"].items():
|
| 1817 |
+
lines.append(f"## {track_id}")
|
| 1818 |
+
lines.append(f"- delta_active_vs_trunk: {payload.get('delta_active_vs_trunk', 0.0):.3f}")
|
| 1819 |
+
lines.append(f"- delta_noop_vs_trunk: {payload.get('delta_noop_vs_trunk', 0.0):.3f}")
|
| 1820 |
+
lines.append(f"- signs_of_life: {payload.get('signs_of_life', False)}")
|
| 1821 |
+
for mode, mode_payload in payload["modes"].items():
|
| 1822 |
+
lines.append(f"- {mode}: mean_success={mode_payload['mean_success']:.3f}")
|
| 1823 |
+
lines.append("")
|
| 1824 |
+
md_path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
|
| 1825 |
+
return summary
|
| 1826 |
+
|
| 1827 |
+
|
| 1828 |
+
def _default_paths() -> SmokePaths:
|
| 1829 |
+
return SmokePaths()
|
| 1830 |
+
|
| 1831 |
+
|
| 1832 |
+
def _dataset_artifact_path(data_dir: Path, basename: str, *, dataset_seed: int) -> Path:
|
| 1833 |
+
if int(dataset_seed) == DEFAULT_SEED:
|
| 1834 |
+
return data_dir / basename
|
| 1835 |
+
artifact = Path(basename)
|
| 1836 |
+
return data_dir / f"{artifact.stem}_seed{int(dataset_seed)}{artifact.suffix}"
|
| 1837 |
+
|
| 1838 |
+
|
| 1839 |
+
def _parse_args() -> argparse.Namespace:
|
| 1840 |
+
parser = argparse.ArgumentParser(description="Minimum-sign-of-life ManiSkill PickClutter smoke run.")
|
| 1841 |
+
parser.add_argument("--stage", choices=("collect", "train", "eval", "all"), default="all")
|
| 1842 |
+
parser.add_argument("--data-dir", type=Path, default=_default_paths().data_dir)
|
| 1843 |
+
parser.add_argument("--output-dir", type=Path, default=_default_paths().output_dir)
|
| 1844 |
+
parser.add_argument("--report-dir", type=Path, default=_default_paths().report_dir)
|
| 1845 |
+
parser.add_argument("--seed", type=int, default=None, help="Deprecated alias for train/eval dataset seed.")
|
| 1846 |
+
parser.add_argument("--train-seed", type=int, default=None)
|
| 1847 |
+
parser.add_argument("--dataset-seed", type=int, default=None)
|
| 1848 |
+
parser.add_argument("--eval-split", choices=("val", "eval"), default="eval")
|
| 1849 |
+
parser.add_argument("--adapter-confidence-threshold", type=float, default=None)
|
| 1850 |
+
parser.add_argument("--retrieve-access-threshold", type=float, default=None)
|
| 1851 |
+
parser.add_argument("--retrieve-persistence-threshold", type=float, default=None)
|
| 1852 |
+
parser.add_argument("--retrieve-support-threshold", type=float, default=None)
|
| 1853 |
+
parser.add_argument("--retrieve-reocclusion-threshold", type=float, default=None)
|
| 1854 |
+
parser.add_argument("--planner-mode-preference-bonus", type=float, default=None)
|
| 1855 |
+
parser.add_argument("--planner-premature-retrieve-penalty", type=float, default=None)
|
| 1856 |
+
parser.add_argument("--planner-premature-insert-penalty", type=float, default=None)
|
| 1857 |
+
parser.add_argument("--planner-premature-occlusion-sweep-penalty", type=float, default=None)
|
| 1858 |
+
parser.add_argument("--planner-premature-maintain-penalty", type=float, default=None)
|
| 1859 |
+
parser.add_argument("--planner-retrieve-stage-access-threshold", type=float, default=None)
|
| 1860 |
+
parser.add_argument("--planner-retrieve-stage-reveal-threshold", type=float, default=None)
|
| 1861 |
+
parser.add_argument("--planner-retrieve-stage-persistence-threshold", type=float, default=None)
|
| 1862 |
+
parser.add_argument("--planner-retrieve-stage-support-threshold", type=float, default=None)
|
| 1863 |
+
parser.add_argument("--planner-insert-stage-access-threshold", type=float, default=None)
|
| 1864 |
+
parser.add_argument("--planner-insert-stage-visibility-threshold", type=float, default=None)
|
| 1865 |
+
parser.add_argument("--planner-insert-stage-support-threshold", type=float, default=None)
|
| 1866 |
+
parser.add_argument("--planner-occlusion-maintain-gap-min-access", type=float, default=None)
|
| 1867 |
+
parser.add_argument("--planner-occlusion-maintain-gap-min-visibility", type=float, default=None)
|
| 1868 |
+
return parser.parse_args()
|
| 1869 |
+
|
| 1870 |
+
|
| 1871 |
+
def _planner_overrides_from_args(args: argparse.Namespace) -> dict[str, float]:
|
| 1872 |
+
override_pairs = (
|
| 1873 |
+
("adapter_confidence_threshold", args.adapter_confidence_threshold),
|
| 1874 |
+
("retrieve_access_threshold", args.retrieve_access_threshold),
|
| 1875 |
+
("retrieve_persistence_threshold", args.retrieve_persistence_threshold),
|
| 1876 |
+
("retrieve_support_threshold", args.retrieve_support_threshold),
|
| 1877 |
+
("retrieve_reocclusion_threshold", args.retrieve_reocclusion_threshold),
|
| 1878 |
+
("mode_preference_bonus", args.planner_mode_preference_bonus),
|
| 1879 |
+
("premature_retrieve_penalty", args.planner_premature_retrieve_penalty),
|
| 1880 |
+
("premature_insert_penalty", args.planner_premature_insert_penalty),
|
| 1881 |
+
("premature_occlusion_sweep_penalty", args.planner_premature_occlusion_sweep_penalty),
|
| 1882 |
+
("premature_maintain_penalty", args.planner_premature_maintain_penalty),
|
| 1883 |
+
("retrieve_stage_access_threshold", args.planner_retrieve_stage_access_threshold),
|
| 1884 |
+
("retrieve_stage_reveal_threshold", args.planner_retrieve_stage_reveal_threshold),
|
| 1885 |
+
("retrieve_stage_persistence_threshold", args.planner_retrieve_stage_persistence_threshold),
|
| 1886 |
+
("retrieve_stage_support_threshold", args.planner_retrieve_stage_support_threshold),
|
| 1887 |
+
("insert_stage_access_threshold", args.planner_insert_stage_access_threshold),
|
| 1888 |
+
("insert_stage_visibility_threshold", args.planner_insert_stage_visibility_threshold),
|
| 1889 |
+
("insert_stage_support_threshold", args.planner_insert_stage_support_threshold),
|
| 1890 |
+
("occlusion_maintain_gap_min_access", args.planner_occlusion_maintain_gap_min_access),
|
| 1891 |
+
("occlusion_maintain_gap_min_visibility", args.planner_occlusion_maintain_gap_min_visibility),
|
| 1892 |
+
)
|
| 1893 |
+
return {key: value for key, value in override_pairs if value is not None}
|
| 1894 |
+
|
| 1895 |
+
|
| 1896 |
+
def main() -> None:
|
| 1897 |
+
args = _parse_args()
|
| 1898 |
+
planner_overrides = _planner_overrides_from_args(args)
|
| 1899 |
+
base_seed = DEFAULT_SEED if args.seed is None else int(args.seed)
|
| 1900 |
+
train_seed = int(args.train_seed) if args.train_seed is not None else base_seed
|
| 1901 |
+
dataset_seed = int(args.dataset_seed) if args.dataset_seed is not None else base_seed
|
| 1902 |
+
spec = SmokeSpec(dataset_seed=dataset_seed, train_seed=train_seed)
|
| 1903 |
+
splits = _build_episode_splits(spec)
|
| 1904 |
+
split_path = _dataset_artifact_path(args.data_dir, "episode_splits.json", dataset_seed=spec.dataset_seed)
|
| 1905 |
+
train_path = _dataset_artifact_path(args.data_dir, "train.pt", dataset_seed=spec.dataset_seed)
|
| 1906 |
+
val_path = _dataset_artifact_path(args.data_dir, "val.pt", dataset_seed=spec.dataset_seed)
|
| 1907 |
+
|
| 1908 |
+
if args.stage in {"collect", "all"}:
|
| 1909 |
+
_save_episode_splits(split_path, splits)
|
| 1910 |
+
if not train_path.exists():
|
| 1911 |
+
print(json.dumps({"phase": "collect_train_start", "episodes": len(splits["train"])}), flush=True)
|
| 1912 |
+
_collect_split(split_name="train", seeds=splits["train"], spec=spec, output_path=train_path)
|
| 1913 |
+
if not val_path.exists():
|
| 1914 |
+
print(json.dumps({"phase": "collect_val_start", "episodes": len(splits["val"])}), flush=True)
|
| 1915 |
+
_collect_split(split_name="val", seeds=splits["val"], spec=spec, output_path=val_path)
|
| 1916 |
+
|
| 1917 |
+
if args.stage == "collect":
|
| 1918 |
+
return
|
| 1919 |
+
|
| 1920 |
+
train_bundle = _load_split(train_path)
|
| 1921 |
+
val_bundle = _load_split(val_path)
|
| 1922 |
+
train_samples = train_bundle["samples"]
|
| 1923 |
+
val_samples = val_bundle["samples"]
|
| 1924 |
+
_print_split_supervision_summary("train", train_samples)
|
| 1925 |
+
_print_split_supervision_summary("val", val_samples)
|
| 1926 |
+
|
| 1927 |
+
trunk_checkpoint = args.output_dir / f"trunk_only_ft_seed{spec.train_seed}" / "checkpoint_best.pt"
|
| 1928 |
+
adapter_checkpoint = args.output_dir / f"adapter_active_ft_seed{spec.train_seed}" / "checkpoint_best.pt"
|
| 1929 |
+
trunk_train_spec: dict[str, Any] | None = None
|
| 1930 |
+
adapter_train_spec: dict[str, Any] | None = None
|
| 1931 |
+
|
| 1932 |
+
if args.stage in {"train", "all"}:
|
| 1933 |
+
if not trunk_checkpoint.exists():
|
| 1934 |
+
print(json.dumps({"phase": "train_variant_start", "variant": "trunk_only_ft"}), flush=True)
|
| 1935 |
+
trunk_checkpoint, trunk_train_spec = _train_variant(
|
| 1936 |
+
variant="trunk_only_ft",
|
| 1937 |
+
train_samples=train_samples,
|
| 1938 |
+
val_samples=val_samples,
|
| 1939 |
+
spec=spec,
|
| 1940 |
+
output_dir=args.output_dir / f"trunk_only_ft_seed{spec.train_seed}",
|
| 1941 |
+
)
|
| 1942 |
+
else:
|
| 1943 |
+
trunk_payload = torch.load(trunk_checkpoint, map_location="cpu", weights_only=False)
|
| 1944 |
+
trunk_train_spec = trunk_payload.get("train_spec")
|
| 1945 |
+
if not adapter_checkpoint.exists():
|
| 1946 |
+
print(json.dumps({"phase": "train_variant_start", "variant": "adapter_active_ft"}), flush=True)
|
| 1947 |
+
adapter_checkpoint, adapter_train_spec = _train_variant(
|
| 1948 |
+
variant="adapter_active_ft",
|
| 1949 |
+
train_samples=train_samples,
|
| 1950 |
+
val_samples=val_samples,
|
| 1951 |
+
spec=spec,
|
| 1952 |
+
output_dir=args.output_dir / f"adapter_active_ft_seed{spec.train_seed}",
|
| 1953 |
+
)
|
| 1954 |
+
else:
|
| 1955 |
+
adapter_payload = torch.load(adapter_checkpoint, map_location="cpu", weights_only=False)
|
| 1956 |
+
adapter_train_spec = adapter_payload.get("train_spec")
|
| 1957 |
+
|
| 1958 |
+
if args.stage == "train":
|
| 1959 |
+
return
|
| 1960 |
+
|
| 1961 |
+
if trunk_train_spec is None and trunk_checkpoint.exists():
|
| 1962 |
+
trunk_payload = torch.load(trunk_checkpoint, map_location="cpu", weights_only=False)
|
| 1963 |
+
trunk_train_spec = trunk_payload.get("train_spec")
|
| 1964 |
+
if adapter_train_spec is None and adapter_checkpoint.exists():
|
| 1965 |
+
adapter_payload = torch.load(adapter_checkpoint, map_location="cpu", weights_only=False)
|
| 1966 |
+
adapter_train_spec = adapter_payload.get("train_spec")
|
| 1967 |
+
|
| 1968 |
+
eval_seeds = splits[args.eval_split]
|
| 1969 |
+
print(json.dumps({"phase": "eval_start", "episodes": len(eval_seeds)}), flush=True)
|
| 1970 |
+
trunk_result = _evaluate_checkpoint(
|
| 1971 |
+
checkpoint_path=trunk_checkpoint,
|
| 1972 |
+
adapter_mode="trunk_only",
|
| 1973 |
+
result_mode_name="trunk_only_ft",
|
| 1974 |
+
seeds=eval_seeds,
|
| 1975 |
+
report_path=args.report_dir / f"trunk_only_ft_seed{spec.train_seed}.json",
|
| 1976 |
+
train_spec=trunk_train_spec,
|
| 1977 |
+
dataset_seed=spec.dataset_seed,
|
| 1978 |
+
planner_overrides=planner_overrides,
|
| 1979 |
+
)
|
| 1980 |
+
noop_result = _evaluate_checkpoint(
|
| 1981 |
+
checkpoint_path=adapter_checkpoint,
|
| 1982 |
+
adapter_mode="adapter_noop",
|
| 1983 |
+
result_mode_name="adapter_noop",
|
| 1984 |
+
seeds=eval_seeds,
|
| 1985 |
+
report_path=args.report_dir / f"adapter_noop_seed{spec.train_seed}.json",
|
| 1986 |
+
train_spec=adapter_train_spec,
|
| 1987 |
+
dataset_seed=spec.dataset_seed,
|
| 1988 |
+
planner_overrides=planner_overrides,
|
| 1989 |
+
)
|
| 1990 |
+
active_result = _evaluate_checkpoint(
|
| 1991 |
+
checkpoint_path=adapter_checkpoint,
|
| 1992 |
+
adapter_mode="adapter_active",
|
| 1993 |
+
result_mode_name="adapter_active_ft",
|
| 1994 |
+
seeds=eval_seeds,
|
| 1995 |
+
report_path=args.report_dir / f"adapter_active_ft_seed{spec.train_seed}.json",
|
| 1996 |
+
train_spec=adapter_train_spec,
|
| 1997 |
+
dataset_seed=spec.dataset_seed,
|
| 1998 |
+
planner_overrides=planner_overrides,
|
| 1999 |
+
)
|
| 2000 |
+
summary = _summarize_smoke([trunk_result, noop_result, active_result], args.report_dir)
|
| 2001 |
+
print(json.dumps({"phase": "complete", "summary": summary}, indent=2), flush=True)
|
| 2002 |
+
|
| 2003 |
+
|
| 2004 |
+
if __name__ == "__main__":
|
| 2005 |
+
main()
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/eval/run_public_benchmark_package.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from eval.public_benchmark_package import (
|
| 11 |
+
ANCHOR_ROLE,
|
| 12 |
+
DEFAULT_ANCHOR_TOLERANCE,
|
| 13 |
+
DEFAULT_SIGN_OF_LIFE_GAIN,
|
| 14 |
+
DEFAULT_SIGN_OF_LIFE_INTERVENTION,
|
| 15 |
+
DEFAULT_SIGN_OF_LIFE_NON_BASE,
|
| 16 |
+
TARGET_ROLE,
|
| 17 |
+
build_public_eval_protocol,
|
| 18 |
+
build_target_training_spec,
|
| 19 |
+
default_public_benchmark_manifest,
|
| 20 |
+
expected_eval_modes,
|
| 21 |
+
public_benchmark_tracks,
|
| 22 |
+
public_protocol_identity_signature,
|
| 23 |
+
public_track_by_id,
|
| 24 |
+
training_fairness_signature,
|
| 25 |
+
write_default_public_benchmark_manifest,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _load_json(path: str | Path) -> dict[str, Any]:
|
| 30 |
+
with Path(path).open("r", encoding="utf-8") as handle:
|
| 31 |
+
payload = json.load(handle)
|
| 32 |
+
if not isinstance(payload, dict):
|
| 33 |
+
raise TypeError(f"Expected a JSON object in {path!s}, got {type(payload)!r}.")
|
| 34 |
+
return payload
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _normalize_success_samples(payload: dict[str, Any]) -> np.ndarray:
|
| 38 |
+
if "successes" in payload:
|
| 39 |
+
raw = np.asarray(payload["successes"], dtype=np.float32).reshape(-1)
|
| 40 |
+
return raw
|
| 41 |
+
if "success_rate" in payload:
|
| 42 |
+
return np.asarray([float(payload["success_rate"])], dtype=np.float32)
|
| 43 |
+
raise KeyError("Each result payload must include either `successes` or `success_rate`.")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _mean_optional(records: list[dict[str, Any]], key: str) -> float | None:
|
| 47 |
+
values = [float(record[key]) for record in records if key in record]
|
| 48 |
+
if not values:
|
| 49 |
+
return None
|
| 50 |
+
return float(np.mean(values))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _bootstrap_delta_ci(
|
| 54 |
+
lhs: np.ndarray,
|
| 55 |
+
rhs: np.ndarray,
|
| 56 |
+
*,
|
| 57 |
+
bootstrap_samples: int,
|
| 58 |
+
bootstrap_seed: int,
|
| 59 |
+
) -> tuple[float, float]:
|
| 60 |
+
if lhs.size == 0 or rhs.size == 0:
|
| 61 |
+
return 0.0, 0.0
|
| 62 |
+
rng = np.random.default_rng(int(bootstrap_seed))
|
| 63 |
+
deltas = np.empty(int(bootstrap_samples), dtype=np.float32)
|
| 64 |
+
for index in range(int(bootstrap_samples)):
|
| 65 |
+
lhs_sample = lhs[rng.integers(0, lhs.shape[0], size=lhs.shape[0])]
|
| 66 |
+
rhs_sample = rhs[rng.integers(0, rhs.shape[0], size=rhs.shape[0])]
|
| 67 |
+
deltas[index] = float(lhs_sample.mean() - rhs_sample.mean())
|
| 68 |
+
low, high = np.percentile(deltas, [2.5, 97.5])
|
| 69 |
+
return float(low), float(high)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _normalize_record(payload: dict[str, Any]) -> dict[str, Any]:
|
| 73 |
+
if "track_id" not in payload:
|
| 74 |
+
raise KeyError("Missing required field `track_id`.")
|
| 75 |
+
if "adapter_mode" not in payload:
|
| 76 |
+
raise KeyError("Missing required field `adapter_mode`.")
|
| 77 |
+
track = public_track_by_id(str(payload["track_id"]))
|
| 78 |
+
success_samples = _normalize_success_samples(payload)
|
| 79 |
+
success_rate = float(payload.get("success_rate", float(success_samples.mean())))
|
| 80 |
+
episodes = int(payload.get("episodes", success_samples.shape[0]))
|
| 81 |
+
record = dict(payload)
|
| 82 |
+
record["track_id"] = track.track_id
|
| 83 |
+
record["suite"] = payload.get("suite", track.suite)
|
| 84 |
+
record["benchmark_task"] = payload.get("benchmark_task", track.benchmark_task)
|
| 85 |
+
record["role"] = payload.get("role", track.role)
|
| 86 |
+
record["adapter_mode"] = str(payload["adapter_mode"])
|
| 87 |
+
record["successes"] = success_samples.tolist()
|
| 88 |
+
record["success_rate"] = success_rate
|
| 89 |
+
record["episodes"] = episodes
|
| 90 |
+
return record
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _validate_protocols(records: list[dict[str, Any]]) -> None:
|
| 94 |
+
by_track: dict[str, list[dict[str, Any]]] = {}
|
| 95 |
+
for record in records:
|
| 96 |
+
by_track.setdefault(record["track_id"], []).append(record)
|
| 97 |
+
for track_id, grouped in by_track.items():
|
| 98 |
+
signatures = []
|
| 99 |
+
for record in grouped:
|
| 100 |
+
protocol = record.get("eval_protocol")
|
| 101 |
+
if protocol is None:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
f"Missing eval_protocol for track {track_id!r}, mode {record['adapter_mode']!r}."
|
| 104 |
+
)
|
| 105 |
+
signatures.append(public_protocol_identity_signature(protocol))
|
| 106 |
+
if any(signature != signatures[0] for signature in signatures[1:]):
|
| 107 |
+
raise ValueError(f"Protocol identity mismatch detected for track {track_id!r}.")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _validate_training_fairness(records: list[dict[str, Any]]) -> None:
|
| 111 |
+
grouped: dict[tuple[str, str], list[dict[str, Any]]] = {}
|
| 112 |
+
for record in records:
|
| 113 |
+
grouped.setdefault((record["track_id"], record["adapter_mode"]), []).append(record)
|
| 114 |
+
for track in public_benchmark_tracks(TARGET_ROLE):
|
| 115 |
+
trunk_records = grouped.get((track.track_id, "trunk_only_ft"), [])
|
| 116 |
+
active_records = grouped.get((track.track_id, "adapter_active_ft"), [])
|
| 117 |
+
if not trunk_records or not active_records:
|
| 118 |
+
continue
|
| 119 |
+
if len(trunk_records) != len(active_records):
|
| 120 |
+
raise ValueError(
|
| 121 |
+
f"Training fairness mismatch for {track.track_id!r}: different run counts "
|
| 122 |
+
f"between trunk_only_ft ({len(trunk_records)}) and adapter_active_ft ({len(active_records)})."
|
| 123 |
+
)
|
| 124 |
+
if any(record.get("train_spec") is None for record in trunk_records + active_records):
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"Training fairness mismatch for {track.track_id!r}: missing train_spec on a target-track result."
|
| 127 |
+
)
|
| 128 |
+
trunk_by_seed = {
|
| 129 |
+
int(record["train_spec"]["seed"]): training_fairness_signature(record["train_spec"])
|
| 130 |
+
for record in trunk_records
|
| 131 |
+
}
|
| 132 |
+
active_by_seed = {
|
| 133 |
+
int(record["train_spec"]["seed"]): training_fairness_signature(record["train_spec"])
|
| 134 |
+
for record in active_records
|
| 135 |
+
}
|
| 136 |
+
if set(trunk_by_seed) != set(active_by_seed):
|
| 137 |
+
raise ValueError(f"Training fairness mismatch for {track.track_id!r}: seed sets differ.")
|
| 138 |
+
for seed, trunk_signature in trunk_by_seed.items():
|
| 139 |
+
if trunk_signature != active_by_seed[seed]:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"Training fairness mismatch for {track.track_id!r} at seed {seed}: "
|
| 142 |
+
"trunk_only_ft and adapter_active_ft do not share the same data/init signature."
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _aggregate_mode(records: list[dict[str, Any]]) -> dict[str, Any]:
|
| 147 |
+
success_rates = np.asarray([float(record["success_rate"]) for record in records], dtype=np.float32)
|
| 148 |
+
success_samples = np.concatenate(
|
| 149 |
+
[np.asarray(record["successes"], dtype=np.float32).reshape(-1) for record in records],
|
| 150 |
+
axis=0,
|
| 151 |
+
)
|
| 152 |
+
payload: dict[str, Any] = {
|
| 153 |
+
"num_runs": len(records),
|
| 154 |
+
"mean_success": float(success_rates.mean()) if success_rates.size else 0.0,
|
| 155 |
+
"success_samples": success_samples.tolist(),
|
| 156 |
+
}
|
| 157 |
+
for key in (
|
| 158 |
+
"intervention_rate",
|
| 159 |
+
"non_base_selection_rate",
|
| 160 |
+
"steps_to_first_reveal_or_access",
|
| 161 |
+
"steps_to_retrieve",
|
| 162 |
+
"disturbance_proxy",
|
| 163 |
+
):
|
| 164 |
+
mean_value = _mean_optional(records, key)
|
| 165 |
+
if mean_value is not None:
|
| 166 |
+
payload[key] = mean_value
|
| 167 |
+
return payload
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def summarize_public_benchmark_package(
|
| 171 |
+
result_payloads: list[dict[str, Any]],
|
| 172 |
+
*,
|
| 173 |
+
bootstrap_samples: int = 2000,
|
| 174 |
+
bootstrap_seed: int = 0,
|
| 175 |
+
allow_partial: bool = False,
|
| 176 |
+
) -> dict[str, Any]:
|
| 177 |
+
records = [_normalize_record(payload) for payload in result_payloads]
|
| 178 |
+
_validate_protocols(records)
|
| 179 |
+
_validate_training_fairness(records)
|
| 180 |
+
|
| 181 |
+
grouped: dict[tuple[str, str], list[dict[str, Any]]] = {}
|
| 182 |
+
for record in records:
|
| 183 |
+
grouped.setdefault((record["track_id"], record["adapter_mode"]), []).append(record)
|
| 184 |
+
|
| 185 |
+
track_summaries: dict[str, Any] = {}
|
| 186 |
+
target_deltas: list[float] = []
|
| 187 |
+
anchor_pass = True
|
| 188 |
+
sign_of_life_tracks: list[str] = []
|
| 189 |
+
ci_above_zero_tracks: list[str] = []
|
| 190 |
+
available_tracks: list[str] = []
|
| 191 |
+
|
| 192 |
+
for track in public_benchmark_tracks():
|
| 193 |
+
track_modes = expected_eval_modes(track.track_id)
|
| 194 |
+
mode_payloads: dict[str, Any] = {}
|
| 195 |
+
missing_modes: list[str] = []
|
| 196 |
+
for mode in track_modes:
|
| 197 |
+
mode_records = grouped.get((track.track_id, mode), [])
|
| 198 |
+
if not mode_records:
|
| 199 |
+
missing_modes.append(mode)
|
| 200 |
+
continue
|
| 201 |
+
mode_payloads[mode] = _aggregate_mode(mode_records)
|
| 202 |
+
if missing_modes:
|
| 203 |
+
if allow_partial:
|
| 204 |
+
continue
|
| 205 |
+
raise ValueError(f"Missing results for track {track.track_id!r}, mode(s) {missing_modes!r}.")
|
| 206 |
+
available_tracks.append(track.track_id)
|
| 207 |
+
|
| 208 |
+
track_summary: dict[str, Any] = {
|
| 209 |
+
"suite": track.suite,
|
| 210 |
+
"benchmark_task": track.benchmark_task,
|
| 211 |
+
"role": track.role,
|
| 212 |
+
"task_family": track.task_family,
|
| 213 |
+
"target_behavior": track.target_behavior,
|
| 214 |
+
"public_source": track.public_source,
|
| 215 |
+
"notes": track.notes,
|
| 216 |
+
"modes": {
|
| 217 |
+
mode: {
|
| 218 |
+
key: value
|
| 219 |
+
for key, value in payload.items()
|
| 220 |
+
if key != "success_samples"
|
| 221 |
+
}
|
| 222 |
+
for mode, payload in mode_payloads.items()
|
| 223 |
+
},
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
if track.role == TARGET_ROLE:
|
| 227 |
+
trunk = mode_payloads["trunk_only_ft"]
|
| 228 |
+
active = mode_payloads["adapter_active_ft"]
|
| 229 |
+
noop = mode_payloads["adapter_noop"]
|
| 230 |
+
delta_active = float(active["mean_success"] - trunk["mean_success"])
|
| 231 |
+
delta_noop = float(noop["mean_success"] - trunk["mean_success"])
|
| 232 |
+
target_deltas.append(delta_active)
|
| 233 |
+
ci_low, ci_high = _bootstrap_delta_ci(
|
| 234 |
+
np.asarray(active["success_samples"], dtype=np.float32),
|
| 235 |
+
np.asarray(trunk["success_samples"], dtype=np.float32),
|
| 236 |
+
bootstrap_samples=bootstrap_samples,
|
| 237 |
+
bootstrap_seed=bootstrap_seed + len(target_deltas),
|
| 238 |
+
)
|
| 239 |
+
sign_of_life = bool(
|
| 240 |
+
float(active.get("intervention_rate", 0.0)) >= DEFAULT_SIGN_OF_LIFE_INTERVENTION
|
| 241 |
+
and float(active.get("non_base_selection_rate", 0.0)) >= DEFAULT_SIGN_OF_LIFE_NON_BASE
|
| 242 |
+
and delta_active >= DEFAULT_SIGN_OF_LIFE_GAIN
|
| 243 |
+
)
|
| 244 |
+
if sign_of_life:
|
| 245 |
+
sign_of_life_tracks.append(track.track_id)
|
| 246 |
+
if ci_low > 0.0:
|
| 247 |
+
ci_above_zero_tracks.append(track.track_id)
|
| 248 |
+
track_summary.update(
|
| 249 |
+
{
|
| 250 |
+
"delta_active_vs_trunk": delta_active,
|
| 251 |
+
"delta_noop_vs_trunk": delta_noop,
|
| 252 |
+
"delta_active_vs_trunk_ci95": [ci_low, ci_high],
|
| 253 |
+
"signs_of_life": sign_of_life,
|
| 254 |
+
}
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
trunk = mode_payloads["trunk_only"]
|
| 258 |
+
active = mode_payloads["adapter_active"]
|
| 259 |
+
noop = mode_payloads["adapter_noop"]
|
| 260 |
+
active_delta = float(active["mean_success"] - trunk["mean_success"])
|
| 261 |
+
noop_delta = float(noop["mean_success"] - trunk["mean_success"])
|
| 262 |
+
within_tolerance = bool(
|
| 263 |
+
abs(active_delta) <= DEFAULT_ANCHOR_TOLERANCE
|
| 264 |
+
and abs(noop_delta) <= DEFAULT_ANCHOR_TOLERANCE
|
| 265 |
+
)
|
| 266 |
+
anchor_pass = anchor_pass and within_tolerance
|
| 267 |
+
track_summary.update(
|
| 268 |
+
{
|
| 269 |
+
"delta_active_vs_trunk": active_delta,
|
| 270 |
+
"delta_noop_vs_trunk": noop_delta,
|
| 271 |
+
"anchor_within_tolerance": within_tolerance,
|
| 272 |
+
}
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
track_summaries[track.track_id] = track_summary
|
| 276 |
+
|
| 277 |
+
headline_pass = bool(
|
| 278 |
+
target_deltas
|
| 279 |
+
and all(delta > 0.0 for delta in target_deltas)
|
| 280 |
+
and len(ci_above_zero_tracks) >= 1
|
| 281 |
+
)
|
| 282 |
+
sign_of_life_pass = len(sign_of_life_tracks) >= 2
|
| 283 |
+
|
| 284 |
+
return {
|
| 285 |
+
"package_name": default_public_benchmark_manifest()["package_name"],
|
| 286 |
+
"tracks": track_summaries,
|
| 287 |
+
"available_tracks": available_tracks,
|
| 288 |
+
"target_macro_average_delta": float(np.mean(target_deltas)) if target_deltas else 0.0,
|
| 289 |
+
"headline_pass": headline_pass,
|
| 290 |
+
"sign_of_life_pass": sign_of_life_pass,
|
| 291 |
+
"sign_of_life_track_count": len(sign_of_life_tracks),
|
| 292 |
+
"sign_of_life_tracks": sign_of_life_tracks,
|
| 293 |
+
"ci_above_zero_tracks": ci_above_zero_tracks,
|
| 294 |
+
"anchor_pass": anchor_pass,
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def _write_markdown(output_path: Path, summary: dict[str, Any]) -> None:
|
| 299 |
+
lines = [
|
| 300 |
+
"# Public Benchmark Package Summary",
|
| 301 |
+
"",
|
| 302 |
+
f"- package_name: {summary['package_name']}",
|
| 303 |
+
f"- headline_pass: {summary['headline_pass']}",
|
| 304 |
+
f"- sign_of_life_pass: {summary['sign_of_life_pass']}",
|
| 305 |
+
f"- sign_of_life_track_count: {summary['sign_of_life_track_count']}",
|
| 306 |
+
f"- anchor_pass: {summary['anchor_pass']}",
|
| 307 |
+
f"- target_macro_average_delta: {summary['target_macro_average_delta']:.3f}",
|
| 308 |
+
"",
|
| 309 |
+
]
|
| 310 |
+
for track_id, payload in summary["tracks"].items():
|
| 311 |
+
lines.append(f"## {track_id}")
|
| 312 |
+
lines.append(f"- suite: {payload['suite']}")
|
| 313 |
+
lines.append(f"- benchmark_task: {payload['benchmark_task']}")
|
| 314 |
+
lines.append(f"- role: {payload['role']}")
|
| 315 |
+
for mode, mode_payload in payload["modes"].items():
|
| 316 |
+
lines.append(f"- {mode}: mean_success={mode_payload['mean_success']:.3f}, num_runs={mode_payload['num_runs']}")
|
| 317 |
+
if "delta_active_vs_trunk" in payload:
|
| 318 |
+
lines.append(f"- delta_active_vs_trunk: {payload['delta_active_vs_trunk']:.3f}")
|
| 319 |
+
if "delta_noop_vs_trunk" in payload:
|
| 320 |
+
lines.append(f"- delta_noop_vs_trunk: {payload['delta_noop_vs_trunk']:.3f}")
|
| 321 |
+
if "delta_active_vs_trunk_ci95" in payload:
|
| 322 |
+
low, high = payload["delta_active_vs_trunk_ci95"]
|
| 323 |
+
lines.append(f"- delta_active_vs_trunk_ci95: [{low:.3f}, {high:.3f}]")
|
| 324 |
+
if "signs_of_life" in payload:
|
| 325 |
+
lines.append(f"- signs_of_life: {payload['signs_of_life']}")
|
| 326 |
+
if "anchor_within_tolerance" in payload:
|
| 327 |
+
lines.append(f"- anchor_within_tolerance: {payload['anchor_within_tolerance']}")
|
| 328 |
+
lines.append("")
|
| 329 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 330 |
+
output_path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _parse_args() -> argparse.Namespace:
|
| 334 |
+
parser = argparse.ArgumentParser(description="Validate and summarize the public benchmark package results.")
|
| 335 |
+
parser.add_argument("--result", action="append", default=[], help="Path to a normalized benchmark result JSON.")
|
| 336 |
+
parser.add_argument("--output-dir", type=Path, default=Path.home() / "workspace" / "reports" / "public_benchmark_package_v1")
|
| 337 |
+
parser.add_argument("--bootstrap-samples", type=int, default=2000)
|
| 338 |
+
parser.add_argument("--bootstrap-seed", type=int, default=0)
|
| 339 |
+
parser.add_argument("--write-default-manifest", type=Path, default=None)
|
| 340 |
+
return parser.parse_args()
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def main() -> None:
|
| 344 |
+
args = _parse_args()
|
| 345 |
+
if args.write_default_manifest is not None:
|
| 346 |
+
path = write_default_public_benchmark_manifest(args.write_default_manifest)
|
| 347 |
+
print(json.dumps({"wrote_manifest": str(path)}, indent=2))
|
| 348 |
+
if not args.result:
|
| 349 |
+
return
|
| 350 |
+
if not args.result:
|
| 351 |
+
raise SystemExit("No results provided. Pass one or more --result files or use --write-default-manifest.")
|
| 352 |
+
|
| 353 |
+
payloads = [_load_json(path) for path in args.result]
|
| 354 |
+
summary = summarize_public_benchmark_package(
|
| 355 |
+
payloads,
|
| 356 |
+
bootstrap_samples=args.bootstrap_samples,
|
| 357 |
+
bootstrap_seed=args.bootstrap_seed,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
args.output_dir.mkdir(parents=True, exist_ok=True)
|
| 361 |
+
json_path = args.output_dir / "public_benchmark_package_summary.json"
|
| 362 |
+
md_path = args.output_dir / "public_benchmark_package_summary.md"
|
| 363 |
+
json_path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
| 364 |
+
_write_markdown(md_path, summary)
|
| 365 |
+
print(json.dumps({"summary_json": str(json_path), "summary_md": str(md_path)}, indent=2))
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
main()
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-310.pyc
ADDED
|
Binary file (26.3 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/action_decoder.cpython-311.pyc
ADDED
|
Binary file (58.2 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-310.pyc
ADDED
|
Binary file (19 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/backbones.cpython-311.pyc
ADDED
|
Binary file (37.3 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-310.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/multiview_fusion.cpython-311.pyc
ADDED
|
Binary file (8.01 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-310.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/observation_memory.cpython-311.pyc
ADDED
|
Binary file (27.7 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-310.pyc
ADDED
|
Binary file (25.3 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/planner.cpython-311.pyc
ADDED
|
Binary file (55.6 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-310.pyc
ADDED
|
Binary file (32.1 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/policy.cpython-311.pyc
ADDED
|
Binary file (58.9 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-310.pyc
ADDED
|
Binary file (19.4 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/reveal_head.cpython-311.pyc
ADDED
|
Binary file (46.7 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/rvt_backbone.cpython-310.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/rvt_backbone.cpython-311.pyc
ADDED
|
Binary file (29.6 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-310.pyc
ADDED
|
Binary file (20 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/__pycache__/world_model.cpython-311.pyc
ADDED
|
Binary file (49.6 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/models/planner.py
ADDED
|
@@ -0,0 +1,887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class PlannerConfig:
|
| 11 |
+
hidden_dim: int = 512
|
| 12 |
+
num_candidates: int = 8
|
| 13 |
+
action_dim: int = 14
|
| 14 |
+
num_support_modes: int = 3
|
| 15 |
+
utility_margin: float = 0.1
|
| 16 |
+
corridor_weight: float = 1.0
|
| 17 |
+
persistence_weight: float = 0.5
|
| 18 |
+
proposal_weight: float = 0.5
|
| 19 |
+
task_progress_weight: float = 0.75
|
| 20 |
+
disturbance_weight: float = 0.75
|
| 21 |
+
reocclusion_weight: float = 0.5
|
| 22 |
+
visibility_weight: float = 0.25
|
| 23 |
+
num_heads: int = 4
|
| 24 |
+
num_layers: int = 2
|
| 25 |
+
num_phases: int = 5
|
| 26 |
+
num_arm_roles: int = 4
|
| 27 |
+
top_k: int = 4
|
| 28 |
+
belief_gain_weight: float = 1.0
|
| 29 |
+
visibility_gain_weight: float = 0.75
|
| 30 |
+
clearance_weight: float = 0.75
|
| 31 |
+
occluder_contact_weight: float = 0.5
|
| 32 |
+
grasp_affordance_weight: float = 0.75
|
| 33 |
+
support_stability_weight: float = 0.5
|
| 34 |
+
residual_weight: float = 0.5
|
| 35 |
+
retrieve_access_threshold: float = 0.15
|
| 36 |
+
retrieve_persistence_threshold: float = 0.15
|
| 37 |
+
retrieve_support_threshold: float = 0.25
|
| 38 |
+
retrieve_reocclusion_threshold: float = 0.6
|
| 39 |
+
adapter_confidence_threshold: float = 0.55
|
| 40 |
+
mode_preference_bonus: float = 3.0
|
| 41 |
+
premature_retrieve_penalty: float = 1.5
|
| 42 |
+
premature_insert_penalty: float = 0.75
|
| 43 |
+
premature_occlusion_sweep_penalty: float = 0.75
|
| 44 |
+
premature_maintain_penalty: float = 0.0
|
| 45 |
+
retrieve_stage_access_threshold: float = 0.45
|
| 46 |
+
retrieve_stage_reveal_threshold: float = 0.40
|
| 47 |
+
retrieve_stage_persistence_threshold: float = 0.20
|
| 48 |
+
retrieve_stage_support_threshold: float = 0.25
|
| 49 |
+
insert_stage_access_threshold: float = 0.40
|
| 50 |
+
insert_stage_visibility_threshold: float = 0.30
|
| 51 |
+
insert_stage_support_threshold: float = 0.25
|
| 52 |
+
occlusion_maintain_gap_min_access: float = 0.0
|
| 53 |
+
occlusion_maintain_gap_min_visibility: float = 0.0
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class RevealPlanner(nn.Module):
|
| 57 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.config = config
|
| 60 |
+
summary_dim = (
|
| 61 |
+
config.action_dim * 2
|
| 62 |
+
+ 3
|
| 63 |
+
+ 3
|
| 64 |
+
+ 1
|
| 65 |
+
+ 3
|
| 66 |
+
+ 1
|
| 67 |
+
)
|
| 68 |
+
self.trunk = nn.Sequential(
|
| 69 |
+
nn.LayerNorm(summary_dim),
|
| 70 |
+
nn.Linear(summary_dim, config.hidden_dim),
|
| 71 |
+
nn.GELU(),
|
| 72 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 73 |
+
nn.GELU(),
|
| 74 |
+
)
|
| 75 |
+
self.success_head = nn.Linear(config.hidden_dim, 1)
|
| 76 |
+
self.risk_head = nn.Linear(config.hidden_dim, 1)
|
| 77 |
+
|
| 78 |
+
def summarize_candidates(self, candidate_chunks: Tensor, rollout_state: dict[str, Tensor]) -> Tensor:
|
| 79 |
+
candidate_mean = candidate_chunks.mean(dim=2)
|
| 80 |
+
candidate_terminal = candidate_chunks[:, :, -1]
|
| 81 |
+
corridor_prob = rollout_state["corridor_logits"].sigmoid().amax(dim=-1).mean(dim=-2)
|
| 82 |
+
persistence = rollout_state["persistence_horizon"].mean(dim=-2)
|
| 83 |
+
disturbance = rollout_state["disturbance_cost"].mean(dim=-1, keepdim=True)
|
| 84 |
+
reocclusion = rollout_state["reocclusion_logit"].sigmoid().mean(dim=-2)
|
| 85 |
+
uncertainty = rollout_state["uncertainty"].mean(dim=-1, keepdim=True)
|
| 86 |
+
return torch.cat(
|
| 87 |
+
[
|
| 88 |
+
candidate_mean,
|
| 89 |
+
candidate_terminal,
|
| 90 |
+
corridor_prob,
|
| 91 |
+
persistence,
|
| 92 |
+
disturbance,
|
| 93 |
+
reocclusion,
|
| 94 |
+
uncertainty,
|
| 95 |
+
],
|
| 96 |
+
dim=-1,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def score_rollouts(self, rollout_state: dict[str, Tensor], candidate_chunks: Tensor) -> dict[str, Tensor]:
|
| 100 |
+
features = self.summarize_candidates(candidate_chunks, rollout_state)
|
| 101 |
+
hidden = self.trunk(features)
|
| 102 |
+
success_logits = self.success_head(hidden).squeeze(-1)
|
| 103 |
+
risk_values = torch.sigmoid(self.risk_head(hidden)).squeeze(-1)
|
| 104 |
+
utility_scores = success_logits.sigmoid() - risk_values
|
| 105 |
+
return {
|
| 106 |
+
"planner_features": features,
|
| 107 |
+
"planner_hidden": hidden,
|
| 108 |
+
"success_logits": success_logits,
|
| 109 |
+
"risk_values": risk_values,
|
| 110 |
+
"utility_scores": utility_scores,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def select_best(self, candidate_chunks: Tensor, rollout_state: dict[str, Tensor]) -> dict[str, Tensor]:
|
| 114 |
+
outputs = self.score_rollouts(rollout_state=rollout_state, candidate_chunks=candidate_chunks)
|
| 115 |
+
best_idx = outputs["utility_scores"].argmax(dim=-1)
|
| 116 |
+
batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
|
| 117 |
+
return {
|
| 118 |
+
**outputs,
|
| 119 |
+
"best_indices": best_idx,
|
| 120 |
+
"best_chunk": candidate_chunks[batch_indices, best_idx],
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class InteractionPlanner(nn.Module):
|
| 125 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.config = config
|
| 128 |
+
step_dim = (
|
| 129 |
+
config.action_dim
|
| 130 |
+
+ config.num_phases
|
| 131 |
+
+ (2 * config.num_arm_roles)
|
| 132 |
+
+ config.num_support_modes
|
| 133 |
+
+ 7
|
| 134 |
+
)
|
| 135 |
+
self.step_proj = nn.Sequential(
|
| 136 |
+
nn.LayerNorm(step_dim),
|
| 137 |
+
nn.Linear(step_dim, config.hidden_dim),
|
| 138 |
+
nn.GELU(),
|
| 139 |
+
)
|
| 140 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 141 |
+
d_model=config.hidden_dim,
|
| 142 |
+
nhead=config.num_heads,
|
| 143 |
+
dim_feedforward=config.hidden_dim * 4,
|
| 144 |
+
batch_first=True,
|
| 145 |
+
norm_first=True,
|
| 146 |
+
)
|
| 147 |
+
self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
|
| 148 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_dim) * 0.02)
|
| 149 |
+
self.success_head = nn.Linear(config.hidden_dim, 1)
|
| 150 |
+
self.risk_head = nn.Linear(config.hidden_dim, 1)
|
| 151 |
+
self.score_head = nn.Linear(config.hidden_dim, 1)
|
| 152 |
+
|
| 153 |
+
def _mean_field(self, tensor: Tensor) -> Tensor:
|
| 154 |
+
return tensor.mean(dim=(-1, -2))
|
| 155 |
+
|
| 156 |
+
def summarize_trajectory(self, candidate_chunks: Tensor, rollout_state: dict[str, Tensor]) -> Tensor:
|
| 157 |
+
horizon = min(candidate_chunks.shape[2], rollout_state["phase_logits"].shape[2])
|
| 158 |
+
candidate_steps = candidate_chunks[:, :, :horizon]
|
| 159 |
+
phase_probs = rollout_state["phase_logits"][:, :, :horizon].softmax(dim=-1)
|
| 160 |
+
support_probs = rollout_state["support_mode_logits"][:, :, :horizon].softmax(dim=-1)
|
| 161 |
+
arm_role_probs = rollout_state["arm_role_logits"][:, :, :horizon].softmax(dim=-1).flatten(start_dim=-2)
|
| 162 |
+
target_mean = self._mean_field(rollout_state["target_field"][:, :, :horizon].sigmoid())
|
| 163 |
+
feasibility_mean = self._mean_field(rollout_state["actor_feasibility_field"][:, :, :horizon].sigmoid())
|
| 164 |
+
persistence_mean = self._mean_field(rollout_state["persistence_field"][:, :, :horizon])
|
| 165 |
+
risk_mean = self._mean_field(rollout_state["risk_field"][:, :, :horizon])
|
| 166 |
+
uncertainty_mean = self._mean_field(rollout_state["uncertainty_field"][:, :, :horizon])
|
| 167 |
+
role_gap = (
|
| 168 |
+
rollout_state["arm_role_logits"][:, :, :horizon, 0].softmax(dim=-1)
|
| 169 |
+
- rollout_state["arm_role_logits"][:, :, :horizon, 1].softmax(dim=-1)
|
| 170 |
+
).abs().mean(dim=-1, keepdim=True)
|
| 171 |
+
return torch.cat(
|
| 172 |
+
[
|
| 173 |
+
candidate_steps,
|
| 174 |
+
phase_probs,
|
| 175 |
+
arm_role_probs,
|
| 176 |
+
support_probs,
|
| 177 |
+
target_mean,
|
| 178 |
+
feasibility_mean,
|
| 179 |
+
persistence_mean,
|
| 180 |
+
risk_mean,
|
| 181 |
+
uncertainty_mean,
|
| 182 |
+
role_gap,
|
| 183 |
+
],
|
| 184 |
+
dim=-1,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def score_rollouts(
|
| 188 |
+
self,
|
| 189 |
+
rollout_state: dict[str, Tensor],
|
| 190 |
+
candidate_chunks: Tensor,
|
| 191 |
+
proposal_logits: Tensor | None = None,
|
| 192 |
+
) -> dict[str, Tensor]:
|
| 193 |
+
features = self.summarize_trajectory(candidate_chunks, rollout_state)
|
| 194 |
+
batch_size, num_candidates, horizon, _ = features.shape
|
| 195 |
+
flat_features = features.view(batch_size * num_candidates, horizon, -1)
|
| 196 |
+
hidden_steps = self.step_proj(flat_features)
|
| 197 |
+
cls = self.cls_token.expand(batch_size * num_candidates, -1, -1)
|
| 198 |
+
encoded = self.sequence_encoder(torch.cat([cls, hidden_steps], dim=1))
|
| 199 |
+
pooled = encoded[:, 0]
|
| 200 |
+
success_logits = self.success_head(pooled).view(batch_size, num_candidates).squeeze(-1)
|
| 201 |
+
risk_values = torch.sigmoid(self.risk_head(pooled)).view(batch_size, num_candidates).squeeze(-1)
|
| 202 |
+
utility_scores = self.score_head(pooled).view(batch_size, num_candidates).squeeze(-1)
|
| 203 |
+
utility_scores = utility_scores + success_logits.sigmoid() - risk_values
|
| 204 |
+
if proposal_logits is not None and proposal_logits.shape == utility_scores.shape:
|
| 205 |
+
utility_scores = utility_scores + self.config.proposal_weight * proposal_logits.sigmoid()
|
| 206 |
+
return {
|
| 207 |
+
"planner_features": features.mean(dim=2),
|
| 208 |
+
"planner_hidden": pooled.view(batch_size, num_candidates, -1),
|
| 209 |
+
"success_logits": success_logits,
|
| 210 |
+
"risk_values": risk_values,
|
| 211 |
+
"utility_scores": utility_scores,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
def select_best(
|
| 215 |
+
self,
|
| 216 |
+
candidate_chunks: Tensor,
|
| 217 |
+
rollout_state: dict[str, Tensor],
|
| 218 |
+
proposal_logits: Tensor | None = None,
|
| 219 |
+
) -> dict[str, Tensor]:
|
| 220 |
+
outputs = self.score_rollouts(
|
| 221 |
+
rollout_state=rollout_state,
|
| 222 |
+
candidate_chunks=candidate_chunks,
|
| 223 |
+
proposal_logits=proposal_logits,
|
| 224 |
+
)
|
| 225 |
+
best_idx = outputs["utility_scores"].argmax(dim=-1)
|
| 226 |
+
batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
|
| 227 |
+
return {
|
| 228 |
+
**outputs,
|
| 229 |
+
"best_indices": best_idx,
|
| 230 |
+
"best_chunk": candidate_chunks[batch_indices, best_idx],
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class StructuredElasticUtility(nn.Module):
|
| 235 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.config = config
|
| 238 |
+
|
| 239 |
+
def _field_mean(self, tensor: Tensor) -> Tensor:
|
| 240 |
+
if tensor.ndim == 6:
|
| 241 |
+
return tensor.mean(dim=(-1, -2, -3))
|
| 242 |
+
if tensor.ndim == 5:
|
| 243 |
+
return tensor.mean(dim=(-1, -2))
|
| 244 |
+
if tensor.ndim == 4:
|
| 245 |
+
return tensor.mean(dim=(-1, -2))
|
| 246 |
+
return tensor
|
| 247 |
+
|
| 248 |
+
def _initial_scalar(self, state: dict[str, Tensor], key: str) -> Tensor:
|
| 249 |
+
value = state[key]
|
| 250 |
+
if value.ndim >= 4:
|
| 251 |
+
return value.mean(dim=tuple(range(1, value.ndim)))
|
| 252 |
+
if value.ndim == 3:
|
| 253 |
+
return value.mean(dim=(-1, -2))
|
| 254 |
+
if value.ndim == 2:
|
| 255 |
+
return value.mean(dim=-1)
|
| 256 |
+
return value
|
| 257 |
+
|
| 258 |
+
def forward(
|
| 259 |
+
self,
|
| 260 |
+
initial_state: dict[str, Tensor],
|
| 261 |
+
rollout_state: dict[str, Tensor],
|
| 262 |
+
candidate_chunks: Tensor,
|
| 263 |
+
) -> dict[str, Tensor]:
|
| 264 |
+
initial_belief = self._initial_scalar(initial_state, "target_belief_field").unsqueeze(1)
|
| 265 |
+
initial_visibility = self._initial_scalar(initial_state, "visibility_field").unsqueeze(1)
|
| 266 |
+
belief_future = self._field_mean(rollout_state["target_belief_field"]).mean(dim=-1)
|
| 267 |
+
visibility_future = self._field_mean(rollout_state["visibility_field"]).mean(dim=-1)
|
| 268 |
+
clearance = self._field_mean(rollout_state["clearance_field"]).mean(dim=-1)
|
| 269 |
+
occluder_contact = self._field_mean(rollout_state["occluder_contact_field"]).mean(dim=-1)
|
| 270 |
+
grasp_affordance = self._field_mean(rollout_state["grasp_affordance_field"]).mean(dim=-1)
|
| 271 |
+
support_stability = torch.sigmoid(self._field_mean(rollout_state["support_stability_field"])).mean(dim=-1)
|
| 272 |
+
persistence_traj = self._field_mean(rollout_state["persistence_field"])
|
| 273 |
+
reocclusion_traj = self._field_mean(rollout_state["reocclusion_field"])
|
| 274 |
+
disturbance_traj = self._field_mean(rollout_state["disturbance_field"])
|
| 275 |
+
access_traj = torch.sigmoid(self._field_mean(rollout_state["access_field"]))
|
| 276 |
+
persistence = persistence_traj.mean(dim=-1)
|
| 277 |
+
reocclusion = reocclusion_traj.mean(dim=-1)
|
| 278 |
+
disturbance = disturbance_traj.mean(dim=-1)
|
| 279 |
+
access_quality = access_traj.mean(dim=-1)
|
| 280 |
+
access_floor = access_traj.amin(dim=-1)
|
| 281 |
+
persistence_floor = persistence_traj.amin(dim=-1)
|
| 282 |
+
support_floor = torch.sigmoid(self._field_mean(rollout_state["support_stability_field"])).amin(dim=-1)
|
| 283 |
+
reocclusion_worst = reocclusion_traj.amax(dim=-1)
|
| 284 |
+
retrieve_progress = torch.sigmoid(candidate_chunks[:, :, :, -1]).mean(dim=-1)
|
| 285 |
+
utility = (
|
| 286 |
+
self.config.belief_gain_weight * (belief_future - initial_belief)
|
| 287 |
+
+ self.config.visibility_gain_weight * (visibility_future - initial_visibility)
|
| 288 |
+
+ self.config.clearance_weight * clearance
|
| 289 |
+
+ self.config.occluder_contact_weight * occluder_contact
|
| 290 |
+
+ self.config.grasp_affordance_weight * grasp_affordance
|
| 291 |
+
+ self.config.persistence_weight * persistence
|
| 292 |
+
+ self.config.support_stability_weight * support_stability
|
| 293 |
+
+ self.config.corridor_weight * access_quality
|
| 294 |
+
+ self.config.task_progress_weight * retrieve_progress
|
| 295 |
+
- self.config.reocclusion_weight * reocclusion
|
| 296 |
+
- self.config.disturbance_weight * disturbance
|
| 297 |
+
- self.config.visibility_weight * (1.0 - visibility_future)
|
| 298 |
+
)
|
| 299 |
+
return {
|
| 300 |
+
"belief_gain": belief_future - initial_belief,
|
| 301 |
+
"visibility_gain": visibility_future - initial_visibility,
|
| 302 |
+
"clearance": clearance,
|
| 303 |
+
"occluder_contact_quality": occluder_contact,
|
| 304 |
+
"grasp_affordance": grasp_affordance,
|
| 305 |
+
"persistence": persistence,
|
| 306 |
+
"support_stability": support_stability,
|
| 307 |
+
"reocclusion_penalty": reocclusion,
|
| 308 |
+
"reocclusion_worst": reocclusion_worst,
|
| 309 |
+
"disturbance_penalty": disturbance,
|
| 310 |
+
"access_quality": access_quality,
|
| 311 |
+
"access_floor": access_floor,
|
| 312 |
+
"persistence_floor": persistence_floor,
|
| 313 |
+
"support_floor": support_floor,
|
| 314 |
+
"task_progress": retrieve_progress,
|
| 315 |
+
"utility_structured": utility,
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class ResidualPlannerScorer(nn.Module):
|
| 320 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 321 |
+
super().__init__()
|
| 322 |
+
feature_dim = (config.action_dim * 2) + 11
|
| 323 |
+
self.trunk = nn.Sequential(
|
| 324 |
+
nn.LayerNorm(feature_dim),
|
| 325 |
+
nn.Linear(feature_dim, config.hidden_dim),
|
| 326 |
+
nn.GELU(),
|
| 327 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 328 |
+
nn.GELU(),
|
| 329 |
+
)
|
| 330 |
+
self.success_head = nn.Linear(config.hidden_dim, 1)
|
| 331 |
+
self.risk_head = nn.Linear(config.hidden_dim, 1)
|
| 332 |
+
self.residual_head = nn.Linear(config.hidden_dim, 1)
|
| 333 |
+
|
| 334 |
+
def forward(
|
| 335 |
+
self,
|
| 336 |
+
candidate_chunks: Tensor,
|
| 337 |
+
structured: dict[str, Tensor],
|
| 338 |
+
proposal_logits: Tensor | None = None,
|
| 339 |
+
) -> dict[str, Tensor]:
|
| 340 |
+
candidate_mean = candidate_chunks.mean(dim=2)
|
| 341 |
+
candidate_terminal = candidate_chunks[:, :, -1]
|
| 342 |
+
components = torch.stack(
|
| 343 |
+
[
|
| 344 |
+
structured["belief_gain"],
|
| 345 |
+
structured["visibility_gain"],
|
| 346 |
+
structured["clearance"],
|
| 347 |
+
structured["occluder_contact_quality"],
|
| 348 |
+
structured["grasp_affordance"],
|
| 349 |
+
structured["persistence"],
|
| 350 |
+
structured["support_stability"],
|
| 351 |
+
structured["reocclusion_penalty"],
|
| 352 |
+
structured["disturbance_penalty"],
|
| 353 |
+
structured["access_quality"],
|
| 354 |
+
structured["task_progress"],
|
| 355 |
+
],
|
| 356 |
+
dim=-1,
|
| 357 |
+
)
|
| 358 |
+
features = torch.cat([candidate_mean, candidate_terminal, components], dim=-1)
|
| 359 |
+
hidden = self.trunk(features)
|
| 360 |
+
success_logits = self.success_head(hidden).squeeze(-1)
|
| 361 |
+
risk_values = torch.sigmoid(self.risk_head(hidden)).squeeze(-1)
|
| 362 |
+
residual = self.residual_head(hidden).squeeze(-1)
|
| 363 |
+
if proposal_logits is not None and proposal_logits.shape == residual.shape:
|
| 364 |
+
residual = residual + 0.25 * proposal_logits.sigmoid()
|
| 365 |
+
return {
|
| 366 |
+
"planner_hidden": hidden,
|
| 367 |
+
"success_logits": success_logits,
|
| 368 |
+
"risk_values": risk_values,
|
| 369 |
+
"utility_residual": residual,
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class CascadePlanner(nn.Module):
|
| 374 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 375 |
+
super().__init__()
|
| 376 |
+
self.config = config
|
| 377 |
+
self.structured = StructuredElasticUtility(config)
|
| 378 |
+
self.residual = ResidualPlannerScorer(config)
|
| 379 |
+
|
| 380 |
+
def shortlist(
|
| 381 |
+
self,
|
| 382 |
+
proposal_logits: Tensor | None,
|
| 383 |
+
candidate_chunks: Tensor,
|
| 384 |
+
proposal_mode_assignments: Tensor | None = None,
|
| 385 |
+
) -> Tensor:
|
| 386 |
+
batch_size, num_candidates = candidate_chunks.shape[:2]
|
| 387 |
+
top_k = min(max(1, self.config.top_k), num_candidates)
|
| 388 |
+
if proposal_logits is None:
|
| 389 |
+
cheap_scores = -candidate_chunks.square().mean(dim=(-1, -2))
|
| 390 |
+
else:
|
| 391 |
+
cheap_scores = proposal_logits
|
| 392 |
+
if proposal_mode_assignments is None:
|
| 393 |
+
return cheap_scores.topk(top_k, dim=-1).indices
|
| 394 |
+
if proposal_mode_assignments.ndim == 1:
|
| 395 |
+
proposal_mode_assignments = proposal_mode_assignments.unsqueeze(0).expand(batch_size, -1)
|
| 396 |
+
|
| 397 |
+
shortlisted = []
|
| 398 |
+
for batch_idx in range(batch_size):
|
| 399 |
+
scores = cheap_scores[batch_idx]
|
| 400 |
+
mode_ids = proposal_mode_assignments[batch_idx]
|
| 401 |
+
mode_best: list[tuple[float, int]] = []
|
| 402 |
+
for mode_id in torch.unique(mode_ids):
|
| 403 |
+
mode_indices = torch.nonzero(mode_ids == mode_id, as_tuple=False).squeeze(-1)
|
| 404 |
+
best_local = mode_indices[scores[mode_indices].argmax()]
|
| 405 |
+
mode_best.append((float(scores[best_local].detach()), int(best_local)))
|
| 406 |
+
mode_best.sort(key=lambda item: item[0], reverse=True)
|
| 407 |
+
chosen = [index for _, index in mode_best[:top_k]]
|
| 408 |
+
if len(chosen) < top_k:
|
| 409 |
+
for candidate_idx in scores.argsort(descending=True).tolist():
|
| 410 |
+
if candidate_idx not in chosen:
|
| 411 |
+
chosen.append(candidate_idx)
|
| 412 |
+
if len(chosen) >= top_k:
|
| 413 |
+
break
|
| 414 |
+
shortlisted.append(torch.as_tensor(chosen[:top_k], device=candidate_chunks.device, dtype=torch.long))
|
| 415 |
+
return torch.stack(shortlisted, dim=0)
|
| 416 |
+
|
| 417 |
+
def select_best(
|
| 418 |
+
self,
|
| 419 |
+
initial_state: dict[str, Tensor],
|
| 420 |
+
candidate_chunks: Tensor,
|
| 421 |
+
rollout_state: dict[str, Tensor],
|
| 422 |
+
proposal_logits: Tensor | None = None,
|
| 423 |
+
candidate_indices: Tensor | None = None,
|
| 424 |
+
proposal_mode_names: list[list[str]] | None = None,
|
| 425 |
+
) -> dict[str, Tensor]:
|
| 426 |
+
structured = self.structured(
|
| 427 |
+
initial_state=initial_state,
|
| 428 |
+
rollout_state=rollout_state,
|
| 429 |
+
candidate_chunks=candidate_chunks,
|
| 430 |
+
)
|
| 431 |
+
residual = self.residual(
|
| 432 |
+
candidate_chunks=candidate_chunks,
|
| 433 |
+
structured=structured,
|
| 434 |
+
proposal_logits=proposal_logits,
|
| 435 |
+
)
|
| 436 |
+
utility_total = structured["utility_structured"] + self.config.residual_weight * residual["utility_residual"]
|
| 437 |
+
utility_total = utility_total + residual["success_logits"].sigmoid() - residual["risk_values"]
|
| 438 |
+
feasibility_penalty = torch.zeros_like(utility_total)
|
| 439 |
+
if proposal_mode_names is not None:
|
| 440 |
+
retrieve_like = torch.zeros_like(utility_total, dtype=torch.bool)
|
| 441 |
+
for batch_idx, names in enumerate(proposal_mode_names):
|
| 442 |
+
for candidate_idx, name in enumerate(names[: utility_total.shape[1]]):
|
| 443 |
+
retrieve_like[batch_idx, candidate_idx] = any(
|
| 444 |
+
token in name for token in ("retrieve", "insert_actor", "probe_inside")
|
| 445 |
+
)
|
| 446 |
+
blocked = (
|
| 447 |
+
(structured["access_floor"] < 0.15)
|
| 448 |
+
| (structured["persistence_floor"] < 0.15)
|
| 449 |
+
| (structured["support_floor"] < 0.25)
|
| 450 |
+
| (structured["reocclusion_worst"] > 0.6)
|
| 451 |
+
)
|
| 452 |
+
feasibility_penalty = retrieve_like.to(dtype=utility_total.dtype) * blocked.to(dtype=utility_total.dtype) * 2.0
|
| 453 |
+
utility_total = utility_total - feasibility_penalty
|
| 454 |
+
best_local = utility_total.argmax(dim=-1)
|
| 455 |
+
batch_indices = torch.arange(candidate_chunks.shape[0], device=candidate_chunks.device)
|
| 456 |
+
if candidate_indices is None:
|
| 457 |
+
best_indices = best_local
|
| 458 |
+
else:
|
| 459 |
+
best_indices = candidate_indices[batch_indices, best_local]
|
| 460 |
+
return {
|
| 461 |
+
**structured,
|
| 462 |
+
**residual,
|
| 463 |
+
"utility_total": utility_total,
|
| 464 |
+
"utility_scores": utility_total,
|
| 465 |
+
"feasibility_penalty": feasibility_penalty,
|
| 466 |
+
"best_indices": best_indices,
|
| 467 |
+
"best_chunk": candidate_chunks[batch_indices, best_local],
|
| 468 |
+
"ranking_diagnostics": {
|
| 469 |
+
"topk_indices": candidate_indices if candidate_indices is not None else best_local.unsqueeze(-1),
|
| 470 |
+
"best_local_indices": best_local,
|
| 471 |
+
},
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def _summary_scalar(state: dict[str, Tensor], key: str, fallback_keys: tuple[str, ...] = ()) -> Tensor:
|
| 476 |
+
for candidate in (key, *fallback_keys):
|
| 477 |
+
value = state.get(candidate)
|
| 478 |
+
if value is None:
|
| 479 |
+
continue
|
| 480 |
+
if value.ndim >= 5:
|
| 481 |
+
return value.mean(dim=tuple(range(value.ndim - 2, value.ndim))).mean(dim=-1)
|
| 482 |
+
if value.ndim == 4:
|
| 483 |
+
return value.mean(dim=(-1, -2))
|
| 484 |
+
if value.ndim == 3:
|
| 485 |
+
return value
|
| 486 |
+
if value.ndim == 2:
|
| 487 |
+
return value
|
| 488 |
+
return value.unsqueeze(-1)
|
| 489 |
+
raise KeyError(f"Missing summary key {key} and fallbacks {fallback_keys}.")
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def _optional_summary_scalar(
|
| 493 |
+
state: dict[str, Tensor],
|
| 494 |
+
key: str,
|
| 495 |
+
*,
|
| 496 |
+
reference: Tensor,
|
| 497 |
+
fallback_keys: tuple[str, ...] = (),
|
| 498 |
+
) -> Tensor:
|
| 499 |
+
try:
|
| 500 |
+
return _summary_scalar(state, key, fallback_keys)
|
| 501 |
+
except KeyError:
|
| 502 |
+
return torch.zeros_like(reference)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class ElasticFeasibilityGate(nn.Module):
|
| 506 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 507 |
+
super().__init__()
|
| 508 |
+
self.config = config
|
| 509 |
+
|
| 510 |
+
def forward(
|
| 511 |
+
self,
|
| 512 |
+
*,
|
| 513 |
+
rollout_state: dict[str, Tensor],
|
| 514 |
+
proposal_mode_names: list[list[str]],
|
| 515 |
+
) -> dict[str, Tensor | list[list[dict[str, float | bool | str]]]]:
|
| 516 |
+
access = _summary_scalar(rollout_state, "access_summary", ("access_quality",))
|
| 517 |
+
persistence = _summary_scalar(rollout_state, "persistence_summary", ("persistence", "persistence_horizon"))
|
| 518 |
+
support = _summary_scalar(rollout_state, "support_summary", ("support_stability",))
|
| 519 |
+
reocclusion = _summary_scalar(rollout_state, "reocclusion_summary", ("reocclusion_penalty",))
|
| 520 |
+
disturbance = _summary_scalar(rollout_state, "disturbance_summary", ("disturbance_penalty",))
|
| 521 |
+
access_floor = access.amin(dim=-1)
|
| 522 |
+
persistence_floor = persistence.amin(dim=-1)
|
| 523 |
+
support_floor = support.amin(dim=-1)
|
| 524 |
+
reocclusion_worst = reocclusion.amax(dim=-1)
|
| 525 |
+
disturbance_worst = disturbance.amax(dim=-1)
|
| 526 |
+
|
| 527 |
+
blocked = (
|
| 528 |
+
(access_floor < self.config.retrieve_access_threshold)
|
| 529 |
+
| (persistence_floor < self.config.retrieve_persistence_threshold)
|
| 530 |
+
| (support_floor < self.config.retrieve_support_threshold)
|
| 531 |
+
| (reocclusion_worst > self.config.retrieve_reocclusion_threshold)
|
| 532 |
+
)
|
| 533 |
+
penalties = blocked.to(dtype=access.dtype) * 2.0
|
| 534 |
+
allowed_mask = torch.ones_like(access_floor, dtype=torch.bool)
|
| 535 |
+
reject_diagnostics: list[list[dict[str, float | bool | str]]] = []
|
| 536 |
+
for batch_idx, names in enumerate(proposal_mode_names):
|
| 537 |
+
sample_records: list[dict[str, float | bool | str]] = []
|
| 538 |
+
for candidate_idx, name in enumerate(names[: access_floor.shape[1]]):
|
| 539 |
+
retrieve_like = any(token in name for token in ("retrieve", "insert_actor", "probe_inside"))
|
| 540 |
+
candidate_blocked = bool(retrieve_like and blocked[batch_idx, candidate_idx])
|
| 541 |
+
if candidate_blocked:
|
| 542 |
+
allowed_mask[batch_idx, candidate_idx] = False
|
| 543 |
+
sample_records.append(
|
| 544 |
+
{
|
| 545 |
+
"mode_name": name,
|
| 546 |
+
"retrieve_like": retrieve_like,
|
| 547 |
+
"blocked": candidate_blocked,
|
| 548 |
+
"access_floor": float(access_floor[batch_idx, candidate_idx].detach()),
|
| 549 |
+
"persistence_floor": float(persistence_floor[batch_idx, candidate_idx].detach()),
|
| 550 |
+
"support_floor": float(support_floor[batch_idx, candidate_idx].detach()),
|
| 551 |
+
"reocclusion_worst": float(reocclusion_worst[batch_idx, candidate_idx].detach()),
|
| 552 |
+
"disturbance_worst": float(disturbance_worst[batch_idx, candidate_idx].detach()),
|
| 553 |
+
}
|
| 554 |
+
)
|
| 555 |
+
reject_diagnostics.append(sample_records)
|
| 556 |
+
|
| 557 |
+
confidence = torch.sigmoid(
|
| 558 |
+
2.0 * access.mean(dim=-1)
|
| 559 |
+
+ 1.5 * persistence.mean(dim=-1)
|
| 560 |
+
+ 1.5 * support.mean(dim=-1)
|
| 561 |
+
- 1.5 * reocclusion.mean(dim=-1)
|
| 562 |
+
- disturbance.mean(dim=-1)
|
| 563 |
+
)
|
| 564 |
+
return {
|
| 565 |
+
"allowed_mask": allowed_mask,
|
| 566 |
+
"penalties": penalties,
|
| 567 |
+
"blocked_mask": blocked,
|
| 568 |
+
"adapter_confidence": confidence,
|
| 569 |
+
"gate_access_floor": access_floor,
|
| 570 |
+
"gate_persistence_floor": persistence_floor,
|
| 571 |
+
"gate_support_floor": support_floor,
|
| 572 |
+
"gate_reocclusion_worst": reocclusion_worst,
|
| 573 |
+
"reject_diagnostics": reject_diagnostics,
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class ResidualActionReranker(nn.Module):
|
| 578 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 579 |
+
super().__init__()
|
| 580 |
+
feature_dim = (config.action_dim * 2) + 8
|
| 581 |
+
self.network = nn.Sequential(
|
| 582 |
+
nn.LayerNorm(feature_dim),
|
| 583 |
+
nn.Linear(feature_dim, config.hidden_dim),
|
| 584 |
+
nn.GELU(),
|
| 585 |
+
nn.Linear(config.hidden_dim, config.hidden_dim),
|
| 586 |
+
nn.GELU(),
|
| 587 |
+
)
|
| 588 |
+
self.score_head = nn.Linear(config.hidden_dim, 1)
|
| 589 |
+
self.success_head = nn.Linear(config.hidden_dim, 1)
|
| 590 |
+
self.risk_head = nn.Linear(config.hidden_dim, 1)
|
| 591 |
+
|
| 592 |
+
def forward(
|
| 593 |
+
self,
|
| 594 |
+
*,
|
| 595 |
+
candidate_chunks: Tensor,
|
| 596 |
+
rollout_state: dict[str, Tensor],
|
| 597 |
+
proposal_logits: Tensor | None,
|
| 598 |
+
) -> dict[str, Tensor]:
|
| 599 |
+
candidate_mean = candidate_chunks.mean(dim=2)
|
| 600 |
+
candidate_terminal = candidate_chunks[:, :, -1]
|
| 601 |
+
visibility = _summary_scalar(rollout_state, "visibility_summary", ("visibility_gain",))
|
| 602 |
+
access = _summary_scalar(rollout_state, "access_summary", ("access_quality",))
|
| 603 |
+
persistence = _summary_scalar(rollout_state, "persistence_summary", ("persistence", "persistence_horizon"))
|
| 604 |
+
support = _summary_scalar(rollout_state, "support_summary", ("support_stability",))
|
| 605 |
+
reocclusion = _summary_scalar(rollout_state, "reocclusion_summary", ("reocclusion_penalty",))
|
| 606 |
+
disturbance = _summary_scalar(rollout_state, "disturbance_summary", ("disturbance_penalty",))
|
| 607 |
+
fold_preservation = _optional_summary_scalar(
|
| 608 |
+
rollout_state,
|
| 609 |
+
"fold_preservation_summary",
|
| 610 |
+
reference=visibility,
|
| 611 |
+
fallback_keys=("fold_preservation",),
|
| 612 |
+
)
|
| 613 |
+
lift_risk = _optional_summary_scalar(
|
| 614 |
+
rollout_state,
|
| 615 |
+
"lift_too_much_risk_summary",
|
| 616 |
+
reference=visibility,
|
| 617 |
+
fallback_keys=("lift_too_much_risk",),
|
| 618 |
+
)
|
| 619 |
+
features = torch.cat(
|
| 620 |
+
[
|
| 621 |
+
candidate_mean,
|
| 622 |
+
candidate_terminal,
|
| 623 |
+
visibility.mean(dim=-1, keepdim=True),
|
| 624 |
+
access.mean(dim=-1, keepdim=True),
|
| 625 |
+
persistence.mean(dim=-1, keepdim=True),
|
| 626 |
+
support.mean(dim=-1, keepdim=True),
|
| 627 |
+
reocclusion.mean(dim=-1, keepdim=True),
|
| 628 |
+
disturbance.mean(dim=-1, keepdim=True),
|
| 629 |
+
fold_preservation.mean(dim=-1, keepdim=True),
|
| 630 |
+
lift_risk.mean(dim=-1, keepdim=True),
|
| 631 |
+
],
|
| 632 |
+
dim=-1,
|
| 633 |
+
)
|
| 634 |
+
hidden = self.network(features)
|
| 635 |
+
residual = self.score_head(hidden).squeeze(-1)
|
| 636 |
+
success = self.success_head(hidden).squeeze(-1)
|
| 637 |
+
risk = torch.sigmoid(self.risk_head(hidden).squeeze(-1))
|
| 638 |
+
if proposal_logits is not None and proposal_logits.shape == residual.shape:
|
| 639 |
+
residual = residual + 0.25 * proposal_logits.sigmoid()
|
| 640 |
+
return {
|
| 641 |
+
"residual_scores": residual,
|
| 642 |
+
"planner_success_logits": success,
|
| 643 |
+
"planner_risk_values": risk,
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
class AdapterPlanner(nn.Module):
|
| 648 |
+
def __init__(self, config: PlannerConfig) -> None:
|
| 649 |
+
super().__init__()
|
| 650 |
+
self.config = config
|
| 651 |
+
self.gate = ElasticFeasibilityGate(config)
|
| 652 |
+
self.reranker = ResidualActionReranker(config)
|
| 653 |
+
|
| 654 |
+
def select_best(
|
| 655 |
+
self,
|
| 656 |
+
*,
|
| 657 |
+
candidate_chunks: Tensor,
|
| 658 |
+
rollout_state: dict[str, Tensor],
|
| 659 |
+
proposal_mode_names: list[list[str]],
|
| 660 |
+
proposal_logits: Tensor | None = None,
|
| 661 |
+
planning_mode: str = "adapter_active",
|
| 662 |
+
) -> dict[str, Tensor | list[list[dict[str, float | bool | str]]]]:
|
| 663 |
+
batch_size = candidate_chunks.shape[0]
|
| 664 |
+
batch_indices = torch.arange(batch_size, device=candidate_chunks.device)
|
| 665 |
+
if planning_mode in {"identity", "trunk_only", "adapter_noop"}:
|
| 666 |
+
zero_scores = candidate_chunks.new_zeros((batch_size, candidate_chunks.shape[1]))
|
| 667 |
+
return {
|
| 668 |
+
"best_indices": torch.zeros(batch_size, dtype=torch.long, device=candidate_chunks.device),
|
| 669 |
+
"best_chunk": candidate_chunks[:, 0],
|
| 670 |
+
"utility_scores": zero_scores,
|
| 671 |
+
"utility_total": zero_scores,
|
| 672 |
+
"planner_success_logits": zero_scores,
|
| 673 |
+
"planner_risk_values": zero_scores,
|
| 674 |
+
"adapter_confidence": candidate_chunks.new_ones((batch_size, candidate_chunks.shape[1])),
|
| 675 |
+
"reject_diagnostics": [[] for _ in range(batch_size)],
|
| 676 |
+
"planning_mode": planning_mode,
|
| 677 |
+
}
|
| 678 |
+
|
| 679 |
+
gate_outputs = self.gate(rollout_state=rollout_state, proposal_mode_names=proposal_mode_names)
|
| 680 |
+
reranker = self.reranker(
|
| 681 |
+
candidate_chunks=candidate_chunks,
|
| 682 |
+
rollout_state=rollout_state,
|
| 683 |
+
proposal_logits=proposal_logits,
|
| 684 |
+
)
|
| 685 |
+
utility = reranker["residual_scores"] + reranker["planner_success_logits"].sigmoid() - reranker["planner_risk_values"]
|
| 686 |
+
visibility = _summary_scalar(rollout_state, "visibility_summary", ("visibility_gain",)).mean(dim=-1)
|
| 687 |
+
access = _summary_scalar(rollout_state, "access_summary", ("access_quality",)).mean(dim=-1)
|
| 688 |
+
persistence = _summary_scalar(rollout_state, "persistence_summary", ("persistence", "persistence_horizon")).mean(dim=-1)
|
| 689 |
+
support = _summary_scalar(rollout_state, "support_summary", ("support_stability",)).mean(dim=-1)
|
| 690 |
+
reocclusion = _summary_scalar(rollout_state, "reocclusion_summary", ("reocclusion_penalty",)).mean(dim=-1)
|
| 691 |
+
disturbance = _summary_scalar(rollout_state, "disturbance_summary", ("disturbance_penalty",)).mean(dim=-1)
|
| 692 |
+
fold_preservation = _optional_summary_scalar(
|
| 693 |
+
rollout_state,
|
| 694 |
+
"fold_preservation_summary",
|
| 695 |
+
reference=_summary_scalar(rollout_state, "access_summary", ("access_quality",)),
|
| 696 |
+
fallback_keys=("fold_preservation",),
|
| 697 |
+
).mean(dim=-1)
|
| 698 |
+
mouth_aperture = _optional_summary_scalar(
|
| 699 |
+
rollout_state,
|
| 700 |
+
"mouth_aperture_summary",
|
| 701 |
+
reference=_summary_scalar(rollout_state, "access_summary", ("access_quality",)),
|
| 702 |
+
fallback_keys=("mouth_aperture",),
|
| 703 |
+
).mean(dim=-1)
|
| 704 |
+
layer_separation = _optional_summary_scalar(
|
| 705 |
+
rollout_state,
|
| 706 |
+
"layer_separation_summary",
|
| 707 |
+
reference=_summary_scalar(rollout_state, "access_summary", ("access_quality",)),
|
| 708 |
+
fallback_keys=("layer_separation_quality",),
|
| 709 |
+
).mean(dim=-1)
|
| 710 |
+
lift_risk = _optional_summary_scalar(
|
| 711 |
+
rollout_state,
|
| 712 |
+
"lift_too_much_risk_summary",
|
| 713 |
+
reference=_summary_scalar(rollout_state, "access_summary", ("access_quality",)),
|
| 714 |
+
fallback_keys=("lift_too_much_risk",),
|
| 715 |
+
).mean(dim=-1)
|
| 716 |
+
mode_bias = utility.new_zeros(utility.shape)
|
| 717 |
+
stage_penalty = utility.new_zeros(utility.shape)
|
| 718 |
+
unresolved_reveal = (1.0 - visibility) + (1.0 - access)
|
| 719 |
+
stabilized_reveal = 0.5 * (access + persistence + support)
|
| 720 |
+
# Use optimistic scene readiness summaries for stage switching.
|
| 721 |
+
# Candidate-level safety is still enforced by the retrieve gate below, so
|
| 722 |
+
# we should not let one poor candidate keep the entire scene stuck in
|
| 723 |
+
# "reveal forever" mode when another candidate already makes retrieve feasible.
|
| 724 |
+
batch_visibility = visibility.amax(dim=1)
|
| 725 |
+
batch_access = access.amax(dim=1)
|
| 726 |
+
batch_persistence = persistence.amax(dim=1)
|
| 727 |
+
batch_support = support.amax(dim=1)
|
| 728 |
+
batch_reocclusion = reocclusion.amin(dim=1)
|
| 729 |
+
batch_disturbance = disturbance.amin(dim=1)
|
| 730 |
+
batch_fold = fold_preservation.amax(dim=1)
|
| 731 |
+
batch_mouth = mouth_aperture.amax(dim=1)
|
| 732 |
+
batch_layer = layer_separation.amax(dim=1)
|
| 733 |
+
batch_lift = lift_risk.amin(dim=1)
|
| 734 |
+
batch_reveal_readiness = torch.maximum(batch_visibility, batch_access)
|
| 735 |
+
for batch_idx, names in enumerate(proposal_mode_names):
|
| 736 |
+
is_bag = any(any(token in name for token in ("mouth", "rim", "probe_inside")) for name in names)
|
| 737 |
+
is_cloth = any(any(token in name for token in ("fold", "lift", "layer")) for name in names)
|
| 738 |
+
can_retrieve = (
|
| 739 |
+
batch_access[batch_idx] >= self.config.retrieve_stage_access_threshold
|
| 740 |
+
and batch_reveal_readiness[batch_idx] >= self.config.retrieve_stage_reveal_threshold
|
| 741 |
+
and batch_persistence[batch_idx] >= self.config.retrieve_stage_persistence_threshold
|
| 742 |
+
and batch_support[batch_idx] >= self.config.retrieve_stage_support_threshold
|
| 743 |
+
and batch_reocclusion[batch_idx] <= self.config.retrieve_reocclusion_threshold
|
| 744 |
+
)
|
| 745 |
+
if is_bag:
|
| 746 |
+
can_retrieve = bool(
|
| 747 |
+
can_retrieve
|
| 748 |
+
and batch_mouth[batch_idx] >= 0.30
|
| 749 |
+
and batch_persistence[batch_idx] >= 0.55
|
| 750 |
+
)
|
| 751 |
+
elif is_cloth:
|
| 752 |
+
can_retrieve = bool(
|
| 753 |
+
can_retrieve
|
| 754 |
+
and batch_layer[batch_idx] >= 0.18
|
| 755 |
+
and batch_fold[batch_idx] >= 0.60
|
| 756 |
+
and batch_lift[batch_idx] <= 0.30
|
| 757 |
+
and batch_support[batch_idx] >= 0.70
|
| 758 |
+
)
|
| 759 |
+
can_insert = (
|
| 760 |
+
batch_access[batch_idx] >= self.config.insert_stage_access_threshold
|
| 761 |
+
and batch_visibility[batch_idx] >= self.config.insert_stage_visibility_threshold
|
| 762 |
+
and batch_support[batch_idx] >= self.config.insert_stage_support_threshold
|
| 763 |
+
and batch_reocclusion[batch_idx] <= 0.65
|
| 764 |
+
)
|
| 765 |
+
maintain_ready = (
|
| 766 |
+
batch_access[batch_idx] >= self.config.occlusion_maintain_gap_min_access
|
| 767 |
+
and batch_visibility[batch_idx] >= self.config.occlusion_maintain_gap_min_visibility
|
| 768 |
+
)
|
| 769 |
+
if can_retrieve:
|
| 770 |
+
preferred_tokens = ("retrieve",)
|
| 771 |
+
elif can_insert:
|
| 772 |
+
preferred_tokens = ("probe_inside", "insert_actor") if is_bag else ("insert_actor",)
|
| 773 |
+
elif is_bag:
|
| 774 |
+
if batch_access[batch_idx] < 0.15 or batch_visibility[batch_idx] < 0.20:
|
| 775 |
+
preferred_tokens = ("widen_mouth", "maintain_mouth")
|
| 776 |
+
else:
|
| 777 |
+
preferred_tokens = ("maintain_mouth", "widen_mouth")
|
| 778 |
+
elif is_cloth:
|
| 779 |
+
if batch_access[batch_idx] < 0.15 or batch_visibility[batch_idx] < 0.20:
|
| 780 |
+
preferred_tokens = ("lift_edge", "separate_layer")
|
| 781 |
+
elif batch_lift[batch_idx] > 0.15 or batch_disturbance[batch_idx] > 0.25:
|
| 782 |
+
preferred_tokens = ("stabilize_fold", "maintain_lift")
|
| 783 |
+
else:
|
| 784 |
+
preferred_tokens = ("maintain_lift", "stabilize_fold")
|
| 785 |
+
else:
|
| 786 |
+
if not maintain_ready:
|
| 787 |
+
preferred_tokens = ("widen_gap", "pin_canopy", "sweep_left", "sweep_right")
|
| 788 |
+
elif batch_visibility[batch_idx] < 0.20 or batch_access[batch_idx] < 0.25:
|
| 789 |
+
preferred_tokens = ("widen_gap", "pin_canopy")
|
| 790 |
+
elif batch_disturbance[batch_idx] > 0.25 or batch_reocclusion[batch_idx] > 0.40:
|
| 791 |
+
preferred_tokens = ("maintain_gap", "pin_canopy")
|
| 792 |
+
else:
|
| 793 |
+
preferred_tokens = ("pin_canopy", "widen_gap")
|
| 794 |
+
for candidate_idx, name in enumerate(names[: utility.shape[1]]):
|
| 795 |
+
if name == "base_action":
|
| 796 |
+
continue
|
| 797 |
+
if any(token in name for token in ("retrieve",)):
|
| 798 |
+
bonus = (
|
| 799 |
+
0.85 * visibility[batch_idx, candidate_idx]
|
| 800 |
+
+ 0.85 * access[batch_idx, candidate_idx]
|
| 801 |
+
+ 0.65 * persistence[batch_idx, candidate_idx]
|
| 802 |
+
+ 0.50 * support[batch_idx, candidate_idx]
|
| 803 |
+
- 0.60 * reocclusion[batch_idx, candidate_idx]
|
| 804 |
+
- 0.25 * disturbance[batch_idx, candidate_idx]
|
| 805 |
+
)
|
| 806 |
+
elif any(token in name for token in ("insert_actor", "probe_inside")):
|
| 807 |
+
bonus = (
|
| 808 |
+
0.70 * visibility[batch_idx, candidate_idx]
|
| 809 |
+
+ 0.70 * access[batch_idx, candidate_idx]
|
| 810 |
+
+ 0.35 * persistence[batch_idx, candidate_idx]
|
| 811 |
+
- 0.35 * reocclusion[batch_idx, candidate_idx]
|
| 812 |
+
- 0.15 * disturbance[batch_idx, candidate_idx]
|
| 813 |
+
)
|
| 814 |
+
elif any(token in name for token in ("maintain", "stabilize", "pin_canopy")):
|
| 815 |
+
bonus = (
|
| 816 |
+
0.85 * stabilized_reveal[batch_idx, candidate_idx]
|
| 817 |
+
+ 0.25 * visibility[batch_idx, candidate_idx]
|
| 818 |
+
- 0.20 * reocclusion[batch_idx, candidate_idx]
|
| 819 |
+
- 0.10 * disturbance[batch_idx, candidate_idx]
|
| 820 |
+
)
|
| 821 |
+
else:
|
| 822 |
+
bonus = (
|
| 823 |
+
0.95 * unresolved_reveal[batch_idx, candidate_idx]
|
| 824 |
+
+ 0.20 * (1.0 - persistence[batch_idx, candidate_idx])
|
| 825 |
+
- 0.10 * disturbance[batch_idx, candidate_idx]
|
| 826 |
+
)
|
| 827 |
+
if any(token in name for token in ("fold", "lift", "layer")):
|
| 828 |
+
bonus = bonus + 0.35 * fold_preservation[batch_idx, candidate_idx] - 0.35 * lift_risk[batch_idx, candidate_idx]
|
| 829 |
+
if any(token in name for token in preferred_tokens):
|
| 830 |
+
bonus = bonus + self.config.mode_preference_bonus
|
| 831 |
+
elif "retrieve" in name and not can_retrieve:
|
| 832 |
+
bonus = bonus - self.config.premature_retrieve_penalty
|
| 833 |
+
stage_penalty[batch_idx, candidate_idx] = (
|
| 834 |
+
stage_penalty[batch_idx, candidate_idx] + self.config.premature_retrieve_penalty
|
| 835 |
+
)
|
| 836 |
+
elif is_cloth and any(token in name for token in ("stabilize", "maintain")) and any(
|
| 837 |
+
token in preferred_tokens for token in ("lift_edge", "separate_layer")
|
| 838 |
+
):
|
| 839 |
+
bonus = bonus - 1.0
|
| 840 |
+
stage_penalty[batch_idx, candidate_idx] = stage_penalty[batch_idx, candidate_idx] + 1.0
|
| 841 |
+
elif (not is_bag and not is_cloth) and any(token in name for token in ("sweep_left", "sweep_right")) and any(
|
| 842 |
+
token in preferred_tokens for token in ("pin_canopy", "widen_gap", "maintain_gap")
|
| 843 |
+
):
|
| 844 |
+
bonus = bonus - self.config.premature_occlusion_sweep_penalty
|
| 845 |
+
elif any(token in name for token in ("probe_inside", "insert_actor", "retrieve")) and not can_insert:
|
| 846 |
+
bonus = bonus - self.config.premature_insert_penalty
|
| 847 |
+
stage_penalty[batch_idx, candidate_idx] = (
|
| 848 |
+
stage_penalty[batch_idx, candidate_idx] + self.config.premature_insert_penalty
|
| 849 |
+
)
|
| 850 |
+
if (
|
| 851 |
+
(not is_bag and not is_cloth)
|
| 852 |
+
and "maintain_gap" in name
|
| 853 |
+
and not maintain_ready
|
| 854 |
+
and self.config.premature_maintain_penalty > 0.0
|
| 855 |
+
):
|
| 856 |
+
bonus = bonus - self.config.premature_maintain_penalty
|
| 857 |
+
stage_penalty[batch_idx, candidate_idx] = (
|
| 858 |
+
stage_penalty[batch_idx, candidate_idx] + self.config.premature_maintain_penalty
|
| 859 |
+
)
|
| 860 |
+
if is_bag and (batch_mouth[batch_idx] < 0.18 or batch_access[batch_idx] < 0.15) and "widen_mouth" in name:
|
| 861 |
+
stage_penalty[batch_idx, candidate_idx] = stage_penalty[batch_idx, candidate_idx] + 1.5
|
| 862 |
+
if is_cloth and (batch_layer[batch_idx] < 0.12 or batch_visibility[batch_idx] < 0.05) and any(
|
| 863 |
+
token in name for token in ("lift_edge", "separate_layer")
|
| 864 |
+
):
|
| 865 |
+
stage_penalty[batch_idx, candidate_idx] = stage_penalty[batch_idx, candidate_idx] + 1.5
|
| 866 |
+
mode_bias[batch_idx, candidate_idx] = bonus
|
| 867 |
+
utility = utility + mode_bias
|
| 868 |
+
utility = utility + 0.5 * fold_preservation - 0.5 * lift_risk
|
| 869 |
+
utility = utility - stage_penalty
|
| 870 |
+
utility = utility - gate_outputs["penalties"]
|
| 871 |
+
allowed_mask = gate_outputs["allowed_mask"]
|
| 872 |
+
assert isinstance(allowed_mask, Tensor)
|
| 873 |
+
utility = utility.masked_fill(~allowed_mask, -1e6)
|
| 874 |
+
best_indices = utility.argmax(dim=-1)
|
| 875 |
+
best_chunk = candidate_chunks[batch_indices, best_indices]
|
| 876 |
+
return {
|
| 877 |
+
"best_indices": best_indices,
|
| 878 |
+
"best_chunk": best_chunk,
|
| 879 |
+
"utility_scores": utility,
|
| 880 |
+
"utility_total": utility,
|
| 881 |
+
"planner_success_logits": reranker["planner_success_logits"],
|
| 882 |
+
"planner_risk_values": reranker["planner_risk_values"],
|
| 883 |
+
"adapter_confidence": gate_outputs["adapter_confidence"],
|
| 884 |
+
"allowed_mask": gate_outputs["allowed_mask"],
|
| 885 |
+
"reject_diagnostics": gate_outputs["reject_diagnostics"],
|
| 886 |
+
"planning_mode": planning_mode,
|
| 887 |
+
}
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (284 Bytes). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (317 Bytes). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/transforms.cpython-310.pyc
ADDED
|
Binary file (3.63 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/pytorch3d/__pycache__/transforms.cpython-311.pyc
ADDED
|
Binary file (7.25 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (1.95 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/dataset.cpython-311.pyc
ADDED
|
Binary file (54.1 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/procedural_envs.cpython-311.pyc
ADDED
|
Binary file (98 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/sim_reveal/__pycache__/proxy_specs.cpython-311.pyc
ADDED
|
Binary file (6.32 kB). View file
|
|
|
code/VLAarchtests2_code/VLAarchtests/code/reveal_vla_bimanual/train/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (291 Bytes). View file
|
|
|