|
|
|
|
|
|
|
|
|
|
| import logging
|
| import math
|
| from typing import Sequence
|
|
|
| import PIL
|
| import torch
|
| from torchvision import transforms
|
|
|
| logger = logging.getLogger("dinov3")
|
|
|
|
|
| def make_interpolation_mode(mode_str: str) -> transforms.InterpolationMode:
|
| return {mode.value: mode for mode in transforms.InterpolationMode}[mode_str]
|
|
|
|
|
| class GaussianBlur(transforms.RandomApply):
|
| """
|
| Apply Gaussian Blur to the PIL image.
|
| """
|
|
|
| def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
|
|
|
| keep_p = 1 - p
|
| transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
|
| super().__init__(transforms=[transform], p=keep_p)
|
|
|
|
|
| class MaybeToTensor(transforms.ToTensor):
|
| """
|
| Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
|
| """
|
|
|
| def __call__(self, pic):
|
| """
|
| Args:
|
| pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
|
| Returns:
|
| Tensor: Converted image.
|
| """
|
| if isinstance(pic, torch.Tensor):
|
| return pic
|
| return super().__call__(pic)
|
|
|
|
|
|
|
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
|
|
| CROP_DEFAULT_SIZE = 224
|
| RESIZE_DEFAULT_SIZE = int(256 * CROP_DEFAULT_SIZE / 224)
|
|
|
|
|
| def make_normalize_transform(
|
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
| std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
| ) -> transforms.Normalize:
|
| return transforms.Normalize(mean=mean, std=std)
|
|
|
|
|
| def make_base_transform(
|
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
| std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
| ) -> transforms.Normalize:
|
| return transforms.Compose(
|
| [
|
| MaybeToTensor(),
|
| make_normalize_transform(mean=mean, std=std),
|
| ]
|
| )
|
|
|
|
|
|
|
|
|
| def make_classification_train_transform(
|
| *,
|
| crop_size: int = CROP_DEFAULT_SIZE,
|
| interpolation=transforms.InterpolationMode.BICUBIC,
|
| hflip_prob: float = 0.5,
|
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
| std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
| ):
|
| transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
| if hflip_prob > 0.0:
|
| transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
|
| transforms_list.append(make_base_transform(mean, std))
|
| transform = transforms.Compose(transforms_list)
|
| logger.info(f"Built classification train transform\n{transform}")
|
| return transform
|
|
|
|
|
| class _MaxSizeResize(object):
|
| def __init__(
|
| self,
|
| max_size: int,
|
| interpolation: transforms.InterpolationMode,
|
| ):
|
| self._size = self._make_size(max_size)
|
| self._resampling = self._make_resampling(interpolation)
|
|
|
| def _make_size(self, max_size: int):
|
| return (max_size, max_size)
|
|
|
| def _make_resampling(self, interpolation: transforms.InterpolationMode):
|
| if interpolation == transforms.InterpolationMode.BICUBIC:
|
| return PIL.Image.Resampling.BICUBIC
|
| if interpolation == transforms.InterpolationMode.BILINEAR:
|
| return PIL.Image.Resampling.BILINEAR
|
| assert interpolation == transforms.InterpolationMode.NEAREST
|
| return PIL.Image.Resampling.NEAREST
|
|
|
| def __call__(self, image):
|
| image.thumbnail(size=self._size, resample=self._resampling)
|
| return image
|
|
|
|
|
| def make_resize_transform(
|
| *,
|
| resize_size: int,
|
| resize_square: bool = False,
|
| resize_large_side: bool = False,
|
| interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC,
|
| ):
|
| assert not (resize_square and resize_large_side), "These two options can not be set together"
|
| if resize_square:
|
| logger.info("resizing image as a square")
|
| size = (resize_size, resize_size)
|
| transform = transforms.Resize(size=size, interpolation=interpolation)
|
| return transform
|
| elif resize_large_side:
|
| logger.info("resizing based on large side")
|
| transform = _MaxSizeResize(max_size=resize_size, interpolation=interpolation)
|
| return transform
|
| else:
|
| transform = transforms.Resize(resize_size, interpolation=interpolation)
|
| return transform
|
|
|
|
|
|
|
| def make_eval_transform(
|
| *,
|
| resize_size: int = RESIZE_DEFAULT_SIZE,
|
| crop_size: int = CROP_DEFAULT_SIZE,
|
| resize_square: bool = False,
|
| resize_large_side: bool = False,
|
| interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC,
|
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
| std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
| ) -> transforms.Compose:
|
| transforms_list = []
|
| resize_transform = make_resize_transform(
|
| resize_size=resize_size,
|
| resize_square=resize_square,
|
| resize_large_side=resize_large_side,
|
| interpolation=interpolation,
|
| )
|
| transforms_list.append(resize_transform)
|
| if crop_size:
|
| transforms_list.append(transforms.CenterCrop(crop_size))
|
| transforms_list.append(make_base_transform(mean, std))
|
| transform = transforms.Compose(transforms_list)
|
| logger.info(f"Built eval transform\n{transform}")
|
| return transform
|
|
|
|
|
|
|
|
|
| def make_classification_eval_transform(
|
| *,
|
| resize_size: int = RESIZE_DEFAULT_SIZE,
|
| crop_size: int = CROP_DEFAULT_SIZE,
|
| interpolation=transforms.InterpolationMode.BICUBIC,
|
| mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
| std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
| ) -> transforms.Compose:
|
| return make_eval_transform(
|
| resize_size=resize_size,
|
| crop_size=crop_size,
|
| interpolation=interpolation,
|
| mean=mean,
|
| std=std,
|
| resize_square=False,
|
| resize_large_side=False,
|
| )
|
|
|
|
|
| class MultipleResize(object):
|
|
|
| def __init__(self, interpolation=transforms.InterpolationMode.BILINEAR, multiple=1):
|
| self.multiple = multiple
|
| self.interpolation = interpolation
|
|
|
| def __call__(self, img):
|
| if self.multiple == 1:
|
| return img
|
| if hasattr(img, "shape"):
|
| h, w = img.shape[-2:]
|
| else:
|
| assert isinstance(
|
| img, PIL.Image.Image
|
| ), f"img should have a `shape` attribute or be a PIL Image, got {type(img)}"
|
| w, h = img.size
|
| new_h, new_w = [math.ceil(s / self.multiple) * self.multiple for s in (h, w)]
|
| resized_image = transforms.functional.resize(img, (new_h, new_w))
|
| return resized_image
|
|
|
|
|
| def voc2007_classification_target_transform(label, n_categories=20):
|
| one_hot = torch.zeros(n_categories, dtype=int)
|
| for instance in label.instances:
|
| one_hot[instance.category_id] = True
|
| return one_hot
|
|
|
|
|
| def imaterialist_classification_target_transform(label, n_categories=294):
|
| one_hot = torch.zeros(n_categories, dtype=int)
|
| one_hot[label.attributes] = True
|
| return one_hot
|
|
|
|
|
| def get_target_transform(dataset_str):
|
| if "VOC2007" in dataset_str:
|
| return voc2007_classification_target_transform
|
| elif "IMaterialist" in dataset_str:
|
| return imaterialist_classification_target_transform
|
| return None
|
|
|