dreamlessx commited on
Commit
b163477
·
verified ·
1 Parent(s): 6421899

Upload landmarkdiff/arcface_torch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/arcface_torch.py +678 -0
landmarkdiff/arcface_torch.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch-native ArcFace model for differentiable identity loss.
2
+
3
+ Drop-in replacement for the ONNX-based InsightFace ArcFace used in losses.py.
4
+ The original IdentityLoss extracts embeddings under @torch.no_grad(), which
5
+ means the identity loss term contributes zero gradients during Phase B training.
6
+ This module provides a fully differentiable path so that gradients flow back
7
+ through the predicted image into the ControlNet.
8
+
9
+ Architecture: IResNet-50 (the standard ArcFace backbone from InsightFace).
10
+ conv1(3->64, 3x3) -> BN -> PReLU ->
11
+ 4 IResNet blocks [3, 4, 14, 3] with channels [64, 128, 256, 512] ->
12
+ BN -> Dropout -> Flatten -> FC(512*7*7 -> 512) -> BN (no bias)
13
+ -> L2-normalize
14
+
15
+ Each IBasicBlock: conv3x3-BN-PReLU-conv3x3-BN + SE attention + residual.
16
+
17
+ Pretrained weights: InsightFace distributes IResNet-50 as a PyTorch .pth
18
+ (backbone.pth inside the buffalo_l model pack). This module can load those
19
+ weights directly, or fall back to random initialization with a warning.
20
+
21
+ Usage in losses.py:
22
+ from landmarkdiff.arcface_torch import ArcFaceLoss
23
+ identity_loss = ArcFaceLoss(device=device)
24
+ loss = identity_loss(pred_image, target_image) # gradients flow through pred
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import logging
30
+ import warnings
31
+ from pathlib import Path
32
+ from typing import Optional
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Building blocks
43
+ # ---------------------------------------------------------------------------
44
+
45
+ class SEModule(nn.Module):
46
+ """Squeeze-and-Excitation channel attention (Hu et al., 2018).
47
+
48
+ Reduces channels by ``reduction``, applies ReLU, expands back, and uses
49
+ sigmoid gating on the original feature map.
50
+ """
51
+
52
+ def __init__(self, channels: int, reduction: int = 4):
53
+ super().__init__()
54
+ mid = channels // reduction
55
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
56
+ self.fc1 = nn.Conv2d(channels, mid, kernel_size=1, bias=True)
57
+ self.relu = nn.ReLU(inplace=True)
58
+ self.fc2 = nn.Conv2d(mid, channels, kernel_size=1, bias=True)
59
+ self.sigmoid = nn.Sigmoid()
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ w = self.avg_pool(x)
63
+ w = self.relu(self.fc1(w))
64
+ w = self.sigmoid(self.fc2(w))
65
+ return x * w
66
+
67
+
68
+ class IBasicBlock(nn.Module):
69
+ """Improved basic residual block for IResNet.
70
+
71
+ Structure: BN -> conv3x3 -> BN -> PReLU -> conv3x3 -> BN -> SE -> + residual
72
+ Uses pre-activation style BatchNorm and includes SE attention.
73
+ """
74
+
75
+ expansion: int = 1
76
+
77
+ def __init__(
78
+ self,
79
+ inplanes: int,
80
+ planes: int,
81
+ stride: int = 1,
82
+ downsample: Optional[nn.Module] = None,
83
+ use_se: bool = True,
84
+ ):
85
+ super().__init__()
86
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-5)
87
+ self.conv1 = nn.Conv2d(
88
+ inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False,
89
+ )
90
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-5)
91
+ self.prelu = nn.PReLU(planes)
92
+ self.conv2 = nn.Conv2d(
93
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False,
94
+ )
95
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-5)
96
+
97
+ self.se_module = SEModule(planes) if use_se else nn.Identity()
98
+ self.downsample = downsample
99
+ self.stride = stride
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ identity = x
103
+
104
+ out = self.bn1(x)
105
+ out = self.conv1(out)
106
+ out = self.bn2(out)
107
+ out = self.prelu(out)
108
+ out = self.conv2(out)
109
+ out = self.bn3(out)
110
+ out = self.se_module(out)
111
+
112
+ if self.downsample is not None:
113
+ identity = self.downsample(x)
114
+
115
+ out = out + identity
116
+ return out
117
+
118
+
119
+ # ---------------------------------------------------------------------------
120
+ # Backbone
121
+ # ---------------------------------------------------------------------------
122
+
123
+ class ArcFaceBackbone(nn.Module):
124
+ """IResNet-50 backbone for ArcFace identity embeddings.
125
+
126
+ Input: (B, 3, 112, 112) face crops normalized to [-1, 1].
127
+ Output: (B, 512) L2-normalized embeddings.
128
+
129
+ Architecture follows the InsightFace IResNet-50 exactly so that
130
+ pretrained weights can be loaded without key remapping.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ layers: tuple[int, ...] = (3, 4, 14, 3),
136
+ dropout_rate: float = 0.0,
137
+ embedding_dim: int = 512,
138
+ use_se: bool = True,
139
+ ):
140
+ super().__init__()
141
+ self.inplanes = 64
142
+ self.use_se = use_se
143
+
144
+ # Stem: conv1 -> BN -> PReLU
145
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
146
+ self.bn1 = nn.BatchNorm2d(64, eps=1e-5)
147
+ self.prelu = nn.PReLU(64)
148
+
149
+ # 4 residual stages
150
+ self.layer1 = self._make_layer(IBasicBlock, 64, layers[0], stride=2)
151
+ self.layer2 = self._make_layer(IBasicBlock, 128, layers[1], stride=2)
152
+ self.layer3 = self._make_layer(IBasicBlock, 256, layers[2], stride=2)
153
+ self.layer4 = self._make_layer(IBasicBlock, 512, layers[3], stride=2)
154
+
155
+ # Head: BN -> Dropout -> Flatten -> FC -> BN
156
+ self.bn2 = nn.BatchNorm2d(512 * IBasicBlock.expansion, eps=1e-5)
157
+ self.dropout = nn.Dropout(p=dropout_rate, inplace=True)
158
+ self.fc = nn.Linear(512 * IBasicBlock.expansion * 7 * 7, embedding_dim)
159
+ self.features = nn.BatchNorm1d(embedding_dim, eps=1e-5)
160
+ # InsightFace convention: final BN has no bias
161
+ nn.init.constant_(self.features.weight, 1.0)
162
+ self.features.bias.requires_grad_(False)
163
+
164
+ # Weight initialization
165
+ self._initialize_weights()
166
+
167
+ def _make_layer(
168
+ self,
169
+ block: type[IBasicBlock],
170
+ planes: int,
171
+ num_blocks: int,
172
+ stride: int = 1,
173
+ ) -> nn.Sequential:
174
+ downsample = None
175
+ if stride != 1 or self.inplanes != planes * block.expansion:
176
+ downsample = nn.Sequential(
177
+ nn.Conv2d(
178
+ self.inplanes,
179
+ planes * block.expansion,
180
+ kernel_size=1,
181
+ stride=stride,
182
+ bias=False,
183
+ ),
184
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-5),
185
+ )
186
+
187
+ layers = [
188
+ block(self.inplanes, planes, stride, downsample, use_se=self.use_se),
189
+ ]
190
+ self.inplanes = planes * block.expansion
191
+ for _ in range(1, num_blocks):
192
+ layers.append(
193
+ block(self.inplanes, planes, stride=1, use_se=self.use_se),
194
+ )
195
+
196
+ return nn.Sequential(*layers)
197
+
198
+ def _initialize_weights(self) -> None:
199
+ for m in self.modules():
200
+ if isinstance(m, nn.Conv2d):
201
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
202
+ if m.bias is not None:
203
+ nn.init.constant_(m.bias, 0)
204
+ elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
205
+ nn.init.constant_(m.weight, 1)
206
+ if m.bias is not None:
207
+ nn.init.constant_(m.bias, 0)
208
+ elif isinstance(m, nn.Linear):
209
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
210
+ if m.bias is not None:
211
+ nn.init.constant_(m.bias, 0)
212
+
213
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
214
+ """
215
+ Args:
216
+ x: (B, 3, 112, 112) in [-1, 1].
217
+
218
+ Returns:
219
+ (B, 512) L2-normalized embeddings.
220
+ """
221
+ x = self.conv1(x)
222
+ x = self.bn1(x)
223
+ x = self.prelu(x)
224
+
225
+ x = self.layer1(x)
226
+ x = self.layer2(x)
227
+ x = self.layer3(x)
228
+ x = self.layer4(x)
229
+
230
+ x = self.bn2(x)
231
+ x = self.dropout(x)
232
+ x = torch.flatten(x, 1)
233
+ x = self.fc(x)
234
+ x = self.features(x)
235
+
236
+ # L2 normalize
237
+ x = F.normalize(x, p=2, dim=1)
238
+ return x
239
+
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # Pretrained weight loading
243
+ # ---------------------------------------------------------------------------
244
+
245
+ # Known locations where InsightFace buffalo_l backbone.pth may live
246
+ _KNOWN_WEIGHT_PATHS = [
247
+ Path.home() / ".insightface" / "models" / "buffalo_l" / "w600k_r50.onnx",
248
+ Path.home() / ".insightface" / "models" / "buffalo_l" / "backbone.pth",
249
+ # Common manual download location
250
+ Path.home() / ".cache" / "arcface" / "backbone.pth",
251
+ ]
252
+
253
+ # Glint360K R50 weights URL (InsightFace official release)
254
+ _WEIGHT_URL = (
255
+ "https://github.com/deepinsight/insightface/releases/download/"
256
+ "v0.7/glint360k_cosface_r50_fp16_0.1-backbone.pth"
257
+ )
258
+
259
+
260
+ def _find_pretrained_weights() -> Optional[Path]:
261
+ """Search known locations for pretrained IResNet-50 weights."""
262
+ for p in _KNOWN_WEIGHT_PATHS:
263
+ if p.exists() and p.suffix == ".pth":
264
+ return p
265
+ return None
266
+
267
+
268
+ def _try_download_weights(dest: Path) -> bool:
269
+ """Attempt to download pretrained weights from the InsightFace release."""
270
+ try:
271
+ import urllib.request
272
+ dest.parent.mkdir(parents=True, exist_ok=True)
273
+ logger.info("Downloading ArcFace IResNet-50 weights from %s ...", _WEIGHT_URL)
274
+ urllib.request.urlretrieve(_WEIGHT_URL, str(dest))
275
+ logger.info("Downloaded to %s", dest)
276
+ return True
277
+ except Exception as e:
278
+ logger.warning("Failed to download ArcFace weights: %s", e)
279
+ return False
280
+
281
+
282
+ def load_pretrained_weights(
283
+ model: ArcFaceBackbone,
284
+ weights_path: Optional[str] = None,
285
+ download: bool = True,
286
+ ) -> bool:
287
+ """Load pretrained InsightFace IResNet-50 weights into the model.
288
+
289
+ InsightFace distributes backbone weights as PyTorch state dicts. The key
290
+ names match our module structure exactly (both follow the IResNet
291
+ convention), so no key remapping is needed in most cases.
292
+
293
+ Args:
294
+ model: An ``ArcFaceBackbone`` instance.
295
+ weights_path: Explicit path to a ``.pth`` file. If ``None``, searches
296
+ known locations and optionally downloads.
297
+ download: Whether to attempt downloading if no local weights found.
298
+
299
+ Returns:
300
+ ``True`` if weights were loaded successfully, ``False`` otherwise
301
+ (model keeps random initialization).
302
+ """
303
+ path: Optional[Path] = None
304
+
305
+ if weights_path is not None:
306
+ path = Path(weights_path)
307
+ if not path.exists():
308
+ logger.warning("Specified weights path does not exist: %s", path)
309
+ path = None
310
+
311
+ if path is None:
312
+ path = _find_pretrained_weights()
313
+
314
+ if path is None and download:
315
+ dest = Path.home() / ".cache" / "arcface" / "backbone.pth"
316
+ if _try_download_weights(dest):
317
+ path = dest
318
+
319
+ if path is None:
320
+ warnings.warn(
321
+ "No pretrained ArcFace weights found. The model will use random "
322
+ "initialization. Identity loss values will be meaningless until "
323
+ "proper weights are loaded. Place backbone.pth at "
324
+ f"{Path.home() / '.cache' / 'arcface' / 'backbone.pth'}",
325
+ UserWarning,
326
+ stacklevel=2,
327
+ )
328
+ return False
329
+
330
+ logger.info("Loading ArcFace weights from %s", path)
331
+ state_dict = torch.load(str(path), map_location="cpu", weights_only=True)
332
+
333
+ # Handle the case where the checkpoint wraps the state dict
334
+ if "state_dict" in state_dict:
335
+ state_dict = state_dict["state_dict"]
336
+
337
+ # Try direct load first (InsightFace uses the same key names)
338
+ try:
339
+ model.load_state_dict(state_dict, strict=True)
340
+ logger.info("Loaded ArcFace weights (strict match)")
341
+ return True
342
+ except RuntimeError:
343
+ pass
344
+
345
+ # Try non-strict load (some checkpoints have extra keys like the
346
+ # classification head 'fc_angular.*' or use 'output_layer' instead
347
+ # of 'features' for the final BN)
348
+ try:
349
+ # Remap common differences
350
+ remapped = {}
351
+ for k, v in state_dict.items():
352
+ new_k = k
353
+ # Some checkpoints use 'output_layer' for the final BatchNorm1d
354
+ if k.startswith("output_layer."):
355
+ new_k = k.replace("output_layer.", "features.")
356
+ remapped[new_k] = v
357
+
358
+ missing, unexpected = model.load_state_dict(remapped, strict=False)
359
+ if missing:
360
+ logger.warning(
361
+ "Missing keys when loading ArcFace weights (may be OK if only "
362
+ "classification head keys): %s",
363
+ missing[:10],
364
+ )
365
+ if unexpected:
366
+ logger.info("Unexpected keys (ignored): %s", unexpected[:10])
367
+ logger.info("Loaded ArcFace weights (non-strict)")
368
+ return True
369
+ except Exception as e:
370
+ warnings.warn(
371
+ f"Failed to load ArcFace weights from {path}: {e}. "
372
+ "Using random initialization.",
373
+ UserWarning,
374
+ stacklevel=2,
375
+ )
376
+ return False
377
+
378
+
379
+ # ---------------------------------------------------------------------------
380
+ # Differentiable face alignment
381
+ # ---------------------------------------------------------------------------
382
+
383
+ def align_face(
384
+ images: torch.Tensor,
385
+ size: int = 112,
386
+ ) -> torch.Tensor:
387
+ """Center-crop and resize face images to (size x size) differentiably.
388
+
389
+ Uses ``F.grid_sample`` with bilinear interpolation so that gradients
390
+ flow back through the spatial transform into the input images.
391
+
392
+ The crop extracts the central 80% of the image (removes background
393
+ padding that is common in generated 512x512 face images) and resizes
394
+ to the target size.
395
+
396
+ Args:
397
+ images: (B, 3, H, W) tensor, any normalization.
398
+ size: Target spatial size (default 112 for ArcFace).
399
+
400
+ Returns:
401
+ (B, 3, size, size) tensor with the same normalization as input.
402
+ """
403
+ B, C, H, W = images.shape
404
+
405
+ if H == size and W == size:
406
+ return images
407
+
408
+ # Crop fraction: keep central 80% to remove background padding
409
+ crop_frac = 0.8
410
+
411
+ # Build a normalized grid [-1, 1] covering the center crop region
412
+ # The grid maps output pixels to input pixel locations
413
+ half_crop = crop_frac / 2.0
414
+ # grid_sample expects coordinates in [-1, 1] where -1 is top-left, +1 is bottom-right
415
+ # Center crop: map [-1, 1] output range to [-crop_frac, +crop_frac] input range
416
+ theta = torch.zeros(B, 2, 3, device=images.device, dtype=images.dtype)
417
+ theta[:, 0, 0] = half_crop # x scale
418
+ theta[:, 1, 1] = half_crop # y scale
419
+ # translation stays 0 (centered)
420
+
421
+ grid = F.affine_grid(theta, [B, C, size, size], align_corners=False)
422
+ aligned = F.grid_sample(
423
+ images, grid, mode="bilinear", padding_mode="border", align_corners=False,
424
+ )
425
+ return aligned
426
+
427
+
428
+ def align_face_no_crop(
429
+ images: torch.Tensor,
430
+ size: int = 112,
431
+ ) -> torch.Tensor:
432
+ """Resize face images to (size x size) without cropping, differentiably.
433
+
434
+ Simple bilinear resize using ``F.grid_sample`` for gradient flow. Use
435
+ this when images are already tightly cropped faces.
436
+
437
+ Args:
438
+ images: (B, 3, H, W) tensor.
439
+ size: Target spatial size.
440
+
441
+ Returns:
442
+ (B, 3, size, size) tensor.
443
+ """
444
+ if images.shape[-2] == size and images.shape[-1] == size:
445
+ return images
446
+ return F.interpolate(
447
+ images, size=(size, size), mode="bilinear", align_corners=False,
448
+ )
449
+
450
+
451
+ # ---------------------------------------------------------------------------
452
+ # ArcFaceLoss: differentiable identity preservation loss
453
+ # ---------------------------------------------------------------------------
454
+
455
+ class ArcFaceLoss(nn.Module):
456
+ """Differentiable identity loss using PyTorch-native ArcFace.
457
+
458
+ Replaces the ONNX-based InsightFace ArcFace in ``IdentityLoss`` from
459
+ ``losses.py``. Gradients flow through the predicted image into the
460
+ generator, while the target embedding is detached.
461
+
462
+ Loss = mean(1 - cosine_similarity(embed(pred), embed(target).detach()))
463
+
464
+ The backbone is frozen (no gradient updates to ArcFace itself) but
465
+ gradients DO flow through the forward pass of the backbone when
466
+ computing pred embeddings.
467
+
468
+ Example::
469
+
470
+ loss_fn = ArcFaceLoss(device=torch.device("cuda"))
471
+ loss = loss_fn(pred_images, target_images)
472
+ loss.backward() # gradients flow into pred_images
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ device: Optional[torch.device] = None,
478
+ weights_path: Optional[str] = None,
479
+ crop_face: bool = True,
480
+ ):
481
+ """
482
+ Args:
483
+ device: Device to place the backbone on. If ``None``, determined
484
+ from the first forward call.
485
+ weights_path: Path to pretrained backbone.pth. If ``None``,
486
+ searches known locations and attempts download.
487
+ crop_face: Whether to center-crop images before embedding.
488
+ Set ``False`` if images are already 112x112 face crops.
489
+ """
490
+ super().__init__()
491
+ self.crop_face = crop_face
492
+ self._weights_path = weights_path
493
+ self._target_device = device
494
+ self._initialized = False
495
+
496
+ # Build backbone (lazy device placement)
497
+ self.backbone = ArcFaceBackbone()
498
+
499
+ def _ensure_initialized(self, device: torch.device) -> None:
500
+ """Lazy initialization: load weights and move to device on first use."""
501
+ if self._initialized:
502
+ return
503
+
504
+ # Load pretrained weights
505
+ loaded = load_pretrained_weights(self.backbone, self._weights_path)
506
+ if not loaded:
507
+ logger.warning(
508
+ "ArcFaceLoss using random weights -- identity loss will not "
509
+ "be meaningful. Download pretrained weights for proper training."
510
+ )
511
+
512
+ # Move to device and freeze
513
+ self.backbone = self.backbone.to(device)
514
+ self.backbone.eval()
515
+ # Freeze all parameters -- we do NOT want to train ArcFace
516
+ for param in self.backbone.parameters():
517
+ param.requires_grad_(False)
518
+
519
+ self._initialized = True
520
+
521
+ def _prepare_images(self, images: torch.Tensor) -> torch.Tensor:
522
+ """Prepare images for ArcFace: crop, resize, normalize to [-1, 1].
523
+
524
+ Args:
525
+ images: (B, 3, H, W) in [0, 1].
526
+
527
+ Returns:
528
+ (B, 3, 112, 112) in [-1, 1].
529
+ """
530
+ if self.crop_face:
531
+ x = align_face(images, size=112)
532
+ else:
533
+ x = align_face_no_crop(images, size=112)
534
+
535
+ # Normalize from [0, 1] to [-1, 1]
536
+ x = x * 2.0 - 1.0
537
+ return x
538
+
539
+ def _extract_embedding(
540
+ self,
541
+ images: torch.Tensor,
542
+ enable_grad: bool = True,
543
+ ) -> torch.Tensor:
544
+ """Extract ArcFace embeddings.
545
+
546
+ The backbone is in eval mode with frozen parameters, but when
547
+ ``enable_grad=True`` we allow gradient computation through the
548
+ forward pass (important for the predicted images).
549
+
550
+ Args:
551
+ images: (B, 3, 112, 112) in [-1, 1].
552
+ enable_grad: If ``True``, gradients flow through the backbone's
553
+ forward pass (used for pred). If ``False``, detached (target).
554
+
555
+ Returns:
556
+ (B, 512) L2-normalized embeddings.
557
+ """
558
+ if enable_grad:
559
+ # Gradients flow through the backbone forward pass so that
560
+ # the generator receives gradient signal from the identity loss.
561
+ # NOTE: backbone parameters are frozen (requires_grad=False), so
562
+ # only the input tensor carries gradients, which is exactly what
563
+ # we want -- gradients w.r.t. the predicted image, not w.r.t.
564
+ # ArcFace weights.
565
+ return self.backbone(images)
566
+ else:
567
+ with torch.no_grad():
568
+ return self.backbone(images)
569
+
570
+ def forward(
571
+ self,
572
+ pred_image: torch.Tensor,
573
+ target_image: torch.Tensor,
574
+ procedure: str = "rhinoplasty",
575
+ ) -> torch.Tensor:
576
+ """Compute differentiable identity loss.
577
+
578
+ Args:
579
+ pred_image: (B, 3, H, W) predicted images in [0, 1].
580
+ Gradients WILL flow back through this tensor.
581
+ target_image: (B, 3, H, W) target images in [0, 1].
582
+ Gradients will NOT flow through this (detached).
583
+ procedure: Surgical procedure type. ``"orthognathic"`` returns
584
+ zero loss (identity irrelevant for jaw surgery).
585
+
586
+ Returns:
587
+ Scalar loss: mean(1 - cosine_similarity(pred_emb, target_emb)).
588
+ Returns 0 for orthognathic or empty batches.
589
+ """
590
+ if procedure == "orthognathic":
591
+ return torch.tensor(0.0, device=pred_image.device, dtype=pred_image.dtype)
592
+
593
+ device = pred_image.device
594
+ self._ensure_initialized(device)
595
+
596
+ # Procedure-specific cropping (before ArcFace alignment)
597
+ pred_crop = self._procedure_crop(pred_image, procedure)
598
+ target_crop = self._procedure_crop(target_image, procedure)
599
+
600
+ # Prepare for ArcFace (crop, resize to 112x112, normalize to [-1, 1])
601
+ pred_prepared = self._prepare_images(pred_crop)
602
+ target_prepared = self._prepare_images(target_crop)
603
+
604
+ # Extract embeddings
605
+ # pred: WITH gradient flow (so generator gets identity signal)
606
+ pred_emb = self._extract_embedding(pred_prepared, enable_grad=True)
607
+ # target: WITHOUT gradient flow (no need to backprop through target)
608
+ target_emb = self._extract_embedding(target_prepared, enable_grad=False)
609
+
610
+ # Detach target to be absolutely sure no gradients leak
611
+ target_emb = target_emb.detach()
612
+
613
+ # Cosine similarity loss: 1 - cos_sim
614
+ # Both embeddings are already L2-normalized by the backbone
615
+ cosine_sim = (pred_emb * target_emb).sum(dim=1) # (B,)
616
+
617
+ # Clamp to valid range (numerical safety for BF16)
618
+ cosine_sim = cosine_sim.clamp(-1.0, 1.0)
619
+
620
+ loss = (1.0 - cosine_sim).mean()
621
+ return loss
622
+
623
+ def _procedure_crop(
624
+ self,
625
+ image: torch.Tensor,
626
+ procedure: str,
627
+ ) -> torch.Tensor:
628
+ """Crop image based on surgical procedure for identity comparison.
629
+
630
+ Matches the cropping logic from the original ``IdentityLoss`` in
631
+ ``losses.py`` for consistency.
632
+ """
633
+ _, _, h, w = image.shape
634
+
635
+ if procedure == "rhinoplasty":
636
+ # Upper face crop (forehead to nose tip) -- exclude surgical region
637
+ return image[:, :, : h * 2 // 3, :]
638
+ elif procedure == "blepharoplasty":
639
+ # Full face
640
+ return image
641
+ elif procedure == "rhytidectomy":
642
+ # Upper face (above jawline)
643
+ return image[:, :, : h * 3 // 4, :]
644
+ else:
645
+ return image
646
+
647
+ def get_embedding(self, images: torch.Tensor) -> torch.Tensor:
648
+ """Extract identity embeddings (utility method for evaluation).
649
+
650
+ Args:
651
+ images: (B, 3, H, W) in [0, 1].
652
+
653
+ Returns:
654
+ (B, 512) L2-normalized embeddings (detached).
655
+ """
656
+ self._ensure_initialized(images.device)
657
+ prepared = self._prepare_images(images)
658
+ return self._extract_embedding(prepared, enable_grad=False)
659
+
660
+
661
+ # ---------------------------------------------------------------------------
662
+ # Convenience: create a pre-configured loss instance
663
+ # ---------------------------------------------------------------------------
664
+
665
+ def create_arcface_loss(
666
+ device: Optional[torch.device] = None,
667
+ weights_path: Optional[str] = None,
668
+ ) -> ArcFaceLoss:
669
+ """Factory function for creating an ArcFaceLoss with sensible defaults.
670
+
671
+ Args:
672
+ device: Target device (auto-detected if ``None``).
673
+ weights_path: Path to backbone.pth (auto-searched if ``None``).
674
+
675
+ Returns:
676
+ Configured ``ArcFaceLoss`` instance.
677
+ """
678
+ return ArcFaceLoss(device=device, weights_path=weights_path)