lewm-implementation / COLAB_GUIDE.md
ar27111994's picture
Upload COLAB_GUIDE.md
d816568 verified
# πŸš€ Free GPU Training Guide for LeWorldModel on Google Colab
This guide walks you through training **LeWorldModel (~18M params)** on a **free Google Colab T4 GPU** (15 GB VRAM) β€” no paid credits required.
**Time estimate**: 30–60 minutes for 10 epochs on synthetic data.
---
## Step 1: Open Colab & Enable GPU
1. Go to [colab.research.google.com](https://colab.research.google.com)
2. Create a new notebook
3. **Runtime β†’ Change runtime type β†’ Hardware accelerator: GPU** (T4)
4. Verify GPU:
```python
!nvidia-smi
import torch
print(torch.cuda.get_device_name(0)) # Should print Tesla T4
```
---
## Step 2: Install Dependencies
```python
%%capture
!pip install -q transformers einops huggingface_hub matplotlib numpy tqdm
```
---
## Step 3: Mount Google Drive (for checkpoints)
```python
from google.colab import drive
drive.mount('/content/drive')
import os
OUTPUT_DIR = "/content/drive/MyDrive/lewm_training"
os.makedirs(OUTPUT_DIR, exist_ok=True)
```
---
## Step 4: Download Implementation
```python
!pip install -q huggingface_hub
from huggingface_hub import hf_hub_download
REPO_ID = "ar27111994/lewm-implementation"
for fname in ["lewm_model.py", "lewm_train.py"]:
hf_hub_download(repo_id=REPO_ID, filename=fname, local_dir="/content")
print(f"βœ“ {fname}")
```
---
## Step 5: Sanity Check
```python
import sys
sys.path.insert(0, "/content")
from lewm_model import build_lewm
import torch
device = torch.device("cuda")
model = build_lewm(action_dim=2, frameskip=5, history_size=3).to(device)
print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")
# Quick forward
B, T = 2, 4
out = model(
torch.randn(B, T, 3, 224, 224).cuda(),
torch.randn(B, T, 10).cuda(),
history_size=3
)
print(f"loss={out['loss'].item():.4f} pred={out['pred_loss'].item():.4f} sigreg={out['sigreg_loss'].item():.4f}")
print("βœ… Ready!")
```
---
## Step 6: Training Configuration
```python
CONFIG = {
"use_synthetic": True, # True = no 12GB download needed
"n_episodes": 2000, # Synthetic episodes (use 500 for quick test)
"seq_len": 4,
"frameskip": 5,
"img_size": 224,
"action_dim": 2,
"history_size": 3,
"embed_dim": 192,
"lambd": 0.1, # ONLY hyperparameter to tune!
"epochs": 10,
"batch_size": 128, # Reduce to 64 if OOM
"lr": 1e-3,
"weight_decay": 0.05,
"grad_clip": 1.0,
"num_workers": 2,
"log_interval": 25,
"output_dir": OUTPUT_DIR,
}
```
---
## Step 7: Full Training Loop
Copy the full training cell from the notebook (available at the repository). The key components:
```python
from lewm_train import SyntheticPushTDataset # or TrajectoryDataset for real data
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from lewm_model import SIGReg
from tqdm.notebook import tqdm
dataset = SyntheticPushTDataset(
n_episodes=CONFIG["n_episodes"],
seq_len=CONFIG["seq_len"],
frameskip=CONFIG["frameskip"],
img_size=CONFIG["img_size"],
)
loader = DataLoader(dataset, batch_size=CONFIG["batch_size"],
shuffle=True, num_workers=2, drop_last=True, pin_memory=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"],
weight_decay=CONFIG["weight_decay"], betas=(0.9, 0.95))
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=int(0.05 * len(loader) * CONFIG["epochs"]),
num_training_steps=len(loader) * CONFIG["epochs"])
sigreg = SIGReg(knots=17, num_proj=1024).cuda()
for epoch in range(CONFIG["epochs"]):
model.train()
for obs, acts in tqdm(loader, desc=f"Epoch {epoch+1}"):
obs, acts = obs.cuda(), acts.cuda()
emb = model.encode(obs)
act_emb = model.action_encoder(acts)
ctx_emb = emb[:, :CONFIG["history_size"]]
ctx_act = act_emb[:, :CONFIG["history_size"]]
pred_emb = model.predict(ctx_emb, ctx_act)
pred_loss = (pred_emb[:, :-1] - emb[:, 1:CONFIG["history_size"]]).pow(2).mean()
sigreg_loss = sigreg(emb.transpose(0, 1))
loss = pred_loss + CONFIG["lambd"] * sigreg_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["grad_clip"])
optimizer.step()
scheduler.step()
# Save checkpoint
torch.save(model.state_dict(), f"{OUTPUT_DIR}/epoch_{epoch+1}.pt")
print(f"Epoch {epoch+1} complete. Saved checkpoint.")
```
---
## Step 8: Using Real PushT Dataset (Optional)
**Warning**: The real dataset is **12 GB compressed**. Colab disk is only ~78 GB and may fill up.
```python
# Download & decompress (takes ~5-10 minutes)
!pip install -q zstandard h5py
from huggingface_hub import hf_hub_download
import zstandard, os
zst_path = hf_hub_download(
repo_id='quentinll/lewm-pusht',
filename='pusht_expert_train.h5.zst',
repo_type='dataset',
local_dir='/content/data'
)
h5_path = '/content/data/pusht_expert_train.h5'
cctx = zstandard.ZstdDecompressor()
with open(zst_path, 'rb') as inf, open(h5_path, 'wb') as outf:
cctx.copy_stream(inf, outf)
print(f"Ready: {h5_path} ({os.path.getsize(h5_path)/1e9:.1f} GB)")
# Then use TrajectoryDataset instead of SyntheticPushTDataset
```
---
## Step 9: Push to Hugging Face Hub
```python
from huggingface_hub import HfApi, login
from getpass import getpass
hf_token = getpass("Enter HF token: ")
login(token=hf_token)
api = HfApi()
repo_id = "YOUR_USERNAME/lewm-colab-pusht"
api.create_repo(repo_id, repo_type="model", exist_ok=True)
api.upload_file(
path_or_fileobj=f"{OUTPUT_DIR}/epoch_10.pt",
path_in_repo="model.pt",
repo_id=repo_id,
repo_type="model",
)
print(f"βœ… Pushed to https://huggingface.co/{repo_id}")
```
---
## ⚑ Quick Tips
| Issue | Fix |
|-------|-----|
| **OOM** | Reduce `batch_size` to 64 or 32 |
| **Slow** | Reduce `n_episodes` to 500; use `img_size=112` |
| **Colab disconnects** | Save every epoch (already done); use shorter epochs |
| **Want faster planning** | Train on real data for better generalization |
| **SIGReg too high** | Increase `lambd` to 0.2 or 0.5 |
| **Prediction loss diverges** | Reduce `lr` to 5e-4; check data normalization |
---
## πŸ“Š Expected Results
With 2000 synthetic episodes, 10 epochs:
| Metric | Typical Value |
|--------|--------------|
| Final pred_loss | 0.05–0.15 |
| Final sigreg_loss | 0.3–0.6 |
| Training time | 30–60 min |
| GPU memory used | ~8–10 GB |
---
## πŸ”— Resources
- **Paper**: [arXiv:2603.19312](https://arxiv.org/abs/2603.19312)
- **This implementation**: https://huggingface.co/ar27111994/lewm-implementation
- **Interactive demo**: https://huggingface.co/spaces/ar27111994/lewm-explainable
- **Official repo**: https://github.com/lucas-maes/le-wm
- **Official datasets**: `quentinll/lewm-pusht`, `quentinll/lewm-cube`
---
*Happy training! πŸš€*