| import torch |
| from torchvision import datasets |
| from albumentations import ( |
| Compose, HorizontalFlip, ShiftScaleRotate, CoarseDropout, |
| Normalize, ColorJitter, PadIfNeeded, RandomCrop |
| ) |
| from albumentations.pytorch import ToTensorV2 |
| import numpy as np |
|
|
|
|
| |
| CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) |
| CIFAR10_STD = (0.2470, 0.2435, 0.2616) |
|
|
|
|
| def _coarse_dropout_fill_value_from_mean(mean_rgb: tuple[float, float, float]) -> tuple[int, int, int]: |
| """Convert mean RGB (0–1) to 0–255 scale for CoarseDropout fill color.""" |
| return tuple(int(m * 255.0) for m in mean_rgb) |
|
|
|
|
| class AlbumentationsAdapter: |
| """Adapter to make Albumentations transforms compatible with torchvision datasets.""" |
| def __init__(self, transform: Compose): |
| self.transform = transform |
|
|
| def __call__(self, img): |
| img_np = np.array(img) |
| augmented = self.transform(image=img_np) |
| return augmented["image"] |
|
|
|
|
| def get_transforms(_: str | None = None): |
| fill_value = _coarse_dropout_fill_value_from_mean(CIFAR10_MEAN) |
|
|
| train_transforms = Compose([ |
| PadIfNeeded(min_height=36, min_width=36, border_mode=0, p=1.0), |
| RandomCrop(height=32, width=32, p=1.0), |
| HorizontalFlip(p=0.5), |
| ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=5, p=0.3), |
| CoarseDropout( |
| num_holes_range=(1, 1), |
| hole_height_range=(8, 8), |
| hole_width_range=(8, 8), |
| fill=fill_value, |
| p=0.4, |
| ), |
| ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02, p=0.4), |
| Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD), |
| ToTensorV2(), |
| ]) |
|
|
| test_transforms = Compose([ |
| Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD), |
| ToTensorV2(), |
| ]) |
|
|
| return AlbumentationsAdapter(train_transforms), AlbumentationsAdapter(test_transforms) |
|
|
|
|
| def get_datasets(data_dir: str = "./data", model_name: str | None = None): |
| """Return CIFAR-10 train/test datasets with Albumentations transforms.""" |
| train_transforms, test_transforms = get_transforms(model_name) |
|
|
| train_dataset = datasets.CIFAR10( |
| root=data_dir, train=True, download=True, transform=train_transforms |
| ) |
| test_dataset = datasets.CIFAR10( |
| root=data_dir, train=False, download=True, transform=test_transforms |
| ) |
|
|
| return train_dataset, test_dataset |
|
|
|
|
| def get_data_loaders( |
| batch_size: int = 128, |
| data_dir: str = "./data", |
| num_workers: int = 2, |
| pin_memory: bool = True, |
| shuffle_train: bool = True, |
| model_name: str | None = None, |
| ): |
| """Return CIFAR-10 train/test dataloaders with on-the-fly Albumentations.""" |
| train_dataset, test_dataset = get_datasets(data_dir=data_dir, model_name=model_name) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=shuffle_train, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| ) |
|
|
| test_loader = torch.utils.data.DataLoader( |
| test_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| ) |
|
|
| return train_loader, test_loader |
|
|