dreamlessx commited on
Commit
2905c51
·
verified ·
1 Parent(s): 6ba7e19

Upload landmarkdiff/synthetic/augmentation.py with huggingface_hub

Browse files
landmarkdiff/synthetic/augmentation.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Clinical degradation augmentations.
2
+
3
+ Degrades clean FFHQ/CelebA-HQ to match real clinical photo distribution.
4
+ Applied from day 1 - domain gap prevention, not afterthought.
5
+ 3-5 random augmentations per sample.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Callable
12
+
13
+ import cv2
14
+ import numpy as np
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class AugmentationConfig:
19
+ """Configuration for a single augmentation."""
20
+
21
+ name: str
22
+ fn: Callable[[np.ndarray, np.random.Generator], np.ndarray]
23
+ probability: float
24
+
25
+
26
+ def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
27
+ """Simulate point-source clinical lighting from a random direction."""
28
+ h, w = image.shape[:2]
29
+
30
+ # Random light source position
31
+ lx = rng.uniform(0, w)
32
+ ly = rng.uniform(0, h)
33
+ intensity = rng.uniform(0.3, 0.7)
34
+
35
+ # Distance-based falloff
36
+ y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
37
+ dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
38
+ max_dist = np.sqrt(w ** 2 + h ** 2)
39
+ light_map = 1.0 - (dist / max_dist) * intensity
40
+
41
+ light_map = np.clip(light_map, 0.3, 1.0)
42
+ light_3ch = np.stack([light_map] * 3, axis=-1)
43
+
44
+ return np.clip(image.astype(np.float32) * light_3ch, 0, 255).astype(np.uint8)
45
+
46
+
47
+ def color_temperature_jitter(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
48
+ """Jitter color temperature +/- 2000K equivalent."""
49
+ shift = rng.uniform(-0.15, 0.15)
50
+
51
+ result = image.astype(np.float32)
52
+ if shift > 0:
53
+ # Warmer: boost red, reduce blue
54
+ result[:, :, 2] *= 1 + shift # red (BGR)
55
+ result[:, :, 0] *= 1 - shift * 0.5 # blue
56
+ else:
57
+ # Cooler: boost blue, reduce red
58
+ result[:, :, 0] *= 1 + abs(shift)
59
+ result[:, :, 2] *= 1 - abs(shift) * 0.5
60
+
61
+ return np.clip(result, 0, 255).astype(np.uint8)
62
+
63
+
64
+ def green_fluorescent_cast(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
65
+ """Add green fluorescent lighting cast (common in clinical settings)."""
66
+ intensity = rng.uniform(0.05, 0.15)
67
+ result = image.astype(np.float32)
68
+ result[:, :, 1] *= 1 + intensity # green channel boost
69
+ result[:, :, 0] *= 1 - intensity * 0.3 # slight blue reduction
70
+ result[:, :, 2] *= 1 - intensity * 0.3 # slight red reduction
71
+ return np.clip(result, 0, 255).astype(np.uint8)
72
+
73
+
74
+ def jpeg_compression(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
75
+ """Simulate JPEG compression artifacts (quality 40-85)."""
76
+ quality = int(rng.uniform(40, 85))
77
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
78
+ _, encoded = cv2.imencode(".jpg", image, encode_param)
79
+ return cv2.imdecode(encoded, cv2.IMREAD_COLOR)
80
+
81
+
82
+ def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
83
+ """Add Gaussian sensor noise (sigma 5-25)."""
84
+ sigma = rng.uniform(5, 25)
85
+ noise = rng.normal(0, sigma, image.shape).astype(np.float32)
86
+ return np.clip(image.astype(np.float32) + noise, 0, 255).astype(np.uint8)
87
+
88
+
89
+ def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
90
+ """Apply barrel/pincushion distortion simulating phone camera lens."""
91
+ h, w = image.shape[:2]
92
+ k1 = rng.uniform(-0.2, 0.2)
93
+
94
+ fx = fy = max(w, h)
95
+ cx, cy = w / 2, h / 2
96
+
97
+ camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
98
+ dist_coeffs = np.array([k1, 0, 0, 0, 0], dtype=np.float64)
99
+
100
+ map1, map2 = cv2.initUndistortRectifyMap(
101
+ camera_matrix, dist_coeffs, None, camera_matrix, (w, h), cv2.CV_32FC1
102
+ )
103
+ return cv2.remap(image, map1, map2, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
104
+
105
+
106
+ def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
107
+ """Slight motion blur (common in handheld clinical photos)."""
108
+ size = int(rng.uniform(3, 7))
109
+ angle = rng.uniform(0, 180)
110
+
111
+ kernel = np.zeros((size, size))
112
+ kernel[size // 2, :] = 1.0 / size
113
+
114
+ M = cv2.getRotationMatrix2D((size / 2, size / 2), angle, 1)
115
+ kernel = cv2.warpAffine(kernel, M, (size, size))
116
+ ksum = kernel.sum()
117
+ if ksum > 0:
118
+ kernel = kernel / ksum
119
+ else:
120
+ # rotation can zero out the kernel - fall back to identity
121
+ kernel = np.zeros_like(kernel)
122
+ kernel[size // 2, size // 2] = 1.0
123
+
124
+ return cv2.filter2D(image, -1, kernel)
125
+
126
+
127
+ def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
128
+ """Add lens vignetting (darkened corners)."""
129
+ h, w = image.shape[:2]
130
+ strength = rng.uniform(0.3, 0.7)
131
+
132
+ y, x = np.mgrid[0:h, 0:w].astype(np.float32)
133
+ cx, cy = w / 2, h / 2
134
+ dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
135
+ max_dist = np.sqrt(cx ** 2 + cy ** 2)
136
+
137
+ mask = 1 - strength * (dist / max_dist) ** 2
138
+ mask = np.clip(mask, 0.3, 1.0)
139
+ mask_3ch = np.stack([mask] * 3, axis=-1)
140
+
141
+ return np.clip(image.astype(np.float32) * mask_3ch, 0, 255).astype(np.uint8)
142
+
143
+
144
+ # Augmentation pool with probabilities from the spec
145
+ AUGMENTATION_POOL: list[AugmentationConfig] = [
146
+ AugmentationConfig("point_source_lighting", point_source_lighting, 0.40),
147
+ AugmentationConfig("color_temperature", color_temperature_jitter, 0.60),
148
+ AugmentationConfig("green_fluorescent", green_fluorescent_cast, 0.25),
149
+ AugmentationConfig("jpeg_compression", jpeg_compression, 0.30),
150
+ AugmentationConfig("sensor_noise", gaussian_sensor_noise, 0.40),
151
+ AugmentationConfig("barrel_distortion", barrel_distortion, 0.30),
152
+ AugmentationConfig("motion_blur", motion_blur, 0.20),
153
+ AugmentationConfig("vignette", vignette, 0.25),
154
+ ]
155
+
156
+
157
+ def apply_clinical_augmentation(
158
+ image: np.ndarray,
159
+ min_augmentations: int = 3,
160
+ max_augmentations: int = 5,
161
+ rng: np.random.Generator | None = None,
162
+ ) -> np.ndarray:
163
+ """Apply random clinical degradation augmentations to an image."""
164
+ rng = rng or np.random.default_rng()
165
+
166
+ # Select augmentations by probability
167
+ selected = []
168
+ for aug in AUGMENTATION_POOL:
169
+ if rng.random() < aug.probability:
170
+ selected.append(aug)
171
+
172
+ # Ensure min/max bounds
173
+ if len(selected) < min_augmentations:
174
+ remaining = [a for a in AUGMENTATION_POOL if a not in selected]
175
+ rng.shuffle(remaining)
176
+ selected.extend(remaining[: min_augmentations - len(selected)])
177
+
178
+ if len(selected) > max_augmentations:
179
+ rng.shuffle(selected)
180
+ selected = selected[:max_augmentations]
181
+
182
+ # Apply in random order
183
+ rng.shuffle(selected)
184
+ result = image.copy()
185
+ for aug in selected:
186
+ result = aug.fn(result, rng)
187
+
188
+ return result