lewm-implementation / COLAB_GUIDE.md
ar27111994's picture
Upload COLAB_GUIDE.md
d816568 verified

πŸš€ Free GPU Training Guide for LeWorldModel on Google Colab

This guide walks you through training LeWorldModel (~18M params) on a free Google Colab T4 GPU (15 GB VRAM) β€” no paid credits required.

Time estimate: 30–60 minutes for 10 epochs on synthetic data.


Step 1: Open Colab & Enable GPU

  1. Go to colab.research.google.com
  2. Create a new notebook
  3. Runtime β†’ Change runtime type β†’ Hardware accelerator: GPU (T4)
  4. Verify GPU:
!nvidia-smi
import torch
print(torch.cuda.get_device_name(0))  # Should print Tesla T4

Step 2: Install Dependencies

%%capture
!pip install -q transformers einops huggingface_hub matplotlib numpy tqdm

Step 3: Mount Google Drive (for checkpoints)

from google.colab import drive
drive.mount('/content/drive')

import os
OUTPUT_DIR = "/content/drive/MyDrive/lewm_training"
os.makedirs(OUTPUT_DIR, exist_ok=True)

Step 4: Download Implementation

!pip install -q huggingface_hub
from huggingface_hub import hf_hub_download

REPO_ID = "ar27111994/lewm-implementation"

for fname in ["lewm_model.py", "lewm_train.py"]:
    hf_hub_download(repo_id=REPO_ID, filename=fname, local_dir="/content")
    print(f"βœ“ {fname}")

Step 5: Sanity Check

import sys
sys.path.insert(0, "/content")

from lewm_model import build_lewm
import torch

device = torch.device("cuda")
model = build_lewm(action_dim=2, frameskip=5, history_size=3).to(device)
print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

# Quick forward
B, T = 2, 4
out = model(
    torch.randn(B, T, 3, 224, 224).cuda(),
    torch.randn(B, T, 10).cuda(),
    history_size=3
)
print(f"loss={out['loss'].item():.4f} pred={out['pred_loss'].item():.4f} sigreg={out['sigreg_loss'].item():.4f}")
print("βœ… Ready!")

Step 6: Training Configuration

CONFIG = {
    "use_synthetic": True,      # True = no 12GB download needed
    "n_episodes": 2000,         # Synthetic episodes (use 500 for quick test)
    "seq_len": 4,
    "frameskip": 5,
    "img_size": 224,
    "action_dim": 2,
    "history_size": 3,
    "embed_dim": 192,
    "lambd": 0.1,              # ONLY hyperparameter to tune!
    "epochs": 10,
    "batch_size": 128,         # Reduce to 64 if OOM
    "lr": 1e-3,
    "weight_decay": 0.05,
    "grad_clip": 1.0,
    "num_workers": 2,
    "log_interval": 25,
    "output_dir": OUTPUT_DIR,
}

Step 7: Full Training Loop

Copy the full training cell from the notebook (available at the repository). The key components:

from lewm_train import SyntheticPushTDataset  # or TrajectoryDataset for real data
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from lewm_model import SIGReg
from tqdm.notebook import tqdm

dataset = SyntheticPushTDataset(
    n_episodes=CONFIG["n_episodes"],
    seq_len=CONFIG["seq_len"],
    frameskip=CONFIG["frameskip"],
    img_size=CONFIG["img_size"],
)

loader = DataLoader(dataset, batch_size=CONFIG["batch_size"],
                    shuffle=True, num_workers=2, drop_last=True, pin_memory=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"],
                               weight_decay=CONFIG["weight_decay"], betas=(0.9, 0.95))

scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=int(0.05 * len(loader) * CONFIG["epochs"]),
    num_training_steps=len(loader) * CONFIG["epochs"])

sigreg = SIGReg(knots=17, num_proj=1024).cuda()

for epoch in range(CONFIG["epochs"]):
    model.train()
    for obs, acts in tqdm(loader, desc=f"Epoch {epoch+1}"):
        obs, acts = obs.cuda(), acts.cuda()
        
        emb = model.encode(obs)
        act_emb = model.action_encoder(acts)
        
        ctx_emb = emb[:, :CONFIG["history_size"]]
        ctx_act = act_emb[:, :CONFIG["history_size"]]
        pred_emb = model.predict(ctx_emb, ctx_act)
        
        pred_loss = (pred_emb[:, :-1] - emb[:, 1:CONFIG["history_size"]]).pow(2).mean()
        sigreg_loss = sigreg(emb.transpose(0, 1))
        loss = pred_loss + CONFIG["lambd"] * sigreg_loss
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["grad_clip"])
        optimizer.step()
        scheduler.step()
    
    # Save checkpoint
    torch.save(model.state_dict(), f"{OUTPUT_DIR}/epoch_{epoch+1}.pt")
    print(f"Epoch {epoch+1} complete. Saved checkpoint.")

Step 8: Using Real PushT Dataset (Optional)

Warning: The real dataset is 12 GB compressed. Colab disk is only ~78 GB and may fill up.

# Download & decompress (takes ~5-10 minutes)
!pip install -q zstandard h5py
from huggingface_hub import hf_hub_download
import zstandard, os

zst_path = hf_hub_download(
    repo_id='quentinll/lewm-pusht',
    filename='pusht_expert_train.h5.zst',
    repo_type='dataset',
    local_dir='/content/data'
)

h5_path = '/content/data/pusht_expert_train.h5'
cctx = zstandard.ZstdDecompressor()
with open(zst_path, 'rb') as inf, open(h5_path, 'wb') as outf:
    cctx.copy_stream(inf, outf)

print(f"Ready: {h5_path} ({os.path.getsize(h5_path)/1e9:.1f} GB)")

# Then use TrajectoryDataset instead of SyntheticPushTDataset

Step 9: Push to Hugging Face Hub

from huggingface_hub import HfApi, login
from getpass import getpass

hf_token = getpass("Enter HF token: ")
login(token=hf_token)

api = HfApi()
repo_id = "YOUR_USERNAME/lewm-colab-pusht"
api.create_repo(repo_id, repo_type="model", exist_ok=True)

api.upload_file(
    path_or_fileobj=f"{OUTPUT_DIR}/epoch_10.pt",
    path_in_repo="model.pt",
    repo_id=repo_id,
    repo_type="model",
)
print(f"βœ… Pushed to https://huggingface.co/{repo_id}")

⚑ Quick Tips

Issue Fix
OOM Reduce batch_size to 64 or 32
Slow Reduce n_episodes to 500; use img_size=112
Colab disconnects Save every epoch (already done); use shorter epochs
Want faster planning Train on real data for better generalization
SIGReg too high Increase lambd to 0.2 or 0.5
Prediction loss diverges Reduce lr to 5e-4; check data normalization

πŸ“Š Expected Results

With 2000 synthetic episodes, 10 epochs:

Metric Typical Value
Final pred_loss 0.05–0.15
Final sigreg_loss 0.3–0.6
Training time 30–60 min
GPU memory used ~8–10 GB

πŸ”— Resources


Happy training! πŸš€