Robotics
world-model
dynamics-model
manipulation
pravsels commited on
Commit
f98b59c
·
verified ·
1 Parent(s): 156978a

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +79 -0
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ tags:
4
+ - world-model
5
+ - dynamics-model
6
+ - robotics
7
+ - manipulation
8
+ datasets:
9
+ - villekuosmanen/bin_pick_pack_coffee_capsules
10
+ ---
11
+
12
+ # Visual World Model — bin-pick-pack (proprio + visual latents)
13
+
14
+ MLP-based world model (`SystemDynamicsEnsemble`) trained on the bin-pick-pack-coffee-capsules manipulation dataset. Predicts next proprioceptive state (17D) and next visual latent (32D) given a history of states, actions, and visual latents.
15
+
16
+ ## Architecture
17
+
18
+ - **Type**: MLP ensemble (2 heads)
19
+ - **State dim**: 17 (7 joint positions + 10 EEF pose as xyz/rot6d/gripper)
20
+ - **Action dim**: 17 (same decomposition)
21
+ - **Visual dim**: 32 (LAM-encoded visual latents from front camera)
22
+ - **History horizon**: 2
23
+ - **Forecast horizon**: 1
24
+ - **Checkpoint size**: 1.9 MB
25
+
26
+ ## Training
27
+
28
+ - **Dataset**: [villekuosmanen/bin_pick_pack_coffee_capsules](https://huggingface.co/datasets/villekuosmanen/bin_pick_pack_coffee_capsules) — 47865 frames, 200 episodes
29
+ - **Visual latents**: Precomputed from fine-tuned LAM encoder ([pravsels/lam-binpack-finetune](https://huggingface.co/pravsels/lam-binpack-finetune))
30
+ - **Split**: 35387 train / 12078 val sequences (val_ratio=0.25, seed=0)
31
+ - **Epochs**: 50
32
+ - **Batch size**: 64
33
+ - **Learning rate**: 3e-4 (cosine schedule, min_lr=3e-5)
34
+ - **Final train loss**: 0.09442
35
+ - **Final val loss**: 0.09916
36
+ - **Visual loss**: 0.07207
37
+ - **W&B**: [pravsels/binpack-world-model/runs/2pq0n2mx](https://wandb.ai/pravsels/binpack-world-model/runs/2pq0n2mx)
38
+
39
+ ## Files
40
+
41
+ | File | Description |
42
+ |------|-------------|
43
+ | `best.pt` | Best checkpoint (epoch 50) |
44
+ | `config.yaml` | Training configuration (Isambard) |
45
+
46
+ ## Checkpoint format
47
+
48
+ ```python
49
+ checkpoint = torch.load("best.pt", map_location="cpu")
50
+ # checkpoint["model_state_dict"] -> SystemDynamicsEnsemble.load_state_dict()
51
+ # checkpoint["epoch"], checkpoint["train_loss"], checkpoint["val_loss"], etc.
52
+ ```
53
+
54
+ ## Integrity
55
+
56
+ ```
57
+ sha256: 050dfaffa2c98ff112d6a0d2eba738328bac8b3934863bfecff59d62bd2d2410 best.pt
58
+ ```
59
+
60
+ Verified by running `sha256sum` twice on the source file.
61
+
62
+ ## Usage
63
+
64
+ ```python
65
+ from rsl_rl.offline.offline_world_model_trainer import build_system_dynamics_model
66
+ import torch
67
+
68
+ model = build_system_dynamics_model(
69
+ state_dim=17, action_dim=17, visual_dim=32,
70
+ ensemble_size=2, history_horizon=2, device="cpu",
71
+ )
72
+ ckpt = torch.load("best.pt", map_location="cpu")
73
+ model.load_state_dict(ckpt["model_state_dict"])
74
+ model.eval()
75
+
76
+ # Single-step prediction
77
+ # state_hist: (1, 2, 17), action_hist: (1, 2, 17), visual_hist: (1, 2, 32)
78
+ state_pred, _, _, _, _, _, visual_pred = model(state_hist, action_hist, visual_hist)
79
+ ```