File size: 12,042 Bytes
73fad7a
ff7e8d0
871693c
ff7e8d0
73fad7a
 
 
 
ff7e8d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871693c
 
ff7e8d0
 
 
73fad7a
ff7e8d0
 
 
 
 
 
 
 
 
 
73fad7a
ff7e8d0
73fad7a
 
ff7e8d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73fad7a
 
 
 
ff7e8d0
73fad7a
 
 
 
 
 
ff7e8d0
 
 
 
 
 
 
 
73fad7a
ff7e8d0
 
 
 
 
 
 
 
871693c
ff7e8d0
 
 
 
 
 
 
73fad7a
 
 
 
 
 
 
 
ff7e8d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871693c
ff7e8d0
 
 
 
 
 
 
 
 
 
871693c
ff7e8d0
 
 
73fad7a
 
 
 
 
 
ff7e8d0
73fad7a
 
 
ff7e8d0
73fad7a
 
ff7e8d0
 
 
 
 
 
73fad7a
ff7e8d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73fad7a
 
 
 
 
ff7e8d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871693c
ff7e8d0
871693c
ff7e8d0
 
 
871693c
ff7e8d0
 
 
871693c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff7e8d0
 
 
 
 
 
 
 
 
73fad7a
 
 
 
 
 
 
 
 
 
ff7e8d0
 
 
 
 
73fad7a
 
ff7e8d0
 
 
 
 
 
 
73fad7a
 
 
 
 
 
 
 
ff7e8d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""4-term loss function module for ControlNet fine-tuning.

L_total = L_diffusion + w_landmark * L_landmark + w_identity * L_identity + w_perceptual * L_perceptual

Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
rubbery TPS warps — it would penalize realism.

