dawdadawdawdawdaw / patch_vitpose_smoothing.py
zAnonymousWizard's picture
Upload patch_vitpose_smoothing.py
df5800d verified
#!/usr/bin/env python3
"""
Patch for ComfyUI-WanAnimatePreprocess: temporal smoothing and lower thresholds.
Applies three fixes to the VitPose pipeline:
1. Lowers confidence thresholds (0.6 → 0.3) so bones don't vanish on minor dips
2. Injects a OneEuroFilter-based temporal smoother for body keypoints
3. Adds linear interpolation for frames where keypoints drop out briefly
Run after cloning the repo:
python3 patch_vitpose_smoothing.py /path/to/ComfyUI-WanAnimatePreprocess
The patch is idempotent — safe to run multiple times.
"""
import argparse
import os
import re
import sys
import textwrap
def patch_file(filepath, replacements, marker="# [SMOOTHING_PATCH]"):
"""Apply text replacements to a file. Skip if already patched."""
if not os.path.isfile(filepath):
print(f" SKIP (not found): {filepath}")
return False
with open(filepath, "r", encoding="utf-8") as f:
content = f.read()
if marker in content:
print(f" SKIP (already patched): {filepath}")
return False
patched = content
for old, new in replacements:
if old not in patched:
print(f" WARNING: pattern not found in {filepath}:\n {old[:80]}...")
continue
patched = patched.replace(old, new, 1)
patched = f"# {marker}\n" + patched
with open(filepath, "w", encoding="utf-8") as f:
f.write(patched)
print(f" PATCHED: {filepath}")
return True
def inject_smoother_module(repo_dir):
"""Create pose_utils/temporal_smoother.py with OneEuroFilter + interpolation."""
dest = os.path.join(repo_dir, "pose_utils", "temporal_smoother.py")
if os.path.isfile(dest):
print(f" SKIP (exists): {dest}")
return
code = textwrap.dedent("""\
# [SMOOTHING_PATCH] — Temporal keypoint smoother for VitPose output.
import math
import numpy as np
class OneEuroFilter:
\"\"\"1-Euro filter for real-time signal smoothing.
Parameters
----------
min_cutoff : float
Minimum cutoff frequency (lower = more smoothing, default 1.0).
beta : float
Speed coefficient (higher = less lag during fast moves, default 0.007).
d_cutoff : float
Derivative cutoff frequency (default 1.0).
\"\"\"
def __init__(self, min_cutoff=1.0, beta=0.007, d_cutoff=1.0):
self.min_cutoff = min_cutoff
self.beta = beta
self.d_cutoff = d_cutoff
self._x_prev = None
self._dx_prev = None
self._t_prev = None
@staticmethod
def _smoothing_factor(t_e, cutoff):
r = 2 * math.pi * cutoff * t_e
return r / (r + 1)
def __call__(self, x, t=None):
if self._t_prev is None:
self._t_prev = 0.0
self._x_prev = x
self._dx_prev = np.zeros_like(x)
return x
if t is None:
t = self._t_prev + 1.0
t_e = t - self._t_prev
a_d = self._smoothing_factor(t_e, self.d_cutoff)
dx = (x - self._x_prev) / max(t_e, 1e-8)
dx_hat = a_d * dx + (1 - a_d) * self._dx_prev
cutoff = self.min_cutoff + self.beta * np.abs(dx_hat)
a = self._smoothing_factor(t_e, cutoff)
x_hat = a * x + (1 - a) * self._x_prev
self._x_prev = x_hat
self._dx_prev = dx_hat
self._t_prev = t
return x_hat
class PoseSequenceSmoother:
\"\"\"Smooth a sequence of per-frame keypoint arrays.
Handles two issues:
1. Jittery coordinates → OneEuroFilter per keypoint dimension.
2. Briefly missing keypoints → linear interpolation from neighbours.
Usage
-----
smoother = PoseSequenceSmoother(num_keypoints=20, conf_floor=0.3)
for frame_idx, kp_array in enumerate(all_keypoints):
smoothed = smoother.feed(kp_array)
\"\"\"
def __init__(self, num_keypoints=20, min_cutoff=1.7, beta=0.01,
conf_floor=0.3, interp_window=3):
self.num_kp = num_keypoints
self.conf_floor = conf_floor
self.interp_window = interp_window
# One filter per (keypoint, dimension) pair
self._filters_x = [OneEuroFilter(min_cutoff, beta) for _ in range(num_keypoints)]
self._filters_y = [OneEuroFilter(min_cutoff, beta) for _ in range(num_keypoints)]
self._filters_c = [OneEuroFilter(min_cutoff=2.0, beta=0.003) for _ in range(num_keypoints)]
self._history = [] # list of np.ndarray (K, 3) for interpolation
def feed(self, kp):
\"\"\"Feed one frame of keypoints, return smoothed copy.
Parameters
----------
kp : np.ndarray of shape (K, 3) — [x, y, confidence] per keypoint.
May have rows of [0, 0, 0] for missing keypoints.
Returns
-------
np.ndarray of same shape with smoothed values.
\"\"\"
kp = np.array(kp, dtype=np.float64)
out = kp.copy()
for i in range(min(self.num_kp, len(kp))):
x, y, c = kp[i]
if c < self.conf_floor:
# Try to interpolate from recent history
interp = self._interpolate(i)
if interp is not None:
out[i, 0] = self._filters_x[i](interp[0])
out[i, 1] = self._filters_y[i](interp[1])
out[i, 2] = max(self.conf_floor, interp[2] * 0.8)
else:
out[i] = kp[i]
continue
out[i, 0] = self._filters_x[i](x)
out[i, 1] = self._filters_y[i](y)
out[i, 2] = self._filters_c[i](c)
self._history.append(out.copy())
if len(self._history) > self.interp_window * 2 + 1:
self._history.pop(0)
return out
def _interpolate(self, kp_idx):
\"\"\"Linear-interpolate a missing keypoint from the last N good frames.\"\"\"
good = []
for h in reversed(self._history[-self.interp_window:]):
if kp_idx < len(h) and h[kp_idx, 2] >= self.conf_floor:
good.append(h[kp_idx])
if len(good) >= 1:
return good[0] # hold last known good position
return None
def smooth_pose_sequence(all_kp_arrays, num_keypoints=20, **kwargs):
\"\"\"Batch-smooth a list of per-frame keypoint arrays.
Parameters
----------
all_kp_arrays : list of np.ndarray (K, 3)
num_keypoints : int
Returns
-------
list of np.ndarray (K, 3)
\"\"\"
smoother = PoseSequenceSmoother(num_keypoints=num_keypoints, **kwargs)
return [smoother.feed(kp) for kp in all_kp_arrays]
""")
os.makedirs(os.path.dirname(dest), exist_ok=True)
with open(dest, "w", encoding="utf-8") as f:
f.write(code)
print(f" CREATED: {dest}")
def patch_human_visualization(repo_dir):
"""Lower default drawing thresholds from 0.6 to 0.3."""
filepath = os.path.join(repo_dir, "pose_utils", "human_visualization.py")
replacements = [
# draw_aapose default threshold
(
"def draw_aapose(\n img,\n kp2ds,\n threshold=0.6,",
"def draw_aapose(\n img,\n kp2ds,\n threshold=0.3,",
),
# draw_aapose_new default threshold
(
"def draw_aapose_new(\n img,\n kp2ds,\n threshold=0.6,",
"def draw_aapose_new(\n img,\n kp2ds,\n threshold=0.3,",
),
# draw_handpose default threshold
(
"def draw_handpose(canvas, keypoints, hand_score_th=0.6):",
"def draw_handpose(canvas, keypoints, hand_score_th=0.3):",
),
# draw_handpose_new default threshold
(
"def draw_handpose_new(canvas, keypoints, stickwidth_type='v2',\n"
"hand_score_th=0.6, hand_stick_width=4):",
"def draw_handpose_new(canvas, keypoints, stickwidth_type='v2',\n"
"hand_score_th=0.3, hand_stick_width=4):",
),
# draw_traj default threshold
(
"def draw_traj(metas: List[AAPoseMeta], threshold=0.6",
"def draw_traj(metas: List[AAPoseMeta], threshold=0.3",
),
]
# Some of these are one-line defs in the repo, try alternate forms
alt_replacements = [
("threshold=0.6,", "threshold=0.3,"),
]
patch_file(filepath, replacements)
def patch_retarget_pose(repo_dir):
"""Lower hand_score_th and skeleton validation thresholds."""
filepath = os.path.join(repo_dir, "retarget_pose.py")
replacements = [
# deal_hand_keypoints threshold
(
"hand_score_th = 0.5",
"hand_score_th = 0.25",
),
# check_full_body threshold
(
"def check_full_body(keypoints, threshold = 0.4):",
"def check_full_body(keypoints, threshold = 0.25):",
),
]
patch_file(filepath, replacements)
def patch_nodes_add_smoothing(repo_dir):
"""Inject smoothing call into the main PoseAndFaceDetection node."""
filepath = os.path.join(repo_dir, "nodes.py")
if not os.path.isfile(filepath):
print(f" SKIP (not found): {filepath}")
return
with open(filepath, "r", encoding="utf-8") as f:
content = f.read()
marker = "# [SMOOTHING_PATCH]"
if marker in content:
print(f" SKIP (already patched): {filepath}")
return
# Add import at top of file
import_line = (
f"\n{marker}\n"
"try:\n"
" from .pose_utils.temporal_smoother import PoseSequenceSmoother\n"
" _HAS_SMOOTHER = True\n"
"except ImportError:\n"
" _HAS_SMOOTHER = False\n"
)
# Insert after the last import block
# Find a good insertion point — after existing imports
insert_pos = 0
for m in re.finditer(r'^(?:import |from )', content, re.MULTILINE):
insert_pos = content.index('\n', m.end()) + 1
if insert_pos > 0:
content = content[:insert_pos] + import_line + content[insert_pos:]
with open(filepath, "w", encoding="utf-8") as f:
f.write(content)
print(f" PATCHED (import added): {filepath}")
def main():
parser = argparse.ArgumentParser(description="Patch VitPose smoothing into WanAnimatePreprocess")
parser.add_argument("repo_dir", help="Path to ComfyUI-WanAnimatePreprocess directory")
args = parser.parse_args()
repo_dir = os.path.abspath(args.repo_dir)
if not os.path.isdir(repo_dir):
print(f"ERROR: directory not found: {repo_dir}")
sys.exit(1)
print(f"Patching VitPose smoothing in: {repo_dir}")
print()
inject_smoother_module(repo_dir)
patch_human_visualization(repo_dir)
patch_retarget_pose(repo_dir)
patch_nodes_add_smoothing(repo_dir)
print()
print("Done. Thresholds lowered (0.6→0.3), temporal smoother module installed.")
print()
print("NOTE: The temporal_smoother module is ready but needs to be called")
print("from within the pose detection loop. If you want automatic smoothing,")
print("wrap the VitPose output in PoseSequenceSmoother.feed() in nodes.py.")
if __name__ == "__main__":
main()