from __future__ import annotations import random import cv2 import numpy as np def random_color_distort( img: np.ndarray, brightness_delta: int = 32, contrast_low: float = 0.5, contrast_high: float = 1.5, saturation_low: float = 0.5, saturation_high: float = 1.5, hue_delta: int = 18, ) -> np.ndarray: """SSD-style random colour jittering. Operates on an HWC **RGB uint8** image and returns the same format. """ cv_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) def _convert(arr, alpha=1.0, beta=0.0): arr = arr.astype(np.float32) * alpha + beta return np.clip(arr, 0, 255).astype(np.uint8) # Brightness if random.random() < 0.5: cv_img = _convert(cv_img, beta=random.uniform(-brightness_delta, brightness_delta)) # Decide order: contrast first or saturation/hue first if random.random() < 0.5: order = ["contrast", "saturation", "hue"] else: order = ["saturation", "hue", "contrast"] for aug in order: if aug == "contrast" and random.random() < 0.5: cv_img = _convert(cv_img, alpha=random.uniform(contrast_low, contrast_high)) elif aug == "saturation" and random.random() < 0.5: hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV) hsv[:, :, 1] = _convert(hsv[:, :, 1], alpha=random.uniform(saturation_low, saturation_high)) cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) elif aug == "hue" and random.random() < 0.5: hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV) hsv[:, :, 0] = ((hsv[:, :, 0].astype(int) + random.randint(-hue_delta, hue_delta)) % 180).astype(np.uint8) cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) return cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB) def random_flip(image: np.ndarray, label: np.ndarray): """Random horizontal and/or vertical flip.""" if random.random() < 0.5: image = np.ascontiguousarray(image[:, ::-1]) label = np.ascontiguousarray(label[:, ::-1]) if random.random() < 0.5: image = np.ascontiguousarray(image[::-1]) label = np.ascontiguousarray(label[::-1]) return image, label def random_rotate90(image: np.ndarray, label: np.ndarray): """Random 0/90/180/270° rotation.""" k = random.randint(0, 3) if k > 0: image = np.rot90(image, k, axes=(0, 1)).copy() label = np.rot90(label, k, axes=(0, 1)).copy() return image, label def random_crop(image: np.ndarray, label: np.ndarray, crop_size: int): """Extract a random crop of ``crop_size × crop_size`` from image/label.""" h, w = image.shape[:2] top = random.randint(0, h - crop_size) left = random.randint(0, w - crop_size) image = image[top : top + crop_size, left : left + crop_size] label = label[top : top + crop_size, left : left + crop_size] return image, label def center_crop(image: np.ndarray, label: np.ndarray, crop_size: int): """Center crop for validation.""" h, w = image.shape[:2] top = (h - crop_size) // 2 left = (w - crop_size) // 2 image = image[top : top + crop_size, left : left + crop_size] label = label[top : top + crop_size, left : left + crop_size] return image, label def pad_to_size( image: np.ndarray, label: np.ndarray, min_size: int, pad_label_value: int = 0, ) -> tuple[np.ndarray, np.ndarray]: """Symmetric-pad image and label so both sides are ≥ min_size.""" h, w = image.shape[:2] if h >= min_size and w >= min_size: return image, label H = max(h, min_size) W = max(w, min_size) py1, px1 = (H - h) // 2, (W - w) // 2 py2, px2 = H - h - py1, W - w - px1 image = np.pad(image, ((py1, py2), (px1, px2), (0, 0)), mode="symmetric") label = np.pad(label, ((py1, py2), (px1, px2)), mode="constant", constant_values=pad_label_value) return image, label def get_training_augmentation( image: np.ndarray, label: np.ndarray, crop_size: int = 400, color_distort: bool = True, ) -> tuple[np.ndarray, np.ndarray]: """Full training augmentation pipeline. Steps: 1. Optional colour distortion 2. Pad if smaller than crop_size 3. Random flip 4. Random 90° rotation 5. Random crop """ if color_distort: image = random_color_distort(image) image, label = pad_to_size(image, label, crop_size, pad_label_value=0) image, label = random_flip(image, label) image, label = random_rotate90(image, label) image, label = random_crop(image, label, crop_size) return image, label def get_validation_transform( image: np.ndarray, label: np.ndarray, crop_size: int = 480, ) -> tuple[np.ndarray, np.ndarray]: """Validation transform: pad → center crop (deterministic).""" image, label = pad_to_size(image, label, crop_size, pad_label_value=255) image, label = center_crop(image, label, crop_size) return image, label