File size: 4,956 Bytes
9eda8e3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from __future__ import annotations
import random
import cv2
import numpy as np
def random_color_distort(
img: np.ndarray,
brightness_delta: int = 32,
contrast_low: float = 0.5,
contrast_high: float = 1.5,
saturation_low: float = 0.5,
saturation_high: float = 1.5,
hue_delta: int = 18,
) -> np.ndarray:
"""SSD-style random colour jittering.
Operates on an HWC **RGB uint8** image and returns the same format.
"""
cv_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
def _convert(arr, alpha=1.0, beta=0.0):
arr = arr.astype(np.float32) * alpha + beta
return np.clip(arr, 0, 255).astype(np.uint8)
# Brightness
if random.random() < 0.5:
cv_img = _convert(cv_img, beta=random.uniform(-brightness_delta, brightness_delta))
# Decide order: contrast first or saturation/hue first
if random.random() < 0.5:
order = ["contrast", "saturation", "hue"]
else:
order = ["saturation", "hue", "contrast"]
for aug in order:
if aug == "contrast" and random.random() < 0.5:
cv_img = _convert(cv_img, alpha=random.uniform(contrast_low, contrast_high))
elif aug == "saturation" and random.random() < 0.5:
hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV)
hsv[:, :, 1] = _convert(hsv[:, :, 1], alpha=random.uniform(saturation_low, saturation_high))
cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
elif aug == "hue" and random.random() < 0.5:
hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV)
hsv[:, :, 0] = ((hsv[:, :, 0].astype(int) + random.randint(-hue_delta, hue_delta)) % 180).astype(np.uint8)
cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
def random_flip(image: np.ndarray, label: np.ndarray):
"""Random horizontal and/or vertical flip."""
if random.random() < 0.5:
image = np.ascontiguousarray(image[:, ::-1])
label = np.ascontiguousarray(label[:, ::-1])
if random.random() < 0.5:
image = np.ascontiguousarray(image[::-1])
label = np.ascontiguousarray(label[::-1])
return image, label
def random_rotate90(image: np.ndarray, label: np.ndarray):
"""Random 0/90/180/270° rotation."""
k = random.randint(0, 3)
if k > 0:
image = np.rot90(image, k, axes=(0, 1)).copy()
label = np.rot90(label, k, axes=(0, 1)).copy()
return image, label
def random_crop(image: np.ndarray, label: np.ndarray, crop_size: int):
"""Extract a random crop of ``crop_size × crop_size`` from image/label."""
h, w = image.shape[:2]
top = random.randint(0, h - crop_size)
left = random.randint(0, w - crop_size)
image = image[top : top + crop_size, left : left + crop_size]
label = label[top : top + crop_size, left : left + crop_size]
return image, label
def center_crop(image: np.ndarray, label: np.ndarray, crop_size: int):
"""Center crop for validation."""
h, w = image.shape[:2]
top = (h - crop_size) // 2
left = (w - crop_size) // 2
image = image[top : top + crop_size, left : left + crop_size]
label = label[top : top + crop_size, left : left + crop_size]
return image, label
def pad_to_size(
image: np.ndarray,
label: np.ndarray,
min_size: int,
pad_label_value: int = 0,
) -> tuple[np.ndarray, np.ndarray]:
"""Symmetric-pad image and label so both sides are ≥ min_size."""
h, w = image.shape[:2]
if h >= min_size and w >= min_size:
return image, label
H = max(h, min_size)
W = max(w, min_size)
py1, px1 = (H - h) // 2, (W - w) // 2
py2, px2 = H - h - py1, W - w - px1
image = np.pad(image, ((py1, py2), (px1, px2), (0, 0)), mode="symmetric")
label = np.pad(label, ((py1, py2), (px1, px2)), mode="constant", constant_values=pad_label_value)
return image, label
def get_training_augmentation(
image: np.ndarray,
label: np.ndarray,
crop_size: int = 400,
color_distort: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
"""Full training augmentation pipeline.
Steps:
1. Optional colour distortion
2. Pad if smaller than crop_size
3. Random flip
4. Random 90° rotation
5. Random crop
"""
if color_distort:
image = random_color_distort(image)
image, label = pad_to_size(image, label, crop_size, pad_label_value=0)
image, label = random_flip(image, label)
image, label = random_rotate90(image, label)
image, label = random_crop(image, label, crop_size)
return image, label
def get_validation_transform(
image: np.ndarray,
label: np.ndarray,
crop_size: int = 480,
) -> tuple[np.ndarray, np.ndarray]:
"""Validation transform: pad → center crop (deterministic)."""
image, label = pad_to_size(image, label, crop_size, pad_label_value=255)
image, label = center_crop(image, label, crop_size)
return image, label
|