File size: 2,910 Bytes
3f984f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | from __future__ import annotations
from typing import List, Tuple
import torchvision.transforms as T
from PIL import Image
from src.config import CFG
# ---------------------------------------------------------------------------
# PIL helpers (TTA expects PIL → PIL transforms; xrv normalisation is applied
# downstream inside the Dataset).
# ---------------------------------------------------------------------------
def _pil_hflip(img: Image.Image) -> Image.Image:
return img.transpose(Image.FLIP_LEFT_RIGHT)
# ---------------------------------------------------------------------------
# Training and evaluation transforms
# ---------------------------------------------------------------------------
def make_transforms(img_size: int | None = None) -> Tuple[T.Compose, T.Compose]:
"""Return (train_transform, eval_transform) PIL-space pipelines.
All transforms produce a PIL grayscale image of size (img_size, img_size).
The downstream Dataset converts it to a single-channel xrv-normalised
tensor in [-1024, 1024].
Train pipeline: small affine, mild jitter, light hflip; random erasing
happens after xrv normalisation inside the Dataset.
Eval pipeline: deterministic resize.
"""
img_size = img_size if img_size is not None else CFG.img_size
train_tf = T.Compose([
T.Resize((img_size + 16, img_size + 16)),
T.RandomCrop((img_size, img_size)),
T.RandomHorizontalFlip(p=0.5),
T.RandomAffine(
degrees=8,
translate=(0.04, 0.04),
scale=(0.95, 1.05),
fill=0,
),
T.ColorJitter(brightness=0.15, contrast=0.15),
])
eval_tf = T.Compose([
T.Resize((img_size, img_size)),
])
return train_tf, eval_tf
# ---------------------------------------------------------------------------
# Test-time augmentation (TTA) transforms
# ---------------------------------------------------------------------------
def make_tta_transforms(img_size: int | None = None) -> List[T.Compose]:
"""Six deterministic PIL-space transforms.
All end with a resized PIL image ready for xrv_normalize_np().
Predictions are averaged across all passes (in logit space) inside
`tta_predict` / `tta_predict_ensemble`.
"""
img_size = img_size if img_size is not None else CFG.img_size
size = (img_size, img_size)
return [
T.Compose([T.Resize(size)]),
T.Compose([T.Resize(size), T.Lambda(_pil_hflip)]),
T.Compose([T.Resize((img_size + 20, img_size + 20)), T.CenterCrop(size)]),
T.Compose([T.Resize((img_size - 20, img_size - 20)),
T.Pad(10, fill=0), T.CenterCrop(size)]),
T.Compose([T.Resize(size),
T.RandomAffine(degrees=(6, 6), fill=0)]),
T.Compose([T.Resize(size),
T.RandomAffine(degrees=(-6, -6), fill=0)]),
]
|