File size: 5,115 Bytes
7fc7772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Grid-JEPA Encoder for ARC-AGI-3."""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

class GridPatchEmbed(nn.Module):
    """Embed 2D color grids into patch tokens."""
    def __init__(self, num_colors=16, embed_dim=384, max_grid_size=64):
        super().__init__()
        self.num_colors = num_colors
        self.embed_dim = embed_dim
        self.num_patches = max_grid_size * max_grid_size
        self.color_embed = nn.Embedding(num_colors, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
    def forward(self, grid):
        B, H, W = grid.shape
        x = self.color_embed(grid)
        x = x.reshape(B, H * W, self.embed_dim)
        x = x + self.pos_embed[:, :H * W]
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads=6, qkv_bias=True, dropout=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, mask=None):
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = MultiHeadAttention(dim, num_heads, qkv_bias, dropout)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        mlp_hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(mlp_hidden, dim), nn.Dropout(dropout),
        )
    def forward(self, x, mask=None):
        x = x + self.attn(self.norm1(x), mask)
        x = x + self.mlp(self.norm2(x))
        return x

class ViTEncoder(nn.Module):
    def __init__(self, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return self.norm(x)

class GridJEPAEncoder(nn.Module):
    def __init__(self, num_colors=16, embed_dim=384, depth=12, num_heads=6,
                 mlp_ratio=4.0, max_grid_size=64):
        super().__init__()
        self.patch_embed = GridPatchEmbed(num_colors, embed_dim, max_grid_size)
        self.encoder = ViTEncoder(embed_dim, depth, num_heads, mlp_ratio)
        self.embed_dim = embed_dim
        self.num_patches = max_grid_size * max_grid_size
    def forward(self, grid, mask=None):
        x = self.patch_embed(grid)
        return self.encoder(x, mask)

class EMATargetEncoder(nn.Module):
    def __init__(self, context_encoder, ema_decay=0.996):
        super().__init__()
        self.ema_decay = ema_decay
        self.encoder = ViTEncoder(
            embed_dim=context_encoder.blocks[0].attn.head_dim * context_encoder.blocks[0].attn.num_heads,
            depth=len(context_encoder.blocks),
            num_heads=context_encoder.blocks[0].attn.num_heads,
        )
        self.encoder.load_state_dict(context_encoder.state_dict())
        for p in self.encoder.parameters():
            p.requires_grad = False
    def update(self, context_encoder):
        with torch.no_grad():
            for pt, pc in zip(self.encoder.parameters(), context_encoder.parameters()):
                pt.data.mul_(self.ema_decay).add_(pc.data, alpha=1 - self.ema_decay)
    def forward(self, x, mask=None):
        return self.encoder(x, mask)

def build_encoders(num_colors=16, embed_dim=384, depth=12, num_heads=6,
                   mlp_ratio=4.0, max_grid_size=64, ema_decay=0.996):
    ctx = GridJEPAEncoder(num_colors, embed_dim, depth, num_heads, mlp_ratio, max_grid_size)
    tgt = EMATargetEncoder(ctx.encoder, ema_decay)
    return ctx, tgt

if __name__ == "__main__":
    grid = torch.randint(0, 10, (2, 10, 10))
    enc, tgt = build_encoders(num_colors=10, embed_dim=128, depth=4, num_heads=4, max_grid_size=10)
    out = enc(grid)
    print(f"Encoder: {out.shape}")
    with torch.no_grad():
        print(f"Target: {tgt(enc.patch_embed(grid)).shape}")
    tgt.update(enc.encoder)
    print("EMA OK")