dreamlessx commited on
Commit
9bfed75
·
verified ·
1 Parent(s): 387e567

Upload landmarkdiff/safety.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/safety.py +380 -0
landmarkdiff/safety.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Clinical safety validation for responsible deployment.
2
+
3
+ Implements safety checks for surgical outcome predictions:
4
+ 1. Identity preservation: verify output preserves patient identity
5
+ 2. Anatomical plausibility: check landmark displacements are realistic
6
+ 3. Out-of-distribution detection: flag unusual inputs
7
+ 4. Watermarking: mark AI-generated images
8
+ 5. Consent metadata: embed provenance information
9
+
10
+ Usage:
11
+ from landmarkdiff.safety import SafetyValidator
12
+
13
+ validator = SafetyValidator()
14
+ result = validator.validate(
15
+ input_image=image,
16
+ output_image=generated,
17
+ landmarks_original=face.landmarks,
18
+ landmarks_manipulated=manip.landmarks,
19
+ procedure="rhinoplasty",
20
+ )
21
+
22
+ if not result.passed:
23
+ print(f"Safety check failed: {result.failures}")
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ from dataclasses import dataclass, field
29
+ from typing import Optional
30
+
31
+ import cv2
32
+ import numpy as np
33
+
34
+
35
+ @dataclass
36
+ class SafetyResult:
37
+ """Result of safety validation checks."""
38
+ passed: bool = True
39
+ failures: list[str] = field(default_factory=list)
40
+ warnings: list[str] = field(default_factory=list)
41
+ checks: dict[str, bool] = field(default_factory=dict)
42
+ details: dict[str, object] = field(default_factory=dict)
43
+
44
+ def add_failure(self, name: str, message: str) -> None:
45
+ self.passed = False
46
+ self.failures.append(message)
47
+ self.checks[name] = False
48
+
49
+ def add_warning(self, name: str, message: str) -> None:
50
+ self.warnings.append(message)
51
+
52
+ def add_pass(self, name: str) -> None:
53
+ self.checks[name] = True
54
+
55
+ def summary(self) -> str:
56
+ lines = [f"Safety: {'PASS' if self.passed else 'FAIL'}"]
57
+ for name, ok in self.checks.items():
58
+ lines.append(f" [{'OK' if ok else 'FAIL'}] {name}")
59
+ for w in self.warnings:
60
+ lines.append(f" [WARN] {w}")
61
+ return "\n".join(lines)
62
+
63
+
64
+ class SafetyValidator:
65
+ """Clinical safety validation for surgical predictions."""
66
+
67
+ def __init__(
68
+ self,
69
+ identity_threshold: float = 0.6,
70
+ max_displacement_fraction: float = 0.05,
71
+ min_face_confidence: float = 0.5,
72
+ max_yaw_degrees: float = 45.0,
73
+ watermark_enabled: bool = True,
74
+ watermark_text: str = "AI-GENERATED PREDICTION",
75
+ ):
76
+ self.identity_threshold = identity_threshold
77
+ self.max_displacement_fraction = max_displacement_fraction
78
+ self.min_face_confidence = min_face_confidence
79
+ self.max_yaw_degrees = max_yaw_degrees
80
+ self.watermark_enabled = watermark_enabled
81
+ self.watermark_text = watermark_text
82
+
83
+ def validate(
84
+ self,
85
+ input_image: np.ndarray,
86
+ output_image: np.ndarray,
87
+ landmarks_original: np.ndarray | None = None,
88
+ landmarks_manipulated: np.ndarray | None = None,
89
+ procedure: str | None = None,
90
+ face_confidence: float = 1.0,
91
+ ) -> SafetyResult:
92
+ """Run all safety checks on a prediction.
93
+
94
+ Args:
95
+ input_image: Original patient image (BGR, uint8).
96
+ output_image: Generated prediction (BGR, uint8).
97
+ landmarks_original: Original landmarks (N, 2-3), normalized [0, 1].
98
+ landmarks_manipulated: Manipulated landmarks (N, 2-3), normalized [0, 1].
99
+ procedure: Surgical procedure name.
100
+ face_confidence: MediaPipe face detection confidence.
101
+
102
+ Returns:
103
+ SafetyResult with all check results.
104
+ """
105
+ result = SafetyResult()
106
+
107
+ # 1. Face detection confidence
108
+ self._check_face_confidence(result, face_confidence)
109
+
110
+ # 2. Identity preservation
111
+ self._check_identity(result, input_image, output_image)
112
+
113
+ # 3. Anatomical plausibility
114
+ if landmarks_original is not None and landmarks_manipulated is not None:
115
+ self._check_anatomical_plausibility(
116
+ result, landmarks_original, landmarks_manipulated, procedure
117
+ )
118
+
119
+ # 4. Output quality
120
+ self._check_output_quality(result, output_image)
121
+
122
+ # 5. OOD detection (basic)
123
+ self._check_ood(result, input_image)
124
+
125
+ return result
126
+
127
+ def _check_face_confidence(
128
+ self, result: SafetyResult, confidence: float
129
+ ) -> None:
130
+ """Check face detection confidence."""
131
+ if confidence < self.min_face_confidence:
132
+ result.add_failure(
133
+ "face_confidence",
134
+ f"Face detection confidence {confidence:.2f} below threshold "
135
+ f"{self.min_face_confidence}",
136
+ )
137
+ else:
138
+ result.add_pass("face_confidence")
139
+ result.details["face_confidence"] = confidence
140
+
141
+ def _check_identity(
142
+ self,
143
+ result: SafetyResult,
144
+ input_image: np.ndarray,
145
+ output_image: np.ndarray,
146
+ ) -> None:
147
+ """Check identity preservation using ArcFace similarity."""
148
+ try:
149
+ from landmarkdiff.evaluation import compute_identity_similarity
150
+ sim = compute_identity_similarity(output_image, input_image)
151
+ result.details["identity_similarity"] = float(sim)
152
+
153
+ if sim < self.identity_threshold:
154
+ result.add_failure(
155
+ "identity",
156
+ f"Identity similarity {sim:.3f} below threshold "
157
+ f"{self.identity_threshold}",
158
+ )
159
+ else:
160
+ result.add_pass("identity")
161
+ except Exception as e:
162
+ result.add_warning("identity", f"Identity check failed: {e}")
163
+
164
+ def _check_anatomical_plausibility(
165
+ self,
166
+ result: SafetyResult,
167
+ landmarks_orig: np.ndarray,
168
+ landmarks_manip: np.ndarray,
169
+ procedure: str | None,
170
+ ) -> None:
171
+ """Check that landmark displacements are anatomically plausible."""
172
+ if len(landmarks_orig) != len(landmarks_manip):
173
+ result.add_failure(
174
+ "anatomical",
175
+ f"Landmark count mismatch: {len(landmarks_orig)} vs {len(landmarks_manip)}",
176
+ )
177
+ return
178
+
179
+ # Compute displacement magnitudes
180
+ n = min(len(landmarks_orig), len(landmarks_manip))
181
+ orig = landmarks_orig[:n, :2] # (N, 2), normalized [0, 1]
182
+ manip = landmarks_manip[:n, :2]
183
+ displacements = np.linalg.norm(manip - orig, axis=1)
184
+
185
+ max_disp = float(displacements.max())
186
+ mean_disp = float(displacements.mean())
187
+ result.details["max_displacement"] = max_disp
188
+ result.details["mean_displacement"] = mean_disp
189
+
190
+ # Check maximum displacement
191
+ if max_disp > self.max_displacement_fraction:
192
+ result.add_failure(
193
+ "anatomical_magnitude",
194
+ f"Maximum displacement {max_disp:.4f} exceeds threshold "
195
+ f"{self.max_displacement_fraction}",
196
+ )
197
+ else:
198
+ result.add_pass("anatomical_magnitude")
199
+
200
+ # Check procedure-specific regions
201
+ if procedure:
202
+ self._check_procedure_regions(result, orig, manip, displacements, procedure)
203
+
204
+ def _check_procedure_regions(
205
+ self,
206
+ result: SafetyResult,
207
+ orig: np.ndarray,
208
+ manip: np.ndarray,
209
+ displacements: np.ndarray,
210
+ procedure: str,
211
+ ) -> None:
212
+ """Verify displacement is concentrated in expected anatomical regions."""
213
+ from landmarkdiff.landmarks import LANDMARK_REGIONS
214
+
215
+ # Expected regions by procedure
216
+ expected_regions = {
217
+ "rhinoplasty": ["nose"],
218
+ "blepharoplasty": ["eye_left", "eye_right"],
219
+ "rhytidectomy": ["jawline"],
220
+ "orthognathic": ["jawline", "lips"],
221
+ }
222
+
223
+ expected = expected_regions.get(procedure, [])
224
+ if not expected:
225
+ result.add_pass("procedure_region")
226
+ return
227
+
228
+ # Get expected region indices
229
+ expected_indices = set()
230
+ for region in expected:
231
+ if region in LANDMARK_REGIONS:
232
+ expected_indices.update(LANDMARK_REGIONS[region])
233
+
234
+ if not expected_indices:
235
+ result.add_pass("procedure_region")
236
+ return
237
+
238
+ # Check: is most displacement in expected regions?
239
+ n = min(len(displacements), len(orig))
240
+ expected_mask = np.array([i in expected_indices for i in range(n)])
241
+
242
+ if expected_mask.sum() > 0 and (~expected_mask).sum() > 0:
243
+ expected_disp = displacements[expected_mask].mean()
244
+ unexpected_disp = displacements[~expected_mask].mean()
245
+ result.details["expected_region_disp"] = float(expected_disp)
246
+ result.details["unexpected_region_disp"] = float(unexpected_disp)
247
+
248
+ # Expected regions should have more displacement
249
+ if unexpected_disp > expected_disp * 2 and unexpected_disp > 0.005:
250
+ result.add_warning(
251
+ "procedure_region",
252
+ f"{procedure}: unexpected regions displaced more than expected "
253
+ f"({unexpected_disp:.4f} vs {expected_disp:.4f})",
254
+ )
255
+ else:
256
+ result.add_pass("procedure_region")
257
+ else:
258
+ result.add_pass("procedure_region")
259
+
260
+ def _check_output_quality(
261
+ self, result: SafetyResult, output: np.ndarray
262
+ ) -> None:
263
+ """Check output image quality (not blank, not corrupted)."""
264
+ if output is None or output.size == 0:
265
+ result.add_failure("output_quality", "Output image is empty")
266
+ return
267
+
268
+ # Check for blank/black images
269
+ mean_val = output.mean()
270
+ if mean_val < 5:
271
+ result.add_failure("output_quality", f"Output is nearly black (mean={mean_val:.1f})")
272
+ return
273
+ if mean_val > 250:
274
+ result.add_failure("output_quality", f"Output is nearly white (mean={mean_val:.1f})")
275
+ return
276
+
277
+ # Check for artifacts (extreme variance)
278
+ std_val = output.std()
279
+ if std_val < 10:
280
+ result.add_warning(
281
+ "output_quality",
282
+ f"Output has very low variance (std={std_val:.1f}), may be uniform",
283
+ )
284
+
285
+ result.add_pass("output_quality")
286
+ result.details["output_mean"] = float(mean_val)
287
+ result.details["output_std"] = float(std_val)
288
+
289
+ def _check_ood(self, result: SafetyResult, image: np.ndarray) -> None:
290
+ """Basic out-of-distribution detection.
291
+
292
+ Checks image properties against expected ranges for face photos.
293
+ """
294
+ h, w = image.shape[:2]
295
+
296
+ # Resolution check
297
+ if min(h, w) < 128:
298
+ result.add_warning("ood", f"Image resolution too low: {w}x{h}")
299
+
300
+ # Aspect ratio (faces should be roughly square after preprocessing)
301
+ aspect = max(h, w) / max(min(h, w), 1)
302
+ if aspect > 3.0:
303
+ result.add_warning("ood", f"Unusual aspect ratio: {aspect:.1f}")
304
+
305
+ # Color distribution (face photos should have some skin tones)
306
+ if len(image.shape) == 3 and image.shape[2] == 3:
307
+ mean_b, mean_g, mean_r = image.mean(axis=(0, 1))
308
+ # Face images typically have red channel > blue channel
309
+ if mean_b > mean_r * 1.5:
310
+ result.add_warning("ood", "Image appears very blue (not typical face photo)")
311
+
312
+ result.add_pass("ood_basic")
313
+
314
+ def apply_watermark(
315
+ self,
316
+ image: np.ndarray,
317
+ text: str | None = None,
318
+ opacity: float = 0.3,
319
+ ) -> np.ndarray:
320
+ """Apply a text watermark to the output image.
321
+
322
+ Places semi-transparent text at the bottom of the image to indicate
323
+ it is AI-generated.
324
+ """
325
+ if not self.watermark_enabled:
326
+ return image
327
+
328
+ text = text or self.watermark_text
329
+ result = image.copy()
330
+ h, w = result.shape[:2]
331
+
332
+ # Create text overlay
333
+ font = cv2.FONT_HERSHEY_SIMPLEX
334
+ font_scale = max(0.3, w / 1500)
335
+ thickness = max(1, int(w / 500))
336
+
337
+ text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
338
+ x = (w - text_size[0]) // 2
339
+ y = h - 10
340
+
341
+ # Semi-transparent background bar
342
+ bar_y1 = y - text_size[1] - 10
343
+ bar_y2 = h
344
+ overlay = result.copy()
345
+ cv2.rectangle(overlay, (0, bar_y1), (w, bar_y2), (0, 0, 0), -1)
346
+ cv2.addWeighted(overlay, opacity, result, 1 - opacity, 0, result)
347
+
348
+ # White text
349
+ cv2.putText(result, text, (x, y), font, font_scale,
350
+ (255, 255, 255), thickness, cv2.LINE_AA)
351
+
352
+ return result
353
+
354
+ def embed_metadata(
355
+ self,
356
+ image_path: str,
357
+ procedure: str,
358
+ intensity: float,
359
+ model_version: str = "0.3.0",
360
+ ) -> None:
361
+ """Embed provenance metadata in the output image.
362
+
363
+ Writes EXIF/PNG metadata with generation parameters for traceability.
364
+ """
365
+ import json
366
+ from pathlib import Path
367
+
368
+ meta = {
369
+ "generator": "LandmarkDiff",
370
+ "version": model_version,
371
+ "procedure": procedure,
372
+ "intensity": intensity,
373
+ "disclaimer": "AI-generated surgical prediction for visualization only. "
374
+ "Not a guarantee of surgical outcome.",
375
+ }
376
+
377
+ # Save as sidecar JSON (PNG doesn't have easy EXIF support)
378
+ meta_path = Path(image_path).with_suffix(".meta.json")
379
+ with open(meta_path, "w") as f:
380
+ json.dump(meta, f, indent=2)