saliacoel commited on
Commit
6c8ffa0
·
verified ·
1 Parent(s): 10a8e09

Upload Batch_Stabilize_Sprite.py

Browse files
Files changed (1) hide show
  1. Batch_Stabilize_Sprite.py +399 -0
Batch_Stabilize_Sprite.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ class SpriteHeadStabilizeX:
5
+ """
6
+ Stabilize sprite animation wiggle (X only) using a Y-band (e.g. head region).
7
+
8
+ Align frames 1..N to frame 0 by estimating horizontal shift from alpha visibility
9
+ inside the selected Y-range.
10
+
11
+ Methods:
12
+ - bbox_center: leftmost/rightmost visible pixel columns -> center
13
+ - alpha_com: alpha-weighted center-of-mass (recommended)
14
+ - profile_corr: phase correlation on horizontal alpha profile (very robust)
15
+ - hybrid: profile_corr with a sanity check fallback to alpha_com
16
+
17
+ Inputs support:
18
+ - True RGBA IMAGE tensor (C>=4) => alpha taken from channel 4
19
+ - Or IMAGE (RGB) + MASK (ComfyUI LoadImage mask) => alpha derived from mask
20
+ """
21
+
22
+ @classmethod
23
+ def INPUT_TYPES(cls):
24
+ return {
25
+ "required": {
26
+ "images": ("IMAGE", {}),
27
+
28
+ # Head band
29
+ "y_min": ("INT", {"default": 210, "min": -99999, "max": 99999, "step": 1}),
30
+ "y_max": ("INT", {"default": 332, "min": -99999, "max": 99999, "step": 1}),
31
+
32
+ # Alpha tolerance: visible if alpha > threshold_8bit / 255
33
+ "alpha_threshold_8bit": ("INT", {"default": 5, "min": 0, "max": 255, "step": 1}),
34
+
35
+ "method": (["bbox_center", "alpha_com", "profile_corr", "hybrid"], {"default": "alpha_com"}),
36
+
37
+ # ComfyUI LoadImage produces MASK from alpha and inverts it.
38
+ # If your mask is already alpha (0=transparent,1=opaque), set False.
39
+ "mask_is_inverted": ("BOOLEAN", {"default": True}),
40
+
41
+ # Optional safety clamps/smoothing
42
+ "max_abs_shift": ("INT", {"default": 0, "min": 0, "max": 99999, "step": 1}),
43
+ "temporal_median": ("INT", {"default": 1, "min": 1, "max": 99, "step": 1}),
44
+
45
+ # Hybrid sanity check: if corr shift differs from COM shift by more than this,
46
+ # use COM shift instead.
47
+ "hybrid_tolerance_px": ("INT", {"default": 8, "min": 0, "max": 99999, "step": 1}),
48
+ },
49
+ "optional": {
50
+ "mask": ("MASK", {}),
51
+ }
52
+ }
53
+
54
+ RETURN_TYPES = ("IMAGE", "MASK", "STRING")
55
+ RETURN_NAMES = ("images", "mask", "shifts_x")
56
+ FUNCTION = "stabilize"
57
+ CATEGORY = "image/sprite"
58
+ SEARCH_ALIASES = ["wiggle stabilize", "sprite stabilize", "head stabilize", "animation stabilize", "sprite jitter fix"]
59
+
60
+ # ---------- helpers ----------
61
+
62
+ def _get_alpha(self, images: torch.Tensor, mask: torch.Tensor | None, mask_is_inverted: bool) -> torch.Tensor:
63
+ """
64
+ Returns alpha in [0..1], shape [B,H,W].
65
+ """
66
+ if images.dim() != 4:
67
+ raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}")
68
+ B, H, W, C = images.shape
69
+
70
+ if C >= 4:
71
+ return images[..., 3]
72
+
73
+ if mask is None:
74
+ raise ValueError("Need RGBA images (C>=4) OR provide a MASK input.")
75
+
76
+ if mask.dim() == 2:
77
+ mask = mask.unsqueeze(0)
78
+ if mask.dim() != 3:
79
+ raise ValueError(f"mask must have shape [B,H,W] or [H,W], got {tuple(mask.shape)}")
80
+
81
+ if mask.shape[1] != H or mask.shape[2] != W:
82
+ raise ValueError(f"mask H/W must match images; mask={tuple(mask.shape)} images={tuple(images.shape)}")
83
+
84
+ if mask.shape[0] == 1 and B > 1:
85
+ mask = mask.repeat(B, 1, 1)
86
+ if mask.shape[0] != B:
87
+ raise ValueError(f"mask batch must match images batch; mask B={mask.shape[0]} images B={B}")
88
+
89
+ alpha = 1.0 - mask if mask_is_inverted else mask
90
+ return alpha
91
+
92
+ def _clamp_y(self, H: int, y_min: int, y_max: int) -> tuple[int, int]:
93
+ y0 = int(y_min)
94
+ y1 = int(y_max)
95
+ if y1 < y0:
96
+ y0, y1 = y1, y0
97
+ y0 = max(0, min(H - 1, y0))
98
+ y1 = max(0, min(H - 1, y1))
99
+ return y0, y1
100
+
101
+ def _bbox_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None:
102
+ """
103
+ alpha_hw: [H,W]
104
+ Returns center X using leftmost/rightmost visible columns, or None if empty.
105
+ """
106
+ # visible: [H,W]
107
+ visible = alpha_hw > thr
108
+ cols = visible.any(dim=0) # [W]
109
+ if not bool(cols.any()):
110
+ return None
111
+ W = alpha_hw.shape[1]
112
+ left = int(torch.argmax(cols.float()).item())
113
+ right = int((W - 1) - torch.argmax(torch.flip(cols, dims=[0]).float()).item())
114
+ return (left + right) / 2.0
115
+
116
+ def _com_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None:
117
+ """
118
+ alpha_hw: [H,W]
119
+ Alpha-weighted center-of-mass of X within visible area, or None if empty.
120
+ """
121
+ W = alpha_hw.shape[1]
122
+ weights = alpha_hw
123
+ if thr > 0:
124
+ weights = weights * (weights > thr)
125
+
126
+ profile = weights.sum(dim=0) # [W]
127
+ total = float(profile.sum().item())
128
+ if total <= 0.0:
129
+ return None
130
+
131
+ x = torch.arange(W, device=alpha_hw.device, dtype=profile.dtype)
132
+ center = float((profile * x).sum().item() / total)
133
+ return center
134
+
135
+ def _phase_corr_shift_x(self, alpha_hw: torch.Tensor, ref_profile: torch.Tensor, thr: float) -> int | None:
136
+ """
137
+ Estimate integer shift to APPLY to current frame (X) so it matches reference.
138
+ Uses 1D phase correlation on horizontal alpha profile.
139
+ Returns shift_x (int), or None if empty.
140
+ """
141
+ weights = alpha_hw
142
+ if thr > 0:
143
+ weights = weights * (weights > thr)
144
+
145
+ prof = weights.sum(dim=0).float()
146
+ if float(prof.sum().item()) <= 0.0:
147
+ return None
148
+
149
+ # Remove DC component
150
+ prof = prof - prof.mean()
151
+ ref = ref_profile
152
+
153
+ # Phase correlation
154
+ F = torch.fft.rfft(prof)
155
+ R = torch.fft.rfft(ref)
156
+ cps = F * torch.conj(R)
157
+ cps = cps / (torch.abs(cps) + 1e-9)
158
+ corr = torch.fft.irfft(cps, n=prof.numel())
159
+ peak = int(torch.argmax(corr).item())
160
+
161
+ W = prof.numel()
162
+ lag = peak if peak <= W // 2 else peak - W # lag = "current is shifted by lag relative to ref"
163
+ shift_x = -lag # apply negative to align to ref
164
+ return int(shift_x)
165
+
166
+ def _shift_frame_x(self, img_hwc: torch.Tensor, shift_x: int) -> torch.Tensor:
167
+ """
168
+ img_hwc: [H,W,C]
169
+ shift_x: int (positive -> move right)
170
+ """
171
+ H, W, C = img_hwc.shape
172
+ out = torch.zeros_like(img_hwc)
173
+ if shift_x == 0:
174
+ return img_hwc
175
+ if abs(shift_x) >= W:
176
+ return out
177
+
178
+ if shift_x > 0:
179
+ out[:, shift_x:, :] = img_hwc[:, : W - shift_x, :]
180
+ else:
181
+ sx = -shift_x
182
+ out[:, : W - sx, :] = img_hwc[:, sx:, :]
183
+ return out
184
+
185
+ def _shift_mask_x(self, m_hw: torch.Tensor, shift_x: int, fill_val: float) -> torch.Tensor:
186
+ """
187
+ m_hw: [H,W]
188
+ """
189
+ H, W = m_hw.shape
190
+ out = torch.full_like(m_hw, fill_val)
191
+ if shift_x == 0:
192
+ return m_hw
193
+ if abs(shift_x) >= W:
194
+ return out
195
+ if shift_x > 0:
196
+ out[:, shift_x:] = m_hw[:, : W - shift_x]
197
+ else:
198
+ sx = -shift_x
199
+ out[:, : W - sx] = m_hw[:, sx:]
200
+ return out
201
+
202
+ def _median_smooth(self, shifts: list[int], window: int) -> list[int]:
203
+ """
204
+ Median filter over shifts with odd window size. Keeps shifts[0] unchanged.
205
+ """
206
+ if window <= 1 or len(shifts) <= 2:
207
+ return shifts
208
+ w = int(window)
209
+ if w % 2 == 0:
210
+ w += 1
211
+ r = w // 2
212
+ out = shifts[:]
213
+ out[0] = shifts[0]
214
+ n = len(shifts)
215
+ for i in range(1, n):
216
+ lo = max(1, i - r)
217
+ hi = min(n, i + r + 1)
218
+ vals = sorted(shifts[lo:hi])
219
+ out[i] = vals[len(vals) // 2]
220
+ return out
221
+
222
+ # ---------- main ----------
223
+
224
+ def stabilize(
225
+ self,
226
+ images: torch.Tensor,
227
+ y_min: int = 210,
228
+ y_max: int = 332,
229
+ alpha_threshold_8bit: int = 5,
230
+ method: str = "alpha_com",
231
+ mask_is_inverted: bool = True,
232
+ max_abs_shift: int = 0,
233
+ temporal_median: int = 1,
234
+ hybrid_tolerance_px: int = 8,
235
+ mask: torch.Tensor | None = None,
236
+ ):
237
+ if not torch.is_tensor(images):
238
+ raise TypeError("images must be a torch.Tensor")
239
+ if images.dim() != 4:
240
+ raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}")
241
+
242
+ B, H, W, C = images.shape
243
+ if B < 1:
244
+ raise ValueError("images batch is empty")
245
+
246
+ alpha = self._get_alpha(images, mask, mask_is_inverted) # [B,H,W]
247
+ y0, y1 = self._clamp_y(H, y_min, y_max)
248
+ thr = float(alpha_threshold_8bit) / 255.0
249
+
250
+ roi_alpha = alpha[:, y0:y1 + 1, :] # [B, Hr, W]
251
+
252
+ # Reference (frame 0)
253
+ ref_roi = roi_alpha[0] # [Hr,W]
254
+
255
+ # Prepare reference for methods
256
+ ref_center_bbox = None
257
+ ref_center_com = None
258
+ ref_profile = None
259
+
260
+ if method in ("bbox_center", "hybrid"):
261
+ ref_center_bbox = self._bbox_center_x(ref_roi, thr)
262
+ if method in ("alpha_com", "hybrid"):
263
+ ref_center_com = self._com_center_x(ref_roi, thr)
264
+ if method in ("profile_corr", "hybrid"):
265
+ # reference profile for phase correlation
266
+ w = ref_roi
267
+ if thr > 0:
268
+ w = w * (w > thr)
269
+ ref_profile = w.sum(dim=0).float()
270
+ ref_profile = ref_profile - ref_profile.mean()
271
+
272
+ # Fallback reference center if missing
273
+ if ref_center_bbox is None and ref_center_com is None and ref_profile is None:
274
+ # Nothing visible even in reference head region; do nothing.
275
+ out_mask = None
276
+ if mask is not None:
277
+ out_mask = mask if mask.dim() == 3 else mask.unsqueeze(0)
278
+ elif C >= 4:
279
+ a = images[..., 3]
280
+ out_mask = (1.0 - a) if mask_is_inverted else a
281
+ else:
282
+ fill_val = 1.0 if mask_is_inverted else 0.0
283
+ out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype)
284
+
285
+ return (images, out_mask, "[0]" if B == 1 else str([0] * B))
286
+
287
+ # For center-based methods, pick a reference center
288
+ # Preference: COM, else BBOX, else image center
289
+ if ref_center_com is not None:
290
+ ref_center = ref_center_com
291
+ elif ref_center_bbox is not None:
292
+ ref_center = ref_center_bbox
293
+ else:
294
+ ref_center = W / 2.0
295
+
296
+ shifts = [0] * B
297
+ shifts[0] = 0 # frame 0 stays
298
+
299
+ for i in range(1, B):
300
+ a_hw = roi_alpha[i]
301
+
302
+ shift_i = 0
303
+
304
+ if method == "bbox_center":
305
+ c = self._bbox_center_x(a_hw, thr)
306
+ if c is None:
307
+ shift_i = 0
308
+ else:
309
+ shift_i = int(round(ref_center - c))
310
+
311
+ elif method == "alpha_com":
312
+ c = self._com_center_x(a_hw, thr)
313
+ if c is None:
314
+ shift_i = 0
315
+ else:
316
+ shift_i = int(round(ref_center - c))
317
+
318
+ elif method == "profile_corr":
319
+ s = self._phase_corr_shift_x(a_hw, ref_profile, thr) # already int shift to APPLY
320
+ shift_i = 0 if s is None else int(s)
321
+
322
+ elif method == "hybrid":
323
+ # corr shift
324
+ s_corr = self._phase_corr_shift_x(a_hw, ref_profile, thr) if ref_profile is not None else None
325
+
326
+ # com shift
327
+ c = self._com_center_x(a_hw, thr)
328
+ s_com = None if c is None else int(round(ref_center - c))
329
+
330
+ if s_corr is None and s_com is None:
331
+ shift_i = 0
332
+ elif s_corr is None:
333
+ shift_i = int(s_com)
334
+ elif s_com is None:
335
+ shift_i = int(s_corr)
336
+ else:
337
+ if abs(int(s_corr) - int(s_com)) > int(hybrid_tolerance_px):
338
+ shift_i = int(s_com)
339
+ else:
340
+ shift_i = int(s_corr)
341
+
342
+ else:
343
+ raise ValueError(f"Unknown method: {method}")
344
+
345
+ # Clamp extreme shifts if requested
346
+ if max_abs_shift and max_abs_shift > 0:
347
+ shift_i = int(max(-max_abs_shift, min(max_abs_shift, shift_i)))
348
+
349
+ shifts[i] = shift_i
350
+
351
+ # Optional temporal median smoothing (keeps shifts[0] anchored)
352
+ shifts = self._median_smooth(shifts, int(temporal_median))
353
+
354
+ # Apply per-frame shifts
355
+ out_images = torch.zeros_like(images)
356
+
357
+ # Output mask handling:
358
+ # - If input mask provided: shift it
359
+ # - Else if RGBA: derive from shifted alpha
360
+ # - Else: produce blank
361
+ out_mask = None
362
+ in_mask_bhw = None
363
+ if mask is not None:
364
+ in_mask_bhw = mask
365
+ if in_mask_bhw.dim() == 2:
366
+ in_mask_bhw = in_mask_bhw.unsqueeze(0)
367
+ if in_mask_bhw.shape[0] == 1 and B > 1:
368
+ in_mask_bhw = in_mask_bhw.repeat(B, 1, 1)
369
+
370
+ fill_val = 1.0 if mask_is_inverted else 0.0
371
+ out_mask = torch.full_like(in_mask_bhw, fill_val)
372
+
373
+ for i in range(B):
374
+ sx = int(shifts[i])
375
+ out_images[i] = self._shift_frame_x(images[i], sx)
376
+
377
+ if out_mask is not None and in_mask_bhw is not None:
378
+ fill_val = 1.0 if mask_is_inverted else 0.0
379
+ out_mask[i] = self._shift_mask_x(in_mask_bhw[i], sx, fill_val)
380
+
381
+ if out_mask is None:
382
+ if out_images.shape[-1] >= 4:
383
+ a = out_images[..., 3]
384
+ out_mask = (1.0 - a) if mask_is_inverted else a
385
+ else:
386
+ fill_val = 1.0 if mask_is_inverted else 0.0
387
+ out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype)
388
+
389
+ shifts_str = str(shifts)
390
+ return (out_images, out_mask, shifts_str)
391
+
392
+
393
+ NODE_CLASS_MAPPINGS = {
394
+ "SpriteHeadStabilizeX": SpriteHeadStabilizeX,
395
+ }
396
+
397
+ NODE_DISPLAY_NAME_MAPPINGS = {
398
+ "SpriteHeadStabilizeX": "Sprite Head Stabilize X (Batch)",
399
+ }