File size: 3,216 Bytes
2a12d90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
from torchvision import datasets
from albumentations import (
    Compose, HorizontalFlip, ShiftScaleRotate, CoarseDropout,
    Normalize, ColorJitter, PadIfNeeded, RandomCrop
)
from albumentations.pytorch import ToTensorV2
import numpy as np


# CIFAR-10 statistics (RGB)
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)


def _coarse_dropout_fill_value_from_mean(mean_rgb: tuple[float, float, float]) -> tuple[int, int, int]:
    """Convert mean RGB (0–1) to 0–255 scale for CoarseDropout fill color."""
    return tuple(int(m * 255.0) for m in mean_rgb)


class AlbumentationsAdapter:
    """Adapter to make Albumentations transforms compatible with torchvision datasets."""
    def __init__(self, transform: Compose):
        self.transform = transform

    def __call__(self, img):
        img_np = np.array(img)
        augmented = self.transform(image=img_np)
        return augmented["image"]


def get_transforms(_: str | None = None):
    fill_value = _coarse_dropout_fill_value_from_mean(CIFAR10_MEAN)

    train_transforms = Compose([
        PadIfNeeded(min_height=36, min_width=36, border_mode=0, p=1.0),
        RandomCrop(height=32, width=32, p=1.0),
        HorizontalFlip(p=0.5),
        ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=5, p=0.3),
        CoarseDropout(
            num_holes_range=(1, 1),
            hole_height_range=(8, 8),
            hole_width_range=(8, 8),
            fill=fill_value,
            p=0.4,
        ),
        ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02, p=0.4),
        Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD),
        ToTensorV2(),
    ])

    test_transforms = Compose([
        Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD),
        ToTensorV2(),
    ])

    return AlbumentationsAdapter(train_transforms), AlbumentationsAdapter(test_transforms)


def get_datasets(data_dir: str = "./data", model_name: str | None = None):
    """Return CIFAR-10 train/test datasets with Albumentations transforms."""
    train_transforms, test_transforms = get_transforms(model_name)

    train_dataset = datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=train_transforms
    )
    test_dataset = datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=test_transforms
    )

    return train_dataset, test_dataset


def get_data_loaders(
    batch_size: int = 128,
    data_dir: str = "./data",
    num_workers: int = 2,
    pin_memory: bool = True,
    shuffle_train: bool = True,
    model_name: str | None = None,
):
    """Return CIFAR-10 train/test dataloaders with on-the-fly Albumentations."""
    train_dataset, test_dataset = get_datasets(data_dir=data_dir, model_name=model_name)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=shuffle_train,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return train_loader, test_loader