File size: 8,276 Bytes
bc90483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.

import logging
import math
from typing import Sequence

import PIL
import torch
from torchvision import transforms

logger = logging.getLogger("dinov3")


def make_interpolation_mode(mode_str: str) -> transforms.InterpolationMode:
    return {mode.value: mode for mode in transforms.InterpolationMode}[mode_str]


class GaussianBlur(transforms.RandomApply):
    """

    Apply Gaussian Blur to the PIL image.

    """

    def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
        # NOTE: torchvision is applying 1 - probability to return the original image
        keep_p = 1 - p
        transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
        super().__init__(transforms=[transform], p=keep_p)


class MaybeToTensor(transforms.ToTensor):
    """

    Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.

    """

    def __call__(self, pic):
        """

        Args:

            pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.

        Returns:

            Tensor: Converted image.

        """
        if isinstance(pic, torch.Tensor):
            return pic
        return super().__call__(pic)


# Use timm's names
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

CROP_DEFAULT_SIZE = 224
RESIZE_DEFAULT_SIZE = int(256 * CROP_DEFAULT_SIZE / 224)


def make_normalize_transform(

    mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,

    std: Sequence[float] = IMAGENET_DEFAULT_STD,

) -> transforms.Normalize:
    return transforms.Normalize(mean=mean, std=std)


def make_base_transform(

    mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,

    std: Sequence[float] = IMAGENET_DEFAULT_STD,

) -> transforms.Normalize:
    return transforms.Compose(
        [
            MaybeToTensor(),
            make_normalize_transform(mean=mean, std=std),
        ]
    )


# This roughly matches torchvision's preset for classification training:
#   https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
def make_classification_train_transform(

    *,

    crop_size: int = CROP_DEFAULT_SIZE,

    interpolation=transforms.InterpolationMode.BICUBIC,

    hflip_prob: float = 0.5,

    mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,

    std: Sequence[float] = IMAGENET_DEFAULT_STD,

):
    transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
    if hflip_prob > 0.0:
        transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
    transforms_list.append(make_base_transform(mean, std))
    transform = transforms.Compose(transforms_list)
    logger.info(f"Built classification train transform\n{transform}")
    return transform


class _MaxSizeResize(object):
    def __init__(

        self,

        max_size: int,

        interpolation: transforms.InterpolationMode,

    ):
        self._size = self._make_size(max_size)
        self._resampling = self._make_resampling(interpolation)

    def _make_size(self, max_size: int):
        return (max_size, max_size)

    def _make_resampling(self, interpolation: transforms.InterpolationMode):
        if interpolation == transforms.InterpolationMode.BICUBIC:
            return PIL.Image.Resampling.BICUBIC
        if interpolation == transforms.InterpolationMode.BILINEAR:
            return PIL.Image.Resampling.BILINEAR
        assert interpolation == transforms.InterpolationMode.NEAREST
        return PIL.Image.Resampling.NEAREST

    def __call__(self, image):
        image.thumbnail(size=self._size, resample=self._resampling)
        return image


def make_resize_transform(

    *,

    resize_size: int,

    resize_square: bool = False,

    resize_large_side: bool = False,  # Set the larger side to resize_size instead of the smaller

    interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC,

):
    assert not (resize_square and resize_large_side), "These two options can not be set together"
    if resize_square:
        logger.info("resizing image as a square")
        size = (resize_size, resize_size)
        transform = transforms.Resize(size=size, interpolation=interpolation)
        return transform
    elif resize_large_side:
        logger.info("resizing based on large side")
        transform = _MaxSizeResize(max_size=resize_size, interpolation=interpolation)
        return transform
    else:
        transform = transforms.Resize(resize_size, interpolation=interpolation)
        return transform


# Derived from make_classification_eval_transform() with more control over resize and crop
def make_eval_transform(

    *,

    resize_size: int = RESIZE_DEFAULT_SIZE,

    crop_size: int = CROP_DEFAULT_SIZE,

    resize_square: bool = False,

    resize_large_side: bool = False,  # Set the larger side to resize_size instead of the smaller

    interpolation: transforms.InterpolationMode = transforms.InterpolationMode.BICUBIC,

    mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,

    std: Sequence[float] = IMAGENET_DEFAULT_STD,

) -> transforms.Compose:
    transforms_list = []
    resize_transform = make_resize_transform(
        resize_size=resize_size,
        resize_square=resize_square,
        resize_large_side=resize_large_side,
        interpolation=interpolation,
    )
    transforms_list.append(resize_transform)
    if crop_size:
        transforms_list.append(transforms.CenterCrop(crop_size))
    transforms_list.append(make_base_transform(mean, std))
    transform = transforms.Compose(transforms_list)
    logger.info(f"Built eval transform\n{transform}")
    return transform


# This matches (roughly) torchvision's preset for classification evaluation:
#   https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
def make_classification_eval_transform(

    *,

    resize_size: int = RESIZE_DEFAULT_SIZE,

    crop_size: int = CROP_DEFAULT_SIZE,

    interpolation=transforms.InterpolationMode.BICUBIC,

    mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,

    std: Sequence[float] = IMAGENET_DEFAULT_STD,

) -> transforms.Compose:
    return make_eval_transform(
        resize_size=resize_size,
        crop_size=crop_size,
        interpolation=interpolation,
        mean=mean,
        std=std,
        resize_square=False,
        resize_large_side=False,
    )


class MultipleResize(object):
    # A resize transform that makes the large side a multiple of a given number. That might change the aspect ratio.
    def __init__(self, interpolation=transforms.InterpolationMode.BILINEAR, multiple=1):
        self.multiple = multiple
        self.interpolation = interpolation

    def __call__(self, img):
        if self.multiple == 1:
            return img
        if hasattr(img, "shape"):
            h, w = img.shape[-2:]
        else:
            assert isinstance(
                img, PIL.Image.Image
            ), f"img should have a `shape` attribute or be a PIL Image, got {type(img)}"
            w, h = img.size
        new_h, new_w = [math.ceil(s / self.multiple) * self.multiple for s in (h, w)]
        resized_image = transforms.functional.resize(img, (new_h, new_w))
        return resized_image


def voc2007_classification_target_transform(label, n_categories=20):
    one_hot = torch.zeros(n_categories, dtype=int)
    for instance in label.instances:
        one_hot[instance.category_id] = True
    return one_hot


def imaterialist_classification_target_transform(label, n_categories=294):
    one_hot = torch.zeros(n_categories, dtype=int)
    one_hot[label.attributes] = True
    return one_hot


def get_target_transform(dataset_str):
    if "VOC2007" in dataset_str:
        return voc2007_classification_target_transform
    elif "IMaterialist" in dataset_str:
        return imaterialist_classification_target_transform
    return None