File size: 34,294 Bytes
9f3acc0
392e60f
 
 
433e26f
0332541
433e26f
0332541
433e26f
392e60f
 
 
 
 
0f5dd41
4ed47f2
392e60f
 
433e26f
 
 
f069dfc
9f3acc0
 
 
f069dfc
9f3acc0
 
 
 
 
 
 
0332541
 
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
 
 
 
 
 
 
 
0f5dd41
 
 
 
 
5c6d1cb
0f5dd41
 
 
5c6d1cb
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
 
 
 
 
 
 
 
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0332541
5c6d1cb
392e60f
 
 
4ed47f2
392e60f
 
 
b55ddea
9f3acc0
b55ddea
 
 
9f3acc0
b55ddea
 
 
 
 
 
 
9f3acc0
 
 
 
433e26f
 
9f3acc0
 
0332541
 
392e60f
0f5dd41
392e60f
9f3acc0
 
433e26f
0332541
 
433e26f
0f5dd41
433e26f
 
9f3acc0
 
433e26f
0f5dd41
392e60f
0f5dd41
 
4ed47f2
5c6d1cb
 
 
 
9f3acc0
4ed47f2
392e60f
433e26f
 
 
 
392e60f
433e26f
9f3acc0
392e60f
433e26f
 
 
392e60f
5c6d1cb
0332541
 
0f5dd41
 
 
 
 
1654955
0f5dd41
aa29cac
1654955
 
 
aa29cac
 
1654955
 
 
aa29cac
 
1654955
 
 
aa29cac
 
1654955
aa29cac
0f5dd41
 
fedb187
392e60f
433e26f
 
9f3acc0
 
392e60f
 
 
9f3acc0
392e60f
9f3acc0
392e60f
433e26f
0f5dd41
 
433e26f
 
 
 
 
 
 
 
 
 
 
 
 
9f3acc0
 
392e60f
 
 
9f3acc0
392e60f
 
 
433e26f
0f5dd41
 
433e26f
 
 
 
9f3acc0
433e26f
 
 
 
 
 
 
 
 
 
9f3acc0
392e60f
 
 
9f3acc0
 
 
0332541
9f3acc0
5c6d1cb
f069dfc
 
1654955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392e60f
1654955
392e60f
1654955
392e60f
1654955
 
 
 
 
 
 
 
 
 
 
aa29cac
 
1654955
 
 
 
 
 
 
 
9f3acc0
 
 
392e60f
 
 
9f3acc0
392e60f
5c6d1cb
 
 
1654955
 
 
 
aa29cac
1654955
 
 
 
0332541
392e60f
5c6d1cb
 
 
 
 
1654955
 
 
 
 
 
 
 
0332541
392e60f
 
 
 
 
 
 
9f3acc0
392e60f
0332541
 
 
 
9f3acc0
0332541
 
1654955
 
aa29cac
 
 
 
 
 
 
 
 
 
 
 
1654955
 
9f3acc0
0f5dd41
 
5c6d1cb
 
 
 
 
 
392e60f
9f3acc0
 
 
392e60f
 
9f3acc0
392e60f
9f3acc0
392e60f
 
9f3acc0
392e60f
 
9f3acc0
 
392e60f
4ed47f2
9f3acc0
4ed47f2
 
392e60f
 
0332541
 
 
5c6d1cb
 
0332541
 
9f3acc0
392e60f
9f3acc0
392e60f
9f3acc0
392e60f
 
9f3acc0
 
 
392e60f
9f3acc0
392e60f
0332541
 
 
5c6d1cb
 
0332541
 
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
 
 
0f5dd41
 
5c6d1cb
 
 
0f5dd41
 
 
 
 
5c6d1cb
 
 
0f5dd41
 
 
 
 
5c6d1cb
 
0f5dd41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c6d1cb
 
 
0f5dd41
 
5c6d1cb
 
 
0f5dd41
 
5c6d1cb
 
 
0f5dd41
 
 
5c6d1cb
 
0f5dd41
 
5c6d1cb
 
0f5dd41
 
5c6d1cb
 
 
0f5dd41
 
 
 
 
 
 
392e60f
1654955
 
 
 
 
 
 
 
9f3acc0
f069dfc
392e60f
433e26f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
"""LandmarkDiff -- Facial surgery outcome prediction demo (TPS on CPU)."""

from __future__ import annotations

import logging
import time
import traceback
from pathlib import Path

import cv2
import gradio as gr
import numpy as np

from landmarkdiff.conditioning import render_wireframe
from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks
from landmarkdiff.manipulation import PROCEDURE_LANDMARKS, apply_procedure_preset
from landmarkdiff.masking import generate_surgical_mask

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

