linxin02's picture
Open-source PixelControl code (relative paths, identity scrubbed)
497c818 verified
Raw
History Blame Contribute Delete
7.3 kB
"""Minimal extracted cycle losses for the user's multi-control method."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiConditionCycleLoss(nn.Module):
"""Dispatch cycle losses by active control mode."""
def __init__(
self,
depth_cycle_loss: nn.Module | None = None,
seg_cycle_loss: nn.Module | None = None,
edge_cycle_loss: nn.Module | None = None,
depth_weight: float = 1.0,
seg_weight: float = 1.0,
edge_weight: float = 1.0,
):
super().__init__()
self.depth_cycle_loss = depth_cycle_loss
self.seg_cycle_loss = seg_cycle_loss
self.edge_cycle_loss = edge_cycle_loss
self.depth_weight = float(depth_weight)
self.seg_weight = float(seg_weight)
self.edge_weight = float(edge_weight)
def forward(
self,
gen_image_m11: torch.Tensor,
depth_01: torch.Tensor | None = None,
seg_01: torch.Tensor | None = None,
gt_image_m11: torch.Tensor | None = None,
control_mode: str = "depth_seg",
) -> torch.Tensor:
total = gen_image_m11.new_zeros(())
tokens = set(str(control_mode).split("_"))
if "depth" in tokens and self.depth_cycle_loss is not None and self.depth_weight != 0.0:
if depth_01 is None:
raise ValueError("depth cycle requested but depth_01 is None")
total = total + self.depth_weight * self.depth_cycle_loss(gen_image_m11, depth_01)
if "seg" in tokens and self.seg_cycle_loss is not None and self.seg_weight != 0.0:
if seg_01 is None:
raise ValueError("seg cycle requested but seg_01 is None")
total = total + self.seg_weight * self.seg_cycle_loss(gen_image_m11, seg_01)
if "edge" in tokens and self.edge_cycle_loss is not None and self.edge_weight != 0.0:
if gt_image_m11 is None:
raise ValueError("edge cycle requested but gt_image_m11 is None")
total = total + self.edge_weight * self.edge_cycle_loss(gen_image_m11, gt_image_m11)
return total
class SoftCannyImagePyramidCycleLoss(nn.Module):
"""Differentiable edge consistency loss.
Generated RGB and GT RGB are both converted to soft edge maps with the same
sampled threshold. This avoids directly matching noisy/random offline Canny
labels.
"""
def __init__(
self,
loss_res: int = 128,
smooth_l1_beta: float = 0.05,
enable_pyramid_cycle_loss: bool = True,
cycle_scales=(512, 256, 128, 64),
cycle_scale_weights=(0.1, 0.25, 1.0, 0.25),
gaussian_kernel: int = 11,
threshold_min: float = 0.2745,
threshold_max: float = 0.5882,
temperature: float = 0.03,
):
super().__init__()
self.loss_res = int(loss_res)
self.smooth_l1_beta = float(smooth_l1_beta)
self.enable_pyramid_cycle_loss = bool(enable_pyramid_cycle_loss)
self.cycle_scales = [int(s) for s in cycle_scales]
self.cycle_scale_weights = [float(w) for w in cycle_scale_weights]
if len(self.cycle_scales) != len(self.cycle_scale_weights):
raise ValueError("cycle_scales and cycle_scale_weights must have same length")
self.gaussian_kernel = int(gaussian_kernel)
if self.gaussian_kernel <= 0 or self.gaussian_kernel % 2 == 0:
raise ValueError(f"gaussian_kernel must be a positive odd int, got {gaussian_kernel}")
self.threshold_min = float(threshold_min)
self.threshold_max = float(threshold_max)
self.temperature = float(temperature)
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
# Stateless loss. Returning empty state keeps checkpoints small.
if destination is None:
destination = {}
return destination
@staticmethod
def _to_gray_01(image_m11: torch.Tensor) -> torch.Tensor:
image_01 = (image_m11.float() + 1.0) * 0.5
return 0.299 * image_01[:, 0:1] + 0.587 * image_01[:, 1:2] + 0.114 * image_01[:, 2:3]
@staticmethod
def _sobel_xy(x: torch.Tensor):
kx = x.new_tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]).view(1, 1, 3, 3)
ky = x.new_tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]).view(1, 1, 3, 3)
return F.conv2d(x, kx, padding=1), F.conv2d(x, ky, padding=1)
def _gaussian_blur(self, x: torch.Tensor) -> torch.Tensor:
k = self.gaussian_kernel
sigma = 0.3 * ((k - 1) * 0.5 - 1.0) + 0.8
coords = torch.arange(k, device=x.device, dtype=x.dtype) - (k - 1) / 2
g = torch.exp(-(coords.square()) / (2 * sigma * sigma))
g = g / g.sum().clamp_min(1e-8)
kernel = (g[:, None] * g[None, :]).view(1, 1, k, k)
return F.conv2d(x, kernel, padding=k // 2)
def _sample_threshold(self, batch: int, device, dtype) -> torch.Tensor:
lo, hi = sorted((self.threshold_min, self.threshold_max))
t = torch.rand((batch, 1, 1, 1), device=device, dtype=dtype)
return lo + (hi - lo) * t
def _soft_canny_01(self, image_m11: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
gray = self._gaussian_blur(self._to_gray_01(image_m11))
gx, gy = self._sobel_xy(gray)
mag = torch.sqrt(gx.square() + gy.square() + 1e-8)
flat = mag.flatten(1)
lo = flat.amin(dim=1, keepdim=True).view(-1, 1, 1, 1)
hi = flat.amax(dim=1, keepdim=True).view(-1, 1, 1, 1)
mag = ((mag - lo) / (hi - lo).clamp_min(1e-6)).clamp(0.0, 1.0)
return torch.sigmoid((mag - threshold) / max(self.temperature, 1e-6))
@staticmethod
def _resize_image(x: torch.Tensor, scale: int) -> torch.Tensor:
return F.interpolate(x, size=(scale, scale), mode="bilinear", align_corners=False)
def _cycle_at_scale(
self,
gen_image_m11: torch.Tensor,
gt_image_m11: torch.Tensor,
scale: int,
threshold: torch.Tensor,
) -> torch.Tensor:
gen_s = self._resize_image(gen_image_m11, scale)
gt_s = self._resize_image(gt_image_m11, scale)
gen_edge = self._soft_canny_01(gen_s, threshold)
gt_edge = self._soft_canny_01(gt_s, threshold)
return F.smooth_l1_loss(gen_edge, gt_edge.detach(), beta=self.smooth_l1_beta, reduction="mean")
def forward(self, gen_image_m11: torch.Tensor, gt_image_m11: torch.Tensor) -> torch.Tensor:
b, c, _, _ = gen_image_m11.shape
if c != 3 or gt_image_m11.shape[:2] != (b, 3):
raise ValueError("SoftCannyImagePyramidCycleLoss expects generated and GT RGB tensors [B,3,H,W]")
threshold = self._sample_threshold(b, gen_image_m11.device, gen_image_m11.dtype)
if not self.enable_pyramid_cycle_loss:
return self._cycle_at_scale(gen_image_m11, gt_image_m11, self.loss_res, threshold)
total = gen_image_m11.new_zeros(())
for scale, weight in zip(self.cycle_scales, self.cycle_scale_weights):
if weight != 0:
total = total + float(weight) * self._cycle_at_scale(gen_image_m11, gt_image_m11, scale, threshold)
return total
EdgePyramidCycleLoss = SoftCannyImagePyramidCycleLoss