Phase B (FEM/clinical data): All 4 terms enabled.
"""

from __future__ import annotations

from dataclasses import dataclass

import torch
import torch.nn.functional as F


@dataclass(frozen=True)
class LossWeights:
    """Loss term weights."""

    diffusion: float = 1.0
    landmark: float = 0.1
    identity: float = 0.1
    perceptual: float = 0.05


class DiffusionLoss:
    """Standard epsilon-prediction MSE loss (primary training signal)."""

    def __call__(
        self,
        noise_pred: torch.Tensor,
        noise_target: torch.Tensor,
    ) -> torch.Tensor:
        return F.mse_loss(noise_pred, noise_target)


class LandmarkLoss:
    """L2 landmark distance normalized by inter-ocular distance.

    Computed INSIDE surgical mask only. Requires MediaPipe re-extraction
    from generated image (done at eval, not every training step for speed).
    """

    def __call__(
        self,
        pred_landmarks: torch.Tensor,  # (B, N, 2)
        target_landmarks: torch.Tensor,  # (B, N, 2)
        mask: torch.Tensor | None = None,  # (B, N) binary
        iod: torch.Tensor | None = None,  # (B,) inter-ocular distance
    ) -> torch.Tensor:
        diff = pred_landmarks - target_landmarks  # (B, N, 2)
        dist = torch.norm(diff, dim=-1)  # (B, N)

        if mask is not None:
            dist = dist * mask
            count = mask.sum(dim=-1).clamp(min=1)
            mean_dist = dist.sum(dim=-1) / count
        else:
            mean_dist = dist.mean(dim=-1)

        if iod is not None:
            mean_dist = mean_dist / iod.clamp(min=1.0)

        return mean_dist.mean()


class IdentityLoss:
    """ArcFace cosine similarity loss with procedure-dependent masking.

    Uses InsightFace ArcFace model (buffalo_l) for 512-dim identity embeddings.
    Falls back to pixel-level cosine similarity if InsightFace is unavailable.

    - Full face for blepharoplasty
    - Upper-face crop for rhinoplasty
    - Disabled for orthognathic

    Input images MUST be normalized to [-1, 1] and cropped to 112x112
    before passing to ArcFace (AdaFace outputs garbage for 1024x1024).
    """

    def __init__(self, device: torch.device | None = None):
        self._model = None
        self._device = device
        self._has_arcface = None  # None = not checked yet

    def _ensure_loaded(self, device: torch.device) -> None:
        """Lazy-load ArcFace model on first use."""
        if self._has_arcface is not None:
            return
        try:
            from insightface.app import FaceAnalysis
            self._app = FaceAnalysis(
                name="buffalo_l",
                providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
            )
            ctx_id = device.index if device.type == "cuda" and device.index is not None else (0 if device.type == "cuda" else -1)
            self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
            self._has_arcface = True
        except Exception:
            self._has_arcface = False

    @torch.no_grad()
    def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor:
        """Extract ArcFace embedding from a batch of images.

        Args:
            image_tensor: (B, 3, 112, 112) in [-1, 1]

        Returns:
            (B, 512) identity embeddings, or (B, D) pixel-level if fallback.
        """
        if self._has_arcface:
            import numpy as np
            embeddings = []
            valid_mask = []
            for i in range(image_tensor.shape[0]):
                # Convert to uint8 BGR for InsightFace
                img = ((image_tensor[i].permute(1, 2, 0) + 1) / 2 * 255).clamp(0, 255)
                img_np = img.cpu().numpy().astype(np.uint8)
                img_bgr = img_np[:, :, ::-1].copy()

                faces = self._app.get(img_bgr)
                if faces and hasattr(faces[0], "embedding") and faces[0].embedding is not None:
                    embeddings.append(torch.from_numpy(faces[0].embedding))
                    valid_mask.append(True)
                else:
                    embeddings.append(torch.zeros(512))
                    valid_mask.append(False)

            return torch.stack(embeddings).to(image_tensor.device), valid_mask
        else:
            # Fallback: pixel-level features
            return image_tensor.flatten(1), [True] * image_tensor.shape[0]

    def __call__(
        self,
        pred_image: torch.Tensor,  # (B, 3, H, W) in [0, 1]
        target_image: torch.Tensor,
        procedure: str = "rhinoplasty",
    ) -> torch.Tensor:
        if procedure == "orthognathic":
            return torch.tensor(0.0, device=pred_image.device)

        self._ensure_loaded(pred_image.device)

        # Crop based on procedure
        pred_crop = self._procedure_crop(pred_image, procedure)
        target_crop = self._procedure_crop(target_image, procedure)

        # Resize to 112x112 for ArcFace
        pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
        target_112 = F.interpolate(target_crop, size=(112, 112), mode="bilinear", align_corners=False)

        # Normalize to [-1, 1]
        pred_norm = pred_112 * 2 - 1
        target_norm = target_112 * 2 - 1

        # Extract embeddings (ArcFace or fallback)
        pred_emb, pred_valid = self._extract_embedding(pred_norm)
        target_emb, target_valid = self._extract_embedding(target_norm)

        # Only compute loss for samples where both faces were detected
        valid = [p and t for p, t in zip(pred_valid, target_valid)]
        if not any(valid):
            return torch.tensor(0.0, device=pred_image.device)

        valid_indices = [i for i, v in enumerate(valid) if v]
        valid_idx_t = torch.tensor(valid_indices, device=pred_image.device, dtype=torch.long)

        # Select ONLY valid embeddings before normalization to avoid 0/0 NaN
        pred_valid_emb = pred_emb[valid_idx_t].float()
        target_valid_emb = target_emb[valid_idx_t].float()

        # L2 normalize (safe — zero vectors excluded above)
        pred_valid_emb = F.normalize(pred_valid_emb, dim=1)
        target_valid_emb = F.normalize(target_valid_emb, dim=1)

        cosine_sim = (pred_valid_emb * target_valid_emb).sum(dim=1)
        return (1 - cosine_sim).mean()

    def _procedure_crop(
        self,
        image: torch.Tensor,
        procedure: str,
    ) -> torch.Tensor:
        """Crop image based on procedure for identity comparison."""
        _, _, h, w = image.shape

        if procedure == "rhinoplasty":
            # Upper face crop (forehead to nose tip)
            return image[:, :, : h * 2 // 3, :]
        elif procedure == "blepharoplasty":
            # Full face
            return image
        elif procedure == "rhytidectomy":
            # Upper face (above jawline)
            return image[:, :, : h * 3 // 4, :]
        else:
            return image


class PerceptualLoss:
    """LPIPS perceptual loss on regions OUTSIDE surgical mask only.

    LPIPS expects [-1, 1] input. VAE outputs [0, 1].
    Must apply (x * 2) - 1 before every call.
    """

    def __init__(self):
        self._lpips = None

    def _ensure_loaded(self, device: torch.device) -> None:
        if self._lpips is None:
            try:
                import lpips
                self._lpips = lpips.LPIPS(net="alex").to(device)
                self._lpips.eval()
                for p in self._lpips.parameters():
                    p.requires_grad_(False)
            except ImportError:
                self._lpips = "unavailable"

    def __call__(
        self,
        pred: torch.Tensor,    # (B, 3, H, W) in [0, 1]
        target: torch.Tensor,
        mask: torch.Tensor,    # (B, 1, H, W) surgical mask [0, 1]
    ) -> torch.Tensor:
        self._ensure_loaded(pred.device)

        # Normalize to [-1, 1] for LPIPS
        pred_norm = pred * 2 - 1
        target_norm = target * 2 - 1

        # When mask is all-ones (no mask file available), compute on full image.
        # Otherwise invert mask to get loss OUTSIDE the surgical region only.
        has_mask = mask.sum() < mask.numel() * 0.99
        if has_mask:
            outside_mask = 1 - mask
            erode_kernel = 5
            if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
                outside_mask = -F.max_pool2d(
                    -outside_mask,
                    kernel_size=erode_kernel,
                    stride=1,
                    padding=erode_kernel // 2,
                )
            pred_norm = pred_norm * outside_mask
            target_norm = target_norm * outside_mask

        if self._lpips == "unavailable":
            # Fallback: simple L1 loss
            return F.l1_loss(pred_norm, target_norm)

        return self._lpips(pred_norm, target_norm).mean()


class CombinedLoss:
    """Combined 4-term loss with configurable weights.

    Use phase='A' for Phase A training (diffusion only).
    Use phase='B' for Phase B training (all terms).

    For Phase B, set ``use_differentiable_arcface=True`` to use the
    PyTorch-native ArcFace backbone (``arcface_torch.py``) that provides
    actual gradient signal. The default ONNX-based IdentityLoss produces
    zero gradients (DA2-03).
    """

    def __init__(
        self,
        weights: LossWeights | None = None,
        phase: str = "A",
        use_differentiable_arcface: bool = False,
        arcface_weights_path: str | None = None,
    ):
        self.weights = weights or LossWeights()
        self.phase = phase
        self.diffusion_loss = DiffusionLoss()
        self.landmark_loss = LandmarkLoss()
        self.perceptual_loss = PerceptualLoss()

        # Identity loss: differentiable PyTorch ArcFace for Phase B,
        # or ONNX-based fallback
        if use_differentiable_arcface:
            from landmarkdiff.arcface_torch import ArcFaceLoss
            self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
        else:
            self.identity_loss = IdentityLoss()

    def __call__(
        self,
        noise_pred: torch.Tensor,
        noise_target: torch.Tensor,
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        losses = {}

        # Always compute diffusion loss
        losses["diffusion"] = self.weights.diffusion * self.diffusion_loss(noise_pred, noise_target)
        losses["total"] = losses["diffusion"]

        if self.phase == "B":
            # Phase B: add auxiliary losses
            if "pred_landmarks" in kwargs and "target_landmarks" in kwargs:
                losses["landmark"] = self.weights.landmark * self.landmark_loss(
                    kwargs["pred_landmarks"],
                    kwargs["target_landmarks"],
                    kwargs.get("landmark_mask"),
                    kwargs.get("iod"),
                )
                losses["total"] = losses["total"] + losses["landmark"]

            if "pred_image" in kwargs and "target_image" in kwargs:
                procedure = kwargs.get("procedure", "rhinoplasty")
                losses["identity"] = self.weights.identity * self.identity_loss(
                    kwargs["pred_image"],
                    kwargs["target_image"],
                    procedure,
                )
                losses["total"] = losses["total"] + losses["identity"]

            if "pred_image" in kwargs and "target_image" in kwargs and "mask" in kwargs:
                losses["perceptual"] = self.weights.perceptual * self.perceptual_loss(
                    kwargs["pred_image"],
                    kwargs["target_image"],
                    kwargs["mask"],
                )
                losses["total"] = losses["total"] + losses["perceptual"]

        return losses