arc-agi-3-grid-jepa / src /utils /augmentations.py
guychuk's picture
Add geometric augmentations
c7b9be8 verified
"""Geometric augmentations for ARC grids."""
import random
import torch
import numpy as np
class GridAugmentations:
def __init__(self, grid_size=64):
self.grid_size = grid_size
def rotate90(self, grid, k=1):
if grid.dim() == 2:
return torch.rot90(grid, k=k, dims=(0, 1))
return torch.rot90(grid, k=k, dims=(-2, -1))
def flip_horizontal(self, grid):
if grid.dim() == 2:
return torch.flip(grid, dims=(0,))
return torch.flip(grid, dims=(-2,))
def flip_vertical(self, grid):
if grid.dim() == 2:
return torch.flip(grid, dims=(1,))
return torch.flip(grid, dims=(-1,))
def transpose(self, grid):
if grid.dim() == 2:
return grid.t()
return grid.transpose(-2, -1)
def random_color_permute(self, grid, colors=None):
if colors is None:
unique = torch.unique(grid)
colors = unique[unique > 0].cpu().numpy().tolist()
colors = [c for c in colors if c > 0]
if len(colors) <= 1:
return grid, {c: c for c in colors}
shuffled = colors.copy()
random.shuffle(shuffled)
cmap = {c: shuffled[i] for i, c in enumerate(colors)}
cmap[0] = 0
perm = grid.clone()
for old, new in cmap.items():
perm = torch.where(grid == old, torch.tensor(new, dtype=grid.dtype, device=grid.device), perm)
return perm, cmap
def apply_all_transforms(self, grid, include_color=True):
transforms = [(grid, "original")]
for k in range(1, 4):
transforms.append((self.rotate90(grid, k=k), f"rotate_{k*90}"))
transforms.append((self.flip_horizontal(grid), "flip_h"))
transforms.append((self.flip_vertical(grid), "flip_v"))
transforms.append((self.transpose(grid), "transpose"))
if include_color:
perm, _ = self.random_color_permute(grid)
transforms.append((perm, "color_permute"))
return transforms
def augment_arc_pair(input_grid, output_grid, aug_prob=0.5):
if random.random() > aug_prob:
return input_grid, output_grid, "none"
aug_name = random.choice(["rotate90", "flip_h", "flip_v", "transpose", "color_permute"])
if aug_name == "rotate90":
k = random.randint(1, 3)
return np.rot90(input_grid, k=k), np.rot90(output_grid, k=k), f"rotate90_{k}"
elif aug_name == "flip_h":
return np.flip(input_grid, 0), np.flip(output_grid, 0), "flip_h"
elif aug_name == "flip_v":
return np.flip(input_grid, 1), np.flip(output_grid, 1), "flip_v"
elif aug_name == "transpose":
return input_grid.T, output_grid.T, "transpose"
elif aug_name == "color_permute":
colors = np.unique(np.concatenate([input_grid.flatten(), output_grid.flatten()]))
colors = colors[colors > 0].tolist()
if len(colors) > 1:
shuffled = colors.copy()
random.shuffle(shuffled)
cmap = {c: shuffled[i] for i, c in enumerate(colors)}
cmap[0] = 0
return np.vectorize(cmap.get)(input_grid), np.vectorize(cmap.get)(output_grid), "color_permute"
return input_grid, output_grid, "none"