File size: 4,729 Bytes
e9e2a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python3
"""Phase-3 validation gate: PCA + t-SNE over z_rl (the RL Token paper's actual test).

The paper validates the encoder NOT by reconstruction error but by visualizing the
token: ordinary pick-place frames should form a *smooth curve* across episodes
(and, with failures, failed grasps / collapses cluster apart). We only have success
demos, so we test the success-side signature:
  1. STRUCTURE   — PCA: how much variance lives in a few dims (low-D manifold?)
  2. SMOOTHNESS  — do consecutive in-episode frames stay close in z_rl space
                   (smooth trajectory), vs random frame pairs?
  3. VISUAL      — t-SNE 2D, colored by normalized in-episode phase (0=start→1=end);
                   a clean task should trace a consistent start→end sweep.

CPU-only (won't fight a GPU training run). Run:
  ./lerobot/.venv/bin/python tsne_gate.py --ckpt checkpoints/rl_token_encoder_nodrop_best.pt
"""
from __future__ import annotations
import argparse, glob, os
import numpy as np, torch
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--ckpt", default="checkpoints/rl_token_encoder_nodrop_best.pt")
    p.add_argument("--shard-dir", default="./encoder_cache_prefix")
    p.add_argument("--max-points", type=int, default=2000)
    p.add_argument("--weights", default="ema", choices=["ema", "model"])
    p.add_argument("--out", default="outputs/tsne_gate.png")
    args = p.parse_args()

    ck = torch.load(args.ckpt, map_location="cpu")
    ae = RLTokenAutoencoder(RLTokenConfig(dim=2560))
    ae.load_state_dict(ck[args.weights]); ae.eval()
    print(f"loaded {args.ckpt} ({args.weights})  step={ck.get('step')}  val_recon={ck.get('val_recon')}")

    fs = sorted(glob.glob(os.path.join(args.shard_dir, "*.npz")))
    if len(fs) > args.max_points:
        fs = fs[:: len(fs) // args.max_points]  # even stride across the whole set
    Z, ep, step = [], [], []
    with torch.no_grad():
        for f in fs:
            z = np.load(f)
            e = torch.from_numpy(z["embeddings"].astype(np.float32))[None]
            m = torch.ones(1, e.shape[1], dtype=torch.bool)
            Z.append(ae.encode(e, m)[0].numpy())
            ep.append(int(z["episode_id"])); step.append(int(z["step_in_episode"]))
    Z = np.stack(Z); ep = np.array(ep); step = np.array(step)
    print(f"encoded {len(Z)} z_rl vectors over {len(set(ep.tolist()))} episodes")

    # normalized in-episode phase 0..1
    phase = np.zeros(len(Z))
    for e in set(ep.tolist()):
        idx = ep == e; s = step[idx]
        phase[idx] = (s - s.min()) / max(1, (s.max() - s.min()))

    # 1) STRUCTURE
    pca = PCA(n_components=min(50, Z.shape[0], Z.shape[1])).fit(Z)
    ev = pca.explained_variance_ratio_
    print(f"PCA var explained: top-2={ev[:2].sum():.2%}  top-5={ev[:5].sum():.2%}  top-10={ev[:10].sum():.2%}")

    # 2) SMOOTHNESS: mean ||Δz|| between consecutive in-episode frames vs random pairs
    Zn = Z / (np.linalg.norm(Z, axis=1, keepdims=True) + 1e-8)
    consec, rand = [], []
    for e in set(ep.tolist()):
        idx = np.where(ep == e)[0]; idx = idx[np.argsort(step[idx])]
        if len(idx) > 2:
            consec += list(np.linalg.norm(np.diff(Zn[idx], axis=0), axis=1))
    rng = np.random.default_rng(0)
    for _ in range(2000):
        i, j = rng.integers(0, len(Z), 2)
        rand.append(np.linalg.norm(Zn[i] - Zn[j]))
    sm = np.mean(consec) / (np.mean(rand) + 1e-8)
    print(f"SMOOTHNESS: consec-frame dist / random-pair dist = {sm:.3f}  "
          f"(<<1 = smooth temporal trajectory; ~1 = no temporal structure)")

    # 3) t-SNE plot
    Zp = pca.transform(Z)[:, : min(50, Z.shape[1])]
    emb = TSNE(n_components=2, perplexity=30, init="pca", random_state=0).fit_transform(Zp)
    fig, ax = plt.subplots(1, 2, figsize=(14, 6))
    s0 = ax[0].scatter(emb[:, 0], emb[:, 1], c=phase, cmap="viridis", s=8)
    ax[0].set_title("t-SNE of z_rl — colored by in-episode phase (0=start→1=end)")
    plt.colorbar(s0, ax=ax[0], label="phase")
    s1 = ax[1].scatter(emb[:, 0], emb[:, 1], c=ep, cmap="tab20", s=8)
    ax[1].set_title("colored by episode")
    os.makedirs(os.path.dirname(args.out), exist_ok=True)
    plt.tight_layout(); plt.savefig(args.out, dpi=110)
    print(f"saved {args.out}")
    # verdict heuristic
    good = ev[:10].sum() > 0.5 and sm < 0.7
    print("GATE:", "✅ structured + smooth (z_rl is task-informative)" if good
          else "⚠️ weak structure/smoothness — inspect the plot")


if __name__ == "__main__":
    main()