GITHUB_URL = "https://github.com/dreamlessx/LandmarkDiff-public"
PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
EXAMPLE_DIR = Path(__file__).parent / "examples"
EXAMPLE_IMAGES = sorted(EXAMPLE_DIR.glob("*.png")) if EXAMPLE_DIR.exists() else []

PROCEDURE_INFO = {
    "rhinoplasty": "Nose reshaping (bridge, tip, alar width)",
    "blepharoplasty": "Eyelid surgery (lid position, canthal tilt)",
    "rhytidectomy": "Facelift (midface, jawline tightening)",
    "orthognathic": "Jaw surgery (maxilla/mandible repositioning)",
    "brow_lift": "Brow elevation, forehead ptosis reduction",
    "mentoplasty": "Chin surgery (projection, vertical height)",
}

# ---------------------------------------------------------------------------
# Bilateral symmetry landmark pairs (MediaPipe face mesh indices)
# ---------------------------------------------------------------------------
SYMMETRY_PAIRS: dict[str, list[tuple[int, int]]] = {
    "eyes": [
        (33, 263),
        (133, 362),
        (159, 386),
        (145, 374),
    ],
    "brows": [
        (70, 300),
        (63, 293),
        (105, 334),
        (66, 296),
        (107, 336),
    ],
    "cheeks": [
        (116, 345),
        (123, 352),
        (147, 376),
        (187, 411),
        (205, 425),
    ],
    "mouth": [
        (61, 291),
        (78, 308),
        (95, 324),
    ],
    "jaw": [
        (172, 397),
        (136, 365),
        (150, 379),
        (149, 378),
        (176, 400),
    ],
}

# Midline landmarks: forehead top and chin bottom
MIDLINE_TOP = 10
MIDLINE_BOTTOM = 152


# ---------------------------------------------------------------------------
# Image preprocessing helpers
# ---------------------------------------------------------------------------


def _normalize_to_bgr(image: np.ndarray) -> np.ndarray:
    """Convert any input image format (RGBA, grayscale, etc.) to BGR uint8."""
    if image is None:
        raise ValueError("No image provided")

    img = np.asarray(image)

    # Handle float images (0-1 range)
    if img.dtype in (np.float32, np.float64):
        img = (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8)

    # Ensure uint8
    if img.dtype != np.uint8:
        img = img.astype(np.uint8)

    if img.ndim == 2:
        # Grayscale -> BGR
        return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    elif img.ndim == 3:
        channels = img.shape[2]
        if channels == 4:
            # RGBA -> BGR (drop alpha)
            return cv2.cvtColor(img, cv2.COLOR_RGBA2BGR)
        elif channels == 3:
            # RGB -> BGR (Gradio sends RGB)
            return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        elif channels == 1:
            return cv2.cvtColor(img.squeeze(-1), cv2.COLOR_GRAY2BGR)
    raise ValueError(f"Unsupported image shape: {img.shape}")


def _auto_adjust_brightness(image_bgr: np.ndarray) -> np.ndarray:
    """Auto-adjust brightness/contrast if the image is too dark or washed out.

    Uses CLAHE on the L channel of LAB color space for adaptive histogram
    equalization that preserves color balance.
    """
    lab = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2LAB)
    l_channel = lab[:, :, 0]
    mean_l = float(np.mean(l_channel))

    # Only adjust if clearly too dark (<60) or washed out (>200)
    if mean_l < 60 or mean_l > 200:
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        lab[:, :, 0] = clahe.apply(l_channel)
        return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)

    return image_bgr


def _prepare_image(image_rgb: np.ndarray, size: int = 512) -> tuple[np.ndarray, np.ndarray]:
    """Full preprocessing pipeline: normalize, resize, auto-adjust.

    Returns (image_bgr_512, image_rgb_512).
    """
    image_bgr = _normalize_to_bgr(image_rgb)
    image_bgr = resize_preserve_aspect(image_bgr, size)
    image_bgr = _auto_adjust_brightness(image_bgr)
    image_rgb_out = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    return image_bgr, image_rgb_out


