CIFAR10-ImageClassifier / preprocess.py
nirmalpratheep's picture
Initial deployment of CIFAR-10 classifier to Hugging Face Spaces
2a12d90
import torch
from torchvision import datasets
from albumentations import (
Compose, HorizontalFlip, ShiftScaleRotate, CoarseDropout,
Normalize, ColorJitter, PadIfNeeded, RandomCrop
)
from albumentations.pytorch import ToTensorV2
import numpy as np
# CIFAR-10 statistics (RGB)
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)
def _coarse_dropout_fill_value_from_mean(mean_rgb: tuple[float, float, float]) -> tuple[int, int, int]:
"""Convert mean RGB (0–1) to 0–255 scale for CoarseDropout fill color."""
return tuple(int(m * 255.0) for m in mean_rgb)
class AlbumentationsAdapter:
"""Adapter to make Albumentations transforms compatible with torchvision datasets."""
def __init__(self, transform: Compose):
self.transform = transform
def __call__(self, img):
img_np = np.array(img)
augmented = self.transform(image=img_np)
return augmented["image"]
def get_transforms(_: str | None = None):
fill_value = _coarse_dropout_fill_value_from_mean(CIFAR10_MEAN)
train_transforms = Compose([
PadIfNeeded(min_height=36, min_width=36, border_mode=0, p=1.0),
RandomCrop(height=32, width=32, p=1.0),
HorizontalFlip(p=0.5),
ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=5, p=0.3),
CoarseDropout(
num_holes_range=(1, 1),
hole_height_range=(8, 8),
hole_width_range=(8, 8),
fill=fill_value,
p=0.4,
),
ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02, p=0.4),
Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD),
ToTensorV2(),
])
test_transforms = Compose([
Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD),
ToTensorV2(),
])
return AlbumentationsAdapter(train_transforms), AlbumentationsAdapter(test_transforms)
def get_datasets(data_dir: str = "./data", model_name: str | None = None):
"""Return CIFAR-10 train/test datasets with Albumentations transforms."""
train_transforms, test_transforms = get_transforms(model_name)
train_dataset = datasets.CIFAR10(
root=data_dir, train=True, download=True, transform=train_transforms
)
test_dataset = datasets.CIFAR10(
root=data_dir, train=False, download=True, transform=test_transforms
)
return train_dataset, test_dataset
def get_data_loaders(
batch_size: int = 128,
data_dir: str = "./data",
num_workers: int = 2,
pin_memory: bool = True,
shuffle_train: bool = True,
model_name: str | None = None,
):
"""Return CIFAR-10 train/test dataloaders with on-the-fly Albumentations."""
train_dataset, test_dataset = get_datasets(data_dir=data_dir, model_name=model_name)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=shuffle_train,
num_workers=num_workers,
pin_memory=pin_memory,
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
)
return train_loader, test_loader