File size: 3,398 Bytes
a10ce46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import random
import torch
import numpy as np
def complete_mask_randomly_np(mask, num_masking_patches, rng):
flat = mask.reshape(-1)
missing = num_masking_patches - flat.sum()
if missing <= 0:
return mask
available = np.flatnonzero(~flat)
chosen = rng.choice(available, size=missing, replace=False)
flat[chosen] = True
return mask
class IBotMasker:
def __init__(
self,
input_size,
num_masking_patches=None,
min_num_patches=0,
max_num_patches=None,
min_aspect=0.3,
max_aspect=3.33,
max_tries=10,
):
if isinstance(input_size, int):
input_size = (input_size, input_size)
self.h, self.w = input_size
self.num_patches = self.h * self.w
self.min_num_patches = min_num_patches
self.num_masking_patches = num_masking_patches
self.max_num_patches = max_num_patches or num_masking_patches
self.log_min_aspect = np.log(min_aspect)
self.log_max_aspect = np.log(max_aspect or 1 / min_aspect)
self.max_tries = max_tries
def __call__(self, num_masking_patches, starting_mask=None, rng=None):
if rng is None:
rng = np.random.default_rng()
if starting_mask is None:
mask = np.zeros((self.h, self.w), dtype=np.bool_)
else:
mask = starting_mask.copy()
mask_count = mask.sum()
while mask_count < num_masking_patches:
max_mask = num_masking_patches - mask_count
if self.max_num_patches is not None:
max_mask = min(max_mask, self.max_num_patches)
delta = self._mask(mask, max_mask, rng)
if delta == 0:
break
mask_count += delta
return complete_mask_randomly_np(mask, num_masking_patches, rng)
def _mask(self, mask, max_mask_patches, rng):
for _ in range(self.max_tries):
target = rng.uniform(self.min_num_patches, max_mask_patches)
aspect = np.exp(rng.uniform(self.log_min_aspect, self.log_max_aspect))
h = int(round(np.sqrt(target * aspect)))
w = int(round(np.sqrt(target / aspect)))
if h <= 0 or w <= 0 or h >= self.h or w >= self.w:
continue
top = rng.integers(0, self.h - h + 1)
left = rng.integers(0, self.w - w + 1)
region = mask[top : top + h, left : left + w]
newly = (~region).sum()
if 0 < newly <= max_mask_patches:
region[:] = True
return newly
return 0
def generate_masks(
mask_generator, number_of_samples, mask_prob=0.5, per_sample_range=(0.1, 0.5)
):
num_masks = int(number_of_samples * mask_prob)
num_tokens = mask_generator.num_patches
prob_per_sample = np.linspace(*per_sample_range, num=num_masks)
masks = [
(
mask_generator(num_masking_patches=int(prob_per_sample[i] * num_tokens))
if i < num_masks
else mask_generator(num_masking_patches=0)
)
for i in range(number_of_samples)
]
random.shuffle(masks)
masks = np.stack(masks, dtype=bool)
masks = torch.from_numpy(masks).flatten(1, -1)
return masks
|