| """Grid-JEPA Encoder for ARC-AGI-3.""" |
| import math |
| from typing import Optional, Tuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class GridPatchEmbed(nn.Module): |
| """Embed 2D color grids into patch tokens.""" |
| def __init__(self, num_colors=16, embed_dim=384, max_grid_size=64): |
| super().__init__() |
| self.num_colors = num_colors |
| self.embed_dim = embed_dim |
| self.num_patches = max_grid_size * max_grid_size |
| self.color_embed = nn.Embedding(num_colors, embed_dim) |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| def forward(self, grid): |
| B, H, W = grid.shape |
| x = self.color_embed(grid) |
| x = x.reshape(B, H * W, self.embed_dim) |
| x = x + self.pos_embed[:, :H * W] |
| return x |
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, dim, num_heads=6, qkv_bias=True, dropout=0.0): |
| super().__init__() |
| assert dim % num_heads == 0 |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = self.head_dim ** -0.5 |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.proj = nn.Linear(dim, dim, bias=qkv_bias) |
| self.dropout = nn.Dropout(dropout) |
| def forward(self, x, mask=None): |
| B, N, D = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| if mask is not None: |
| attn = attn.masked_fill(mask == 0, float("-inf")) |
| attn = F.softmax(attn, dim=-1) |
| attn = self.dropout(attn) |
| x = (attn @ v).transpose(1, 2).reshape(B, N, D) |
| x = self.proj(x) |
| x = self.dropout(x) |
| return x |
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, dropout=0.0): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
| self.attn = MultiHeadAttention(dim, num_heads, qkv_bias, dropout) |
| self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
| mlp_hidden = int(dim * mlp_ratio) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_hidden), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(mlp_hidden, dim), nn.Dropout(dropout), |
| ) |
| def forward(self, x, mask=None): |
| x = x + self.attn(self.norm1(x), mask) |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
| class ViTEncoder(nn.Module): |
| def __init__(self, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4.0): |
| super().__init__() |
| self.blocks = nn.ModuleList([ |
| TransformerBlock(embed_dim, num_heads, mlp_ratio) |
| for _ in range(depth) |
| ]) |
| self.norm = nn.LayerNorm(embed_dim, eps=1e-6) |
| def forward(self, x, mask=None): |
| for block in self.blocks: |
| x = block(x, mask) |
| return self.norm(x) |
|
|
| class GridJEPAEncoder(nn.Module): |
| def __init__(self, num_colors=16, embed_dim=384, depth=12, num_heads=6, |
| mlp_ratio=4.0, max_grid_size=64): |
| super().__init__() |
| self.patch_embed = GridPatchEmbed(num_colors, embed_dim, max_grid_size) |
| self.encoder = ViTEncoder(embed_dim, depth, num_heads, mlp_ratio) |
| self.embed_dim = embed_dim |
| self.num_patches = max_grid_size * max_grid_size |
| def forward(self, grid, mask=None): |
| x = self.patch_embed(grid) |
| return self.encoder(x, mask) |
|
|
| class EMATargetEncoder(nn.Module): |
| def __init__(self, context_encoder, ema_decay=0.996): |
| super().__init__() |
| self.ema_decay = ema_decay |
| self.encoder = ViTEncoder( |
| embed_dim=context_encoder.blocks[0].attn.head_dim * context_encoder.blocks[0].attn.num_heads, |
| depth=len(context_encoder.blocks), |
| num_heads=context_encoder.blocks[0].attn.num_heads, |
| ) |
| self.encoder.load_state_dict(context_encoder.state_dict()) |
| for p in self.encoder.parameters(): |
| p.requires_grad = False |
| def update(self, context_encoder): |
| with torch.no_grad(): |
| for pt, pc in zip(self.encoder.parameters(), context_encoder.parameters()): |
| pt.data.mul_(self.ema_decay).add_(pc.data, alpha=1 - self.ema_decay) |
| def forward(self, x, mask=None): |
| return self.encoder(x, mask) |
|
|
| def build_encoders(num_colors=16, embed_dim=384, depth=12, num_heads=6, |
| mlp_ratio=4.0, max_grid_size=64, ema_decay=0.996): |
| ctx = GridJEPAEncoder(num_colors, embed_dim, depth, num_heads, mlp_ratio, max_grid_size) |
| tgt = EMATargetEncoder(ctx.encoder, ema_decay) |
| return ctx, tgt |
|
|
| if __name__ == "__main__": |
| grid = torch.randint(0, 10, (2, 10, 10)) |
| enc, tgt = build_encoders(num_colors=10, embed_dim=128, depth=4, num_heads=4, max_grid_size=10) |
| out = enc(grid) |
| print(f"Encoder: {out.shape}") |
| with torch.no_grad(): |
| print(f"Target: {tgt(enc.patch_embed(grid)).shape}") |
| tgt.update(enc.encoder) |
| print("EMA OK") |
|
|