π Free GPU Training Guide for LeWorldModel on Google Colab
This guide walks you through training LeWorldModel (~18M params) on a free Google Colab T4 GPU (15 GB VRAM) β no paid credits required.
Time estimate: 30β60 minutes for 10 epochs on synthetic data.
Step 1: Open Colab & Enable GPU
- Go to colab.research.google.com
- Create a new notebook
- Runtime β Change runtime type β Hardware accelerator: GPU (T4)
- Verify GPU:
!nvidia-smi
import torch
print(torch.cuda.get_device_name(0)) # Should print Tesla T4
Step 2: Install Dependencies
%%capture
!pip install -q transformers einops huggingface_hub matplotlib numpy tqdm
Step 3: Mount Google Drive (for checkpoints)
from google.colab import drive
drive.mount('/content/drive')
import os
OUTPUT_DIR = "/content/drive/MyDrive/lewm_training"
os.makedirs(OUTPUT_DIR, exist_ok=True)
Step 4: Download Implementation
!pip install -q huggingface_hub
from huggingface_hub import hf_hub_download
REPO_ID = "ar27111994/lewm-implementation"
for fname in ["lewm_model.py", "lewm_train.py"]:
hf_hub_download(repo_id=REPO_ID, filename=fname, local_dir="/content")
print(f"β {fname}")
Step 5: Sanity Check
import sys
sys.path.insert(0, "/content")
from lewm_model import build_lewm
import torch
device = torch.device("cuda")
model = build_lewm(action_dim=2, frameskip=5, history_size=3).to(device)
print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")
# Quick forward
B, T = 2, 4
out = model(
torch.randn(B, T, 3, 224, 224).cuda(),
torch.randn(B, T, 10).cuda(),
history_size=3
)
print(f"loss={out['loss'].item():.4f} pred={out['pred_loss'].item():.4f} sigreg={out['sigreg_loss'].item():.4f}")
print("β
Ready!")
Step 6: Training Configuration
CONFIG = {
"use_synthetic": True, # True = no 12GB download needed
"n_episodes": 2000, # Synthetic episodes (use 500 for quick test)
"seq_len": 4,
"frameskip": 5,
"img_size": 224,
"action_dim": 2,
"history_size": 3,
"embed_dim": 192,
"lambd": 0.1, # ONLY hyperparameter to tune!
"epochs": 10,
"batch_size": 128, # Reduce to 64 if OOM
"lr": 1e-3,
"weight_decay": 0.05,
"grad_clip": 1.0,
"num_workers": 2,
"log_interval": 25,
"output_dir": OUTPUT_DIR,
}
Step 7: Full Training Loop
Copy the full training cell from the notebook (available at the repository). The key components:
from lewm_train import SyntheticPushTDataset # or TrajectoryDataset for real data
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from lewm_model import SIGReg
from tqdm.notebook import tqdm
dataset = SyntheticPushTDataset(
n_episodes=CONFIG["n_episodes"],
seq_len=CONFIG["seq_len"],
frameskip=CONFIG["frameskip"],
img_size=CONFIG["img_size"],
)
loader = DataLoader(dataset, batch_size=CONFIG["batch_size"],
shuffle=True, num_workers=2, drop_last=True, pin_memory=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"],
weight_decay=CONFIG["weight_decay"], betas=(0.9, 0.95))
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=int(0.05 * len(loader) * CONFIG["epochs"]),
num_training_steps=len(loader) * CONFIG["epochs"])
sigreg = SIGReg(knots=17, num_proj=1024).cuda()
for epoch in range(CONFIG["epochs"]):
model.train()
for obs, acts in tqdm(loader, desc=f"Epoch {epoch+1}"):
obs, acts = obs.cuda(), acts.cuda()
emb = model.encode(obs)
act_emb = model.action_encoder(acts)
ctx_emb = emb[:, :CONFIG["history_size"]]
ctx_act = act_emb[:, :CONFIG["history_size"]]
pred_emb = model.predict(ctx_emb, ctx_act)
pred_loss = (pred_emb[:, :-1] - emb[:, 1:CONFIG["history_size"]]).pow(2).mean()
sigreg_loss = sigreg(emb.transpose(0, 1))
loss = pred_loss + CONFIG["lambd"] * sigreg_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["grad_clip"])
optimizer.step()
scheduler.step()
# Save checkpoint
torch.save(model.state_dict(), f"{OUTPUT_DIR}/epoch_{epoch+1}.pt")
print(f"Epoch {epoch+1} complete. Saved checkpoint.")
Step 8: Using Real PushT Dataset (Optional)
Warning: The real dataset is 12 GB compressed. Colab disk is only ~78 GB and may fill up.
# Download & decompress (takes ~5-10 minutes)
!pip install -q zstandard h5py
from huggingface_hub import hf_hub_download
import zstandard, os
zst_path = hf_hub_download(
repo_id='quentinll/lewm-pusht',
filename='pusht_expert_train.h5.zst',
repo_type='dataset',
local_dir='/content/data'
)
h5_path = '/content/data/pusht_expert_train.h5'
cctx = zstandard.ZstdDecompressor()
with open(zst_path, 'rb') as inf, open(h5_path, 'wb') as outf:
cctx.copy_stream(inf, outf)
print(f"Ready: {h5_path} ({os.path.getsize(h5_path)/1e9:.1f} GB)")
# Then use TrajectoryDataset instead of SyntheticPushTDataset
Step 9: Push to Hugging Face Hub
from huggingface_hub import HfApi, login
from getpass import getpass
hf_token = getpass("Enter HF token: ")
login(token=hf_token)
api = HfApi()
repo_id = "YOUR_USERNAME/lewm-colab-pusht"
api.create_repo(repo_id, repo_type="model", exist_ok=True)
api.upload_file(
path_or_fileobj=f"{OUTPUT_DIR}/epoch_10.pt",
path_in_repo="model.pt",
repo_id=repo_id,
repo_type="model",
)
print(f"β
Pushed to https://huggingface.co/{repo_id}")
β‘ Quick Tips
| Issue | Fix |
|---|---|
| OOM | Reduce batch_size to 64 or 32 |
| Slow | Reduce n_episodes to 500; use img_size=112 |
| Colab disconnects | Save every epoch (already done); use shorter epochs |
| Want faster planning | Train on real data for better generalization |
| SIGReg too high | Increase lambd to 0.2 or 0.5 |
| Prediction loss diverges | Reduce lr to 5e-4; check data normalization |
π Expected Results
With 2000 synthetic episodes, 10 epochs:
| Metric | Typical Value |
|---|---|
| Final pred_loss | 0.05β0.15 |
| Final sigreg_loss | 0.3β0.6 |
| Training time | 30β60 min |
| GPU memory used | ~8β10 GB |
π Resources
- Paper: arXiv:2603.19312
- This implementation: https://huggingface.co/ar27111994/lewm-implementation
- Interactive demo: https://huggingface.co/spaces/ar27111994/lewm-explainable
- Official repo: https://github.com/lucas-maes/le-wm
- Official datasets:
quentinll/lewm-pusht,quentinll/lewm-cube
Happy training! π