"""Integration test for Grid-JEPA.""" import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "src")) import torch import numpy as np from data.arc_dataset import ARCDataset, collate_fn from data.masking import GridMaskingStrategy from models.grid_jepa import GridJEPA from torch.utils.data import DataLoader def test(): B, G, C = 4, 10, 10 device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") import json, shutil d = Path("./test_data") d.mkdir(exist_ok=True) (d / "training").mkdir(exist_ok=True) (d / "evaluation").mkdir(exist_ok=True) task = {"train": [{"input": [[0,0,0],[0,1,0],[0,0,0]], "output": [[0,0,0],[0,2,0],[0,0,0]]}], "test": [{"input": [[0,0,0],[0,1,0],[0,0,0]], "output": [[0,0,0],[0,2,0],[0,0,0]]}]} for s in ["training", "evaluation"]: with open(d / s / "task.json", "w") as f: json.dump(task, f) train_ds = ARCDataset(d, "training", G, C, augment=True) val_ds = ARCDataset(d, "evaluation", G, C, augment=False) train_loader = DataLoader(train_ds, batch_size=B, shuffle=True, collate_fn=collate_fn) val_loader = DataLoader(val_ds, batch_size=B, shuffle=False, collate_fn=collate_fn) print(f"Train: {len(train_ds)}, Val: {len(val_ds)}") model = GridJEPA(C, 128, 4, 4, 4, max_grid_size=G).to(device) params = list(model.patch_embed.parameters()) + list(model.context_encoder.parameters()) + list(model.predictor.parameters()) if model.use_action_conditioning: params += list(model.action_embed.parameters()) optimizer = torch.optim.AdamW(params, lr=1e-3, weight_decay=0.04) mask_strategy = GridMaskingStrategy(G, 2, (0.1, 0.3)) for epoch in range(1, 3): model.train() epoch_loss = 0.0 nb = 0 for batch in train_loader: grid = batch["input_grids"].to(device) Bb = grid.shape[0] masks = mask_strategy.sample_masks_batch(Bb) ctx = masks["context_mask"].to(device) tgt = masks["target_mask"].to(device) act = torch.randint(0, 6, (Bb,), device=device) pos = torch.randint(0, G * G, (Bb,), device=device) loss, _ = model(grid, ctx, tgt, act, pos) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() model.update_ema() epoch_loss += loss.item() nb += 1 avg = epoch_loss / nb model.eval() vl = 0.0 vb = 0 with torch.no_grad(): for batch in val_loader: grid = batch["input_grids"].to(device) Bb = grid.shape[0] masks = mask_strategy.sample_masks_batch(Bb) loss, _ = model(grid, masks["context_mask"].to(device), masks["target_mask"].to(device), torch.randint(0, 6, (Bb,), device=device), torch.randint(0, G * G, (Bb,), device=device)) vl += loss.item() vb += 1 print(f"Epoch {epoch}: train={avg:.2f}, val={vl/vb:.2f}") shutil.rmtree(d) print("\nIntegration test PASSED!") if __name__ == "__main__": test()