""" Image transforms for DL training and evaluation. Provides separate transform pipelines for training (with augmentation) and evaluation (resize + normalize only). """ from torchvision import transforms # ImageNet normalization statistics IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] def get_train_transforms(image_size: int = 224): """ Get training transforms with data augmentation. Includes: resize, random flip, rotation, color jitter, affine, gaussian blur, random erasing, and ImageNet normalization. Args: image_size: Target image size (default 224 for ResNet/EfficientNet) Returns: torchvision.transforms.Compose pipeline """ return transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=15), transforms.ColorJitter( brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, ), transforms.RandomAffine( degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1), ), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)), ]) def get_eval_transforms(image_size: int = 224): """ Get evaluation transforms (no augmentation). Includes: resize and ImageNet normalization only. Args: image_size: Target image size (default 224) Returns: torchvision.transforms.Compose pipeline """ return transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) def get_minority_augment_transforms(): """ Get stronger augmentation pipeline for minority class images. Applied BEFORE the standard train transforms to create visual diversity for under-represented classes (e.g., fake backs). Includes more aggressive geometric and color perturbations. Returns: torchvision.transforms.Compose pipeline (operates on PIL images) """ return transforms.Compose([ transforms.RandomPerspective(distortion_scale=0.2, p=0.5), transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.3), transforms.ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, hue=0.15, ), transforms.RandomVerticalFlip(p=0.3), ]) def denormalize(tensor, mean=None, std=None): """ Reverse ImageNet normalization for visualization. Args: tensor: Normalized image tensor (C, H, W) mean: Normalization mean (defaults to ImageNet) std: Normalization std (defaults to ImageNet) Returns: Denormalized tensor with values in [0, 1] """ import torch if mean is None: mean = IMAGENET_MEAN if std is None: std = IMAGENET_STD mean = torch.tensor(mean).view(-1, 1, 1) std = torch.tensor(std).view(-1, 1, 1) if tensor.device != mean.device: mean = mean.to(tensor.device) std = std.to(tensor.device) return tensor * std + mean