guychuk's picture
Add Grid-JEPA encoder module
7fc7772 verified
"""Grid-JEPA Encoder for ARC-AGI-3."""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class GridPatchEmbed(nn.Module):
"""Embed 2D color grids into patch tokens."""
def __init__(self, num_colors=16, embed_dim=384, max_grid_size=64):
super().__init__()
self.num_colors = num_colors
self.embed_dim = embed_dim
self.num_patches = max_grid_size * max_grid_size
self.color_embed = nn.Embedding(num_colors, embed_dim)
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, grid):
B, H, W = grid.shape
x = self.color_embed(grid)
x = x.reshape(B, H * W, self.embed_dim)
x = x + self.pos_embed[:, :H * W]
return x
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads=6, qkv_bias=True, dropout=0.0):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, D)
x = self.proj(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = MultiHeadAttention(dim, num_heads, qkv_bias, dropout)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
mlp_hidden = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden), nn.GELU(), nn.Dropout(dropout),
nn.Linear(mlp_hidden, dim), nn.Dropout(dropout),
)
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask)
x = x + self.mlp(self.norm2(x))
return x
class ViTEncoder(nn.Module):
def __init__(self, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4.0):
super().__init__()
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
def forward(self, x, mask=None):
for block in self.blocks:
x = block(x, mask)
return self.norm(x)
class GridJEPAEncoder(nn.Module):
def __init__(self, num_colors=16, embed_dim=384, depth=12, num_heads=6,
mlp_ratio=4.0, max_grid_size=64):
super().__init__()
self.patch_embed = GridPatchEmbed(num_colors, embed_dim, max_grid_size)
self.encoder = ViTEncoder(embed_dim, depth, num_heads, mlp_ratio)
self.embed_dim = embed_dim
self.num_patches = max_grid_size * max_grid_size
def forward(self, grid, mask=None):
x = self.patch_embed(grid)
return self.encoder(x, mask)
class EMATargetEncoder(nn.Module):
def __init__(self, context_encoder, ema_decay=0.996):
super().__init__()
self.ema_decay = ema_decay
self.encoder = ViTEncoder(
embed_dim=context_encoder.blocks[0].attn.head_dim * context_encoder.blocks[0].attn.num_heads,
depth=len(context_encoder.blocks),
num_heads=context_encoder.blocks[0].attn.num_heads,
)
self.encoder.load_state_dict(context_encoder.state_dict())
for p in self.encoder.parameters():
p.requires_grad = False
def update(self, context_encoder):
with torch.no_grad():
for pt, pc in zip(self.encoder.parameters(), context_encoder.parameters()):
pt.data.mul_(self.ema_decay).add_(pc.data, alpha=1 - self.ema_decay)
def forward(self, x, mask=None):
return self.encoder(x, mask)
def build_encoders(num_colors=16, embed_dim=384, depth=12, num_heads=6,
mlp_ratio=4.0, max_grid_size=64, ema_decay=0.996):
ctx = GridJEPAEncoder(num_colors, embed_dim, depth, num_heads, mlp_ratio, max_grid_size)
tgt = EMATargetEncoder(ctx.encoder, ema_decay)
return ctx, tgt
if __name__ == "__main__":
grid = torch.randint(0, 10, (2, 10, 10))
enc, tgt = build_encoders(num_colors=10, embed_dim=128, depth=4, num_heads=4, max_grid_size=10)
out = enc(grid)
print(f"Encoder: {out.shape}")
with torch.no_grad():
print(f"Target: {tgt(enc.patch_embed(grid)).shape}")
tgt.update(enc.encoder)
print("EMA OK")