| """ARC dataset loader.""" |
| import json |
| import os |
| import random |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| from pathlib import Path |
| from typing import Dict, List, Union |
|
|
|
|
| class ARCDataset(Dataset): |
| def __init__(self, data_dir, split="training", max_grid_size=64, num_colors=16, |
| augment=False, mode="pairs"): |
| self.data_dir = Path(data_dir) |
| self.split = split |
| self.max_grid_size = max_grid_size |
| self.num_colors = num_colors |
| self.augment = augment |
| self.mode = mode |
| self.samples = self._load() |
| |
| def _load(self): |
| samples = [] |
| task_dir = self.data_dir / self.split |
| if not task_dir.exists(): |
| raise FileNotFoundError(task_dir) |
| for f in sorted(task_dir.glob("*.json")): |
| with open(f) as fh: |
| task = json.load(fh) |
| tid = f.stem |
| for p in task.get("train", []): |
| samples.append({"task_id": tid, "input": np.array(p["input"], dtype=np.int64), |
| "output": np.array(p["output"], dtype=np.int64), "split": "train"}) |
| for p in task.get("test", []): |
| samples.append({"task_id": tid, "input": np.array(p["input"], dtype=np.int64), |
| "output": np.array(p["output"], dtype=np.int64), "split": "test"}) |
| return samples |
| |
| def _pad(self, g): |
| h, w = g.shape |
| if h > self.max_grid_size or w > self.max_grid_size: |
| g = g[:self.max_grid_size, :self.max_grid_size] |
| h, w = g.shape |
| ph = self.max_grid_size - h |
| pw = self.max_grid_size - w |
| if ph > 0 or pw > 0: |
| g = np.pad(g, ((0, ph), (0, pw)), mode="constant", constant_values=0) |
| return g |
| |
| def _aug(self, inp, out): |
| if not self.augment: |
| return inp, out |
| if random.random() < 0.5: |
| inp, out = np.flip(inp, 0).copy(), np.flip(out, 0).copy() |
| if random.random() < 0.5: |
| inp, out = np.flip(inp, 1).copy(), np.flip(out, 1).copy() |
| k = random.randint(0, 3) |
| if k > 0: |
| inp, out = np.rot90(inp, k=k).copy(), np.rot90(out, k=k).copy() |
| if random.random() < 0.5: |
| colors = np.unique(np.concatenate([inp.flatten(), out.flatten()])) |
| colors = colors[colors > 0] |
| if len(colors) > 1: |
| shuffled = colors.copy() |
| np.random.shuffle(shuffled) |
| cmap = {c: shuffled[i] for i, c in enumerate(colors)} |
| cmap[0] = 0 |
| inp = np.vectorize(cmap.get)(inp) |
| out = np.vectorize(cmap.get)(out) |
| return inp, out |
| |
| def __len__(self): |
| return len(self.samples) |
| |
| def __getitem__(self, idx): |
| s = self.samples[idx] |
| inp, out = self._aug(s["input"], s["output"]) |
| inp = self._pad(inp) |
| out = self._pad(out) |
| return { |
| "task_id": s["task_id"], |
| "input_grid": torch.from_numpy(np.clip(inp, 0, self.num_colors - 1)).long(), |
| "output_grid": torch.from_numpy(np.clip(out, 0, self.num_colors - 1)).long(), |
| "split": s["split"], |
| } |
|
|
|
|
| def collate_fn(batch): |
| return { |
| "task_ids": [b["task_id"] for b in batch], |
| "input_grids": torch.stack([b["input_grid"] for b in batch]), |
| "output_grids": torch.stack([b["output_grid"] for b in batch]), |
| "splits": [b["split"] for b in batch], |
| } |
|
|
|
|
| def load_arc_data(source="local", split="training", cache_dir="./data/arc_agi_1", **kwargs): |
| return ARCDataset(cache_dir, split=split, **kwargs) |
|
|