--- tags: - ml-intern - jepa - world-model - robotics - explainable-ai license: mit --- # LeWorldModel (LeWM): Stable End-to-End JEPA from Pixels This repository contains a **clean, self-contained PyTorch implementation** of **LeWorldModel (LeWM)** from the paper: > **LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels** > Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero > arXiv: 2603.19312 — https://arxiv.org/abs/2603.19312 > Official repo: https://github.com/lucas-maes/le-wm --- ## 🚀 Quick Start: Free GPU Training on Google Colab The easiest way to train LeWM on a **free GPU** is via our Colab-ready notebook: 📓 **[Open in Colab](https://colab.research.google.com/github/ar27111994/lewm-implementation/blob/main/lewm_colab.ipynb)** *(upload the notebook from this repo)* Or read the step-by-step guide: 📖 **[COLAB_GUIDE.md](https://huggingface.co/ar27111994/lewm-implementation/blob/main/COLAB_GUIDE.md)** **What you need**: - A Google account (free) - ~30–60 minutes for 10 epochs on synthetic data - Optional: Hugging Face token (free) to push trained models **Hardware**: Free Colab T4 GPU (15 GB VRAM) — LeWM's ~18M parameters fit comfortably. --- ## What is LeWorldModel? LeWorldModel (LeWM) is a **Joint-Embedding Predictive Architecture (JEPA)** world model that learns **directly from raw pixels** with a **single tunable hyperparameter**. It is the first end-to-end JEPA that trains stably without: - Stop-gradient / EMA mechanisms - Pre-trained encoders (e.g., DINOv2) - Complex multi-term losses (e.g., VICReg variants) ### Key Innovations | Feature | LeWM | Prior work (PLDM) | |---------|------|-------------------| | Loss terms | **2** (prediction + SIGReg) | 7 (prediction + 6 regularizers) | | Tunable hyperparameters | **1** (lambda) | 6 (grid search O(n^6)) | | End-to-end trainable | Yes | Partial (fragile) | | Planning speed | **48x faster** than DINO-WM | Comparable | | Params | ~18M | Similar | ### Architecture (from paper section 3.1 & Appendix D) ``` Raw Pixels (224x224) ---> ViT-Tiny Encoder ---> [CLS] + MLP+BN ---> Latent z_t | (192-dim) | v +-------------------+ | AR Predictor | <--- Actions (AdaLN-zero) | 6 layers, 16h | | Causal masking | +-------------------+ | v Predicted z_{t+1} | v MSE(z_{t+1}, pred) + lambda * SIGReg(z) ``` **Components:** - **Encoder**: ViT-Tiny (patch 14, 12 layers, 3 heads, hidden 192) -> [CLS] token -> **MLP + BatchNorm1d** projector - **Predictor**: 6-layer transformer with **AdaLN-zero** action conditioning, **causal temporal masking** - **SIGReg**: Sketch Isotropic Gaussian Regularizer - anti-collapse via Epps-Pulley test on random 1-D projections - **Planner**: Cross-Entropy Method (CEM) in latent space for goal-conditioned control --- ## SIGReg: The Anti-Collapse Engine SIGReg is the critical component that makes stable end-to-end training possible. **Problem**: Prediction-only loss causes representation collapse (encoder maps everything to a constant). **Solution**: SIGReg forces latent embeddings to match an **isotropic Gaussian N(0, I)**. **How it works**: 1. Collect latent tensor **Z in R^(TxBxd)** (time x batch x dim) 2. Sample **M=1024** random unit-norm directions **u^(m)** on the hypersphere S^(d-1) 3. Project: **h^(m) = Z dot u^(m)** -> (T, B) 1-D marginals 4. Apply the **Epps-Pulley test statistic** T(h^(m)) using the characteristic function 5. Trapezoid quadrature on nodes uniformly in [0, 3] with weighting w(t) = exp(-t^2/2) 6. By the **Cramer-Wold theorem**: matching all 1-D marginals <=> matching the full joint distribution **Key insight**: The projector uses **BatchNorm1d** (not LayerNorm) because the ViT final layer already applies LayerNorm - this is essential for SIGReg optimization. --- ## Training ### Free GPU Training (Google Colab T4) ```python # In a Colab notebook with GPU runtime enabled: !pip install -q transformers einops huggingface_hub matplotlib numpy tqdm # Download implementation from huggingface_hub import hf_hub_download hf_hub_download("ar27111994/lewm-implementation", "lewm_model.py", local_dir="/content") hf_hub_download("ar27111994/lewm-implementation", "lewm_train.py", local_dir="/content") # Train with synthetic data (no 12GB download needed) !python /content/lewm_train.py --use_synthetic \ --n_episodes 2000 --epochs 10 --batch_size 128 \ --lambd 0.1 --history_size 3 --seq_len 4 \ --frameskip 5 --action_dim 2 --output_dir /content/drive/MyDrive/lewm ``` See **[COLAB_GUIDE.md](https://huggingface.co/ar27111994/lewm-implementation/blob/main/COLAB_GUIDE.md)** for the full notebook, troubleshooting, and real dataset download instructions. ### Synthetic Smoke Test (CPU-friendly) ```bash python lewm_train.py --use_synthetic \ --n_episodes 2000 --epochs 10 \ --batch_size 128 --lr 1e-3 \ --lambd 0.1 --history_size 3 --seq_len 4 \ --frameskip 5 --action_dim 2 ``` ### Real PushT Dataset 1. Download official dataset: ```bash python -c "from huggingface_hub import hf_hub_download; \ hf_hub_download('quentinll/lewm-pusht', 'pusht_expert_train.h5.zst', repo_type='dataset')" ``` 2. Decompress and train: ```bash python lewm_train.py \ --h5_path /path/to/pusht_expert_train.h5 \ --epochs 10 --batch_size 128 \ --lambd 0.1 --history_size 3 ``` ### Hyperparameters (from paper) | Parameter | Value | |-----------|-------| | Batch size | 128 | | Seq length | 4 frames + 4 action blocks | | Frame skip | 5 | | Resolution | 224x224 | | Epochs | 10 | | Embedding dim | 192 | | Predictor dropout | 0.1 | | lambda (SIGReg weight) | 0.1 | | History length | 3 (PushT, Cube), 1 (TwoRoom) | | Optimizer | AdamW with cosine schedule | **Only lambda needs tuning** - performance is insensitive to number of projections (M=1024) and integration knots (17). --- ## Planning with CEM ```python from lewm_model import build_lewm, cem_plan model = build_lewm(action_dim=10, history_size=3) # ... load trained weights ... best_actions = cem_plan( model, initial_pixels=context_frames, # (1, H, C, 224, 224) goal_pixels=goal_frame, # (1, 1, C, 224, 224) action_dim=10, horizon=5, # 5 latent steps = 25 env steps (frame_skip=5) n_samples=300, n_iters=30, # 30 for PushT, 10 for others n_elites=30, history_size=3, ) ``` --- ## Results (from paper) | Method | PushT Success Rate | Planning Time | |--------|-------------------|---------------| | **LeWM (ours)** | **96.0 +/- 2.8%** | **<1 sec** | | DINO-WM | 92.0 +/- 1.6% | ~48x slower | | PLDM | 78.0 +/- 5.0% | comparable | - **48x faster planning** than DINO-WM due to ~200x fewer tokens in latent space - Single GPU (L40S) training in "a few hours" - No stop-gradient, no EMA, no pre-trained encoders --- ## Project Structure ``` lewm_model.py - Core model (Encoder, Predictor, SIGReg, CEM) lewm_train.py - Training script (HDF5 + synthetic datasets) lewm_mini_test.py - Minimal sanity check lewm_colab.ipynb - Full Colab-ready training notebook COLAB_GUIDE.md - Step-by-step free GPU training guide config.json - Verified architecture config from official model EXPLANATION.md - 16KB deep-dive technical explanation ``` --- ## Interactive Demo Try the **explainable interactive Space** (no training required): 🔗 **https://huggingface.co/spaces/ar27111994/lewm-explainable** Features: - **Architecture** tab: Full pipeline schematic - **SIGReg Explorer**: Adjust collapse level and see real-time distributional analysis - **CEM Planning**: Visualize Cross-Entropy Method convergence - **Key Results**: Paper results and hyperparameters --- ## Citation ```bibtex @article{maes_lelidec2026lewm, title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels}, author={Maes, Lucas and Le Lidec, Quentin and Scieur, Damien and LeCun, Yann and Balestriero, Randall}, journal={arXiv preprint}, year={2026} } ``` --- ## License MIT (same as the official repository). --- *This implementation is self-contained in standard PyTorch + transformers + einops, with no dependency on the private `stable-pretraining` or `stable-worldmodel` packages for the core model logic.* ## Generated by ML Intern This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub. - Try ML Intern: https://smolagents-ml-intern.hf.space - Source code: https://github.com/huggingface/ml-intern ## Usage ```python from transformers import AutoModelForCausalLM, AutoTokenizer model_id = "ar27111994/lewm-implementation" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) ``` For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.