Spaces:
Sleeping
Sleeping
| """ | |
| Image transforms for DL training and evaluation. | |
| Provides separate transform pipelines for training (with augmentation) | |
| and evaluation (resize + normalize only). | |
| """ | |
| from torchvision import transforms | |
| # ImageNet normalization statistics | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| def get_train_transforms(image_size: int = 224): | |
| """ | |
| Get training transforms with data augmentation. | |
| Includes: resize, random flip, rotation, color jitter, affine, | |
| gaussian blur, random erasing, and ImageNet normalization. | |
| Args: | |
| image_size: Target image size (default 224 for ResNet/EfficientNet) | |
| Returns: | |
| torchvision.transforms.Compose pipeline | |
| """ | |
| return transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomRotation(degrees=15), | |
| transforms.ColorJitter( | |
| brightness=0.3, | |
| contrast=0.3, | |
| saturation=0.3, | |
| hue=0.1, | |
| ), | |
| transforms.RandomAffine( | |
| degrees=0, | |
| translate=(0.1, 0.1), | |
| scale=(0.9, 1.1), | |
| ), | |
| transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)), | |
| ]) | |
| def get_eval_transforms(image_size: int = 224): | |
| """ | |
| Get evaluation transforms (no augmentation). | |
| Includes: resize and ImageNet normalization only. | |
| Args: | |
| image_size: Target image size (default 224) | |
| Returns: | |
| torchvision.transforms.Compose pipeline | |
| """ | |
| return transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
| ]) | |
| def get_minority_augment_transforms(): | |
| """ | |
| Get stronger augmentation pipeline for minority class images. | |
| Applied BEFORE the standard train transforms to create visual diversity | |
| for under-represented classes (e.g., fake backs). Includes more aggressive | |
| geometric and color perturbations. | |
| Returns: | |
| torchvision.transforms.Compose pipeline (operates on PIL images) | |
| """ | |
| return transforms.Compose([ | |
| transforms.RandomPerspective(distortion_scale=0.2, p=0.5), | |
| transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.3), | |
| transforms.ColorJitter( | |
| brightness=0.4, | |
| contrast=0.4, | |
| saturation=0.4, | |
| hue=0.15, | |
| ), | |
| transforms.RandomVerticalFlip(p=0.3), | |
| ]) | |
| def denormalize(tensor, mean=None, std=None): | |
| """ | |
| Reverse ImageNet normalization for visualization. | |
| Args: | |
| tensor: Normalized image tensor (C, H, W) | |
| mean: Normalization mean (defaults to ImageNet) | |
| std: Normalization std (defaults to ImageNet) | |
| Returns: | |
| Denormalized tensor with values in [0, 1] | |
| """ | |
| import torch | |
| if mean is None: | |
| mean = IMAGENET_MEAN | |
| if std is None: | |
| std = IMAGENET_STD | |
| mean = torch.tensor(mean).view(-1, 1, 1) | |
| std = torch.tensor(std).view(-1, 1, 1) | |
| if tensor.device != mean.device: | |
| mean = mean.to(tensor.device) | |
| std = std.to(tensor.device) | |
| return tensor * std + mean | |