| |
| """Train the RLT Stage-1 encoder/decoder on cached (M,2560) prefix shards. |
| |
| Uses the EXACT training knobs from the openpi reference (pravsels/openpi PR #6, |
| the `pi05_rl_token_bin_pack_coffee_capsules` TrainConfig), adapted for our |
| frozen-VLA / pre-cached-features setting (alpha = 0, so no VLA forward here): |
| |
| optimizer AdamW, clip_gradient_norm = 1.0 |
| lr schedule linear warmup (1000) -> constant peak_lr = 5e-5 |
| (their CosineDecay had peak_lr == decay_lr == 5e-5 = flat) |
| ema_decay 0.999 (eval/save from the EMA weights) |
| loss per-token squared-L2, sum over dim, mean over valid tokens, |
| targets stop-gradiented (matches rl_token_encoder forward()) |
| |
| Deviations from the reference, on purpose: |
| * batch_size > 1: their bs=1 was forced by running the full pi05 VLA each step; |
| our enc/dec is tiny and features are cached, so we batch and pad+mask. |
| * NO feature standardization (reference reconstructs raw prefix_out). A |
| --standardize escape hatch is provided but OFF by default to stay faithful. |
| |
| Run (server must be DOWN first to free VRAM): |
| ./lerobot/.venv/bin/python train_encoder.py \ |
| --shard-dir ./encoder_cache_prefix --out ./checkpoints/rl_token_encoder |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import glob |
| import os |
| import time |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader, Dataset, random_split |
|
|
| from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig |
|
|
|
|
| class PrefixShards(Dataset): |
| """Each .npz holds `embeddings` (M, dim) float16 — one cached prefix.""" |
|
|
| def __init__(self, shard_dir: str): |
| self.paths = sorted(glob.glob(os.path.join(os.path.expanduser(shard_dir), "*.npz"))) |
| if not self.paths: |
| raise FileNotFoundError(f"no .npz shards in {shard_dir}") |
| |
| |
| self.episodes = [self._ep(p) for p in self.paths] |
|
|
| @staticmethod |
| def _ep(path: str) -> int: |
| base = os.path.basename(path) |
| if base.startswith("ep"): |
| try: |
| return int(base[2:6]) |
| except ValueError: |
| pass |
| return -1 |
|
|
| def __len__(self) -> int: |
| return len(self.paths) |
|
|
| def __getitem__(self, i: int) -> torch.Tensor: |
| with np.load(self.paths[i]) as z: |
| emb = z["embeddings"].astype(np.float32) |
| return torch.from_numpy(emb) |
|
|
|
|
| def collate(batch: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: |
| """Pad variable-M prefixes to the batch max; build the valid-token mask.""" |
| dim = batch[0].shape[-1] |
| M = max(x.shape[0] for x in batch) |
| b = len(batch) |
| out = torch.zeros(b, M, dim, dtype=torch.float32) |
| mask = torch.zeros(b, M, dtype=torch.bool) |
| for i, x in enumerate(batch): |
| m = x.shape[0] |
| out[i, :m] = x |
| mask[i, :m] = True |
| return out, mask |
|
|
|
|
| def linear_warmup_then_constant(step: int, warmup: int, peak: float) -> float: |
| if step < warmup: |
| return peak * (step + 1) / warmup |
| return peak |
|
|
|
|
| @torch.no_grad() |
| def ema_update(ema: dict[str, torch.Tensor], model: torch.nn.Module, decay: float) -> None: |
| for k, v in model.state_dict().items(): |
| if v.dtype.is_floating_point: |
| ema[k].mul_(decay).add_(v.detach(), alpha=1 - decay) |
| else: |
| ema[k].copy_(v) |
|
|
|
|
| @torch.no_grad() |
| def z_rl_structure(model: RLTokenAutoencoder, loader: DataLoader, device: str) -> dict: |
| """Valid z_rl probe (label-free). NOTE: the old first-token ablation was VACUOUS |
| here — the first prefix token is a constant special token (id 151645, std=0), so |
| token-0 recon is trivially constant and real==shuffled regardless of z_rl quality. |
| Instead measure (1) cross-sample cosine of z_rl (collapse: ~1 bad, ~0 diverse) and |
| (2) PCA top-10 variance ratio (structure: higher = lower-D task manifold).""" |
| model.eval() |
| Z = [] |
| for x, mask in loader: |
| x, mask = x.to(device), mask.to(device) |
| Z.append(model.encode(x, mask).float().cpu()) |
| if sum(z.shape[0] for z in Z) >= 512: |
| break |
| Z = torch.cat(Z)[:512] |
| Zn = torch.nn.functional.normalize(Z, dim=1) |
| n = Z.shape[0] |
| cos = (Zn @ Zn.T)[~torch.eye(n, dtype=torch.bool)].mean().item() |
| s = torch.linalg.svdvals(Z - Z.mean(0)) |
| var = s ** 2 |
| pca10 = (var[:10].sum() / var.sum().clamp(min=1e-9)).item() |
| return {"cos": cos, "pca10": pca10} |
|
|
|
|
| def main() -> None: |
| p = argparse.ArgumentParser() |
| p.add_argument("--shard-dir", default="./encoder_cache_prefix") |
| p.add_argument("--out", default="./checkpoints/rl_token_encoder") |
| p.add_argument("--dim", type=int, default=2560) |
| p.add_argument("--batch-size", type=int, default=16) |
| p.add_argument("--num-train-steps", type=int, default=10_000) |
| p.add_argument("--peak-lr", type=float, default=5e-5) |
| p.add_argument("--warmup-steps", type=int, default=1_000) |
| p.add_argument("--clip-grad-norm", type=float, default=1.0) |
| p.add_argument("--ema-decay", type=float, default=0.999) |
| p.add_argument("--weight-decay", type=float, default=1e-4) |
| p.add_argument("--val-frac", type=float, default=0.1) |
| p.add_argument("--eval-every", type=int, default=500) |
| p.add_argument("--standardize", action="store_true", help="(off=faithful) z-score features first") |
| p.add_argument("--context-dropout", type=float, default=0.0, |
| help="train-only: prob of zeroing each decoder teacher-forced context token, " |
| "forcing info through z_rl (fixes latent collapse / the AR leak). 0=bare reference, 0.5=fix") |
| p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| p.add_argument("--seed", type=int, default=0) |
| args = p.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| os.makedirs(os.path.dirname(os.path.abspath(args.out)) or ".", exist_ok=True) |
|
|
| full = PrefixShards(args.shard_dir) |
| n_val = max(1, int(len(full) * args.val_frac)) |
| n_tr = len(full) - n_val |
| tr, va = random_split(full, [n_tr, n_val], generator=torch.Generator().manual_seed(args.seed)) |
| print(f"shards: {len(full)} train: {n_tr} val: {n_val} episodes: {len(set(full.episodes))}") |
|
|
| |
| mean = std = None |
| if args.standardize: |
| acc, c = torch.zeros(args.dim), 0 |
| sq = torch.zeros(args.dim) |
| for idx in list(tr.indices)[:512]: |
| x = full[idx] |
| acc += x.sum(0); sq += (x * x).sum(0); c += x.shape[0] |
| mean = acc / c |
| std = (sq / c - mean**2).clamp_min(1e-6).sqrt() |
| print("standardize ON: feature mean/std computed over", c, "tokens") |
|
|
| def norm(x): |
| return (x - mean) / std if mean is not None else x |
|
|
| dl_kw = dict(batch_size=args.batch_size, collate_fn=collate, num_workers=4, pin_memory=True) |
| tr_loader = DataLoader(tr, shuffle=True, drop_last=True, **dl_kw) |
| va_loader = DataLoader(va, shuffle=False, **dl_kw) |
|
|
| model = RLTokenAutoencoder(RLTokenConfig(dim=args.dim)).to(args.device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"model params: {n_params/1e6:.1f}M device: {args.device}") |
| opt = torch.optim.AdamW(model.parameters(), lr=args.peak_lr, weight_decay=args.weight_decay) |
| ema = {k: v.detach().clone() for k, v in model.state_dict().items()} |
|
|
| def save(tag: str, extra: dict) -> None: |
| torch.save({ |
| "model": model.state_dict(), |
| "ema": ema, |
| "cfg": vars(RLTokenConfig(dim=args.dim)), |
| "mean": mean, "std": std, |
| "args": vars(args), |
| **extra, |
| }, f"{args.out}_{tag}.pt") |
|
|
| step = 0 |
| best_val = float("inf") |
| t0 = time.time() |
| model.train() |
| print("training... (their knobs: AdamW lr5e-5, warmup1k, clip1.0, ema0.999)") |
| while step < args.num_train_steps: |
| for x, mask in tr_loader: |
| if step >= args.num_train_steps: |
| break |
| x, mask = norm(x).to(args.device), mask.to(args.device) |
| for g in opt.param_groups: |
| g["lr"] = linear_warmup_then_constant(step, args.warmup_steps, args.peak_lr) |
| _, loss = model(x, mask, context_dropout=args.context_dropout) |
| opt.zero_grad(set_to_none=True) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) |
| opt.step() |
| ema_update(ema, model, args.ema_decay) |
|
|
| if step % 50 == 0: |
| print(f"step {step:6d} recon={loss.item():10.3f} lr={opt.param_groups[0]['lr']:.2e}" |
| f" {(step+1)/(time.time()-t0):.1f} it/s") |
| if step > 0 and step % args.eval_every == 0: |
| |
| live = {k: v.detach().clone() for k, v in model.state_dict().items()} |
| model.load_state_dict(ema) |
| vlosses = [] |
| with torch.no_grad(): |
| for vx, vm in va_loader: |
| vx, vm = norm(vx).to(args.device), vm.to(args.device) |
| vlosses.append(model(vx, vm)[1].item()) |
| vmean = float(np.mean(vlosses)) |
| st = z_rl_structure(model, va_loader, args.device) |
| structured = st["cos"] < 0.5 and st["pca10"] > 0.3 |
| print(f" [eval] val_recon={vmean:.3f} z_rl: cos={st['cos']:.3f} (low=diverse) " |
| f"pca10={st['pca10']:.2%} (high=structured) " |
| f"{'✅ structured' if structured else '⚠️ diffuse'}") |
| if vmean < best_val: |
| best_val = vmean |
| save("best", {"step": step, "val_recon": vmean, "z_rl_structure": st}) |
| model.load_state_dict(live) |
| model.train() |
| step += 1 |
|
|
| model.load_state_dict(ema) |
| save("final", {"step": step, "val_recon": best_val}) |
| print(f"done. best val_recon={best_val:.3f}. saved {args.out}_best.pt / _final.pt") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|