| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import random |
| |
|
| |
|
| | def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None): |
| | |
| |
|
| | n_global_crops = len(samples_list[0][0]["global_crops"]) |
| | n_local_crops = len(samples_list[0][0]["local_crops"]) |
| |
|
| | collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]) |
| |
|
| | collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) |
| |
|
| | B = len(collated_global_crops) |
| | N = n_tokens |
| | n_samples_masked = int(B * mask_probability) |
| | probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) |
| | upperbound = 0 |
| | masks_list = [] |
| | for i in range(0, n_samples_masked): |
| | prob_min = probs[i] |
| | prob_max = probs[i + 1] |
| | masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) |
| | upperbound += int(N * prob_max) |
| | for i in range(n_samples_masked, B): |
| | masks_list.append(torch.BoolTensor(mask_generator(0))) |
| |
|
| | random.shuffle(masks_list) |
| |
|
| | collated_masks = torch.stack(masks_list).flatten(1) |
| | mask_indices_list = collated_masks.flatten().nonzero().flatten() |
| |
|
| | masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] |
| |
|
| | return { |
| | "collated_global_crops": collated_global_crops.to(dtype), |
| | "collated_local_crops": collated_local_crops.to(dtype), |
| | "collated_masks": collated_masks, |
| | "mask_indices_list": mask_indices_list, |
| | "masks_weight": masks_weight, |
| | "upperbound": upperbound, |
| | "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), |
| | } |
| |
|