File size: 3,967 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2025 STMicroelectronics.
#  * All rights reserved.
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/
from timm.data.transforms_factory import (
    transforms_imagenet_eval,
    transforms_imagenet_train,
)
from timm.data.transforms import ToNumpy
from torchvision import transforms

from image_classification.pt.src.datasets.augmentations.augs.cutout import Cutout


DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


def get_imagenet_transforms(
    img_size,
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
    crop_pct=DEFAULT_CROP_PCT,
    scale=(0.08, 1.0),  # Random resize scale
    ratio=(3.0 / 4.0, 4.0 / 3.0),  # Random resize aspect ratio
    hflip=0.5,
    vflip=0.0,
    color_jitter=0.4,  # Color jitter factor
    auto_augment=None,  # Use AutoAugment policy. Choices: 'v0', 'v0r', 'original', 'originalr', 'randaugment', 'augmix'
    train_interpolation='random',  # Training interpolation (random, bilinear, bicubic default: "random")
    test_interpolation='bilinear',
    re_prob=0.0,  # Random erase prob (default: 0.)
    re_mode='pixel',  # Random erase mode (default: "pixel")
    re_count=1,  # Random erase count (default: 1)
    re_num_splits=0,
    use_prefetcher=False,
):
    if isinstance(img_size, (tuple, list)):
        img_size = img_size[-1]

    train_transforms = transforms_imagenet_train(
        img_size,
        mean=mean,
        std=std,
        scale=scale,
        ratio=ratio,
        hflip=hflip,
        vflip=vflip,
        color_jitter=color_jitter,
        auto_augment=auto_augment,
        interpolation=train_interpolation,
        re_prob=re_prob,
        re_mode=re_mode,
        re_count=re_count,
        re_num_splits=re_num_splits,
        use_prefetcher=use_prefetcher,
    )
    val_transforms = transforms_imagenet_eval(
        img_size,
        mean=mean,
        std=std,
        crop_pct=crop_pct,
        interpolation=test_interpolation,
        use_prefetcher=use_prefetcher,
    )
    return train_transforms, val_transforms


def get_vanilla_transforms(
    img_size,
    hflip=0.5,
    jitter=0.4,
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
    crop_pct=DEFAULT_CROP_PCT,
    add_train_transforms=None,
    add_test_transforms=None,
    cutout_args=None,
    use_prefetcher=False,
):
    if isinstance(img_size, (tuple, list)):
        img_size = img_size[-1]

    train_transforms = [
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(hflip),
        transforms.ColorJitter(
            brightness=jitter, contrast=jitter, saturation=jitter, hue=0
        ),
    ]
    if add_train_transforms is not None:
        train_transforms.append(add_train_transforms)

    if not use_prefetcher:
        train_transforms.append(transforms.ToTensor())
    else:
        train_transforms.append(ToNumpy())

    if cutout_args is not None:
        train_transforms.append(Cutout(**cutout_args))

    if not use_prefetcher:
        train_transforms.append(transforms.Normalize(mean, std))

    test_transforms = [
        transforms.Resize(int(img_size / crop_pct)),
        transforms.CenterCrop(img_size),
    ]
    if add_test_transforms is not None:
        test_transforms.append(add_test_transforms)

    if not use_prefetcher:
        test_transforms += [
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ]
    else:
        test_transforms.append(ToNumpy())

    return transforms.Compose(train_transforms), transforms.Compose(test_transforms)