| --- |
| license: other |
| tags: |
| - robotics |
| - rlt |
| - rl-token |
| - molmoact2 |
| - vla |
| --- |
| |
| # RLT Stage-1 RL Token Encoder (MolmoAct2 / YAM stack-cube) |
|
|
| Backup of the **RL Token (RLT) Stage-1 encoder** for the frozen MolmoAct2-BimanualYAM |
| stack-cube fine-tune. Faithful PyTorch port of openpi's `pi0_rl.py` (Xu et al. 2025): |
| a learned `<rl>` query compresses the VLA's `(M=690, 2560)` prefix hidden states into a |
| single **`z_rl`** token; a causal AR decoder reconstructs the prefix (per-token squared-L2, |
| stop-grad targets, α=0 / frozen VLA). `z_rl` is the state for the downstream SAC actor-critic. |
| |
| ## Chosen encoder |
| **`checkpoints/rl_token_encoder_ctxdrop09_best.pt`** (load `["ema"]`). Trained with the |
| openpi/paper knobs (AdamW 5e-5, 1k warmup, grad-clip 1.0, EMA 0.999, 10k steps) **plus |
| `context_dropout=0.9`** — zeroing 90% of the decoder's teacher-forced context, which fixes |
| the AR-leak that otherwise leaves `z_rl` diffuse (the bare α=0 reconstruction lets the decoder |
| ignore the token). |
| |
| ## Validation |
| | | baseline (α=0) | **dropout-0.9 (chosen)** | |
| |---|---|---| |
| | PCA top-10 var | 15% | **28%** | |
| | temporal smoothness (↓) | 0.72 | **0.69** | |
| | **success-vs-failure** LogReg CV acc | — | **99.2%** (silhouette 0.13) | |
|
|
| `z_rl` cleanly separates success (44 teleop demos) from failure (7 baseline rollouts, SR≈0) |
| in t-SNE — see `outputs/gate_success_fail.png`. Caveat: success/failure are from different |
| sessions, so part of the 99% is domain shift, not pure task semantics — strong upper bound. |
|
|
| ## Data |
| Trained on 9,668 `(690,2560)` prefix shards from the 44 `atharva-pantheon/yam-stack-cube` |
| demos (~1.3 h teleop @ 10 Hz). Matches the RL Token paper's "small per-task demo set" (1–10 h). |
|
|
| ## Files |
| - `code/` — `rl_token_encoder.py` (model), `train_encoder.py`, `collect_prefix.py` (demo→prefix |
| collector), `collect_fail_replay.py` (karma-rollout→prefix), `tsne_gate.py`, `gate_success_fail.py`. |
| - `checkpoints/` — `ctxdrop09_best/final` (chosen), `nodrop_best/final` (baseline), `ctxdrop05_best`. |
| - `plots/` — `tsne_final.png` (phase structure), `gate_success_fail.png` (success/fail), others. |
|
|
| ## Use (Phase-4 actor-critic) |
| ```python |
| import torch |
| from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig |
| ae = RLTokenAutoencoder(RLTokenConfig(dim=2560)) |
| ae.load_state_dict(torch.load("rl_token_encoder_ctxdrop09_best.pt", map_location="cpu")["ema"]) |
| ae.eval() |
| z_rl = ae.encode(prefix, mask) # (b, M, 2560) -> (b, 2560); SAC state x = (z_rl, proprio) |
| ``` |
|
|
| **Gotcha:** validate `z_rl` via `tsne_gate.py` / `gate_success_fail.py`, NOT a first-token |
| ablation — the first prefix token is a constant special id (151645), making that test vacuous. |
|
|