File size: 3,369 Bytes
0ba6002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61ba70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ba6002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Image transforms for DL training and evaluation.

Provides separate transform pipelines for training (with augmentation)
and evaluation (resize + normalize only).
"""

from torchvision import transforms


# ImageNet normalization statistics
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


def get_train_transforms(image_size: int = 224):
    """
    Get training transforms with data augmentation.

    Includes: resize, random flip, rotation, color jitter, affine,
    gaussian blur, random erasing, and ImageNet normalization.

    Args:
        image_size: Target image size (default 224 for ResNet/EfficientNet)

    Returns:
        torchvision.transforms.Compose pipeline
    """
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(
            brightness=0.3,
            contrast=0.3,
            saturation=0.3,
            hue=0.1,
        ),
        transforms.RandomAffine(
            degrees=0,
            translate=(0.1, 0.1),
            scale=(0.9, 1.1),
        ),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)),
    ])


def get_eval_transforms(image_size: int = 224):
    """
    Get evaluation transforms (no augmentation).

    Includes: resize and ImageNet normalization only.

    Args:
        image_size: Target image size (default 224)

    Returns:
        torchvision.transforms.Compose pipeline
    """
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])


def get_minority_augment_transforms():
    """
    Get stronger augmentation pipeline for minority class images.

    Applied BEFORE the standard train transforms to create visual diversity
    for under-represented classes (e.g., fake backs). Includes more aggressive
    geometric and color perturbations.

    Returns:
        torchvision.transforms.Compose pipeline (operates on PIL images)
    """
    return transforms.Compose([
        transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.3),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.4,
            hue=0.15,
        ),
        transforms.RandomVerticalFlip(p=0.3),
    ])


def denormalize(tensor, mean=None, std=None):
    """
    Reverse ImageNet normalization for visualization.

    Args:
        tensor: Normalized image tensor (C, H, W)
        mean: Normalization mean (defaults to ImageNet)
        std: Normalization std (defaults to ImageNet)

    Returns:
        Denormalized tensor with values in [0, 1]
    """
    import torch

    if mean is None:
        mean = IMAGENET_MEAN
    if std is None:
        std = IMAGENET_STD

    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)

    if tensor.device != mean.device:
        mean = mean.to(tensor.device)
        std = std.to(tensor.device)

    return tensor * std + mean