| """Albumentations-based transform builder (the conventional-augmentation tier). |
| |
| Key correctness guarantees for segmentation: |
| * masks always use NEAREST interpolation (integer class ids never blended); |
| every geometric transform sets mask_interpolation=cv2.INTER_NEAREST. |
| * image and mask receive the SAME random spatial parameters (Albumentations |
| applies one transform jointly to image= and mask=). |
| |
| `aug` presets: none (resize+normalize only) | standard | strong. |
| Returns a callable: (image HWC uint8, mask HW int) -> (FloatTensor[C,H,W], LongTensor[H,W]). |
| """ |
| from __future__ import annotations |
|
|
| from typing import Tuple |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
|
|
|
|
| _IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| _IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
|
| def _normalize(in_channels: int, normalize: str) -> A.Normalize: |
| if normalize == "none": |
| mean = (0.0,) * in_channels |
| std = (1.0,) * in_channels |
| elif normalize == "imagenet" and in_channels == 3: |
| mean, std = _IMAGENET_MEAN, _IMAGENET_STD |
| else: |
| if in_channels == 3: |
| mean, std = _IMAGENET_MEAN, _IMAGENET_STD |
| else: |
| mean, std = (0.5,) * in_channels, (0.5,) * in_channels |
| return A.Normalize(mean=mean, std=std, max_pixel_value=255.0) |
|
|
|
|
| def build_transform(img_size: int, in_channels: int, train: bool, |
| aug: str = "standard", normalize: str = "auto"): |
| N = cv2.INTER_NEAREST |
| L = cv2.INTER_LINEAR |
| ops = [] |
|
|
| if train and aug != "none": |
| ops += [ |
| A.Resize(img_size, img_size, interpolation=L, mask_interpolation=N), |
| A.HorizontalFlip(p=0.5), |
| A.VerticalFlip(p=0.5), |
| A.Affine(scale=(0.9, 1.1), translate_percent=(0.0, 0.05), |
| rotate=(-15, 15), interpolation=L, mask_interpolation=N, p=0.5), |
| ] |
| if aug == "strong": |
| ops += [ |
| A.ElasticTransform(alpha=30, sigma=6, interpolation=L, |
| mask_interpolation=N, p=0.3), |
| A.GridDistortion(num_steps=5, distort_limit=0.2, interpolation=L, |
| mask_interpolation=N, p=0.3), |
| A.RandomBrightnessContrast(p=0.5), |
| A.GaussNoise(p=0.2), |
| ] |
| if in_channels == 3: |
| ops.append(A.CLAHE(p=0.2)) |
| else: |
| ops.append(A.Resize(img_size, img_size, interpolation=L, mask_interpolation=N)) |
|
|
| ops += [_normalize(in_channels, normalize), ToTensorV2()] |
| compose = A.Compose(ops) |
|
|
| def _apply(image: np.ndarray, mask: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]: |
| out = compose(image=image, mask=mask) |
| img = out["image"].float() |
| msk = out["mask"].long() |
| return img, msk |
|
|
| return _apply |
|
|