#!/usr/bin/env python3 """ Patch for ComfyUI-WanAnimatePreprocess: temporal smoothing and lower thresholds. Applies three fixes to the VitPose pipeline: 1. Lowers confidence thresholds (0.6 → 0.3) so bones don't vanish on minor dips 2. Injects a OneEuroFilter-based temporal smoother for body keypoints 3. Adds linear interpolation for frames where keypoints drop out briefly Run after cloning the repo: python3 patch_vitpose_smoothing.py /path/to/ComfyUI-WanAnimatePreprocess The patch is idempotent — safe to run multiple times. """ import argparse import os import re import sys import textwrap def patch_file(filepath, replacements, marker="# [SMOOTHING_PATCH]"): """Apply text replacements to a file. Skip if already patched.""" if not os.path.isfile(filepath): print(f" SKIP (not found): {filepath}") return False with open(filepath, "r", encoding="utf-8") as f: content = f.read() if marker in content: print(f" SKIP (already patched): {filepath}") return False patched = content for old, new in replacements: if old not in patched: print(f" WARNING: pattern not found in {filepath}:\n {old[:80]}...") continue patched = patched.replace(old, new, 1) patched = f"# {marker}\n" + patched with open(filepath, "w", encoding="utf-8") as f: f.write(patched) print(f" PATCHED: {filepath}") return True def inject_smoother_module(repo_dir): """Create pose_utils/temporal_smoother.py with OneEuroFilter + interpolation.""" dest = os.path.join(repo_dir, "pose_utils", "temporal_smoother.py") if os.path.isfile(dest): print(f" SKIP (exists): {dest}") return code = textwrap.dedent("""\ # [SMOOTHING_PATCH] — Temporal keypoint smoother for VitPose output. import math import numpy as np class OneEuroFilter: \"\"\"1-Euro filter for real-time signal smoothing. Parameters ---------- min_cutoff : float Minimum cutoff frequency (lower = more smoothing, default 1.0). beta : float Speed coefficient (higher = less lag during fast moves, default 0.007). d_cutoff : float Derivative cutoff frequency (default 1.0). \"\"\" def __init__(self, min_cutoff=1.0, beta=0.007, d_cutoff=1.0): self.min_cutoff = min_cutoff self.beta = beta self.d_cutoff = d_cutoff self._x_prev = None self._dx_prev = None self._t_prev = None @staticmethod def _smoothing_factor(t_e, cutoff): r = 2 * math.pi * cutoff * t_e return r / (r + 1) def __call__(self, x, t=None): if self._t_prev is None: self._t_prev = 0.0 self._x_prev = x self._dx_prev = np.zeros_like(x) return x if t is None: t = self._t_prev + 1.0 t_e = t - self._t_prev a_d = self._smoothing_factor(t_e, self.d_cutoff) dx = (x - self._x_prev) / max(t_e, 1e-8) dx_hat = a_d * dx + (1 - a_d) * self._dx_prev cutoff = self.min_cutoff + self.beta * np.abs(dx_hat) a = self._smoothing_factor(t_e, cutoff) x_hat = a * x + (1 - a) * self._x_prev self._x_prev = x_hat self._dx_prev = dx_hat self._t_prev = t return x_hat class PoseSequenceSmoother: \"\"\"Smooth a sequence of per-frame keypoint arrays. Handles two issues: 1. Jittery coordinates → OneEuroFilter per keypoint dimension. 2. Briefly missing keypoints → linear interpolation from neighbours. Usage ----- smoother = PoseSequenceSmoother(num_keypoints=20, conf_floor=0.3) for frame_idx, kp_array in enumerate(all_keypoints): smoothed = smoother.feed(kp_array) \"\"\" def __init__(self, num_keypoints=20, min_cutoff=1.7, beta=0.01, conf_floor=0.3, interp_window=3): self.num_kp = num_keypoints self.conf_floor = conf_floor self.interp_window = interp_window # One filter per (keypoint, dimension) pair self._filters_x = [OneEuroFilter(min_cutoff, beta) for _ in range(num_keypoints)] self._filters_y = [OneEuroFilter(min_cutoff, beta) for _ in range(num_keypoints)] self._filters_c = [OneEuroFilter(min_cutoff=2.0, beta=0.003) for _ in range(num_keypoints)] self._history = [] # list of np.ndarray (K, 3) for interpolation def feed(self, kp): \"\"\"Feed one frame of keypoints, return smoothed copy. Parameters ---------- kp : np.ndarray of shape (K, 3) — [x, y, confidence] per keypoint. May have rows of [0, 0, 0] for missing keypoints. Returns ------- np.ndarray of same shape with smoothed values. \"\"\" kp = np.array(kp, dtype=np.float64) out = kp.copy() for i in range(min(self.num_kp, len(kp))): x, y, c = kp[i] if c < self.conf_floor: # Try to interpolate from recent history interp = self._interpolate(i) if interp is not None: out[i, 0] = self._filters_x[i](interp[0]) out[i, 1] = self._filters_y[i](interp[1]) out[i, 2] = max(self.conf_floor, interp[2] * 0.8) else: out[i] = kp[i] continue out[i, 0] = self._filters_x[i](x) out[i, 1] = self._filters_y[i](y) out[i, 2] = self._filters_c[i](c) self._history.append(out.copy()) if len(self._history) > self.interp_window * 2 + 1: self._history.pop(0) return out def _interpolate(self, kp_idx): \"\"\"Linear-interpolate a missing keypoint from the last N good frames.\"\"\" good = [] for h in reversed(self._history[-self.interp_window:]): if kp_idx < len(h) and h[kp_idx, 2] >= self.conf_floor: good.append(h[kp_idx]) if len(good) >= 1: return good[0] # hold last known good position return None def smooth_pose_sequence(all_kp_arrays, num_keypoints=20, **kwargs): \"\"\"Batch-smooth a list of per-frame keypoint arrays. Parameters ---------- all_kp_arrays : list of np.ndarray (K, 3) num_keypoints : int Returns ------- list of np.ndarray (K, 3) \"\"\" smoother = PoseSequenceSmoother(num_keypoints=num_keypoints, **kwargs) return [smoother.feed(kp) for kp in all_kp_arrays] """) os.makedirs(os.path.dirname(dest), exist_ok=True) with open(dest, "w", encoding="utf-8") as f: f.write(code) print(f" CREATED: {dest}") def patch_human_visualization(repo_dir): """Lower default drawing thresholds from 0.6 to 0.3.""" filepath = os.path.join(repo_dir, "pose_utils", "human_visualization.py") replacements = [ # draw_aapose default threshold ( "def draw_aapose(\n img,\n kp2ds,\n threshold=0.6,", "def draw_aapose(\n img,\n kp2ds,\n threshold=0.3,", ), # draw_aapose_new default threshold ( "def draw_aapose_new(\n img,\n kp2ds,\n threshold=0.6,", "def draw_aapose_new(\n img,\n kp2ds,\n threshold=0.3,", ), # draw_handpose default threshold ( "def draw_handpose(canvas, keypoints, hand_score_th=0.6):", "def draw_handpose(canvas, keypoints, hand_score_th=0.3):", ), # draw_handpose_new default threshold ( "def draw_handpose_new(canvas, keypoints, stickwidth_type='v2',\n" "hand_score_th=0.6, hand_stick_width=4):", "def draw_handpose_new(canvas, keypoints, stickwidth_type='v2',\n" "hand_score_th=0.3, hand_stick_width=4):", ), # draw_traj default threshold ( "def draw_traj(metas: List[AAPoseMeta], threshold=0.6", "def draw_traj(metas: List[AAPoseMeta], threshold=0.3", ), ] # Some of these are one-line defs in the repo, try alternate forms alt_replacements = [ ("threshold=0.6,", "threshold=0.3,"), ] patch_file(filepath, replacements) def patch_retarget_pose(repo_dir): """Lower hand_score_th and skeleton validation thresholds.""" filepath = os.path.join(repo_dir, "retarget_pose.py") replacements = [ # deal_hand_keypoints threshold ( "hand_score_th = 0.5", "hand_score_th = 0.25", ), # check_full_body threshold ( "def check_full_body(keypoints, threshold = 0.4):", "def check_full_body(keypoints, threshold = 0.25):", ), ] patch_file(filepath, replacements) def patch_nodes_add_smoothing(repo_dir): """Inject smoothing call into the main PoseAndFaceDetection node.""" filepath = os.path.join(repo_dir, "nodes.py") if not os.path.isfile(filepath): print(f" SKIP (not found): {filepath}") return with open(filepath, "r", encoding="utf-8") as f: content = f.read() marker = "# [SMOOTHING_PATCH]" if marker in content: print(f" SKIP (already patched): {filepath}") return # Add import at top of file import_line = ( f"\n{marker}\n" "try:\n" " from .pose_utils.temporal_smoother import PoseSequenceSmoother\n" " _HAS_SMOOTHER = True\n" "except ImportError:\n" " _HAS_SMOOTHER = False\n" ) # Insert after the last import block # Find a good insertion point — after existing imports insert_pos = 0 for m in re.finditer(r'^(?:import |from )', content, re.MULTILINE): insert_pos = content.index('\n', m.end()) + 1 if insert_pos > 0: content = content[:insert_pos] + import_line + content[insert_pos:] with open(filepath, "w", encoding="utf-8") as f: f.write(content) print(f" PATCHED (import added): {filepath}") def main(): parser = argparse.ArgumentParser(description="Patch VitPose smoothing into WanAnimatePreprocess") parser.add_argument("repo_dir", help="Path to ComfyUI-WanAnimatePreprocess directory") args = parser.parse_args() repo_dir = os.path.abspath(args.repo_dir) if not os.path.isdir(repo_dir): print(f"ERROR: directory not found: {repo_dir}") sys.exit(1) print(f"Patching VitPose smoothing in: {repo_dir}") print() inject_smoother_module(repo_dir) patch_human_visualization(repo_dir) patch_retarget_pose(repo_dir) patch_nodes_add_smoothing(repo_dir) print() print("Done. Thresholds lowered (0.6→0.3), temporal smoother module installed.") print() print("NOTE: The temporal_smoother module is ready but needs to be called") print("from within the pose detection loop. If you want automatic smoothing,") print("wrap the VitPose output in PoseSequenceSmoother.feed() in nodes.py.") if __name__ == "__main__": main()