Prior2DSM / src /dinov3 /data /augmentations.py
osherr's picture
Upload 222 files
bc90483 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import logging
import numpy as np
from torch import nn
from torchvision import transforms
from dinov3.data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, GaussianBlur, make_normalize_transform
logger = logging.getLogger("dinov3")
class DataAugmentationDINO(object):
def __init__(
self,
global_crops_scale,
local_crops_scale,
local_crops_number,
global_crops_size=224,
local_crops_size=96,
gram_teacher_crops_size=None,
gram_teacher_no_distortions=False,
teacher_no_color_jitter=False,
local_crops_subset_of_global_crops=False,
patch_size=16,
share_color_jitter=False,
horizontal_flips=True,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
):
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
self.local_crops_number = local_crops_number
self.global_crops_size = global_crops_size
self.local_crops_size = local_crops_size
self.gram_teacher_crops_size = gram_teacher_crops_size
self.gram_teacher_no_distortions = gram_teacher_no_distortions
self.teacher_no_color_jitter = teacher_no_color_jitter
self.local_crops_subset_of_global_crops = local_crops_subset_of_global_crops
self.patch_size = patch_size
self.share_color_jitter = share_color_jitter
self.mean = mean
self.std = std
logger.info("###################################")
logger.info("Using data augmentation parameters:")
logger.info(f"global_crops_scale: {global_crops_scale}")
logger.info(f"local_crops_scale: {local_crops_scale}")
logger.info(f"local_crops_number: {local_crops_number}")
logger.info(f"global_crops_size: {global_crops_size}")
logger.info(f"local_crops_size: {local_crops_size}")
logger.info(f"gram_crops_size: {gram_teacher_crops_size}")
logger.info(f"gram_teacher_no_distortions: {gram_teacher_no_distortions}")
logger.info(f"teacher_no_color_jitter: {teacher_no_color_jitter}")
logger.info(f"local_crops_subset_of_global_crops: {local_crops_subset_of_global_crops}")
logger.info(f"patch_size if local_crops_subset_of_global_crops: {patch_size}")
logger.info(f"share_color_jitter: {share_color_jitter}")
logger.info(f"horizontal flips: {horizontal_flips}")
logger.info("###################################")
# Global crops and gram teacher crops can have different sizes. We first take a crop of the maximum size
# and then resize it to the desired size for global and gram teacher crops.
global_crop_max_size = max(global_crops_size, gram_teacher_crops_size if gram_teacher_crops_size else 0)
# random resized crop and flip
self.geometric_augmentation_global = transforms.Compose(
[
transforms.RandomResizedCrop(
global_crop_max_size,
scale=global_crops_scale,
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(p=0.5 if horizontal_flips else 0.0),
]
)
resize_global = nn.Identity() # Resize transform applied to global crops after random crop
self.resize_global_post_transf = (
nn.Identity()
) # Resize transform applied to global crops after all other transforms
self.resize_gram_teacher = None # Resize transform applied to crops for gram teacher
if gram_teacher_crops_size is not None:
# All resize transforms will do nothing if the crop size is already the desired size.
if gram_teacher_no_distortions:
# When there a no distortions for the gram teacher crop, we can resize before the distortions.
# This is the preferred order, because it keeps the image size for the augmentations consistent,
# which matters e.g. for GaussianBlur.
resize_global = transforms.Resize(
global_crops_size,
interpolation=transforms.InterpolationMode.BICUBIC,
)
else:
# When there a no distortions for the gram teacher crop, we need to resize after the distortions,
# because the distortions are shared between global and gram teacher crops.
self.resize_global_post_transf = transforms.Resize(
global_crops_size,
interpolation=transforms.InterpolationMode.BICUBIC,
)
self.resize_gram_teacher = transforms.Resize(
gram_teacher_crops_size,
interpolation=transforms.InterpolationMode.BICUBIC,
)
self.geometric_augmentation_local = transforms.Compose(
[
transforms.RandomResizedCrop(
local_crops_size,
scale=local_crops_scale,
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(p=0.5 if horizontal_flips else 0.0),
]
)
# color distortions / blurring
color_jittering = transforms.Compose(
[
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
p=0.8,
),
transforms.RandomGrayscale(p=0.2),
]
)
global_transfo1_extra = GaussianBlur(p=1.0)
global_transfo2_extra = transforms.Compose(
[
GaussianBlur(p=0.1),
transforms.RandomSolarize(threshold=128, p=0.2),
]
)
local_transfo_extra = GaussianBlur(p=0.5)
# normalization
self.normalize = transforms.Compose(
[
transforms.ToTensor(),
make_normalize_transform(mean=mean, std=std),
]
)
if self.share_color_jitter:
self.color_jittering = color_jittering
self.global_transfo1 = transforms.Compose([resize_global, global_transfo1_extra, self.normalize])
self.global_transfo2 = transforms.Compose([resize_global, global_transfo2_extra, self.normalize])
self.local_transfo = transforms.Compose([local_transfo_extra, self.normalize])
else:
self.global_transfo1 = transforms.Compose(
[resize_global, color_jittering, global_transfo1_extra, self.normalize]
)
self.global_transfo2 = transforms.Compose(
[resize_global, color_jittering, global_transfo2_extra, self.normalize]
)
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
def __call__(self, image):
output = {}
output["weak_flag"] = True # some residual from mugs
if self.share_color_jitter:
image = self.color_jittering(image)
# global crops:
im1_base = self.geometric_augmentation_global(image)
global_crop_1_transf = self.global_transfo1(im1_base)
global_crop_1 = self.resize_global_post_transf(global_crop_1_transf)
im2_base = self.geometric_augmentation_global(image)
global_crop_2_transf = self.global_transfo2(im2_base)
global_crop_2 = self.resize_global_post_transf(global_crop_2_transf)
output["global_crops"] = [global_crop_1, global_crop_2]
# global crops for teacher:
if self.teacher_no_color_jitter:
output["global_crops_teacher"] = [
self.normalize(im1_base),
self.normalize(im2_base),
]
else:
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
if self.gram_teacher_crops_size is not None:
# crops for gram teacher:
if self.gram_teacher_no_distortions:
gram_crop_1 = self.normalize(self.resize_gram_teacher(im1_base))
gram_crop_2 = self.normalize(self.resize_gram_teacher(im2_base))
else:
gram_crop_1 = self.resize_gram_teacher(global_crop_1_transf)
gram_crop_2 = self.resize_gram_teacher(global_crop_2_transf)
output["gram_teacher_crops"] = [gram_crop_1, gram_crop_2]
# local crops:
if self.local_crops_subset_of_global_crops:
_local_crops = [self.local_transfo(im1_base) for _ in range(self.local_crops_number // 2)] + [
self.local_transfo(im2_base) for _ in range(self.local_crops_number // 2)
]
local_crops = []
offsets = []
gs = self.global_crops_size
ls = self.local_crops_size
for img in _local_crops:
rx, ry = np.random.randint(0, (gs - ls) // self.patch_size, 2) * self.patch_size
local_crops.append(img[:, rx : rx + ls, ry : ry + ls])
offsets.append((rx, ry))
output["local_crops"] = local_crops
output["offsets"] = offsets
else:
local_crops = [
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
]
output["local_crops"] = local_crops
output["offsets"] = ()
return output