|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import random |
|
|
from torchvision import transforms |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger("dinov2") |
|
|
|
|
|
|
|
|
|
|
|
def collate_data_and_cast_aug( |
|
|
samples_list, |
|
|
mask_ratio, |
|
|
mask_probability, |
|
|
dtype, |
|
|
n_tokens=None, |
|
|
mask_first_n=False, |
|
|
mask_generator=None, |
|
|
clone_batch=1, |
|
|
): |
|
|
|
|
|
|
|
|
n_global_crops = 1 |
|
|
|
|
|
assert n_global_crops > 0, "global crops number should be > 0" |
|
|
collated_global_crops = torch.stack([s[i] for i in range(n_global_crops) for s in samples_list]) |
|
|
|
|
|
labels = [s[1] for s in samples_list] |
|
|
labels = torch.LongTensor(labels) |
|
|
collated_global_labels = labels.repeat(n_global_crops) |
|
|
|
|
|
B = len(collated_global_crops) |
|
|
N = n_tokens |
|
|
n_samples_masked = int(B * mask_probability) |
|
|
|
|
|
masks_list = [] |
|
|
upperbound = 0 |
|
|
|
|
|
masks_enc = torch.full((1,), 0, dtype=torch.int32) |
|
|
masks_pred = torch.full((1,), 0, dtype=torch.int32) |
|
|
|
|
|
number_masks = n_samples_masked * clone_batch |
|
|
|
|
|
if isinstance(mask_ratio, (tuple, list)) and len(mask_ratio) == 2: |
|
|
probs = torch.linspace(*mask_ratio, number_masks + 1) |
|
|
for i in range(0, number_masks): |
|
|
prob_min = probs[i] |
|
|
prob_max = probs[i + 1] |
|
|
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) |
|
|
upperbound += int(N * prob_max) |
|
|
else: |
|
|
mask_ratio = mask_ratio[0] |
|
|
|
|
|
for i in range(0, number_masks): |
|
|
masks_list.append(torch.BoolTensor(mask_generator(int(N * mask_ratio)))) |
|
|
upperbound += int(N * mask_ratio) |
|
|
|
|
|
|
|
|
for i in range(n_samples_masked, B): |
|
|
|
|
|
masks_list.append(torch.BoolTensor(mask_generator.get_none_mask())) |
|
|
|
|
|
if not mask_first_n and mask_probability > 0.0: |
|
|
random.shuffle(masks_list) |
|
|
|
|
|
collated_masks = torch.stack(masks_list).flatten(1) |
|
|
mask_indices_list = collated_masks.flatten().nonzero().flatten() |
|
|
|
|
|
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] |
|
|
|
|
|
return { |
|
|
"collated_global_crops": collated_global_crops.to(dtype), |
|
|
"collated_global_labels": collated_global_labels, |
|
|
"collated_masks": collated_masks, |
|
|
"mask_indices_list": mask_indices_list, |
|
|
"masks_weight": masks_weight, |
|
|
"upperbound": upperbound, |
|
|
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), |
|
|
"masks_enc": masks_enc, |
|
|
"masks_pred": masks_pred, |
|
|
} |