File size: 25,666 Bytes
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eac09b2
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
 
 
 
 
 
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
f790aa8
 
 
 
 
 
 
 
 
 
 
 
 
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433e26f
eac09b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
"""Data-driven surgical displacement extraction and modeling.

Extracts real landmark displacements from before/after surgery image pairs,
classifies procedures based on regional displacement patterns, and fits
per-procedure statistical models that can replace the hand-tuned RBF
displacement vectors in ``manipulation.py``.

Typical usage::

    from landmarkdiff.displacement_model import (
        extract_displacements,
        extract_from_directory,
        DisplacementModel,
    )

    # Single pair
    result = extract_displacements(before_img, after_img)

    # Batch from directory
    all_displacements = extract_from_directory("data/surgery_pairs/")

    # Fit model
    model = DisplacementModel()
    model.fit(all_displacements)
    model.save("displacement_model.npz")

    # Generate displacement field
    field = model.get_displacement_field("rhinoplasty", intensity=0.7)
"""

from __future__ import annotations

import json
import logging
from pathlib import Path

import cv2
import numpy as np

from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks
from landmarkdiff.manipulation import PROCEDURE_LANDMARKS

logger = logging.getLogger(__name__)

# Number of MediaPipe Face Mesh landmarks (468 face + 10 iris)
NUM_LANDMARKS = 478

# All supported procedures
PROCEDURES = list(PROCEDURE_LANDMARKS.keys())


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _normalized_coords_2d(face: FaceLandmarks) -> np.ndarray:
    """Extract (478, 2) normalized [0, 1] coordinates from a FaceLandmarks object.

    ``FaceLandmarks.landmarks`` is (478, 3) with (x, y, z) in normalized space.
    We take only the x, y columns.
    """
    return face.landmarks[:, :2].copy()


def _compute_alignment_quality(
    landmarks_before: np.ndarray,
    landmarks_after: np.ndarray,
) -> float:
    """Estimate alignment quality between two landmark sets.

    Uses a Procrustes-style analysis on landmarks that should *not* move during
    surgery (forehead, temples, ears) to measure how well the faces are aligned.
    A score of 1.0 means perfect alignment; lower values indicate pose/scale
    mismatches that contaminate the displacement signal.

    Args:
        landmarks_before: (478, 2) normalized coordinates.
        landmarks_after: (478, 2) normalized coordinates.

    Returns:
        Quality score in [0, 1].
    """
    # Stable landmarks: forehead, temple region, outer face oval
    # These should exhibit near-zero displacement after surgery.
    stable_indices = [
        10,
        109,
        67,
        103,
        54,
        21,
        162,
        127,  # left forehead/temple
        338,
        297,
        332,
        284,
        251,
        389,
        356,
        454,  # right forehead/temple
        234,
        93,  # outer cheek anchors
    ]
    stable_indices = [i for i in stable_indices if i < NUM_LANDMARKS]

    before_stable = landmarks_before[stable_indices]
    after_stable = landmarks_after[stable_indices]

    # RMS displacement on stable points
    diffs = after_stable - before_stable
    rms = np.sqrt(np.mean(np.sum(diffs**2, axis=1)))

    # Map RMS to quality: 0 displacement -> 1.0, rms >= 0.05 (5% of image) -> 0.0
    quality = float(np.clip(1.0 - rms / 0.05, 0.0, 1.0))
    return quality


# ---------------------------------------------------------------------------
# Procedure classification
# ---------------------------------------------------------------------------


def classify_procedure(displacements: np.ndarray) -> str:
    """Classify which surgical procedure was performed from displacement vectors.

    Computes the mean displacement magnitude within each procedure's landmark
    region (as defined by ``PROCEDURE_LANDMARKS``) and returns the procedure
    with the highest regional activity.

    Args:
        displacements: (478, 2) displacement vectors (after - before) in
            normalized coordinate space.

    Returns:
        Procedure name string, one of ``PROCEDURES``, or ``"unknown"`` if
        no region shows significant displacement.
    """
    magnitudes = np.linalg.norm(displacements, axis=1)

    best_procedure = "unknown"
    best_score = 0.0

    for procedure, indices in PROCEDURE_LANDMARKS.items():
        valid_indices = [i for i in indices if i < len(magnitudes)]
        if not valid_indices:
            continue

        region_mag = magnitudes[valid_indices]
        # Use mean magnitude in the region as the score
        score = float(np.mean(region_mag))

        if score > best_score:
            best_score = score
            best_procedure = procedure

    # If the best score is negligible, classify as unknown
    # Threshold: mean displacement < 0.002 (~1 pixel at 512x512)
    if best_score < 0.002:
        logger.debug(
            "No significant displacement detected (best=%.5f). Classified as 'unknown'.",
            best_score,
        )
        return "unknown"

    return best_procedure


