day4model / model_training /src /transforms.py
cardio-deploy
Deploy CardioScan inference 2026-04-23T18:11:52Z
3f984f1
from __future__ import annotations
from typing import List, Tuple
import torchvision.transforms as T
from PIL import Image
from src.config import CFG
# ---------------------------------------------------------------------------
# PIL helpers (TTA expects PIL → PIL transforms; xrv normalisation is applied
# downstream inside the Dataset).
# ---------------------------------------------------------------------------
def _pil_hflip(img: Image.Image) -> Image.Image:
return img.transpose(Image.FLIP_LEFT_RIGHT)
# ---------------------------------------------------------------------------
# Training and evaluation transforms
# ---------------------------------------------------------------------------
def make_transforms(img_size: int | None = None) -> Tuple[T.Compose, T.Compose]:
"""Return (train_transform, eval_transform) PIL-space pipelines.
All transforms produce a PIL grayscale image of size (img_size, img_size).
The downstream Dataset converts it to a single-channel xrv-normalised
tensor in [-1024, 1024].
Train pipeline: small affine, mild jitter, light hflip; random erasing
happens after xrv normalisation inside the Dataset.
Eval pipeline: deterministic resize.
"""
img_size = img_size if img_size is not None else CFG.img_size
train_tf = T.Compose([
T.Resize((img_size + 16, img_size + 16)),
T.RandomCrop((img_size, img_size)),
T.RandomHorizontalFlip(p=0.5),
T.RandomAffine(
degrees=8,
translate=(0.04, 0.04),
scale=(0.95, 1.05),
fill=0,
),
T.ColorJitter(brightness=0.15, contrast=0.15),
])
eval_tf = T.Compose([
T.Resize((img_size, img_size)),
])
return train_tf, eval_tf
# ---------------------------------------------------------------------------
# Test-time augmentation (TTA) transforms
# ---------------------------------------------------------------------------
def make_tta_transforms(img_size: int | None = None) -> List[T.Compose]:
"""Six deterministic PIL-space transforms.
All end with a resized PIL image ready for xrv_normalize_np().
Predictions are averaged across all passes (in logit space) inside
`tta_predict` / `tta_predict_ensemble`.
"""
img_size = img_size if img_size is not None else CFG.img_size
size = (img_size, img_size)
return [
T.Compose([T.Resize(size)]),
T.Compose([T.Resize(size), T.Lambda(_pil_hflip)]),
T.Compose([T.Resize((img_size + 20, img_size + 20)), T.CenterCrop(size)]),
T.Compose([T.Resize((img_size - 20, img_size - 20)),
T.Pad(10, fill=0), T.CenterCrop(size)]),
T.Compose([T.Resize(size),
T.RandomAffine(degrees=(6, 6), fill=0)]),
T.Compose([T.Resize(size),
T.RandomAffine(degrees=(-6, -6), fill=0)]),
]