def _detect_face_with_hints(image_bgr: np.ndarray) -> tuple[FaceLandmarks | None, str]:
    """Extract landmarks with better error messages for common failure modes.

    Returns (face_or_None, error_hint_string).
    """
    try:
        face = extract_landmarks(image_bgr)
    except Exception as exc:
        logger.error("Landmark extraction failed: %s\n%s", exc, traceback.format_exc())
        return None, f"Landmark extraction error: {exc}"

    if face is not None:
        return face, ""

    # Try to give a more useful hint about why detection failed.
    gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
    h, w = gray.shape[:2]

    # Check if image is mostly black (too dark even after auto-adjust)
    if float(np.mean(gray)) < 30:
        return None, (
            "No face detected -- the image appears very dark. Try a photo with better lighting."
        )

    # Check for very low contrast (washed out)
    if float(np.std(gray)) < 15:
        return None, (
            "No face detected -- the image has very low contrast. "
            "Try a photo with more natural lighting."
        )

    # Check aspect ratio -- extremely tall/wide may indicate a side profile crop
    aspect = w / max(h, 1)
    if aspect > 2.5 or aspect < 0.4:
        return None, (
            "No face detected -- unusual aspect ratio. Use a standard portrait or headshot photo."
        )

    return None, (
        "No face detected. Make sure the photo shows a clear, "
        "well-lit frontal face. Side profiles and heavily occluded "
        "faces are not supported."
    )


# ---------------------------------------------------------------------------
# Symmetry analysis
# ---------------------------------------------------------------------------


def compute_symmetry_score(
    face: FaceLandmarks,
) -> tuple[float, dict[str, float]]:
    """Compute bilateral facial symmetry from a FaceLandmarks object.

    Reflects left-side landmarks across the facial midline and measures the
    Euclidean distance to their right-side counterparts. Distances are
    normalized by the inter-pupil distance to make the score scale-invariant.

    Args:
        face: FaceLandmarks with .pixel_coords property returning (478, 2).

    Returns:
        (overall_score, region_scores) where scores are 0-100
        (100 = perfectly symmetric).
    """
    coords = face.pixel_coords  # (478, 2) -- property, not method

    # Compute facial midline from forehead top (10) and chin bottom (152)
    mid_top = coords[MIDLINE_TOP]  # (2,)
    mid_bot = coords[MIDLINE_BOTTOM]  # (2,)

    # Midline direction vector and unit normal
    midline_dir = mid_bot - mid_top
    midline_len = np.linalg.norm(midline_dir)
    if midline_len < 1e-6:
        # Degenerate case -- landmarks are stacked
        return 0.0, {region: 0.0 for region in SYMMETRY_PAIRS}

    midline_unit = midline_dir / midline_len
    # Normal to midline (pointing right)
    midline_normal = np.array([midline_unit[1], -midline_unit[0]])

    # Normalization factor: use inter-eye distance (outer corners 33 <-> 263)
    # for scale-invariant scoring
    inter_eye = float(np.linalg.norm(coords[33] - coords[263]))
    if inter_eye < 1e-6:
        inter_eye = midline_len * 0.4  # fallback

    region_scores: dict[str, float] = {}
    all_distances: list[float] = []

    for region, pairs in SYMMETRY_PAIRS.items():
        region_dists: list[float] = []
        for left_idx, right_idx in pairs:
            left_pt = coords[left_idx]
            right_pt = coords[right_idx]

            # Reflect left point across the midline:
            # 1. Vector from midline top to the point
            v = left_pt - mid_top
            # 2. Component along the midline normal
            normal_component = np.dot(v, midline_normal)
            # 3. Reflected point: subtract twice the normal component
            reflected = left_pt - 2.0 * normal_component * midline_normal

            # Distance between reflected-left and actual-right
            dist = float(np.linalg.norm(reflected - right_pt))
            region_dists.append(dist)

        # Normalize by inter-eye distance and convert to 0-100 score
        if region_dists:
            mean_dist = float(np.mean(region_dists))
            # Normalized distance as fraction of inter-eye distance
            norm_dist = mean_dist / inter_eye
            # Convert to score: 0 distance = 100, large distance = 0
            # Use exponential decay so small asymmetries are penalized gently
            score = 100.0 * np.exp(-3.0 * norm_dist)
            region_scores[region] = round(max(0.0, min(100.0, score)), 1)
            all_distances.extend(region_dists)
        else:
            region_scores[region] = 0.0

    # Overall score: weighted mean of all pair distances
    if all_distances:
        overall_norm = float(np.mean(all_distances)) / inter_eye
        overall = 100.0 * np.exp(-3.0 * overall_norm)
        overall = round(max(0.0, min(100.0, overall)), 1)
    else:
        overall = 0.0

    return overall, region_scores


