#!/usr/bin/env python3 """Phase-3 validation gate: PCA + t-SNE over z_rl (the RL Token paper's actual test). The paper validates the encoder NOT by reconstruction error but by visualizing the token: ordinary pick-place frames should form a *smooth curve* across episodes (and, with failures, failed grasps / collapses cluster apart). We only have success demos, so we test the success-side signature: 1. STRUCTURE — PCA: how much variance lives in a few dims (low-D manifold?) 2. SMOOTHNESS — do consecutive in-episode frames stay close in z_rl space (smooth trajectory), vs random frame pairs? 3. VISUAL — t-SNE 2D, colored by normalized in-episode phase (0=start→1=end); a clean task should trace a consistent start→end sweep. CPU-only (won't fight a GPU training run). Run: ./lerobot/.venv/bin/python tsne_gate.py --ckpt checkpoints/rl_token_encoder_nodrop_best.pt """ from __future__ import annotations import argparse, glob, os import numpy as np, torch import matplotlib; matplotlib.use("Agg") import matplotlib.pyplot as plt from sklearn.decomposition import PCA from sklearn.manifold import TSNE from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig def main(): p = argparse.ArgumentParser() p.add_argument("--ckpt", default="checkpoints/rl_token_encoder_nodrop_best.pt") p.add_argument("--shard-dir", default="./encoder_cache_prefix") p.add_argument("--max-points", type=int, default=2000) p.add_argument("--weights", default="ema", choices=["ema", "model"]) p.add_argument("--out", default="outputs/tsne_gate.png") args = p.parse_args() ck = torch.load(args.ckpt, map_location="cpu") ae = RLTokenAutoencoder(RLTokenConfig(dim=2560)) ae.load_state_dict(ck[args.weights]); ae.eval() print(f"loaded {args.ckpt} ({args.weights}) step={ck.get('step')} val_recon={ck.get('val_recon')}") fs = sorted(glob.glob(os.path.join(args.shard_dir, "*.npz"))) if len(fs) > args.max_points: fs = fs[:: len(fs) // args.max_points] # even stride across the whole set Z, ep, step = [], [], [] with torch.no_grad(): for f in fs: z = np.load(f) e = torch.from_numpy(z["embeddings"].astype(np.float32))[None] m = torch.ones(1, e.shape[1], dtype=torch.bool) Z.append(ae.encode(e, m)[0].numpy()) ep.append(int(z["episode_id"])); step.append(int(z["step_in_episode"])) Z = np.stack(Z); ep = np.array(ep); step = np.array(step) print(f"encoded {len(Z)} z_rl vectors over {len(set(ep.tolist()))} episodes") # normalized in-episode phase 0..1 phase = np.zeros(len(Z)) for e in set(ep.tolist()): idx = ep == e; s = step[idx] phase[idx] = (s - s.min()) / max(1, (s.max() - s.min())) # 1) STRUCTURE pca = PCA(n_components=min(50, Z.shape[0], Z.shape[1])).fit(Z) ev = pca.explained_variance_ratio_ print(f"PCA var explained: top-2={ev[:2].sum():.2%} top-5={ev[:5].sum():.2%} top-10={ev[:10].sum():.2%}") # 2) SMOOTHNESS: mean ||Δz|| between consecutive in-episode frames vs random pairs Zn = Z / (np.linalg.norm(Z, axis=1, keepdims=True) + 1e-8) consec, rand = [], [] for e in set(ep.tolist()): idx = np.where(ep == e)[0]; idx = idx[np.argsort(step[idx])] if len(idx) > 2: consec += list(np.linalg.norm(np.diff(Zn[idx], axis=0), axis=1)) rng = np.random.default_rng(0) for _ in range(2000): i, j = rng.integers(0, len(Z), 2) rand.append(np.linalg.norm(Zn[i] - Zn[j])) sm = np.mean(consec) / (np.mean(rand) + 1e-8) print(f"SMOOTHNESS: consec-frame dist / random-pair dist = {sm:.3f} " f"(<<1 = smooth temporal trajectory; ~1 = no temporal structure)") # 3) t-SNE plot Zp = pca.transform(Z)[:, : min(50, Z.shape[1])] emb = TSNE(n_components=2, perplexity=30, init="pca", random_state=0).fit_transform(Zp) fig, ax = plt.subplots(1, 2, figsize=(14, 6)) s0 = ax[0].scatter(emb[:, 0], emb[:, 1], c=phase, cmap="viridis", s=8) ax[0].set_title("t-SNE of z_rl — colored by in-episode phase (0=start→1=end)") plt.colorbar(s0, ax=ax[0], label="phase") s1 = ax[1].scatter(emb[:, 0], emb[:, 1], c=ep, cmap="tab20", s=8) ax[1].set_title("colored by episode") os.makedirs(os.path.dirname(args.out), exist_ok=True) plt.tight_layout(); plt.savefig(args.out, dpi=110) print(f"saved {args.out}") # verdict heuristic good = ev[:10].sum() > 0.5 and sm < 0.7 print("GATE:", "✅ structured + smooth (z_rl is task-informative)" if good else "⚠️ weak structure/smoothness — inspect the plot") if __name__ == "__main__": main()