dreamlessx commited on
Commit
ff7e8d0
·
verified ·
1 Parent(s): cfdd827

Upload landmarkdiff/losses.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/losses.py +295 -0
landmarkdiff/losses.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """4-term loss for ControlNet fine-tuning.
2
+
3
+ L_total = L_diff + w_lm * L_landmark + w_id * L_identity + w_perc * L_perceptual
4
+
5
+ Phase A (synthetic TPS data): diffusion loss only. No perceptual against
6
+ rubbery TPS warps - it would penalize realism.
7
+ Phase B (FEM/clinical data): all 4 terms.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class LossWeights:
20
+ """Loss term weights."""
21
+
22
+ diffusion: float = 1.0
23
+ landmark: float = 0.1
24
+ identity: float = 0.05
25
+ perceptual: float = 0.1
26
+
27
+
28
+ class DiffusionLoss:
29
+ """Epsilon-prediction MSE."""
30
+
31
+ def __call__(
32
+ self,
33
+ noise_pred: torch.Tensor,
34
+ noise_target: torch.Tensor,
35
+ ) -> torch.Tensor:
36
+ return F.mse_loss(noise_pred, noise_target)
37
+
38
+
39
+ class LandmarkLoss:
40
+ """L2 landmark distance, IOD-normalized, inside surgical mask only.
41
+
42
+ Requires re-extraction from generated image (eval only, too slow per step).
43
+ """
44
+
45
+ def __call__(
46
+ self,
47
+ pred_landmarks: torch.Tensor, # (B, N, 2)
48
+ target_landmarks: torch.Tensor, # (B, N, 2)
49
+ mask: torch.Tensor | None = None, # (B, N) binary
50
+ iod: torch.Tensor | None = None, # (B,) inter-ocular distance
51
+ ) -> torch.Tensor:
52
+ diff = pred_landmarks - target_landmarks # (B, N, 2)
53
+ dist = torch.norm(diff, dim=-1) # (B, N)
54
+
55
+ if mask is not None:
56
+ dist = dist * mask
57
+ count = mask.sum(dim=-1).clamp(min=1)
58
+ mean_dist = dist.sum(dim=-1) / count
59
+ else:
60
+ mean_dist = dist.mean(dim=-1)
61
+
62
+ if iod is not None:
63
+ mean_dist = mean_dist / iod.clamp(min=1.0)
64
+
65
+ return mean_dist.mean()
66
+
67
+
68
+ class IdentityLoss:
69
+ """ArcFace cosine sim loss, procedure-dependent crop.
70
+
71
+ buffalo_l 512-dim embeddings, falls back to pixel cosine if unavailable.
72
+ Disabled for orthognathic. Images MUST be [-1,1] at 112x112 for ArcFace.
73
+ """
74
+
75
+ def __init__(self, device: torch.device | None = None):
76
+ self._model = None
77
+ self._device = device
78
+ self._has_arcface = None # None = not checked yet
79
+
80
+ def _ensure_loaded(self, device: torch.device) -> None:
81
+ """Lazy-load ArcFace on first call."""
82
+ if self._has_arcface is not None:
83
+ return
84
+ try:
85
+ from insightface.app import FaceAnalysis
86
+ self._app = FaceAnalysis(
87
+ name="buffalo_l",
88
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
89
+ )
90
+ ctx_id = device.index if device.type == "cuda" and device.index is not None else (0 if device.type == "cuda" else -1)
91
+ self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
92
+ self._has_arcface = True
93
+ except Exception:
94
+ self._has_arcface = False
95
+
96
+ @torch.no_grad()
97
+ def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor:
98
+ """(B,3,112,112) in [-1,1] -> (B,512) embeddings (or pixel fallback)."""
99
+ if self._has_arcface:
100
+ import numpy as np
101
+ embeddings = []
102
+ valid_mask = []
103
+ for i in range(image_tensor.shape[0]):
104
+ # Convert to uint8 BGR for InsightFace
105
+ img = ((image_tensor[i].permute(1, 2, 0) + 1) / 2 * 255).clamp(0, 255)
106
+ img_np = img.cpu().numpy().astype(np.uint8)
107
+ img_bgr = img_np[:, :, ::-1].copy()
108
+
109
+ faces = self._app.get(img_bgr)
110
+ if faces and hasattr(faces[0], "embedding") and faces[0].embedding is not None:
111
+ embeddings.append(torch.from_numpy(faces[0].embedding))
112
+ valid_mask.append(True)
113
+ else:
114
+ embeddings.append(torch.zeros(512))
115
+ valid_mask.append(False)
116
+
117
+ return torch.stack(embeddings).to(image_tensor.device), valid_mask
118
+ else:
119
+ # Fallback: pixel-level features
120
+ return image_tensor.flatten(1), [True] * image_tensor.shape[0]
121
+
122
+ def __call__(
123
+ self,
124
+ pred_image: torch.Tensor, # (B, 3, H, W) in [0, 1]
125
+ target_image: torch.Tensor,
126
+ procedure: str = "rhinoplasty",
127
+ ) -> torch.Tensor:
128
+ if procedure == "orthognathic":
129
+ return torch.tensor(0.0, device=pred_image.device)
130
+
131
+ self._ensure_loaded(pred_image.device)
132
+
133
+ # Crop based on procedure
134
+ pred_crop = self._procedure_crop(pred_image, procedure)
135
+ target_crop = self._procedure_crop(target_image, procedure)
136
+
137
+ # Resize to 112x112 for ArcFace
138
+ pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
139
+ target_112 = F.interpolate(target_crop, size=(112, 112), mode="bilinear", align_corners=False)
140
+
141
+ # Normalize to [-1, 1]
142
+ pred_norm = pred_112 * 2 - 1
143
+ target_norm = target_112 * 2 - 1
144
+
145
+ # Extract embeddings (ArcFace or fallback)
146
+ pred_emb, pred_valid = self._extract_embedding(pred_norm)
147
+ target_emb, target_valid = self._extract_embedding(target_norm)
148
+
149
+ # Only compute loss for samples where both faces were detected
150
+ valid = [p and t for p, t in zip(pred_valid, target_valid)]
151
+ if not any(valid):
152
+ return torch.tensor(0.0, device=pred_image.device)
153
+
154
+ valid_t = torch.tensor(valid, device=pred_image.device)
155
+
156
+ # L2 normalize (safe, only valid embeddings have nonzero norm)
157
+ pred_emb = F.normalize(pred_emb.float(), dim=1)
158
+ target_emb = F.normalize(target_emb.float(), dim=1)
159
+
160
+ cosine_sim = (pred_emb * target_emb).sum(dim=1)
161
+ # Zero out invalid entries before averaging
162
+ cosine_sim = cosine_sim * valid_t.float()
163
+ return (1 - cosine_sim).sum() / valid_t.float().sum()
164
+
165
+ def _procedure_crop(
166
+ self,
167
+ image: torch.Tensor,
168
+ procedure: str,
169
+ ) -> torch.Tensor:
170
+ """Procedure-specific crop for identity comparison."""
171
+ _, _, h, w = image.shape
172
+
173
+ if procedure == "rhinoplasty":
174
+ # Upper face crop (forehead to nose tip)
175
+ return image[:, :, : h * 2 // 3, :]
176
+ elif procedure == "blepharoplasty":
177
+ # Full face
178
+ return image
179
+ elif procedure == "rhytidectomy":
180
+ # Upper face (above jawline)
181
+ return image[:, :, : h * 3 // 4, :]
182
+ else:
183
+ return image
184
+
185
+
186
+ class PerceptualLoss:
187
+ """LPIPS outside surgical mask only. Remember: LPIPS wants [-1,1], VAE gives [0,1]."""
188
+
189
+ def __init__(self):
190
+ self._lpips = None
191
+
192
+ def _ensure_loaded(self, device: torch.device) -> None:
193
+ if self._lpips is None:
194
+ try:
195
+ import lpips
196
+ self._lpips = lpips.LPIPS(net="alex").to(device)
197
+ self._lpips.eval()
198
+ for p in self._lpips.parameters():
199
+ p.requires_grad_(False)
200
+ except ImportError:
201
+ self._lpips = "unavailable"
202
+
203
+ def __call__(
204
+ self,
205
+ pred: torch.Tensor, # (B, 3, H, W) in [0, 1]
206
+ target: torch.Tensor,
207
+ mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1]
208
+ ) -> torch.Tensor:
209
+ self._ensure_loaded(pred.device)
210
+
211
+ # Invert mask: we want loss OUTSIDE surgical region
212
+ outside_mask = 1 - mask
213
+
214
+ # Erode outside_mask by a few pixels to avoid artificial edge features
215
+ # at the mask boundary (LPIPS VGG detects the hard 0->value transition)
216
+ erode_kernel = 5
217
+ if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
218
+ outside_mask = -F.max_pool2d(
219
+ -outside_mask,
220
+ kernel_size=erode_kernel,
221
+ stride=1,
222
+ padding=erode_kernel // 2,
223
+ )
224
+
225
+ # Normalize to [-1, 1] for LPIPS FIRST, then mask
226
+ pred_norm = pred * 2 - 1
227
+ target_norm = target * 2 - 1
228
+
229
+ # Apply mask after normalization (masked regions become 0, not -1)
230
+ pred_norm = pred_norm * outside_mask
231
+ target_norm = target_norm * outside_mask
232
+
233
+ if self._lpips == "unavailable":
234
+ # Fallback: simple L1 loss
235
+ return F.l1_loss(pred_norm, target_norm)
236
+
237
+ return self._lpips(pred_norm, target_norm).mean()
238
+
239
+
240
+ class CombinedLoss:
241
+ """4-term combined loss. phase='A' = diffusion only, phase='B' = all terms."""
242
+
243
+ def __init__(
244
+ self,
245
+ weights: LossWeights | None = None,
246
+ phase: str = "A",
247
+ ):
248
+ self.weights = weights or LossWeights()
249
+ self.phase = phase
250
+ self.diffusion_loss = DiffusionLoss()
251
+ self.landmark_loss = LandmarkLoss()
252
+ self.identity_loss = IdentityLoss()
253
+ self.perceptual_loss = PerceptualLoss()
254
+
255
+ def __call__(
256
+ self,
257
+ noise_pred: torch.Tensor,
258
+ noise_target: torch.Tensor,
259
+ **kwargs,
260
+ ) -> dict[str, torch.Tensor]:
261
+ losses = {}
262
+
263
+ # Always compute diffusion loss
264
+ losses["diffusion"] = self.weights.diffusion * self.diffusion_loss(noise_pred, noise_target)
265
+ losses["total"] = losses["diffusion"]
266
+
267
+ if self.phase == "B":
268
+ # Phase B: add auxiliary losses
269
+ if "pred_landmarks" in kwargs and "target_landmarks" in kwargs:
270
+ losses["landmark"] = self.weights.landmark * self.landmark_loss(
271
+ kwargs["pred_landmarks"],
272
+ kwargs["target_landmarks"],
273
+ kwargs.get("landmark_mask"),
274
+ kwargs.get("iod"),
275
+ )
276
+ losses["total"] = losses["total"] + losses["landmark"]
277
+
278
+ if "pred_image" in kwargs and "target_image" in kwargs:
279
+ procedure = kwargs.get("procedure", "rhinoplasty")
280
+ losses["identity"] = self.weights.identity * self.identity_loss(
281
+ kwargs["pred_image"],
282
+ kwargs["target_image"],
283
+ procedure,
284
+ )
285
+ losses["total"] = losses["total"] + losses["identity"]
286
+
287
+ if "pred_image" in kwargs and "target_image" in kwargs and "mask" in kwargs:
288
+ losses["perceptual"] = self.weights.perceptual * self.perceptual_loss(
289
+ kwargs["pred_image"],
290
+ kwargs["target_image"],
291
+ kwargs["mask"],
292
+ )
293
+ losses["total"] = losses["total"] + losses["perceptual"]
294
+
295
+ return losses