arc-agi-3-grid-jepa / src /data /masking.py
guychuk's picture
Add grid masking strategy
b52c75d verified
"""Masking strategy for Grid-JEPA."""
import random
import torch
class GridMaskingStrategy:
def __init__(self, grid_size=64, num_target_blocks=4,
target_scale=(0.05, 0.15), target_aspect_ratio=(0.5, 2.0),
context_scale=(0.85, 1.0), min_target_size=2):
self.grid_size = grid_size
self.num_patches = grid_size * grid_size
self.num_target_blocks = num_target_blocks
self.target_scale = target_scale
self.target_aspect_ratio = target_aspect_ratio
self.context_scale = context_scale
self.min_target_size = min_target_size
def _sample_block(self, scale_range, aspect_range):
total = self.grid_size * self.grid_size
area = max(int(random.uniform(*scale_range) * total), self.min_target_size * self.min_target_size)
aspect = random.uniform(*aspect_range)
h = max(min(int((area / aspect) ** 0.5), self.grid_size), self.min_target_size)
w = max(min(int(area / h), self.grid_size), self.min_target_size)
rs = random.randint(0, max(0, self.grid_size - h))
cs = random.randint(0, max(0, self.grid_size - w))
return rs, min(rs + h, self.grid_size), cs, min(cs + w, self.grid_size)
def _block_to_indices(self, block):
rs, re, cs, ce = block
return [r * self.grid_size + c for r in range(rs, re) for c in range(cs, ce)]
def sample_masks(self):
target_blocks = [self._sample_block(self.target_scale, self.target_aspect_ratio) for _ in range(self.num_target_blocks)]
target_indices = set()
for b in target_blocks:
target_indices.update(self._block_to_indices(b))
target_indices = sorted(list(target_indices))
ctx_block = self._sample_block(self.context_scale, (0.75, 1.5))
ctx_indices = [i for i in self._block_to_indices(ctx_block) if i not in target_indices]
tgt_mask = torch.zeros(self.num_patches, dtype=torch.bool)
tgt_mask[target_indices] = True
ctx_mask = torch.zeros(self.num_patches, dtype=torch.bool)
ctx_mask[ctx_indices] = True
return {
'target_indices': torch.tensor(target_indices, dtype=torch.long),
'context_indices': torch.tensor(ctx_indices, dtype=torch.long),
'target_mask': tgt_mask,
'context_mask': ctx_mask,
}
def sample_masks_batch(self, batch_size):
all_ti, all_ci, all_tm, all_cm = [], [], [], []
for _ in range(batch_size):
m = self.sample_masks()
all_ti.append(m['target_indices'])
all_ci.append(m['context_indices'])
all_tm.append(m['target_mask'])
all_cm.append(m['context_mask'])
max_t = max(len(t) for t in all_ti)
max_c = max(len(c) for c in all_ci)
ti_p = torch.zeros(batch_size, max_t, dtype=torch.long)
ci_p = torch.zeros(batch_size, max_c, dtype=torch.long)
for i, (t, c) in enumerate(zip(all_ti, all_ci)):
ti_p[i, :len(t)] = t
ci_p[i, :len(c)] = c
return {
'target_indices': ti_p,
'target_lengths': torch.tensor([len(t) for t in all_ti]),
'context_indices': ci_p,
'context_lengths': torch.tensor([len(c) for c in all_ci]),
'target_mask': torch.stack(all_tm),
'context_mask': torch.stack(all_cm),
}