"""Fine-tune SmolVLA for ElRobot using LBST/t01_pick_and_place as base. ElRobot has 8 joints (7 DOF + gripper). Uses 1 camera to match LBST base. Run on Google Colab (T4 GPU): uv sync uv run python scripts/train_elrobot.py \ --steps 5000 --batch-size 32 \ --parquets datasets/*.parquet \ --base-checkpoint LBST/t01_pick_and_place \ --output checkpoints/elrobot-run """ from __future__ import annotations import argparse import math import time from functools import partial from pathlib import Path import torch from torch.utils.data import DataLoader from smolvla import SmolVLAPolicy from smolvla.dataset import PickAndPlaceDataset, collate_samples from smolvla.normalize import normalize_action, normalize_state from smolvla.stats import compute_stats, save_stats ROOT = Path(__file__).resolve().parent.parent DATASETS = ROOT / "datasets" CKPT_DIR = ROOT / "checkpoints" IMAGE_KEYS = ("observation.images.cam0",) STATE_DIM = 8 ACTION_DIM = 8 DEFAULT_PARQUETS = list(DATASETS.glob("*.parquet")) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--steps", type=int, default=5000) p.add_argument("--batch-size", type=int, default=32) p.add_argument("--lr", type=float, default=1e-4) p.add_argument("--weight-decay", type=float, default=1e-10) p.add_argument("--warmup-steps", type=int, default=500, help="Nominal warmup steps; auto-scaled if --steps < --decay-steps.") p.add_argument("--decay-steps", type=int, default=10000, help="Nominal cosine decay length; auto-scaled if --steps is shorter.") p.add_argument("--decay-lr", type=float, default=2.5e-6, help="LR floor at end of cosine decay.") p.add_argument("--grad-clip", type=float, default=10.0) p.add_argument("--save-every", type=int, default=1000) p.add_argument("--log-every", type=int, default=20) p.add_argument("--num-workers", type=int, default=2) p.add_argument("--seed", type=int, default=0) p.add_argument( "--parquets", type=Path, nargs="+", default=DEFAULT_PARQUETS, help="Parquet files to merge.", ) p.add_argument("--base-checkpoint", type=str, default="LBST/t01_pick_and_place") p.add_argument("--output", type=Path, default=CKPT_DIR / "elrobot-run") return p.parse_args() def make_loader(ds, batch_size: int, shuffle: bool, num_workers: int) -> DataLoader: return DataLoader( ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=partial(collate_samples, image_keys=IMAGE_KEYS), pin_memory=True, drop_last=shuffle, persistent_workers=num_workers > 0, ) def move_batch(batch: dict, device: torch.device) -> dict: return { k: (v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v) for k, v in batch.items() } def build_train_batch(batch: dict, policy, stats, device) -> dict: batch = move_batch(batch, device) batch["observation.state"] = normalize_state(batch["observation.state"], stats) batch["action"] = normalize_action(batch["action"], stats) tokens, mask = policy.tokenize_task(batch.pop("task"), device=device) batch["observation.language.tokens"] = tokens batch["observation.language.attention_mask"] = mask return batch def lr_lambda( step: int, nominal_warmup: int, nominal_decay: int, total_steps: int, decay_lr: float, peak_lr: float, ) -> float: if total_steps < nominal_decay: scale = total_steps / nominal_decay warmup = max(int(nominal_warmup * scale), 1) decay = total_steps else: warmup = max(nominal_warmup, 1) decay = nominal_decay if step < warmup: return (step + 1) / warmup s = min(step - warmup, decay - warmup) cos = 0.5 * (1 + math.cos(math.pi * s / max(decay - warmup, 1))) alpha = decay_lr / peak_lr return (1 - alpha) * cos + alpha def main() -> None: args = parse_args() torch.manual_seed(args.seed) if not torch.cuda.is_available(): raise SystemExit("CUDA required. Run this on Google Colab with a T4 GPU.") device = torch.device("cuda") args.output.mkdir(parents=True, exist_ok=True) parquets = args.parquets if not parquets: parquets = sorted(DATASETS.glob("*.parquet")) if not parquets: raise SystemExit(f"No parquet files found. Pass --parquets or put them in {DATASETS}") print(f"Loading {len(parquets)} parquet files...") train_ds = PickAndPlaceDataset(parquets, image_keys=IMAGE_KEYS) print(f"episodes: {train_ds.total_episodes} frames: {len(train_ds)}") stats = compute_stats(train_ds._state, train_ds._action) stats_path = args.output / "stats.safetensors" save_stats(stats, stats_path) print(f"\nstats saved -> {stats_path}") for k, v in stats.items(): print(f" {k:12s} {v.tolist()}") stats = {k: v.to(device) for k, v in stats.items()} print(f"\nLoading base checkpoint {args.base_checkpoint} ...") policy = SmolVLAPolicy.from_pretrained( args.base_checkpoint, config_overrides={ "image_keys": list(IMAGE_KEYS), "state_dim": STATE_DIM, "action_dim": ACTION_DIM, "load_vlm_weights": False, "empty_cameras": 0, }, strict=False, ).to(device) policy.train() trainable = [p for p in policy.parameters() if p.requires_grad] n_train = sum(p.numel() for p in trainable) n_total = sum(p.numel() for p in policy.parameters()) print(f"trainable params: {n_train:,} / {n_total:,} ({100*n_train/n_total:.1f}%)") opt = torch.optim.AdamW( trainable, lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay ) sched = torch.optim.lr_scheduler.LambdaLR( opt, lr_lambda=partial( lr_lambda, nominal_warmup=args.warmup_steps, nominal_decay=args.decay_steps, total_steps=args.steps, decay_lr=args.decay_lr, peak_lr=args.lr, ) ) train_loader = make_loader(train_ds, args.batch_size, shuffle=True, num_workers=args.num_workers) print(f"train batches/epoch: {len(train_loader)}") step = 0 t0 = time.time() running_loss = 0.0 running_n = 0 while step < args.steps: for raw_batch in train_loader: if step >= args.steps: break batch = build_train_batch(raw_batch, policy, stats, device) loss, info = policy.forward(batch) opt.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(trainable, args.grad_clip) opt.step() sched.step() running_loss += info["loss"] running_n += 1 if (step + 1) % args.log_every == 0: dt = time.time() - t0 avg = running_loss / running_n lr = sched.get_last_lr()[0] print(f"step {step+1:>6}/{args.steps} loss {avg:.4f} lr {lr:.2e} " f"({(step+1) / max(dt, 1e-9):.2f} step/s)") running_loss, running_n = 0.0, 0 if (step + 1) % args.save_every == 0: out = args.output / f"step-{step+1:06d}" policy.save_pretrained(out) (out / "stats.safetensors").write_bytes(stats_path.read_bytes()) print(f" [save] {out}") step += 1 final = args.output / "final" policy.save_pretrained(final) (final / "stats.safetensors").write_bytes(stats_path.read_bytes()) print(f"\ndone — final checkpoint at {final}") if __name__ == "__main__": main()