File size: 8,695 Bytes
c6535db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import json
import torch
import torchvision.transforms.functional as TF
from ..utils import log
from .trajectory import create_pos_feature_map, draw_tracks_on_video, replace_feature
import os
from comfy import model_management as mm
device = mm.get_torch_device()
script_directory = os.path.dirname(os.path.abspath(__file__))

VAE_STRIDE = (4, 8, 8)  # t, h, w

class WanVideoWanDrawWanMoveTracks:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "images": ("IMAGE",),
                    "tracks": ("TRACKS",),
                },
                "optional": {
                    "line_resolution": ("INT", {"default": 24, "min": 4, "max": 64, "step": 1, "tooltip": "Number of points to use for each line segment"}),
                    "circle_size": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1, "tooltip": "Size of the circle to draw for each track point"}),
                    "opacity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Opacity of the circle to draw for each track point"}),
                    "line_width": ("INT", {"default": 14, "min": 1, "max": 50, "step": 1, "tooltip": "Width of the line to draw for each track"}),
                }
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("image",)
    FUNCTION = "execute"
    CATEGORY = "WanVideoWrapper"

    def execute(self, images, tracks, line_resolution=24, circle_size=10, opacity=0.5, line_width=14):
        if tracks is None or "track_path" not in tracks:
            log.warning("WanVideoWanDrawWanMoveTracks: No tracks provided.")
            return (images.float().cpu(), )
        track = tracks["track_path"].unsqueeze(0)
        track_visibility = tracks["track_visibility"].unsqueeze(0)
        images_in = images * 255.0
        if images_in.shape[0] != track.shape[1]:
            repeat_count = track.shape[1] // images.shape[0]
            images_in = images_in.repeat(repeat_count, 1, 1, 1)
        track_video = draw_tracks_on_video(images_in, track, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
        track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1)

        return (track_video.float().cpu(), )


class WanVideoAddWanMoveTracks:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "image_embeds": ("WANVIDIMAGE_EMBEDS",),
                    "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the reference embedding"}),
                },
                "optional": {
                    "track_mask": ("MASK",),
                    "track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}),
                    "tracks": ("TRACKS", {"tooltip": "Alternatively use Comfy Tracks dictionary"}),
                }
        }

    RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "TRACKS")
    RETURN_NAMES = ("image_embeds", "tracks")
    FUNCTION = "add"
    CATEGORY = "WanVideoWrapper"

    def add(self, image_embeds, track_coords=None, tracks=None, strength=1.0, track_mask=None):
        updated = dict(image_embeds)

        track_visibility = None

        target_shape = image_embeds.get("target_shape")
        if target_shape is not None:
            height = target_shape[2] * VAE_STRIDE[1]
            width = target_shape[3] * VAE_STRIDE[2]
        else:
            height = image_embeds["lat_h"] * VAE_STRIDE[1]
            width = image_embeds["lat_w"] * VAE_STRIDE[2]
        num_frames = image_embeds["num_frames"]

        if track_coords is not None:
            tracks_data = parse_json_tracks(track_coords)
            track_list = [
                [[track[frame]['x'], track[frame]['y']] for track in tracks_data]
                for frame in range(len(tracks_data[0]))
            ]
            track = torch.tensor(track_list, dtype=torch.float32, device=device)  # shape: (frames, num_tracks, 2)
        elif tracks is not None and "track_path" in tracks:
            track = tracks["track_path"]
            if track_mask is None:
                track_visibility = tracks.get("track_visibility", None)
        track = track[:num_frames]

        num_tracks = track.shape[-2]
        if track_visibility is None:
            if track_mask is None:
                track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device)
            else:
                track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
        feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device)

        updated.setdefault("wanmove_embeds", {})
        updated["wanmove_embeds"]["track_pos"] = track_pos
        updated["wanmove_embeds"]["strength"] = strength

        tracks_dict = {
            "track_path": track,
            "track_visibility": track_visibility,
        }

        return (updated, tracks_dict,)


def parse_json_tracks(tracks):
    tracks_data = []
    try:
        # If tracks is a string, try to parse it as JSON
        if isinstance(tracks, str):
            parsed = json.loads(tracks.replace("'", '"'))
            tracks_data.extend(parsed)
        else:
            # If tracks is a list of strings, parse each one
            for track_str in tracks:
                parsed = json.loads(track_str.replace("'", '"'))
                tracks_data.append(parsed)

        # Check if we have a single track (dict with x,y) or a list of tracks
        if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
            # Single track detected, wrap it in a list
            tracks_data = [tracks_data]
        elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
            # Already a list of tracks, nothing to do
            pass
        else:
            # Unexpected format
            log.warning(f"Warning: Unexpected track format: {type(tracks_data[0])}")

    except json.JSONDecodeError as e:
        log.warning(f"Error parsing tracks JSON: {e}")
        tracks_data = []

    return tracks_data

import node_helpers

class WanMove_native:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "positive": ("CONDITIONING",),
            "track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}),
            },
            "optional": {
                "track_mask": ("MASK",),
            }
        }

    RETURN_TYPES = ("CONDITIONING", "TRACKS")
    RETURN_NAMES = ("positive", "tracks")
    FUNCTION = "patchcond"
    CATEGORY = "WanVideoWrapper"
    DEPRECATED = True

    def patchcond(self, positive, track_coords, track_mask=None):

        concat_latent_image = positive[0][1]["concat_latent_image"]
        B, C, T, H, W = concat_latent_image.shape
        num_frames = (T-1) * 4 + 1
        width = W * 8
        height = H * 8

        tracks_data = parse_json_tracks(track_coords)
        track_list = [
            [[track[frame]['x'], track[frame]['y']] for track in tracks_data]
            for frame in range(len(tracks_data[0]))
        ]
        track = torch.tensor(track_list, dtype=torch.float32, device=device)  # shape: (frames, num_tracks, 2)
        track = track[:num_frames]

        num_tracks = track.shape[-2]
        if track_mask is None:
            track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device)
        else:
            track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)

        feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device)
        wanmove_cond = replace_feature(concat_latent_image, track_pos.unsqueeze(0))
        positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": wanmove_cond})

        tracks_dict = {
            "track_path": track,
            "track_visibility": track_visibility,
        }
        return (positive, tracks_dict)


NODE_CLASS_MAPPINGS = {
    "WanVideoAddWanMoveTracks": WanVideoAddWanMoveTracks,
    "WanVideoWanDrawWanMoveTracks": WanVideoWanDrawWanMoveTracks,
    "WanMove_native": WanMove_native,
    }
NODE_DISPLAY_NAME_MAPPINGS = {
    "WanVideoAddWanMoveTracks": "WanVideo Add WanMove Tracks",
    "WanVideoWanDrawWanMoveTracks": "WanVideo Draw WanMove Tracks",
    "WanMove_native": "WanMove Native",
    }