| | |
| | |
| | |
| | |
| |
|
| | import logging |
| |
|
| | from torchvision import transforms |
| |
|
| | from .transforms import ( |
| | GaussianBlur, |
| | make_normalize_transform, |
| | ) |
| |
|
| |
|
| | logger = logging.getLogger("dinov2") |
| |
|
| |
|
| | class DataAugmentationDINO(object): |
| | def __init__( |
| | self, |
| | global_crops_scale, |
| | local_crops_scale, |
| | local_crops_number, |
| | global_crops_size=224, |
| | local_crops_size=96, |
| | ): |
| | self.global_crops_scale = global_crops_scale |
| | self.local_crops_scale = local_crops_scale |
| | self.local_crops_number = local_crops_number |
| | self.global_crops_size = global_crops_size |
| | self.local_crops_size = local_crops_size |
| |
|
| | logger.info("###################################") |
| | logger.info("Using data augmentation parameters:") |
| | logger.info(f"global_crops_scale: {global_crops_scale}") |
| | logger.info(f"local_crops_scale: {local_crops_scale}") |
| | logger.info(f"local_crops_number: {local_crops_number}") |
| | logger.info(f"global_crops_size: {global_crops_size}") |
| | logger.info(f"local_crops_size: {local_crops_size}") |
| | logger.info("###################################") |
| |
|
| | |
| | self.geometric_augmentation_global = transforms.Compose( |
| | [ |
| | transforms.RandomResizedCrop( |
| | global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC |
| | ), |
| | transforms.RandomHorizontalFlip(p=0.5), |
| | ] |
| | ) |
| |
|
| | self.geometric_augmentation_local = transforms.Compose( |
| | [ |
| | transforms.RandomResizedCrop( |
| | local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC |
| | ), |
| | transforms.RandomHorizontalFlip(p=0.5), |
| | ] |
| | ) |
| |
|
| | |
| | color_jittering = transforms.Compose( |
| | [ |
| | transforms.RandomApply( |
| | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], |
| | p=0.8, |
| | ), |
| | transforms.RandomGrayscale(p=0.2), |
| | ] |
| | ) |
| |
|
| | global_transfo1_extra = GaussianBlur(p=1.0) |
| |
|
| | global_transfo2_extra = transforms.Compose( |
| | [ |
| | GaussianBlur(p=0.1), |
| | transforms.RandomSolarize(threshold=128, p=0.2), |
| | ] |
| | ) |
| |
|
| | local_transfo_extra = GaussianBlur(p=0.5) |
| |
|
| | |
| | self.normalize = transforms.Compose( |
| | [ |
| | transforms.ToTensor(), |
| | make_normalize_transform(), |
| | ] |
| | ) |
| |
|
| | self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) |
| | self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) |
| | self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) |
| |
|
| | def __call__(self, image): |
| | output = {} |
| |
|
| | |
| | im1_base = self.geometric_augmentation_global(image) |
| | global_crop_1 = self.global_transfo1(im1_base) |
| |
|
| | im2_base = self.geometric_augmentation_global(image) |
| | global_crop_2 = self.global_transfo2(im2_base) |
| |
|
| | output["global_crops"] = [global_crop_1, global_crop_2] |
| |
|
| | |
| | output["global_crops_teacher"] = [global_crop_1, global_crop_2] |
| |
|
| | |
| | local_crops = [ |
| | self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) |
| | ] |
| | output["local_crops"] = local_crops |
| | output["offsets"] = () |
| |
|
| | return output |
| |
|