Prior2DSM / src /dinov3 /data /transforms.py
osherr's picture
Upload 222 files
bc90483 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import logging
import math
from typing import Sequence
import PIL
import torch
from torchvision import transforms
logger = logging.getLogger("dinov3")
def make_interpolation_mode(mode_str: str) -> transforms.InterpolationMode:
return {mode.value: mode for mode in transforms.InterpolationMode}[mode_str]
class GaussianBlur(transforms.RandomApply):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
# NOTE: torchvision is applying 1 - probability to return the original image
keep_p = 1 - p
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
super().__init__(transforms=[transform], p=keep_p)
class MaybeToTensor(transforms.ToTensor):
"""
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, torch.Tensor):
return pic
return super().__call__(pic)
# Use timm's names
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
CROP_DEFAULT_SIZE = 224
RESIZE_DEFAULT_SIZE = int(256 * CROP_DEFAULT_SIZE / 224)
def make_normalize_transform(
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Normalize:
return transforms.Normalize(mean=mean, std=std)
def make_base_transform(
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Normalize:
return transforms.Compose(
[
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
)
# This roughly matches torchvision's preset for classification training:
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
def make_classification_train_transform(
*,
crop_size: int = CROP_DEFAULT_SIZE,
interpolation=transforms.InterpolationMode.BICUBIC,
hflip_prob: float = 0.5,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
):
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0.0:
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
transforms_list.append(make_base_transform(mean, std))
transform = transforms.Compose(transforms_list)
logger.info(f"Built classification train transform\n{transform}")
return transform
class _MaxSizeResize(object):
def __init__(
self,
max_size: int,
interpolation: transforms.InterpolationMode,
):
self._size = self._make_size(max_size)
self._resampling = self._make_resampling(interpolation)
def _make_size(self, max_size: int):
return (max_size, max_size)
def _make_resampling(self, interpolation: transforms.InterpolationMode):
if interpolation == transforms.InterpolationMode.BICUBIC:
return PIL.Image.Resampling.BICUBIC
if interpolation == transforms.InterpolationMode.BILINEAR:
return PIL.Image.Resampling.BILINEAR
assert interpolation == transforms.InterpolationMode.NEAREST
return PIL.Image.Resampling.NEAREST
def __call__(self, image):
image.thumbnail(size=self._size, resample=self._resampling)
return image
def make_resize_transform(
*,
resize_size: int,
resize_square: bool = False,
resize_large_side: bool = False, # Set the larger side to resize_size instead of the smaller
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC,
):
assert not (resize_square and resize_large_side), "These two options can not be set together"
if resize_square:
logger.info("resizing image as a square")
size = (resize_size, resize_size)
transform = transforms.Resize(size=size, interpolation=interpolation)
return transform
elif resize_large_side:
logger.info("resizing based on large side")
transform = _MaxSizeResize(max_size=resize_size, interpolation=interpolation)
return transform
else:
transform = transforms.Resize(resize_size, interpolation=interpolation)
return transform
# Derived from make_classification_eval_transform() with more control over resize and crop
def make_eval_transform(
*,
resize_size: int = RESIZE_DEFAULT_SIZE,
crop_size: int = CROP_DEFAULT_SIZE,
resize_square: bool = False,
resize_large_side: bool = False, # Set the larger side to resize_size instead of the smaller
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Compose:
transforms_list = []
resize_transform = make_resize_transform(
resize_size=resize_size,
resize_square=resize_square,
resize_large_side=resize_large_side,
interpolation=interpolation,
)
transforms_list.append(resize_transform)
if crop_size:
transforms_list.append(transforms.CenterCrop(crop_size))
transforms_list.append(make_base_transform(mean, std))
transform = transforms.Compose(transforms_list)
logger.info(f"Built eval transform\n{transform}")
return transform
# This matches (roughly) torchvision's preset for classification evaluation:
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
def make_classification_eval_transform(
*,
resize_size: int = RESIZE_DEFAULT_SIZE,
crop_size: int = CROP_DEFAULT_SIZE,
interpolation=transforms.InterpolationMode.BICUBIC,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Compose:
return make_eval_transform(
resize_size=resize_size,
crop_size=crop_size,
interpolation=interpolation,
mean=mean,
std=std,
resize_square=False,
resize_large_side=False,
)
class MultipleResize(object):
# A resize transform that makes the large side a multiple of a given number. That might change the aspect ratio.
def __init__(self, interpolation=transforms.InterpolationMode.BILINEAR, multiple=1):
self.multiple = multiple
self.interpolation = interpolation
def __call__(self, img):
if self.multiple == 1:
return img
if hasattr(img, "shape"):
h, w = img.shape[-2:]
else:
assert isinstance(
img, PIL.Image.Image
), f"img should have a `shape` attribute or be a PIL Image, got {type(img)}"
w, h = img.size
new_h, new_w = [math.ceil(s / self.multiple) * self.multiple for s in (h, w)]
resized_image = transforms.functional.resize(img, (new_h, new_w))
return resized_image
def voc2007_classification_target_transform(label, n_categories=20):
one_hot = torch.zeros(n_categories, dtype=int)
for instance in label.instances:
one_hot[instance.category_id] = True
return one_hot
def imaterialist_classification_target_transform(label, n_categories=294):
one_hot = torch.zeros(n_categories, dtype=int)
one_hot[label.attributes] = True
return one_hot
def get_target_transform(dataset_str):
if "VOC2007" in dataset_str:
return voc2007_classification_target_transform
elif "IMaterialist" in dataset_str:
return imaterialist_classification_target_transform
return None