| |
| """Phase-3 validation gate: PCA + t-SNE over z_rl (the RL Token paper's actual test). |
| |
| The paper validates the encoder NOT by reconstruction error but by visualizing the |
| token: ordinary pick-place frames should form a *smooth curve* across episodes |
| (and, with failures, failed grasps / collapses cluster apart). We only have success |
| demos, so we test the success-side signature: |
| 1. STRUCTURE — PCA: how much variance lives in a few dims (low-D manifold?) |
| 2. SMOOTHNESS — do consecutive in-episode frames stay close in z_rl space |
| (smooth trajectory), vs random frame pairs? |
| 3. VISUAL — t-SNE 2D, colored by normalized in-episode phase (0=start→1=end); |
| a clean task should trace a consistent start→end sweep. |
| |
| CPU-only (won't fight a GPU training run). Run: |
| ./lerobot/.venv/bin/python tsne_gate.py --ckpt checkpoints/rl_token_encoder_nodrop_best.pt |
| """ |
| from __future__ import annotations |
| import argparse, glob, os |
| import numpy as np, torch |
| import matplotlib; matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from sklearn.decomposition import PCA |
| from sklearn.manifold import TSNE |
| from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt", default="checkpoints/rl_token_encoder_nodrop_best.pt") |
| p.add_argument("--shard-dir", default="./encoder_cache_prefix") |
| p.add_argument("--max-points", type=int, default=2000) |
| p.add_argument("--weights", default="ema", choices=["ema", "model"]) |
| p.add_argument("--out", default="outputs/tsne_gate.png") |
| args = p.parse_args() |
|
|
| ck = torch.load(args.ckpt, map_location="cpu") |
| ae = RLTokenAutoencoder(RLTokenConfig(dim=2560)) |
| ae.load_state_dict(ck[args.weights]); ae.eval() |
| print(f"loaded {args.ckpt} ({args.weights}) step={ck.get('step')} val_recon={ck.get('val_recon')}") |
|
|
| fs = sorted(glob.glob(os.path.join(args.shard_dir, "*.npz"))) |
| if len(fs) > args.max_points: |
| fs = fs[:: len(fs) // args.max_points] |
| Z, ep, step = [], [], [] |
| with torch.no_grad(): |
| for f in fs: |
| z = np.load(f) |
| e = torch.from_numpy(z["embeddings"].astype(np.float32))[None] |
| m = torch.ones(1, e.shape[1], dtype=torch.bool) |
| Z.append(ae.encode(e, m)[0].numpy()) |
| ep.append(int(z["episode_id"])); step.append(int(z["step_in_episode"])) |
| Z = np.stack(Z); ep = np.array(ep); step = np.array(step) |
| print(f"encoded {len(Z)} z_rl vectors over {len(set(ep.tolist()))} episodes") |
|
|
| |
| phase = np.zeros(len(Z)) |
| for e in set(ep.tolist()): |
| idx = ep == e; s = step[idx] |
| phase[idx] = (s - s.min()) / max(1, (s.max() - s.min())) |
|
|
| |
| pca = PCA(n_components=min(50, Z.shape[0], Z.shape[1])).fit(Z) |
| ev = pca.explained_variance_ratio_ |
| print(f"PCA var explained: top-2={ev[:2].sum():.2%} top-5={ev[:5].sum():.2%} top-10={ev[:10].sum():.2%}") |
|
|
| |
| Zn = Z / (np.linalg.norm(Z, axis=1, keepdims=True) + 1e-8) |
| consec, rand = [], [] |
| for e in set(ep.tolist()): |
| idx = np.where(ep == e)[0]; idx = idx[np.argsort(step[idx])] |
| if len(idx) > 2: |
| consec += list(np.linalg.norm(np.diff(Zn[idx], axis=0), axis=1)) |
| rng = np.random.default_rng(0) |
| for _ in range(2000): |
| i, j = rng.integers(0, len(Z), 2) |
| rand.append(np.linalg.norm(Zn[i] - Zn[j])) |
| sm = np.mean(consec) / (np.mean(rand) + 1e-8) |
| print(f"SMOOTHNESS: consec-frame dist / random-pair dist = {sm:.3f} " |
| f"(<<1 = smooth temporal trajectory; ~1 = no temporal structure)") |
|
|
| |
| Zp = pca.transform(Z)[:, : min(50, Z.shape[1])] |
| emb = TSNE(n_components=2, perplexity=30, init="pca", random_state=0).fit_transform(Zp) |
| fig, ax = plt.subplots(1, 2, figsize=(14, 6)) |
| s0 = ax[0].scatter(emb[:, 0], emb[:, 1], c=phase, cmap="viridis", s=8) |
| ax[0].set_title("t-SNE of z_rl — colored by in-episode phase (0=start→1=end)") |
| plt.colorbar(s0, ax=ax[0], label="phase") |
| s1 = ax[1].scatter(emb[:, 0], emb[:, 1], c=ep, cmap="tab20", s=8) |
| ax[1].set_title("colored by episode") |
| os.makedirs(os.path.dirname(args.out), exist_ok=True) |
| plt.tight_layout(); plt.savefig(args.out, dpi=110) |
| print(f"saved {args.out}") |
| |
| good = ev[:10].sum() > 0.5 and sm < 0.7 |
| print("GATE:", "✅ structured + smooth (z_rl is task-informative)" if good |
| else "⚠️ weak structure/smoothness — inspect the plot") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|