def render_symmetry_overlay(
    image_bgr: np.ndarray,
    face: FaceLandmarks,
    region_scores: dict[str, float],
) -> np.ndarray:
    """Draw a symmetry visualization overlay on the image.

    Draws the facial midline and color-codes bilateral landmark pairs by
    their region symmetry score: green (>80), yellow (50-80), red (<50).
    """
    canvas = image_bgr.copy()
    coords = face.pixel_coords

    # Draw midline
    mid_top = coords[MIDLINE_TOP].astype(int)
    mid_bot = coords[MIDLINE_BOTTOM].astype(int)
    cv2.line(canvas, tuple(mid_top), tuple(mid_bot), (255, 200, 0), 2, cv2.LINE_AA)

    # Small label at midline top
    cv2.putText(
        canvas,
        "midline",
        (int(mid_top[0]) + 5, int(mid_top[1]) - 5),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.4,
        (255, 200, 0),
        1,
        cv2.LINE_AA,
    )

    def _score_color(score: float) -> tuple[int, int, int]:
        """BGR color based on symmetry score."""
        if score >= 80:
            return (0, 200, 0)  # green
        elif score >= 50:
            return (0, 200, 220)  # yellow (BGR)
        else:
            return (0, 0, 220)  # red

    for region, pairs in SYMMETRY_PAIRS.items():
        score = region_scores.get(region, 0.0)
        color = _score_color(score)

        for left_idx, right_idx in pairs:
            lp = coords[left_idx].astype(int)
            rp = coords[right_idx].astype(int)

            # Draw landmark dots
            cv2.circle(canvas, tuple(lp), 3, color, -1, cv2.LINE_AA)
            cv2.circle(canvas, tuple(rp), 3, color, -1, cv2.LINE_AA)

            # Draw thin connecting line across midline
            cv2.line(canvas, tuple(lp), tuple(rp), color, 1, cv2.LINE_AA)

    # Draw region labels with scores
    # Position labels near each region's centroid
    region_label_offsets: dict[str, tuple[int, int]] = {
        "eyes": (0, -15),
        "brows": (0, -10),
        "cheeks": (15, 0),
        "mouth": (0, 10),
        "jaw": (0, 15),
    }

    for region, pairs in SYMMETRY_PAIRS.items():
        score = region_scores.get(region, 0.0)
        color = _score_color(score)

        # Compute centroid of the region landmarks
        region_pts = []
        for left_idx, right_idx in pairs:
            region_pts.append(coords[left_idx])
            region_pts.append(coords[right_idx])
        centroid = np.mean(region_pts, axis=0).astype(int)

        ox, oy = region_label_offsets.get(region, (0, 0))
        label_pos = (int(centroid[0]) + ox, int(centroid[1]) + oy)

        label = f"{region}: {score:.0f}"
        cv2.putText(
            canvas,
            label,
            label_pos,
            cv2.FONT_HERSHEY_SIMPLEX,
            0.4,
            color,
            1,
            cv2.LINE_AA,
        )

    return canvas


def _format_symmetry_text(
    overall: float,
    region_scores: dict[str, float],
    prefix: str = "",
) -> str:
    """Format symmetry scores into a readable text block."""
    lines = []
    if prefix:
        lines.append(prefix)
    lines.append(f"Overall symmetry: {overall:.1f}/100")
    for region, score in region_scores.items():
        bar_len = int(score / 5)
        bar = "|" * bar_len + "." * (20 - bar_len)
        lines.append(f"  {region:>6s}: {score:5.1f}  [{bar}]")
    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Symmetry tab callbacks
# ---------------------------------------------------------------------------


def analyze_symmetry(image_rgb: np.ndarray):
    """Analyze facial symmetry from an uploaded photo."""
    if image_rgb is None:
        b = _blank()
        return b, "Upload a face photo to analyze symmetry."

    try:
        image_bgr, image_rgb_512 = _prepare_image(image_rgb, 512)
    except Exception as exc:
        logger.error("Image conversion failed: %s", exc)
        b = _blank()
        return b, f"Image conversion failed: {exc}"

    face, hint = _detect_face_with_hints(image_bgr)
    if face is None:
        return image_rgb_512, hint

    overall, region_scores = compute_symmetry_score(face)
    overlay_bgr = render_symmetry_overlay(image_bgr, face, region_scores)
    overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)

    text = _format_symmetry_text(overall, region_scores)
    return overlay_rgb, text


