DiffICM / 1_feature_extractor /augmentations.py
Qiyp's picture
code of stage1 & 3, remove large files
1633fcc
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from torchvision import transforms
import torch
logger = logging.getLogger("dinov2")
def collate_data_and_cast_aug(
samples_list,
mask_ratio,
mask_probability,
dtype,
n_tokens=None,
mask_first_n=False,
mask_generator=None,
clone_batch=1,
):
# dtype = torch.half # TODO: Remove
n_global_crops = 1
assert n_global_crops > 0, "global crops number should be > 0"
collated_global_crops = torch.stack([s[i] for i in range(n_global_crops) for s in samples_list])
labels = [s[1] for s in samples_list]
labels = torch.LongTensor(labels)
collated_global_labels = labels.repeat(n_global_crops)
B = len(collated_global_crops)
N = n_tokens
n_samples_masked = int(B * mask_probability)
masks_list = []
upperbound = 0
masks_enc = torch.full((1,), 0, dtype=torch.int32)
masks_pred = torch.full((1,), 0, dtype=torch.int32)
# specify the number of masks to append
number_masks = n_samples_masked * clone_batch
# do per-sample masking
if isinstance(mask_ratio, (tuple, list)) and len(mask_ratio) == 2:
probs = torch.linspace(*mask_ratio, number_masks + 1)
for i in range(0, number_masks):
prob_min = probs[i]
prob_max = probs[i + 1]
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
upperbound += int(N * prob_max)
else:
mask_ratio = mask_ratio[0]
# apply the same mask ratio to all images
for i in range(0, number_masks):
masks_list.append(torch.BoolTensor(mask_generator(int(N * mask_ratio))))
upperbound += int(N * mask_ratio)
# append masks for unmasked samples
for i in range(n_samples_masked, B):
# masks_list.append(torch.BoolTensor(mask_generator(0)))
masks_list.append(torch.BoolTensor(mask_generator.get_none_mask()))
if not mask_first_n and mask_probability > 0.0: # shuffle masking -- not shuffling for mae-style
random.shuffle(masks_list)
collated_masks = torch.stack(masks_list).flatten(1)
mask_indices_list = collated_masks.flatten().nonzero().flatten()
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
return {
"collated_global_crops": collated_global_crops.to(dtype),
"collated_global_labels": collated_global_labels,
"collated_masks": collated_masks,
"mask_indices_list": mask_indices_list,
"masks_weight": masks_weight,
"upperbound": upperbound,
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
"masks_enc": masks_enc,
"masks_pred": masks_pred,
}