import torch from torchvision import datasets from albumentations import ( Compose, HorizontalFlip, ShiftScaleRotate, CoarseDropout, Normalize, ColorJitter, PadIfNeeded, RandomCrop ) from albumentations.pytorch import ToTensorV2 import numpy as np # CIFAR-100 statistics (RGB) CIFAR100_MEAN = (0.5071, 0.4867, 0.4408) CIFAR100_STD = (0.2675, 0.2565, 0.2761) def _coarse_dropout_fill_value_from_mean(mean_rgb: tuple[float, float, float]) -> tuple[int, int, int]: """Convert mean RGB (0–1) to 0–255 scale for CoarseDropout fill color.""" return tuple(int(m * 255.0) for m in mean_rgb) class AlbumentationsAdapter: """Adapter to make Albumentations transforms compatible with torchvision datasets.""" def __init__(self, transform: Compose): self.transform = transform def __call__(self, img): img_np = np.array(img) augmented = self.transform(image=img_np) return augmented["image"] def get_transforms(_: str | None = None): """Return Albumentations transforms for CIFAR-100.""" fill_value = _coarse_dropout_fill_value_from_mean(CIFAR100_MEAN) train_transforms = Compose([ PadIfNeeded(min_height=36, min_width=36, border_mode=0, p=1.0), RandomCrop(height=32, width=32, p=1.0), HorizontalFlip(p=0.5), ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=5, p=0.3), CoarseDropout( num_holes_range=(1, 1), hole_height_range=(16, 16), hole_width_range=(16, 16), fill=fill_value, p=0.4, ), ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02, p=0.4), Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD), ToTensorV2(), ]) test_transforms = Compose([ Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD), ToTensorV2(), ]) return AlbumentationsAdapter(train_transforms), AlbumentationsAdapter(test_transforms) def get_datasets(data_dir: str = "./data", model_name: str | None = None): """Return CIFAR-100 train/test datasets with Albumentations transforms.""" train_transforms, test_transforms = get_transforms(model_name) train_dataset = datasets.CIFAR100( root=data_dir, train=True, download=True, transform=train_transforms ) test_dataset = datasets.CIFAR100( root=data_dir, train=False, download=True, transform=test_transforms ) return train_dataset, test_dataset def get_data_loaders( batch_size: int = 128, data_dir: str = "./data", num_workers: int = 2, pin_memory: bool = True, shuffle_train: bool = True, model_name: str | None = None, ): """Return CIFAR-100 train/test dataloaders with on-the-fly Albumentations.""" train_dataset, test_dataset = get_datasets(data_dir=data_dir, model_name=model_name) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers, pin_memory=pin_memory, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, ) return train_loader, test_loader