"""Albumentations-based transform builder (the conventional-augmentation tier). Key correctness guarantees for segmentation: * masks always use NEAREST interpolation (integer class ids never blended); every geometric transform sets mask_interpolation=cv2.INTER_NEAREST. * image and mask receive the SAME random spatial parameters (Albumentations applies one transform jointly to image= and mask=). `aug` presets: none (resize+normalize only) | standard | strong. Returns a callable: (image HWC uint8, mask HW int) -> (FloatTensor[C,H,W], LongTensor[H,W]). """ from __future__ import annotations from typing import Tuple import cv2 import numpy as np import torch import albumentations as A from albumentations.pytorch import ToTensorV2 _IMAGENET_MEAN = (0.485, 0.456, 0.406) _IMAGENET_STD = (0.229, 0.224, 0.225) def _normalize(in_channels: int, normalize: str) -> A.Normalize: if normalize == "none": mean = (0.0,) * in_channels std = (1.0,) * in_channels elif normalize == "imagenet" and in_channels == 3: mean, std = _IMAGENET_MEAN, _IMAGENET_STD else: # auto if in_channels == 3: mean, std = _IMAGENET_MEAN, _IMAGENET_STD else: mean, std = (0.5,) * in_channels, (0.5,) * in_channels return A.Normalize(mean=mean, std=std, max_pixel_value=255.0) def build_transform(img_size: int, in_channels: int, train: bool, aug: str = "standard", normalize: str = "auto"): N = cv2.INTER_NEAREST L = cv2.INTER_LINEAR ops = [] if train and aug != "none": ops += [ A.Resize(img_size, img_size, interpolation=L, mask_interpolation=N), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Affine(scale=(0.9, 1.1), translate_percent=(0.0, 0.05), rotate=(-15, 15), interpolation=L, mask_interpolation=N, p=0.5), ] if aug == "strong": ops += [ A.ElasticTransform(alpha=30, sigma=6, interpolation=L, mask_interpolation=N, p=0.3), A.GridDistortion(num_steps=5, distort_limit=0.2, interpolation=L, mask_interpolation=N, p=0.3), A.RandomBrightnessContrast(p=0.5), A.GaussNoise(p=0.2), ] if in_channels == 3: ops.append(A.CLAHE(p=0.2)) else: ops.append(A.Resize(img_size, img_size, interpolation=L, mask_interpolation=N)) ops += [_normalize(in_channels, normalize), ToTensorV2()] compose = A.Compose(ops) def _apply(image: np.ndarray, mask: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]: out = compose(image=image, mask=mask) img = out["image"].float() # C,H,W msk = out["mask"].long() # H,W return img, msk return _apply