| """Geometric augmentations for ARC grids.""" |
| import random |
| import torch |
| import numpy as np |
|
|
| class GridAugmentations: |
| def __init__(self, grid_size=64): |
| self.grid_size = grid_size |
| |
| def rotate90(self, grid, k=1): |
| if grid.dim() == 2: |
| return torch.rot90(grid, k=k, dims=(0, 1)) |
| return torch.rot90(grid, k=k, dims=(-2, -1)) |
| |
| def flip_horizontal(self, grid): |
| if grid.dim() == 2: |
| return torch.flip(grid, dims=(0,)) |
| return torch.flip(grid, dims=(-2,)) |
| |
| def flip_vertical(self, grid): |
| if grid.dim() == 2: |
| return torch.flip(grid, dims=(1,)) |
| return torch.flip(grid, dims=(-1,)) |
| |
| def transpose(self, grid): |
| if grid.dim() == 2: |
| return grid.t() |
| return grid.transpose(-2, -1) |
| |
| def random_color_permute(self, grid, colors=None): |
| if colors is None: |
| unique = torch.unique(grid) |
| colors = unique[unique > 0].cpu().numpy().tolist() |
| colors = [c for c in colors if c > 0] |
| if len(colors) <= 1: |
| return grid, {c: c for c in colors} |
| shuffled = colors.copy() |
| random.shuffle(shuffled) |
| cmap = {c: shuffled[i] for i, c in enumerate(colors)} |
| cmap[0] = 0 |
| perm = grid.clone() |
| for old, new in cmap.items(): |
| perm = torch.where(grid == old, torch.tensor(new, dtype=grid.dtype, device=grid.device), perm) |
| return perm, cmap |
| |
| def apply_all_transforms(self, grid, include_color=True): |
| transforms = [(grid, "original")] |
| for k in range(1, 4): |
| transforms.append((self.rotate90(grid, k=k), f"rotate_{k*90}")) |
| transforms.append((self.flip_horizontal(grid), "flip_h")) |
| transforms.append((self.flip_vertical(grid), "flip_v")) |
| transforms.append((self.transpose(grid), "transpose")) |
| if include_color: |
| perm, _ = self.random_color_permute(grid) |
| transforms.append((perm, "color_permute")) |
| return transforms |
|
|
|
|
| def augment_arc_pair(input_grid, output_grid, aug_prob=0.5): |
| if random.random() > aug_prob: |
| return input_grid, output_grid, "none" |
| aug_name = random.choice(["rotate90", "flip_h", "flip_v", "transpose", "color_permute"]) |
| if aug_name == "rotate90": |
| k = random.randint(1, 3) |
| return np.rot90(input_grid, k=k), np.rot90(output_grid, k=k), f"rotate90_{k}" |
| elif aug_name == "flip_h": |
| return np.flip(input_grid, 0), np.flip(output_grid, 0), "flip_h" |
| elif aug_name == "flip_v": |
| return np.flip(input_grid, 1), np.flip(output_grid, 1), "flip_v" |
| elif aug_name == "transpose": |
| return input_grid.T, output_grid.T, "transpose" |
| elif aug_name == "color_permute": |
| colors = np.unique(np.concatenate([input_grid.flatten(), output_grid.flatten()])) |
| colors = colors[colors > 0].tolist() |
| if len(colors) > 1: |
| shuffled = colors.copy() |
| random.shuffle(shuffled) |
| cmap = {c: shuffled[i] for i, c in enumerate(colors)} |
| cmap[0] = 0 |
| return np.vectorize(cmap.get)(input_grid), np.vectorize(cmap.get)(output_grid), "color_permute" |
| return input_grid, output_grid, "none" |
|
|