| --- |
| license: bsd-3-clause |
| tags: |
| - world-model |
| - dynamics-model |
| - robotics |
| - manipulation |
| datasets: |
| - villekuosmanen/bin_pick_pack_coffee_capsules |
| --- |
| |
| # Visual World Model — bin-pick-pack (proprio + visual latents) |
|
|
| MLP-based world model (`SystemDynamicsEnsemble`) trained on the bin-pick-pack-coffee-capsules manipulation dataset. Predicts next proprioceptive state (17D) and next visual latent (32D) given a history of states, actions, and visual latents. |
|
|
| ## Architecture |
|
|
| - **Type**: MLP ensemble (2 heads) |
| - **State dim**: 17 (7 joint positions + 10 EEF pose as xyz/rot6d/gripper) |
| - **Action dim**: 17 (same decomposition) |
| - **Visual dim**: 32 (LAM-encoded visual latents from front camera) |
| - **History horizon**: 2 |
| - **Forecast horizon**: 1 |
| - **Checkpoint size**: 1.9 MB |
|
|
| ## Training |
|
|
| - **Dataset**: [villekuosmanen/bin_pick_pack_coffee_capsules](https://huggingface.co/datasets/villekuosmanen/bin_pick_pack_coffee_capsules) — 47865 frames, 200 episodes |
| - **Visual latents**: Precomputed from fine-tuned LAM encoder ([pravsels/lam-binpack-finetune](https://huggingface.co/pravsels/lam-binpack-finetune)) |
| - **Split**: 35387 train / 12078 val sequences (val_ratio=0.25, seed=0) |
| - **Epochs**: 50 |
| - **Batch size**: 64 |
| - **Learning rate**: 3e-4 (cosine schedule, min_lr=3e-5) |
| - **Final train loss**: 0.09442 |
| - **Final val loss**: 0.09916 |
| - **Visual loss**: 0.07207 |
| - **W&B**: [pravsels/binpack-world-model/runs/2pq0n2mx](https://wandb.ai/pravsels/binpack-world-model/runs/2pq0n2mx) |
|
|
| ## Files |
|
|
| | File | Description | |
| |------|-------------| |
| | `best.pt` | Best checkpoint (epoch 50) | |
| | `config.yaml` | Training configuration (Isambard) | |
|
|
| ## Checkpoint format |
|
|
| ```python |
| checkpoint = torch.load("best.pt", map_location="cpu") |
| # checkpoint["model_state_dict"] -> SystemDynamicsEnsemble.load_state_dict() |
| # checkpoint["epoch"], checkpoint["train_loss"], checkpoint["val_loss"], etc. |
| ``` |
|
|
| ## Integrity |
|
|
| ``` |
| sha256: 050dfaffa2c98ff112d6a0d2eba738328bac8b3934863bfecff59d62bd2d2410 best.pt |
| ``` |
|
|
| Verified by running `sha256sum` twice on the source file. |
|
|
| ## Usage |
|
|
| ```python |
| from rsl_rl.offline.offline_world_model_trainer import build_system_dynamics_model |
| import torch |
| |
| model = build_system_dynamics_model( |
| state_dim=17, action_dim=17, visual_dim=32, |
| ensemble_size=2, history_horizon=2, device="cpu", |
| ) |
| ckpt = torch.load("best.pt", map_location="cpu") |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
| |
| # Single-step prediction |
| # state_hist: (1, 2, 17), action_hist: (1, 2, 17), visual_hist: (1, 2, 32) |
| state_pred, _, _, _, _, _, visual_pred = model(state_hist, action_hist, visual_hist) |
| ``` |
|
|