saliacoel commited on
Commit
cf1376b
·
verified ·
1 Parent(s): 71aa491

Upload salia_compare_img.py

Browse files
Files changed (1) hide show
  1. salia_compare_img.py +304 -0
salia_compare_img.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ # -----------------------------
5
+ # Helpers
6
+ # -----------------------------
7
+
8
+ def _bhwc_to_nchw(img: torch.Tensor) -> torch.Tensor:
9
+ # ComfyUI IMAGE is usually float32 in [0,1], shape [B,H,W,C]
10
+ if img.dim() != 4:
11
+ raise ValueError(f"Expected IMAGE tensor with 4 dims [B,H,W,C], got {img.shape}")
12
+ return img.permute(0, 3, 1, 2).contiguous()
13
+
14
+ def _drop_alpha_if_any(x: torch.Tensor) -> torch.Tensor:
15
+ # If RGBA, keep RGB
16
+ if x.shape[1] > 3:
17
+ return x[:, :3, :, :].contiguous()
18
+ return x
19
+
20
+ def _to_luma(x: torch.Tensor) -> torch.Tensor:
21
+ # x: [B,C,H,W], expects C=1 or C=3
22
+ if x.shape[1] == 1:
23
+ return x
24
+ r = x[:, 0:1, :, :]
25
+ g = x[:, 1:2, :, :]
26
+ b = x[:, 2:3, :, :]
27
+ # Standard-ish luma weights
28
+ return (0.2989 * r + 0.5870 * g + 0.1140 * b)
29
+
30
+ def _resize_max(x: torch.Tensor, max_size: int) -> torch.Tensor:
31
+ if max_size <= 0:
32
+ return x
33
+ b, c, h, w = x.shape
34
+ m = max(h, w)
35
+ if m <= max_size:
36
+ return x
37
+ scale = max_size / float(m)
38
+ nh = max(1, int(round(h * scale)))
39
+ nw = max(1, int(round(w * scale)))
40
+ return F.interpolate(x, size=(nh, nw), mode="bilinear", align_corners=False)
41
+
42
+ def _gaussian_blur(x: torch.Tensor, sigma: float) -> torch.Tensor:
43
+ if sigma <= 0:
44
+ return x
45
+
46
+ # radius ~ 3*sigma
47
+ radius = int(max(1, round(3.0 * sigma)))
48
+ ksize = 2 * radius + 1
49
+ device = x.device
50
+ dtype = x.dtype
51
+
52
+ coords = torch.arange(-radius, radius + 1, device=device, dtype=dtype)
53
+ kernel1d = torch.exp(-(coords * coords) / (2.0 * sigma * sigma))
54
+ kernel1d = kernel1d / (kernel1d.sum() + 1e-12)
55
+
56
+ c = x.shape[1]
57
+
58
+ # separable conv: horizontal then vertical
59
+ kh = kernel1d.view(1, 1, 1, ksize).repeat(c, 1, 1, 1)
60
+ kv = kernel1d.view(1, 1, ksize, 1).repeat(c, 1, 1, 1)
61
+
62
+ out = F.conv2d(x, kh, padding=(0, radius), groups=c)
63
+ out = F.conv2d(out, kv, padding=(radius, 0), groups=c)
64
+ return out
65
+
66
+ def _sobel_edges(y: torch.Tensor) -> torch.Tensor:
67
+ # y: [B,1,H,W] or [B,C,H,W]
68
+ device = y.device
69
+ dtype = y.dtype
70
+ c = y.shape[1]
71
+
72
+ kx = torch.tensor(
73
+ [[-1, 0, 1],
74
+ [-2, 0, 2],
75
+ [-1, 0, 1]],
76
+ device=device, dtype=dtype
77
+ ) / 8.0
78
+
79
+ ky = torch.tensor(
80
+ [[-1, -2, -1],
81
+ [ 0, 0, 0],
82
+ [ 1, 2, 1]],
83
+ device=device, dtype=dtype
84
+ ) / 8.0
85
+
86
+ kx = kx.view(1, 1, 3, 3).repeat(c, 1, 1, 1)
87
+ ky = ky.view(1, 1, 3, 3).repeat(c, 1, 1, 1)
88
+
89
+ gx = F.conv2d(y, kx, padding=1, groups=c)
90
+ gy = F.conv2d(y, ky, padding=1, groups=c)
91
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
92
+
93
+ def _ssim(x: torch.Tensor, y: torch.Tensor, window_size: int = 11, sigma: float = 1.5) -> torch.Tensor:
94
+ """
95
+ SSIM per batch item. Returns shape [B], roughly in [0,1] for normal images.
96
+ x,y: [B,C,H,W]
97
+ """
98
+ device = x.device
99
+ dtype = x.dtype
100
+ c = x.shape[1]
101
+ radius = window_size // 2
102
+
103
+ coords = torch.arange(window_size, device=device, dtype=dtype) - radius
104
+ g = torch.exp(-(coords * coords) / (2.0 * sigma * sigma))
105
+ g = g / (g.sum() + 1e-12)
106
+ w2d = (g[:, None] * g[None, :]).view(1, 1, window_size, window_size)
107
+ w2d = w2d.repeat(c, 1, 1, 1)
108
+
109
+ mu_x = F.conv2d(x, w2d, padding=radius, groups=c)
110
+ mu_y = F.conv2d(y, w2d, padding=radius, groups=c)
111
+
112
+ mu_x2 = mu_x * mu_x
113
+ mu_y2 = mu_y * mu_y
114
+ mu_xy = mu_x * mu_y
115
+
116
+ sigma_x2 = F.conv2d(x * x, w2d, padding=radius, groups=c) - mu_x2
117
+ sigma_y2 = F.conv2d(y * y, w2d, padding=radius, groups=c) - mu_y2
118
+ sigma_xy = F.conv2d(x * y, w2d, padding=radius, groups=c) - mu_xy
119
+
120
+ C1 = (0.01) ** 2
121
+ C2 = (0.03) ** 2
122
+
123
+ num = (2.0 * mu_xy + C1) * (2.0 * sigma_xy + C2)
124
+ den = (mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2)
125
+
126
+ ssim_map = num / (den + 1e-12)
127
+ return ssim_map.mean(dim=[1, 2, 3]) # [B]
128
+
129
+ def _hist_chi2(x: torch.Tensor, y: torch.Tensor, bins: int = 32) -> torch.Tensor:
130
+ """
131
+ Color histogram chi-square distance. Returns [B].
132
+ Done on CPU for compatibility (hist ops can be awkward on some GPUs).
133
+ x,y: [B,C,H,W] in [0,1]
134
+ """
135
+ x_cpu = x.detach().float().cpu()
136
+ y_cpu = y.detach().float().cpu()
137
+ b, c, _, _ = x_cpu.shape
138
+ out = []
139
+
140
+ eps = 1e-12
141
+ for i in range(b):
142
+ dist = 0.0
143
+ for ch in range(c):
144
+ hx = torch.histc(x_cpu[i, ch], bins=bins, min=0.0, max=1.0)
145
+ hy = torch.histc(y_cpu[i, ch], bins=bins, min=0.0, max=1.0)
146
+ hx = hx / (hx.sum() + eps)
147
+ hy = hy / (hy.sum() + eps)
148
+
149
+ # chi-square distance
150
+ dist += 0.5 * torch.sum((hx - hy) ** 2 / (hx + hy + eps)).item()
151
+ out.append(dist / float(c))
152
+
153
+ return torch.tensor(out, dtype=torch.float32, device=x.device)
154
+
155
+
156
+ # -----------------------------
157
+ # ComfyUI Node
158
+ # -----------------------------
159
+
160
+ class ImageCompareFloat:
161
+ """
162
+ Compares two ComfyUI IMAGE inputs and returns a single float score:
163
+ - smaller score => more similar (likely frozen)
164
+ - larger score => more different (moving)
165
+ """
166
+
167
+ @classmethod
168
+ def INPUT_TYPES(cls):
169
+ return {
170
+ "required": {
171
+ "image_a": ("IMAGE",),
172
+ "image_b": ("IMAGE",),
173
+ "mode": (["pixel_mae", "ssim", "hist_chi2", "hybrid"],),
174
+ "color_space": (["RGB", "LUMA"],),
175
+ "downscale_max": ("INT", {"default": 256, "min": 32, "max": 2048, "step": 16}),
176
+ "blur_sigma": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.1}),
177
+ "hist_bins": ("INT", {"default": 32, "min": 8, "max": 256, "step": 8}),
178
+ "scale": ("FLOAT", {"default": 1000.0, "min": 0.001, "max": 1000000.0, "step": 1.0}),
179
+
180
+ # Hybrid weights (used only when mode="hybrid")
181
+ "w_pixel": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.05}),
182
+ "w_ssim": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.05}),
183
+ "w_edge": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.05}),
184
+ "w_hist": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.05}),
185
+ }
186
+ }
187
+
188
+ RETURN_TYPES = ("FLOAT",)
189
+ RETURN_NAMES = ("difference",)
190
+ FUNCTION = "compare"
191
+ CATEGORY = "image/analysis"
192
+
193
+ def compare(
194
+ self,
195
+ image_a,
196
+ image_b,
197
+ mode,
198
+ color_space,
199
+ downscale_max,
200
+ blur_sigma,
201
+ hist_bins,
202
+ scale,
203
+ w_pixel,
204
+ w_ssim,
205
+ w_edge,
206
+ w_hist,
207
+ ):
208
+ a = _bhwc_to_nchw(image_a)
209
+ b = _bhwc_to_nchw(image_b)
210
+
211
+ a = _drop_alpha_if_any(a)
212
+ b = _drop_alpha_if_any(b)
213
+
214
+ # Match batch sizes: if one is batch=1 and other is batch>1, broadcast the 1
215
+ if a.shape[0] != b.shape[0]:
216
+ if a.shape[0] == 1:
217
+ a = a.repeat(b.shape[0], 1, 1, 1)
218
+ elif b.shape[0] == 1:
219
+ b = b.repeat(a.shape[0], 1, 1, 1)
220
+ else:
221
+ m = min(a.shape[0], b.shape[0])
222
+ a = a[:m]
223
+ b = b[:m]
224
+
225
+ # Match spatial size (avoid errors if upstream produced different sizes)
226
+ if a.shape[2:] != b.shape[2:]:
227
+ b = F.interpolate(b, size=a.shape[2:], mode="bilinear", align_corners=False)
228
+
229
+ # Clamp to safe range
230
+ a = a.clamp(0.0, 1.0)
231
+ b = b.clamp(0.0, 1.0)
232
+
233
+ # Downscale for speed + robustness
234
+ a = _resize_max(a, downscale_max)
235
+ b = _resize_max(b, downscale_max)
236
+
237
+ # Select comparison space
238
+ if color_space == "LUMA":
239
+ a_cs = _to_luma(a)
240
+ b_cs = _to_luma(b)
241
+ else:
242
+ a_cs = a
243
+ b_cs = b
244
+
245
+ # Blur to ignore tiny diffusion flicker / grain
246
+ a_blur = _gaussian_blur(a_cs, blur_sigma)
247
+ b_blur = _gaussian_blur(b_cs, blur_sigma)
248
+
249
+ if mode == "pixel_mae":
250
+ per_item = torch.mean(torch.abs(a_blur - b_blur), dim=[1, 2, 3])
251
+
252
+ elif mode == "ssim":
253
+ # SSIM is more stable on luma/structure, so force luma for this metric
254
+ a_y = _to_luma(a_blur) if a_blur.shape[1] != 1 else a_blur
255
+ b_y = _to_luma(b_blur) if b_blur.shape[1] != 1 else b_blur
256
+ s = _ssim(a_y, b_y)
257
+ per_item = (1.0 - s).clamp(min=0.0)
258
+
259
+ elif mode == "hist_chi2":
260
+ # Histograms should use RGB if available (color distribution)
261
+ a_rgb = a if a.shape[1] == 3 else a.repeat(1, 3, 1, 1)
262
+ b_rgb = b if b.shape[1] == 3 else b.repeat(1, 3, 1, 1)
263
+ a_rgb = _resize_max(a_rgb, downscale_max)
264
+ b_rgb = _resize_max(b_rgb, downscale_max)
265
+ per_item = _hist_chi2(a_rgb, b_rgb, bins=hist_bins)
266
+
267
+ elif mode == "hybrid":
268
+ # Pixel MAE (blurred)
269
+ pix = torch.mean(torch.abs(a_blur - b_blur), dim=[1, 2, 3])
270
+
271
+ # SSIM diff on luma
272
+ a_y = _to_luma(a_blur) if a_blur.shape[1] != 1 else a_blur
273
+ b_y = _to_luma(b_blur) if b_blur.shape[1] != 1 else b_blur
274
+ ssim_diff = (1.0 - _ssim(a_y, b_y)).clamp(min=0.0)
275
+
276
+ # Edge MAE on luma (good against tiny color shifts)
277
+ ea = _sobel_edges(a_y)
278
+ eb = _sobel_edges(b_y)
279
+ edge = torch.mean(torch.abs(ea - eb), dim=[1, 2, 3])
280
+
281
+ # Histogram chi2 on RGB (global color changes)
282
+ a_rgb = a if a.shape[1] == 3 else a.repeat(1, 3, 1, 1)
283
+ b_rgb = b if b.shape[1] == 3 else b.repeat(1, 3, 1, 1)
284
+ a_rgb = _resize_max(a_rgb, downscale_max)
285
+ b_rgb = _resize_max(b_rgb, downscale_max)
286
+ hist = _hist_chi2(a_rgb, b_rgb, bins=hist_bins)
287
+
288
+ per_item = (w_pixel * pix) + (w_ssim * ssim_diff) + (w_edge * edge) + (w_hist * hist)
289
+
290
+ else:
291
+ raise ValueError(f"Unknown mode: {mode}")
292
+
293
+ # Reduce to single float (average across batch)
294
+ score = float(per_item.mean().item() * scale)
295
+ return (score,)
296
+
297
+
298
+ NODE_CLASS_MAPPINGS = {
299
+ "ImageCompareFloat": ImageCompareFloat
300
+ }
301
+
302
+ NODE_DISPLAY_NAME_MAPPINGS = {
303
+ "ImageCompareFloat": "Image Compare → Float (Freeze Detect)"
304
+ }