arc-agi-3-grid-jepa / src /models /grid_jepa.py
guychuk's picture
Add full Grid-JEPA system
0bec89e verified
"""Complete Grid-JEPA system."""
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
try:
from .encoder import GridPatchEmbed, ViTEncoder
from .predictor import DiscreteActionEmbed, ActionConditionedPredictor
except ImportError:
from encoder import GridPatchEmbed, ViTEncoder
from predictor import DiscreteActionEmbed, ActionConditionedPredictor
class GridJEPA(nn.Module):
def __init__(self, num_colors=16, embed_dim=384, encoder_depth=12, predictor_depth=12,
num_heads=6, mlp_ratio=4.0, max_grid_size=64, ema_decay=0.996,
num_key_actions=6, action_embed_dim=64, use_action_conditioning=True):
super().__init__()
self.num_colors = num_colors
self.embed_dim = embed_dim
self.max_grid_size = max_grid_size
self.num_patches = max_grid_size * max_grid_size
self.use_action_conditioning = use_action_conditioning
self.ema_decay = ema_decay
self.patch_embed = GridPatchEmbed(num_colors, embed_dim, max_grid_size)
self.context_encoder = ViTEncoder(embed_dim, encoder_depth, num_heads, mlp_ratio)
self.target_encoder = ViTEncoder(embed_dim, encoder_depth, num_heads, mlp_ratio)
self.target_encoder.load_state_dict(self.context_encoder.state_dict())
for p in self.target_encoder.parameters():
p.requires_grad = False
self.action_embed = DiscreteActionEmbed(num_key_actions, max_grid_size, action_embed_dim)
self.predictor = ActionConditionedPredictor(self.num_patches, embed_dim, action_embed_dim, predictor_depth, num_heads)
def forward(self, grid, context_mask, target_mask, action_key=None, action_pos=None):
B = grid.shape[0]
N = self.num_patches
x = self.patch_embed(grid)
ctx_x = x.clone()
ctx_x[~context_mask] = 0
context_repr = self.context_encoder(ctx_x)
with torch.no_grad():
target_repr = self.target_encoder(x)
target_indices = [target_mask[b].nonzero(as_tuple=True)[0] for b in range(B)]
if self.use_action_conditioning and action_key is not None:
action_emb = self.action_embed(action_key, action_pos)
else:
action_emb = torch.zeros(B, 64, device=grid.device)
losses = []
for b in range(B):
tgt_idx = target_indices[b]
if len(tgt_idx) == 0:
continue
pred = self.predictor(context_repr[b:b+1], action_emb[b:b+1], tgt_idx.unsqueeze(0))
tgt = target_repr[b, tgt_idx].unsqueeze(0)
losses.append((pred - tgt).pow(2).sum(dim=-1).mean())
if len(losses) == 0:
return torch.tensor(0.0, device=grid.device), {}
total = torch.stack(losses).mean()
metrics = {"loss": total.item(), "num_targets": sum(len(t) for t in target_indices) / B}
return total, metrics
def update_ema(self):
with torch.no_grad():
for pt, pc in zip(self.target_encoder.parameters(), self.context_encoder.parameters()):
pt.data.mul_(self.ema_decay).add_(pc.data, alpha=1 - self.ema_decay)
def encode(self, grid):
return self.context_encoder(self.patch_embed(grid))
@torch.no_grad()
def encode_target(self, grid):
return self.target_encoder(self.patch_embed(grid))
if __name__ == "__main__":
B, G, C = 2, 10, 10
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GridJEPA(C, 128, 4, 4, 4, max_grid_size=G).to(device)
grid = torch.randint(0, C, (B, G, G), device=device)
N = G * G
ctx_mask = torch.zeros(B, N, dtype=torch.bool, device=device)
tgt_mask = torch.zeros(B, N, dtype=torch.bool, device=device)
for b in range(B):
idx = list(range(N))
random.shuffle(idx)
split = N // 2
ctx_mask[b, idx[:split]] = True
tgt_mask[b, idx[split:]] = True
action_key = torch.randint(0, 6, (B,), device=device)
action_pos = torch.randint(0, N, (B,), device=device)
loss, metrics = model(grid, ctx_mask, tgt_mask, action_key, action_pos)
print(f"Loss: {loss.item():.2f}, Metrics: {metrics}")
model.update_ema()
print("EMA OK")
print(f"Encode: {model.encode(grid).shape}")