File size: 2,535 Bytes
7e5d57c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations

import random
from typing import Any

from PIL import Image, ImageEnhance, ImageOps


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_train_transform(config: dict[str, Any]):
    from torchvision import transforms

    size = int(config["preprocessing"]["image_size"])
    aug = config.get("augmentation", {})
    ops: list[Any] = [transforms.Resize((size, size))]
    if aug.get("enabled", True):
        if aug.get("horizontal_flip", True):
            ops.append(transforms.RandomHorizontalFlip(p=0.5))
        ops.append(
            transforms.RandomAffine(
                degrees=float(aug.get("rotation_degrees", 10)),
                translate=(float(aug.get("translate", 0.03)), float(aug.get("translate", 0.03))),
                scale=(float(aug.get("scale_min", 0.95)), float(aug.get("scale_max", 1.05))),
                fill=255,
            )
        )
        jitter = aug.get("color_jitter", {})
        if jitter.get("enabled", True):
            ops.append(
                transforms.ColorJitter(
                    brightness=float(jitter.get("brightness", 0.12)),
                    contrast=float(jitter.get("contrast", 0.12)),
                    saturation=float(jitter.get("saturation", 0.08)),
                    hue=float(jitter.get("hue", 0.02)),
                )
            )
    ops.extend([transforms.ToTensor(), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
    return transforms.Compose(ops)


def build_eval_transform(config: dict[str, Any]):
    from torchvision import transforms

    size = int(config["preprocessing"]["image_size"])
    return transforms.Compose(
        [
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ]
    )


def augment_pil_for_classical(image: Image.Image, augment_id: int, seed: int = 42) -> Image.Image:
    rng = random.Random(seed + augment_id * 9973)
    img = image.copy()
    variant = augment_id % 6
    if variant in {1, 4}:
        img = ImageOps.mirror(img)
    if variant in {2, 3, 5}:
        img = img.rotate(rng.uniform(-9.0, 9.0), resample=Image.Resampling.BILINEAR, fillcolor=255)
    if variant in {3, 4}:
        img = ImageEnhance.Brightness(img).enhance(rng.uniform(0.9, 1.1))
        img = ImageEnhance.Contrast(img).enhance(rng.uniform(0.9, 1.12))
    if variant == 5:
        img = ImageEnhance.Sharpness(img).enhance(rng.uniform(0.9, 1.2))
    return img