def analyze_symmetry_comparison(
    pre_image_rgb: np.ndarray,
    post_image_rgb: np.ndarray,
):
    """Compare symmetry between pre- and post-procedure photos."""
    b = _blank()
    empty = (b, b, "Upload both pre and post photos to compare.")

    if pre_image_rgb is None or post_image_rgb is None:
        return empty

    try:
        pre_bgr, _ = _prepare_image(pre_image_rgb, 512)
        post_bgr, _ = _prepare_image(post_image_rgb, 512)
    except Exception as exc:
        logger.error("Image conversion failed: %s", exc)
        return b, b, f"Image conversion failed: {exc}"

    pre_face, pre_hint = _detect_face_with_hints(pre_bgr)
    if pre_face is None:
        return b, b, f"Pre-procedure: {pre_hint}"

    post_face, post_hint = _detect_face_with_hints(post_bgr)
    if post_face is None:
        return b, b, f"Post-procedure: {post_hint}"

    pre_overall, pre_regions = compute_symmetry_score(pre_face)
    post_overall, post_regions = compute_symmetry_score(post_face)

    pre_overlay = render_symmetry_overlay(pre_bgr, pre_face, pre_regions)
    post_overlay = render_symmetry_overlay(post_bgr, post_face, post_regions)

    pre_rgb = cv2.cvtColor(pre_overlay, cv2.COLOR_BGR2RGB)
    post_rgb = cv2.cvtColor(post_overlay, cv2.COLOR_BGR2RGB)

    lines = []
    lines.append(_format_symmetry_text(pre_overall, pre_regions, prefix="-- Pre-procedure --"))
    lines.append("")
    lines.append(_format_symmetry_text(post_overall, post_regions, prefix="-- Post-procedure --"))
    lines.append("")

    delta = post_overall - pre_overall
    direction = "improved" if delta > 0 else "decreased"
    lines.append(f"Change: {delta:+.1f} ({direction})")

    # Per-region deltas
    for region in pre_regions:
        d = post_regions.get(region, 0.0) - pre_regions[region]
        lines.append(f"  {region:>6s}: {d:+.1f}")

    return pre_rgb, post_rgb, "\n".join(lines)


# ---------------------------------------------------------------------------
# Core pipeline functions
# ---------------------------------------------------------------------------


def warp_image_tps(image, src_pts, dst_pts):
    """Thin-plate spline warp (CPU only)."""
    from landmarkdiff.synthetic.tps_warp import warp_image_tps as _warp

    return _warp(image, src_pts, dst_pts)


def resize_preserve_aspect(image, size=512):
    """Resize to square canvas, padding to preserve aspect ratio."""
    h, w = image.shape[:2]
    scale = size / max(h, w)
    new_w, new_h = int(w * scale), int(h * scale)
    resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
    canvas = np.zeros((size, size, 3), dtype=np.uint8)
    y_off = (size - new_h) // 2
    x_off = (size - new_w) // 2
    canvas[y_off : y_off + new_h, x_off : x_off + new_w] = resized
    return canvas


def mask_composite(warped, original, mask):
    """Alpha-blend warped region into original using mask."""
    mask_3 = np.stack([mask] * 3, axis=-1) if mask.ndim == 2 else mask
    return (warped * mask_3 + original * (1.0 - mask_3)).astype(np.uint8)


def _blank():
    return np.zeros((512, 512, 3), dtype=np.uint8)


