File size: 9,664 Bytes
5ba4011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
import itertools
import numpy as np
import torch
from PIL import Image
import psutil

# Constants (consistent with ComfyUI conventions)
BIGMAX = 2**32
DIMMAX = 16384

def strip_path(path):
    return path.strip().strip('"').strip("'")

def validate_path(path, allow_none=False):
    if allow_none and path is None:
        return True
    return os.path.isfile(path)

def target_size(width, height, force_size, downscale_ratio=8):
    if force_size == "Disabled":
        pass
    elif force_size == "256x?":
        height = int(height * 256 / width)
        width = 256
    elif force_size == "?x256":
        width = int(width * 256 / height)
        height = 256
    elif force_size == "256x256":
        width, height = 256, 256
    elif force_size == "512x?":
        height = int(height * 512 / width)
        width = 512
    elif force_size == "?x512":
        width = int(width * 512 / height)
        height = 512
    elif force_size == "512x512":
        width, height = 512, 512
    width = int(width / downscale_ratio + 0.5) * downscale_ratio
    height = int(height / downscale_ratio + 0.5) * downscale_ratio
    return (width, height)

def webp_frame_generator(webp_path, force_rate, frame_load_cap, skip_first_frames, select_every_nth):
    webp_path = strip_path(webp_path)
    print(f"Attempting to load WebP animation: {webp_path}")
    
    with Image.open(webp_path) as img:
        if not img.format == "WEBP":
            raise ValueError(f"File {webp_path} is not a WebP file.")
        
        # Get metadata
        width, height = img.size
        total_frames = getattr(img, 'n_frames', 1)
        duration = getattr(img, 'info', {}).get('duration', 100) / 1000  # Default to 100ms if no duration
        fps = 1 / duration if duration > 0 else 10  # Default to 10 FPS if no duration
        
        print(f"WebP metadata: FPS={fps}, Width={width}, Height={height}, Total Frames={total_frames}")

        base_frame_time = 1 / fps if fps > 0 else 1
        target_frame_time = base_frame_time if force_rate == 0 else 1 / force_rate

        yield (width, height, fps, duration * total_frames, total_frames, target_frame_time)

        frames_added = 0
        frame_idx = 0
        time_offset = 0

        yieldable_frames = total_frames if force_rate == 0 else int(total_frames / fps * force_rate)
        if frame_load_cap != 0:
            yieldable_frames = min(frame_load_cap, yieldable_frames)
        print(f"Expected yieldable frames: {yieldable_frames}")

        while frame_idx < total_frames:
            if time_offset < target_frame_time:
                time_offset += base_frame_time
                frame_idx += 1
                continue
            time_offset -= target_frame_time

            if frame_idx < skip_first_frames:
                frame_idx += 1
                continue

            if (frame_idx - skip_first_frames) % select_every_nth != 0:
                frame_idx += 1
                continue

            img.seek(frame_idx)
            frame = img.copy().convert('RGB')
            frame = np.array(frame, dtype=np.float32) / 255.0
            yield frame
            frames_added += 1
            print(f"Frame {frames_added} added.")

            frame_idx += 1
            if frame_load_cap > 0 and frames_added >= frame_load_cap:
                break

        print(f"Total frames yielded: {frames_added}")
        if frames_added == 0:
            print("Warning: No frames were yielded from the WebP animation.")

def common_upscale(samples, width, height, upscale_method="lanczos", crop="center"):
    s = samples.movedim(-1, 1)  # Move channels to second dimension
    s = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
    return s.movedim(1, -1)  # Move channels back to last dimension

