atharva-pantheon commited on
Commit
371dfea
·
verified ·
1 Parent(s): f9042a0

Upload code/train_encoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/train_encoder.py +238 -0
code/train_encoder.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Train the RLT Stage-1 encoder/decoder on cached (M,2560) prefix shards.
3
+
4
+ Uses the EXACT training knobs from the openpi reference (pravsels/openpi PR #6,
5
+ the `pi05_rl_token_bin_pack_coffee_capsules` TrainConfig), adapted for our
6
+ frozen-VLA / pre-cached-features setting (alpha = 0, so no VLA forward here):
7
+
8
+ optimizer AdamW, clip_gradient_norm = 1.0
9
+ lr schedule linear warmup (1000) -> constant peak_lr = 5e-5
10
+ (their CosineDecay had peak_lr == decay_lr == 5e-5 = flat)
11
+ ema_decay 0.999 (eval/save from the EMA weights)
12
+ loss per-token squared-L2, sum over dim, mean over valid tokens,
13
+ targets stop-gradiented (matches rl_token_encoder forward())
14
+
15
+ Deviations from the reference, on purpose:
16
+ * batch_size > 1: their bs=1 was forced by running the full pi05 VLA each step;
17
+ our enc/dec is tiny and features are cached, so we batch and pad+mask.
18
+ * NO feature standardization (reference reconstructs raw prefix_out). A
19
+ --standardize escape hatch is provided but OFF by default to stay faithful.
20
+
21
+ Run (server must be DOWN first to free VRAM):
22
+ ./lerobot/.venv/bin/python train_encoder.py \
23
+ --shard-dir ./encoder_cache_prefix --out ./checkpoints/rl_token_encoder
24
+ """
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import glob
29
+ import os
30
+ import time
31
+
32
+ import numpy as np
33
+ import torch
34
+ from torch.utils.data import DataLoader, Dataset, random_split
35
+
36
+ from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig
37
+
38
+
39
+ class PrefixShards(Dataset):
40
+ """Each .npz holds `embeddings` (M, dim) float16 — one cached prefix."""
41
+
42
+ def __init__(self, shard_dir: str):
43
+ self.paths = sorted(glob.glob(os.path.join(os.path.expanduser(shard_dir), "*.npz")))
44
+ if not self.paths:
45
+ raise FileNotFoundError(f"no .npz shards in {shard_dir}")
46
+ # episode_id per shard (parsed from filename ep{NNNN}_...) for the
47
+ # success/failure t-SNE gate later; cheap to keep around.
48
+ self.episodes = [self._ep(p) for p in self.paths]
49
+
50
+ @staticmethod
51
+ def _ep(path: str) -> int:
52
+ base = os.path.basename(path)
53
+ if base.startswith("ep"):
54
+ try:
55
+ return int(base[2:6])
56
+ except ValueError:
57
+ pass
58
+ return -1
59
+
60
+ def __len__(self) -> int:
61
+ return len(self.paths)
62
+
63
+ def __getitem__(self, i: int) -> torch.Tensor:
64
+ with np.load(self.paths[i]) as z:
65
+ emb = z["embeddings"].astype(np.float32) # (M, dim)
66
+ return torch.from_numpy(emb)
67
+
68
+
69
+ def collate(batch: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
70
+ """Pad variable-M prefixes to the batch max; build the valid-token mask."""
71
+ dim = batch[0].shape[-1]
72
+ M = max(x.shape[0] for x in batch)
73
+ b = len(batch)
74
+ out = torch.zeros(b, M, dim, dtype=torch.float32)
75
+ mask = torch.zeros(b, M, dtype=torch.bool)
76
+ for i, x in enumerate(batch):
77
+ m = x.shape[0]
78
+ out[i, :m] = x
79
+ mask[i, :m] = True
80
+ return out, mask
81
+
82
+
83
+ def linear_warmup_then_constant(step: int, warmup: int, peak: float) -> float:
84
+ if step < warmup:
85
+ return peak * (step + 1) / warmup
86
+ return peak
87
+
88
+
89
+ @torch.no_grad()
90
+ def ema_update(ema: dict[str, torch.Tensor], model: torch.nn.Module, decay: float) -> None:
91
+ for k, v in model.state_dict().items():
92
+ if v.dtype.is_floating_point:
93
+ ema[k].mul_(decay).add_(v.detach(), alpha=1 - decay)
94
+ else:
95
+ ema[k].copy_(v)
96
+
97
+
98
+ @torch.no_grad()
99
+ def z_rl_structure(model: RLTokenAutoencoder, loader: DataLoader, device: str) -> dict:
100
+ """Valid z_rl probe (label-free). NOTE: the old first-token ablation was VACUOUS
101
+ here — the first prefix token is a constant special token (id 151645, std=0), so
102
+ token-0 recon is trivially constant and real==shuffled regardless of z_rl quality.
103
+ Instead measure (1) cross-sample cosine of z_rl (collapse: ~1 bad, ~0 diverse) and
104
+ (2) PCA top-10 variance ratio (structure: higher = lower-D task manifold)."""
105
+ model.eval()
106
+ Z = []
107
+ for x, mask in loader:
108
+ x, mask = x.to(device), mask.to(device)
109
+ Z.append(model.encode(x, mask).float().cpu())
110
+ if sum(z.shape[0] for z in Z) >= 512:
111
+ break
112
+ Z = torch.cat(Z)[:512]
113
+ Zn = torch.nn.functional.normalize(Z, dim=1)
114
+ n = Z.shape[0]
115
+ cos = (Zn @ Zn.T)[~torch.eye(n, dtype=torch.bool)].mean().item()
116
+ s = torch.linalg.svdvals(Z - Z.mean(0))
117
+ var = s ** 2
118
+ pca10 = (var[:10].sum() / var.sum().clamp(min=1e-9)).item()
119
+ return {"cos": cos, "pca10": pca10}
120
+
121
+
122
+ def main() -> None:
123
+ p = argparse.ArgumentParser()
124
+ p.add_argument("--shard-dir", default="./encoder_cache_prefix")
125
+ p.add_argument("--out", default="./checkpoints/rl_token_encoder")
126
+ p.add_argument("--dim", type=int, default=2560)
127
+ p.add_argument("--batch-size", type=int, default=16)
128
+ p.add_argument("--num-train-steps", type=int, default=10_000) # reference default; cap by the gate
129
+ p.add_argument("--peak-lr", type=float, default=5e-5) # reference
130
+ p.add_argument("--warmup-steps", type=int, default=1_000) # reference
131
+ p.add_argument("--clip-grad-norm", type=float, default=1.0) # reference
132
+ p.add_argument("--ema-decay", type=float, default=0.999) # reference
133
+ p.add_argument("--weight-decay", type=float, default=1e-4) # AdamW default-ish; ref AdamW unspecified
134
+ p.add_argument("--val-frac", type=float, default=0.1)
135
+ p.add_argument("--eval-every", type=int, default=500)
136
+ p.add_argument("--standardize", action="store_true", help="(off=faithful) z-score features first")
137
+ p.add_argument("--context-dropout", type=float, default=0.0,
138
+ help="train-only: prob of zeroing each decoder teacher-forced context token, "
139
+ "forcing info through z_rl (fixes latent collapse / the AR leak). 0=bare reference, 0.5=fix")
140
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
141
+ p.add_argument("--seed", type=int, default=0)
142
+ args = p.parse_args()
143
+
144
+ torch.manual_seed(args.seed)
145
+ os.makedirs(os.path.dirname(os.path.abspath(args.out)) or ".", exist_ok=True)
146
+
147
+ full = PrefixShards(args.shard_dir)
148
+ n_val = max(1, int(len(full) * args.val_frac))
149
+ n_tr = len(full) - n_val
150
+ tr, va = random_split(full, [n_tr, n_val], generator=torch.Generator().manual_seed(args.seed))
151
+ print(f"shards: {len(full)} train: {n_tr} val: {n_val} episodes: {len(set(full.episodes))}")
152
+
153
+ # Optional standardization (per-feature mean/std over a sample of train shards).
154
+ mean = std = None
155
+ if args.standardize:
156
+ acc, c = torch.zeros(args.dim), 0
157
+ sq = torch.zeros(args.dim)
158
+ for idx in list(tr.indices)[:512]:
159
+ x = full[idx]
160
+ acc += x.sum(0); sq += (x * x).sum(0); c += x.shape[0]
161
+ mean = acc / c
162
+ std = (sq / c - mean**2).clamp_min(1e-6).sqrt()
163
+ print("standardize ON: feature mean/std computed over", c, "tokens")
164
+
165
+ def norm(x):
166
+ return (x - mean) / std if mean is not None else x
167
+
168
+ dl_kw = dict(batch_size=args.batch_size, collate_fn=collate, num_workers=4, pin_memory=True)
169
+ tr_loader = DataLoader(tr, shuffle=True, drop_last=True, **dl_kw)
170
+ va_loader = DataLoader(va, shuffle=False, **dl_kw)
171
+
172
+ model = RLTokenAutoencoder(RLTokenConfig(dim=args.dim)).to(args.device)
173
+ n_params = sum(p.numel() for p in model.parameters())
174
+ print(f"model params: {n_params/1e6:.1f}M device: {args.device}")
175
+ opt = torch.optim.AdamW(model.parameters(), lr=args.peak_lr, weight_decay=args.weight_decay)
176
+ ema = {k: v.detach().clone() for k, v in model.state_dict().items()}
177
+
178
+ def save(tag: str, extra: dict) -> None:
179
+ torch.save({
180
+ "model": model.state_dict(),
181
+ "ema": ema,
182
+ "cfg": vars(RLTokenConfig(dim=args.dim)),
183
+ "mean": mean, "std": std,
184
+ "args": vars(args),
185
+ **extra,
186
+ }, f"{args.out}_{tag}.pt")
187
+
188
+ step = 0
189
+ best_val = float("inf")
190
+ t0 = time.time()
191
+ model.train()
192
+ print("training... (their knobs: AdamW lr5e-5, warmup1k, clip1.0, ema0.999)")
193
+ while step < args.num_train_steps:
194
+ for x, mask in tr_loader:
195
+ if step >= args.num_train_steps:
196
+ break
197
+ x, mask = norm(x).to(args.device), mask.to(args.device)
198
+ for g in opt.param_groups:
199
+ g["lr"] = linear_warmup_then_constant(step, args.warmup_steps, args.peak_lr)
200
+ _, loss = model(x, mask, context_dropout=args.context_dropout)
201
+ opt.zero_grad(set_to_none=True)
202
+ loss.backward()
203
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
204
+ opt.step()
205
+ ema_update(ema, model, args.ema_decay)
206
+
207
+ if step % 50 == 0:
208
+ print(f"step {step:6d} recon={loss.item():10.3f} lr={opt.param_groups[0]['lr']:.2e}"
209
+ f" {(step+1)/(time.time()-t0):.1f} it/s")
210
+ if step > 0 and step % args.eval_every == 0:
211
+ # eval from EMA weights (reference uses EMA for eval/save)
212
+ live = {k: v.detach().clone() for k, v in model.state_dict().items()}
213
+ model.load_state_dict(ema)
214
+ vlosses = []
215
+ with torch.no_grad():
216
+ for vx, vm in va_loader:
217
+ vx, vm = norm(vx).to(args.device), vm.to(args.device)
218
+ vlosses.append(model(vx, vm)[1].item())
219
+ vmean = float(np.mean(vlosses))
220
+ st = z_rl_structure(model, va_loader, args.device)
221
+ structured = st["cos"] < 0.5 and st["pca10"] > 0.3
222
+ print(f" [eval] val_recon={vmean:.3f} z_rl: cos={st['cos']:.3f} (low=diverse) "
223
+ f"pca10={st['pca10']:.2%} (high=structured) "
224
+ f"{'✅ structured' if structured else '⚠️ diffuse'}")
225
+ if vmean < best_val:
226
+ best_val = vmean
227
+ save("best", {"step": step, "val_recon": vmean, "z_rl_structure": st})
228
+ model.load_state_dict(live)
229
+ model.train()
230
+ step += 1
231
+
232
+ model.load_state_dict(ema)
233
+ save("final", {"step": step, "val_recon": best_val})
234
+ print(f"done. best val_recon={best_val:.3f}. saved {args.out}_best.pt / _final.pt")
235
+
236
+
237
+ if __name__ == "__main__":
238
+ main()