villekuosmanen/bin_pick_pack_coffee_capsules
Viewer • Updated • 47.9k • 1.96k
MLP-based world model (SystemDynamicsEnsemble) trained on the bin-pick-pack-coffee-capsules manipulation dataset. Predicts next proprioceptive state (17D) and next visual latent (32D) given a history of states, actions, and visual latents.
| File | Description |
|---|---|
best.pt |
Best checkpoint (epoch 50) |
config.yaml |
Training configuration (Isambard) |
checkpoint = torch.load("best.pt", map_location="cpu")
# checkpoint["model_state_dict"] -> SystemDynamicsEnsemble.load_state_dict()
# checkpoint["epoch"], checkpoint["train_loss"], checkpoint["val_loss"], etc.
sha256: 050dfaffa2c98ff112d6a0d2eba738328bac8b3934863bfecff59d62bd2d2410 best.pt
Verified by running sha256sum twice on the source file.
from rsl_rl.offline.offline_world_model_trainer import build_system_dynamics_model
import torch
model = build_system_dynamics_model(
state_dim=17, action_dim=17, visual_dim=32,
ensemble_size=2, history_horizon=2, device="cpu",
)
ckpt = torch.load("best.pt", map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
# Single-step prediction
# state_hist: (1, 2, 17), action_hist: (1, 2, 17), visual_hist: (1, 2, 32)
state_pred, _, _, _, _, _, visual_pred = model(state_hist, action_hist, visual_hist)