"""Grid-JEPA Encoder for ARC-AGI-3.""" import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class GridPatchEmbed(nn.Module): """Embed 2D color grids into patch tokens.""" def __init__(self, num_colors=16, embed_dim=384, max_grid_size=64): super().__init__() self.num_colors = num_colors self.embed_dim = embed_dim self.num_patches = max_grid_size * max_grid_size self.color_embed = nn.Embedding(num_colors, embed_dim) self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) nn.init.trunc_normal_(self.pos_embed, std=0.02) def forward(self, grid): B, H, W = grid.shape x = self.color_embed(grid) x = x.reshape(B, H * W, self.embed_dim) x = x + self.pos_embed[:, :H * W] return x class MultiHeadAttention(nn.Module): def __init__(self, dim, num_heads=6, qkv_bias=True, dropout=0.0): super().__init__() assert dim % num_heads == 0 self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim, bias=qkv_bias) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): B, N, D = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale if mask is not None: attn = attn.masked_fill(mask == 0, float("-inf")) attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, D) x = self.proj(x) x = self.dropout(x) return x class TransformerBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, dropout=0.0): super().__init__() self.norm1 = nn.LayerNorm(dim, eps=1e-6) self.attn = MultiHeadAttention(dim, num_heads, qkv_bias, dropout) self.norm2 = nn.LayerNorm(dim, eps=1e-6) mlp_hidden = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, mlp_hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden, dim), nn.Dropout(dropout), ) def forward(self, x, mask=None): x = x + self.attn(self.norm1(x), mask) x = x + self.mlp(self.norm2(x)) return x class ViTEncoder(nn.Module): def __init__(self, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4.0): super().__init__() self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads, mlp_ratio) for _ in range(depth) ]) self.norm = nn.LayerNorm(embed_dim, eps=1e-6) def forward(self, x, mask=None): for block in self.blocks: x = block(x, mask) return self.norm(x) class GridJEPAEncoder(nn.Module): def __init__(self, num_colors=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4.0, max_grid_size=64): super().__init__() self.patch_embed = GridPatchEmbed(num_colors, embed_dim, max_grid_size) self.encoder = ViTEncoder(embed_dim, depth, num_heads, mlp_ratio) self.embed_dim = embed_dim self.num_patches = max_grid_size * max_grid_size def forward(self, grid, mask=None): x = self.patch_embed(grid) return self.encoder(x, mask) class EMATargetEncoder(nn.Module): def __init__(self, context_encoder, ema_decay=0.996): super().__init__() self.ema_decay = ema_decay self.encoder = ViTEncoder( embed_dim=context_encoder.blocks[0].attn.head_dim * context_encoder.blocks[0].attn.num_heads, depth=len(context_encoder.blocks), num_heads=context_encoder.blocks[0].attn.num_heads, ) self.encoder.load_state_dict(context_encoder.state_dict()) for p in self.encoder.parameters(): p.requires_grad = False def update(self, context_encoder): with torch.no_grad(): for pt, pc in zip(self.encoder.parameters(), context_encoder.parameters()): pt.data.mul_(self.ema_decay).add_(pc.data, alpha=1 - self.ema_decay) def forward(self, x, mask=None): return self.encoder(x, mask) def build_encoders(num_colors=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4.0, max_grid_size=64, ema_decay=0.996): ctx = GridJEPAEncoder(num_colors, embed_dim, depth, num_heads, mlp_ratio, max_grid_size) tgt = EMATargetEncoder(ctx.encoder, ema_decay) return ctx, tgt if __name__ == "__main__": grid = torch.randint(0, 10, (2, 10, 10)) enc, tgt = build_encoders(num_colors=10, embed_dim=128, depth=4, num_heads=4, max_grid_size=10) out = enc(grid) print(f"Encoder: {out.shape}") with torch.no_grad(): print(f"Target: {tgt(enc.patch_embed(grid)).shape}") tgt.update(enc.encoder) print("EMA OK")