| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
|
|
| from PIL import Image |
| from torchvision import transforms |
|
|
| from .transforms import ( |
| GaussianBlur, |
| MaybeToTensor, |
| make_normalize_transform, |
| ) |
|
|
|
|
| logger = logging.getLogger("dinov2") |
|
|
|
|
| class DataAugmentationDINO(object): |
| def __init__( |
| self, |
| global_crops_scale, |
| local_crops_scale, |
| local_crops_number, |
| global_crops_size=224, |
| local_crops_size=96, |
| ): |
| self.global_crops_scale = global_crops_scale |
| self.local_crops_scale = local_crops_scale |
| self.local_crops_number = local_crops_number |
| self.global_crops_size = global_crops_size |
| self.local_crops_size = local_crops_size |
|
|
| logger.info("###################################") |
| logger.info("Using data augmentation parameters:") |
| logger.info(f"global_crops_scale: {global_crops_scale}") |
| logger.info(f"local_crops_scale: {local_crops_scale}") |
| logger.info(f"local_crops_number: {local_crops_number}") |
| logger.info(f"global_crops_size: {global_crops_size}") |
| logger.info(f"local_crops_size: {local_crops_size}") |
| logger.info("###################################") |
|
|
| |
| self.geometric_augmentation_global = transforms.Compose( |
| [ |
| transforms.RandomResizedCrop( |
| global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC |
| ), |
| transforms.RandomHorizontalFlip(p=0.5), |
| ] |
| ) |
|
|
| self.geometric_augmentation_local = transforms.Compose( |
| [ |
| transforms.RandomResizedCrop( |
| local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC |
| ), |
| transforms.RandomHorizontalFlip(p=0.5), |
| ] |
| ) |
|
|
| |
| color_jittering = transforms.Compose( |
| [ |
| transforms.RandomApply( |
| [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], |
| p=0.8, |
| ), |
| transforms.RandomGrayscale(p=0.2), |
| ] |
| ) |
|
|
| global_transfo1_extra = GaussianBlur(p=0.5) |
|
|
| global_transfo2_extra = transforms.Compose( |
| [ |
| GaussianBlur(p=0.1), |
| ] |
| ) |
|
|
| local_transfo_extra = GaussianBlur(p=0.5) |
|
|
| |
| self.normalize = transforms.Compose( |
| [ |
| MaybeToTensor(), |
| make_normalize_transform(), |
| ] |
| ) |
|
|
| self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) |
| self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) |
| self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) |
|
|
| def __call__(self, image): |
| output = {} |
|
|
| |
| im1_base = self.geometric_augmentation_global(image) |
| global_crop_1 = self.global_transfo1(im1_base) |
|
|
| im2_base = self.geometric_augmentation_global(image) |
| global_crop_2 = self.global_transfo2(im2_base) |
|
|
| output["global_crops"] = [global_crop_1, global_crop_2] |
|
|
| |
| output["global_crops_teacher"] = [global_crop_1, global_crop_2] |
|
|
| |
| local_crops = [ |
| self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) |
| ] |
| output["local_crops"] = local_crops |
| output["offsets"] = () |
|
|
| return output |
|
|
|
|
| def get_online_classification_augmentation_from_config(cfg) -> transforms.Compose: |
| augmentation_config = cfg.evaluation.online.augmentation |
| interpolation = getattr(Image.Resampling, augmentation_config.interpolation) |
| resize_size = crop_size = cfg.crops.global_crops_size |
| resize = transforms.Resize(resize_size, interpolation=interpolation) |
| crop = transforms.CenterCrop(crop_size) |
| affine = transforms.RandomAffine( |
| degrees=augmentation_config.degrees, |
| scale=augmentation_config.scale, |
| shear=augmentation_config.shear, |
| interpolation=interpolation, |
| ) |
| transforms_list = [ |
| resize, |
| crop, |
| affine, |
| MaybeToTensor(), |
| make_normalize_transform(), |
| ] |
| if augmentation_config.horizontal_flip: |
| transforms_list.append(transforms.RandomHorizontalFlip()) |
| return transforms.Compose(transforms_list) |
|
|