# ---------------------------------------------------------------------------
# Single-pair extraction
# ---------------------------------------------------------------------------


def extract_displacements(
    before_img: np.ndarray,
    after_img: np.ndarray,
    min_detection_confidence: float = 0.5,
) -> dict | None:
    """Extract landmark displacements from a before/after surgery image pair.

    Runs MediaPipe Face Mesh on both images, computes per-landmark
    displacement vectors, classifies the procedure, and evaluates
    alignment quality.

    Args:
        before_img: Pre-surgery BGR image as numpy array.
        after_img: Post-surgery BGR image as numpy array.
        min_detection_confidence: Minimum face detection confidence for
            MediaPipe (default 0.5).

    Returns:
        Dictionary with keys:
            - ``landmarks_before``: (478, 2) normalized coordinates
            - ``landmarks_after``: (478, 2) normalized coordinates
            - ``displacements``: (478, 2) displacement vectors
            - ``magnitude``: (478,) per-landmark displacement magnitudes
            - ``procedure``: classified procedure name or ``"unknown"``
            - ``quality_score``: float in [0, 1] indicating alignment quality

        Returns ``None`` if face detection fails on either image.
    """
    # Extract landmarks from both images
    face_before = extract_landmarks(before_img, min_detection_confidence=min_detection_confidence)
    if face_before is None:
        logger.warning("Face detection failed on before image.")
        return None

    face_after = extract_landmarks(after_img, min_detection_confidence=min_detection_confidence)
    if face_after is None:
        logger.warning("Face detection failed on after image.")
        return None

    # Get normalized 2D coordinates
    coords_before = _normalized_coords_2d(face_before)
    coords_after = _normalized_coords_2d(face_after)

    # Compute displacements
    displacements = coords_after - coords_before
    magnitudes = np.linalg.norm(displacements, axis=1)

    # Classify procedure
    procedure = classify_procedure(displacements)

    # Evaluate alignment quality
    quality = _compute_alignment_quality(coords_before, coords_after)

    return {
        "landmarks_before": coords_before,
        "landmarks_after": coords_after,
        "displacements": displacements,
        "magnitude": magnitudes,
        "procedure": procedure,
        "quality_score": quality,
    }


# ---------------------------------------------------------------------------
# Batch extraction from directory
# ---------------------------------------------------------------------------


def extract_from_directory(
    pairs_dir: str | Path,
    min_detection_confidence: float = 0.5,
    min_quality: float = 0.0,
) -> list[dict]:
    """Batch-extract displacements from a directory of before/after image pairs.

    Supports two naming conventions:
        - ``<name>_before.{png,jpg,...}`` / ``<name>_after.{png,jpg,...}``
        - ``<name>_input.{png,jpg,...}`` / ``<name>_target.{png,jpg,...}``

    Args:
        pairs_dir: Path to directory containing image pairs.
        min_detection_confidence: Passed to ``extract_displacements``.
        min_quality: Minimum alignment quality score to include a pair
            in the results (default 0.0 = include all).

    Returns:
        List of displacement dictionaries (same format as
        ``extract_displacements``), each augmented with:
            - ``pair_name``: stem of the pair (e.g. ``"patient_001"``)
            - ``before_path``: path to the before image
            - ``after_path``: path to the after image
    """
    pairs_dir = Path(pairs_dir)
    if not pairs_dir.is_dir():
        raise FileNotFoundError(f"Directory not found: {pairs_dir}")

    # Collect all image files
    image_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".webp"}
    all_files = {
        f.stem.lower(): f
        for f in pairs_dir.iterdir()
        if f.is_file() and f.suffix.lower() in image_extensions
    }

    # Find pairs using both naming conventions
    pairs: list[tuple[str, Path, Path]] = []
    seen_stems: set[str] = set()

    for stem_lower, filepath in all_files.items():
        # Convention 1: *_before / *_after
        for before_suffix, after_suffix in [("_before", "_after"), ("_input", "_target")]:
            if stem_lower.endswith(before_suffix):
                base = stem_lower[: -len(before_suffix)]
                after_stem = base + after_suffix
                if after_stem in all_files and base not in seen_stems:
                    # Use original-case paths
                    before_path = filepath
                    after_path = all_files[after_stem]
                    pairs.append((base, before_path, after_path))
                    seen_stems.add(base)

    if not pairs:
        logger.warning("No image pairs found in %s", pairs_dir)
        return []

    logger.info("Found %d image pairs in %s", len(pairs), pairs_dir)

    results: list[dict] = []
    for pair_name, before_path, after_path in sorted(pairs):
        logger.info("Processing pair: %s", pair_name)

        # Load images
        before_img = cv2.imread(str(before_path))
        if before_img is None:
            logger.warning("Failed to load before image: %s", before_path)
            continue

        after_img = cv2.imread(str(after_path))
        if after_img is None:
            logger.warning("Failed to load after image: %s", after_path)
            continue

        # Extract displacements
        result = extract_displacements(
            before_img, after_img, min_detection_confidence=min_detection_confidence
        )
        if result is None:
            logger.warning("Skipping pair %s: face detection failed.", pair_name)
            continue

        # Filter by quality
        if result["quality_score"] < min_quality:
            logger.info(
                "Skipping pair %s: quality %.3f < threshold %.3f",
                pair_name,
                result["quality_score"],
                min_quality,
            )
            continue

        # Augment with metadata
        result["pair_name"] = pair_name
        result["before_path"] = str(before_path)
        result["after_path"] = str(after_path)
        results.append(result)

    logger.info(
        "Successfully extracted %d / %d pairs (%.0f%%)",
        len(results),
        len(pairs),
        100.0 * len(results) / max(len(pairs), 1),
    )
    return results


