LeWorldModel (LeWM): Stable End-to-End JEPA from Pixels

This repository contains a clean, self-contained PyTorch implementation of LeWorldModel (LeWM) from the paper:

LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels
Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero
arXiv: 2603.19312 β€” https://arxiv.org/abs/2603.19312
Official repo: https://github.com/lucas-maes/le-wm


πŸš€ Quick Start: Free GPU Training on Google Colab

The easiest way to train LeWM on a free GPU is via our Colab-ready notebook:

πŸ““ Open in Colab (upload the notebook from this repo)

Or read the step-by-step guide: πŸ“– COLAB_GUIDE.md

What you need:

  • A Google account (free)
  • ~30–60 minutes for 10 epochs on synthetic data
  • Optional: Hugging Face token (free) to push trained models

Hardware: Free Colab T4 GPU (15 GB VRAM) β€” LeWM's ~18M parameters fit comfortably.


What is LeWorldModel?

LeWorldModel (LeWM) is a Joint-Embedding Predictive Architecture (JEPA) world model that learns directly from raw pixels with a single tunable hyperparameter. It is the first end-to-end JEPA that trains stably without:

  • Stop-gradient / EMA mechanisms
  • Pre-trained encoders (e.g., DINOv2)
  • Complex multi-term losses (e.g., VICReg variants)

Key Innovations

Feature LeWM Prior work (PLDM)
Loss terms 2 (prediction + SIGReg) 7 (prediction + 6 regularizers)
Tunable hyperparameters 1 (lambda) 6 (grid search O(n^6))
End-to-end trainable Yes Partial (fragile)
Planning speed 48x faster than DINO-WM Comparable
Params ~18M Similar

Architecture (from paper section 3.1 & Appendix D)

Raw Pixels (224x224) ---> ViT-Tiny Encoder ---> [CLS] + MLP+BN ---> Latent z_t
                                              |    (192-dim)
                                              |
                                              v
                                    +-------------------+
                                    |  AR Predictor     |  <--- Actions (AdaLN-zero)
                                    |  6 layers, 16h    |
                                    |  Causal masking   |
                                    +-------------------+
                                              |
                                              v
                                    Predicted z_{t+1}
                                              |
                                              v
                                    MSE(z_{t+1}, pred) + lambda * SIGReg(z)

Components:

  • Encoder: ViT-Tiny (patch 14, 12 layers, 3 heads, hidden 192) -> [CLS] token -> MLP + BatchNorm1d projector
  • Predictor: 6-layer transformer with AdaLN-zero action conditioning, causal temporal masking
  • SIGReg: Sketch Isotropic Gaussian Regularizer - anti-collapse via Epps-Pulley test on random 1-D projections
  • Planner: Cross-Entropy Method (CEM) in latent space for goal-conditioned control

SIGReg: The Anti-Collapse Engine

SIGReg is the critical component that makes stable end-to-end training possible.

Problem: Prediction-only loss causes representation collapse (encoder maps everything to a constant).

Solution: SIGReg forces latent embeddings to match an isotropic Gaussian N(0, I).

How it works:

  1. Collect latent tensor Z in R^(TxBxd) (time x batch x dim)
  2. Sample M=1024 random unit-norm directions u^(m) on the hypersphere S^(d-1)
  3. Project: h^(m) = Z dot u^(m) -> (T, B) 1-D marginals
  4. Apply the Epps-Pulley test statistic T(h^(m)) using the characteristic function
  5. Trapezoid quadrature on nodes uniformly in [0, 3] with weighting w(t) = exp(-t^2/2)
  6. By the Cramer-Wold theorem: matching all 1-D marginals <=> matching the full joint distribution

Key insight: The projector uses BatchNorm1d (not LayerNorm) because the ViT final layer already applies LayerNorm - this is essential for SIGReg optimization.


Training

Free GPU Training (Google Colab T4)

# In a Colab notebook with GPU runtime enabled:
!pip install -q transformers einops huggingface_hub matplotlib numpy tqdm

# Download implementation
from huggingface_hub import hf_hub_download
hf_hub_download("ar27111994/lewm-implementation", "lewm_model.py", local_dir="/content")
hf_hub_download("ar27111994/lewm-implementation", "lewm_train.py", local_dir="/content")

