Spaces:
Sleeping
Sleeping
| """Minimal SmolVLA fine-tune loop for the norma-core pick-and-place parquets. | |
| Merges all parquets, computes stats over the full set, and fine-tunes | |
| `lerobot/smolvla_base`. Validation is done on device, not here. | |
| Run: | |
| uv run python scripts/train.py --steps 5000 --batch-size 26 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import math | |
| import time | |
| from functools import partial | |
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from smolvla import SmolVLAPolicy | |
| from smolvla.dataset import PickAndPlaceDataset, collate_samples | |
| from smolvla.normalize import normalize_action, normalize_state | |
| from smolvla.stats import compute_stats, save_stats | |
| ROOT = Path(__file__).resolve().parent.parent | |
| DATASETS = ROOT.parent / "datasets" | |
| CKPT_DIR = ROOT / "checkpoints" | |
| IMAGE_KEYS = ("observation.images.cam0", "observation.images.cam1") | |
| STATE_DIM = 6 | |
| ACTION_DIM = 6 | |
| DEFAULT_PARQUETS = [DATASETS / "dataset.parquet"] | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--steps", type=int, default=5000) | |
| p.add_argument("--batch-size", type=int, default=8) | |
| p.add_argument("--lr", type=float, default=1e-4) | |
| p.add_argument("--weight-decay", type=float, default=1e-10) | |
| p.add_argument("--warmup-steps", type=int, default=1000, | |
| help="Nominal warmup steps; auto-scaled if --steps < --decay-steps.") | |
| p.add_argument("--decay-steps", type=int, default=30000, | |
| help="Nominal cosine decay length; auto-scaled if --steps is shorter.") | |
| p.add_argument("--decay-lr", type=float, default=2.5e-6, | |
| help="LR floor at end of cosine decay.") | |
| p.add_argument("--grad-clip", type=float, default=10.0) | |
| p.add_argument("--save-every", type=int, default=1000) | |
| p.add_argument("--log-every", type=int, default=20) | |
| p.add_argument("--num-workers", type=int, default=2) | |
| p.add_argument("--seed", type=int, default=0) | |
| p.add_argument( | |
| "--parquets", | |
| type=Path, | |
| nargs="+", | |
| default=DEFAULT_PARQUETS, | |
| help="Parquet files to merge.", | |
| ) | |
| p.add_argument("--base-checkpoint", type=str, default="lerobot/smolvla_base") | |
| p.add_argument("--output", type=Path, default=CKPT_DIR / "run") | |
| return p.parse_args() | |
| def make_loader(ds, batch_size: int, shuffle: bool, num_workers: int) -> DataLoader: | |
| return DataLoader( | |
| ds, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| collate_fn=partial(collate_samples, image_keys=IMAGE_KEYS), | |
| pin_memory=True, | |
| drop_last=shuffle, | |
| persistent_workers=num_workers > 0, | |
| ) | |
| def move_batch(batch: dict, device: torch.device) -> dict: | |
| return { | |
| k: (v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v) | |
| for k, v in batch.items() | |
| } | |
| def build_train_batch(batch: dict, policy, stats, device) -> dict: | |
| batch = move_batch(batch, device) | |
| batch["observation.state"] = normalize_state(batch["observation.state"], stats) | |
| batch["action"] = normalize_action(batch["action"], stats) | |
| tokens, mask = policy.tokenize_task(batch.pop("task"), device=device) | |
| batch["observation.language.tokens"] = tokens | |
| batch["observation.language.attention_mask"] = mask | |
| return batch | |
| def lr_lambda( | |
| step: int, | |
| nominal_warmup: int, | |
| nominal_decay: int, | |
| total_steps: int, | |
| decay_lr: float, | |
| peak_lr: float, | |
| ) -> float: | |
| if total_steps < nominal_decay: | |
| scale = total_steps / nominal_decay | |
| warmup = max(int(nominal_warmup * scale), 1) | |
| decay = total_steps | |
| else: | |
| warmup = max(nominal_warmup, 1) | |
| decay = nominal_decay | |
| if step < warmup: | |
| return (step + 1) / warmup | |
| s = min(step - warmup, decay - warmup) | |
| cos = 0.5 * (1 + math.cos(math.pi * s / max(decay - warmup, 1))) | |
| alpha = decay_lr / peak_lr | |
| return (1 - alpha) * cos + alpha | |
| def main() -> None: | |
| args = parse_args() | |
| torch.manual_seed(args.seed) | |
| if not torch.cuda.is_available(): | |
| raise SystemExit("CUDA required.") | |
| device = torch.device("cuda") | |
| args.output.mkdir(parents=True, exist_ok=True) | |
| train_ds = PickAndPlaceDataset(args.parquets, image_keys=IMAGE_KEYS) | |
| print(f"episodes: {train_ds.total_episodes} frames: {len(train_ds)}") | |
| stats = compute_stats(train_ds._state, train_ds._action) | |
| stats_path = args.output / "stats.safetensors" | |
| save_stats(stats, stats_path) | |
| print(f"\nstats saved → {stats_path}") | |
| for k, v in stats.items(): | |
| print(f" {k:12s} {v.tolist()}") | |
| stats = {k: v.to(device) for k, v in stats.items()} | |
| # --- model --- | |
| print(f"\nLoading base checkpoint {args.base_checkpoint} ...") | |
| policy = SmolVLAPolicy.from_pretrained( | |
| args.base_checkpoint, | |
| config_overrides={ | |
| "image_keys": list(IMAGE_KEYS), | |
| "state_dim": STATE_DIM, | |
| "action_dim": ACTION_DIM, | |
| "load_vlm_weights": False, | |
| "empty_cameras": 0, | |
| }, | |
| strict=False, | |
| ).to(device) | |
| policy.train() | |
| trainable = [p for p in policy.parameters() if p.requires_grad] | |
| n_train = sum(p.numel() for p in trainable) | |
| n_total = sum(p.numel() for p in policy.parameters()) | |
| print(f"trainable params: {n_train:,} / {n_total:,} ({100*n_train/n_total:.1f}%)") | |
| opt = torch.optim.AdamW( | |
| trainable, lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay | |
| ) | |
| sched = torch.optim.lr_scheduler.LambdaLR( | |
| opt, lr_lambda=partial( | |
| lr_lambda, | |
| nominal_warmup=args.warmup_steps, | |
| nominal_decay=args.decay_steps, | |
| total_steps=args.steps, | |
| decay_lr=args.decay_lr, | |
| peak_lr=args.lr, | |
| ) | |
| ) | |
| train_loader = make_loader(train_ds, args.batch_size, shuffle=True, num_workers=args.num_workers) | |
| print(f"train batches/epoch: {len(train_loader)}") | |
| step = 0 | |
| t0 = time.time() | |
| running_loss = 0.0 | |
| running_n = 0 | |
| while step < args.steps: | |
| for raw_batch in train_loader: | |
| if step >= args.steps: | |
| break | |
| batch = build_train_batch(raw_batch, policy, stats, device) | |
| loss, info = policy.forward(batch) | |
| opt.zero_grad(set_to_none=True) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(trainable, args.grad_clip) | |
| opt.step() | |
| sched.step() | |
| running_loss += info["loss"] | |
| running_n += 1 | |
| if (step + 1) % args.log_every == 0: | |
| dt = time.time() - t0 | |
| avg = running_loss / running_n | |
| lr = sched.get_last_lr()[0] | |
| print(f"step {step+1:>6}/{args.steps} loss {avg:.4f} lr {lr:.2e} " | |
| f"({(step+1) / max(dt, 1e-9):.2f} step/s)") | |
| running_loss, running_n = 0.0, 0 | |
| if (step + 1) % args.save_every == 0: | |
| out = args.output / f"step-{step+1:06d}" | |
| policy.save_pretrained(out) | |
| (out / "stats.safetensors").write_bytes(stats_path.read_bytes()) | |
| print(f" [save] {out}") | |
| step += 1 | |
| final = args.output / "final" | |
| policy.save_pretrained(final) | |
| (final / "stats.safetensors").write_bytes(stats_path.read_bytes()) | |
| print(f"\ndone — final checkpoint at {final}") | |
| if __name__ == "__main__": | |
| main() | |