rltoken-encoder / code /rl_token_encoder.py
atharva-pantheon's picture
Upload code/rl_token_encoder.py with huggingface_hub
f9042a0 verified
"""RL Token encoder-decoder for MolmoAct2 (RLT Stage 1) — PyTorch port.
Faithful port of openpi's ``pi0_rl.py`` (Xu et al. 2025, "RL Tokens") to PyTorch
for the frozen MolmoAct2 lerobot fork. Differences from my earlier
``rlt_logit_autoencoder.py`` (which was wrong): that one MLP-reconstructed the
2048-D action logits; THIS reconstructs the VLA's **per-token prefix hidden
states** ``(M, dim)`` with a transformer encoder + autoregressive decoder, so
the single ``z_rl`` token is forced to regenerate the whole prefix — the real
RLT bottleneck, and what todo Phase 3 specifies.
Design (matches the reference):
encoder: append a learned <rl> query to the prefix embeddings (b, M, dim),
run bidirectional pre-norm transformer blocks (RMSNorm + SwiGLU),
read the query position -> z_rl (b, dim).
decoder: autoregressive. input [z_rl, z̄_1 … z̄_{M-1}], causal mask,
predict [z̄_1 … z̄_M]; output_proj.
loss: per-token squared-L2 recon (sum over dim, masked mean over tokens),
targets stop-gradiented. VLA is a frozen server here, so there is no
L_vla term (alpha = 0): we only train the encoder/decoder.
z_rl is full-dim (= dim), exactly like the reference — the bottleneck is the
sequence compression (M tokens -> 1), not a narrow feature dim. Downstream SAC
consumes z_rl as its (frozen) RLT state.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class RLTokenConfig:
dim: int = 2560 # MolmoAct2 VLM hidden width (cached embeddings are 2560-D)
num_layers: int = 2
num_heads: int = 8 # 2560 / 8 = 320 head_dim
mlp_dim: int = 8192
class _Block(nn.Module):
"""Pre-norm transformer block: MHA + SwiGLU FFN, RMSNorm. Matches the ref."""
def __init__(self, dim: int, num_heads: int, mlp_dim: int):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.attn_norm = nn.RMSNorm(dim)
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.o_proj = nn.Linear(dim, dim, bias=False)
self.ffn_norm = nn.RMSNorm(dim)
self.ffn_gate = nn.Linear(dim, mlp_dim, bias=False)
self.ffn_up = nn.Linear(dim, mlp_dim, bias=False)
self.ffn_down = nn.Linear(mlp_dim, dim, bias=False)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None) -> torch.Tensor:
b, s, d = x.shape
h = self.attn_norm(x)
q = self.q_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2) # (b,nh,s,hd)
k = self.k_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
# attn_mask: (b, s, s) bool, True = attend. -> (b,1,s,s) for SDPA additive.
am = None
if attn_mask is not None:
am = torch.zeros(b, 1, s, s, dtype=x.dtype, device=x.device)
am = am.masked_fill(~attn_mask[:, None, :, :], float("-inf"))
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=am) # (b,nh,s,hd)
attn = attn.transpose(1, 2).reshape(b, s, d)
x = x + self.o_proj(attn)
h = self.ffn_norm(x)
x = x + self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h))
return x
class RLTokenEncoder(nn.Module):
"""Compress prefix embeddings (b, M, dim) -> z_rl (b, dim) via a learned query."""
def __init__(self, cfg: RLTokenConfig):
super().__init__()
self.rl_query = nn.Parameter(torch.randn(1, 1, cfg.dim) * 0.02)
self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers))
def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
b, m, d = prefix.shape
query = self.rl_query.expand(b, 1, d)
x = torch.cat([prefix, query], dim=1) # (b, M+1, dim)
if mask is not None:
ext = torch.cat([mask, torch.ones(b, 1, dtype=torch.bool, device=mask.device)], dim=1)
attn_mask = ext[:, None, :] & ext[:, :, None] # (b, M+1, M+1) bidirectional
else:
attn_mask = None
for layer in self.layers:
x = layer(x, attn_mask)
return x[:, -1, :] # z_rl at the query position
class RLTokenDecoder(nn.Module):
"""Autoregressively reconstruct prefix embeddings from z_rl."""
def __init__(self, cfg: RLTokenConfig):
super().__init__()
self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers))
self.output_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
def forward(self, z_rl: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None,
context_dropout: float = 0.0) -> torch.Tensor:
# input [z_rl, z̄_1..z̄_{M-1}] -> predict [z̄_1..z̄_M]
b, m, d = target.shape
ctx = target[:, :-1, :]
# Context dropout (train only): randomly zero teacher-forced context tokens
# so the decoder cannot reconstruct purely from the true-previous-token leak
# and is forced to route information through z_rl. Off (0.0) = bare reference.
if self.training and context_dropout > 0.0:
keep = (torch.rand(b, m - 1, 1, device=target.device) >= context_dropout).to(target.dtype)
ctx = ctx * keep
dec_in = torch.cat([z_rl[:, None, :], ctx], dim=1) # (b, M, dim)
causal = torch.tril(torch.ones(m, m, dtype=torch.bool, device=target.device))[None] # (1,M,M)
if mask is not None:
key_valid = torch.cat([torch.ones(b, 1, dtype=torch.bool, device=mask.device), mask[:, :-1]], dim=1)
attn_mask = causal & key_valid[:, None, :] # (b, M, M)
else:
attn_mask = causal.expand(b, m, m)
x = dec_in
for layer in self.layers:
x = layer(x, attn_mask)
return self.output_proj(x)
class RLTokenAutoencoder(nn.Module):
"""Encoder + decoder. forward() returns (z_rl, recon_loss) for training."""
def __init__(self, cfg: RLTokenConfig | None = None):
super().__init__()
self.cfg = cfg or RLTokenConfig()
self.encoder = RLTokenEncoder(self.cfg)
self.decoder = RLTokenDecoder(self.cfg)
def encode(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
return self.encoder(prefix, mask)
def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None, context_dropout: float = 0.0):
# Targets are stop-gradiented (frozen VLA features). detach() = jax.lax.stop_gradient.
target = prefix.detach()
z_rl = self.encoder(target, mask)
pred = self.decoder(z_rl, target, mask, context_dropout=context_dropout)
per_token = (pred - target).pow(2).sum(dim=-1) # (b, M) squared-L2 per token
if mask is not None:
per_token = per_token * mask
denom = mask.sum(dim=1).clamp(min=1)
recon = (per_token.sum(dim=1) / denom) # (b,)
else:
recon = per_token.mean(dim=1)
return z_rl, recon.mean()
if __name__ == "__main__":
# Self-test on COMPRESSIBLE data: each sequence is a per-sample latent c
# broadcast across positions + a small FIXED positional pattern. So one z_rl
# can capture c. Fair ablation = FIRST-token recon: position 0 sees ONLY
# z_rl (no AR context), so it isolates whether z_rl carries information.
torch.manual_seed(0)
cfg = RLTokenConfig(dim=64, num_layers=2, num_heads=4, mlp_dim=128) # tiny for CPU
ae = RLTokenAutoencoder(cfg)
opt = torch.optim.AdamW(ae.parameters(), lr=1e-3)
B, M = 32, 12
pos_pattern = torch.randn(M, cfg.dim) * 0.3 # fixed per-position offset
def batch():
c = torch.randn(B, cfg.dim) # per-sample latent
x = c[:, None, :] + pos_pattern[None] # (B, M, dim), compressible
return x, torch.ones(B, M, dtype=torch.bool)
for step in range(600):
x, mask = batch()
z, loss = ae(x, mask)
opt.zero_grad(); loss.backward(); opt.step()
if step % 150 == 0 or step == 599:
print(f"step {step:3d} recon={loss.item():.4f}")
ae.eval()
with torch.no_grad():
x, mask = batch()
z, _ = ae(x, mask)
def first_tok_err(zt):
pred = ae.decoder(zt, x, mask)
return (pred[:, 0] - x[:, 0]).pow(2).sum(-1).mean().item() # token-0 only
real0 = first_tok_err(z)
zero0 = first_tok_err(torch.zeros_like(z))
shuf0 = first_tok_err(z[torch.randperm(B)])
print(f"first-token recon: real={real0:.3f} zeroed={zero0:.3f} shuffled={shuf0:.3f}")
ok = real0 < 0.3 * zero0 and real0 < 0.3 * shuf0
print("SELF-TEST:", "PASS ✅ (z_rl carries the prefix latent)" if ok else "FAIL ❌")