saliacoel commited on
Commit
7de8be2
·
verified ·
1 Parent(s): 56031a1

Upload salia_sprite_head_stabilizer.py

Browse files
Files changed (1) hide show
  1. salia_sprite_head_stabilizer.py +382 -0
salia_sprite_head_stabilizer.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ class Salia_Sprite_Head_Stabilizer:
6
+ """
7
+ Stabilize sprite animation wiggle (X only) using a fixed Y-band (head region).
8
+
9
+ Frames 1..N are aligned to frame 0 by estimating horizontal shift from alpha visibility
10
+ inside the selected Y-range.
11
+
12
+ Methods supported internally:
13
+ - bbox_center: leftmost/rightmost visible pixel columns -> center
14
+ - alpha_com: alpha-weighted center-of-mass (RECOMMENDED / HARDCODED)
15
+ - profile_corr: phase correlation on horizontal alpha profile (very robust)
16
+ - hybrid: profile_corr with a sanity check fallback to alpha_com
17
+
18
+ Inputs support:
19
+ - True RGBA IMAGE tensor (C>=4) => alpha taken from channel 4
20
+ - Or IMAGE (RGB) + MASK (ComfyUI LoadImage mask) => alpha derived from mask
21
+
22
+ NOTE: User-adjustable parameters are HARDCODED via constants below.
23
+ """
24
+
25
+ # ----------------------------
26
+ # HARDCODED CONSTANTS (as requested)
27
+ # ----------------------------
28
+ Y_MIN = 210
29
+ Y_MAX = 332
30
+ ALPHA_THRESHOLD_8BIT = 5
31
+ METHOD = "alpha_com" # one of: bbox_center, alpha_com, profile_corr, hybrid
32
+ MASK_IS_INVERTED = True
33
+ MAX_ABS_SHIFT = 0
34
+ TEMPORAL_MEDIAN = 1
35
+
36
+ # User asked for "hybrid_tolerance_pc" (typo vs px). Keep both names as constants.
37
+ HYBRID_TOLERANCE_PC = 8
38
+ HYBRID_TOLERANCE_PX = HYBRID_TOLERANCE_PC
39
+
40
+ @classmethod
41
+ def INPUT_TYPES(cls):
42
+ # Only expose inputs that are still meaningful with hardcoded params.
43
+ return {
44
+ "required": {
45
+ "images": ("IMAGE", {}),
46
+ },
47
+ "optional": {
48
+ "mask": ("MASK", {}),
49
+ },
50
+ }
51
+
52
+ RETURN_TYPES = ("IMAGE", "MASK", "STRING")
53
+ RETURN_NAMES = ("images", "mask", "shifts_x")
54
+ FUNCTION = "stabilize"
55
+ CATEGORY = "image/sprite"
56
+ SEARCH_ALIASES = ["wiggle stabilize", "sprite stabilize", "head stabilize", "animation stabilize", "sprite jitter fix"]
57
+
58
+ # ---------- helpers ----------
59
+
60
+ def _get_alpha(self, images: torch.Tensor, mask: torch.Tensor | None, mask_is_inverted: bool) -> torch.Tensor:
61
+ """
62
+ Returns alpha in [0..1], shape [B,H,W].
63
+ """
64
+ if images.dim() != 4:
65
+ raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}")
66
+ B, H, W, C = images.shape
67
+
68
+ if C >= 4:
69
+ return images[..., 3]
70
+
71
+ if mask is None:
72
+ raise ValueError("Need RGBA images (C>=4) OR provide a MASK input.")
73
+
74
+ if mask.dim() == 2:
75
+ mask = mask.unsqueeze(0)
76
+ if mask.dim() != 3:
77
+ raise ValueError(f"mask must have shape [B,H,W] or [H,W], got {tuple(mask.shape)}")
78
+
79
+ if mask.shape[1] != H or mask.shape[2] != W:
80
+ raise ValueError(f"mask H/W must match images; mask={tuple(mask.shape)} images={tuple(images.shape)}")
81
+
82
+ if mask.shape[0] == 1 and B > 1:
83
+ mask = mask.repeat(B, 1, 1)
84
+ if mask.shape[0] != B:
85
+ raise ValueError(f"mask batch must match images batch; mask B={mask.shape[0]} images B={B}")
86
+
87
+ alpha = 1.0 - mask if mask_is_inverted else mask
88
+ return alpha
89
+
90
+ def _clamp_y(self, H: int, y_min: int, y_max: int) -> tuple[int, int]:
91
+ y0 = int(y_min)
92
+ y1 = int(y_max)
93
+ if y1 < y0:
94
+ y0, y1 = y1, y0
95
+ y0 = max(0, min(H - 1, y0))
96
+ y1 = max(0, min(H - 1, y1))
97
+ return y0, y1
98
+
99
+ def _bbox_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None:
100
+ """
101
+ alpha_hw: [H,W]
102
+ Returns center X using leftmost/rightmost visible columns, or None if empty.
103
+ """
104
+ visible = alpha_hw > thr
105
+ cols = visible.any(dim=0) # [W]
106
+ if not bool(cols.any()):
107
+ return None
108
+ W = alpha_hw.shape[1]
109
+ left = int(torch.argmax(cols.float()).item())
110
+ right = int((W - 1) - torch.argmax(torch.flip(cols, dims=[0]).float()).item())
111
+ return (left + right) / 2.0
112
+
113
+ def _com_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None:
114
+ """
115
+ alpha_hw: [H,W]
116
+ Alpha-weighted center-of-mass of X within visible area, or None if empty.
117
+ """
118
+ W = alpha_hw.shape[1]
119
+ weights = alpha_hw
120
+ if thr > 0:
121
+ weights = weights * (weights > thr)
122
+
123
+ profile = weights.sum(dim=0) # [W]
124
+ total = float(profile.sum().item())
125
+ if total <= 0.0:
126
+ return None
127
+
128
+ x = torch.arange(W, device=alpha_hw.device, dtype=profile.dtype)
129
+ center = float((profile * x).sum().item() / total)
130
+ return center
131
+
132
+ def _phase_corr_shift_x(self, alpha_hw: torch.Tensor, ref_profile: torch.Tensor, thr: float) -> int | None:
133
+ """
134
+ Estimate integer shift to APPLY to current frame (X) so it matches reference.
135
+ Uses 1D phase correlation on horizontal alpha profile.
136
+ Returns shift_x (int), or None if empty.
137
+ """
138
+ weights = alpha_hw
139
+ if thr > 0:
140
+ weights = weights * (weights > thr)
141
+
142
+ prof = weights.sum(dim=0).float()
143
+ if float(prof.sum().item()) <= 0.0:
144
+ return None
145
+
146
+ # Remove DC component
147
+ prof = prof - prof.mean()
148
+ ref = ref_profile
149
+
150
+ # Phase correlation
151
+ F = torch.fft.rfft(prof)
152
+ R = torch.fft.rfft(ref)
153
+ cps = F * torch.conj(R)
154
+ cps = cps / (torch.abs(cps) + 1e-9)
155
+ corr = torch.fft.irfft(cps, n=prof.numel())
156
+ peak = int(torch.argmax(corr).item())
157
+
158
+ W = prof.numel()
159
+ lag = peak if peak <= W // 2 else peak - W # lag = "current is shifted by lag relative to ref"
160
+ shift_x = -lag # apply negative to align to ref
161
+ return int(shift_x)
162
+
163
+ def _shift_frame_x(self, img_hwc: torch.Tensor, shift_x: int) -> torch.Tensor:
164
+ """
165
+ img_hwc: [H,W,C]
166
+ shift_x: int (positive -> move right)
167
+ """
168
+ H, W, C = img_hwc.shape
169
+ out = torch.zeros_like(img_hwc)
170
+ if shift_x == 0:
171
+ return img_hwc
172
+ if abs(shift_x) >= W:
173
+ return out
174
+
175
+ if shift_x > 0:
176
+ out[:, shift_x:, :] = img_hwc[:, : W - shift_x, :]
177
+ else:
178
+ sx = -shift_x
179
+ out[:, : W - sx, :] = img_hwc[:, sx:, :]
180
+ return out
181
+
182
+ def _shift_mask_x(self, m_hw: torch.Tensor, shift_x: int, fill_val: float) -> torch.Tensor:
183
+ """
184
+ m_hw: [H,W]
185
+ """
186
+ H, W = m_hw.shape
187
+ out = torch.full_like(m_hw, fill_val)
188
+ if shift_x == 0:
189
+ return m_hw
190
+ if abs(shift_x) >= W:
191
+ return out
192
+ if shift_x > 0:
193
+ out[:, shift_x:] = m_hw[:, : W - shift_x]
194
+ else:
195
+ sx = -shift_x
196
+ out[:, : W - sx] = m_hw[:, sx:]
197
+ return out
198
+
199
+ def _median_smooth(self, shifts: list[int], window: int) -> list[int]:
200
+ """
201
+ Median filter over shifts with odd window size. Keeps shifts[0] unchanged.
202
+ """
203
+ if window <= 1 or len(shifts) <= 2:
204
+ return shifts
205
+ w = int(window)
206
+ if w % 2 == 0:
207
+ w += 1
208
+ r = w // 2
209
+ out = shifts[:]
210
+ out[0] = shifts[0]
211
+ n = len(shifts)
212
+ for i in range(1, n):
213
+ lo = max(1, i - r)
214
+ hi = min(n, i + r + 1)
215
+ vals = sorted(shifts[lo:hi])
216
+ out[i] = vals[len(vals) // 2]
217
+ return out
218
+
219
+ # ---------- main ----------
220
+
221
+ def stabilize(self, images: torch.Tensor, mask: torch.Tensor | None = None):
222
+ # Pull hardcoded parameters
223
+ y_min = int(self.Y_MIN)
224
+ y_max = int(self.Y_MAX)
225
+ alpha_threshold_8bit = int(self.ALPHA_THRESHOLD_8BIT)
226
+ method = str(self.METHOD)
227
+ mask_is_inverted = bool(self.MASK_IS_INVERTED)
228
+ max_abs_shift = int(self.MAX_ABS_SHIFT)
229
+ temporal_median = int(self.TEMPORAL_MEDIAN)
230
+ hybrid_tolerance_px = int(self.HYBRID_TOLERANCE_PX)
231
+
232
+ if not torch.is_tensor(images):
233
+ raise TypeError("images must be a torch.Tensor")
234
+ if images.dim() != 4:
235
+ raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}")
236
+
237
+ B, H, W, C = images.shape
238
+ if B < 1:
239
+ raise ValueError("images batch is empty")
240
+
241
+ alpha = self._get_alpha(images, mask, mask_is_inverted) # [B,H,W]
242
+ y0, y1 = self._clamp_y(H, y_min, y_max)
243
+ thr = float(alpha_threshold_8bit) / 255.0
244
+
245
+ roi_alpha = alpha[:, y0 : y1 + 1, :] # [B, Hr, W]
246
+
247
+ # Reference (frame 0)
248
+ ref_roi = roi_alpha[0] # [Hr,W]
249
+
250
+ # Prepare reference for methods
251
+ ref_center_bbox = None
252
+ ref_center_com = None
253
+ ref_profile = None
254
+
255
+ if method in ("bbox_center", "hybrid"):
256
+ ref_center_bbox = self._bbox_center_x(ref_roi, thr)
257
+ if method in ("alpha_com", "hybrid"):
258
+ ref_center_com = self._com_center_x(ref_roi, thr)
259
+ if method in ("profile_corr", "hybrid"):
260
+ w = ref_roi
261
+ if thr > 0:
262
+ w = w * (w > thr)
263
+ ref_profile = w.sum(dim=0).float()
264
+ ref_profile = ref_profile - ref_profile.mean()
265
+
266
+ # Fallback reference center if missing
267
+ if ref_center_bbox is None and ref_center_com is None and ref_profile is None:
268
+ # Nothing visible even in reference head region; do nothing.
269
+ out_mask = None
270
+ if mask is not None:
271
+ out_mask = mask if mask.dim() == 3 else mask.unsqueeze(0)
272
+ elif C >= 4:
273
+ a = images[..., 3]
274
+ out_mask = (1.0 - a) if mask_is_inverted else a
275
+ else:
276
+ fill_val = 1.0 if mask_is_inverted else 0.0
277
+ out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype)
278
+
279
+ return (images, out_mask, "[0]" if B == 1 else str([0] * B))
280
+
281
+ # For center-based methods, pick a reference center
282
+ if ref_center_com is not None:
283
+ ref_center = ref_center_com
284
+ elif ref_center_bbox is not None:
285
+ ref_center = ref_center_bbox
286
+ else:
287
+ ref_center = W / 2.0
288
+
289
+ shifts = [0] * B
290
+ shifts[0] = 0 # frame 0 stays
291
+
292
+ for i in range(1, B):
293
+ a_hw = roi_alpha[i]
294
+ shift_i = 0
295
+
296
+ if method == "bbox_center":
297
+ c = self._bbox_center_x(a_hw, thr)
298
+ shift_i = 0 if c is None else int(round(ref_center - c))
299
+
300
+ elif method == "alpha_com":
301
+ c = self._com_center_x(a_hw, thr)
302
+ shift_i = 0 if c is None else int(round(ref_center - c))
303
+
304
+ elif method == "profile_corr":
305
+ s = self._phase_corr_shift_x(a_hw, ref_profile, thr) # already int shift to APPLY
306
+ shift_i = 0 if s is None else int(s)
307
+
308
+ elif method == "hybrid":
309
+ s_corr = self._phase_corr_shift_x(a_hw, ref_profile, thr) if ref_profile is not None else None
310
+ c = self._com_center_x(a_hw, thr)
311
+ s_com = None if c is None else int(round(ref_center - c))
312
+
313
+ if s_corr is None and s_com is None:
314
+ shift_i = 0
315
+ elif s_corr is None:
316
+ shift_i = int(s_com)
317
+ elif s_com is None:
318
+ shift_i = int(s_corr)
319
+ else:
320
+ if abs(int(s_corr) - int(s_com)) > int(hybrid_tolerance_px):
321
+ shift_i = int(s_com)
322
+ else:
323
+ shift_i = int(s_corr)
324
+
325
+ else:
326
+ raise ValueError(f"Unknown method: {method}")
327
+
328
+ # Clamp extreme shifts if requested
329
+ if max_abs_shift and max_abs_shift > 0:
330
+ shift_i = int(max(-max_abs_shift, min(max_abs_shift, shift_i)))
331
+
332
+ shifts[i] = shift_i
333
+
334
+ # Optional temporal median smoothing (keeps shifts[0] anchored)
335
+ shifts = self._median_smooth(shifts, int(temporal_median))
336
+
337
+ # Apply per-frame shifts
338
+ out_images = torch.zeros_like(images)
339
+
340
+ # Output mask handling:
341
+ # - If input mask provided: shift it
342
+ # - Else if RGBA: derive from shifted alpha
343
+ # - Else: produce blank
344
+ out_mask = None
345
+ in_mask_bhw = None
346
+ if mask is not None:
347
+ in_mask_bhw = mask
348
+ if in_mask_bhw.dim() == 2:
349
+ in_mask_bhw = in_mask_bhw.unsqueeze(0)
350
+ if in_mask_bhw.shape[0] == 1 and B > 1:
351
+ in_mask_bhw = in_mask_bhw.repeat(B, 1, 1)
352
+
353
+ fill_val = 1.0 if mask_is_inverted else 0.0
354
+ out_mask = torch.full_like(in_mask_bhw, fill_val)
355
+
356
+ for i in range(B):
357
+ sx = int(shifts[i])
358
+ out_images[i] = self._shift_frame_x(images[i], sx)
359
+
360
+ if out_mask is not None and in_mask_bhw is not None:
361
+ fill_val = 1.0 if mask_is_inverted else 0.0
362
+ out_mask[i] = self._shift_mask_x(in_mask_bhw[i], sx, fill_val)
363
+
364
+ if out_mask is None:
365
+ if out_images.shape[-1] >= 4:
366
+ a = out_images[..., 3]
367
+ out_mask = (1.0 - a) if mask_is_inverted else a
368
+ else:
369
+ fill_val = 1.0 if mask_is_inverted else 0.0
370
+ out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype)
371
+
372
+ shifts_str = str(shifts)
373
+ return (out_images, out_mask, shifts_str)
374
+
375
+
376
+ NODE_CLASS_MAPPINGS = {
377
+ "Salia_Sprite_Head_Stabilizer": Salia_Sprite_Head_Stabilizer,
378
+ }
379
+
380
+ NODE_DISPLAY_NAME_MAPPINGS = {
381
+ "Salia_Sprite_Head_Stabilizer": "Salia_Sprite_Head_Stabilizer",
382
+ }