File size: 3,247 Bytes
c7b9be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Geometric augmentations for ARC grids."""
import random
import torch
import numpy as np

class GridAugmentations:
    def __init__(self, grid_size=64):
        self.grid_size = grid_size
    
    def rotate90(self, grid, k=1):
        if grid.dim() == 2:
            return torch.rot90(grid, k=k, dims=(0, 1))
        return torch.rot90(grid, k=k, dims=(-2, -1))
    
    def flip_horizontal(self, grid):
        if grid.dim() == 2:
            return torch.flip(grid, dims=(0,))
        return torch.flip(grid, dims=(-2,))
    
    def flip_vertical(self, grid):
        if grid.dim() == 2:
            return torch.flip(grid, dims=(1,))
        return torch.flip(grid, dims=(-1,))
    
    def transpose(self, grid):
        if grid.dim() == 2:
            return grid.t()
        return grid.transpose(-2, -1)
    
    def random_color_permute(self, grid, colors=None):
        if colors is None:
            unique = torch.unique(grid)
            colors = unique[unique > 0].cpu().numpy().tolist()
        colors = [c for c in colors if c > 0]
        if len(colors) <= 1:
            return grid, {c: c for c in colors}
        shuffled = colors.copy()
        random.shuffle(shuffled)
        cmap = {c: shuffled[i] for i, c in enumerate(colors)}
        cmap[0] = 0
        perm = grid.clone()
        for old, new in cmap.items():
            perm = torch.where(grid == old, torch.tensor(new, dtype=grid.dtype, device=grid.device), perm)
        return perm, cmap
    
    def apply_all_transforms(self, grid, include_color=True):
        transforms = [(grid, "original")]
        for k in range(1, 4):
            transforms.append((self.rotate90(grid, k=k), f"rotate_{k*90}"))
        transforms.append((self.flip_horizontal(grid), "flip_h"))
        transforms.append((self.flip_vertical(grid), "flip_v"))
        transforms.append((self.transpose(grid), "transpose"))
        if include_color:
            perm, _ = self.random_color_permute(grid)
            transforms.append((perm, "color_permute"))
        return transforms


def augment_arc_pair(input_grid, output_grid, aug_prob=0.5):
    if random.random() > aug_prob:
        return input_grid, output_grid, "none"
    aug_name = random.choice(["rotate90", "flip_h", "flip_v", "transpose", "color_permute"])
    if aug_name == "rotate90":
        k = random.randint(1, 3)
        return np.rot90(input_grid, k=k), np.rot90(output_grid, k=k), f"rotate90_{k}"
    elif aug_name == "flip_h":
        return np.flip(input_grid, 0), np.flip(output_grid, 0), "flip_h"
    elif aug_name == "flip_v":
        return np.flip(input_grid, 1), np.flip(output_grid, 1), "flip_v"
    elif aug_name == "transpose":
        return input_grid.T, output_grid.T, "transpose"
    elif aug_name == "color_permute":
        colors = np.unique(np.concatenate([input_grid.flatten(), output_grid.flatten()]))
        colors = colors[colors > 0].tolist()
        if len(colors) > 1:
            shuffled = colors.copy()
            random.shuffle(shuffled)
            cmap = {c: shuffled[i] for i, c in enumerate(colors)}
            cmap[0] = 0
            return np.vectorize(cmap.get)(input_grid), np.vectorize(cmap.get)(output_grid), "color_permute"
    return input_grid, output_grid, "none"