Upload code/rl_token_encoder.py with huggingface_hub
Browse files- code/rl_token_encoder.py +194 -0
code/rl_token_encoder.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RL Token encoder-decoder for MolmoAct2 (RLT Stage 1) — PyTorch port.
|
| 2 |
+
|
| 3 |
+
Faithful port of openpi's ``pi0_rl.py`` (Xu et al. 2025, "RL Tokens") to PyTorch
|
| 4 |
+
for the frozen MolmoAct2 lerobot fork. Differences from my earlier
|
| 5 |
+
``rlt_logit_autoencoder.py`` (which was wrong): that one MLP-reconstructed the
|
| 6 |
+
2048-D action logits; THIS reconstructs the VLA's **per-token prefix hidden
|
| 7 |
+
states** ``(M, dim)`` with a transformer encoder + autoregressive decoder, so
|
| 8 |
+
the single ``z_rl`` token is forced to regenerate the whole prefix — the real
|
| 9 |
+
RLT bottleneck, and what todo Phase 3 specifies.
|
| 10 |
+
|
| 11 |
+
Design (matches the reference):
|
| 12 |
+
encoder: append a learned <rl> query to the prefix embeddings (b, M, dim),
|
| 13 |
+
run bidirectional pre-norm transformer blocks (RMSNorm + SwiGLU),
|
| 14 |
+
read the query position -> z_rl (b, dim).
|
| 15 |
+
decoder: autoregressive. input [z_rl, z̄_1 … z̄_{M-1}], causal mask,
|
| 16 |
+
predict [z̄_1 … z̄_M]; output_proj.
|
| 17 |
+
loss: per-token squared-L2 recon (sum over dim, masked mean over tokens),
|
| 18 |
+
targets stop-gradiented. VLA is a frozen server here, so there is no
|
| 19 |
+
L_vla term (alpha = 0): we only train the encoder/decoder.
|
| 20 |
+
|
| 21 |
+
z_rl is full-dim (= dim), exactly like the reference — the bottleneck is the
|
| 22 |
+
sequence compression (M tokens -> 1), not a narrow feature dim. Downstream SAC
|
| 23 |
+
consumes z_rl as its (frozen) RLT state.
|
| 24 |
+
"""
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import math
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class RLTokenConfig:
|
| 37 |
+
dim: int = 2560 # MolmoAct2 VLM hidden width (cached embeddings are 2560-D)
|
| 38 |
+
num_layers: int = 2
|
| 39 |
+
num_heads: int = 8 # 2560 / 8 = 320 head_dim
|
| 40 |
+
mlp_dim: int = 8192
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class _Block(nn.Module):
|
| 44 |
+
"""Pre-norm transformer block: MHA + SwiGLU FFN, RMSNorm. Matches the ref."""
|
| 45 |
+
|
| 46 |
+
def __init__(self, dim: int, num_heads: int, mlp_dim: int):
|
| 47 |
+
super().__init__()
|
| 48 |
+
assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
|
| 49 |
+
self.num_heads = num_heads
|
| 50 |
+
self.head_dim = dim // num_heads
|
| 51 |
+
self.attn_norm = nn.RMSNorm(dim)
|
| 52 |
+
self.q_proj = nn.Linear(dim, dim, bias=False)
|
| 53 |
+
self.k_proj = nn.Linear(dim, dim, bias=False)
|
| 54 |
+
self.v_proj = nn.Linear(dim, dim, bias=False)
|
| 55 |
+
self.o_proj = nn.Linear(dim, dim, bias=False)
|
| 56 |
+
self.ffn_norm = nn.RMSNorm(dim)
|
| 57 |
+
self.ffn_gate = nn.Linear(dim, mlp_dim, bias=False)
|
| 58 |
+
self.ffn_up = nn.Linear(dim, mlp_dim, bias=False)
|
| 59 |
+
self.ffn_down = nn.Linear(mlp_dim, dim, bias=False)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None) -> torch.Tensor:
|
| 62 |
+
b, s, d = x.shape
|
| 63 |
+
h = self.attn_norm(x)
|
| 64 |
+
q = self.q_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2) # (b,nh,s,hd)
|
| 65 |
+
k = self.k_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
|
| 66 |
+
v = self.v_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
|
| 67 |
+
# attn_mask: (b, s, s) bool, True = attend. -> (b,1,s,s) for SDPA additive.
|
| 68 |
+
am = None
|
| 69 |
+
if attn_mask is not None:
|
| 70 |
+
am = torch.zeros(b, 1, s, s, dtype=x.dtype, device=x.device)
|
| 71 |
+
am = am.masked_fill(~attn_mask[:, None, :, :], float("-inf"))
|
| 72 |
+
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=am) # (b,nh,s,hd)
|
| 73 |
+
attn = attn.transpose(1, 2).reshape(b, s, d)
|
| 74 |
+
x = x + self.o_proj(attn)
|
| 75 |
+
h = self.ffn_norm(x)
|
| 76 |
+
x = x + self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h))
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class RLTokenEncoder(nn.Module):
|
| 81 |
+
"""Compress prefix embeddings (b, M, dim) -> z_rl (b, dim) via a learned query."""
|
| 82 |
+
|
| 83 |
+
def __init__(self, cfg: RLTokenConfig):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.rl_query = nn.Parameter(torch.randn(1, 1, cfg.dim) * 0.02)
|
| 86 |
+
self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers))
|
| 87 |
+
|
| 88 |
+
def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 89 |
+
b, m, d = prefix.shape
|
| 90 |
+
query = self.rl_query.expand(b, 1, d)
|
| 91 |
+
x = torch.cat([prefix, query], dim=1) # (b, M+1, dim)
|
| 92 |
+
if mask is not None:
|
| 93 |
+
ext = torch.cat([mask, torch.ones(b, 1, dtype=torch.bool, device=mask.device)], dim=1)
|
| 94 |
+
attn_mask = ext[:, None, :] & ext[:, :, None] # (b, M+1, M+1) bidirectional
|
| 95 |
+
else:
|
| 96 |
+
attn_mask = None
|
| 97 |
+
for layer in self.layers:
|
| 98 |
+
x = layer(x, attn_mask)
|
| 99 |
+
return x[:, -1, :] # z_rl at the query position
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class RLTokenDecoder(nn.Module):
|
| 103 |
+
"""Autoregressively reconstruct prefix embeddings from z_rl."""
|
| 104 |
+
|
| 105 |
+
def __init__(self, cfg: RLTokenConfig):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers))
|
| 108 |
+
self.output_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
|
| 109 |
+
|
| 110 |
+
def forward(self, z_rl: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None,
|
| 111 |
+
context_dropout: float = 0.0) -> torch.Tensor:
|
| 112 |
+
# input [z_rl, z̄_1..z̄_{M-1}] -> predict [z̄_1..z̄_M]
|
| 113 |
+
b, m, d = target.shape
|
| 114 |
+
ctx = target[:, :-1, :]
|
| 115 |
+
# Context dropout (train only): randomly zero teacher-forced context tokens
|
| 116 |
+
# so the decoder cannot reconstruct purely from the true-previous-token leak
|
| 117 |
+
# and is forced to route information through z_rl. Off (0.0) = bare reference.
|
| 118 |
+
if self.training and context_dropout > 0.0:
|
| 119 |
+
keep = (torch.rand(b, m - 1, 1, device=target.device) >= context_dropout).to(target.dtype)
|
| 120 |
+
ctx = ctx * keep
|
| 121 |
+
dec_in = torch.cat([z_rl[:, None, :], ctx], dim=1) # (b, M, dim)
|
| 122 |
+
causal = torch.tril(torch.ones(m, m, dtype=torch.bool, device=target.device))[None] # (1,M,M)
|
| 123 |
+
if mask is not None:
|
| 124 |
+
key_valid = torch.cat([torch.ones(b, 1, dtype=torch.bool, device=mask.device), mask[:, :-1]], dim=1)
|
| 125 |
+
attn_mask = causal & key_valid[:, None, :] # (b, M, M)
|
| 126 |
+
else:
|
| 127 |
+
attn_mask = causal.expand(b, m, m)
|
| 128 |
+
x = dec_in
|
| 129 |
+
for layer in self.layers:
|
| 130 |
+
x = layer(x, attn_mask)
|
| 131 |
+
return self.output_proj(x)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class RLTokenAutoencoder(nn.Module):
|
| 135 |
+
"""Encoder + decoder. forward() returns (z_rl, recon_loss) for training."""
|
| 136 |
+
|
| 137 |
+
def __init__(self, cfg: RLTokenConfig | None = None):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.cfg = cfg or RLTokenConfig()
|
| 140 |
+
self.encoder = RLTokenEncoder(self.cfg)
|
| 141 |
+
self.decoder = RLTokenDecoder(self.cfg)
|
| 142 |
+
|
| 143 |
+
def encode(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 144 |
+
return self.encoder(prefix, mask)
|
| 145 |
+
|
| 146 |
+
def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None, context_dropout: float = 0.0):
|
| 147 |
+
# Targets are stop-gradiented (frozen VLA features). detach() = jax.lax.stop_gradient.
|
| 148 |
+
target = prefix.detach()
|
| 149 |
+
z_rl = self.encoder(target, mask)
|
| 150 |
+
pred = self.decoder(z_rl, target, mask, context_dropout=context_dropout)
|
| 151 |
+
per_token = (pred - target).pow(2).sum(dim=-1) # (b, M) squared-L2 per token
|
| 152 |
+
if mask is not None:
|
| 153 |
+
per_token = per_token * mask
|
| 154 |
+
denom = mask.sum(dim=1).clamp(min=1)
|
| 155 |
+
recon = (per_token.sum(dim=1) / denom) # (b,)
|
| 156 |
+
else:
|
| 157 |
+
recon = per_token.mean(dim=1)
|
| 158 |
+
return z_rl, recon.mean()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
# Self-test on COMPRESSIBLE data: each sequence is a per-sample latent c
|
| 163 |
+
# broadcast across positions + a small FIXED positional pattern. So one z_rl
|
| 164 |
+
# can capture c. Fair ablation = FIRST-token recon: position 0 sees ONLY
|
| 165 |
+
# z_rl (no AR context), so it isolates whether z_rl carries information.
|
| 166 |
+
torch.manual_seed(0)
|
| 167 |
+
cfg = RLTokenConfig(dim=64, num_layers=2, num_heads=4, mlp_dim=128) # tiny for CPU
|
| 168 |
+
ae = RLTokenAutoencoder(cfg)
|
| 169 |
+
opt = torch.optim.AdamW(ae.parameters(), lr=1e-3)
|
| 170 |
+
B, M = 32, 12
|
| 171 |
+
pos_pattern = torch.randn(M, cfg.dim) * 0.3 # fixed per-position offset
|
| 172 |
+
def batch():
|
| 173 |
+
c = torch.randn(B, cfg.dim) # per-sample latent
|
| 174 |
+
x = c[:, None, :] + pos_pattern[None] # (B, M, dim), compressible
|
| 175 |
+
return x, torch.ones(B, M, dtype=torch.bool)
|
| 176 |
+
for step in range(600):
|
| 177 |
+
x, mask = batch()
|
| 178 |
+
z, loss = ae(x, mask)
|
| 179 |
+
opt.zero_grad(); loss.backward(); opt.step()
|
| 180 |
+
if step % 150 == 0 or step == 599:
|
| 181 |
+
print(f"step {step:3d} recon={loss.item():.4f}")
|
| 182 |
+
ae.eval()
|
| 183 |
+
with torch.no_grad():
|
| 184 |
+
x, mask = batch()
|
| 185 |
+
z, _ = ae(x, mask)
|
| 186 |
+
def first_tok_err(zt):
|
| 187 |
+
pred = ae.decoder(zt, x, mask)
|
| 188 |
+
return (pred[:, 0] - x[:, 0]).pow(2).sum(-1).mean().item() # token-0 only
|
| 189 |
+
real0 = first_tok_err(z)
|
| 190 |
+
zero0 = first_tok_err(torch.zeros_like(z))
|
| 191 |
+
shuf0 = first_tok_err(z[torch.randperm(B)])
|
| 192 |
+
print(f"first-token recon: real={real0:.3f} zeroed={zero0:.3f} shuffled={shuf0:.3f}")
|
| 193 |
+
ok = real0 < 0.3 * zero0 and real0 < 0.3 * shuf0
|
| 194 |
+
print("SELF-TEST:", "PASS ✅ (z_rl carries the prefix latent)" if ok else "FAIL ❌")
|