import math import torch class SpriteHeadStabilizeX: """ Stabilize sprite animation wiggle (X only) using a Y-band (e.g. head region). Align frames 1..N to frame 0 by estimating horizontal shift from alpha visibility inside the selected Y-range. Methods: - bbox_center: leftmost/rightmost visible pixel columns -> center - alpha_com: alpha-weighted center-of-mass (recommended) - profile_corr: phase correlation on horizontal alpha profile (very robust) - hybrid: profile_corr with a sanity check fallback to alpha_com Inputs support: - True RGBA IMAGE tensor (C>=4) => alpha taken from channel 4 - Or IMAGE (RGB) + MASK (ComfyUI LoadImage mask) => alpha derived from mask """ @classmethod def INPUT_TYPES(cls): return { "required": { "images": ("IMAGE", {}), # Head band "y_min": ("INT", {"default": 210, "min": -99999, "max": 99999, "step": 1}), "y_max": ("INT", {"default": 332, "min": -99999, "max": 99999, "step": 1}), # Alpha tolerance: visible if alpha > threshold_8bit / 255 "alpha_threshold_8bit": ("INT", {"default": 5, "min": 0, "max": 255, "step": 1}), "method": (["bbox_center", "alpha_com", "profile_corr", "hybrid"], {"default": "alpha_com"}), # ComfyUI LoadImage produces MASK from alpha and inverts it. # If your mask is already alpha (0=transparent,1=opaque), set False. "mask_is_inverted": ("BOOLEAN", {"default": True}), # Optional safety clamps/smoothing "max_abs_shift": ("INT", {"default": 0, "min": 0, "max": 99999, "step": 1}), "temporal_median": ("INT", {"default": 1, "min": 1, "max": 99, "step": 1}), # Hybrid sanity check: if corr shift differs from COM shift by more than this, # use COM shift instead. "hybrid_tolerance_px": ("INT", {"default": 8, "min": 0, "max": 99999, "step": 1}), }, "optional": { "mask": ("MASK", {}), } } RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_NAMES = ("images", "mask", "shifts_x") FUNCTION = "stabilize" CATEGORY = "image/sprite" SEARCH_ALIASES = ["wiggle stabilize", "sprite stabilize", "head stabilize", "animation stabilize", "sprite jitter fix"] # ---------- helpers ---------- def _get_alpha(self, images: torch.Tensor, mask: torch.Tensor | None, mask_is_inverted: bool) -> torch.Tensor: """ Returns alpha in [0..1], shape [B,H,W]. """ if images.dim() != 4: raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}") B, H, W, C = images.shape if C >= 4: return images[..., 3] if mask is None: raise ValueError("Need RGBA images (C>=4) OR provide a MASK input.") if mask.dim() == 2: mask = mask.unsqueeze(0) if mask.dim() != 3: raise ValueError(f"mask must have shape [B,H,W] or [H,W], got {tuple(mask.shape)}") if mask.shape[1] != H or mask.shape[2] != W: raise ValueError(f"mask H/W must match images; mask={tuple(mask.shape)} images={tuple(images.shape)}") if mask.shape[0] == 1 and B > 1: mask = mask.repeat(B, 1, 1) if mask.shape[0] != B: raise ValueError(f"mask batch must match images batch; mask B={mask.shape[0]} images B={B}") alpha = 1.0 - mask if mask_is_inverted else mask return alpha def _clamp_y(self, H: int, y_min: int, y_max: int) -> tuple[int, int]: y0 = int(y_min) y1 = int(y_max) if y1 < y0: y0, y1 = y1, y0 y0 = max(0, min(H - 1, y0)) y1 = max(0, min(H - 1, y1)) return y0, y1 def _bbox_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None: """ alpha_hw: [H,W] Returns center X using leftmost/rightmost visible columns, or None if empty. """ # visible: [H,W] visible = alpha_hw > thr cols = visible.any(dim=0) # [W] if not bool(cols.any()): return None W = alpha_hw.shape[1] left = int(torch.argmax(cols.float()).item()) right = int((W - 1) - torch.argmax(torch.flip(cols, dims=[0]).float()).item()) return (left + right) / 2.0 def _com_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None: """ alpha_hw: [H,W] Alpha-weighted center-of-mass of X within visible area, or None if empty. """ W = alpha_hw.shape[1] weights = alpha_hw if thr > 0: weights = weights * (weights > thr) profile = weights.sum(dim=0) # [W] total = float(profile.sum().item()) if total <= 0.0: return None x = torch.arange(W, device=alpha_hw.device, dtype=profile.dtype) center = float((profile * x).sum().item() / total) return center def _phase_corr_shift_x(self, alpha_hw: torch.Tensor, ref_profile: torch.Tensor, thr: float) -> int | None: """ Estimate integer shift to APPLY to current frame (X) so it matches reference. Uses 1D phase correlation on horizontal alpha profile. Returns shift_x (int), or None if empty. """ weights = alpha_hw if thr > 0: weights = weights * (weights > thr) prof = weights.sum(dim=0).float() if float(prof.sum().item()) <= 0.0: return None # Remove DC component prof = prof - prof.mean() ref = ref_profile # Phase correlation F = torch.fft.rfft(prof) R = torch.fft.rfft(ref) cps = F * torch.conj(R) cps = cps / (torch.abs(cps) + 1e-9) corr = torch.fft.irfft(cps, n=prof.numel()) peak = int(torch.argmax(corr).item()) W = prof.numel() lag = peak if peak <= W // 2 else peak - W # lag = "current is shifted by lag relative to ref" shift_x = -lag # apply negative to align to ref return int(shift_x) def _shift_frame_x(self, img_hwc: torch.Tensor, shift_x: int) -> torch.Tensor: """ img_hwc: [H,W,C] shift_x: int (positive -> move right) """ H, W, C = img_hwc.shape out = torch.zeros_like(img_hwc) if shift_x == 0: return img_hwc if abs(shift_x) >= W: return out if shift_x > 0: out[:, shift_x:, :] = img_hwc[:, : W - shift_x, :] else: sx = -shift_x out[:, : W - sx, :] = img_hwc[:, sx:, :] return out def _shift_mask_x(self, m_hw: torch.Tensor, shift_x: int, fill_val: float) -> torch.Tensor: """ m_hw: [H,W] """ H, W = m_hw.shape out = torch.full_like(m_hw, fill_val) if shift_x == 0: return m_hw if abs(shift_x) >= W: return out if shift_x > 0: out[:, shift_x:] = m_hw[:, : W - shift_x] else: sx = -shift_x out[:, : W - sx] = m_hw[:, sx:] return out def _median_smooth(self, shifts: list[int], window: int) -> list[int]: """ Median filter over shifts with odd window size. Keeps shifts[0] unchanged. """ if window <= 1 or len(shifts) <= 2: return shifts w = int(window) if w % 2 == 0: w += 1 r = w // 2 out = shifts[:] out[0] = shifts[0] n = len(shifts) for i in range(1, n): lo = max(1, i - r) hi = min(n, i + r + 1) vals = sorted(shifts[lo:hi]) out[i] = vals[len(vals) // 2] return out # ---------- main ---------- def stabilize( self, images: torch.Tensor, y_min: int = 210, y_max: int = 332, alpha_threshold_8bit: int = 5, method: str = "alpha_com", mask_is_inverted: bool = True, max_abs_shift: int = 0, temporal_median: int = 1, hybrid_tolerance_px: int = 8, mask: torch.Tensor | None = None, ): if not torch.is_tensor(images): raise TypeError("images must be a torch.Tensor") if images.dim() != 4: raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}") B, H, W, C = images.shape if B < 1: raise ValueError("images batch is empty") alpha = self._get_alpha(images, mask, mask_is_inverted) # [B,H,W] y0, y1 = self._clamp_y(H, y_min, y_max) thr = float(alpha_threshold_8bit) / 255.0 roi_alpha = alpha[:, y0:y1 + 1, :] # [B, Hr, W] # Reference (frame 0) ref_roi = roi_alpha[0] # [Hr,W] # Prepare reference for methods ref_center_bbox = None ref_center_com = None ref_profile = None if method in ("bbox_center", "hybrid"): ref_center_bbox = self._bbox_center_x(ref_roi, thr) if method in ("alpha_com", "hybrid"): ref_center_com = self._com_center_x(ref_roi, thr) if method in ("profile_corr", "hybrid"): # reference profile for phase correlation w = ref_roi if thr > 0: w = w * (w > thr) ref_profile = w.sum(dim=0).float() ref_profile = ref_profile - ref_profile.mean() # Fallback reference center if missing if ref_center_bbox is None and ref_center_com is None and ref_profile is None: # Nothing visible even in reference head region; do nothing. out_mask = None if mask is not None: out_mask = mask if mask.dim() == 3 else mask.unsqueeze(0) elif C >= 4: a = images[..., 3] out_mask = (1.0 - a) if mask_is_inverted else a else: fill_val = 1.0 if mask_is_inverted else 0.0 out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype) return (images, out_mask, "[0]" if B == 1 else str([0] * B)) # For center-based methods, pick a reference center # Preference: COM, else BBOX, else image center if ref_center_com is not None: ref_center = ref_center_com elif ref_center_bbox is not None: ref_center = ref_center_bbox else: ref_center = W / 2.0 shifts = [0] * B shifts[0] = 0 # frame 0 stays for i in range(1, B): a_hw = roi_alpha[i] shift_i = 0 if method == "bbox_center": c = self._bbox_center_x(a_hw, thr) if c is None: shift_i = 0 else: shift_i = int(round(ref_center - c)) elif method == "alpha_com": c = self._com_center_x(a_hw, thr) if c is None: shift_i = 0 else: shift_i = int(round(ref_center - c)) elif method == "profile_corr": s = self._phase_corr_shift_x(a_hw, ref_profile, thr) # already int shift to APPLY shift_i = 0 if s is None else int(s) elif method == "hybrid": # corr shift s_corr = self._phase_corr_shift_x(a_hw, ref_profile, thr) if ref_profile is not None else None # com shift c = self._com_center_x(a_hw, thr) s_com = None if c is None else int(round(ref_center - c)) if s_corr is None and s_com is None: shift_i = 0 elif s_corr is None: shift_i = int(s_com) elif s_com is None: shift_i = int(s_corr) else: if abs(int(s_corr) - int(s_com)) > int(hybrid_tolerance_px): shift_i = int(s_com) else: shift_i = int(s_corr) else: raise ValueError(f"Unknown method: {method}") # Clamp extreme shifts if requested if max_abs_shift and max_abs_shift > 0: shift_i = int(max(-max_abs_shift, min(max_abs_shift, shift_i))) shifts[i] = shift_i # Optional temporal median smoothing (keeps shifts[0] anchored) shifts = self._median_smooth(shifts, int(temporal_median)) # Apply per-frame shifts out_images = torch.zeros_like(images) # Output mask handling: # - If input mask provided: shift it # - Else if RGBA: derive from shifted alpha # - Else: produce blank out_mask = None in_mask_bhw = None if mask is not None: in_mask_bhw = mask if in_mask_bhw.dim() == 2: in_mask_bhw = in_mask_bhw.unsqueeze(0) if in_mask_bhw.shape[0] == 1 and B > 1: in_mask_bhw = in_mask_bhw.repeat(B, 1, 1) fill_val = 1.0 if mask_is_inverted else 0.0 out_mask = torch.full_like(in_mask_bhw, fill_val) for i in range(B): sx = int(shifts[i]) out_images[i] = self._shift_frame_x(images[i], sx) if out_mask is not None and in_mask_bhw is not None: fill_val = 1.0 if mask_is_inverted else 0.0 out_mask[i] = self._shift_mask_x(in_mask_bhw[i], sx, fill_val) if out_mask is None: if out_images.shape[-1] >= 4: a = out_images[..., 3] out_mask = (1.0 - a) if mask_is_inverted else a else: fill_val = 1.0 if mask_is_inverted else 0.0 out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype) shifts_str = str(shifts) return (out_images, out_mask, shifts_str) NODE_CLASS_MAPPINGS = { "SpriteHeadStabilizeX": SpriteHeadStabilizeX, } NODE_DISPLAY_NAME_MAPPINGS = { "SpriteHeadStabilizeX": "Sprite Head Stabilize X (Batch)", }