spacenet / utils /transforms.py
harshinde's picture
Upload 6 files
9eda8e3 verified
from __future__ import annotations
import random
import cv2
import numpy as np
def random_color_distort(
img: np.ndarray,
brightness_delta: int = 32,
contrast_low: float = 0.5,
contrast_high: float = 1.5,
saturation_low: float = 0.5,
saturation_high: float = 1.5,
hue_delta: int = 18,
) -> np.ndarray:
"""SSD-style random colour jittering.
Operates on an HWC **RGB uint8** image and returns the same format.
"""
cv_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
def _convert(arr, alpha=1.0, beta=0.0):
arr = arr.astype(np.float32) * alpha + beta
return np.clip(arr, 0, 255).astype(np.uint8)
# Brightness
if random.random() < 0.5:
cv_img = _convert(cv_img, beta=random.uniform(-brightness_delta, brightness_delta))
# Decide order: contrast first or saturation/hue first
if random.random() < 0.5:
order = ["contrast", "saturation", "hue"]
else:
order = ["saturation", "hue", "contrast"]
for aug in order:
if aug == "contrast" and random.random() < 0.5:
cv_img = _convert(cv_img, alpha=random.uniform(contrast_low, contrast_high))
elif aug == "saturation" and random.random() < 0.5:
hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV)
hsv[:, :, 1] = _convert(hsv[:, :, 1], alpha=random.uniform(saturation_low, saturation_high))
cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
elif aug == "hue" and random.random() < 0.5:
hsv = cv2.cvtColor(cv_img, cv2.COLOR_BGR2HSV)
hsv[:, :, 0] = ((hsv[:, :, 0].astype(int) + random.randint(-hue_delta, hue_delta)) % 180).astype(np.uint8)
cv_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
def random_flip(image: np.ndarray, label: np.ndarray):
"""Random horizontal and/or vertical flip."""
if random.random() < 0.5:
image = np.ascontiguousarray(image[:, ::-1])
label = np.ascontiguousarray(label[:, ::-1])
if random.random() < 0.5:
image = np.ascontiguousarray(image[::-1])
label = np.ascontiguousarray(label[::-1])
return image, label
def random_rotate90(image: np.ndarray, label: np.ndarray):
"""Random 0/90/180/270° rotation."""
k = random.randint(0, 3)
if k > 0:
image = np.rot90(image, k, axes=(0, 1)).copy()
label = np.rot90(label, k, axes=(0, 1)).copy()
return image, label
def random_crop(image: np.ndarray, label: np.ndarray, crop_size: int):
"""Extract a random crop of ``crop_size × crop_size`` from image/label."""
h, w = image.shape[:2]
top = random.randint(0, h - crop_size)
left = random.randint(0, w - crop_size)
image = image[top : top + crop_size, left : left + crop_size]
label = label[top : top + crop_size, left : left + crop_size]
return image, label
def center_crop(image: np.ndarray, label: np.ndarray, crop_size: int):
"""Center crop for validation."""
h, w = image.shape[:2]
top = (h - crop_size) // 2
left = (w - crop_size) // 2
image = image[top : top + crop_size, left : left + crop_size]
label = label[top : top + crop_size, left : left + crop_size]
return image, label
def pad_to_size(
image: np.ndarray,
label: np.ndarray,
min_size: int,
pad_label_value: int = 0,
) -> tuple[np.ndarray, np.ndarray]:
"""Symmetric-pad image and label so both sides are ≥ min_size."""
h, w = image.shape[:2]
if h >= min_size and w >= min_size:
return image, label
H = max(h, min_size)
W = max(w, min_size)
py1, px1 = (H - h) // 2, (W - w) // 2
py2, px2 = H - h - py1, W - w - px1
image = np.pad(image, ((py1, py2), (px1, px2), (0, 0)), mode="symmetric")
label = np.pad(label, ((py1, py2), (px1, px2)), mode="constant", constant_values=pad_label_value)
return image, label
def get_training_augmentation(
image: np.ndarray,
label: np.ndarray,
crop_size: int = 400,
color_distort: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
"""Full training augmentation pipeline.
Steps:
1. Optional colour distortion
2. Pad if smaller than crop_size
3. Random flip
4. Random 90° rotation
5. Random crop
"""
if color_distort:
image = random_color_distort(image)
image, label = pad_to_size(image, label, crop_size, pad_label_value=0)
image, label = random_flip(image, label)
image, label = random_rotate90(image, label)
image, label = random_crop(image, label, crop_size)
return image, label
def get_validation_transform(
image: np.ndarray,
label: np.ndarray,
crop_size: int = 480,
) -> tuple[np.ndarray, np.ndarray]:
"""Validation transform: pad → center crop (deterministic)."""
image, label = pad_to_size(image, label, crop_size, pad_label_value=255)
image, label = center_crop(image, label, crop_size)
return image, label