| """Minimal extracted cycle losses for the user's multi-control method.""" |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class MultiConditionCycleLoss(nn.Module): |
| """Dispatch cycle losses by active control mode.""" |
|
|
| def __init__( |
| self, |
| depth_cycle_loss: nn.Module | None = None, |
| seg_cycle_loss: nn.Module | None = None, |
| edge_cycle_loss: nn.Module | None = None, |
| depth_weight: float = 1.0, |
| seg_weight: float = 1.0, |
| edge_weight: float = 1.0, |
| ): |
| super().__init__() |
| self.depth_cycle_loss = depth_cycle_loss |
| self.seg_cycle_loss = seg_cycle_loss |
| self.edge_cycle_loss = edge_cycle_loss |
| self.depth_weight = float(depth_weight) |
| self.seg_weight = float(seg_weight) |
| self.edge_weight = float(edge_weight) |
|
|
| def forward( |
| self, |
| gen_image_m11: torch.Tensor, |
| depth_01: torch.Tensor | None = None, |
| seg_01: torch.Tensor | None = None, |
| gt_image_m11: torch.Tensor | None = None, |
| control_mode: str = "depth_seg", |
| ) -> torch.Tensor: |
| total = gen_image_m11.new_zeros(()) |
| tokens = set(str(control_mode).split("_")) |
| if "depth" in tokens and self.depth_cycle_loss is not None and self.depth_weight != 0.0: |
| if depth_01 is None: |
| raise ValueError("depth cycle requested but depth_01 is None") |
| total = total + self.depth_weight * self.depth_cycle_loss(gen_image_m11, depth_01) |
| if "seg" in tokens and self.seg_cycle_loss is not None and self.seg_weight != 0.0: |
| if seg_01 is None: |
| raise ValueError("seg cycle requested but seg_01 is None") |
| total = total + self.seg_weight * self.seg_cycle_loss(gen_image_m11, seg_01) |
| if "edge" in tokens and self.edge_cycle_loss is not None and self.edge_weight != 0.0: |
| if gt_image_m11 is None: |
| raise ValueError("edge cycle requested but gt_image_m11 is None") |
| total = total + self.edge_weight * self.edge_cycle_loss(gen_image_m11, gt_image_m11) |
| return total |
|
|
|
|
| class SoftCannyImagePyramidCycleLoss(nn.Module): |
| """Differentiable edge consistency loss. |
| |
| Generated RGB and GT RGB are both converted to soft edge maps with the same |
| sampled threshold. This avoids directly matching noisy/random offline Canny |
| labels. |
| """ |
|
|
| def __init__( |
| self, |
| loss_res: int = 128, |
| smooth_l1_beta: float = 0.05, |
| enable_pyramid_cycle_loss: bool = True, |
| cycle_scales=(512, 256, 128, 64), |
| cycle_scale_weights=(0.1, 0.25, 1.0, 0.25), |
| gaussian_kernel: int = 11, |
| threshold_min: float = 0.2745, |
| threshold_max: float = 0.5882, |
| temperature: float = 0.03, |
| ): |
| super().__init__() |
| self.loss_res = int(loss_res) |
| self.smooth_l1_beta = float(smooth_l1_beta) |
| self.enable_pyramid_cycle_loss = bool(enable_pyramid_cycle_loss) |
| self.cycle_scales = [int(s) for s in cycle_scales] |
| self.cycle_scale_weights = [float(w) for w in cycle_scale_weights] |
| if len(self.cycle_scales) != len(self.cycle_scale_weights): |
| raise ValueError("cycle_scales and cycle_scale_weights must have same length") |
| self.gaussian_kernel = int(gaussian_kernel) |
| if self.gaussian_kernel <= 0 or self.gaussian_kernel % 2 == 0: |
| raise ValueError(f"gaussian_kernel must be a positive odd int, got {gaussian_kernel}") |
| self.threshold_min = float(threshold_min) |
| self.threshold_max = float(threshold_max) |
| self.temperature = float(temperature) |
|
|
| def state_dict(self, *args, destination=None, prefix="", keep_vars=False): |
| |
| if destination is None: |
| destination = {} |
| return destination |
|
|
| @staticmethod |
| def _to_gray_01(image_m11: torch.Tensor) -> torch.Tensor: |
| image_01 = (image_m11.float() + 1.0) * 0.5 |
| return 0.299 * image_01[:, 0:1] + 0.587 * image_01[:, 1:2] + 0.114 * image_01[:, 2:3] |
|
|
| @staticmethod |
| def _sobel_xy(x: torch.Tensor): |
| kx = x.new_tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]).view(1, 1, 3, 3) |
| ky = x.new_tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]).view(1, 1, 3, 3) |
| return F.conv2d(x, kx, padding=1), F.conv2d(x, ky, padding=1) |
|
|
| def _gaussian_blur(self, x: torch.Tensor) -> torch.Tensor: |
| k = self.gaussian_kernel |
| sigma = 0.3 * ((k - 1) * 0.5 - 1.0) + 0.8 |
| coords = torch.arange(k, device=x.device, dtype=x.dtype) - (k - 1) / 2 |
| g = torch.exp(-(coords.square()) / (2 * sigma * sigma)) |
| g = g / g.sum().clamp_min(1e-8) |
| kernel = (g[:, None] * g[None, :]).view(1, 1, k, k) |
| return F.conv2d(x, kernel, padding=k // 2) |
|
|
| def _sample_threshold(self, batch: int, device, dtype) -> torch.Tensor: |
| lo, hi = sorted((self.threshold_min, self.threshold_max)) |
| t = torch.rand((batch, 1, 1, 1), device=device, dtype=dtype) |
| return lo + (hi - lo) * t |
|
|
| def _soft_canny_01(self, image_m11: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor: |
| gray = self._gaussian_blur(self._to_gray_01(image_m11)) |
| gx, gy = self._sobel_xy(gray) |
| mag = torch.sqrt(gx.square() + gy.square() + 1e-8) |
| flat = mag.flatten(1) |
| lo = flat.amin(dim=1, keepdim=True).view(-1, 1, 1, 1) |
| hi = flat.amax(dim=1, keepdim=True).view(-1, 1, 1, 1) |
| mag = ((mag - lo) / (hi - lo).clamp_min(1e-6)).clamp(0.0, 1.0) |
| return torch.sigmoid((mag - threshold) / max(self.temperature, 1e-6)) |
|
|
| @staticmethod |
| def _resize_image(x: torch.Tensor, scale: int) -> torch.Tensor: |
| return F.interpolate(x, size=(scale, scale), mode="bilinear", align_corners=False) |
|
|
| def _cycle_at_scale( |
| self, |
| gen_image_m11: torch.Tensor, |
| gt_image_m11: torch.Tensor, |
| scale: int, |
| threshold: torch.Tensor, |
| ) -> torch.Tensor: |
| gen_s = self._resize_image(gen_image_m11, scale) |
| gt_s = self._resize_image(gt_image_m11, scale) |
| gen_edge = self._soft_canny_01(gen_s, threshold) |
| gt_edge = self._soft_canny_01(gt_s, threshold) |
| return F.smooth_l1_loss(gen_edge, gt_edge.detach(), beta=self.smooth_l1_beta, reduction="mean") |
|
|
| def forward(self, gen_image_m11: torch.Tensor, gt_image_m11: torch.Tensor) -> torch.Tensor: |
| b, c, _, _ = gen_image_m11.shape |
| if c != 3 or gt_image_m11.shape[:2] != (b, 3): |
| raise ValueError("SoftCannyImagePyramidCycleLoss expects generated and GT RGB tensors [B,3,H,W]") |
| threshold = self._sample_threshold(b, gen_image_m11.device, gen_image_m11.dtype) |
| if not self.enable_pyramid_cycle_loss: |
| return self._cycle_at_scale(gen_image_m11, gt_image_m11, self.loss_res, threshold) |
| total = gen_image_m11.new_zeros(()) |
| for scale, weight in zip(self.cycle_scales, self.cycle_scale_weights): |
| if weight != 0: |
| total = total + float(weight) * self._cycle_at_scale(gen_image_m11, gt_image_m11, scale, threshold) |
| return total |
|
|
|
|
| EdgePyramidCycleLoss = SoftCannyImagePyramidCycleLoss |
|
|