def load_webp_advanced(webp_path, force_rate, force_size, frame_load_cap, skip_first_frames, select_every_nth, memory_limit_mb=None):
    gen = webp_frame_generator(webp_path, force_rate, frame_load_cap, skip_first_frames, select_every_nth)
    metadata = next(gen)
    width, height, fps, duration, total_frames, target_frame_time = metadata
    print(f"Loaded metadata: {metadata}")

    # Memory limit calculation
    memory_limit = None
    if memory_limit_mb is not None and memory_limit_mb > 0:
        memory_limit = memory_limit_mb * (2 ** 20)  # Convert MB to bytes
    else:
        try:
            memory_limit = (psutil.virtual_memory().available + psutil.swap_memory().free) - (2 ** 27)
        except:
            print("Warning: Failed to calculate memory limit.")

    if memory_limit is not None:
        max_loadable_frames = int(memory_limit // (width * height * 3 * 4))  # 3 channels, 4 bytes per float32
        gen = itertools.islice(gen, max_loadable_frames)
        print(f"Applied memory limit: Max frames = {max_loadable_frames}")

    # Handle resizing
    downscale_ratio = 8
    if force_size != "Disabled":
        new_size = target_size(width, height, force_size, downscale_ratio)
        if new_size[0] != width or new_size[1] != height:
            def rescale(frame):
                s = torch.from_numpy(np.array(frame, dtype=np.float32))
                s = s.movedim(-1, 1)  # (H, W, C) -> (C, H, W)
                s = common_upscale(s.unsqueeze(0), new_size[0], new_size[1], "lanczos", "center").squeeze(0)
                return s.movedim(1, -1).numpy()  # (C, H, W) -> (H, W, C)
            gen = map(rescale, gen)
            print(f"Resizing frames to {new_size}")
    else:
        new_size = (width, height)

    # Load frames into a tensor
    images = torch.from_numpy(np.fromiter(gen, dtype=np.dtype((np.float32, (new_size[1], new_size[0], 3)))))
    if len(images) == 0:
        raise RuntimeError("No frames generated from the WebP animation.")

    # Video info dictionary
    video_info = {
        "source_fps": fps,
        "source_frame_count": total_frames,
        "source_duration": duration,
        "source_width": width,
        "source_height": height,
        "loaded_fps": 1 / (target_frame_time * select_every_nth),
        "loaded_frame_count": len(images),
        "loaded_duration": len(images) * target_frame_time * select_every_nth,
        "loaded_width": new_size[0],
        "loaded_height": new_size[1],
    }
    print(f"Loaded {len(images)} frames. Video info: {video_info}")

    return (images, len(images), video_info)

class LoadWebPAnimationAdvanced:
    @classmethod
    def INPUT_TYPES(cls):
        input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "input")
        files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.lower().endswith('.webp')]
        return {
            "required": {
                "webp_file": (sorted(files),),
                "force_rate": ("INT", {"default": 0, "min": 0, "max": 60, "step": 1}),
                "force_size": (["Disabled", "256x?", "?x256", "256x256", "512x?", "?x512", "512x512"],),
                "frame_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
                "skip_first_frames": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
                "select_every_nth": ("INT", {"default": 1, "min": 1, "max": BIGMAX, "step": 1}),
            },
            "optional": {
                "memory_limit_mb": ("INT", {"default": 0, "min": 0, "max": 1024*1024, "step": 1}),
            },
        }

    CATEGORY = "Image Helper"
    RETURN_TYPES = ("IMAGE", "INT", "DICT")
    RETURN_NAMES = ("IMAGE", "frame_count", "video_info")
    FUNCTION = "load_webp"

    def load_webp(self, webp_file, force_rate, force_size, frame_load_cap, skip_first_frames, select_every_nth, memory_limit_mb=None):
        input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "input")
        webp_path = os.path.join(input_dir, strip_path(webp_file))
        if not validate_path(webp_path):
            raise ValueError(f"Invalid WebP file path: {webp_path}")
        if not webp_path.lower().endswith('.webp'):
            raise ValueError("This node only supports .webp files.")

        return load_webp_advanced(
            webp_path=webp_path,
            force_rate=force_rate,
            force_size=force_size,
            frame_load_cap=frame_load_cap,
            skip_first_frames=skip_first_frames,
            select_every_nth=select_every_nth,
            memory_limit_mb=memory_limit_mb
        )

    @classmethod
    def IS_CHANGED(cls, webp_file, **kwargs):
        input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "input")
        webp_path = os.path.join(input_dir, strip_path(webp_file))
        return hash(str(webp_path) + str(os.path.getmtime(webp_path) if os.path.exists(webp_path) else 0))

    @classmethod
    def VALIDATE_INPUTS(cls, webp_file, **kwargs):
        input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "input")
        webp_path = os.path.join(input_dir, strip_path(webp_file))
        if not validate_path(webp_path):
            return f"Invalid WebP file path: {webp_path}"
        if not webp_path.lower().endswith('.webp'):
            return "Only .webp files are supported."
        return True

NODE_CLASS_MAPPINGS = {
    "LoadWebPAnimationAdvanced": LoadWebPAnimationAdvanced
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "LoadWebPAnimationAdvanced": "Load WebP Animation (Advanced)"
}