File size: 9,664 Bytes
5ba4011 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
import itertools
import numpy as np
import torch
from PIL import Image
import psutil
# Constants (consistent with ComfyUI conventions)
BIGMAX = 2**32
DIMMAX = 16384
def strip_path(path):
return path.strip().strip('"').strip("'")
def validate_path(path, allow_none=False):
if allow_none and path is None:
return True
return os.path.isfile(path)
def target_size(width, height, force_size, downscale_ratio=8):
if force_size == "Disabled":
pass
elif force_size == "256x?":
height = int(height * 256 / width)
width = 256
elif force_size == "?x256":
width = int(width * 256 / height)
height = 256
elif force_size == "256x256":
width, height = 256, 256
elif force_size == "512x?":
height = int(height * 512 / width)
width = 512
elif force_size == "?x512":
width = int(width * 512 / height)
height = 512
elif force_size == "512x512":
width, height = 512, 512
width = int(width / downscale_ratio + 0.5) * downscale_ratio
height = int(height / downscale_ratio + 0.5) * downscale_ratio
return (width, height)
def webp_frame_generator(webp_path, force_rate, frame_load_cap, skip_first_frames, select_every_nth):
webp_path = strip_path(webp_path)
print(f"Attempting to load WebP animation: {webp_path}")
with Image.open(webp_path) as img:
if not img.format == "WEBP":
raise ValueError(f"File {webp_path} is not a WebP file.")
# Get metadata
width, height = img.size
total_frames = getattr(img, 'n_frames', 1)
duration = getattr(img, 'info', {}).get('duration', 100) / 1000 # Default to 100ms if no duration
fps = 1 / duration if duration > 0 else 10 # Default to 10 FPS if no duration
print(f"WebP metadata: FPS={fps}, Width={width}, Height={height}, Total Frames={total_frames}")
base_frame_time = 1 / fps if fps > 0 else 1
target_frame_time = base_frame_time if force_rate == 0 else 1 / force_rate
yield (width, height, fps, duration * total_frames, total_frames, target_frame_time)
frames_added = 0
frame_idx = 0
time_offset = 0
yieldable_frames = total_frames if force_rate == 0 else int(total_frames / fps * force_rate)
if frame_load_cap != 0:
yieldable_frames = min(frame_load_cap, yieldable_frames)
print(f"Expected yieldable frames: {yieldable_frames}")
while frame_idx < total_frames:
if time_offset < target_frame_time:
time_offset += base_frame_time
frame_idx += 1
continue
time_offset -= target_frame_time
if frame_idx < skip_first_frames:
frame_idx += 1
continue
if (frame_idx - skip_first_frames) % select_every_nth != 0:
frame_idx += 1
continue
img.seek(frame_idx)
frame = img.copy().convert('RGB')
frame = np.array(frame, dtype=np.float32) / 255.0
yield frame
frames_added += 1
print(f"Frame {frames_added} added.")
frame_idx += 1
if frame_load_cap > 0 and frames_added >= frame_load_cap:
break
print(f"Total frames yielded: {frames_added}")
if frames_added == 0:
print("Warning: No frames were yielded from the WebP animation.")
def common_upscale(samples, width, height, upscale_method="lanczos", crop="center"):
s = samples.movedim(-1, 1) # Move channels to second dimension
s = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
return s.movedim(1, -1) # Move channels back to last dimension
def load_webp_advanced(webp_path, force_rate, force_size, frame_load_cap, skip_first_frames, select_every_nth, memory_limit_mb=None):
gen = webp_frame_generator(webp_path, force_rate, frame_load_cap, skip_first_frames, select_every_nth)
metadata = next(gen)
width, height, fps, duration, total_frames, target_frame_time = metadata
print(f"Loaded metadata: {metadata}")
# Memory limit calculation
memory_limit = None
if memory_limit_mb is not None and memory_limit_mb > 0:
memory_limit = memory_limit_mb * (2 ** 20) # Convert MB to bytes
else:
try:
memory_limit = (psutil.virtual_memory().available + psutil.swap_memory().free) - (2 ** 27)
except:
print("Warning: Failed to calculate memory limit.")
if memory_limit is not None:
max_loadable_frames = int(memory_limit // (width * height * 3 * 4)) # 3 channels, 4 bytes per float32
gen = itertools.islice(gen, max_loadable_frames)
print(f"Applied memory limit: Max frames = {max_loadable_frames}")
# Handle resizing
downscale_ratio = 8
if force_size != "Disabled":
new_size = target_size(width, height, force_size, downscale_ratio)
if new_size[0] != width or new_size[1] != height:
def rescale(frame):
s = torch.from_numpy(np.array(frame, dtype=np.float32))
s = s.movedim(-1, 1) # (H, W, C) -> (C, H, W)
s = common_upscale(s.unsqueeze(0), new_size[0], new_size[1], "lanczos", "center").squeeze(0)
return s.movedim(1, -1).numpy() # (C, H, W) -> (H, W, C)
gen = map(rescale, gen)
print(f"Resizing frames to {new_size}")
else:
new_size = (width, height)
# Load frames into a tensor
images = torch.from_numpy(np.fromiter(gen, dtype=np.dtype((np.float32, (new_size[1], new_size[0], 3)))))
if len(images) == 0:
raise RuntimeError("No frames generated from the WebP animation.")
# Video info dictionary
video_info = {
"source_fps": fps,
"source_frame_count": total_frames,
"source_duration": duration,
"source_width": width,
"source_height": height,
"loaded_fps": 1 / (target_frame_time * select_every_nth),
"loaded_frame_count": len(images),
"loaded_duration": len(images) * target_frame_time * select_every_nth,
"loaded_width": new_size[0],
"loaded_height": new_size[1],
}
print(f"Loaded {len(images)} frames. Video info: {video_info}")
return (images, len(images), video_info)
class LoadWebPAnimationAdvanced:
@classmethod
def INPUT_TYPES(cls):
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "input")
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.lower().endswith('.webp')]
return {
"required": {
"webp_file": (sorted(files),),
"force_rate": ("INT", {"default": 0, "min": 0, "max": 60, "step": 1}),
"force_size": (["Disabled", "256x?", "?x256", "256x256", "512x?", "?x512", "512x512"],),
"frame_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
"skip_first_frames": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
"select_every_nth": ("INT", {"default": 1, "min": 1, "max": BIGMAX, "step": 1}),
},
"optional": {
"memory_limit_mb": ("INT", {"default": 0, "min": 0, "max": 1024*1024, "step": 1}),
},
}
CATEGORY = "Image Helper"
RETURN_TYPES = ("IMAGE", "INT", "DICT")
RETURN_NAMES = ("IMAGE", "frame_count", "video_info")
FUNCTION = "load_webp"
def load_webp(self, webp_file, force_rate, force_size, frame_load_cap, skip_first_frames, select_every_nth, memory_limit_mb=None):
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "input")
webp_path = os.path.join(input_dir, strip_path(webp_file))
if not validate_path(webp_path):
raise ValueError(f"Invalid WebP file path: {webp_path}")
if not webp_path.lower().endswith('.webp'):
raise ValueError("This node only supports .webp files.")
return load_webp_advanced(
webp_path=webp_path,
force_rate=force_rate,
force_size=force_size,
frame_load_cap=frame_load_cap,
skip_first_frames=skip_first_frames,
select_every_nth=select_every_nth,
memory_limit_mb=memory_limit_mb
)
@classmethod
def IS_CHANGED(cls, webp_file, **kwargs):
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "input")
webp_path = os.path.join(input_dir, strip_path(webp_file))
return hash(str(webp_path) + str(os.path.getmtime(webp_path) if os.path.exists(webp_path) else 0))
@classmethod
def VALIDATE_INPUTS(cls, webp_file, **kwargs):
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "input")
webp_path = os.path.join(input_dir, strip_path(webp_file))
if not validate_path(webp_path):
return f"Invalid WebP file path: {webp_path}"
if not webp_path.lower().endswith('.webp'):
return "Only .webp files are supported."
return True
NODE_CLASS_MAPPINGS = {
"LoadWebPAnimationAdvanced": LoadWebPAnimationAdvanced
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoadWebPAnimationAdvanced": "Load WebP Animation (Advanced)"
} |