saliacoel commited on
Commit
8de5a16
·
verified ·
1 Parent(s): ef216e4

Upload salia_get_diff_mask.py

Browse files
Files changed (1) hide show
  1. salia_get_diff_mask.py +1235 -0
salia_get_diff_mask.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ComfyUI Deterministic Change Mask
3
+ =================================
4
+
5
+ A no-AI custom node that builds a soft change mask from two before/after images.
6
+ It is designed for cases where the after image has global quality loss but also a
7
+ local real edit such as a newly equipped garment.
8
+
9
+ Install:
10
+ ComfyUI/custom_nodes/ComfyUI_DeterministicChangeMask/__init__.py
11
+
12
+ Notes:
13
+ - Standard ComfyUI IMAGE tensors are usually RGB: [B,H,W,3].
14
+ - If a node passes [B,H,W,4], this node extracts channel 4 as alpha.
15
+ - If using ComfyUI LoadImage, connect the IMAGE outputs to before/after_image
16
+ and optionally connect the LoadImage MASK outputs to before/after_alpha.
17
+ LoadImage masks are inverted alpha, so this node converts alpha = 1 - mask.
18
+
19
+ No hard dependency on OpenCV or SciPy. If available, they are used for better
20
+ alignment / distance transform / connected components. Otherwise pure torch/numpy
21
+ fallbacks are used.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import math
27
+ from collections import deque
28
+ from typing import Dict, List, Optional, Sequence, Tuple
29
+
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn.functional as F
33
+
34
+ try:
35
+ import cv2 # type: ignore
36
+ except Exception: # pragma: no cover - optional dependency
37
+ cv2 = None
38
+
39
+ try:
40
+ from scipy import ndimage as ndi # type: ignore
41
+ except Exception: # pragma: no cover - optional dependency
42
+ ndi = None
43
+
44
+
45
+ _EPS = 1.0e-8
46
+
47
+
48
+ class Salia_Get_Diff_Mask:
49
+ """Deterministic before/after change mask for RGBA-aware workflows."""
50
+
51
+ CATEGORY = "mask/deterministic"
52
+ FUNCTION = "make_mask"
53
+ RETURN_TYPES = ("MASK",)
54
+ RETURN_NAMES = ("mask",)
55
+
56
+ @classmethod
57
+ def INPUT_TYPES(cls):
58
+ return {
59
+ "required": {
60
+ "before_image": ("IMAGE",),
61
+ "after_image": ("IMAGE",),
62
+
63
+ # Method switches. Convention: -1 = off, 0 = fast/simple/auto if available,
64
+ # 1 = recommended/default, 2+ = alternate variants.
65
+ "align_mode": ("INT", {"default": -1, "min": -1, "max": 2, "step": 1}),
66
+ "denoise_mode": ("INT", {"default": 1, "min": -1, "max": 3, "step": 1}),
67
+ "color_mode": ("INT", {"default": 1, "min": -1, "max": 3, "step": 1}),
68
+ "alpha_mode": ("INT", {"default": 1, "min": -1, "max": 2, "step": 1}),
69
+ "structure_mode": ("INT", {"default": 1, "min": -1, "max": 2, "step": 1}),
70
+ "gradient_mode": ("INT", {"default": 1, "min": -1, "max": 2, "step": 1}),
71
+ "normalize_mode": ("INT", {"default": 1, "min": -1, "max": 2, "step": 1}),
72
+ "combine_mode": ("INT", {"default": 1, "min": -1, "max": 3, "step": 1}),
73
+ "hysteresis_mode": ("INT", {"default": 1, "min": -1, "max": 2, "step": 1}),
74
+ "morph_mode": ("INT", {"default": 1, "min": -1, "max": 4, "step": 1}),
75
+ "component_mode": ("INT", {"default": 1, "min": -1, "max": 1, "step": 1}),
76
+ "feather_mode": ("INT", {"default": 2, "min": -1, "max": 3, "step": 1}),
77
+ "resize_mode": ("INT", {"default": 1, "min": -1, "max": 1, "step": 1}),
78
+
79
+ # Weights. They are normalized internally among enabled terms.
80
+ "color_weight": ("FLOAT", {"default": 0.45, "min": 0.0, "max": 5.0, "step": 0.01}),
81
+ "alpha_weight": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 5.0, "step": 0.01}),
82
+ "structure_weight": ("FLOAT", {"default": 0.20, "min": 0.0, "max": 5.0, "step": 0.01}),
83
+ "gradient_weight": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 5.0, "step": 0.01}),
84
+
85
+ # Robust normalization / thresholding.
86
+ "noise_floor_k": ("FLOAT", {"default": 2.5, "min": 0.0, "max": 12.0, "step": 0.05}),
87
+ "mad_scale": ("FLOAT", {"default": 3.0, "min": 0.25, "max": 20.0, "step": 0.05}),
88
+ "low_threshold": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}),
89
+ "high_threshold": ("FLOAT", {"default": 0.58, "min": 0.0, "max": 1.0, "step": 0.01}),
90
+
91
+ # Denoising / structural settings.
92
+ "preblur_radius": ("INT", {"default": 1, "min": 0, "max": 12, "step": 1}),
93
+ "preblur_sigma": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 12.0, "step": 0.05}),
94
+ "ssim_window": ("INT", {"default": 11, "min": 3, "max": 31, "step": 2}),
95
+ "ssim_sigma": ("FLOAT", {"default": 1.5, "min": 0.3, "max": 8.0, "step": 0.05}),
96
+
97
+ # Geometry / cleanup / feather.
98
+ "max_align_pixels": ("INT", {"default": 24, "min": 0, "max": 256, "step": 1}),
99
+ "valid_alpha_threshold": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.005}),
100
+ "morph_radius": ("INT", {"default": 2, "min": 0, "max": 32, "step": 1}),
101
+ "min_region_area": ("INT", {"default": 64, "min": 0, "max": 1000000, "step": 1}),
102
+ "keep_largest_regions": ("INT", {"default": 0, "min": 0, "max": 128, "step": 1}),
103
+ "feather_radius": ("INT", {"default": 8, "min": 0, "max": 256, "step": 1}),
104
+ "logistic_steepness": ("FLOAT", {"default": 10.0, "min": 0.1, "max": 64.0, "step": 0.1}),
105
+ },
106
+ "optional": {
107
+ "before_alpha": ("MASK",),
108
+ "after_alpha": ("MASK",),
109
+ },
110
+ }
111
+
112
+ def make_mask(
113
+ self,
114
+ before_image: torch.Tensor,
115
+ after_image: torch.Tensor,
116
+ align_mode: int,
117
+ denoise_mode: int,
118
+ color_mode: int,
119
+ alpha_mode: int,
120
+ structure_mode: int,
121
+ gradient_mode: int,
122
+ normalize_mode: int,
123
+ combine_mode: int,
124
+ hysteresis_mode: int,
125
+ morph_mode: int,
126
+ component_mode: int,
127
+ feather_mode: int,
128
+ resize_mode: int,
129
+ color_weight: float,
130
+ alpha_weight: float,
131
+ structure_weight: float,
132
+ gradient_weight: float,
133
+ noise_floor_k: float,
134
+ mad_scale: float,
135
+ low_threshold: float,
136
+ high_threshold: float,
137
+ preblur_radius: int,
138
+ preblur_sigma: float,
139
+ ssim_window: int,
140
+ ssim_sigma: float,
141
+ max_align_pixels: int,
142
+ valid_alpha_threshold: float,
143
+ morph_radius: int,
144
+ min_region_area: int,
145
+ keep_largest_regions: int,
146
+ feather_radius: int,
147
+ logistic_steepness: float,
148
+ before_alpha: Optional[torch.Tensor] = None,
149
+ after_alpha: Optional[torch.Tensor] = None,
150
+ ) -> Tuple[torch.Tensor]:
151
+ before_image = _ensure_image(before_image)
152
+ after_image = _ensure_image(after_image)
153
+
154
+ b_rgb, b_a = _split_rgb_alpha(before_image, before_alpha)
155
+ a_rgb, a_a = _split_rgb_alpha(after_image, after_alpha)
156
+
157
+ if b_rgb.shape[1:3] != a_rgb.shape[1:3]:
158
+ if resize_mode < 1:
159
+ raise ValueError(
160
+ "before_image and after_image sizes differ. Set resize_mode=1 to resize after_image to before_image."
161
+ )
162
+ target_h, target_w = int(b_rgb.shape[1]), int(b_rgb.shape[2])
163
+ a_rgb = _resize_bhwc(a_rgb, target_h, target_w)
164
+ a_a = _resize_bhw(a_a, target_h, target_w)
165
+
166
+ b_rgb, b_a, a_rgb, a_a = _broadcast_batch(b_rgb, b_a, a_rgb, a_a)
167
+
168
+ low_threshold = float(min(low_threshold, high_threshold))
169
+ high_threshold = float(max(low_threshold, high_threshold))
170
+ ssim_window = _make_odd(int(ssim_window), minimum=3)
171
+
172
+ out_masks: List[torch.Tensor] = []
173
+ batch = int(b_rgb.shape[0])
174
+ for i in range(batch):
175
+ mask_i = self._process_one(
176
+ before_rgb=b_rgb[i],
177
+ before_alpha=b_a[i],
178
+ after_rgb=a_rgb[i],
179
+ after_alpha=a_a[i],
180
+ align_mode=int(align_mode),
181
+ denoise_mode=int(denoise_mode),
182
+ color_mode=int(color_mode),
183
+ alpha_mode=int(alpha_mode),
184
+ structure_mode=int(structure_mode),
185
+ gradient_mode=int(gradient_mode),
186
+ normalize_mode=int(normalize_mode),
187
+ combine_mode=int(combine_mode),
188
+ hysteresis_mode=int(hysteresis_mode),
189
+ morph_mode=int(morph_mode),
190
+ component_mode=int(component_mode),
191
+ color_weight=float(color_weight),
192
+ alpha_weight=float(alpha_weight),
193
+ structure_weight=float(structure_weight),
194
+ gradient_weight=float(gradient_weight),
195
+ noise_floor_k=float(noise_floor_k),
196
+ mad_scale=float(mad_scale),
197
+ low_threshold=float(low_threshold),
198
+ high_threshold=float(high_threshold),
199
+ preblur_radius=int(preblur_radius),
200
+ preblur_sigma=float(preblur_sigma),
201
+ ssim_window=int(ssim_window),
202
+ ssim_sigma=float(ssim_sigma),
203
+ max_align_pixels=int(max_align_pixels),
204
+ valid_alpha_threshold=float(valid_alpha_threshold),
205
+ morph_radius=int(morph_radius),
206
+ min_region_area=int(min_region_area),
207
+ keep_largest_regions=int(keep_largest_regions),
208
+ feather_mode=int(feather_mode),
209
+ feather_radius=int(feather_radius),
210
+ logistic_steepness=float(logistic_steepness),
211
+ )
212
+ out_masks.append(mask_i)
213
+
214
+ return (torch.stack(out_masks, dim=0).clamp(0.0, 1.0),)
215
+
216
+ def _process_one(
217
+ self,
218
+ before_rgb: torch.Tensor, # [H,W,3]
219
+ before_alpha: torch.Tensor, # [H,W]
220
+ after_rgb: torch.Tensor, # [H,W,3]
221
+ after_alpha: torch.Tensor, # [H,W]
222
+ align_mode: int,
223
+ denoise_mode: int,
224
+ color_mode: int,
225
+ alpha_mode: int,
226
+ structure_mode: int,
227
+ gradient_mode: int,
228
+ normalize_mode: int,
229
+ combine_mode: int,
230
+ hysteresis_mode: int,
231
+ morph_mode: int,
232
+ component_mode: int,
233
+ color_weight: float,
234
+ alpha_weight: float,
235
+ structure_weight: float,
236
+ gradient_weight: float,
237
+ noise_floor_k: float,
238
+ mad_scale: float,
239
+ low_threshold: float,
240
+ high_threshold: float,
241
+ preblur_radius: int,
242
+ preblur_sigma: float,
243
+ ssim_window: int,
244
+ ssim_sigma: float,
245
+ max_align_pixels: int,
246
+ valid_alpha_threshold: float,
247
+ morph_radius: int,
248
+ min_region_area: int,
249
+ keep_largest_regions: int,
250
+ feather_mode: int,
251
+ feather_radius: int,
252
+ logistic_steepness: float,
253
+ ) -> torch.Tensor:
254
+ device = before_rgb.device
255
+ dtype = before_rgb.dtype
256
+ h, w = int(before_rgb.shape[0]), int(before_rgb.shape[1])
257
+
258
+ if align_mode >= 1:
259
+ after_rgb, after_alpha = _align_after_to_before(
260
+ before_rgb,
261
+ before_alpha,
262
+ after_rgb,
263
+ after_alpha,
264
+ align_mode=align_mode,
265
+ max_align_pixels=max_align_pixels,
266
+ )
267
+
268
+ # Work in channel-first tensors for filtering and SSIM.
269
+ b_rgb_cf = before_rgb.permute(2, 0, 1).unsqueeze(0)
270
+ a_rgb_cf = after_rgb.permute(2, 0, 1).unsqueeze(0)
271
+ b_a_cf = before_alpha.unsqueeze(0).unsqueeze(0)
272
+ a_a_cf = after_alpha.unsqueeze(0).unsqueeze(0)
273
+
274
+ if denoise_mode >= 0 and preblur_radius > 0:
275
+ if denoise_mode == 2:
276
+ b_rgb_cf = _median_blur_bchw(b_rgb_cf, preblur_radius)
277
+ a_rgb_cf = _median_blur_bchw(a_rgb_cf, preblur_radius)
278
+ b_a_cf = _median_blur_bchw(b_a_cf, preblur_radius)
279
+ a_a_cf = _median_blur_bchw(a_a_cf, preblur_radius)
280
+ elif denoise_mode == 3 and cv2 is not None:
281
+ b_rgb_cf = _bilateral_or_gaussian_bchw(b_rgb_cf, preblur_radius, preblur_sigma)
282
+ a_rgb_cf = _bilateral_or_gaussian_bchw(a_rgb_cf, preblur_radius, preblur_sigma)
283
+ b_a_cf = _gaussian_blur_bchw(b_a_cf, preblur_radius, preblur_sigma)
284
+ a_a_cf = _gaussian_blur_bchw(a_a_cf, preblur_radius, preblur_sigma)
285
+ else:
286
+ b_rgb_cf = _gaussian_blur_bchw(b_rgb_cf, preblur_radius, preblur_sigma)
287
+ a_rgb_cf = _gaussian_blur_bchw(a_rgb_cf, preblur_radius, preblur_sigma)
288
+ b_a_cf = _gaussian_blur_bchw(b_a_cf, preblur_radius, preblur_sigma)
289
+ a_a_cf = _gaussian_blur_bchw(a_a_cf, preblur_radius, preblur_sigma)
290
+
291
+ b_rgb = b_rgb_cf.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0)
292
+ a_rgb = a_rgb_cf.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0)
293
+ b_alpha = b_a_cf.squeeze(0).squeeze(0).clamp(0.0, 1.0)
294
+ a_alpha = a_a_cf.squeeze(0).squeeze(0).clamp(0.0, 1.0)
295
+
296
+ valid = torch.maximum(b_alpha, a_alpha) > float(valid_alpha_threshold)
297
+ if valid.sum().item() < 16:
298
+ valid = torch.ones((h, w), dtype=torch.bool, device=device)
299
+
300
+ maps: List[Tuple[torch.Tensor, float, str, float]] = []
301
+
302
+ if color_mode >= 0 and color_weight > 0.0:
303
+ d_color = _color_difference_map(b_rgb, a_rgb, b_alpha, a_alpha, mode=color_mode)
304
+ n_color = _normalize_map(
305
+ d_color,
306
+ valid=valid,
307
+ mode=normalize_mode,
308
+ noise_floor_k=noise_floor_k,
309
+ mad_scale=mad_scale,
310
+ fixed_scale=20.0,
311
+ )
312
+ maps.append((n_color, color_weight, "color", 20.0))
313
+
314
+ if alpha_mode >= 1 and alpha_weight > 0.0:
315
+ d_alpha = (a_alpha - b_alpha).abs()
316
+ if alpha_mode >= 2:
317
+ premul_b = b_rgb * b_alpha.unsqueeze(-1)
318
+ premul_a = a_rgb * a_alpha.unsqueeze(-1)
319
+ d_premul = torch.linalg.vector_norm(premul_a - premul_b, dim=-1) / math.sqrt(3.0)
320
+ d_alpha = torch.maximum(d_alpha, d_premul)
321
+ n_alpha = _normalize_map(
322
+ d_alpha,
323
+ valid=valid,
324
+ mode=normalize_mode,
325
+ noise_floor_k=max(0.0, noise_floor_k - 0.5),
326
+ mad_scale=mad_scale,
327
+ fixed_scale=1.0,
328
+ )
329
+ maps.append((n_alpha, alpha_weight, "alpha", 1.0))
330
+
331
+ if structure_mode >= 1 and structure_weight > 0.0:
332
+ d_struct = _structure_difference_map(
333
+ b_rgb,
334
+ a_rgb,
335
+ mode=structure_mode,
336
+ window=ssim_window,
337
+ sigma=ssim_sigma,
338
+ )
339
+ d_struct = d_struct * valid.to(dtype=d_struct.dtype)
340
+ n_struct = _normalize_map(
341
+ d_struct,
342
+ valid=valid,
343
+ mode=normalize_mode,
344
+ noise_floor_k=noise_floor_k,
345
+ mad_scale=mad_scale,
346
+ fixed_scale=0.5,
347
+ )
348
+ maps.append((n_struct, structure_weight, "structure", 0.5))
349
+
350
+ if gradient_mode >= 1 and gradient_weight > 0.0:
351
+ d_grad = _gradient_difference_map(b_rgb, a_rgb, b_alpha, a_alpha, mode=gradient_mode)
352
+ d_grad = d_grad * valid.to(dtype=d_grad.dtype)
353
+ n_grad = _normalize_map(
354
+ d_grad,
355
+ valid=valid,
356
+ mode=normalize_mode,
357
+ noise_floor_k=noise_floor_k,
358
+ mad_scale=mad_scale,
359
+ fixed_scale=0.25,
360
+ )
361
+ maps.append((n_grad, gradient_weight, "gradient", 0.25))
362
+
363
+ if not maps:
364
+ return torch.zeros((h, w), dtype=dtype, device=device)
365
+
366
+ score = _combine_maps(maps, mode=combine_mode).clamp(0.0, 1.0)
367
+ score = score * valid.to(dtype=score.dtype)
368
+
369
+ core_np = _make_core_mask(
370
+ score.detach().float().cpu().numpy(),
371
+ low_threshold=low_threshold,
372
+ high_threshold=high_threshold,
373
+ hysteresis_mode=hysteresis_mode,
374
+ )
375
+
376
+ if morph_mode >= 1 and morph_radius > 0:
377
+ core_np = _morph_binary(core_np, mode=morph_mode, radius=morph_radius)
378
+
379
+ if component_mode >= 1:
380
+ core_np = _filter_components(
381
+ core_np,
382
+ min_area=max(0, int(min_region_area)),
383
+ keep_largest=max(0, int(keep_largest_regions)),
384
+ )
385
+
386
+ mask_np = _feather_core(
387
+ core_np,
388
+ mode=feather_mode,
389
+ radius=max(0, int(feather_radius)),
390
+ logistic_steepness=max(0.1, float(logistic_steepness)),
391
+ )
392
+
393
+ mask = torch.from_numpy(mask_np).to(device=device, dtype=dtype)
394
+ return mask.clamp(0.0, 1.0)
395
+
396
+
397
+ # -----------------------------------------------------------------------------
398
+ # Tensor preparation
399
+ # -----------------------------------------------------------------------------
400
+
401
+
402
+ def _ensure_image(image: torch.Tensor) -> torch.Tensor:
403
+ if not isinstance(image, torch.Tensor):
404
+ raise TypeError("Expected ComfyUI IMAGE as torch.Tensor.")
405
+ if image.dim() == 3:
406
+ image = image.unsqueeze(0)
407
+ if image.dim() != 4:
408
+ raise ValueError(f"Expected IMAGE shape [B,H,W,C], got {tuple(image.shape)}")
409
+ return image.float().clamp(0.0, 1.0)
410
+
411
+
412
+ def _split_rgb_alpha(image: torch.Tensor, optional_alpha: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
413
+ b, h, w, c = image.shape
414
+ if c >= 4:
415
+ rgb = image[..., :3]
416
+ alpha = image[..., 3].float().clamp(0.0, 1.0)
417
+ elif c == 3:
418
+ rgb = image
419
+ alpha = torch.ones((b, h, w), dtype=image.dtype, device=image.device)
420
+ elif c == 1:
421
+ rgb = image.repeat(1, 1, 1, 3)
422
+ alpha = torch.ones((b, h, w), dtype=image.dtype, device=image.device)
423
+ else:
424
+ # Unusual, but keep the node from crashing on odd custom tensors.
425
+ first = image[..., :1]
426
+ rgb = first.repeat(1, 1, 1, 3)
427
+ alpha = torch.ones((b, h, w), dtype=image.dtype, device=image.device)
428
+
429
+ # For normal ComfyUI LoadImage workflows, alpha arrives as MASK and is inverted:
430
+ # mask = 1 - opacity. Optional mask is only used when IMAGE has no explicit alpha.
431
+ if optional_alpha is not None and c < 4:
432
+ alpha = 1.0 - _ensure_mask(optional_alpha, b, h, w, image.device, image.dtype)
433
+
434
+ return rgb.float().clamp(0.0, 1.0), alpha.float().clamp(0.0, 1.0)
435
+
436
+
437
+ def _ensure_mask(
438
+ mask: torch.Tensor,
439
+ batch: int,
440
+ height: int,
441
+ width: int,
442
+ device: torch.device,
443
+ dtype: torch.dtype,
444
+ ) -> torch.Tensor:
445
+ if not isinstance(mask, torch.Tensor):
446
+ raise TypeError("Expected ComfyUI MASK as torch.Tensor.")
447
+ mask = mask.float()
448
+
449
+ if mask.dim() == 2:
450
+ mask = mask.unsqueeze(0)
451
+ elif mask.dim() == 3:
452
+ # Could be [B,H,W] or [H,W,1].
453
+ if mask.shape[-1] == 1 and mask.shape[0] == height and mask.shape[1] == width:
454
+ mask = mask[..., 0].unsqueeze(0)
455
+ elif mask.dim() == 4:
456
+ if mask.shape[-1] == 1:
457
+ mask = mask[..., 0]
458
+ elif mask.shape[1] == 1:
459
+ mask = mask[:, 0, :, :]
460
+ else:
461
+ mask = mask[:, 0, :, :]
462
+ else:
463
+ raise ValueError(f"Expected MASK shape [H,W], [B,H,W], [B,H,W,1], or [B,1,H,W], got {tuple(mask.shape)}")
464
+
465
+ if mask.dim() != 3:
466
+ raise ValueError(f"Could not normalize MASK shape, got {tuple(mask.shape)}")
467
+
468
+ mask = mask.to(device=device, dtype=dtype).clamp(0.0, 1.0)
469
+
470
+ if mask.shape[1] != height or mask.shape[2] != width:
471
+ mask = _resize_bhw(mask, height, width)
472
+
473
+ if mask.shape[0] == batch:
474
+ return mask
475
+ if mask.shape[0] == 1:
476
+ return mask.repeat(batch, 1, 1)
477
+ if batch == 1:
478
+ return mask[:1]
479
+ raise ValueError(f"MASK batch {mask.shape[0]} does not match IMAGE batch {batch}.")
480
+
481
+
482
+ def _broadcast_batch(
483
+ b_rgb: torch.Tensor,
484
+ b_a: torch.Tensor,
485
+ a_rgb: torch.Tensor,
486
+ a_a: torch.Tensor,
487
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
488
+ b_batch = int(b_rgb.shape[0])
489
+ a_batch = int(a_rgb.shape[0])
490
+ target = max(b_batch, a_batch)
491
+
492
+ def rep_img(x: torch.Tensor) -> torch.Tensor:
493
+ if x.shape[0] == target:
494
+ return x
495
+ if x.shape[0] == 1:
496
+ return x.repeat(target, 1, 1, 1)
497
+ raise ValueError(f"Incompatible IMAGE batches: {b_batch} and {a_batch}")
498
+
499
+ def rep_mask(x: torch.Tensor) -> torch.Tensor:
500
+ if x.shape[0] == target:
501
+ return x
502
+ if x.shape[0] == 1:
503
+ return x.repeat(target, 1, 1)
504
+ raise ValueError(f"Incompatible MASK batches: {b_batch} and {a_batch}")
505
+
506
+ return rep_img(b_rgb), rep_mask(b_a), rep_img(a_rgb), rep_mask(a_a)
507
+
508
+
509
+ def _resize_bhwc(image: torch.Tensor, height: int, width: int) -> torch.Tensor:
510
+ x = image.permute(0, 3, 1, 2)
511
+ x = F.interpolate(x, size=(height, width), mode="bilinear", align_corners=False)
512
+ return x.permute(0, 2, 3, 1).clamp(0.0, 1.0)
513
+
514
+
515
+ def _resize_bhw(mask: torch.Tensor, height: int, width: int) -> torch.Tensor:
516
+ x = mask.unsqueeze(1)
517
+ x = F.interpolate(x, size=(height, width), mode="bilinear", align_corners=False)
518
+ return x[:, 0, :, :].clamp(0.0, 1.0)
519
+
520
+
521
+ def _make_odd(value: int, minimum: int = 3) -> int:
522
+ value = max(int(value), minimum)
523
+ if value % 2 == 0:
524
+ value += 1
525
+ return value
526
+
527
+
528
+ # -----------------------------------------------------------------------------
529
+ # Alignment
530
+ # -----------------------------------------------------------------------------
531
+
532
+
533
+ def _align_after_to_before(
534
+ before_rgb: torch.Tensor,
535
+ before_alpha: torch.Tensor,
536
+ after_rgb: torch.Tensor,
537
+ after_alpha: torch.Tensor,
538
+ align_mode: int,
539
+ max_align_pixels: int,
540
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
541
+ if align_mode >= 2 and cv2 is not None:
542
+ try:
543
+ return _align_ecc_cv2(before_rgb, before_alpha, after_rgb, after_alpha, max_align_pixels)
544
+ except Exception:
545
+ # ECC is allowed to fail. Fall back to phase correlation.
546
+ pass
547
+
548
+ try:
549
+ dy, dx = _phase_correlation_shift(before_rgb, before_alpha, after_rgb, after_alpha)
550
+ if abs(dy) <= max_align_pixels and abs(dx) <= max_align_pixels:
551
+ after_rgb = _translate_hwc(after_rgb, dy, dx)
552
+ after_alpha = _translate_hw(after_alpha, dy, dx)
553
+ except Exception:
554
+ pass
555
+ return after_rgb, after_alpha
556
+
557
+
558
+ def _phase_correlation_shift(
559
+ before_rgb: torch.Tensor,
560
+ before_alpha: torch.Tensor,
561
+ after_rgb: torch.Tensor,
562
+ after_alpha: torch.Tensor,
563
+ ) -> Tuple[float, float]:
564
+ # Returns the estimated shift of after relative to before. To align after to before,
565
+ # sample after at target + shift.
566
+ yb = _luma(before_rgb) * before_alpha
567
+ ya = _luma(after_rgb) * after_alpha
568
+ yb = yb - yb.mean()
569
+ ya = ya - ya.mean()
570
+
571
+ fa = torch.fft.fft2(ya.float())
572
+ fb = torch.fft.fft2(yb.float())
573
+ r = fa * torch.conj(fb)
574
+ r = r / (torch.abs(r) + _EPS)
575
+ corr = torch.fft.ifft2(r).real
576
+ flat_idx = int(torch.argmax(corr).item())
577
+ h, w = corr.shape
578
+ py = flat_idx // w
579
+ px = flat_idx % w
580
+ if py > h // 2:
581
+ py -= h
582
+ if px > w // 2:
583
+ px -= w
584
+ return float(py), float(px)
585
+
586
+
587
+ def _translate_hwc(image: torch.Tensor, shift_y: float, shift_x: float) -> torch.Tensor:
588
+ h, w, c = image.shape
589
+ x = image.permute(2, 0, 1).unsqueeze(0)
590
+ y_coords, x_coords = torch.meshgrid(
591
+ torch.arange(h, device=image.device, dtype=image.dtype),
592
+ torch.arange(w, device=image.device, dtype=image.dtype),
593
+ indexing="ij",
594
+ )
595
+ sample_x = x_coords + float(shift_x)
596
+ sample_y = y_coords + float(shift_y)
597
+ if w > 1:
598
+ sample_x = sample_x / (w - 1) * 2.0 - 1.0
599
+ else:
600
+ sample_x = torch.zeros_like(sample_x)
601
+ if h > 1:
602
+ sample_y = sample_y / (h - 1) * 2.0 - 1.0
603
+ else:
604
+ sample_y = torch.zeros_like(sample_y)
605
+ grid = torch.stack([sample_x, sample_y], dim=-1).unsqueeze(0)
606
+ y = F.grid_sample(x, grid, mode="bilinear", padding_mode="zeros", align_corners=True)
607
+ return y.squeeze(0).permute(1, 2, 0).clamp(0.0, 1.0)
608
+
609
+
610
+ def _translate_hw(mask: torch.Tensor, shift_y: float, shift_x: float) -> torch.Tensor:
611
+ return _translate_hwc(mask.unsqueeze(-1), shift_y, shift_x)[..., 0]
612
+
613
+
614
+ def _align_ecc_cv2(
615
+ before_rgb: torch.Tensor,
616
+ before_alpha: torch.Tensor,
617
+ after_rgb: torch.Tensor,
618
+ after_alpha: torch.Tensor,
619
+ max_align_pixels: int,
620
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
621
+ if cv2 is None:
622
+ raise RuntimeError("OpenCV not available")
623
+
624
+ device = after_rgb.device
625
+ dtype = after_rgb.dtype
626
+ h, w = before_alpha.shape
627
+ template = (_luma(before_rgb) * before_alpha).detach().float().cpu().numpy().astype(np.float32)
628
+ moving = (_luma(after_rgb) * after_alpha).detach().float().cpu().numpy().astype(np.float32)
629
+
630
+ warp = np.eye(2, 3, dtype=np.float32)
631
+ criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 50, 1.0e-5)
632
+ cc, warp = cv2.findTransformECC(template, moving, warp, cv2.MOTION_AFFINE, criteria, None, 5)
633
+
634
+ # Reject extreme translations because a garment change can confuse ECC.
635
+ if abs(float(warp[0, 2])) > max_align_pixels or abs(float(warp[1, 2])) > max_align_pixels:
636
+ raise RuntimeError("ECC alignment rejected due to large transform")
637
+
638
+ a_np = after_rgb.detach().float().cpu().numpy().astype(np.float32)
639
+ aa_np = after_alpha.detach().float().cpu().numpy().astype(np.float32)
640
+ aligned_rgb = cv2.warpAffine(
641
+ a_np,
642
+ warp,
643
+ (w, h),
644
+ flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP,
645
+ borderMode=cv2.BORDER_CONSTANT,
646
+ borderValue=0,
647
+ )
648
+ aligned_a = cv2.warpAffine(
649
+ aa_np,
650
+ warp,
651
+ (w, h),
652
+ flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP,
653
+ borderMode=cv2.BORDER_CONSTANT,
654
+ borderValue=0,
655
+ )
656
+ return (
657
+ torch.from_numpy(aligned_rgb).to(device=device, dtype=dtype).clamp(0.0, 1.0),
658
+ torch.from_numpy(aligned_a).to(device=device, dtype=dtype).clamp(0.0, 1.0),
659
+ )
660
+
661
+
662
+ # -----------------------------------------------------------------------------
663
+ # Filtering
664
+ # -----------------------------------------------------------------------------
665
+
666
+
667
+ def _gaussian_kernel1d(radius: int, sigma: float, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
668
+ radius = max(0, int(radius))
669
+ if radius == 0:
670
+ return torch.ones(1, device=device, dtype=dtype)
671
+ x = torch.arange(-radius, radius + 1, device=device, dtype=dtype)
672
+ sigma = max(float(sigma), 1.0e-3)
673
+ k = torch.exp(-(x * x) / (2.0 * sigma * sigma))
674
+ return k / (k.sum() + _EPS)
675
+
676
+
677
+ def _pad_mode_for(x: torch.Tensor, radius: int) -> str:
678
+ if radius <= 0:
679
+ return "constant"
680
+ h, w = int(x.shape[-2]), int(x.shape[-1])
681
+ return "reflect" if h > radius and w > radius else "replicate"
682
+
683
+
684
+ def _gaussian_blur_bchw(x: torch.Tensor, radius: int, sigma: float) -> torch.Tensor:
685
+ radius = int(radius)
686
+ if radius <= 0:
687
+ return x
688
+ b, c, h, w = x.shape
689
+ k = _gaussian_kernel1d(radius, sigma, x.device, x.dtype)
690
+ kx = k.view(1, 1, 1, -1).repeat(c, 1, 1, 1)
691
+ ky = k.view(1, 1, -1, 1).repeat(c, 1, 1, 1)
692
+ mode = _pad_mode_for(x, radius)
693
+ y = F.pad(x, (radius, radius, 0, 0), mode=mode)
694
+ y = F.conv2d(y, kx, groups=c)
695
+ y = F.pad(y, (0, 0, radius, radius), mode=mode)
696
+ y = F.conv2d(y, ky, groups=c)
697
+ return y.clamp(0.0, 1.0)
698
+
699
+
700
+ def _median_blur_bchw(x: torch.Tensor, radius: int) -> torch.Tensor:
701
+ radius = int(radius)
702
+ if radius <= 0:
703
+ return x
704
+ k = 2 * radius + 1
705
+ b, c, h, w = x.shape
706
+ mode = _pad_mode_for(x, radius)
707
+ xp = F.pad(x, (radius, radius, radius, radius), mode=mode)
708
+ patches = F.unfold(xp, kernel_size=k) # [B, C*k*k, H*W]
709
+ patches = patches.view(b, c, k * k, h, w)
710
+ return patches.median(dim=2).values.clamp(0.0, 1.0)
711
+
712
+
713
+ def _bilateral_or_gaussian_bchw(x: torch.Tensor, radius: int, sigma: float) -> torch.Tensor:
714
+ if cv2 is None:
715
+ return _gaussian_blur_bchw(x, radius, sigma)
716
+ b, c, h, w = x.shape
717
+ device, dtype = x.device, x.dtype
718
+ result = []
719
+ diameter = max(3, 2 * int(radius) + 1)
720
+ for i in range(b):
721
+ arr = x[i].detach().float().cpu().permute(1, 2, 0).numpy().astype(np.float32)
722
+ if c == 1:
723
+ filtered = cv2.bilateralFilter(arr[..., 0], diameter, sigmaColor=0.08, sigmaSpace=max(1.0, float(sigma)))
724
+ filtered = filtered[..., None]
725
+ else:
726
+ filtered = cv2.bilateralFilter(arr, diameter, sigmaColor=0.08, sigmaSpace=max(1.0, float(sigma)))
727
+ result.append(torch.from_numpy(filtered).permute(2, 0, 1))
728
+ return torch.stack(result, dim=0).to(device=device, dtype=dtype).clamp(0.0, 1.0)
729
+
730
+
731
+ # -----------------------------------------------------------------------------
732
+ # Difference maps
733
+ # -----------------------------------------------------------------------------
734
+
735
+
736
+ def _luma(rgb: torch.Tensor) -> torch.Tensor:
737
+ return 0.2126 * rgb[..., 0] + 0.7152 * rgb[..., 1] + 0.0722 * rgb[..., 2]
738
+
739
+
740
+ def _color_difference_map(
741
+ before_rgb: torch.Tensor,
742
+ after_rgb: torch.Tensor,
743
+ before_alpha: torch.Tensor,
744
+ after_alpha: torch.Tensor,
745
+ mode: int,
746
+ ) -> torch.Tensor:
747
+ overlap = torch.sqrt((before_alpha * after_alpha).clamp(0.0, 1.0))
748
+ if mode == 0:
749
+ # Fast RGB L2 on premultiplied color, scaled roughly like Delta-E.
750
+ premul_b = before_rgb * before_alpha.unsqueeze(-1)
751
+ premul_a = after_rgb * after_alpha.unsqueeze(-1)
752
+ return torch.linalg.vector_norm(premul_a - premul_b, dim=-1) * (100.0 / math.sqrt(3.0))
753
+
754
+ lab_b = _rgb_to_lab(before_rgb)
755
+ lab_a = _rgb_to_lab(after_rgb)
756
+ if mode == 2:
757
+ d = torch.linalg.vector_norm(lab_a - lab_b, dim=-1)
758
+ else:
759
+ d = _delta_e_ciede2000(lab_b, lab_a)
760
+
761
+ d = d * overlap
762
+
763
+ if mode >= 3:
764
+ # Hybrid: add a premultiplied color guard for transparent/antialiased boundaries.
765
+ premul_b = before_rgb * before_alpha.unsqueeze(-1)
766
+ premul_a = after_rgb * after_alpha.unsqueeze(-1)
767
+ premul = torch.linalg.vector_norm(premul_a - premul_b, dim=-1) * (100.0 / math.sqrt(3.0))
768
+ d = torch.maximum(d, premul)
769
+ return d.clamp_min(0.0)
770
+
771
+
772
+ def _structure_difference_map(
773
+ before_rgb: torch.Tensor,
774
+ after_rgb: torch.Tensor,
775
+ mode: int,
776
+ window: int,
777
+ sigma: float,
778
+ ) -> torch.Tensor:
779
+ if mode == 2:
780
+ maps = []
781
+ for c in range(3):
782
+ maps.append(_ssim_difference(before_rgb[..., c], after_rgb[..., c], window, sigma))
783
+ return torch.stack(maps, dim=0).mean(dim=0).clamp(0.0, 1.0)
784
+ return _ssim_difference(_luma(before_rgb), _luma(after_rgb), window, sigma).clamp(0.0, 1.0)
785
+
786
+
787
+ def _ssim_difference(x: torch.Tensor, y: torch.Tensor, window: int, sigma: float) -> torch.Tensor:
788
+ radius = _make_odd(window, minimum=3) // 2
789
+ x4 = x.unsqueeze(0).unsqueeze(0)
790
+ y4 = y.unsqueeze(0).unsqueeze(0)
791
+ k = _gaussian_kernel1d(radius, sigma, x.device, x.dtype)
792
+ kx = k.view(1, 1, 1, -1)
793
+ ky = k.view(1, 1, -1, 1)
794
+ mode = _pad_mode_for(x4, radius)
795
+
796
+ def blur(z: torch.Tensor) -> torch.Tensor:
797
+ z = F.pad(z, (radius, radius, 0, 0), mode=mode)
798
+ z = F.conv2d(z, kx)
799
+ z = F.pad(z, (0, 0, radius, radius), mode=mode)
800
+ z = F.conv2d(z, ky)
801
+ return z
802
+
803
+ mux = blur(x4)
804
+ muy = blur(y4)
805
+ mux2 = mux * mux
806
+ muy2 = muy * muy
807
+ muxy = mux * muy
808
+
809
+ sigx2 = blur(x4 * x4) - mux2
810
+ sigy2 = blur(y4 * y4) - muy2
811
+ sigxy = blur(x4 * y4) - muxy
812
+
813
+ c1 = 0.01 ** 2
814
+ c2 = 0.03 ** 2
815
+ ssim = ((2.0 * muxy + c1) * (2.0 * sigxy + c2)) / ((mux2 + muy2 + c1) * (sigx2 + sigy2 + c2) + _EPS)
816
+ # Difference in [0, 1] for normal cases. Negative SSIM maps become strong differences.
817
+ return ((1.0 - ssim.squeeze(0).squeeze(0).clamp(-1.0, 1.0)) * 0.5).clamp(0.0, 1.0)
818
+
819
+
820
+ def _gradient_difference_map(
821
+ before_rgb: torch.Tensor,
822
+ after_rgb: torch.Tensor,
823
+ before_alpha: torch.Tensor,
824
+ after_alpha: torch.Tensor,
825
+ mode: int,
826
+ ) -> torch.Tensor:
827
+ yb = (_luma(before_rgb) * before_alpha).unsqueeze(0).unsqueeze(0)
828
+ ya = (_luma(after_rgb) * after_alpha).unsqueeze(0).unsqueeze(0)
829
+ gx_b, gy_b = _sobel_xy(yb)
830
+ gx_a, gy_a = _sobel_xy(ya)
831
+ if mode == 2:
832
+ d = torch.sqrt((gx_a - gx_b).square() + (gy_a - gy_b).square() + _EPS)
833
+ else:
834
+ mag_b = torch.sqrt(gx_b.square() + gy_b.square() + _EPS)
835
+ mag_a = torch.sqrt(gx_a.square() + gy_a.square() + _EPS)
836
+ d = (mag_a - mag_b).abs()
837
+ return d.squeeze(0).squeeze(0).clamp_min(0.0)
838
+
839
+
840
+ def _sobel_xy(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
841
+ kx = torch.tensor(
842
+ [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]],
843
+ dtype=x.dtype,
844
+ device=x.device,
845
+ ).view(1, 1, 3, 3) / 8.0
846
+ ky = torch.tensor(
847
+ [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]],
848
+ dtype=x.dtype,
849
+ device=x.device,
850
+ ).view(1, 1, 3, 3) / 8.0
851
+ xp = F.pad(x, (1, 1, 1, 1), mode="reflect" if x.shape[-1] > 1 and x.shape[-2] > 1 else "replicate")
852
+ return F.conv2d(xp, kx), F.conv2d(xp, ky)
853
+
854
+
855
+ # -----------------------------------------------------------------------------
856
+ # Color science: sRGB -> Lab and CIEDE2000
857
+ # -----------------------------------------------------------------------------
858
+
859
+
860
+ def _srgb_to_linear(rgb: torch.Tensor) -> torch.Tensor:
861
+ return torch.where(rgb <= 0.04045, rgb / 12.92, torch.pow((rgb + 0.055) / 1.055, 2.4))
862
+
863
+
864
+ def _rgb_to_lab(rgb: torch.Tensor) -> torch.Tensor:
865
+ rgb = rgb.clamp(0.0, 1.0)
866
+ lin = _srgb_to_linear(rgb)
867
+ r, g, b = lin[..., 0], lin[..., 1], lin[..., 2]
868
+
869
+ x = 0.4124564 * r + 0.3575761 * g + 0.1804375 * b
870
+ y = 0.2126729 * r + 0.7151522 * g + 0.0721750 * b
871
+ z = 0.0193339 * r + 0.1191920 * g + 0.9503041 * b
872
+
873
+ # D65 white point.
874
+ x = x / 0.95047
875
+ y = y / 1.00000
876
+ z = z / 1.08883
877
+
878
+ delta = 6.0 / 29.0
879
+
880
+ def f(t: torch.Tensor) -> torch.Tensor:
881
+ return torch.where(t > delta ** 3, torch.pow(t.clamp_min(0.0), 1.0 / 3.0), t / (3.0 * delta * delta) + 4.0 / 29.0)
882
+
883
+ fx, fy, fz = f(x), f(y), f(z)
884
+ L = 116.0 * fy - 16.0
885
+ A = 500.0 * (fx - fy)
886
+ B = 200.0 * (fy - fz)
887
+ return torch.stack([L, A, B], dim=-1)
888
+
889
+
890
+ def _delta_e_ciede2000(lab1: torch.Tensor, lab2: torch.Tensor) -> torch.Tensor:
891
+ L1, a1, b1 = lab1[..., 0], lab1[..., 1], lab1[..., 2]
892
+ L2, a2, b2 = lab2[..., 0], lab2[..., 1], lab2[..., 2]
893
+
894
+ C1 = torch.sqrt(a1 * a1 + b1 * b1 + _EPS)
895
+ C2 = torch.sqrt(a2 * a2 + b2 * b2 + _EPS)
896
+ C_bar = (C1 + C2) * 0.5
897
+ C_bar7 = C_bar.pow(7.0)
898
+ G = 0.5 * (1.0 - torch.sqrt(C_bar7 / (C_bar7 + 25.0 ** 7 + _EPS)))
899
+
900
+ a1p = (1.0 + G) * a1
901
+ a2p = (1.0 + G) * a2
902
+ C1p = torch.sqrt(a1p * a1p + b1 * b1 + _EPS)
903
+ C2p = torch.sqrt(a2p * a2p + b2 * b2 + _EPS)
904
+
905
+ h1p = torch.rad2deg(torch.atan2(b1, a1p)) % 360.0
906
+ h2p = torch.rad2deg(torch.atan2(b2, a2p)) % 360.0
907
+
908
+ dLp = L2 - L1
909
+ dCp = C2p - C1p
910
+
911
+ dh = h2p - h1p
912
+ cprod_zero = (C1p * C2p) <= _EPS
913
+ dhp = torch.where(cprod_zero, torch.zeros_like(dh), dh)
914
+ dhp = torch.where((dhp > 180.0) & (~cprod_zero), dhp - 360.0, dhp)
915
+ dhp = torch.where((dhp < -180.0) & (~cprod_zero), dhp + 360.0, dhp)
916
+ dHp = 2.0 * torch.sqrt(C1p * C2p + _EPS) * torch.sin(torch.deg2rad(dhp * 0.5))
917
+
918
+ Lp_bar = (L1 + L2) * 0.5
919
+ Cp_bar = (C1p + C2p) * 0.5
920
+
921
+ hsum = h1p + h2p
922
+ hdiff = torch.abs(h1p - h2p)
923
+ hp_bar = torch.where(cprod_zero, hsum, hsum * 0.5)
924
+ hp_bar = torch.where((~cprod_zero) & (hdiff > 180.0) & (hsum < 360.0), (hsum + 360.0) * 0.5, hp_bar)
925
+ hp_bar = torch.where((~cprod_zero) & (hdiff > 180.0) & (hsum >= 360.0), (hsum - 360.0) * 0.5, hp_bar)
926
+
927
+ T = (
928
+ 1.0
929
+ - 0.17 * torch.cos(torch.deg2rad(hp_bar - 30.0))
930
+ + 0.24 * torch.cos(torch.deg2rad(2.0 * hp_bar))
931
+ + 0.32 * torch.cos(torch.deg2rad(3.0 * hp_bar + 6.0))
932
+ - 0.20 * torch.cos(torch.deg2rad(4.0 * hp_bar - 63.0))
933
+ )
934
+ delta_theta = 30.0 * torch.exp(-((hp_bar - 275.0) / 25.0).square())
935
+ Cp_bar7 = Cp_bar.pow(7.0)
936
+ Rc = 2.0 * torch.sqrt(Cp_bar7 / (Cp_bar7 + 25.0 ** 7 + _EPS))
937
+ Sl = 1.0 + (0.015 * (Lp_bar - 50.0).square()) / torch.sqrt(20.0 + (Lp_bar - 50.0).square() + _EPS)
938
+ Sc = 1.0 + 0.045 * Cp_bar
939
+ Sh = 1.0 + 0.015 * Cp_bar * T
940
+ Rt = -torch.sin(torch.deg2rad(2.0 * delta_theta)) * Rc
941
+
942
+ dL = dLp / (Sl + _EPS)
943
+ dC = dCp / (Sc + _EPS)
944
+ dH = dHp / (Sh + _EPS)
945
+ de = torch.sqrt((dL * dL + dC * dC + dH * dH + Rt * dC * dH).clamp_min(0.0))
946
+ return de
947
+
948
+
949
+ # -----------------------------------------------------------------------------
950
+ # Normalization and combination
951
+ # -----------------------------------------------------------------------------
952
+
953
+
954
+ def _normalize_map(
955
+ d: torch.Tensor,
956
+ valid: torch.Tensor,
957
+ mode: int,
958
+ noise_floor_k: float,
959
+ mad_scale: float,
960
+ fixed_scale: float,
961
+ ) -> torch.Tensor:
962
+ d = torch.nan_to_num(d.float(), nan=0.0, posinf=0.0, neginf=0.0).clamp_min(0.0)
963
+ if mode < 0:
964
+ return (d / max(float(fixed_scale), _EPS)).clamp(0.0, 1.0)
965
+
966
+ vals = d[valid]
967
+ if vals.numel() < 16:
968
+ vals = d.reshape(-1)
969
+ if vals.numel() < 1:
970
+ return torch.zeros_like(d)
971
+
972
+ vals = vals.float()
973
+ q50 = torch.quantile(vals, 0.50)
974
+ q95 = torch.quantile(vals, 0.95)
975
+ q99 = torch.quantile(vals, 0.99)
976
+
977
+ if mode == 0:
978
+ denom = (q95 - q50).abs().clamp_min(1.0e-6)
979
+ return ((d - q50) / denom).clamp(0.0, 1.0)
980
+
981
+ med = q50
982
+ mad = torch.quantile((vals - med).abs(), 0.50) * 1.4826
983
+ floor = med + float(noise_floor_k) * mad
984
+
985
+ if mode == 2:
986
+ # Hybrid: threshold by MAD, but stretch by high percentile for less brittle behavior.
987
+ denom = torch.maximum((q95 - floor).abs(), float(mad_scale) * mad).clamp_min(1.0e-6)
988
+ else:
989
+ denom = (float(mad_scale) * mad).clamp_min(1.0e-6)
990
+
991
+ # Fallback if the image is almost constant and MAD collapses.
992
+ if float(denom.detach().cpu()) <= 1.0e-5:
993
+ floor = q50
994
+ denom = (q99 - q50).abs().clamp_min(1.0e-6)
995
+
996
+ return ((d - floor) / denom).clamp(0.0, 1.0)
997
+
998
+
999
+ def _combine_maps(maps: Sequence[Tuple[torch.Tensor, float, str, float]], mode: int) -> torch.Tensor:
1000
+ tensors = [m[0].clamp(0.0, 1.0) for m in maps]
1001
+ weights = [max(0.0, float(m[1])) for m in maps]
1002
+ total_w = sum(weights)
1003
+ if total_w <= _EPS:
1004
+ return torch.zeros_like(tensors[0])
1005
+
1006
+ if mode == 2:
1007
+ # Max catches small decisive signals. Weights become exponents/sensitivity.
1008
+ weighted = [t * (w / total_w * len(tensors)) for t, w in zip(tensors, weights) if w > 0.0]
1009
+ return torch.stack(weighted, dim=0).max(dim=0).values.clamp(0.0, 1.0)
1010
+
1011
+ if mode == 3:
1012
+ # Noisy OR: useful when any strong cue should fire, but isolated weak cues should not dominate.
1013
+ acc = torch.ones_like(tensors[0])
1014
+ for t, w in zip(tensors, weights):
1015
+ if w > 0.0:
1016
+ acc = acc * (1.0 - (t * (w / total_w)).clamp(0.0, 1.0))
1017
+ return (1.0 - acc).clamp(0.0, 1.0)
1018
+
1019
+ # mode -1, 0, 1 all resolve to weighted sum; mode 1 is the recommended one.
1020
+ acc = torch.zeros_like(tensors[0])
1021
+ for t, w in zip(tensors, weights):
1022
+ if w > 0.0:
1023
+ acc = acc + t * w
1024
+ return (acc / total_w).clamp(0.0, 1.0)
1025
+
1026
+
1027
+ # -----------------------------------------------------------------------------
1028
+ # Core mask, morphology, components, feathering
1029
+ # -----------------------------------------------------------------------------
1030
+
1031
+
1032
+ def _make_core_mask(score: np.ndarray, low_threshold: float, high_threshold: float, hysteresis_mode: int) -> np.ndarray:
1033
+ score = np.nan_to_num(score.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)
1034
+ if hysteresis_mode < 0:
1035
+ return score >= float(high_threshold)
1036
+ if hysteresis_mode == 2:
1037
+ return score >= float(high_threshold)
1038
+
1039
+ strong = score >= float(high_threshold)
1040
+ weak = score >= float(low_threshold)
1041
+ if not strong.any():
1042
+ return strong
1043
+ return _hysteresis_connected(weak, strong)
1044
+
1045
+
1046
+ def _hysteresis_connected(weak: np.ndarray, strong: np.ndarray) -> np.ndarray:
1047
+ weak = weak.astype(bool)
1048
+ strong = strong.astype(bool)
1049
+ if ndi is not None:
1050
+ labels, n = ndi.label(weak, structure=np.ones((3, 3), dtype=np.uint8))
1051
+ if n == 0:
1052
+ return np.zeros_like(weak, dtype=bool)
1053
+ strong_labels = np.unique(labels[strong])
1054
+ strong_labels = strong_labels[strong_labels != 0]
1055
+ if len(strong_labels) == 0:
1056
+ return np.zeros_like(weak, dtype=bool)
1057
+ return np.isin(labels, strong_labels)
1058
+
1059
+ h, w = weak.shape
1060
+ out = np.zeros_like(weak, dtype=bool)
1061
+ seen = np.zeros_like(weak, dtype=bool)
1062
+ neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
1063
+
1064
+ ys, xs = np.nonzero(weak)
1065
+ for sy, sx in zip(ys, xs):
1066
+ if seen[sy, sx]:
1067
+ continue
1068
+ q = deque([(int(sy), int(sx))])
1069
+ seen[sy, sx] = True
1070
+ coords = []
1071
+ has_strong = False
1072
+ while q:
1073
+ y, x = q.popleft()
1074
+ coords.append((y, x))
1075
+ if strong[y, x]:
1076
+ has_strong = True
1077
+ for dy, dx in neighbors:
1078
+ ny, nx = y + dy, x + dx
1079
+ if 0 <= ny < h and 0 <= nx < w and weak[ny, nx] and not seen[ny, nx]:
1080
+ seen[ny, nx] = True
1081
+ q.append((ny, nx))
1082
+ if has_strong:
1083
+ for y, x in coords:
1084
+ out[y, x] = True
1085
+ return out
1086
+
1087
+
1088
+ def _morph_binary(mask: np.ndarray, mode: int, radius: int) -> np.ndarray:
1089
+ radius = int(radius)
1090
+ if radius <= 0:
1091
+ return mask.astype(bool)
1092
+ x = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0).unsqueeze(0)
1093
+
1094
+ def dilate(t: torch.Tensor) -> torch.Tensor:
1095
+ return F.max_pool2d(t, kernel_size=2 * radius + 1, stride=1, padding=radius)
1096
+
1097
+ def erode(t: torch.Tensor) -> torch.Tensor:
1098
+ return 1.0 - F.max_pool2d(1.0 - t, kernel_size=2 * radius + 1, stride=1, padding=radius)
1099
+
1100
+ if mode == 2:
1101
+ y = erode(dilate(x)) # close only
1102
+ elif mode == 3:
1103
+ y = dilate(x)
1104
+ elif mode == 4:
1105
+ y = erode(x)
1106
+ else:
1107
+ y = dilate(erode(x)) # open
1108
+ y = erode(dilate(y)) # close
1109
+ return (y.squeeze(0).squeeze(0).numpy() >= 0.5)
1110
+
1111
+
1112
+ def _filter_components(mask: np.ndarray, min_area: int, keep_largest: int) -> np.ndarray:
1113
+ mask = mask.astype(bool)
1114
+ if not mask.any():
1115
+ return mask
1116
+ min_area = max(0, int(min_area))
1117
+ keep_largest = max(0, int(keep_largest))
1118
+
1119
+ if cv2 is not None:
1120
+ num, labels, stats, _ = cv2.connectedComponentsWithStats(mask.astype(np.uint8), connectivity=8)
1121
+ if num <= 1:
1122
+ return mask
1123
+ components = []
1124
+ for i in range(1, num):
1125
+ area = int(stats[i, cv2.CC_STAT_AREA])
1126
+ if area >= min_area:
1127
+ components.append((i, area))
1128
+ if keep_largest > 0:
1129
+ components = sorted(components, key=lambda x: x[1], reverse=True)[:keep_largest]
1130
+ keep_ids = {i for i, _ in components}
1131
+ return np.isin(labels, list(keep_ids))
1132
+
1133
+ if ndi is not None:
1134
+ labels, n = ndi.label(mask, structure=np.ones((3, 3), dtype=np.uint8))
1135
+ areas = np.bincount(labels.ravel())
1136
+ components = [(i, int(areas[i])) for i in range(1, n + 1) if int(areas[i]) >= min_area]
1137
+ if keep_largest > 0:
1138
+ components = sorted(components, key=lambda x: x[1], reverse=True)[:keep_largest]
1139
+ keep_ids = [i for i, _ in components]
1140
+ return np.isin(labels, keep_ids)
1141
+
1142
+ # Pure numpy fallback.
1143
+ h, w = mask.shape
1144
+ out = np.zeros_like(mask, dtype=bool)
1145
+ seen = np.zeros_like(mask, dtype=bool)
1146
+ comps: List[List[Tuple[int, int]]] = []
1147
+ neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
1148
+ ys, xs = np.nonzero(mask)
1149
+ for sy, sx in zip(ys, xs):
1150
+ if seen[sy, sx]:
1151
+ continue
1152
+ q = deque([(int(sy), int(sx))])
1153
+ seen[sy, sx] = True
1154
+ coords = []
1155
+ while q:
1156
+ y, x = q.popleft()
1157
+ coords.append((y, x))
1158
+ for dy, dx in neighbors:
1159
+ ny, nx = y + dy, x + dx
1160
+ if 0 <= ny < h and 0 <= nx < w and mask[ny, nx] and not seen[ny, nx]:
1161
+ seen[ny, nx] = True
1162
+ q.append((ny, nx))
1163
+ if len(coords) >= min_area:
1164
+ comps.append(coords)
1165
+
1166
+ if keep_largest > 0:
1167
+ comps = sorted(comps, key=len, reverse=True)[:keep_largest]
1168
+ for coords in comps:
1169
+ for y, x in coords:
1170
+ out[y, x] = True
1171
+ return out
1172
+
1173
+
1174
+ def _feather_core(mask: np.ndarray, mode: int, radius: int, logistic_steepness: float) -> np.ndarray:
1175
+ mask = mask.astype(bool)
1176
+ if mode < 0 or radius <= 0:
1177
+ return mask.astype(np.float32)
1178
+ if not mask.any():
1179
+ return np.zeros_like(mask, dtype=np.float32)
1180
+
1181
+ if mode == 1:
1182
+ x = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0).unsqueeze(0)
1183
+ y = _gaussian_blur_bchw(x, radius=max(1, radius), sigma=max(0.1, radius / 2.0))
1184
+ return y.squeeze(0).squeeze(0).numpy().clip(0.0, 1.0).astype(np.float32)
1185
+
1186
+ dist = _outside_distance(mask, max_radius=radius).astype(np.float32)
1187
+ inside = mask
1188
+ d = np.clip(dist, 0.0, float(radius))
1189
+
1190
+ if mode == 3:
1191
+ r = max(float(radius), 1.0)
1192
+ z = logistic_steepness * (0.5 - d / r)
1193
+ raw = 1.0 / (1.0 + np.exp(-z))
1194
+ raw0 = 1.0 / (1.0 + math.exp(-logistic_steepness * 0.5))
1195
+ rawr = 1.0 / (1.0 + math.exp(logistic_steepness * 0.5))
1196
+ falloff = (raw - rawr) / max(raw0 - rawr, 1.0e-6)
1197
+ else:
1198
+ # Smoothstep distance falloff. Recommended default.
1199
+ t = np.clip(1.0 - d / max(float(radius), 1.0), 0.0, 1.0)
1200
+ falloff = t * t * (3.0 - 2.0 * t)
1201
+
1202
+ falloff[d >= float(radius)] = 0.0
1203
+ falloff[inside] = 1.0
1204
+ return falloff.clip(0.0, 1.0).astype(np.float32)
1205
+
1206
+
1207
+ def _outside_distance(mask: np.ndarray, max_radius: int) -> np.ndarray:
1208
+ # Distance outside mask to nearest mask pixel. Inside mask is 0.
1209
+ if cv2 is not None:
1210
+ outside = (~mask).astype(np.uint8)
1211
+ return cv2.distanceTransform(outside, cv2.DIST_L2, 5)
1212
+ if ndi is not None:
1213
+ return ndi.distance_transform_edt(~mask).astype(np.float32)
1214
+
1215
+ # Fallback: Chebyshev ring distance up to max_radius.
1216
+ max_radius = max(1, int(max_radius))
1217
+ dist = np.full(mask.shape, fill_value=max_radius + 1, dtype=np.float32)
1218
+ current = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0).unsqueeze(0)
1219
+ dist[mask] = 0.0
1220
+ prev = current.clone()
1221
+ for r in range(1, max_radius + 1):
1222
+ current = F.max_pool2d(current, kernel_size=3, stride=1, padding=1)
1223
+ ring = (current.squeeze().numpy() >= 0.5) & (prev.squeeze().numpy() < 0.5)
1224
+ dist[ring] = float(r)
1225
+ prev = current.clone()
1226
+ return dist
1227
+
1228
+
1229
+ NODE_CLASS_MAPPINGS = {
1230
+ "Salia_Get_Diff_Mask": Salia_Get_Diff_Mask,
1231
+ }
1232
+
1233
+ NODE_DISPLAY_NAME_MAPPINGS = {
1234
+ "Salia_Get_Diff_Mask": "Salia_Get_Diff_Mask",
1235
+ }