Robotics
world-model
dynamics-model
manipulation
visual-wm-binpack / README.md
pravsels's picture
Upload README.md with huggingface_hub
f98b59c verified
---
license: bsd-3-clause
tags:
- world-model
- dynamics-model
- robotics
- manipulation
datasets:
- villekuosmanen/bin_pick_pack_coffee_capsules
---
# Visual World Model — bin-pick-pack (proprio + visual latents)
MLP-based world model (`SystemDynamicsEnsemble`) trained on the bin-pick-pack-coffee-capsules manipulation dataset. Predicts next proprioceptive state (17D) and next visual latent (32D) given a history of states, actions, and visual latents.
## Architecture
- **Type**: MLP ensemble (2 heads)
- **State dim**: 17 (7 joint positions + 10 EEF pose as xyz/rot6d/gripper)
- **Action dim**: 17 (same decomposition)
- **Visual dim**: 32 (LAM-encoded visual latents from front camera)
- **History horizon**: 2
- **Forecast horizon**: 1
- **Checkpoint size**: 1.9 MB
## Training
- **Dataset**: [villekuosmanen/bin_pick_pack_coffee_capsules](https://huggingface.co/datasets/villekuosmanen/bin_pick_pack_coffee_capsules) — 47865 frames, 200 episodes
- **Visual latents**: Precomputed from fine-tuned LAM encoder ([pravsels/lam-binpack-finetune](https://huggingface.co/pravsels/lam-binpack-finetune))
- **Split**: 35387 train / 12078 val sequences (val_ratio=0.25, seed=0)
- **Epochs**: 50
- **Batch size**: 64
- **Learning rate**: 3e-4 (cosine schedule, min_lr=3e-5)
- **Final train loss**: 0.09442
- **Final val loss**: 0.09916
- **Visual loss**: 0.07207
- **W&B**: [pravsels/binpack-world-model/runs/2pq0n2mx](https://wandb.ai/pravsels/binpack-world-model/runs/2pq0n2mx)
## Files
| File | Description |
|------|-------------|
| `best.pt` | Best checkpoint (epoch 50) |
| `config.yaml` | Training configuration (Isambard) |
## Checkpoint format
```python
checkpoint = torch.load("best.pt", map_location="cpu")
# checkpoint["model_state_dict"] -> SystemDynamicsEnsemble.load_state_dict()
# checkpoint["epoch"], checkpoint["train_loss"], checkpoint["val_loss"], etc.
```
## Integrity
```
sha256: 050dfaffa2c98ff112d6a0d2eba738328bac8b3934863bfecff59d62bd2d2410 best.pt
```
Verified by running `sha256sum` twice on the source file.
## Usage
```python
from rsl_rl.offline.offline_world_model_trainer import build_system_dynamics_model
import torch
model = build_system_dynamics_model(
state_dim=17, action_dim=17, visual_dim=32,
ensemble_size=2, history_horizon=2, device="cpu",
)
ckpt = torch.load("best.pt", map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
# Single-step prediction
# state_hist: (1, 2, 17), action_hist: (1, 2, 17), visual_hist: (1, 2, 32)
state_pred, _, _, _, _, _, visual_pred = model(state_hist, action_hist, visual_hist)
```