File size: 9,184 Bytes
f9042a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""RL Token encoder-decoder for MolmoAct2 (RLT Stage 1) — PyTorch port.

Faithful port of openpi's ``pi0_rl.py`` (Xu et al. 2025, "RL Tokens") to PyTorch
for the frozen MolmoAct2 lerobot fork. Differences from my earlier
``rlt_logit_autoencoder.py`` (which was wrong): that one MLP-reconstructed the
2048-D action logits; THIS reconstructs the VLA's **per-token prefix hidden
states** ``(M, dim)`` with a transformer encoder + autoregressive decoder, so
the single ``z_rl`` token is forced to regenerate the whole prefix — the real
RLT bottleneck, and what todo Phase 3 specifies.

Design (matches the reference):
  encoder: append a learned <rl> query to the prefix embeddings (b, M, dim),
           run bidirectional pre-norm transformer blocks (RMSNorm + SwiGLU),
           read the query position  ->  z_rl (b, dim).
  decoder: autoregressive. input  [z_rl, z̄_1 … z̄_{M-1}], causal mask,
           predict [z̄_1 … z̄_M]; output_proj.
  loss:    per-token squared-L2 recon (sum over dim, masked mean over tokens),
           targets stop-gradiented. VLA is a frozen server here, so there is no
           L_vla term (alpha = 0): we only train the encoder/decoder.