def process_image(image_rgb, procedure, intensity):
    """Run the TPS pipeline on a single image, including symmetry scores."""
    if image_rgb is None:
        b = _blank()
        return b, b, b, b, "Upload a face photo to begin."

    t0 = time.monotonic()

    try:
        image_bgr, image_rgb_512 = _prepare_image(image_rgb, 512)
    except Exception as exc:
        logger.error("Image conversion failed: %s", exc)
        b = _blank()
        return b, b, b, b, f"Image conversion failed: {exc}"

    face, hint = _detect_face_with_hints(image_bgr)
    if face is None:
        if hint:
            return image_rgb_512, image_rgb_512, image_rgb_512, image_rgb_512, hint
        return (
            image_rgb_512,
            image_rgb_512,
            image_rgb_512,
            image_rgb_512,
            "No face detected. Try a clearer, well-lit frontal photo.",
        )

    try:
        manipulated = apply_procedure_preset(face, procedure, float(intensity), image_size=512)
        wireframe = render_wireframe(manipulated, width=512, height=512)
        wireframe_rgb = cv2.cvtColor(wireframe, cv2.COLOR_GRAY2RGB)

        mask = generate_surgical_mask(face, procedure, 512, 512)
        mask_vis = cv2.cvtColor((mask * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)

        warped = warp_image_tps(image_bgr, face.pixel_coords, manipulated.pixel_coords)
        composited = mask_composite(warped, image_bgr, mask)
        composited_rgb = cv2.cvtColor(composited, cv2.COLOR_BGR2RGB)

        displacement = np.mean(np.linalg.norm(manipulated.pixel_coords - face.pixel_coords, axis=1))
        elapsed = time.monotonic() - t0

        # Compute symmetry for original and predicted result
        pre_overall, pre_regions = compute_symmetry_score(face)
        post_overall, post_regions = compute_symmetry_score(manipulated)
        sym_delta = post_overall - pre_overall

        sym_arrow = "+" if sym_delta > 0 else ""
        info_lines = [
            "--- Procedure ---",
            f"  Type:          {procedure.replace('_', ' ').title()}",
            f"  Intensity:     {intensity:.0f}%",
            f"  Description:   {PROCEDURE_INFO.get(procedure, '')}",
            "",
            "--- Detection ---",
            f"  Landmarks:     {len(face.landmarks)} points",
            f"  Confidence:    {face.confidence:.2f}",
            f"  Avg shift:     {displacement:.1f} px",
            "",
            "--- Symmetry ---",
            f"  Before:        {pre_overall:.1f} / 100",
            f"  After:         {post_overall:.1f} / 100",
            f"  Change:        {sym_arrow}{sym_delta:.1f}",
            "",
            "--- Performance ---",
            f"  Time:          {elapsed:.2f}s",
            "  Mode:          TPS (CPU)",
        ]
        info = "\n".join(info_lines)
        return wireframe_rgb, mask_vis, composited_rgb, image_rgb_512, info

    except Exception as exc:
        logger.error("Processing failed: %s\n%s", exc, traceback.format_exc())
        b = _blank()
        return b, b, b, b, f"Processing error: {exc}"


def compare_procedures(image_rgb, intensity):
    """Compare all six procedures at the same intensity."""
    if image_rgb is None:
        return [_blank()] * len(PROCEDURES)

    try:
        image_bgr, _ = _prepare_image(image_rgb, 512)
        face, _ = _detect_face_with_hints(image_bgr)
        if face is None:
            rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
            return [rgb] * len(PROCEDURES)

        results = []
        for proc in PROCEDURES:
            manip = apply_procedure_preset(face, proc, float(intensity), image_size=512)
            mask = generate_surgical_mask(face, proc, 512, 512)
            warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
            comp = mask_composite(warped, image_bgr, mask)
            results.append(cv2.cvtColor(comp, cv2.COLOR_BGR2RGB))
        return results
    except Exception as exc:
        logger.error("Compare failed: %s\n%s", exc, traceback.format_exc())
        return [_blank()] * len(PROCEDURES)


def intensity_sweep(image_rgb, procedure):
    """Generate results at 0%, 20%, 40%, 60%, 80%, 100% intensity."""
    if image_rgb is None:
        return []

    try:
        image_bgr, _ = _prepare_image(image_rgb, 512)
        face, _ = _detect_face_with_hints(image_bgr)
        if face is None:
            return []

        results = []
        for val in [0, 20, 40, 60, 80, 100]:
            if val == 0:
                results.append((cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB), "0%"))
                continue
            manip = apply_procedure_preset(face, procedure, float(val), image_size=512)
            mask = generate_surgical_mask(face, procedure, 512, 512)
            warped = warp_image_tps(image_bgr, face.pixel_coords, manip.pixel_coords)
            comp = mask_composite(warped, image_bgr, mask)
            results.append((cv2.cvtColor(comp, cv2.COLOR_BGR2RGB), f"{val}%"))
        return results
    except Exception as exc:
        logger.error("Sweep failed: %s\n%s", exc, traceback.format_exc())
        return []


# ---------------------------------------------------------------------------
# Build the Gradio UI
# ---------------------------------------------------------------------------

_proc_table = "\n".join(
    f"| {name.replace('_', ' ').title()} | {desc} |" for name, desc in PROCEDURE_INFO.items()
)

_CSS = """
.header-banner {
    background: linear-gradient(135deg, #1a1a2e 0%, #16213e 50%, #0f3460 100%);
    border-radius: 12px;
    padding: 24px 32px;
    margin-bottom: 8px;
    color: white;
}
.header-banner h1 { color: white !important; margin-bottom: 4px; font-size: 2em; }
.header-banner p { color: #ccd; margin: 4px 0; font-size: 0.95em; }
.header-banner a { color: #7eb8f7; text-decoration: none; }
.header-banner a:hover { text-decoration: underline; }
.link-bar { display: flex; gap: 16px; margin-top: 10px; flex-wrap: wrap; }
.info-output textarea {
    font-family: 'SF Mono', 'Fira Code', 'Consolas', monospace !important;
    font-size: 0.88em !important;
    line-height: 1.6 !important;
}
"""

