| --- |
| tags: |
| - ml-intern |
| - jepa |
| - world-model |
| - robotics |
| - explainable-ai |
| license: mit |
| --- |
| # LeWorldModel (LeWM): Stable End-to-End JEPA from Pixels |
|
|
| This repository contains a **clean, self-contained PyTorch implementation** of **LeWorldModel (LeWM)** from the paper: |
|
|
| > **LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels** |
| > Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero |
| > arXiv: 2603.19312 — https://arxiv.org/abs/2603.19312 |
| > Official repo: https://github.com/lucas-maes/le-wm |
|
|
| --- |
|
|
| ## 🚀 Quick Start: Free GPU Training on Google Colab |
|
|
| The easiest way to train LeWM on a **free GPU** is via our Colab-ready notebook: |
|
|
| 📓 **[Open in Colab](https://colab.research.google.com/github/ar27111994/lewm-implementation/blob/main/lewm_colab.ipynb)** *(upload the notebook from this repo)* |
|
|
| Or read the step-by-step guide: |
| 📖 **[COLAB_GUIDE.md](https://huggingface.co/ar27111994/lewm-implementation/blob/main/COLAB_GUIDE.md)** |
|
|
| **What you need**: |
| - A Google account (free) |
| - ~30–60 minutes for 10 epochs on synthetic data |
| - Optional: Hugging Face token (free) to push trained models |
|
|
| **Hardware**: Free Colab T4 GPU (15 GB VRAM) — LeWM's ~18M parameters fit comfortably. |
|
|
| --- |
|
|
| ## What is LeWorldModel? |
|
|
| LeWorldModel (LeWM) is a **Joint-Embedding Predictive Architecture (JEPA)** world model that learns **directly from raw pixels** with a **single tunable hyperparameter**. It is the first end-to-end JEPA that trains stably without: |
| - Stop-gradient / EMA mechanisms |
| - Pre-trained encoders (e.g., DINOv2) |
| - Complex multi-term losses (e.g., VICReg variants) |
|
|
| ### Key Innovations |
|
|
| | Feature | LeWM | Prior work (PLDM) | |
| |---------|------|-------------------| |
| | Loss terms | **2** (prediction + SIGReg) | 7 (prediction + 6 regularizers) | |
| | Tunable hyperparameters | **1** (lambda) | 6 (grid search O(n^6)) | |
| | End-to-end trainable | Yes | Partial (fragile) | |
| | Planning speed | **48x faster** than DINO-WM | Comparable | |
| | Params | ~18M | Similar | |
|
|
| ### Architecture (from paper section 3.1 & Appendix D) |
|
|
| ``` |
| Raw Pixels (224x224) ---> ViT-Tiny Encoder ---> [CLS] + MLP+BN ---> Latent z_t |
| | (192-dim) |
| | |
| v |
| +-------------------+ |
| | AR Predictor | <--- Actions (AdaLN-zero) |
| | 6 layers, 16h | |
| | Causal masking | |
| +-------------------+ |
| | |
| v |
| Predicted z_{t+1} |
| | |
| v |
| MSE(z_{t+1}, pred) + lambda * SIGReg(z) |
| ``` |
|
|
| **Components:** |
| - **Encoder**: ViT-Tiny (patch 14, 12 layers, 3 heads, hidden 192) -> [CLS] token -> **MLP + BatchNorm1d** projector |
| - **Predictor**: 6-layer transformer with **AdaLN-zero** action conditioning, **causal temporal masking** |
| - **SIGReg**: Sketch Isotropic Gaussian Regularizer - anti-collapse via Epps-Pulley test on random 1-D projections |
| - **Planner**: Cross-Entropy Method (CEM) in latent space for goal-conditioned control |
|
|
| --- |
|
|
| ## SIGReg: The Anti-Collapse Engine |
|
|
| SIGReg is the critical component that makes stable end-to-end training possible. |
|
|
| **Problem**: Prediction-only loss causes representation collapse (encoder maps everything to a constant). |
|
|
| **Solution**: SIGReg forces latent embeddings to match an **isotropic Gaussian N(0, I)**. |
|
|
| **How it works**: |
| 1. Collect latent tensor **Z in R^(TxBxd)** (time x batch x dim) |
| 2. Sample **M=1024** random unit-norm directions **u^(m)** on the hypersphere S^(d-1) |
| 3. Project: **h^(m) = Z dot u^(m)** -> (T, B) 1-D marginals |
| 4. Apply the **Epps-Pulley test statistic** T(h^(m)) using the characteristic function |
| 5. Trapezoid quadrature on nodes uniformly in [0, 3] with weighting w(t) = exp(-t^2/2) |
| 6. By the **Cramer-Wold theorem**: matching all 1-D marginals <=> matching the full joint distribution |
|
|
| **Key insight**: The projector uses **BatchNorm1d** (not LayerNorm) because the ViT final layer already applies LayerNorm - this is essential for SIGReg optimization. |
|
|
| --- |
|
|
| ## Training |
|
|
| ### Free GPU Training (Google Colab T4) |
|
|
| ```python |
| # In a Colab notebook with GPU runtime enabled: |
| !pip install -q transformers einops huggingface_hub matplotlib numpy tqdm |
| |
| # Download implementation |
| from huggingface_hub import hf_hub_download |
| hf_hub_download("ar27111994/lewm-implementation", "lewm_model.py", local_dir="/content") |
| hf_hub_download("ar27111994/lewm-implementation", "lewm_train.py", local_dir="/content") |
| |
| # Train with synthetic data (no 12GB download needed) |
| !python /content/lewm_train.py --use_synthetic \ |
| --n_episodes 2000 --epochs 10 --batch_size 128 \ |
| --lambd 0.1 --history_size 3 --seq_len 4 \ |
| --frameskip 5 --action_dim 2 --output_dir /content/drive/MyDrive/lewm |
| ``` |
|
|
| See **[COLAB_GUIDE.md](https://huggingface.co/ar27111994/lewm-implementation/blob/main/COLAB_GUIDE.md)** for the full notebook, troubleshooting, and real dataset download instructions. |
|
|
| ### Synthetic Smoke Test (CPU-friendly) |
|
|
| ```bash |
| python lewm_train.py --use_synthetic \ |
| --n_episodes 2000 --epochs 10 \ |
| --batch_size 128 --lr 1e-3 \ |
| --lambd 0.1 --history_size 3 --seq_len 4 \ |
| --frameskip 5 --action_dim 2 |
| ``` |
|
|
| ### Real PushT Dataset |
|
|
| 1. Download official dataset: |
| ```bash |
| python -c "from huggingface_hub import hf_hub_download; \ |
| hf_hub_download('quentinll/lewm-pusht', 'pusht_expert_train.h5.zst', repo_type='dataset')" |
| ``` |
|
|
| 2. Decompress and train: |
| ```bash |
| python lewm_train.py \ |
| --h5_path /path/to/pusht_expert_train.h5 \ |
| --epochs 10 --batch_size 128 \ |
| --lambd 0.1 --history_size 3 |
| ``` |
|
|
| ### Hyperparameters (from paper) |
|
|
| | Parameter | Value | |
| |-----------|-------| |
| | Batch size | 128 | |
| | Seq length | 4 frames + 4 action blocks | |
| | Frame skip | 5 | |
| | Resolution | 224x224 | |
| | Epochs | 10 | |
| | Embedding dim | 192 | |
| | Predictor dropout | 0.1 | |
| | lambda (SIGReg weight) | 0.1 | |
| | History length | 3 (PushT, Cube), 1 (TwoRoom) | |
| | Optimizer | AdamW with cosine schedule | |
|
|
| **Only lambda needs tuning** - performance is insensitive to number of projections (M=1024) and integration knots (17). |
|
|
| --- |
|
|
| ## Planning with CEM |
|
|
| ```python |
| from lewm_model import build_lewm, cem_plan |
| |
| model = build_lewm(action_dim=10, history_size=3) |
| # ... load trained weights ... |
| |
| best_actions = cem_plan( |
| model, |
| initial_pixels=context_frames, # (1, H, C, 224, 224) |
| goal_pixels=goal_frame, # (1, 1, C, 224, 224) |
| action_dim=10, |
| horizon=5, # 5 latent steps = 25 env steps (frame_skip=5) |
| n_samples=300, |
| n_iters=30, # 30 for PushT, 10 for others |
| n_elites=30, |
| history_size=3, |
| ) |
| ``` |
|
|
| --- |
|
|
| ## Results (from paper) |
|
|
| | Method | PushT Success Rate | Planning Time | |
| |--------|-------------------|---------------| |
| | **LeWM (ours)** | **96.0 +/- 2.8%** | **<1 sec** | |
| | DINO-WM | 92.0 +/- 1.6% | ~48x slower | |
| | PLDM | 78.0 +/- 5.0% | comparable | |
|
|
| - **48x faster planning** than DINO-WM due to ~200x fewer tokens in latent space |
| - Single GPU (L40S) training in "a few hours" |
| - No stop-gradient, no EMA, no pre-trained encoders |
|
|
| --- |
|
|
| ## Project Structure |
|
|
| ``` |
| lewm_model.py - Core model (Encoder, Predictor, SIGReg, CEM) |
| lewm_train.py - Training script (HDF5 + synthetic datasets) |
| lewm_mini_test.py - Minimal sanity check |
| lewm_colab.ipynb - Full Colab-ready training notebook |
| COLAB_GUIDE.md - Step-by-step free GPU training guide |
| config.json - Verified architecture config from official model |
| EXPLANATION.md - 16KB deep-dive technical explanation |
| ``` |
|
|
| --- |
|
|
| ## Interactive Demo |
|
|
| Try the **explainable interactive Space** (no training required): |
|
|
| 🔗 **https://huggingface.co/spaces/ar27111994/lewm-explainable** |
|
|
| Features: |
| - **Architecture** tab: Full pipeline schematic |
| - **SIGReg Explorer**: Adjust collapse level and see real-time distributional analysis |
| - **CEM Planning**: Visualize Cross-Entropy Method convergence |
| - **Key Results**: Paper results and hyperparameters |
|
|
| --- |
|
|
| ## Citation |
|
|
| ```bibtex |
| @article{maes_lelidec2026lewm, |
| title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels}, |
| author={Maes, Lucas and Le Lidec, Quentin and Scieur, Damien and LeCun, Yann and Balestriero, Randall}, |
| journal={arXiv preprint}, |
| year={2026} |
| } |
| ``` |
|
|
| --- |
|
|
| ## License |
|
|
| MIT (same as the official repository). |
|
|
| --- |
|
|
| *This implementation is self-contained in standard PyTorch + transformers + einops, with no dependency on the private `stable-pretraining` or `stable-worldmodel` packages for the core model logic.* |
|
|
| <!-- ml-intern-provenance --> |
| ## Generated by ML Intern |
|
|
| This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub. |
|
|
| - Try ML Intern: https://smolagents-ml-intern.hf.space |
| - Source code: https://github.com/huggingface/ml-intern |
|
|
| ## Usage |
|
|
| ```python |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
| model_id = "ar27111994/lewm-implementation" |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| ``` |
|
|
| For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class. |
|
|