pi05_prgvla_stack / README.md
spikefly's picture
Upload README.md with huggingface_hub
a0460db verified
---
license: apache-2.0
tags:
- robotics
- vla
- openpi
- pi0.5
- franka
library_name: openpi
---
# pi0.5 fine-tuned on `zhuoKCL/prgvla_stack`
Single-task pi0.5 (JAX) fine-tune for **cup stacking** on a Franka.
Trained from `gs://openpi-assets/checkpoints/pi05_droid/params` using `openpi`'s `pi05_droid_finetune` recipe with our own re-computed norm stats.
- Step: **55 000** (last fully finalized checkpoint; the 60 000-step run got I/O-blocked while saving step 58 000, but loss had already plateaued at ~0.001 from ~step 50 000 onwards, so 55 000 is functionally the converged model)
- Final loss: ~0.001 (flow-matching MSE)
- Action: `(horizon=16, dim=32)` — pi0.5 standard
- State: 8-dim Franka (joint_position 7 + gripper 1)
- Cameras (from DROID layout):
- `base_0_rgb` ← `ext_1`
- `left_wrist_0_rgb``wrist`
- `right_wrist_0_rgb` ← zeros (mask=False)
- Prompt: per-episode natural-language synonym from `tasks.jsonl` (NO fixed phrase)
## Files
| Path | Purpose |
|---|---|
| `params/` | orbax checkpoint, JAX params (12 GB) |
| `assets/zhuoKCL/prgvla_stack/norm_stats.json` | q01/q99 quantile norm stats (pi05 standard) |
| `norm_stats.json` | same file copied to root for quick inspection |
`train_state/` (optimizer state, ~30 GB) is **not** included — inference does not need it.
## Use it from `openpi`
In your local copy of `openpi`, edit `src/openpi/training/config.py` → the `pi05_droid_finetune` entry:
```python
TrainConfig(
name="pi05_droid_finetune",
model=pi0_config.Pi0Config(pi05=True, action_dim=32, action_horizon=16),
data=LeRobotDROIDDataConfig(
repo_id="zhuoKCL/prgvla_stack", # 1) was: lerobot's droid repo
base_config=DataConfig(prompt_from_task=True),
# 2) remove / comment out any AssetsConfig(asset_id="droid", ...) line
),
weight_loader=weight_loaders.CheckpointWeightLoader(
"<path-to-this-repo>/params" # or huggingface-cli download
),
num_train_steps=60_000,
),
```
Then run inference exactly as upstream does.
## Inference contract (observation dict)
```python
obs = {
"observation/joint_position": np.ndarray(7,),
"observation/gripper_position": float, # scalar, becomes 1d
"observation/exterior_image_1_left": uint8 H×W×3, # → base_0_rgb
"observation/wrist_image_left": uint8 H×W×3, # → left_wrist_0_rgb
"prompt": "<natural language sentence>",
}
```
## Training data
[`zhuoKCL/prgvla_stack`](https://huggingface.co/datasets/zhuoKCL/prgvla_stack) (LeRobot v2.1).