File size: 4,729 Bytes
e9e2a68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | #!/usr/bin/env python3
"""Phase-3 validation gate: PCA + t-SNE over z_rl (the RL Token paper's actual test).
The paper validates the encoder NOT by reconstruction error but by visualizing the
token: ordinary pick-place frames should form a *smooth curve* across episodes
(and, with failures, failed grasps / collapses cluster apart). We only have success
demos, so we test the success-side signature:
1. STRUCTURE — PCA: how much variance lives in a few dims (low-D manifold?)
2. SMOOTHNESS — do consecutive in-episode frames stay close in z_rl space
(smooth trajectory), vs random frame pairs?
3. VISUAL — t-SNE 2D, colored by normalized in-episode phase (0=start→1=end);
a clean task should trace a consistent start→end sweep.
CPU-only (won't fight a GPU training run). Run:
./lerobot/.venv/bin/python tsne_gate.py --ckpt checkpoints/rl_token_encoder_nodrop_best.pt
"""
from __future__ import annotations
import argparse, glob, os
import numpy as np, torch
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", default="checkpoints/rl_token_encoder_nodrop_best.pt")
p.add_argument("--shard-dir", default="./encoder_cache_prefix")
p.add_argument("--max-points", type=int, default=2000)
p.add_argument("--weights", default="ema", choices=["ema", "model"])
p.add_argument("--out", default="outputs/tsne_gate.png")
args = p.parse_args()
ck = torch.load(args.ckpt, map_location="cpu")
ae = RLTokenAutoencoder(RLTokenConfig(dim=2560))
ae.load_state_dict(ck[args.weights]); ae.eval()
print(f"loaded {args.ckpt} ({args.weights}) step={ck.get('step')} val_recon={ck.get('val_recon')}")
fs = sorted(glob.glob(os.path.join(args.shard_dir, "*.npz")))
if len(fs) > args.max_points:
fs = fs[:: len(fs) // args.max_points] # even stride across the whole set
Z, ep, step = [], [], []
with torch.no_grad():
for f in fs:
z = np.load(f)
e = torch.from_numpy(z["embeddings"].astype(np.float32))[None]
m = torch.ones(1, e.shape[1], dtype=torch.bool)
Z.append(ae.encode(e, m)[0].numpy())
ep.append(int(z["episode_id"])); step.append(int(z["step_in_episode"]))
Z = np.stack(Z); ep = np.array(ep); step = np.array(step)
print(f"encoded {len(Z)} z_rl vectors over {len(set(ep.tolist()))} episodes")
# normalized in-episode phase 0..1
phase = np.zeros(len(Z))
for e in set(ep.tolist()):
idx = ep == e; s = step[idx]
phase[idx] = (s - s.min()) / max(1, (s.max() - s.min()))
# 1) STRUCTURE
pca = PCA(n_components=min(50, Z.shape[0], Z.shape[1])).fit(Z)
ev = pca.explained_variance_ratio_
print(f"PCA var explained: top-2={ev[:2].sum():.2%} top-5={ev[:5].sum():.2%} top-10={ev[:10].sum():.2%}")
# 2) SMOOTHNESS: mean ||Δz|| between consecutive in-episode frames vs random pairs
Zn = Z / (np.linalg.norm(Z, axis=1, keepdims=True) + 1e-8)
consec, rand = [], []
for e in set(ep.tolist()):
idx = np.where(ep == e)[0]; idx = idx[np.argsort(step[idx])]
if len(idx) > 2:
consec += list(np.linalg.norm(np.diff(Zn[idx], axis=0), axis=1))
rng = np.random.default_rng(0)
for _ in range(2000):
i, j = rng.integers(0, len(Z), 2)
rand.append(np.linalg.norm(Zn[i] - Zn[j]))
sm = np.mean(consec) / (np.mean(rand) + 1e-8)
print(f"SMOOTHNESS: consec-frame dist / random-pair dist = {sm:.3f} "
f"(<<1 = smooth temporal trajectory; ~1 = no temporal structure)")
# 3) t-SNE plot
Zp = pca.transform(Z)[:, : min(50, Z.shape[1])]
emb = TSNE(n_components=2, perplexity=30, init="pca", random_state=0).fit_transform(Zp)
fig, ax = plt.subplots(1, 2, figsize=(14, 6))
s0 = ax[0].scatter(emb[:, 0], emb[:, 1], c=phase, cmap="viridis", s=8)
ax[0].set_title("t-SNE of z_rl — colored by in-episode phase (0=start→1=end)")
plt.colorbar(s0, ax=ax[0], label="phase")
s1 = ax[1].scatter(emb[:, 0], emb[:, 1], c=ep, cmap="tab20", s=8)
ax[1].set_title("colored by episode")
os.makedirs(os.path.dirname(args.out), exist_ok=True)
plt.tight_layout(); plt.savefig(args.out, dpi=110)
print(f"saved {args.out}")
# verdict heuristic
good = ev[:10].sum() > 0.5 and sm < 0.7
print("GATE:", "✅ structured + smooth (z_rl is task-informative)" if good
else "⚠️ weak structure/smoothness — inspect the plot")
if __name__ == "__main__":
main()
|