lewm-implementation / EXPLANATION.md
ar27111994's picture
Upload EXPLANATION.md
0d93e28 verified
# LeWorldModel (LeWM) — Detailed Technical Explanation
**Paper**: *LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels*
**Authors**: Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero
**arXiv**: [2603.19312](https://arxiv.org/abs/2603.19312)
**Official Repo**: [github.com/lucas-maes/le-wm](https://github.com/lucas-maes/le-wm)
---
## Table of Contents
1. [What is JEPA?](#1-what-is-jepa)
2. [The Representation Collapse Problem](#2-the-representation-collapse-problem)
3. [LeWorldModel Architecture](#3-leworldmodel-architecture)
4. [SIGReg: The Mathematical Core](#4-sigreg-the-mathematical-core)
5. [Training Objective](#5-training-objective)
6. [Latent Planning with CEM](#6-latent-planning-with-cem)
7. [Why It Works: Key Design Decisions](#7-why-it-works-key-design-decisions)
8. [Results & Comparisons](#8-results--comparisons)
9. [Implementation Notes](#9-implementation-notes)
10. [References](#10-references)
---
## 1. What is JEPA?
**Joint-Embedding Predictive Architecture (JEPA)** is a learning framework introduced by Yann LeCun for learning **world models** — internal predictive models of the environment that enable an agent to plan and reason.
Unlike generative models (which reconstruct pixel-level observations), JEPA learns to predict in a **compact latent space**:
```
Traditional Generative Model: JEPA:
obs_t ──► [encoder] obs_t ──► [encoder] ──► z_t
│ │ │
▼ ▼ ▼
reconstruct obs_{t+1} predict z_{t+1} from (z_t, a_t)
↑ ↑
pixel loss latent prediction loss
```
**Key advantages of JEPA over reconstruction-based models**:
- **Efficiency**: Latent vectors are ~200× smaller than pixel patches (192-dim vs. 196 patches × dim)
- **Robustness**: Latent space discards irrelevant visual noise
- **Planning speed**: ~48× faster than pixel-space planning
---
## 2. The Representation Collapse Problem
The fundamental challenge in end-to-end JEPA training is **representation collapse**.
### What is Collapse?
If the loss only penalizes prediction error in latent space, the encoder learns a **trivial solution**: map all inputs to a constant vector. The predictor then trivially learns to predict that same constant, achieving zero prediction loss.
```python
# Collapsed encoder — maps everything to zero
encoder(x) = 0 for all x
predictor(0, a) = 0 for all a
loss = ||0 - 0||² = 0 ← perfect but useless!
```
### How Prior Work Handled It
| Method | Anti-Collapse Strategy | Drawback |
|--------|----------------------|----------|
| **I-JEPA** | Stop-gradient + EMA target network | Not end-to-end |
| **DINO-WM** | Frozen DINOv2 pre-trained encoder | Relies on large-scale pretraining |
| **PLDM** | 7-term VICReg-inspired loss | Fragile, 6 hyperparameters |
| **LeWM (ours)** | **SIGReg** — single statistic | **1 hyperparameter, stable** |
---
## 3. LeWorldModel Architecture
LeWM consists of four main components:
### 3.1 Encoder: ViT-Tiny
```
Input: (B, T, 3, 224, 224) raw pixel frames
ViT-Tiny (patch=14, 12 layers, 3 heads, hidden=192)
[CLS] token from last layer: (B*T, 192)
Projector: Linear(192→2048) → BatchNorm1d → GELU → Linear(2048→192)
Latent embedding z_t: (B, T, 192)
```
**Critical design**: The projector uses **BatchNorm1d** (not LayerNorm). The ViT's final layer already applies LayerNorm. Adding another LayerNorm would block SIGReg from optimizing because:
- LayerNorm normalizes within the sample, destroying batch-level statistics that SIGReg measures
- BatchNorm normalizes across the batch, preserving the distributional properties SIGReg optimizes
### 3.2 Action Encoder
```
Action a_t ∈ ℝ^A (e.g., A=2 for PushT: dx, dy)
Group 5 consecutive actions (frameskip=5)
a_block = mean(a_{t:t+5}) ∈ ℝ^(5×A) = ℝ^10 for PushT
Conv1d(10→10, kernel=1) + MLP(10→768→192)
Action embedding act_emb_t: (B, T, 192)
```
### 3.3 Predictor: Autoregressive Transformer with AdaLN-zero
```
Input: z_t (history of N=3 frames) + act_emb_t
Add learned positional embeddings
6-layer Transformer, 16 heads, causal temporal masking
Each layer uses AdaLN-zero conditioned on action embeddings:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
Projector (same as encoder): BatchNorm1d
Predicted z_{t+1}: (B, N, 192)
```
**AdaLN-zero initialization**: All parameters initialized to **zero**, so actions initially have no effect. As training progresses, the network gradually learns to use action information — this prevents early training instability.
**Causal masking**: The predictor can only attend to past and current time steps, not future ones. This is enforced via PyTorch's `scaled_dot_product_attention(..., is_causal=True)`.
### 3.4 SIGReg: Anti-Collapse Regularizer
See [Section 4](#4-sigreg-the-mathematical-core) for the full mathematical derivation.
---
## 4. SIGReg: The Mathematical Core
### 4.1 Problem Formulation
SIGReg forces the distribution of latent embeddings to match an **isotropic Gaussian N(0, I)**.
Given a batch of latent embeddings **Z ∈ ℝ^(T×B×d)** (time × batch × dimension), we want:
```
P_Z ≈ N(0, I_d)
```
### 4.2 Why Not Just Use KL Divergence?
In high dimensions (>100), multivariate normality tests are unreliable. Directly matching the full d-dimensional distribution is statistically difficult.
### 4.3 The Cramér-Wold Theorem
**Theorem**: A random vector X ∈ ℝ^d has distribution N(0, I) if and only if every 1-D projection u^T X ~ N(0, 1) for all unit vectors u ∈ S^(d-1).
This means: **matching all 1-D marginals = matching the full joint distribution**.
### 4.4 The Epps-Pulley Test
For a 1-D sample h = (h_1, ..., h_N), the Epps-Pulley test statistic measures how far h is from N(0,1) using **characteristic functions**:
```
T(h) = ∫ w(t) |φ_N(t; h) - φ_0(t)|² dt
```
Where:
- **φ_N(t; h)** = (1/N) Σ_{n=1}^N exp(i·t·h_n) is the empirical characteristic function (ECF)
- **φ_0(t)** = exp(-t²/2) is the standard Gaussian CF
- **w(t)** = exp(-t²/(2λ²)) is a weighting function (Gaussian window)
**Intuition**: The ECF is the Fourier transform of the empirical distribution. If the sample is Gaussian, its ECF should match the Gaussian CF at all frequencies t.
### 4.5 SIGReg Algorithm
```python
def SIGReg(Z, M=1024):
# Z: (T, B, d) — time, batch, dimension
# 1. Sample M random directions on the unit sphere
U = sample_sphere(d, M) # (d, M)
# 2. Project Z onto each direction
H = Z @ U # (T, B, M) — 1-D marginals
# 3. Compute Epps-Pulley statistic for each projection
T_values = []
for m in range(M):
h = H[:, :, m] # (T, B)
T_m = epps_pulley_integral(h)
T_values.append(T_m)
# 4. Average over all projections
return mean(T_values)
```
### 4.6 Numerical Integration
The integral is approximated via **trapezoid quadrature**:
```python
def epps_pulley_integral(h, knots=17):
t = linspace(0, 3, knots) # integration nodes
dt = 3 / (knots - 1)
# Trapezoid weights
weights = [dt, 2*dt, ..., 2*dt, dt]
# Gaussian window
w = exp(-t² / 2)
# ECF at each node
phi_N = mean(exp(i * t_j * h)) # complex-valued
phi_0 = exp(-t_j² / 2)
# Squared error
err = |phi_N - phi_0|²
# Integrate
return sum(err * weights * w) * N
```
### 4.7 Why SIGReg is Brilliant
1. **Dimension-independent**: Works in any embedding dimension (tested up to 192+ in the paper)
2. **No mode collapse**: Unlike VICReg-style variance/covariance losses, SIGReg directly measures distribution matching
3. **Scalable**: Random projection sketching makes it O(M·B·d) where M=1024 is fixed
4. **Insensitive to hyperparameters**: Performance is stable across M ∈ [256, 4096] and knots ∈ [10, 50]
---
## 5. Training Objective
### 5.1 Complete Loss
```
L_LeWM = L_pred + λ · SIGReg(Z)
```
Where:
- **L_pred** = ||ẑ_{t+1} - z_{t+1}||²₂ (next-embedding prediction, teacher-forcing)
- **SIGReg(Z)** = (1/M) Σ_m T(h^(m)) (anti-collapse regularization)
- **λ = 0.1** is the only tunable hyperparameter
### 5.2 Teacher-Forcing
The predictor receives **ground-truth** latent z_t (not its own previous prediction) as input:
```python
# Teacher-forcing
for t in range(T-1):
pred = predictor(z_t[history], action_emb[history])
loss += MSE(pred, z_{t+1})
```
This is standard in sequence modeling and prevents error accumulation during training.
### 5.3 Training Pseudocode (from paper Algorithm 5)
```python
def train_step(batch, model, sigreg, optimizer, lambd=0.1):
obs, actions = batch # (B, T, C, H, W), (B, T, A)
# Encode
emb = model.encode(obs) # (B, T, D)
act_emb = model.encode_actions(actions) # (B, T, D)
# Predict next embeddings
pred_emb = model.predict(emb[:history], act_emb[:history])
# Prediction loss
pred_loss = MSE(pred_emb[:, :-1], emb[:, 1:])
# SIGReg on all embeddings (transpose → (T, B, D))
sigreg_loss = sigreg(emb.transpose(0, 1))
# Total loss
loss = pred_loss + lambd * sigreg_loss
# Backprop through ALL components (end-to-end!)
loss.backward()
optimizer.step()
```
**No stop-gradient, no EMA, no momentum encoder, no pre-trained weights.**
---
## 6. Latent Planning with CEM
### 6.1 Why Latent Planning is Fast
| Space | Representation | Tokens per frame | Relative size |
|-------|--------------|------------------|---------------|
| Pixel | 224×224 RGB | 50176 values | 1× |
| Patch (ViT p=14) | 16×16 patches | 256 patches | ~200× smaller |
| **Latent (LeWM)** | **192-dim vector** | **1 [CLS] token** | **~260× smaller** |
Planning in latent space means the CEM optimizer only needs to evaluate **192-dimensional** state vectors instead of pixel patches.
### 6.2 Cross-Entropy Method (CEM)
CEM is a sampling-based (zero-order) optimizer that iteratively refines a Gaussian distribution over action sequences.
```python
def CEM_plan(model, initial_obs, goal_obs, horizon=5):
# Encode goal
z_goal = model.encode(goal_obs) # (D,)
# Initialize sampling distribution
mu = zeros(horizon, action_dim)
sigma = eye(action_dim).expand(horizon, -1, -1)
for iteration in range(30): # 30 for PushT
# Sample action sequences
actions = sample_Gaussian(mu, sigma, n_samples=300)
# Roll out in latent space
costs = []
for plan in actions:
z_rollout = model.rollout(initial_obs, plan)
cost = MSE(z_rollout[-1], z_goal)
costs.append(cost)
# Select elites (top 30)
elites = actions[argsort(costs)[:30]]
# Refit distribution
mu = mean(elites, dim=0)
sigma = var(elites, dim=0)
return mu # Best action sequence
```
### 6.3 Receding-Horizon MPC
In practice, CEM is used in a **Model Predictive Control (MPC)** loop:
1. Plan an action sequence of length H=5 (25 environment steps with frameskip=5)
2. Execute the full sequence
3. Re-observe the environment
4. Re-plan from the new state
This closed-loop execution compensates for model prediction errors.
---
## 7. Why It Works: Key Design Decisions
### 7.1 BatchNorm1d in the Projector (Not LayerNorm)
**Problem**: ViT already ends with LayerNorm. Adding another LayerNorm in the projector means:
- Each sample's embedding is normalized independently
- Batch-level statistics (mean, variance) are destroyed
- SIGReg, which measures batch-level distributional properties, cannot optimize
**Solution**: Use **BatchNorm1d** in the projector. This normalizes across the batch dimension, preserving the distribution that SIGReg matches against N(0,I).
### 7.2 Dropout = 0.1 in the Predictor (Not 0.0)
**Ablations in the paper** (Table 8):
- dropout=0.0 → PushT success rate: **78%**
- dropout=0.1 → PushT success rate: **96%**
This is surprising — a small amount of dropout in the predictor is **essential** for strong performance. The authors hypothesize it acts as a regularizer that prevents the predictor from memorizing training trajectories, forcing the encoder to produce more robust representations.
### 7.3 AdaLN-zero Initialization
All AdaLN parameters are initialized to **zero**, meaning actions have no effect at the start of training. This ensures:
- The predictor first learns a good prior over dynamics
- Gradually incorporates action information as training progresses
- Prevents early training collapse due to noisy action signals
### 7.4 Frame Skip = 5
Grouping 5 consecutive actions between frames:
- Reduces temporal redundancy (consecutive frames are nearly identical)
- Increases effective prediction horizon (4 frames × 5 skip = 20 environment steps)
- Makes dynamics more informative (larger state changes between frames)
---
## 8. Results & Comparisons
### 8.1 Planning Performance
| Method | PushT | OGBench-Cube | TwoRoom | Reacher |
|--------|-------|-------------|---------|---------|
| **LeWM** | **96.0 ± 2.8%** | competitive | lower* | competitive |
| DINO-WM | 92.0 ± 1.6% | **best** | better | competitive |
| PLDM | 78.0 ± 5.0% | competitive | better | competitive |
\* TwoRoom underperformance: The low intrinsic dimensionality of this simple navigation task makes it difficult for SIGReg to match a high-dimensional isotropic Gaussian — a known limitation.
### 8.2 Planning Speed
| Method | Time per plan | Relative speed |
|--------|--------------|----------------|
| **LeWM** | **< 1 second** | **1× (baseline)** |
| DINO-WM | ~48 seconds | **48× slower** |
### 8.3 Physical Understanding (Probing)
LeWM's latent space encodes meaningful physical quantities:
| Property | Linear Probe MSE | Pearson r |
|----------|-----------------|-----------|
| Agent Location | 0.052 | 0.974 |
| Block Location | 0.029 | 0.986 |
| Block Angle | 0.187 | 0.902 |
These values match or exceed DINO-WM (which uses a foundation model pre-trained on 124M images) despite LeWM training from scratch on only 20K PushT episodes.
### 8.4 Violation-of-Expectation
LeWM reliably detects physically implausible events (object teleportation) but is less sensitive to purely visual changes (color changes), demonstrating that its latent space captures **physical structure** rather than just visual appearance.
---
## 9. Implementation Notes
### 9.1 Self-Contained Dependencies
This implementation requires only:
- `torch` (PyTorch 2.0+)
- `transformers` (for ViTModel)
- `einops` (for tensor manipulation)
- Standard library: `numpy`, `h5py`
No dependency on the private `stable-pretraining` or `stable-worldmodel` packages.
### 9.2 Verified Against Official Implementation
The architecture was cross-checked against the official model config:
- [quentinll/lewm-pusht/config.json](https://huggingface.co/quentinll/lewm-pusht/blob/main/config.json)
- GitHub source: [github.com/lucas-maes/le-wm](https://github.com/lucas-maes/le-wm)
### 9.3 Datasets
Official datasets are available on HuggingFace Hub:
- [quentinll/lewm-pusht](https://huggingface.co/datasets/quentinll/lewm-pusht) — PushT manipulation (12.2 GB)
- [quentinll/lewm-cube](https://huggingface.co/datasets/quentinll/lewm-cube) — OGBench-Cube (43 GB)
- [quentinll/lewm-reacher](https://huggingface.co/datasets/quentinll/lewm-reacher)
- [quentinll/lewm-tworooms](https://huggingface.co/datasets/quentinll/lewm-tworooms)
Format: HDF5 with keys `observations/pixels` (uint8, N×T×H×W×C) and `actions` (float32, N×T×A).
---
## 10. References
1. Maes et al., "LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels", arXiv:2603.19312, 2026.
2. LeCun, "A Path Towards Autonomous Machine Intelligence", 2022.
3. Assran et al., "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture", CVPR 2023 (I-JEPA).
4. Le Lidec et al., "PLDM: Predictive Latent Dynamics Model", NeurIPS 2024.
5. Sobal et al., "DINO-World Model", ICLR 2025.
6. Balestriero et al., "SIGReg: Sketched Isotropic Gaussian Regularizer", 2024.
---
*This explanation was generated as part of an educational implementation of LeWorldModel. All credit for the scientific contribution goes to the original authors.*