# 🚀 Free GPU Training Guide for LeWorldModel on Google Colab This guide walks you through training **LeWorldModel (~18M params)** on a **free Google Colab T4 GPU** (15 GB VRAM) — no paid credits required. **Time estimate**: 30–60 minutes for 10 epochs on synthetic data. --- ## Step 1: Open Colab & Enable GPU 1. Go to [colab.research.google.com](https://colab.research.google.com) 2. Create a new notebook 3. **Runtime → Change runtime type → Hardware accelerator: GPU** (T4) 4. Verify GPU: ```python !nvidia-smi import torch print(torch.cuda.get_device_name(0)) # Should print Tesla T4 ``` --- ## Step 2: Install Dependencies ```python %%capture !pip install -q transformers einops huggingface_hub matplotlib numpy tqdm ``` --- ## Step 3: Mount Google Drive (for checkpoints) ```python from google.colab import drive drive.mount('/content/drive') import os OUTPUT_DIR = "/content/drive/MyDrive/lewm_training" os.makedirs(OUTPUT_DIR, exist_ok=True) ``` --- ## Step 4: Download Implementation ```python !pip install -q huggingface_hub from huggingface_hub import hf_hub_download REPO_ID = "ar27111994/lewm-implementation" for fname in ["lewm_model.py", "lewm_train.py"]: hf_hub_download(repo_id=REPO_ID, filename=fname, local_dir="/content") print(f"✓ {fname}") ``` --- ## Step 5: Sanity Check ```python import sys sys.path.insert(0, "/content") from lewm_model import build_lewm import torch device = torch.device("cuda") model = build_lewm(action_dim=2, frameskip=5, history_size=3).to(device) print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M") # Quick forward B, T = 2, 4 out = model( torch.randn(B, T, 3, 224, 224).cuda(), torch.randn(B, T, 10).cuda(), history_size=3 ) print(f"loss={out['loss'].item():.4f} pred={out['pred_loss'].item():.4f} sigreg={out['sigreg_loss'].item():.4f}") print("✅ Ready!") ``` --- ## Step 6: Training Configuration ```python CONFIG = { "use_synthetic": True, # True = no 12GB download needed "n_episodes": 2000, # Synthetic episodes (use 500 for quick test) "seq_len": 4, "frameskip": 5, "img_size": 224, "action_dim": 2, "history_size": 3, "embed_dim": 192, "lambd": 0.1, # ONLY hyperparameter to tune! "epochs": 10, "batch_size": 128, # Reduce to 64 if OOM "lr": 1e-3, "weight_decay": 0.05, "grad_clip": 1.0, "num_workers": 2, "log_interval": 25, "output_dir": OUTPUT_DIR, } ``` --- ## Step 7: Full Training Loop Copy the full training cell from the notebook (available at the repository). The key components: ```python from lewm_train import SyntheticPushTDataset # or TrajectoryDataset for real data from torch.utils.data import DataLoader from transformers import get_cosine_schedule_with_warmup from lewm_model import SIGReg from tqdm.notebook import tqdm dataset = SyntheticPushTDataset( n_episodes=CONFIG["n_episodes"], seq_len=CONFIG["seq_len"], frameskip=CONFIG["frameskip"], img_size=CONFIG["img_size"], ) loader = DataLoader(dataset, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=2, drop_last=True, pin_memory=True) optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["weight_decay"], betas=(0.9, 0.95)) scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=int(0.05 * len(loader) * CONFIG["epochs"]), num_training_steps=len(loader) * CONFIG["epochs"]) sigreg = SIGReg(knots=17, num_proj=1024).cuda() for epoch in range(CONFIG["epochs"]): model.train() for obs, acts in tqdm(loader, desc=f"Epoch {epoch+1}"): obs, acts = obs.cuda(), acts.cuda() emb = model.encode(obs) act_emb = model.action_encoder(acts) ctx_emb = emb[:, :CONFIG["history_size"]] ctx_act = act_emb[:, :CONFIG["history_size"]] pred_emb = model.predict(ctx_emb, ctx_act) pred_loss = (pred_emb[:, :-1] - emb[:, 1:CONFIG["history_size"]]).pow(2).mean() sigreg_loss = sigreg(emb.transpose(0, 1)) loss = pred_loss + CONFIG["lambd"] * sigreg_loss optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["grad_clip"]) optimizer.step() scheduler.step() # Save checkpoint torch.save(model.state_dict(), f"{OUTPUT_DIR}/epoch_{epoch+1}.pt") print(f"Epoch {epoch+1} complete. Saved checkpoint.") ``` --- ## Step 8: Using Real PushT Dataset (Optional) **Warning**: The real dataset is **12 GB compressed**. Colab disk is only ~78 GB and may fill up. ```python # Download & decompress (takes ~5-10 minutes) !pip install -q zstandard h5py from huggingface_hub import hf_hub_download import zstandard, os zst_path = hf_hub_download( repo_id='quentinll/lewm-pusht', filename='pusht_expert_train.h5.zst', repo_type='dataset', local_dir='/content/data' ) h5_path = '/content/data/pusht_expert_train.h5' cctx = zstandard.ZstdDecompressor() with open(zst_path, 'rb') as inf, open(h5_path, 'wb') as outf: cctx.copy_stream(inf, outf) print(f"Ready: {h5_path} ({os.path.getsize(h5_path)/1e9:.1f} GB)") # Then use TrajectoryDataset instead of SyntheticPushTDataset ``` --- ## Step 9: Push to Hugging Face Hub ```python from huggingface_hub import HfApi, login from getpass import getpass hf_token = getpass("Enter HF token: ") login(token=hf_token) api = HfApi() repo_id = "YOUR_USERNAME/lewm-colab-pusht" api.create_repo(repo_id, repo_type="model", exist_ok=True) api.upload_file( path_or_fileobj=f"{OUTPUT_DIR}/epoch_10.pt", path_in_repo="model.pt", repo_id=repo_id, repo_type="model", ) print(f"✅ Pushed to https://huggingface.co/{repo_id}") ``` --- ## ⚡ Quick Tips | Issue | Fix | |-------|-----| | **OOM** | Reduce `batch_size` to 64 or 32 | | **Slow** | Reduce `n_episodes` to 500; use `img_size=112` | | **Colab disconnects** | Save every epoch (already done); use shorter epochs | | **Want faster planning** | Train on real data for better generalization | | **SIGReg too high** | Increase `lambd` to 0.2 or 0.5 | | **Prediction loss diverges** | Reduce `lr` to 5e-4; check data normalization | --- ## 📊 Expected Results With 2000 synthetic episodes, 10 epochs: | Metric | Typical Value | |--------|--------------| | Final pred_loss | 0.05–0.15 | | Final sigreg_loss | 0.3–0.6 | | Training time | 30–60 min | | GPU memory used | ~8–10 GB | --- ## 🔗 Resources - **Paper**: [arXiv:2603.19312](https://arxiv.org/abs/2603.19312) - **This implementation**: https://huggingface.co/ar27111994/lewm-implementation - **Interactive demo**: https://huggingface.co/spaces/ar27111994/lewm-explainable - **Official repo**: https://github.com/lucas-maes/le-wm - **Official datasets**: `quentinll/lewm-pusht`, `quentinll/lewm-cube` --- *Happy training! 🚀*