| # LeWorldModel (LeWM) — Detailed Technical Explanation |
|
|
| **Paper**: *LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels* |
| **Authors**: Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero |
| **arXiv**: [2603.19312](https://arxiv.org/abs/2603.19312) |
| **Official Repo**: [github.com/lucas-maes/le-wm](https://github.com/lucas-maes/le-wm) |
|
|
| --- |
|
|
| ## Table of Contents |
|
|
| 1. [What is JEPA?](#1-what-is-jepa) |
| 2. [The Representation Collapse Problem](#2-the-representation-collapse-problem) |
| 3. [LeWorldModel Architecture](#3-leworldmodel-architecture) |
| 4. [SIGReg: The Mathematical Core](#4-sigreg-the-mathematical-core) |
| 5. [Training Objective](#5-training-objective) |
| 6. [Latent Planning with CEM](#6-latent-planning-with-cem) |
| 7. [Why It Works: Key Design Decisions](#7-why-it-works-key-design-decisions) |
| 8. [Results & Comparisons](#8-results--comparisons) |
| 9. [Implementation Notes](#9-implementation-notes) |
| 10. [References](#10-references) |
|
|
| --- |
|
|
| ## 1. What is JEPA? |
|
|
| **Joint-Embedding Predictive Architecture (JEPA)** is a learning framework introduced by Yann LeCun for learning **world models** — internal predictive models of the environment that enable an agent to plan and reason. |
|
|
| Unlike generative models (which reconstruct pixel-level observations), JEPA learns to predict in a **compact latent space**: |
|
|
| ``` |
| Traditional Generative Model: JEPA: |
| obs_t ──► [encoder] obs_t ──► [encoder] ──► z_t |
| │ │ │ |
| ▼ ▼ ▼ |
| reconstruct obs_{t+1} predict z_{t+1} from (z_t, a_t) |
| ↑ ↑ |
| pixel loss latent prediction loss |
| ``` |
|
|
| **Key advantages of JEPA over reconstruction-based models**: |
| - **Efficiency**: Latent vectors are ~200× smaller than pixel patches (192-dim vs. 196 patches × dim) |
| - **Robustness**: Latent space discards irrelevant visual noise |
| - **Planning speed**: ~48× faster than pixel-space planning |
|
|
| --- |
|
|
| ## 2. The Representation Collapse Problem |
|
|
| The fundamental challenge in end-to-end JEPA training is **representation collapse**. |
|
|
| ### What is Collapse? |
|
|
| If the loss only penalizes prediction error in latent space, the encoder learns a **trivial solution**: map all inputs to a constant vector. The predictor then trivially learns to predict that same constant, achieving zero prediction loss. |
|
|
| ```python |
| # Collapsed encoder — maps everything to zero |
| encoder(x) = 0 for all x |
| predictor(0, a) = 0 for all a |
| loss = ||0 - 0||² = 0 ← perfect but useless! |
| ``` |
|
|
| ### How Prior Work Handled It |
|
|
| | Method | Anti-Collapse Strategy | Drawback | |
| |--------|----------------------|----------| |
| | **I-JEPA** | Stop-gradient + EMA target network | Not end-to-end | |
| | **DINO-WM** | Frozen DINOv2 pre-trained encoder | Relies on large-scale pretraining | |
| | **PLDM** | 7-term VICReg-inspired loss | Fragile, 6 hyperparameters | |
| | **LeWM (ours)** | **SIGReg** — single statistic | **1 hyperparameter, stable** | |
|
|
| --- |
|
|
| ## 3. LeWorldModel Architecture |
|
|
| LeWM consists of four main components: |
|
|
| ### 3.1 Encoder: ViT-Tiny |
|
|
| ``` |
| Input: (B, T, 3, 224, 224) raw pixel frames |
| ↓ |
| ViT-Tiny (patch=14, 12 layers, 3 heads, hidden=192) |
| ↓ |
| [CLS] token from last layer: (B*T, 192) |
| ↓ |
| Projector: Linear(192→2048) → BatchNorm1d → GELU → Linear(2048→192) |
| ↓ |
| Latent embedding z_t: (B, T, 192) |
| ``` |
|
|
| **Critical design**: The projector uses **BatchNorm1d** (not LayerNorm). The ViT's final layer already applies LayerNorm. Adding another LayerNorm would block SIGReg from optimizing because: |
| - LayerNorm normalizes within the sample, destroying batch-level statistics that SIGReg measures |
| - BatchNorm normalizes across the batch, preserving the distributional properties SIGReg optimizes |
|
|
| ### 3.2 Action Encoder |
|
|
| ``` |
| Action a_t ∈ ℝ^A (e.g., A=2 for PushT: dx, dy) |
| ↓ |
| Group 5 consecutive actions (frameskip=5) |
| a_block = mean(a_{t:t+5}) ∈ ℝ^(5×A) = ℝ^10 for PushT |
| ↓ |
| Conv1d(10→10, kernel=1) + MLP(10→768→192) |
| ↓ |
| Action embedding act_emb_t: (B, T, 192) |
| ``` |
|
|
| ### 3.3 Predictor: Autoregressive Transformer with AdaLN-zero |
|
|
| ``` |
| Input: z_t (history of N=3 frames) + act_emb_t |
| ↓ |
| Add learned positional embeddings |
| ↓ |
| 6-layer Transformer, 16 heads, causal temporal masking |
| ↓ |
| Each layer uses AdaLN-zero conditioned on action embeddings: |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp |
| ↓ |
| Projector (same as encoder): BatchNorm1d |
| ↓ |
| Predicted z_{t+1}: (B, N, 192) |
| ``` |
|
|
| **AdaLN-zero initialization**: All parameters initialized to **zero**, so actions initially have no effect. As training progresses, the network gradually learns to use action information — this prevents early training instability. |
|
|
| **Causal masking**: The predictor can only attend to past and current time steps, not future ones. This is enforced via PyTorch's `scaled_dot_product_attention(..., is_causal=True)`. |
|
|
| ### 3.4 SIGReg: Anti-Collapse Regularizer |
|
|
| See [Section 4](#4-sigreg-the-mathematical-core) for the full mathematical derivation. |
|
|
| --- |
|
|
| ## 4. SIGReg: The Mathematical Core |
|
|
| ### 4.1 Problem Formulation |
|
|
| SIGReg forces the distribution of latent embeddings to match an **isotropic Gaussian N(0, I)**. |
|
|
| Given a batch of latent embeddings **Z ∈ ℝ^(T×B×d)** (time × batch × dimension), we want: |
|
|
| ``` |
| P_Z ≈ N(0, I_d) |
| ``` |
|
|
| ### 4.2 Why Not Just Use KL Divergence? |
|
|
| In high dimensions (>100), multivariate normality tests are unreliable. Directly matching the full d-dimensional distribution is statistically difficult. |
|
|
| ### 4.3 The Cramér-Wold Theorem |
|
|
| **Theorem**: A random vector X ∈ ℝ^d has distribution N(0, I) if and only if every 1-D projection u^T X ~ N(0, 1) for all unit vectors u ∈ S^(d-1). |
|
|
| This means: **matching all 1-D marginals = matching the full joint distribution**. |
|
|
| ### 4.4 The Epps-Pulley Test |
|
|
| For a 1-D sample h = (h_1, ..., h_N), the Epps-Pulley test statistic measures how far h is from N(0,1) using **characteristic functions**: |
|
|
| ``` |
| T(h) = ∫ w(t) |φ_N(t; h) - φ_0(t)|² dt |
| ``` |
|
|
| Where: |
| - **φ_N(t; h)** = (1/N) Σ_{n=1}^N exp(i·t·h_n) is the empirical characteristic function (ECF) |
| - **φ_0(t)** = exp(-t²/2) is the standard Gaussian CF |
| - **w(t)** = exp(-t²/(2λ²)) is a weighting function (Gaussian window) |
|
|
| **Intuition**: The ECF is the Fourier transform of the empirical distribution. If the sample is Gaussian, its ECF should match the Gaussian CF at all frequencies t. |
|
|
| ### 4.5 SIGReg Algorithm |
|
|
| ```python |
| def SIGReg(Z, M=1024): |
| # Z: (T, B, d) — time, batch, dimension |
| |
| # 1. Sample M random directions on the unit sphere |
| U = sample_sphere(d, M) # (d, M) |
| |
| # 2. Project Z onto each direction |
| H = Z @ U # (T, B, M) — 1-D marginals |
| |
| # 3. Compute Epps-Pulley statistic for each projection |
| T_values = [] |
| for m in range(M): |
| h = H[:, :, m] # (T, B) |
| T_m = epps_pulley_integral(h) |
| T_values.append(T_m) |
| |
| # 4. Average over all projections |
| return mean(T_values) |
| ``` |
|
|
| ### 4.6 Numerical Integration |
|
|
| The integral is approximated via **trapezoid quadrature**: |
|
|
| ```python |
| def epps_pulley_integral(h, knots=17): |
| t = linspace(0, 3, knots) # integration nodes |
| dt = 3 / (knots - 1) |
| |
| # Trapezoid weights |
| weights = [dt, 2*dt, ..., 2*dt, dt] |
| |
| # Gaussian window |
| w = exp(-t² / 2) |
| |
| # ECF at each node |
| phi_N = mean(exp(i * t_j * h)) # complex-valued |
| phi_0 = exp(-t_j² / 2) |
| |
| # Squared error |
| err = |phi_N - phi_0|² |
| |
| # Integrate |
| return sum(err * weights * w) * N |
| ``` |
|
|
| ### 4.7 Why SIGReg is Brilliant |
|
|
| 1. **Dimension-independent**: Works in any embedding dimension (tested up to 192+ in the paper) |
| 2. **No mode collapse**: Unlike VICReg-style variance/covariance losses, SIGReg directly measures distribution matching |
| 3. **Scalable**: Random projection sketching makes it O(M·B·d) where M=1024 is fixed |
| 4. **Insensitive to hyperparameters**: Performance is stable across M ∈ [256, 4096] and knots ∈ [10, 50] |
|
|
| --- |
|
|
| ## 5. Training Objective |
|
|
| ### 5.1 Complete Loss |
|
|
| ``` |
| L_LeWM = L_pred + λ · SIGReg(Z) |
| ``` |
|
|
| Where: |
| - **L_pred** = ||ẑ_{t+1} - z_{t+1}||²₂ (next-embedding prediction, teacher-forcing) |
| - **SIGReg(Z)** = (1/M) Σ_m T(h^(m)) (anti-collapse regularization) |
| - **λ = 0.1** is the only tunable hyperparameter |
|
|
| ### 5.2 Teacher-Forcing |
|
|
| The predictor receives **ground-truth** latent z_t (not its own previous prediction) as input: |
| |
| ```python |
| # Teacher-forcing |
| for t in range(T-1): |
| pred = predictor(z_t[history], action_emb[history]) |
| loss += MSE(pred, z_{t+1}) |
| ``` |
| |
| This is standard in sequence modeling and prevents error accumulation during training. |
| |
| ### 5.3 Training Pseudocode (from paper Algorithm 5) |
| |
| ```python |
| def train_step(batch, model, sigreg, optimizer, lambd=0.1): |
| obs, actions = batch # (B, T, C, H, W), (B, T, A) |
| |
| # Encode |
| emb = model.encode(obs) # (B, T, D) |
| act_emb = model.encode_actions(actions) # (B, T, D) |
| |
| # Predict next embeddings |
| pred_emb = model.predict(emb[:history], act_emb[:history]) |
| |
| # Prediction loss |
| pred_loss = MSE(pred_emb[:, :-1], emb[:, 1:]) |
| |
| # SIGReg on all embeddings (transpose → (T, B, D)) |
| sigreg_loss = sigreg(emb.transpose(0, 1)) |
| |
| # Total loss |
| loss = pred_loss + lambd * sigreg_loss |
| |
| # Backprop through ALL components (end-to-end!) |
| loss.backward() |
| optimizer.step() |
| ``` |
| |
| **No stop-gradient, no EMA, no momentum encoder, no pre-trained weights.** |
|
|
| --- |
|
|
| ## 6. Latent Planning with CEM |
|
|
| ### 6.1 Why Latent Planning is Fast |
|
|
| | Space | Representation | Tokens per frame | Relative size | |
| |-------|--------------|------------------|---------------| |
| | Pixel | 224×224 RGB | 50176 values | 1× | |
| | Patch (ViT p=14) | 16×16 patches | 256 patches | ~200× smaller | |
| | **Latent (LeWM)** | **192-dim vector** | **1 [CLS] token** | **~260× smaller** | |
|
|
| Planning in latent space means the CEM optimizer only needs to evaluate **192-dimensional** state vectors instead of pixel patches. |
|
|
| ### 6.2 Cross-Entropy Method (CEM) |
|
|
| CEM is a sampling-based (zero-order) optimizer that iteratively refines a Gaussian distribution over action sequences. |
|
|
| ```python |
| def CEM_plan(model, initial_obs, goal_obs, horizon=5): |
| # Encode goal |
| z_goal = model.encode(goal_obs) # (D,) |
| |
| # Initialize sampling distribution |
| mu = zeros(horizon, action_dim) |
| sigma = eye(action_dim).expand(horizon, -1, -1) |
| |
| for iteration in range(30): # 30 for PushT |
| # Sample action sequences |
| actions = sample_Gaussian(mu, sigma, n_samples=300) |
| |
| # Roll out in latent space |
| costs = [] |
| for plan in actions: |
| z_rollout = model.rollout(initial_obs, plan) |
| cost = MSE(z_rollout[-1], z_goal) |
| costs.append(cost) |
| |
| # Select elites (top 30) |
| elites = actions[argsort(costs)[:30]] |
| |
| # Refit distribution |
| mu = mean(elites, dim=0) |
| sigma = var(elites, dim=0) |
| |
| return mu # Best action sequence |
| ``` |
|
|
| ### 6.3 Receding-Horizon MPC |
|
|
| In practice, CEM is used in a **Model Predictive Control (MPC)** loop: |
| 1. Plan an action sequence of length H=5 (25 environment steps with frameskip=5) |
| 2. Execute the full sequence |
| 3. Re-observe the environment |
| 4. Re-plan from the new state |
|
|
| This closed-loop execution compensates for model prediction errors. |
|
|
| --- |
|
|
| ## 7. Why It Works: Key Design Decisions |
|
|
| ### 7.1 BatchNorm1d in the Projector (Not LayerNorm) |
|
|
| **Problem**: ViT already ends with LayerNorm. Adding another LayerNorm in the projector means: |
| - Each sample's embedding is normalized independently |
| - Batch-level statistics (mean, variance) are destroyed |
| - SIGReg, which measures batch-level distributional properties, cannot optimize |
|
|
| **Solution**: Use **BatchNorm1d** in the projector. This normalizes across the batch dimension, preserving the distribution that SIGReg matches against N(0,I). |
|
|
| ### 7.2 Dropout = 0.1 in the Predictor (Not 0.0) |
|
|
| **Ablations in the paper** (Table 8): |
| - dropout=0.0 → PushT success rate: **78%** |
| - dropout=0.1 → PushT success rate: **96%** |
|
|
| This is surprising — a small amount of dropout in the predictor is **essential** for strong performance. The authors hypothesize it acts as a regularizer that prevents the predictor from memorizing training trajectories, forcing the encoder to produce more robust representations. |
|
|
| ### 7.3 AdaLN-zero Initialization |
|
|
| All AdaLN parameters are initialized to **zero**, meaning actions have no effect at the start of training. This ensures: |
| - The predictor first learns a good prior over dynamics |
| - Gradually incorporates action information as training progresses |
| - Prevents early training collapse due to noisy action signals |
|
|
| ### 7.4 Frame Skip = 5 |
|
|
| Grouping 5 consecutive actions between frames: |
| - Reduces temporal redundancy (consecutive frames are nearly identical) |
| - Increases effective prediction horizon (4 frames × 5 skip = 20 environment steps) |
| - Makes dynamics more informative (larger state changes between frames) |
|
|
| --- |
|
|
| ## 8. Results & Comparisons |
|
|
| ### 8.1 Planning Performance |
|
|
| | Method | PushT | OGBench-Cube | TwoRoom | Reacher | |
| |--------|-------|-------------|---------|---------| |
| | **LeWM** | **96.0 ± 2.8%** | competitive | lower* | competitive | |
| | DINO-WM | 92.0 ± 1.6% | **best** | better | competitive | |
| | PLDM | 78.0 ± 5.0% | competitive | better | competitive | |
|
|
| \* TwoRoom underperformance: The low intrinsic dimensionality of this simple navigation task makes it difficult for SIGReg to match a high-dimensional isotropic Gaussian — a known limitation. |
|
|
| ### 8.2 Planning Speed |
|
|
| | Method | Time per plan | Relative speed | |
| |--------|--------------|----------------| |
| | **LeWM** | **< 1 second** | **1× (baseline)** | |
| | DINO-WM | ~48 seconds | **48× slower** | |
|
|
| ### 8.3 Physical Understanding (Probing) |
|
|
| LeWM's latent space encodes meaningful physical quantities: |
|
|
| | Property | Linear Probe MSE | Pearson r | |
| |----------|-----------------|-----------| |
| | Agent Location | 0.052 | 0.974 | |
| | Block Location | 0.029 | 0.986 | |
| | Block Angle | 0.187 | 0.902 | |
|
|
| These values match or exceed DINO-WM (which uses a foundation model pre-trained on 124M images) despite LeWM training from scratch on only 20K PushT episodes. |
|
|
| ### 8.4 Violation-of-Expectation |
|
|
| LeWM reliably detects physically implausible events (object teleportation) but is less sensitive to purely visual changes (color changes), demonstrating that its latent space captures **physical structure** rather than just visual appearance. |
|
|
| --- |
|
|
| ## 9. Implementation Notes |
|
|
| ### 9.1 Self-Contained Dependencies |
|
|
| This implementation requires only: |
| - `torch` (PyTorch 2.0+) |
| - `transformers` (for ViTModel) |
| - `einops` (for tensor manipulation) |
| - Standard library: `numpy`, `h5py` |
|
|
| No dependency on the private `stable-pretraining` or `stable-worldmodel` packages. |
|
|
| ### 9.2 Verified Against Official Implementation |
|
|
| The architecture was cross-checked against the official model config: |
| - [quentinll/lewm-pusht/config.json](https://huggingface.co/quentinll/lewm-pusht/blob/main/config.json) |
| - GitHub source: [github.com/lucas-maes/le-wm](https://github.com/lucas-maes/le-wm) |
|
|
| ### 9.3 Datasets |
|
|
| Official datasets are available on HuggingFace Hub: |
| - [quentinll/lewm-pusht](https://huggingface.co/datasets/quentinll/lewm-pusht) — PushT manipulation (12.2 GB) |
| - [quentinll/lewm-cube](https://huggingface.co/datasets/quentinll/lewm-cube) — OGBench-Cube (43 GB) |
| - [quentinll/lewm-reacher](https://huggingface.co/datasets/quentinll/lewm-reacher) |
| - [quentinll/lewm-tworooms](https://huggingface.co/datasets/quentinll/lewm-tworooms) |
|
|
| Format: HDF5 with keys `observations/pixels` (uint8, N×T×H×W×C) and `actions` (float32, N×T×A). |
|
|
| --- |
|
|
| ## 10. References |
|
|
| 1. Maes et al., "LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels", arXiv:2603.19312, 2026. |
| 2. LeCun, "A Path Towards Autonomous Machine Intelligence", 2022. |
| 3. Assran et al., "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture", CVPR 2023 (I-JEPA). |
| 4. Le Lidec et al., "PLDM: Predictive Latent Dynamics Model", NeurIPS 2024. |
| 5. Sobal et al., "DINO-World Model", ICLR 2025. |
| 6. Balestriero et al., "SIGReg: Sketched Isotropic Gaussian Regularizer", 2024. |
|
|
| --- |
|
|
| *This explanation was generated as part of an educational implementation of LeWorldModel. All credit for the scientific contribution goes to the original authors.* |
|
|