venayc's picture
Upload 31 files
59653ee verified
Raw
History Blame Contribute Delete
7.69 kB
"""Minimal SmolVLA fine-tune loop for the norma-core pick-and-place parquets.
Merges all parquets, computes stats over the full set, and fine-tunes
`lerobot/smolvla_base`. Validation is done on device, not here.
Run:
uv run python scripts/train.py --steps 5000 --batch-size 26
"""
from __future__ import annotations
import argparse
import math
import time
from functools import partial
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from smolvla import SmolVLAPolicy
from smolvla.dataset import PickAndPlaceDataset, collate_samples
from smolvla.normalize import normalize_action, normalize_state
from smolvla.stats import compute_stats, save_stats
ROOT = Path(__file__).resolve().parent.parent
DATASETS = ROOT.parent / "datasets"
CKPT_DIR = ROOT / "checkpoints"
IMAGE_KEYS = ("observation.images.cam0", "observation.images.cam1")
STATE_DIM = 6
ACTION_DIM = 6
DEFAULT_PARQUETS = [DATASETS / "dataset.parquet"]
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--steps", type=int, default=5000)
p.add_argument("--batch-size", type=int, default=8)
p.add_argument("--lr", type=float, default=1e-4)
p.add_argument("--weight-decay", type=float, default=1e-10)
p.add_argument("--warmup-steps", type=int, default=1000,
help="Nominal warmup steps; auto-scaled if --steps < --decay-steps.")
p.add_argument("--decay-steps", type=int, default=30000,
help="Nominal cosine decay length; auto-scaled if --steps is shorter.")
p.add_argument("--decay-lr", type=float, default=2.5e-6,
help="LR floor at end of cosine decay.")
p.add_argument("--grad-clip", type=float, default=10.0)
p.add_argument("--save-every", type=int, default=1000)
p.add_argument("--log-every", type=int, default=20)
p.add_argument("--num-workers", type=int, default=2)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--parquets",
type=Path,
nargs="+",
default=DEFAULT_PARQUETS,
help="Parquet files to merge.",
)
p.add_argument("--base-checkpoint", type=str, default="lerobot/smolvla_base")
p.add_argument("--output", type=Path, default=CKPT_DIR / "run")
return p.parse_args()
def make_loader(ds, batch_size: int, shuffle: bool, num_workers: int) -> DataLoader:
return DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=partial(collate_samples, image_keys=IMAGE_KEYS),
pin_memory=True,
drop_last=shuffle,
persistent_workers=num_workers > 0,
)
def move_batch(batch: dict, device: torch.device) -> dict:
return {
k: (v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v)
for k, v in batch.items()
}
def build_train_batch(batch: dict, policy, stats, device) -> dict:
batch = move_batch(batch, device)
batch["observation.state"] = normalize_state(batch["observation.state"], stats)
batch["action"] = normalize_action(batch["action"], stats)
tokens, mask = policy.tokenize_task(batch.pop("task"), device=device)
batch["observation.language.tokens"] = tokens
batch["observation.language.attention_mask"] = mask
return batch
def lr_lambda(
step: int,
nominal_warmup: int,
nominal_decay: int,
total_steps: int,
decay_lr: float,
peak_lr: float,
) -> float:
if total_steps < nominal_decay:
scale = total_steps / nominal_decay
warmup = max(int(nominal_warmup * scale), 1)
decay = total_steps
else:
warmup = max(nominal_warmup, 1)
decay = nominal_decay
if step < warmup:
return (step + 1) / warmup
s = min(step - warmup, decay - warmup)
cos = 0.5 * (1 + math.cos(math.pi * s / max(decay - warmup, 1)))
alpha = decay_lr / peak_lr
return (1 - alpha) * cos + alpha
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
if not torch.cuda.is_available():
raise SystemExit("CUDA required.")
device = torch.device("cuda")
args.output.mkdir(parents=True, exist_ok=True)
train_ds = PickAndPlaceDataset(args.parquets, image_keys=IMAGE_KEYS)
print(f"episodes: {train_ds.total_episodes} frames: {len(train_ds)}")
stats = compute_stats(train_ds._state, train_ds._action)
stats_path = args.output / "stats.safetensors"
save_stats(stats, stats_path)
print(f"\nstats saved → {stats_path}")
for k, v in stats.items():
print(f" {k:12s} {v.tolist()}")
stats = {k: v.to(device) for k, v in stats.items()}
# --- model ---
print(f"\nLoading base checkpoint {args.base_checkpoint} ...")
policy = SmolVLAPolicy.from_pretrained(
args.base_checkpoint,
config_overrides={
"image_keys": list(IMAGE_KEYS),
"state_dim": STATE_DIM,
"action_dim": ACTION_DIM,
"load_vlm_weights": False,
"empty_cameras": 0,
},
strict=False,
).to(device)
policy.train()
trainable = [p for p in policy.parameters() if p.requires_grad]
n_train = sum(p.numel() for p in trainable)
n_total = sum(p.numel() for p in policy.parameters())
print(f"trainable params: {n_train:,} / {n_total:,} ({100*n_train/n_total:.1f}%)")
opt = torch.optim.AdamW(
trainable, lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay
)
sched = torch.optim.lr_scheduler.LambdaLR(
opt, lr_lambda=partial(
lr_lambda,
nominal_warmup=args.warmup_steps,
nominal_decay=args.decay_steps,
total_steps=args.steps,
decay_lr=args.decay_lr,
peak_lr=args.lr,
)
)
train_loader = make_loader(train_ds, args.batch_size, shuffle=True, num_workers=args.num_workers)
print(f"train batches/epoch: {len(train_loader)}")
step = 0
t0 = time.time()
running_loss = 0.0
running_n = 0
while step < args.steps:
for raw_batch in train_loader:
if step >= args.steps:
break
batch = build_train_batch(raw_batch, policy, stats, device)
loss, info = policy.forward(batch)
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(trainable, args.grad_clip)
opt.step()
sched.step()
running_loss += info["loss"]
running_n += 1
if (step + 1) % args.log_every == 0:
dt = time.time() - t0
avg = running_loss / running_n
lr = sched.get_last_lr()[0]
print(f"step {step+1:>6}/{args.steps} loss {avg:.4f} lr {lr:.2e} "
f"({(step+1) / max(dt, 1e-9):.2f} step/s)")
running_loss, running_n = 0.0, 0
if (step + 1) % args.save_every == 0:
out = args.output / f"step-{step+1:06d}"
policy.save_pretrained(out)
(out / "stats.safetensors").write_bytes(stats_path.read_bytes())
print(f" [save] {out}")
step += 1
final = args.output / "final"
policy.save_pretrained(final)
(final / "stats.safetensors").write_bytes(stats_path.read_bytes())
print(f"\ndone — final checkpoint at {final}")
if __name__ == "__main__":
main()