| """RL Token encoder-decoder for MolmoAct2 (RLT Stage 1) — PyTorch port. |
| |
| Faithful port of openpi's ``pi0_rl.py`` (Xu et al. 2025, "RL Tokens") to PyTorch |
| for the frozen MolmoAct2 lerobot fork. Differences from my earlier |
| ``rlt_logit_autoencoder.py`` (which was wrong): that one MLP-reconstructed the |
| 2048-D action logits; THIS reconstructs the VLA's **per-token prefix hidden |
| states** ``(M, dim)`` with a transformer encoder + autoregressive decoder, so |
| the single ``z_rl`` token is forced to regenerate the whole prefix — the real |
| RLT bottleneck, and what todo Phase 3 specifies. |
| |
| Design (matches the reference): |
| encoder: append a learned <rl> query to the prefix embeddings (b, M, dim), |
| run bidirectional pre-norm transformer blocks (RMSNorm + SwiGLU), |
| read the query position -> z_rl (b, dim). |
| decoder: autoregressive. input [z_rl, z̄_1 … z̄_{M-1}], causal mask, |
| predict [z̄_1 … z̄_M]; output_proj. |
| loss: per-token squared-L2 recon (sum over dim, masked mean over tokens), |
| targets stop-gradiented. VLA is a frozen server here, so there is no |
| L_vla term (alpha = 0): we only train the encoder/decoder. |
| |
| z_rl is full-dim (= dim), exactly like the reference — the bottleneck is the |
| sequence compression (M tokens -> 1), not a narrow feature dim. Downstream SAC |
| consumes z_rl as its (frozen) RLT state. |
| """ |
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class RLTokenConfig: |
| dim: int = 2560 |
| num_layers: int = 2 |
| num_heads: int = 8 |
| mlp_dim: int = 8192 |
|
|
|
|
| class _Block(nn.Module): |
| """Pre-norm transformer block: MHA + SwiGLU FFN, RMSNorm. Matches the ref.""" |
|
|
| def __init__(self, dim: int, num_heads: int, mlp_dim: int): |
| super().__init__() |
| assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}" |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.attn_norm = nn.RMSNorm(dim) |
| self.q_proj = nn.Linear(dim, dim, bias=False) |
| self.k_proj = nn.Linear(dim, dim, bias=False) |
| self.v_proj = nn.Linear(dim, dim, bias=False) |
| self.o_proj = nn.Linear(dim, dim, bias=False) |
| self.ffn_norm = nn.RMSNorm(dim) |
| self.ffn_gate = nn.Linear(dim, mlp_dim, bias=False) |
| self.ffn_up = nn.Linear(dim, mlp_dim, bias=False) |
| self.ffn_down = nn.Linear(mlp_dim, dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None) -> torch.Tensor: |
| b, s, d = x.shape |
| h = self.attn_norm(x) |
| q = self.q_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| am = None |
| if attn_mask is not None: |
| am = torch.zeros(b, 1, s, s, dtype=x.dtype, device=x.device) |
| am = am.masked_fill(~attn_mask[:, None, :, :], float("-inf")) |
| attn = F.scaled_dot_product_attention(q, k, v, attn_mask=am) |
| attn = attn.transpose(1, 2).reshape(b, s, d) |
| x = x + self.o_proj(attn) |
| h = self.ffn_norm(x) |
| x = x + self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h)) |
| return x |
|
|
|
|
| class RLTokenEncoder(nn.Module): |
| """Compress prefix embeddings (b, M, dim) -> z_rl (b, dim) via a learned query.""" |
|
|
| def __init__(self, cfg: RLTokenConfig): |
| super().__init__() |
| self.rl_query = nn.Parameter(torch.randn(1, 1, cfg.dim) * 0.02) |
| self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers)) |
|
|
| def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: |
| b, m, d = prefix.shape |
| query = self.rl_query.expand(b, 1, d) |
| x = torch.cat([prefix, query], dim=1) |
| if mask is not None: |
| ext = torch.cat([mask, torch.ones(b, 1, dtype=torch.bool, device=mask.device)], dim=1) |
| attn_mask = ext[:, None, :] & ext[:, :, None] |
| else: |
| attn_mask = None |
| for layer in self.layers: |
| x = layer(x, attn_mask) |
| return x[:, -1, :] |
|
|
|
|
| class RLTokenDecoder(nn.Module): |
| """Autoregressively reconstruct prefix embeddings from z_rl.""" |
|
|
| def __init__(self, cfg: RLTokenConfig): |
| super().__init__() |
| self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers)) |
| self.output_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) |
|
|
| def forward(self, z_rl: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None, |
| context_dropout: float = 0.0) -> torch.Tensor: |
| |
| b, m, d = target.shape |
| ctx = target[:, :-1, :] |
| |
| |
| |
| if self.training and context_dropout > 0.0: |
| keep = (torch.rand(b, m - 1, 1, device=target.device) >= context_dropout).to(target.dtype) |
| ctx = ctx * keep |
| dec_in = torch.cat([z_rl[:, None, :], ctx], dim=1) |
| causal = torch.tril(torch.ones(m, m, dtype=torch.bool, device=target.device))[None] |
| if mask is not None: |
| key_valid = torch.cat([torch.ones(b, 1, dtype=torch.bool, device=mask.device), mask[:, :-1]], dim=1) |
| attn_mask = causal & key_valid[:, None, :] |
| else: |
| attn_mask = causal.expand(b, m, m) |
| x = dec_in |
| for layer in self.layers: |
| x = layer(x, attn_mask) |
| return self.output_proj(x) |
|
|
|
|
| class RLTokenAutoencoder(nn.Module): |
| """Encoder + decoder. forward() returns (z_rl, recon_loss) for training.""" |
|
|
| def __init__(self, cfg: RLTokenConfig | None = None): |
| super().__init__() |
| self.cfg = cfg or RLTokenConfig() |
| self.encoder = RLTokenEncoder(self.cfg) |
| self.decoder = RLTokenDecoder(self.cfg) |
|
|
| def encode(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: |
| return self.encoder(prefix, mask) |
|
|
| def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None, context_dropout: float = 0.0): |
| |
| target = prefix.detach() |
| z_rl = self.encoder(target, mask) |
| pred = self.decoder(z_rl, target, mask, context_dropout=context_dropout) |
| per_token = (pred - target).pow(2).sum(dim=-1) |
| if mask is not None: |
| per_token = per_token * mask |
| denom = mask.sum(dim=1).clamp(min=1) |
| recon = (per_token.sum(dim=1) / denom) |
| else: |
| recon = per_token.mean(dim=1) |
| return z_rl, recon.mean() |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| |
| torch.manual_seed(0) |
| cfg = RLTokenConfig(dim=64, num_layers=2, num_heads=4, mlp_dim=128) |
| ae = RLTokenAutoencoder(cfg) |
| opt = torch.optim.AdamW(ae.parameters(), lr=1e-3) |
| B, M = 32, 12 |
| pos_pattern = torch.randn(M, cfg.dim) * 0.3 |
| def batch(): |
| c = torch.randn(B, cfg.dim) |
| x = c[:, None, :] + pos_pattern[None] |
| return x, torch.ones(B, M, dtype=torch.bool) |
| for step in range(600): |
| x, mask = batch() |
| z, loss = ae(x, mask) |
| opt.zero_grad(); loss.backward(); opt.step() |
| if step % 150 == 0 or step == 599: |
| print(f"step {step:3d} recon={loss.item():.4f}") |
| ae.eval() |
| with torch.no_grad(): |
| x, mask = batch() |
| z, _ = ae(x, mask) |
| def first_tok_err(zt): |
| pred = ae.decoder(zt, x, mask) |
| return (pred[:, 0] - x[:, 0]).pow(2).sum(-1).mean().item() |
| real0 = first_tok_err(z) |
| zero0 = first_tok_err(torch.zeros_like(z)) |
| shuf0 = first_tok_err(z[torch.randperm(B)]) |
| print(f"first-token recon: real={real0:.3f} zeroed={zero0:.3f} shuffled={shuf0:.3f}") |
| ok = real0 < 0.3 * zero0 and real0 < 0.3 * shuf0 |
| print("SELF-TEST:", "PASS ✅ (z_rl carries the prefix latent)" if ok else "FAIL ❌") |
|
|