"""ARC dataset loader.""" import json import os import random import numpy as np import torch from torch.utils.data import Dataset from pathlib import Path from typing import Dict, List, Union class ARCDataset(Dataset): def __init__(self, data_dir, split="training", max_grid_size=64, num_colors=16, augment=False, mode="pairs"): self.data_dir = Path(data_dir) self.split = split self.max_grid_size = max_grid_size self.num_colors = num_colors self.augment = augment self.mode = mode self.samples = self._load() def _load(self): samples = [] task_dir = self.data_dir / self.split if not task_dir.exists(): raise FileNotFoundError(task_dir) for f in sorted(task_dir.glob("*.json")): with open(f) as fh: task = json.load(fh) tid = f.stem for p in task.get("train", []): samples.append({"task_id": tid, "input": np.array(p["input"], dtype=np.int64), "output": np.array(p["output"], dtype=np.int64), "split": "train"}) for p in task.get("test", []): samples.append({"task_id": tid, "input": np.array(p["input"], dtype=np.int64), "output": np.array(p["output"], dtype=np.int64), "split": "test"}) return samples def _pad(self, g): h, w = g.shape if h > self.max_grid_size or w > self.max_grid_size: g = g[:self.max_grid_size, :self.max_grid_size] h, w = g.shape ph = self.max_grid_size - h pw = self.max_grid_size - w if ph > 0 or pw > 0: g = np.pad(g, ((0, ph), (0, pw)), mode="constant", constant_values=0) return g def _aug(self, inp, out): if not self.augment: return inp, out if random.random() < 0.5: inp, out = np.flip(inp, 0).copy(), np.flip(out, 0).copy() if random.random() < 0.5: inp, out = np.flip(inp, 1).copy(), np.flip(out, 1).copy() k = random.randint(0, 3) if k > 0: inp, out = np.rot90(inp, k=k).copy(), np.rot90(out, k=k).copy() if random.random() < 0.5: colors = np.unique(np.concatenate([inp.flatten(), out.flatten()])) colors = colors[colors > 0] if len(colors) > 1: shuffled = colors.copy() np.random.shuffle(shuffled) cmap = {c: shuffled[i] for i, c in enumerate(colors)} cmap[0] = 0 inp = np.vectorize(cmap.get)(inp) out = np.vectorize(cmap.get)(out) return inp, out def __len__(self): return len(self.samples) def __getitem__(self, idx): s = self.samples[idx] inp, out = self._aug(s["input"], s["output"]) inp = self._pad(inp) out = self._pad(out) return { "task_id": s["task_id"], "input_grid": torch.from_numpy(np.clip(inp, 0, self.num_colors - 1)).long(), "output_grid": torch.from_numpy(np.clip(out, 0, self.num_colors - 1)).long(), "split": s["split"], } def collate_fn(batch): return { "task_ids": [b["task_id"] for b in batch], "input_grids": torch.stack([b["input_grid"] for b in batch]), "output_grids": torch.stack([b["output_grid"] for b in batch]), "splits": [b["split"] for b in batch], } def load_arc_data(source="local", split="training", cache_dir="./data/arc_agi_1", **kwargs): return ARCDataset(cache_dir, split=split, **kwargs)