arc-agi-3-grid-jepa / tests /test_integration.py
guychuk's picture
Add integration test
fd16635 verified
"""Integration test for Grid-JEPA."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
import torch
import numpy as np
from data.arc_dataset import ARCDataset, collate_fn
from data.masking import GridMaskingStrategy
from models.grid_jepa import GridJEPA
from torch.utils.data import DataLoader
def test():
B, G, C = 4, 10, 10
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
import json, shutil
d = Path("./test_data")
d.mkdir(exist_ok=True)
(d / "training").mkdir(exist_ok=True)
(d / "evaluation").mkdir(exist_ok=True)
task = {"train": [{"input": [[0,0,0],[0,1,0],[0,0,0]], "output": [[0,0,0],[0,2,0],[0,0,0]]}],
"test": [{"input": [[0,0,0],[0,1,0],[0,0,0]], "output": [[0,0,0],[0,2,0],[0,0,0]]}]}
for s in ["training", "evaluation"]:
with open(d / s / "task.json", "w") as f:
json.dump(task, f)
train_ds = ARCDataset(d, "training", G, C, augment=True)
val_ds = ARCDataset(d, "evaluation", G, C, augment=False)
train_loader = DataLoader(train_ds, batch_size=B, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=B, shuffle=False, collate_fn=collate_fn)
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")
model = GridJEPA(C, 128, 4, 4, 4, max_grid_size=G).to(device)
params = list(model.patch_embed.parameters()) + list(model.context_encoder.parameters()) + list(model.predictor.parameters())
if model.use_action_conditioning:
params += list(model.action_embed.parameters())
optimizer = torch.optim.AdamW(params, lr=1e-3, weight_decay=0.04)
mask_strategy = GridMaskingStrategy(G, 2, (0.1, 0.3))
for epoch in range(1, 3):
model.train()
epoch_loss = 0.0
nb = 0
for batch in train_loader:
grid = batch["input_grids"].to(device)
Bb = grid.shape[0]
masks = mask_strategy.sample_masks_batch(Bb)
ctx = masks["context_mask"].to(device)
tgt = masks["target_mask"].to(device)
act = torch.randint(0, 6, (Bb,), device=device)
pos = torch.randint(0, G * G, (Bb,), device=device)
loss, _ = model(grid, ctx, tgt, act, pos)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
model.update_ema()
epoch_loss += loss.item()
nb += 1
avg = epoch_loss / nb
model.eval()
vl = 0.0
vb = 0
with torch.no_grad():
for batch in val_loader:
grid = batch["input_grids"].to(device)
Bb = grid.shape[0]
masks = mask_strategy.sample_masks_batch(Bb)
loss, _ = model(grid, masks["context_mask"].to(device), masks["target_mask"].to(device),
torch.randint(0, 6, (Bb,), device=device), torch.randint(0, G * G, (Bb,), device=device))
vl += loss.item()
vb += 1
print(f"Epoch {epoch}: train={avg:.2f}, val={vl/vb:.2f}")
shutil.rmtree(d)
print("\nIntegration test PASSED!")
if __name__ == "__main__":
test()