# LeWorldModel (LeWM) — Detailed Technical Explanation **Paper**: *LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels* **Authors**: Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero **arXiv**: [2603.19312](https://arxiv.org/abs/2603.19312) **Official Repo**: [github.com/lucas-maes/le-wm](https://github.com/lucas-maes/le-wm) --- ## Table of Contents 1. [What is JEPA?](#1-what-is-jepa) 2. [The Representation Collapse Problem](#2-the-representation-collapse-problem) 3. [LeWorldModel Architecture](#3-leworldmodel-architecture) 4. [SIGReg: The Mathematical Core](#4-sigreg-the-mathematical-core) 5. [Training Objective](#5-training-objective) 6. [Latent Planning with CEM](#6-latent-planning-with-cem) 7. [Why It Works: Key Design Decisions](#7-why-it-works-key-design-decisions) 8. [Results & Comparisons](#8-results--comparisons) 9. [Implementation Notes](#9-implementation-notes) 10. [References](#10-references) --- ## 1. What is JEPA? **Joint-Embedding Predictive Architecture (JEPA)** is a learning framework introduced by Yann LeCun for learning **world models** — internal predictive models of the environment that enable an agent to plan and reason. Unlike generative models (which reconstruct pixel-level observations), JEPA learns to predict in a **compact latent space**: ``` Traditional Generative Model: JEPA: obs_t ──► [encoder] obs_t ──► [encoder] ──► z_t │ │ │ ▼ ▼ ▼ reconstruct obs_{t+1} predict z_{t+1} from (z_t, a_t) ↑ ↑ pixel loss latent prediction loss ``` **Key advantages of JEPA over reconstruction-based models**: - **Efficiency**: Latent vectors are ~200× smaller than pixel patches (192-dim vs. 196 patches × dim) - **Robustness**: Latent space discards irrelevant visual noise - **Planning speed**: ~48× faster than pixel-space planning --- ## 2. The Representation Collapse Problem The fundamental challenge in end-to-end JEPA training is **representation collapse**. ### What is Collapse? If the loss only penalizes prediction error in latent space, the encoder learns a **trivial solution**: map all inputs to a constant vector. The predictor then trivially learns to predict that same constant, achieving zero prediction loss. ```python # Collapsed encoder — maps everything to zero encoder(x) = 0 for all x predictor(0, a) = 0 for all a loss = ||0 - 0||² = 0 ← perfect but useless! ``` ### How Prior Work Handled It | Method | Anti-Collapse Strategy | Drawback | |--------|----------------------|----------| | **I-JEPA** | Stop-gradient + EMA target network | Not end-to-end | | **DINO-WM** | Frozen DINOv2 pre-trained encoder | Relies on large-scale pretraining | | **PLDM** | 7-term VICReg-inspired loss | Fragile, 6 hyperparameters | | **LeWM (ours)** | **SIGReg** — single statistic | **1 hyperparameter, stable** | --- ## 3. LeWorldModel Architecture LeWM consists of four main components: ### 3.1 Encoder: ViT-Tiny ``` Input: (B, T, 3, 224, 224) raw pixel frames ↓ ViT-Tiny (patch=14, 12 layers, 3 heads, hidden=192) ↓ [CLS] token from last layer: (B*T, 192) ↓ Projector: Linear(192→2048) → BatchNorm1d → GELU → Linear(2048→192) ↓ Latent embedding z_t: (B, T, 192) ``` **Critical design**: The projector uses **BatchNorm1d** (not LayerNorm). The ViT's final layer already applies LayerNorm. Adding another LayerNorm would block SIGReg from optimizing because: - LayerNorm normalizes within the sample, destroying batch-level statistics that SIGReg measures - BatchNorm normalizes across the batch, preserving the distributional properties SIGReg optimizes ### 3.2 Action Encoder ``` Action a_t ∈ ℝ^A (e.g., A=2 for PushT: dx, dy) ↓ Group 5 consecutive actions (frameskip=5) a_block = mean(a_{t:t+5}) ∈ ℝ^(5×A) = ℝ^10 for PushT ↓ Conv1d(10→10, kernel=1) + MLP(10→768→192) ↓ Action embedding act_emb_t: (B, T, 192) ``` ### 3.3 Predictor: Autoregressive Transformer with AdaLN-zero ``` Input: z_t (history of N=3 frames) + act_emb_t ↓ Add learned positional embeddings ↓ 6-layer Transformer, 16 heads, causal temporal masking ↓ Each layer uses AdaLN-zero conditioned on action embeddings: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp ↓ Projector (same as encoder): BatchNorm1d ↓ Predicted z_{t+1}: (B, N, 192) ``` **AdaLN-zero initialization**: All parameters initialized to **zero**, so actions initially have no effect. As training progresses, the network gradually learns to use action information — this prevents early training instability. **Causal masking**: The predictor can only attend to past and current time steps, not future ones. This is enforced via PyTorch's `scaled_dot_product_attention(..., is_causal=True)`. ### 3.4 SIGReg: Anti-Collapse Regularizer See [Section 4](#4-sigreg-the-mathematical-core) for the full mathematical derivation. --- ## 4. SIGReg: The Mathematical Core ### 4.1 Problem Formulation SIGReg forces the distribution of latent embeddings to match an **isotropic Gaussian N(0, I)**. Given a batch of latent embeddings **Z ∈ ℝ^(T×B×d)** (time × batch × dimension), we want: ``` P_Z ≈ N(0, I_d) ``` ### 4.2 Why Not Just Use KL Divergence? In high dimensions (>100), multivariate normality tests are unreliable. Directly matching the full d-dimensional distribution is statistically difficult. ### 4.3 The Cramér-Wold Theorem **Theorem**: A random vector X ∈ ℝ^d has distribution N(0, I) if and only if every 1-D projection u^T X ~ N(0, 1) for all unit vectors u ∈ S^(d-1). This means: **matching all 1-D marginals = matching the full joint distribution**. ### 4.4 The Epps-Pulley Test For a 1-D sample h = (h_1, ..., h_N), the Epps-Pulley test statistic measures how far h is from N(0,1) using **characteristic functions**: ``` T(h) = ∫ w(t) |φ_N(t; h) - φ_0(t)|² dt ``` Where: - **φ_N(t; h)** = (1/N) Σ_{n=1}^N exp(i·t·h_n) is the empirical characteristic function (ECF) - **φ_0(t)** = exp(-t²/2) is the standard Gaussian CF - **w(t)** = exp(-t²/(2λ²)) is a weighting function (Gaussian window) **Intuition**: The ECF is the Fourier transform of the empirical distribution. If the sample is Gaussian, its ECF should match the Gaussian CF at all frequencies t. ### 4.5 SIGReg Algorithm ```python def SIGReg(Z, M=1024): # Z: (T, B, d) — time, batch, dimension # 1. Sample M random directions on the unit sphere U = sample_sphere(d, M) # (d, M) # 2. Project Z onto each direction H = Z @ U # (T, B, M) — 1-D marginals # 3. Compute Epps-Pulley statistic for each projection T_values = [] for m in range(M): h = H[:, :, m] # (T, B) T_m = epps_pulley_integral(h) T_values.append(T_m) # 4. Average over all projections return mean(T_values) ``` ### 4.6 Numerical Integration The integral is approximated via **trapezoid quadrature**: ```python def epps_pulley_integral(h, knots=17): t = linspace(0, 3, knots) # integration nodes dt = 3 / (knots - 1) # Trapezoid weights weights = [dt, 2*dt, ..., 2*dt, dt] # Gaussian window w = exp(-t² / 2) # ECF at each node phi_N = mean(exp(i * t_j * h)) # complex-valued phi_0 = exp(-t_j² / 2) # Squared error err = |phi_N - phi_0|² # Integrate return sum(err * weights * w) * N ``` ### 4.7 Why SIGReg is Brilliant 1. **Dimension-independent**: Works in any embedding dimension (tested up to 192+ in the paper) 2. **No mode collapse**: Unlike VICReg-style variance/covariance losses, SIGReg directly measures distribution matching 3. **Scalable**: Random projection sketching makes it O(M·B·d) where M=1024 is fixed 4. **Insensitive to hyperparameters**: Performance is stable across M ∈ [256, 4096] and knots ∈ [10, 50] --- ## 5. Training Objective ### 5.1 Complete Loss ``` L_LeWM = L_pred + λ · SIGReg(Z) ``` Where: - **L_pred** = ||ẑ_{t+1} - z_{t+1}||²₂ (next-embedding prediction, teacher-forcing) - **SIGReg(Z)** = (1/M) Σ_m T(h^(m)) (anti-collapse regularization) - **λ = 0.1** is the only tunable hyperparameter ### 5.2 Teacher-Forcing The predictor receives **ground-truth** latent z_t (not its own previous prediction) as input: ```python # Teacher-forcing for t in range(T-1): pred = predictor(z_t[history], action_emb[history]) loss += MSE(pred, z_{t+1}) ``` This is standard in sequence modeling and prevents error accumulation during training. ### 5.3 Training Pseudocode (from paper Algorithm 5) ```python def train_step(batch, model, sigreg, optimizer, lambd=0.1): obs, actions = batch # (B, T, C, H, W), (B, T, A) # Encode emb = model.encode(obs) # (B, T, D) act_emb = model.encode_actions(actions) # (B, T, D) # Predict next embeddings pred_emb = model.predict(emb[:history], act_emb[:history]) # Prediction loss pred_loss = MSE(pred_emb[:, :-1], emb[:, 1:]) # SIGReg on all embeddings (transpose → (T, B, D)) sigreg_loss = sigreg(emb.transpose(0, 1)) # Total loss loss = pred_loss + lambd * sigreg_loss # Backprop through ALL components (end-to-end!) loss.backward() optimizer.step() ``` **No stop-gradient, no EMA, no momentum encoder, no pre-trained weights.** --- ## 6. Latent Planning with CEM ### 6.1 Why Latent Planning is Fast | Space | Representation | Tokens per frame | Relative size | |-------|--------------|------------------|---------------| | Pixel | 224×224 RGB | 50176 values | 1× | | Patch (ViT p=14) | 16×16 patches | 256 patches | ~200× smaller | | **Latent (LeWM)** | **192-dim vector** | **1 [CLS] token** | **~260× smaller** | Planning in latent space means the CEM optimizer only needs to evaluate **192-dimensional** state vectors instead of pixel patches. ### 6.2 Cross-Entropy Method (CEM) CEM is a sampling-based (zero-order) optimizer that iteratively refines a Gaussian distribution over action sequences. ```python def CEM_plan(model, initial_obs, goal_obs, horizon=5): # Encode goal z_goal = model.encode(goal_obs) # (D,) # Initialize sampling distribution mu = zeros(horizon, action_dim) sigma = eye(action_dim).expand(horizon, -1, -1) for iteration in range(30): # 30 for PushT # Sample action sequences actions = sample_Gaussian(mu, sigma, n_samples=300) # Roll out in latent space costs = [] for plan in actions: z_rollout = model.rollout(initial_obs, plan) cost = MSE(z_rollout[-1], z_goal) costs.append(cost) # Select elites (top 30) elites = actions[argsort(costs)[:30]] # Refit distribution mu = mean(elites, dim=0) sigma = var(elites, dim=0) return mu # Best action sequence ``` ### 6.3 Receding-Horizon MPC In practice, CEM is used in a **Model Predictive Control (MPC)** loop: 1. Plan an action sequence of length H=5 (25 environment steps with frameskip=5) 2. Execute the full sequence 3. Re-observe the environment 4. Re-plan from the new state This closed-loop execution compensates for model prediction errors. --- ## 7. Why It Works: Key Design Decisions ### 7.1 BatchNorm1d in the Projector (Not LayerNorm) **Problem**: ViT already ends with LayerNorm. Adding another LayerNorm in the projector means: - Each sample's embedding is normalized independently - Batch-level statistics (mean, variance) are destroyed - SIGReg, which measures batch-level distributional properties, cannot optimize **Solution**: Use **BatchNorm1d** in the projector. This normalizes across the batch dimension, preserving the distribution that SIGReg matches against N(0,I). ### 7.2 Dropout = 0.1 in the Predictor (Not 0.0) **Ablations in the paper** (Table 8): - dropout=0.0 → PushT success rate: **78%** - dropout=0.1 → PushT success rate: **96%** This is surprising — a small amount of dropout in the predictor is **essential** for strong performance. The authors hypothesize it acts as a regularizer that prevents the predictor from memorizing training trajectories, forcing the encoder to produce more robust representations. ### 7.3 AdaLN-zero Initialization All AdaLN parameters are initialized to **zero**, meaning actions have no effect at the start of training. This ensures: - The predictor first learns a good prior over dynamics - Gradually incorporates action information as training progresses - Prevents early training collapse due to noisy action signals ### 7.4 Frame Skip = 5 Grouping 5 consecutive actions between frames: - Reduces temporal redundancy (consecutive frames are nearly identical) - Increases effective prediction horizon (4 frames × 5 skip = 20 environment steps) - Makes dynamics more informative (larger state changes between frames) --- ## 8. Results & Comparisons ### 8.1 Planning Performance | Method | PushT | OGBench-Cube | TwoRoom | Reacher | |--------|-------|-------------|---------|---------| | **LeWM** | **96.0 ± 2.8%** | competitive | lower* | competitive | | DINO-WM | 92.0 ± 1.6% | **best** | better | competitive | | PLDM | 78.0 ± 5.0% | competitive | better | competitive | \* TwoRoom underperformance: The low intrinsic dimensionality of this simple navigation task makes it difficult for SIGReg to match a high-dimensional isotropic Gaussian — a known limitation. ### 8.2 Planning Speed | Method | Time per plan | Relative speed | |--------|--------------|----------------| | **LeWM** | **< 1 second** | **1× (baseline)** | | DINO-WM | ~48 seconds | **48× slower** | ### 8.3 Physical Understanding (Probing) LeWM's latent space encodes meaningful physical quantities: | Property | Linear Probe MSE | Pearson r | |----------|-----------------|-----------| | Agent Location | 0.052 | 0.974 | | Block Location | 0.029 | 0.986 | | Block Angle | 0.187 | 0.902 | These values match or exceed DINO-WM (which uses a foundation model pre-trained on 124M images) despite LeWM training from scratch on only 20K PushT episodes. ### 8.4 Violation-of-Expectation LeWM reliably detects physically implausible events (object teleportation) but is less sensitive to purely visual changes (color changes), demonstrating that its latent space captures **physical structure** rather than just visual appearance. --- ## 9. Implementation Notes ### 9.1 Self-Contained Dependencies This implementation requires only: - `torch` (PyTorch 2.0+) - `transformers` (for ViTModel) - `einops` (for tensor manipulation) - Standard library: `numpy`, `h5py` No dependency on the private `stable-pretraining` or `stable-worldmodel` packages. ### 9.2 Verified Against Official Implementation The architecture was cross-checked against the official model config: - [quentinll/lewm-pusht/config.json](https://huggingface.co/quentinll/lewm-pusht/blob/main/config.json) - GitHub source: [github.com/lucas-maes/le-wm](https://github.com/lucas-maes/le-wm) ### 9.3 Datasets Official datasets are available on HuggingFace Hub: - [quentinll/lewm-pusht](https://huggingface.co/datasets/quentinll/lewm-pusht) — PushT manipulation (12.2 GB) - [quentinll/lewm-cube](https://huggingface.co/datasets/quentinll/lewm-cube) — OGBench-Cube (43 GB) - [quentinll/lewm-reacher](https://huggingface.co/datasets/quentinll/lewm-reacher) - [quentinll/lewm-tworooms](https://huggingface.co/datasets/quentinll/lewm-tworooms) Format: HDF5 with keys `observations/pixels` (uint8, N×T×H×W×C) and `actions` (float32, N×T×A). --- ## 10. References 1. Maes et al., "LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels", arXiv:2603.19312, 2026. 2. LeCun, "A Path Towards Autonomous Machine Intelligence", 2022. 3. Assran et al., "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture", CVPR 2023 (I-JEPA). 4. Le Lidec et al., "PLDM: Predictive Latent Dynamics Model", NeurIPS 2024. 5. Sobal et al., "DINO-World Model", ICLR 2025. 6. Balestriero et al., "SIGReg: Sketched Isotropic Gaussian Regularizer", 2024. --- *This explanation was generated as part of an educational implementation of LeWorldModel. All credit for the scientific contribution goes to the original authors.*