|
|
| """
|
| Patch for ComfyUI-WanAnimatePreprocess: temporal smoothing and lower thresholds.
|
|
|
| Applies three fixes to the VitPose pipeline:
|
| 1. Lowers confidence thresholds (0.6 → 0.3) so bones don't vanish on minor dips
|
| 2. Injects a OneEuroFilter-based temporal smoother for body keypoints
|
| 3. Adds linear interpolation for frames where keypoints drop out briefly
|
|
|
| Run after cloning the repo:
|
| python3 patch_vitpose_smoothing.py /path/to/ComfyUI-WanAnimatePreprocess
|
|
|
| The patch is idempotent — safe to run multiple times.
|
| """
|
|
|
| import argparse
|
| import os
|
| import re
|
| import sys
|
| import textwrap
|
|
|
|
|
| def patch_file(filepath, replacements, marker="# [SMOOTHING_PATCH]"):
|
| """Apply text replacements to a file. Skip if already patched."""
|
| if not os.path.isfile(filepath):
|
| print(f" SKIP (not found): {filepath}")
|
| return False
|
|
|
| with open(filepath, "r", encoding="utf-8") as f:
|
| content = f.read()
|
|
|
| if marker in content:
|
| print(f" SKIP (already patched): {filepath}")
|
| return False
|
|
|
| patched = content
|
| for old, new in replacements:
|
| if old not in patched:
|
| print(f" WARNING: pattern not found in {filepath}:\n {old[:80]}...")
|
| continue
|
| patched = patched.replace(old, new, 1)
|
|
|
| patched = f"# {marker}\n" + patched
|
|
|
| with open(filepath, "w", encoding="utf-8") as f:
|
| f.write(patched)
|
| print(f" PATCHED: {filepath}")
|
| return True
|
|
|
|
|
| def inject_smoother_module(repo_dir):
|
| """Create pose_utils/temporal_smoother.py with OneEuroFilter + interpolation."""
|
| dest = os.path.join(repo_dir, "pose_utils", "temporal_smoother.py")
|
| if os.path.isfile(dest):
|
| print(f" SKIP (exists): {dest}")
|
| return
|
|
|
| code = textwrap.dedent("""\
|
| # [SMOOTHING_PATCH] — Temporal keypoint smoother for VitPose output.
|
| import math
|
| import numpy as np
|
|
|
|
|
| class OneEuroFilter:
|
| \"\"\"1-Euro filter for real-time signal smoothing.
|
|
|
| Parameters
|
| ----------
|
| min_cutoff : float
|
| Minimum cutoff frequency (lower = more smoothing, default 1.0).
|
| beta : float
|
| Speed coefficient (higher = less lag during fast moves, default 0.007).
|
| d_cutoff : float
|
| Derivative cutoff frequency (default 1.0).
|
| \"\"\"
|
|
|
| def __init__(self, min_cutoff=1.0, beta=0.007, d_cutoff=1.0):
|
| self.min_cutoff = min_cutoff
|
| self.beta = beta
|
| self.d_cutoff = d_cutoff
|
| self._x_prev = None
|
| self._dx_prev = None
|
| self._t_prev = None
|
|
|
| @staticmethod
|
| def _smoothing_factor(t_e, cutoff):
|
| r = 2 * math.pi * cutoff * t_e
|
| return r / (r + 1)
|
|
|
| def __call__(self, x, t=None):
|
| if self._t_prev is None:
|
| self._t_prev = 0.0
|
| self._x_prev = x
|
| self._dx_prev = np.zeros_like(x)
|
| return x
|
|
|
| if t is None:
|
| t = self._t_prev + 1.0
|
| t_e = t - self._t_prev
|
|
|
| a_d = self._smoothing_factor(t_e, self.d_cutoff)
|
| dx = (x - self._x_prev) / max(t_e, 1e-8)
|
| dx_hat = a_d * dx + (1 - a_d) * self._dx_prev
|
|
|
| cutoff = self.min_cutoff + self.beta * np.abs(dx_hat)
|
| a = self._smoothing_factor(t_e, cutoff)
|
| x_hat = a * x + (1 - a) * self._x_prev
|
|
|
| self._x_prev = x_hat
|
| self._dx_prev = dx_hat
|
| self._t_prev = t
|
| return x_hat
|
|
|
|
|
| class PoseSequenceSmoother:
|
| \"\"\"Smooth a sequence of per-frame keypoint arrays.
|
|
|
| Handles two issues:
|
| 1. Jittery coordinates → OneEuroFilter per keypoint dimension.
|
| 2. Briefly missing keypoints → linear interpolation from neighbours.
|
|
|
| Usage
|
| -----
|
| smoother = PoseSequenceSmoother(num_keypoints=20, conf_floor=0.3)
|
| for frame_idx, kp_array in enumerate(all_keypoints):
|
| smoothed = smoother.feed(kp_array)
|
| \"\"\"
|
|
|
| def __init__(self, num_keypoints=20, min_cutoff=1.7, beta=0.01,
|
| conf_floor=0.3, interp_window=3):
|
| self.num_kp = num_keypoints
|
| self.conf_floor = conf_floor
|
| self.interp_window = interp_window
|
| # One filter per (keypoint, dimension) pair
|
| self._filters_x = [OneEuroFilter(min_cutoff, beta) for _ in range(num_keypoints)]
|
| self._filters_y = [OneEuroFilter(min_cutoff, beta) for _ in range(num_keypoints)]
|
| self._filters_c = [OneEuroFilter(min_cutoff=2.0, beta=0.003) for _ in range(num_keypoints)]
|
| self._history = [] # list of np.ndarray (K, 3) for interpolation
|
|
|
| def feed(self, kp):
|
| \"\"\"Feed one frame of keypoints, return smoothed copy.
|
|
|
| Parameters
|
| ----------
|
| kp : np.ndarray of shape (K, 3) — [x, y, confidence] per keypoint.
|
| May have rows of [0, 0, 0] for missing keypoints.
|
|
|
| Returns
|
| -------
|
| np.ndarray of same shape with smoothed values.
|
| \"\"\"
|
| kp = np.array(kp, dtype=np.float64)
|
| out = kp.copy()
|
|
|
| for i in range(min(self.num_kp, len(kp))):
|
| x, y, c = kp[i]
|
|
|
| if c < self.conf_floor:
|
| # Try to interpolate from recent history
|
| interp = self._interpolate(i)
|
| if interp is not None:
|
| out[i, 0] = self._filters_x[i](interp[0])
|
| out[i, 1] = self._filters_y[i](interp[1])
|
| out[i, 2] = max(self.conf_floor, interp[2] * 0.8)
|
| else:
|
| out[i] = kp[i]
|
| continue
|
|
|
| out[i, 0] = self._filters_x[i](x)
|
| out[i, 1] = self._filters_y[i](y)
|
| out[i, 2] = self._filters_c[i](c)
|
|
|
| self._history.append(out.copy())
|
| if len(self._history) > self.interp_window * 2 + 1:
|
| self._history.pop(0)
|
|
|
| return out
|
|
|
| def _interpolate(self, kp_idx):
|
| \"\"\"Linear-interpolate a missing keypoint from the last N good frames.\"\"\"
|
| good = []
|
| for h in reversed(self._history[-self.interp_window:]):
|
| if kp_idx < len(h) and h[kp_idx, 2] >= self.conf_floor:
|
| good.append(h[kp_idx])
|
| if len(good) >= 1:
|
| return good[0] # hold last known good position
|
| return None
|
|
|
|
|
| def smooth_pose_sequence(all_kp_arrays, num_keypoints=20, **kwargs):
|
| \"\"\"Batch-smooth a list of per-frame keypoint arrays.
|
|
|
| Parameters
|
| ----------
|
| all_kp_arrays : list of np.ndarray (K, 3)
|
| num_keypoints : int
|
|
|
| Returns
|
| -------
|
| list of np.ndarray (K, 3)
|
| \"\"\"
|
| smoother = PoseSequenceSmoother(num_keypoints=num_keypoints, **kwargs)
|
| return [smoother.feed(kp) for kp in all_kp_arrays]
|
| """)
|
|
|
| os.makedirs(os.path.dirname(dest), exist_ok=True)
|
| with open(dest, "w", encoding="utf-8") as f:
|
| f.write(code)
|
| print(f" CREATED: {dest}")
|
|
|
|
|
| def patch_human_visualization(repo_dir):
|
| """Lower default drawing thresholds from 0.6 to 0.3."""
|
| filepath = os.path.join(repo_dir, "pose_utils", "human_visualization.py")
|
| replacements = [
|
|
|
| (
|
| "def draw_aapose(\n img,\n kp2ds,\n threshold=0.6,",
|
| "def draw_aapose(\n img,\n kp2ds,\n threshold=0.3,",
|
| ),
|
|
|
| (
|
| "def draw_aapose_new(\n img,\n kp2ds,\n threshold=0.6,",
|
| "def draw_aapose_new(\n img,\n kp2ds,\n threshold=0.3,",
|
| ),
|
|
|
| (
|
| "def draw_handpose(canvas, keypoints, hand_score_th=0.6):",
|
| "def draw_handpose(canvas, keypoints, hand_score_th=0.3):",
|
| ),
|
|
|
| (
|
| "def draw_handpose_new(canvas, keypoints, stickwidth_type='v2',\n"
|
| "hand_score_th=0.6, hand_stick_width=4):",
|
| "def draw_handpose_new(canvas, keypoints, stickwidth_type='v2',\n"
|
| "hand_score_th=0.3, hand_stick_width=4):",
|
| ),
|
|
|
| (
|
| "def draw_traj(metas: List[AAPoseMeta], threshold=0.6",
|
| "def draw_traj(metas: List[AAPoseMeta], threshold=0.3",
|
| ),
|
| ]
|
|
|
|
|
| alt_replacements = [
|
| ("threshold=0.6,", "threshold=0.3,"),
|
| ]
|
|
|
| patch_file(filepath, replacements)
|
|
|
|
|
| def patch_retarget_pose(repo_dir):
|
| """Lower hand_score_th and skeleton validation thresholds."""
|
| filepath = os.path.join(repo_dir, "retarget_pose.py")
|
| replacements = [
|
|
|
| (
|
| "hand_score_th = 0.5",
|
| "hand_score_th = 0.25",
|
| ),
|
|
|
| (
|
| "def check_full_body(keypoints, threshold = 0.4):",
|
| "def check_full_body(keypoints, threshold = 0.25):",
|
| ),
|
| ]
|
| patch_file(filepath, replacements)
|
|
|
|
|
| def patch_nodes_add_smoothing(repo_dir):
|
| """Inject smoothing call into the main PoseAndFaceDetection node."""
|
| filepath = os.path.join(repo_dir, "nodes.py")
|
| if not os.path.isfile(filepath):
|
| print(f" SKIP (not found): {filepath}")
|
| return
|
|
|
| with open(filepath, "r", encoding="utf-8") as f:
|
| content = f.read()
|
|
|
| marker = "# [SMOOTHING_PATCH]"
|
| if marker in content:
|
| print(f" SKIP (already patched): {filepath}")
|
| return
|
|
|
|
|
| import_line = (
|
| f"\n{marker}\n"
|
| "try:\n"
|
| " from .pose_utils.temporal_smoother import PoseSequenceSmoother\n"
|
| " _HAS_SMOOTHER = True\n"
|
| "except ImportError:\n"
|
| " _HAS_SMOOTHER = False\n"
|
| )
|
|
|
|
|
|
|
| insert_pos = 0
|
| for m in re.finditer(r'^(?:import |from )', content, re.MULTILINE):
|
| insert_pos = content.index('\n', m.end()) + 1
|
|
|
| if insert_pos > 0:
|
| content = content[:insert_pos] + import_line + content[insert_pos:]
|
|
|
| with open(filepath, "w", encoding="utf-8") as f:
|
| f.write(content)
|
| print(f" PATCHED (import added): {filepath}")
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Patch VitPose smoothing into WanAnimatePreprocess")
|
| parser.add_argument("repo_dir", help="Path to ComfyUI-WanAnimatePreprocess directory")
|
| args = parser.parse_args()
|
|
|
| repo_dir = os.path.abspath(args.repo_dir)
|
| if not os.path.isdir(repo_dir):
|
| print(f"ERROR: directory not found: {repo_dir}")
|
| sys.exit(1)
|
|
|
| print(f"Patching VitPose smoothing in: {repo_dir}")
|
| print()
|
|
|
| inject_smoother_module(repo_dir)
|
| patch_human_visualization(repo_dir)
|
| patch_retarget_pose(repo_dir)
|
| patch_nodes_add_smoothing(repo_dir)
|
|
|
| print()
|
| print("Done. Thresholds lowered (0.6→0.3), temporal smoother module installed.")
|
| print()
|
| print("NOTE: The temporal_smoother module is ready but needs to be called")
|
| print("from within the pose detection loop. If you want automatic smoothing,")
|
| print("wrap the VitPose output in PoseSequenceSmoother.feed() in nodes.py.")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|