diff --git a/ComfyUI/comfy_extras/chainner_models/model_loading.py b/ComfyUI/comfy_extras/chainner_models/model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..b97f9db365d78ea242712798bed61b0a1765a4af --- /dev/null +++ b/ComfyUI/comfy_extras/chainner_models/model_loading.py @@ -0,0 +1,6 @@ +import logging +from spandrel import ModelLoader + +def load_state_dict(state_dict): + logging.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.") + return ModelLoader().load_from_state_dict(state_dict).eval() diff --git a/ComfyUI/comfy_extras/frame_interpolation_models/film_net.py b/ComfyUI/comfy_extras/frame_interpolation_models/film_net.py new file mode 100644 index 0000000000000000000000000000000000000000..fc86f97bc20abf9b33d18a54a6dbab3030b67e1b --- /dev/null +++ b/ComfyUI/comfy_extras/frame_interpolation_models/film_net.py @@ -0,0 +1,258 @@ +"""FILM: Frame Interpolation for Large Motion (ECCV 2022).""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +class FilmConv2d(nn.Module): + """Conv2d with optional LeakyReLU and FILM-style padding.""" + + def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops): + super().__init__() + self.even_pad = not size % 2 + self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype) + self.activation = nn.LeakyReLU(0.2) if activation else None + + def forward(self, x): + if self.even_pad: + x = F.pad(x, (0, 1, 0, 1)) + x = self.conv(x) + if self.activation is not None: + x = self.activation(x) + return x + + +def _warp_core(image, flow, grid_x, grid_y): + dtype = image.dtype + H, W = flow.shape[2], flow.shape[3] + dx = flow[:, 0].float() / (W * 0.5) + dy = flow[:, 1].float() / (H * 0.5) + grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3) + return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype) + + +def build_image_pyramid(image, pyramid_levels): + pyramid = [image] + for _ in range(1, pyramid_levels): + image = F.avg_pool2d(image, 2, 2) + pyramid.append(image) + return pyramid + + +def flow_pyramid_synthesis(residual_pyramid): + flow = residual_pyramid[-1] + flow_pyramid = [flow] + for residual_flow in residual_pyramid[:-1][::-1]: + flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow) + flow_pyramid.append(flow) + flow_pyramid.reverse() + return flow_pyramid + + +def multiply_pyramid(pyramid, scalar): + return [image * scalar[:, None, None, None] for image in pyramid] + + +def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn): + return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)] + + +def concatenate_pyramids(pyramid1, pyramid2): + return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)] + + +class SubTreeExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops): + super().__init__() + convs = [] + for i in range(n_layers): + out_ch = channels << i + convs.append(nn.Sequential( + FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations))) + in_channels = out_ch + self.convs = nn.ModuleList(convs) + + def forward(self, image, n): + head = image + pyramid = [] + for i, layer in enumerate(self.convs): + head = layer(head) + pyramid.append(head) + if i < n - 1: + head = F.avg_pool2d(head, 2, 2) + return pyramid + + +class FeatureExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops): + super().__init__() + self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations) + self.sub_levels = sub_levels + + def forward(self, image_pyramid): + sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels)) + for i in range(len(image_pyramid))] + feature_pyramid = [] + for i in range(len(image_pyramid)): + features = sub_pyramids[i][0] + for j in range(1, self.sub_levels): + if j <= i: + features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) + feature_pyramid.append(features) + # Free sub-pyramids no longer needed by future levels + if i >= self.sub_levels - 1: + sub_pyramids[i - self.sub_levels + 1] = None + return feature_pyramid + + +class FlowEstimator(nn.Module): + def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops): + super().__init__() + self._convs = nn.ModuleList() + for _ in range(num_convs): + self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations)) + in_channels = num_filters + self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations)) + self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations)) + + def forward(self, features_a, features_b): + net = torch.cat([features_a, features_b], dim=1) + for conv in self._convs: + net = conv(net) + return net + + +class PyramidFlowEstimator(nn.Module): + def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + in_channels = filters << 1 + predictors = [] + for i in range(len(flow_convs)): + predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations)) + in_channels += filters << (i + 2) + self._predictor = predictors[-1] + self._predictors = nn.ModuleList(predictors[:-1][::-1]) + + def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn): + levels = len(feature_pyramid_a) + v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1]) + residuals = [v] + # Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels + steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)] + steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)] + for i, predictor in steps: + v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2) + v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v)) + residuals.append(v_residual) + v = v.add_(v_residual) + residuals.reverse() + return residuals + + +def _get_fusion_channels(level, filters): + # Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions + return (sum(filters << i for i in range(level)) + 3 + 2) * 2 + + +class Fusion(nn.Module): + def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops): + super().__init__() + self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype) + self.convs = nn.ModuleList() + in_channels = _get_fusion_channels(n_layers, filters) + increase = 0 + for i in range(n_layers)[::-1]: + num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) + self.convs.append(nn.ModuleList([ + FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations), + FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)])) + in_channels = num_filters + increase = _get_fusion_channels(i, filters) - num_filters // 2 + + def forward(self, pyramid): + net = pyramid[-1] + for k, layers in enumerate(self.convs): + i = len(self.convs) - 1 - k + net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest")) + net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1))) + return self.output_conv(net) + + +class FILMNet(nn.Module): + def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4, + filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + self.pyramid_levels = pyramid_levels + self.fusion_pyramid_levels = fusion_pyramid_levels + self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations) + self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations) + self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations) + self._warp_grids = {} + + def get_dtype(self): + return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype + + def _build_warp_grids(self, H, W, device): + """Pre-compute warp grids for all pyramid levels.""" + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + for _ in range(self.pyramid_levels): + self._warp_grids[(H, W)] = ( + torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device), + torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device), + ) + H, W = H // 2, W // 2 + + def warp(self, image, flow): + grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])] + return _warp_core(image, flow, grid_x, grid_y) + + def extract_features(self, img): + """Extract image and feature pyramids for a single frame. Can be cached across pairs.""" + image_pyramid = build_image_pyramid(img, self.pyramid_levels) + feature_pyramid = self.extract(image_pyramid) + return image_pyramid, feature_pyramid + + def forward(self, img0, img1, timestep=0.5, cache=None): + # FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported) + t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep + return self.forward_multi_timestep(img0, img1, [t], cache=cache) + + def forward_multi_timestep(self, img0, img1, timesteps, cache=None): + """Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs.""" + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0) + image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1) + + fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels] + bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels] + + # Build warp targets and free full pyramids (only first fpl levels needed from here) + fpl = self.fusion_pyramid_levels + p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]), + concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])] + del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1 + + results = [] + dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype) + for idx in range(len(timesteps)): + batch_dt = dt_tensors[idx:idx + 1] + bwd_scaled = multiply_pyramid(bwd_flow, batch_dt) + fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt) + fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp) + bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp) + aligned = [torch.cat([fw, bw, bf, ff], dim=1) + for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)] + del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled + results.append(self.fuse(aligned)) + del aligned + return torch.cat(results, dim=0) diff --git a/ComfyUI/comfy_extras/frame_interpolation_models/ifnet.py b/ComfyUI/comfy_extras/frame_interpolation_models/ifnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6bcfb851f1720163125a81e4ecfde5678ed8b739 --- /dev/null +++ b/ComfyUI/comfy_extras/frame_interpolation_models/ifnet.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +def _warp(img, flow, warp_grids): + B, _, H, W = img.shape + base_grid, flow_div = warp_grids[(H, W)] + flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float() + grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1) + return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype) + + +class Head(nn.Module): + def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): + super().__init__() + self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) + self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + x = self.relu(self.cnn0(x)) + x = self.relu(self.cnn1(x)) + x = self.relu(self.cnn2(x)) + return self.cnn3(x) + + +class ResConv(nn.Module): + def __init__(self, c, device=None, dtype=None, operations=ops): + super().__init__() + self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(torch.addcmul(x, self.conv(x), self.beta)) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): + super().__init__() + self.conv0 = nn.Sequential( + nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)), + nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True))) + self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8))) + self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2)) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) + x = torch.cat((x, flow), 1) + feat = self.convblock(self.conv0(x)) + tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear") + return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:] + + +class IFNet(nn.Module): + def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops): + super().__init__() + self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) + block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 + self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)]) + self.scale_list = [16, 8, 4, 2, 1] + self.pad_align = 64 + self._warp_grids = {} + + def get_dtype(self): + return self.encode.cnn0.weight.dtype + + def _build_warp_grids(self, H, W, device): + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + grid_y, grid_x = torch.meshgrid( + torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32), + torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij") + self._warp_grids[(H, W)] = ( + torch.stack((grid_x, grid_y), dim=0).unsqueeze(0), + torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device)) + + def warp(self, img, flow): + return _warp(img, flow, self._warp_grids) + + def extract_features(self, img): + """Extract head features for a single frame. Can be cached across pairs.""" + return self.encode(img) + + def forward(self, img0, img1, timestep=0.5, cache=None): + if not isinstance(timestep, torch.Tensor): + timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype) + + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + B = img0.shape[0] + f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0) + f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1) + flow = mask = feat = None + warped_img0, warped_img1 = img0, img1 + for i, block in enumerate(self.blocks): + if flow is None: + flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) + else: + fd, mask, feat = block( + torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1), + flow, scale=self.scale_list[i]) + flow = flow.add_(fd) + warped_img0 = self.warp(img0, flow[:, :2]) + warped_img1 = self.warp(img1, flow[:, 2:4]) + return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask)) + + +def detect_rife_config(state_dict): + head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW) + channels = [] + for i in range(5): + key = f"blocks.{i}.conv0.1.0.weight" + if key in state_dict: + channels.append(state_dict[key].shape[0]) + if len(channels) != 5: + raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") + return head_ch, channels diff --git a/ComfyUI/comfy_extras/nodes_ace.py b/ComfyUI/comfy_extras/nodes_ace.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0f1a66b7fb231a802d9332a01831979c963619 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_ace.py @@ -0,0 +1,145 @@ +import torch +from typing_extensions import override + +import comfy.model_management +import node_helpers +from comfy_api.latest import ComfyExtension, IO + + +class TextEncodeAceStepAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TextEncodeAceStepAudio", + category="conditioning", + inputs=[ + IO.Clip.Input("clip"), + IO.String.Input("tags", multiline=True, dynamic_prompts=True), + IO.String.Input("lyrics", multiline=True, dynamic_prompts=True), + IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[IO.Conditioning.Output()], + ) + + @classmethod + def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics) + conditioning = clip.encode_from_tokens_scheduled(tokens) + conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) + return IO.NodeOutput(conditioning) + +class TextEncodeAceStepAudio15(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TextEncodeAceStepAudio1.5", + category="conditioning", + inputs=[ + IO.Clip.Input("clip"), + IO.String.Input("tags", multiline=True, dynamic_prompts=True), + IO.String.Input("lyrics", multiline=True, dynamic_prompts=True), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + IO.Int.Input("bpm", default=120, min=10, max=300), + IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), + IO.Combo.Input("timesignature", options=['2', '3', '4', '6']), + IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), + IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), + IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), + IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), + IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), + IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True), + IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True), + ], + outputs=[IO.Conditioning.Output()], + ) + + @classmethod + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p) + conditioning = clip.encode_from_tokens_scheduled(tokens) + return IO.NodeOutput(conditioning) + + +class EmptyAceStepLatentAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyAceStepLatentAudio", + display_name="Empty Ace Step 1.0 Latent Audio", + category="latent/audio", + inputs=[ + IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), + IO.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[IO.Latent.Output()], + ) + + @classmethod + def execute(cls, seconds, batch_size) -> IO.NodeOutput: + length = int(seconds * 44100 / 512 / 8) + latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) + return IO.NodeOutput({"samples": latent, "type": "audio"}) + + +class EmptyAceStep15LatentAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyAceStep1.5LatentAudio", + display_name="Empty Ace Step 1.5 Latent Audio", + category="latent/audio", + inputs=[ + IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), + IO.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[IO.Latent.Output()], + ) + + @classmethod + def execute(cls, seconds, batch_size) -> IO.NodeOutput: + length = round((seconds * 48000 / 1920)) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) + return IO.NodeOutput({"samples": latent, "type": "audio"}) + +class ReferenceAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReferenceTimbreAudio", + display_name="Reference Audio", + category="advanced/conditioning/audio", + is_experimental=True, + description="This node sets the reference audio for ace step 1.5", + inputs=[ + IO.Conditioning.Input("conditioning"), + IO.Latent.Input("latent", optional=True), + ], + outputs=[ + IO.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, conditioning, latent=None) -> IO.NodeOutput: + if latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True) + return IO.NodeOutput(conditioning) + +class AceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TextEncodeAceStepAudio, + EmptyAceStepLatentAudio, + TextEncodeAceStepAudio15, + EmptyAceStep15LatentAudio, + ReferenceAudio, + ] + +async def comfy_entrypoint() -> AceExtension: + return AceExtension() diff --git a/ComfyUI/comfy_extras/nodes_advanced_samplers.py b/ComfyUI/comfy_extras/nodes_advanced_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..fc06217ddea0b4d092bba9c6c892c243e3550d76 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_advanced_samplers.py @@ -0,0 +1,121 @@ +import numpy as np +import torch +from tqdm.auto import trange +from typing_extensions import override + +import comfy.model_patcher +import comfy.samplers +import comfy.utils +from comfy.k_diffusion.sampling import to_d +from comfy_api.latest import ComfyExtension, io + + +@torch.no_grad() +def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable=None, total_upscale=2.0, upscale_method="bislerp", upscale_steps=None): + extra_args = {} if extra_args is None else extra_args + + if upscale_steps is None: + upscale_steps = max(len(sigmas) // 2 + 1, 2) + else: + upscale_steps += 1 + upscale_steps = min(upscale_steps, len(sigmas) + 1) + + upscales = np.linspace(1.0, total_upscale, upscale_steps)[1:] + + orig_shape = x.size() + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + x = denoised + if i < len(upscales): + x = comfy.utils.common_upscale(x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled") + + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * torch.randn_like(x) + return x + + +class SamplerLCMUpscale(io.ComfyNode): + UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerLCMUpscale", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True), + io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True), + io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), + ], + outputs=[io.Sampler.Output()], + ) + + @classmethod + def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput: + if scale_steps < 0: + scale_steps = None + sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method}) + return io.NodeOutput(sampler) + + +@torch.no_grad() +def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): + extra_args = {} if extra_args is None else extra_args + + temp = [0] + def post_cfg_function(args): + temp[0] = args["uncond_denoised"] + return args["denoised"] + + model_options = extra_args.get("model_options", {}).copy() + extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + sigma_hat = sigmas[i] + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x - denoised + temp[0], sigmas[i], denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + x = x + d * dt + return x + + +class SamplerEulerCFGpp(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerEulerCFGpp", + display_name="SamplerEulerCFG++", + category="_for_testing", # "sampling/custom_sampling/samplers" + inputs=[ + io.Combo.Input("version", options=["regular", "alternative"], advanced=True), + ], + outputs=[io.Sampler.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, version) -> io.NodeOutput: + if version == "alternative": + sampler = comfy.samplers.KSAMPLER(sample_euler_pp) + else: + sampler = comfy.samplers.ksampler("euler_cfg_pp") + return io.NodeOutput(sampler) + + +class AdvancedSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerLCMUpscale, + SamplerEulerCFGpp, + ] + +async def comfy_entrypoint() -> AdvancedSamplersExtension: + return AdvancedSamplersExtension() diff --git a/ComfyUI/comfy_extras/nodes_align_your_steps.py b/ComfyUI/comfy_extras/nodes_align_your_steps.py new file mode 100644 index 0000000000000000000000000000000000000000..e83f3bb362ab3549595877086e2a6f633404febd --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_align_your_steps.py @@ -0,0 +1,70 @@ +#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html +import numpy as np +import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + +NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582], + "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], + "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]} + +class AlignYourStepsScheduler(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AlignYourStepsScheduler", + search_aliases=["AYS scheduler"], + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]), + io.Int.Input("steps", default=10, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()], + ) + + def get_sigmas(self, model_type, steps, denoise): + # Deprecated: use the V3 schema's `execute` method instead of this. + return AlignYourStepsScheduler().execute(model_type, steps, denoise).result + + @classmethod + def execute(cls, model_type, steps, denoise) -> io.NodeOutput: + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return io.NodeOutput(torch.FloatTensor([])) + total_steps = round(steps * denoise) + + sigmas = NOISE_LEVELS[model_type][:] + if (steps + 1) != len(sigmas): + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas = sigmas[-(total_steps + 1):] + sigmas[-1] = 0 + return io.NodeOutput(torch.FloatTensor(sigmas)) + + +class AlignYourStepsExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + AlignYourStepsScheduler, + ] + +async def comfy_entrypoint() -> AlignYourStepsExtension: + return AlignYourStepsExtension() diff --git a/ComfyUI/comfy_extras/nodes_apg.py b/ComfyUI/comfy_extras/nodes_apg.py new file mode 100644 index 0000000000000000000000000000000000000000..a7bfd5231c6477950ac27589b06a771b2bda12c3 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_apg.py @@ -0,0 +1,110 @@ +import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +def project(v0, v1): + v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) + v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + return v0_parallel, v0_orthogonal + +class APG(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="APG", + display_name="Adaptive Projected Guidance", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "eta", + default=1.0, + min=-10.0, + max=10.0, + step=0.01, + tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.", + advanced=True, + ), + io.Float.Input( + "norm_threshold", + default=5.0, + min=0.0, + max=50.0, + step=0.1, + tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.", + advanced=True, + ), + io.Float.Input( + "momentum", + default=0.0, + min=-5.0, + max=1.0, + step=0.01, + tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.", + advanced=True, + ), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput: + running_avg = 0 + prev_sigma = None + + def pre_cfg_function(args): + nonlocal running_avg, prev_sigma + + if len(args["conds_out"]) == 1: + return args["conds_out"] + + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + sigma = args["sigma"][0] + cond_scale = args["cond_scale"] + + if prev_sigma is not None and sigma > prev_sigma: + running_avg = 0 + prev_sigma = sigma + + guidance = cond - uncond + + if momentum != 0: + if not torch.is_tensor(running_avg): + running_avg = guidance + else: + running_avg = momentum * running_avg + guidance + guidance = running_avg + + if norm_threshold > 0: + guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale = torch.minimum( + torch.ones_like(guidance_norm), + norm_threshold / guidance_norm + ) + guidance = guidance * scale + + guidance_parallel, guidance_orthogonal = project(guidance, cond) + modified_guidance = guidance_orthogonal + eta * guidance_parallel + + modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale + + return [modified_cond, uncond] + args["conds_out"][2:] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(pre_cfg_function) + return io.NodeOutput(m) + + +class ApgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + APG, + ] + +async def comfy_entrypoint() -> ApgExtension: + return ApgExtension() diff --git a/ComfyUI/comfy_extras/nodes_attention_multiply.py b/ComfyUI/comfy_extras/nodes_attention_multiply.py new file mode 100644 index 0000000000000000000000000000000000000000..ff53f28f06363015c0d9798d91da6f02c4135e9a --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_attention_multiply.py @@ -0,0 +1,151 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +def attention_multiply(attn, model, q, k, v, out): + m = model.clone() + sd = model.model_state_dict() + + for key in sd: + if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, q) + if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, k) + if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, v) + if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, out) + + return m + + +class UNetSelfAttentionMultiply(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetSelfAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: + m = attention_multiply("attn1", model, q, k, v, out) + return io.NodeOutput(m) + + +class UNetCrossAttentionMultiply(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetCrossAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: + m = attention_multiply("attn2", model, q, k, v, out) + return io.NodeOutput(m) + + +class CLIPAttentionMultiply(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CLIPAttentionMultiply", + search_aliases=["clip attention scale", "text encoder attention"], + category="_for_testing/attention_experiments", + inputs=[ + io.Clip.Input("clip"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + ], + outputs=[io.Clip.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, clip, q, k, v, out) -> io.NodeOutput: + m = clip.clone() + sd = m.patcher.model_state_dict() + + for key in sd: + if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"): + m.add_patches({key: (None,)}, 0.0, q) + if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"): + m.add_patches({key: (None,)}, 0.0, k) + if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"): + m.add_patches({key: (None,)}, 0.0, v) + if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"): + m.add_patches({key: (None,)}, 0.0, out) + return io.NodeOutput(m) + + +class UNetTemporalAttentionMultiply(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetTemporalAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput: + m = model.clone() + sd = model.model_state_dict() + + for k in sd: + if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")): + if '.time_stack.' in k: + m.add_patches({k: (None,)}, 0.0, self_temporal) + else: + m.add_patches({k: (None,)}, 0.0, self_structural) + elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")): + if '.time_stack.' in k: + m.add_patches({k: (None,)}, 0.0, cross_temporal) + else: + m.add_patches({k: (None,)}, 0.0, cross_structural) + return io.NodeOutput(m) + + +class AttentionMultiplyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + UNetSelfAttentionMultiply, + UNetCrossAttentionMultiply, + CLIPAttentionMultiply, + UNetTemporalAttentionMultiply, + ] + +async def comfy_entrypoint() -> AttentionMultiplyExtension: + return AttentionMultiplyExtension() diff --git a/ComfyUI/comfy_extras/nodes_audio.py b/ComfyUI/comfy_extras/nodes_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..d5139207491d393be1b213eb0cc632db778bc199 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_audio.py @@ -0,0 +1,794 @@ +from __future__ import annotations + +import av +import torchaudio +import torch +import comfy.model_management +import folder_paths +import os +import hashlib +import node_helpers +import logging +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI + +class EmptyLatentAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentAudio", + display_name="Empty Latent Audio", + category="latent/audio", + essentials_category="Audio", + inputs=[ + IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), + IO.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch.", + ), + ], + outputs=[IO.Latent.Output()], + ) + + @classmethod + def execute(cls, seconds, batch_size) -> IO.NodeOutput: + length = round((seconds * 44100 / 2048) / 2) * 2 + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return IO.NodeOutput({"samples":latent, "type": "audio"}) + + generate = execute # TODO: remove + + +class ConditioningStableAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ConditioningStableAudio", + category="conditioning", + inputs=[ + IO.Conditioning.Input("positive"), + IO.Conditioning.Input("negative"), + IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1), + IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput: + positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total}) + negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total}) + return IO.NodeOutput(positive, negative) + + append = execute # TODO: remove + + +class VAEEncodeAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEEncodeAudio", + search_aliases=["audio to latent"], + display_name="VAE Encode Audio", + category="latent/audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Latent.Output()], + ) + + @classmethod + def execute(cls, vae, audio) -> IO.NodeOutput: + sample_rate = audio["sample_rate"] + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + if vae_sample_rate != sample_rate: + waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate) + else: + waveform = audio["waveform"] + + t = vae.encode(waveform.movedim(1, -1)) + return IO.NodeOutput({"samples": t}) + + encode = execute # TODO: remove + + +def vae_decode_audio(vae, samples, tile=None, overlap=None): + if tile is not None: + audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1) + else: + audio = vae.decode(samples["samples"]).movedim(-1, 1) + + std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 + std[std < 1.0] = 1.0 + audio /= std + vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100)) + return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} + + +class VAEDecodeAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeAudio", + search_aliases=["latent to audio"], + display_name="VAE Decode Audio", + category="latent/audio", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, vae, samples) -> IO.NodeOutput: + return IO.NodeOutput(vae_decode_audio(vae, samples)) + + decode = execute # TODO: remove + + +class VAEDecodeAudioTiled(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeAudioTiled", + search_aliases=["latent to audio"], + display_name="VAE Decode Audio (Tiled)", + category="latent/audio", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8), + IO.Int.Input("overlap", default=64, min=0, max=1024, step=8), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput: + return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap)) + + +class SaveAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudio", + search_aliases=["export flac"], + display_name="Save Audio (FLAC)", + category="audio", + essentials_category="Audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) + ) + + save_flac = execute # TODO: remove + + +class SaveAudioMP3(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioMP3", + search_aliases=["export mp3"], + display_name="Save Audio (MP3)", + category="audio", + essentials_category="Audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) + + save_mp3 = execute # TODO: remove + + +class SaveAudioOpus(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioOpus", + search_aliases=["export opus"], + display_name="Save Audio (Opus)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) + + save_opus = execute # TODO: remove + + +class PreviewAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewAudio", + search_aliases=["play audio"], + display_name="Preview Audio", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, audio) -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls)) + + save_flac = execute # TODO: remove + + +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2 ** 15) + elif wav.dtype == torch.int32: + return wav.float() / (2 ** 31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + +def load(filepath: str) -> tuple[torch.Tensor, int]: + with av.open(filepath) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in the file.") + + stream = af.streams.audio[0] + sr = stream.codec_context.sample_rate + n_channels = stream.channels + + frames = [] + length = 0 + for frame in af.decode(streams=stream.index): + buf = torch.from_numpy(frame.to_ndarray()) + if buf.shape[0] != n_channels: + buf = buf.view(-1, n_channels).t() + + frames.append(buf) + length += buf.shape[1] + + if not frames: + raise ValueError("No audio frames decoded.") + + wav = torch.cat(frames, dim=1) + wav = f32_pcm(wav) + return wav, sr + +class LoadAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + input_dir = folder_paths.get_input_directory() + files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) + return IO.Schema( + node_id="LoadAudio", + search_aliases=["import audio", "open audio", "audio file"], + display_name="Load Audio", + category="audio", + essentials_category="Audio", + inputs=[ + IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio) -> IO.NodeOutput: + audio_path = folder_paths.get_annotated_filepath(audio) + waveform, sample_rate = load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return IO.NodeOutput(audio) + + @classmethod + def fingerprint_inputs(cls, audio): + image_path = folder_paths.get_annotated_filepath(audio) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def validate_inputs(cls, audio): + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True + + load = execute # TODO: remove + + +class RecordAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecordAudio", + search_aliases=["microphone input", "audio capture", "voice input"], + display_name="Record Audio", + category="audio", + inputs=[ + IO.Custom("AUDIO_RECORD").Input("audio"), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio) -> IO.NodeOutput: + audio_path = folder_paths.get_annotated_filepath(audio) + + waveform, sample_rate = load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return IO.NodeOutput(audio) + + load = execute # TODO: remove + + +class TrimAudioDuration(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TrimAudioDuration", + search_aliases=["cut audio", "audio clip", "shorten audio"], + display_name="Trim Audio Duration", + description="Trim audio tensor into chosen time range.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input( + "start_index", + default=0.0, + min=-0xffffffffffffffff, + max=0xffffffffffffffff, + step=0.01, + tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).", + ), + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + step=0.01, + tooltip="Duration in seconds", + ), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio, start_index, duration) -> IO.NodeOutput: + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + audio_length = waveform.shape[-1] + + if start_index < 0: + start_frame = audio_length + int(round(start_index * sample_rate)) + else: + start_frame = int(round(start_index * sample_rate)) + start_frame = max(0, min(start_frame, audio_length - 1)) + + end_frame = start_frame + int(round(duration * sample_rate)) + end_frame = max(0, min(end_frame, audio_length)) + + if start_frame >= end_frame: + raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.") + + return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate}) + + trim = execute # TODO: remove + + +class SplitAudioChannels(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplitAudioChannels", + search_aliases=["stereo to mono"], + display_name="Split Audio Channels", + description="Separates the audio into left and right channels.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + outputs=[ + IO.Audio.Output(display_name="left"), + IO.Audio.Output(display_name="right"), + ], + ) + + @classmethod + def execute(cls, audio) -> IO.NodeOutput: + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + if waveform.shape[1] != 2: + raise ValueError("AudioSplit: Input audio has only one channel.") + + left_channel = waveform[..., 0:1, :] + right_channel = waveform[..., 1:2, :] + + return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + + separate = execute # TODO: remove + +class JoinAudioChannels(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="JoinAudioChannels", + display_name="Join Audio Channels", + description="Joins left and right mono audio channels into a stereo audio.", + category="audio", + inputs=[ + IO.Audio.Input("audio_left"), + IO.Audio.Input("audio_right"), + ], + outputs=[ + IO.Audio.Output(display_name="audio"), + ], + ) + + @classmethod + def execute(cls, audio_left, audio_right) -> IO.NodeOutput: + waveform_left = audio_left["waveform"] + sample_rate_left = audio_left["sample_rate"] + waveform_right = audio_right["waveform"] + sample_rate_right = audio_right["sample_rate"] + + if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1: + raise ValueError("AudioJoin: Both input audios must be mono.") + + # Handle different sample rates by resampling to the higher rate + waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates( + waveform_left, sample_rate_left, waveform_right, sample_rate_right + ) + + # Handle different lengths by trimming to the shorter length + length_left = waveform_left.shape[-1] + length_right = waveform_right.shape[-1] + + if length_left != length_right: + min_length = min(length_left, length_right) + if length_left > min_length: + logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.") + waveform_left = waveform_left[..., :min_length] + if length_right > min_length: + logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.") + waveform_right = waveform_right[..., :min_length] + + # Join the channels into stereo + left_channel = waveform_left[..., 0:1, :] + right_channel = waveform_right[..., 0:1, :] + stereo_waveform = torch.cat([left_channel, right_channel], dim=1) + + return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate}) + + +def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): + if sample_rate_1 != sample_rate_2: + if sample_rate_1 > sample_rate_2: + waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1) + output_sample_rate = sample_rate_1 + logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.") + else: + waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2) + output_sample_rate = sample_rate_2 + logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.") + else: + output_sample_rate = sample_rate_1 + return waveform_1, waveform_2, output_sample_rate + + +class AudioConcat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioConcat", + search_aliases=["join audio", "combine audio", "append audio"], + display_name="Audio Concat", + description="Concatenates the audio1 to audio2 in the specified direction.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "direction", + options=['after', 'before'], + default="after", + tooltip="Whether to append audio2 after or before audio1.", + ) + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio1, audio2, direction) -> IO.NodeOutput: + waveform_1 = audio1["waveform"] + waveform_2 = audio2["waveform"] + sample_rate_1 = audio1["sample_rate"] + sample_rate_2 = audio2["sample_rate"] + + if waveform_1.shape[1] == 1: + waveform_1 = waveform_1.repeat(1, 2, 1) + logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.") + if waveform_2.shape[1] == 1: + waveform_2 = waveform_2.repeat(1, 2, 1) + logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.") + + waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) + + if direction == 'after': + concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2) + elif direction == 'before': + concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) + + return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate}) + + concat = execute # TODO: remove + + +class AudioMerge(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioMerge", + search_aliases=["mix audio", "overlay audio", "layer audio"], + display_name="Audio Merge", + description="Combine two audio tracks by overlaying their waveforms.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "merge_method", + options=["add", "mean", "subtract", "multiply"], + tooltip="The method used to combine the audio waveforms.", + ) + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput: + waveform_1 = audio1["waveform"] + waveform_2 = audio2["waveform"] + sample_rate_1 = audio1["sample_rate"] + sample_rate_2 = audio2["sample_rate"] + + waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) + + length_1 = waveform_1.shape[-1] + length_2 = waveform_2.shape[-1] + + if length_2 > length_1: + logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.") + waveform_2 = waveform_2[..., :length_1] + elif length_2 < length_1: + logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.") + pad_shape = list(waveform_2.shape) + pad_shape[-1] = length_1 - length_2 + pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device) + waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1) + + if merge_method == "add": + waveform = waveform_1 + waveform_2 + elif merge_method == "subtract": + waveform = waveform_1 - waveform_2 + elif merge_method == "multiply": + waveform = waveform_1 * waveform_2 + elif merge_method == "mean": + waveform = (waveform_1 + waveform_2) / 2 + + max_val = waveform.abs().max() + if max_val > 1.0: + waveform = waveform / max_val + + return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate}) + + merge = execute # TODO: remove + + +class AudioAdjustVolume(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioAdjustVolume", + search_aliases=["audio gain", "loudness", "audio level"], + display_name="Audio Adjust Volume", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Int.Input( + "volume", + default=1, + min=-100, + max=100, + tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc", + ) + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio, volume) -> IO.NodeOutput: + if volume == 0: + return IO.NodeOutput(audio) + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + gain = 10 ** (volume / 20) + waveform = waveform * gain + + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + adjust_volume = execute # TODO: remove + + +class EmptyAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyAudio", + search_aliases=["blank audio"], + display_name="Empty Audio", + category="audio", + inputs=[ + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + max=0xffffffffffffffff, + step=0.01, + tooltip="Duration of the empty audio clip in seconds", + ), + IO.Int.Input( + "sample_rate", + default=44100, + tooltip="Sample rate of the empty audio clip.", + min=1, + max=192000, + advanced=True, + ), + IO.Int.Input( + "channels", + default=2, + min=1, + max=2, + tooltip="Number of audio channels (1 for mono, 2 for stereo).", + advanced=True, + ), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput: + num_samples = int(round(duration * sample_rate)) + waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + create_empty_audio = execute # TODO: remove + + +class AudioEqualizer3Band(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioEqualizer3Band", + search_aliases=["eq", "bass boost", "treble boost", "equalizer"], + display_name="Audio Equalizer (3-Band)", + category="audio", + is_experimental=True, + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"), + IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"), + IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"), + IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"), + IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"), + IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"), + IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput: + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + eq_waveform = waveform.clone() + + # 1. Apply Low Shelf (Bass) + if low_gain_dB != 0: + eq_waveform = torchaudio.functional.bass_biquad( + eq_waveform, + sample_rate, + gain=low_gain_dB, + central_freq=float(low_freq), + Q=0.707 + ) + + # 2. Apply Peaking EQ (Mids) + if mid_gain_dB != 0: + eq_waveform = torchaudio.functional.equalizer_biquad( + eq_waveform, + sample_rate, + center_freq=float(mid_freq), + gain=mid_gain_dB, + Q=mid_q + ) + + # 3. Apply High Shelf (Treble) + if high_gain_dB != 0: + eq_waveform = torchaudio.functional.treble_biquad( + eq_waveform, + sample_rate, + gain=high_gain_dB, + central_freq=float(high_freq), + Q=0.707 + ) + + return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate}) + + +class AudioExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyLatentAudio, + VAEEncodeAudio, + VAEDecodeAudio, + VAEDecodeAudioTiled, + SaveAudio, + SaveAudioMP3, + SaveAudioOpus, + LoadAudio, + PreviewAudio, + ConditioningStableAudio, + RecordAudio, + TrimAudioDuration, + SplitAudioChannels, + JoinAudioChannels, + AudioConcat, + AudioMerge, + AudioAdjustVolume, + EmptyAudio, + AudioEqualizer3Band, + ] + +async def comfy_entrypoint() -> AudioExtension: + return AudioExtension() diff --git a/ComfyUI/comfy_extras/nodes_audio_encoder.py b/ComfyUI/comfy_extras/nodes_audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..530fd09cfc3b754d8c01e66d44822d8b13b69c41 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_audio_encoder.py @@ -0,0 +1,62 @@ +import folder_paths +import comfy.audio_encoders.audio_encoders +import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class AudioEncoderLoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AudioEncoderLoader", + category="loaders", + inputs=[ + io.Combo.Input( + "audio_encoder_name", + options=folder_paths.get_filename_list("audio_encoders"), + ), + ], + outputs=[io.AudioEncoder.Output()], + ) + + @classmethod + def execute(cls, audio_encoder_name) -> io.NodeOutput: + audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name) + sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True) + audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd) + if audio_encoder is None: + raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.") + return io.NodeOutput(audio_encoder) + + +class AudioEncoderEncode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AudioEncoderEncode", + category="conditioning", + inputs=[ + io.AudioEncoder.Input("audio_encoder"), + io.Audio.Input("audio"), + ], + outputs=[io.AudioEncoderOutput.Output()], + ) + + @classmethod + def execute(cls, audio_encoder, audio) -> io.NodeOutput: + output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"]) + return io.NodeOutput(output) + + +class AudioEncoder(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + AudioEncoderLoader, + AudioEncoderEncode, + ] + + +async def comfy_entrypoint() -> AudioEncoder: + return AudioEncoder() diff --git a/ComfyUI/comfy_extras/nodes_camera_trajectory.py b/ComfyUI/comfy_extras/nodes_camera_trajectory.py new file mode 100644 index 0000000000000000000000000000000000000000..b72d32d84a6259068bbbdf6f3181e495858537e6 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_camera_trajectory.py @@ -0,0 +1,239 @@ +import nodes +import torch +import numpy as np +from einops import rearrange +from typing_extensions import override +import comfy.model_management + +from comfy_api.latest import ComfyExtension, io + + +CAMERA_DICT = { + "base_T_norm": 1.5, + "base_angle": np.pi/3, + "Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]}, + "Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]}, + "Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]}, + "Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]}, + "Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]}, + "Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]}, + "Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]}, + "Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]}, + "ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]}, +} + + +def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): + + def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + c2w_mat = np.array(entry[7:]).reshape(4, 4) + self.c2w_mat = c2w_mat + self.w2c_mat = np.linalg.inv(c2w_mat) + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = torch.meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + indexing='ij' + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + +def get_camera_motion(angle, T, speed, n=81): + def compute_R_form_rad_angle(angles): + theta_x, theta_y, theta_z = angles + Rx = np.array([[1, 0, 0], + [0, np.cos(theta_x), -np.sin(theta_x)], + [0, np.sin(theta_x), np.cos(theta_x)]]) + + Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], + [0, 1, 0], + [-np.sin(theta_y), 0, np.cos(theta_y)]]) + + Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], + [np.sin(theta_z), np.cos(theta_z), 0], + [0, 0, 1]]) + + R = np.dot(Rz, np.dot(Ry, Rx)) + return R + RT = [] + for i in range(n): + _angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle + R = compute_R_form_rad_angle(_angle) + _T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1)) + _RT = np.concatenate([R,_T], axis=1) + RT.append(_RT) + RT = np.stack(RT) + return RT + +class WanCameraEmbedding(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanCameraEmbedding", + category="camera", + inputs=[ + io.Combo.Input( + "camera_pose", + options=[ + "Static", + "Pan Up", + "Pan Down", + "Pan Left", + "Pan Right", + "Zoom In", + "Zoom Out", + "Anti Clockwise (ACW)", + "ClockWise (CW)", + ], + default="Static", + ), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True), + io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True), + io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True), + io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True), + io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True), + ], + outputs=[ + io.WanCameraEmbedding.Output(display_name="camera_embedding"), + io.Int.Output(display_name="width"), + io.Int.Output(display_name="height"), + io.Int.Output(display_name="length"), + ], + ) + + @classmethod + def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput: + """ + Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021) + Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py + """ + motion_list = [camera_pose] + speed = speed + angle = np.array(CAMERA_DICT[motion_list[0]]["angle"]) + T = np.array(CAMERA_DICT[motion_list[0]]["T"]) + RT = get_camera_motion(angle, T, speed, length) + + trajs=[] + for cp in RT.tolist(): + traj=[fx,fy,cx,cy,0,0] + traj.extend(cp[0]) + traj.extend(cp[1]) + traj.extend(cp[2]) + traj.extend([0,0,0,1]) + trajs.append(traj) + + cam_params = np.array([[float(x) for x in pose] for pose in trajs]) + cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1) + control_camera_video = process_pose_params(cam_params, width=width, height=height) + control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device()) + + control_camera_video = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + + # Reshape, transpose, and view into desired shape + b, f, c, h, w = control_camera_video.shape + control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + + return io.NodeOutput(control_camera_video, width, height, length) + + +class CameraTrajectoryExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanCameraEmbedding, + ] + +async def comfy_entrypoint() -> CameraTrajectoryExtension: + return CameraTrajectoryExtension() diff --git a/ComfyUI/comfy_extras/nodes_canny.py b/ComfyUI/comfy_extras/nodes_canny.py new file mode 100644 index 0000000000000000000000000000000000000000..542c1f46862183135097aac93782f71edd00e766 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_canny.py @@ -0,0 +1,45 @@ +from kornia.filters import canny +from typing_extensions import override + +import comfy.model_management +from comfy_api.latest import ComfyExtension, io +import torch + + +class Canny(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Canny", + display_name="Canny", + search_aliases=["edge detection", "outline", "contour detection", "line art"], + category="image/preprocessors", + essentials_category="Image Tools", + inputs=[ + io.Image.Input("image"), + io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01), + io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01), + ], + outputs=[io.Image.Output()], + ) + + @classmethod + def detect_edge(cls, image, low_threshold, high_threshold): + # Deprecated: use the V3 schema's `execute` method instead of this. + return cls.execute(image, low_threshold, high_threshold) + + @classmethod + def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput: + output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold) + img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1) + return io.NodeOutput(img_out) + + +class CannyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [Canny] + + +async def comfy_entrypoint() -> CannyExtension: + return CannyExtension() diff --git a/ComfyUI/comfy_extras/nodes_cfg.py b/ComfyUI/comfy_extras/nodes_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..20cf3f2a7566b701edd971801873cdefaa30a7c9 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_cfg.py @@ -0,0 +1,91 @@ +from typing_extensions import override + +import torch + +from comfy_api.latest import ComfyExtension, io + + +# https://github.com/WeichenFan/CFG-Zero-star +def optimized_scale(positive, negative): + positive_flat = positive.reshape(positive.shape[0], -1) + negative_flat = negative.reshape(negative.shape[0], -1) + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) + +class CFGZeroStar(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGZeroStar", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + ], + outputs=[io.Model.Output(display_name="patched_model")], + ) + + @classmethod + def execute(cls, model) -> io.NodeOutput: + m = model.clone() + def cfg_zero_star(args): + guidance_scale = args['cond_scale'] + x = args['input'] + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + out = args["denoised"] + alpha = optimized_scale(x - cond_p, x - uncond_p) + + return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) + m.set_model_sampler_post_cfg_function(cfg_zero_star) + return io.NodeOutput(m) + +class CFGNorm(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGNorm", + category="advanced/guidance", + inputs=[ + io.Model.Input("model"), + io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[io.Model.Output(display_name="patched_model")], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, strength) -> io.NodeOutput: + m = model.clone() + def cfg_norm(args): + cond_p = args['cond_denoised'] + pred_text_ = args["denoised"] + + norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) + norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) + scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) + return pred_text_ * scale * strength + + m.set_model_sampler_post_cfg_function(cfg_norm) + return io.NodeOutput(m) + + +class CfgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CFGZeroStar, + CFGNorm, + ] + + +async def comfy_entrypoint() -> CfgExtension: + return CfgExtension() diff --git a/ComfyUI/comfy_extras/nodes_chroma_radiance.py b/ComfyUI/comfy_extras/nodes_chroma_radiance.py new file mode 100644 index 0000000000000000000000000000000000000000..93d05e771642364ab06eb755e94d6feede99429f --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_chroma_radiance.py @@ -0,0 +1,117 @@ +from typing_extensions import override +from typing import Callable + +import torch + +import comfy.model_management +from comfy_api.latest import ComfyExtension, io + +import nodes + +class EmptyChromaRadianceLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyChromaRadianceLatentImage", + category="latent/chroma_radiance", + inputs=[ + io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input(id="batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent().Output()], + ) + + @classmethod + def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput: + latent = torch.zeros((batch_size, 3, height, width), device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) + + +class ChromaRadianceOptions(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ChromaRadianceOptions", + category="model_patches/chroma_radiance", + description="Allows setting advanced options for the Chroma Radiance model.", + inputs=[ + io.Model.Input(id="model"), + io.Boolean.Input( + id="preserve_wrapper", + default=True, + tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.", + ), + io.Float.Input( + id="start_sigma", + default=1.0, + min=0.0, + max=1.0, + tooltip="First sigma that these options will be in effect.", + advanced=True, + ), + io.Float.Input( + id="end_sigma", + default=0.0, + min=0.0, + max=1.0, + tooltip="Last sigma that these options will be in effect.", + advanced=True, + ), + io.Int.Input( + id="nerf_tile_size", + default=-1, + min=-1, + tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", + advanced=True, + ), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute( + cls, + *, + model: io.Model.Type, + preserve_wrapper: bool, + start_sigma: float, + end_sigma: float, + nerf_tile_size: int, + ) -> io.NodeOutput: + radiance_options = {} + if nerf_tile_size >= 0: + radiance_options["nerf_tile_size"] = nerf_tile_size + + if not radiance_options: + return io.NodeOutput(model) + + old_wrapper = model.model_options.get("model_function_wrapper") + + def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor: + c = args["c"].copy() + sigma = args["timestep"].max().detach().cpu().item() + if end_sigma <= sigma <= start_sigma: + transformer_options = c.get("transformer_options", {}).copy() + transformer_options["chroma_radiance_options"] = radiance_options.copy() + c["transformer_options"] = transformer_options + if not (preserve_wrapper and old_wrapper): + return apply_model(args["input"], args["timestep"], **c) + return old_wrapper(apply_model, args | {"c": c}) + + model = model.clone() + model.set_model_unet_function_wrapper(model_function_wrapper) + return io.NodeOutput(model) + + +class ChromaRadianceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyChromaRadianceLatentImage, + ChromaRadianceOptions, + ] + + +async def comfy_entrypoint() -> ChromaRadianceExtension: + return ChromaRadianceExtension() diff --git a/ComfyUI/comfy_extras/nodes_clip_sdxl.py b/ComfyUI/comfy_extras/nodes_clip_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..47076ceef499649039deeca4eca09104959de5dc --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_clip_sdxl.py @@ -0,0 +1,71 @@ +from typing_extensions import override + +import nodes +from comfy_api.latest import ComfyExtension, io + + +class CLIPTextEncodeSDXLRefiner(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSDXLRefiner", + category="advanced/conditioning", + inputs=[ + io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01), + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text", multiline=True, dynamic_prompts=True), + io.Clip.Input("clip"), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, clip, ascore, width, height, text) -> io.NodeOutput: + tokens = clip.tokenize(text) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height})) + +class CLIPTextEncodeSDXL(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSDXL", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True), + io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True), + io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION), + io.String.Input("text_g", multiline=True, dynamic_prompts=True), + io.String.Input("text_l", multiline=True, dynamic_prompts=True), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput: + tokens = clip.tokenize(text_g) + tokens["l"] = clip.tokenize(text_l)["l"] + if len(tokens["l"]) != len(tokens["g"]): + empty = clip.tokenize("") + while len(tokens["l"]) < len(tokens["g"]): + tokens["l"] += empty["l"] + while len(tokens["l"]) > len(tokens["g"]): + tokens["g"] += empty["g"] + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height})) + + +class ClipSdxlExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeSDXLRefiner, + CLIPTextEncodeSDXL, + ] + + +async def comfy_entrypoint() -> ClipSdxlExtension: + return ClipSdxlExtension() diff --git a/ComfyUI/comfy_extras/nodes_color.py b/ComfyUI/comfy_extras/nodes_color.py new file mode 100644 index 0000000000000000000000000000000000000000..96744ec7cc0243d4d1d17db0c276ad8e41483a95 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_color.py @@ -0,0 +1,42 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class ColorToRGBInt(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ColorToRGBInt", + display_name="Color to RGB Int", + category="utils", + description="Convert a color to a RGB integer value.", + inputs=[ + io.Color.Input("color"), + ], + outputs=[ + io.Int.Output(display_name="rgb_int"), + ], + ) + + @classmethod + def execute( + cls, + color: str, + ) -> io.NodeOutput: + # expect format #RRGGBB + if len(color) != 7 or color[0] != "#": + raise ValueError("Color must be in format #RRGGBB") + r = int(color[1:3], 16) + g = int(color[3:5], 16) + b = int(color[5:7], 16) + return io.NodeOutput(r * 256 * 256 + g * 256 + b) + + +class ColorExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ColorToRGBInt] + + +async def comfy_entrypoint() -> ColorExtension: + return ColorExtension() diff --git a/ComfyUI/comfy_extras/nodes_compositing.py b/ComfyUI/comfy_extras/nodes_compositing.py new file mode 100644 index 0000000000000000000000000000000000000000..b5bf613b728a8c589d17b4f03bf06f37fd0d7e42 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_compositing.py @@ -0,0 +1,226 @@ +import torch +import comfy.utils +from enum import Enum +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +def resize_mask(mask, shape): + return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) + +class PorterDuffMode(Enum): + ADD = 0 + CLEAR = 1 + DARKEN = 2 + DST = 3 + DST_ATOP = 4 + DST_IN = 5 + DST_OUT = 6 + DST_OVER = 7 + LIGHTEN = 8 + MULTIPLY = 9 + OVERLAY = 10 + SCREEN = 11 + SRC = 12 + SRC_ATOP = 13 + SRC_IN = 14 + SRC_OUT = 15 + SRC_OVER = 16 + XOR = 17 + + +def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): + # convert mask to alpha + src_alpha = 1 - src_alpha + dst_alpha = 1 - dst_alpha + # premultiply alpha + src_image = src_image * src_alpha + dst_image = dst_image * dst_alpha + + # composite ops below assume alpha-premultiplied images + if mode == PorterDuffMode.ADD: + out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1) + out_image = torch.clamp(src_image + dst_image, 0, 1) + elif mode == PorterDuffMode.CLEAR: + out_alpha = torch.zeros_like(dst_alpha) + out_image = torch.zeros_like(dst_image) + elif mode == PorterDuffMode.DARKEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image) + elif mode == PorterDuffMode.DST: + out_alpha = dst_alpha + out_image = dst_image + elif mode == PorterDuffMode.DST_ATOP: + out_alpha = src_alpha + out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.DST_IN: + out_alpha = src_alpha * dst_alpha + out_image = dst_image * src_alpha + elif mode == PorterDuffMode.DST_OUT: + out_alpha = (1 - src_alpha) * dst_alpha + out_image = (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.DST_OVER: + out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha + out_image = dst_image + (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.LIGHTEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image) + elif mode == PorterDuffMode.MULTIPLY: + out_alpha = src_alpha * dst_alpha + out_image = src_image * dst_image + elif mode == PorterDuffMode.OVERLAY: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image, + src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image)) + elif mode == PorterDuffMode.SCREEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = src_image + dst_image - src_image * dst_image + elif mode == PorterDuffMode.SRC: + out_alpha = src_alpha + out_image = src_image + elif mode == PorterDuffMode.SRC_ATOP: + out_alpha = dst_alpha + out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.SRC_IN: + out_alpha = src_alpha * dst_alpha + out_image = src_image * dst_alpha + elif mode == PorterDuffMode.SRC_OUT: + out_alpha = (1 - dst_alpha) * src_alpha + out_image = (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.SRC_OVER: + out_alpha = src_alpha + (1 - src_alpha) * dst_alpha + out_image = src_image + (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.XOR: + out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + else: + return None, None + + # back to non-premultiplied alpha + out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image)) + out_image = torch.clamp(out_image, 0, 1) + # convert alpha to mask + out_alpha = 1 - out_alpha + return out_image, out_alpha + + +class PorterDuffImageComposite(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PorterDuffImageComposite", + search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"], + display_name="Porter-Duff Image Composite", + category="mask/compositing", + inputs=[ + io.Image.Input("source"), + io.Mask.Input("source_alpha"), + io.Image.Input("destination"), + io.Mask.Input("destination_alpha"), + io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name), + ], + outputs=[ + io.Image.Output(), + io.Mask.Output(), + ], + ) + + @classmethod + def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput: + batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) + out_images = [] + out_alphas = [] + + for i in range(batch_size): + src_image = source[i] + dst_image = destination[i] + + assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels + + src_alpha = source_alpha[i].unsqueeze(2) + dst_alpha = destination_alpha[i].unsqueeze(2) + + if dst_alpha.shape[:2] != dst_image.shape[:2]: + upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') + dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) + if src_image.shape != dst_image.shape: + upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') + src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0) + if src_alpha.shape != dst_alpha.shape: + upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center') + src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) + + out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode]) + + out_images.append(out_image) + out_alphas.append(out_alpha.squeeze(2)) + + return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas)) + + +class SplitImageWithAlpha(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SplitImageWithAlpha", + search_aliases=["extract alpha", "separate transparency", "remove alpha"], + display_name="Split Image with Alpha", + category="mask/compositing", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(), + io.Mask.Output(), + ], + ) + + @classmethod + def execute(cls, image: torch.Tensor) -> io.NodeOutput: + out_images = [i[:,:,:3] for i in image] + out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] + return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas)) + + +class JoinImageWithAlpha(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="JoinImageWithAlpha", + search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"], + display_name="Join Image with Alpha", + category="mask/compositing", + inputs=[ + io.Image.Input("image"), + io.Mask.Input("alpha"), + ], + outputs=[io.Image.Output()], + ) + + @classmethod + def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: + batch_size = min(len(image), len(alpha)) + out_images = [] + + alpha = 1.0 - resize_mask(alpha, image.shape[1:]) + for i in range(batch_size): + out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) + + return io.NodeOutput(torch.stack(out_images)) + + +class CompositingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PorterDuffImageComposite, + SplitImageWithAlpha, + JoinImageWithAlpha, + ] + + +async def comfy_entrypoint() -> CompositingExtension: + return CompositingExtension() diff --git a/ComfyUI/comfy_extras/nodes_cond.py b/ComfyUI/comfy_extras/nodes_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..012e7b693f6328fe361d95e0f3a2ddb64580f91a --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_cond.py @@ -0,0 +1,68 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class CLIPTextEncodeControlnet(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CLIPTextEncodeControlnet", + category="_for_testing/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Conditioning.Input("conditioning"), + io.String.Input("text", multiline=True, dynamic_prompts=True), + ], + outputs=[io.Conditioning.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, clip, conditioning, text) -> io.NodeOutput: + tokens = clip.tokenize(text) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['cross_attn_controlnet'] = cond + n[1]['pooled_output_controlnet'] = pooled + c.append(n) + return io.NodeOutput(c) + +class T5TokenizerOptions(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="T5TokenizerOptions", + category="_for_testing/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True), + io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True), + ], + outputs=[io.Clip.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, clip, min_padding, min_length) -> io.NodeOutput: + clip = clip.clone() + for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]: + clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding) + clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length) + + return io.NodeOutput(clip) + + +class CondExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeControlnet, + T5TokenizerOptions, + ] + + +async def comfy_entrypoint() -> CondExtension: + return CondExtension() diff --git a/ComfyUI/comfy_extras/nodes_context_windows.py b/ComfyUI/comfy_extras/nodes_context_windows.py new file mode 100644 index 0000000000000000000000000000000000000000..355c15e70894964cce050e1b3e944c99fc7ef91c --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_context_windows.py @@ -0,0 +1,103 @@ +from __future__ import annotations +from comfy_api.latest import ComfyExtension, io +import comfy.context_windows +import nodes + + +class ContextWindowsManualNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ContextWindowsManual", + display_name="Context Windows (Manual)", + category="context", + description="Manually set context windows.", + inputs=[ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], tooltip="The stride of the context window."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + ], + outputs=[ + io.Model.Output(tooltip="The model with context windows applied during sampling."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: + model = model.clone() + model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( + context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), + fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method), + context_length=context_length, + context_overlap=context_overlap, + context_stride=context_stride, + closed_loop=closed_loop, + dim=dim, + freenoise=freenoise, + cond_retain_index_list=cond_retain_index_list, + split_conds_to_windows=split_conds_to_windows + ) + # make memory usage calculation only take into account the context window latents + comfy.context_windows.create_prepare_sampling_wrapper(model) + if freenoise: # no other use for this wrapper at this time + comfy.context_windows.create_sampler_sample_wrapper(model) + return io.NodeOutput(model) + +class WanContextWindowsManualNode(ContextWindowsManualNode): + @classmethod + def define_schema(cls) -> io.Schema: + schema = super().define_schema() + schema.node_id = "WanContextWindowsManual" + schema.display_name = "WAN Context Windows (Manual)" + schema.description = "Manually set context windows for WAN-like models (dim=2)." + schema.inputs = [ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True), + io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], tooltip="The stride of the context window."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + ] + return schema + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 + context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) + + +class ContextWindowsExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ContextWindowsManualNode, + WanContextWindowsManualNode, + ] + +def comfy_entrypoint(): + return ContextWindowsExtension() diff --git a/ComfyUI/comfy_extras/nodes_controlnet.py b/ComfyUI/comfy_extras/nodes_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa2f5d7325621ca55c5e664c131c88a6c87353e --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_controlnet.py @@ -0,0 +1,85 @@ +from comfy.cldm.control_types import UNION_CONTROLNET_TYPES +import nodes +import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + +class SetUnionControlNetType(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SetUnionControlNetType", + category="conditioning/controlnet", + inputs=[ + io.ControlNet.Input("control_net"), + io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())), + ], + outputs=[ + io.ControlNet.Output(), + ], + ) + + @classmethod + def execute(cls, control_net, type) -> io.NodeOutput: + control_net = control_net.copy() + type_number = UNION_CONTROLNET_TYPES.get(type, -1) + if type_number >= 0: + control_net.set_extra_arg("control_type", [type_number]) + else: + control_net.set_extra_arg("control_type", []) + + return io.NodeOutput(control_net) + + set_controlnet_type = execute # TODO: remove + + +class ControlNetInpaintingAliMamaApply(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ControlNetInpaintingAliMamaApply", + search_aliases=["masked controlnet"], + category="conditioning/controlnet", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.ControlNet.Input("control_net"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Mask.Input("mask"), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput: + extra_concat = [] + if control_net.concat_mask: + mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) + mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round() + image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3]) + extra_concat = [mask] + + result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat) + return io.NodeOutput(result[0], result[1]) + + apply_inpaint_controlnet = execute # TODO: remove + + +class ControlNetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SetUnionControlNetType, + ControlNetInpaintingAliMamaApply, + ] + + +async def comfy_entrypoint() -> ControlNetExtension: + return ControlNetExtension() diff --git a/ComfyUI/comfy_extras/nodes_cosmos.py b/ComfyUI/comfy_extras/nodes_cosmos.py new file mode 100644 index 0000000000000000000000000000000000000000..70b56039395ff39b4a9da0fe70ed926e8fe20e7c --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_cosmos.py @@ -0,0 +1,143 @@ +from typing_extensions import override +import nodes +import torch +import comfy.model_management +import comfy.utils +import comfy.latent_formats + +from comfy_api.latest import ComfyExtension, io + + +class EmptyCosmosLatentVideo(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EmptyCosmosLatentVideo", + category="latent/video", + inputs=[ + io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[io.Latent.Output()], + ) + + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + +def vae_encode_with_padding(vae, image, width, height, length, padding=0): + pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + pixel_len = min(pixels.shape[0], length) + padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7) + padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5 + padded_pixels[:pixel_len] = pixels[:pixel_len] + latent_len = ((pixel_len - 1) // 8) + 1 + latent_temp = vae.encode(padded_pixels) + return latent_temp[:, :, :latent_len] + + +class CosmosImageToVideoLatent(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CosmosImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[io.Latent.Output()], + ) + + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput: + latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is None and end_image is None: + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(out_latent) + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + if end_image is not None: + latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0) + latent[:, :, -latent_temp.shape[-3]:] = latent_temp + mask[:, :, -latent_temp.shape[-3]:] *= 0.0 + + out_latent = {} + out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) + return io.NodeOutput(out_latent) + +class CosmosPredict2ImageToVideoLatent(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CosmosPredict2ImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[io.Latent.Output()], + ) + + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput: + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is None and end_image is None: + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(out_latent) + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + if end_image is not None: + latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0) + latent[:, :, -latent_temp.shape[-3]:] = latent_temp + mask[:, :, -latent_temp.shape[-3]:] *= 0.0 + + out_latent = {} + latent_format = comfy.latent_formats.Wan21() + latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) + out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) + return io.NodeOutput(out_latent) + + +class CosmosExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyCosmosLatentVideo, + CosmosImageToVideoLatent, + CosmosPredict2ImageToVideoLatent, + ] + + +async def comfy_entrypoint() -> CosmosExtension: + return CosmosExtension() diff --git a/ComfyUI/comfy_extras/nodes_curve.py b/ComfyUI/comfy_extras/nodes_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..aee18587b6631efc4995f0eba68bc5043838cdb4 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_curve.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import numpy as np + +from comfy_api.latest import ComfyExtension, io +from comfy_api.input import CurveInput +from typing_extensions import override + + +class CurveEditor(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CurveEditor", + display_name="Curve Editor", + category="utils", + inputs=[ + io.Curve.Input("curve"), + io.Histogram.Input("histogram", optional=True), + ], + outputs=[ + io.Curve.Output("curve"), + ], + ) + + @classmethod + def execute(cls, curve, histogram=None) -> io.NodeOutput: + result = CurveInput.from_raw(curve) + + ui = {} + if histogram is not None: + ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram) + + return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result) + + +class ImageHistogram(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ImageHistogram", + display_name="Image Histogram", + category="utils", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Histogram.Output("rgb"), + io.Histogram.Output("luminance"), + io.Histogram.Output("red"), + io.Histogram.Output("green"), + io.Histogram.Output("blue"), + ], + ) + + @classmethod + def execute(cls, image) -> io.NodeOutput: + img = image[0].cpu().numpy() + img_uint8 = np.clip(img * 255, 0, 255).astype(np.uint8) + + def bincount(data): + return np.bincount(data.ravel(), minlength=256)[:256] + + hist_r = bincount(img_uint8[:, :, 0]) + hist_g = bincount(img_uint8[:, :, 1]) + hist_b = bincount(img_uint8[:, :, 2]) + + # Average of R, G, B histograms (same as Photoshop's RGB composite) + rgb = ((hist_r + hist_g + hist_b) // 3).tolist() + + # ITU-R BT.709-6, Item 3.2 (p.6) — Derivation of luminance signal + # https://www.itu.int/rec/R-REC-BT.709-6-201506-I/en + lum = 0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] + 0.0722 * img[:, :, 2] + luminance = bincount(np.clip(lum * 255, 0, 255).astype(np.uint8)).tolist() + + return io.NodeOutput( + rgb, + luminance, + hist_r.tolist(), + hist_g.tolist(), + hist_b.tolist(), + ) + + +class CurveExtension(ComfyExtension): + @override + async def get_node_list(self): + return [CurveEditor, ImageHistogram] + + +async def comfy_entrypoint(): + return CurveExtension() diff --git a/ComfyUI/comfy_extras/nodes_custom_sampler.py b/ComfyUI/comfy_extras/nodes_custom_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..59039239fd4ad17d98b08d6e86ced7e9affb7a28 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_custom_sampler.py @@ -0,0 +1,1095 @@ +import math +import comfy.samplers +import comfy.sample +from comfy.k_diffusion import sampling as k_diffusion_sampling +from comfy.k_diffusion import sa_solver +import latent_preview +import torch +import comfy.utils +import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import re + + +class BasicScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BasicScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput: + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return io.NodeOutput(torch.FloatTensor([])) + total_steps = int(steps/denoise) + + sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() + sigmas = sigmas[-(steps + 1):] + return io.NodeOutput(sigmas) + + get_sigmas = execute + + +class KarrasScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="KarrasScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return io.NodeOutput(sigmas) + + get_sigmas = execute + +class ExponentialScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ExponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min) -> io.NodeOutput: + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) + return io.NodeOutput(sigmas) + + get_sigmas = execute + +class PolyexponentialScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PolyexponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: + sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return io.NodeOutput(sigmas) + + get_sigmas = execute + +class LaplaceScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LaplaceScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False, advanced=True), + io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False, advanced=True), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, steps, sigma_max, sigma_min, mu, beta) -> io.NodeOutput: + sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta) + return io.NodeOutput(sigmas) + + get_sigmas = execute + + +class SDTurboScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDTurboScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=1, min=1, max=10), + io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, model, steps, denoise) -> io.NodeOutput: + start_step = 10 - int(10 * denoise) + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] + sigmas = model.get_model_object("model_sampling").sigma(timesteps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + return io.NodeOutput(sigmas) + + get_sigmas = execute + +class BetaSamplingScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BetaSamplingScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True), + io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, model, steps, alpha, beta) -> io.NodeOutput: + sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) + return io.NodeOutput(sigmas) + + get_sigmas = execute + +class VPScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VPScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), #TODO: fix default values + io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, steps, beta_d, beta_min, eps_s) -> io.NodeOutput: + sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) + return io.NodeOutput(sigmas) + + get_sigmas = execute + +class SplitSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SplitSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("step", default=0, min=0, max=10000), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) + + @classmethod + def execute(cls, sigmas, step) -> io.NodeOutput: + sigmas1 = sigmas[:step + 1] + sigmas2 = sigmas[step:] + return io.NodeOutput(sigmas1, sigmas2) + + get_sigmas = execute + +class SplitSigmasDenoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SplitSigmasDenoise", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) + + @classmethod + def execute(cls, sigmas, denoise) -> io.NodeOutput: + steps = max(sigmas.shape[-1] - 1, 0) + total_steps = round(steps * denoise) + sigmas1 = sigmas[:-(total_steps)] + sigmas2 = sigmas[-(total_steps + 1):] + return io.NodeOutput(sigmas1, sigmas2) + + get_sigmas = execute + +class FlipSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FlipSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[io.Sigmas.Input("sigmas")], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: + if len(sigmas) == 0: + return io.NodeOutput(sigmas) + + sigmas = sigmas.flip(0) + if sigmas[0] == 0: + sigmas[0] = 0.0001 + return io.NodeOutput(sigmas) + + get_sigmas = execute + +class SetFirstSigma(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SetFirstSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas, sigma) -> io.NodeOutput: + sigmas = sigmas.clone() + sigmas[0] = sigma + return io.NodeOutput(sigmas) + + set_first_sigma = execute + +class ExtendIntermediateSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ExtendIntermediateSigmas", + search_aliases=["interpolate sigmas"], + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("steps", default=2, min=1, max=100), + io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False), + io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False), + io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']), + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str) -> io.NodeOutput: + if start_at_sigma < 0: + start_at_sigma = float("inf") + + interpolator = { + 'linear': lambda x: x, + 'cosine': lambda x: torch.sin(x*math.pi/2), + 'sine': lambda x: 1 - torch.cos(x*math.pi/2) + }[spacing] + + # linear space for our interpolation function + x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1] + computed_spacing = interpolator(x) + + extended_sigmas = [] + for i in range(len(sigmas) - 1): + sigma_current = sigmas[i] + sigma_next = sigmas[i+1] + + extended_sigmas.append(sigma_current) + + if end_at_sigma <= sigma_current <= start_at_sigma: + interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current + extended_sigmas.extend(interpolated_steps.tolist()) + + # Add the last sigma value + if len(sigmas) > 0: + extended_sigmas.append(sigmas[-1]) + + extended_sigmas = torch.FloatTensor(extended_sigmas) + + return io.NodeOutput(extended_sigmas) + + extend = execute + + +class SamplingPercentToSigma(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplingPercentToSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Model.Input("model"), + io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), + io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."), + ], + outputs=[io.Float.Output(display_name="sigma_value")] + ) + + @classmethod + def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput: + model_sampling = model.get_model_object("model_sampling") + sigma_val = model_sampling.percent_to_sigma(sampling_percent) + if return_actual_sigma: + if sampling_percent == 0.0: + sigma_val = model_sampling.sigma_max.item() + elif sampling_percent == 1.0: + sigma_val = model_sampling.sigma_min.item() + return io.NodeOutput(sigma_val) + + get_sigma = execute + + +class KSamplerSelect(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="KSamplerSelect", + category="sampling/custom_sampling/samplers", + inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, sampler_name) -> io.NodeOutput: + sampler = comfy.samplers.sampler_object(sampler_name) + return io.NodeOutput(sampler) + + get_sampler = execute + +class SamplerDPMPP_3M_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_3M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, eta, s_noise, noise_device) -> io.NodeOutput: + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + get_sampler = execute + +class SamplerDPMPP_2M_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=['midpoint', 'heun']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, solver_type, eta, s_noise, noise_device) -> io.NodeOutput: + if noise_device == 'cpu': + sampler_name = "dpmpp_2m_sde" + else: + sampler_name = "dpmpp_2m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) + return io.NodeOutput(sampler) + + get_sampler = execute + + +class SamplerDPMPP_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, eta, s_noise, r, noise_device) -> io.NodeOutput: + if noise_device == 'cpu': + sampler_name = "dpmpp_sde" + else: + sampler_name = "dpmpp_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) + return io.NodeOutput(sampler) + + get_sampler = execute + +class SamplerDPMPP_2S_Ancestral(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2S_Ancestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: + sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + get_sampler = execute + +class SamplerEulerAncestral(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: + sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + get_sampler = execute + +class SamplerEulerAncestralCFGPP(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestralCFGPP", + display_name="SamplerEulerAncestralCFG++", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: + sampler = comfy.samplers.ksampler( + "euler_ancestral_cfg_pp", + {"eta": eta, "s_noise": s_noise}) + return io.NodeOutput(sampler) + + get_sampler = execute + +class SamplerLMS(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerLMS", + category="sampling/custom_sampling/samplers", + inputs=[io.Int.Input("order", default=4, min=1, max=100, advanced=True)], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, order) -> io.NodeOutput: + sampler = comfy.samplers.ksampler("lms", {"order": order}) + return io.NodeOutput(sampler) + + get_sampler = execute + +class SamplerDPMAdaptative(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMAdaptative", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=3, min=2, max=3, advanced=True), + io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) -> io.NodeOutput: + sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) + return io.NodeOutput(sampler) + + get_sampler = execute + + +class SamplerER_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerER_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), + io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type.", advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput: + if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): + eta = 0 + s_noise = 0 + + def reverse_time_sde_noise_scaler(x): + return x ** (eta + 1) + + if solver_type == "ER-SDE": + # Use the default one in sample_er_sde() + noise_scaler = None + else: + noise_scaler = reverse_time_sde_noise_scaler + + sampler_name = "er_sde" + sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) + return io.NodeOutput(sampler) + + get_sampler = execute + + +class SamplerSASolver(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerSASolver", + search_aliases=["sde"], + category="sampling/custom_sampling/samplers", + inputs=[ + io.Model.Input("model"), + io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False, advanced=True), + io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Int.Input("predictor_order", default=3, min=1, max=6, advanced=True), + io.Int.Input("corrector_order", default=4, min=0, max=6, advanced=True), + io.Boolean.Input("use_pece", advanced=True), + io.Boolean.Input("simple_order_2", advanced=True), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput: + model_sampling = model.get_model_object("model_sampling") + start_sigma = model_sampling.percent_to_sigma(sde_start_percent) + end_sigma = model_sampling.percent_to_sigma(sde_end_percent) + tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta) + + sampler_name = "sa_solver" + sampler = comfy.samplers.ksampler( + sampler_name, + { + "tau_func": tau_func, + "s_noise": s_noise, + "predictor_order": predictor_order, + "corrector_order": corrector_order, + "use_pece": use_pece, + "simple_order_2": simple_order_2, + }, + ) + return io.NodeOutput(sampler) + + get_sampler = execute + + +class SamplerSEEDS2(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerSEEDS2", + search_aliases=["sde", "exp heun"], + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["phi_1", "phi_2"]), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength", advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier", advanced=True), + io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)", advanced=True), + ], + outputs=[io.Sampler.Output()], + description=( + "This sampler node can represent multiple samplers:\n\n" + "seeds_2\n" + "- default setting\n\n" + "exp_heun_2_x0\n" + "- solver_type=phi_2, r=1.0, eta=0.0\n\n" + "exp_heun_2_x0_sde\n" + "- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0" + ) + ) + + @classmethod + def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput: + sampler_name = "seeds_2" + sampler = comfy.samplers.ksampler( + sampler_name, + {"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type}, + ) + return io.NodeOutput(sampler) + + +class Noise_EmptyNoise: + def __init__(self): + self.seed = 0 + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + if latent_image.is_nested: + tensors = latent_image.unbind() + zeros = [] + for t in tensors: + zeros.append(torch.zeros(t.shape, dtype=t.dtype, layout=t.layout, device="cpu")) + return comfy.nested_tensor.NestedTensor(zeros) + else: + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + + +class Noise_RandomNoise: + def __init__(self, seed): + self.seed = seed + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None + return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) + +class SamplerCustom(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustom", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Boolean.Input("add_noise", default=True, advanced=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput: + latent = latent_image + latent_image = latent["samples"] + latent = latent.copy() + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None)) + latent["samples"] = latent_image + + if not add_noise: + noise = Noise_EmptyNoise().generate_noise(latent) + else: + noise = Noise_RandomNoise(noise_seed).generate_noise(latent) + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) + + out = latent.copy() + out.pop("downscale_ratio_spacial", None) + out["samples"] = samples + if "x0" in x0_output: + x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) + out_denoised = latent.copy() + out_denoised["samples"] = x0_out + else: + out_denoised = out + return io.NodeOutput(out, out_denoised) + + sample = execute + +class Guider_Basic(comfy.samplers.CFGGuider): + def set_conds(self, positive): + self.inner_set_conds({"positive": positive}) + +class BasicGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BasicGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("conditioning"), + ], + outputs=[io.Guider.Output()] + ) + + @classmethod + def execute(cls, model, conditioning) -> io.NodeOutput: + guider = Guider_Basic(model) + guider.set_conds(conditioning) + return io.NodeOutput(guider) + + get_guider = execute + +class CFGGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + ], + outputs=[io.Guider.Output()] + ) + + @classmethod + def execute(cls, model, positive, negative, cfg) -> io.NodeOutput: + guider = comfy.samplers.CFGGuider(model) + guider.set_conds(positive, negative) + guider.set_cfg(cfg) + return io.NodeOutput(guider) + + get_guider = execute + +class Guider_DualCFG(comfy.samplers.CFGGuider): + def set_cfg(self, cfg1, cfg2, nested=False): + self.cfg1 = cfg1 + self.cfg2 = cfg2 + self.nested = nested + + def set_conds(self, positive, middle, negative): + middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative}) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + negative_cond = self.conds.get("negative", None) + middle_cond = self.conds.get("middle", None) + positive_cond = self.conds.get("positive", None) + + if self.nested: + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) + pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond) + return out[0] + self.cfg2 * (pred_text - out[0]) + else: + if model_options.get("disable_cfg1_optimization", False) == False: + if math.isclose(self.cfg2, 1.0): + negative_cond = None + if math.isclose(self.cfg1, 1.0): + middle_cond = None + + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) + return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 + +class DualCFGGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DualCFGGuider", + search_aliases=["dual prompt guidance"], + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("cond1"), + io.Conditioning.Input("cond2"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Combo.Input("style", options=["regular", "nested"]), + ], + outputs=[io.Guider.Output()] + ) + + @classmethod + def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput: + guider = Guider_DualCFG(model) + guider.set_conds(cond1, cond2, negative) + guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) + return io.NodeOutput(guider) + + get_guider = execute + +class DisableNoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DisableNoise", + search_aliases=["zero noise"], + category="sampling/custom_sampling/noise", + inputs=[], + outputs=[io.Noise.Output()] + ) + + @classmethod + def execute(cls) -> io.NodeOutput: + return io.NodeOutput(Noise_EmptyNoise()) + + get_noise = execute + + +class RandomNoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="RandomNoise", + category="sampling/custom_sampling/noise", + inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)], + outputs=[io.Noise.Output()] + ) + + @classmethod + def execute(cls, noise_seed) -> io.NodeOutput: + return io.NodeOutput(Noise_RandomNoise(noise_seed)) + + get_noise = execute + + +class SamplerCustomAdvanced(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustomAdvanced", + category="sampling/custom_sampling", + inputs=[ + io.Noise.Input("noise"), + io.Guider.Input("guider"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput: + latent = latent_image + latent_image = latent["samples"] + latent = latent.copy() + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None)) + latent["samples"] = latent_image + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) + samples = samples.to(comfy.model_management.intermediate_device()) + + out = latent.copy() + out.pop("downscale_ratio_spacial", None) + out["samples"] = samples + if "x0" in x0_output: + x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) + out_denoised = latent.copy() + out_denoised["samples"] = x0_out + else: + out_denoised = out + return io.NodeOutput(out, out_denoised) + + sample = execute + +class AddNoise(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="AddNoise", + category="_for_testing/custom_sampling/noise", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Noise.Input("noise"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(), + ] + ) + + @classmethod + def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput: + if len(sigmas) == 0: + return io.NodeOutput(latent_image) + + latent = latent_image + latent_image = latent["samples"] + + noisy = noise.generate_noise(latent) + + model_sampling = model.get_model_object("model_sampling") + process_latent_out = model.get_model_object("process_latent_out") + process_latent_in = model.get_model_object("process_latent_in") + + if len(sigmas) > 1: + scale = torch.abs(sigmas[0] - sigmas[-1]) + else: + scale = sigmas[0] + + if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = process_latent_in(latent_image) + noisy = model_sampling.noise_scaling(scale, noisy, latent_image) + noisy = process_latent_out(noisy) + noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0) + + out = latent.copy() + out["samples"] = noisy + return io.NodeOutput(out) + + add_noise = execute + +class ManualSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ManualSigmas", + search_aliases=["custom noise schedule", "define sigmas"], + category="_for_testing/custom_sampling", + is_experimental=True, + inputs=[ + io.String.Input("sigmas", default="1, 0.5", multiline=False) + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: + sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas) + sigmas = [float(i) for i in sigmas] + sigmas = torch.FloatTensor(sigmas) + return io.NodeOutput(sigmas) + +class CustomSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerCustom, + BasicScheduler, + KarrasScheduler, + ExponentialScheduler, + PolyexponentialScheduler, + LaplaceScheduler, + VPScheduler, + BetaSamplingScheduler, + SDTurboScheduler, + KSamplerSelect, + SamplerEulerAncestral, + SamplerEulerAncestralCFGPP, + SamplerLMS, + SamplerDPMPP_3M_SDE, + SamplerDPMPP_2M_SDE, + SamplerDPMPP_SDE, + SamplerDPMPP_2S_Ancestral, + SamplerDPMAdaptative, + SamplerER_SDE, + SamplerSASolver, + SamplerSEEDS2, + SplitSigmas, + SplitSigmasDenoise, + FlipSigmas, + SetFirstSigma, + ExtendIntermediateSigmas, + SamplingPercentToSigma, + CFGGuider, + DualCFGGuider, + BasicGuider, + RandomNoise, + DisableNoise, + AddNoise, + SamplerCustomAdvanced, + ManualSigmas, + ] + + +async def comfy_entrypoint() -> CustomSamplersExtension: + return CustomSamplersExtension() diff --git a/ComfyUI/comfy_extras/nodes_dataset.py b/ComfyUI/comfy_extras/nodes_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f56206b49a1ce4d74791eaac1220d3f704e5899b --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_dataset.py @@ -0,0 +1,1537 @@ +import logging +import os +import json + +import numpy as np +import torch +from PIL import Image +from typing_extensions import override + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +def load_and_process_images(image_files, input_dir): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return output_images + + +class LoadImageDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageDataSetFromFolder", + display_name="Load Image Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ) + ], + ) + + @classmethod + def execute(cls, folder): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir) + return io.NodeOutput(output_tensor) + + +class LoadImageTextDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageTextDataSetFromFolder", + display_name="Load Image and Text Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ), + io.String.Output( + display_name="texts", + is_output_list=True, + tooltip="List of text captions", + ), + ], + ) + + @classmethod + def execute(cls, folder): + logging.info(f"Loading images from folder: {folder}") + + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + + image_files = [] + for item in os.listdir(sub_input_dir): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in valid_extensions): + image_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + image_files.extend( + [ + os.path.join(path, f) + for f in os.listdir(path) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + * repeat + ) + + caption_file_path = [ + f.replace(os.path.splitext(f)[1], ".txt") for f in image_files + ] + captions = [] + for caption_file in caption_file_path: + caption_path = os.path.join(sub_input_dir, caption_file) + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + captions.append(caption) + else: + captions.append("") + + output_tensor = load_and_process_images(image_files, sub_input_dir) + + logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") + return io.NodeOutput(output_tensor, captions) + + +def save_images_to_folder(image_list, output_dir, prefix="image"): + """Utility function to save a list of image tensors to disk. + + Args: + image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W]) + output_dir: Directory to save images to + prefix: Filename prefix + + Returns: + List of saved filenames + """ + os.makedirs(output_dir, exist_ok=True) + saved_files = [] + + for idx, img_tensor in enumerate(image_list): + # Handle different tensor shapes + if isinstance(img_tensor, torch.Tensor): + # Remove batch dimension if present [1, H, W, C] -> [H, W, C] + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + + # If tensor is [C, H, W], permute to [H, W, C] + if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]: + if ( + img_tensor.shape[0] <= 4 + and img_tensor.shape[1] > 4 + and img_tensor.shape[2] > 4 + ): + img_tensor = img_tensor.permute(1, 2, 0) + + # Convert to numpy and scale to 0-255 + img_array = img_tensor.cpu().numpy() + img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8) + + # Convert to PIL Image + img = Image.fromarray(img_array) + else: + raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") + + # Save image + filename = f"{prefix}_{idx:05d}.png" + filepath = os.path.join(output_dir, filename) + img.save(filepath) + saved_files.append(filename) + + return saved_files + + +class SaveImageDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageDataSetToFolder", + display_name="Save Image Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive images as list + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + advanced=True, + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + logging.info(f"Saved {len(saved_files)} images to {output_dir}.") + return io.NodeOutput() + + +class SaveImageTextDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageTextDataSetToFolder", + display_name="Save Image and Text Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive both images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input("texts", tooltip="List of text captions to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + advanced=True, + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, texts, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + # Save captions + for idx, (filename, caption) in enumerate(zip(saved_files, texts)): + caption_filename = filename.replace(".png", ".txt") + caption_path = os.path.join(output_dir, caption_filename) + with open(caption_path, "w", encoding="utf-8") as f: + f.write(caption) + + logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") + return io.NodeOutput() + + +# ========== Helper Functions for Transform Nodes ========== + + +def tensor_to_pil(img_tensor): + """Convert tensor to PIL Image.""" + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(img_array) + + +def pil_to_tensor(img): + """Convert PIL Image to tensor.""" + img_array = np.array(img).astype(np.float32) / 255.0 + return torch.from_numpy(img_array)[None,] + + +# ========== Base Classes for Transform Nodes ========== + + +class ImageProcessingNode(io.ComfyNode): + """Base class for image processing nodes that operate on images. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "images" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, image, **kwargs) -> tensor (for single-item processing) + _group_process(cls, images, **kwargs) -> list[tensor] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = ImageProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + # Auto-detect is_output_list if not explicitly set + # Single processing: False (backend collects results into list) + # Group processing: True by default (can be False for single-output nodes) + output_is_list = ( + cls.is_output_list if cls.is_output_list is not None else is_group + ) + + inputs = [ + io.Image.Input( + "images", + tooltip=( + "List of images to process." if is_group else "Image to process." + ), + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/image", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=output_is_list, + tooltip="Processed images", + ) + ], + ) + + @classmethod + def execute(cls, images, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: images is list, call _group_process + result = cls._group_process(images, **params) + else: + # Individual processing: images is single item, call _process + result = cls._process(images, **params) + + return io.NodeOutput(result) + + @classmethod + def _process(cls, image, **kwargs): + """Override this method for single-item processing. + + Args: + image: tensor - Single image tensor + **kwargs: Additional parameters (already extracted from lists) + + Returns: + tensor - Processed image + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, images, **kwargs): + """Override this method for group processing. + + Args: + images: list[tensor] - List of image tensors + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[tensor] - Processed images + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +class TextProcessingNode(io.ComfyNode): + """Base class for text processing nodes that operate on texts. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "texts" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, text, **kwargs) -> str (for single-item processing) + _group_process(cls, texts, **kwargs) -> list[str] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = TextProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + inputs = [ + io.String.Input( + "texts", + tooltip="List of texts to process." if is_group else "Text to process.", + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/text", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.String.Output( + display_name="texts", + is_output_list=cls.is_output_list, + tooltip="Processed texts", + ) + ], + ) + + @classmethod + def execute(cls, texts, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: texts is list, call _group_process + result = cls._group_process(texts, **params) + else: + # Individual processing: texts is single item, call _process + result = cls._process(texts, **params) + + # Wrap result based on is_output_list + if cls.is_output_list: + # Result should already be a list (or will be for individual) + return io.NodeOutput(result if is_group else [result]) + else: + # Single output - wrap in list for NodeOutput + return io.NodeOutput([result]) + + @classmethod + def _process(cls, text, **kwargs): + """Override this method for single-item processing. + + Args: + text: str - Single text string + **kwargs: Additional parameters (already extracted from lists) + + Returns: + str - Processed text + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, texts, **kwargs): + """Override this method for group processing. + + Args: + texts: list[str] - List of text strings + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[str] - Processed texts + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +# ========== Image Transform Nodes ========== + + +class ResizeImagesByShorterEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByShorterEdge" + display_name = "Resize Images by Shorter Edge" + description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "shorter_edge", + default=512, + min=1, + max=8192, + tooltip="Target length for the shorter edge.", + ), + ] + + @classmethod + def _process(cls, image, shorter_edge): + img = tensor_to_pil(image) + w, h = img.size + if w < h: + new_w = shorter_edge + new_h = int(h * (shorter_edge / w)) + else: + new_h = shorter_edge + new_w = int(w * (shorter_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByLongerEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByLongerEdge" + display_name = "Resize Images by Longer Edge" + description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "longer_edge", + default=1024, + min=1, + max=8192, + tooltip="Target length for the longer edge.", + ), + ] + + @classmethod + def _process(cls, image, longer_edge): + resized_images = [] + for image_i in image: + img = tensor_to_pil(image_i) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + resized_images.append(pil_to_tensor(img)) + return torch.cat(resized_images, dim=0) + + +class CenterCropImagesNode(ImageProcessingNode): + node_id = "CenterCropImages" + display_name = "Center Crop Images" + description = "Center crop all images to the specified dimensions." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + ] + + @classmethod + def _process(cls, image, width, height): + img = tensor_to_pil(image) + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class RandomCropImagesNode(ImageProcessingNode): + node_id = "RandomCropImages" + display_name = "Random Crop Images" + description = ( + "Randomly crop all images to the specified dimensions (for data augmentation)." + ) + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _process(cls, image, width, height, seed): + np.random.seed(seed % (2**32 - 1)) + img = tensor_to_pil(image) + max_left = max(0, img.width - width) + max_top = max(0, img.height - height) + left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 + top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class NormalizeImagesNode(ImageProcessingNode): + node_id = "NormalizeImages" + display_name = "Normalize Images" + description = "Normalize images using mean and standard deviation." + extra_inputs = [ + io.Float.Input( + "mean", + default=0.5, + min=0.0, + max=1.0, + tooltip="Mean value for normalization.", + advanced=True, + ), + io.Float.Input( + "std", + default=0.5, + min=0.001, + max=1.0, + tooltip="Standard deviation for normalization.", + advanced=True, + ), + ] + + @classmethod + def _process(cls, image, mean, std): + return (image - mean) / std + + +class AdjustBrightnessNode(ImageProcessingNode): + node_id = "AdjustBrightness" + display_name = "Adjust Brightness" + description = "Adjust brightness of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return (image * factor).clamp(0.0, 1.0) + + +class AdjustContrastNode(ImageProcessingNode): + node_id = "AdjustContrast" + display_name = "Adjust Contrast" + description = "Adjust contrast of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0) + + +class ShuffleDatasetNode(ImageProcessingNode): + node_id = "ShuffleDataset" + display_name = "Shuffle Image Dataset" + description = "Randomly shuffle the order of images in the dataset." + is_group_process = True # Requires full list to shuffle + extra_inputs = [ + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _group_process(cls, images, seed): + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + return [images[i] for i in indices] + + +class ShuffleImageTextDatasetNode(io.ComfyNode): + """Special node that shuffles both images and texts together.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ShuffleImageTextDataset", + display_name="Shuffle Image-Text Dataset", + category="dataset/image", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("images", tooltip="List of images to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle."), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="Shuffled images", + ), + io.String.Output( + display_name="texts", is_output_list=True, tooltip="Shuffled texts" + ), + ], + ) + + @classmethod + def execute(cls, images, texts, seed): + seed = seed[0] # Extract scalar + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + shuffled_images = [images[i] for i in indices] + shuffled_texts = [texts[i] for i in indices] + return io.NodeOutput(shuffled_images, shuffled_texts) + + +# ========== Text Transform Nodes ========== + + +class TextToLowercaseNode(TextProcessingNode): + node_id = "TextToLowercase" + display_name = "Text to Lowercase" + description = "Convert all texts to lowercase." + + @classmethod + def _process(cls, text): + return text.lower() + + +class TextToUppercaseNode(TextProcessingNode): + node_id = "TextToUppercase" + display_name = "Text to Uppercase" + description = "Convert all texts to uppercase." + + @classmethod + def _process(cls, text): + return text.upper() + + +class TruncateTextNode(TextProcessingNode): + node_id = "TruncateText" + display_name = "Truncate Text" + description = "Truncate all texts to a maximum length." + extra_inputs = [ + io.Int.Input( + "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." + ), + ] + + @classmethod + def _process(cls, text, max_length): + return text[:max_length] + + +class AddTextPrefixNode(TextProcessingNode): + node_id = "AddTextPrefix" + display_name = "Add Text Prefix" + description = "Add a prefix to all texts." + extra_inputs = [ + io.String.Input("prefix", default="", tooltip="Prefix to add."), + ] + + @classmethod + def _process(cls, text, prefix): + return prefix + text + + +class AddTextSuffixNode(TextProcessingNode): + node_id = "AddTextSuffix" + display_name = "Add Text Suffix" + description = "Add a suffix to all texts." + extra_inputs = [ + io.String.Input("suffix", default="", tooltip="Suffix to add."), + ] + + @classmethod + def _process(cls, text, suffix): + return text + suffix + + +class ReplaceTextNode(TextProcessingNode): + node_id = "ReplaceText" + display_name = "Replace Text" + description = "Replace text in all texts." + extra_inputs = [ + io.String.Input("find", default="", tooltip="Text to find."), + io.String.Input("replace", default="", tooltip="Text to replace with."), + ] + + @classmethod + def _process(cls, text, find, replace): + return text.replace(find, replace) + + +class StripWhitespaceNode(TextProcessingNode): + node_id = "StripWhitespace" + display_name = "Strip Whitespace" + description = "Strip leading and trailing whitespace from all texts." + + @classmethod + def _process(cls, text): + return text.strip() + + +# ========== Group Processing Example Nodes ========== + + +class ImageDeduplicationNode(ImageProcessingNode): + """Remove duplicate or very similar images from the dataset using perceptual hashing.""" + + node_id = "ImageDeduplication" + display_name = "Image Deduplication" + description = "Remove duplicate or very similar images from the dataset." + is_group_process = True # Requires full list to compare images + extra_inputs = [ + io.Float.Input( + "similarity_threshold", + default=0.95, + min=0.0, + max=1.0, + tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", + advanced=True, + ), + ] + + @classmethod + def _group_process(cls, images, similarity_threshold): + """Remove duplicate images using perceptual hashing.""" + if len(images) == 0: + return [] + + # Compute simple perceptual hash for each image + def compute_hash(img_tensor): + """Compute a simple perceptual hash by resizing to 8x8 and comparing to average.""" + img = tensor_to_pil(img_tensor) + # Resize to 8x8 + img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L") + # Get pixels + pixels = list(img_small.getdata()) + # Compute average + avg = sum(pixels) / len(pixels) + # Create hash (1 if above average, 0 otherwise) + hash_bits = "".join("1" if p > avg else "0" for p in pixels) + return hash_bits + + def hamming_distance(hash1, hash2): + """Compute Hamming distance between two hash strings.""" + return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) + + # Compute hashes for all images + hashes = [compute_hash(img) for img in images] + + # Find duplicates + keep_indices = [] + for i in range(len(images)): + is_duplicate = False + for j in keep_indices: + # Compare hashes + distance = hamming_distance(hashes[i], hashes[j]) + similarity = 1.0 - (distance / 64.0) # 64 bits total + if similarity >= similarity_threshold: + is_duplicate = True + logging.info( + f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" + ) + break + + if not is_duplicate: + keep_indices.append(i) + + # Return only unique images + unique_images = [images[i] for i in keep_indices] + logging.info( + f"Deduplication: kept {len(unique_images)} out of {len(images)} images" + ) + return unique_images + + +class ImageGridNode(ImageProcessingNode): + """Combine multiple images into a single grid/collage.""" + + node_id = "ImageGrid" + display_name = "Image Grid" + description = "Arrange multiple images into a grid layout." + is_group_process = True # Requires full list to create grid + is_output_list = False # Outputs single grid image + extra_inputs = [ + io.Int.Input( + "columns", + default=4, + min=1, + max=20, + tooltip="Number of columns in the grid.", + ), + io.Int.Input( + "cell_width", + default=256, + min=32, + max=2048, + tooltip="Width of each cell in the grid.", + advanced=True, + ), + io.Int.Input( + "cell_height", + default=256, + min=32, + max=2048, + tooltip="Height of each cell in the grid.", + advanced=True, + ), + io.Int.Input( + "padding", default=4, min=0, max=50, tooltip="Padding between images.", advanced=True + ), + ] + + @classmethod + def _group_process(cls, images, columns, cell_width, cell_height, padding): + """Arrange images into a grid.""" + if len(images) == 0: + raise ValueError("Cannot create grid from empty image list") + + # Calculate grid dimensions + num_images = len(images) + rows = (num_images + columns - 1) // columns # Ceiling division + + # Calculate total grid size + grid_width = columns * cell_width + (columns - 1) * padding + grid_height = rows * cell_height + (rows - 1) * padding + + # Create blank grid + grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0)) + + # Place images + for idx, img_tensor in enumerate(images): + row = idx // columns + col = idx % columns + + # Convert to PIL and resize to cell size + img = tensor_to_pil(img_tensor) + img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS) + + # Calculate position + x = col * (cell_width + padding) + y = row * (cell_height + padding) + + # Paste into grid + grid.paste(img, (x, y)) + + logging.info( + f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" + ) + return pil_to_tensor(grid) + + +class MergeImageListsNode(ImageProcessingNode): + """Merge multiple image lists into a single list.""" + + node_id = "MergeImageLists" + display_name = "Merge Image Lists" + description = "Concatenate multiple image lists into one." + is_group_process = True # Receives images as list + + @classmethod + def _group_process(cls, images): + """Simply return the images list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged image list contains {len(images)} images") + return images + + +class MergeTextListsNode(TextProcessingNode): + """Merge multiple text lists into a single list.""" + + node_id = "MergeTextLists" + display_name = "Merge Text Lists" + description = "Concatenate multiple text lists into one." + is_group_process = True # Receives texts as list + + @classmethod + def _group_process(cls, texts): + """Simply return the texts list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged text list contains {len(texts)} texts") + return texts + + +# ========== Training Dataset Nodes ========== + + +class ResolutionBucket(io.ComfyNode): + """Bucket latents and conditions by resolution for efficient batch training.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResolutionBucket", + display_name="Resolution Bucket", + category="dataset", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts to bucket by resolution.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists (must match latents length).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of batched latent dicts, one per resolution bucket.", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of condition lists, one per resolution bucket.", + ), + ], + ) + + @classmethod + def execute(cls, latents, conditioning): + # latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1 + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})." + ) + + # Flatten latents and conditions to individual samples + flat_latents = [] # list of (C, H, W) tensors + flat_conditions = [] # list of condition lists + + for latent_dict, cond in zip(latents, conditioning): + samples = latent_dict["samples"] # (B, C, H, W) + batch_size = samples.shape[0] + + # cond is a list of conditions with length == batch_size + for i in range(batch_size): + flat_latents.append(samples[i]) # (C, H, W) + flat_conditions.append(cond[i]) # single condition + + # Group by resolution (H, W) + buckets = {} # (H, W) -> {"latents": list, "conditions": list} + + for latent, cond in zip(flat_latents, flat_conditions): + # latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W) + h, w = latent.shape[-2], latent.shape[-1] + key = (h, w) + + if key not in buckets: + buckets[key] = {"latents": [], "conditions": []} + + buckets[key]["latents"].append(latent) + buckets[key]["conditions"].append(cond) + + # Convert buckets to output format + output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W) + output_conditions = [] # list[list[cond]] where each inner list has Bi conditions + + for (h, w), bucket_data in buckets.items(): + # Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W) + stacked_latents = torch.stack(bucket_data["latents"], dim=0) + output_latents.append({"samples": stacked_latents}) + + # Conditions stay as list of condition lists + output_conditions.append(bucket_data["conditions"]) + + logging.info( + f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples" + ) + + logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples") + return io.NodeOutput(output_latents, output_conditions) + + +class MakeTrainingDataset(io.ComfyNode): + """Encode images with VAE and texts with CLIP to create a training dataset.""" + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MakeTrainingDataset", + search_aliases=["encode dataset"], + display_name="Make Training Dataset", + category="dataset", + is_experimental=True, + is_input_list=True, # images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to encode."), + io.Vae.Input( + "vae", tooltip="VAE model for encoding images to latents." + ), + io.Clip.Input( + "clip", tooltip="CLIP model for encoding text to conditioning." + ), + io.String.Input( + "texts", + optional=True, + tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, images, vae, clip, texts=None): + # Extract scalars (vae and clip are single values wrapped in lists) + vae = vae[0] + clip = clip[0] + + # Handle text list + num_images = len(images) + + if texts is None or len(texts) == 0: + # Treat as [""] for unconditional training + texts = [""] + + if len(texts) == 1 and num_images > 1: + # Repeat single text for all images + texts = texts * num_images + elif len(texts) != num_images: + raise ValueError( + f"Number of texts ({len(texts)}) does not match number of images ({num_images}). " + f"Text list should have length {num_images}, 1, or 0." + ) + + # Encode images with VAE + logging.info(f"Encoding {num_images} images with VAE...") + latents_list = [] # list[{"samples": tensor}] + for img_tensor in images: + # img_tensor is [1, H, W, 3] + latent_tensor = vae.encode(img_tensor[:, :, :, :3]) + latents_list.append({"samples": latent_tensor}) + + # Encode texts with CLIP + logging.info(f"Encoding {len(texts)} texts with CLIP...") + conditioning_list = [] # list[list[cond]] + for text in texts: + if text == "": + cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + else: + tokens = clip.tokenize(text) + cond = clip.encode_from_tokens_scheduled(tokens) + conditioning_list.append(cond) + + logging.info( + f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." + ) + return io.NodeOutput(latents_list, conditioning_list) + + +class SaveTrainingDataset(io.ComfyNode): + """Save encoded training dataset (latents + conditioning) to disk.""" + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveTrainingDataset", + search_aliases=["export training data"], + display_name="Save Training Dataset", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive lists + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts from MakeTrainingDataset.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists from MakeTrainingDataset.", + ), + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder to save dataset (inside output directory).", + ), + io.Int.Input( + "shard_size", + default=1000, + min=1, + max=100000, + tooltip="Number of samples per shard file.", + advanced=True, + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, latents, conditioning, folder_name, shard_size): + # Extract scalars + folder_name = folder_name[0] + shard_size = shard_size[0] + + # latents: list[{"samples": tensor}] + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " + f"Something went wrong in dataset preparation." + ) + + # Create output directory + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + os.makedirs(output_dir, exist_ok=True) + + # Prepare data pairs + num_samples = len(latents) + num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division + + logging.info( + f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." + ) + + # Save data in shards + for shard_idx in range(num_shards): + start_idx = shard_idx * shard_size + end_idx = min(start_idx + shard_size, num_samples) + + # Get shard data (list of latent dicts and conditioning lists) + shard_data = { + "latents": latents[start_idx:end_idx], + "conditioning": conditioning[start_idx:end_idx], + } + + # Save shard + shard_filename = f"shard_{shard_idx:04d}.pkl" + shard_path = os.path.join(output_dir, shard_filename) + + with open(shard_path, "wb") as f: + torch.save(shard_data, f) + + logging.info( + f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" + ) + + # Save metadata + metadata = { + "num_samples": num_samples, + "num_shards": num_shards, + "shard_size": shard_size, + } + metadata_path = os.path.join(output_dir, "metadata.json") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") + return io.NodeOutput() + + +class LoadTrainingDataset(io.ComfyNode): + """Load encoded training dataset from disk.""" + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadTrainingDataset", + search_aliases=["import dataset", "training data"], + display_name="Load Training Dataset", + category="dataset", + is_experimental=True, + inputs=[ + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder containing the saved dataset (inside output directory).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, folder_name): + # Get dataset directory + dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + + if not os.path.exists(dataset_dir): + raise ValueError(f"Dataset directory not found: {dataset_dir}") + + # Find all shard files + shard_files = sorted( + [ + f + for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".pkl") + ] + ) + + if not shard_files: + raise ValueError(f"No shard files found in {dataset_dir}") + + logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") + + # Load all shards + all_latents = [] # list[{"samples": tensor}] + all_conditioning = [] # list[list[cond]] + + for shard_file in shard_files: + shard_path = os.path.join(dataset_dir, shard_file) + + with open(shard_path, "rb") as f: + shard_data = torch.load(f) + + all_latents.extend(shard_data["latents"]) + all_conditioning.extend(shard_data["conditioning"]) + + logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") + + logging.info( + f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." + ) + return io.NodeOutput(all_latents, all_conditioning) + + +# ========== Extension Setup ========== + + +class DatasetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # Data loading/saving nodes + LoadImageDataSetFromFolderNode, + LoadImageTextDataSetFromFolderNode, + SaveImageDataSetToFolderNode, + SaveImageTextDataSetToFolderNode, + # Image transform nodes + ResizeImagesByShorterEdgeNode, + ResizeImagesByLongerEdgeNode, + CenterCropImagesNode, + RandomCropImagesNode, + NormalizeImagesNode, + AdjustBrightnessNode, + AdjustContrastNode, + ShuffleDatasetNode, + ShuffleImageTextDatasetNode, + # Text transform nodes + TextToLowercaseNode, + TextToUppercaseNode, + TruncateTextNode, + AddTextPrefixNode, + AddTextSuffixNode, + ReplaceTextNode, + StripWhitespaceNode, + # Group processing examples + ImageDeduplicationNode, + ImageGridNode, + MergeImageListsNode, + MergeTextListsNode, + # Training dataset nodes + MakeTrainingDataset, + SaveTrainingDataset, + LoadTrainingDataset, + ResolutionBucket, + ] + + +async def comfy_entrypoint() -> DatasetExtension: + return DatasetExtension() diff --git a/ComfyUI/comfy_extras/nodes_differential_diffusion.py b/ComfyUI/comfy_extras/nodes_differential_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..5279c6a69c186a2e161f601d139ff833b7f0e6c4 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_differential_diffusion.py @@ -0,0 +1,73 @@ +# code adapted from https://github.com/exx8/differential-diffusion + +from typing_extensions import override + +import torch +from comfy_api.latest import ComfyExtension, io + + +class DifferentialDiffusion(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DifferentialDiffusion", + search_aliases=["inpaint gradient", "variable denoise strength"], + display_name="Differential Diffusion", + category="_for_testing", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "strength", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + optional=True, + ), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, strength=1.0) -> io.NodeOutput: + model = model.clone() + model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength)) + return io.NodeOutput(model) + + @classmethod + def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float): + model = extra_options["model"] + step_sigmas = extra_options["sigmas"] + sigma_to = model.inner_model.model_sampling.sigma_min + if step_sigmas[-1] > sigma_to: + sigma_to = step_sigmas[-1] + sigma_from = step_sigmas[0] + + ts_from = model.inner_model.model_sampling.timestep(sigma_from) + ts_to = model.inner_model.model_sampling.timestep(sigma_to) + current_ts = model.inner_model.model_sampling.timestep(sigma[0]) + + threshold = (current_ts - ts_to) / (ts_from - ts_to) + + # Generate the binary mask based on the threshold + binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype) + + # Blend binary mask with the original denoise_mask using strength + if strength and strength < 1: + blended_mask = strength * binary_mask + (1 - strength) * denoise_mask + return blended_mask + else: + return binary_mask + + +class DifferentialDiffusionExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + DifferentialDiffusion, + ] + + +async def comfy_entrypoint() -> DifferentialDiffusionExtension: + return DifferentialDiffusionExtension() diff --git a/ComfyUI/comfy_extras/nodes_easycache.py b/ComfyUI/comfy_extras/nodes_easycache.py new file mode 100644 index 0000000000000000000000000000000000000000..b985960274c2ab1684364685074b486cbf259357 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_easycache.py @@ -0,0 +1,530 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Union +from comfy_api.latest import io, ComfyExtension +import comfy.patcher_extension +import logging +import torch +import comfy.model_patcher +if TYPE_CHECKING: + from uuid import UUID + + +def _extract_tensor(data, output_channels): + """Extract tensor from data, handling both single tensors and lists.""" + if isinstance(data, list): + # LTX2 AV tensors: [video, audio] + return data[0][:, :output_channels], data[1][:, :output_channels] + return data[:, :output_channels], None + + +def easycache_forward_wrapper(executor, *args, **kwargs): + # get values from args + transformer_options: dict[str] = args[-1] + if not isinstance(transformer_options, dict): + transformer_options = kwargs.get("transformer_options") + if not transformer_options: + transformer_options = args[-2] + easycache: EasyCacheHolder = transformer_options["easycache"] + x, ax = _extract_tensor(args[0], easycache.output_channels) + sigmas = transformer_options["sigmas"] + uuids = transformer_options["uuids"] + if sigmas is not None and easycache.is_past_end_timestep(sigmas): + return executor(*args, **kwargs) + # prepare next x_prev + has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) + next_x_prev = x + input_change = None + do_easycache = easycache.should_do_easycache(sigmas) + if do_easycache: + easycache.check_metadata(x) + # if there isn't a cache diff for current conds, we cannot skip this step + can_apply_cache_diff = easycache.can_apply_cache_diff(uuids) + # if first cond marked this step for skipping, skip it and use appropriate cached values + if easycache.skip_current_step and can_apply_cache_diff: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result + if easycache.initial_step: + easycache.first_cond_uuid = uuids[0] + has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) + easycache.initial_step = False + if has_first_cond_uuid: + if easycache.has_x_prev_subsampled(): + input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() + if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.cumulative_change_rate += approx_output_change_rate + if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + # other conds should also skip this step, and instead use their cached values + easycache.skip_current_step = True + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result + else: + if easycache.verbose: + logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + easycache.cumulative_change_rate = 0.0 + + full_output: torch.Tensor = executor(*args, **kwargs) + output, audio_output = _extract_tensor(full_output, easycache.output_channels) + if has_first_cond_uuid and easycache.has_output_prev_norm(): + output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() + if easycache.verbose: + output_change_rate = output_change / easycache.output_prev_norm + easycache.output_change_rates.append(output_change_rate.item()) + if easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.approx_output_change_rates.append(approx_output_change_rate.item()) + if easycache.verbose: + logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") + if input_change is not None: + easycache.relative_transformation_rate = output_change / input_change + if easycache.verbose: + logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}") + # TODO: allow cache_diff to be offloaded + easycache.update_cache_diff(output, next_x_prev, uuids) + if audio_output is not None: + easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True) + if has_first_cond_uuid: + easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids) + easycache.output_prev_subsampled = easycache.subsample(output, uuids) + easycache.output_prev_norm = output.flatten().abs().mean() + if easycache.verbose: + logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") + return full_output + +def lazycache_predict_noise_wrapper(executor, *args, **kwargs): + # get values from args + timestep: float = args[1] + model_options: dict[str] = args[2] + easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] + if easycache.is_past_end_timestep(timestep): + return executor(*args, **kwargs) + x: torch.Tensor = args[0][:, :easycache.output_channels] + # prepare next x_prev + next_x_prev = x + input_change = None + do_easycache = easycache.should_do_easycache(timestep) + if do_easycache: + easycache.check_metadata(x) + if easycache.has_x_prev_subsampled(): + if easycache.has_x_prev_subsampled(): + input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() + if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.cumulative_change_rate += approx_output_change_rate + if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.verbose: + logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + # other conds should also skip this step, and instead use their cached values + easycache.skip_current_step = True + return easycache.apply_cache_diff(x) + else: + if easycache.verbose: + logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") + easycache.cumulative_change_rate = 0.0 + output: torch.Tensor = executor(*args, **kwargs) + if easycache.has_output_prev_norm(): + output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() + if easycache.verbose: + output_change_rate = output_change / easycache.output_prev_norm + easycache.output_change_rates.append(output_change_rate.item()) + if easycache.has_relative_transformation_rate(): + approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm + easycache.approx_output_change_rates.append(approx_output_change_rate.item()) + if easycache.verbose: + logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") + if input_change is not None: + easycache.relative_transformation_rate = output_change / input_change + if easycache.verbose: + logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}") + # TODO: allow cache_diff to be offloaded + easycache.update_cache_diff(output, next_x_prev) + easycache.x_prev_subsampled = easycache.subsample(next_x_prev) + easycache.output_prev_subsampled = easycache.subsample(output) + easycache.output_prev_norm = output.flatten().abs().mean() + if easycache.verbose: + logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") + return output + +def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs): + model_options = args[-1] + easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"] + easycache.skip_current_step = False + # TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset + return executor(*args, **kwargs) + +def easycache_sample_wrapper(executor, *args, **kwargs): + """ + This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end. + """ + try: + guider = executor.class_obj + orig_model_options = guider.model_options + guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options) + # clone and prepare timesteps + guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling) + easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache'] + logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}") + return executor(*args, **kwargs) + finally: + easycache = guider.model_options['transformer_options']['easycache'] + output_change_rates = easycache.output_change_rates + approx_output_change_rates = easycache.approx_output_change_rates + if easycache.verbose: + logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}") + logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}") + total_steps = len(args[3])-1 + # catch division by zero for log statement; sucks to crash after all sampling is done + try: + speedup = total_steps/(total_steps-easycache.total_steps_skipped) + except ZeroDivisionError: + speedup = 1.0 + logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).") + easycache.reset() + guider.model_options = orig_model_options + + +class EasyCacheHolder: + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None): + self.name = "EasyCache" + self.reuse_threshold = reuse_threshold + self.start_percent = start_percent + self.end_percent = end_percent + self.subsample_factor = subsample_factor + self.offload_cache_diff = offload_cache_diff + self.verbose = verbose + # timestep values + self.start_t = 0.0 + self.end_t = 0.0 + # control values + self.relative_transformation_rate: float = None + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.skip_current_step = False + # cache values + self.first_cond_uuid = None + self.x_prev_subsampled: torch.Tensor = None + self.output_prev_subsampled: torch.Tensor = None + self.output_prev_norm: torch.Tensor = None + self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {} + self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {} + self.output_change_rates = [] + self.approx_output_change_rates = [] + self.total_steps_skipped = 0 + # how to deal with mismatched dims + self.allow_mismatch = True + self.cut_from_start = True + self.state_metadata = None + self.output_channels = output_channels + + def is_past_end_timestep(self, timestep: float) -> bool: + return not (timestep[0] > self.end_t).item() + + def should_do_easycache(self, timestep: float) -> bool: + return (timestep[0] <= self.start_t).item() + + def has_x_prev_subsampled(self) -> bool: + return self.x_prev_subsampled is not None + + def has_output_prev_subsampled(self) -> bool: + return self.output_prev_subsampled is not None + + def has_output_prev_norm(self) -> bool: + return self.output_prev_norm is not None + + def has_relative_transformation_rate(self) -> bool: + return self.relative_transformation_rate is not None + + def prepare_timesteps(self, model_sampling): + self.start_t = model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model_sampling.percent_to_sigma(self.end_percent) + return self + + def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor: + batch_offset = x.shape[0] // len(uuids) + uuid_idx = uuids.index(self.first_cond_uuid) + if self.subsample_factor > 1: + to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor] + if clone: + return to_return.clone() + return to_return + to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...] + if clone: + return to_return.clone() + return to_return + + def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: + return all(uuid in self.uuid_cache_diffs for uuid in uuids) + + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + if self.first_cond_uuid in uuids and not is_audio: + self.total_steps_skipped += 1 + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs + batch_offset = x.shape[0] // len(uuids) + for i, uuid in enumerate(uuids): + # slice out only what is relevant to this cond + batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)] + # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) + if x.shape[1:] != cache_diffs[uuid].shape[1:]: + if not self.allow_mismatch: + raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good") + slicing = [] + skip_this_dim = True + for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape): + if skip_this_dim: + skip_this_dim = False + continue + if dim_u != dim_x: + if self.cut_from_start: + slicing.append(slice(dim_x-dim_u, None)) + else: + slicing.append(slice(None, dim_u)) + else: + slicing.append(slice(None)) + batch_slice = batch_slice + slicing + x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device) + return x + + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs + # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video) + if output.shape[1:] != x.shape[1:]: + if not self.allow_mismatch: + raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good") + slicing = [] + skip_dim = True + for dim_o, dim_x in zip(output.shape, x.shape): + if not skip_dim and dim_o != dim_x: + if self.cut_from_start: + slicing.append(slice(dim_x-dim_o, None)) + else: + slicing.append(slice(None, dim_o)) + else: + slicing.append(slice(None)) + skip_dim = False + x = x[tuple(slicing)] + diff = output - x + batch_offset = diff.shape[0] // len(uuids) + for i, uuid in enumerate(uuids): + cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] + + def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: + return self.first_cond_uuid in uuids + + def check_metadata(self, x: torch.Tensor) -> bool: + metadata = (x.device, x.dtype, x.shape[1:]) + if self.state_metadata is None: + self.state_metadata = metadata + return True + if metadata == self.state_metadata: + return True + logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state") + self.reset() + return False + + def reset(self): + self.relative_transformation_rate = 0.0 + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.skip_current_step = False + self.output_change_rates = [] + self.first_cond_uuid = None + del self.x_prev_subsampled + self.x_prev_subsampled = None + del self.output_prev_subsampled + self.output_prev_subsampled = None + del self.output_prev_norm + self.output_prev_norm = None + del self.uuid_cache_diffs + self.uuid_cache_diffs = {} + del self.uuid_cache_diffs_audio + self.uuid_cache_diffs_audio = {} + self.total_steps_skipped = 0 + self.state_metadata = None + return self + + def clone(self): + return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels) + + +class EasyCacheNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="EasyCache", + display_name="EasyCache", + description="Native EasyCache implementation.", + category="advanced/debug/model", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to add EasyCache to."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache.", advanced=True), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache.", advanced=True), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True), + ], + outputs=[ + io.Model.Output(tooltip="The model with EasyCache."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + model = model.clone() + model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper) + return io.NodeOutput(model) + + +class LazyCacheHolder: + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None): + self.name = "LazyCache" + self.reuse_threshold = reuse_threshold + self.start_percent = start_percent + self.end_percent = end_percent + self.subsample_factor = subsample_factor + self.offload_cache_diff = offload_cache_diff + self.verbose = verbose + # timestep values + self.start_t = 0.0 + self.end_t = 0.0 + # control values + self.relative_transformation_rate: float = None + self.cumulative_change_rate = 0.0 + self.initial_step = True + # cache values + self.x_prev_subsampled: torch.Tensor = None + self.output_prev_subsampled: torch.Tensor = None + self.output_prev_norm: torch.Tensor = None + self.cache_diff: torch.Tensor = None + self.output_change_rates = [] + self.approx_output_change_rates = [] + self.total_steps_skipped = 0 + self.state_metadata = None + self.output_channels = output_channels + + def has_cache_diff(self) -> bool: + return self.cache_diff is not None + + def is_past_end_timestep(self, timestep: float) -> bool: + return not (timestep[0] > self.end_t).item() + + def should_do_easycache(self, timestep: float) -> bool: + return (timestep[0] <= self.start_t).item() + + def has_x_prev_subsampled(self) -> bool: + return self.x_prev_subsampled is not None + + def has_output_prev_subsampled(self) -> bool: + return self.output_prev_subsampled is not None + + def has_output_prev_norm(self) -> bool: + return self.output_prev_norm is not None + + def has_relative_transformation_rate(self) -> bool: + return self.relative_transformation_rate is not None + + def prepare_timesteps(self, model_sampling): + self.start_t = model_sampling.percent_to_sigma(self.start_percent) + self.end_t = model_sampling.percent_to_sigma(self.end_percent) + return self + + def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor: + if self.subsample_factor > 1: + to_return = x[..., ::self.subsample_factor, ::self.subsample_factor] + if clone: + return to_return.clone() + return to_return + if clone: + return x.clone() + return x + + def apply_cache_diff(self, x: torch.Tensor): + self.total_steps_skipped += 1 + return x + self.cache_diff.to(x.device) + + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor): + self.cache_diff = output - x + + def check_metadata(self, x: torch.Tensor) -> bool: + metadata = (x.device, x.dtype, x.shape) + if self.state_metadata is None: + self.state_metadata = metadata + return True + if metadata == self.state_metadata: + return True + logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state") + self.reset() + return False + + def reset(self): + self.relative_transformation_rate = 0.0 + self.cumulative_change_rate = 0.0 + self.initial_step = True + self.output_change_rates = [] + self.approx_output_change_rates = [] + del self.cache_diff + self.cache_diff = None + del self.x_prev_subsampled + self.x_prev_subsampled = None + del self.output_prev_subsampled + self.output_prev_subsampled = None + del self.output_prev_norm + self.output_prev_norm = None + self.total_steps_skipped = 0 + self.state_metadata = None + return self + + def clone(self): + return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels) + +class LazyCacheNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LazyCache", + display_name="LazyCache", + description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.", + category="advanced/debug/model", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to add LazyCache to."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache.", advanced=True), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache.", advanced=True), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True), + ], + outputs=[ + io.Model.Output(tooltip="The model with LazyCache."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: + model = model.clone() + model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper) + model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper) + return io.NodeOutput(model) + + +class EasyCacheExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EasyCacheNode, + LazyCacheNode, + ] + +def comfy_entrypoint(): + return EasyCacheExtension() diff --git a/ComfyUI/comfy_extras/nodes_edit_model.py b/ComfyUI/comfy_extras/nodes_edit_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9d213b9789c25fc1a5205f661b2af1fe10b47f --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_edit_model.py @@ -0,0 +1,38 @@ +import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class ReferenceLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReferenceLatent", + category="advanced/conditioning/edit_models", + description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, conditioning, latent=None) -> io.NodeOutput: + if latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True) + return io.NodeOutput(conditioning) + + +class EditModelExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ReferenceLatent, + ] + + +def comfy_entrypoint() -> EditModelExtension: + return EditModelExtension() diff --git a/ComfyUI/comfy_extras/nodes_eps.py b/ComfyUI/comfy_extras/nodes_eps.py new file mode 100644 index 0000000000000000000000000000000000000000..887d77081d94f638dcb4714cb1f9cf993cfe7f3a --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_eps.py @@ -0,0 +1,172 @@ +import torch +from typing_extensions import override + +from comfy.k_diffusion.sampling import sigma_to_half_log_snr +from comfy_api.latest import ComfyExtension, io + + +class EpsilonScaling(io.ComfyNode): + """ + Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models' + (https://arxiv.org/abs/2308.15321v6). + + This method mitigates exposure bias by scaling the predicted noise during sampling, + which can significantly improve sample quality. This implementation uses the "uniform schedule" + recommended by the paper for its practicality and effectiveness. + """ + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Epsilon Scaling", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "scaling_factor", + default=1.005, + min=0.5, + max=1.5, + step=0.001, + display_mode=io.NumberDisplay.number, + advanced=True, + ), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, scaling_factor) -> io.NodeOutput: + # Prevent division by zero, though the UI's min value should prevent this. + if scaling_factor == 0: + scaling_factor = 1e-9 + + def epsilon_scaling_function(args): + """ + This function is applied after the CFG guidance has been calculated. + It recalculates the denoised latent by scaling the predicted noise. + """ + denoised = args["denoised"] + x = args["input"] + + noise_pred = x - denoised + + scaled_noise_pred = noise_pred / scaling_factor + + new_denoised = x - scaled_noise_pred + + return new_denoised + + # Clone the model patcher to avoid modifying the original model in place + model_clone = model.clone() + + model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function) + + return io.NodeOutput(model_clone) + + +def compute_tsr_rescaling_factor( + snr: torch.Tensor, tsr_k: float, tsr_variance: float +) -> torch.Tensor: + """Compute the rescaling score ratio in Temporal Score Rescaling. + + See equation (6) in https://arxiv.org/pdf/2510.01184v1. + """ + posinf_mask = torch.isposinf(snr) + rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1) + return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k + + +class TemporalScoreRescaling(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TemporalScoreRescaling", + display_name="TSR - Temporal Score Rescaling", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "tsr_k", + tooltip=( + "Controls the rescaling strength.\n" + "Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling." + ), + default=0.95, + min=0.01, + max=100.0, + step=0.001, + display_mode=io.NumberDisplay.number, + advanced=True, + ), + io.Float.Input( + "tsr_sigma", + tooltip=( + "Controls how early rescaling takes effect.\n" + "Larger values take effect earlier." + ), + default=1.0, + min=0.01, + max=100.0, + step=0.001, + display_mode=io.NumberDisplay.number, + advanced=True, + ), + ], + outputs=[ + io.Model.Output( + display_name="patched_model", + ), + ], + description=( + "[Post-CFG Function]\n" + "TSR - Temporal Score Rescaling (2510.01184)\n\n" + "Rescaling the model's score or noise to steer the sampling diversity.\n" + ), + ) + + @classmethod + def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput: + tsr_variance = tsr_sigma**2 + + def temporal_score_rescaling(args): + denoised = args["denoised"] + x = args["input"] + sigma = args["sigma"] + curr_model = args["model"] + + # No rescaling (r = 1) or no noise + if tsr_k == 1 or sigma == 0: + return denoised + + model_sampling = curr_model.current_patcher.get_model_object("model_sampling") + half_log_snr = sigma_to_half_log_snr(sigma, model_sampling) + snr = (2 * half_log_snr).exp() + + # No rescaling needed (r = 1) + if snr == 0: + return denoised + + rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance) + + # Derived from scaled_denoised = (x - r * sigma * noise) / alpha + alpha = sigma * half_log_snr.exp() + return torch.lerp(x / alpha, denoised, rescaling_r) + + m = model.clone() + m.set_model_sampler_post_cfg_function(temporal_score_rescaling) + return io.NodeOutput(m) + + +class EpsilonScalingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EpsilonScaling, + TemporalScoreRescaling, + ] + + +async def comfy_entrypoint() -> EpsilonScalingExtension: + return EpsilonScalingExtension() diff --git a/ComfyUI/comfy_extras/nodes_flux.py b/ComfyUI/comfy_extras/nodes_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..77d59b1a78b50b26c2cee5ad1ada2f4afbdd5cab --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_flux.py @@ -0,0 +1,314 @@ +import node_helpers +import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import comfy.model_management +import torch +import math +import nodes +import comfy.ldm.flux.math + +class CLIPTextEncodeFlux(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeFlux", + category="advanced/conditioning/flux", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput: + tokens = clip.tokenize(clip_l) + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance})) + + encode = execute # TODO: remove + +class EmptyFlux2LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyFlux2LatentImage", + display_name="Empty Flux 2 Latent", + category="latent", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + +class FluxGuidance(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FluxGuidance", + category="advanced/conditioning/flux", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, conditioning, guidance) -> io.NodeOutput: + c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance}) + return io.NodeOutput(c) + + append = execute # TODO: remove + + +class FluxDisableGuidance(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FluxDisableGuidance", + category="advanced/conditioning/flux", + description="This node completely disables the guidance embed on Flux and Flux like models", + inputs=[ + io.Conditioning.Input("conditioning"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, conditioning) -> io.NodeOutput: + c = node_helpers.conditioning_set_values(conditioning, {"guidance": None}) + return io.NodeOutput(c) + + append = execute # TODO: remove + + +PREFERED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +class FluxKontextImageScale(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FluxKontextImageScale", + category="advanced/conditioning/flux", + description="This node resizes the image to one that is more optimal for flux kontext.", + inputs=[ + io.Image.Input("image"), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, image) -> io.NodeOutput: + width = image.shape[2] + height = image.shape[1] + aspect_ratio = width / height + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) + image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) + return io.NodeOutput(image) + + scale = execute # TODO: remove + + +class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FluxKontextMultiReferenceLatentMethod", + display_name="Edit Model Reference Method", + category="advanced/conditioning/flux", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Combo.Input( + "reference_latents_method", + options=["offset", "index", "uxo/uno", "index_timestep_zero"], + advanced=True, + ), + ], + outputs=[ + io.Conditioning.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput: + if "uxo" in reference_latents_method or "uso" in reference_latents_method: + reference_latents_method = "uxo" + c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) + return io.NodeOutput(c) + + append = execute # TODO: remove + + +def generalized_time_snr_shift(t, mu: float, sigma: float): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +def get_schedule(num_steps: int, image_seq_len: int) -> list[float]: + mu = compute_empirical_mu(image_seq_len, num_steps) + timesteps = torch.linspace(1, 0, num_steps + 1) + timesteps = generalized_time_snr_shift(timesteps, mu, 1.0) + return timesteps + + +class Flux2Scheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Flux2Scheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=4096), + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) + + @classmethod + def execute(cls, steps, width, height) -> io.NodeOutput: + seq_len = (width * height / (16 * 16)) + sigmas = get_schedule(steps, round(seq_len)) + return io.NodeOutput(sigmas) + +class KV_Attn_Input: + def __init__(self): + self.cache = {} + + def __call__(self, q, k, v, extra_options, **kwargs): + reference_image_num_tokens = extra_options.get("reference_image_num_tokens", []) + if len(reference_image_num_tokens) == 0: + return {} + + ref_toks = sum(reference_image_num_tokens) + cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"]) + if cache_key in self.cache: + kk, vv = self.cache[cache_key] + self.set_cache = False + return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)} + + self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone()) + self.set_cache = True + return {"q": q, "k": k, "v": v} + + def cleanup(self): + self.cache = {} + + +class FluxKVCache(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="FluxKVCache", + display_name="Flux KV Cache", + description="Enables KV Cache optimization for reference images on Flux family models.", + category="", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to use KV Cache on."), + ], + outputs=[ + io.Model.Output(tooltip="The patched model with KV Cache enabled."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type) -> io.NodeOutput: + m = model.clone() + input_patch_obj = KV_Attn_Input() + + def model_input_patch(inputs): + if len(input_patch_obj.cache) > 0: + ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", [])) + if ref_image_tokens > 0: + img = inputs["img"] + inputs["img"] = img[:, :-ref_image_tokens] + return inputs + + m.set_model_attn1_patch(input_patch_obj) + m.set_model_post_input_patch(model_input_patch) + if hasattr(model.model.diffusion_model, "params"): + m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero") + else: + m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero") + + return io.NodeOutput(m) + +class FluxExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeFlux, + FluxGuidance, + FluxDisableGuidance, + FluxKontextImageScale, + FluxKontextMultiReferenceLatentMethod, + EmptyFlux2LatentImage, + Flux2Scheduler, + FluxKVCache, + ] + + +async def comfy_entrypoint() -> FluxExtension: + return FluxExtension() diff --git a/ComfyUI/comfy_extras/nodes_frame_interpolation.py b/ComfyUI/comfy_extras/nodes_frame_interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..0aab675ec5c0c56b03b72cfc9d7ffb1980b64a04 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_frame_interpolation.py @@ -0,0 +1,211 @@ +import torch +from tqdm import tqdm +from typing_extensions import override + +import comfy.model_patcher +import comfy.utils +import folder_paths +from comfy import model_management +from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config +from comfy_extras.frame_interpolation_models.film_net import FILMNet +from comfy_api.latest import ComfyExtension, io + +FrameInterpolationModel = io.Custom("INTERP_MODEL") + + +class FrameInterpolationModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolationModelLoader", + display_name="Load Frame Interpolation Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"), + tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."), + ], + outputs=[ + FrameInterpolationModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + model = cls._detect_and_load(sd) + dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 + model.eval().to(dtype) + patcher = comfy.model_patcher.ModelPatcher( + model, + load_device=model_management.get_torch_device(), + offload_device=model_management.unet_offload_device(), + ) + return io.NodeOutput(patcher) + + @classmethod + def _detect_and_load(cls, sd): + # Try FILM + if "extract.extract_sublevels.convs.0.0.conv.weight" in sd: + model = FILMNet() + model.load_state_dict(sd) + return model + + # Try RIFE (needs key remapping for raw checkpoints) + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) + key_map = {} + for k in sd: + for i in range(5): + if k.startswith(f"block{i}."): + key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}" + if key_map: + sd = {key_map.get(k, k): v for k, v in sd.items()} + sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))} + + try: + head_ch, channels = detect_rife_config(sd) + except (KeyError, ValueError): + raise ValueError("Unrecognized frame interpolation model format") + model = IFNet(head_ch=head_ch, channels=channels) + model.load_state_dict(sd) + return model + + +class FrameInterpolate(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolate", + display_name="Frame Interpolate", + category="image/video", + search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], + inputs=[ + FrameInterpolationModel.Input("interp_model"), + io.Image.Input("images"), + io.Int.Input("multiplier", default=2, min=2, max=16), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, interp_model, images, multiplier) -> io.NodeOutput: + offload_device = model_management.intermediate_device() + + num_frames = images.shape[0] + if num_frames < 2 or multiplier < 2: + return io.NodeOutput(images) + + model_management.load_model_gpu(interp_model) + device = interp_model.load_device + dtype = interp_model.model_dtype() + inference_model = interp_model.model + + # Free VRAM for inference activations (model weights + ~20x a single frame's worth) + H, W = images.shape[1], images.shape[2] + activation_mem = H * W * 3 * images.element_size() * 20 + model_management.free_memory(activation_mem, device) + align = getattr(inference_model, "pad_align", 1) + + # Prepare a single padded frame on device for determining output dimensions + def prepare_frame(idx): + frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device) + if align > 1: + from comfy.ldm.common_dit import pad_to_patch_size + frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect") + return frame + + # Count total interpolation passes for progress bar + total_pairs = num_frames - 1 + num_interp = multiplier - 1 + total_steps = total_pairs * num_interp + pbar = comfy.utils.ProgressBar(total_steps) + tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation") + + batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit) + t_values = [t / multiplier for t in range(1, multiplier)] + + out_dtype = model_management.intermediate_dtype() + total_out_frames = total_pairs * multiplier + 1 + result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device) + result[0] = images[0].movedim(-1, 0).to(out_dtype) + out_idx = 1 + + # Pre-compute timestep tensor on device (padded dimensions needed) + sample = prepare_frame(0) + pH, pW = sample.shape[2], sample.shape[3] + ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) + ts_full = ts_full.expand(-1, 1, pH, pW) + del sample + + multi_fn = getattr(inference_model, "forward_multi_timestep", None) + feat_cache = {} + prev_frame = None + + try: + for i in range(total_pairs): + img0_single = prev_frame if prev_frame is not None else prepare_frame(i) + img1_single = prepare_frame(i + 1) + prev_frame = img1_single + + # Cache features: img1 of pair N becomes img0 of pair N+1 + feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single) + feat_cache["img1"] = inference_model.extract_features(img1_single) + feat_cache["next"] = feat_cache["img1"] + + used_multi = False + if multi_fn is not None: + # Models with timestep-independent flow can compute it once for all timesteps + try: + mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) + result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype) + out_idx += num_interp + pbar.update(num_interp) + tqdm_bar.update(num_interp) + used_multi = True + except model_management.OOM_EXCEPTION: + model_management.soft_empty_cache() + multi_fn = None # fall through to single-timestep path + + if not used_multi: + j = 0 + while j < num_interp: + b = min(batch, num_interp - j) + try: + img0 = img0_single.expand(b, -1, -1, -1) + img1 = img1_single.expand(b, -1, -1, -1) + mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache) + result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype) + out_idx += b + pbar.update(b) + tqdm_bar.update(b) + j += b + except model_management.OOM_EXCEPTION: + if batch <= 1: + raise + batch = max(1, batch // 2) + model_management.soft_empty_cache() + + result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype) + out_idx += 1 + finally: + tqdm_bar.close() + + # BCHW -> BHWC + result = result.movedim(1, -1).clamp_(0.0, 1.0) + return io.NodeOutput(result) + + +class FrameInterpolationExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FrameInterpolationModelLoader, + FrameInterpolate, + ] + + +async def comfy_entrypoint() -> FrameInterpolationExtension: + return FrameInterpolationExtension() diff --git a/ComfyUI/comfy_extras/nodes_freelunch.py b/ComfyUI/comfy_extras/nodes_freelunch.py new file mode 100644 index 0000000000000000000000000000000000000000..bbb4676f4b6f08cdccde463ab6641fa367d76ecc --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_freelunch.py @@ -0,0 +1,138 @@ +#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) + +import torch +import logging +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO + +def Fourier_filter(x, threshold, scale): + # FFT + x_freq = torch.fft.fftn(x.float(), dim=(-2, -1)) + x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + + crow, ccol = H // 2, W //2 + mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real + + return x_filtered.to(x.dtype) + + +class FreeU(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="FreeU", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True), + ], + outputs=[ + IO.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: + model_channels = model.model.model_config.unet_config["model_channels"] + scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} + on_cpu_devices = {} + + def output_block_patch(h, hsp, transformer_options): + scale = scale_dict.get(int(h.shape[1]), None) + if scale is not None: + h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0] + if hsp.device not in on_cpu_devices: + try: + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + except: + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) + on_cpu_devices[hsp.device] = True + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + else: + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + + return h, hsp + + m = model.clone() + m.set_model_output_block_patch(output_block_patch) + return IO.NodeOutput(m) + + patch = execute # TODO: remove + + +class FreeU_V2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="FreeU_V2", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True), + ], + outputs=[ + IO.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: + model_channels = model.model.model_config.unet_config["model_channels"] + scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} + on_cpu_devices = {} + + def output_block_patch(h, hsp, transformer_options): + scale = scale_dict.get(int(h.shape[1]), None) + if scale is not None: + hidden_mean = h.mean(1).unsqueeze(1) + B = hidden_mean.shape[0] + hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) + + h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1) + + if hsp.device not in on_cpu_devices: + try: + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + except: + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) + on_cpu_devices[hsp.device] = True + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + else: + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + + return h, hsp + + m = model.clone() + m.set_model_output_block_patch(output_block_patch) + return IO.NodeOutput(m) + + patch = execute # TODO: remove + + +class FreelunchExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + FreeU, + FreeU_V2, + ] + + +async def comfy_entrypoint() -> FreelunchExtension: + return FreelunchExtension() diff --git a/ComfyUI/comfy_extras/nodes_fresca.py b/ComfyUI/comfy_extras/nodes_fresca.py new file mode 100644 index 0000000000000000000000000000000000000000..6658eab7b361bd6b77170b974b1658b40262f2c8 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_fresca.py @@ -0,0 +1,115 @@ +# Code based on https://github.com/WikiChao/FreSca (MIT License) +import torch +import torch.fft as fft +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): + """ + Apply frequency-dependent scaling to an image tensor using Fourier transforms. + + Parameters: + x: Input tensor of shape (B, C, H, W) + scale_low: Scaling factor for low-frequency components (default: 1.0) + scale_high: Scaling factor for high-frequency components (default: 1.5) + freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) + + Returns: + x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. + """ + # Preserve input dtype and device + dtype, device = x.dtype, x.device + + # Convert to float32 for FFT computations + x = x.to(torch.float32) + + # 1) Apply FFT and shift low frequencies to center + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + # Initialize mask with high-frequency scaling factor + mask = torch.ones(x_freq.shape, device=device) * scale_high + m = mask + for d in range(len(x_freq.shape) - 2): + dim = d + 2 + cc = x_freq.shape[dim] // 2 + f_c = min(freq_cutoff, cc) + m = m.narrow(dim, cc - f_c, f_c * 2) + + # Apply low-frequency scaling factor to center region + m[:] = scale_low + + # 3) Apply frequency-specific scaling + x_freq = x_freq * mask + + # 4) Convert back to spatial domain + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + # 5) Restore original dtype + x_filtered = x_filtered.to(dtype) + + return x_filtered + + +class FreSca(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FreSca", + search_aliases=["frequency guidance"], + display_name="FreSca", + category="_for_testing", + description="Applies frequency-dependent scaling to the guidance", + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01, + tooltip="Scaling factor for low-frequency components", advanced=True), + io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01, + tooltip="Scaling factor for high-frequency components", advanced=True), + io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1, + tooltip="Number of frequency indices around center to consider as low-frequency", advanced=True), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model, scale_low, scale_high, freq_cutoff): + def custom_cfg_function(args): + conds_out = args["conds_out"] + if len(conds_out) <= 1 or None in args["conds"][:2]: + return conds_out + cond = conds_out[0] + uncond = conds_out[1] + + guidance = cond - uncond + filtered_guidance = Fourier_filter( + guidance, + scale_low=scale_low, + scale_high=scale_high, + freq_cutoff=freq_cutoff, + ) + filtered_cond = filtered_guidance + uncond + + return [filtered_cond, uncond] + conds_out[2:] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(custom_cfg_function) + + return io.NodeOutput(m) + + +class FreScaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FreSca, + ] + + +async def comfy_entrypoint() -> FreScaExtension: + return FreScaExtension() diff --git a/ComfyUI/comfy_extras/nodes_gits.py b/ComfyUI/comfy_extras/nodes_gits.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d7318a17709845c93550f7d6f391835a21b24c --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_gits.py @@ -0,0 +1,382 @@ +# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main +import numpy as np +import torch +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + +NOISE_LEVELS = { + 0.80: [ + [14.61464119, 7.49001646, 0.02916753], + [14.61464119, 11.54541874, 6.77309084, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 3.07277966, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 2.05039096, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 2.05039096, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.75677586, 2.84484982, 1.78698075, 0.803307, 0.02916753], + ], + 0.85: [ + [14.61464119, 7.49001646, 0.02916753], + [14.61464119, 7.49001646, 1.84880662, 0.02916753], + [14.61464119, 11.54541874, 6.77309084, 1.56271636, 0.02916753], + [14.61464119, 11.54541874, 7.11996698, 3.07277966, 1.24153244, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.60512662, 2.6383388, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + ], + 0.90: [ + [14.61464119, 6.77309084, 0.02916753], + [14.61464119, 7.49001646, 1.56271636, 0.02916753], + [14.61464119, 7.49001646, 3.07277966, 0.95350921, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.07277966, 1.61558151, 0.69515091, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.11996698, 4.86714602, 3.07277966, 1.61558151, 0.69515091, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 2.95596409, 1.61558151, 0.69515091, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753], + [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.19988537, 1.51179266, 0.89115214, 0.43325692, 0.02916753], + ], + 0.95: [ + [14.61464119, 6.77309084, 0.02916753], + [14.61464119, 6.77309084, 1.56271636, 0.02916753], + [14.61464119, 7.49001646, 2.84484982, 0.89115214, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.36326075, 0.803307, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.41535246, 0.803307, 0.38853383, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753], + [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753], + [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753], + [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.78698075, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.75677586, 3.07277966, 2.45070267, 1.91321158, 1.46270394, 1.05362725, 0.72133851, 0.43325692, 0.19894916, 0.02916753], + ], + 1.00: [ + [14.61464119, 1.56271636, 0.02916753], + [14.61464119, 6.77309084, 0.95350921, 0.02916753], + [14.61464119, 6.77309084, 2.36326075, 0.803307, 0.02916753], + [14.61464119, 7.11996698, 3.07277966, 1.56271636, 0.59516323, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.41535246, 0.57119018, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.54755926, 0.25053367, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.12350607, 1.56271636, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.803307, 0.50118381, 0.27464288, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.84880662, 1.36964464, 1.01931262, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.46139455, 2.84484982, 2.19988537, 1.67050016, 1.24153244, 0.92192322, 0.64427125, 0.43325692, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + ], + 1.05: [ + [14.61464119, 0.95350921, 0.02916753], + [14.61464119, 6.77309084, 0.89115214, 0.02916753], + [14.61464119, 6.77309084, 2.05039096, 0.72133851, 0.02916753], + [14.61464119, 6.77309084, 2.84484982, 1.28281462, 0.52423614, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.803307, 0.34370604, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.56271636, 0.803307, 0.34370604, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.52423614, 0.22545385, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.74807048, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.59516323, 0.34370604, 0.13792117, 0.02916753], + [14.61464119, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.72759056, 1.24153244, 0.86115354, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.67050016, 1.28281462, 0.95350921, 0.72133851, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.36326075, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.61951244, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.98035145, 1.61558151, 1.32549286, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + ], + 1.10: [ + [14.61464119, 0.89115214, 0.02916753], + [14.61464119, 2.36326075, 0.72133851, 0.02916753], + [14.61464119, 5.85520077, 1.61558151, 0.57119018, 0.02916753], + [14.61464119, 6.77309084, 2.45070267, 1.08895338, 0.45573691, 0.02916753], + [14.61464119, 6.77309084, 2.95596409, 1.56271636, 0.803307, 0.34370604, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.89115214, 0.4783645, 0.19894916, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.08895338, 0.64427125, 0.34370604, 0.13792117, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.54755926, 0.27464288, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.4783645, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.41535246, 0.95350921, 0.64427125, 0.41087446, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.61558151, 1.12534678, 0.803307, 0.54755926, 0.36617002, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.32507086, 2.45070267, 1.72759056, 1.24153244, 0.89115214, 0.64427125, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.05039096, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.12350607, 1.61558151, 1.24153244, 0.95350921, 0.72133851, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + ], + 1.15: [ + [14.61464119, 0.83188516, 0.02916753], + [14.61464119, 1.84880662, 0.59516323, 0.02916753], + [14.61464119, 5.85520077, 1.56271636, 0.52423614, 0.02916753], + [14.61464119, 5.85520077, 1.91321158, 0.83188516, 0.34370604, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.59516323, 0.25053367, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.51179266, 0.803307, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.56271636, 0.89115214, 0.50118381, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.12534678, 0.72133851, 0.43325692, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.36964464, 0.95350921, 0.69515091, 0.4783645, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.20: [ + [14.61464119, 0.803307, 0.02916753], + [14.61464119, 1.56271636, 0.52423614, 0.02916753], + [14.61464119, 2.36326075, 0.92192322, 0.36617002, 0.02916753], + [14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.25053367, 0.02916753], + [14.61464119, 5.85520077, 2.05039096, 0.95350921, 0.45573691, 0.17026083, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.64427125, 0.29807833, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.803307, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.95350921, 0.59516323, 0.36617002, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.83188516, 0.59516323, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 3.07277966, 1.98035145, 1.36964464, 0.95350921, 0.69515091, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 6.77309084, 3.46139455, 2.36326075, 1.56271636, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 6.77309084, 3.46139455, 2.45070267, 1.61558151, 1.162866, 0.86115354, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.20157266, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.25: [ + [14.61464119, 0.72133851, 0.02916753], + [14.61464119, 1.56271636, 0.50118381, 0.02916753], + [14.61464119, 2.05039096, 0.803307, 0.32104823, 0.02916753], + [14.61464119, 2.36326075, 0.95350921, 0.43325692, 0.17026083, 0.02916753], + [14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.27464288, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.51179266, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.36326075, 1.24153244, 0.72133851, 0.41087446, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.98595673, 0.64427125, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.52423614, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.86115354, 0.64427125, 0.4783645, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.28281462, 0.92192322, 0.69515091, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.46270394, 1.08895338, 0.83188516, 0.66947293, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.30: [ + [14.61464119, 0.72133851, 0.02916753], + [14.61464119, 1.24153244, 0.43325692, 0.02916753], + [14.61464119, 1.56271636, 0.59516323, 0.22545385, 0.02916753], + [14.61464119, 1.84880662, 0.803307, 0.36617002, 0.13792117, 0.02916753], + [14.61464119, 2.36326075, 1.01931262, 0.52423614, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.36964464, 0.74807048, 0.41087446, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.56271636, 0.89115214, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.61558151, 0.95350921, 0.61951244, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.41535246, 0.92192322, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.6383388, 1.56271636, 1.01931262, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.77538133, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.35: [ + [14.61464119, 0.69515091, 0.02916753], + [14.61464119, 0.95350921, 0.34370604, 0.02916753], + [14.61464119, 1.56271636, 0.57119018, 0.19894916, 0.02916753], + [14.61464119, 1.61558151, 0.69515091, 0.29807833, 0.09824532, 0.02916753], + [14.61464119, 1.84880662, 0.83188516, 0.43325692, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.162866, 0.64427125, 0.36617002, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.36964464, 0.803307, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.41535246, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.32104823, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.51179266, 1.01931262, 0.74807048, 0.57119018, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.40: [ + [14.61464119, 0.59516323, 0.02916753], + [14.61464119, 0.95350921, 0.34370604, 0.02916753], + [14.61464119, 1.08895338, 0.43325692, 0.13792117, 0.02916753], + [14.61464119, 1.56271636, 0.64427125, 0.27464288, 0.09824532, 0.02916753], + [14.61464119, 1.61558151, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 2.05039096, 0.95350921, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.24153244, 0.72133851, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.52423614, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.41535246, 0.86115354, 0.59516323, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.4783645, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.72133851, 0.54755926, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.43325692, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.45: [ + [14.61464119, 0.59516323, 0.02916753], + [14.61464119, 0.803307, 0.25053367, 0.02916753], + [14.61464119, 0.95350921, 0.34370604, 0.09824532, 0.02916753], + [14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 1.91321158, 0.95350921, 0.57119018, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.19988537, 1.08895338, 0.64427125, 0.41087446, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.72133851, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.59516323, 0.50118381, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.50: [ + [14.61464119, 0.54755926, 0.02916753], + [14.61464119, 0.803307, 0.25053367, 0.02916753], + [14.61464119, 0.86115354, 0.32104823, 0.09824532, 0.02916753], + [14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 1.61558151, 0.83188516, 0.52423614, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.38853383, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.41087446, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 1.84880662, 0.95350921, 0.61951244, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.32549286, 0.86115354, 0.64427125, 0.50118381, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.36964464, 0.92192322, 0.69515091, 0.54755926, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.41535246, 0.95350921, 0.72133851, 0.57119018, 0.4783645, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], +} + +class GITSScheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="GITSScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05, advanced=True), + io.Int.Input("steps", default=10, min=2, max=1000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) + + @classmethod + def execute(cls, coeff, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return io.NodeOutput(torch.FloatTensor([])) + total_steps = round(steps * denoise) + + if steps <= 20: + sigmas = NOISE_LEVELS[round(coeff, 2)][steps-2][:] + else: + sigmas = NOISE_LEVELS[round(coeff, 2)][-1][:] + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas = sigmas[-(total_steps + 1):] + sigmas[-1] = 0 + return io.NodeOutput(torch.FloatTensor(sigmas)) + + +class GITSSchedulerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + GITSScheduler, + ] + + +async def comfy_entrypoint() -> GITSSchedulerExtension: + return GITSSchedulerExtension() diff --git a/ComfyUI/comfy_extras/nodes_glsl.py b/ComfyUI/comfy_extras/nodes_glsl.py new file mode 100644 index 0000000000000000000000000000000000000000..2d03c486f7e73c0a11894fe135a0652d20623ddb --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_glsl.py @@ -0,0 +1,958 @@ +import os +import sys +import re +import logging +import ctypes.util +import importlib.util +from typing import TypedDict + +import numpy as np +import torch + +import nodes +from comfy_api.latest import ComfyExtension, io, ui +from typing_extensions import override +from utils.install_util import get_missing_requirements_message + +logger = logging.getLogger(__name__) + + +def _check_opengl_availability(): + """Early check for OpenGL availability. Raises RuntimeError if unlikely to work.""" + logger.debug("_check_opengl_availability: starting") + missing = [] + + # Check Python packages (using find_spec to avoid importing) + logger.debug("_check_opengl_availability: checking for glfw package") + if importlib.util.find_spec("glfw") is None: + missing.append("glfw") + + logger.debug("_check_opengl_availability: checking for OpenGL package") + if importlib.util.find_spec("OpenGL") is None: + missing.append("PyOpenGL") + + if missing: + raise RuntimeError( + f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n" + ) + + # On Linux without display, check if headless backends are available + logger.debug(f"_check_opengl_availability: platform={sys.platform}") + if sys.platform.startswith("linux"): + has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY") + logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}") + if not has_display: + # Check for EGL or OSMesa libraries + logger.debug("_check_opengl_availability: checking for EGL library") + has_egl = ctypes.util.find_library("EGL") + logger.debug("_check_opengl_availability: checking for OSMesa library") + has_osmesa = ctypes.util.find_library("OSMesa") + + # Error disabled for CI as it fails this check + # if not has_egl and not has_osmesa: + # raise RuntimeError( + # "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n" + # "See error below for installation instructions." + # ) + logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}") + + logger.debug("_check_opengl_availability: completed") + + +# Run early check at import time +logger.debug("nodes_glsl: running _check_opengl_availability at import time") +_check_opengl_availability() + +# OpenGL modules - initialized lazily when context is created +gl = None +glfw = None +EGL = None + + +def _import_opengl(): + """Import OpenGL module. Called after context is created.""" + global gl + if gl is None: + logger.debug("_import_opengl: importing OpenGL.GL") + import OpenGL.GL as _gl + gl = _gl + logger.debug("_import_opengl: import completed") + return gl + + +class SizeModeInput(TypedDict): + size_mode: str + width: int + height: int + + +MAX_IMAGES = 5 # u_image0-4 +MAX_UNIFORMS = 20 # u_float0-19, u_int0-19 +MAX_BOOLS = 10 # u_bool0-9 +MAX_CURVES = 4 # u_curve0-3 (1D LUT textures) +MAX_OUTPUTS = 4 # fragColor0-3 (MRT) + +# Vertex shader using gl_VertexID trick - no VBO needed. +# Draws a single triangle that covers the entire screen: +# +# (-1,3) +# /| +# / | <- visible area is the unit square from (-1,-1) to (1,1) +# / | parts outside get clipped away +# (-1,-1)---(3,-1) +# +# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1) +VERTEX_SHADER = """#version 330 core +out vec2 v_texCoord; +void main() { + vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3)); + v_texCoord = verts[gl_VertexID] * 0.5 + 0.5; + gl_Position = vec4(verts[gl_VertexID], 0, 1); +} +""" + +DEFAULT_FRAGMENT_SHADER = """#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform vec2 u_resolution; + +in vec2 v_texCoord; +layout(location = 0) out vec4 fragColor0; + +void main() { + fragColor0 = texture(u_image0, v_texCoord); +} +""" + + +def _convert_es_to_desktop(source: str) -> str: + """Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core.""" + # Remove any existing #version directive + source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE) + # Remove precision qualifiers (not needed in desktop GLSL) + source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source) + # Prepend desktop GLSL version + return "#version 330 core\n" + source + + +def _detect_output_count(source: str) -> int: + """Detect how many fragColor outputs are used in the shader. + + Returns the count of outputs needed (1 to MAX_OUTPUTS). + """ + matches = re.findall(r"fragColor(\d+)", source) + if not matches: + return 1 # Default to 1 output if none found + max_index = max(int(m) for m in matches) + return min(max_index + 1, MAX_OUTPUTS) + + +def _detect_pass_count(source: str) -> int: + """Detect multi-pass rendering from #pragma passes N directive. + + Returns the number of passes (1 if not specified). + """ + match = re.search(r'#pragma\s+passes\s+(\d+)', source) + if match: + return max(1, int(match.group(1))) + return 1 + + +def _init_glfw(): + """Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure.""" + logger.debug("_init_glfw: starting") + # On macOS, glfw.init() must be called from main thread or it hangs forever + if sys.platform == "darwin": + logger.debug("_init_glfw: skipping on macOS") + raise RuntimeError("GLFW backend not supported on macOS") + + logger.debug("_init_glfw: importing glfw module") + import glfw as _glfw + + logger.debug("_init_glfw: calling glfw.init()") + if not _glfw.init(): + raise RuntimeError("glfw.init() failed") + + try: + logger.debug("_init_glfw: setting window hints") + _glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE) + _glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3) + _glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3) + _glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE) + + logger.debug("_init_glfw: calling create_window()") + window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None) + if not window: + raise RuntimeError("glfw.create_window() failed") + + logger.debug("_init_glfw: calling make_context_current()") + _glfw.make_context_current(window) + logger.debug("_init_glfw: completed successfully") + return window, _glfw + except Exception: + logger.debug("_init_glfw: failed, terminating glfw") + _glfw.terminate() + raise + + +def _init_egl(): + """Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure.""" + logger.debug("_init_egl: starting") + from OpenGL import EGL as _EGL + from OpenGL.EGL import ( + eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext, + eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI, + eglTerminate, eglDestroyContext, eglDestroySurface, + EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE, + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, + EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE, + EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API, + ) + logger.debug("_init_egl: imports completed") + + display = None + context = None + surface = None + + try: + logger.debug("_init_egl: calling eglGetDisplay()") + display = eglGetDisplay(EGL_DEFAULT_DISPLAY) + if display == _EGL.EGL_NO_DISPLAY: + raise RuntimeError("eglGetDisplay() failed") + + logger.debug("_init_egl: calling eglInitialize()") + major, minor = _EGL.EGLint(), _EGL.EGLint() + if not eglInitialize(display, major, minor): + display = None # Not initialized, don't terminate + raise RuntimeError("eglInitialize() failed") + logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}") + + config_attribs = [ + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, + EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, + EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8, + EGL_DEPTH_SIZE, 0, EGL_NONE + ] + configs = (_EGL.EGLConfig * 1)() + num_configs = _EGL.EGLint() + if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0: + raise RuntimeError("eglChooseConfig() failed") + config = configs[0] + logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}") + + if not eglBindAPI(EGL_OPENGL_API): + raise RuntimeError("eglBindAPI() failed") + + logger.debug("_init_egl: calling eglCreateContext()") + context_attribs = [ + _EGL.EGL_CONTEXT_MAJOR_VERSION, 3, + _EGL.EGL_CONTEXT_MINOR_VERSION, 3, + _EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT, + EGL_NONE + ] + context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs) + if context == EGL_NO_CONTEXT: + raise RuntimeError("eglCreateContext() failed") + + logger.debug("_init_egl: calling eglCreatePbufferSurface()") + pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE] + surface = eglCreatePbufferSurface(display, config, pbuffer_attribs) + if surface == _EGL.EGL_NO_SURFACE: + raise RuntimeError("eglCreatePbufferSurface() failed") + + logger.debug("_init_egl: calling eglMakeCurrent()") + if not eglMakeCurrent(display, surface, surface, context): + raise RuntimeError("eglMakeCurrent() failed") + + logger.debug("_init_egl: completed successfully") + return display, context, surface, _EGL + + except Exception: + logger.debug("_init_egl: failed, cleaning up") + # Clean up any resources on failure + if surface is not None: + eglDestroySurface(display, surface) + if context is not None: + eglDestroyContext(display, context) + if display is not None: + eglTerminate(display) + raise + + +def _init_osmesa(): + """Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure.""" + import ctypes + + logger.debug("_init_osmesa: starting") + os.environ["PYOPENGL_PLATFORM"] = "osmesa" + + logger.debug("_init_osmesa: importing OpenGL.osmesa") + from OpenGL import GL as _gl + from OpenGL.osmesa import ( + OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext, + OSMESA_RGBA, + ) + logger.debug("_init_osmesa: imports completed") + + ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None) + if not ctx: + raise RuntimeError("OSMesaCreateContextExt() failed") + + width, height = 64, 64 + buffer = (ctypes.c_ubyte * (width * height * 4))() + + logger.debug("_init_osmesa: calling OSMesaMakeCurrent()") + if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height): + OSMesaDestroyContext(ctx) + raise RuntimeError("OSMesaMakeCurrent() failed") + + logger.debug("_init_osmesa: completed successfully") + return ctx, buffer + + +class GLContext: + """Manages OpenGL context and resources for shader execution. + + Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software). + """ + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if GLContext._initialized: + logger.debug("GLContext.__init__: already initialized, skipping") + return + + logger.debug("GLContext.__init__: starting initialization") + + global glfw, EGL + + import time + start = time.perf_counter() + + self._backend = None + self._window = None + self._egl_display = None + self._egl_context = None + self._egl_surface = None + self._osmesa_ctx = None + self._osmesa_buffer = None + self._vao = None + + # Try backends in order: GLFW → EGL → OSMesa + errors = [] + + logger.debug("GLContext.__init__: trying GLFW backend") + try: + self._window, glfw = _init_glfw() + self._backend = "glfw" + logger.debug("GLContext.__init__: GLFW backend succeeded") + except Exception as e: + logger.debug(f"GLContext.__init__: GLFW backend failed: {e}") + errors.append(("GLFW", e)) + + if self._backend is None: + logger.debug("GLContext.__init__: trying EGL backend") + try: + self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl() + self._backend = "egl" + logger.debug("GLContext.__init__: EGL backend succeeded") + except Exception as e: + logger.debug(f"GLContext.__init__: EGL backend failed: {e}") + errors.append(("EGL", e)) + + if self._backend is None: + logger.debug("GLContext.__init__: trying OSMesa backend") + try: + self._osmesa_ctx, self._osmesa_buffer = _init_osmesa() + self._backend = "osmesa" + logger.debug("GLContext.__init__: OSMesa backend succeeded") + except Exception as e: + logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}") + errors.append(("OSMesa", e)) + + if self._backend is None: + if sys.platform == "win32": + platform_help = ( + "Windows: Ensure GPU drivers are installed and display is available.\n" + " CPU-only/headless mode is not supported on Windows." + ) + elif sys.platform == "darwin": + platform_help = ( + "macOS: GLFW is not supported.\n" + " Install OSMesa via Homebrew: brew install mesa\n" + " Then: pip install PyOpenGL PyOpenGL-accelerate" + ) + else: + platform_help = ( + "Linux: Install one of these backends:\n" + " Desktop: sudo apt install libgl1-mesa-glx libglfw3\n" + " Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n" + " Headless (CPU): sudo apt install libosmesa6" + ) + + error_details = "\n".join(f" {name}: {err}" for name, err in errors) + raise RuntimeError( + f"Failed to create OpenGL context.\n\n" + f"Backend errors:\n{error_details}\n\n" + f"{platform_help}" + ) + + # Now import OpenGL.GL (after context is current) + logger.debug("GLContext.__init__: importing OpenGL.GL") + _import_opengl() + + # Create VAO (required for core profile, but OSMesa may use compat profile) + logger.debug("GLContext.__init__: creating VAO") + try: + vao = gl.glGenVertexArrays(1) + gl.glBindVertexArray(vao) + self._vao = vao # Only store after successful bind + logger.debug("GLContext.__init__: VAO created successfully") + except Exception as e: + logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}") + # OSMesa with older Mesa may not support VAOs + # Clean up if we created but couldn't bind + if vao: + try: + gl.glDeleteVertexArrays(1, [vao]) + except Exception: + pass + + elapsed = (time.perf_counter() - start) * 1000 + + # Log device info + renderer = gl.glGetString(gl.GL_RENDERER) + vendor = gl.glGetString(gl.GL_VENDOR) + version = gl.glGetString(gl.GL_VERSION) + renderer = renderer.decode() if renderer else "Unknown" + vendor = vendor.decode() if vendor else "Unknown" + version = version.decode() if version else "Unknown" + + GLContext._initialized = True + logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}") + + def make_current(self): + if self._backend == "glfw": + glfw.make_context_current(self._window) + elif self._backend == "egl": + from OpenGL.EGL import eglMakeCurrent + eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context) + elif self._backend == "osmesa": + from OpenGL.osmesa import OSMesaMakeCurrent + OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64) + + if self._vao is not None: + gl.glBindVertexArray(self._vao) + + +def _compile_shader(source: str, shader_type: int) -> int: + """Compile a shader and return its ID.""" + shader = gl.glCreateShader(shader_type) + gl.glShaderSource(shader, source) + gl.glCompileShader(shader) + + if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: + error = gl.glGetShaderInfoLog(shader).decode() + gl.glDeleteShader(shader) + raise RuntimeError(f"Shader compilation failed:\n{error}") + + return shader + + +def _create_program(vertex_source: str, fragment_source: str) -> int: + """Create and link a shader program.""" + vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER) + try: + fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER) + except RuntimeError: + gl.glDeleteShader(vertex_shader) + raise + + program = gl.glCreateProgram() + gl.glAttachShader(program, vertex_shader) + gl.glAttachShader(program, fragment_shader) + gl.glLinkProgram(program) + + gl.glDeleteShader(vertex_shader) + gl.glDeleteShader(fragment_shader) + + if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE: + error = gl.glGetProgramInfoLog(program).decode() + gl.glDeleteProgram(program) + raise RuntimeError(f"Program linking failed:\n{error}") + + return program + + +def _render_shader_batch( + fragment_code: str, + width: int, + height: int, + image_batches: list[list[np.ndarray]], + floats: list[float], + ints: list[int], + bools: list[bool] | None = None, + curves: list[np.ndarray] | None = None, +) -> list[list[np.ndarray]]: + """ + Render a fragment shader for multiple batches efficiently. + + Compiles shader once, reuses framebuffer/textures across batches. + Supports multi-pass rendering via #pragma passes N directive. + + Args: + fragment_code: User's fragment shader code + width: Output width + height: Output height + image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1] + floats: List of float uniforms + ints: List of int uniforms + bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms) + curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N + + Returns: + List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1] + """ + import time + start_time = time.perf_counter() + + if not image_batches: + return [] + + ctx = GLContext() + ctx.make_current() + + # Convert from GLSL ES to desktop GLSL 330 + fragment_source = _convert_es_to_desktop(fragment_code) + + # Detect how many outputs the shader actually uses + num_outputs = _detect_output_count(fragment_code) + + # Detect multi-pass rendering + num_passes = _detect_pass_count(fragment_code) + + if bools is None: + bools = [] + if curves is None: + curves = [] + + # Track resources for cleanup + program = None + fbo = None + output_textures = [] + input_textures = [] + curve_textures = [] + ping_pong_textures = [] + ping_pong_fbos = [] + + num_inputs = len(image_batches[0]) + + try: + # Compile shaders (once for all batches) + try: + program = _create_program(VERTEX_SHADER, fragment_source) + except RuntimeError: + logger.error(f"Fragment shader:\n{fragment_source}") + raise + + gl.glUseProgram(program) + + # Create framebuffer with only the needed color attachments + fbo = gl.glGenFramebuffers(1) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) + + draw_buffers = [] + for i in range(num_outputs): + tex = gl.glGenTextures(1) + output_textures.append(tex) + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0) + draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i) + + gl.glDrawBuffers(num_outputs, draw_buffers) + + if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE: + raise RuntimeError("Framebuffer is not complete") + + # Create ping-pong resources for multi-pass rendering + if num_passes > 1: + for _ in range(2): + pp_tex = gl.glGenTextures(1) + ping_pong_textures.append(pp_tex) + gl.glBindTexture(gl.GL_TEXTURE_2D, pp_tex) + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) + + pp_fbo = gl.glGenFramebuffers(1) + ping_pong_fbos.append(pp_fbo) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, pp_fbo) + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, pp_tex, 0) + gl.glDrawBuffers(1, [gl.GL_COLOR_ATTACHMENT0]) + + if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE: + raise RuntimeError("Ping-pong framebuffer is not complete") + + # Create input textures (reused for all batches) + for i in range(num_inputs): + tex = gl.glGenTextures(1) + input_textures.append(tex) + gl.glActiveTexture(gl.GL_TEXTURE0 + i) + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) + + loc = gl.glGetUniformLocation(program, f"u_image{i}") + if loc >= 0: + gl.glUniform1i(loc, i) + + # Set static uniforms (once for all batches) + loc = gl.glGetUniformLocation(program, "u_resolution") + if loc >= 0: + gl.glUniform2f(loc, float(width), float(height)) + + for i, v in enumerate(floats): + loc = gl.glGetUniformLocation(program, f"u_float{i}") + if loc >= 0: + gl.glUniform1f(loc, v) + + for i, v in enumerate(ints): + loc = gl.glGetUniformLocation(program, f"u_int{i}") + if loc >= 0: + gl.glUniform1i(loc, v) + + for i, v in enumerate(bools): + loc = gl.glGetUniformLocation(program, f"u_bool{i}") + if loc >= 0: + gl.glUniform1i(loc, 1 if v else 0) + + # Create 1D LUT textures for curves (bound after image texture units) + for i, lut in enumerate(curves): + tex = gl.glGenTextures(1) + curve_textures.append(tex) + unit = MAX_IMAGES + i + gl.glActiveTexture(gl.GL_TEXTURE0 + unit) + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) + + loc = gl.glGetUniformLocation(program, f"u_curve{i}") + if loc >= 0: + gl.glUniform1i(loc, unit) + + # Get u_pass uniform location for multi-pass + pass_loc = gl.glGetUniformLocation(program, "u_pass") + + gl.glViewport(0, 0, width, height) + gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly + + # Process each batch + all_batch_outputs = [] + for images in image_batches: + # Update input textures with this batch's images + for i, img in enumerate(images): + gl.glActiveTexture(gl.GL_TEXTURE0 + i) + gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i]) + + # Flip vertically for GL coordinates, ensure RGBA + h, w, c = img.shape + if c == 3: + img_upload = np.empty((h, w, 4), dtype=np.float32) + img_upload[:, :, :3] = img[::-1, :, :] + img_upload[:, :, 3] = 1.0 + else: + img_upload = np.ascontiguousarray(img[::-1, :, :]) + + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload) + + if num_passes == 1: + # Single pass - render directly to output FBO + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) + if pass_loc >= 0: + gl.glUniform1i(pass_loc, 0) + gl.glClearColor(0, 0, 0, 0) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) + else: + # Multi-pass rendering with ping-pong + for p in range(num_passes): + is_last_pass = (p == num_passes - 1) + + # Set pass uniform + if pass_loc >= 0: + gl.glUniform1i(pass_loc, p) + + if is_last_pass: + # Last pass renders to the main output FBO + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) + else: + # Intermediate passes render to ping-pong FBO + target_fbo = ping_pong_fbos[p % 2] + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, target_fbo) + + # Set input texture for this pass + gl.glActiveTexture(gl.GL_TEXTURE0) + if p == 0: + # First pass reads from original input + gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[0]) + else: + # Subsequent passes read from previous pass output + source_tex = ping_pong_textures[(p - 1) % 2] + gl.glBindTexture(gl.GL_TEXTURE_2D, source_tex) + + gl.glClearColor(0, 0, 0, 0) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) + + # Read back outputs for this batch + # (glGetTexImage is synchronous, implicitly waits for rendering) + batch_outputs = [] + for tex in output_textures: + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT) + img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4) + batch_outputs.append(img[::-1, :, :].copy()) + + # Pad with black images for unused outputs + black_img = np.zeros((height, width, 4), dtype=np.float32) + for _ in range(num_outputs, MAX_OUTPUTS): + batch_outputs.append(black_img) + + all_batch_outputs.append(batch_outputs) + + elapsed = (time.perf_counter() - start_time) * 1000 + num_batches = len(image_batches) + pass_info = f", {num_passes} passes" if num_passes > 1 else "" + logger.info(f"GLSL shader executed in {elapsed:.1f}ms ({num_batches} batch{'es' if num_batches != 1 else ''}, {width}x{height}{pass_info})") + + return all_batch_outputs + + finally: + # Unbind before deleting + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) + gl.glUseProgram(0) + + for tex in input_textures: + gl.glDeleteTextures(int(tex)) + for tex in curve_textures: + gl.glDeleteTextures(int(tex)) + for tex in output_textures: + gl.glDeleteTextures(int(tex)) + for tex in ping_pong_textures: + gl.glDeleteTextures(int(tex)) + if fbo is not None: + gl.glDeleteFramebuffers(1, [fbo]) + for pp_fbo in ping_pong_fbos: + gl.glDeleteFramebuffers(1, [pp_fbo]) + if program is not None: + gl.glDeleteProgram(program) + +class GLSLShader(io.ComfyNode): + + @classmethod + def define_schema(cls) -> io.Schema: + image_template = io.Autogrow.TemplatePrefix( + io.Image.Input("image"), + prefix="image", + min=1, + max=MAX_IMAGES, + ) + + float_template = io.Autogrow.TemplatePrefix( + io.Float.Input("float", default=0.0), + prefix="u_float", + min=0, + max=MAX_UNIFORMS, + ) + + int_template = io.Autogrow.TemplatePrefix( + io.Int.Input("int", default=0), + prefix="u_int", + min=0, + max=MAX_UNIFORMS, + ) + + bool_template = io.Autogrow.TemplatePrefix( + io.Boolean.Input("bool", default=False), + prefix="u_bool", + min=0, + max=MAX_BOOLS, + ) + + curve_template = io.Autogrow.TemplatePrefix( + io.Curve.Input("curve"), + prefix="u_curve", + min=0, + max=MAX_CURVES, + ) + + return io.Schema( + node_id="GLSLShader", + display_name="GLSL Shader", + category="image/shader", + description=( + "Apply GLSL ES fragment shaders to images. " + "u_resolution (vec2) is always available." + ), + is_experimental=True, + has_intermediate_output=True, + inputs=[ + io.String.Input( + "fragment_shader", + default=DEFAULT_FRAGMENT_SHADER, + multiline=True, + tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)", + ), + io.DynamicCombo.Input( + "size_mode", + options=[ + io.DynamicCombo.Option("from_input", []), + io.DynamicCombo.Option( + "custom", + [ + io.Int.Input( + "width", + default=512, + min=1, + max=nodes.MAX_RESOLUTION, + ), + io.Int.Input( + "height", + default=512, + min=1, + max=nodes.MAX_RESOLUTION, + ), + ], + ), + ], + tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size", + ), + io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"), + io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"), + io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"), + io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"), + io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"), + ], + outputs=[ + io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"), + io.Image.Output(display_name="IMAGE1", tooltip="Available via layout(location = 1) out vec4 fragColor1 in the shader code"), + io.Image.Output(display_name="IMAGE2", tooltip="Available via layout(location = 2) out vec4 fragColor2 in the shader code"), + io.Image.Output(display_name="IMAGE3", tooltip="Available via layout(location = 3) out vec4 fragColor3 in the shader code"), + ], + ) + + @classmethod + def execute( + cls, + fragment_shader: str, + size_mode: SizeModeInput, + images: io.Autogrow.Type, + floats: io.Autogrow.Type = None, + ints: io.Autogrow.Type = None, + bools: io.Autogrow.Type = None, + curves: io.Autogrow.Type = None, + **kwargs, + ) -> io.NodeOutput: + + image_list = [v for v in images.values() if v is not None] + float_list = ( + [v if v is not None else 0.0 for v in floats.values()] if floats else [] + ) + int_list = [v if v is not None else 0 for v in ints.values()] if ints else [] + bool_list = [v if v is not None else False for v in bools.values()] if bools else [] + + curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else [] + + if not image_list: + raise ValueError("At least one input image is required") + + # Determine output dimensions + if size_mode["size_mode"] == "custom": + out_width = size_mode["width"] + out_height = size_mode["height"] + else: + out_height, out_width = image_list[0].shape[1:3] + + batch_size = image_list[0].shape[0] + + # Prepare batches + image_batches = [] + for batch_idx in range(batch_size): + batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list] + image_batches.append(batch_images) + + all_batch_outputs = _render_shader_batch( + fragment_shader, + out_width, + out_height, + image_batches, + float_list, + int_list, + bool_list, + curve_luts, + ) + + # Collect outputs into tensors + all_outputs = [[] for _ in range(MAX_OUTPUTS)] + for batch_outputs in all_batch_outputs: + for i, out_img in enumerate(batch_outputs): + all_outputs[i].append(torch.from_numpy(out_img)) + + output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)] + return io.NodeOutput( + *output_tensors, + ui=cls._build_ui_output(image_list, output_tensors[0]), + ) + + @classmethod + def _build_ui_output( + cls, image_list: list[torch.Tensor], output_batch: torch.Tensor + ) -> dict[str, list]: + """Build UI output with input and output images for client-side shader execution.""" + input_images_ui = [] + for img in image_list: + input_images_ui.extend(ui.ImageSaveHelper.save_images( + img, + filename_prefix="GLSLShader_input", + folder_type=io.FolderType.temp, + cls=None, + compress_level=1, + )) + + output_images_ui = ui.ImageSaveHelper.save_images( + output_batch, + filename_prefix="GLSLShader_output", + folder_type=io.FolderType.temp, + cls=None, + compress_level=1, + ) + + return {"input_images": input_images_ui, "images": output_images_ui} + + +class GLSLExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [GLSLShader] + + +async def comfy_entrypoint() -> GLSLExtension: + return GLSLExtension() diff --git a/ComfyUI/comfy_extras/nodes_hidream.py b/ComfyUI/comfy_extras/nodes_hidream.py new file mode 100644 index 0000000000000000000000000000000000000000..1599146e583df54350702146ce67623d8a107b6b --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_hidream.py @@ -0,0 +1,74 @@ +from typing_extensions import override + +import folder_paths +import comfy.sd +import comfy.model_management +from comfy_api.latest import ComfyExtension, io + + +class QuadrupleCLIPLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="QuadrupleCLIPLoader", + category="advanced/loaders", + description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ] + ) + + @classmethod + def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4): + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) + clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) + return io.NodeOutput(clip) + +class CLIPTextEncodeHiDream(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeHiDream", + search_aliases=["hidream prompt"], + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.String.Input("llama", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, clip, clip_l, clip_g, t5xxl, llama): + tokens = clip.tokenize(clip_g) + tokens["l"] = clip.tokenize(clip_l)["l"] + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + tokens["llama"] = clip.tokenize(llama)["llama"] + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +class HiDreamExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + QuadrupleCLIPLoader, + CLIPTextEncodeHiDream, + ] + + +async def comfy_entrypoint() -> HiDreamExtension: + return HiDreamExtension() diff --git a/ComfyUI/comfy_extras/nodes_hooks.py b/ComfyUI/comfy_extras/nodes_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2c85428108f8490b036c4c7d7f107f15df919f --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_hooks.py @@ -0,0 +1,750 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Union +import logging +import torch +from collections.abc import Iterable + +if TYPE_CHECKING: + from comfy.sd import CLIP + +import comfy.hooks +import comfy.sd +import comfy.utils +import folder_paths + +########################################### +# Mask, Combine, and Hook Conditioning +#------------------------------------------ +class PairConditioningSetProperties: + NodeId = 'PairConditioningSetProperties' + NodeName = 'Cond Pair Set Props' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive_NEW": ("CONDITIONING", ), + "negative_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "set_properties" + + def set_properties(self, positive_NEW, negative_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_positive, final_negative) + +class PairConditioningSetPropertiesAndCombine: + NodeId = 'PairConditioningSetPropertiesAndCombine' + NodeName = 'Cond Pair Set Props Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "positive_NEW": ("CONDITIONING", ), + "negative_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "set_properties" + + def set_properties(self, positive, negative, positive_NEW, negative_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_positive, final_negative) + +class ConditioningSetProperties: + NodeId = 'ConditioningSetProperties' + NodeName = 'Cond Set Props' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/cond single" + FUNCTION = "set_properties" + + def set_properties(self, cond_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + (final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_cond,) + +class ConditioningSetPropertiesAndCombine: + NodeId = 'ConditioningSetPropertiesAndCombine' + NodeName = 'Cond Set Props Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond": ("CONDITIONING", ), + "cond_NEW": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), + }, + "optional": { + "mask": ("MASK", ), + "hooks": ("HOOKS",), + "timesteps": ("TIMESTEPS_RANGE",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/cond single" + FUNCTION = "set_properties" + + def set_properties(self, cond, cond_NEW, + strength: float, set_cond_area: str, + mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): + (final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW], + strength=strength, set_cond_area=set_cond_area, + mask=mask, hooks=hooks, timesteps_range=timesteps) + return (final_cond,) + +class PairConditioningCombine: + NodeId = 'PairConditioningCombine' + NodeName = 'Cond Pair Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive_A": ("CONDITIONING",), + "negative_A": ("CONDITIONING",), + "positive_B": ("CONDITIONING",), + "negative_B": ("CONDITIONING",), + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "combine" + + def combine(self, positive_A, negative_A, positive_B, negative_B): + final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) + return (final_positive, final_negative,) + +class PairConditioningSetDefaultAndCombine: + NodeId = 'PairConditioningSetDefaultCombine' + NodeName = 'Cond Pair Set Default Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "positive_DEFAULT": ("CONDITIONING",), + "negative_DEFAULT": ("CONDITIONING",), + }, + "optional": { + "hooks": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + CATEGORY = "advanced/hooks/cond pair" + FUNCTION = "set_default_and_combine" + + def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT, + hooks: comfy.hooks.HookGroup=None): + final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], + hooks=hooks) + return (final_positive, final_negative) + +class ConditioningSetDefaultAndCombine: + NodeId = 'ConditioningSetDefaultCombine' + NodeName = 'Cond Set Default Combine' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "cond": ("CONDITIONING",), + "cond_DEFAULT": ("CONDITIONING",), + }, + "optional": { + "hooks": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/cond single" + FUNCTION = "set_default_and_combine" + + def set_default_and_combine(self, cond, cond_DEFAULT, + hooks: comfy.hooks.HookGroup=None): + (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT], + hooks=hooks) + return (final_conditioning,) + +class SetClipHooks: + NodeId = 'SetClipHooks' + NodeName = 'Set CLIP Hooks' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "clip": ("CLIP",), + "apply_to_conds": ("BOOLEAN", {"default": True, "advanced": True}), + "schedule_clip": ("BOOLEAN", {"default": False, "advanced": True}) + }, + "optional": { + "hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CLIP",) + CATEGORY = "advanced/hooks/clip" + FUNCTION = "apply_hooks" + + def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): + if hooks is not None: + clip = clip.clone(disable_dynamic=True) + if apply_to_conds: + clip.apply_hooks_to_conds = hooks + clip.patcher.forced_hooks = hooks.clone() + clip.use_clip_schedule = schedule_clip + if not clip.use_clip_schedule: + clip.patcher.forced_hooks.set_keyframes_on_hooks(None) + clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) + return (clip,) + +class ConditioningTimestepsRange: + SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"] + NodeId = 'ConditioningTimestepsRange' + NodeName = 'Timesteps Range' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE") + RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE") + CATEGORY = "advanced/hooks" + FUNCTION = "create_range" + + def create_range(self, start_percent: float, end_percent: float): + return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0)) +#------------------------------------------ +########################################### + + +########################################### +# Create Hooks +#------------------------------------------ +class CreateHookLora: + NodeId = 'CreateHookLora' + NodeName = 'Create Hook LoRA' + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook" + + def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None): + if prev_hooks is None: + prev_hooks = comfy.hooks.HookGroup() + prev_hooks.clone() + + if strength_model == 0 and strength_clip == 0: + return (prev_hooks,) + + lora_path = folder_paths.get_full_path("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + temp = self.loaded_lora + self.loaded_lora = None + del temp + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip) + return (prev_hooks.clone_and_combine(hooks),) + +class CreateHookLoraModelOnly(CreateHookLora): + NodeId = 'CreateHookLoraModelOnly' + NodeName = 'Create Hook LoRA (MO)' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook_model_only" + + def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None): + return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks) + +class CreateHookModelAsLora: + NodeId = 'CreateHookModelAsLora' + NodeName = 'Create Hook Model as LoRA' + + def __init__(self): + # when not None, will be in following format: + # (ckpt_path: str, weights_model: dict, weights_clip: dict) + self.loaded_weights = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook" + + def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float, + prev_hooks: comfy.hooks.HookGroup=None): + if prev_hooks is None: + prev_hooks = comfy.hooks.HookGroup() + prev_hooks.clone() + + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + weights_model = None + weights_clip = None + if self.loaded_weights is not None: + if self.loaded_weights[0] == ckpt_path: + weights_model = self.loaded_weights[1] + weights_clip = self.loaded_weights[2] + else: + temp = self.loaded_weights + self.loaded_weights = None + del temp + + if weights_model is None: + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + weights_model = comfy.hooks.get_patch_weights_from_model(out[0]) + weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1]) + self.loaded_weights = (ckpt_path, weights_model, weights_clip) + + hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip, + strength_model=strength_model, strength_clip=strength_clip) + return (prev_hooks.clone_and_combine(hooks),) + +class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora): + NodeId = 'CreateHookModelAsLoraModelOnly' + NodeName = 'Create Hook Model as LoRA (MO)' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "prev_hooks": ("HOOKS",) + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/create" + FUNCTION = "create_hook_model_only" + + def create_hook_model_only(self, ckpt_name: str, strength_model: float, + prev_hooks: comfy.hooks.HookGroup=None): + return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks) +#------------------------------------------ +########################################### + + +########################################### +# Schedule Hooks +#------------------------------------------ +class SetHookKeyframes: + NodeId = 'SetHookKeyframes' + NodeName = 'Set Hook Keyframes' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "hooks": ("HOOKS",), + }, + "optional": { + "hook_kf": ("HOOK_KEYFRAMES",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "set_hook_keyframes" + + def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None): + if hook_kf is not None: + hooks = hooks.clone() + hooks.set_keyframes_on_hooks(hook_kf=hook_kf) + return (hooks,) + +class CreateHookKeyframe: + SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"] + NodeId = 'CreateHookKeyframe' + NodeName = 'Create Hook Keyframe' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + "optional": { + "prev_hook_kf": ("HOOK_KEYFRAMES",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOK_KEYFRAMES",) + RETURN_NAMES = ("HOOK_KF",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "create_hook_keyframe" + + def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None): + if prev_hook_kf is None: + prev_hook_kf = comfy.hooks.HookKeyframeGroup() + prev_hook_kf = prev_hook_kf.clone() + keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent) + prev_hook_kf.add(keyframe) + return (prev_hook_kf,) + +class CreateHookKeyframesInterpolated: + SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"] + NodeId = 'CreateHookKeyframesInterpolated' + NodeName = 'Create Hook Keyframes Interp.' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), + "strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), + "interpolation": (comfy.hooks.InterpolationMethod._LIST, ), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}), + "print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}), + }, + "optional": { + "prev_hook_kf": ("HOOK_KEYFRAMES",), + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOK_KEYFRAMES",) + RETURN_NAMES = ("HOOK_KF",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "create_hook_keyframes" + + def create_hook_keyframes(self, strength_start: float, strength_end: float, interpolation: str, + start_percent: float, end_percent: float, keyframes_count: int, + print_keyframes=False, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None): + if prev_hook_kf is None: + prev_hook_kf = comfy.hooks.HookKeyframeGroup() + prev_hook_kf = prev_hook_kf.clone() + percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=keyframes_count, + method=comfy.hooks.InterpolationMethod.LINEAR) + strengths = comfy.hooks.InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=keyframes_count, method=interpolation) + + is_first = True + for percent, strength in zip(percents, strengths): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) + if print_keyframes: + logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}") + return (prev_hook_kf,) + +class CreateHookKeyframesFromFloats: + SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"] + NodeId = 'CreateHookKeyframesFromFloats' + NodeName = 'Create Hook Keyframes From Floats' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}), + }, + "optional": { + "prev_hook_kf": ("HOOK_KEYFRAMES",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOK_KEYFRAMES",) + RETURN_NAMES = ("HOOK_KF",) + CATEGORY = "advanced/hooks/scheduling" + FUNCTION = "create_hook_keyframes" + + def create_hook_keyframes(self, floats_strength: Union[float, list[float]], + start_percent: float, end_percent: float, + prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False): + if prev_hook_kf is None: + prev_hook_kf = comfy.hooks.HookKeyframeGroup() + prev_hook_kf = prev_hook_kf.clone() + if type(floats_strength) in (float, int): + floats_strength = [float(floats_strength)] + elif isinstance(floats_strength, Iterable): + pass + else: + raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.") + percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength), + method=comfy.hooks.InterpolationMethod.LINEAR) + + is_first = True + for percent, strength in zip(percents, floats_strength): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) + if print_keyframes: + logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}") + return (prev_hook_kf,) +#------------------------------------------ +########################################### + + +class SetModelHooksOnCond: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "conditioning": ("CONDITIONING",), + "hooks": ("HOOKS",), + }, + } + + EXPERIMENTAL = True + RETURN_TYPES = ("CONDITIONING",) + CATEGORY = "advanced/hooks/manual" + FUNCTION = "attach_hook" + + def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup): + return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),) + + +########################################### +# Combine Hooks +#------------------------------------------ +class CombineHooks: + SEARCH_ALIASES = ["merge hooks"] + NodeId = 'CombineHooks2' + NodeName = 'Combine Hooks [2]' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "hooks_A": ("HOOKS",), + "hooks_B": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/combine" + FUNCTION = "combine_hooks" + + def combine_hooks(self, + hooks_A: comfy.hooks.HookGroup=None, + hooks_B: comfy.hooks.HookGroup=None): + candidates = [hooks_A, hooks_B] + return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) + +class CombineHooksFour: + NodeId = 'CombineHooks4' + NodeName = 'Combine Hooks [4]' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "hooks_A": ("HOOKS",), + "hooks_B": ("HOOKS",), + "hooks_C": ("HOOKS",), + "hooks_D": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/combine" + FUNCTION = "combine_hooks" + + def combine_hooks(self, + hooks_A: comfy.hooks.HookGroup=None, + hooks_B: comfy.hooks.HookGroup=None, + hooks_C: comfy.hooks.HookGroup=None, + hooks_D: comfy.hooks.HookGroup=None): + candidates = [hooks_A, hooks_B, hooks_C, hooks_D] + return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) + +class CombineHooksEight: + NodeId = 'CombineHooks8' + NodeName = 'Combine Hooks [8]' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + }, + "optional": { + "hooks_A": ("HOOKS",), + "hooks_B": ("HOOKS",), + "hooks_C": ("HOOKS",), + "hooks_D": ("HOOKS",), + "hooks_E": ("HOOKS",), + "hooks_F": ("HOOKS",), + "hooks_G": ("HOOKS",), + "hooks_H": ("HOOKS",), + } + } + + EXPERIMENTAL = True + RETURN_TYPES = ("HOOKS",) + CATEGORY = "advanced/hooks/combine" + FUNCTION = "combine_hooks" + + def combine_hooks(self, + hooks_A: comfy.hooks.HookGroup=None, + hooks_B: comfy.hooks.HookGroup=None, + hooks_C: comfy.hooks.HookGroup=None, + hooks_D: comfy.hooks.HookGroup=None, + hooks_E: comfy.hooks.HookGroup=None, + hooks_F: comfy.hooks.HookGroup=None, + hooks_G: comfy.hooks.HookGroup=None, + hooks_H: comfy.hooks.HookGroup=None): + candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H] + return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) +#------------------------------------------ +########################################### + +node_list = [ + # Create + CreateHookLora, + CreateHookLoraModelOnly, + CreateHookModelAsLora, + CreateHookModelAsLoraModelOnly, + # Scheduling + SetHookKeyframes, + CreateHookKeyframe, + CreateHookKeyframesInterpolated, + CreateHookKeyframesFromFloats, + # Combine + CombineHooks, + CombineHooksFour, + CombineHooksEight, + # Attach + ConditioningSetProperties, + ConditioningSetPropertiesAndCombine, + PairConditioningSetProperties, + PairConditioningSetPropertiesAndCombine, + ConditioningSetDefaultAndCombine, + PairConditioningSetDefaultAndCombine, + PairConditioningCombine, + SetClipHooks, + # Other + ConditioningTimestepsRange, +] +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +for node in node_list: + NODE_CLASS_MAPPINGS[node.NodeId] = node + NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName diff --git a/ComfyUI/comfy_extras/nodes_hunyuan.py b/ComfyUI/comfy_extras/nodes_hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..90dd4890b36d4207903bbb69eeda522ace13ec44 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_hunyuan.py @@ -0,0 +1,427 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler +import folder_paths +import json + +class CLIPTextEncodeHunyuanDiT(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeHunyuanDiT", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("bert", multiline=True, dynamic_prompts=True), + io.String.Input("mt5xl", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, bert, mt5xl) -> io.NodeOutput: + tokens = clip.tokenize(bert) + tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] + + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + encode = execute # TODO: remove + + +class EmptyHunyuanLatentVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyHunyuanLatentVideo", + display_name="Empty HunyuanVideo 1.0 Latent", + category="latent/video", + inputs=[ + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8}) + + generate = execute # TODO: remove + + +class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): + @classmethod + def define_schema(cls): + schema = super().define_schema() + schema.node_id = "EmptyHunyuanVideo15Latent" + schema.display_name = "Empty HunyuanVideo 1.5 Latent" + return schema + + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: + # Using scale factor of 16 instead of 8 + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16}) + + +class HunyuanVideo15ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + encoded = vae.encode(start_image[:, :, :, :3]) + concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device()) + concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded + + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + +class HunyuanVideo15SuperResolution(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15SuperResolution", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae", optional=True), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Latent.Input("latent"), + io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01, advanced=True), + + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput: + in_latent = latent["samples"] + in_channels = in_latent.shape[1] + cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device()) + cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent + cond_latent[:, 2 * in_channels + 1] = 1 + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded + cond_latent[:, in_channels + 1, 0] = 1 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + return io.NodeOutput(positive, negative, latent) + + +class LatentUpscaleModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentUpscaleModelLoader", + display_name="Load Latent Upscale Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")), + ], + outputs=[ + io.LatentUpscaleModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) + sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True) + + if "blocks.0.block.0.conv.weight" in sd: + config = { + "in_channels": sd["in_conv.conv.weight"].shape[1], + "out_channels": sd["out_conv.conv.weight"].shape[0], + "hidden_channels": sd["in_conv.conv.weight"].shape[0], + "num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]), + "global_residual": False, + } + model_type = "720p" + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + elif "up.0.block.0.conv1.conv.weight" in sd: + sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} + config = { + "z_channels": sd["conv_in.conv.weight"].shape[1], + "out_channels": sd["conv_out.conv.weight"].shape[0], + "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), + } + model_type = "1080p" + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + elif "post_upsample_res_blocks.0.conv2.bias" in sd: + config = json.loads(metadata["config"]) + model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32])) + model.load_state_dict(sd) + + return io.NodeOutput(model) + + +class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15LatentUpscaleWithModel", + display_name="Hunyuan Video 15 Latent Upscale With Model", + category="latent", + inputs=[ + io.LatentUpscaleModel.Input("model"), + io.Latent.Input("samples"), + io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"), + io.Int.Input("width", default=1280, min=0, max=16384, step=8), + io.Int.Input("height", default=720, min=0, max=16384, step=8), + io.Combo.Input("crop", options=["disabled", "center"]), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput: + if width == 0 and height == 0: + return io.NodeOutput(samples) + else: + if width == 0: + height = max(64, height) + width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2])) + elif height == 0: + width = max(64, width) + height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1])) + else: + width = max(64, width) + height = max(64, height) + s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop) + s = model.resample_latent(s) + return io.NodeOutput({"samples": s.cpu().float()}) + + +PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) + +class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeHunyuanVideo_ImageToVideo", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.ClipVisionOutput.Input("clip_vision_output"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Int.Input( + "image_interleave", + default=2, + min=1, + max=512, + tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.", + advanced=True, + ), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput: + tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + encode = execute # TODO: remove + + +class HunyuanImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"], advanced=True), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + out_latent = {} + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + concat_latent_image = vae.encode(start_image) + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + if guidance_type == "v1 (concat)": + cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} + elif guidance_type == "v2 (replace)": + cond = {'guiding_frame_index': 0} + latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image + out_latent["noise_mask"] = mask + elif guidance_type == "custom": + cond = {"ref_latent": concat_latent_image} + + positive = node_helpers.conditioning_set_values(positive, cond) + + out_latent["samples"] = latent + return io.NodeOutput(positive, out_latent) + + encode = execute # TODO: remove + + +class EmptyHunyuanImageLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyHunyuanImageLatent", + category="latent", + inputs=[ + io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) + + generate = execute # TODO: remove + + +class HunyuanRefinerLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanRefinerLatent", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Latent.Input("latent"), + io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01, advanced=True), + + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput: + latent = latent["samples"] + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation}) + out_latent = {} + out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + return io.NodeOutput(positive, negative, out_latent) + + +class HunyuanExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + CLIPTextEncodeHunyuanDiT, + TextEncodeHunyuanVideo_ImageToVideo, + EmptyHunyuanLatentVideo, + EmptyHunyuanVideo15Latent, + HunyuanVideo15ImageToVideo, + HunyuanVideo15SuperResolution, + HunyuanVideo15LatentUpscaleWithModel, + LatentUpscaleModelLoader, + HunyuanImageToVideo, + EmptyHunyuanImageLatent, + HunyuanRefinerLatent, + ] + + +async def comfy_entrypoint() -> HunyuanExtension: + return HunyuanExtension() diff --git a/ComfyUI/comfy_extras/nodes_hunyuan3d.py b/ComfyUI/comfy_extras/nodes_hunyuan3d.py new file mode 100644 index 0000000000000000000000000000000000000000..2317de8adb984a6269e9beddd7358526ef1d78b1 --- /dev/null +++ b/ComfyUI/comfy_extras/nodes_hunyuan3d.py @@ -0,0 +1,697 @@ +import torch +import os +import json +import struct +import numpy as np +from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch +import folder_paths +import comfy.model_management +from comfy.cli_args import args +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, Types +from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa + + +class EmptyLatentHunyuan3Dv2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentHunyuan3Dv2", + category="latent/3d", + inputs=[ + IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + @classmethod + def execute(cls, resolution, batch_size) -> IO.NodeOutput: + latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) + return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"}) + + generate = execute # TODO: remove + + +class Hunyuan3Dv2Conditioning(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Hunyuan3Dv2Conditioning", + category="conditioning/video_models", + inputs=[ + IO.ClipVisionOutput.Input("clip_vision_output"), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) + + @classmethod + def execute(cls, clip_vision_output) -> IO.NodeOutput: + embeds = clip_vision_output.last_hidden_state + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return IO.NodeOutput(positive, negative) + + encode = execute # TODO: remove + + +class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Hunyuan3Dv2ConditioningMultiView", + category="conditioning/video_models", + inputs=[ + IO.ClipVisionOutput.Input("front", optional=True), + IO.ClipVisionOutput.Input("left", optional=True), + IO.ClipVisionOutput.Input("back", optional=True), + IO.ClipVisionOutput.Input("right", optional=True), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) + + @classmethod + def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput: + all_embeds = [front, left, back, right] + out = [] + pos_embeds = None + for i, e in enumerate(all_embeds): + if e is not None: + if pos_embeds is None: + pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4)) + out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1)) + + embeds = torch.cat(out, dim=1) + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return IO.NodeOutput(positive, negative) + + encode = execute # TODO: remove + + +class VAEDecodeHunyuan3D(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeHunyuan3D", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("num_chunks", default=8000, min=1000, max=500000, advanced=True), + IO.Int.Input("octree_resolution", default=256, min=16, max=512, advanced=True), + ], + outputs=[ + IO.Voxel.Output(), + ] + ) + + @classmethod + def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput: + voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) + return IO.NodeOutput(voxels) + + decode = execute # TODO: remove + + +def voxel_to_mesh(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + binary = (voxels > threshold).float() + padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0) + + D, H, W = binary.shape + + neighbors = torch.tensor([ + [0, 0, 1], + [0, 0, -1], + [0, 1, 0], + [0, -1, 0], + [1, 0, 0], + [-1, 0, 0] + ], device=device) + + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + solid_mask = binary.flatten() > 0 + solid_indices = voxel_indices[solid_mask] + + corner_offsets = [ + torch.tensor([ + [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0] + ], device=device), + torch.tensor([ + [1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0] + ], device=device) + ] + + all_vertices = [] + all_indices = [] + + vertex_count = 0 + + for face_idx, offset in enumerate(neighbors): + neighbor_indices = solid_indices + offset + + padded_indices = neighbor_indices + 1 + + is_exposed = padded[ + padded_indices[:, 0], + padded_indices[:, 1], + padded_indices[:, 2] + ] == 0 + + if not is_exposed.any(): + continue + + exposed_indices = solid_indices[is_exposed] + + corners = corner_offsets[face_idx].unsqueeze(0) + + face_vertices = exposed_indices.unsqueeze(1) + corners + + all_vertices.append(face_vertices.reshape(-1, 3)) + + num_faces = exposed_indices.shape[0] + face_indices = torch.arange( + vertex_count, + vertex_count + 4 * num_faces, + device=device + ).reshape(-1, 4) + + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1)) + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1)) + + vertex_count += 4 * num_faces + + if len(all_vertices) > 0: + vertices = torch.cat(all_vertices, dim=0) + faces = torch.cat(all_indices, dim=0) + else: + vertices = torch.zeros((1, 3)) + faces = torch.zeros((1, 3)) + + v_min = 0 + v_max = max(voxels.shape) + + vertices = vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + vertices = vertices / scale + + vertices = torch.fliplr(vertices) + return vertices, faces + +def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + D, H, W = voxels.shape + + padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0) + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + corner_offsets = torch.tensor([ + [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1] + ], device=device) + + pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0) + z_idx, y_idx, x_idx = pos.unbind(-1) + corner_values = padded[z_idx, y_idx, x_idx] + + corner_signs = corner_values > threshold + has_inside = torch.any(corner_signs, dim=1) + has_outside = torch.any(~corner_signs, dim=1) + contains_surface = has_inside & has_outside + + active_cells = cell_positions[contains_surface] + active_signs = corner_signs[contains_surface] + active_values = corner_values[contains_surface] + + if active_cells.shape[0] == 0: + return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device) + + edges = torch.tensor([ + [0, 1], [0, 2], [0, 4], [1, 3], + [1, 5], [2, 3], [2, 6], [3, 7], + [4, 5], [4, 6], [5, 7], [6, 7] + ], device=device) + + cell_vertices = {} + progress = comfy.utils.ProgressBar(100) + + for edge_idx, (e1, e2) in enumerate(edges): + progress.update(1) + crossing = active_signs[:, e1] != active_signs[:, e2] + if not crossing.any(): + continue + + cell_indices = torch.nonzero(crossing, as_tuple=True)[0] + + v1 = active_values[cell_indices, e1] + v2 = active_values[cell_indices, e2] + + t = torch.zeros_like(v1, device=device) + denom = v2 - v1 + valid = denom != 0 + t[valid] = (threshold - v1[valid]) / denom[valid] + t[~valid] = 0.5 + + p1 = corner_offsets[e1].float() + p2 = corner_offsets[e2].float() + + intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0)) + + for i, point in zip(cell_indices.tolist(), intersection): + if i not in cell_vertices: + cell_vertices[i] = [] + cell_vertices[i].append(point) + + # Calculate the final vertices as the average of intersection points for each cell + vertices = [] + vertex_lookup = {} + + vert_progress_mod = round(len(cell_vertices)/50) + + for i, points in cell_vertices.items(): + if not i % vert_progress_mod: + progress.update(1) + + if points: + vertex = torch.stack(points).mean(dim=0) + vertex = vertex + active_cells[i].float() + vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices) + vertices.append(vertex) + + if not vertices: + return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device) + + final_vertices = torch.stack(vertices) + + inside_corners_mask = active_signs + outside_corners_mask = ~active_signs + + inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float() + outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float() + + inside_pos = torch.zeros((active_cells.shape[0], 3), device=device) + outside_pos = torch.zeros((active_cells.shape[0], 3), device=device) + + for i in range(8): + mask_inside = inside_corners_mask[:, i].unsqueeze(1) + mask_outside = outside_corners_mask[:, i].unsqueeze(1) + inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside + outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside + + inside_pos /= inside_counts + outside_pos /= outside_counts + gradients = inside_pos - outside_pos + + pos_dirs = torch.tensor([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1] + ], device=device) + + cross_products = [ + torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float()) + for i in range(3) for j in range(i+1, 3) + ] + + faces = [] + all_keys = set(vertex_lookup.keys()) + + face_progress_mod = round(len(active_cells)/38*3) + + for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]): + dir_i = pos_dirs[i] + dir_j = pos_dirs[j] + cross_product = cross_products[pair_idx] + + ni_positions = active_cells + dir_i + nj_positions = active_cells + dir_j + diag_positions = active_cells + dir_i + dir_j + + alignments = torch.matmul(gradients, cross_product) + + valid_quads = [] + quad_indices = [] + + for idx, active_cell in enumerate(active_cells): + if not idx % face_progress_mod: + progress.update(1) + cell_key = tuple(active_cell.tolist()) + ni_key = tuple(ni_positions[idx].tolist()) + nj_key = tuple(nj_positions[idx].tolist()) + diag_key = tuple(diag_positions[idx].tolist()) + + if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys: + v0 = vertex_lookup[cell_key] + v1 = vertex_lookup[ni_key] + v2 = vertex_lookup[nj_key] + v3 = vertex_lookup[diag_key] + + valid_quads.append((v0, v1, v2, v3)) + quad_indices.append(idx) + + for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads): + cell_idx = quad_indices[q_idx] + if alignments[cell_idx] > 0: + faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long)) + faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long)) + else: + faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long)) + faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long)) + + if faces: + faces = torch.stack(faces) + else: + faces = torch.zeros((0, 3), dtype=torch.long, device=device) + + v_min = 0 + v_max = max(D, H, W) + + final_vertices = final_vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + final_vertices = final_vertices / scale + + final_vertices = torch.fliplr(final_vertices) + + return final_vertices, faces + + +class VoxelToMeshBasic(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VoxelToMeshBasic", + category="3d", + inputs=[ + IO.Voxel.Input("voxel"), + IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), + ], + outputs=[ + IO.Mesh.Output(), + ] + ) + + @classmethod + def execute(cls, voxel, threshold) -> IO.NodeOutput: + vertices = [] + faces = [] + for x in voxel.data: + v, f = voxel_to_mesh(x, threshold=threshold, device=None) + vertices.append(v) + faces.append(f) + + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + + decode = execute # TODO: remove + + +class VoxelToMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VoxelToMesh", + category="3d", + inputs=[ + IO.Voxel.Input("voxel"), + IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True), + IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), + ], + outputs=[ + IO.Mesh.Output(), + ] + ) + + @classmethod + def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput: + vertices = [] + faces = [] + + if algorithm == "basic": + mesh_function = voxel_to_mesh + elif algorithm == "surface net": + mesh_function = voxel_to_mesh_surfnet + + for x in voxel.data: + v, f = mesh_function(x, threshold=threshold, device=None) + vertices.append(v) + faces.append(f) + + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + + decode = execute # TODO: remove + + +def save_glb(vertices, faces, filepath, metadata=None): + """ + Save PyTorch tensor vertices and faces as a GLB file without external dependencies. + + Parameters: + vertices: torch.Tensor of shape (N, 3) - The vertex coordinates + faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces) + filepath: str - Output filepath (should end with .glb) + """ + + # Convert tensors to numpy arrays + vertices_np = vertices.cpu().numpy().astype(np.float32) + faces_np = faces.cpu().numpy().astype(np.uint32) + + vertices_buffer = vertices_np.tobytes() + indices_buffer = faces_np.tobytes() + + def pad_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b'\x00' * padding_length + + vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) + indices_buffer_padded = pad_to_4_bytes(indices_buffer) + + buffer_data = vertices_buffer_padded + indices_buffer_padded + + vertices_byte_length = len(vertices_buffer) + vertices_byte_offset = 0 + indices_byte_length = len(indices_buffer) + indices_byte_offset = len(vertices_buffer_padded) + + gltf = { + "asset": {"version": "2.0", "generator": "ComfyUI"}, + "buffers": [ + { + "byteLength": len(buffer_data) + } + ], + "bufferViews": [ + { + "buffer": 0, + "byteOffset": vertices_byte_offset, + "byteLength": vertices_byte_length, + "target": 34962 # ARRAY_BUFFER + }, + { + "buffer": 0, + "byteOffset": indices_byte_offset, + "byteLength": indices_byte_length, + "target": 34963 # ELEMENT_ARRAY_BUFFER + } + ], + "accessors": [ + { + "bufferView": 0, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(vertices_np), + "type": "VEC3", + "max": vertices_np.max(axis=0).tolist(), + "min": vertices_np.min(axis=0).tolist() + }, + { + "bufferView": 1, + "byteOffset": 0, + "componentType": 5125, # UNSIGNED_INT + "count": faces_np.size, + "type": "SCALAR" + } + ], + "meshes": [ + { + "primitives": [ + { + "attributes": { + "POSITION": 0 + }, + "indices": 1, + "mode": 4 # TRIANGLES + } + ] + } + ], + "nodes": [ + { + "mesh": 0 + } + ], + "scenes": [ + { + "nodes": [0] + } + ], + "scene": 0 + } + + if metadata is not None: + gltf["asset"]["extras"] = metadata + + # Convert the JSON to bytes + gltf_json = json.dumps(gltf).encode('utf8') + + def pad_json_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b' ' * padding_length + + gltf_json_padded = pad_json_to_4_bytes(gltf_json) + + # Create the GLB header + # Magic glTF + glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) + + # Create JSON chunk header (chunk type 0) + json_chunk_header = struct.pack('