|
|
import json |
|
|
from .motion import process_tracks |
|
|
import numpy as np |
|
|
from typing import List, Tuple |
|
|
import torch |
|
|
FIXED_LENGTH = 121 |
|
|
def pad_pts(tr): |
|
|
"""Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating.""" |
|
|
pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32) |
|
|
n = pts.shape[0] |
|
|
if n < FIXED_LENGTH: |
|
|
pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32) |
|
|
pts = np.vstack((pts, pad)) |
|
|
else: |
|
|
pts = pts[:FIXED_LENGTH] |
|
|
return pts.reshape(FIXED_LENGTH, 1, 3) |
|
|
|
|
|
def age_to_bgr(ratio: float) -> Tuple[int,int,int]: |
|
|
""" |
|
|
Map ratio∈[0,1] through: 0→blue, 1/3→green, 2/3→yellow, 1→red. |
|
|
Returns (B,G,R) for OpenCV. |
|
|
""" |
|
|
if ratio <= 1/3: |
|
|
|
|
|
t = ratio / (1/3) |
|
|
b = int(255 * (1 - t)) |
|
|
g = int(255 * t) |
|
|
r = 0 |
|
|
elif ratio <= 2/3: |
|
|
|
|
|
t = (ratio - 1/3) / (1/3) |
|
|
b = 0 |
|
|
g = 255 |
|
|
r = int(255 * t) |
|
|
else: |
|
|
|
|
|
t = (ratio - 2/3) / (1/3) |
|
|
b = 0 |
|
|
g = int(255 * (1 - t)) |
|
|
r = 255 |
|
|
return (r, g, b) |
|
|
|
|
|
def paint_point_track( |
|
|
frames: np.ndarray, |
|
|
point_tracks: np.ndarray, |
|
|
visibles: np.ndarray, |
|
|
min_radius: int = 1, |
|
|
max_radius: int = 6, |
|
|
max_retain: int = 50 |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Draws every past point of each track on each frame, with radius and color |
|
|
interpolated by the point's age (old→small to new→large). |
|
|
|
|
|
Args: |
|
|
frames: [F, H, W, 3] uint8 RGB |
|
|
point_tracks:[N, F, 2] float32 – (x,y) in pixel coords |
|
|
visibles: [N, F] bool – visibility mask |
|
|
min_radius: radius for the very first point (oldest) |
|
|
max_radius: radius for the current point (newest) |
|
|
|
|
|
Returns: |
|
|
video: [F, H, W, 3] uint8 RGB |
|
|
""" |
|
|
import cv2 |
|
|
num_points, num_frames = point_tracks.shape[:2] |
|
|
H, W = frames.shape[1:3] |
|
|
|
|
|
video = frames.copy() |
|
|
|
|
|
for t in range(num_frames): |
|
|
|
|
|
frame = video[t].copy() |
|
|
|
|
|
for i in range(num_points): |
|
|
|
|
|
for τ in range(t + 1): |
|
|
if not visibles[i, τ]: |
|
|
continue |
|
|
|
|
|
if t - τ > max_retain: |
|
|
continue |
|
|
|
|
|
|
|
|
x, y = point_tracks[i, τ] + 0.5 |
|
|
xi = int(np.clip(x, 0, W - 1)) |
|
|
yi = int(np.clip(y, 0, H - 1)) |
|
|
|
|
|
|
|
|
if num_frames > 1: |
|
|
ratio = 1 - float(t - τ) / max_retain |
|
|
else: |
|
|
ratio = 1.0 |
|
|
|
|
|
|
|
|
radius = int(round(min_radius + (max_radius - min_radius) * ratio)) |
|
|
|
|
|
|
|
|
color_rgb = age_to_bgr(ratio) |
|
|
|
|
|
|
|
|
cv2.circle(frame, (xi, yi), radius, color_rgb, thickness=-1) |
|
|
|
|
|
video[t] = frame |
|
|
|
|
|
return video |
|
|
|
|
|
def parse_json_tracks(tracks): |
|
|
tracks_data = [] |
|
|
try: |
|
|
|
|
|
if isinstance(tracks, str): |
|
|
parsed = json.loads(tracks.replace("'", '"')) |
|
|
tracks_data.extend(parsed) |
|
|
else: |
|
|
|
|
|
for track_str in tracks: |
|
|
parsed = json.loads(track_str.replace("'", '"')) |
|
|
tracks_data.append(parsed) |
|
|
|
|
|
|
|
|
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]: |
|
|
|
|
|
tracks_data = [tracks_data] |
|
|
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]: |
|
|
|
|
|
pass |
|
|
else: |
|
|
|
|
|
print(f"Warning: Unexpected track format: {type(tracks_data[0])}") |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
print(f"Error parsing tracks JSON: {e}") |
|
|
tracks_data = [] |
|
|
|
|
|
return tracks_data |
|
|
|
|
|
class WanVideoATITracks: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": { |
|
|
"model": ("WANVIDEOMODEL", ), |
|
|
"tracks": ("STRING",), |
|
|
"width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), |
|
|
"height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), |
|
|
"temperature": ("FLOAT", {"default": 220.0, "min": 0.0, "max": 1000.0, "step": 0.1}), |
|
|
"topk": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), |
|
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent of the steps to apply ATI"}), |
|
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent of the steps to apply ATI"}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("WANVIDEOMODEL",) |
|
|
RETURN_NAMES = ("model",) |
|
|
FUNCTION = "patchmodel" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def patchmodel(self, model, tracks, width, height, temperature, topk, start_percent, end_percent): |
|
|
tracks_data = parse_json_tracks(tracks) |
|
|
arrs = [] |
|
|
for track in tracks_data: |
|
|
pts = pad_pts(track) |
|
|
arrs.append(pts) |
|
|
|
|
|
tracks_np = np.stack(arrs, axis=0) |
|
|
|
|
|
processed_tracks = process_tracks(tracks_np, (width, height)) |
|
|
|
|
|
patcher = model.clone() |
|
|
patcher.model_options["transformer_options"]["ati_tracks"] = processed_tracks.unsqueeze(0) |
|
|
patcher.model_options["transformer_options"]["ati_temperature"] = temperature |
|
|
patcher.model_options["transformer_options"]["ati_topk"] = topk |
|
|
patcher.model_options["transformer_options"]["ati_start_percent"] = start_percent |
|
|
patcher.model_options["transformer_options"]["ati_end_percent"] = end_percent |
|
|
|
|
|
return (patcher,) |
|
|
|
|
|
class WanVideoATITracksVisualize: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": { |
|
|
"images": ("IMAGE",), |
|
|
"tracks": ("STRING",), |
|
|
"min_radius": ("INT", {"default": 1, "min": 0, "max": 100, "step": 1, "tooltip": "radius for the very first point (oldest)"}), |
|
|
"max_radius": ("INT", {"default": 6, "min": 0, "max": 100, "step": 1, "tooltip": "radius for the current point (newest)"}), |
|
|
"max_retain": ("INT", {"default": 50, "min": 0, "max": 100, "step": 1, "tooltip": "Maximum number of points to retain"}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
|
RETURN_NAMES = ("images",) |
|
|
FUNCTION = "patchmodel" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def patchmodel(self, images, tracks, min_radius, max_radius, max_retain): |
|
|
tracks_data = parse_json_tracks(tracks) |
|
|
arrs = [] |
|
|
for track in tracks_data: |
|
|
pts = pad_pts(track) |
|
|
arrs.append(pts) |
|
|
|
|
|
tracks_np = np.stack(arrs, axis=0) |
|
|
track = np.repeat(tracks_np, 2, axis=1)[:, ::3] |
|
|
points = track[:, :, 0, :2].astype(np.float32) |
|
|
visibles = track[:, :, 0, 2].astype(np.float32) |
|
|
|
|
|
if images.shape[0] < points.shape[1]: |
|
|
repeat_count = (points.shape[1] + images.shape[0] - 1) // images.shape[0] |
|
|
images = images.repeat(repeat_count, 1, 1, 1) |
|
|
images = images[:points.shape[1]] |
|
|
elif images.shape[0] > points.shape[1]: |
|
|
images = images[:points.shape[1]] |
|
|
|
|
|
video_viz = paint_point_track(images.cpu().numpy(), points, visibles, min_radius, max_radius, max_retain) |
|
|
video_viz = torch.from_numpy(video_viz).float() |
|
|
|
|
|
return (video_viz,) |
|
|
|
|
|
from comfy import utils |
|
|
import types |
|
|
from .motion_patch import patch_motion |
|
|
|
|
|
class WanConcatCondPatch: |
|
|
def __init__(self, tracks, temperature, topk): |
|
|
self.tracks = tracks |
|
|
self.temperature = temperature |
|
|
self.topk = topk |
|
|
|
|
|
def __get__(self, obj, objtype=None): |
|
|
|
|
|
def wrapped_concat_cond(self_module, *args, **kwargs): |
|
|
return modified_concat_cond(self_module, self.tracks, self.temperature, self.topk, *args, **kwargs) |
|
|
return types.MethodType(wrapped_concat_cond, obj) |
|
|
|
|
|
def modified_concat_cond(self, tracks, temperature, topk, **kwargs): |
|
|
noise = kwargs.get("noise", None) |
|
|
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] |
|
|
if extra_channels == 0: |
|
|
return None |
|
|
|
|
|
image = kwargs.get("concat_latent_image", None) |
|
|
device = kwargs["device"] |
|
|
|
|
|
if image is None: |
|
|
shape_image = list(noise.shape) |
|
|
shape_image[1] = extra_channels |
|
|
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) |
|
|
else: |
|
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") |
|
|
for i in range(0, image.shape[1], 16): |
|
|
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) |
|
|
image = utils.resize_to_batch_size(image, noise.shape[0]) |
|
|
|
|
|
if not self.image_to_video or extra_channels == image.shape[1]: |
|
|
return image |
|
|
|
|
|
if image.shape[1] > (extra_channels - 4): |
|
|
image = image[:, :(extra_channels - 4)] |
|
|
|
|
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) |
|
|
if mask is None: |
|
|
mask = torch.zeros_like(noise)[:, :4] |
|
|
else: |
|
|
if mask.shape[1] != 4: |
|
|
mask = torch.mean(mask, dim=1, keepdim=True) |
|
|
mask = 1.0 - mask |
|
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") |
|
|
if mask.shape[-3] < noise.shape[-3]: |
|
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) |
|
|
if mask.shape[1] == 1: |
|
|
mask = mask.repeat(1, 4, 1, 1, 1) |
|
|
mask = utils.resize_to_batch_size(mask, noise.shape[0]) |
|
|
|
|
|
image_cond = torch.cat((mask, image), dim=1) |
|
|
image_cond_ati = patch_motion(tracks.to(image_cond.device, image_cond.dtype), image_cond[0], |
|
|
temperature=temperature, topk=topk) |
|
|
|
|
|
return image_cond_ati.unsqueeze(0) |
|
|
|
|
|
class WanVideoATI_comfy: |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": { |
|
|
"model": ("MODEL", ), |
|
|
"width": ("INT", {"default": 832, "min": 64, "max": 2048, "step": 8, "tooltip": "Width of the image to encode"}), |
|
|
"height": ("INT", {"default": 480, "min": 64, "max": 29048, "step": 8, "tooltip": "Height of the image to encode"}), |
|
|
"tracks": ("STRING",), |
|
|
"temperature": ("FLOAT", {"default": 220.0, "min": 0.0, "max": 1000.0, "step": 0.1}), |
|
|
"topk": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1}), |
|
|
}, |
|
|
} |
|
|
|
|
|
RETURN_TYPES = ("MODEL",) |
|
|
RETURN_NAMES = ("model", ) |
|
|
FUNCTION = "patchcond" |
|
|
CATEGORY = "WanVideoWrapper" |
|
|
|
|
|
def patchcond(self, model, tracks, width, height, temperature, topk): |
|
|
|
|
|
tracks_data = parse_json_tracks(tracks) |
|
|
arrs = [] |
|
|
for track in tracks_data: |
|
|
pts = pad_pts(track) |
|
|
arrs.append(pts) |
|
|
|
|
|
tracks_np = np.stack(arrs, axis=0) |
|
|
|
|
|
processed_tracks = process_tracks(tracks_np, (width, height)) |
|
|
|
|
|
model_clone = model.clone() |
|
|
model_clone.add_object_patch( |
|
|
"concat_cond", |
|
|
WanConcatCondPatch( |
|
|
processed_tracks.unsqueeze(0), temperature, topk |
|
|
).__get__(model.model, model.model.__class__) |
|
|
) |
|
|
|
|
|
return (model_clone,) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"WanVideoATITracks": WanVideoATITracks, |
|
|
"WanVideoATITracksVisualize": WanVideoATITracksVisualize, |
|
|
"WanVideoATI_comfy": WanVideoATI_comfy, |
|
|
} |
|
|
NODE_DISPLAY_NAME_MAPPINGS = { |
|
|
"WanVideoATITracks": "WanVideo ATI Tracks", |
|
|
"WanVideoATITracksVisualize": "WanVideo ATI Tracks Visualize", |
|
|
"WanVideoATI_comfy": "WanVideo ATI Comfy", |
|
|
} |
|
|
|