GenSeg-Baselines / code /framework /data /transforms.py
MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
2.92 kB
"""Albumentations-based transform builder (the conventional-augmentation tier).
Key correctness guarantees for segmentation:
* masks always use NEAREST interpolation (integer class ids never blended);
every geometric transform sets mask_interpolation=cv2.INTER_NEAREST.
* image and mask receive the SAME random spatial parameters (Albumentations
applies one transform jointly to image= and mask=).
`aug` presets: none (resize+normalize only) | standard | strong.
Returns a callable: (image HWC uint8, mask HW int) -> (FloatTensor[C,H,W], LongTensor[H,W]).
"""
from __future__ import annotations
from typing import Tuple
import cv2
import numpy as np
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
_IMAGENET_MEAN = (0.485, 0.456, 0.406)
_IMAGENET_STD = (0.229, 0.224, 0.225)
def _normalize(in_channels: int, normalize: str) -> A.Normalize:
if normalize == "none":
mean = (0.0,) * in_channels
std = (1.0,) * in_channels
elif normalize == "imagenet" and in_channels == 3:
mean, std = _IMAGENET_MEAN, _IMAGENET_STD
else: # auto
if in_channels == 3:
mean, std = _IMAGENET_MEAN, _IMAGENET_STD
else:
mean, std = (0.5,) * in_channels, (0.5,) * in_channels
return A.Normalize(mean=mean, std=std, max_pixel_value=255.0)
def build_transform(img_size: int, in_channels: int, train: bool,
aug: str = "standard", normalize: str = "auto"):
N = cv2.INTER_NEAREST
L = cv2.INTER_LINEAR
ops = []
if train and aug != "none":
ops += [
A.Resize(img_size, img_size, interpolation=L, mask_interpolation=N),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Affine(scale=(0.9, 1.1), translate_percent=(0.0, 0.05),
rotate=(-15, 15), interpolation=L, mask_interpolation=N, p=0.5),
]
if aug == "strong":
ops += [
A.ElasticTransform(alpha=30, sigma=6, interpolation=L,
mask_interpolation=N, p=0.3),
A.GridDistortion(num_steps=5, distort_limit=0.2, interpolation=L,
mask_interpolation=N, p=0.3),
A.RandomBrightnessContrast(p=0.5),
A.GaussNoise(p=0.2),
]
if in_channels == 3:
ops.append(A.CLAHE(p=0.2))
else:
ops.append(A.Resize(img_size, img_size, interpolation=L, mask_interpolation=N))
ops += [_normalize(in_channels, normalize), ToTensorV2()]
compose = A.Compose(ops)
def _apply(image: np.ndarray, mask: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
out = compose(image=image, mask=mask)
img = out["image"].float() # C,H,W
msk = out["mask"].long() # H,W
return img, msk
return _apply