LeWorldModel (LeWM): Stable End-to-End JEPA from Pixels
This repository contains a clean, self-contained PyTorch implementation of LeWorldModel (LeWM) from the paper:
LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels
Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero
arXiv: 2603.19312 β https://arxiv.org/abs/2603.19312
Official repo: https://github.com/lucas-maes/le-wm
π Quick Start: Free GPU Training on Google Colab
The easiest way to train LeWM on a free GPU is via our Colab-ready notebook:
π Open in Colab (upload the notebook from this repo)
Or read the step-by-step guide: π COLAB_GUIDE.md
What you need:
- A Google account (free)
- ~30β60 minutes for 10 epochs on synthetic data
- Optional: Hugging Face token (free) to push trained models
Hardware: Free Colab T4 GPU (15 GB VRAM) β LeWM's ~18M parameters fit comfortably.
What is LeWorldModel?
LeWorldModel (LeWM) is a Joint-Embedding Predictive Architecture (JEPA) world model that learns directly from raw pixels with a single tunable hyperparameter. It is the first end-to-end JEPA that trains stably without:
- Stop-gradient / EMA mechanisms
- Pre-trained encoders (e.g., DINOv2)
- Complex multi-term losses (e.g., VICReg variants)
Key Innovations
| Feature | LeWM | Prior work (PLDM) |
|---|---|---|
| Loss terms | 2 (prediction + SIGReg) | 7 (prediction + 6 regularizers) |
| Tunable hyperparameters | 1 (lambda) | 6 (grid search O(n^6)) |
| End-to-end trainable | Yes | Partial (fragile) |
| Planning speed | 48x faster than DINO-WM | Comparable |
| Params | ~18M | Similar |
Architecture (from paper section 3.1 & Appendix D)
Raw Pixels (224x224) ---> ViT-Tiny Encoder ---> [CLS] + MLP+BN ---> Latent z_t
| (192-dim)
|
v
+-------------------+
| AR Predictor | <--- Actions (AdaLN-zero)
| 6 layers, 16h |
| Causal masking |
+-------------------+
|
v
Predicted z_{t+1}
|
v
MSE(z_{t+1}, pred) + lambda * SIGReg(z)
Components:
- Encoder: ViT-Tiny (patch 14, 12 layers, 3 heads, hidden 192) -> [CLS] token -> MLP + BatchNorm1d projector
- Predictor: 6-layer transformer with AdaLN-zero action conditioning, causal temporal masking
- SIGReg: Sketch Isotropic Gaussian Regularizer - anti-collapse via Epps-Pulley test on random 1-D projections
- Planner: Cross-Entropy Method (CEM) in latent space for goal-conditioned control
SIGReg: The Anti-Collapse Engine
SIGReg is the critical component that makes stable end-to-end training possible.
Problem: Prediction-only loss causes representation collapse (encoder maps everything to a constant).
Solution: SIGReg forces latent embeddings to match an isotropic Gaussian N(0, I).
How it works:
- Collect latent tensor Z in R^(TxBxd) (time x batch x dim)
- Sample M=1024 random unit-norm directions u^(m) on the hypersphere S^(d-1)
- Project: h^(m) = Z dot u^(m) -> (T, B) 1-D marginals
- Apply the Epps-Pulley test statistic T(h^(m)) using the characteristic function
- Trapezoid quadrature on nodes uniformly in [0, 3] with weighting w(t) = exp(-t^2/2)
- By the Cramer-Wold theorem: matching all 1-D marginals <=> matching the full joint distribution
Key insight: The projector uses BatchNorm1d (not LayerNorm) because the ViT final layer already applies LayerNorm - this is essential for SIGReg optimization.
Training
Free GPU Training (Google Colab T4)
# In a Colab notebook with GPU runtime enabled:
!pip install -q transformers einops huggingface_hub matplotlib numpy tqdm
# Download implementation
from huggingface_hub import hf_hub_download
hf_hub_download("ar27111994/lewm-implementation", "lewm_model.py", local_dir="/content")
hf_hub_download("ar27111994/lewm-implementation", "lewm_train.py", local_dir="/content")
# Train with synthetic data (no 12GB download needed)
!python /content/lewm_train.py --use_synthetic \
--n_episodes 2000 --epochs 10 --batch_size 128 \
--lambd 0.1 --history_size 3 --seq_len 4 \
--frameskip 5 --action_dim 2 --output_dir /content/drive/MyDrive/lewm
See COLAB_GUIDE.md for the full notebook, troubleshooting, and real dataset download instructions.
Synthetic Smoke Test (CPU-friendly)
python lewm_train.py --use_synthetic \
--n_episodes 2000 --epochs 10 \
--batch_size 128 --lr 1e-3 \
--lambd 0.1 --history_size 3 --seq_len 4 \
--frameskip 5 --action_dim 2
Real PushT Dataset
- Download official dataset:
python -c "from huggingface_hub import hf_hub_download; \
hf_hub_download('quentinll/lewm-pusht', 'pusht_expert_train.h5.zst', repo_type='dataset')"
- Decompress and train:
python lewm_train.py \
--h5_path /path/to/pusht_expert_train.h5 \
--epochs 10 --batch_size 128 \
--lambd 0.1 --history_size 3
Hyperparameters (from paper)
| Parameter | Value |
|---|---|
| Batch size | 128 |
| Seq length | 4 frames + 4 action blocks |
| Frame skip | 5 |
| Resolution | 224x224 |
| Epochs | 10 |
| Embedding dim | 192 |
| Predictor dropout | 0.1 |
| lambda (SIGReg weight) | 0.1 |
| History length | 3 (PushT, Cube), 1 (TwoRoom) |
| Optimizer | AdamW with cosine schedule |
Only lambda needs tuning - performance is insensitive to number of projections (M=1024) and integration knots (17).
Planning with CEM
from lewm_model import build_lewm, cem_plan
model = build_lewm(action_dim=10, history_size=3)
# ... load trained weights ...
best_actions = cem_plan(
model,
initial_pixels=context_frames, # (1, H, C, 224, 224)
goal_pixels=goal_frame, # (1, 1, C, 224, 224)
action_dim=10,
horizon=5, # 5 latent steps = 25 env steps (frame_skip=5)
n_samples=300,
n_iters=30, # 30 for PushT, 10 for others
n_elites=30,
history_size=3,
)
Results (from paper)
| Method | PushT Success Rate | Planning Time |
|---|---|---|
| LeWM (ours) | 96.0 +/- 2.8% | <1 sec |
| DINO-WM | 92.0 +/- 1.6% | ~48x slower |
| PLDM | 78.0 +/- 5.0% | comparable |
- 48x faster planning than DINO-WM due to ~200x fewer tokens in latent space
- Single GPU (L40S) training in "a few hours"
- No stop-gradient, no EMA, no pre-trained encoders
Project Structure
lewm_model.py - Core model (Encoder, Predictor, SIGReg, CEM)
lewm_train.py - Training script (HDF5 + synthetic datasets)
lewm_mini_test.py - Minimal sanity check
lewm_colab.ipynb - Full Colab-ready training notebook
COLAB_GUIDE.md - Step-by-step free GPU training guide
config.json - Verified architecture config from official model
EXPLANATION.md - 16KB deep-dive technical explanation
Interactive Demo
Try the explainable interactive Space (no training required):
π https://huggingface.co/spaces/ar27111994/lewm-explainable
Features:
- Architecture tab: Full pipeline schematic
- SIGReg Explorer: Adjust collapse level and see real-time distributional analysis
- CEM Planning: Visualize Cross-Entropy Method convergence
- Key Results: Paper results and hyperparameters
Citation
@article{maes_lelidec2026lewm,
title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels},
author={Maes, Lucas and Le Lidec, Quentin and Scieur, Damien and LeCun, Yann and Balestriero, Randall},
journal={arXiv preprint},
year={2026}
}
License
MIT (same as the official repository).
This implementation is self-contained in standard PyTorch + transformers + einops, with no dependency on the private stable-pretraining or stable-worldmodel packages for the core model logic.
Generated by ML Intern
This model repository was generated by ML Intern, an agent for machine learning research and development on the Hugging Face Hub.
- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "ar27111994/lewm-implementation"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
For non-causal architectures, replace AutoModelForCausalLM with the appropriate AutoModel class.
- Downloads last month
- 54