LeWM Billiards β Trained World Model
Author: Santosh Jaiswal (@hellojais)
Base architecture: LeWM by Lucas Maes et al. (2025)
Training data: hellojais/billiards-worldmodel
Code: hellojais/le-wm
Model variants
| File | embed_dim | Input | Ξ»_aux | Best epoch | val/pred_loss | Notes |
|---|---|---|---|---|---|---|
lewm_epoch_8_object.ckpt |
192 | 3-ch | β | 8 | 0.00946 | Original full-size transformer |
lewm_small_epoch_8_object.ckpt |
32 | 3-ch | β | 8 | 0.00280 | Transformer baseline |
lewm_mamba_best_object.ckpt |
32 | 3-ch | β | best | 0.00340 | Mamba predictor |
lewm_framestacked_best_object.ckpt |
32 | 9-ch | β | 7 | 0.00594 | Frame-stacking; JEPA eviction occurs |
lewm_auxloss_full_best_object.ckpt |
32 | 9-ch | 0.1 | 7 | 0.00105 | Aux state supervision; eviction fixed; best model |
What this model learned
Trained on 4,000 episodes (971,321 frames) of 2D billiards gameplay. The model learned to predict future frame embeddings from current embeddings and actions β encoding billiards physics purely from pixels.
Probe results (linear probe on encoder representations; 192-dim = pre-projector CLS token, 32-dim = post-projector):
| Model | Rep dim | pos RΒ² | vel RΒ² |
|---|---|---|---|
lewm_small |
32 | 0.983 | 0.296 |
lewm_mamba |
32 | 0.983 | 0.297 |
lewm_framestacked |
192 | 0.446 β οΈ | 0.138 |
lewm_framestacked |
32 | 0.599 | 0.417 |
lewm_auxloss_full |
192 | 0.999 β | 0.947 β |
lewm_auxloss_full |
32 | 0.982 | 0.554 |
Key finding: Frame-stacking causes JEPA representational eviction β the ViT encoder stops encoding ball position (pos RΒ²=0.446 at 192-dim) under optical-flow pressure. Adding a lightweight auxiliary state supervision head (Ξ»=0.1) fully recovers position encoding (pos RΒ²=0.999) and achieves the best prediction loss across all variants.
Planning results
| Approach | Same-episode | Novel cross-episode |
|---|---|---|
| Pure JEPA embedding CEM | β FAIL | β FAIL |
| State-based hybrid CEM | β SUCCESS (9 steps) | β SUCCESS (13 steps) |
Pure JEPA planning failed due to uniform embedding geometry in this visually simple domain. See FINDINGS.md for complete analysis.
Usage
# Load checkpoint
import stable_worldmodel as swm
import torch
device = torch.device("mps") # or "cuda" or "cpu"
# Load the small model (recommended)
checkpoint = torch.load(
"lewm_small_epoch_8_object.ckpt",
map_location=device
)
Training setup
- Hardware: Apple M5 Max (64GB unified memory)
- Backend: PyTorch MPS
- Training time: ~10β11 hours per run (10 epochs); 4 model variants trained
- Framework: PyTorch Lightning + stable-worldmodel
Credits
Original LeWM architecture by:
Lucas Maes, Quentin Leroux, Gauthier Gidel, Glen Berseth
Mila / McGill University (2025)
arXiv:2603.19312