"""RL Token encoder-decoder for MolmoAct2 (RLT Stage 1) — PyTorch port. Faithful port of openpi's ``pi0_rl.py`` (Xu et al. 2025, "RL Tokens") to PyTorch for the frozen MolmoAct2 lerobot fork. Differences from my earlier ``rlt_logit_autoencoder.py`` (which was wrong): that one MLP-reconstructed the 2048-D action logits; THIS reconstructs the VLA's **per-token prefix hidden states** ``(M, dim)`` with a transformer encoder + autoregressive decoder, so the single ``z_rl`` token is forced to regenerate the whole prefix — the real RLT bottleneck, and what todo Phase 3 specifies. Design (matches the reference): encoder: append a learned query to the prefix embeddings (b, M, dim), run bidirectional pre-norm transformer blocks (RMSNorm + SwiGLU), read the query position -> z_rl (b, dim). decoder: autoregressive. input [z_rl, z̄_1 … z̄_{M-1}], causal mask, predict [z̄_1 … z̄_M]; output_proj. loss: per-token squared-L2 recon (sum over dim, masked mean over tokens), targets stop-gradiented. VLA is a frozen server here, so there is no L_vla term (alpha = 0): we only train the encoder/decoder. z_rl is full-dim (= dim), exactly like the reference — the bottleneck is the sequence compression (M tokens -> 1), not a narrow feature dim. Downstream SAC consumes z_rl as its (frozen) RLT state. """ from __future__ import annotations import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @dataclass class RLTokenConfig: dim: int = 2560 # MolmoAct2 VLM hidden width (cached embeddings are 2560-D) num_layers: int = 2 num_heads: int = 8 # 2560 / 8 = 320 head_dim mlp_dim: int = 8192 class _Block(nn.Module): """Pre-norm transformer block: MHA + SwiGLU FFN, RMSNorm. Matches the ref.""" def __init__(self, dim: int, num_heads: int, mlp_dim: int): super().__init__() assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}" self.num_heads = num_heads self.head_dim = dim // num_heads self.attn_norm = nn.RMSNorm(dim) self.q_proj = nn.Linear(dim, dim, bias=False) self.k_proj = nn.Linear(dim, dim, bias=False) self.v_proj = nn.Linear(dim, dim, bias=False) self.o_proj = nn.Linear(dim, dim, bias=False) self.ffn_norm = nn.RMSNorm(dim) self.ffn_gate = nn.Linear(dim, mlp_dim, bias=False) self.ffn_up = nn.Linear(dim, mlp_dim, bias=False) self.ffn_down = nn.Linear(mlp_dim, dim, bias=False) def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None) -> torch.Tensor: b, s, d = x.shape h = self.attn_norm(x) q = self.q_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2) # (b,nh,s,hd) k = self.k_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2) # attn_mask: (b, s, s) bool, True = attend. -> (b,1,s,s) for SDPA additive. am = None if attn_mask is not None: am = torch.zeros(b, 1, s, s, dtype=x.dtype, device=x.device) am = am.masked_fill(~attn_mask[:, None, :, :], float("-inf")) attn = F.scaled_dot_product_attention(q, k, v, attn_mask=am) # (b,nh,s,hd) attn = attn.transpose(1, 2).reshape(b, s, d) x = x + self.o_proj(attn) h = self.ffn_norm(x) x = x + self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h)) return x class RLTokenEncoder(nn.Module): """Compress prefix embeddings (b, M, dim) -> z_rl (b, dim) via a learned query.""" def __init__(self, cfg: RLTokenConfig): super().__init__() self.rl_query = nn.Parameter(torch.randn(1, 1, cfg.dim) * 0.02) self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers)) def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: b, m, d = prefix.shape query = self.rl_query.expand(b, 1, d) x = torch.cat([prefix, query], dim=1) # (b, M+1, dim) if mask is not None: ext = torch.cat([mask, torch.ones(b, 1, dtype=torch.bool, device=mask.device)], dim=1) attn_mask = ext[:, None, :] & ext[:, :, None] # (b, M+1, M+1) bidirectional else: attn_mask = None for layer in self.layers: x = layer(x, attn_mask) return x[:, -1, :] # z_rl at the query position class RLTokenDecoder(nn.Module): """Autoregressively reconstruct prefix embeddings from z_rl.""" def __init__(self, cfg: RLTokenConfig): super().__init__() self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers)) self.output_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) def forward(self, z_rl: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None, context_dropout: float = 0.0) -> torch.Tensor: # input [z_rl, z̄_1..z̄_{M-1}] -> predict [z̄_1..z̄_M] b, m, d = target.shape ctx = target[:, :-1, :] # Context dropout (train only): randomly zero teacher-forced context tokens # so the decoder cannot reconstruct purely from the true-previous-token leak # and is forced to route information through z_rl. Off (0.0) = bare reference. if self.training and context_dropout > 0.0: keep = (torch.rand(b, m - 1, 1, device=target.device) >= context_dropout).to(target.dtype) ctx = ctx * keep dec_in = torch.cat([z_rl[:, None, :], ctx], dim=1) # (b, M, dim) causal = torch.tril(torch.ones(m, m, dtype=torch.bool, device=target.device))[None] # (1,M,M) if mask is not None: key_valid = torch.cat([torch.ones(b, 1, dtype=torch.bool, device=mask.device), mask[:, :-1]], dim=1) attn_mask = causal & key_valid[:, None, :] # (b, M, M) else: attn_mask = causal.expand(b, m, m) x = dec_in for layer in self.layers: x = layer(x, attn_mask) return self.output_proj(x) class RLTokenAutoencoder(nn.Module): """Encoder + decoder. forward() returns (z_rl, recon_loss) for training.""" def __init__(self, cfg: RLTokenConfig | None = None): super().__init__() self.cfg = cfg or RLTokenConfig() self.encoder = RLTokenEncoder(self.cfg) self.decoder = RLTokenDecoder(self.cfg) def encode(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: return self.encoder(prefix, mask) def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None, context_dropout: float = 0.0): # Targets are stop-gradiented (frozen VLA features). detach() = jax.lax.stop_gradient. target = prefix.detach() z_rl = self.encoder(target, mask) pred = self.decoder(z_rl, target, mask, context_dropout=context_dropout) per_token = (pred - target).pow(2).sum(dim=-1) # (b, M) squared-L2 per token if mask is not None: per_token = per_token * mask denom = mask.sum(dim=1).clamp(min=1) recon = (per_token.sum(dim=1) / denom) # (b,) else: recon = per_token.mean(dim=1) return z_rl, recon.mean() if __name__ == "__main__": # Self-test on COMPRESSIBLE data: each sequence is a per-sample latent c # broadcast across positions + a small FIXED positional pattern. So one z_rl # can capture c. Fair ablation = FIRST-token recon: position 0 sees ONLY # z_rl (no AR context), so it isolates whether z_rl carries information. torch.manual_seed(0) cfg = RLTokenConfig(dim=64, num_layers=2, num_heads=4, mlp_dim=128) # tiny for CPU ae = RLTokenAutoencoder(cfg) opt = torch.optim.AdamW(ae.parameters(), lr=1e-3) B, M = 32, 12 pos_pattern = torch.randn(M, cfg.dim) * 0.3 # fixed per-position offset def batch(): c = torch.randn(B, cfg.dim) # per-sample latent x = c[:, None, :] + pos_pattern[None] # (B, M, dim), compressible return x, torch.ones(B, M, dtype=torch.bool) for step in range(600): x, mask = batch() z, loss = ae(x, mask) opt.zero_grad(); loss.backward(); opt.step() if step % 150 == 0 or step == 599: print(f"step {step:3d} recon={loss.item():.4f}") ae.eval() with torch.no_grad(): x, mask = batch() z, _ = ae(x, mask) def first_tok_err(zt): pred = ae.decoder(zt, x, mask) return (pred[:, 0] - x[:, 0]).pow(2).sum(-1).mean().item() # token-0 only real0 = first_tok_err(z) zero0 = first_tok_err(torch.zeros_like(z)) shuf0 = first_tok_err(z[torch.randperm(B)]) print(f"first-token recon: real={real0:.3f} zeroed={zero0:.3f} shuffled={shuf0:.3f}") ok = real0 < 0.3 * zero0 and real0 < 0.3 * shuf0 print("SELF-TEST:", "PASS ✅ (z_rl carries the prefix latent)" if ok else "FAIL ❌")