arc-agi-3-grid-jepa / src /data /arc_dataset.py
guychuk's picture
Add ARC dataset loader
c211096 verified
"""ARC dataset loader."""
import json
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from pathlib import Path
from typing import Dict, List, Union
class ARCDataset(Dataset):
def __init__(self, data_dir, split="training", max_grid_size=64, num_colors=16,
augment=False, mode="pairs"):
self.data_dir = Path(data_dir)
self.split = split
self.max_grid_size = max_grid_size
self.num_colors = num_colors
self.augment = augment
self.mode = mode
self.samples = self._load()
def _load(self):
samples = []
task_dir = self.data_dir / self.split
if not task_dir.exists():
raise FileNotFoundError(task_dir)
for f in sorted(task_dir.glob("*.json")):
with open(f) as fh:
task = json.load(fh)
tid = f.stem
for p in task.get("train", []):
samples.append({"task_id": tid, "input": np.array(p["input"], dtype=np.int64),
"output": np.array(p["output"], dtype=np.int64), "split": "train"})
for p in task.get("test", []):
samples.append({"task_id": tid, "input": np.array(p["input"], dtype=np.int64),
"output": np.array(p["output"], dtype=np.int64), "split": "test"})
return samples
def _pad(self, g):
h, w = g.shape
if h > self.max_grid_size or w > self.max_grid_size:
g = g[:self.max_grid_size, :self.max_grid_size]
h, w = g.shape
ph = self.max_grid_size - h
pw = self.max_grid_size - w
if ph > 0 or pw > 0:
g = np.pad(g, ((0, ph), (0, pw)), mode="constant", constant_values=0)
return g
def _aug(self, inp, out):
if not self.augment:
return inp, out
if random.random() < 0.5:
inp, out = np.flip(inp, 0).copy(), np.flip(out, 0).copy()
if random.random() < 0.5:
inp, out = np.flip(inp, 1).copy(), np.flip(out, 1).copy()
k = random.randint(0, 3)
if k > 0:
inp, out = np.rot90(inp, k=k).copy(), np.rot90(out, k=k).copy()
if random.random() < 0.5:
colors = np.unique(np.concatenate([inp.flatten(), out.flatten()]))
colors = colors[colors > 0]
if len(colors) > 1:
shuffled = colors.copy()
np.random.shuffle(shuffled)
cmap = {c: shuffled[i] for i, c in enumerate(colors)}
cmap[0] = 0
inp = np.vectorize(cmap.get)(inp)
out = np.vectorize(cmap.get)(out)
return inp, out
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
s = self.samples[idx]
inp, out = self._aug(s["input"], s["output"])
inp = self._pad(inp)
out = self._pad(out)
return {
"task_id": s["task_id"],
"input_grid": torch.from_numpy(np.clip(inp, 0, self.num_colors - 1)).long(),
"output_grid": torch.from_numpy(np.clip(out, 0, self.num_colors - 1)).long(),
"split": s["split"],
}
def collate_fn(batch):
return {
"task_ids": [b["task_id"] for b in batch],
"input_grids": torch.stack([b["input_grid"] for b in batch]),
"output_grids": torch.stack([b["output_grid"] for b in batch]),
"splits": [b["split"] for b in batch],
}
def load_arc_data(source="local", split="training", cache_dir="./data/arc_agi_1", **kwargs):
return ARCDataset(cache_dir, split=split, **kwargs)