# Train with synthetic data (no 12GB download needed)
!python /content/lewm_train.py --use_synthetic \
    --n_episodes 2000 --epochs 10 --batch_size 128 \
    --lambd 0.1 --history_size 3 --seq_len 4 \
    --frameskip 5 --action_dim 2 --output_dir /content/drive/MyDrive/lewm

See COLAB_GUIDE.md for the full notebook, troubleshooting, and real dataset download instructions.

Synthetic Smoke Test (CPU-friendly)

python lewm_train.py --use_synthetic \
    --n_episodes 2000 --epochs 10 \
    --batch_size 128 --lr 1e-3 \
    --lambd 0.1 --history_size 3 --seq_len 4 \
    --frameskip 5 --action_dim 2

Real PushT Dataset

  1. Download official dataset:
python -c "from huggingface_hub import hf_hub_download; \
    hf_hub_download('quentinll/lewm-pusht', 'pusht_expert_train.h5.zst', repo_type='dataset')"
  1. Decompress and train:
python lewm_train.py \
    --h5_path /path/to/pusht_expert_train.h5 \
    --epochs 10 --batch_size 128 \
    --lambd 0.1 --history_size 3

Hyperparameters (from paper)

Parameter Value
Batch size 128
Seq length 4 frames + 4 action blocks
Frame skip 5
Resolution 224x224
Epochs 10
Embedding dim 192
Predictor dropout 0.1
lambda (SIGReg weight) 0.1
History length 3 (PushT, Cube), 1 (TwoRoom)
Optimizer AdamW with cosine schedule

Only lambda needs tuning - performance is insensitive to number of projections (M=1024) and integration knots (17).


Planning with CEM

from lewm_model import build_lewm, cem_plan

model = build_lewm(action_dim=10, history_size=3)
# ... load trained weights ...

best_actions = cem_plan(
    model,
    initial_pixels=context_frames,  # (1, H, C, 224, 224)
    goal_pixels=goal_frame,         # (1, 1, C, 224, 224)
    action_dim=10,
    horizon=5,        # 5 latent steps = 25 env steps (frame_skip=5)
    n_samples=300,
    n_iters=30,       # 30 for PushT, 10 for others
    n_elites=30,
    history_size=3,
)

Results (from paper)

Method PushT Success Rate Planning Time
LeWM (ours) 96.0 +/- 2.8% <1 sec
DINO-WM 92.0 +/- 1.6% ~48x slower
PLDM 78.0 +/- 5.0% comparable
  • 48x faster planning than DINO-WM due to ~200x fewer tokens in latent space
  • Single GPU (L40S) training in "a few hours"
  • No stop-gradient, no EMA, no pre-trained encoders

Project Structure

lewm_model.py         - Core model (Encoder, Predictor, SIGReg, CEM)
lewm_train.py         - Training script (HDF5 + synthetic datasets)
lewm_mini_test.py     - Minimal sanity check
lewm_colab.ipynb      - Full Colab-ready training notebook
COLAB_GUIDE.md        - Step-by-step free GPU training guide
config.json           - Verified architecture config from official model
EXPLANATION.md        - 16KB deep-dive technical explanation

Interactive Demo

Try the explainable interactive Space (no training required):

πŸ”— https://huggingface.co/spaces/ar27111994/lewm-explainable

Features:

  • Architecture tab: Full pipeline schematic
  • SIGReg Explorer: Adjust collapse level and see real-time distributional analysis
  • CEM Planning: Visualize Cross-Entropy Method convergence
  • Key Results: Paper results and hyperparameters

Citation

@article{maes_lelidec2026lewm,
  title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels},
  author={Maes, Lucas and Le Lidec, Quentin and Scieur, Damien and LeCun, Yann and Balestriero, Randall},
  journal={arXiv preprint},
  year={2026}
}

License

MIT (same as the official repository).


This implementation is self-contained in standard PyTorch + transformers + einops, with no dependency on the private stable-pretraining or stable-worldmodel packages for the core model logic.

Generated by ML Intern

This model repository was generated by ML Intern, an agent for machine learning research and development on the Hugging Face Hub.

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "ar27111994/lewm-implementation"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

For non-causal architectures, replace AutoModelForCausalLM with the appropriate AutoModel class.

Downloads last month
54
Video Preview
loading

Space using ar27111994/lewm-implementation 1

Paper for ar27111994/lewm-implementation