saliacoel commited on
Commit
58a1f6b
·
verified ·
1 Parent(s): 8de5a16

Upload salia_sprite_batch_stabilizer.py

Browse files
Files changed (1) hide show
  1. salia_sprite_batch_stabilizer.py +306 -0
salia_sprite_batch_stabilizer.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import deque
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ class SpriteBatchStabilizeToTarget:
9
+ """
10
+ ComfyUI IMAGE batch node.
11
+
12
+ Input image tensor: [B, H, W, C], C can be 3/RGB or 4/RGBA.
13
+ Output image tensor: [B, H_out, W_out, 3].
14
+
15
+ The node:
16
+ 1. Composites RGBA over white, if needed.
17
+ 2. Estimates the white/off-white/grey background color from the image border.
18
+ 3. Looks along coord_y_height for the largest contiguous non-background sprite run.
19
+ 4. Moves the whole sprite image so that that run's center lands on target_center_x/y.
20
+ 5. Expands the whole batch canvas enough that no shifted image pixels are clipped.
21
+ 6. Re-composites onto white and returns RGB.
22
+ """
23
+
24
+ # Internal tuning constants. Increase MIN_BACKGROUND_TOLERANCE if JPEG/grey
25
+ # background noise is being detected as sprite. Decrease it if very pale
26
+ # sprites are being ignored.
27
+ MIN_BACKGROUND_TOLERANCE = 0.055
28
+ MAX_BACKGROUND_TOLERANCE = 0.22
29
+ NOISE_SIGMA_MULTIPLIER = 6.0
30
+ SMALL_GAP_FRACTION_OF_WIDTH = 0.01
31
+ SMALL_GAP_MIN_PIXELS = 2
32
+ SMALL_GAP_MAX_PIXELS = 12
33
+
34
+ @classmethod
35
+ def INPUT_TYPES(cls):
36
+ return {
37
+ "required": {
38
+ "images": ("IMAGE",),
39
+ "coord_y_height": ("INT", {
40
+ "default": 0,
41
+ "min": 0,
42
+ "max": 65535,
43
+ "step": 1,
44
+ "display": "number",
45
+ }),
46
+ "target_center_x": ("INT", {
47
+ "default": 0,
48
+ "min": -65535,
49
+ "max": 65535,
50
+ "step": 1,
51
+ "display": "number",
52
+ }),
53
+ "target_center_y": ("INT", {
54
+ "default": 0,
55
+ "min": -65535,
56
+ "max": 65535,
57
+ "step": 1,
58
+ "display": "number",
59
+ }),
60
+ }
61
+ }
62
+
63
+ RETURN_TYPES = ("IMAGE",)
64
+ RETURN_NAMES = ("images",)
65
+ FUNCTION = "stabilize"
66
+ CATEGORY = "image/sprite"
67
+
68
+ def stabilize(self, images, coord_y_height, target_center_x, target_center_y):
69
+ if not torch.is_tensor(images):
70
+ raise TypeError("images must be a torch.Tensor in ComfyUI IMAGE format [B,H,W,C].")
71
+ if images.ndim != 4:
72
+ raise ValueError(f"Expected IMAGE tensor shape [B,H,W,C], got {tuple(images.shape)}.")
73
+
74
+ batch, height, width, channels = images.shape
75
+ if channels not in (3, 4):
76
+ raise ValueError(f"Expected RGB or RGBA images with C=3 or C=4, got C={channels}.")
77
+ if batch < 1 or height < 1 or width < 1:
78
+ raise ValueError("images must contain at least one non-empty image.")
79
+
80
+ input_device = images.device
81
+ input_dtype = images.dtype if images.dtype.is_floating_point else torch.float32
82
+
83
+ rgb = self._rgba_or_rgb_to_rgb_float(images)
84
+ rgb_np = rgb.detach().cpu().numpy().astype(np.float32, copy=False)
85
+
86
+ scan_y = int(np.clip(coord_y_height, 0, height - 1))
87
+ target_x = int(target_center_x)
88
+ target_y = int(target_center_y)
89
+
90
+ prepared = []
91
+ shifts_x = []
92
+ shifts_y = []
93
+
94
+ for index in range(batch):
95
+ arr = rgb_np[index]
96
+ bg_color = self._estimate_background_color(arr)
97
+ dist = self._color_distance(arr, bg_color)
98
+ threshold = self._adaptive_background_threshold(dist)
99
+
100
+ center_x, found = self._find_sprite_center_x_on_row(
101
+ row_distance=dist[scan_y],
102
+ threshold=threshold,
103
+ width=width,
104
+ )
105
+
106
+ if found:
107
+ dx = int(round(target_x - center_x))
108
+ dy = int(round(target_y - scan_y))
109
+ else:
110
+ # Conservative fallback: if the requested scanline does not hit
111
+ # any sprite pixels, do not introduce a potentially wild shift.
112
+ dx = 0
113
+ dy = 0
114
+
115
+ alpha = self._external_background_alpha(dist, threshold)
116
+
117
+ prepared.append((arr, alpha))
118
+ shifts_x.append(dx)
119
+ shifts_y.append(dy)
120
+
121
+ pad_left = int(max(0, max((-dx for dx in shifts_x), default=0)))
122
+ pad_right = int(max(0, max((dx for dx in shifts_x), default=0)))
123
+ pad_top = int(max(0, max((-dy for dy in shifts_y), default=0)))
124
+ pad_bottom = int(max(0, max((dy for dy in shifts_y), default=0)))
125
+
126
+ out_height = height + pad_top + pad_bottom
127
+ out_width = width + pad_left + pad_right
128
+
129
+ outputs = []
130
+ for (arr, alpha), dx, dy in zip(prepared, shifts_x, shifts_y):
131
+ rgba_canvas = np.zeros((out_height, out_width, 4), dtype=np.float32)
132
+ x0 = pad_left + dx
133
+ y0 = pad_top + dy
134
+
135
+ rgba_canvas[y0:y0 + height, x0:x0 + width, 0:3] = arr
136
+ rgba_canvas[y0:y0 + height, x0:x0 + width, 3] = alpha
137
+
138
+ a = rgba_canvas[..., 3:4]
139
+ out_rgb = rgba_canvas[..., 0:3] * a + (1.0 - a) # white background
140
+ outputs.append(np.clip(out_rgb, 0.0, 1.0))
141
+
142
+ out = torch.from_numpy(np.stack(outputs, axis=0)).to(device=input_device, dtype=input_dtype)
143
+ return (out,)
144
+
145
+ @staticmethod
146
+ def _rgba_or_rgb_to_rgb_float(images):
147
+ img = images.to(dtype=torch.float32).clamp(0.0, 1.0)
148
+ if img.shape[-1] == 4:
149
+ rgb = img[..., 0:3]
150
+ alpha = img[..., 3:4]
151
+ return rgb * alpha + (1.0 - alpha) # composite over white
152
+ return img[..., 0:3]
153
+
154
+ @staticmethod
155
+ def _estimate_background_color(arr):
156
+ h, w, _ = arr.shape
157
+ strip = max(1, min(8, min(h, w) // 64 if min(h, w) >= 64 else 1))
158
+
159
+ samples = [
160
+ arr[:strip, :, :].reshape(-1, 3),
161
+ arr[h - strip:, :, :].reshape(-1, 3),
162
+ arr[:, :strip, :].reshape(-1, 3),
163
+ arr[:, w - strip:, :].reshape(-1, 3),
164
+ ]
165
+ border = np.concatenate(samples, axis=0)
166
+
167
+ # Median is robust if a small part of the sprite touches an edge.
168
+ return np.median(border, axis=0).astype(np.float32)
169
+
170
+ @staticmethod
171
+ def _color_distance(arr, bg_color):
172
+ # RMS RGB distance in 0..1. RMS is easier to tune than full Euclidean.
173
+ delta = arr - bg_color.reshape(1, 1, 3)
174
+ return np.sqrt(np.mean(delta * delta, axis=2)).astype(np.float32)
175
+
176
+ def _adaptive_background_threshold(self, dist):
177
+ h, w = dist.shape
178
+ strip = max(1, min(8, min(h, w) // 64 if min(h, w) >= 64 else 1))
179
+ border = np.concatenate([
180
+ dist[:strip, :].reshape(-1),
181
+ dist[h - strip:, :].reshape(-1),
182
+ dist[:, :strip].reshape(-1),
183
+ dist[:, w - strip:].reshape(-1),
184
+ ])
185
+
186
+ med = float(np.median(border))
187
+ mad = float(np.median(np.abs(border - med)))
188
+ robust_sigma = 1.4826 * mad
189
+ threshold = med + self.NOISE_SIGMA_MULTIPLIER * robust_sigma + self.MIN_BACKGROUND_TOLERANCE
190
+ return float(np.clip(threshold, self.MIN_BACKGROUND_TOLERANCE, self.MAX_BACKGROUND_TOLERANCE))
191
+
192
+ def _find_sprite_center_x_on_row(self, row_distance, threshold, width):
193
+ different = row_distance > threshold
194
+ gap = int(round(width * self.SMALL_GAP_FRACTION_OF_WIDTH))
195
+ gap = int(np.clip(gap, self.SMALL_GAP_MIN_PIXELS, self.SMALL_GAP_MAX_PIXELS))
196
+ different = self._close_small_false_gaps(different, gap)
197
+
198
+ runs = self._true_runs(different)
199
+ if not runs:
200
+ return 0.0, False
201
+
202
+ # Largest group with strongest total color difference.
203
+ # score=sum distance; tie-breaker=length.
204
+ best = None
205
+ best_score = -1.0
206
+ best_len = -1
207
+ for start, end in runs:
208
+ length = end - start
209
+ if length <= 0:
210
+ continue
211
+ score = float(np.sum(row_distance[start:end]))
212
+ if score > best_score or (math.isclose(score, best_score) and length > best_len):
213
+ best = (start, end)
214
+ best_score = score
215
+ best_len = length
216
+
217
+ if best is None:
218
+ return 0.0, False
219
+
220
+ start, end = best # end is exclusive
221
+ center_x = (start + end - 1) / 2.0
222
+ return center_x, True
223
+
224
+ @staticmethod
225
+ def _close_small_false_gaps(mask, max_gap):
226
+ # Fill False gaps between True runs when the gap is small.
227
+ closed = mask.astype(bool).copy()
228
+ n = closed.size
229
+ i = 0
230
+ while i < n:
231
+ while i < n and closed[i]:
232
+ i += 1
233
+ gap_start = i
234
+ while i < n and not closed[i]:
235
+ i += 1
236
+ gap_end = i
237
+
238
+ if gap_start == 0 or gap_end == n:
239
+ continue
240
+ if (gap_end - gap_start) <= max_gap and closed[gap_start - 1] and closed[gap_end]:
241
+ closed[gap_start:gap_end] = True
242
+ return closed
243
+
244
+ @staticmethod
245
+ def _true_runs(mask):
246
+ runs = []
247
+ n = mask.size
248
+ i = 0
249
+ while i < n:
250
+ while i < n and not mask[i]:
251
+ i += 1
252
+ start = i
253
+ while i < n and mask[i]:
254
+ i += 1
255
+ end = i
256
+ if end > start:
257
+ runs.append((start, end))
258
+ return runs
259
+
260
+ @staticmethod
261
+ def _external_background_alpha(dist, threshold):
262
+ h, w = dist.shape
263
+ background_like = dist <= threshold
264
+ external = np.zeros((h, w), dtype=bool)
265
+ q = deque()
266
+
267
+ def push(y, x):
268
+ if background_like[y, x] and not external[y, x]:
269
+ external[y, x] = True
270
+ q.append((y, x))
271
+
272
+ for x in range(w):
273
+ push(0, x)
274
+ push(h - 1, x)
275
+ for y in range(h):
276
+ push(y, 0)
277
+ push(y, w - 1)
278
+
279
+ while q:
280
+ y, x = q.popleft()
281
+ yy = y - 1
282
+ if yy >= 0:
283
+ push(yy, x)
284
+ yy = y + 1
285
+ if yy < h:
286
+ push(yy, x)
287
+ xx = x - 1
288
+ if xx >= 0:
289
+ push(y, xx)
290
+ xx = x + 1
291
+ if xx < w:
292
+ push(y, xx)
293
+
294
+ # Keep sprite and enclosed light pixels opaque; make external background transparent.
295
+ return (~external).astype(np.float32)
296
+
297
+
298
+ NODE_CLASS_MAPPINGS = {
299
+ "SpriteBatchStabilizeToTarget": SpriteBatchStabilizeToTarget,
300
+ }
301
+
302
+ NODE_DISPLAY_NAME_MAPPINGS = {
303
+ "SpriteBatchStabilizeToTarget": "Sprite Batch Stabilize To Target",
304
+ }
305
+
306
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]