#!/usr/bin/env python3 """Train the RLT Stage-1 encoder/decoder on cached (M,2560) prefix shards. Uses the EXACT training knobs from the openpi reference (pravsels/openpi PR #6, the `pi05_rl_token_bin_pack_coffee_capsules` TrainConfig), adapted for our frozen-VLA / pre-cached-features setting (alpha = 0, so no VLA forward here): optimizer AdamW, clip_gradient_norm = 1.0 lr schedule linear warmup (1000) -> constant peak_lr = 5e-5 (their CosineDecay had peak_lr == decay_lr == 5e-5 = flat) ema_decay 0.999 (eval/save from the EMA weights) loss per-token squared-L2, sum over dim, mean over valid tokens, targets stop-gradiented (matches rl_token_encoder forward()) Deviations from the reference, on purpose: * batch_size > 1: their bs=1 was forced by running the full pi05 VLA each step; our enc/dec is tiny and features are cached, so we batch and pad+mask. * NO feature standardization (reference reconstructs raw prefix_out). A --standardize escape hatch is provided but OFF by default to stay faithful. Run (server must be DOWN first to free VRAM): ./lerobot/.venv/bin/python train_encoder.py \ --shard-dir ./encoder_cache_prefix --out ./checkpoints/rl_token_encoder """ from __future__ import annotations import argparse import glob import os import time import numpy as np import torch from torch.utils.data import DataLoader, Dataset, random_split from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig class PrefixShards(Dataset): """Each .npz holds `embeddings` (M, dim) float16 — one cached prefix.""" def __init__(self, shard_dir: str): self.paths = sorted(glob.glob(os.path.join(os.path.expanduser(shard_dir), "*.npz"))) if not self.paths: raise FileNotFoundError(f"no .npz shards in {shard_dir}") # episode_id per shard (parsed from filename ep{NNNN}_...) for the # success/failure t-SNE gate later; cheap to keep around. self.episodes = [self._ep(p) for p in self.paths] @staticmethod def _ep(path: str) -> int: base = os.path.basename(path) if base.startswith("ep"): try: return int(base[2:6]) except ValueError: pass return -1 def __len__(self) -> int: return len(self.paths) def __getitem__(self, i: int) -> torch.Tensor: with np.load(self.paths[i]) as z: emb = z["embeddings"].astype(np.float32) # (M, dim) return torch.from_numpy(emb) def collate(batch: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: """Pad variable-M prefixes to the batch max; build the valid-token mask.""" dim = batch[0].shape[-1] M = max(x.shape[0] for x in batch) b = len(batch) out = torch.zeros(b, M, dim, dtype=torch.float32) mask = torch.zeros(b, M, dtype=torch.bool) for i, x in enumerate(batch): m = x.shape[0] out[i, :m] = x mask[i, :m] = True return out, mask def linear_warmup_then_constant(step: int, warmup: int, peak: float) -> float: if step < warmup: return peak * (step + 1) / warmup return peak @torch.no_grad() def ema_update(ema: dict[str, torch.Tensor], model: torch.nn.Module, decay: float) -> None: for k, v in model.state_dict().items(): if v.dtype.is_floating_point: ema[k].mul_(decay).add_(v.detach(), alpha=1 - decay) else: ema[k].copy_(v) @torch.no_grad() def z_rl_structure(model: RLTokenAutoencoder, loader: DataLoader, device: str) -> dict: """Valid z_rl probe (label-free). NOTE: the old first-token ablation was VACUOUS here — the first prefix token is a constant special token (id 151645, std=0), so token-0 recon is trivially constant and real==shuffled regardless of z_rl quality. Instead measure (1) cross-sample cosine of z_rl (collapse: ~1 bad, ~0 diverse) and (2) PCA top-10 variance ratio (structure: higher = lower-D task manifold).""" model.eval() Z = [] for x, mask in loader: x, mask = x.to(device), mask.to(device) Z.append(model.encode(x, mask).float().cpu()) if sum(z.shape[0] for z in Z) >= 512: break Z = torch.cat(Z)[:512] Zn = torch.nn.functional.normalize(Z, dim=1) n = Z.shape[0] cos = (Zn @ Zn.T)[~torch.eye(n, dtype=torch.bool)].mean().item() s = torch.linalg.svdvals(Z - Z.mean(0)) var = s ** 2 pca10 = (var[:10].sum() / var.sum().clamp(min=1e-9)).item() return {"cos": cos, "pca10": pca10} def main() -> None: p = argparse.ArgumentParser() p.add_argument("--shard-dir", default="./encoder_cache_prefix") p.add_argument("--out", default="./checkpoints/rl_token_encoder") p.add_argument("--dim", type=int, default=2560) p.add_argument("--batch-size", type=int, default=16) p.add_argument("--num-train-steps", type=int, default=10_000) # reference default; cap by the gate p.add_argument("--peak-lr", type=float, default=5e-5) # reference p.add_argument("--warmup-steps", type=int, default=1_000) # reference p.add_argument("--clip-grad-norm", type=float, default=1.0) # reference p.add_argument("--ema-decay", type=float, default=0.999) # reference p.add_argument("--weight-decay", type=float, default=1e-4) # AdamW default-ish; ref AdamW unspecified p.add_argument("--val-frac", type=float, default=0.1) p.add_argument("--eval-every", type=int, default=500) p.add_argument("--standardize", action="store_true", help="(off=faithful) z-score features first") p.add_argument("--context-dropout", type=float, default=0.0, help="train-only: prob of zeroing each decoder teacher-forced context token, " "forcing info through z_rl (fixes latent collapse / the AR leak). 0=bare reference, 0.5=fix") p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--seed", type=int, default=0) args = p.parse_args() torch.manual_seed(args.seed) os.makedirs(os.path.dirname(os.path.abspath(args.out)) or ".", exist_ok=True) full = PrefixShards(args.shard_dir) n_val = max(1, int(len(full) * args.val_frac)) n_tr = len(full) - n_val tr, va = random_split(full, [n_tr, n_val], generator=torch.Generator().manual_seed(args.seed)) print(f"shards: {len(full)} train: {n_tr} val: {n_val} episodes: {len(set(full.episodes))}") # Optional standardization (per-feature mean/std over a sample of train shards). mean = std = None if args.standardize: acc, c = torch.zeros(args.dim), 0 sq = torch.zeros(args.dim) for idx in list(tr.indices)[:512]: x = full[idx] acc += x.sum(0); sq += (x * x).sum(0); c += x.shape[0] mean = acc / c std = (sq / c - mean**2).clamp_min(1e-6).sqrt() print("standardize ON: feature mean/std computed over", c, "tokens") def norm(x): return (x - mean) / std if mean is not None else x dl_kw = dict(batch_size=args.batch_size, collate_fn=collate, num_workers=4, pin_memory=True) tr_loader = DataLoader(tr, shuffle=True, drop_last=True, **dl_kw) va_loader = DataLoader(va, shuffle=False, **dl_kw) model = RLTokenAutoencoder(RLTokenConfig(dim=args.dim)).to(args.device) n_params = sum(p.numel() for p in model.parameters()) print(f"model params: {n_params/1e6:.1f}M device: {args.device}") opt = torch.optim.AdamW(model.parameters(), lr=args.peak_lr, weight_decay=args.weight_decay) ema = {k: v.detach().clone() for k, v in model.state_dict().items()} def save(tag: str, extra: dict) -> None: torch.save({ "model": model.state_dict(), "ema": ema, "cfg": vars(RLTokenConfig(dim=args.dim)), "mean": mean, "std": std, "args": vars(args), **extra, }, f"{args.out}_{tag}.pt") step = 0 best_val = float("inf") t0 = time.time() model.train() print("training... (their knobs: AdamW lr5e-5, warmup1k, clip1.0, ema0.999)") while step < args.num_train_steps: for x, mask in tr_loader: if step >= args.num_train_steps: break x, mask = norm(x).to(args.device), mask.to(args.device) for g in opt.param_groups: g["lr"] = linear_warmup_then_constant(step, args.warmup_steps, args.peak_lr) _, loss = model(x, mask, context_dropout=args.context_dropout) opt.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) opt.step() ema_update(ema, model, args.ema_decay) if step % 50 == 0: print(f"step {step:6d} recon={loss.item():10.3f} lr={opt.param_groups[0]['lr']:.2e}" f" {(step+1)/(time.time()-t0):.1f} it/s") if step > 0 and step % args.eval_every == 0: # eval from EMA weights (reference uses EMA for eval/save) live = {k: v.detach().clone() for k, v in model.state_dict().items()} model.load_state_dict(ema) vlosses = [] with torch.no_grad(): for vx, vm in va_loader: vx, vm = norm(vx).to(args.device), vm.to(args.device) vlosses.append(model(vx, vm)[1].item()) vmean = float(np.mean(vlosses)) st = z_rl_structure(model, va_loader, args.device) structured = st["cos"] < 0.5 and st["pca10"] > 0.3 print(f" [eval] val_recon={vmean:.3f} z_rl: cos={st['cos']:.3f} (low=diverse) " f"pca10={st['pca10']:.2%} (high=structured) " f"{'✅ structured' if structured else '⚠️ diffuse'}") if vmean < best_val: best_val = vmean save("best", {"step": step, "val_recon": vmean, "z_rl_structure": st}) model.load_state_dict(live) model.train() step += 1 model.load_state_dict(ema) save("final", {"step": step, "val_recon": best_val}) print(f"done. best val_recon={best_val:.3f}. saved {args.out}_best.pt / _final.pt") if __name__ == "__main__": main()