File size: 4,956 Bytes
9eda8e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from __future__ import annotations

import random
import cv2
import numpy as np

def random_color_distort(
    img: np.ndarray,
    brightness_delta: int = 32,
    contrast_low: float = 0.5,
    contrast_high: float = 1.5,
    saturation_low: float = 0.5,
    saturation_high: float = 1.5,
    hue_delta: int = 18,
) -> np.ndarray:
    """SSD-style random colour jittering.

    Operates on an HWC **RGB uint8** image and returns the same format.
    """
    cv_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

    def _convert(arr, alpha=1.0, beta=0.0):
        arr = arr.astype(np.float32) * alpha + beta
        return np.clip(arr, 0, 255).astype(np.uint8)

    # Brightness
    if random.random() < 0.5:
        cv_img = _convert(cv_img, beta=random.uniform(-brightness_delta, brightness_delta))

    # Decide order: contrast first or saturation/hue first
    if random.random() < 0.5:
        order = ["contrast", "saturation", "hue"]
    else:
        order = ["saturation", "hue", "contrast"]

    for aug in order:
        if aug == "contrast" and random.random() < 0.5:
            cv_img = _convert(cv_img, alpha=random.uniform(contrast_low, contrast_high))
        elif aug == "saturation" and random.random() < 0.5:
            hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV)
            hsv[:, :, 1] = _convert(hsv[:, :, 1], alpha=random.uniform(saturation_low, saturation_high))
            cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
        elif aug == "hue" and random.random() < 0.5:
            hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV)
            hsv[:, :, 0] = ((hsv[:, :, 0].astype(int) + random.randint(-hue_delta, hue_delta)) % 180).astype(np.uint8)
            cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

    return cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)

def random_flip(image: np.ndarray, label: np.ndarray):
    """Random horizontal and/or vertical flip."""
    if random.random() < 0.5:
        image = np.ascontiguousarray(image[:, ::-1])
        label = np.ascontiguousarray(label[:, ::-1])
    if random.random() < 0.5:
        image = np.ascontiguousarray(image[::-1])
        label = np.ascontiguousarray(label[::-1])
    return image, label

def random_rotate90(image: np.ndarray, label: np.ndarray):
    """Random 0/90/180/270° rotation."""
    k = random.randint(0, 3)
    if k > 0:
        image = np.rot90(image, k, axes=(0, 1)).copy()
        label = np.rot90(label, k, axes=(0, 1)).copy()
    return image, label

def random_crop(image: np.ndarray, label: np.ndarray, crop_size: int):
    """Extract a random crop of ``crop_size × crop_size`` from image/label."""
    h, w = image.shape[:2]
    top = random.randint(0, h - crop_size)
    left = random.randint(0, w - crop_size)
    image = image[top : top + crop_size, left : left + crop_size]
    label = label[top : top + crop_size, left : left + crop_size]
    return image, label

def center_crop(image: np.ndarray, label: np.ndarray, crop_size: int):
    """Center crop for validation."""
    h, w = image.shape[:2]
    top = (h - crop_size) // 2
    left = (w - crop_size) // 2
    image = image[top : top + crop_size, left : left + crop_size]
    label = label[top : top + crop_size, left : left + crop_size]
    return image, label

def pad_to_size(
    image: np.ndarray,
    label: np.ndarray,
    min_size: int,
    pad_label_value: int = 0,
) -> tuple[np.ndarray, np.ndarray]:
    """Symmetric-pad image and label so both sides are ≥ min_size."""
    h, w = image.shape[:2]
    if h >= min_size and w >= min_size:
        return image, label

    H = max(h, min_size)
    W = max(w, min_size)
    py1, px1 = (H - h) // 2, (W - w) // 2
    py2, px2 = H - h - py1, W - w - px1

    image = np.pad(image, ((py1, py2), (px1, px2), (0, 0)), mode="symmetric")
    label = np.pad(label, ((py1, py2), (px1, px2)), mode="constant", constant_values=pad_label_value)
    return image, label

def get_training_augmentation(
    image: np.ndarray,
    label: np.ndarray,
    crop_size: int = 400,
    color_distort: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
    """Full training augmentation pipeline.

    Steps:
      1. Optional colour distortion
      2. Pad if smaller than crop_size
      3. Random flip
      4. Random 90° rotation
      5. Random crop
    """
    if color_distort:
        image = random_color_distort(image)

    image, label = pad_to_size(image, label, crop_size, pad_label_value=0)
    image, label = random_flip(image, label)
    image, label = random_rotate90(image, label)
    image, label = random_crop(image, label, crop_size)
    return image, label

def get_validation_transform(
    image: np.ndarray,
    label: np.ndarray,
    crop_size: int = 480,
) -> tuple[np.ndarray, np.ndarray]:
    """Validation transform:  pad → center crop (deterministic)."""
    image, label = pad_to_size(image, label, crop_size, pad_label_value=255)
    image, label = center_crop(image, label, crop_size)
    return image, label