with gr.Blocks(
    title="LandmarkDiff -- Facial Surgery Prediction",
    theme=gr.themes.Soft(),
    css=_CSS,
) as demo:
    gr.HTML(
        f"""<div class="header-banner">
        <h1>LandmarkDiff</h1>
        <p>
            Anatomically-conditioned facial surgery outcome prediction from standard clinical
            photography. Upload a face photo, select a procedure, adjust intensity, and see
            the predicted result in real time.
        </p>
        <p style="font-size:0.85em; color:#aab;">
            Powered by MediaPipe 478-point face mesh, thin-plate spline warping, and
            procedure-specific anatomical displacement models. Runs entirely on CPU.
            This 2D demo is the foundation -- 3D face reconstruction from phone video
            is on the roadmap.
        </p>
        <div class="link-bar">
            <a href="{GITHUB_URL}">GitHub</a>
            <a href="{GITHUB_URL}/tree/main/docs">Documentation</a>
            <a href="{GITHUB_URL}/wiki">Wiki</a>
            <a href="{GITHUB_URL}/discussions">Discussions</a>
        </div>
        </div>"""
    )

    # -- Tab 1: Single Procedure --
    with gr.Tab("Single Procedure"):
        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(label="Face Photo", type="numpy", height=350)
                procedure = gr.Radio(
                    choices=PROCEDURES,
                    value="rhinoplasty",
                    label="Procedure",
                    info="Select a surgical procedure to simulate",
                )
                # Show a brief description for each procedure
                _proc_desc_md = " | ".join(
                    f"**{k.replace('_', ' ').title()}**: {v}" for k, v in PROCEDURE_INFO.items()
                )
                gr.Markdown(
                    f"<div style='font-size:0.82em;color:#666;line-height:1.5;'>"
                    f"{_proc_desc_md}</div>"
                )
                intensity = gr.Slider(
                    0,
                    100,
                    50,
                    step=1,
                    label="Intensity (%)",
                    info="0 = no change, 100 = maximum effect",
                )
                run_btn = gr.Button("Generate Prediction", variant="primary", size="lg")
                info_box = gr.Textbox(
                    label="Results",
                    lines=10,
                    interactive=False,
                    elem_classes=["info-output"],
                )

            with gr.Column(scale=2):
                with gr.Row():
                    out_wireframe = gr.Image(label="Deformed Wireframe", height=256)
                    out_mask = gr.Image(label="Surgical Mask", height=256)
                with gr.Row():
                    out_result = gr.Image(label="Predicted Result", height=256)
                    out_original = gr.Image(label="Original", height=256)

        if EXAMPLE_IMAGES:
            gr.Examples(
                examples=[[str(p)] for p in EXAMPLE_IMAGES],
                inputs=[input_image],
                label="Example faces (click to load)",
            )

        with gr.Accordion("Photo Tips for Best Results", open=False):
            gr.Markdown(
                "- **Front-facing**: Use a straight-on frontal photo, "
                "not a side profile\n"
                "- **Good lighting**: Even, natural lighting works best. "
                "Avoid harsh shadows\n"
                "- **Neutral expression**: Keep a relaxed, neutral face "
                "for accurate landmark detection\n"
                "- **No obstructions**: Remove glasses, hats, or anything "
                "covering the face\n"
                "- **Resolution**: At least 256x256 pixels. The image will "
                "be resized to 512x512 internally\n"
                "- **Single face**: Make sure only one face is clearly "
                "visible in the frame"
            )

        outputs = [out_wireframe, out_mask, out_result, out_original, info_box]
        _inputs = [input_image, procedure, intensity]
        run_btn.click(fn=process_image, inputs=_inputs, outputs=outputs)
        # Auto-trigger on image upload and procedure change, but not on every
        # slider tick during drag (each tick would re-run TPS on free CPU,
        # causing severe lag). Use .release so it fires once on mouse-up.
        input_image.change(fn=process_image, inputs=_inputs, outputs=outputs)
        procedure.change(fn=process_image, inputs=_inputs, outputs=outputs)
        intensity.release(fn=process_image, inputs=_inputs, outputs=outputs)

    # -- Tab 2: Compare Procedures --
    with gr.Tab("Compare All"):
        gr.Markdown("All six procedures at the same intensity, side by side.")
        with gr.Row():
            with gr.Column(scale=1):
                cmp_image = gr.Image(label="Face Photo", type="numpy", height=300)
                cmp_intensity = gr.Slider(0, 100, 50, step=1, label="Intensity (%)")
                cmp_btn = gr.Button("Compare", variant="primary", size="lg")
            with gr.Column(scale=2):
                cmp_outputs = []
                for row_idx in range(2):
                    with gr.Row():
                        for col_idx in range(3):
                            idx = row_idx * 3 + col_idx
                            if idx < len(PROCEDURES):
                                cmp_outputs.append(
                                    gr.Image(
                                        label=PROCEDURES[idx].replace("_", " ").title(),
                                        height=200,
                                    )
                                )

        if EXAMPLE_IMAGES:
            gr.Examples(
                examples=[[str(p)] for p in EXAMPLE_IMAGES],
                inputs=[cmp_image],
                label="Examples",
            )

        cmp_btn.click(fn=compare_procedures, inputs=[cmp_image, cmp_intensity], outputs=cmp_outputs)

    # -- Tab 3: Intensity Sweep --
    with gr.Tab("Intensity Sweep"):
        gr.Markdown("See a procedure at 0% through 100% in six steps.")
        with gr.Row():
            with gr.Column(scale=1):
                sweep_image = gr.Image(label="Face Photo", type="numpy", height=300)
                sweep_proc = gr.Radio(choices=PROCEDURES, value="rhinoplasty", label="Procedure")
                sweep_btn = gr.Button("Sweep", variant="primary", size="lg")
            with gr.Column(scale=2):
                sweep_gallery = gr.Gallery(label="0% to 100%", columns=3, height=400)

        if EXAMPLE_IMAGES:
            gr.Examples(
                examples=[[str(p)] for p in EXAMPLE_IMAGES],
                inputs=[sweep_image],
                label="Examples",
            )

        sweep_btn.click(
            fn=intensity_sweep,
            inputs=[sweep_image, sweep_proc],
            outputs=[sweep_gallery],
        )

    # -- Tab 4: Symmetry Analysis --
    with gr.Tab("Symmetry Analysis"):
        gr.Markdown(
            "### Bilateral Facial Symmetry\n\n"
            "Analyzes left-right symmetry across five facial regions "
            "(eyes, brows, cheeks, mouth, jaw) using MediaPipe 478-point "
            "face mesh landmark pairs. The midline is computed from the "
            "forehead apex to the chin, and left landmarks are reflected "
            "across it to measure deviation from the right side.\n\n"
            "**Score interpretation:** 90-100 = highly symmetric, "
            "70-89 = mild asymmetry, <70 = notable asymmetry."
        )

        with gr.Tabs():
            # Sub-tab: Single image analysis
            with gr.Tab("Single Photo"):
                with gr.Row():
                    with gr.Column(scale=1):
                        sym_image = gr.Image(
                            label="Face Photo",
                            type="numpy",
                            height=350,
                        )
                        sym_btn = gr.Button(
                            "Analyze Symmetry",
                            variant="primary",
                            size="lg",
                        )
                    with gr.Column(scale=1):
                        sym_overlay = gr.Image(label="Symmetry Overlay", height=350)

                sym_scores_box = gr.Textbox(
                    label="Symmetry Scores",
                    lines=8,
                    interactive=False,
                )

                if EXAMPLE_IMAGES:
                    gr.Examples(
                        examples=[[str(p)] for p in EXAMPLE_IMAGES],
                        inputs=[sym_image],
                        label="Examples",
                    )

                sym_btn.click(
                    fn=analyze_symmetry,
                    inputs=[sym_image],
                    outputs=[sym_overlay, sym_scores_box],
                )

            # Sub-tab: Pre vs post comparison
            with gr.Tab("Pre vs Post Comparison"):
                gr.Markdown(
                    "Upload a pre-procedure and post-procedure photo to compare "
                    "how symmetry changed."
                )
                with gr.Row():
                    sym_pre_image = gr.Image(
                        label="Pre-Procedure",
                        type="numpy",
                        height=300,
                    )
                    sym_post_image = gr.Image(
                        label="Post-Procedure",
                        type="numpy",
                        height=300,
                    )
                sym_cmp_btn = gr.Button(
                    "Compare Symmetry",
                    variant="primary",
                    size="lg",
                )
                with gr.Row():
                    sym_pre_overlay = gr.Image(
                        label="Pre Symmetry Overlay",
                        height=300,
                    )
                    sym_post_overlay = gr.Image(
                        label="Post Symmetry Overlay",
                        height=300,
                    )
                sym_cmp_box = gr.Textbox(
                    label="Comparison",
                    lines=14,
                    interactive=False,
                )

                sym_cmp_btn.click(
                    fn=analyze_symmetry_comparison,
                    inputs=[sym_pre_image, sym_post_image],
                    outputs=[sym_pre_overlay, sym_post_overlay, sym_cmp_box],
                )

    gr.HTML(
        f"<div style='text-align:center;color:#888;font-size:0.78em;padding:12px 8px;"
        f"border-top:1px solid #eee;margin-top:12px;'>"
        f"LandmarkDiff v0.2.2 &middot; TPS on CPU &middot; "
        f"MediaPipe 478-point mesh &middot; "
        f"<a href='{GITHUB_URL}' style='color:#7eb8f7;'>GitHub</a> &middot; "
        f"MIT License &middot; For research and educational purposes only"
        f"</div>"
    )

if __name__ == "__main__":
    demo.launch(show_error=True)