| """Masking strategy for Grid-JEPA.""" |
| import random |
| import torch |
|
|
| class GridMaskingStrategy: |
| def __init__(self, grid_size=64, num_target_blocks=4, |
| target_scale=(0.05, 0.15), target_aspect_ratio=(0.5, 2.0), |
| context_scale=(0.85, 1.0), min_target_size=2): |
| self.grid_size = grid_size |
| self.num_patches = grid_size * grid_size |
| self.num_target_blocks = num_target_blocks |
| self.target_scale = target_scale |
| self.target_aspect_ratio = target_aspect_ratio |
| self.context_scale = context_scale |
| self.min_target_size = min_target_size |
| |
| def _sample_block(self, scale_range, aspect_range): |
| total = self.grid_size * self.grid_size |
| area = max(int(random.uniform(*scale_range) * total), self.min_target_size * self.min_target_size) |
| aspect = random.uniform(*aspect_range) |
| h = max(min(int((area / aspect) ** 0.5), self.grid_size), self.min_target_size) |
| w = max(min(int(area / h), self.grid_size), self.min_target_size) |
| rs = random.randint(0, max(0, self.grid_size - h)) |
| cs = random.randint(0, max(0, self.grid_size - w)) |
| return rs, min(rs + h, self.grid_size), cs, min(cs + w, self.grid_size) |
| |
| def _block_to_indices(self, block): |
| rs, re, cs, ce = block |
| return [r * self.grid_size + c for r in range(rs, re) for c in range(cs, ce)] |
| |
| def sample_masks(self): |
| target_blocks = [self._sample_block(self.target_scale, self.target_aspect_ratio) for _ in range(self.num_target_blocks)] |
| target_indices = set() |
| for b in target_blocks: |
| target_indices.update(self._block_to_indices(b)) |
| target_indices = sorted(list(target_indices)) |
| ctx_block = self._sample_block(self.context_scale, (0.75, 1.5)) |
| ctx_indices = [i for i in self._block_to_indices(ctx_block) if i not in target_indices] |
| tgt_mask = torch.zeros(self.num_patches, dtype=torch.bool) |
| tgt_mask[target_indices] = True |
| ctx_mask = torch.zeros(self.num_patches, dtype=torch.bool) |
| ctx_mask[ctx_indices] = True |
| return { |
| 'target_indices': torch.tensor(target_indices, dtype=torch.long), |
| 'context_indices': torch.tensor(ctx_indices, dtype=torch.long), |
| 'target_mask': tgt_mask, |
| 'context_mask': ctx_mask, |
| } |
| |
| def sample_masks_batch(self, batch_size): |
| all_ti, all_ci, all_tm, all_cm = [], [], [], [] |
| for _ in range(batch_size): |
| m = self.sample_masks() |
| all_ti.append(m['target_indices']) |
| all_ci.append(m['context_indices']) |
| all_tm.append(m['target_mask']) |
| all_cm.append(m['context_mask']) |
| max_t = max(len(t) for t in all_ti) |
| max_c = max(len(c) for c in all_ci) |
| ti_p = torch.zeros(batch_size, max_t, dtype=torch.long) |
| ci_p = torch.zeros(batch_size, max_c, dtype=torch.long) |
| for i, (t, c) in enumerate(zip(all_ti, all_ci)): |
| ti_p[i, :len(t)] = t |
| ci_p[i, :len(c)] = c |
| return { |
| 'target_indices': ti_p, |
| 'target_lengths': torch.tensor([len(t) for t in all_ti]), |
| 'context_indices': ci_p, |
| 'context_lengths': torch.tensor([len(c) for c in all_ci]), |
| 'target_mask': torch.stack(all_tm), |
| 'context_mask': torch.stack(all_cm), |
| } |
|
|