MyCustomNodes / Batch_Stabilize_Sprite.py
saliacoel's picture
Upload Batch_Stabilize_Sprite.py
6c8ffa0 verified
import math
import torch
class SpriteHeadStabilizeX:
"""
Stabilize sprite animation wiggle (X only) using a Y-band (e.g. head region).
Align frames 1..N to frame 0 by estimating horizontal shift from alpha visibility
inside the selected Y-range.
Methods:
- bbox_center: leftmost/rightmost visible pixel columns -> center
- alpha_com: alpha-weighted center-of-mass (recommended)
- profile_corr: phase correlation on horizontal alpha profile (very robust)
- hybrid: profile_corr with a sanity check fallback to alpha_com
Inputs support:
- True RGBA IMAGE tensor (C>=4) => alpha taken from channel 4
- Or IMAGE (RGB) + MASK (ComfyUI LoadImage mask) => alpha derived from mask
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE", {}),
# Head band
"y_min": ("INT", {"default": 210, "min": -99999, "max": 99999, "step": 1}),
"y_max": ("INT", {"default": 332, "min": -99999, "max": 99999, "step": 1}),
# Alpha tolerance: visible if alpha > threshold_8bit / 255
"alpha_threshold_8bit": ("INT", {"default": 5, "min": 0, "max": 255, "step": 1}),
"method": (["bbox_center", "alpha_com", "profile_corr", "hybrid"], {"default": "alpha_com"}),
# ComfyUI LoadImage produces MASK from alpha and inverts it.
# If your mask is already alpha (0=transparent,1=opaque), set False.
"mask_is_inverted": ("BOOLEAN", {"default": True}),
# Optional safety clamps/smoothing
"max_abs_shift": ("INT", {"default": 0, "min": 0, "max": 99999, "step": 1}),
"temporal_median": ("INT", {"default": 1, "min": 1, "max": 99, "step": 1}),
# Hybrid sanity check: if corr shift differs from COM shift by more than this,
# use COM shift instead.
"hybrid_tolerance_px": ("INT", {"default": 8, "min": 0, "max": 99999, "step": 1}),
},
"optional": {
"mask": ("MASK", {}),
}
}
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
RETURN_NAMES = ("images", "mask", "shifts_x")
FUNCTION = "stabilize"
CATEGORY = "image/sprite"
SEARCH_ALIASES = ["wiggle stabilize", "sprite stabilize", "head stabilize", "animation stabilize", "sprite jitter fix"]
# ---------- helpers ----------
def _get_alpha(self, images: torch.Tensor, mask: torch.Tensor | None, mask_is_inverted: bool) -> torch.Tensor:
"""
Returns alpha in [0..1], shape [B,H,W].
"""
if images.dim() != 4:
raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}")
B, H, W, C = images.shape
if C >= 4:
return images[..., 3]
if mask is None:
raise ValueError("Need RGBA images (C>=4) OR provide a MASK input.")
if mask.dim() == 2:
mask = mask.unsqueeze(0)
if mask.dim() != 3:
raise ValueError(f"mask must have shape [B,H,W] or [H,W], got {tuple(mask.shape)}")
if mask.shape[1] != H or mask.shape[2] != W:
raise ValueError(f"mask H/W must match images; mask={tuple(mask.shape)} images={tuple(images.shape)}")
if mask.shape[0] == 1 and B > 1:
mask = mask.repeat(B, 1, 1)
if mask.shape[0] != B:
raise ValueError(f"mask batch must match images batch; mask B={mask.shape[0]} images B={B}")
alpha = 1.0 - mask if mask_is_inverted else mask
return alpha
def _clamp_y(self, H: int, y_min: int, y_max: int) -> tuple[int, int]:
y0 = int(y_min)
y1 = int(y_max)
if y1 < y0:
y0, y1 = y1, y0
y0 = max(0, min(H - 1, y0))
y1 = max(0, min(H - 1, y1))
return y0, y1
def _bbox_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None:
"""
alpha_hw: [H,W]
Returns center X using leftmost/rightmost visible columns, or None if empty.
"""
# visible: [H,W]
visible = alpha_hw > thr
cols = visible.any(dim=0) # [W]
if not bool(cols.any()):
return None
W = alpha_hw.shape[1]
left = int(torch.argmax(cols.float()).item())
right = int((W - 1) - torch.argmax(torch.flip(cols, dims=[0]).float()).item())
return (left + right) / 2.0
def _com_center_x(self, alpha_hw: torch.Tensor, thr: float) -> float | None:
"""
alpha_hw: [H,W]
Alpha-weighted center-of-mass of X within visible area, or None if empty.
"""
W = alpha_hw.shape[1]
weights = alpha_hw
if thr > 0:
weights = weights * (weights > thr)
profile = weights.sum(dim=0) # [W]
total = float(profile.sum().item())
if total <= 0.0:
return None
x = torch.arange(W, device=alpha_hw.device, dtype=profile.dtype)
center = float((profile * x).sum().item() / total)
return center
def _phase_corr_shift_x(self, alpha_hw: torch.Tensor, ref_profile: torch.Tensor, thr: float) -> int | None:
"""
Estimate integer shift to APPLY to current frame (X) so it matches reference.
Uses 1D phase correlation on horizontal alpha profile.
Returns shift_x (int), or None if empty.
"""
weights = alpha_hw
if thr > 0:
weights = weights * (weights > thr)
prof = weights.sum(dim=0).float()
if float(prof.sum().item()) <= 0.0:
return None
# Remove DC component
prof = prof - prof.mean()
ref = ref_profile
# Phase correlation
F = torch.fft.rfft(prof)
R = torch.fft.rfft(ref)
cps = F * torch.conj(R)
cps = cps / (torch.abs(cps) + 1e-9)
corr = torch.fft.irfft(cps, n=prof.numel())
peak = int(torch.argmax(corr).item())
W = prof.numel()
lag = peak if peak <= W // 2 else peak - W # lag = "current is shifted by lag relative to ref"
shift_x = -lag # apply negative to align to ref
return int(shift_x)
def _shift_frame_x(self, img_hwc: torch.Tensor, shift_x: int) -> torch.Tensor:
"""
img_hwc: [H,W,C]
shift_x: int (positive -> move right)
"""
H, W, C = img_hwc.shape
out = torch.zeros_like(img_hwc)
if shift_x == 0:
return img_hwc
if abs(shift_x) >= W:
return out
if shift_x > 0:
out[:, shift_x:, :] = img_hwc[:, : W - shift_x, :]
else:
sx = -shift_x
out[:, : W - sx, :] = img_hwc[:, sx:, :]
return out
def _shift_mask_x(self, m_hw: torch.Tensor, shift_x: int, fill_val: float) -> torch.Tensor:
"""
m_hw: [H,W]
"""
H, W = m_hw.shape
out = torch.full_like(m_hw, fill_val)
if shift_x == 0:
return m_hw
if abs(shift_x) >= W:
return out
if shift_x > 0:
out[:, shift_x:] = m_hw[:, : W - shift_x]
else:
sx = -shift_x
out[:, : W - sx] = m_hw[:, sx:]
return out
def _median_smooth(self, shifts: list[int], window: int) -> list[int]:
"""
Median filter over shifts with odd window size. Keeps shifts[0] unchanged.
"""
if window <= 1 or len(shifts) <= 2:
return shifts
w = int(window)
if w % 2 == 0:
w += 1
r = w // 2
out = shifts[:]
out[0] = shifts[0]
n = len(shifts)
for i in range(1, n):
lo = max(1, i - r)
hi = min(n, i + r + 1)
vals = sorted(shifts[lo:hi])
out[i] = vals[len(vals) // 2]
return out
# ---------- main ----------
def stabilize(
self,
images: torch.Tensor,
y_min: int = 210,
y_max: int = 332,
alpha_threshold_8bit: int = 5,
method: str = "alpha_com",
mask_is_inverted: bool = True,
max_abs_shift: int = 0,
temporal_median: int = 1,
hybrid_tolerance_px: int = 8,
mask: torch.Tensor | None = None,
):
if not torch.is_tensor(images):
raise TypeError("images must be a torch.Tensor")
if images.dim() != 4:
raise ValueError(f"images must have shape [B,H,W,C], got {tuple(images.shape)}")
B, H, W, C = images.shape
if B < 1:
raise ValueError("images batch is empty")
alpha = self._get_alpha(images, mask, mask_is_inverted) # [B,H,W]
y0, y1 = self._clamp_y(H, y_min, y_max)
thr = float(alpha_threshold_8bit) / 255.0
roi_alpha = alpha[:, y0:y1 + 1, :] # [B, Hr, W]
# Reference (frame 0)
ref_roi = roi_alpha[0] # [Hr,W]
# Prepare reference for methods
ref_center_bbox = None
ref_center_com = None
ref_profile = None
if method in ("bbox_center", "hybrid"):
ref_center_bbox = self._bbox_center_x(ref_roi, thr)
if method in ("alpha_com", "hybrid"):
ref_center_com = self._com_center_x(ref_roi, thr)
if method in ("profile_corr", "hybrid"):
# reference profile for phase correlation
w = ref_roi
if thr > 0:
w = w * (w > thr)
ref_profile = w.sum(dim=0).float()
ref_profile = ref_profile - ref_profile.mean()
# Fallback reference center if missing
if ref_center_bbox is None and ref_center_com is None and ref_profile is None:
# Nothing visible even in reference head region; do nothing.
out_mask = None
if mask is not None:
out_mask = mask if mask.dim() == 3 else mask.unsqueeze(0)
elif C >= 4:
a = images[..., 3]
out_mask = (1.0 - a) if mask_is_inverted else a
else:
fill_val = 1.0 if mask_is_inverted else 0.0
out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype)
return (images, out_mask, "[0]" if B == 1 else str([0] * B))
# For center-based methods, pick a reference center
# Preference: COM, else BBOX, else image center
if ref_center_com is not None:
ref_center = ref_center_com
elif ref_center_bbox is not None:
ref_center = ref_center_bbox
else:
ref_center = W / 2.0
shifts = [0] * B
shifts[0] = 0 # frame 0 stays
for i in range(1, B):
a_hw = roi_alpha[i]
shift_i = 0
if method == "bbox_center":
c = self._bbox_center_x(a_hw, thr)
if c is None:
shift_i = 0
else:
shift_i = int(round(ref_center - c))
elif method == "alpha_com":
c = self._com_center_x(a_hw, thr)
if c is None:
shift_i = 0
else:
shift_i = int(round(ref_center - c))
elif method == "profile_corr":
s = self._phase_corr_shift_x(a_hw, ref_profile, thr) # already int shift to APPLY
shift_i = 0 if s is None else int(s)
elif method == "hybrid":
# corr shift
s_corr = self._phase_corr_shift_x(a_hw, ref_profile, thr) if ref_profile is not None else None
# com shift
c = self._com_center_x(a_hw, thr)
s_com = None if c is None else int(round(ref_center - c))
if s_corr is None and s_com is None:
shift_i = 0
elif s_corr is None:
shift_i = int(s_com)
elif s_com is None:
shift_i = int(s_corr)
else:
if abs(int(s_corr) - int(s_com)) > int(hybrid_tolerance_px):
shift_i = int(s_com)
else:
shift_i = int(s_corr)
else:
raise ValueError(f"Unknown method: {method}")
# Clamp extreme shifts if requested
if max_abs_shift and max_abs_shift > 0:
shift_i = int(max(-max_abs_shift, min(max_abs_shift, shift_i)))
shifts[i] = shift_i
# Optional temporal median smoothing (keeps shifts[0] anchored)
shifts = self._median_smooth(shifts, int(temporal_median))
# Apply per-frame shifts
out_images = torch.zeros_like(images)
# Output mask handling:
# - If input mask provided: shift it
# - Else if RGBA: derive from shifted alpha
# - Else: produce blank
out_mask = None
in_mask_bhw = None
if mask is not None:
in_mask_bhw = mask
if in_mask_bhw.dim() == 2:
in_mask_bhw = in_mask_bhw.unsqueeze(0)
if in_mask_bhw.shape[0] == 1 and B > 1:
in_mask_bhw = in_mask_bhw.repeat(B, 1, 1)
fill_val = 1.0 if mask_is_inverted else 0.0
out_mask = torch.full_like(in_mask_bhw, fill_val)
for i in range(B):
sx = int(shifts[i])
out_images[i] = self._shift_frame_x(images[i], sx)
if out_mask is not None and in_mask_bhw is not None:
fill_val = 1.0 if mask_is_inverted else 0.0
out_mask[i] = self._shift_mask_x(in_mask_bhw[i], sx, fill_val)
if out_mask is None:
if out_images.shape[-1] >= 4:
a = out_images[..., 3]
out_mask = (1.0 - a) if mask_is_inverted else a
else:
fill_val = 1.0 if mask_is_inverted else 0.0
out_mask = torch.full((B, H, W), fill_val, device=images.device, dtype=images.dtype)
shifts_str = str(shifts)
return (out_images, out_mask, shifts_str)
NODE_CLASS_MAPPINGS = {
"SpriteHeadStabilizeX": SpriteHeadStabilizeX,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SpriteHeadStabilizeX": "Sprite Head Stabilize X (Batch)",
}