z_rl is full-dim (= dim), exactly like the reference — the bottleneck is the
sequence compression (M tokens -> 1), not a narrow feature dim. Downstream SAC
consumes z_rl as its (frozen) RLT state.
"""
from __future__ import annotations

import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class RLTokenConfig:
    dim: int = 2560          # MolmoAct2 VLM hidden width (cached embeddings are 2560-D)
    num_layers: int = 2
    num_heads: int = 8       # 2560 / 8 = 320 head_dim
    mlp_dim: int = 8192


class _Block(nn.Module):
    """Pre-norm transformer block: MHA + SwiGLU FFN, RMSNorm. Matches the ref."""

    def __init__(self, dim: int, num_heads: int, mlp_dim: int):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.attn_norm = nn.RMSNorm(dim)
        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.o_proj = nn.Linear(dim, dim, bias=False)
        self.ffn_norm = nn.RMSNorm(dim)
        self.ffn_gate = nn.Linear(dim, mlp_dim, bias=False)
        self.ffn_up = nn.Linear(dim, mlp_dim, bias=False)
        self.ffn_down = nn.Linear(mlp_dim, dim, bias=False)

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None) -> torch.Tensor:
        b, s, d = x.shape
        h = self.attn_norm(x)
        q = self.q_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)  # (b,nh,s,hd)
        k = self.k_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
        # attn_mask: (b, s, s) bool, True = attend. -> (b,1,s,s) for SDPA additive.
        am = None
        if attn_mask is not None:
            am = torch.zeros(b, 1, s, s, dtype=x.dtype, device=x.device)
            am = am.masked_fill(~attn_mask[:, None, :, :], float("-inf"))
        attn = F.scaled_dot_product_attention(q, k, v, attn_mask=am)  # (b,nh,s,hd)
        attn = attn.transpose(1, 2).reshape(b, s, d)
        x = x + self.o_proj(attn)
        h = self.ffn_norm(x)
        x = x + self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h))
        return x


class RLTokenEncoder(nn.Module):
    """Compress prefix embeddings (b, M, dim) -> z_rl (b, dim) via a learned query."""

    def __init__(self, cfg: RLTokenConfig):
        super().__init__()
        self.rl_query = nn.Parameter(torch.randn(1, 1, cfg.dim) * 0.02)
        self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers))

    def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        b, m, d = prefix.shape
        query = self.rl_query.expand(b, 1, d)
        x = torch.cat([prefix, query], dim=1)  # (b, M+1, dim)
        if mask is not None:
            ext = torch.cat([mask, torch.ones(b, 1, dtype=torch.bool, device=mask.device)], dim=1)
            attn_mask = ext[:, None, :] & ext[:, :, None]  # (b, M+1, M+1) bidirectional
        else:
            attn_mask = None
        for layer in self.layers:
            x = layer(x, attn_mask)
        return x[:, -1, :]  # z_rl at the query position


class RLTokenDecoder(nn.Module):
    """Autoregressively reconstruct prefix embeddings from z_rl."""

    def __init__(self, cfg: RLTokenConfig):
        super().__init__()
        self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers))
        self.output_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)

    def forward(self, z_rl: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None,
                context_dropout: float = 0.0) -> torch.Tensor:
        # input [z_rl, z̄_1..z̄_{M-1}] -> predict [z̄_1..z̄_M]
        b, m, d = target.shape
        ctx = target[:, :-1, :]
        # Context dropout (train only): randomly zero teacher-forced context tokens
        # so the decoder cannot reconstruct purely from the true-previous-token leak
        # and is forced to route information through z_rl. Off (0.0) = bare reference.
        if self.training and context_dropout > 0.0:
            keep = (torch.rand(b, m - 1, 1, device=target.device) >= context_dropout).to(target.dtype)
            ctx = ctx * keep
        dec_in = torch.cat([z_rl[:, None, :], ctx], dim=1)  # (b, M, dim)
        causal = torch.tril(torch.ones(m, m, dtype=torch.bool, device=target.device))[None]  # (1,M,M)
        if mask is not None:
            key_valid = torch.cat([torch.ones(b, 1, dtype=torch.bool, device=mask.device), mask[:, :-1]], dim=1)
            attn_mask = causal & key_valid[:, None, :]  # (b, M, M)
        else:
            attn_mask = causal.expand(b, m, m)
        x = dec_in
        for layer in self.layers:
            x = layer(x, attn_mask)
        return self.output_proj(x)


class RLTokenAutoencoder(nn.Module):
    """Encoder + decoder. forward() returns (z_rl, recon_loss) for training."""

    def __init__(self, cfg: RLTokenConfig | None = None):
        super().__init__()
        self.cfg = cfg or RLTokenConfig()
        self.encoder = RLTokenEncoder(self.cfg)
        self.decoder = RLTokenDecoder(self.cfg)

    def encode(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        return self.encoder(prefix, mask)

    def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None, context_dropout: float = 0.0):
        # Targets are stop-gradiented (frozen VLA features). detach() = jax.lax.stop_gradient.
        target = prefix.detach()
        z_rl = self.encoder(target, mask)
        pred = self.decoder(z_rl, target, mask, context_dropout=context_dropout)
        per_token = (pred - target).pow(2).sum(dim=-1)  # (b, M) squared-L2 per token
        if mask is not None:
            per_token = per_token * mask
            denom = mask.sum(dim=1).clamp(min=1)
            recon = (per_token.sum(dim=1) / denom)  # (b,)
        else:
            recon = per_token.mean(dim=1)
        return z_rl, recon.mean()


if __name__ == "__main__":
    # Self-test on COMPRESSIBLE data: each sequence is a per-sample latent c
    # broadcast across positions + a small FIXED positional pattern. So one z_rl
    # can capture c. Fair ablation = FIRST-token recon: position 0 sees ONLY
    # z_rl (no AR context), so it isolates whether z_rl carries information.
    torch.manual_seed(0)
    cfg = RLTokenConfig(dim=64, num_layers=2, num_heads=4, mlp_dim=128)  # tiny for CPU
    ae = RLTokenAutoencoder(cfg)
    opt = torch.optim.AdamW(ae.parameters(), lr=1e-3)
    B, M = 32, 12
    pos_pattern = torch.randn(M, cfg.dim) * 0.3  # fixed per-position offset
    def batch():
        c = torch.randn(B, cfg.dim)                       # per-sample latent
        x = c[:, None, :] + pos_pattern[None]             # (B, M, dim), compressible
        return x, torch.ones(B, M, dtype=torch.bool)
    for step in range(600):
        x, mask = batch()
        z, loss = ae(x, mask)
        opt.zero_grad(); loss.backward(); opt.step()
        if step % 150 == 0 or step == 599:
            print(f"step {step:3d}  recon={loss.item():.4f}")
    ae.eval()
    with torch.no_grad():
        x, mask = batch()
        z, _ = ae(x, mask)
        def first_tok_err(zt):
            pred = ae.decoder(zt, x, mask)
            return (pred[:, 0] - x[:, 0]).pow(2).sum(-1).mean().item()  # token-0 only
        real0 = first_tok_err(z)
        zero0 = first_tok_err(torch.zeros_like(z))
        shuf0 = first_tok_err(z[torch.randperm(B)])
    print(f"first-token recon: real={real0:.3f}  zeroed={zero0:.3f}  shuffled={shuf0:.3f}")
    ok = real0 < 0.3 * zero0 and real0 < 0.3 * shuf0
    print("SELF-TEST:", "PASS ✅ (z_rl carries the prefix latent)" if ok else "FAIL ❌")