lewm-implementation / README.md
ar27111994's picture
Upload README.md
a15cf53 verified
---
tags:
- ml-intern
- jepa
- world-model
- robotics
- explainable-ai
license: mit
---
# LeWorldModel (LeWM): Stable End-to-End JEPA from Pixels
This repository contains a **clean, self-contained PyTorch implementation** of **LeWorldModel (LeWM)** from the paper:
> **LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels**
> Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero
> arXiv: 2603.19312 — https://arxiv.org/abs/2603.19312
> Official repo: https://github.com/lucas-maes/le-wm
---
## 🚀 Quick Start: Free GPU Training on Google Colab
The easiest way to train LeWM on a **free GPU** is via our Colab-ready notebook:
📓 **[Open in Colab](https://colab.research.google.com/github/ar27111994/lewm-implementation/blob/main/lewm_colab.ipynb)** *(upload the notebook from this repo)*
Or read the step-by-step guide:
📖 **[COLAB_GUIDE.md](https://huggingface.co/ar27111994/lewm-implementation/blob/main/COLAB_GUIDE.md)**
**What you need**:
- A Google account (free)
- ~30–60 minutes for 10 epochs on synthetic data
- Optional: Hugging Face token (free) to push trained models
**Hardware**: Free Colab T4 GPU (15 GB VRAM) — LeWM's ~18M parameters fit comfortably.
---
## What is LeWorldModel?
LeWorldModel (LeWM) is a **Joint-Embedding Predictive Architecture (JEPA)** world model that learns **directly from raw pixels** with a **single tunable hyperparameter**. It is the first end-to-end JEPA that trains stably without:
- Stop-gradient / EMA mechanisms
- Pre-trained encoders (e.g., DINOv2)
- Complex multi-term losses (e.g., VICReg variants)
### Key Innovations
| Feature | LeWM | Prior work (PLDM) |
|---------|------|-------------------|
| Loss terms | **2** (prediction + SIGReg) | 7 (prediction + 6 regularizers) |
| Tunable hyperparameters | **1** (lambda) | 6 (grid search O(n^6)) |
| End-to-end trainable | Yes | Partial (fragile) |
| Planning speed | **48x faster** than DINO-WM | Comparable |
| Params | ~18M | Similar |
### Architecture (from paper section 3.1 & Appendix D)
```
Raw Pixels (224x224) ---> ViT-Tiny Encoder ---> [CLS] + MLP+BN ---> Latent z_t
| (192-dim)
|
v
+-------------------+
| AR Predictor | <--- Actions (AdaLN-zero)
| 6 layers, 16h |
| Causal masking |
+-------------------+
|
v
Predicted z_{t+1}
|
v
MSE(z_{t+1}, pred) + lambda * SIGReg(z)
```
**Components:**
- **Encoder**: ViT-Tiny (patch 14, 12 layers, 3 heads, hidden 192) -> [CLS] token -> **MLP + BatchNorm1d** projector
- **Predictor**: 6-layer transformer with **AdaLN-zero** action conditioning, **causal temporal masking**
- **SIGReg**: Sketch Isotropic Gaussian Regularizer - anti-collapse via Epps-Pulley test on random 1-D projections
- **Planner**: Cross-Entropy Method (CEM) in latent space for goal-conditioned control
---
## SIGReg: The Anti-Collapse Engine
SIGReg is the critical component that makes stable end-to-end training possible.
**Problem**: Prediction-only loss causes representation collapse (encoder maps everything to a constant).
**Solution**: SIGReg forces latent embeddings to match an **isotropic Gaussian N(0, I)**.
**How it works**:
1. Collect latent tensor **Z in R^(TxBxd)** (time x batch x dim)
2. Sample **M=1024** random unit-norm directions **u^(m)** on the hypersphere S^(d-1)
3. Project: **h^(m) = Z dot u^(m)** -> (T, B) 1-D marginals
4. Apply the **Epps-Pulley test statistic** T(h^(m)) using the characteristic function
5. Trapezoid quadrature on nodes uniformly in [0, 3] with weighting w(t) = exp(-t^2/2)
6. By the **Cramer-Wold theorem**: matching all 1-D marginals <=> matching the full joint distribution
**Key insight**: The projector uses **BatchNorm1d** (not LayerNorm) because the ViT final layer already applies LayerNorm - this is essential for SIGReg optimization.
---
## Training
### Free GPU Training (Google Colab T4)
```python
# In a Colab notebook with GPU runtime enabled:
!pip install -q transformers einops huggingface_hub matplotlib numpy tqdm
# Download implementation
from huggingface_hub import hf_hub_download
hf_hub_download("ar27111994/lewm-implementation", "lewm_model.py", local_dir="/content")
hf_hub_download("ar27111994/lewm-implementation", "lewm_train.py", local_dir="/content")
# Train with synthetic data (no 12GB download needed)
!python /content/lewm_train.py --use_synthetic \
--n_episodes 2000 --epochs 10 --batch_size 128 \
--lambd 0.1 --history_size 3 --seq_len 4 \
--frameskip 5 --action_dim 2 --output_dir /content/drive/MyDrive/lewm
```
See **[COLAB_GUIDE.md](https://huggingface.co/ar27111994/lewm-implementation/blob/main/COLAB_GUIDE.md)** for the full notebook, troubleshooting, and real dataset download instructions.
### Synthetic Smoke Test (CPU-friendly)
```bash
python lewm_train.py --use_synthetic \
--n_episodes 2000 --epochs 10 \
--batch_size 128 --lr 1e-3 \
--lambd 0.1 --history_size 3 --seq_len 4 \
--frameskip 5 --action_dim 2
```
### Real PushT Dataset
1. Download official dataset:
```bash
python -c "from huggingface_hub import hf_hub_download; \
hf_hub_download('quentinll/lewm-pusht', 'pusht_expert_train.h5.zst', repo_type='dataset')"
```
2. Decompress and train:
```bash
python lewm_train.py \
--h5_path /path/to/pusht_expert_train.h5 \
--epochs 10 --batch_size 128 \
--lambd 0.1 --history_size 3
```
### Hyperparameters (from paper)
| Parameter | Value |
|-----------|-------|
| Batch size | 128 |
| Seq length | 4 frames + 4 action blocks |
| Frame skip | 5 |
| Resolution | 224x224 |
| Epochs | 10 |
| Embedding dim | 192 |
| Predictor dropout | 0.1 |
| lambda (SIGReg weight) | 0.1 |
| History length | 3 (PushT, Cube), 1 (TwoRoom) |
| Optimizer | AdamW with cosine schedule |
**Only lambda needs tuning** - performance is insensitive to number of projections (M=1024) and integration knots (17).
---
## Planning with CEM
```python
from lewm_model import build_lewm, cem_plan
model = build_lewm(action_dim=10, history_size=3)
# ... load trained weights ...
best_actions = cem_plan(
model,
initial_pixels=context_frames, # (1, H, C, 224, 224)
goal_pixels=goal_frame, # (1, 1, C, 224, 224)
action_dim=10,
horizon=5, # 5 latent steps = 25 env steps (frame_skip=5)
n_samples=300,
n_iters=30, # 30 for PushT, 10 for others
n_elites=30,
history_size=3,
)
```
---
## Results (from paper)
| Method | PushT Success Rate | Planning Time |
|--------|-------------------|---------------|
| **LeWM (ours)** | **96.0 +/- 2.8%** | **<1 sec** |
| DINO-WM | 92.0 +/- 1.6% | ~48x slower |
| PLDM | 78.0 +/- 5.0% | comparable |
- **48x faster planning** than DINO-WM due to ~200x fewer tokens in latent space
- Single GPU (L40S) training in "a few hours"
- No stop-gradient, no EMA, no pre-trained encoders
---
## Project Structure
```
lewm_model.py - Core model (Encoder, Predictor, SIGReg, CEM)
lewm_train.py - Training script (HDF5 + synthetic datasets)
lewm_mini_test.py - Minimal sanity check
lewm_colab.ipynb - Full Colab-ready training notebook
COLAB_GUIDE.md - Step-by-step free GPU training guide
config.json - Verified architecture config from official model
EXPLANATION.md - 16KB deep-dive technical explanation
```
---
## Interactive Demo
Try the **explainable interactive Space** (no training required):
🔗 **https://huggingface.co/spaces/ar27111994/lewm-explainable**
Features:
- **Architecture** tab: Full pipeline schematic
- **SIGReg Explorer**: Adjust collapse level and see real-time distributional analysis
- **CEM Planning**: Visualize Cross-Entropy Method convergence
- **Key Results**: Paper results and hyperparameters
---
## Citation
```bibtex
@article{maes_lelidec2026lewm,
title={LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels},
author={Maes, Lucas and Le Lidec, Quentin and Scieur, Damien and LeCun, Yann and Balestriero, Randall},
journal={arXiv preprint},
year={2026}
}
```
---
## License
MIT (same as the official repository).
---
*This implementation is self-contained in standard PyTorch + transformers + einops, with no dependency on the private `stable-pretraining` or `stable-worldmodel` packages for the core model logic.*
<!-- ml-intern-provenance -->
## Generated by ML Intern
This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern
## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "ar27111994/lewm-implementation"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
```
For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.