# ---------------------------------------------------------------------------
# Displacement model
# ---------------------------------------------------------------------------


class DisplacementModel:
    """Statistical model of per-procedure surgical displacements.

    Aggregates displacement vectors from multiple before/after pairs and
    computes per-procedure, per-landmark statistics (mean, std, min, max).
    Can then generate displacement fields for use in the conditioning
    pipeline, replacing hand-tuned RBF vectors.

    Attributes:
        procedures: List of procedure names the model has data for.
        stats: Nested dict ``{procedure: {stat_name: array}}``.
        n_samples: Dict ``{procedure: int}`` sample counts.
    """

    def __init__(self) -> None:
        self.stats: dict[str, dict[str, np.ndarray]] = {}
        self.n_samples: dict[str, int] = {}
        self._fitted = False

    @property
    def procedures(self) -> list[str]:
        """Return list of procedures the model has been fitted on."""
        return list(self.stats.keys())

    @property
    def fitted(self) -> bool:
        """Whether the model has been fitted."""
        return self._fitted

    def fit(self, displacement_list: list[dict]) -> None:
        """Fit the model from a list of extracted displacement dictionaries.

        Groups displacements by classified procedure and computes per-landmark
        statistics for each group.

        Args:
            displacement_list: List of dicts as returned by
                ``extract_displacements`` or ``extract_from_directory``.
                Each must contain ``"displacements"`` (478, 2) and
                ``"procedure"`` (str) keys.

        Raises:
            ValueError: If ``displacement_list`` is empty or contains no
                valid displacement data.
        """
        if not displacement_list:
            raise ValueError("displacement_list is empty.")

        # Group by procedure
        procedure_groups: dict[str, list[np.ndarray]] = {}
        for entry in displacement_list:
            proc = entry.get("procedure", "unknown")
            disp = entry.get("displacements")
            if disp is None:
                logger.warning("Skipping entry without 'displacements' key.")
                continue
            if disp.shape != (NUM_LANDMARKS, 2):
                logger.warning(
                    "Skipping entry with unexpected shape %s (expected (%d, 2)).",
                    disp.shape,
                    NUM_LANDMARKS,
                )
                continue

            if proc not in procedure_groups:
                procedure_groups[proc] = []
            procedure_groups[proc].append(disp)

        if not procedure_groups:
            raise ValueError("No valid displacement data found in displacement_list.")

        # Compute per-procedure statistics
        self.stats = {}
        self.n_samples = {}

        for proc, disp_arrays in procedure_groups.items():
            stacked = np.stack(disp_arrays, axis=0)  # (N, 478, 2)
            n = stacked.shape[0]

            self.stats[proc] = {
                "mean": np.mean(stacked, axis=0),  # (478, 2)
                "std": np.std(stacked, axis=0),  # (478, 2)
                "min": np.min(stacked, axis=0),  # (478, 2)
                "max": np.max(stacked, axis=0),  # (478, 2)
                "median": np.median(stacked, axis=0),  # (478, 2)
                "mean_magnitude": np.mean(  # (478,)
                    np.linalg.norm(stacked, axis=2), axis=0
                ),
            }
            self.n_samples[proc] = n
            logger.info(
                "Fitted procedure '%s': %d samples, mean magnitude=%.5f",
                proc,
                n,
                float(np.mean(self.stats[proc]["mean_magnitude"])),
            )

        self._fitted = True

    def get_displacement_field(
        self,
        procedure: str,
        intensity: float = 1.0,
        noise_scale: float = 0.0,
        rng: np.random.Generator | None = None,
    ) -> np.ndarray:
        """Generate a displacement field for a given procedure and intensity.

        Returns the mean displacement scaled by ``intensity``, optionally
        with Gaussian noise added (scaled by per-landmark std).

        Args:
            procedure: Procedure name (must exist in the fitted model).
            intensity: Scaling factor for the mean displacement. 1.0 = average
                observed displacement; 0.5 = half intensity; etc.
            noise_scale: If > 0, adds Gaussian noise with this many standard
                deviations of variation. 0.0 = deterministic mean field.
            rng: NumPy random generator for reproducible noise. If ``None``
                and ``noise_scale > 0``, uses ``np.random.default_rng()``.

        Returns:
            (478, 2) displacement field in normalized coordinate space.

        Raises:
            RuntimeError: If the model has not been fitted.
            KeyError: If the procedure is not in the model.
        """
        if not self._fitted:
            raise RuntimeError("Model has not been fitted. Call fit() first.")

        if procedure not in self.stats:
            available = ", ".join(self.procedures)
            raise KeyError(f"Procedure '{procedure}' not in model. Available: {available}")

        proc_stats = self.stats[procedure]
        field = proc_stats["mean"].copy() * intensity

        if noise_scale > 0:
            if rng is None:
                rng = np.random.default_rng()
            noise = rng.normal(
                loc=0.0,
                scale=proc_stats["std"] * noise_scale,
            )
            field += noise

        return field.astype(np.float32)

    def get_summary(self, procedure: str | None = None) -> dict:
        """Get a human-readable summary of the model statistics.

        Args:
            procedure: If provided, return summary for one procedure.
                If ``None``, return summaries for all procedures.

        Returns:
            Dictionary with summary statistics.
        """
        if not self._fitted:
            return {"fitted": False}

        procs = [procedure] if procedure else self.procedures
        summary = {"fitted": True, "procedures": {}}

        for proc in procs:
            if proc not in self.stats:
                continue
            s = self.stats[proc]
            summary["procedures"][proc] = {
                "n_samples": self.n_samples[proc],
                "global_mean_magnitude": float(np.mean(s["mean_magnitude"])),
                "global_max_magnitude": float(np.max(s["mean_magnitude"])),
                "top_landmarks": _top_k_landmarks(s["mean_magnitude"], k=10),
            }

        return summary

    def save(self, path: str | Path) -> None:
        """Save the fitted model to disk as a ``.npz`` file.

        The file contains:
            - Per-procedure stat arrays keyed as ``{procedure}__{stat_name}``
            - A JSON metadata string with sample counts and procedure list

        Args:
            path: Output file path. Extension ``.npz`` is added if missing.

        Raises:
            RuntimeError: If the model has not been fitted.
        """
        if not self._fitted:
            raise RuntimeError("Model has not been fitted. Call fit() first.")

        path = Path(path)
        if path.suffix != ".npz":
            path = path.with_suffix(".npz")

        arrays: dict[str, np.ndarray] = {}
        for proc, proc_stats in self.stats.items():
            for stat_name, arr in proc_stats.items():
                key = f"{proc}__{stat_name}"
                arrays[key] = arr

        # Store metadata as a JSON string encoded to bytes
        metadata = {
            "procedures": self.procedures,
            "n_samples": self.n_samples,
            "num_landmarks": NUM_LANDMARKS,
        }
        arrays["__metadata__"] = np.frombuffer(json.dumps(metadata).encode("utf-8"), dtype=np.uint8)

        np.savez_compressed(str(path), **arrays)
        logger.info("Saved displacement model to %s", path)

    @classmethod
    def load(cls, path: str | Path) -> DisplacementModel:
        """Load a fitted model from a ``.npz`` file.

        Supports two formats:
        1. ``save()`` format: keys like ``{proc}__{stat}`` with ``__metadata__``
        2. ``extract_displacements.py`` format: keys like ``{proc}_{stat}``
           with a ``procedures`` array

        Args:
            path: Path to the ``.npz`` file.

        Returns:
            A fitted ``DisplacementModel`` instance.

        Raises:
            FileNotFoundError: If the file does not exist.
        """
        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"Model file not found: {path}")

        data = np.load(str(path), allow_pickle=False)
        model = cls()

        # Format 1: save() format with __metadata__
        if "__metadata__" in data.files:
            meta_bytes = data["__metadata__"].tobytes()
            metadata = json.loads(meta_bytes.decode("utf-8"))
            model.n_samples = {k: int(v) for k, v in metadata["n_samples"].items()}

            for proc in metadata["procedures"]:
                model.stats[proc] = {}
                for key in data.files:
                    if key.startswith(f"{proc}__"):
                        stat_name = key[len(f"{proc}__") :]
                        model.stats[proc][stat_name] = data[key]

        # Format 2: extract_displacements.py format with procedures array
        elif "procedures" in data.files:
            procedures = [str(p) for p in data["procedures"]]
            # Map from extraction script key names to DisplacementModel stat names
            stat_map = {
                "mean": "mean",
                "std": "std",
                "median": "median",
                "min": "min",
                "max": "max",
                "mag_mean": "mean_magnitude",
                "mag_std": "std_magnitude",
                "count": "_count",
            }
            for proc in procedures:
                model.stats[proc] = {}
                for ext_key, model_key in stat_map.items():
                    npz_key = f"{proc}_{ext_key}"
                    if npz_key in data.files:
                        arr = data[npz_key]
                        if model_key == "_count":
                            model.n_samples[proc] = int(arr)
                        else:
                            model.stats[proc][model_key] = arr

                # Ensure count is set
                if proc not in model.n_samples:
                    model.n_samples[proc] = 0

        else:
            raise ValueError(f"Unrecognized displacement model format. Keys: {data.files[:10]}")

        # Validate loaded model is not empty
        if not model.stats:
            raise ValueError(
                f"Displacement model at {path} contains no procedure data. "
                f"File may be corrupted or empty. Keys found: {data.files[:10]}"
            )
        for proc, stats in model.stats.items():
            if not stats:
                raise ValueError(
                    f"Displacement model at {path} has no statistics for "
                    f"procedure '{proc}'. File may be corrupted."
                )

        model._fitted = True
        logger.info(
            "Loaded displacement model from %s (%d procedures, %s samples)",
            path,
            len(model.procedures),
            model.n_samples,
        )
        return model


# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------


def _top_k_landmarks(
    magnitudes: np.ndarray,
    k: int = 10,
) -> list[dict]:
    """Return the top-k landmarks by mean displacement magnitude.

    Args:
        magnitudes: (478,) array of per-landmark magnitudes.
        k: Number of top landmarks to return.

    Returns:
        List of dicts with ``index`` and ``magnitude`` keys, sorted
        descending by magnitude.
    """
    top_indices = np.argsort(magnitudes)[::-1][:k]
    return [{"index": int(idx), "magnitude": float(magnitudes[idx])} for idx in top_indices]


def visualize_displacements(
    before_img: np.ndarray,
    result: dict,
    scale: float = 10.0,
    arrow_color: tuple[int, int, int] = (0, 255, 0),
    thickness: int = 1,
) -> np.ndarray:
    """Draw displacement arrows on the before image for visual inspection.

    Args:
        before_img: BGR image (will be copied).
        result: Displacement dict from ``extract_displacements``.
        scale: Arrow length multiplier (displacements are small in
            normalized space, so scale up for visibility).
        arrow_color: BGR color for arrows.
        thickness: Arrow line thickness.

    Returns:
        Annotated BGR image.
    """
    canvas = before_img.copy()
    h, w = canvas.shape[:2]

    coords_before = result["landmarks_before"]
    displacements = result["displacements"]

    for i in range(NUM_LANDMARKS):
        bx = int(coords_before[i, 0] * w)
        by = int(coords_before[i, 1] * h)
        dx = int(displacements[i, 0] * w * scale)
        dy = int(displacements[i, 1] * h * scale)

        # Only draw if displacement is above noise floor
        mag = np.sqrt(dx**2 + dy**2)
        if mag < 1.0:
            continue

        cv2.arrowedLine(
            canvas,
            (bx, by),
            (bx + dx, by + dy),
            arrow_color,
            thickness,
            tipLength=0.3,
        )

    # Add procedure label and quality score
    proc = result.get("procedure", "unknown")
    quality = result.get("quality_score", 0.0)
    label = f"{proc} (quality={quality:.2f})"
    cv2.putText(
        canvas,
        label,
        (10, 30),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.8,
        (255, 255, 255),
        2,
    )

    return canvas