Robotics
world-model
dynamics-model
manipulation
File size: 2,621 Bytes
f98b59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
---
license: bsd-3-clause
tags:
  - world-model
  - dynamics-model
  - robotics
  - manipulation
datasets:
  - villekuosmanen/bin_pick_pack_coffee_capsules
---

# Visual World Model — bin-pick-pack (proprio + visual latents)

MLP-based world model (`SystemDynamicsEnsemble`) trained on the bin-pick-pack-coffee-capsules manipulation dataset. Predicts next proprioceptive state (17D) and next visual latent (32D) given a history of states, actions, and visual latents.

## Architecture

- **Type**: MLP ensemble (2 heads)
- **State dim**: 17 (7 joint positions + 10 EEF pose as xyz/rot6d/gripper)
- **Action dim**: 17 (same decomposition)
- **Visual dim**: 32 (LAM-encoded visual latents from front camera)
- **History horizon**: 2
- **Forecast horizon**: 1
- **Checkpoint size**: 1.9 MB

## Training

- **Dataset**: [villekuosmanen/bin_pick_pack_coffee_capsules](https://huggingface.co/datasets/villekuosmanen/bin_pick_pack_coffee_capsules) — 47865 frames, 200 episodes
- **Visual latents**: Precomputed from fine-tuned LAM encoder ([pravsels/lam-binpack-finetune](https://huggingface.co/pravsels/lam-binpack-finetune))
- **Split**: 35387 train / 12078 val sequences (val_ratio=0.25, seed=0)
- **Epochs**: 50
- **Batch size**: 64
- **Learning rate**: 3e-4 (cosine schedule, min_lr=3e-5)
- **Final train loss**: 0.09442
- **Final val loss**: 0.09916
- **Visual loss**: 0.07207
- **W&B**: [pravsels/binpack-world-model/runs/2pq0n2mx](https://wandb.ai/pravsels/binpack-world-model/runs/2pq0n2mx)

## Files

| File | Description |
|------|-------------|
| `best.pt` | Best checkpoint (epoch 50) |
| `config.yaml` | Training configuration (Isambard) |

## Checkpoint format

```python
checkpoint = torch.load("best.pt", map_location="cpu")
# checkpoint["model_state_dict"] -> SystemDynamicsEnsemble.load_state_dict()
# checkpoint["epoch"], checkpoint["train_loss"], checkpoint["val_loss"], etc.
```

## Integrity

```
sha256: 050dfaffa2c98ff112d6a0d2eba738328bac8b3934863bfecff59d62bd2d2410  best.pt
```

Verified by running `sha256sum` twice on the source file.

## Usage

```python
from rsl_rl.offline.offline_world_model_trainer import build_system_dynamics_model
import torch

model = build_system_dynamics_model(
    state_dim=17, action_dim=17, visual_dim=32,
    ensemble_size=2, history_horizon=2, device="cpu",
)
ckpt = torch.load("best.pt", map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# Single-step prediction
# state_hist: (1, 2, 17), action_hist: (1, 2, 17), visual_hist: (1, 2, 32)
state_pred, _, _, _, _, _, visual_pred = model(state_hist, action_hist, visual_hist)
```