atharva-pantheon commited on
Commit
e9e2a68
·
verified ·
1 Parent(s): edbd124

Upload code/tsne_gate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/tsne_gate.py +100 -0
code/tsne_gate.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Phase-3 validation gate: PCA + t-SNE over z_rl (the RL Token paper's actual test).
3
+
4
+ The paper validates the encoder NOT by reconstruction error but by visualizing the
5
+ token: ordinary pick-place frames should form a *smooth curve* across episodes
6
+ (and, with failures, failed grasps / collapses cluster apart). We only have success
7
+ demos, so we test the success-side signature:
8
+ 1. STRUCTURE — PCA: how much variance lives in a few dims (low-D manifold?)
9
+ 2. SMOOTHNESS — do consecutive in-episode frames stay close in z_rl space
10
+ (smooth trajectory), vs random frame pairs?
11
+ 3. VISUAL — t-SNE 2D, colored by normalized in-episode phase (0=start→1=end);
12
+ a clean task should trace a consistent start→end sweep.
13
+
14
+ CPU-only (won't fight a GPU training run). Run:
15
+ ./lerobot/.venv/bin/python tsne_gate.py --ckpt checkpoints/rl_token_encoder_nodrop_best.pt
16
+ """
17
+ from __future__ import annotations
18
+ import argparse, glob, os
19
+ import numpy as np, torch
20
+ import matplotlib; matplotlib.use("Agg")
21
+ import matplotlib.pyplot as plt
22
+ from sklearn.decomposition import PCA
23
+ from sklearn.manifold import TSNE
24
+ from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig
25
+
26
+
27
+ def main():
28
+ p = argparse.ArgumentParser()
29
+ p.add_argument("--ckpt", default="checkpoints/rl_token_encoder_nodrop_best.pt")
30
+ p.add_argument("--shard-dir", default="./encoder_cache_prefix")
31
+ p.add_argument("--max-points", type=int, default=2000)
32
+ p.add_argument("--weights", default="ema", choices=["ema", "model"])
33
+ p.add_argument("--out", default="outputs/tsne_gate.png")
34
+ args = p.parse_args()
35
+
36
+ ck = torch.load(args.ckpt, map_location="cpu")
37
+ ae = RLTokenAutoencoder(RLTokenConfig(dim=2560))
38
+ ae.load_state_dict(ck[args.weights]); ae.eval()
39
+ print(f"loaded {args.ckpt} ({args.weights}) step={ck.get('step')} val_recon={ck.get('val_recon')}")
40
+
41
+ fs = sorted(glob.glob(os.path.join(args.shard_dir, "*.npz")))
42
+ if len(fs) > args.max_points:
43
+ fs = fs[:: len(fs) // args.max_points] # even stride across the whole set
44
+ Z, ep, step = [], [], []
45
+ with torch.no_grad():
46
+ for f in fs:
47
+ z = np.load(f)
48
+ e = torch.from_numpy(z["embeddings"].astype(np.float32))[None]
49
+ m = torch.ones(1, e.shape[1], dtype=torch.bool)
50
+ Z.append(ae.encode(e, m)[0].numpy())
51
+ ep.append(int(z["episode_id"])); step.append(int(z["step_in_episode"]))
52
+ Z = np.stack(Z); ep = np.array(ep); step = np.array(step)
53
+ print(f"encoded {len(Z)} z_rl vectors over {len(set(ep.tolist()))} episodes")
54
+
55
+ # normalized in-episode phase 0..1
56
+ phase = np.zeros(len(Z))
57
+ for e in set(ep.tolist()):
58
+ idx = ep == e; s = step[idx]
59
+ phase[idx] = (s - s.min()) / max(1, (s.max() - s.min()))
60
+
61
+ # 1) STRUCTURE
62
+ pca = PCA(n_components=min(50, Z.shape[0], Z.shape[1])).fit(Z)
63
+ ev = pca.explained_variance_ratio_
64
+ print(f"PCA var explained: top-2={ev[:2].sum():.2%} top-5={ev[:5].sum():.2%} top-10={ev[:10].sum():.2%}")
65
+
66
+ # 2) SMOOTHNESS: mean ||Δz|| between consecutive in-episode frames vs random pairs
67
+ Zn = Z / (np.linalg.norm(Z, axis=1, keepdims=True) + 1e-8)
68
+ consec, rand = [], []
69
+ for e in set(ep.tolist()):
70
+ idx = np.where(ep == e)[0]; idx = idx[np.argsort(step[idx])]
71
+ if len(idx) > 2:
72
+ consec += list(np.linalg.norm(np.diff(Zn[idx], axis=0), axis=1))
73
+ rng = np.random.default_rng(0)
74
+ for _ in range(2000):
75
+ i, j = rng.integers(0, len(Z), 2)
76
+ rand.append(np.linalg.norm(Zn[i] - Zn[j]))
77
+ sm = np.mean(consec) / (np.mean(rand) + 1e-8)
78
+ print(f"SMOOTHNESS: consec-frame dist / random-pair dist = {sm:.3f} "
79
+ f"(<<1 = smooth temporal trajectory; ~1 = no temporal structure)")
80
+
81
+ # 3) t-SNE plot
82
+ Zp = pca.transform(Z)[:, : min(50, Z.shape[1])]
83
+ emb = TSNE(n_components=2, perplexity=30, init="pca", random_state=0).fit_transform(Zp)
84
+ fig, ax = plt.subplots(1, 2, figsize=(14, 6))
85
+ s0 = ax[0].scatter(emb[:, 0], emb[:, 1], c=phase, cmap="viridis", s=8)
86
+ ax[0].set_title("t-SNE of z_rl — colored by in-episode phase (0=start→1=end)")
87
+ plt.colorbar(s0, ax=ax[0], label="phase")
88
+ s1 = ax[1].scatter(emb[:, 0], emb[:, 1], c=ep, cmap="tab20", s=8)
89
+ ax[1].set_title("colored by episode")
90
+ os.makedirs(os.path.dirname(args.out), exist_ok=True)
91
+ plt.tight_layout(); plt.savefig(args.out, dpi=110)
92
+ print(f"saved {args.out}")
93
+ # verdict heuristic
94
+ good = ev[:10].sum() > 0.5 and sm < 0.7
95
+ print("GATE:", "✅ structured + smooth (z_rl is task-informative)" if good
96
+ else "⚠️ weak structure/smoothness — inspect the plot")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()