rltoken-encoder / code /train_encoder.py
atharva-pantheon's picture
Upload code/train_encoder.py with huggingface_hub
371dfea verified
#!/usr/bin/env python3
"""Train the RLT Stage-1 encoder/decoder on cached (M,2560) prefix shards.
Uses the EXACT training knobs from the openpi reference (pravsels/openpi PR #6,
the `pi05_rl_token_bin_pack_coffee_capsules` TrainConfig), adapted for our
frozen-VLA / pre-cached-features setting (alpha = 0, so no VLA forward here):
optimizer AdamW, clip_gradient_norm = 1.0
lr schedule linear warmup (1000) -> constant peak_lr = 5e-5
(their CosineDecay had peak_lr == decay_lr == 5e-5 = flat)
ema_decay 0.999 (eval/save from the EMA weights)
loss per-token squared-L2, sum over dim, mean over valid tokens,
targets stop-gradiented (matches rl_token_encoder forward())
Deviations from the reference, on purpose:
* batch_size > 1: their bs=1 was forced by running the full pi05 VLA each step;
our enc/dec is tiny and features are cached, so we batch and pad+mask.
* NO feature standardization (reference reconstructs raw prefix_out). A
--standardize escape hatch is provided but OFF by default to stay faithful.
Run (server must be DOWN first to free VRAM):
./lerobot/.venv/bin/python train_encoder.py \
--shard-dir ./encoder_cache_prefix --out ./checkpoints/rl_token_encoder
"""
from __future__ import annotations
import argparse
import glob
import os
import time
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig
class PrefixShards(Dataset):
"""Each .npz holds `embeddings` (M, dim) float16 — one cached prefix."""
def __init__(self, shard_dir: str):
self.paths = sorted(glob.glob(os.path.join(os.path.expanduser(shard_dir), "*.npz")))
if not self.paths:
raise FileNotFoundError(f"no .npz shards in {shard_dir}")
# episode_id per shard (parsed from filename ep{NNNN}_...) for the
# success/failure t-SNE gate later; cheap to keep around.
self.episodes = [self._ep(p) for p in self.paths]
@staticmethod
def _ep(path: str) -> int:
base = os.path.basename(path)
if base.startswith("ep"):
try:
return int(base[2:6])
except ValueError:
pass
return -1
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, i: int) -> torch.Tensor:
with np.load(self.paths[i]) as z:
emb = z["embeddings"].astype(np.float32) # (M, dim)
return torch.from_numpy(emb)
def collate(batch: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Pad variable-M prefixes to the batch max; build the valid-token mask."""
dim = batch[0].shape[-1]
M = max(x.shape[0] for x in batch)
b = len(batch)
out = torch.zeros(b, M, dim, dtype=torch.float32)
mask = torch.zeros(b, M, dtype=torch.bool)
for i, x in enumerate(batch):
m = x.shape[0]
out[i, :m] = x
mask[i, :m] = True
return out, mask
def linear_warmup_then_constant(step: int, warmup: int, peak: float) -> float:
if step < warmup:
return peak * (step + 1) / warmup
return peak
@torch.no_grad()
def ema_update(ema: dict[str, torch.Tensor], model: torch.nn.Module, decay: float) -> None:
for k, v in model.state_dict().items():
if v.dtype.is_floating_point:
ema[k].mul_(decay).add_(v.detach(), alpha=1 - decay)
else:
ema[k].copy_(v)
@torch.no_grad()
def z_rl_structure(model: RLTokenAutoencoder, loader: DataLoader, device: str) -> dict:
"""Valid z_rl probe (label-free). NOTE: the old first-token ablation was VACUOUS
here — the first prefix token is a constant special token (id 151645, std=0), so
token-0 recon is trivially constant and real==shuffled regardless of z_rl quality.
Instead measure (1) cross-sample cosine of z_rl (collapse: ~1 bad, ~0 diverse) and
(2) PCA top-10 variance ratio (structure: higher = lower-D task manifold)."""
model.eval()
Z = []
for x, mask in loader:
x, mask = x.to(device), mask.to(device)
Z.append(model.encode(x, mask).float().cpu())
if sum(z.shape[0] for z in Z) >= 512:
break
Z = torch.cat(Z)[:512]
Zn = torch.nn.functional.normalize(Z, dim=1)
n = Z.shape[0]
cos = (Zn @ Zn.T)[~torch.eye(n, dtype=torch.bool)].mean().item()
s = torch.linalg.svdvals(Z - Z.mean(0))
var = s ** 2
pca10 = (var[:10].sum() / var.sum().clamp(min=1e-9)).item()
return {"cos": cos, "pca10": pca10}
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--shard-dir", default="./encoder_cache_prefix")
p.add_argument("--out", default="./checkpoints/rl_token_encoder")
p.add_argument("--dim", type=int, default=2560)
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--num-train-steps", type=int, default=10_000) # reference default; cap by the gate
p.add_argument("--peak-lr", type=float, default=5e-5) # reference
p.add_argument("--warmup-steps", type=int, default=1_000) # reference
p.add_argument("--clip-grad-norm", type=float, default=1.0) # reference
p.add_argument("--ema-decay", type=float, default=0.999) # reference
p.add_argument("--weight-decay", type=float, default=1e-4) # AdamW default-ish; ref AdamW unspecified
p.add_argument("--val-frac", type=float, default=0.1)
p.add_argument("--eval-every", type=int, default=500)
p.add_argument("--standardize", action="store_true", help="(off=faithful) z-score features first")
p.add_argument("--context-dropout", type=float, default=0.0,
help="train-only: prob of zeroing each decoder teacher-forced context token, "
"forcing info through z_rl (fixes latent collapse / the AR leak). 0=bare reference, 0.5=fix")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--seed", type=int, default=0)
args = p.parse_args()
torch.manual_seed(args.seed)
os.makedirs(os.path.dirname(os.path.abspath(args.out)) or ".", exist_ok=True)
full = PrefixShards(args.shard_dir)
n_val = max(1, int(len(full) * args.val_frac))
n_tr = len(full) - n_val
tr, va = random_split(full, [n_tr, n_val], generator=torch.Generator().manual_seed(args.seed))
print(f"shards: {len(full)} train: {n_tr} val: {n_val} episodes: {len(set(full.episodes))}")
# Optional standardization (per-feature mean/std over a sample of train shards).
mean = std = None
if args.standardize:
acc, c = torch.zeros(args.dim), 0
sq = torch.zeros(args.dim)
for idx in list(tr.indices)[:512]:
x = full[idx]
acc += x.sum(0); sq += (x * x).sum(0); c += x.shape[0]
mean = acc / c
std = (sq / c - mean**2).clamp_min(1e-6).sqrt()
print("standardize ON: feature mean/std computed over", c, "tokens")
def norm(x):
return (x - mean) / std if mean is not None else x
dl_kw = dict(batch_size=args.batch_size, collate_fn=collate, num_workers=4, pin_memory=True)
tr_loader = DataLoader(tr, shuffle=True, drop_last=True, **dl_kw)
va_loader = DataLoader(va, shuffle=False, **dl_kw)
model = RLTokenAutoencoder(RLTokenConfig(dim=args.dim)).to(args.device)
n_params = sum(p.numel() for p in model.parameters())
print(f"model params: {n_params/1e6:.1f}M device: {args.device}")
opt = torch.optim.AdamW(model.parameters(), lr=args.peak_lr, weight_decay=args.weight_decay)
ema = {k: v.detach().clone() for k, v in model.state_dict().items()}
def save(tag: str, extra: dict) -> None:
torch.save({
"model": model.state_dict(),
"ema": ema,
"cfg": vars(RLTokenConfig(dim=args.dim)),
"mean": mean, "std": std,
"args": vars(args),
**extra,
}, f"{args.out}_{tag}.pt")
step = 0
best_val = float("inf")
t0 = time.time()
model.train()
print("training... (their knobs: AdamW lr5e-5, warmup1k, clip1.0, ema0.999)")
while step < args.num_train_steps:
for x, mask in tr_loader:
if step >= args.num_train_steps:
break
x, mask = norm(x).to(args.device), mask.to(args.device)
for g in opt.param_groups:
g["lr"] = linear_warmup_then_constant(step, args.warmup_steps, args.peak_lr)
_, loss = model(x, mask, context_dropout=args.context_dropout)
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
opt.step()
ema_update(ema, model, args.ema_decay)
if step % 50 == 0:
print(f"step {step:6d} recon={loss.item():10.3f} lr={opt.param_groups[0]['lr']:.2e}"
f" {(step+1)/(time.time()-t0):.1f} it/s")
if step > 0 and step % args.eval_every == 0:
# eval from EMA weights (reference uses EMA for eval/save)
live = {k: v.detach().clone() for k, v in model.state_dict().items()}
model.load_state_dict(ema)
vlosses = []
with torch.no_grad():
for vx, vm in va_loader:
vx, vm = norm(vx).to(args.device), vm.to(args.device)
vlosses.append(model(vx, vm)[1].item())
vmean = float(np.mean(vlosses))
st = z_rl_structure(model, va_loader, args.device)
structured = st["cos"] < 0.5 and st["pca10"] > 0.3
print(f" [eval] val_recon={vmean:.3f} z_rl: cos={st['cos']:.3f} (low=diverse) "
f"pca10={st['pca10']:.2%} (high=structured) "
f"{'✅ structured' if structured else '⚠️ diffuse'}")
if vmean < best_val:
best_val = vmean
save("best", {"step": step, "val_recon": vmean, "z_rl_structure": st})
model.load_state_dict(live)
model.train()
step += 1
model.load_state_dict(ema)
save("final", {"step": step, "val_recon": best_val})
print(f"done. best val_recon={best_val:.3f}. saved {args.out}_best.pt / _final.pt")
if __name__ == "__main__":
main()