"""Masking strategy for Grid-JEPA.""" import random import torch class GridMaskingStrategy: def __init__(self, grid_size=64, num_target_blocks=4, target_scale=(0.05, 0.15), target_aspect_ratio=(0.5, 2.0), context_scale=(0.85, 1.0), min_target_size=2): self.grid_size = grid_size self.num_patches = grid_size * grid_size self.num_target_blocks = num_target_blocks self.target_scale = target_scale self.target_aspect_ratio = target_aspect_ratio self.context_scale = context_scale self.min_target_size = min_target_size def _sample_block(self, scale_range, aspect_range): total = self.grid_size * self.grid_size area = max(int(random.uniform(*scale_range) * total), self.min_target_size * self.min_target_size) aspect = random.uniform(*aspect_range) h = max(min(int((area / aspect) ** 0.5), self.grid_size), self.min_target_size) w = max(min(int(area / h), self.grid_size), self.min_target_size) rs = random.randint(0, max(0, self.grid_size - h)) cs = random.randint(0, max(0, self.grid_size - w)) return rs, min(rs + h, self.grid_size), cs, min(cs + w, self.grid_size) def _block_to_indices(self, block): rs, re, cs, ce = block return [r * self.grid_size + c for r in range(rs, re) for c in range(cs, ce)] def sample_masks(self): target_blocks = [self._sample_block(self.target_scale, self.target_aspect_ratio) for _ in range(self.num_target_blocks)] target_indices = set() for b in target_blocks: target_indices.update(self._block_to_indices(b)) target_indices = sorted(list(target_indices)) ctx_block = self._sample_block(self.context_scale, (0.75, 1.5)) ctx_indices = [i for i in self._block_to_indices(ctx_block) if i not in target_indices] tgt_mask = torch.zeros(self.num_patches, dtype=torch.bool) tgt_mask[target_indices] = True ctx_mask = torch.zeros(self.num_patches, dtype=torch.bool) ctx_mask[ctx_indices] = True return { 'target_indices': torch.tensor(target_indices, dtype=torch.long), 'context_indices': torch.tensor(ctx_indices, dtype=torch.long), 'target_mask': tgt_mask, 'context_mask': ctx_mask, } def sample_masks_batch(self, batch_size): all_ti, all_ci, all_tm, all_cm = [], [], [], [] for _ in range(batch_size): m = self.sample_masks() all_ti.append(m['target_indices']) all_ci.append(m['context_indices']) all_tm.append(m['target_mask']) all_cm.append(m['context_mask']) max_t = max(len(t) for t in all_ti) max_c = max(len(c) for c in all_ci) ti_p = torch.zeros(batch_size, max_t, dtype=torch.long) ci_p = torch.zeros(batch_size, max_c, dtype=torch.long) for i, (t, c) in enumerate(zip(all_ti, all_ci)): ti_p[i, :len(t)] = t ci_p[i, :len(c)] = c return { 'target_indices': ti_p, 'target_lengths': torch.tensor([len(t) for t in all_ti]), 'context_indices': ci_p, 'context_lengths': torch.tensor([len(c) for c in all_ci]), 'target_mask': torch.stack(all_tm), 'context_mask': torch.stack(all_cm), }