File size: 2,910 Bytes
3f984f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

from typing import List, Tuple

import torchvision.transforms as T
from PIL import Image

from src.config import CFG


# ---------------------------------------------------------------------------
# PIL helpers (TTA expects PIL → PIL transforms; xrv normalisation is applied
# downstream inside the Dataset).
# ---------------------------------------------------------------------------
def _pil_hflip(img: Image.Image) -> Image.Image:
    return img.transpose(Image.FLIP_LEFT_RIGHT)


# ---------------------------------------------------------------------------
# Training and evaluation transforms
# ---------------------------------------------------------------------------
def make_transforms(img_size: int | None = None) -> Tuple[T.Compose, T.Compose]:
    """Return (train_transform, eval_transform) PIL-space pipelines.

    All transforms produce a PIL grayscale image of size (img_size, img_size).
    The downstream Dataset converts it to a single-channel xrv-normalised
    tensor in [-1024, 1024].

    Train pipeline: small affine, mild jitter, light hflip; random erasing
                    happens after xrv normalisation inside the Dataset.
    Eval pipeline:  deterministic resize.
    """
    img_size = img_size if img_size is not None else CFG.img_size

    train_tf = T.Compose([
        T.Resize((img_size + 16, img_size + 16)),
        T.RandomCrop((img_size, img_size)),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomAffine(
            degrees=8,
            translate=(0.04, 0.04),
            scale=(0.95, 1.05),
            fill=0,
        ),
        T.ColorJitter(brightness=0.15, contrast=0.15),
    ])

    eval_tf = T.Compose([
        T.Resize((img_size, img_size)),
    ])

    return train_tf, eval_tf


# ---------------------------------------------------------------------------
# Test-time augmentation (TTA) transforms
# ---------------------------------------------------------------------------
def make_tta_transforms(img_size: int | None = None) -> List[T.Compose]:
    """Six deterministic PIL-space transforms.

    All end with a resized PIL image ready for xrv_normalize_np().
    Predictions are averaged across all passes (in logit space) inside
    `tta_predict` / `tta_predict_ensemble`.
    """
    img_size = img_size if img_size is not None else CFG.img_size
    size = (img_size, img_size)

    return [
        T.Compose([T.Resize(size)]),
        T.Compose([T.Resize(size), T.Lambda(_pil_hflip)]),
        T.Compose([T.Resize((img_size + 20, img_size + 20)), T.CenterCrop(size)]),
        T.Compose([T.Resize((img_size - 20, img_size - 20)),
                   T.Pad(10, fill=0), T.CenterCrop(size)]),
        T.Compose([T.Resize(size),
                   T.RandomAffine(degrees=(6, 6), fill=0)]),
        T.Compose([T.Resize(size),
                   T.RandomAffine(degrees=(-6, -6), fill=0)]),
    ]