"""Geometric augmentations for ARC grids.""" import random import torch import numpy as np class GridAugmentations: def __init__(self, grid_size=64): self.grid_size = grid_size def rotate90(self, grid, k=1): if grid.dim() == 2: return torch.rot90(grid, k=k, dims=(0, 1)) return torch.rot90(grid, k=k, dims=(-2, -1)) def flip_horizontal(self, grid): if grid.dim() == 2: return torch.flip(grid, dims=(0,)) return torch.flip(grid, dims=(-2,)) def flip_vertical(self, grid): if grid.dim() == 2: return torch.flip(grid, dims=(1,)) return torch.flip(grid, dims=(-1,)) def transpose(self, grid): if grid.dim() == 2: return grid.t() return grid.transpose(-2, -1) def random_color_permute(self, grid, colors=None): if colors is None: unique = torch.unique(grid) colors = unique[unique > 0].cpu().numpy().tolist() colors = [c for c in colors if c > 0] if len(colors) <= 1: return grid, {c: c for c in colors} shuffled = colors.copy() random.shuffle(shuffled) cmap = {c: shuffled[i] for i, c in enumerate(colors)} cmap[0] = 0 perm = grid.clone() for old, new in cmap.items(): perm = torch.where(grid == old, torch.tensor(new, dtype=grid.dtype, device=grid.device), perm) return perm, cmap def apply_all_transforms(self, grid, include_color=True): transforms = [(grid, "original")] for k in range(1, 4): transforms.append((self.rotate90(grid, k=k), f"rotate_{k*90}")) transforms.append((self.flip_horizontal(grid), "flip_h")) transforms.append((self.flip_vertical(grid), "flip_v")) transforms.append((self.transpose(grid), "transpose")) if include_color: perm, _ = self.random_color_permute(grid) transforms.append((perm, "color_permute")) return transforms def augment_arc_pair(input_grid, output_grid, aug_prob=0.5): if random.random() > aug_prob: return input_grid, output_grid, "none" aug_name = random.choice(["rotate90", "flip_h", "flip_v", "transpose", "color_permute"]) if aug_name == "rotate90": k = random.randint(1, 3) return np.rot90(input_grid, k=k), np.rot90(output_grid, k=k), f"rotate90_{k}" elif aug_name == "flip_h": return np.flip(input_grid, 0), np.flip(output_grid, 0), "flip_h" elif aug_name == "flip_v": return np.flip(input_grid, 1), np.flip(output_grid, 1), "flip_v" elif aug_name == "transpose": return input_grid.T, output_grid.T, "transpose" elif aug_name == "color_permute": colors = np.unique(np.concatenate([input_grid.flatten(), output_grid.flatten()])) colors = colors[colors > 0].tolist() if len(colors) > 1: shuffled = colors.copy() random.shuffle(shuffled) cmap = {c: shuffled[i] for i, c in enumerate(colors)} cmap[0] = 0 return np.vectorize(cmap.get)(input_grid), np.vectorize(cmap.get)(output_grid), "color_permute" return input_grid, output_grid, "none"