Spaces:
Sleeping
Sleeping
File size: 12,042 Bytes
73fad7a ff7e8d0 871693c ff7e8d0 73fad7a ff7e8d0 871693c ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 871693c ff7e8d0 73fad7a ff7e8d0 871693c ff7e8d0 871693c ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 871693c ff7e8d0 871693c ff7e8d0 871693c ff7e8d0 871693c ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 73fad7a ff7e8d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 | """4-term loss function module for ControlNet fine-tuning.
L_total = L_diffusion + w_landmark * L_landmark + w_identity * L_identity + w_perceptual * L_perceptual
Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
rubbery TPS warps — it would penalize realism.
Phase B (FEM/clinical data): All 4 terms enabled.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn.functional as F
@dataclass(frozen=True)
class LossWeights:
"""Loss term weights."""
diffusion: float = 1.0
landmark: float = 0.1
identity: float = 0.1
perceptual: float = 0.05
class DiffusionLoss:
"""Standard epsilon-prediction MSE loss (primary training signal)."""
def __call__(
self,
noise_pred: torch.Tensor,
noise_target: torch.Tensor,
) -> torch.Tensor:
return F.mse_loss(noise_pred, noise_target)
class LandmarkLoss:
"""L2 landmark distance normalized by inter-ocular distance.
Computed INSIDE surgical mask only. Requires MediaPipe re-extraction
from generated image (done at eval, not every training step for speed).
"""
def __call__(
self,
pred_landmarks: torch.Tensor, # (B, N, 2)
target_landmarks: torch.Tensor, # (B, N, 2)
mask: torch.Tensor | None = None, # (B, N) binary
iod: torch.Tensor | None = None, # (B,) inter-ocular distance
) -> torch.Tensor:
diff = pred_landmarks - target_landmarks # (B, N, 2)
dist = torch.norm(diff, dim=-1) # (B, N)
if mask is not None:
dist = dist * mask
count = mask.sum(dim=-1).clamp(min=1)
mean_dist = dist.sum(dim=-1) / count
else:
mean_dist = dist.mean(dim=-1)
if iod is not None:
mean_dist = mean_dist / iod.clamp(min=1.0)
return mean_dist.mean()
class IdentityLoss:
"""ArcFace cosine similarity loss with procedure-dependent masking.
Uses InsightFace ArcFace model (buffalo_l) for 512-dim identity embeddings.
Falls back to pixel-level cosine similarity if InsightFace is unavailable.
- Full face for blepharoplasty
- Upper-face crop for rhinoplasty
- Disabled for orthognathic
Input images MUST be normalized to [-1, 1] and cropped to 112x112
before passing to ArcFace (AdaFace outputs garbage for 1024x1024).
"""
def __init__(self, device: torch.device | None = None):
self._model = None
self._device = device
self._has_arcface = None # None = not checked yet
def _ensure_loaded(self, device: torch.device) -> None:
"""Lazy-load ArcFace model on first use."""
if self._has_arcface is not None:
return
try:
from insightface.app import FaceAnalysis
self._app = FaceAnalysis(
name="buffalo_l",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
ctx_id = device.index if device.type == "cuda" and device.index is not None else (0 if device.type == "cuda" else -1)
self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
self._has_arcface = True
except Exception:
self._has_arcface = False
@torch.no_grad()
def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""Extract ArcFace embedding from a batch of images.
Args:
image_tensor: (B, 3, 112, 112) in [-1, 1]
Returns:
(B, 512) identity embeddings, or (B, D) pixel-level if fallback.
"""
if self._has_arcface:
import numpy as np
embeddings = []
valid_mask = []
for i in range(image_tensor.shape[0]):
# Convert to uint8 BGR for InsightFace
img = ((image_tensor[i].permute(1, 2, 0) + 1) / 2 * 255).clamp(0, 255)
img_np = img.cpu().numpy().astype(np.uint8)
img_bgr = img_np[:, :, ::-1].copy()
faces = self._app.get(img_bgr)
if faces and hasattr(faces[0], "embedding") and faces[0].embedding is not None:
embeddings.append(torch.from_numpy(faces[0].embedding))
valid_mask.append(True)
else:
embeddings.append(torch.zeros(512))
valid_mask.append(False)
return torch.stack(embeddings).to(image_tensor.device), valid_mask
else:
# Fallback: pixel-level features
return image_tensor.flatten(1), [True] * image_tensor.shape[0]
def __call__(
self,
pred_image: torch.Tensor, # (B, 3, H, W) in [0, 1]
target_image: torch.Tensor,
procedure: str = "rhinoplasty",
) -> torch.Tensor:
if procedure == "orthognathic":
return torch.tensor(0.0, device=pred_image.device)
self._ensure_loaded(pred_image.device)
# Crop based on procedure
pred_crop = self._procedure_crop(pred_image, procedure)
target_crop = self._procedure_crop(target_image, procedure)
# Resize to 112x112 for ArcFace
pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
target_112 = F.interpolate(target_crop, size=(112, 112), mode="bilinear", align_corners=False)
# Normalize to [-1, 1]
pred_norm = pred_112 * 2 - 1
target_norm = target_112 * 2 - 1
# Extract embeddings (ArcFace or fallback)
pred_emb, pred_valid = self._extract_embedding(pred_norm)
target_emb, target_valid = self._extract_embedding(target_norm)
# Only compute loss for samples where both faces were detected
valid = [p and t for p, t in zip(pred_valid, target_valid)]
if not any(valid):
return torch.tensor(0.0, device=pred_image.device)
valid_indices = [i for i, v in enumerate(valid) if v]
valid_idx_t = torch.tensor(valid_indices, device=pred_image.device, dtype=torch.long)
# Select ONLY valid embeddings before normalization to avoid 0/0 NaN
pred_valid_emb = pred_emb[valid_idx_t].float()
target_valid_emb = target_emb[valid_idx_t].float()
# L2 normalize (safe — zero vectors excluded above)
pred_valid_emb = F.normalize(pred_valid_emb, dim=1)
target_valid_emb = F.normalize(target_valid_emb, dim=1)
cosine_sim = (pred_valid_emb * target_valid_emb).sum(dim=1)
return (1 - cosine_sim).mean()
def _procedure_crop(
self,
image: torch.Tensor,
procedure: str,
) -> torch.Tensor:
"""Crop image based on procedure for identity comparison."""
_, _, h, w = image.shape
if procedure == "rhinoplasty":
# Upper face crop (forehead to nose tip)
return image[:, :, : h * 2 // 3, :]
elif procedure == "blepharoplasty":
# Full face
return image
elif procedure == "rhytidectomy":
# Upper face (above jawline)
return image[:, :, : h * 3 // 4, :]
else:
return image
class PerceptualLoss:
"""LPIPS perceptual loss on regions OUTSIDE surgical mask only.
LPIPS expects [-1, 1] input. VAE outputs [0, 1].
Must apply (x * 2) - 1 before every call.
"""
def __init__(self):
self._lpips = None
def _ensure_loaded(self, device: torch.device) -> None:
if self._lpips is None:
try:
import lpips
self._lpips = lpips.LPIPS(net="alex").to(device)
self._lpips.eval()
for p in self._lpips.parameters():
p.requires_grad_(False)
except ImportError:
self._lpips = "unavailable"
def __call__(
self,
pred: torch.Tensor, # (B, 3, H, W) in [0, 1]
target: torch.Tensor,
mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1]
) -> torch.Tensor:
self._ensure_loaded(pred.device)
# Normalize to [-1, 1] for LPIPS
pred_norm = pred * 2 - 1
target_norm = target * 2 - 1
# When mask is all-ones (no mask file available), compute on full image.
# Otherwise invert mask to get loss OUTSIDE the surgical region only.
has_mask = mask.sum() < mask.numel() * 0.99
if has_mask:
outside_mask = 1 - mask
erode_kernel = 5
if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
outside_mask = -F.max_pool2d(
-outside_mask,
kernel_size=erode_kernel,
stride=1,
padding=erode_kernel // 2,
)
pred_norm = pred_norm * outside_mask
target_norm = target_norm * outside_mask
if self._lpips == "unavailable":
# Fallback: simple L1 loss
return F.l1_loss(pred_norm, target_norm)
return self._lpips(pred_norm, target_norm).mean()
class CombinedLoss:
"""Combined 4-term loss with configurable weights.
Use phase='A' for Phase A training (diffusion only).
Use phase='B' for Phase B training (all terms).
For Phase B, set ``use_differentiable_arcface=True`` to use the
PyTorch-native ArcFace backbone (``arcface_torch.py``) that provides
actual gradient signal. The default ONNX-based IdentityLoss produces
zero gradients (DA2-03).
"""
def __init__(
self,
weights: LossWeights | None = None,
phase: str = "A",
use_differentiable_arcface: bool = False,
arcface_weights_path: str | None = None,
):
self.weights = weights or LossWeights()
self.phase = phase
self.diffusion_loss = DiffusionLoss()
self.landmark_loss = LandmarkLoss()
self.perceptual_loss = PerceptualLoss()
# Identity loss: differentiable PyTorch ArcFace for Phase B,
# or ONNX-based fallback
if use_differentiable_arcface:
from landmarkdiff.arcface_torch import ArcFaceLoss
self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
else:
self.identity_loss = IdentityLoss()
def __call__(
self,
noise_pred: torch.Tensor,
noise_target: torch.Tensor,
**kwargs,
) -> dict[str, torch.Tensor]:
losses = {}
# Always compute diffusion loss
losses["diffusion"] = self.weights.diffusion * self.diffusion_loss(noise_pred, noise_target)
losses["total"] = losses["diffusion"]
if self.phase == "B":
# Phase B: add auxiliary losses
if "pred_landmarks" in kwargs and "target_landmarks" in kwargs:
losses["landmark"] = self.weights.landmark * self.landmark_loss(
kwargs["pred_landmarks"],
kwargs["target_landmarks"],
kwargs.get("landmark_mask"),
kwargs.get("iod"),
)
losses["total"] = losses["total"] + losses["landmark"]
if "pred_image" in kwargs and "target_image" in kwargs:
procedure = kwargs.get("procedure", "rhinoplasty")
losses["identity"] = self.weights.identity * self.identity_loss(
kwargs["pred_image"],
kwargs["target_image"],
procedure,
)
losses["total"] = losses["total"] + losses["identity"]
if "pred_image" in kwargs and "target_image" in kwargs and "mask" in kwargs:
losses["perceptual"] = self.weights.perceptual * self.perceptual_loss(
kwargs["pred_image"],
kwargs["target_image"],
kwargs["mask"],
)
losses["total"] = losses["total"] + losses["perceptual"]
return losses
|