--- license: bsd-3-clause tags: - world-model - dynamics-model - robotics - manipulation datasets: - villekuosmanen/bin_pick_pack_coffee_capsules --- # Visual World Model — bin-pick-pack (proprio + visual latents) MLP-based world model (`SystemDynamicsEnsemble`) trained on the bin-pick-pack-coffee-capsules manipulation dataset. Predicts next proprioceptive state (17D) and next visual latent (32D) given a history of states, actions, and visual latents. ## Architecture - **Type**: MLP ensemble (2 heads) - **State dim**: 17 (7 joint positions + 10 EEF pose as xyz/rot6d/gripper) - **Action dim**: 17 (same decomposition) - **Visual dim**: 32 (LAM-encoded visual latents from front camera) - **History horizon**: 2 - **Forecast horizon**: 1 - **Checkpoint size**: 1.9 MB ## Training - **Dataset**: [villekuosmanen/bin_pick_pack_coffee_capsules](https://huggingface.co/datasets/villekuosmanen/bin_pick_pack_coffee_capsules) — 47865 frames, 200 episodes - **Visual latents**: Precomputed from fine-tuned LAM encoder ([pravsels/lam-binpack-finetune](https://huggingface.co/pravsels/lam-binpack-finetune)) - **Split**: 35387 train / 12078 val sequences (val_ratio=0.25, seed=0) - **Epochs**: 50 - **Batch size**: 64 - **Learning rate**: 3e-4 (cosine schedule, min_lr=3e-5) - **Final train loss**: 0.09442 - **Final val loss**: 0.09916 - **Visual loss**: 0.07207 - **W&B**: [pravsels/binpack-world-model/runs/2pq0n2mx](https://wandb.ai/pravsels/binpack-world-model/runs/2pq0n2mx) ## Files | File | Description | |------|-------------| | `best.pt` | Best checkpoint (epoch 50) | | `config.yaml` | Training configuration (Isambard) | ## Checkpoint format ```python checkpoint = torch.load("best.pt", map_location="cpu") # checkpoint["model_state_dict"] -> SystemDynamicsEnsemble.load_state_dict() # checkpoint["epoch"], checkpoint["train_loss"], checkpoint["val_loss"], etc. ``` ## Integrity ``` sha256: 050dfaffa2c98ff112d6a0d2eba738328bac8b3934863bfecff59d62bd2d2410 best.pt ``` Verified by running `sha256sum` twice on the source file. ## Usage ```python from rsl_rl.offline.offline_world_model_trainer import build_system_dynamics_model import torch model = build_system_dynamics_model( state_dim=17, action_dim=17, visual_dim=32, ensemble_size=2, history_horizon=2, device="cpu", ) ckpt = torch.load("best.pt", map_location="cpu") model.load_state_dict(ckpt["model_state_dict"]) model.eval() # Single-step prediction # state_hist: (1, 2, 17), action_hist: (1, 2, 17), visual_hist: (1, 2, 32) state_pred, _, _, _, _, _, visual_pred = model(state_hist, action_hist, visual_hist) ```