| # π Free GPU Training Guide for LeWorldModel on Google Colab |
|
|
| This guide walks you through training **LeWorldModel (~18M params)** on a **free Google Colab T4 GPU** (15 GB VRAM) β no paid credits required. |
|
|
| **Time estimate**: 30β60 minutes for 10 epochs on synthetic data. |
|
|
| --- |
|
|
| ## Step 1: Open Colab & Enable GPU |
|
|
| 1. Go to [colab.research.google.com](https://colab.research.google.com) |
| 2. Create a new notebook |
| 3. **Runtime β Change runtime type β Hardware accelerator: GPU** (T4) |
| 4. Verify GPU: |
| ```python |
| !nvidia-smi |
| import torch |
| print(torch.cuda.get_device_name(0)) # Should print Tesla T4 |
| ``` |
|
|
| --- |
|
|
| ## Step 2: Install Dependencies |
|
|
| ```python |
| %%capture |
| !pip install -q transformers einops huggingface_hub matplotlib numpy tqdm |
| ``` |
|
|
| --- |
|
|
| ## Step 3: Mount Google Drive (for checkpoints) |
|
|
| ```python |
| from google.colab import drive |
| drive.mount('/content/drive') |
| |
| import os |
| OUTPUT_DIR = "/content/drive/MyDrive/lewm_training" |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| ``` |
|
|
| --- |
|
|
| ## Step 4: Download Implementation |
|
|
| ```python |
| !pip install -q huggingface_hub |
| from huggingface_hub import hf_hub_download |
| |
| REPO_ID = "ar27111994/lewm-implementation" |
| |
| for fname in ["lewm_model.py", "lewm_train.py"]: |
| hf_hub_download(repo_id=REPO_ID, filename=fname, local_dir="/content") |
| print(f"β {fname}") |
| ``` |
|
|
| --- |
|
|
| ## Step 5: Sanity Check |
|
|
| ```python |
| import sys |
| sys.path.insert(0, "/content") |
| |
| from lewm_model import build_lewm |
| import torch |
| |
| device = torch.device("cuda") |
| model = build_lewm(action_dim=2, frameskip=5, history_size=3).to(device) |
| print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M") |
| |
| # Quick forward |
| B, T = 2, 4 |
| out = model( |
| torch.randn(B, T, 3, 224, 224).cuda(), |
| torch.randn(B, T, 10).cuda(), |
| history_size=3 |
| ) |
| print(f"loss={out['loss'].item():.4f} pred={out['pred_loss'].item():.4f} sigreg={out['sigreg_loss'].item():.4f}") |
| print("β
Ready!") |
| ``` |
|
|
| --- |
|
|
| ## Step 6: Training Configuration |
|
|
| ```python |
| CONFIG = { |
| "use_synthetic": True, # True = no 12GB download needed |
| "n_episodes": 2000, # Synthetic episodes (use 500 for quick test) |
| "seq_len": 4, |
| "frameskip": 5, |
| "img_size": 224, |
| "action_dim": 2, |
| "history_size": 3, |
| "embed_dim": 192, |
| "lambd": 0.1, # ONLY hyperparameter to tune! |
| "epochs": 10, |
| "batch_size": 128, # Reduce to 64 if OOM |
| "lr": 1e-3, |
| "weight_decay": 0.05, |
| "grad_clip": 1.0, |
| "num_workers": 2, |
| "log_interval": 25, |
| "output_dir": OUTPUT_DIR, |
| } |
| ``` |
|
|
| --- |
|
|
| ## Step 7: Full Training Loop |
|
|
| Copy the full training cell from the notebook (available at the repository). The key components: |
|
|
| ```python |
| from lewm_train import SyntheticPushTDataset # or TrajectoryDataset for real data |
| from torch.utils.data import DataLoader |
| from transformers import get_cosine_schedule_with_warmup |
| from lewm_model import SIGReg |
| from tqdm.notebook import tqdm |
| |
| dataset = SyntheticPushTDataset( |
| n_episodes=CONFIG["n_episodes"], |
| seq_len=CONFIG["seq_len"], |
| frameskip=CONFIG["frameskip"], |
| img_size=CONFIG["img_size"], |
| ) |
| |
| loader = DataLoader(dataset, batch_size=CONFIG["batch_size"], |
| shuffle=True, num_workers=2, drop_last=True, pin_memory=True) |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"], |
| weight_decay=CONFIG["weight_decay"], betas=(0.9, 0.95)) |
| |
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer, num_warmup_steps=int(0.05 * len(loader) * CONFIG["epochs"]), |
| num_training_steps=len(loader) * CONFIG["epochs"]) |
| |
| sigreg = SIGReg(knots=17, num_proj=1024).cuda() |
| |
| for epoch in range(CONFIG["epochs"]): |
| model.train() |
| for obs, acts in tqdm(loader, desc=f"Epoch {epoch+1}"): |
| obs, acts = obs.cuda(), acts.cuda() |
| |
| emb = model.encode(obs) |
| act_emb = model.action_encoder(acts) |
| |
| ctx_emb = emb[:, :CONFIG["history_size"]] |
| ctx_act = act_emb[:, :CONFIG["history_size"]] |
| pred_emb = model.predict(ctx_emb, ctx_act) |
| |
| pred_loss = (pred_emb[:, :-1] - emb[:, 1:CONFIG["history_size"]]).pow(2).mean() |
| sigreg_loss = sigreg(emb.transpose(0, 1)) |
| loss = pred_loss + CONFIG["lambd"] * sigreg_loss |
| |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["grad_clip"]) |
| optimizer.step() |
| scheduler.step() |
| |
| # Save checkpoint |
| torch.save(model.state_dict(), f"{OUTPUT_DIR}/epoch_{epoch+1}.pt") |
| print(f"Epoch {epoch+1} complete. Saved checkpoint.") |
| ``` |
|
|
| --- |
|
|
| ## Step 8: Using Real PushT Dataset (Optional) |
|
|
| **Warning**: The real dataset is **12 GB compressed**. Colab disk is only ~78 GB and may fill up. |
|
|
| ```python |
| # Download & decompress (takes ~5-10 minutes) |
| !pip install -q zstandard h5py |
| from huggingface_hub import hf_hub_download |
| import zstandard, os |
| |
| zst_path = hf_hub_download( |
| repo_id='quentinll/lewm-pusht', |
| filename='pusht_expert_train.h5.zst', |
| repo_type='dataset', |
| local_dir='/content/data' |
| ) |
| |
| h5_path = '/content/data/pusht_expert_train.h5' |
| cctx = zstandard.ZstdDecompressor() |
| with open(zst_path, 'rb') as inf, open(h5_path, 'wb') as outf: |
| cctx.copy_stream(inf, outf) |
| |
| print(f"Ready: {h5_path} ({os.path.getsize(h5_path)/1e9:.1f} GB)") |
| |
| # Then use TrajectoryDataset instead of SyntheticPushTDataset |
| ``` |
|
|
| --- |
|
|
| ## Step 9: Push to Hugging Face Hub |
|
|
| ```python |
| from huggingface_hub import HfApi, login |
| from getpass import getpass |
| |
| hf_token = getpass("Enter HF token: ") |
| login(token=hf_token) |
| |
| api = HfApi() |
| repo_id = "YOUR_USERNAME/lewm-colab-pusht" |
| api.create_repo(repo_id, repo_type="model", exist_ok=True) |
| |
| api.upload_file( |
| path_or_fileobj=f"{OUTPUT_DIR}/epoch_10.pt", |
| path_in_repo="model.pt", |
| repo_id=repo_id, |
| repo_type="model", |
| ) |
| print(f"β
Pushed to https://huggingface.co/{repo_id}") |
| ``` |
|
|
| --- |
|
|
| ## β‘ Quick Tips |
|
|
| | Issue | Fix | |
| |-------|-----| |
| | **OOM** | Reduce `batch_size` to 64 or 32 | |
| | **Slow** | Reduce `n_episodes` to 500; use `img_size=112` | |
| | **Colab disconnects** | Save every epoch (already done); use shorter epochs | |
| | **Want faster planning** | Train on real data for better generalization | |
| | **SIGReg too high** | Increase `lambd` to 0.2 or 0.5 | |
| | **Prediction loss diverges** | Reduce `lr` to 5e-4; check data normalization | |
|
|
| --- |
|
|
| ## π Expected Results |
|
|
| With 2000 synthetic episodes, 10 epochs: |
|
|
| | Metric | Typical Value | |
| |--------|--------------| |
| | Final pred_loss | 0.05β0.15 | |
| | Final sigreg_loss | 0.3β0.6 | |
| | Training time | 30β60 min | |
| | GPU memory used | ~8β10 GB | |
|
|
| --- |
|
|
| ## π Resources |
|
|
| - **Paper**: [arXiv:2603.19312](https://arxiv.org/abs/2603.19312) |
| - **This implementation**: https://huggingface.co/ar27111994/lewm-implementation |
| - **Interactive demo**: https://huggingface.co/spaces/ar27111994/lewm-explainable |
| - **Official repo**: https://github.com/lucas-maes/le-wm |
| - **Official datasets**: `quentinll/lewm-pusht`, `quentinll/lewm-cube` |
|
|
| --- |
|
|
| *Happy training! π* |
|
|