import json import torch import torchvision.transforms.functional as TF from ..utils import log from .trajectory import create_pos_feature_map, draw_tracks_on_video, replace_feature import os from comfy import model_management as mm device = mm.get_torch_device() script_directory = os.path.dirname(os.path.abspath(__file__)) VAE_STRIDE = (4, 8, 8) # t, h, w class WanVideoWanDrawWanMoveTracks: @classmethod def INPUT_TYPES(s): return {"required": { "images": ("IMAGE",), "tracks": ("TRACKS",), }, "optional": { "line_resolution": ("INT", {"default": 24, "min": 4, "max": 64, "step": 1, "tooltip": "Number of points to use for each line segment"}), "circle_size": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1, "tooltip": "Size of the circle to draw for each track point"}), "opacity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Opacity of the circle to draw for each track point"}), "line_width": ("INT", {"default": 14, "min": 1, "max": 50, "step": 1, "tooltip": "Width of the line to draw for each track"}), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "execute" CATEGORY = "WanVideoWrapper" def execute(self, images, tracks, line_resolution=24, circle_size=10, opacity=0.5, line_width=14): if tracks is None or "track_path" not in tracks: log.warning("WanVideoWanDrawWanMoveTracks: No tracks provided.") return (images.float().cpu(), ) track = tracks["track_path"].unsqueeze(0) track_visibility = tracks["track_visibility"].unsqueeze(0) images_in = images * 255.0 if images_in.shape[0] != track.shape[1]: repeat_count = track.shape[1] // images.shape[0] images_in = images_in.repeat(repeat_count, 1, 1, 1) track_video = draw_tracks_on_video(images_in, track, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width) track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1) return (track_video.float().cpu(), ) class WanVideoAddWanMoveTracks: @classmethod def INPUT_TYPES(s): return {"required": { "image_embeds": ("WANVIDIMAGE_EMBEDS",), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the reference embedding"}), }, "optional": { "track_mask": ("MASK",), "track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}), "tracks": ("TRACKS", {"tooltip": "Alternatively use Comfy Tracks dictionary"}), } } RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "TRACKS") RETURN_NAMES = ("image_embeds", "tracks") FUNCTION = "add" CATEGORY = "WanVideoWrapper" def add(self, image_embeds, track_coords=None, tracks=None, strength=1.0, track_mask=None): updated = dict(image_embeds) track_visibility = None target_shape = image_embeds.get("target_shape") if target_shape is not None: height = target_shape[2] * VAE_STRIDE[1] width = target_shape[3] * VAE_STRIDE[2] else: height = image_embeds["lat_h"] * VAE_STRIDE[1] width = image_embeds["lat_w"] * VAE_STRIDE[2] num_frames = image_embeds["num_frames"] if track_coords is not None: tracks_data = parse_json_tracks(track_coords) track_list = [ [[track[frame]['x'], track[frame]['y']] for track in tracks_data] for frame in range(len(tracks_data[0])) ] track = torch.tensor(track_list, dtype=torch.float32, device=device) # shape: (frames, num_tracks, 2) elif tracks is not None and "track_path" in tracks: track = tracks["track_path"] if track_mask is None: track_visibility = tracks.get("track_visibility", None) track = track[:num_frames] num_tracks = track.shape[-2] if track_visibility is None: if track_mask is None: track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device) else: track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1) feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device) updated.setdefault("wanmove_embeds", {}) updated["wanmove_embeds"]["track_pos"] = track_pos updated["wanmove_embeds"]["strength"] = strength tracks_dict = { "track_path": track, "track_visibility": track_visibility, } return (updated, tracks_dict,) def parse_json_tracks(tracks): tracks_data = [] try: # If tracks is a string, try to parse it as JSON if isinstance(tracks, str): parsed = json.loads(tracks.replace("'", '"')) tracks_data.extend(parsed) else: # If tracks is a list of strings, parse each one for track_str in tracks: parsed = json.loads(track_str.replace("'", '"')) tracks_data.append(parsed) # Check if we have a single track (dict with x,y) or a list of tracks if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]: # Single track detected, wrap it in a list tracks_data = [tracks_data] elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]: # Already a list of tracks, nothing to do pass else: # Unexpected format log.warning(f"Warning: Unexpected track format: {type(tracks_data[0])}") except json.JSONDecodeError as e: log.warning(f"Error parsing tracks JSON: {e}") tracks_data = [] return tracks_data import node_helpers class WanMove_native: @classmethod def INPUT_TYPES(s): return {"required": { "positive": ("CONDITIONING",), "track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}), }, "optional": { "track_mask": ("MASK",), } } RETURN_TYPES = ("CONDITIONING", "TRACKS") RETURN_NAMES = ("positive", "tracks") FUNCTION = "patchcond" CATEGORY = "WanVideoWrapper" DEPRECATED = True def patchcond(self, positive, track_coords, track_mask=None): concat_latent_image = positive[0][1]["concat_latent_image"] B, C, T, H, W = concat_latent_image.shape num_frames = (T-1) * 4 + 1 width = W * 8 height = H * 8 tracks_data = parse_json_tracks(track_coords) track_list = [ [[track[frame]['x'], track[frame]['y']] for track in tracks_data] for frame in range(len(tracks_data[0])) ] track = torch.tensor(track_list, dtype=torch.float32, device=device) # shape: (frames, num_tracks, 2) track = track[:num_frames] num_tracks = track.shape[-2] if track_mask is None: track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device) else: track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1) feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device) wanmove_cond = replace_feature(concat_latent_image, track_pos.unsqueeze(0)) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": wanmove_cond}) tracks_dict = { "track_path": track, "track_visibility": track_visibility, } return (positive, tracks_dict) NODE_CLASS_MAPPINGS = { "WanVideoAddWanMoveTracks": WanVideoAddWanMoveTracks, "WanVideoWanDrawWanMoveTracks": WanVideoWanDrawWanMoveTracks, "WanMove_native": WanMove_native, } NODE_DISPLAY_NAME_MAPPINGS = { "WanVideoAddWanMoveTracks": "WanVideo Add WanMove Tracks", "WanVideoWanDrawWanMoveTracks": "WanVideo Draw WanMove Tracks", "WanMove_native": "WanMove Native", }