File size: 8,276 Bytes
bc90483 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | # Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import logging
import math
from typing import Sequence
import PIL
import torch
from torchvision import transforms
logger = logging.getLogger("dinov3")
def make_interpolation_mode(mode_str: str) -> transforms.InterpolationMode:
return {mode.value: mode for mode in transforms.InterpolationMode}[mode_str]
class GaussianBlur(transforms.RandomApply):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
# NOTE: torchvision is applying 1 - probability to return the original image
keep_p = 1 - p
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
super().__init__(transforms=[transform], p=keep_p)
class MaybeToTensor(transforms.ToTensor):
"""
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, torch.Tensor):
return pic
return super().__call__(pic)
# Use timm's names
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
CROP_DEFAULT_SIZE = 224
RESIZE_DEFAULT_SIZE = int(256 * CROP_DEFAULT_SIZE / 224)
def make_normalize_transform(
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Normalize:
return transforms.Normalize(mean=mean, std=std)
def make_base_transform(
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Normalize:
return transforms.Compose(
[
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
)
# This roughly matches torchvision's preset for classification training:
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
def make_classification_train_transform(
*,
crop_size: int = CROP_DEFAULT_SIZE,
interpolation=transforms.InterpolationMode.BICUBIC,
hflip_prob: float = 0.5,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
):
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0.0:
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
transforms_list.append(make_base_transform(mean, std))
transform = transforms.Compose(transforms_list)
logger.info(f"Built classification train transform\n{transform}")
return transform
class _MaxSizeResize(object):
def __init__(
self,
max_size: int,
interpolation: transforms.InterpolationMode,
):
self._size = self._make_size(max_size)
self._resampling = self._make_resampling(interpolation)
def _make_size(self, max_size: int):
return (max_size, max_size)
def _make_resampling(self, interpolation: transforms.InterpolationMode):
if interpolation == transforms.InterpolationMode.BICUBIC:
return PIL.Image.Resampling.BICUBIC
if interpolation == transforms.InterpolationMode.BILINEAR:
return PIL.Image.Resampling.BILINEAR
assert interpolation == transforms.InterpolationMode.NEAREST
return PIL.Image.Resampling.NEAREST
def __call__(self, image):
image.thumbnail(size=self._size, resample=self._resampling)
return image
def make_resize_transform(
*,
resize_size: int,
resize_square: bool = False,
resize_large_side: bool = False, # Set the larger side to resize_size instead of the smaller
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC,
):
assert not (resize_square and resize_large_side), "These two options can not be set together"
if resize_square:
logger.info("resizing image as a square")
size = (resize_size, resize_size)
transform = transforms.Resize(size=size, interpolation=interpolation)
return transform
elif resize_large_side:
logger.info("resizing based on large side")
transform = _MaxSizeResize(max_size=resize_size, interpolation=interpolation)
return transform
else:
transform = transforms.Resize(resize_size, interpolation=interpolation)
return transform
# Derived from make_classification_eval_transform() with more control over resize and crop
def make_eval_transform(
*,
resize_size: int = RESIZE_DEFAULT_SIZE,
crop_size: int = CROP_DEFAULT_SIZE,
resize_square: bool = False,
resize_large_side: bool = False, # Set the larger side to resize_size instead of the smaller
interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Compose:
transforms_list = []
resize_transform = make_resize_transform(
resize_size=resize_size,
resize_square=resize_square,
resize_large_side=resize_large_side,
interpolation=interpolation,
)
transforms_list.append(resize_transform)
if crop_size:
transforms_list.append(transforms.CenterCrop(crop_size))
transforms_list.append(make_base_transform(mean, std))
transform = transforms.Compose(transforms_list)
logger.info(f"Built eval transform\n{transform}")
return transform
# This matches (roughly) torchvision's preset for classification evaluation:
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
def make_classification_eval_transform(
*,
resize_size: int = RESIZE_DEFAULT_SIZE,
crop_size: int = CROP_DEFAULT_SIZE,
interpolation=transforms.InterpolationMode.BICUBIC,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Compose:
return make_eval_transform(
resize_size=resize_size,
crop_size=crop_size,
interpolation=interpolation,
mean=mean,
std=std,
resize_square=False,
resize_large_side=False,
)
class MultipleResize(object):
# A resize transform that makes the large side a multiple of a given number. That might change the aspect ratio.
def __init__(self, interpolation=transforms.InterpolationMode.BILINEAR, multiple=1):
self.multiple = multiple
self.interpolation = interpolation
def __call__(self, img):
if self.multiple == 1:
return img
if hasattr(img, "shape"):
h, w = img.shape[-2:]
else:
assert isinstance(
img, PIL.Image.Image
), f"img should have a `shape` attribute or be a PIL Image, got {type(img)}"
w, h = img.size
new_h, new_w = [math.ceil(s / self.multiple) * self.multiple for s in (h, w)]
resized_image = transforms.functional.resize(img, (new_h, new_w))
return resized_image
def voc2007_classification_target_transform(label, n_categories=20):
one_hot = torch.zeros(n_categories, dtype=int)
for instance in label.instances:
one_hot[instance.category_id] = True
return one_hot
def imaterialist_classification_target_transform(label, n_categories=294):
one_hot = torch.zeros(n_categories, dtype=int)
one_hot[label.attributes] = True
return one_hot
def get_target_transform(dataset_str):
if "VOC2007" in dataset_str:
return voc2007_classification_target_transform
elif "IMaterialist" in dataset_str:
return imaterialist_classification_target_transform
return None
|