| """Integration test for Grid-JEPA.""" |
| import sys |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).parent.parent / "src")) |
|
|
| import torch |
| import numpy as np |
| from data.arc_dataset import ARCDataset, collate_fn |
| from data.masking import GridMaskingStrategy |
| from models.grid_jepa import GridJEPA |
| from torch.utils.data import DataLoader |
|
|
| def test(): |
| B, G, C = 4, 10, 10 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Device: {device}") |
| |
| import json, shutil |
| d = Path("./test_data") |
| d.mkdir(exist_ok=True) |
| (d / "training").mkdir(exist_ok=True) |
| (d / "evaluation").mkdir(exist_ok=True) |
| task = {"train": [{"input": [[0,0,0],[0,1,0],[0,0,0]], "output": [[0,0,0],[0,2,0],[0,0,0]]}], |
| "test": [{"input": [[0,0,0],[0,1,0],[0,0,0]], "output": [[0,0,0],[0,2,0],[0,0,0]]}]} |
| for s in ["training", "evaluation"]: |
| with open(d / s / "task.json", "w") as f: |
| json.dump(task, f) |
| |
| train_ds = ARCDataset(d, "training", G, C, augment=True) |
| val_ds = ARCDataset(d, "evaluation", G, C, augment=False) |
| train_loader = DataLoader(train_ds, batch_size=B, shuffle=True, collate_fn=collate_fn) |
| val_loader = DataLoader(val_ds, batch_size=B, shuffle=False, collate_fn=collate_fn) |
| print(f"Train: {len(train_ds)}, Val: {len(val_ds)}") |
| |
| model = GridJEPA(C, 128, 4, 4, 4, max_grid_size=G).to(device) |
| params = list(model.patch_embed.parameters()) + list(model.context_encoder.parameters()) + list(model.predictor.parameters()) |
| if model.use_action_conditioning: |
| params += list(model.action_embed.parameters()) |
| optimizer = torch.optim.AdamW(params, lr=1e-3, weight_decay=0.04) |
| mask_strategy = GridMaskingStrategy(G, 2, (0.1, 0.3)) |
| |
| for epoch in range(1, 3): |
| model.train() |
| epoch_loss = 0.0 |
| nb = 0 |
| for batch in train_loader: |
| grid = batch["input_grids"].to(device) |
| Bb = grid.shape[0] |
| masks = mask_strategy.sample_masks_batch(Bb) |
| ctx = masks["context_mask"].to(device) |
| tgt = masks["target_mask"].to(device) |
| act = torch.randint(0, 6, (Bb,), device=device) |
| pos = torch.randint(0, G * G, (Bb,), device=device) |
| loss, _ = model(grid, ctx, tgt, act, pos) |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| model.update_ema() |
| epoch_loss += loss.item() |
| nb += 1 |
| avg = epoch_loss / nb |
| |
| model.eval() |
| vl = 0.0 |
| vb = 0 |
| with torch.no_grad(): |
| for batch in val_loader: |
| grid = batch["input_grids"].to(device) |
| Bb = grid.shape[0] |
| masks = mask_strategy.sample_masks_batch(Bb) |
| loss, _ = model(grid, masks["context_mask"].to(device), masks["target_mask"].to(device), |
| torch.randint(0, 6, (Bb,), device=device), torch.randint(0, G * G, (Bb,), device=device)) |
| vl += loss.item() |
| vb += 1 |
| print(f"Epoch {epoch}: train={avg:.2f}, val={vl/vb:.2f}") |
| |
| shutil.rmtree(d) |
| print("\nIntegration test PASSED!") |
|
|
| if __name__ == "__main__": |
| test() |
|
|