dreamlessx commited on
Commit
d895a0c
·
verified ·
1 Parent(s): 83b71db

Upload landmarkdiff/face_verifier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/face_verifier.py +805 -0
landmarkdiff/face_verifier.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Face distortion detection, neural restoration, and identity verification.
2
+
3
+ Used for cleaning scraped data, post-diffusion QA, and beauty filter removal.
4
+ Cascades: CodeFormer -> GFPGAN -> Real-ESRGAN, with ArcFace identity gate.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from dataclasses import dataclass, field
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import cv2
15
+ import numpy as np
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Data structures
20
+ # ---------------------------------------------------------------------------
21
+
22
+ @dataclass
23
+ class DistortionReport:
24
+ """Distortion analysis for a face image."""
25
+
26
+ # Overall quality score (0-100, higher = better)
27
+ quality_score: float = 0.0
28
+
29
+ # Individual distortion scores (0-1, higher = more distorted)
30
+ blur_score: float = 0.0 # Laplacian variance-based
31
+ noise_score: float = 0.0 # High-freq energy ratio
32
+ compression_score: float = 0.0 # JPEG block artifact detection
33
+ oversmooth_score: float = 0.0 # Beauty filter / airbrushed detection
34
+ color_cast_score: float = 0.0 # Unnatural color shift
35
+ geometric_distort: float = 0.0 # Face proportion anomalies
36
+ lighting_score: float = 0.0 # Over/under exposure
37
+
38
+ # Classification
39
+ primary_distortion: str = "none"
40
+ severity: str = "none" # none, mild, moderate, severe
41
+ is_usable: bool = True # Whether image is worth restoring vs rejecting
42
+
43
+ # Details
44
+ details: dict = field(default_factory=dict)
45
+
46
+ def summary(self) -> str:
47
+ lines = [
48
+ f"Quality Score: {self.quality_score:.1f}/100",
49
+ f"Primary Issue: {self.primary_distortion} ({self.severity})",
50
+ f"Usable: {self.is_usable}",
51
+ "",
52
+ "Distortion Breakdown:",
53
+ f" Blur: {self.blur_score:.3f}",
54
+ f" Noise: {self.noise_score:.3f}",
55
+ f" Compression: {self.compression_score:.3f}",
56
+ f" Oversmooth: {self.oversmooth_score:.3f}",
57
+ f" Color Cast: {self.color_cast_score:.3f}",
58
+ f" Geometric: {self.geometric_distort:.3f}",
59
+ f" Lighting: {self.lighting_score:.3f}",
60
+ ]
61
+ return "\n".join(lines)
62
+
63
+
64
+ @dataclass
65
+ class RestorationResult:
66
+ """What came out of the restoration pipeline."""
67
+
68
+ restored: np.ndarray # Restored BGR image
69
+ original: np.ndarray # Original BGR image
70
+ distortion_report: DistortionReport # Pre-restoration analysis
71
+ post_quality_score: float = 0.0 # Quality after restoration
72
+ identity_similarity: float = 0.0 # ArcFace cosine sim (original vs restored)
73
+ identity_preserved: bool = True # Whether identity check passed
74
+ restoration_stages: list[str] = field(default_factory=list) # Which nets ran
75
+ improvement: float = 0.0 # quality_after - quality_before
76
+
77
+ def summary(self) -> str:
78
+ lines = [
79
+ f"Pre-restoration: {self.distortion_report.quality_score:.1f}/100",
80
+ f"Post-restoration: {self.post_quality_score:.1f}/100",
81
+ f"Improvement: +{self.improvement:.1f}",
82
+ f"Identity Sim: {self.identity_similarity:.3f}",
83
+ f"Identity OK: {self.identity_preserved}",
84
+ f"Stages Used: {' -> '.join(self.restoration_stages) or 'none'}",
85
+ ]
86
+ return "\n".join(lines)
87
+
88
+
89
+ @dataclass
90
+ class BatchVerificationReport:
91
+ """Batch verification stats."""
92
+
93
+ total: int = 0
94
+ passed: int = 0 # Good quality, no fix needed
95
+ restored: int = 0 # Fixed and now usable
96
+ rejected: int = 0 # Too distorted to salvage
97
+ identity_failures: int = 0 # Restoration changed identity
98
+ avg_quality_before: float = 0.0
99
+ avg_quality_after: float = 0.0
100
+ avg_identity_sim: float = 0.0
101
+ distortion_counts: dict[str, int] = field(default_factory=dict)
102
+
103
+ def summary(self) -> str:
104
+ lines = [
105
+ f"Total Images: {self.total}",
106
+ f" Passed (good): {self.passed}",
107
+ f" Restored: {self.restored}",
108
+ f" Rejected: {self.rejected}",
109
+ f" Identity Fail: {self.identity_failures}",
110
+ f"Avg Quality Before: {self.avg_quality_before:.1f}",
111
+ f"Avg Quality After: {self.avg_quality_after:.1f}",
112
+ f"Avg Identity Sim: {self.avg_identity_sim:.3f}",
113
+ "",
114
+ "Distortion Breakdown:",
115
+ ]
116
+ for dist_type, count in sorted(
117
+ self.distortion_counts.items(), key=lambda x: -x[1],
118
+ ):
119
+ lines.append(f" {dist_type}: {count}")
120
+ return "\n".join(lines)
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # Distortion Detection (classical + neural)
125
+ # ---------------------------------------------------------------------------
126
+
127
+ def detect_blur(image: np.ndarray) -> float:
128
+ """Laplacian variance + gradient magnitude blur score (0-1, 1=blurry)."""
129
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
130
+
131
+ # Laplacian variance (primary metric)
132
+ lap_var = cv2.Laplacian(gray, cv2.CV_64F).var()
133
+
134
+ # Gradient magnitude (secondary)
135
+ gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
136
+ gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
137
+ grad_mag = np.sqrt(gx ** 2 + gy ** 2).mean()
138
+
139
+ # Normalize: typical sharp face has lap_var > 500, grad_mag > 30
140
+ blur_lap = 1.0 - min(lap_var / 800.0, 1.0)
141
+ blur_grad = 1.0 - min(grad_mag / 50.0, 1.0)
142
+
143
+ return float(np.clip(0.6 * blur_lap + 0.4 * blur_grad, 0, 1))
144
+
145
+
146
+ def detect_noise(image: np.ndarray) -> float:
147
+ """Noise estimate via MAD of Laplacian (0-1, 1=noisy)."""
148
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
149
+
150
+ # Robust noise estimation via MAD of Laplacian
151
+ lap = cv2.Laplacian(gray.astype(np.float64), cv2.CV_64F)
152
+ sigma_est = np.median(np.abs(lap)) * 1.4826 # MAD -> std conversion
153
+
154
+ # Normalize: sigma > 20 is very noisy
155
+ return float(np.clip(sigma_est / 25.0, 0, 1))
156
+
157
+
158
+ def detect_compression_artifacts(image: np.ndarray) -> float:
159
+ """JPEG 8x8 block boundary energy ratio (0-1, 1=heavy artifacts)."""
160
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
161
+ h, w = gray.shape
162
+
163
+ if h < 16 or w < 16:
164
+ return 0.0
165
+
166
+ gray_f = gray.astype(np.float64)
167
+
168
+ # Compute horizontal and vertical differences
169
+ h_diff = np.abs(np.diff(gray_f, axis=1))
170
+ v_diff = np.abs(np.diff(gray_f, axis=0))
171
+
172
+ # Energy at 8-pixel boundaries vs non-boundaries
173
+ h_boundary = h_diff[:, 7::8].mean() if h_diff[:, 7::8].size > 0 else 0
174
+ h_interior = h_diff.mean()
175
+ v_boundary = v_diff[7::8, :].mean() if v_diff[7::8, :].size > 0 else 0
176
+ v_interior = v_diff.mean()
177
+
178
+ if h_interior < 1e-6 or v_interior < 1e-6:
179
+ return 0.0
180
+
181
+ # Ratio of boundary to interior energy (>1 means block artifacts)
182
+ h_ratio = h_boundary / (h_interior + 1e-6)
183
+ v_ratio = v_boundary / (v_interior + 1e-6)
184
+ artifact_ratio = (h_ratio + v_ratio) / 2.0
185
+
186
+ # Normalize: ratio > 1.5 indicates visible artifacts
187
+ return float(np.clip((artifact_ratio - 1.0) / 0.8, 0, 1))
188
+
189
+
190
+ def detect_oversmoothing(image: np.ndarray) -> float:
191
+ """Catch beauty filters: low texture energy but edges still there."""
192
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
193
+ h, w = gray.shape
194
+
195
+ # Focus on face center region (avoid background)
196
+ roi = gray[h // 4:3 * h // 4, w // 4:3 * w // 4]
197
+
198
+ # Texture energy: variance of high-pass filtered image
199
+ blurred = cv2.GaussianBlur(roi.astype(np.float64), (0, 0), 2.0)
200
+ high_pass = roi.astype(np.float64) - blurred
201
+ texture_energy = np.var(high_pass)
202
+
203
+ # Edge energy: Canny edge density
204
+ edges = cv2.Canny(roi, 50, 150)
205
+ edge_density = np.mean(edges > 0)
206
+
207
+ # Oversmooth: low texture but edges still present
208
+ # Natural skin: texture_energy > 20, beauty filter: < 8
209
+ smooth_score = 1.0 - min(texture_energy / 30.0, 1.0)
210
+
211
+ # If there are still strong edges but no texture, it's a filter
212
+ if edge_density > 0.02:
213
+ smooth_score *= 1.3 # Amplify if edges present but no texture
214
+
215
+ return float(np.clip(smooth_score, 0, 1))
216
+
217
+
218
+ def detect_color_cast(image: np.ndarray) -> float:
219
+ """LAB A/B channel deviation from neutral - catches Instagram filters."""
220
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
221
+ h, w = image.shape[:2]
222
+
223
+ # Sample face center region
224
+ roi = lab[h // 4:3 * h // 4, w // 4:3 * w // 4]
225
+
226
+ # A channel: green-red axis (neutral ~128)
227
+ # B channel: blue-yellow axis (neutral ~128)
228
+ a_mean = roi[:, :, 1].mean()
229
+ b_mean = roi[:, :, 2].mean()
230
+
231
+ # Deviation from neutral
232
+ a_dev = abs(a_mean - 128) / 128.0
233
+ b_dev = abs(b_mean - 128) / 128.0
234
+
235
+ # Also check if color distribution is unnaturally narrow (saturated filter)
236
+ a_std = roi[:, :, 1].std()
237
+ b_std = roi[:, :, 2].std()
238
+ narrow_color = max(0, 1.0 - (a_std + b_std) / 30.0)
239
+
240
+ score = 0.5 * (a_dev + b_dev) + 0.3 * narrow_color
241
+ return float(np.clip(score, 0, 1))
242
+
243
+
244
+ def detect_geometric_distortion(image: np.ndarray) -> float:
245
+ """Check face proportions against anatomical norms via landmarks."""
246
+ try:
247
+ from landmarkdiff.landmarks import extract_landmarks
248
+ except ImportError:
249
+ return 0.0
250
+
251
+ face = extract_landmarks(image)
252
+ if face is None:
253
+ return 0.5 # Can't detect face = possibly distorted
254
+
255
+ coords = face.pixel_coords
256
+ h, w = image.shape[:2]
257
+
258
+ # Key ratios that should be anatomically consistent
259
+ left_eye = coords[33]
260
+ right_eye = coords[263]
261
+ nose_tip = coords[1]
262
+ chin = coords[152]
263
+ forehead = coords[10]
264
+
265
+ iod = np.linalg.norm(left_eye - right_eye)
266
+ face_height = np.linalg.norm(forehead - chin)
267
+ nose_to_chin = np.linalg.norm(nose_tip - chin)
268
+
269
+ if iod < 1.0 or face_height < 1.0:
270
+ return 0.5
271
+
272
+ # Anatomical norms (approximate):
273
+ # face_height / iod ≈ 2.5-3.5
274
+ # nose_to_chin / face_height ≈ 0.3-0.45
275
+ height_ratio = face_height / iod
276
+ lower_ratio = nose_to_chin / face_height
277
+
278
+ # Score deviations from normal ranges
279
+ height_dev = max(0, abs(height_ratio - 3.0) - 0.5) / 1.5
280
+ lower_dev = max(0, abs(lower_ratio - 0.38) - 0.08) / 0.15
281
+
282
+ # Eye symmetry check (vertical alignment)
283
+ eye_tilt = abs(left_eye[1] - right_eye[1]) / (iod + 1e-6)
284
+ tilt_dev = max(0, eye_tilt - 0.05) / 0.15
285
+
286
+ score = 0.4 * height_dev + 0.3 * lower_dev + 0.3 * tilt_dev
287
+ return float(np.clip(score, 0, 1))
288
+
289
+
290
+ def detect_lighting_issues(image: np.ndarray) -> float:
291
+ """Luminance histogram clipping and entropy check."""
292
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
293
+ l_channel = lab[:, :, 0]
294
+
295
+ # Check for clipping
296
+ overexposed = np.mean(l_channel > 245) * 5 # Fraction near white
297
+ underexposed = np.mean(l_channel < 10) * 5 # Fraction near black
298
+
299
+ # Check for bimodal distribution (harsh shadows)
300
+ hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten()
301
+ hist = hist / hist.sum()
302
+ # Measure how spread out the histogram is
303
+ entropy = -np.sum(hist[hist > 0] * np.log2(hist[hist > 0] + 1e-10))
304
+ # Low entropy = concentrated = potentially problematic
305
+ entropy_score = max(0, 1.0 - entropy / 7.0)
306
+
307
+ score = 0.4 * overexposed + 0.4 * underexposed + 0.2 * entropy_score
308
+ return float(np.clip(score, 0, 1))
309
+
310
+
311
+ def analyze_distortions(image: np.ndarray) -> DistortionReport:
312
+ """Run all detectors and return a DistortionReport."""
313
+ blur = detect_blur(image)
314
+ noise = detect_noise(image)
315
+ compression = detect_compression_artifacts(image)
316
+ oversmooth = detect_oversmoothing(image)
317
+ color_cast = detect_color_cast(image)
318
+ geometric = detect_geometric_distortion(image)
319
+ lighting = detect_lighting_issues(image)
320
+
321
+ # weighted combination (inverted, 100 = perfect)
322
+ weighted = (
323
+ 0.25 * blur
324
+ + 0.15 * noise
325
+ + 0.10 * compression
326
+ + 0.20 * oversmooth
327
+ + 0.10 * color_cast
328
+ + 0.10 * geometric
329
+ + 0.10 * lighting
330
+ )
331
+ quality = (1.0 - weighted) * 100.0
332
+
333
+ # Classify primary distortion
334
+ scores = {
335
+ "blur": blur,
336
+ "noise": noise,
337
+ "compression": compression,
338
+ "oversmooth": oversmooth,
339
+ "color_cast": color_cast,
340
+ "geometric": geometric,
341
+ "lighting": lighting,
342
+ }
343
+ primary = max(scores, key=scores.get)
344
+ primary_val = scores[primary]
345
+
346
+ if primary_val < 0.15:
347
+ severity = "none"
348
+ primary = "none"
349
+ elif primary_val < 0.35:
350
+ severity = "mild"
351
+ elif primary_val < 0.60:
352
+ severity = "moderate"
353
+ else:
354
+ severity = "severe"
355
+
356
+ # Image is usable if quality > 30 and no severe geometric distortion
357
+ is_usable = quality > 25 and geometric < 0.7
358
+
359
+ return DistortionReport(
360
+ quality_score=quality,
361
+ blur_score=blur,
362
+ noise_score=noise,
363
+ compression_score=compression,
364
+ oversmooth_score=oversmooth,
365
+ color_cast_score=color_cast,
366
+ geometric_distort=geometric,
367
+ lighting_score=lighting,
368
+ primary_distortion=primary,
369
+ severity=severity,
370
+ is_usable=is_usable,
371
+ details=scores,
372
+ )
373
+
374
+
375
+ # ---------------------------------------------------------------------------
376
+ # Neural Face Quality Scoring (no-reference)
377
+ # ---------------------------------------------------------------------------
378
+
379
+ _FACE_QUALITY_NET = None
380
+
381
+
382
+ def _get_face_quality_scorer():
383
+ """Singleton FaceXLib quality model (or None if not installed)."""
384
+ global _FACE_QUALITY_NET
385
+ if _FACE_QUALITY_NET is not None:
386
+ return _FACE_QUALITY_NET
387
+
388
+ try:
389
+ from facexlib.assessment import init_assessment_model
390
+ _FACE_QUALITY_NET = init_assessment_model("hypernet")
391
+ return _FACE_QUALITY_NET
392
+ except Exception:
393
+ pass
394
+
395
+ return None
396
+
397
+
398
+ def neural_quality_score(image: np.ndarray) -> float:
399
+ """Face quality 0-100. FaceXLib if available, else classical fallback."""
400
+ # Try neural scorer
401
+ scorer = _get_face_quality_scorer()
402
+ if scorer is not None:
403
+ try:
404
+ import torch
405
+ from facexlib.utils import img2tensor
406
+ img_t = img2tensor(image / 255.0, bgr2rgb=True, float32=True)
407
+ img_t = img_t.unsqueeze(0)
408
+ if torch.cuda.is_available():
409
+ img_t = img_t.cuda()
410
+ scorer = scorer.cuda()
411
+ with torch.no_grad():
412
+ score = scorer(img_t).item()
413
+ return float(np.clip(score * 100, 0, 100))
414
+ except Exception:
415
+ pass
416
+
417
+ # Fallback: composite classical score
418
+ report = analyze_distortions(image)
419
+ return report.quality_score
420
+
421
+
422
+ # ---------------------------------------------------------------------------
423
+ # Neural Face Restoration (cascaded)
424
+ # ---------------------------------------------------------------------------
425
+
426
+ def restore_face(
427
+ image: np.ndarray,
428
+ distortion: DistortionReport | None = None,
429
+ mode: str = "auto",
430
+ codeformer_fidelity: float = 0.7,
431
+ ) -> tuple[np.ndarray, list[str]]:
432
+ """Cascaded neural face restoration."""
433
+ if distortion is None:
434
+ distortion = analyze_distortions(image)
435
+
436
+ result = image.copy()
437
+ stages = []
438
+
439
+ # fix color cast first (classical, fast, doesn't affect identity)
440
+ if distortion.color_cast_score > 0.25:
441
+ result = _fix_color_cast(result)
442
+ stages.append("color_correction")
443
+
444
+ # Step 1: Fix lighting issues (classical)
445
+ if distortion.lighting_score > 0.35:
446
+ result = _fix_lighting(result)
447
+ stages.append("lighting_fix")
448
+
449
+ # Step 2: Neural face restoration
450
+ if mode == "auto":
451
+ # Choose based on what's wrong
452
+ needs_face_restore = (
453
+ distortion.blur_score > 0.2
454
+ or distortion.oversmooth_score > 0.25
455
+ or distortion.noise_score > 0.25
456
+ or distortion.compression_score > 0.2
457
+ )
458
+ if needs_face_restore:
459
+ mode = "codeformer" # CodeFormer handles most degradations well
460
+
461
+ if mode in ("codeformer", "all"):
462
+ restored = _try_codeformer(result, fidelity=codeformer_fidelity)
463
+ if restored is not None:
464
+ result = restored
465
+ stages.append("codeformer")
466
+ else:
467
+ # Fallback to GFPGAN
468
+ restored = _try_gfpgan(result)
469
+ if restored is not None:
470
+ result = restored
471
+ stages.append("gfpgan")
472
+
473
+ elif mode == "gfpgan":
474
+ restored = _try_gfpgan(result)
475
+ if restored is not None:
476
+ result = restored
477
+ stages.append("gfpgan")
478
+
479
+ # Step 3: Background enhancement with Real-ESRGAN (if image is low-res)
480
+ h, w = result.shape[:2]
481
+ if h < 400 or w < 400:
482
+ enhanced = _try_realesrgan(result)
483
+ if enhanced is not None:
484
+ result = enhanced
485
+ stages.append("realesrgan")
486
+
487
+ # Step 4: Mild sharpening if still soft after restoration
488
+ post_blur = detect_blur(result)
489
+ if post_blur > 0.3:
490
+ from landmarkdiff.postprocess import frequency_aware_sharpen
491
+ result = frequency_aware_sharpen(result, strength=0.3)
492
+ stages.append("sharpen")
493
+
494
+ return result, stages
495
+
496
+
497
+ def _try_codeformer(image: np.ndarray, fidelity: float = 0.7) -> np.ndarray | None:
498
+ """Try CodeFormer restoration. Returns None if unavailable."""
499
+ try:
500
+ from landmarkdiff.postprocess import restore_face_codeformer
501
+ restored = restore_face_codeformer(image, fidelity=fidelity)
502
+ if restored is not image:
503
+ return restored
504
+ except Exception:
505
+ pass
506
+ return None
507
+
508
+
509
+ def _try_gfpgan(image: np.ndarray) -> np.ndarray | None:
510
+ """Try GFPGAN restoration. Returns None if unavailable."""
511
+ try:
512
+ from landmarkdiff.postprocess import restore_face_gfpgan
513
+ restored = restore_face_gfpgan(image)
514
+ if restored is not image:
515
+ return restored
516
+ except Exception:
517
+ pass
518
+ return None
519
+
520
+
521
+ def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
522
+ """Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable."""
523
+ try:
524
+ from realesrgan import RealESRGANer
525
+ from basicsr.archs.rrdbnet_arch import RRDBNet
526
+ import torch
527
+
528
+ model = RRDBNet(
529
+ num_in_ch=3, num_out_ch=3, num_feat=64,
530
+ num_block=23, num_grow_ch=32, scale=4,
531
+ )
532
+ upsampler = RealESRGANer(
533
+ scale=4,
534
+ model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
535
+ model=model,
536
+ tile=400,
537
+ tile_pad=10,
538
+ pre_pad=0,
539
+ half=torch.cuda.is_available(),
540
+ )
541
+ enhanced, _ = upsampler.enhance(image, outscale=2)
542
+
543
+ # Downsample to 512x512 for pipeline consistency
544
+ enhanced = cv2.resize(enhanced, (512, 512), interpolation=cv2.INTER_LANCZOS4)
545
+ return enhanced
546
+ except Exception:
547
+ pass
548
+ return None
549
+
550
+
551
+ def _fix_color_cast(image: np.ndarray) -> np.ndarray:
552
+ """Remove color cast by normalizing A/B channels in LAB space."""
553
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
554
+
555
+ # Center A and B channels around 128 (neutral)
556
+ for ch in [1, 2]:
557
+ channel = lab[:, :, ch]
558
+ mean_val = channel.mean()
559
+ # Shift toward neutral, but only partially to preserve natural skin tone
560
+ shift = (128.0 - mean_val) * 0.6
561
+ lab[:, :, ch] = np.clip(channel + shift, 0, 255)
562
+
563
+ return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
564
+
565
+
566
+ def _fix_lighting(image: np.ndarray) -> np.ndarray:
567
+ """Fix over/under exposure using adaptive CLAHE in LAB space."""
568
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
569
+
570
+ # CLAHE on luminance channel only
571
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
572
+ lab[:, :, 0] = clahe.apply(lab[:, :, 0])
573
+
574
+ return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
575
+
576
+
577
+ # ---------------------------------------------------------------------------
578
+ # ArcFace Identity Verification
579
+ # ---------------------------------------------------------------------------
580
+
581
+ _ARCFACE_APP = None
582
+
583
+
584
+ def _get_arcface():
585
+ """Get or create singleton ArcFace model."""
586
+ global _ARCFACE_APP
587
+ if _ARCFACE_APP is not None:
588
+ return _ARCFACE_APP
589
+
590
+ try:
591
+ from insightface.app import FaceAnalysis
592
+ import torch
593
+
594
+ app = FaceAnalysis(
595
+ name="buffalo_l",
596
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
597
+ )
598
+ ctx_id = 0 if torch.cuda.is_available() else -1
599
+ app.prepare(ctx_id=ctx_id, det_size=(320, 320))
600
+ _ARCFACE_APP = app
601
+ return app
602
+ except Exception:
603
+ return None
604
+
605
+
606
+ def get_face_embedding(image: np.ndarray) -> np.ndarray | None:
607
+ """ArcFace 512-d embedding, or None if no face / no InsightFace."""
608
+ app = _get_arcface()
609
+ if app is None:
610
+ return None
611
+
612
+ try:
613
+ faces = app.get(image)
614
+ if faces:
615
+ return faces[0].embedding
616
+ except Exception:
617
+ pass
618
+ return None
619
+
620
+
621
+ def verify_identity(
622
+ original: np.ndarray,
623
+ restored: np.ndarray,
624
+ threshold: float = 0.6,
625
+ ) -> tuple[float, bool]:
626
+ """ArcFace cosine sim between original and restored. Returns (sim, passed)."""
627
+ emb_orig = get_face_embedding(original)
628
+ emb_rest = get_face_embedding(restored)
629
+
630
+ if emb_orig is None or emb_rest is None:
631
+ return -1.0, True # can't verify, assume OK
632
+
633
+ sim = float(np.dot(emb_orig, emb_rest) / (
634
+ np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8
635
+ ))
636
+ sim = float(np.clip(sim, -1, 1))
637
+ return sim, sim >= threshold
638
+
639
+
640
+ # ---------------------------------------------------------------------------
641
+ # Full Verification + Restoration Pipeline
642
+ # ---------------------------------------------------------------------------
643
+
644
+ def verify_and_restore(
645
+ image: np.ndarray,
646
+ quality_threshold: float = 60.0,
647
+ identity_threshold: float = 0.6,
648
+ restore_mode: str = "auto",
649
+ codeformer_fidelity: float = 0.7,
650
+ ) -> RestorationResult:
651
+ """Full pipeline: analyze -> restore -> verify identity."""
652
+ # Step 1: Analyze distortions
653
+ report = analyze_distortions(image)
654
+
655
+ # Step 2: Decide if restoration needed
656
+ if report.quality_score >= quality_threshold and report.severity in ("none", "mild"):
657
+ # image is good enough, skip restoration
658
+ return RestorationResult(
659
+ restored=image.copy(),
660
+ original=image.copy(),
661
+ distortion_report=report,
662
+ post_quality_score=report.quality_score,
663
+ identity_similarity=1.0,
664
+ identity_preserved=True,
665
+ restoration_stages=[],
666
+ improvement=0.0,
667
+ )
668
+
669
+ if not report.is_usable:
670
+ # Too distorted to salvage
671
+ return RestorationResult(
672
+ restored=image.copy(),
673
+ original=image.copy(),
674
+ distortion_report=report,
675
+ post_quality_score=report.quality_score,
676
+ identity_similarity=0.0,
677
+ identity_preserved=False,
678
+ restoration_stages=["rejected"],
679
+ improvement=0.0,
680
+ )
681
+
682
+ # Step 3: Neural restoration
683
+ restored, stages = restore_face(
684
+ image,
685
+ distortion=report,
686
+ mode=restore_mode,
687
+ codeformer_fidelity=codeformer_fidelity,
688
+ )
689
+
690
+ # Step 4: Post-restoration quality check
691
+ post_quality = neural_quality_score(restored)
692
+
693
+ # Step 5: Identity verification
694
+ sim, id_ok = verify_identity(image, restored, threshold=identity_threshold)
695
+
696
+ return RestorationResult(
697
+ restored=restored,
698
+ original=image.copy(),
699
+ distortion_report=report,
700
+ post_quality_score=post_quality,
701
+ identity_similarity=sim,
702
+ identity_preserved=id_ok,
703
+ restoration_stages=stages,
704
+ improvement=post_quality - report.quality_score,
705
+ )
706
+
707
+
708
+ # ---------------------------------------------------------------------------
709
+ # Batch Processing
710
+ # ---------------------------------------------------------------------------
711
+
712
+ def verify_batch(
713
+ image_dir: str,
714
+ output_dir: str | None = None,
715
+ quality_threshold: float = 60.0,
716
+ identity_threshold: float = 0.6,
717
+ restore_mode: str = "auto",
718
+ save_rejected: bool = False,
719
+ extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp", ".bmp"),
720
+ ) -> BatchVerificationReport:
721
+ """Process a directory of face images: analyze, restore, verify, sort."""
722
+ image_path = Path(image_dir)
723
+ if output_dir is None:
724
+ out_path = image_path.parent / f"{image_path.name}_verified"
725
+ else:
726
+ out_path = Path(output_dir)
727
+
728
+ # Create output dirs
729
+ passed_dir = out_path / "passed"
730
+ restored_dir = out_path / "restored"
731
+ rejected_dir = out_path / "rejected"
732
+ passed_dir.mkdir(parents=True, exist_ok=True)
733
+ restored_dir.mkdir(parents=True, exist_ok=True)
734
+ if save_rejected:
735
+ rejected_dir.mkdir(parents=True, exist_ok=True)
736
+
737
+ # Find all images
738
+ image_files = sorted([
739
+ f for f in image_path.iterdir()
740
+ if f.suffix.lower() in extensions and f.is_file()
741
+ ])
742
+
743
+ report = BatchVerificationReport(total=len(image_files))
744
+ quality_before = []
745
+ quality_after = []
746
+ identity_sims = []
747
+
748
+ for i, img_file in enumerate(image_files):
749
+ if (i + 1) % 50 == 0 or i == 0:
750
+ print(f"Processing {i + 1}/{len(image_files)}: {img_file.name}")
751
+
752
+ image = cv2.imread(str(img_file))
753
+ if image is None:
754
+ report.rejected += 1
755
+ continue
756
+
757
+ # Resize to 512x512 for consistency
758
+ image = cv2.resize(image, (512, 512))
759
+
760
+ # Run verification + restoration
761
+ result = verify_and_restore(
762
+ image,
763
+ quality_threshold=quality_threshold,
764
+ identity_threshold=identity_threshold,
765
+ restore_mode=restore_mode,
766
+ )
767
+
768
+ quality_before.append(result.distortion_report.quality_score)
769
+ quality_after.append(result.post_quality_score)
770
+
771
+ # Track distortion types
772
+ dist_type = result.distortion_report.primary_distortion
773
+ report.distortion_counts[dist_type] = report.distortion_counts.get(dist_type, 0) + 1
774
+
775
+ if not result.distortion_report.is_usable or "rejected" in result.restoration_stages:
776
+ report.rejected += 1
777
+ if save_rejected:
778
+ cv2.imwrite(str(rejected_dir / img_file.name), image)
779
+ elif not result.restoration_stages:
780
+ # Passed without restoration
781
+ report.passed += 1
782
+ cv2.imwrite(str(passed_dir / img_file.name), image)
783
+ else:
784
+ # Restored
785
+ if result.identity_preserved:
786
+ report.restored += 1
787
+ cv2.imwrite(str(restored_dir / img_file.name), result.restored)
788
+ identity_sims.append(result.identity_similarity)
789
+ else:
790
+ report.identity_failures += 1
791
+ if save_rejected:
792
+ cv2.imwrite(str(rejected_dir / img_file.name), image)
793
+
794
+ # Compute averages
795
+ report.avg_quality_before = float(np.mean(quality_before)) if quality_before else 0.0
796
+ report.avg_quality_after = float(np.mean(quality_after)) if quality_after else 0.0
797
+ report.avg_identity_sim = float(np.mean(identity_sims)) if identity_sims else 0.0
798
+
799
+ # Save report
800
+ report_text = report.summary()
801
+ (out_path / "report.txt").write_text(report_text)
802
+ print(f"\n{report_text}")
803
+ print(f"\nResults saved to {out_path}/")
804
+
805
+ return report