atharva-pantheon commited on
Commit
f9042a0
·
verified ·
1 Parent(s): af282a6

Upload code/rl_token_encoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/rl_token_encoder.py +194 -0
code/rl_token_encoder.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RL Token encoder-decoder for MolmoAct2 (RLT Stage 1) — PyTorch port.
2
+
3
+ Faithful port of openpi's ``pi0_rl.py`` (Xu et al. 2025, "RL Tokens") to PyTorch
4
+ for the frozen MolmoAct2 lerobot fork. Differences from my earlier
5
+ ``rlt_logit_autoencoder.py`` (which was wrong): that one MLP-reconstructed the
6
+ 2048-D action logits; THIS reconstructs the VLA's **per-token prefix hidden
7
+ states** ``(M, dim)`` with a transformer encoder + autoregressive decoder, so
8
+ the single ``z_rl`` token is forced to regenerate the whole prefix — the real
9
+ RLT bottleneck, and what todo Phase 3 specifies.
10
+
11
+ Design (matches the reference):
12
+ encoder: append a learned <rl> query to the prefix embeddings (b, M, dim),
13
+ run bidirectional pre-norm transformer blocks (RMSNorm + SwiGLU),
14
+ read the query position -> z_rl (b, dim).
15
+ decoder: autoregressive. input [z_rl, z̄_1 … z̄_{M-1}], causal mask,
16
+ predict [z̄_1 … z̄_M]; output_proj.
17
+ loss: per-token squared-L2 recon (sum over dim, masked mean over tokens),
18
+ targets stop-gradiented. VLA is a frozen server here, so there is no
19
+ L_vla term (alpha = 0): we only train the encoder/decoder.
20
+
21
+ z_rl is full-dim (= dim), exactly like the reference — the bottleneck is the
22
+ sequence compression (M tokens -> 1), not a narrow feature dim. Downstream SAC
23
+ consumes z_rl as its (frozen) RLT state.
24
+ """
25
+ from __future__ import annotations
26
+
27
+ import math
28
+ from dataclasses import dataclass
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+
35
+ @dataclass
36
+ class RLTokenConfig:
37
+ dim: int = 2560 # MolmoAct2 VLM hidden width (cached embeddings are 2560-D)
38
+ num_layers: int = 2
39
+ num_heads: int = 8 # 2560 / 8 = 320 head_dim
40
+ mlp_dim: int = 8192
41
+
42
+
43
+ class _Block(nn.Module):
44
+ """Pre-norm transformer block: MHA + SwiGLU FFN, RMSNorm. Matches the ref."""
45
+
46
+ def __init__(self, dim: int, num_heads: int, mlp_dim: int):
47
+ super().__init__()
48
+ assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
49
+ self.num_heads = num_heads
50
+ self.head_dim = dim // num_heads
51
+ self.attn_norm = nn.RMSNorm(dim)
52
+ self.q_proj = nn.Linear(dim, dim, bias=False)
53
+ self.k_proj = nn.Linear(dim, dim, bias=False)
54
+ self.v_proj = nn.Linear(dim, dim, bias=False)
55
+ self.o_proj = nn.Linear(dim, dim, bias=False)
56
+ self.ffn_norm = nn.RMSNorm(dim)
57
+ self.ffn_gate = nn.Linear(dim, mlp_dim, bias=False)
58
+ self.ffn_up = nn.Linear(dim, mlp_dim, bias=False)
59
+ self.ffn_down = nn.Linear(mlp_dim, dim, bias=False)
60
+
61
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None) -> torch.Tensor:
62
+ b, s, d = x.shape
63
+ h = self.attn_norm(x)
64
+ q = self.q_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2) # (b,nh,s,hd)
65
+ k = self.k_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
66
+ v = self.v_proj(h).view(b, s, self.num_heads, self.head_dim).transpose(1, 2)
67
+ # attn_mask: (b, s, s) bool, True = attend. -> (b,1,s,s) for SDPA additive.
68
+ am = None
69
+ if attn_mask is not None:
70
+ am = torch.zeros(b, 1, s, s, dtype=x.dtype, device=x.device)
71
+ am = am.masked_fill(~attn_mask[:, None, :, :], float("-inf"))
72
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask=am) # (b,nh,s,hd)
73
+ attn = attn.transpose(1, 2).reshape(b, s, d)
74
+ x = x + self.o_proj(attn)
75
+ h = self.ffn_norm(x)
76
+ x = x + self.ffn_down(F.silu(self.ffn_gate(h)) * self.ffn_up(h))
77
+ return x
78
+
79
+
80
+ class RLTokenEncoder(nn.Module):
81
+ """Compress prefix embeddings (b, M, dim) -> z_rl (b, dim) via a learned query."""
82
+
83
+ def __init__(self, cfg: RLTokenConfig):
84
+ super().__init__()
85
+ self.rl_query = nn.Parameter(torch.randn(1, 1, cfg.dim) * 0.02)
86
+ self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers))
87
+
88
+ def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
89
+ b, m, d = prefix.shape
90
+ query = self.rl_query.expand(b, 1, d)
91
+ x = torch.cat([prefix, query], dim=1) # (b, M+1, dim)
92
+ if mask is not None:
93
+ ext = torch.cat([mask, torch.ones(b, 1, dtype=torch.bool, device=mask.device)], dim=1)
94
+ attn_mask = ext[:, None, :] & ext[:, :, None] # (b, M+1, M+1) bidirectional
95
+ else:
96
+ attn_mask = None
97
+ for layer in self.layers:
98
+ x = layer(x, attn_mask)
99
+ return x[:, -1, :] # z_rl at the query position
100
+
101
+
102
+ class RLTokenDecoder(nn.Module):
103
+ """Autoregressively reconstruct prefix embeddings from z_rl."""
104
+
105
+ def __init__(self, cfg: RLTokenConfig):
106
+ super().__init__()
107
+ self.layers = nn.ModuleList(_Block(cfg.dim, cfg.num_heads, cfg.mlp_dim) for _ in range(cfg.num_layers))
108
+ self.output_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
109
+
110
+ def forward(self, z_rl: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None,
111
+ context_dropout: float = 0.0) -> torch.Tensor:
112
+ # input [z_rl, z̄_1..z̄_{M-1}] -> predict [z̄_1..z̄_M]
113
+ b, m, d = target.shape
114
+ ctx = target[:, :-1, :]
115
+ # Context dropout (train only): randomly zero teacher-forced context tokens
116
+ # so the decoder cannot reconstruct purely from the true-previous-token leak
117
+ # and is forced to route information through z_rl. Off (0.0) = bare reference.
118
+ if self.training and context_dropout > 0.0:
119
+ keep = (torch.rand(b, m - 1, 1, device=target.device) >= context_dropout).to(target.dtype)
120
+ ctx = ctx * keep
121
+ dec_in = torch.cat([z_rl[:, None, :], ctx], dim=1) # (b, M, dim)
122
+ causal = torch.tril(torch.ones(m, m, dtype=torch.bool, device=target.device))[None] # (1,M,M)
123
+ if mask is not None:
124
+ key_valid = torch.cat([torch.ones(b, 1, dtype=torch.bool, device=mask.device), mask[:, :-1]], dim=1)
125
+ attn_mask = causal & key_valid[:, None, :] # (b, M, M)
126
+ else:
127
+ attn_mask = causal.expand(b, m, m)
128
+ x = dec_in
129
+ for layer in self.layers:
130
+ x = layer(x, attn_mask)
131
+ return self.output_proj(x)
132
+
133
+
134
+ class RLTokenAutoencoder(nn.Module):
135
+ """Encoder + decoder. forward() returns (z_rl, recon_loss) for training."""
136
+
137
+ def __init__(self, cfg: RLTokenConfig | None = None):
138
+ super().__init__()
139
+ self.cfg = cfg or RLTokenConfig()
140
+ self.encoder = RLTokenEncoder(self.cfg)
141
+ self.decoder = RLTokenDecoder(self.cfg)
142
+
143
+ def encode(self, prefix: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
144
+ return self.encoder(prefix, mask)
145
+
146
+ def forward(self, prefix: torch.Tensor, mask: torch.Tensor | None = None, context_dropout: float = 0.0):
147
+ # Targets are stop-gradiented (frozen VLA features). detach() = jax.lax.stop_gradient.
148
+ target = prefix.detach()
149
+ z_rl = self.encoder(target, mask)
150
+ pred = self.decoder(z_rl, target, mask, context_dropout=context_dropout)
151
+ per_token = (pred - target).pow(2).sum(dim=-1) # (b, M) squared-L2 per token
152
+ if mask is not None:
153
+ per_token = per_token * mask
154
+ denom = mask.sum(dim=1).clamp(min=1)
155
+ recon = (per_token.sum(dim=1) / denom) # (b,)
156
+ else:
157
+ recon = per_token.mean(dim=1)
158
+ return z_rl, recon.mean()
159
+
160
+
161
+ if __name__ == "__main__":
162
+ # Self-test on COMPRESSIBLE data: each sequence is a per-sample latent c
163
+ # broadcast across positions + a small FIXED positional pattern. So one z_rl
164
+ # can capture c. Fair ablation = FIRST-token recon: position 0 sees ONLY
165
+ # z_rl (no AR context), so it isolates whether z_rl carries information.
166
+ torch.manual_seed(0)
167
+ cfg = RLTokenConfig(dim=64, num_layers=2, num_heads=4, mlp_dim=128) # tiny for CPU
168
+ ae = RLTokenAutoencoder(cfg)
169
+ opt = torch.optim.AdamW(ae.parameters(), lr=1e-3)
170
+ B, M = 32, 12
171
+ pos_pattern = torch.randn(M, cfg.dim) * 0.3 # fixed per-position offset
172
+ def batch():
173
+ c = torch.randn(B, cfg.dim) # per-sample latent
174
+ x = c[:, None, :] + pos_pattern[None] # (B, M, dim), compressible
175
+ return x, torch.ones(B, M, dtype=torch.bool)
176
+ for step in range(600):
177
+ x, mask = batch()
178
+ z, loss = ae(x, mask)
179
+ opt.zero_grad(); loss.backward(); opt.step()
180
+ if step % 150 == 0 or step == 599:
181
+ print(f"step {step:3d} recon={loss.item():.4f}")
182
+ ae.eval()
183
+ with torch.no_grad():
184
+ x, mask = batch()
185
+ z, _ = ae(x, mask)
186
+ def first_tok_err(zt):
187
+ pred = ae.decoder(zt, x, mask)
188
+ return (pred[:, 0] - x[:, 0]).pow(2).sum(-1).mean().item() # token-0 only
189
+ real0 = first_tok_err(z)
190
+ zero0 = first_tok_err(torch.zeros_like(z))
191
+ shuf0 = first_tok_err(z[torch.randperm(B)])
192
+ print(f"first-token recon: real={real0:.3f} zeroed={zero0:.3f} shuffled={shuf0:.3f}")
193
+ ok = real0 < 0.3 * zero0 and real0 < 0.3 * shuf0
194
+ print("SELF-TEST:", "PASS ✅ (z_rl carries the prefix latent)" if ok else "FAIL ❌")