| import json |
| import torch |
| import torchvision.transforms.functional as TF |
| from ..utils import log |
| from .trajectory import create_pos_feature_map, draw_tracks_on_video, replace_feature |
| import os |
| from comfy import model_management as mm |
| device = mm.get_torch_device() |
| script_directory = os.path.dirname(os.path.abspath(__file__)) |
|
|
| VAE_STRIDE = (4, 8, 8) |
|
|
| class WanVideoWanDrawWanMoveTracks: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "images": ("IMAGE",), |
| "tracks": ("TRACKS",), |
| }, |
| "optional": { |
| "line_resolution": ("INT", {"default": 24, "min": 4, "max": 64, "step": 1, "tooltip": "Number of points to use for each line segment"}), |
| "circle_size": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1, "tooltip": "Size of the circle to draw for each track point"}), |
| "opacity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Opacity of the circle to draw for each track point"}), |
| "line_width": ("INT", {"default": 14, "min": 1, "max": 50, "step": 1, "tooltip": "Width of the line to draw for each track"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("IMAGE",) |
| RETURN_NAMES = ("image",) |
| FUNCTION = "execute" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def execute(self, images, tracks, line_resolution=24, circle_size=10, opacity=0.5, line_width=14): |
| if tracks is None or "track_path" not in tracks: |
| log.warning("WanVideoWanDrawWanMoveTracks: No tracks provided.") |
| return (images.float().cpu(), ) |
| track = tracks["track_path"].unsqueeze(0) |
| track_visibility = tracks["track_visibility"].unsqueeze(0) |
| images_in = images * 255.0 |
| if images_in.shape[0] != track.shape[1]: |
| repeat_count = track.shape[1] // images.shape[0] |
| images_in = images_in.repeat(repeat_count, 1, 1, 1) |
| track_video = draw_tracks_on_video(images_in, track, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width) |
| track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1) |
|
|
| return (track_video.float().cpu(), ) |
|
|
|
|
| class WanVideoAddWanMoveTracks: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "image_embeds": ("WANVIDIMAGE_EMBEDS",), |
| "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the reference embedding"}), |
| }, |
| "optional": { |
| "track_mask": ("MASK",), |
| "track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}), |
| "tracks": ("TRACKS", {"tooltip": "Alternatively use Comfy Tracks dictionary"}), |
| } |
| } |
|
|
| RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "TRACKS") |
| RETURN_NAMES = ("image_embeds", "tracks") |
| FUNCTION = "add" |
| CATEGORY = "WanVideoWrapper" |
|
|
| def add(self, image_embeds, track_coords=None, tracks=None, strength=1.0, track_mask=None): |
| updated = dict(image_embeds) |
|
|
| track_visibility = None |
|
|
| target_shape = image_embeds.get("target_shape") |
| if target_shape is not None: |
| height = target_shape[2] * VAE_STRIDE[1] |
| width = target_shape[3] * VAE_STRIDE[2] |
| else: |
| height = image_embeds["lat_h"] * VAE_STRIDE[1] |
| width = image_embeds["lat_w"] * VAE_STRIDE[2] |
| num_frames = image_embeds["num_frames"] |
|
|
| if track_coords is not None: |
| tracks_data = parse_json_tracks(track_coords) |
| track_list = [ |
| [[track[frame]['x'], track[frame]['y']] for track in tracks_data] |
| for frame in range(len(tracks_data[0])) |
| ] |
| track = torch.tensor(track_list, dtype=torch.float32, device=device) |
| elif tracks is not None and "track_path" in tracks: |
| track = tracks["track_path"] |
| if track_mask is None: |
| track_visibility = tracks.get("track_visibility", None) |
| track = track[:num_frames] |
|
|
| num_tracks = track.shape[-2] |
| if track_visibility is None: |
| if track_mask is None: |
| track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device) |
| else: |
| track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1) |
| feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device) |
|
|
| updated.setdefault("wanmove_embeds", {}) |
| updated["wanmove_embeds"]["track_pos"] = track_pos |
| updated["wanmove_embeds"]["strength"] = strength |
|
|
| tracks_dict = { |
| "track_path": track, |
| "track_visibility": track_visibility, |
| } |
|
|
| return (updated, tracks_dict,) |
|
|
|
|
| def parse_json_tracks(tracks): |
| tracks_data = [] |
| try: |
| |
| if isinstance(tracks, str): |
| parsed = json.loads(tracks.replace("'", '"')) |
| tracks_data.extend(parsed) |
| else: |
| |
| for track_str in tracks: |
| parsed = json.loads(track_str.replace("'", '"')) |
| tracks_data.append(parsed) |
|
|
| |
| if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]: |
| |
| tracks_data = [tracks_data] |
| elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]: |
| |
| pass |
| else: |
| |
| log.warning(f"Warning: Unexpected track format: {type(tracks_data[0])}") |
|
|
| except json.JSONDecodeError as e: |
| log.warning(f"Error parsing tracks JSON: {e}") |
| tracks_data = [] |
|
|
| return tracks_data |
|
|
| import node_helpers |
|
|
| class WanMove_native: |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "positive": ("CONDITIONING",), |
| "track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}), |
| }, |
| "optional": { |
| "track_mask": ("MASK",), |
| } |
| } |
|
|
| RETURN_TYPES = ("CONDITIONING", "TRACKS") |
| RETURN_NAMES = ("positive", "tracks") |
| FUNCTION = "patchcond" |
| CATEGORY = "WanVideoWrapper" |
| DEPRECATED = True |
|
|
| def patchcond(self, positive, track_coords, track_mask=None): |
|
|
| concat_latent_image = positive[0][1]["concat_latent_image"] |
| B, C, T, H, W = concat_latent_image.shape |
| num_frames = (T-1) * 4 + 1 |
| width = W * 8 |
| height = H * 8 |
|
|
| tracks_data = parse_json_tracks(track_coords) |
| track_list = [ |
| [[track[frame]['x'], track[frame]['y']] for track in tracks_data] |
| for frame in range(len(tracks_data[0])) |
| ] |
| track = torch.tensor(track_list, dtype=torch.float32, device=device) |
| track = track[:num_frames] |
|
|
| num_tracks = track.shape[-2] |
| if track_mask is None: |
| track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device) |
| else: |
| track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1) |
|
|
| feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device) |
| wanmove_cond = replace_feature(concat_latent_image, track_pos.unsqueeze(0)) |
| positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": wanmove_cond}) |
|
|
| tracks_dict = { |
| "track_path": track, |
| "track_visibility": track_visibility, |
| } |
| return (positive, tracks_dict) |
|
|
|
|
| NODE_CLASS_MAPPINGS = { |
| "WanVideoAddWanMoveTracks": WanVideoAddWanMoveTracks, |
| "WanVideoWanDrawWanMoveTracks": WanVideoWanDrawWanMoveTracks, |
| "WanMove_native": WanMove_native, |
| } |
| NODE_DISPLAY_NAME_MAPPINGS = { |
| "WanVideoAddWanMoveTracks": "WanVideo Add WanMove Tracks", |
| "WanVideoWanDrawWanMoveTracks": "WanVideo Draw WanMove Tracks", |
| "WanMove_native": "WanMove Native", |
| } |
|
|