arc-agi-3-grid-jepa / src /models /predictor.py
guychuk's picture
Add action-conditioned predictor module
e332759 verified
"""Action-Conditioned Predictor for Grid-JEPA."""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiscreteActionEmbed(nn.Module):
def __init__(self, num_key_actions=6, grid_size=64, action_embed_dim=64):
super().__init__()
self.num_key_actions = num_key_actions
self.grid_size = grid_size
self.num_positions = grid_size * grid_size
self.key_embed = nn.Embedding(num_key_actions, action_embed_dim)
self.pos_embed = nn.Embedding(self.num_positions, action_embed_dim)
def forward(self, key_action, cell_position=None):
emb = self.key_embed(key_action)
if cell_position is not None:
emb = emb + self.pos_embed(cell_position)
return emb
class PredictorBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
hidden = int(dim * mlp_ratio)
self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))
def forward(self, x):
B, N, D = x.shape
xn = self.norm1(x)
qkv = self.qkv(xn).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, D)
x = x + self.proj(out)
x = x + self.mlp(self.norm2(x))
return x
class ActionConditionedPredictor(nn.Module):
def __init__(self, num_patches=4096, embed_dim=384, action_embed_dim=64,
predictor_depth=12, num_heads=6):
super().__init__()
self.num_patches = num_patches
self.embed_dim = embed_dim
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
nn.init.trunc_normal_(self.mask_token, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.action_mixer = nn.Linear(embed_dim + action_embed_dim, embed_dim)
self.blocks = nn.ModuleList([PredictorBlock(embed_dim, num_heads) for _ in range(predictor_depth)])
self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
def forward(self, context_repr, action_emb, target_positions):
B = context_repr.shape[0]
if target_positions.dim() == 1:
N_tgt = target_positions.shape[0]
tgt_pos = target_positions.unsqueeze(0).expand(B, -1)
else:
N_tgt = target_positions.shape[1]
tgt_pos = target_positions
tgt_pos_emb = torch.gather(
self.pos_embed.expand(B, -1, -1), 1,
tgt_pos.unsqueeze(-1).expand(-1, -1, self.embed_dim)
)
mask_tokens = self.mask_token.expand(B, N_tgt, -1) + tgt_pos_emb
act_tokens = action_emb.unsqueeze(1).expand(-1, N_tgt, -1)
tokens = self.action_mixer(torch.cat([mask_tokens, act_tokens], dim=-1))
x = torch.cat([context_repr, tokens], dim=1)
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x[:, -N_tgt:, :]
class GridWorldPredictor(nn.Module):
def __init__(self, num_colors=16, embed_dim=384, predictor_depth=12, num_heads=6,
num_key_actions=6, grid_size=64, action_embed_dim=64):
super().__init__()
self.num_patches = grid_size * grid_size
self.action_embed = DiscreteActionEmbed(num_key_actions, grid_size, action_embed_dim)
self.predictor = ActionConditionedPredictor(self.num_patches, embed_dim, action_embed_dim, predictor_depth, num_heads)
self.decoder = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, num_colors))
def predict_next_state(self, current_repr, action_key, action_pos=None):
B = current_repr.shape[0]
action_emb = self.action_embed(action_key, action_pos)
all_pos = torch.arange(self.num_patches, device=current_repr.device).unsqueeze(0).expand(B, -1)
return self.predictor(current_repr, action_emb, all_pos)
def decode_to_grid(self, repr):
return self.decoder(repr)
if __name__ == "__main__":
B, N_ctx, D, N_tgt = 2, 60, 128, 40
context = torch.randn(B, N_ctx, D)
action_key = torch.randint(0, 6, (B,))
action_pos = torch.randint(0, 100, (B,))
target_pos = torch.randint(0, 100, (N_tgt,))
embed = DiscreteActionEmbed(6, 10, 64)
print(f"Action emb: {embed(action_key, action_pos).shape}")
pred = ActionConditionedPredictor(100, D, 64, 4, 4)
print(f"Pred: {pred(context, embed(action_key, action_pos), target_pos).shape}")
wp = GridWorldPredictor(10, D, 4, 4, 6, 10, 64)
print(f"Next state: {wp.predict_next_state(context, action_key, action_pos).shape}")
print(f"Decode: {wp.decode_to_grid(wp.predict_next_state(context, action_key, action_pos)).shape}")