# Copyright (c) Meta Platforms, Inc. and affiliates. # # This software may be used and distributed in accordance with # the terms of the DINOv3 License Agreement. import logging import numpy as np from torch import nn from torchvision import transforms from dinov3.data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, GaussianBlur, make_normalize_transform logger = logging.getLogger("dinov3") class DataAugmentationDINO(object): def __init__( self, global_crops_scale, local_crops_scale, local_crops_number, global_crops_size=224, local_crops_size=96, gram_teacher_crops_size=None, gram_teacher_no_distortions=False, teacher_no_color_jitter=False, local_crops_subset_of_global_crops=False, patch_size=16, share_color_jitter=False, horizontal_flips=True, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, ): self.global_crops_scale = global_crops_scale self.local_crops_scale = local_crops_scale self.local_crops_number = local_crops_number self.global_crops_size = global_crops_size self.local_crops_size = local_crops_size self.gram_teacher_crops_size = gram_teacher_crops_size self.gram_teacher_no_distortions = gram_teacher_no_distortions self.teacher_no_color_jitter = teacher_no_color_jitter self.local_crops_subset_of_global_crops = local_crops_subset_of_global_crops self.patch_size = patch_size self.share_color_jitter = share_color_jitter self.mean = mean self.std = std logger.info("###################################") logger.info("Using data augmentation parameters:") logger.info(f"global_crops_scale: {global_crops_scale}") logger.info(f"local_crops_scale: {local_crops_scale}") logger.info(f"local_crops_number: {local_crops_number}") logger.info(f"global_crops_size: {global_crops_size}") logger.info(f"local_crops_size: {local_crops_size}") logger.info(f"gram_crops_size: {gram_teacher_crops_size}") logger.info(f"gram_teacher_no_distortions: {gram_teacher_no_distortions}") logger.info(f"teacher_no_color_jitter: {teacher_no_color_jitter}") logger.info(f"local_crops_subset_of_global_crops: {local_crops_subset_of_global_crops}") logger.info(f"patch_size if local_crops_subset_of_global_crops: {patch_size}") logger.info(f"share_color_jitter: {share_color_jitter}") logger.info(f"horizontal flips: {horizontal_flips}") logger.info("###################################") # Global crops and gram teacher crops can have different sizes. We first take a crop of the maximum size # and then resize it to the desired size for global and gram teacher crops. global_crop_max_size = max(global_crops_size, gram_teacher_crops_size if gram_teacher_crops_size else 0) # random resized crop and flip self.geometric_augmentation_global = transforms.Compose( [ transforms.RandomResizedCrop( global_crop_max_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC, ), transforms.RandomHorizontalFlip(p=0.5 if horizontal_flips else 0.0), ] ) resize_global = nn.Identity() # Resize transform applied to global crops after random crop self.resize_global_post_transf = ( nn.Identity() ) # Resize transform applied to global crops after all other transforms self.resize_gram_teacher = None # Resize transform applied to crops for gram teacher if gram_teacher_crops_size is not None: # All resize transforms will do nothing if the crop size is already the desired size. if gram_teacher_no_distortions: # When there a no distortions for the gram teacher crop, we can resize before the distortions. # This is the preferred order, because it keeps the image size for the augmentations consistent, # which matters e.g. for GaussianBlur. resize_global = transforms.Resize( global_crops_size, interpolation=transforms.InterpolationMode.BICUBIC, ) else: # When there a no distortions for the gram teacher crop, we need to resize after the distortions, # because the distortions are shared between global and gram teacher crops. self.resize_global_post_transf = transforms.Resize( global_crops_size, interpolation=transforms.InterpolationMode.BICUBIC, ) self.resize_gram_teacher = transforms.Resize( gram_teacher_crops_size, interpolation=transforms.InterpolationMode.BICUBIC, ) self.geometric_augmentation_local = transforms.Compose( [ transforms.RandomResizedCrop( local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC, ), transforms.RandomHorizontalFlip(p=0.5 if horizontal_flips else 0.0), ] ) # color distortions / blurring color_jittering = transforms.Compose( [ transforms.RandomApply( [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8, ), transforms.RandomGrayscale(p=0.2), ] ) global_transfo1_extra = GaussianBlur(p=1.0) global_transfo2_extra = transforms.Compose( [ GaussianBlur(p=0.1), transforms.RandomSolarize(threshold=128, p=0.2), ] ) local_transfo_extra = GaussianBlur(p=0.5) # normalization self.normalize = transforms.Compose( [ transforms.ToTensor(), make_normalize_transform(mean=mean, std=std), ] ) if self.share_color_jitter: self.color_jittering = color_jittering self.global_transfo1 = transforms.Compose([resize_global, global_transfo1_extra, self.normalize]) self.global_transfo2 = transforms.Compose([resize_global, global_transfo2_extra, self.normalize]) self.local_transfo = transforms.Compose([local_transfo_extra, self.normalize]) else: self.global_transfo1 = transforms.Compose( [resize_global, color_jittering, global_transfo1_extra, self.normalize] ) self.global_transfo2 = transforms.Compose( [resize_global, color_jittering, global_transfo2_extra, self.normalize] ) self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) def __call__(self, image): output = {} output["weak_flag"] = True # some residual from mugs if self.share_color_jitter: image = self.color_jittering(image) # global crops: im1_base = self.geometric_augmentation_global(image) global_crop_1_transf = self.global_transfo1(im1_base) global_crop_1 = self.resize_global_post_transf(global_crop_1_transf) im2_base = self.geometric_augmentation_global(image) global_crop_2_transf = self.global_transfo2(im2_base) global_crop_2 = self.resize_global_post_transf(global_crop_2_transf) output["global_crops"] = [global_crop_1, global_crop_2] # global crops for teacher: if self.teacher_no_color_jitter: output["global_crops_teacher"] = [ self.normalize(im1_base), self.normalize(im2_base), ] else: output["global_crops_teacher"] = [global_crop_1, global_crop_2] if self.gram_teacher_crops_size is not None: # crops for gram teacher: if self.gram_teacher_no_distortions: gram_crop_1 = self.normalize(self.resize_gram_teacher(im1_base)) gram_crop_2 = self.normalize(self.resize_gram_teacher(im2_base)) else: gram_crop_1 = self.resize_gram_teacher(global_crop_1_transf) gram_crop_2 = self.resize_gram_teacher(global_crop_2_transf) output["gram_teacher_crops"] = [gram_crop_1, gram_crop_2] # local crops: if self.local_crops_subset_of_global_crops: _local_crops = [self.local_transfo(im1_base) for _ in range(self.local_crops_number // 2)] + [ self.local_transfo(im2_base) for _ in range(self.local_crops_number // 2) ] local_crops = [] offsets = [] gs = self.global_crops_size ls = self.local_crops_size for img in _local_crops: rx, ry = np.random.randint(0, (gs - ls) // self.patch_size, 2) * self.patch_size local_crops.append(img[:, rx : rx + ls, ry : ry + ls]) offsets.append((rx, ry)) output["local_crops"] = local_crops output["offsets"] = offsets else: local_crops = [ self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) ] output["local_crops"] = local_crops output["offsets"] = () return output