Spaces:
Running
Running
File size: 7,261 Bytes
46ecbf8 2905c51 46ecbf8 2905c51 433e26f 2905c51 46ecbf8 2905c51 46ecbf8 2905c51 46ecbf8 2905c51 46ecbf8 2905c51 46ecbf8 2905c51 46ecbf8 2905c51 46ecbf8 2905c51 46ecbf8 2905c51 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """Clinical degradation augmentation pipeline.
Degrades clean FFHQ/CelebA-HQ images to match real clinical photo distribution.
Applied from day 1 of training — domain gap prevention, not afterthought.
Each sample gets 3-5 random augmentations from the pool.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
import cv2
import numpy as np
@dataclass(frozen=True)
class AugmentationConfig:
"""Configuration for a single augmentation."""
name: str
fn: Callable[[np.ndarray, np.random.Generator], np.ndarray]
probability: float
def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Simulate point-source clinical lighting from a random direction."""
h, w = image.shape[:2]
if h < 4 or w < 4:
return image
# Random light source position
lx = rng.uniform(0, w)
ly = rng.uniform(0, h)
intensity = rng.uniform(0.3, 0.7)
# Distance-based falloff
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
max_dist = np.sqrt(w ** 2 + h ** 2)
light_map = 1.0 - (dist / max_dist) * intensity
light_map = np.clip(light_map, 0.3, 1.0)
light_3ch = np.stack([light_map] * 3, axis=-1)
return np.clip(image.astype(np.float32) * light_3ch, 0, 255).astype(np.uint8)
def color_temperature_jitter(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Jitter color temperature +/- 2000K equivalent."""
shift = rng.uniform(-0.15, 0.15)
result = image.astype(np.float32)
if shift > 0:
# Warmer: boost red, reduce blue
result[:, :, 2] *= 1 + shift # red (BGR)
result[:, :, 0] *= 1 - shift * 0.5 # blue
else:
# Cooler: boost blue, reduce red
result[:, :, 0] *= 1 + abs(shift)
result[:, :, 2] *= 1 - abs(shift) * 0.5
return np.clip(result, 0, 255).astype(np.uint8)
def green_fluorescent_cast(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Add green fluorescent lighting cast (common in clinical settings)."""
intensity = rng.uniform(0.05, 0.15)
result = image.astype(np.float32)
result[:, :, 1] *= 1 + intensity # green channel boost
result[:, :, 0] *= 1 - intensity * 0.3 # slight blue reduction
result[:, :, 2] *= 1 - intensity * 0.3 # slight red reduction
return np.clip(result, 0, 255).astype(np.uint8)
def jpeg_compression(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Simulate JPEG compression artifacts (quality 40-85)."""
quality = int(rng.uniform(40, 85))
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
_, encoded = cv2.imencode(".jpg", image, encode_param)
decoded = cv2.imdecode(encoded, cv2.IMREAD_COLOR)
return decoded if decoded is not None else image
def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Add Gaussian sensor noise (sigma 5-25)."""
sigma = rng.uniform(5, 25)
noise = rng.normal(0, sigma, image.shape).astype(np.float32)
return np.clip(image.astype(np.float32) + noise, 0, 255).astype(np.uint8)
def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Apply barrel/pincushion distortion simulating phone camera lens."""
h, w = image.shape[:2]
if h < 4 or w < 4:
return image
k1 = rng.uniform(-0.2, 0.2)
fx = fy = max(w, h)
cx, cy = w / 2, h / 2
camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
dist_coeffs = np.array([k1, 0, 0, 0, 0], dtype=np.float64)
map1, map2 = cv2.initUndistortRectifyMap(
camera_matrix, dist_coeffs, None, camera_matrix, (w, h), cv2.CV_32FC1
)
return cv2.remap(image, map1, map2, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Slight motion blur (common in handheld clinical photos)."""
h, w = image.shape[:2]
if h < 4 or w < 4:
return image
size = int(rng.uniform(3, 7))
angle = rng.uniform(0, 180)
kernel = np.zeros((size, size))
kernel[size // 2, :] = 1.0 / size
M = cv2.getRotationMatrix2D((size / 2, size / 2), angle, 1)
kernel = cv2.warpAffine(kernel, M, (size, size))
ksum = kernel.sum()
if ksum > 0:
kernel = kernel / ksum
else:
kernel = np.zeros_like(kernel)
kernel[size // 2, size // 2] = 1.0
return cv2.filter2D(image, -1, kernel)
def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
"""Add lens vignetting (darkened corners)."""
h, w = image.shape[:2]
if h < 4 or w < 4:
return image
strength = rng.uniform(0.3, 0.7)
y, x = np.mgrid[0:h, 0:w].astype(np.float32)
cx, cy = w / 2, h / 2
dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
max_dist = np.sqrt(cx ** 2 + cy ** 2)
mask = 1 - strength * (dist / max_dist) ** 2
mask = np.clip(mask, 0.3, 1.0)
mask_3ch = np.stack([mask] * 3, axis=-1)
return np.clip(image.astype(np.float32) * mask_3ch, 0, 255).astype(np.uint8)
# Augmentation pool with probabilities from the spec
AUGMENTATION_POOL: list[AugmentationConfig] = [
AugmentationConfig("point_source_lighting", point_source_lighting, 0.40),
AugmentationConfig("color_temperature", color_temperature_jitter, 0.60),
AugmentationConfig("green_fluorescent", green_fluorescent_cast, 0.25),
AugmentationConfig("jpeg_compression", jpeg_compression, 0.30),
AugmentationConfig("sensor_noise", gaussian_sensor_noise, 0.40),
AugmentationConfig("barrel_distortion", barrel_distortion, 0.30),
AugmentationConfig("motion_blur", motion_blur, 0.20),
AugmentationConfig("vignette", vignette, 0.25),
]
def apply_clinical_augmentation(
image: np.ndarray,
min_augmentations: int = 3,
max_augmentations: int = 5,
rng: np.random.Generator | None = None,
) -> np.ndarray:
"""Apply random clinical degradation augmentations to an image.
Each sample gets min_augmentations to max_augmentations from the pool,
selected by their individual probabilities.
Args:
image: BGR input image (clean FFHQ/CelebA-HQ).
min_augmentations: Minimum number of augmentations to apply.
max_augmentations: Maximum number of augmentations to apply.
rng: Random number generator.
Returns:
Degraded image matching clinical photo distribution.
"""
rng = rng or np.random.default_rng()
# Select augmentations by probability
selected = []
for aug in AUGMENTATION_POOL:
if rng.random() < aug.probability:
selected.append(aug)
# Ensure min/max bounds
if len(selected) < min_augmentations:
remaining = [a for a in AUGMENTATION_POOL if a not in selected]
rng.shuffle(remaining)
selected.extend(remaining[: min_augmentations - len(selected)])
if len(selected) > max_augmentations:
rng.shuffle(selected)
selected = selected[:max_augmentations]
# Apply in random order
rng.shuffle(selected)
result = image.copy()
for aug in selected:
result = aug.fn(result, rng)
return result
|