import random import torch import numpy as np def complete_mask_randomly_np(mask, num_masking_patches, rng): flat = mask.reshape(-1) missing = num_masking_patches - flat.sum() if missing <= 0: return mask available = np.flatnonzero(~flat) chosen = rng.choice(available, size=missing, replace=False) flat[chosen] = True return mask class IBotMasker: def __init__( self, input_size, num_masking_patches=None, min_num_patches=0, max_num_patches=None, min_aspect=0.3, max_aspect=3.33, max_tries=10, ): if isinstance(input_size, int): input_size = (input_size, input_size) self.h, self.w = input_size self.num_patches = self.h * self.w self.min_num_patches = min_num_patches self.num_masking_patches = num_masking_patches self.max_num_patches = max_num_patches or num_masking_patches self.log_min_aspect = np.log(min_aspect) self.log_max_aspect = np.log(max_aspect or 1 / min_aspect) self.max_tries = max_tries def __call__(self, num_masking_patches, starting_mask=None, rng=None): if rng is None: rng = np.random.default_rng() if starting_mask is None: mask = np.zeros((self.h, self.w), dtype=np.bool_) else: mask = starting_mask.copy() mask_count = mask.sum() while mask_count < num_masking_patches: max_mask = num_masking_patches - mask_count if self.max_num_patches is not None: max_mask = min(max_mask, self.max_num_patches) delta = self._mask(mask, max_mask, rng) if delta == 0: break mask_count += delta return complete_mask_randomly_np(mask, num_masking_patches, rng) def _mask(self, mask, max_mask_patches, rng): for _ in range(self.max_tries): target = rng.uniform(self.min_num_patches, max_mask_patches) aspect = np.exp(rng.uniform(self.log_min_aspect, self.log_max_aspect)) h = int(round(np.sqrt(target * aspect))) w = int(round(np.sqrt(target / aspect))) if h <= 0 or w <= 0 or h >= self.h or w >= self.w: continue top = rng.integers(0, self.h - h + 1) left = rng.integers(0, self.w - w + 1) region = mask[top : top + h, left : left + w] newly = (~region).sum() if 0 < newly <= max_mask_patches: region[:] = True return newly return 0 def generate_masks( mask_generator, number_of_samples, mask_prob=0.5, per_sample_range=(0.1, 0.5) ): num_masks = int(number_of_samples * mask_prob) num_tokens = mask_generator.num_patches prob_per_sample = np.linspace(*per_sample_range, num=num_masks) masks = [ ( mask_generator(num_masking_patches=int(prob_per_sample[i] * num_tokens)) if i < num_masks else mask_generator(num_masking_patches=0) ) for i in range(number_of_samples) ] random.shuffle(masks) masks = np.stack(masks, dtype=bool) masks = torch.from_numpy(masks).flatten(1, -1) return masks