Spaces:
Runtime error
Runtime error
Upload ComfyUI/comfy_extras
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- ComfyUI/comfy_extras/chainner_models/model_loading.py +6 -0
- ComfyUI/comfy_extras/frame_interpolation_models/film_net.py +258 -0
- ComfyUI/comfy_extras/frame_interpolation_models/ifnet.py +128 -0
- ComfyUI/comfy_extras/nodes_ace.py +145 -0
- ComfyUI/comfy_extras/nodes_advanced_samplers.py +121 -0
- ComfyUI/comfy_extras/nodes_align_your_steps.py +70 -0
- ComfyUI/comfy_extras/nodes_apg.py +110 -0
- ComfyUI/comfy_extras/nodes_attention_multiply.py +151 -0
- ComfyUI/comfy_extras/nodes_audio.py +794 -0
- ComfyUI/comfy_extras/nodes_audio_encoder.py +62 -0
- ComfyUI/comfy_extras/nodes_camera_trajectory.py +239 -0
- ComfyUI/comfy_extras/nodes_canny.py +45 -0
- ComfyUI/comfy_extras/nodes_cfg.py +91 -0
- ComfyUI/comfy_extras/nodes_chroma_radiance.py +117 -0
- ComfyUI/comfy_extras/nodes_clip_sdxl.py +71 -0
- ComfyUI/comfy_extras/nodes_color.py +42 -0
- ComfyUI/comfy_extras/nodes_compositing.py +226 -0
- ComfyUI/comfy_extras/nodes_cond.py +68 -0
- ComfyUI/comfy_extras/nodes_context_windows.py +103 -0
- ComfyUI/comfy_extras/nodes_controlnet.py +85 -0
- ComfyUI/comfy_extras/nodes_cosmos.py +143 -0
- ComfyUI/comfy_extras/nodes_curve.py +92 -0
- ComfyUI/comfy_extras/nodes_custom_sampler.py +1095 -0
- ComfyUI/comfy_extras/nodes_dataset.py +1537 -0
- ComfyUI/comfy_extras/nodes_differential_diffusion.py +73 -0
- ComfyUI/comfy_extras/nodes_easycache.py +530 -0
- ComfyUI/comfy_extras/nodes_edit_model.py +38 -0
- ComfyUI/comfy_extras/nodes_eps.py +172 -0
- ComfyUI/comfy_extras/nodes_flux.py +314 -0
- ComfyUI/comfy_extras/nodes_frame_interpolation.py +211 -0
- ComfyUI/comfy_extras/nodes_freelunch.py +138 -0
- ComfyUI/comfy_extras/nodes_fresca.py +115 -0
- ComfyUI/comfy_extras/nodes_gits.py +382 -0
- ComfyUI/comfy_extras/nodes_glsl.py +958 -0
- ComfyUI/comfy_extras/nodes_hidream.py +74 -0
- ComfyUI/comfy_extras/nodes_hooks.py +750 -0
- ComfyUI/comfy_extras/nodes_hunyuan.py +427 -0
- ComfyUI/comfy_extras/nodes_hunyuan3d.py +697 -0
- ComfyUI/comfy_extras/nodes_hypernetwork.py +138 -0
- ComfyUI/comfy_extras/nodes_hypertile.py +98 -0
- ComfyUI/comfy_extras/nodes_image_compare.py +54 -0
- ComfyUI/comfy_extras/nodes_images.py +851 -0
- ComfyUI/comfy_extras/nodes_ip2p.py +63 -0
- ComfyUI/comfy_extras/nodes_kandinsky5.py +137 -0
- ComfyUI/comfy_extras/nodes_latent.py +504 -0
- ComfyUI/comfy_extras/nodes_load_3d.py +131 -0
- ComfyUI/comfy_extras/nodes_logic.py +274 -0
- ComfyUI/comfy_extras/nodes_lora_debug.py +79 -0
- ComfyUI/comfy_extras/nodes_lora_extract.py +145 -0
- ComfyUI/comfy_extras/nodes_lotus.py +39 -0
ComfyUI/comfy_extras/chainner_models/model_loading.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from spandrel import ModelLoader
|
| 3 |
+
|
| 4 |
+
def load_state_dict(state_dict):
|
| 5 |
+
logging.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
|
| 6 |
+
return ModelLoader().load_from_state_dict(state_dict).eval()
|
ComfyUI/comfy_extras/frame_interpolation_models/film_net.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FILM: Frame Interpolation for Large Motion (ECCV 2022)."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
import comfy.ops
|
| 8 |
+
|
| 9 |
+
ops = comfy.ops.disable_weight_init
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FilmConv2d(nn.Module):
|
| 13 |
+
"""Conv2d with optional LeakyReLU and FILM-style padding."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.even_pad = not size % 2
|
| 18 |
+
self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype)
|
| 19 |
+
self.activation = nn.LeakyReLU(0.2) if activation else None
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
if self.even_pad:
|
| 23 |
+
x = F.pad(x, (0, 1, 0, 1))
|
| 24 |
+
x = self.conv(x)
|
| 25 |
+
if self.activation is not None:
|
| 26 |
+
x = self.activation(x)
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _warp_core(image, flow, grid_x, grid_y):
|
| 31 |
+
dtype = image.dtype
|
| 32 |
+
H, W = flow.shape[2], flow.shape[3]
|
| 33 |
+
dx = flow[:, 0].float() / (W * 0.5)
|
| 34 |
+
dy = flow[:, 1].float() / (H * 0.5)
|
| 35 |
+
grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3)
|
| 36 |
+
return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_image_pyramid(image, pyramid_levels):
|
| 40 |
+
pyramid = [image]
|
| 41 |
+
for _ in range(1, pyramid_levels):
|
| 42 |
+
image = F.avg_pool2d(image, 2, 2)
|
| 43 |
+
pyramid.append(image)
|
| 44 |
+
return pyramid
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def flow_pyramid_synthesis(residual_pyramid):
|
| 48 |
+
flow = residual_pyramid[-1]
|
| 49 |
+
flow_pyramid = [flow]
|
| 50 |
+
for residual_flow in residual_pyramid[:-1][::-1]:
|
| 51 |
+
flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow)
|
| 52 |
+
flow_pyramid.append(flow)
|
| 53 |
+
flow_pyramid.reverse()
|
| 54 |
+
return flow_pyramid
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def multiply_pyramid(pyramid, scalar):
|
| 58 |
+
return [image * scalar[:, None, None, None] for image in pyramid]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn):
|
| 62 |
+
return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def concatenate_pyramids(pyramid1, pyramid2):
|
| 66 |
+
return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SubTreeExtractor(nn.Module):
|
| 70 |
+
def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops):
|
| 71 |
+
super().__init__()
|
| 72 |
+
convs = []
|
| 73 |
+
for i in range(n_layers):
|
| 74 |
+
out_ch = channels << i
|
| 75 |
+
convs.append(nn.Sequential(
|
| 76 |
+
FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations),
|
| 77 |
+
FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations)))
|
| 78 |
+
in_channels = out_ch
|
| 79 |
+
self.convs = nn.ModuleList(convs)
|
| 80 |
+
|
| 81 |
+
def forward(self, image, n):
|
| 82 |
+
head = image
|
| 83 |
+
pyramid = []
|
| 84 |
+
for i, layer in enumerate(self.convs):
|
| 85 |
+
head = layer(head)
|
| 86 |
+
pyramid.append(head)
|
| 87 |
+
if i < n - 1:
|
| 88 |
+
head = F.avg_pool2d(head, 2, 2)
|
| 89 |
+
return pyramid
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class FeatureExtractor(nn.Module):
|
| 93 |
+
def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations)
|
| 96 |
+
self.sub_levels = sub_levels
|
| 97 |
+
|
| 98 |
+
def forward(self, image_pyramid):
|
| 99 |
+
sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels))
|
| 100 |
+
for i in range(len(image_pyramid))]
|
| 101 |
+
feature_pyramid = []
|
| 102 |
+
for i in range(len(image_pyramid)):
|
| 103 |
+
features = sub_pyramids[i][0]
|
| 104 |
+
for j in range(1, self.sub_levels):
|
| 105 |
+
if j <= i:
|
| 106 |
+
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
|
| 107 |
+
feature_pyramid.append(features)
|
| 108 |
+
# Free sub-pyramids no longer needed by future levels
|
| 109 |
+
if i >= self.sub_levels - 1:
|
| 110 |
+
sub_pyramids[i - self.sub_levels + 1] = None
|
| 111 |
+
return feature_pyramid
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class FlowEstimator(nn.Module):
|
| 115 |
+
def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self._convs = nn.ModuleList()
|
| 118 |
+
for _ in range(num_convs):
|
| 119 |
+
self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations))
|
| 120 |
+
in_channels = num_filters
|
| 121 |
+
self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations))
|
| 122 |
+
self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations))
|
| 123 |
+
|
| 124 |
+
def forward(self, features_a, features_b):
|
| 125 |
+
net = torch.cat([features_a, features_b], dim=1)
|
| 126 |
+
for conv in self._convs:
|
| 127 |
+
net = conv(net)
|
| 128 |
+
return net
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class PyramidFlowEstimator(nn.Module):
|
| 132 |
+
def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
|
| 133 |
+
super().__init__()
|
| 134 |
+
in_channels = filters << 1
|
| 135 |
+
predictors = []
|
| 136 |
+
for i in range(len(flow_convs)):
|
| 137 |
+
predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations))
|
| 138 |
+
in_channels += filters << (i + 2)
|
| 139 |
+
self._predictor = predictors[-1]
|
| 140 |
+
self._predictors = nn.ModuleList(predictors[:-1][::-1])
|
| 141 |
+
|
| 142 |
+
def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn):
|
| 143 |
+
levels = len(feature_pyramid_a)
|
| 144 |
+
v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
|
| 145 |
+
residuals = [v]
|
| 146 |
+
# Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels
|
| 147 |
+
steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)]
|
| 148 |
+
steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)]
|
| 149 |
+
for i, predictor in steps:
|
| 150 |
+
v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2)
|
| 151 |
+
v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v))
|
| 152 |
+
residuals.append(v_residual)
|
| 153 |
+
v = v.add_(v_residual)
|
| 154 |
+
residuals.reverse()
|
| 155 |
+
return residuals
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _get_fusion_channels(level, filters):
|
| 159 |
+
# Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions
|
| 160 |
+
return (sum(filters << i for i in range(level)) + 3 + 2) * 2
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class Fusion(nn.Module):
|
| 164 |
+
def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype)
|
| 167 |
+
self.convs = nn.ModuleList()
|
| 168 |
+
in_channels = _get_fusion_channels(n_layers, filters)
|
| 169 |
+
increase = 0
|
| 170 |
+
for i in range(n_layers)[::-1]:
|
| 171 |
+
num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
|
| 172 |
+
self.convs.append(nn.ModuleList([
|
| 173 |
+
FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations),
|
| 174 |
+
FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations),
|
| 175 |
+
FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)]))
|
| 176 |
+
in_channels = num_filters
|
| 177 |
+
increase = _get_fusion_channels(i, filters) - num_filters // 2
|
| 178 |
+
|
| 179 |
+
def forward(self, pyramid):
|
| 180 |
+
net = pyramid[-1]
|
| 181 |
+
for k, layers in enumerate(self.convs):
|
| 182 |
+
i = len(self.convs) - 1 - k
|
| 183 |
+
net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest"))
|
| 184 |
+
net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1)))
|
| 185 |
+
return self.output_conv(net)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class FILMNet(nn.Module):
|
| 189 |
+
def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4,
|
| 190 |
+
filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.pyramid_levels = pyramid_levels
|
| 193 |
+
self.fusion_pyramid_levels = fusion_pyramid_levels
|
| 194 |
+
self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations)
|
| 195 |
+
self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations)
|
| 196 |
+
self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations)
|
| 197 |
+
self._warp_grids = {}
|
| 198 |
+
|
| 199 |
+
def get_dtype(self):
|
| 200 |
+
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
|
| 201 |
+
|
| 202 |
+
def _build_warp_grids(self, H, W, device):
|
| 203 |
+
"""Pre-compute warp grids for all pyramid levels."""
|
| 204 |
+
if (H, W) in self._warp_grids:
|
| 205 |
+
return
|
| 206 |
+
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
|
| 207 |
+
for _ in range(self.pyramid_levels):
|
| 208 |
+
self._warp_grids[(H, W)] = (
|
| 209 |
+
torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device),
|
| 210 |
+
torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device),
|
| 211 |
+
)
|
| 212 |
+
H, W = H // 2, W // 2
|
| 213 |
+
|
| 214 |
+
def warp(self, image, flow):
|
| 215 |
+
grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])]
|
| 216 |
+
return _warp_core(image, flow, grid_x, grid_y)
|
| 217 |
+
|
| 218 |
+
def extract_features(self, img):
|
| 219 |
+
"""Extract image and feature pyramids for a single frame. Can be cached across pairs."""
|
| 220 |
+
image_pyramid = build_image_pyramid(img, self.pyramid_levels)
|
| 221 |
+
feature_pyramid = self.extract(image_pyramid)
|
| 222 |
+
return image_pyramid, feature_pyramid
|
| 223 |
+
|
| 224 |
+
def forward(self, img0, img1, timestep=0.5, cache=None):
|
| 225 |
+
# FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported)
|
| 226 |
+
t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep
|
| 227 |
+
return self.forward_multi_timestep(img0, img1, [t], cache=cache)
|
| 228 |
+
|
| 229 |
+
def forward_multi_timestep(self, img0, img1, timesteps, cache=None):
|
| 230 |
+
"""Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs."""
|
| 231 |
+
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
|
| 232 |
+
|
| 233 |
+
image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0)
|
| 234 |
+
image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1)
|
| 235 |
+
|
| 236 |
+
fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels]
|
| 237 |
+
bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels]
|
| 238 |
+
|
| 239 |
+
# Build warp targets and free full pyramids (only first fpl levels needed from here)
|
| 240 |
+
fpl = self.fusion_pyramid_levels
|
| 241 |
+
p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]),
|
| 242 |
+
concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])]
|
| 243 |
+
del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1
|
| 244 |
+
|
| 245 |
+
results = []
|
| 246 |
+
dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype)
|
| 247 |
+
for idx in range(len(timesteps)):
|
| 248 |
+
batch_dt = dt_tensors[idx:idx + 1]
|
| 249 |
+
bwd_scaled = multiply_pyramid(bwd_flow, batch_dt)
|
| 250 |
+
fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt)
|
| 251 |
+
fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp)
|
| 252 |
+
bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp)
|
| 253 |
+
aligned = [torch.cat([fw, bw, bf, ff], dim=1)
|
| 254 |
+
for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)]
|
| 255 |
+
del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled
|
| 256 |
+
results.append(self.fuse(aligned))
|
| 257 |
+
del aligned
|
| 258 |
+
return torch.cat(results, dim=0)
|
ComfyUI/comfy_extras/frame_interpolation_models/ifnet.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import comfy.ops
|
| 6 |
+
|
| 7 |
+
ops = comfy.ops.disable_weight_init
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _warp(img, flow, warp_grids):
|
| 11 |
+
B, _, H, W = img.shape
|
| 12 |
+
base_grid, flow_div = warp_grids[(H, W)]
|
| 13 |
+
flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float()
|
| 14 |
+
grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1)
|
| 15 |
+
return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Head(nn.Module):
|
| 19 |
+
def __init__(self, out_ch=4, device=None, dtype=None, operations=ops):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype)
|
| 22 |
+
self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
|
| 23 |
+
self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
|
| 24 |
+
self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype)
|
| 25 |
+
self.relu = nn.LeakyReLU(0.2, True)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
x = self.relu(self.cnn0(x))
|
| 29 |
+
x = self.relu(self.cnn1(x))
|
| 30 |
+
x = self.relu(self.cnn2(x))
|
| 31 |
+
return self.cnn3(x)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ResConv(nn.Module):
|
| 35 |
+
def __init__(self, c, device=None, dtype=None, operations=ops):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype)
|
| 38 |
+
self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype))
|
| 39 |
+
self.relu = nn.LeakyReLU(0.2, True)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
return self.relu(torch.addcmul(x, self.conv(x), self.beta))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class IFBlock(nn.Module):
|
| 46 |
+
def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.conv0 = nn.Sequential(
|
| 49 |
+
nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)),
|
| 50 |
+
nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)))
|
| 51 |
+
self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)))
|
| 52 |
+
self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2))
|
| 53 |
+
|
| 54 |
+
def forward(self, x, flow=None, scale=1):
|
| 55 |
+
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear")
|
| 56 |
+
if flow is not None:
|
| 57 |
+
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale)
|
| 58 |
+
x = torch.cat((x, flow), 1)
|
| 59 |
+
feat = self.convblock(self.conv0(x))
|
| 60 |
+
tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear")
|
| 61 |
+
return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class IFNet(nn.Module):
|
| 65 |
+
def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations)
|
| 68 |
+
block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4
|
| 69 |
+
self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)])
|
| 70 |
+
self.scale_list = [16, 8, 4, 2, 1]
|
| 71 |
+
self.pad_align = 64
|
| 72 |
+
self._warp_grids = {}
|
| 73 |
+
|
| 74 |
+
def get_dtype(self):
|
| 75 |
+
return self.encode.cnn0.weight.dtype
|
| 76 |
+
|
| 77 |
+
def _build_warp_grids(self, H, W, device):
|
| 78 |
+
if (H, W) in self._warp_grids:
|
| 79 |
+
return
|
| 80 |
+
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
|
| 81 |
+
grid_y, grid_x = torch.meshgrid(
|
| 82 |
+
torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32),
|
| 83 |
+
torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij")
|
| 84 |
+
self._warp_grids[(H, W)] = (
|
| 85 |
+
torch.stack((grid_x, grid_y), dim=0).unsqueeze(0),
|
| 86 |
+
torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device))
|
| 87 |
+
|
| 88 |
+
def warp(self, img, flow):
|
| 89 |
+
return _warp(img, flow, self._warp_grids)
|
| 90 |
+
|
| 91 |
+
def extract_features(self, img):
|
| 92 |
+
"""Extract head features for a single frame. Can be cached across pairs."""
|
| 93 |
+
return self.encode(img)
|
| 94 |
+
|
| 95 |
+
def forward(self, img0, img1, timestep=0.5, cache=None):
|
| 96 |
+
if not isinstance(timestep, torch.Tensor):
|
| 97 |
+
timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype)
|
| 98 |
+
|
| 99 |
+
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
|
| 100 |
+
|
| 101 |
+
B = img0.shape[0]
|
| 102 |
+
f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0)
|
| 103 |
+
f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1)
|
| 104 |
+
flow = mask = feat = None
|
| 105 |
+
warped_img0, warped_img1 = img0, img1
|
| 106 |
+
for i, block in enumerate(self.blocks):
|
| 107 |
+
if flow is None:
|
| 108 |
+
flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i])
|
| 109 |
+
else:
|
| 110 |
+
fd, mask, feat = block(
|
| 111 |
+
torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1),
|
| 112 |
+
flow, scale=self.scale_list[i])
|
| 113 |
+
flow = flow.add_(fd)
|
| 114 |
+
warped_img0 = self.warp(img0, flow[:, :2])
|
| 115 |
+
warped_img1 = self.warp(img1, flow[:, 2:4])
|
| 116 |
+
return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def detect_rife_config(state_dict):
|
| 120 |
+
head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW)
|
| 121 |
+
channels = []
|
| 122 |
+
for i in range(5):
|
| 123 |
+
key = f"blocks.{i}.conv0.1.0.weight"
|
| 124 |
+
if key in state_dict:
|
| 125 |
+
channels.append(state_dict[key].shape[0])
|
| 126 |
+
if len(channels) != 5:
|
| 127 |
+
raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}")
|
| 128 |
+
return head_ch, channels
|
ComfyUI/comfy_extras/nodes_ace.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing_extensions import override
|
| 3 |
+
|
| 4 |
+
import comfy.model_management
|
| 5 |
+
import node_helpers
|
| 6 |
+
from comfy_api.latest import ComfyExtension, IO
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TextEncodeAceStepAudio(IO.ComfyNode):
|
| 10 |
+
@classmethod
|
| 11 |
+
def define_schema(cls):
|
| 12 |
+
return IO.Schema(
|
| 13 |
+
node_id="TextEncodeAceStepAudio",
|
| 14 |
+
category="conditioning",
|
| 15 |
+
inputs=[
|
| 16 |
+
IO.Clip.Input("clip"),
|
| 17 |
+
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
| 18 |
+
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
| 19 |
+
IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
| 20 |
+
],
|
| 21 |
+
outputs=[IO.Conditioning.Output()],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput:
|
| 26 |
+
tokens = clip.tokenize(tags, lyrics=lyrics)
|
| 27 |
+
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
| 28 |
+
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
| 29 |
+
return IO.NodeOutput(conditioning)
|
| 30 |
+
|
| 31 |
+
class TextEncodeAceStepAudio15(IO.ComfyNode):
|
| 32 |
+
@classmethod
|
| 33 |
+
def define_schema(cls):
|
| 34 |
+
return IO.Schema(
|
| 35 |
+
node_id="TextEncodeAceStepAudio1.5",
|
| 36 |
+
category="conditioning",
|
| 37 |
+
inputs=[
|
| 38 |
+
IO.Clip.Input("clip"),
|
| 39 |
+
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
| 40 |
+
IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
| 41 |
+
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
| 42 |
+
IO.Int.Input("bpm", default=120, min=10, max=300),
|
| 43 |
+
IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
| 44 |
+
IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
| 45 |
+
IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
| 46 |
+
IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
| 47 |
+
IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
| 48 |
+
IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
| 49 |
+
IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
| 50 |
+
IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
| 51 |
+
IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
| 52 |
+
IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
| 53 |
+
],
|
| 54 |
+
outputs=[IO.Conditioning.Output()],
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput:
|
| 59 |
+
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
| 60 |
+
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
| 61 |
+
return IO.NodeOutput(conditioning)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class EmptyAceStepLatentAudio(IO.ComfyNode):
|
| 65 |
+
@classmethod
|
| 66 |
+
def define_schema(cls):
|
| 67 |
+
return IO.Schema(
|
| 68 |
+
node_id="EmptyAceStepLatentAudio",
|
| 69 |
+
display_name="Empty Ace Step 1.0 Latent Audio",
|
| 70 |
+
category="latent/audio",
|
| 71 |
+
inputs=[
|
| 72 |
+
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
| 73 |
+
IO.Int.Input(
|
| 74 |
+
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
| 75 |
+
),
|
| 76 |
+
],
|
| 77 |
+
outputs=[IO.Latent.Output()],
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
| 82 |
+
length = int(seconds * 44100 / 512 / 8)
|
| 83 |
+
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
| 84 |
+
return IO.NodeOutput({"samples": latent, "type": "audio"})
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class EmptyAceStep15LatentAudio(IO.ComfyNode):
|
| 88 |
+
@classmethod
|
| 89 |
+
def define_schema(cls):
|
| 90 |
+
return IO.Schema(
|
| 91 |
+
node_id="EmptyAceStep1.5LatentAudio",
|
| 92 |
+
display_name="Empty Ace Step 1.5 Latent Audio",
|
| 93 |
+
category="latent/audio",
|
| 94 |
+
inputs=[
|
| 95 |
+
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
| 96 |
+
IO.Int.Input(
|
| 97 |
+
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
| 98 |
+
),
|
| 99 |
+
],
|
| 100 |
+
outputs=[IO.Latent.Output()],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
| 105 |
+
length = round((seconds * 48000 / 1920))
|
| 106 |
+
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
| 107 |
+
return IO.NodeOutput({"samples": latent, "type": "audio"})
|
| 108 |
+
|
| 109 |
+
class ReferenceAudio(IO.ComfyNode):
|
| 110 |
+
@classmethod
|
| 111 |
+
def define_schema(cls):
|
| 112 |
+
return IO.Schema(
|
| 113 |
+
node_id="ReferenceTimbreAudio",
|
| 114 |
+
display_name="Reference Audio",
|
| 115 |
+
category="advanced/conditioning/audio",
|
| 116 |
+
is_experimental=True,
|
| 117 |
+
description="This node sets the reference audio for ace step 1.5",
|
| 118 |
+
inputs=[
|
| 119 |
+
IO.Conditioning.Input("conditioning"),
|
| 120 |
+
IO.Latent.Input("latent", optional=True),
|
| 121 |
+
],
|
| 122 |
+
outputs=[
|
| 123 |
+
IO.Conditioning.Output(),
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def execute(cls, conditioning, latent=None) -> IO.NodeOutput:
|
| 129 |
+
if latent is not None:
|
| 130 |
+
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
|
| 131 |
+
return IO.NodeOutput(conditioning)
|
| 132 |
+
|
| 133 |
+
class AceExtension(ComfyExtension):
|
| 134 |
+
@override
|
| 135 |
+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
| 136 |
+
return [
|
| 137 |
+
TextEncodeAceStepAudio,
|
| 138 |
+
EmptyAceStepLatentAudio,
|
| 139 |
+
TextEncodeAceStepAudio15,
|
| 140 |
+
EmptyAceStep15LatentAudio,
|
| 141 |
+
ReferenceAudio,
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
async def comfy_entrypoint() -> AceExtension:
|
| 145 |
+
return AceExtension()
|
ComfyUI/comfy_extras/nodes_advanced_samplers.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm.auto import trange
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
|
| 6 |
+
import comfy.model_patcher
|
| 7 |
+
import comfy.samplers
|
| 8 |
+
import comfy.utils
|
| 9 |
+
from comfy.k_diffusion.sampling import to_d
|
| 10 |
+
from comfy_api.latest import ComfyExtension, io
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@torch.no_grad()
|
| 14 |
+
def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable=None, total_upscale=2.0, upscale_method="bislerp", upscale_steps=None):
|
| 15 |
+
extra_args = {} if extra_args is None else extra_args
|
| 16 |
+
|
| 17 |
+
if upscale_steps is None:
|
| 18 |
+
upscale_steps = max(len(sigmas) // 2 + 1, 2)
|
| 19 |
+
else:
|
| 20 |
+
upscale_steps += 1
|
| 21 |
+
upscale_steps = min(upscale_steps, len(sigmas) + 1)
|
| 22 |
+
|
| 23 |
+
upscales = np.linspace(1.0, total_upscale, upscale_steps)[1:]
|
| 24 |
+
|
| 25 |
+
orig_shape = x.size()
|
| 26 |
+
s_in = x.new_ones([x.shape[0]])
|
| 27 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 28 |
+
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 29 |
+
if callback is not None:
|
| 30 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 31 |
+
|
| 32 |
+
x = denoised
|
| 33 |
+
if i < len(upscales):
|
| 34 |
+
x = comfy.utils.common_upscale(x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled")
|
| 35 |
+
|
| 36 |
+
if sigmas[i + 1] > 0:
|
| 37 |
+
x += sigmas[i + 1] * torch.randn_like(x)
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SamplerLCMUpscale(io.ComfyNode):
|
| 42 |
+
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def define_schema(cls) -> io.Schema:
|
| 46 |
+
return io.Schema(
|
| 47 |
+
node_id="SamplerLCMUpscale",
|
| 48 |
+
category="sampling/custom_sampling/samplers",
|
| 49 |
+
inputs=[
|
| 50 |
+
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True),
|
| 51 |
+
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True),
|
| 52 |
+
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
| 53 |
+
],
|
| 54 |
+
outputs=[io.Sampler.Output()],
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
|
| 59 |
+
if scale_steps < 0:
|
| 60 |
+
scale_steps = None
|
| 61 |
+
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
| 62 |
+
return io.NodeOutput(sampler)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@torch.no_grad()
|
| 66 |
+
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
| 67 |
+
extra_args = {} if extra_args is None else extra_args
|
| 68 |
+
|
| 69 |
+
temp = [0]
|
| 70 |
+
def post_cfg_function(args):
|
| 71 |
+
temp[0] = args["uncond_denoised"]
|
| 72 |
+
return args["denoised"]
|
| 73 |
+
|
| 74 |
+
model_options = extra_args.get("model_options", {}).copy()
|
| 75 |
+
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
| 76 |
+
|
| 77 |
+
s_in = x.new_ones([x.shape[0]])
|
| 78 |
+
for i in trange(len(sigmas) - 1, disable=disable):
|
| 79 |
+
sigma_hat = sigmas[i]
|
| 80 |
+
denoised = model(x, sigma_hat * s_in, **extra_args)
|
| 81 |
+
d = to_d(x - denoised + temp[0], sigmas[i], denoised)
|
| 82 |
+
if callback is not None:
|
| 83 |
+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
| 84 |
+
dt = sigmas[i + 1] - sigma_hat
|
| 85 |
+
x = x + d * dt
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class SamplerEulerCFGpp(io.ComfyNode):
|
| 90 |
+
@classmethod
|
| 91 |
+
def define_schema(cls) -> io.Schema:
|
| 92 |
+
return io.Schema(
|
| 93 |
+
node_id="SamplerEulerCFGpp",
|
| 94 |
+
display_name="SamplerEulerCFG++",
|
| 95 |
+
category="_for_testing", # "sampling/custom_sampling/samplers"
|
| 96 |
+
inputs=[
|
| 97 |
+
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
|
| 98 |
+
],
|
| 99 |
+
outputs=[io.Sampler.Output()],
|
| 100 |
+
is_experimental=True,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def execute(cls, version) -> io.NodeOutput:
|
| 105 |
+
if version == "alternative":
|
| 106 |
+
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
|
| 107 |
+
else:
|
| 108 |
+
sampler = comfy.samplers.ksampler("euler_cfg_pp")
|
| 109 |
+
return io.NodeOutput(sampler)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class AdvancedSamplersExtension(ComfyExtension):
|
| 113 |
+
@override
|
| 114 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 115 |
+
return [
|
| 116 |
+
SamplerLCMUpscale,
|
| 117 |
+
SamplerEulerCFGpp,
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
async def comfy_entrypoint() -> AdvancedSamplersExtension:
|
| 121 |
+
return AdvancedSamplersExtension()
|
ComfyUI/comfy_extras/nodes_align_your_steps.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
|
| 6 |
+
from comfy_api.latest import ComfyExtension, io
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def loglinear_interp(t_steps, num_steps):
|
| 10 |
+
"""
|
| 11 |
+
Performs log-linear interpolation of a given array of decreasing numbers.
|
| 12 |
+
"""
|
| 13 |
+
xs = np.linspace(0, 1, len(t_steps))
|
| 14 |
+
ys = np.log(t_steps[::-1])
|
| 15 |
+
|
| 16 |
+
new_xs = np.linspace(0, 1, num_steps)
|
| 17 |
+
new_ys = np.interp(new_xs, xs, ys)
|
| 18 |
+
|
| 19 |
+
interped_ys = np.exp(new_ys)[::-1].copy()
|
| 20 |
+
return interped_ys
|
| 21 |
+
|
| 22 |
+
NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
|
| 23 |
+
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
|
| 24 |
+
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
|
| 25 |
+
|
| 26 |
+
class AlignYourStepsScheduler(io.ComfyNode):
|
| 27 |
+
@classmethod
|
| 28 |
+
def define_schema(cls) -> io.Schema:
|
| 29 |
+
return io.Schema(
|
| 30 |
+
node_id="AlignYourStepsScheduler",
|
| 31 |
+
search_aliases=["AYS scheduler"],
|
| 32 |
+
category="sampling/custom_sampling/schedulers",
|
| 33 |
+
inputs=[
|
| 34 |
+
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
|
| 35 |
+
io.Int.Input("steps", default=10, min=1, max=10000),
|
| 36 |
+
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
| 37 |
+
],
|
| 38 |
+
outputs=[io.Sigmas.Output()],
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def get_sigmas(self, model_type, steps, denoise):
|
| 42 |
+
# Deprecated: use the V3 schema's `execute` method instead of this.
|
| 43 |
+
return AlignYourStepsScheduler().execute(model_type, steps, denoise).result
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
|
| 47 |
+
total_steps = steps
|
| 48 |
+
if denoise < 1.0:
|
| 49 |
+
if denoise <= 0.0:
|
| 50 |
+
return io.NodeOutput(torch.FloatTensor([]))
|
| 51 |
+
total_steps = round(steps * denoise)
|
| 52 |
+
|
| 53 |
+
sigmas = NOISE_LEVELS[model_type][:]
|
| 54 |
+
if (steps + 1) != len(sigmas):
|
| 55 |
+
sigmas = loglinear_interp(sigmas, steps + 1)
|
| 56 |
+
|
| 57 |
+
sigmas = sigmas[-(total_steps + 1):]
|
| 58 |
+
sigmas[-1] = 0
|
| 59 |
+
return io.NodeOutput(torch.FloatTensor(sigmas))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class AlignYourStepsExtension(ComfyExtension):
|
| 63 |
+
@override
|
| 64 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 65 |
+
return [
|
| 66 |
+
AlignYourStepsScheduler,
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
async def comfy_entrypoint() -> AlignYourStepsExtension:
|
| 70 |
+
return AlignYourStepsExtension()
|
ComfyUI/comfy_extras/nodes_apg.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing_extensions import override
|
| 3 |
+
|
| 4 |
+
from comfy_api.latest import ComfyExtension, io
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def project(v0, v1):
|
| 8 |
+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
| 9 |
+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
|
| 10 |
+
v0_orthogonal = v0 - v0_parallel
|
| 11 |
+
return v0_parallel, v0_orthogonal
|
| 12 |
+
|
| 13 |
+
class APG(io.ComfyNode):
|
| 14 |
+
@classmethod
|
| 15 |
+
def define_schema(cls) -> io.Schema:
|
| 16 |
+
return io.Schema(
|
| 17 |
+
node_id="APG",
|
| 18 |
+
display_name="Adaptive Projected Guidance",
|
| 19 |
+
category="sampling/custom_sampling",
|
| 20 |
+
inputs=[
|
| 21 |
+
io.Model.Input("model"),
|
| 22 |
+
io.Float.Input(
|
| 23 |
+
"eta",
|
| 24 |
+
default=1.0,
|
| 25 |
+
min=-10.0,
|
| 26 |
+
max=10.0,
|
| 27 |
+
step=0.01,
|
| 28 |
+
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
|
| 29 |
+
advanced=True,
|
| 30 |
+
),
|
| 31 |
+
io.Float.Input(
|
| 32 |
+
"norm_threshold",
|
| 33 |
+
default=5.0,
|
| 34 |
+
min=0.0,
|
| 35 |
+
max=50.0,
|
| 36 |
+
step=0.1,
|
| 37 |
+
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
|
| 38 |
+
advanced=True,
|
| 39 |
+
),
|
| 40 |
+
io.Float.Input(
|
| 41 |
+
"momentum",
|
| 42 |
+
default=0.0,
|
| 43 |
+
min=-5.0,
|
| 44 |
+
max=1.0,
|
| 45 |
+
step=0.01,
|
| 46 |
+
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
|
| 47 |
+
advanced=True,
|
| 48 |
+
),
|
| 49 |
+
],
|
| 50 |
+
outputs=[io.Model.Output()],
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
|
| 55 |
+
running_avg = 0
|
| 56 |
+
prev_sigma = None
|
| 57 |
+
|
| 58 |
+
def pre_cfg_function(args):
|
| 59 |
+
nonlocal running_avg, prev_sigma
|
| 60 |
+
|
| 61 |
+
if len(args["conds_out"]) == 1:
|
| 62 |
+
return args["conds_out"]
|
| 63 |
+
|
| 64 |
+
cond = args["conds_out"][0]
|
| 65 |
+
uncond = args["conds_out"][1]
|
| 66 |
+
sigma = args["sigma"][0]
|
| 67 |
+
cond_scale = args["cond_scale"]
|
| 68 |
+
|
| 69 |
+
if prev_sigma is not None and sigma > prev_sigma:
|
| 70 |
+
running_avg = 0
|
| 71 |
+
prev_sigma = sigma
|
| 72 |
+
|
| 73 |
+
guidance = cond - uncond
|
| 74 |
+
|
| 75 |
+
if momentum != 0:
|
| 76 |
+
if not torch.is_tensor(running_avg):
|
| 77 |
+
running_avg = guidance
|
| 78 |
+
else:
|
| 79 |
+
running_avg = momentum * running_avg + guidance
|
| 80 |
+
guidance = running_avg
|
| 81 |
+
|
| 82 |
+
if norm_threshold > 0:
|
| 83 |
+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
| 84 |
+
scale = torch.minimum(
|
| 85 |
+
torch.ones_like(guidance_norm),
|
| 86 |
+
norm_threshold / guidance_norm
|
| 87 |
+
)
|
| 88 |
+
guidance = guidance * scale
|
| 89 |
+
|
| 90 |
+
guidance_parallel, guidance_orthogonal = project(guidance, cond)
|
| 91 |
+
modified_guidance = guidance_orthogonal + eta * guidance_parallel
|
| 92 |
+
|
| 93 |
+
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
|
| 94 |
+
|
| 95 |
+
return [modified_cond, uncond] + args["conds_out"][2:]
|
| 96 |
+
|
| 97 |
+
m = model.clone()
|
| 98 |
+
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
| 99 |
+
return io.NodeOutput(m)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ApgExtension(ComfyExtension):
|
| 103 |
+
@override
|
| 104 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 105 |
+
return [
|
| 106 |
+
APG,
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
async def comfy_entrypoint() -> ApgExtension:
|
| 110 |
+
return ApgExtension()
|
ComfyUI/comfy_extras/nodes_attention_multiply.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
|
| 3 |
+
from comfy_api.latest import ComfyExtension, io
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def attention_multiply(attn, model, q, k, v, out):
|
| 7 |
+
m = model.clone()
|
| 8 |
+
sd = model.model_state_dict()
|
| 9 |
+
|
| 10 |
+
for key in sd:
|
| 11 |
+
if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)):
|
| 12 |
+
m.add_patches({key: (None,)}, 0.0, q)
|
| 13 |
+
if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)):
|
| 14 |
+
m.add_patches({key: (None,)}, 0.0, k)
|
| 15 |
+
if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)):
|
| 16 |
+
m.add_patches({key: (None,)}, 0.0, v)
|
| 17 |
+
if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)):
|
| 18 |
+
m.add_patches({key: (None,)}, 0.0, out)
|
| 19 |
+
|
| 20 |
+
return m
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class UNetSelfAttentionMultiply(io.ComfyNode):
|
| 24 |
+
@classmethod
|
| 25 |
+
def define_schema(cls) -> io.Schema:
|
| 26 |
+
return io.Schema(
|
| 27 |
+
node_id="UNetSelfAttentionMultiply",
|
| 28 |
+
category="_for_testing/attention_experiments",
|
| 29 |
+
inputs=[
|
| 30 |
+
io.Model.Input("model"),
|
| 31 |
+
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 32 |
+
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 33 |
+
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 34 |
+
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 35 |
+
],
|
| 36 |
+
outputs=[io.Model.Output()],
|
| 37 |
+
is_experimental=True,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
| 42 |
+
m = attention_multiply("attn1", model, q, k, v, out)
|
| 43 |
+
return io.NodeOutput(m)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class UNetCrossAttentionMultiply(io.ComfyNode):
|
| 47 |
+
@classmethod
|
| 48 |
+
def define_schema(cls) -> io.Schema:
|
| 49 |
+
return io.Schema(
|
| 50 |
+
node_id="UNetCrossAttentionMultiply",
|
| 51 |
+
category="_for_testing/attention_experiments",
|
| 52 |
+
inputs=[
|
| 53 |
+
io.Model.Input("model"),
|
| 54 |
+
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 55 |
+
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 56 |
+
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 57 |
+
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 58 |
+
],
|
| 59 |
+
outputs=[io.Model.Output()],
|
| 60 |
+
is_experimental=True,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
| 65 |
+
m = attention_multiply("attn2", model, q, k, v, out)
|
| 66 |
+
return io.NodeOutput(m)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class CLIPAttentionMultiply(io.ComfyNode):
|
| 70 |
+
@classmethod
|
| 71 |
+
def define_schema(cls) -> io.Schema:
|
| 72 |
+
return io.Schema(
|
| 73 |
+
node_id="CLIPAttentionMultiply",
|
| 74 |
+
search_aliases=["clip attention scale", "text encoder attention"],
|
| 75 |
+
category="_for_testing/attention_experiments",
|
| 76 |
+
inputs=[
|
| 77 |
+
io.Clip.Input("clip"),
|
| 78 |
+
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 79 |
+
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 80 |
+
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 81 |
+
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 82 |
+
],
|
| 83 |
+
outputs=[io.Clip.Output()],
|
| 84 |
+
is_experimental=True,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def execute(cls, clip, q, k, v, out) -> io.NodeOutput:
|
| 89 |
+
m = clip.clone()
|
| 90 |
+
sd = m.patcher.model_state_dict()
|
| 91 |
+
|
| 92 |
+
for key in sd:
|
| 93 |
+
if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"):
|
| 94 |
+
m.add_patches({key: (None,)}, 0.0, q)
|
| 95 |
+
if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"):
|
| 96 |
+
m.add_patches({key: (None,)}, 0.0, k)
|
| 97 |
+
if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"):
|
| 98 |
+
m.add_patches({key: (None,)}, 0.0, v)
|
| 99 |
+
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
| 100 |
+
m.add_patches({key: (None,)}, 0.0, out)
|
| 101 |
+
return io.NodeOutput(m)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class UNetTemporalAttentionMultiply(io.ComfyNode):
|
| 105 |
+
@classmethod
|
| 106 |
+
def define_schema(cls) -> io.Schema:
|
| 107 |
+
return io.Schema(
|
| 108 |
+
node_id="UNetTemporalAttentionMultiply",
|
| 109 |
+
category="_for_testing/attention_experiments",
|
| 110 |
+
inputs=[
|
| 111 |
+
io.Model.Input("model"),
|
| 112 |
+
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 113 |
+
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 114 |
+
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 115 |
+
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 116 |
+
],
|
| 117 |
+
outputs=[io.Model.Output()],
|
| 118 |
+
is_experimental=True,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput:
|
| 123 |
+
m = model.clone()
|
| 124 |
+
sd = model.model_state_dict()
|
| 125 |
+
|
| 126 |
+
for k in sd:
|
| 127 |
+
if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")):
|
| 128 |
+
if '.time_stack.' in k:
|
| 129 |
+
m.add_patches({k: (None,)}, 0.0, self_temporal)
|
| 130 |
+
else:
|
| 131 |
+
m.add_patches({k: (None,)}, 0.0, self_structural)
|
| 132 |
+
elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")):
|
| 133 |
+
if '.time_stack.' in k:
|
| 134 |
+
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
| 135 |
+
else:
|
| 136 |
+
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
| 137 |
+
return io.NodeOutput(m)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class AttentionMultiplyExtension(ComfyExtension):
|
| 141 |
+
@override
|
| 142 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 143 |
+
return [
|
| 144 |
+
UNetSelfAttentionMultiply,
|
| 145 |
+
UNetCrossAttentionMultiply,
|
| 146 |
+
CLIPAttentionMultiply,
|
| 147 |
+
UNetTemporalAttentionMultiply,
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
async def comfy_entrypoint() -> AttentionMultiplyExtension:
|
| 151 |
+
return AttentionMultiplyExtension()
|
ComfyUI/comfy_extras/nodes_audio.py
ADDED
|
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import av
|
| 4 |
+
import torchaudio
|
| 5 |
+
import torch
|
| 6 |
+
import comfy.model_management
|
| 7 |
+
import folder_paths
|
| 8 |
+
import os
|
| 9 |
+
import hashlib
|
| 10 |
+
import node_helpers
|
| 11 |
+
import logging
|
| 12 |
+
from typing_extensions import override
|
| 13 |
+
from comfy_api.latest import ComfyExtension, IO, UI
|
| 14 |
+
|
| 15 |
+
class EmptyLatentAudio(IO.ComfyNode):
|
| 16 |
+
@classmethod
|
| 17 |
+
def define_schema(cls):
|
| 18 |
+
return IO.Schema(
|
| 19 |
+
node_id="EmptyLatentAudio",
|
| 20 |
+
display_name="Empty Latent Audio",
|
| 21 |
+
category="latent/audio",
|
| 22 |
+
essentials_category="Audio",
|
| 23 |
+
inputs=[
|
| 24 |
+
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
| 25 |
+
IO.Int.Input(
|
| 26 |
+
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch.",
|
| 27 |
+
),
|
| 28 |
+
],
|
| 29 |
+
outputs=[IO.Latent.Output()],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
| 34 |
+
length = round((seconds * 44100 / 2048) / 2) * 2
|
| 35 |
+
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
| 36 |
+
return IO.NodeOutput({"samples":latent, "type": "audio"})
|
| 37 |
+
|
| 38 |
+
generate = execute # TODO: remove
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ConditioningStableAudio(IO.ComfyNode):
|
| 42 |
+
@classmethod
|
| 43 |
+
def define_schema(cls):
|
| 44 |
+
return IO.Schema(
|
| 45 |
+
node_id="ConditioningStableAudio",
|
| 46 |
+
category="conditioning",
|
| 47 |
+
inputs=[
|
| 48 |
+
IO.Conditioning.Input("positive"),
|
| 49 |
+
IO.Conditioning.Input("negative"),
|
| 50 |
+
IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
|
| 51 |
+
IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
|
| 52 |
+
],
|
| 53 |
+
outputs=[
|
| 54 |
+
IO.Conditioning.Output(display_name="positive"),
|
| 55 |
+
IO.Conditioning.Output(display_name="negative"),
|
| 56 |
+
],
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
|
| 61 |
+
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
| 62 |
+
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
| 63 |
+
return IO.NodeOutput(positive, negative)
|
| 64 |
+
|
| 65 |
+
append = execute # TODO: remove
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class VAEEncodeAudio(IO.ComfyNode):
|
| 69 |
+
@classmethod
|
| 70 |
+
def define_schema(cls):
|
| 71 |
+
return IO.Schema(
|
| 72 |
+
node_id="VAEEncodeAudio",
|
| 73 |
+
search_aliases=["audio to latent"],
|
| 74 |
+
display_name="VAE Encode Audio",
|
| 75 |
+
category="latent/audio",
|
| 76 |
+
inputs=[
|
| 77 |
+
IO.Audio.Input("audio"),
|
| 78 |
+
IO.Vae.Input("vae"),
|
| 79 |
+
],
|
| 80 |
+
outputs=[IO.Latent.Output()],
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def execute(cls, vae, audio) -> IO.NodeOutput:
|
| 85 |
+
sample_rate = audio["sample_rate"]
|
| 86 |
+
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
| 87 |
+
if vae_sample_rate != sample_rate:
|
| 88 |
+
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate)
|
| 89 |
+
else:
|
| 90 |
+
waveform = audio["waveform"]
|
| 91 |
+
|
| 92 |
+
t = vae.encode(waveform.movedim(1, -1))
|
| 93 |
+
return IO.NodeOutput({"samples": t})
|
| 94 |
+
|
| 95 |
+
encode = execute # TODO: remove
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
| 99 |
+
if tile is not None:
|
| 100 |
+
audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1)
|
| 101 |
+
else:
|
| 102 |
+
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
| 103 |
+
|
| 104 |
+
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
| 105 |
+
std[std < 1.0] = 1.0
|
| 106 |
+
audio /= std
|
| 107 |
+
vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
|
| 108 |
+
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class VAEDecodeAudio(IO.ComfyNode):
|
| 112 |
+
@classmethod
|
| 113 |
+
def define_schema(cls):
|
| 114 |
+
return IO.Schema(
|
| 115 |
+
node_id="VAEDecodeAudio",
|
| 116 |
+
search_aliases=["latent to audio"],
|
| 117 |
+
display_name="VAE Decode Audio",
|
| 118 |
+
category="latent/audio",
|
| 119 |
+
inputs=[
|
| 120 |
+
IO.Latent.Input("samples"),
|
| 121 |
+
IO.Vae.Input("vae"),
|
| 122 |
+
],
|
| 123 |
+
outputs=[IO.Audio.Output()],
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def execute(cls, vae, samples) -> IO.NodeOutput:
|
| 128 |
+
return IO.NodeOutput(vae_decode_audio(vae, samples))
|
| 129 |
+
|
| 130 |
+
decode = execute # TODO: remove
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class VAEDecodeAudioTiled(IO.ComfyNode):
|
| 134 |
+
@classmethod
|
| 135 |
+
def define_schema(cls):
|
| 136 |
+
return IO.Schema(
|
| 137 |
+
node_id="VAEDecodeAudioTiled",
|
| 138 |
+
search_aliases=["latent to audio"],
|
| 139 |
+
display_name="VAE Decode Audio (Tiled)",
|
| 140 |
+
category="latent/audio",
|
| 141 |
+
inputs=[
|
| 142 |
+
IO.Latent.Input("samples"),
|
| 143 |
+
IO.Vae.Input("vae"),
|
| 144 |
+
IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
|
| 145 |
+
IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
|
| 146 |
+
],
|
| 147 |
+
outputs=[IO.Audio.Output()],
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
@classmethod
|
| 151 |
+
def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
|
| 152 |
+
return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class SaveAudio(IO.ComfyNode):
|
| 156 |
+
@classmethod
|
| 157 |
+
def define_schema(cls):
|
| 158 |
+
return IO.Schema(
|
| 159 |
+
node_id="SaveAudio",
|
| 160 |
+
search_aliases=["export flac"],
|
| 161 |
+
display_name="Save Audio (FLAC)",
|
| 162 |
+
category="audio",
|
| 163 |
+
essentials_category="Audio",
|
| 164 |
+
inputs=[
|
| 165 |
+
IO.Audio.Input("audio"),
|
| 166 |
+
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
| 167 |
+
],
|
| 168 |
+
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
| 169 |
+
is_output_node=True,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
@classmethod
|
| 173 |
+
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
|
| 174 |
+
return IO.NodeOutput(
|
| 175 |
+
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
save_flac = execute # TODO: remove
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class SaveAudioMP3(IO.ComfyNode):
|
| 182 |
+
@classmethod
|
| 183 |
+
def define_schema(cls):
|
| 184 |
+
return IO.Schema(
|
| 185 |
+
node_id="SaveAudioMP3",
|
| 186 |
+
search_aliases=["export mp3"],
|
| 187 |
+
display_name="Save Audio (MP3)",
|
| 188 |
+
category="audio",
|
| 189 |
+
essentials_category="Audio",
|
| 190 |
+
inputs=[
|
| 191 |
+
IO.Audio.Input("audio"),
|
| 192 |
+
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
| 193 |
+
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
| 194 |
+
],
|
| 195 |
+
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
| 196 |
+
is_output_node=True,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
@classmethod
|
| 200 |
+
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
|
| 201 |
+
return IO.NodeOutput(
|
| 202 |
+
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
| 203 |
+
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
| 204 |
+
)
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
save_mp3 = execute # TODO: remove
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class SaveAudioOpus(IO.ComfyNode):
|
| 211 |
+
@classmethod
|
| 212 |
+
def define_schema(cls):
|
| 213 |
+
return IO.Schema(
|
| 214 |
+
node_id="SaveAudioOpus",
|
| 215 |
+
search_aliases=["export opus"],
|
| 216 |
+
display_name="Save Audio (Opus)",
|
| 217 |
+
category="audio",
|
| 218 |
+
inputs=[
|
| 219 |
+
IO.Audio.Input("audio"),
|
| 220 |
+
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
| 221 |
+
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
| 222 |
+
],
|
| 223 |
+
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
| 224 |
+
is_output_node=True,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
@classmethod
|
| 228 |
+
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
|
| 229 |
+
return IO.NodeOutput(
|
| 230 |
+
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
| 231 |
+
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
save_opus = execute # TODO: remove
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class PreviewAudio(IO.ComfyNode):
|
| 239 |
+
@classmethod
|
| 240 |
+
def define_schema(cls):
|
| 241 |
+
return IO.Schema(
|
| 242 |
+
node_id="PreviewAudio",
|
| 243 |
+
search_aliases=["play audio"],
|
| 244 |
+
display_name="Preview Audio",
|
| 245 |
+
category="audio",
|
| 246 |
+
inputs=[
|
| 247 |
+
IO.Audio.Input("audio"),
|
| 248 |
+
],
|
| 249 |
+
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
| 250 |
+
is_output_node=True,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
@classmethod
|
| 254 |
+
def execute(cls, audio) -> IO.NodeOutput:
|
| 255 |
+
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
|
| 256 |
+
|
| 257 |
+
save_flac = execute # TODO: remove
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
| 261 |
+
"""Convert audio to float 32 bits PCM format."""
|
| 262 |
+
if wav.dtype.is_floating_point:
|
| 263 |
+
return wav
|
| 264 |
+
elif wav.dtype == torch.int16:
|
| 265 |
+
return wav.float() / (2 ** 15)
|
| 266 |
+
elif wav.dtype == torch.int32:
|
| 267 |
+
return wav.float() / (2 ** 31)
|
| 268 |
+
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
| 269 |
+
|
| 270 |
+
def load(filepath: str) -> tuple[torch.Tensor, int]:
|
| 271 |
+
with av.open(filepath) as af:
|
| 272 |
+
if not af.streams.audio:
|
| 273 |
+
raise ValueError("No audio stream found in the file.")
|
| 274 |
+
|
| 275 |
+
stream = af.streams.audio[0]
|
| 276 |
+
sr = stream.codec_context.sample_rate
|
| 277 |
+
n_channels = stream.channels
|
| 278 |
+
|
| 279 |
+
frames = []
|
| 280 |
+
length = 0
|
| 281 |
+
for frame in af.decode(streams=stream.index):
|
| 282 |
+
buf = torch.from_numpy(frame.to_ndarray())
|
| 283 |
+
if buf.shape[0] != n_channels:
|
| 284 |
+
buf = buf.view(-1, n_channels).t()
|
| 285 |
+
|
| 286 |
+
frames.append(buf)
|
| 287 |
+
length += buf.shape[1]
|
| 288 |
+
|
| 289 |
+
if not frames:
|
| 290 |
+
raise ValueError("No audio frames decoded.")
|
| 291 |
+
|
| 292 |
+
wav = torch.cat(frames, dim=1)
|
| 293 |
+
wav = f32_pcm(wav)
|
| 294 |
+
return wav, sr
|
| 295 |
+
|
| 296 |
+
class LoadAudio(IO.ComfyNode):
|
| 297 |
+
@classmethod
|
| 298 |
+
def define_schema(cls):
|
| 299 |
+
input_dir = folder_paths.get_input_directory()
|
| 300 |
+
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
| 301 |
+
return IO.Schema(
|
| 302 |
+
node_id="LoadAudio",
|
| 303 |
+
search_aliases=["import audio", "open audio", "audio file"],
|
| 304 |
+
display_name="Load Audio",
|
| 305 |
+
category="audio",
|
| 306 |
+
essentials_category="Audio",
|
| 307 |
+
inputs=[
|
| 308 |
+
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
|
| 309 |
+
],
|
| 310 |
+
outputs=[IO.Audio.Output()],
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
@classmethod
|
| 314 |
+
def execute(cls, audio) -> IO.NodeOutput:
|
| 315 |
+
audio_path = folder_paths.get_annotated_filepath(audio)
|
| 316 |
+
waveform, sample_rate = load(audio_path)
|
| 317 |
+
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
| 318 |
+
return IO.NodeOutput(audio)
|
| 319 |
+
|
| 320 |
+
@classmethod
|
| 321 |
+
def fingerprint_inputs(cls, audio):
|
| 322 |
+
image_path = folder_paths.get_annotated_filepath(audio)
|
| 323 |
+
m = hashlib.sha256()
|
| 324 |
+
with open(image_path, 'rb') as f:
|
| 325 |
+
m.update(f.read())
|
| 326 |
+
return m.digest().hex()
|
| 327 |
+
|
| 328 |
+
@classmethod
|
| 329 |
+
def validate_inputs(cls, audio):
|
| 330 |
+
if not folder_paths.exists_annotated_filepath(audio):
|
| 331 |
+
return "Invalid audio file: {}".format(audio)
|
| 332 |
+
return True
|
| 333 |
+
|
| 334 |
+
load = execute # TODO: remove
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class RecordAudio(IO.ComfyNode):
|
| 338 |
+
@classmethod
|
| 339 |
+
def define_schema(cls):
|
| 340 |
+
return IO.Schema(
|
| 341 |
+
node_id="RecordAudio",
|
| 342 |
+
search_aliases=["microphone input", "audio capture", "voice input"],
|
| 343 |
+
display_name="Record Audio",
|
| 344 |
+
category="audio",
|
| 345 |
+
inputs=[
|
| 346 |
+
IO.Custom("AUDIO_RECORD").Input("audio"),
|
| 347 |
+
],
|
| 348 |
+
outputs=[IO.Audio.Output()],
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
@classmethod
|
| 352 |
+
def execute(cls, audio) -> IO.NodeOutput:
|
| 353 |
+
audio_path = folder_paths.get_annotated_filepath(audio)
|
| 354 |
+
|
| 355 |
+
waveform, sample_rate = load(audio_path)
|
| 356 |
+
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
| 357 |
+
return IO.NodeOutput(audio)
|
| 358 |
+
|
| 359 |
+
load = execute # TODO: remove
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class TrimAudioDuration(IO.ComfyNode):
|
| 363 |
+
@classmethod
|
| 364 |
+
def define_schema(cls):
|
| 365 |
+
return IO.Schema(
|
| 366 |
+
node_id="TrimAudioDuration",
|
| 367 |
+
search_aliases=["cut audio", "audio clip", "shorten audio"],
|
| 368 |
+
display_name="Trim Audio Duration",
|
| 369 |
+
description="Trim audio tensor into chosen time range.",
|
| 370 |
+
category="audio",
|
| 371 |
+
inputs=[
|
| 372 |
+
IO.Audio.Input("audio"),
|
| 373 |
+
IO.Float.Input(
|
| 374 |
+
"start_index",
|
| 375 |
+
default=0.0,
|
| 376 |
+
min=-0xffffffffffffffff,
|
| 377 |
+
max=0xffffffffffffffff,
|
| 378 |
+
step=0.01,
|
| 379 |
+
tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
|
| 380 |
+
),
|
| 381 |
+
IO.Float.Input(
|
| 382 |
+
"duration",
|
| 383 |
+
default=60.0,
|
| 384 |
+
min=0.0,
|
| 385 |
+
step=0.01,
|
| 386 |
+
tooltip="Duration in seconds",
|
| 387 |
+
),
|
| 388 |
+
],
|
| 389 |
+
outputs=[IO.Audio.Output()],
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
@classmethod
|
| 393 |
+
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
|
| 394 |
+
waveform = audio["waveform"]
|
| 395 |
+
sample_rate = audio["sample_rate"]
|
| 396 |
+
audio_length = waveform.shape[-1]
|
| 397 |
+
|
| 398 |
+
if start_index < 0:
|
| 399 |
+
start_frame = audio_length + int(round(start_index * sample_rate))
|
| 400 |
+
else:
|
| 401 |
+
start_frame = int(round(start_index * sample_rate))
|
| 402 |
+
start_frame = max(0, min(start_frame, audio_length - 1))
|
| 403 |
+
|
| 404 |
+
end_frame = start_frame + int(round(duration * sample_rate))
|
| 405 |
+
end_frame = max(0, min(end_frame, audio_length))
|
| 406 |
+
|
| 407 |
+
if start_frame >= end_frame:
|
| 408 |
+
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
| 409 |
+
|
| 410 |
+
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
|
| 411 |
+
|
| 412 |
+
trim = execute # TODO: remove
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class SplitAudioChannels(IO.ComfyNode):
|
| 416 |
+
@classmethod
|
| 417 |
+
def define_schema(cls):
|
| 418 |
+
return IO.Schema(
|
| 419 |
+
node_id="SplitAudioChannels",
|
| 420 |
+
search_aliases=["stereo to mono"],
|
| 421 |
+
display_name="Split Audio Channels",
|
| 422 |
+
description="Separates the audio into left and right channels.",
|
| 423 |
+
category="audio",
|
| 424 |
+
inputs=[
|
| 425 |
+
IO.Audio.Input("audio"),
|
| 426 |
+
],
|
| 427 |
+
outputs=[
|
| 428 |
+
IO.Audio.Output(display_name="left"),
|
| 429 |
+
IO.Audio.Output(display_name="right"),
|
| 430 |
+
],
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
@classmethod
|
| 434 |
+
def execute(cls, audio) -> IO.NodeOutput:
|
| 435 |
+
waveform = audio["waveform"]
|
| 436 |
+
sample_rate = audio["sample_rate"]
|
| 437 |
+
|
| 438 |
+
if waveform.shape[1] != 2:
|
| 439 |
+
raise ValueError("AudioSplit: Input audio has only one channel.")
|
| 440 |
+
|
| 441 |
+
left_channel = waveform[..., 0:1, :]
|
| 442 |
+
right_channel = waveform[..., 1:2, :]
|
| 443 |
+
|
| 444 |
+
return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
| 445 |
+
|
| 446 |
+
separate = execute # TODO: remove
|
| 447 |
+
|
| 448 |
+
class JoinAudioChannels(IO.ComfyNode):
|
| 449 |
+
@classmethod
|
| 450 |
+
def define_schema(cls):
|
| 451 |
+
return IO.Schema(
|
| 452 |
+
node_id="JoinAudioChannels",
|
| 453 |
+
display_name="Join Audio Channels",
|
| 454 |
+
description="Joins left and right mono audio channels into a stereo audio.",
|
| 455 |
+
category="audio",
|
| 456 |
+
inputs=[
|
| 457 |
+
IO.Audio.Input("audio_left"),
|
| 458 |
+
IO.Audio.Input("audio_right"),
|
| 459 |
+
],
|
| 460 |
+
outputs=[
|
| 461 |
+
IO.Audio.Output(display_name="audio"),
|
| 462 |
+
],
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
@classmethod
|
| 466 |
+
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
|
| 467 |
+
waveform_left = audio_left["waveform"]
|
| 468 |
+
sample_rate_left = audio_left["sample_rate"]
|
| 469 |
+
waveform_right = audio_right["waveform"]
|
| 470 |
+
sample_rate_right = audio_right["sample_rate"]
|
| 471 |
+
|
| 472 |
+
if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1:
|
| 473 |
+
raise ValueError("AudioJoin: Both input audios must be mono.")
|
| 474 |
+
|
| 475 |
+
# Handle different sample rates by resampling to the higher rate
|
| 476 |
+
waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates(
|
| 477 |
+
waveform_left, sample_rate_left, waveform_right, sample_rate_right
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Handle different lengths by trimming to the shorter length
|
| 481 |
+
length_left = waveform_left.shape[-1]
|
| 482 |
+
length_right = waveform_right.shape[-1]
|
| 483 |
+
|
| 484 |
+
if length_left != length_right:
|
| 485 |
+
min_length = min(length_left, length_right)
|
| 486 |
+
if length_left > min_length:
|
| 487 |
+
logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.")
|
| 488 |
+
waveform_left = waveform_left[..., :min_length]
|
| 489 |
+
if length_right > min_length:
|
| 490 |
+
logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.")
|
| 491 |
+
waveform_right = waveform_right[..., :min_length]
|
| 492 |
+
|
| 493 |
+
# Join the channels into stereo
|
| 494 |
+
left_channel = waveform_left[..., 0:1, :]
|
| 495 |
+
right_channel = waveform_right[..., 0:1, :]
|
| 496 |
+
stereo_waveform = torch.cat([left_channel, right_channel], dim=1)
|
| 497 |
+
|
| 498 |
+
return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate})
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
| 502 |
+
if sample_rate_1 != sample_rate_2:
|
| 503 |
+
if sample_rate_1 > sample_rate_2:
|
| 504 |
+
waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1)
|
| 505 |
+
output_sample_rate = sample_rate_1
|
| 506 |
+
logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.")
|
| 507 |
+
else:
|
| 508 |
+
waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2)
|
| 509 |
+
output_sample_rate = sample_rate_2
|
| 510 |
+
logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.")
|
| 511 |
+
else:
|
| 512 |
+
output_sample_rate = sample_rate_1
|
| 513 |
+
return waveform_1, waveform_2, output_sample_rate
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
class AudioConcat(IO.ComfyNode):
|
| 517 |
+
@classmethod
|
| 518 |
+
def define_schema(cls):
|
| 519 |
+
return IO.Schema(
|
| 520 |
+
node_id="AudioConcat",
|
| 521 |
+
search_aliases=["join audio", "combine audio", "append audio"],
|
| 522 |
+
display_name="Audio Concat",
|
| 523 |
+
description="Concatenates the audio1 to audio2 in the specified direction.",
|
| 524 |
+
category="audio",
|
| 525 |
+
inputs=[
|
| 526 |
+
IO.Audio.Input("audio1"),
|
| 527 |
+
IO.Audio.Input("audio2"),
|
| 528 |
+
IO.Combo.Input(
|
| 529 |
+
"direction",
|
| 530 |
+
options=['after', 'before'],
|
| 531 |
+
default="after",
|
| 532 |
+
tooltip="Whether to append audio2 after or before audio1.",
|
| 533 |
+
)
|
| 534 |
+
],
|
| 535 |
+
outputs=[IO.Audio.Output()],
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
@classmethod
|
| 539 |
+
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
|
| 540 |
+
waveform_1 = audio1["waveform"]
|
| 541 |
+
waveform_2 = audio2["waveform"]
|
| 542 |
+
sample_rate_1 = audio1["sample_rate"]
|
| 543 |
+
sample_rate_2 = audio2["sample_rate"]
|
| 544 |
+
|
| 545 |
+
if waveform_1.shape[1] == 1:
|
| 546 |
+
waveform_1 = waveform_1.repeat(1, 2, 1)
|
| 547 |
+
logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.")
|
| 548 |
+
if waveform_2.shape[1] == 1:
|
| 549 |
+
waveform_2 = waveform_2.repeat(1, 2, 1)
|
| 550 |
+
logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.")
|
| 551 |
+
|
| 552 |
+
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
|
| 553 |
+
|
| 554 |
+
if direction == 'after':
|
| 555 |
+
concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)
|
| 556 |
+
elif direction == 'before':
|
| 557 |
+
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
|
| 558 |
+
|
| 559 |
+
return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
|
| 560 |
+
|
| 561 |
+
concat = execute # TODO: remove
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class AudioMerge(IO.ComfyNode):
|
| 565 |
+
@classmethod
|
| 566 |
+
def define_schema(cls):
|
| 567 |
+
return IO.Schema(
|
| 568 |
+
node_id="AudioMerge",
|
| 569 |
+
search_aliases=["mix audio", "overlay audio", "layer audio"],
|
| 570 |
+
display_name="Audio Merge",
|
| 571 |
+
description="Combine two audio tracks by overlaying their waveforms.",
|
| 572 |
+
category="audio",
|
| 573 |
+
inputs=[
|
| 574 |
+
IO.Audio.Input("audio1"),
|
| 575 |
+
IO.Audio.Input("audio2"),
|
| 576 |
+
IO.Combo.Input(
|
| 577 |
+
"merge_method",
|
| 578 |
+
options=["add", "mean", "subtract", "multiply"],
|
| 579 |
+
tooltip="The method used to combine the audio waveforms.",
|
| 580 |
+
)
|
| 581 |
+
],
|
| 582 |
+
outputs=[IO.Audio.Output()],
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
@classmethod
|
| 586 |
+
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
|
| 587 |
+
waveform_1 = audio1["waveform"]
|
| 588 |
+
waveform_2 = audio2["waveform"]
|
| 589 |
+
sample_rate_1 = audio1["sample_rate"]
|
| 590 |
+
sample_rate_2 = audio2["sample_rate"]
|
| 591 |
+
|
| 592 |
+
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
|
| 593 |
+
|
| 594 |
+
length_1 = waveform_1.shape[-1]
|
| 595 |
+
length_2 = waveform_2.shape[-1]
|
| 596 |
+
|
| 597 |
+
if length_2 > length_1:
|
| 598 |
+
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
|
| 599 |
+
waveform_2 = waveform_2[..., :length_1]
|
| 600 |
+
elif length_2 < length_1:
|
| 601 |
+
logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.")
|
| 602 |
+
pad_shape = list(waveform_2.shape)
|
| 603 |
+
pad_shape[-1] = length_1 - length_2
|
| 604 |
+
pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device)
|
| 605 |
+
waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1)
|
| 606 |
+
|
| 607 |
+
if merge_method == "add":
|
| 608 |
+
waveform = waveform_1 + waveform_2
|
| 609 |
+
elif merge_method == "subtract":
|
| 610 |
+
waveform = waveform_1 - waveform_2
|
| 611 |
+
elif merge_method == "multiply":
|
| 612 |
+
waveform = waveform_1 * waveform_2
|
| 613 |
+
elif merge_method == "mean":
|
| 614 |
+
waveform = (waveform_1 + waveform_2) / 2
|
| 615 |
+
|
| 616 |
+
max_val = waveform.abs().max()
|
| 617 |
+
if max_val > 1.0:
|
| 618 |
+
waveform = waveform / max_val
|
| 619 |
+
|
| 620 |
+
return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
|
| 621 |
+
|
| 622 |
+
merge = execute # TODO: remove
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class AudioAdjustVolume(IO.ComfyNode):
|
| 626 |
+
@classmethod
|
| 627 |
+
def define_schema(cls):
|
| 628 |
+
return IO.Schema(
|
| 629 |
+
node_id="AudioAdjustVolume",
|
| 630 |
+
search_aliases=["audio gain", "loudness", "audio level"],
|
| 631 |
+
display_name="Audio Adjust Volume",
|
| 632 |
+
category="audio",
|
| 633 |
+
inputs=[
|
| 634 |
+
IO.Audio.Input("audio"),
|
| 635 |
+
IO.Int.Input(
|
| 636 |
+
"volume",
|
| 637 |
+
default=1,
|
| 638 |
+
min=-100,
|
| 639 |
+
max=100,
|
| 640 |
+
tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
|
| 641 |
+
)
|
| 642 |
+
],
|
| 643 |
+
outputs=[IO.Audio.Output()],
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
@classmethod
|
| 647 |
+
def execute(cls, audio, volume) -> IO.NodeOutput:
|
| 648 |
+
if volume == 0:
|
| 649 |
+
return IO.NodeOutput(audio)
|
| 650 |
+
waveform = audio["waveform"]
|
| 651 |
+
sample_rate = audio["sample_rate"]
|
| 652 |
+
|
| 653 |
+
gain = 10 ** (volume / 20)
|
| 654 |
+
waveform = waveform * gain
|
| 655 |
+
|
| 656 |
+
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
| 657 |
+
|
| 658 |
+
adjust_volume = execute # TODO: remove
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
class EmptyAudio(IO.ComfyNode):
|
| 662 |
+
@classmethod
|
| 663 |
+
def define_schema(cls):
|
| 664 |
+
return IO.Schema(
|
| 665 |
+
node_id="EmptyAudio",
|
| 666 |
+
search_aliases=["blank audio"],
|
| 667 |
+
display_name="Empty Audio",
|
| 668 |
+
category="audio",
|
| 669 |
+
inputs=[
|
| 670 |
+
IO.Float.Input(
|
| 671 |
+
"duration",
|
| 672 |
+
default=60.0,
|
| 673 |
+
min=0.0,
|
| 674 |
+
max=0xffffffffffffffff,
|
| 675 |
+
step=0.01,
|
| 676 |
+
tooltip="Duration of the empty audio clip in seconds",
|
| 677 |
+
),
|
| 678 |
+
IO.Int.Input(
|
| 679 |
+
"sample_rate",
|
| 680 |
+
default=44100,
|
| 681 |
+
tooltip="Sample rate of the empty audio clip.",
|
| 682 |
+
min=1,
|
| 683 |
+
max=192000,
|
| 684 |
+
advanced=True,
|
| 685 |
+
),
|
| 686 |
+
IO.Int.Input(
|
| 687 |
+
"channels",
|
| 688 |
+
default=2,
|
| 689 |
+
min=1,
|
| 690 |
+
max=2,
|
| 691 |
+
tooltip="Number of audio channels (1 for mono, 2 for stereo).",
|
| 692 |
+
advanced=True,
|
| 693 |
+
),
|
| 694 |
+
],
|
| 695 |
+
outputs=[IO.Audio.Output()],
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
@classmethod
|
| 699 |
+
def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
|
| 700 |
+
num_samples = int(round(duration * sample_rate))
|
| 701 |
+
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
|
| 702 |
+
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
| 703 |
+
|
| 704 |
+
create_empty_audio = execute # TODO: remove
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class AudioEqualizer3Band(IO.ComfyNode):
|
| 708 |
+
@classmethod
|
| 709 |
+
def define_schema(cls):
|
| 710 |
+
return IO.Schema(
|
| 711 |
+
node_id="AudioEqualizer3Band",
|
| 712 |
+
search_aliases=["eq", "bass boost", "treble boost", "equalizer"],
|
| 713 |
+
display_name="Audio Equalizer (3-Band)",
|
| 714 |
+
category="audio",
|
| 715 |
+
is_experimental=True,
|
| 716 |
+
inputs=[
|
| 717 |
+
IO.Audio.Input("audio"),
|
| 718 |
+
IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"),
|
| 719 |
+
IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"),
|
| 720 |
+
IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"),
|
| 721 |
+
IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"),
|
| 722 |
+
IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"),
|
| 723 |
+
IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"),
|
| 724 |
+
IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"),
|
| 725 |
+
],
|
| 726 |
+
outputs=[IO.Audio.Output()],
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
@classmethod
|
| 730 |
+
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
|
| 731 |
+
waveform = audio["waveform"]
|
| 732 |
+
sample_rate = audio["sample_rate"]
|
| 733 |
+
eq_waveform = waveform.clone()
|
| 734 |
+
|
| 735 |
+
# 1. Apply Low Shelf (Bass)
|
| 736 |
+
if low_gain_dB != 0:
|
| 737 |
+
eq_waveform = torchaudio.functional.bass_biquad(
|
| 738 |
+
eq_waveform,
|
| 739 |
+
sample_rate,
|
| 740 |
+
gain=low_gain_dB,
|
| 741 |
+
central_freq=float(low_freq),
|
| 742 |
+
Q=0.707
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# 2. Apply Peaking EQ (Mids)
|
| 746 |
+
if mid_gain_dB != 0:
|
| 747 |
+
eq_waveform = torchaudio.functional.equalizer_biquad(
|
| 748 |
+
eq_waveform,
|
| 749 |
+
sample_rate,
|
| 750 |
+
center_freq=float(mid_freq),
|
| 751 |
+
gain=mid_gain_dB,
|
| 752 |
+
Q=mid_q
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
# 3. Apply High Shelf (Treble)
|
| 756 |
+
if high_gain_dB != 0:
|
| 757 |
+
eq_waveform = torchaudio.functional.treble_biquad(
|
| 758 |
+
eq_waveform,
|
| 759 |
+
sample_rate,
|
| 760 |
+
gain=high_gain_dB,
|
| 761 |
+
central_freq=float(high_freq),
|
| 762 |
+
Q=0.707
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate})
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
class AudioExtension(ComfyExtension):
|
| 769 |
+
@override
|
| 770 |
+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
| 771 |
+
return [
|
| 772 |
+
EmptyLatentAudio,
|
| 773 |
+
VAEEncodeAudio,
|
| 774 |
+
VAEDecodeAudio,
|
| 775 |
+
VAEDecodeAudioTiled,
|
| 776 |
+
SaveAudio,
|
| 777 |
+
SaveAudioMP3,
|
| 778 |
+
SaveAudioOpus,
|
| 779 |
+
LoadAudio,
|
| 780 |
+
PreviewAudio,
|
| 781 |
+
ConditioningStableAudio,
|
| 782 |
+
RecordAudio,
|
| 783 |
+
TrimAudioDuration,
|
| 784 |
+
SplitAudioChannels,
|
| 785 |
+
JoinAudioChannels,
|
| 786 |
+
AudioConcat,
|
| 787 |
+
AudioMerge,
|
| 788 |
+
AudioAdjustVolume,
|
| 789 |
+
EmptyAudio,
|
| 790 |
+
AudioEqualizer3Band,
|
| 791 |
+
]
|
| 792 |
+
|
| 793 |
+
async def comfy_entrypoint() -> AudioExtension:
|
| 794 |
+
return AudioExtension()
|
ComfyUI/comfy_extras/nodes_audio_encoder.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import folder_paths
|
| 2 |
+
import comfy.audio_encoders.audio_encoders
|
| 3 |
+
import comfy.utils
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AudioEncoderLoader(io.ComfyNode):
|
| 9 |
+
@classmethod
|
| 10 |
+
def define_schema(cls) -> io.Schema:
|
| 11 |
+
return io.Schema(
|
| 12 |
+
node_id="AudioEncoderLoader",
|
| 13 |
+
category="loaders",
|
| 14 |
+
inputs=[
|
| 15 |
+
io.Combo.Input(
|
| 16 |
+
"audio_encoder_name",
|
| 17 |
+
options=folder_paths.get_filename_list("audio_encoders"),
|
| 18 |
+
),
|
| 19 |
+
],
|
| 20 |
+
outputs=[io.AudioEncoder.Output()],
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
@classmethod
|
| 24 |
+
def execute(cls, audio_encoder_name) -> io.NodeOutput:
|
| 25 |
+
audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
|
| 26 |
+
sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
|
| 27 |
+
audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
|
| 28 |
+
if audio_encoder is None:
|
| 29 |
+
raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
|
| 30 |
+
return io.NodeOutput(audio_encoder)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AudioEncoderEncode(io.ComfyNode):
|
| 34 |
+
@classmethod
|
| 35 |
+
def define_schema(cls) -> io.Schema:
|
| 36 |
+
return io.Schema(
|
| 37 |
+
node_id="AudioEncoderEncode",
|
| 38 |
+
category="conditioning",
|
| 39 |
+
inputs=[
|
| 40 |
+
io.AudioEncoder.Input("audio_encoder"),
|
| 41 |
+
io.Audio.Input("audio"),
|
| 42 |
+
],
|
| 43 |
+
outputs=[io.AudioEncoderOutput.Output()],
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def execute(cls, audio_encoder, audio) -> io.NodeOutput:
|
| 48 |
+
output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
|
| 49 |
+
return io.NodeOutput(output)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class AudioEncoder(ComfyExtension):
|
| 53 |
+
@override
|
| 54 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 55 |
+
return [
|
| 56 |
+
AudioEncoderLoader,
|
| 57 |
+
AudioEncoderEncode,
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
async def comfy_entrypoint() -> AudioEncoder:
|
| 62 |
+
return AudioEncoder()
|
ComfyUI/comfy_extras/nodes_camera_trajectory.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nodes
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from typing_extensions import override
|
| 6 |
+
import comfy.model_management
|
| 7 |
+
|
| 8 |
+
from comfy_api.latest import ComfyExtension, io
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
CAMERA_DICT = {
|
| 12 |
+
"base_T_norm": 1.5,
|
| 13 |
+
"base_angle": np.pi/3,
|
| 14 |
+
"Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]},
|
| 15 |
+
"Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]},
|
| 16 |
+
"Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]},
|
| 17 |
+
"Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]},
|
| 18 |
+
"Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]},
|
| 19 |
+
"Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]},
|
| 20 |
+
"Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]},
|
| 21 |
+
"Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]},
|
| 22 |
+
"ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]},
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
|
| 27 |
+
|
| 28 |
+
def get_relative_pose(cam_params):
|
| 29 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 30 |
+
"""
|
| 31 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
| 32 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
| 33 |
+
cam_to_origin = 0
|
| 34 |
+
target_cam_c2w = np.array([
|
| 35 |
+
[1, 0, 0, 0],
|
| 36 |
+
[0, 1, 0, -cam_to_origin],
|
| 37 |
+
[0, 0, 1, 0],
|
| 38 |
+
[0, 0, 0, 1]
|
| 39 |
+
])
|
| 40 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
| 41 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
| 42 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
| 43 |
+
return ret_poses
|
| 44 |
+
|
| 45 |
+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 46 |
+
"""
|
| 47 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
| 48 |
+
|
| 49 |
+
sample_wh_ratio = width / height
|
| 50 |
+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
| 51 |
+
|
| 52 |
+
if pose_wh_ratio > sample_wh_ratio:
|
| 53 |
+
resized_ori_w = height * pose_wh_ratio
|
| 54 |
+
for cam_param in cam_params:
|
| 55 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
| 56 |
+
else:
|
| 57 |
+
resized_ori_h = width / pose_wh_ratio
|
| 58 |
+
for cam_param in cam_params:
|
| 59 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
| 60 |
+
|
| 61 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
| 62 |
+
cam_param.fy * height,
|
| 63 |
+
cam_param.cx * width,
|
| 64 |
+
cam_param.cy * height]
|
| 65 |
+
for cam_param in cam_params], dtype=np.float32)
|
| 66 |
+
|
| 67 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
| 68 |
+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
| 69 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
| 70 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
| 71 |
+
plucker_embedding = plucker_embedding[None]
|
| 72 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
| 73 |
+
return plucker_embedding
|
| 74 |
+
|
| 75 |
+
class Camera(object):
|
| 76 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 77 |
+
"""
|
| 78 |
+
def __init__(self, entry):
|
| 79 |
+
fx, fy, cx, cy = entry[1:5]
|
| 80 |
+
self.fx = fx
|
| 81 |
+
self.fy = fy
|
| 82 |
+
self.cx = cx
|
| 83 |
+
self.cy = cy
|
| 84 |
+
c2w_mat = np.array(entry[7:]).reshape(4, 4)
|
| 85 |
+
self.c2w_mat = c2w_mat
|
| 86 |
+
self.w2c_mat = np.linalg.inv(c2w_mat)
|
| 87 |
+
|
| 88 |
+
def ray_condition(K, c2w, H, W, device):
|
| 89 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 90 |
+
"""
|
| 91 |
+
# c2w: B, V, 4, 4
|
| 92 |
+
# K: B, V, 4
|
| 93 |
+
|
| 94 |
+
B = K.shape[0]
|
| 95 |
+
|
| 96 |
+
j, i = torch.meshgrid(
|
| 97 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
| 98 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
| 99 |
+
indexing='ij'
|
| 100 |
+
)
|
| 101 |
+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 102 |
+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 103 |
+
|
| 104 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
| 105 |
+
|
| 106 |
+
zs = torch.ones_like(i) # [B, HxW]
|
| 107 |
+
xs = (i - cx) / fx * zs
|
| 108 |
+
ys = (j - cy) / fy * zs
|
| 109 |
+
zs = zs.expand_as(ys)
|
| 110 |
+
|
| 111 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
| 112 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
| 113 |
+
|
| 114 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
| 115 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
| 116 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
| 117 |
+
# c2w @ dirctions
|
| 118 |
+
rays_dxo = torch.cross(rays_o, rays_d)
|
| 119 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
| 120 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
| 121 |
+
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
| 122 |
+
return plucker
|
| 123 |
+
|
| 124 |
+
def get_camera_motion(angle, T, speed, n=81):
|
| 125 |
+
def compute_R_form_rad_angle(angles):
|
| 126 |
+
theta_x, theta_y, theta_z = angles
|
| 127 |
+
Rx = np.array([[1, 0, 0],
|
| 128 |
+
[0, np.cos(theta_x), -np.sin(theta_x)],
|
| 129 |
+
[0, np.sin(theta_x), np.cos(theta_x)]])
|
| 130 |
+
|
| 131 |
+
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)],
|
| 132 |
+
[0, 1, 0],
|
| 133 |
+
[-np.sin(theta_y), 0, np.cos(theta_y)]])
|
| 134 |
+
|
| 135 |
+
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
|
| 136 |
+
[np.sin(theta_z), np.cos(theta_z), 0],
|
| 137 |
+
[0, 0, 1]])
|
| 138 |
+
|
| 139 |
+
R = np.dot(Rz, np.dot(Ry, Rx))
|
| 140 |
+
return R
|
| 141 |
+
RT = []
|
| 142 |
+
for i in range(n):
|
| 143 |
+
_angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle
|
| 144 |
+
R = compute_R_form_rad_angle(_angle)
|
| 145 |
+
_T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1))
|
| 146 |
+
_RT = np.concatenate([R,_T], axis=1)
|
| 147 |
+
RT.append(_RT)
|
| 148 |
+
RT = np.stack(RT)
|
| 149 |
+
return RT
|
| 150 |
+
|
| 151 |
+
class WanCameraEmbedding(io.ComfyNode):
|
| 152 |
+
@classmethod
|
| 153 |
+
def define_schema(cls):
|
| 154 |
+
return io.Schema(
|
| 155 |
+
node_id="WanCameraEmbedding",
|
| 156 |
+
category="camera",
|
| 157 |
+
inputs=[
|
| 158 |
+
io.Combo.Input(
|
| 159 |
+
"camera_pose",
|
| 160 |
+
options=[
|
| 161 |
+
"Static",
|
| 162 |
+
"Pan Up",
|
| 163 |
+
"Pan Down",
|
| 164 |
+
"Pan Left",
|
| 165 |
+
"Pan Right",
|
| 166 |
+
"Zoom In",
|
| 167 |
+
"Zoom Out",
|
| 168 |
+
"Anti Clockwise (ACW)",
|
| 169 |
+
"ClockWise (CW)",
|
| 170 |
+
],
|
| 171 |
+
default="Static",
|
| 172 |
+
),
|
| 173 |
+
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 174 |
+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 175 |
+
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
| 176 |
+
io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True),
|
| 177 |
+
io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True),
|
| 178 |
+
io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True),
|
| 179 |
+
io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True),
|
| 180 |
+
io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True),
|
| 181 |
+
],
|
| 182 |
+
outputs=[
|
| 183 |
+
io.WanCameraEmbedding.Output(display_name="camera_embedding"),
|
| 184 |
+
io.Int.Output(display_name="width"),
|
| 185 |
+
io.Int.Output(display_name="height"),
|
| 186 |
+
io.Int.Output(display_name="length"),
|
| 187 |
+
],
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@classmethod
|
| 191 |
+
def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput:
|
| 192 |
+
"""
|
| 193 |
+
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
|
| 194 |
+
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
|
| 195 |
+
"""
|
| 196 |
+
motion_list = [camera_pose]
|
| 197 |
+
speed = speed
|
| 198 |
+
angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
|
| 199 |
+
T = np.array(CAMERA_DICT[motion_list[0]]["T"])
|
| 200 |
+
RT = get_camera_motion(angle, T, speed, length)
|
| 201 |
+
|
| 202 |
+
trajs=[]
|
| 203 |
+
for cp in RT.tolist():
|
| 204 |
+
traj=[fx,fy,cx,cy,0,0]
|
| 205 |
+
traj.extend(cp[0])
|
| 206 |
+
traj.extend(cp[1])
|
| 207 |
+
traj.extend(cp[2])
|
| 208 |
+
traj.extend([0,0,0,1])
|
| 209 |
+
trajs.append(traj)
|
| 210 |
+
|
| 211 |
+
cam_params = np.array([[float(x) for x in pose] for pose in trajs])
|
| 212 |
+
cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
|
| 213 |
+
control_camera_video = process_pose_params(cam_params, width=width, height=height)
|
| 214 |
+
control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
|
| 215 |
+
|
| 216 |
+
control_camera_video = torch.concat(
|
| 217 |
+
[
|
| 218 |
+
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
| 219 |
+
control_camera_video[:, :, 1:]
|
| 220 |
+
], dim=2
|
| 221 |
+
).transpose(1, 2)
|
| 222 |
+
|
| 223 |
+
# Reshape, transpose, and view into desired shape
|
| 224 |
+
b, f, c, h, w = control_camera_video.shape
|
| 225 |
+
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
| 226 |
+
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
| 227 |
+
|
| 228 |
+
return io.NodeOutput(control_camera_video, width, height, length)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class CameraTrajectoryExtension(ComfyExtension):
|
| 232 |
+
@override
|
| 233 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 234 |
+
return [
|
| 235 |
+
WanCameraEmbedding,
|
| 236 |
+
]
|
| 237 |
+
|
| 238 |
+
async def comfy_entrypoint() -> CameraTrajectoryExtension:
|
| 239 |
+
return CameraTrajectoryExtension()
|
ComfyUI/comfy_extras/nodes_canny.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from kornia.filters import canny
|
| 2 |
+
from typing_extensions import override
|
| 3 |
+
|
| 4 |
+
import comfy.model_management
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Canny(io.ComfyNode):
|
| 10 |
+
@classmethod
|
| 11 |
+
def define_schema(cls):
|
| 12 |
+
return io.Schema(
|
| 13 |
+
node_id="Canny",
|
| 14 |
+
display_name="Canny",
|
| 15 |
+
search_aliases=["edge detection", "outline", "contour detection", "line art"],
|
| 16 |
+
category="image/preprocessors",
|
| 17 |
+
essentials_category="Image Tools",
|
| 18 |
+
inputs=[
|
| 19 |
+
io.Image.Input("image"),
|
| 20 |
+
io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01),
|
| 21 |
+
io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01),
|
| 22 |
+
],
|
| 23 |
+
outputs=[io.Image.Output()],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def detect_edge(cls, image, low_threshold, high_threshold):
|
| 28 |
+
# Deprecated: use the V3 schema's `execute` method instead of this.
|
| 29 |
+
return cls.execute(image, low_threshold, high_threshold)
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
|
| 33 |
+
output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
|
| 34 |
+
img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
|
| 35 |
+
return io.NodeOutput(img_out)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CannyExtension(ComfyExtension):
|
| 39 |
+
@override
|
| 40 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 41 |
+
return [Canny]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
async def comfy_entrypoint() -> CannyExtension:
|
| 45 |
+
return CannyExtension()
|
ComfyUI/comfy_extras/nodes_cfg.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# https://github.com/WeichenFan/CFG-Zero-star
|
| 9 |
+
def optimized_scale(positive, negative):
|
| 10 |
+
positive_flat = positive.reshape(positive.shape[0], -1)
|
| 11 |
+
negative_flat = negative.reshape(negative.shape[0], -1)
|
| 12 |
+
|
| 13 |
+
# Calculate dot production
|
| 14 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 15 |
+
|
| 16 |
+
# Squared norm of uncondition
|
| 17 |
+
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
| 18 |
+
|
| 19 |
+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
| 20 |
+
st_star = dot_product / squared_norm
|
| 21 |
+
|
| 22 |
+
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
|
| 23 |
+
|
| 24 |
+
class CFGZeroStar(io.ComfyNode):
|
| 25 |
+
@classmethod
|
| 26 |
+
def define_schema(cls) -> io.Schema:
|
| 27 |
+
return io.Schema(
|
| 28 |
+
node_id="CFGZeroStar",
|
| 29 |
+
category="advanced/guidance",
|
| 30 |
+
inputs=[
|
| 31 |
+
io.Model.Input("model"),
|
| 32 |
+
],
|
| 33 |
+
outputs=[io.Model.Output(display_name="patched_model")],
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
@classmethod
|
| 37 |
+
def execute(cls, model) -> io.NodeOutput:
|
| 38 |
+
m = model.clone()
|
| 39 |
+
def cfg_zero_star(args):
|
| 40 |
+
guidance_scale = args['cond_scale']
|
| 41 |
+
x = args['input']
|
| 42 |
+
cond_p = args['cond_denoised']
|
| 43 |
+
uncond_p = args['uncond_denoised']
|
| 44 |
+
out = args["denoised"]
|
| 45 |
+
alpha = optimized_scale(x - cond_p, x - uncond_p)
|
| 46 |
+
|
| 47 |
+
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
|
| 48 |
+
m.set_model_sampler_post_cfg_function(cfg_zero_star)
|
| 49 |
+
return io.NodeOutput(m)
|
| 50 |
+
|
| 51 |
+
class CFGNorm(io.ComfyNode):
|
| 52 |
+
@classmethod
|
| 53 |
+
def define_schema(cls) -> io.Schema:
|
| 54 |
+
return io.Schema(
|
| 55 |
+
node_id="CFGNorm",
|
| 56 |
+
category="advanced/guidance",
|
| 57 |
+
inputs=[
|
| 58 |
+
io.Model.Input("model"),
|
| 59 |
+
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
|
| 60 |
+
],
|
| 61 |
+
outputs=[io.Model.Output(display_name="patched_model")],
|
| 62 |
+
is_experimental=True,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def execute(cls, model, strength) -> io.NodeOutput:
|
| 67 |
+
m = model.clone()
|
| 68 |
+
def cfg_norm(args):
|
| 69 |
+
cond_p = args['cond_denoised']
|
| 70 |
+
pred_text_ = args["denoised"]
|
| 71 |
+
|
| 72 |
+
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
| 73 |
+
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
| 74 |
+
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
| 75 |
+
return pred_text_ * scale * strength
|
| 76 |
+
|
| 77 |
+
m.set_model_sampler_post_cfg_function(cfg_norm)
|
| 78 |
+
return io.NodeOutput(m)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class CfgExtension(ComfyExtension):
|
| 82 |
+
@override
|
| 83 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 84 |
+
return [
|
| 85 |
+
CFGZeroStar,
|
| 86 |
+
CFGNorm,
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
async def comfy_entrypoint() -> CfgExtension:
|
| 91 |
+
return CfgExtension()
|
ComfyUI/comfy_extras/nodes_chroma_radiance.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
from typing import Callable
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import comfy.model_management
|
| 7 |
+
from comfy_api.latest import ComfyExtension, io
|
| 8 |
+
|
| 9 |
+
import nodes
|
| 10 |
+
|
| 11 |
+
class EmptyChromaRadianceLatentImage(io.ComfyNode):
|
| 12 |
+
@classmethod
|
| 13 |
+
def define_schema(cls) -> io.Schema:
|
| 14 |
+
return io.Schema(
|
| 15 |
+
node_id="EmptyChromaRadianceLatentImage",
|
| 16 |
+
category="latent/chroma_radiance",
|
| 17 |
+
inputs=[
|
| 18 |
+
io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 19 |
+
io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 20 |
+
io.Int.Input(id="batch_size", default=1, min=1, max=4096),
|
| 21 |
+
],
|
| 22 |
+
outputs=[io.Latent().Output()],
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput:
|
| 27 |
+
latent = torch.zeros((batch_size, 3, height, width), device=comfy.model_management.intermediate_device())
|
| 28 |
+
return io.NodeOutput({"samples":latent})
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ChromaRadianceOptions(io.ComfyNode):
|
| 32 |
+
@classmethod
|
| 33 |
+
def define_schema(cls) -> io.Schema:
|
| 34 |
+
return io.Schema(
|
| 35 |
+
node_id="ChromaRadianceOptions",
|
| 36 |
+
category="model_patches/chroma_radiance",
|
| 37 |
+
description="Allows setting advanced options for the Chroma Radiance model.",
|
| 38 |
+
inputs=[
|
| 39 |
+
io.Model.Input(id="model"),
|
| 40 |
+
io.Boolean.Input(
|
| 41 |
+
id="preserve_wrapper",
|
| 42 |
+
default=True,
|
| 43 |
+
tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.",
|
| 44 |
+
),
|
| 45 |
+
io.Float.Input(
|
| 46 |
+
id="start_sigma",
|
| 47 |
+
default=1.0,
|
| 48 |
+
min=0.0,
|
| 49 |
+
max=1.0,
|
| 50 |
+
tooltip="First sigma that these options will be in effect.",
|
| 51 |
+
advanced=True,
|
| 52 |
+
),
|
| 53 |
+
io.Float.Input(
|
| 54 |
+
id="end_sigma",
|
| 55 |
+
default=0.0,
|
| 56 |
+
min=0.0,
|
| 57 |
+
max=1.0,
|
| 58 |
+
tooltip="Last sigma that these options will be in effect.",
|
| 59 |
+
advanced=True,
|
| 60 |
+
),
|
| 61 |
+
io.Int.Input(
|
| 62 |
+
id="nerf_tile_size",
|
| 63 |
+
default=-1,
|
| 64 |
+
min=-1,
|
| 65 |
+
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
| 66 |
+
advanced=True,
|
| 67 |
+
),
|
| 68 |
+
],
|
| 69 |
+
outputs=[io.Model.Output()],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def execute(
|
| 74 |
+
cls,
|
| 75 |
+
*,
|
| 76 |
+
model: io.Model.Type,
|
| 77 |
+
preserve_wrapper: bool,
|
| 78 |
+
start_sigma: float,
|
| 79 |
+
end_sigma: float,
|
| 80 |
+
nerf_tile_size: int,
|
| 81 |
+
) -> io.NodeOutput:
|
| 82 |
+
radiance_options = {}
|
| 83 |
+
if nerf_tile_size >= 0:
|
| 84 |
+
radiance_options["nerf_tile_size"] = nerf_tile_size
|
| 85 |
+
|
| 86 |
+
if not radiance_options:
|
| 87 |
+
return io.NodeOutput(model)
|
| 88 |
+
|
| 89 |
+
old_wrapper = model.model_options.get("model_function_wrapper")
|
| 90 |
+
|
| 91 |
+
def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor:
|
| 92 |
+
c = args["c"].copy()
|
| 93 |
+
sigma = args["timestep"].max().detach().cpu().item()
|
| 94 |
+
if end_sigma <= sigma <= start_sigma:
|
| 95 |
+
transformer_options = c.get("transformer_options", {}).copy()
|
| 96 |
+
transformer_options["chroma_radiance_options"] = radiance_options.copy()
|
| 97 |
+
c["transformer_options"] = transformer_options
|
| 98 |
+
if not (preserve_wrapper and old_wrapper):
|
| 99 |
+
return apply_model(args["input"], args["timestep"], **c)
|
| 100 |
+
return old_wrapper(apply_model, args | {"c": c})
|
| 101 |
+
|
| 102 |
+
model = model.clone()
|
| 103 |
+
model.set_model_unet_function_wrapper(model_function_wrapper)
|
| 104 |
+
return io.NodeOutput(model)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class ChromaRadianceExtension(ComfyExtension):
|
| 108 |
+
@override
|
| 109 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 110 |
+
return [
|
| 111 |
+
EmptyChromaRadianceLatentImage,
|
| 112 |
+
ChromaRadianceOptions,
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
async def comfy_entrypoint() -> ChromaRadianceExtension:
|
| 117 |
+
return ChromaRadianceExtension()
|
ComfyUI/comfy_extras/nodes_clip_sdxl.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
|
| 3 |
+
import nodes
|
| 4 |
+
from comfy_api.latest import ComfyExtension, io
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CLIPTextEncodeSDXLRefiner(io.ComfyNode):
|
| 8 |
+
@classmethod
|
| 9 |
+
def define_schema(cls):
|
| 10 |
+
return io.Schema(
|
| 11 |
+
node_id="CLIPTextEncodeSDXLRefiner",
|
| 12 |
+
category="advanced/conditioning",
|
| 13 |
+
inputs=[
|
| 14 |
+
io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01),
|
| 15 |
+
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
| 16 |
+
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
| 17 |
+
io.String.Input("text", multiline=True, dynamic_prompts=True),
|
| 18 |
+
io.Clip.Input("clip"),
|
| 19 |
+
],
|
| 20 |
+
outputs=[io.Conditioning.Output()],
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
@classmethod
|
| 24 |
+
def execute(cls, clip, ascore, width, height, text) -> io.NodeOutput:
|
| 25 |
+
tokens = clip.tokenize(text)
|
| 26 |
+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}))
|
| 27 |
+
|
| 28 |
+
class CLIPTextEncodeSDXL(io.ComfyNode):
|
| 29 |
+
@classmethod
|
| 30 |
+
def define_schema(cls):
|
| 31 |
+
return io.Schema(
|
| 32 |
+
node_id="CLIPTextEncodeSDXL",
|
| 33 |
+
category="advanced/conditioning",
|
| 34 |
+
inputs=[
|
| 35 |
+
io.Clip.Input("clip"),
|
| 36 |
+
io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
| 37 |
+
io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
| 38 |
+
io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True),
|
| 39 |
+
io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True),
|
| 40 |
+
io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
| 41 |
+
io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
|
| 42 |
+
io.String.Input("text_g", multiline=True, dynamic_prompts=True),
|
| 43 |
+
io.String.Input("text_l", multiline=True, dynamic_prompts=True),
|
| 44 |
+
],
|
| 45 |
+
outputs=[io.Conditioning.Output()],
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput:
|
| 50 |
+
tokens = clip.tokenize(text_g)
|
| 51 |
+
tokens["l"] = clip.tokenize(text_l)["l"]
|
| 52 |
+
if len(tokens["l"]) != len(tokens["g"]):
|
| 53 |
+
empty = clip.tokenize("")
|
| 54 |
+
while len(tokens["l"]) < len(tokens["g"]):
|
| 55 |
+
tokens["l"] += empty["l"]
|
| 56 |
+
while len(tokens["l"]) > len(tokens["g"]):
|
| 57 |
+
tokens["g"] += empty["g"]
|
| 58 |
+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ClipSdxlExtension(ComfyExtension):
|
| 62 |
+
@override
|
| 63 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 64 |
+
return [
|
| 65 |
+
CLIPTextEncodeSDXLRefiner,
|
| 66 |
+
CLIPTextEncodeSDXL,
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
async def comfy_entrypoint() -> ClipSdxlExtension:
|
| 71 |
+
return ClipSdxlExtension()
|
ComfyUI/comfy_extras/nodes_color.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
from comfy_api.latest import ComfyExtension, io
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ColorToRGBInt(io.ComfyNode):
|
| 6 |
+
@classmethod
|
| 7 |
+
def define_schema(cls) -> io.Schema:
|
| 8 |
+
return io.Schema(
|
| 9 |
+
node_id="ColorToRGBInt",
|
| 10 |
+
display_name="Color to RGB Int",
|
| 11 |
+
category="utils",
|
| 12 |
+
description="Convert a color to a RGB integer value.",
|
| 13 |
+
inputs=[
|
| 14 |
+
io.Color.Input("color"),
|
| 15 |
+
],
|
| 16 |
+
outputs=[
|
| 17 |
+
io.Int.Output(display_name="rgb_int"),
|
| 18 |
+
],
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
@classmethod
|
| 22 |
+
def execute(
|
| 23 |
+
cls,
|
| 24 |
+
color: str,
|
| 25 |
+
) -> io.NodeOutput:
|
| 26 |
+
# expect format #RRGGBB
|
| 27 |
+
if len(color) != 7 or color[0] != "#":
|
| 28 |
+
raise ValueError("Color must be in format #RRGGBB")
|
| 29 |
+
r = int(color[1:3], 16)
|
| 30 |
+
g = int(color[3:5], 16)
|
| 31 |
+
b = int(color[5:7], 16)
|
| 32 |
+
return io.NodeOutput(r * 256 * 256 + g * 256 + b)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ColorExtension(ComfyExtension):
|
| 36 |
+
@override
|
| 37 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 38 |
+
return [ColorToRGBInt]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def comfy_entrypoint() -> ColorExtension:
|
| 42 |
+
return ColorExtension()
|
ComfyUI/comfy_extras/nodes_compositing.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import comfy.utils
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def resize_mask(mask, shape):
|
| 9 |
+
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
| 10 |
+
|
| 11 |
+
class PorterDuffMode(Enum):
|
| 12 |
+
ADD = 0
|
| 13 |
+
CLEAR = 1
|
| 14 |
+
DARKEN = 2
|
| 15 |
+
DST = 3
|
| 16 |
+
DST_ATOP = 4
|
| 17 |
+
DST_IN = 5
|
| 18 |
+
DST_OUT = 6
|
| 19 |
+
DST_OVER = 7
|
| 20 |
+
LIGHTEN = 8
|
| 21 |
+
MULTIPLY = 9
|
| 22 |
+
OVERLAY = 10
|
| 23 |
+
SCREEN = 11
|
| 24 |
+
SRC = 12
|
| 25 |
+
SRC_ATOP = 13
|
| 26 |
+
SRC_IN = 14
|
| 27 |
+
SRC_OUT = 15
|
| 28 |
+
SRC_OVER = 16
|
| 29 |
+
XOR = 17
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
|
| 33 |
+
# convert mask to alpha
|
| 34 |
+
src_alpha = 1 - src_alpha
|
| 35 |
+
dst_alpha = 1 - dst_alpha
|
| 36 |
+
# premultiply alpha
|
| 37 |
+
src_image = src_image * src_alpha
|
| 38 |
+
dst_image = dst_image * dst_alpha
|
| 39 |
+
|
| 40 |
+
# composite ops below assume alpha-premultiplied images
|
| 41 |
+
if mode == PorterDuffMode.ADD:
|
| 42 |
+
out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
|
| 43 |
+
out_image = torch.clamp(src_image + dst_image, 0, 1)
|
| 44 |
+
elif mode == PorterDuffMode.CLEAR:
|
| 45 |
+
out_alpha = torch.zeros_like(dst_alpha)
|
| 46 |
+
out_image = torch.zeros_like(dst_image)
|
| 47 |
+
elif mode == PorterDuffMode.DARKEN:
|
| 48 |
+
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
| 49 |
+
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
|
| 50 |
+
elif mode == PorterDuffMode.DST:
|
| 51 |
+
out_alpha = dst_alpha
|
| 52 |
+
out_image = dst_image
|
| 53 |
+
elif mode == PorterDuffMode.DST_ATOP:
|
| 54 |
+
out_alpha = src_alpha
|
| 55 |
+
out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
|
| 56 |
+
elif mode == PorterDuffMode.DST_IN:
|
| 57 |
+
out_alpha = src_alpha * dst_alpha
|
| 58 |
+
out_image = dst_image * src_alpha
|
| 59 |
+
elif mode == PorterDuffMode.DST_OUT:
|
| 60 |
+
out_alpha = (1 - src_alpha) * dst_alpha
|
| 61 |
+
out_image = (1 - src_alpha) * dst_image
|
| 62 |
+
elif mode == PorterDuffMode.DST_OVER:
|
| 63 |
+
out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
|
| 64 |
+
out_image = dst_image + (1 - dst_alpha) * src_image
|
| 65 |
+
elif mode == PorterDuffMode.LIGHTEN:
|
| 66 |
+
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
| 67 |
+
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
|
| 68 |
+
elif mode == PorterDuffMode.MULTIPLY:
|
| 69 |
+
out_alpha = src_alpha * dst_alpha
|
| 70 |
+
out_image = src_image * dst_image
|
| 71 |
+
elif mode == PorterDuffMode.OVERLAY:
|
| 72 |
+
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
| 73 |
+
out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
|
| 74 |
+
src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
|
| 75 |
+
elif mode == PorterDuffMode.SCREEN:
|
| 76 |
+
out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
|
| 77 |
+
out_image = src_image + dst_image - src_image * dst_image
|
| 78 |
+
elif mode == PorterDuffMode.SRC:
|
| 79 |
+
out_alpha = src_alpha
|
| 80 |
+
out_image = src_image
|
| 81 |
+
elif mode == PorterDuffMode.SRC_ATOP:
|
| 82 |
+
out_alpha = dst_alpha
|
| 83 |
+
out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
|
| 84 |
+
elif mode == PorterDuffMode.SRC_IN:
|
| 85 |
+
out_alpha = src_alpha * dst_alpha
|
| 86 |
+
out_image = src_image * dst_alpha
|
| 87 |
+
elif mode == PorterDuffMode.SRC_OUT:
|
| 88 |
+
out_alpha = (1 - dst_alpha) * src_alpha
|
| 89 |
+
out_image = (1 - dst_alpha) * src_image
|
| 90 |
+
elif mode == PorterDuffMode.SRC_OVER:
|
| 91 |
+
out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
|
| 92 |
+
out_image = src_image + (1 - src_alpha) * dst_image
|
| 93 |
+
elif mode == PorterDuffMode.XOR:
|
| 94 |
+
out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
|
| 95 |
+
out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
|
| 96 |
+
else:
|
| 97 |
+
return None, None
|
| 98 |
+
|
| 99 |
+
# back to non-premultiplied alpha
|
| 100 |
+
out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image))
|
| 101 |
+
out_image = torch.clamp(out_image, 0, 1)
|
| 102 |
+
# convert alpha to mask
|
| 103 |
+
out_alpha = 1 - out_alpha
|
| 104 |
+
return out_image, out_alpha
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class PorterDuffImageComposite(io.ComfyNode):
|
| 108 |
+
@classmethod
|
| 109 |
+
def define_schema(cls):
|
| 110 |
+
return io.Schema(
|
| 111 |
+
node_id="PorterDuffImageComposite",
|
| 112 |
+
search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"],
|
| 113 |
+
display_name="Porter-Duff Image Composite",
|
| 114 |
+
category="mask/compositing",
|
| 115 |
+
inputs=[
|
| 116 |
+
io.Image.Input("source"),
|
| 117 |
+
io.Mask.Input("source_alpha"),
|
| 118 |
+
io.Image.Input("destination"),
|
| 119 |
+
io.Mask.Input("destination_alpha"),
|
| 120 |
+
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
|
| 121 |
+
],
|
| 122 |
+
outputs=[
|
| 123 |
+
io.Image.Output(),
|
| 124 |
+
io.Mask.Output(),
|
| 125 |
+
],
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
|
| 130 |
+
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
| 131 |
+
out_images = []
|
| 132 |
+
out_alphas = []
|
| 133 |
+
|
| 134 |
+
for i in range(batch_size):
|
| 135 |
+
src_image = source[i]
|
| 136 |
+
dst_image = destination[i]
|
| 137 |
+
|
| 138 |
+
assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
|
| 139 |
+
|
| 140 |
+
src_alpha = source_alpha[i].unsqueeze(2)
|
| 141 |
+
dst_alpha = destination_alpha[i].unsqueeze(2)
|
| 142 |
+
|
| 143 |
+
if dst_alpha.shape[:2] != dst_image.shape[:2]:
|
| 144 |
+
upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
| 145 |
+
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
| 146 |
+
dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
| 147 |
+
if src_image.shape != dst_image.shape:
|
| 148 |
+
upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
|
| 149 |
+
upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
|
| 150 |
+
src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
| 151 |
+
if src_alpha.shape != dst_alpha.shape:
|
| 152 |
+
upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
|
| 153 |
+
upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
|
| 154 |
+
src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
|
| 155 |
+
|
| 156 |
+
out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
|
| 157 |
+
|
| 158 |
+
out_images.append(out_image)
|
| 159 |
+
out_alphas.append(out_alpha.squeeze(2))
|
| 160 |
+
|
| 161 |
+
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class SplitImageWithAlpha(io.ComfyNode):
|
| 165 |
+
@classmethod
|
| 166 |
+
def define_schema(cls):
|
| 167 |
+
return io.Schema(
|
| 168 |
+
node_id="SplitImageWithAlpha",
|
| 169 |
+
search_aliases=["extract alpha", "separate transparency", "remove alpha"],
|
| 170 |
+
display_name="Split Image with Alpha",
|
| 171 |
+
category="mask/compositing",
|
| 172 |
+
inputs=[
|
| 173 |
+
io.Image.Input("image"),
|
| 174 |
+
],
|
| 175 |
+
outputs=[
|
| 176 |
+
io.Image.Output(),
|
| 177 |
+
io.Mask.Output(),
|
| 178 |
+
],
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
@classmethod
|
| 182 |
+
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
|
| 183 |
+
out_images = [i[:,:,:3] for i in image]
|
| 184 |
+
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
| 185 |
+
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class JoinImageWithAlpha(io.ComfyNode):
|
| 189 |
+
@classmethod
|
| 190 |
+
def define_schema(cls):
|
| 191 |
+
return io.Schema(
|
| 192 |
+
node_id="JoinImageWithAlpha",
|
| 193 |
+
search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"],
|
| 194 |
+
display_name="Join Image with Alpha",
|
| 195 |
+
category="mask/compositing",
|
| 196 |
+
inputs=[
|
| 197 |
+
io.Image.Input("image"),
|
| 198 |
+
io.Mask.Input("alpha"),
|
| 199 |
+
],
|
| 200 |
+
outputs=[io.Image.Output()],
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
@classmethod
|
| 204 |
+
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
| 205 |
+
batch_size = min(len(image), len(alpha))
|
| 206 |
+
out_images = []
|
| 207 |
+
|
| 208 |
+
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
| 209 |
+
for i in range(batch_size):
|
| 210 |
+
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
| 211 |
+
|
| 212 |
+
return io.NodeOutput(torch.stack(out_images))
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class CompositingExtension(ComfyExtension):
|
| 216 |
+
@override
|
| 217 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 218 |
+
return [
|
| 219 |
+
PorterDuffImageComposite,
|
| 220 |
+
SplitImageWithAlpha,
|
| 221 |
+
JoinImageWithAlpha,
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
async def comfy_entrypoint() -> CompositingExtension:
|
| 226 |
+
return CompositingExtension()
|
ComfyUI/comfy_extras/nodes_cond.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
|
| 3 |
+
from comfy_api.latest import ComfyExtension, io
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CLIPTextEncodeControlnet(io.ComfyNode):
|
| 7 |
+
@classmethod
|
| 8 |
+
def define_schema(cls) -> io.Schema:
|
| 9 |
+
return io.Schema(
|
| 10 |
+
node_id="CLIPTextEncodeControlnet",
|
| 11 |
+
category="_for_testing/conditioning",
|
| 12 |
+
inputs=[
|
| 13 |
+
io.Clip.Input("clip"),
|
| 14 |
+
io.Conditioning.Input("conditioning"),
|
| 15 |
+
io.String.Input("text", multiline=True, dynamic_prompts=True),
|
| 16 |
+
],
|
| 17 |
+
outputs=[io.Conditioning.Output()],
|
| 18 |
+
is_experimental=True,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
@classmethod
|
| 22 |
+
def execute(cls, clip, conditioning, text) -> io.NodeOutput:
|
| 23 |
+
tokens = clip.tokenize(text)
|
| 24 |
+
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
| 25 |
+
c = []
|
| 26 |
+
for t in conditioning:
|
| 27 |
+
n = [t[0], t[1].copy()]
|
| 28 |
+
n[1]['cross_attn_controlnet'] = cond
|
| 29 |
+
n[1]['pooled_output_controlnet'] = pooled
|
| 30 |
+
c.append(n)
|
| 31 |
+
return io.NodeOutput(c)
|
| 32 |
+
|
| 33 |
+
class T5TokenizerOptions(io.ComfyNode):
|
| 34 |
+
@classmethod
|
| 35 |
+
def define_schema(cls) -> io.Schema:
|
| 36 |
+
return io.Schema(
|
| 37 |
+
node_id="T5TokenizerOptions",
|
| 38 |
+
category="_for_testing/conditioning",
|
| 39 |
+
inputs=[
|
| 40 |
+
io.Clip.Input("clip"),
|
| 41 |
+
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
| 42 |
+
io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True),
|
| 43 |
+
],
|
| 44 |
+
outputs=[io.Clip.Output()],
|
| 45 |
+
is_experimental=True,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def execute(cls, clip, min_padding, min_length) -> io.NodeOutput:
|
| 50 |
+
clip = clip.clone()
|
| 51 |
+
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
|
| 52 |
+
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
|
| 53 |
+
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
|
| 54 |
+
|
| 55 |
+
return io.NodeOutput(clip)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CondExtension(ComfyExtension):
|
| 59 |
+
@override
|
| 60 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 61 |
+
return [
|
| 62 |
+
CLIPTextEncodeControlnet,
|
| 63 |
+
T5TokenizerOptions,
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
async def comfy_entrypoint() -> CondExtension:
|
| 68 |
+
return CondExtension()
|
ComfyUI/comfy_extras/nodes_context_windows.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from comfy_api.latest import ComfyExtension, io
|
| 3 |
+
import comfy.context_windows
|
| 4 |
+
import nodes
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ContextWindowsManualNode(io.ComfyNode):
|
| 8 |
+
@classmethod
|
| 9 |
+
def define_schema(cls) -> io.Schema:
|
| 10 |
+
return io.Schema(
|
| 11 |
+
node_id="ContextWindowsManual",
|
| 12 |
+
display_name="Context Windows (Manual)",
|
| 13 |
+
category="context",
|
| 14 |
+
description="Manually set context windows.",
|
| 15 |
+
inputs=[
|
| 16 |
+
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
| 17 |
+
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True),
|
| 18 |
+
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True),
|
| 19 |
+
io.Combo.Input("context_schedule", options=[
|
| 20 |
+
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
| 21 |
+
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
| 22 |
+
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
| 23 |
+
comfy.context_windows.ContextSchedules.BATCHED,
|
| 24 |
+
], tooltip="The stride of the context window."),
|
| 25 |
+
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
| 26 |
+
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
| 27 |
+
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
| 28 |
+
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
| 29 |
+
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
| 30 |
+
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
| 31 |
+
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
| 32 |
+
],
|
| 33 |
+
outputs=[
|
| 34 |
+
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
| 35 |
+
],
|
| 36 |
+
is_experimental=True,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
| 41 |
+
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
| 42 |
+
model = model.clone()
|
| 43 |
+
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
| 44 |
+
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
| 45 |
+
fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method),
|
| 46 |
+
context_length=context_length,
|
| 47 |
+
context_overlap=context_overlap,
|
| 48 |
+
context_stride=context_stride,
|
| 49 |
+
closed_loop=closed_loop,
|
| 50 |
+
dim=dim,
|
| 51 |
+
freenoise=freenoise,
|
| 52 |
+
cond_retain_index_list=cond_retain_index_list,
|
| 53 |
+
split_conds_to_windows=split_conds_to_windows
|
| 54 |
+
)
|
| 55 |
+
# make memory usage calculation only take into account the context window latents
|
| 56 |
+
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
| 57 |
+
if freenoise: # no other use for this wrapper at this time
|
| 58 |
+
comfy.context_windows.create_sampler_sample_wrapper(model)
|
| 59 |
+
return io.NodeOutput(model)
|
| 60 |
+
|
| 61 |
+
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
| 62 |
+
@classmethod
|
| 63 |
+
def define_schema(cls) -> io.Schema:
|
| 64 |
+
schema = super().define_schema()
|
| 65 |
+
schema.node_id = "WanContextWindowsManual"
|
| 66 |
+
schema.display_name = "WAN Context Windows (Manual)"
|
| 67 |
+
schema.description = "Manually set context windows for WAN-like models (dim=2)."
|
| 68 |
+
schema.inputs = [
|
| 69 |
+
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
| 70 |
+
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True),
|
| 71 |
+
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True),
|
| 72 |
+
io.Combo.Input("context_schedule", options=[
|
| 73 |
+
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
| 74 |
+
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
| 75 |
+
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
| 76 |
+
comfy.context_windows.ContextSchedules.BATCHED,
|
| 77 |
+
], tooltip="The stride of the context window."),
|
| 78 |
+
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
| 79 |
+
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
| 80 |
+
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
| 81 |
+
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
| 82 |
+
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
| 83 |
+
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
| 84 |
+
]
|
| 85 |
+
return schema
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
| 89 |
+
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
| 90 |
+
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
| 91 |
+
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
| 92 |
+
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ContextWindowsExtension(ComfyExtension):
|
| 96 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 97 |
+
return [
|
| 98 |
+
ContextWindowsManualNode,
|
| 99 |
+
WanContextWindowsManualNode,
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
def comfy_entrypoint():
|
| 103 |
+
return ContextWindowsExtension()
|
ComfyUI/comfy_extras/nodes_controlnet.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
| 2 |
+
import nodes
|
| 3 |
+
import comfy.utils
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
|
| 7 |
+
class SetUnionControlNetType(io.ComfyNode):
|
| 8 |
+
@classmethod
|
| 9 |
+
def define_schema(cls):
|
| 10 |
+
return io.Schema(
|
| 11 |
+
node_id="SetUnionControlNetType",
|
| 12 |
+
category="conditioning/controlnet",
|
| 13 |
+
inputs=[
|
| 14 |
+
io.ControlNet.Input("control_net"),
|
| 15 |
+
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
|
| 16 |
+
],
|
| 17 |
+
outputs=[
|
| 18 |
+
io.ControlNet.Output(),
|
| 19 |
+
],
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def execute(cls, control_net, type) -> io.NodeOutput:
|
| 24 |
+
control_net = control_net.copy()
|
| 25 |
+
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
|
| 26 |
+
if type_number >= 0:
|
| 27 |
+
control_net.set_extra_arg("control_type", [type_number])
|
| 28 |
+
else:
|
| 29 |
+
control_net.set_extra_arg("control_type", [])
|
| 30 |
+
|
| 31 |
+
return io.NodeOutput(control_net)
|
| 32 |
+
|
| 33 |
+
set_controlnet_type = execute # TODO: remove
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
| 37 |
+
@classmethod
|
| 38 |
+
def define_schema(cls):
|
| 39 |
+
return io.Schema(
|
| 40 |
+
node_id="ControlNetInpaintingAliMamaApply",
|
| 41 |
+
search_aliases=["masked controlnet"],
|
| 42 |
+
category="conditioning/controlnet",
|
| 43 |
+
inputs=[
|
| 44 |
+
io.Conditioning.Input("positive"),
|
| 45 |
+
io.Conditioning.Input("negative"),
|
| 46 |
+
io.ControlNet.Input("control_net"),
|
| 47 |
+
io.Vae.Input("vae"),
|
| 48 |
+
io.Image.Input("image"),
|
| 49 |
+
io.Mask.Input("mask"),
|
| 50 |
+
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
| 51 |
+
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True),
|
| 52 |
+
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True),
|
| 53 |
+
],
|
| 54 |
+
outputs=[
|
| 55 |
+
io.Conditioning.Output(display_name="positive"),
|
| 56 |
+
io.Conditioning.Output(display_name="negative"),
|
| 57 |
+
],
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
|
| 62 |
+
extra_concat = []
|
| 63 |
+
if control_net.concat_mask:
|
| 64 |
+
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
| 65 |
+
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
|
| 66 |
+
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
| 67 |
+
extra_concat = [mask]
|
| 68 |
+
|
| 69 |
+
result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
| 70 |
+
return io.NodeOutput(result[0], result[1])
|
| 71 |
+
|
| 72 |
+
apply_inpaint_controlnet = execute # TODO: remove
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ControlNetExtension(ComfyExtension):
|
| 76 |
+
@override
|
| 77 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 78 |
+
return [
|
| 79 |
+
SetUnionControlNetType,
|
| 80 |
+
ControlNetInpaintingAliMamaApply,
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
async def comfy_entrypoint() -> ControlNetExtension:
|
| 85 |
+
return ControlNetExtension()
|
ComfyUI/comfy_extras/nodes_cosmos.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
import nodes
|
| 3 |
+
import torch
|
| 4 |
+
import comfy.model_management
|
| 5 |
+
import comfy.utils
|
| 6 |
+
import comfy.latent_formats
|
| 7 |
+
|
| 8 |
+
from comfy_api.latest import ComfyExtension, io
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EmptyCosmosLatentVideo(io.ComfyNode):
|
| 12 |
+
@classmethod
|
| 13 |
+
def define_schema(cls) -> io.Schema:
|
| 14 |
+
return io.Schema(
|
| 15 |
+
node_id="EmptyCosmosLatentVideo",
|
| 16 |
+
category="latent/video",
|
| 17 |
+
inputs=[
|
| 18 |
+
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 19 |
+
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 20 |
+
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
| 21 |
+
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
| 22 |
+
],
|
| 23 |
+
outputs=[io.Latent.Output()],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
| 28 |
+
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
| 29 |
+
return io.NodeOutput({"samples": latent})
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
|
| 33 |
+
pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
| 34 |
+
pixel_len = min(pixels.shape[0], length)
|
| 35 |
+
padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7)
|
| 36 |
+
padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
|
| 37 |
+
padded_pixels[:pixel_len] = pixels[:pixel_len]
|
| 38 |
+
latent_len = ((pixel_len - 1) // 8) + 1
|
| 39 |
+
latent_temp = vae.encode(padded_pixels)
|
| 40 |
+
return latent_temp[:, :, :latent_len]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CosmosImageToVideoLatent(io.ComfyNode):
|
| 44 |
+
@classmethod
|
| 45 |
+
def define_schema(cls) -> io.Schema:
|
| 46 |
+
return io.Schema(
|
| 47 |
+
node_id="CosmosImageToVideoLatent",
|
| 48 |
+
category="conditioning/inpaint",
|
| 49 |
+
inputs=[
|
| 50 |
+
io.Vae.Input("vae"),
|
| 51 |
+
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 52 |
+
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 53 |
+
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
| 54 |
+
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
| 55 |
+
io.Image.Input("start_image", optional=True),
|
| 56 |
+
io.Image.Input("end_image", optional=True),
|
| 57 |
+
],
|
| 58 |
+
outputs=[io.Latent.Output()],
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
|
| 63 |
+
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
| 64 |
+
if start_image is None and end_image is None:
|
| 65 |
+
out_latent = {}
|
| 66 |
+
out_latent["samples"] = latent
|
| 67 |
+
return io.NodeOutput(out_latent)
|
| 68 |
+
|
| 69 |
+
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
| 70 |
+
|
| 71 |
+
if start_image is not None:
|
| 72 |
+
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
|
| 73 |
+
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
| 74 |
+
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
| 75 |
+
|
| 76 |
+
if end_image is not None:
|
| 77 |
+
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
|
| 78 |
+
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
|
| 79 |
+
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
| 80 |
+
|
| 81 |
+
out_latent = {}
|
| 82 |
+
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
| 83 |
+
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
| 84 |
+
return io.NodeOutput(out_latent)
|
| 85 |
+
|
| 86 |
+
class CosmosPredict2ImageToVideoLatent(io.ComfyNode):
|
| 87 |
+
@classmethod
|
| 88 |
+
def define_schema(cls) -> io.Schema:
|
| 89 |
+
return io.Schema(
|
| 90 |
+
node_id="CosmosPredict2ImageToVideoLatent",
|
| 91 |
+
category="conditioning/inpaint",
|
| 92 |
+
inputs=[
|
| 93 |
+
io.Vae.Input("vae"),
|
| 94 |
+
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 95 |
+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 96 |
+
io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
| 97 |
+
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
| 98 |
+
io.Image.Input("start_image", optional=True),
|
| 99 |
+
io.Image.Input("end_image", optional=True),
|
| 100 |
+
],
|
| 101 |
+
outputs=[io.Latent.Output()],
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
|
| 106 |
+
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
| 107 |
+
if start_image is None and end_image is None:
|
| 108 |
+
out_latent = {}
|
| 109 |
+
out_latent["samples"] = latent
|
| 110 |
+
return io.NodeOutput(out_latent)
|
| 111 |
+
|
| 112 |
+
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
| 113 |
+
|
| 114 |
+
if start_image is not None:
|
| 115 |
+
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
|
| 116 |
+
latent[:, :, :latent_temp.shape[-3]] = latent_temp
|
| 117 |
+
mask[:, :, :latent_temp.shape[-3]] *= 0.0
|
| 118 |
+
|
| 119 |
+
if end_image is not None:
|
| 120 |
+
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
|
| 121 |
+
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
|
| 122 |
+
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
|
| 123 |
+
|
| 124 |
+
out_latent = {}
|
| 125 |
+
latent_format = comfy.latent_formats.Wan21()
|
| 126 |
+
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
| 127 |
+
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
| 128 |
+
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
| 129 |
+
return io.NodeOutput(out_latent)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class CosmosExtension(ComfyExtension):
|
| 133 |
+
@override
|
| 134 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 135 |
+
return [
|
| 136 |
+
EmptyCosmosLatentVideo,
|
| 137 |
+
CosmosImageToVideoLatent,
|
| 138 |
+
CosmosPredict2ImageToVideoLatent,
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
async def comfy_entrypoint() -> CosmosExtension:
|
| 143 |
+
return CosmosExtension()
|
ComfyUI/comfy_extras/nodes_curve.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
from comfy_api.input import CurveInput
|
| 7 |
+
from typing_extensions import override
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CurveEditor(io.ComfyNode):
|
| 11 |
+
@classmethod
|
| 12 |
+
def define_schema(cls):
|
| 13 |
+
return io.Schema(
|
| 14 |
+
node_id="CurveEditor",
|
| 15 |
+
display_name="Curve Editor",
|
| 16 |
+
category="utils",
|
| 17 |
+
inputs=[
|
| 18 |
+
io.Curve.Input("curve"),
|
| 19 |
+
io.Histogram.Input("histogram", optional=True),
|
| 20 |
+
],
|
| 21 |
+
outputs=[
|
| 22 |
+
io.Curve.Output("curve"),
|
| 23 |
+
],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
| 28 |
+
result = CurveInput.from_raw(curve)
|
| 29 |
+
|
| 30 |
+
ui = {}
|
| 31 |
+
if histogram is not None:
|
| 32 |
+
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
| 33 |
+
|
| 34 |
+
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ImageHistogram(io.ComfyNode):
|
| 38 |
+
@classmethod
|
| 39 |
+
def define_schema(cls):
|
| 40 |
+
return io.Schema(
|
| 41 |
+
node_id="ImageHistogram",
|
| 42 |
+
display_name="Image Histogram",
|
| 43 |
+
category="utils",
|
| 44 |
+
inputs=[
|
| 45 |
+
io.Image.Input("image"),
|
| 46 |
+
],
|
| 47 |
+
outputs=[
|
| 48 |
+
io.Histogram.Output("rgb"),
|
| 49 |
+
io.Histogram.Output("luminance"),
|
| 50 |
+
io.Histogram.Output("red"),
|
| 51 |
+
io.Histogram.Output("green"),
|
| 52 |
+
io.Histogram.Output("blue"),
|
| 53 |
+
],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def execute(cls, image) -> io.NodeOutput:
|
| 58 |
+
img = image[0].cpu().numpy()
|
| 59 |
+
img_uint8 = np.clip(img * 255, 0, 255).astype(np.uint8)
|
| 60 |
+
|
| 61 |
+
def bincount(data):
|
| 62 |
+
return np.bincount(data.ravel(), minlength=256)[:256]
|
| 63 |
+
|
| 64 |
+
hist_r = bincount(img_uint8[:, :, 0])
|
| 65 |
+
hist_g = bincount(img_uint8[:, :, 1])
|
| 66 |
+
hist_b = bincount(img_uint8[:, :, 2])
|
| 67 |
+
|
| 68 |
+
# Average of R, G, B histograms (same as Photoshop's RGB composite)
|
| 69 |
+
rgb = ((hist_r + hist_g + hist_b) // 3).tolist()
|
| 70 |
+
|
| 71 |
+
# ITU-R BT.709-6, Item 3.2 (p.6) — Derivation of luminance signal
|
| 72 |
+
# https://www.itu.int/rec/R-REC-BT.709-6-201506-I/en
|
| 73 |
+
lum = 0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] + 0.0722 * img[:, :, 2]
|
| 74 |
+
luminance = bincount(np.clip(lum * 255, 0, 255).astype(np.uint8)).tolist()
|
| 75 |
+
|
| 76 |
+
return io.NodeOutput(
|
| 77 |
+
rgb,
|
| 78 |
+
luminance,
|
| 79 |
+
hist_r.tolist(),
|
| 80 |
+
hist_g.tolist(),
|
| 81 |
+
hist_b.tolist(),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CurveExtension(ComfyExtension):
|
| 86 |
+
@override
|
| 87 |
+
async def get_node_list(self):
|
| 88 |
+
return [CurveEditor, ImageHistogram]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
async def comfy_entrypoint():
|
| 92 |
+
return CurveExtension()
|
ComfyUI/comfy_extras/nodes_custom_sampler.py
ADDED
|
@@ -0,0 +1,1095 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import comfy.samplers
|
| 3 |
+
import comfy.sample
|
| 4 |
+
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
| 5 |
+
from comfy.k_diffusion import sa_solver
|
| 6 |
+
import latent_preview
|
| 7 |
+
import torch
|
| 8 |
+
import comfy.utils
|
| 9 |
+
import node_helpers
|
| 10 |
+
from typing_extensions import override
|
| 11 |
+
from comfy_api.latest import ComfyExtension, io
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BasicScheduler(io.ComfyNode):
|
| 16 |
+
@classmethod
|
| 17 |
+
def define_schema(cls):
|
| 18 |
+
return io.Schema(
|
| 19 |
+
node_id="BasicScheduler",
|
| 20 |
+
category="sampling/custom_sampling/schedulers",
|
| 21 |
+
inputs=[
|
| 22 |
+
io.Model.Input("model"),
|
| 23 |
+
io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES),
|
| 24 |
+
io.Int.Input("steps", default=20, min=1, max=10000),
|
| 25 |
+
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
| 26 |
+
],
|
| 27 |
+
outputs=[io.Sigmas.Output()]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput:
|
| 32 |
+
total_steps = steps
|
| 33 |
+
if denoise < 1.0:
|
| 34 |
+
if denoise <= 0.0:
|
| 35 |
+
return io.NodeOutput(torch.FloatTensor([]))
|
| 36 |
+
total_steps = int(steps/denoise)
|
| 37 |
+
|
| 38 |
+
sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
|
| 39 |
+
sigmas = sigmas[-(steps + 1):]
|
| 40 |
+
return io.NodeOutput(sigmas)
|
| 41 |
+
|
| 42 |
+
get_sigmas = execute
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class KarrasScheduler(io.ComfyNode):
|
| 46 |
+
@classmethod
|
| 47 |
+
def define_schema(cls):
|
| 48 |
+
return io.Schema(
|
| 49 |
+
node_id="KarrasScheduler",
|
| 50 |
+
category="sampling/custom_sampling/schedulers",
|
| 51 |
+
inputs=[
|
| 52 |
+
io.Int.Input("steps", default=20, min=1, max=10000),
|
| 53 |
+
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
| 54 |
+
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
| 55 |
+
io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 56 |
+
],
|
| 57 |
+
outputs=[io.Sigmas.Output()]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput:
|
| 62 |
+
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
| 63 |
+
return io.NodeOutput(sigmas)
|
| 64 |
+
|
| 65 |
+
get_sigmas = execute
|
| 66 |
+
|
| 67 |
+
class ExponentialScheduler(io.ComfyNode):
|
| 68 |
+
@classmethod
|
| 69 |
+
def define_schema(cls):
|
| 70 |
+
return io.Schema(
|
| 71 |
+
node_id="ExponentialScheduler",
|
| 72 |
+
category="sampling/custom_sampling/schedulers",
|
| 73 |
+
inputs=[
|
| 74 |
+
io.Int.Input("steps", default=20, min=1, max=10000),
|
| 75 |
+
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
| 76 |
+
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
| 77 |
+
],
|
| 78 |
+
outputs=[io.Sigmas.Output()]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def execute(cls, steps, sigma_max, sigma_min) -> io.NodeOutput:
|
| 83 |
+
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max)
|
| 84 |
+
return io.NodeOutput(sigmas)
|
| 85 |
+
|
| 86 |
+
get_sigmas = execute
|
| 87 |
+
|
| 88 |
+
class PolyexponentialScheduler(io.ComfyNode):
|
| 89 |
+
@classmethod
|
| 90 |
+
def define_schema(cls):
|
| 91 |
+
return io.Schema(
|
| 92 |
+
node_id="PolyexponentialScheduler",
|
| 93 |
+
category="sampling/custom_sampling/schedulers",
|
| 94 |
+
inputs=[
|
| 95 |
+
io.Int.Input("steps", default=20, min=1, max=10000),
|
| 96 |
+
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
| 97 |
+
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
| 98 |
+
io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 99 |
+
],
|
| 100 |
+
outputs=[io.Sigmas.Output()]
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput:
|
| 105 |
+
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
|
| 106 |
+
return io.NodeOutput(sigmas)
|
| 107 |
+
|
| 108 |
+
get_sigmas = execute
|
| 109 |
+
|
| 110 |
+
class LaplaceScheduler(io.ComfyNode):
|
| 111 |
+
@classmethod
|
| 112 |
+
def define_schema(cls):
|
| 113 |
+
return io.Schema(
|
| 114 |
+
node_id="LaplaceScheduler",
|
| 115 |
+
category="sampling/custom_sampling/schedulers",
|
| 116 |
+
inputs=[
|
| 117 |
+
io.Int.Input("steps", default=20, min=1, max=10000),
|
| 118 |
+
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
| 119 |
+
io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
| 120 |
+
io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False, advanced=True),
|
| 121 |
+
io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False, advanced=True),
|
| 122 |
+
],
|
| 123 |
+
outputs=[io.Sigmas.Output()]
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def execute(cls, steps, sigma_max, sigma_min, mu, beta) -> io.NodeOutput:
|
| 128 |
+
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
|
| 129 |
+
return io.NodeOutput(sigmas)
|
| 130 |
+
|
| 131 |
+
get_sigmas = execute
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class SDTurboScheduler(io.ComfyNode):
|
| 135 |
+
@classmethod
|
| 136 |
+
def define_schema(cls):
|
| 137 |
+
return io.Schema(
|
| 138 |
+
node_id="SDTurboScheduler",
|
| 139 |
+
category="sampling/custom_sampling/schedulers",
|
| 140 |
+
inputs=[
|
| 141 |
+
io.Model.Input("model"),
|
| 142 |
+
io.Int.Input("steps", default=1, min=1, max=10),
|
| 143 |
+
io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01),
|
| 144 |
+
],
|
| 145 |
+
outputs=[io.Sigmas.Output()]
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
@classmethod
|
| 149 |
+
def execute(cls, model, steps, denoise) -> io.NodeOutput:
|
| 150 |
+
start_step = 10 - int(10 * denoise)
|
| 151 |
+
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
|
| 152 |
+
sigmas = model.get_model_object("model_sampling").sigma(timesteps)
|
| 153 |
+
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
| 154 |
+
return io.NodeOutput(sigmas)
|
| 155 |
+
|
| 156 |
+
get_sigmas = execute
|
| 157 |
+
|
| 158 |
+
class BetaSamplingScheduler(io.ComfyNode):
|
| 159 |
+
@classmethod
|
| 160 |
+
def define_schema(cls):
|
| 161 |
+
return io.Schema(
|
| 162 |
+
node_id="BetaSamplingScheduler",
|
| 163 |
+
category="sampling/custom_sampling/schedulers",
|
| 164 |
+
inputs=[
|
| 165 |
+
io.Model.Input("model"),
|
| 166 |
+
io.Int.Input("steps", default=20, min=1, max=10000),
|
| 167 |
+
io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True),
|
| 168 |
+
io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True),
|
| 169 |
+
],
|
| 170 |
+
outputs=[io.Sigmas.Output()]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
@classmethod
|
| 174 |
+
def execute(cls, model, steps, alpha, beta) -> io.NodeOutput:
|
| 175 |
+
sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
|
| 176 |
+
return io.NodeOutput(sigmas)
|
| 177 |
+
|
| 178 |
+
get_sigmas = execute
|
| 179 |
+
|
| 180 |
+
class VPScheduler(io.ComfyNode):
|
| 181 |
+
@classmethod
|
| 182 |
+
def define_schema(cls):
|
| 183 |
+
return io.Schema(
|
| 184 |
+
node_id="VPScheduler",
|
| 185 |
+
category="sampling/custom_sampling/schedulers",
|
| 186 |
+
inputs=[
|
| 187 |
+
io.Int.Input("steps", default=20, min=1, max=10000),
|
| 188 |
+
io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), #TODO: fix default values
|
| 189 |
+
io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
| 190 |
+
io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
|
| 191 |
+
],
|
| 192 |
+
outputs=[io.Sigmas.Output()]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
@classmethod
|
| 196 |
+
def execute(cls, steps, beta_d, beta_min, eps_s) -> io.NodeOutput:
|
| 197 |
+
sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
|
| 198 |
+
return io.NodeOutput(sigmas)
|
| 199 |
+
|
| 200 |
+
get_sigmas = execute
|
| 201 |
+
|
| 202 |
+
class SplitSigmas(io.ComfyNode):
|
| 203 |
+
@classmethod
|
| 204 |
+
def define_schema(cls):
|
| 205 |
+
return io.Schema(
|
| 206 |
+
node_id="SplitSigmas",
|
| 207 |
+
category="sampling/custom_sampling/sigmas",
|
| 208 |
+
inputs=[
|
| 209 |
+
io.Sigmas.Input("sigmas"),
|
| 210 |
+
io.Int.Input("step", default=0, min=0, max=10000),
|
| 211 |
+
],
|
| 212 |
+
outputs=[
|
| 213 |
+
io.Sigmas.Output(display_name="high_sigmas"),
|
| 214 |
+
io.Sigmas.Output(display_name="low_sigmas"),
|
| 215 |
+
]
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def execute(cls, sigmas, step) -> io.NodeOutput:
|
| 220 |
+
sigmas1 = sigmas[:step + 1]
|
| 221 |
+
sigmas2 = sigmas[step:]
|
| 222 |
+
return io.NodeOutput(sigmas1, sigmas2)
|
| 223 |
+
|
| 224 |
+
get_sigmas = execute
|
| 225 |
+
|
| 226 |
+
class SplitSigmasDenoise(io.ComfyNode):
|
| 227 |
+
@classmethod
|
| 228 |
+
def define_schema(cls):
|
| 229 |
+
return io.Schema(
|
| 230 |
+
node_id="SplitSigmasDenoise",
|
| 231 |
+
category="sampling/custom_sampling/sigmas",
|
| 232 |
+
inputs=[
|
| 233 |
+
io.Sigmas.Input("sigmas"),
|
| 234 |
+
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
| 235 |
+
],
|
| 236 |
+
outputs=[
|
| 237 |
+
io.Sigmas.Output(display_name="high_sigmas"),
|
| 238 |
+
io.Sigmas.Output(display_name="low_sigmas"),
|
| 239 |
+
]
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def execute(cls, sigmas, denoise) -> io.NodeOutput:
|
| 244 |
+
steps = max(sigmas.shape[-1] - 1, 0)
|
| 245 |
+
total_steps = round(steps * denoise)
|
| 246 |
+
sigmas1 = sigmas[:-(total_steps)]
|
| 247 |
+
sigmas2 = sigmas[-(total_steps + 1):]
|
| 248 |
+
return io.NodeOutput(sigmas1, sigmas2)
|
| 249 |
+
|
| 250 |
+
get_sigmas = execute
|
| 251 |
+
|
| 252 |
+
class FlipSigmas(io.ComfyNode):
|
| 253 |
+
@classmethod
|
| 254 |
+
def define_schema(cls):
|
| 255 |
+
return io.Schema(
|
| 256 |
+
node_id="FlipSigmas",
|
| 257 |
+
category="sampling/custom_sampling/sigmas",
|
| 258 |
+
inputs=[io.Sigmas.Input("sigmas")],
|
| 259 |
+
outputs=[io.Sigmas.Output()]
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
@classmethod
|
| 263 |
+
def execute(cls, sigmas) -> io.NodeOutput:
|
| 264 |
+
if len(sigmas) == 0:
|
| 265 |
+
return io.NodeOutput(sigmas)
|
| 266 |
+
|
| 267 |
+
sigmas = sigmas.flip(0)
|
| 268 |
+
if sigmas[0] == 0:
|
| 269 |
+
sigmas[0] = 0.0001
|
| 270 |
+
return io.NodeOutput(sigmas)
|
| 271 |
+
|
| 272 |
+
get_sigmas = execute
|
| 273 |
+
|
| 274 |
+
class SetFirstSigma(io.ComfyNode):
|
| 275 |
+
@classmethod
|
| 276 |
+
def define_schema(cls):
|
| 277 |
+
return io.Schema(
|
| 278 |
+
node_id="SetFirstSigma",
|
| 279 |
+
category="sampling/custom_sampling/sigmas",
|
| 280 |
+
inputs=[
|
| 281 |
+
io.Sigmas.Input("sigmas"),
|
| 282 |
+
io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False),
|
| 283 |
+
],
|
| 284 |
+
outputs=[io.Sigmas.Output()]
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
@classmethod
|
| 288 |
+
def execute(cls, sigmas, sigma) -> io.NodeOutput:
|
| 289 |
+
sigmas = sigmas.clone()
|
| 290 |
+
sigmas[0] = sigma
|
| 291 |
+
return io.NodeOutput(sigmas)
|
| 292 |
+
|
| 293 |
+
set_first_sigma = execute
|
| 294 |
+
|
| 295 |
+
class ExtendIntermediateSigmas(io.ComfyNode):
|
| 296 |
+
@classmethod
|
| 297 |
+
def define_schema(cls):
|
| 298 |
+
return io.Schema(
|
| 299 |
+
node_id="ExtendIntermediateSigmas",
|
| 300 |
+
search_aliases=["interpolate sigmas"],
|
| 301 |
+
category="sampling/custom_sampling/sigmas",
|
| 302 |
+
inputs=[
|
| 303 |
+
io.Sigmas.Input("sigmas"),
|
| 304 |
+
io.Int.Input("steps", default=2, min=1, max=100),
|
| 305 |
+
io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False),
|
| 306 |
+
io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False),
|
| 307 |
+
io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']),
|
| 308 |
+
],
|
| 309 |
+
outputs=[io.Sigmas.Output()]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
@classmethod
|
| 313 |
+
def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str) -> io.NodeOutput:
|
| 314 |
+
if start_at_sigma < 0:
|
| 315 |
+
start_at_sigma = float("inf")
|
| 316 |
+
|
| 317 |
+
interpolator = {
|
| 318 |
+
'linear': lambda x: x,
|
| 319 |
+
'cosine': lambda x: torch.sin(x*math.pi/2),
|
| 320 |
+
'sine': lambda x: 1 - torch.cos(x*math.pi/2)
|
| 321 |
+
}[spacing]
|
| 322 |
+
|
| 323 |
+
# linear space for our interpolation function
|
| 324 |
+
x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1]
|
| 325 |
+
computed_spacing = interpolator(x)
|
| 326 |
+
|
| 327 |
+
extended_sigmas = []
|
| 328 |
+
for i in range(len(sigmas) - 1):
|
| 329 |
+
sigma_current = sigmas[i]
|
| 330 |
+
sigma_next = sigmas[i+1]
|
| 331 |
+
|
| 332 |
+
extended_sigmas.append(sigma_current)
|
| 333 |
+
|
| 334 |
+
if end_at_sigma <= sigma_current <= start_at_sigma:
|
| 335 |
+
interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current
|
| 336 |
+
extended_sigmas.extend(interpolated_steps.tolist())
|
| 337 |
+
|
| 338 |
+
# Add the last sigma value
|
| 339 |
+
if len(sigmas) > 0:
|
| 340 |
+
extended_sigmas.append(sigmas[-1])
|
| 341 |
+
|
| 342 |
+
extended_sigmas = torch.FloatTensor(extended_sigmas)
|
| 343 |
+
|
| 344 |
+
return io.NodeOutput(extended_sigmas)
|
| 345 |
+
|
| 346 |
+
extend = execute
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class SamplingPercentToSigma(io.ComfyNode):
|
| 350 |
+
@classmethod
|
| 351 |
+
def define_schema(cls):
|
| 352 |
+
return io.Schema(
|
| 353 |
+
node_id="SamplingPercentToSigma",
|
| 354 |
+
category="sampling/custom_sampling/sigmas",
|
| 355 |
+
inputs=[
|
| 356 |
+
io.Model.Input("model"),
|
| 357 |
+
io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001),
|
| 358 |
+
io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."),
|
| 359 |
+
],
|
| 360 |
+
outputs=[io.Float.Output(display_name="sigma_value")]
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
@classmethod
|
| 364 |
+
def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput:
|
| 365 |
+
model_sampling = model.get_model_object("model_sampling")
|
| 366 |
+
sigma_val = model_sampling.percent_to_sigma(sampling_percent)
|
| 367 |
+
if return_actual_sigma:
|
| 368 |
+
if sampling_percent == 0.0:
|
| 369 |
+
sigma_val = model_sampling.sigma_max.item()
|
| 370 |
+
elif sampling_percent == 1.0:
|
| 371 |
+
sigma_val = model_sampling.sigma_min.item()
|
| 372 |
+
return io.NodeOutput(sigma_val)
|
| 373 |
+
|
| 374 |
+
get_sigma = execute
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class KSamplerSelect(io.ComfyNode):
|
| 378 |
+
@classmethod
|
| 379 |
+
def define_schema(cls):
|
| 380 |
+
return io.Schema(
|
| 381 |
+
node_id="KSamplerSelect",
|
| 382 |
+
category="sampling/custom_sampling/samplers",
|
| 383 |
+
inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)],
|
| 384 |
+
outputs=[io.Sampler.Output()]
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
@classmethod
|
| 388 |
+
def execute(cls, sampler_name) -> io.NodeOutput:
|
| 389 |
+
sampler = comfy.samplers.sampler_object(sampler_name)
|
| 390 |
+
return io.NodeOutput(sampler)
|
| 391 |
+
|
| 392 |
+
get_sampler = execute
|
| 393 |
+
|
| 394 |
+
class SamplerDPMPP_3M_SDE(io.ComfyNode):
|
| 395 |
+
@classmethod
|
| 396 |
+
def define_schema(cls):
|
| 397 |
+
return io.Schema(
|
| 398 |
+
node_id="SamplerDPMPP_3M_SDE",
|
| 399 |
+
category="sampling/custom_sampling/samplers",
|
| 400 |
+
inputs=[
|
| 401 |
+
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 402 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 403 |
+
io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True),
|
| 404 |
+
],
|
| 405 |
+
outputs=[io.Sampler.Output()]
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
@classmethod
|
| 409 |
+
def execute(cls, eta, s_noise, noise_device) -> io.NodeOutput:
|
| 410 |
+
if noise_device == 'cpu':
|
| 411 |
+
sampler_name = "dpmpp_3m_sde"
|
| 412 |
+
else:
|
| 413 |
+
sampler_name = "dpmpp_3m_sde_gpu"
|
| 414 |
+
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise})
|
| 415 |
+
return io.NodeOutput(sampler)
|
| 416 |
+
|
| 417 |
+
get_sampler = execute
|
| 418 |
+
|
| 419 |
+
class SamplerDPMPP_2M_SDE(io.ComfyNode):
|
| 420 |
+
@classmethod
|
| 421 |
+
def define_schema(cls):
|
| 422 |
+
return io.Schema(
|
| 423 |
+
node_id="SamplerDPMPP_2M_SDE",
|
| 424 |
+
category="sampling/custom_sampling/samplers",
|
| 425 |
+
inputs=[
|
| 426 |
+
io.Combo.Input("solver_type", options=['midpoint', 'heun']),
|
| 427 |
+
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 428 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 429 |
+
io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True),
|
| 430 |
+
],
|
| 431 |
+
outputs=[io.Sampler.Output()]
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
@classmethod
|
| 435 |
+
def execute(cls, solver_type, eta, s_noise, noise_device) -> io.NodeOutput:
|
| 436 |
+
if noise_device == 'cpu':
|
| 437 |
+
sampler_name = "dpmpp_2m_sde"
|
| 438 |
+
else:
|
| 439 |
+
sampler_name = "dpmpp_2m_sde_gpu"
|
| 440 |
+
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
|
| 441 |
+
return io.NodeOutput(sampler)
|
| 442 |
+
|
| 443 |
+
get_sampler = execute
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class SamplerDPMPP_SDE(io.ComfyNode):
|
| 447 |
+
@classmethod
|
| 448 |
+
def define_schema(cls):
|
| 449 |
+
return io.Schema(
|
| 450 |
+
node_id="SamplerDPMPP_SDE",
|
| 451 |
+
category="sampling/custom_sampling/samplers",
|
| 452 |
+
inputs=[
|
| 453 |
+
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 454 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 455 |
+
io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 456 |
+
io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True),
|
| 457 |
+
],
|
| 458 |
+
outputs=[io.Sampler.Output()]
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
@classmethod
|
| 462 |
+
def execute(cls, eta, s_noise, r, noise_device) -> io.NodeOutput:
|
| 463 |
+
if noise_device == 'cpu':
|
| 464 |
+
sampler_name = "dpmpp_sde"
|
| 465 |
+
else:
|
| 466 |
+
sampler_name = "dpmpp_sde_gpu"
|
| 467 |
+
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
| 468 |
+
return io.NodeOutput(sampler)
|
| 469 |
+
|
| 470 |
+
get_sampler = execute
|
| 471 |
+
|
| 472 |
+
class SamplerDPMPP_2S_Ancestral(io.ComfyNode):
|
| 473 |
+
@classmethod
|
| 474 |
+
def define_schema(cls):
|
| 475 |
+
return io.Schema(
|
| 476 |
+
node_id="SamplerDPMPP_2S_Ancestral",
|
| 477 |
+
category="sampling/custom_sampling/samplers",
|
| 478 |
+
inputs=[
|
| 479 |
+
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
|
| 480 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
|
| 481 |
+
],
|
| 482 |
+
outputs=[io.Sampler.Output()]
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
@classmethod
|
| 486 |
+
def execute(cls, eta, s_noise) -> io.NodeOutput:
|
| 487 |
+
sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise})
|
| 488 |
+
return io.NodeOutput(sampler)
|
| 489 |
+
|
| 490 |
+
get_sampler = execute
|
| 491 |
+
|
| 492 |
+
class SamplerEulerAncestral(io.ComfyNode):
|
| 493 |
+
@classmethod
|
| 494 |
+
def define_schema(cls):
|
| 495 |
+
return io.Schema(
|
| 496 |
+
node_id="SamplerEulerAncestral",
|
| 497 |
+
category="sampling/custom_sampling/samplers",
|
| 498 |
+
inputs=[
|
| 499 |
+
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 500 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 501 |
+
],
|
| 502 |
+
outputs=[io.Sampler.Output()]
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
@classmethod
|
| 506 |
+
def execute(cls, eta, s_noise) -> io.NodeOutput:
|
| 507 |
+
sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
|
| 508 |
+
return io.NodeOutput(sampler)
|
| 509 |
+
|
| 510 |
+
get_sampler = execute
|
| 511 |
+
|
| 512 |
+
class SamplerEulerAncestralCFGPP(io.ComfyNode):
|
| 513 |
+
@classmethod
|
| 514 |
+
def define_schema(cls):
|
| 515 |
+
return io.Schema(
|
| 516 |
+
node_id="SamplerEulerAncestralCFGPP",
|
| 517 |
+
display_name="SamplerEulerAncestralCFG++",
|
| 518 |
+
category="sampling/custom_sampling/samplers",
|
| 519 |
+
inputs=[
|
| 520 |
+
io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False),
|
| 521 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
|
| 522 |
+
],
|
| 523 |
+
outputs=[io.Sampler.Output()]
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
@classmethod
|
| 527 |
+
def execute(cls, eta, s_noise) -> io.NodeOutput:
|
| 528 |
+
sampler = comfy.samplers.ksampler(
|
| 529 |
+
"euler_ancestral_cfg_pp",
|
| 530 |
+
{"eta": eta, "s_noise": s_noise})
|
| 531 |
+
return io.NodeOutput(sampler)
|
| 532 |
+
|
| 533 |
+
get_sampler = execute
|
| 534 |
+
|
| 535 |
+
class SamplerLMS(io.ComfyNode):
|
| 536 |
+
@classmethod
|
| 537 |
+
def define_schema(cls):
|
| 538 |
+
return io.Schema(
|
| 539 |
+
node_id="SamplerLMS",
|
| 540 |
+
category="sampling/custom_sampling/samplers",
|
| 541 |
+
inputs=[io.Int.Input("order", default=4, min=1, max=100, advanced=True)],
|
| 542 |
+
outputs=[io.Sampler.Output()]
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
@classmethod
|
| 546 |
+
def execute(cls, order) -> io.NodeOutput:
|
| 547 |
+
sampler = comfy.samplers.ksampler("lms", {"order": order})
|
| 548 |
+
return io.NodeOutput(sampler)
|
| 549 |
+
|
| 550 |
+
get_sampler = execute
|
| 551 |
+
|
| 552 |
+
class SamplerDPMAdaptative(io.ComfyNode):
|
| 553 |
+
@classmethod
|
| 554 |
+
def define_schema(cls):
|
| 555 |
+
return io.Schema(
|
| 556 |
+
node_id="SamplerDPMAdaptative",
|
| 557 |
+
category="sampling/custom_sampling/samplers",
|
| 558 |
+
inputs=[
|
| 559 |
+
io.Int.Input("order", default=3, min=2, max=3, advanced=True),
|
| 560 |
+
io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 561 |
+
io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 562 |
+
io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 563 |
+
io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 564 |
+
io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 565 |
+
io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 566 |
+
io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 567 |
+
io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 568 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 569 |
+
],
|
| 570 |
+
outputs=[io.Sampler.Output()]
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
@classmethod
|
| 574 |
+
def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) -> io.NodeOutput:
|
| 575 |
+
sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff,
|
| 576 |
+
"icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta,
|
| 577 |
+
"s_noise":s_noise })
|
| 578 |
+
return io.NodeOutput(sampler)
|
| 579 |
+
|
| 580 |
+
get_sampler = execute
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
class SamplerER_SDE(io.ComfyNode):
|
| 584 |
+
@classmethod
|
| 585 |
+
def define_schema(cls):
|
| 586 |
+
return io.Schema(
|
| 587 |
+
node_id="SamplerER_SDE",
|
| 588 |
+
category="sampling/custom_sampling/samplers",
|
| 589 |
+
inputs=[
|
| 590 |
+
io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
|
| 591 |
+
io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True),
|
| 592 |
+
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type.", advanced=True),
|
| 593 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 594 |
+
],
|
| 595 |
+
outputs=[io.Sampler.Output()]
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
@classmethod
|
| 599 |
+
def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput:
|
| 600 |
+
if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
|
| 601 |
+
eta = 0
|
| 602 |
+
s_noise = 0
|
| 603 |
+
|
| 604 |
+
def reverse_time_sde_noise_scaler(x):
|
| 605 |
+
return x ** (eta + 1)
|
| 606 |
+
|
| 607 |
+
if solver_type == "ER-SDE":
|
| 608 |
+
# Use the default one in sample_er_sde()
|
| 609 |
+
noise_scaler = None
|
| 610 |
+
else:
|
| 611 |
+
noise_scaler = reverse_time_sde_noise_scaler
|
| 612 |
+
|
| 613 |
+
sampler_name = "er_sde"
|
| 614 |
+
sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
|
| 615 |
+
return io.NodeOutput(sampler)
|
| 616 |
+
|
| 617 |
+
get_sampler = execute
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class SamplerSASolver(io.ComfyNode):
|
| 621 |
+
@classmethod
|
| 622 |
+
def define_schema(cls):
|
| 623 |
+
return io.Schema(
|
| 624 |
+
node_id="SamplerSASolver",
|
| 625 |
+
search_aliases=["sde"],
|
| 626 |
+
category="sampling/custom_sampling/samplers",
|
| 627 |
+
inputs=[
|
| 628 |
+
io.Model.Input("model"),
|
| 629 |
+
io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False, advanced=True),
|
| 630 |
+
io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001, advanced=True),
|
| 631 |
+
io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001, advanced=True),
|
| 632 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
| 633 |
+
io.Int.Input("predictor_order", default=3, min=1, max=6, advanced=True),
|
| 634 |
+
io.Int.Input("corrector_order", default=4, min=0, max=6, advanced=True),
|
| 635 |
+
io.Boolean.Input("use_pece", advanced=True),
|
| 636 |
+
io.Boolean.Input("simple_order_2", advanced=True),
|
| 637 |
+
],
|
| 638 |
+
outputs=[io.Sampler.Output()]
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
@classmethod
|
| 642 |
+
def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput:
|
| 643 |
+
model_sampling = model.get_model_object("model_sampling")
|
| 644 |
+
start_sigma = model_sampling.percent_to_sigma(sde_start_percent)
|
| 645 |
+
end_sigma = model_sampling.percent_to_sigma(sde_end_percent)
|
| 646 |
+
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta)
|
| 647 |
+
|
| 648 |
+
sampler_name = "sa_solver"
|
| 649 |
+
sampler = comfy.samplers.ksampler(
|
| 650 |
+
sampler_name,
|
| 651 |
+
{
|
| 652 |
+
"tau_func": tau_func,
|
| 653 |
+
"s_noise": s_noise,
|
| 654 |
+
"predictor_order": predictor_order,
|
| 655 |
+
"corrector_order": corrector_order,
|
| 656 |
+
"use_pece": use_pece,
|
| 657 |
+
"simple_order_2": simple_order_2,
|
| 658 |
+
},
|
| 659 |
+
)
|
| 660 |
+
return io.NodeOutput(sampler)
|
| 661 |
+
|
| 662 |
+
get_sampler = execute
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
class SamplerSEEDS2(io.ComfyNode):
|
| 666 |
+
@classmethod
|
| 667 |
+
def define_schema(cls):
|
| 668 |
+
return io.Schema(
|
| 669 |
+
node_id="SamplerSEEDS2",
|
| 670 |
+
search_aliases=["sde", "exp heun"],
|
| 671 |
+
category="sampling/custom_sampling/samplers",
|
| 672 |
+
inputs=[
|
| 673 |
+
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
| 674 |
+
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength", advanced=True),
|
| 675 |
+
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier", advanced=True),
|
| 676 |
+
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)", advanced=True),
|
| 677 |
+
],
|
| 678 |
+
outputs=[io.Sampler.Output()],
|
| 679 |
+
description=(
|
| 680 |
+
"This sampler node can represent multiple samplers:\n\n"
|
| 681 |
+
"seeds_2\n"
|
| 682 |
+
"- default setting\n\n"
|
| 683 |
+
"exp_heun_2_x0\n"
|
| 684 |
+
"- solver_type=phi_2, r=1.0, eta=0.0\n\n"
|
| 685 |
+
"exp_heun_2_x0_sde\n"
|
| 686 |
+
"- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0"
|
| 687 |
+
)
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
@classmethod
|
| 691 |
+
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
| 692 |
+
sampler_name = "seeds_2"
|
| 693 |
+
sampler = comfy.samplers.ksampler(
|
| 694 |
+
sampler_name,
|
| 695 |
+
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
| 696 |
+
)
|
| 697 |
+
return io.NodeOutput(sampler)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class Noise_EmptyNoise:
|
| 701 |
+
def __init__(self):
|
| 702 |
+
self.seed = 0
|
| 703 |
+
|
| 704 |
+
def generate_noise(self, input_latent):
|
| 705 |
+
latent_image = input_latent["samples"]
|
| 706 |
+
if latent_image.is_nested:
|
| 707 |
+
tensors = latent_image.unbind()
|
| 708 |
+
zeros = []
|
| 709 |
+
for t in tensors:
|
| 710 |
+
zeros.append(torch.zeros(t.shape, dtype=t.dtype, layout=t.layout, device="cpu"))
|
| 711 |
+
return comfy.nested_tensor.NestedTensor(zeros)
|
| 712 |
+
else:
|
| 713 |
+
return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
class Noise_RandomNoise:
|
| 717 |
+
def __init__(self, seed):
|
| 718 |
+
self.seed = seed
|
| 719 |
+
|
| 720 |
+
def generate_noise(self, input_latent):
|
| 721 |
+
latent_image = input_latent["samples"]
|
| 722 |
+
batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
|
| 723 |
+
return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
|
| 724 |
+
|
| 725 |
+
class SamplerCustom(io.ComfyNode):
|
| 726 |
+
@classmethod
|
| 727 |
+
def define_schema(cls):
|
| 728 |
+
return io.Schema(
|
| 729 |
+
node_id="SamplerCustom",
|
| 730 |
+
category="sampling/custom_sampling",
|
| 731 |
+
inputs=[
|
| 732 |
+
io.Model.Input("model"),
|
| 733 |
+
io.Boolean.Input("add_noise", default=True, advanced=True),
|
| 734 |
+
io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
| 735 |
+
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
| 736 |
+
io.Conditioning.Input("positive"),
|
| 737 |
+
io.Conditioning.Input("negative"),
|
| 738 |
+
io.Sampler.Input("sampler"),
|
| 739 |
+
io.Sigmas.Input("sigmas"),
|
| 740 |
+
io.Latent.Input("latent_image"),
|
| 741 |
+
],
|
| 742 |
+
outputs=[
|
| 743 |
+
io.Latent.Output(display_name="output"),
|
| 744 |
+
io.Latent.Output(display_name="denoised_output"),
|
| 745 |
+
]
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
@classmethod
|
| 749 |
+
def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput:
|
| 750 |
+
latent = latent_image
|
| 751 |
+
latent_image = latent["samples"]
|
| 752 |
+
latent = latent.copy()
|
| 753 |
+
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
|
| 754 |
+
latent["samples"] = latent_image
|
| 755 |
+
|
| 756 |
+
if not add_noise:
|
| 757 |
+
noise = Noise_EmptyNoise().generate_noise(latent)
|
| 758 |
+
else:
|
| 759 |
+
noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
|
| 760 |
+
|
| 761 |
+
noise_mask = None
|
| 762 |
+
if "noise_mask" in latent:
|
| 763 |
+
noise_mask = latent["noise_mask"]
|
| 764 |
+
|
| 765 |
+
x0_output = {}
|
| 766 |
+
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
| 767 |
+
|
| 768 |
+
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
| 769 |
+
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
| 770 |
+
|
| 771 |
+
out = latent.copy()
|
| 772 |
+
out.pop("downscale_ratio_spacial", None)
|
| 773 |
+
out["samples"] = samples
|
| 774 |
+
if "x0" in x0_output:
|
| 775 |
+
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
|
| 776 |
+
if samples.is_nested:
|
| 777 |
+
latent_shapes = [x.shape for x in samples.unbind()]
|
| 778 |
+
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
|
| 779 |
+
out_denoised = latent.copy()
|
| 780 |
+
out_denoised["samples"] = x0_out
|
| 781 |
+
else:
|
| 782 |
+
out_denoised = out
|
| 783 |
+
return io.NodeOutput(out, out_denoised)
|
| 784 |
+
|
| 785 |
+
sample = execute
|
| 786 |
+
|
| 787 |
+
class Guider_Basic(comfy.samplers.CFGGuider):
|
| 788 |
+
def set_conds(self, positive):
|
| 789 |
+
self.inner_set_conds({"positive": positive})
|
| 790 |
+
|
| 791 |
+
class BasicGuider(io.ComfyNode):
|
| 792 |
+
@classmethod
|
| 793 |
+
def define_schema(cls):
|
| 794 |
+
return io.Schema(
|
| 795 |
+
node_id="BasicGuider",
|
| 796 |
+
category="sampling/custom_sampling/guiders",
|
| 797 |
+
inputs=[
|
| 798 |
+
io.Model.Input("model"),
|
| 799 |
+
io.Conditioning.Input("conditioning"),
|
| 800 |
+
],
|
| 801 |
+
outputs=[io.Guider.Output()]
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
@classmethod
|
| 805 |
+
def execute(cls, model, conditioning) -> io.NodeOutput:
|
| 806 |
+
guider = Guider_Basic(model)
|
| 807 |
+
guider.set_conds(conditioning)
|
| 808 |
+
return io.NodeOutput(guider)
|
| 809 |
+
|
| 810 |
+
get_guider = execute
|
| 811 |
+
|
| 812 |
+
class CFGGuider(io.ComfyNode):
|
| 813 |
+
@classmethod
|
| 814 |
+
def define_schema(cls):
|
| 815 |
+
return io.Schema(
|
| 816 |
+
node_id="CFGGuider",
|
| 817 |
+
category="sampling/custom_sampling/guiders",
|
| 818 |
+
inputs=[
|
| 819 |
+
io.Model.Input("model"),
|
| 820 |
+
io.Conditioning.Input("positive"),
|
| 821 |
+
io.Conditioning.Input("negative"),
|
| 822 |
+
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
| 823 |
+
],
|
| 824 |
+
outputs=[io.Guider.Output()]
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
@classmethod
|
| 828 |
+
def execute(cls, model, positive, negative, cfg) -> io.NodeOutput:
|
| 829 |
+
guider = comfy.samplers.CFGGuider(model)
|
| 830 |
+
guider.set_conds(positive, negative)
|
| 831 |
+
guider.set_cfg(cfg)
|
| 832 |
+
return io.NodeOutput(guider)
|
| 833 |
+
|
| 834 |
+
get_guider = execute
|
| 835 |
+
|
| 836 |
+
class Guider_DualCFG(comfy.samplers.CFGGuider):
|
| 837 |
+
def set_cfg(self, cfg1, cfg2, nested=False):
|
| 838 |
+
self.cfg1 = cfg1
|
| 839 |
+
self.cfg2 = cfg2
|
| 840 |
+
self.nested = nested
|
| 841 |
+
|
| 842 |
+
def set_conds(self, positive, middle, negative):
|
| 843 |
+
middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
|
| 844 |
+
self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
|
| 845 |
+
|
| 846 |
+
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
| 847 |
+
negative_cond = self.conds.get("negative", None)
|
| 848 |
+
middle_cond = self.conds.get("middle", None)
|
| 849 |
+
positive_cond = self.conds.get("positive", None)
|
| 850 |
+
|
| 851 |
+
if self.nested:
|
| 852 |
+
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
| 853 |
+
pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond)
|
| 854 |
+
return out[0] + self.cfg2 * (pred_text - out[0])
|
| 855 |
+
else:
|
| 856 |
+
if model_options.get("disable_cfg1_optimization", False) == False:
|
| 857 |
+
if math.isclose(self.cfg2, 1.0):
|
| 858 |
+
negative_cond = None
|
| 859 |
+
if math.isclose(self.cfg1, 1.0):
|
| 860 |
+
middle_cond = None
|
| 861 |
+
|
| 862 |
+
out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
|
| 863 |
+
return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
|
| 864 |
+
|
| 865 |
+
class DualCFGGuider(io.ComfyNode):
|
| 866 |
+
@classmethod
|
| 867 |
+
def define_schema(cls):
|
| 868 |
+
return io.Schema(
|
| 869 |
+
node_id="DualCFGGuider",
|
| 870 |
+
search_aliases=["dual prompt guidance"],
|
| 871 |
+
category="sampling/custom_sampling/guiders",
|
| 872 |
+
inputs=[
|
| 873 |
+
io.Model.Input("model"),
|
| 874 |
+
io.Conditioning.Input("cond1"),
|
| 875 |
+
io.Conditioning.Input("cond2"),
|
| 876 |
+
io.Conditioning.Input("negative"),
|
| 877 |
+
io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
| 878 |
+
io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
| 879 |
+
io.Combo.Input("style", options=["regular", "nested"]),
|
| 880 |
+
],
|
| 881 |
+
outputs=[io.Guider.Output()]
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
@classmethod
|
| 885 |
+
def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput:
|
| 886 |
+
guider = Guider_DualCFG(model)
|
| 887 |
+
guider.set_conds(cond1, cond2, negative)
|
| 888 |
+
guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
|
| 889 |
+
return io.NodeOutput(guider)
|
| 890 |
+
|
| 891 |
+
get_guider = execute
|
| 892 |
+
|
| 893 |
+
class DisableNoise(io.ComfyNode):
|
| 894 |
+
@classmethod
|
| 895 |
+
def define_schema(cls):
|
| 896 |
+
return io.Schema(
|
| 897 |
+
node_id="DisableNoise",
|
| 898 |
+
search_aliases=["zero noise"],
|
| 899 |
+
category="sampling/custom_sampling/noise",
|
| 900 |
+
inputs=[],
|
| 901 |
+
outputs=[io.Noise.Output()]
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
@classmethod
|
| 905 |
+
def execute(cls) -> io.NodeOutput:
|
| 906 |
+
return io.NodeOutput(Noise_EmptyNoise())
|
| 907 |
+
|
| 908 |
+
get_noise = execute
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
class RandomNoise(io.ComfyNode):
|
| 912 |
+
@classmethod
|
| 913 |
+
def define_schema(cls):
|
| 914 |
+
return io.Schema(
|
| 915 |
+
node_id="RandomNoise",
|
| 916 |
+
category="sampling/custom_sampling/noise",
|
| 917 |
+
inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)],
|
| 918 |
+
outputs=[io.Noise.Output()]
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
@classmethod
|
| 922 |
+
def execute(cls, noise_seed) -> io.NodeOutput:
|
| 923 |
+
return io.NodeOutput(Noise_RandomNoise(noise_seed))
|
| 924 |
+
|
| 925 |
+
get_noise = execute
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
class SamplerCustomAdvanced(io.ComfyNode):
|
| 929 |
+
@classmethod
|
| 930 |
+
def define_schema(cls):
|
| 931 |
+
return io.Schema(
|
| 932 |
+
node_id="SamplerCustomAdvanced",
|
| 933 |
+
category="sampling/custom_sampling",
|
| 934 |
+
inputs=[
|
| 935 |
+
io.Noise.Input("noise"),
|
| 936 |
+
io.Guider.Input("guider"),
|
| 937 |
+
io.Sampler.Input("sampler"),
|
| 938 |
+
io.Sigmas.Input("sigmas"),
|
| 939 |
+
io.Latent.Input("latent_image"),
|
| 940 |
+
],
|
| 941 |
+
outputs=[
|
| 942 |
+
io.Latent.Output(display_name="output"),
|
| 943 |
+
io.Latent.Output(display_name="denoised_output"),
|
| 944 |
+
]
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
@classmethod
|
| 948 |
+
def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
|
| 949 |
+
latent = latent_image
|
| 950 |
+
latent_image = latent["samples"]
|
| 951 |
+
latent = latent.copy()
|
| 952 |
+
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None))
|
| 953 |
+
latent["samples"] = latent_image
|
| 954 |
+
|
| 955 |
+
noise_mask = None
|
| 956 |
+
if "noise_mask" in latent:
|
| 957 |
+
noise_mask = latent["noise_mask"]
|
| 958 |
+
|
| 959 |
+
x0_output = {}
|
| 960 |
+
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
|
| 961 |
+
|
| 962 |
+
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
| 963 |
+
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
|
| 964 |
+
samples = samples.to(comfy.model_management.intermediate_device())
|
| 965 |
+
|
| 966 |
+
out = latent.copy()
|
| 967 |
+
out.pop("downscale_ratio_spacial", None)
|
| 968 |
+
out["samples"] = samples
|
| 969 |
+
if "x0" in x0_output:
|
| 970 |
+
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
|
| 971 |
+
if samples.is_nested:
|
| 972 |
+
latent_shapes = [x.shape for x in samples.unbind()]
|
| 973 |
+
x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
|
| 974 |
+
out_denoised = latent.copy()
|
| 975 |
+
out_denoised["samples"] = x0_out
|
| 976 |
+
else:
|
| 977 |
+
out_denoised = out
|
| 978 |
+
return io.NodeOutput(out, out_denoised)
|
| 979 |
+
|
| 980 |
+
sample = execute
|
| 981 |
+
|
| 982 |
+
class AddNoise(io.ComfyNode):
|
| 983 |
+
@classmethod
|
| 984 |
+
def define_schema(cls):
|
| 985 |
+
return io.Schema(
|
| 986 |
+
node_id="AddNoise",
|
| 987 |
+
category="_for_testing/custom_sampling/noise",
|
| 988 |
+
is_experimental=True,
|
| 989 |
+
inputs=[
|
| 990 |
+
io.Model.Input("model"),
|
| 991 |
+
io.Noise.Input("noise"),
|
| 992 |
+
io.Sigmas.Input("sigmas"),
|
| 993 |
+
io.Latent.Input("latent_image"),
|
| 994 |
+
],
|
| 995 |
+
outputs=[
|
| 996 |
+
io.Latent.Output(),
|
| 997 |
+
]
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
@classmethod
|
| 1001 |
+
def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput:
|
| 1002 |
+
if len(sigmas) == 0:
|
| 1003 |
+
return io.NodeOutput(latent_image)
|
| 1004 |
+
|
| 1005 |
+
latent = latent_image
|
| 1006 |
+
latent_image = latent["samples"]
|
| 1007 |
+
|
| 1008 |
+
noisy = noise.generate_noise(latent)
|
| 1009 |
+
|
| 1010 |
+
model_sampling = model.get_model_object("model_sampling")
|
| 1011 |
+
process_latent_out = model.get_model_object("process_latent_out")
|
| 1012 |
+
process_latent_in = model.get_model_object("process_latent_in")
|
| 1013 |
+
|
| 1014 |
+
if len(sigmas) > 1:
|
| 1015 |
+
scale = torch.abs(sigmas[0] - sigmas[-1])
|
| 1016 |
+
else:
|
| 1017 |
+
scale = sigmas[0]
|
| 1018 |
+
|
| 1019 |
+
if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
| 1020 |
+
latent_image = process_latent_in(latent_image)
|
| 1021 |
+
noisy = model_sampling.noise_scaling(scale, noisy, latent_image)
|
| 1022 |
+
noisy = process_latent_out(noisy)
|
| 1023 |
+
noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0)
|
| 1024 |
+
|
| 1025 |
+
out = latent.copy()
|
| 1026 |
+
out["samples"] = noisy
|
| 1027 |
+
return io.NodeOutput(out)
|
| 1028 |
+
|
| 1029 |
+
add_noise = execute
|
| 1030 |
+
|
| 1031 |
+
class ManualSigmas(io.ComfyNode):
|
| 1032 |
+
@classmethod
|
| 1033 |
+
def define_schema(cls):
|
| 1034 |
+
return io.Schema(
|
| 1035 |
+
node_id="ManualSigmas",
|
| 1036 |
+
search_aliases=["custom noise schedule", "define sigmas"],
|
| 1037 |
+
category="_for_testing/custom_sampling",
|
| 1038 |
+
is_experimental=True,
|
| 1039 |
+
inputs=[
|
| 1040 |
+
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
| 1041 |
+
],
|
| 1042 |
+
outputs=[io.Sigmas.Output()]
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
+
@classmethod
|
| 1046 |
+
def execute(cls, sigmas) -> io.NodeOutput:
|
| 1047 |
+
sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas)
|
| 1048 |
+
sigmas = [float(i) for i in sigmas]
|
| 1049 |
+
sigmas = torch.FloatTensor(sigmas)
|
| 1050 |
+
return io.NodeOutput(sigmas)
|
| 1051 |
+
|
| 1052 |
+
class CustomSamplersExtension(ComfyExtension):
|
| 1053 |
+
@override
|
| 1054 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 1055 |
+
return [
|
| 1056 |
+
SamplerCustom,
|
| 1057 |
+
BasicScheduler,
|
| 1058 |
+
KarrasScheduler,
|
| 1059 |
+
ExponentialScheduler,
|
| 1060 |
+
PolyexponentialScheduler,
|
| 1061 |
+
LaplaceScheduler,
|
| 1062 |
+
VPScheduler,
|
| 1063 |
+
BetaSamplingScheduler,
|
| 1064 |
+
SDTurboScheduler,
|
| 1065 |
+
KSamplerSelect,
|
| 1066 |
+
SamplerEulerAncestral,
|
| 1067 |
+
SamplerEulerAncestralCFGPP,
|
| 1068 |
+
SamplerLMS,
|
| 1069 |
+
SamplerDPMPP_3M_SDE,
|
| 1070 |
+
SamplerDPMPP_2M_SDE,
|
| 1071 |
+
SamplerDPMPP_SDE,
|
| 1072 |
+
SamplerDPMPP_2S_Ancestral,
|
| 1073 |
+
SamplerDPMAdaptative,
|
| 1074 |
+
SamplerER_SDE,
|
| 1075 |
+
SamplerSASolver,
|
| 1076 |
+
SamplerSEEDS2,
|
| 1077 |
+
SplitSigmas,
|
| 1078 |
+
SplitSigmasDenoise,
|
| 1079 |
+
FlipSigmas,
|
| 1080 |
+
SetFirstSigma,
|
| 1081 |
+
ExtendIntermediateSigmas,
|
| 1082 |
+
SamplingPercentToSigma,
|
| 1083 |
+
CFGGuider,
|
| 1084 |
+
DualCFGGuider,
|
| 1085 |
+
BasicGuider,
|
| 1086 |
+
RandomNoise,
|
| 1087 |
+
DisableNoise,
|
| 1088 |
+
AddNoise,
|
| 1089 |
+
SamplerCustomAdvanced,
|
| 1090 |
+
ManualSigmas,
|
| 1091 |
+
]
|
| 1092 |
+
|
| 1093 |
+
|
| 1094 |
+
async def comfy_entrypoint() -> CustomSamplersExtension:
|
| 1095 |
+
return CustomSamplersExtension()
|
ComfyUI/comfy_extras/nodes_dataset.py
ADDED
|
@@ -0,0 +1,1537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing_extensions import override
|
| 9 |
+
|
| 10 |
+
import folder_paths
|
| 11 |
+
import node_helpers
|
| 12 |
+
from comfy_api.latest import ComfyExtension, io
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_and_process_images(image_files, input_dir):
|
| 16 |
+
"""Utility function to load and process a list of images.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
image_files: List of image filenames
|
| 20 |
+
input_dir: Base directory containing the images
|
| 21 |
+
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
torch.Tensor: Batch of processed images
|
| 25 |
+
"""
|
| 26 |
+
if not image_files:
|
| 27 |
+
raise ValueError("No valid images found in input")
|
| 28 |
+
|
| 29 |
+
output_images = []
|
| 30 |
+
|
| 31 |
+
for file in image_files:
|
| 32 |
+
image_path = os.path.join(input_dir, file)
|
| 33 |
+
img = node_helpers.pillow(Image.open, image_path)
|
| 34 |
+
|
| 35 |
+
if img.mode == "I":
|
| 36 |
+
img = img.point(lambda i: i * (1 / 255))
|
| 37 |
+
img = img.convert("RGB")
|
| 38 |
+
img_array = np.array(img).astype(np.float32) / 255.0
|
| 39 |
+
img_tensor = torch.from_numpy(img_array)[None,]
|
| 40 |
+
output_images.append(img_tensor)
|
| 41 |
+
|
| 42 |
+
return output_images
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class LoadImageDataSetFromFolderNode(io.ComfyNode):
|
| 46 |
+
@classmethod
|
| 47 |
+
def define_schema(cls):
|
| 48 |
+
return io.Schema(
|
| 49 |
+
node_id="LoadImageDataSetFromFolder",
|
| 50 |
+
display_name="Load Image Dataset from Folder",
|
| 51 |
+
category="dataset",
|
| 52 |
+
is_experimental=True,
|
| 53 |
+
inputs=[
|
| 54 |
+
io.Combo.Input(
|
| 55 |
+
"folder",
|
| 56 |
+
options=folder_paths.get_input_subfolders(),
|
| 57 |
+
tooltip="The folder to load images from.",
|
| 58 |
+
)
|
| 59 |
+
],
|
| 60 |
+
outputs=[
|
| 61 |
+
io.Image.Output(
|
| 62 |
+
display_name="images",
|
| 63 |
+
is_output_list=True,
|
| 64 |
+
tooltip="List of loaded images",
|
| 65 |
+
)
|
| 66 |
+
],
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def execute(cls, folder):
|
| 71 |
+
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
| 72 |
+
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
| 73 |
+
image_files = [
|
| 74 |
+
f
|
| 75 |
+
for f in os.listdir(sub_input_dir)
|
| 76 |
+
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
| 77 |
+
]
|
| 78 |
+
output_tensor = load_and_process_images(image_files, sub_input_dir)
|
| 79 |
+
return io.NodeOutput(output_tensor)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
|
| 83 |
+
@classmethod
|
| 84 |
+
def define_schema(cls):
|
| 85 |
+
return io.Schema(
|
| 86 |
+
node_id="LoadImageTextDataSetFromFolder",
|
| 87 |
+
display_name="Load Image and Text Dataset from Folder",
|
| 88 |
+
category="dataset",
|
| 89 |
+
is_experimental=True,
|
| 90 |
+
inputs=[
|
| 91 |
+
io.Combo.Input(
|
| 92 |
+
"folder",
|
| 93 |
+
options=folder_paths.get_input_subfolders(),
|
| 94 |
+
tooltip="The folder to load images from.",
|
| 95 |
+
)
|
| 96 |
+
],
|
| 97 |
+
outputs=[
|
| 98 |
+
io.Image.Output(
|
| 99 |
+
display_name="images",
|
| 100 |
+
is_output_list=True,
|
| 101 |
+
tooltip="List of loaded images",
|
| 102 |
+
),
|
| 103 |
+
io.String.Output(
|
| 104 |
+
display_name="texts",
|
| 105 |
+
is_output_list=True,
|
| 106 |
+
tooltip="List of text captions",
|
| 107 |
+
),
|
| 108 |
+
],
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def execute(cls, folder):
|
| 113 |
+
logging.info(f"Loading images from folder: {folder}")
|
| 114 |
+
|
| 115 |
+
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
| 116 |
+
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
| 117 |
+
|
| 118 |
+
image_files = []
|
| 119 |
+
for item in os.listdir(sub_input_dir):
|
| 120 |
+
path = os.path.join(sub_input_dir, item)
|
| 121 |
+
if any(item.lower().endswith(ext) for ext in valid_extensions):
|
| 122 |
+
image_files.append(path)
|
| 123 |
+
elif os.path.isdir(path):
|
| 124 |
+
# Support kohya-ss/sd-scripts folder structure
|
| 125 |
+
repeat = 1
|
| 126 |
+
if item.split("_")[0].isdigit():
|
| 127 |
+
repeat = int(item.split("_")[0])
|
| 128 |
+
image_files.extend(
|
| 129 |
+
[
|
| 130 |
+
os.path.join(path, f)
|
| 131 |
+
for f in os.listdir(path)
|
| 132 |
+
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
| 133 |
+
]
|
| 134 |
+
* repeat
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
caption_file_path = [
|
| 138 |
+
f.replace(os.path.splitext(f)[1], ".txt") for f in image_files
|
| 139 |
+
]
|
| 140 |
+
captions = []
|
| 141 |
+
for caption_file in caption_file_path:
|
| 142 |
+
caption_path = os.path.join(sub_input_dir, caption_file)
|
| 143 |
+
if os.path.exists(caption_path):
|
| 144 |
+
with open(caption_path, "r", encoding="utf-8") as f:
|
| 145 |
+
caption = f.read().strip()
|
| 146 |
+
captions.append(caption)
|
| 147 |
+
else:
|
| 148 |
+
captions.append("")
|
| 149 |
+
|
| 150 |
+
output_tensor = load_and_process_images(image_files, sub_input_dir)
|
| 151 |
+
|
| 152 |
+
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
| 153 |
+
return io.NodeOutput(output_tensor, captions)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def save_images_to_folder(image_list, output_dir, prefix="image"):
|
| 157 |
+
"""Utility function to save a list of image tensors to disk.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W])
|
| 161 |
+
output_dir: Directory to save images to
|
| 162 |
+
prefix: Filename prefix
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
List of saved filenames
|
| 166 |
+
"""
|
| 167 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 168 |
+
saved_files = []
|
| 169 |
+
|
| 170 |
+
for idx, img_tensor in enumerate(image_list):
|
| 171 |
+
# Handle different tensor shapes
|
| 172 |
+
if isinstance(img_tensor, torch.Tensor):
|
| 173 |
+
# Remove batch dimension if present [1, H, W, C] -> [H, W, C]
|
| 174 |
+
if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
|
| 175 |
+
img_tensor = img_tensor.squeeze(0)
|
| 176 |
+
|
| 177 |
+
# If tensor is [C, H, W], permute to [H, W, C]
|
| 178 |
+
if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]:
|
| 179 |
+
if (
|
| 180 |
+
img_tensor.shape[0] <= 4
|
| 181 |
+
and img_tensor.shape[1] > 4
|
| 182 |
+
and img_tensor.shape[2] > 4
|
| 183 |
+
):
|
| 184 |
+
img_tensor = img_tensor.permute(1, 2, 0)
|
| 185 |
+
|
| 186 |
+
# Convert to numpy and scale to 0-255
|
| 187 |
+
img_array = img_tensor.cpu().numpy()
|
| 188 |
+
img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8)
|
| 189 |
+
|
| 190 |
+
# Convert to PIL Image
|
| 191 |
+
img = Image.fromarray(img_array)
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}")
|
| 194 |
+
|
| 195 |
+
# Save image
|
| 196 |
+
filename = f"{prefix}_{idx:05d}.png"
|
| 197 |
+
filepath = os.path.join(output_dir, filename)
|
| 198 |
+
img.save(filepath)
|
| 199 |
+
saved_files.append(filename)
|
| 200 |
+
|
| 201 |
+
return saved_files
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class SaveImageDataSetToFolderNode(io.ComfyNode):
|
| 205 |
+
@classmethod
|
| 206 |
+
def define_schema(cls):
|
| 207 |
+
return io.Schema(
|
| 208 |
+
node_id="SaveImageDataSetToFolder",
|
| 209 |
+
display_name="Save Image Dataset to Folder",
|
| 210 |
+
category="dataset",
|
| 211 |
+
is_experimental=True,
|
| 212 |
+
is_output_node=True,
|
| 213 |
+
is_input_list=True, # Receive images as list
|
| 214 |
+
inputs=[
|
| 215 |
+
io.Image.Input("images", tooltip="List of images to save."),
|
| 216 |
+
io.String.Input(
|
| 217 |
+
"folder_name",
|
| 218 |
+
default="dataset",
|
| 219 |
+
tooltip="Name of the folder to save images to (inside output directory).",
|
| 220 |
+
),
|
| 221 |
+
io.String.Input(
|
| 222 |
+
"filename_prefix",
|
| 223 |
+
default="image",
|
| 224 |
+
tooltip="Prefix for saved image filenames.",
|
| 225 |
+
advanced=True,
|
| 226 |
+
),
|
| 227 |
+
],
|
| 228 |
+
outputs=[],
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
@classmethod
|
| 232 |
+
def execute(cls, images, folder_name, filename_prefix):
|
| 233 |
+
# Extract scalar values
|
| 234 |
+
folder_name = folder_name[0]
|
| 235 |
+
filename_prefix = filename_prefix[0]
|
| 236 |
+
|
| 237 |
+
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
| 238 |
+
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
|
| 239 |
+
|
| 240 |
+
logging.info(f"Saved {len(saved_files)} images to {output_dir}.")
|
| 241 |
+
return io.NodeOutput()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
| 245 |
+
@classmethod
|
| 246 |
+
def define_schema(cls):
|
| 247 |
+
return io.Schema(
|
| 248 |
+
node_id="SaveImageTextDataSetToFolder",
|
| 249 |
+
display_name="Save Image and Text Dataset to Folder",
|
| 250 |
+
category="dataset",
|
| 251 |
+
is_experimental=True,
|
| 252 |
+
is_output_node=True,
|
| 253 |
+
is_input_list=True, # Receive both images and texts as lists
|
| 254 |
+
inputs=[
|
| 255 |
+
io.Image.Input("images", tooltip="List of images to save."),
|
| 256 |
+
io.String.Input("texts", tooltip="List of text captions to save."),
|
| 257 |
+
io.String.Input(
|
| 258 |
+
"folder_name",
|
| 259 |
+
default="dataset",
|
| 260 |
+
tooltip="Name of the folder to save images to (inside output directory).",
|
| 261 |
+
),
|
| 262 |
+
io.String.Input(
|
| 263 |
+
"filename_prefix",
|
| 264 |
+
default="image",
|
| 265 |
+
tooltip="Prefix for saved image filenames.",
|
| 266 |
+
advanced=True,
|
| 267 |
+
),
|
| 268 |
+
],
|
| 269 |
+
outputs=[],
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
@classmethod
|
| 273 |
+
def execute(cls, images, texts, folder_name, filename_prefix):
|
| 274 |
+
# Extract scalar values
|
| 275 |
+
folder_name = folder_name[0]
|
| 276 |
+
filename_prefix = filename_prefix[0]
|
| 277 |
+
|
| 278 |
+
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
| 279 |
+
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
|
| 280 |
+
|
| 281 |
+
# Save captions
|
| 282 |
+
for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
|
| 283 |
+
caption_filename = filename.replace(".png", ".txt")
|
| 284 |
+
caption_path = os.path.join(output_dir, caption_filename)
|
| 285 |
+
with open(caption_path, "w", encoding="utf-8") as f:
|
| 286 |
+
f.write(caption)
|
| 287 |
+
|
| 288 |
+
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
|
| 289 |
+
return io.NodeOutput()
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ========== Helper Functions for Transform Nodes ==========
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def tensor_to_pil(img_tensor):
|
| 296 |
+
"""Convert tensor to PIL Image."""
|
| 297 |
+
if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
|
| 298 |
+
img_tensor = img_tensor.squeeze(0)
|
| 299 |
+
img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
| 300 |
+
return Image.fromarray(img_array)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def pil_to_tensor(img):
|
| 304 |
+
"""Convert PIL Image to tensor."""
|
| 305 |
+
img_array = np.array(img).astype(np.float32) / 255.0
|
| 306 |
+
return torch.from_numpy(img_array)[None,]
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# ========== Base Classes for Transform Nodes ==========
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class ImageProcessingNode(io.ComfyNode):
|
| 313 |
+
"""Base class for image processing nodes that operate on images.
|
| 314 |
+
|
| 315 |
+
Child classes should set:
|
| 316 |
+
node_id: Unique node identifier (required)
|
| 317 |
+
display_name: Display name (optional, defaults to node_id)
|
| 318 |
+
description: Node description (optional)
|
| 319 |
+
extra_inputs: List of additional io.Input objects beyond "images" (optional)
|
| 320 |
+
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
| 321 |
+
is_output_list: True (list output) or False (single output) (optional, default True)
|
| 322 |
+
|
| 323 |
+
Child classes must implement ONE of:
|
| 324 |
+
_process(cls, image, **kwargs) -> tensor (for single-item processing)
|
| 325 |
+
_group_process(cls, images, **kwargs) -> list[tensor] (for group processing)
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
node_id = None
|
| 329 |
+
display_name = None
|
| 330 |
+
description = None
|
| 331 |
+
extra_inputs = []
|
| 332 |
+
is_group_process = None # None = auto-detect, True/False = explicit
|
| 333 |
+
is_output_list = None # None = auto-detect based on processing mode
|
| 334 |
+
|
| 335 |
+
@classmethod
|
| 336 |
+
def _detect_processing_mode(cls):
|
| 337 |
+
"""Detect whether this node uses group or individual processing.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
bool: True if group processing, False if individual processing
|
| 341 |
+
"""
|
| 342 |
+
# Explicit setting takes precedence
|
| 343 |
+
if cls.is_group_process is not None:
|
| 344 |
+
return cls.is_group_process
|
| 345 |
+
|
| 346 |
+
# Check which method is overridden by looking at the defining class in MRO
|
| 347 |
+
base_class = ImageProcessingNode
|
| 348 |
+
|
| 349 |
+
# Find which class in MRO defines _process
|
| 350 |
+
process_definer = None
|
| 351 |
+
for klass in cls.__mro__:
|
| 352 |
+
if "_process" in klass.__dict__:
|
| 353 |
+
process_definer = klass
|
| 354 |
+
break
|
| 355 |
+
|
| 356 |
+
# Find which class in MRO defines _group_process
|
| 357 |
+
group_definer = None
|
| 358 |
+
for klass in cls.__mro__:
|
| 359 |
+
if "_group_process" in klass.__dict__:
|
| 360 |
+
group_definer = klass
|
| 361 |
+
break
|
| 362 |
+
|
| 363 |
+
# Check what was overridden (not defined in base class)
|
| 364 |
+
has_process = process_definer is not None and process_definer is not base_class
|
| 365 |
+
has_group = group_definer is not None and group_definer is not base_class
|
| 366 |
+
|
| 367 |
+
if has_process and has_group:
|
| 368 |
+
raise ValueError(
|
| 369 |
+
f"{cls.__name__}: Cannot override both _process and _group_process. "
|
| 370 |
+
"Override only one, or set is_group_process explicitly."
|
| 371 |
+
)
|
| 372 |
+
if not has_process and not has_group:
|
| 373 |
+
raise ValueError(
|
| 374 |
+
f"{cls.__name__}: Must override either _process or _group_process"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
return has_group
|
| 378 |
+
|
| 379 |
+
@classmethod
|
| 380 |
+
def define_schema(cls):
|
| 381 |
+
if cls.node_id is None:
|
| 382 |
+
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
| 383 |
+
|
| 384 |
+
is_group = cls._detect_processing_mode()
|
| 385 |
+
|
| 386 |
+
# Auto-detect is_output_list if not explicitly set
|
| 387 |
+
# Single processing: False (backend collects results into list)
|
| 388 |
+
# Group processing: True by default (can be False for single-output nodes)
|
| 389 |
+
output_is_list = (
|
| 390 |
+
cls.is_output_list if cls.is_output_list is not None else is_group
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
inputs = [
|
| 394 |
+
io.Image.Input(
|
| 395 |
+
"images",
|
| 396 |
+
tooltip=(
|
| 397 |
+
"List of images to process." if is_group else "Image to process."
|
| 398 |
+
),
|
| 399 |
+
)
|
| 400 |
+
]
|
| 401 |
+
inputs.extend(cls.extra_inputs)
|
| 402 |
+
|
| 403 |
+
return io.Schema(
|
| 404 |
+
node_id=cls.node_id,
|
| 405 |
+
display_name=cls.display_name or cls.node_id,
|
| 406 |
+
category="dataset/image",
|
| 407 |
+
is_experimental=True,
|
| 408 |
+
is_input_list=is_group, # True for group, False for individual
|
| 409 |
+
inputs=inputs,
|
| 410 |
+
outputs=[
|
| 411 |
+
io.Image.Output(
|
| 412 |
+
display_name="images",
|
| 413 |
+
is_output_list=output_is_list,
|
| 414 |
+
tooltip="Processed images",
|
| 415 |
+
)
|
| 416 |
+
],
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
@classmethod
|
| 420 |
+
def execute(cls, images, **kwargs):
|
| 421 |
+
"""Execute the node. Routes to _process or _group_process based on mode."""
|
| 422 |
+
is_group = cls._detect_processing_mode()
|
| 423 |
+
|
| 424 |
+
# Extract scalar values from lists for parameters
|
| 425 |
+
params = {}
|
| 426 |
+
for k, v in kwargs.items():
|
| 427 |
+
if isinstance(v, list) and len(v) == 1:
|
| 428 |
+
params[k] = v[0]
|
| 429 |
+
else:
|
| 430 |
+
params[k] = v
|
| 431 |
+
|
| 432 |
+
if is_group:
|
| 433 |
+
# Group processing: images is list, call _group_process
|
| 434 |
+
result = cls._group_process(images, **params)
|
| 435 |
+
else:
|
| 436 |
+
# Individual processing: images is single item, call _process
|
| 437 |
+
result = cls._process(images, **params)
|
| 438 |
+
|
| 439 |
+
return io.NodeOutput(result)
|
| 440 |
+
|
| 441 |
+
@classmethod
|
| 442 |
+
def _process(cls, image, **kwargs):
|
| 443 |
+
"""Override this method for single-item processing.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
image: tensor - Single image tensor
|
| 447 |
+
**kwargs: Additional parameters (already extracted from lists)
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
tensor - Processed image
|
| 451 |
+
"""
|
| 452 |
+
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
| 453 |
+
|
| 454 |
+
@classmethod
|
| 455 |
+
def _group_process(cls, images, **kwargs):
|
| 456 |
+
"""Override this method for group processing.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
images: list[tensor] - List of image tensors
|
| 460 |
+
**kwargs: Additional parameters (already extracted from lists)
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
list[tensor] - Processed images
|
| 464 |
+
"""
|
| 465 |
+
raise NotImplementedError(
|
| 466 |
+
f"{cls.__name__} must implement _group_process method"
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class TextProcessingNode(io.ComfyNode):
|
| 471 |
+
"""Base class for text processing nodes that operate on texts.
|
| 472 |
+
|
| 473 |
+
Child classes should set:
|
| 474 |
+
node_id: Unique node identifier (required)
|
| 475 |
+
display_name: Display name (optional, defaults to node_id)
|
| 476 |
+
description: Node description (optional)
|
| 477 |
+
extra_inputs: List of additional io.Input objects beyond "texts" (optional)
|
| 478 |
+
is_group_process: None (auto-detect), True (group), or False (individual) (optional)
|
| 479 |
+
is_output_list: True (list output) or False (single output) (optional, default True)
|
| 480 |
+
|
| 481 |
+
Child classes must implement ONE of:
|
| 482 |
+
_process(cls, text, **kwargs) -> str (for single-item processing)
|
| 483 |
+
_group_process(cls, texts, **kwargs) -> list[str] (for group processing)
|
| 484 |
+
"""
|
| 485 |
+
|
| 486 |
+
node_id = None
|
| 487 |
+
display_name = None
|
| 488 |
+
description = None
|
| 489 |
+
extra_inputs = []
|
| 490 |
+
is_group_process = None # None = auto-detect, True/False = explicit
|
| 491 |
+
is_output_list = None # None = auto-detect based on processing mode
|
| 492 |
+
|
| 493 |
+
@classmethod
|
| 494 |
+
def _detect_processing_mode(cls):
|
| 495 |
+
"""Detect whether this node uses group or individual processing.
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
bool: True if group processing, False if individual processing
|
| 499 |
+
"""
|
| 500 |
+
# Explicit setting takes precedence
|
| 501 |
+
if cls.is_group_process is not None:
|
| 502 |
+
return cls.is_group_process
|
| 503 |
+
|
| 504 |
+
# Check which method is overridden by looking at the defining class in MRO
|
| 505 |
+
base_class = TextProcessingNode
|
| 506 |
+
|
| 507 |
+
# Find which class in MRO defines _process
|
| 508 |
+
process_definer = None
|
| 509 |
+
for klass in cls.__mro__:
|
| 510 |
+
if "_process" in klass.__dict__:
|
| 511 |
+
process_definer = klass
|
| 512 |
+
break
|
| 513 |
+
|
| 514 |
+
# Find which class in MRO defines _group_process
|
| 515 |
+
group_definer = None
|
| 516 |
+
for klass in cls.__mro__:
|
| 517 |
+
if "_group_process" in klass.__dict__:
|
| 518 |
+
group_definer = klass
|
| 519 |
+
break
|
| 520 |
+
|
| 521 |
+
# Check what was overridden (not defined in base class)
|
| 522 |
+
has_process = process_definer is not None and process_definer is not base_class
|
| 523 |
+
has_group = group_definer is not None and group_definer is not base_class
|
| 524 |
+
|
| 525 |
+
if has_process and has_group:
|
| 526 |
+
raise ValueError(
|
| 527 |
+
f"{cls.__name__}: Cannot override both _process and _group_process. "
|
| 528 |
+
"Override only one, or set is_group_process explicitly."
|
| 529 |
+
)
|
| 530 |
+
if not has_process and not has_group:
|
| 531 |
+
raise ValueError(
|
| 532 |
+
f"{cls.__name__}: Must override either _process or _group_process"
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
return has_group
|
| 536 |
+
|
| 537 |
+
@classmethod
|
| 538 |
+
def define_schema(cls):
|
| 539 |
+
if cls.node_id is None:
|
| 540 |
+
raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
|
| 541 |
+
|
| 542 |
+
is_group = cls._detect_processing_mode()
|
| 543 |
+
|
| 544 |
+
inputs = [
|
| 545 |
+
io.String.Input(
|
| 546 |
+
"texts",
|
| 547 |
+
tooltip="List of texts to process." if is_group else "Text to process.",
|
| 548 |
+
)
|
| 549 |
+
]
|
| 550 |
+
inputs.extend(cls.extra_inputs)
|
| 551 |
+
|
| 552 |
+
return io.Schema(
|
| 553 |
+
node_id=cls.node_id,
|
| 554 |
+
display_name=cls.display_name or cls.node_id,
|
| 555 |
+
category="dataset/text",
|
| 556 |
+
is_experimental=True,
|
| 557 |
+
is_input_list=is_group, # True for group, False for individual
|
| 558 |
+
inputs=inputs,
|
| 559 |
+
outputs=[
|
| 560 |
+
io.String.Output(
|
| 561 |
+
display_name="texts",
|
| 562 |
+
is_output_list=cls.is_output_list,
|
| 563 |
+
tooltip="Processed texts",
|
| 564 |
+
)
|
| 565 |
+
],
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
@classmethod
|
| 569 |
+
def execute(cls, texts, **kwargs):
|
| 570 |
+
"""Execute the node. Routes to _process or _group_process based on mode."""
|
| 571 |
+
is_group = cls._detect_processing_mode()
|
| 572 |
+
|
| 573 |
+
# Extract scalar values from lists for parameters
|
| 574 |
+
params = {}
|
| 575 |
+
for k, v in kwargs.items():
|
| 576 |
+
if isinstance(v, list) and len(v) == 1:
|
| 577 |
+
params[k] = v[0]
|
| 578 |
+
else:
|
| 579 |
+
params[k] = v
|
| 580 |
+
|
| 581 |
+
if is_group:
|
| 582 |
+
# Group processing: texts is list, call _group_process
|
| 583 |
+
result = cls._group_process(texts, **params)
|
| 584 |
+
else:
|
| 585 |
+
# Individual processing: texts is single item, call _process
|
| 586 |
+
result = cls._process(texts, **params)
|
| 587 |
+
|
| 588 |
+
# Wrap result based on is_output_list
|
| 589 |
+
if cls.is_output_list:
|
| 590 |
+
# Result should already be a list (or will be for individual)
|
| 591 |
+
return io.NodeOutput(result if is_group else [result])
|
| 592 |
+
else:
|
| 593 |
+
# Single output - wrap in list for NodeOutput
|
| 594 |
+
return io.NodeOutput([result])
|
| 595 |
+
|
| 596 |
+
@classmethod
|
| 597 |
+
def _process(cls, text, **kwargs):
|
| 598 |
+
"""Override this method for single-item processing.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
text: str - Single text string
|
| 602 |
+
**kwargs: Additional parameters (already extracted from lists)
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
str - Processed text
|
| 606 |
+
"""
|
| 607 |
+
raise NotImplementedError(f"{cls.__name__} must implement _process method")
|
| 608 |
+
|
| 609 |
+
@classmethod
|
| 610 |
+
def _group_process(cls, texts, **kwargs):
|
| 611 |
+
"""Override this method for group processing.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
texts: list[str] - List of text strings
|
| 615 |
+
**kwargs: Additional parameters (already extracted from lists)
|
| 616 |
+
|
| 617 |
+
Returns:
|
| 618 |
+
list[str] - Processed texts
|
| 619 |
+
"""
|
| 620 |
+
raise NotImplementedError(
|
| 621 |
+
f"{cls.__name__} must implement _group_process method"
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
# ========== Image Transform Nodes ==========
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
|
| 629 |
+
node_id = "ResizeImagesByShorterEdge"
|
| 630 |
+
display_name = "Resize Images by Shorter Edge"
|
| 631 |
+
description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
|
| 632 |
+
extra_inputs = [
|
| 633 |
+
io.Int.Input(
|
| 634 |
+
"shorter_edge",
|
| 635 |
+
default=512,
|
| 636 |
+
min=1,
|
| 637 |
+
max=8192,
|
| 638 |
+
tooltip="Target length for the shorter edge.",
|
| 639 |
+
),
|
| 640 |
+
]
|
| 641 |
+
|
| 642 |
+
@classmethod
|
| 643 |
+
def _process(cls, image, shorter_edge):
|
| 644 |
+
img = tensor_to_pil(image)
|
| 645 |
+
w, h = img.size
|
| 646 |
+
if w < h:
|
| 647 |
+
new_w = shorter_edge
|
| 648 |
+
new_h = int(h * (shorter_edge / w))
|
| 649 |
+
else:
|
| 650 |
+
new_h = shorter_edge
|
| 651 |
+
new_w = int(w * (shorter_edge / h))
|
| 652 |
+
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
| 653 |
+
return pil_to_tensor(img)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
|
| 657 |
+
node_id = "ResizeImagesByLongerEdge"
|
| 658 |
+
display_name = "Resize Images by Longer Edge"
|
| 659 |
+
description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
|
| 660 |
+
extra_inputs = [
|
| 661 |
+
io.Int.Input(
|
| 662 |
+
"longer_edge",
|
| 663 |
+
default=1024,
|
| 664 |
+
min=1,
|
| 665 |
+
max=8192,
|
| 666 |
+
tooltip="Target length for the longer edge.",
|
| 667 |
+
),
|
| 668 |
+
]
|
| 669 |
+
|
| 670 |
+
@classmethod
|
| 671 |
+
def _process(cls, image, longer_edge):
|
| 672 |
+
resized_images = []
|
| 673 |
+
for image_i in image:
|
| 674 |
+
img = tensor_to_pil(image_i)
|
| 675 |
+
w, h = img.size
|
| 676 |
+
if w > h:
|
| 677 |
+
new_w = longer_edge
|
| 678 |
+
new_h = int(h * (longer_edge / w))
|
| 679 |
+
else:
|
| 680 |
+
new_h = longer_edge
|
| 681 |
+
new_w = int(w * (longer_edge / h))
|
| 682 |
+
img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
| 683 |
+
resized_images.append(pil_to_tensor(img))
|
| 684 |
+
return torch.cat(resized_images, dim=0)
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
class CenterCropImagesNode(ImageProcessingNode):
|
| 688 |
+
node_id = "CenterCropImages"
|
| 689 |
+
display_name = "Center Crop Images"
|
| 690 |
+
description = "Center crop all images to the specified dimensions."
|
| 691 |
+
extra_inputs = [
|
| 692 |
+
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
| 693 |
+
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
| 694 |
+
]
|
| 695 |
+
|
| 696 |
+
@classmethod
|
| 697 |
+
def _process(cls, image, width, height):
|
| 698 |
+
img = tensor_to_pil(image)
|
| 699 |
+
left = max(0, (img.width - width) // 2)
|
| 700 |
+
top = max(0, (img.height - height) // 2)
|
| 701 |
+
right = min(img.width, left + width)
|
| 702 |
+
bottom = min(img.height, top + height)
|
| 703 |
+
img = img.crop((left, top, right, bottom))
|
| 704 |
+
return pil_to_tensor(img)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class RandomCropImagesNode(ImageProcessingNode):
|
| 708 |
+
node_id = "RandomCropImages"
|
| 709 |
+
display_name = "Random Crop Images"
|
| 710 |
+
description = (
|
| 711 |
+
"Randomly crop all images to the specified dimensions (for data augmentation)."
|
| 712 |
+
)
|
| 713 |
+
extra_inputs = [
|
| 714 |
+
io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
|
| 715 |
+
io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
|
| 716 |
+
io.Int.Input(
|
| 717 |
+
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
| 718 |
+
),
|
| 719 |
+
]
|
| 720 |
+
|
| 721 |
+
@classmethod
|
| 722 |
+
def _process(cls, image, width, height, seed):
|
| 723 |
+
np.random.seed(seed % (2**32 - 1))
|
| 724 |
+
img = tensor_to_pil(image)
|
| 725 |
+
max_left = max(0, img.width - width)
|
| 726 |
+
max_top = max(0, img.height - height)
|
| 727 |
+
left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
|
| 728 |
+
top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
|
| 729 |
+
right = min(img.width, left + width)
|
| 730 |
+
bottom = min(img.height, top + height)
|
| 731 |
+
img = img.crop((left, top, right, bottom))
|
| 732 |
+
return pil_to_tensor(img)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
class NormalizeImagesNode(ImageProcessingNode):
|
| 736 |
+
node_id = "NormalizeImages"
|
| 737 |
+
display_name = "Normalize Images"
|
| 738 |
+
description = "Normalize images using mean and standard deviation."
|
| 739 |
+
extra_inputs = [
|
| 740 |
+
io.Float.Input(
|
| 741 |
+
"mean",
|
| 742 |
+
default=0.5,
|
| 743 |
+
min=0.0,
|
| 744 |
+
max=1.0,
|
| 745 |
+
tooltip="Mean value for normalization.",
|
| 746 |
+
advanced=True,
|
| 747 |
+
),
|
| 748 |
+
io.Float.Input(
|
| 749 |
+
"std",
|
| 750 |
+
default=0.5,
|
| 751 |
+
min=0.001,
|
| 752 |
+
max=1.0,
|
| 753 |
+
tooltip="Standard deviation for normalization.",
|
| 754 |
+
advanced=True,
|
| 755 |
+
),
|
| 756 |
+
]
|
| 757 |
+
|
| 758 |
+
@classmethod
|
| 759 |
+
def _process(cls, image, mean, std):
|
| 760 |
+
return (image - mean) / std
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
class AdjustBrightnessNode(ImageProcessingNode):
|
| 764 |
+
node_id = "AdjustBrightness"
|
| 765 |
+
display_name = "Adjust Brightness"
|
| 766 |
+
description = "Adjust brightness of all images."
|
| 767 |
+
extra_inputs = [
|
| 768 |
+
io.Float.Input(
|
| 769 |
+
"factor",
|
| 770 |
+
default=1.0,
|
| 771 |
+
min=0.0,
|
| 772 |
+
max=2.0,
|
| 773 |
+
tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.",
|
| 774 |
+
),
|
| 775 |
+
]
|
| 776 |
+
|
| 777 |
+
@classmethod
|
| 778 |
+
def _process(cls, image, factor):
|
| 779 |
+
return (image * factor).clamp(0.0, 1.0)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
class AdjustContrastNode(ImageProcessingNode):
|
| 783 |
+
node_id = "AdjustContrast"
|
| 784 |
+
display_name = "Adjust Contrast"
|
| 785 |
+
description = "Adjust contrast of all images."
|
| 786 |
+
extra_inputs = [
|
| 787 |
+
io.Float.Input(
|
| 788 |
+
"factor",
|
| 789 |
+
default=1.0,
|
| 790 |
+
min=0.0,
|
| 791 |
+
max=2.0,
|
| 792 |
+
tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.",
|
| 793 |
+
),
|
| 794 |
+
]
|
| 795 |
+
|
| 796 |
+
@classmethod
|
| 797 |
+
def _process(cls, image, factor):
|
| 798 |
+
return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0)
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
class ShuffleDatasetNode(ImageProcessingNode):
|
| 802 |
+
node_id = "ShuffleDataset"
|
| 803 |
+
display_name = "Shuffle Image Dataset"
|
| 804 |
+
description = "Randomly shuffle the order of images in the dataset."
|
| 805 |
+
is_group_process = True # Requires full list to shuffle
|
| 806 |
+
extra_inputs = [
|
| 807 |
+
io.Int.Input(
|
| 808 |
+
"seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
|
| 809 |
+
),
|
| 810 |
+
]
|
| 811 |
+
|
| 812 |
+
@classmethod
|
| 813 |
+
def _group_process(cls, images, seed):
|
| 814 |
+
np.random.seed(seed % (2**32 - 1))
|
| 815 |
+
indices = np.random.permutation(len(images))
|
| 816 |
+
return [images[i] for i in indices]
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
class ShuffleImageTextDatasetNode(io.ComfyNode):
|
| 820 |
+
"""Special node that shuffles both images and texts together."""
|
| 821 |
+
|
| 822 |
+
@classmethod
|
| 823 |
+
def define_schema(cls):
|
| 824 |
+
return io.Schema(
|
| 825 |
+
node_id="ShuffleImageTextDataset",
|
| 826 |
+
display_name="Shuffle Image-Text Dataset",
|
| 827 |
+
category="dataset/image",
|
| 828 |
+
is_experimental=True,
|
| 829 |
+
is_input_list=True,
|
| 830 |
+
inputs=[
|
| 831 |
+
io.Image.Input("images", tooltip="List of images to shuffle."),
|
| 832 |
+
io.String.Input("texts", tooltip="List of texts to shuffle."),
|
| 833 |
+
io.Int.Input(
|
| 834 |
+
"seed",
|
| 835 |
+
default=0,
|
| 836 |
+
min=0,
|
| 837 |
+
max=0xFFFFFFFFFFFFFFFF,
|
| 838 |
+
tooltip="Random seed.",
|
| 839 |
+
),
|
| 840 |
+
],
|
| 841 |
+
outputs=[
|
| 842 |
+
io.Image.Output(
|
| 843 |
+
display_name="images",
|
| 844 |
+
is_output_list=True,
|
| 845 |
+
tooltip="Shuffled images",
|
| 846 |
+
),
|
| 847 |
+
io.String.Output(
|
| 848 |
+
display_name="texts", is_output_list=True, tooltip="Shuffled texts"
|
| 849 |
+
),
|
| 850 |
+
],
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
@classmethod
|
| 854 |
+
def execute(cls, images, texts, seed):
|
| 855 |
+
seed = seed[0] # Extract scalar
|
| 856 |
+
np.random.seed(seed % (2**32 - 1))
|
| 857 |
+
indices = np.random.permutation(len(images))
|
| 858 |
+
shuffled_images = [images[i] for i in indices]
|
| 859 |
+
shuffled_texts = [texts[i] for i in indices]
|
| 860 |
+
return io.NodeOutput(shuffled_images, shuffled_texts)
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
# ========== Text Transform Nodes ==========
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
class TextToLowercaseNode(TextProcessingNode):
|
| 867 |
+
node_id = "TextToLowercase"
|
| 868 |
+
display_name = "Text to Lowercase"
|
| 869 |
+
description = "Convert all texts to lowercase."
|
| 870 |
+
|
| 871 |
+
@classmethod
|
| 872 |
+
def _process(cls, text):
|
| 873 |
+
return text.lower()
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
class TextToUppercaseNode(TextProcessingNode):
|
| 877 |
+
node_id = "TextToUppercase"
|
| 878 |
+
display_name = "Text to Uppercase"
|
| 879 |
+
description = "Convert all texts to uppercase."
|
| 880 |
+
|
| 881 |
+
@classmethod
|
| 882 |
+
def _process(cls, text):
|
| 883 |
+
return text.upper()
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
class TruncateTextNode(TextProcessingNode):
|
| 887 |
+
node_id = "TruncateText"
|
| 888 |
+
display_name = "Truncate Text"
|
| 889 |
+
description = "Truncate all texts to a maximum length."
|
| 890 |
+
extra_inputs = [
|
| 891 |
+
io.Int.Input(
|
| 892 |
+
"max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
|
| 893 |
+
),
|
| 894 |
+
]
|
| 895 |
+
|
| 896 |
+
@classmethod
|
| 897 |
+
def _process(cls, text, max_length):
|
| 898 |
+
return text[:max_length]
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
class AddTextPrefixNode(TextProcessingNode):
|
| 902 |
+
node_id = "AddTextPrefix"
|
| 903 |
+
display_name = "Add Text Prefix"
|
| 904 |
+
description = "Add a prefix to all texts."
|
| 905 |
+
extra_inputs = [
|
| 906 |
+
io.String.Input("prefix", default="", tooltip="Prefix to add."),
|
| 907 |
+
]
|
| 908 |
+
|
| 909 |
+
@classmethod
|
| 910 |
+
def _process(cls, text, prefix):
|
| 911 |
+
return prefix + text
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
class AddTextSuffixNode(TextProcessingNode):
|
| 915 |
+
node_id = "AddTextSuffix"
|
| 916 |
+
display_name = "Add Text Suffix"
|
| 917 |
+
description = "Add a suffix to all texts."
|
| 918 |
+
extra_inputs = [
|
| 919 |
+
io.String.Input("suffix", default="", tooltip="Suffix to add."),
|
| 920 |
+
]
|
| 921 |
+
|
| 922 |
+
@classmethod
|
| 923 |
+
def _process(cls, text, suffix):
|
| 924 |
+
return text + suffix
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
class ReplaceTextNode(TextProcessingNode):
|
| 928 |
+
node_id = "ReplaceText"
|
| 929 |
+
display_name = "Replace Text"
|
| 930 |
+
description = "Replace text in all texts."
|
| 931 |
+
extra_inputs = [
|
| 932 |
+
io.String.Input("find", default="", tooltip="Text to find."),
|
| 933 |
+
io.String.Input("replace", default="", tooltip="Text to replace with."),
|
| 934 |
+
]
|
| 935 |
+
|
| 936 |
+
@classmethod
|
| 937 |
+
def _process(cls, text, find, replace):
|
| 938 |
+
return text.replace(find, replace)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
class StripWhitespaceNode(TextProcessingNode):
|
| 942 |
+
node_id = "StripWhitespace"
|
| 943 |
+
display_name = "Strip Whitespace"
|
| 944 |
+
description = "Strip leading and trailing whitespace from all texts."
|
| 945 |
+
|
| 946 |
+
@classmethod
|
| 947 |
+
def _process(cls, text):
|
| 948 |
+
return text.strip()
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
# ========== Group Processing Example Nodes ==========
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
class ImageDeduplicationNode(ImageProcessingNode):
|
| 955 |
+
"""Remove duplicate or very similar images from the dataset using perceptual hashing."""
|
| 956 |
+
|
| 957 |
+
node_id = "ImageDeduplication"
|
| 958 |
+
display_name = "Image Deduplication"
|
| 959 |
+
description = "Remove duplicate or very similar images from the dataset."
|
| 960 |
+
is_group_process = True # Requires full list to compare images
|
| 961 |
+
extra_inputs = [
|
| 962 |
+
io.Float.Input(
|
| 963 |
+
"similarity_threshold",
|
| 964 |
+
default=0.95,
|
| 965 |
+
min=0.0,
|
| 966 |
+
max=1.0,
|
| 967 |
+
tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.",
|
| 968 |
+
advanced=True,
|
| 969 |
+
),
|
| 970 |
+
]
|
| 971 |
+
|
| 972 |
+
@classmethod
|
| 973 |
+
def _group_process(cls, images, similarity_threshold):
|
| 974 |
+
"""Remove duplicate images using perceptual hashing."""
|
| 975 |
+
if len(images) == 0:
|
| 976 |
+
return []
|
| 977 |
+
|
| 978 |
+
# Compute simple perceptual hash for each image
|
| 979 |
+
def compute_hash(img_tensor):
|
| 980 |
+
"""Compute a simple perceptual hash by resizing to 8x8 and comparing to average."""
|
| 981 |
+
img = tensor_to_pil(img_tensor)
|
| 982 |
+
# Resize to 8x8
|
| 983 |
+
img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L")
|
| 984 |
+
# Get pixels
|
| 985 |
+
pixels = list(img_small.getdata())
|
| 986 |
+
# Compute average
|
| 987 |
+
avg = sum(pixels) / len(pixels)
|
| 988 |
+
# Create hash (1 if above average, 0 otherwise)
|
| 989 |
+
hash_bits = "".join("1" if p > avg else "0" for p in pixels)
|
| 990 |
+
return hash_bits
|
| 991 |
+
|
| 992 |
+
def hamming_distance(hash1, hash2):
|
| 993 |
+
"""Compute Hamming distance between two hash strings."""
|
| 994 |
+
return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
|
| 995 |
+
|
| 996 |
+
# Compute hashes for all images
|
| 997 |
+
hashes = [compute_hash(img) for img in images]
|
| 998 |
+
|
| 999 |
+
# Find duplicates
|
| 1000 |
+
keep_indices = []
|
| 1001 |
+
for i in range(len(images)):
|
| 1002 |
+
is_duplicate = False
|
| 1003 |
+
for j in keep_indices:
|
| 1004 |
+
# Compare hashes
|
| 1005 |
+
distance = hamming_distance(hashes[i], hashes[j])
|
| 1006 |
+
similarity = 1.0 - (distance / 64.0) # 64 bits total
|
| 1007 |
+
if similarity >= similarity_threshold:
|
| 1008 |
+
is_duplicate = True
|
| 1009 |
+
logging.info(
|
| 1010 |
+
f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping"
|
| 1011 |
+
)
|
| 1012 |
+
break
|
| 1013 |
+
|
| 1014 |
+
if not is_duplicate:
|
| 1015 |
+
keep_indices.append(i)
|
| 1016 |
+
|
| 1017 |
+
# Return only unique images
|
| 1018 |
+
unique_images = [images[i] for i in keep_indices]
|
| 1019 |
+
logging.info(
|
| 1020 |
+
f"Deduplication: kept {len(unique_images)} out of {len(images)} images"
|
| 1021 |
+
)
|
| 1022 |
+
return unique_images
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
class ImageGridNode(ImageProcessingNode):
|
| 1026 |
+
"""Combine multiple images into a single grid/collage."""
|
| 1027 |
+
|
| 1028 |
+
node_id = "ImageGrid"
|
| 1029 |
+
display_name = "Image Grid"
|
| 1030 |
+
description = "Arrange multiple images into a grid layout."
|
| 1031 |
+
is_group_process = True # Requires full list to create grid
|
| 1032 |
+
is_output_list = False # Outputs single grid image
|
| 1033 |
+
extra_inputs = [
|
| 1034 |
+
io.Int.Input(
|
| 1035 |
+
"columns",
|
| 1036 |
+
default=4,
|
| 1037 |
+
min=1,
|
| 1038 |
+
max=20,
|
| 1039 |
+
tooltip="Number of columns in the grid.",
|
| 1040 |
+
),
|
| 1041 |
+
io.Int.Input(
|
| 1042 |
+
"cell_width",
|
| 1043 |
+
default=256,
|
| 1044 |
+
min=32,
|
| 1045 |
+
max=2048,
|
| 1046 |
+
tooltip="Width of each cell in the grid.",
|
| 1047 |
+
advanced=True,
|
| 1048 |
+
),
|
| 1049 |
+
io.Int.Input(
|
| 1050 |
+
"cell_height",
|
| 1051 |
+
default=256,
|
| 1052 |
+
min=32,
|
| 1053 |
+
max=2048,
|
| 1054 |
+
tooltip="Height of each cell in the grid.",
|
| 1055 |
+
advanced=True,
|
| 1056 |
+
),
|
| 1057 |
+
io.Int.Input(
|
| 1058 |
+
"padding", default=4, min=0, max=50, tooltip="Padding between images.", advanced=True
|
| 1059 |
+
),
|
| 1060 |
+
]
|
| 1061 |
+
|
| 1062 |
+
@classmethod
|
| 1063 |
+
def _group_process(cls, images, columns, cell_width, cell_height, padding):
|
| 1064 |
+
"""Arrange images into a grid."""
|
| 1065 |
+
if len(images) == 0:
|
| 1066 |
+
raise ValueError("Cannot create grid from empty image list")
|
| 1067 |
+
|
| 1068 |
+
# Calculate grid dimensions
|
| 1069 |
+
num_images = len(images)
|
| 1070 |
+
rows = (num_images + columns - 1) // columns # Ceiling division
|
| 1071 |
+
|
| 1072 |
+
# Calculate total grid size
|
| 1073 |
+
grid_width = columns * cell_width + (columns - 1) * padding
|
| 1074 |
+
grid_height = rows * cell_height + (rows - 1) * padding
|
| 1075 |
+
|
| 1076 |
+
# Create blank grid
|
| 1077 |
+
grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0))
|
| 1078 |
+
|
| 1079 |
+
# Place images
|
| 1080 |
+
for idx, img_tensor in enumerate(images):
|
| 1081 |
+
row = idx // columns
|
| 1082 |
+
col = idx % columns
|
| 1083 |
+
|
| 1084 |
+
# Convert to PIL and resize to cell size
|
| 1085 |
+
img = tensor_to_pil(img_tensor)
|
| 1086 |
+
img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS)
|
| 1087 |
+
|
| 1088 |
+
# Calculate position
|
| 1089 |
+
x = col * (cell_width + padding)
|
| 1090 |
+
y = row * (cell_height + padding)
|
| 1091 |
+
|
| 1092 |
+
# Paste into grid
|
| 1093 |
+
grid.paste(img, (x, y))
|
| 1094 |
+
|
| 1095 |
+
logging.info(
|
| 1096 |
+
f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})"
|
| 1097 |
+
)
|
| 1098 |
+
return pil_to_tensor(grid)
|
| 1099 |
+
|
| 1100 |
+
|
| 1101 |
+
class MergeImageListsNode(ImageProcessingNode):
|
| 1102 |
+
"""Merge multiple image lists into a single list."""
|
| 1103 |
+
|
| 1104 |
+
node_id = "MergeImageLists"
|
| 1105 |
+
display_name = "Merge Image Lists"
|
| 1106 |
+
description = "Concatenate multiple image lists into one."
|
| 1107 |
+
is_group_process = True # Receives images as list
|
| 1108 |
+
|
| 1109 |
+
@classmethod
|
| 1110 |
+
def _group_process(cls, images):
|
| 1111 |
+
"""Simply return the images list (already merged by input handling)."""
|
| 1112 |
+
# When multiple list inputs are connected, they're concatenated
|
| 1113 |
+
# For now, this is a simple pass-through
|
| 1114 |
+
logging.info(f"Merged image list contains {len(images)} images")
|
| 1115 |
+
return images
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
class MergeTextListsNode(TextProcessingNode):
|
| 1119 |
+
"""Merge multiple text lists into a single list."""
|
| 1120 |
+
|
| 1121 |
+
node_id = "MergeTextLists"
|
| 1122 |
+
display_name = "Merge Text Lists"
|
| 1123 |
+
description = "Concatenate multiple text lists into one."
|
| 1124 |
+
is_group_process = True # Receives texts as list
|
| 1125 |
+
|
| 1126 |
+
@classmethod
|
| 1127 |
+
def _group_process(cls, texts):
|
| 1128 |
+
"""Simply return the texts list (already merged by input handling)."""
|
| 1129 |
+
# When multiple list inputs are connected, they're concatenated
|
| 1130 |
+
# For now, this is a simple pass-through
|
| 1131 |
+
logging.info(f"Merged text list contains {len(texts)} texts")
|
| 1132 |
+
return texts
|
| 1133 |
+
|
| 1134 |
+
|
| 1135 |
+
# ========== Training Dataset Nodes ==========
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
class ResolutionBucket(io.ComfyNode):
|
| 1139 |
+
"""Bucket latents and conditions by resolution for efficient batch training."""
|
| 1140 |
+
|
| 1141 |
+
@classmethod
|
| 1142 |
+
def define_schema(cls):
|
| 1143 |
+
return io.Schema(
|
| 1144 |
+
node_id="ResolutionBucket",
|
| 1145 |
+
display_name="Resolution Bucket",
|
| 1146 |
+
category="dataset",
|
| 1147 |
+
is_experimental=True,
|
| 1148 |
+
is_input_list=True,
|
| 1149 |
+
inputs=[
|
| 1150 |
+
io.Latent.Input(
|
| 1151 |
+
"latents",
|
| 1152 |
+
tooltip="List of latent dicts to bucket by resolution.",
|
| 1153 |
+
),
|
| 1154 |
+
io.Conditioning.Input(
|
| 1155 |
+
"conditioning",
|
| 1156 |
+
tooltip="List of conditioning lists (must match latents length).",
|
| 1157 |
+
),
|
| 1158 |
+
],
|
| 1159 |
+
outputs=[
|
| 1160 |
+
io.Latent.Output(
|
| 1161 |
+
display_name="latents",
|
| 1162 |
+
is_output_list=True,
|
| 1163 |
+
tooltip="List of batched latent dicts, one per resolution bucket.",
|
| 1164 |
+
),
|
| 1165 |
+
io.Conditioning.Output(
|
| 1166 |
+
display_name="conditioning",
|
| 1167 |
+
is_output_list=True,
|
| 1168 |
+
tooltip="List of condition lists, one per resolution bucket.",
|
| 1169 |
+
),
|
| 1170 |
+
],
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
@classmethod
|
| 1174 |
+
def execute(cls, latents, conditioning):
|
| 1175 |
+
# latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
|
| 1176 |
+
# conditioning: list[list[cond]]
|
| 1177 |
+
|
| 1178 |
+
# Validate lengths match
|
| 1179 |
+
if len(latents) != len(conditioning):
|
| 1180 |
+
raise ValueError(
|
| 1181 |
+
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
# Flatten latents and conditions to individual samples
|
| 1185 |
+
flat_latents = [] # list of (C, H, W) tensors
|
| 1186 |
+
flat_conditions = [] # list of condition lists
|
| 1187 |
+
|
| 1188 |
+
for latent_dict, cond in zip(latents, conditioning):
|
| 1189 |
+
samples = latent_dict["samples"] # (B, C, H, W)
|
| 1190 |
+
batch_size = samples.shape[0]
|
| 1191 |
+
|
| 1192 |
+
# cond is a list of conditions with length == batch_size
|
| 1193 |
+
for i in range(batch_size):
|
| 1194 |
+
flat_latents.append(samples[i]) # (C, H, W)
|
| 1195 |
+
flat_conditions.append(cond[i]) # single condition
|
| 1196 |
+
|
| 1197 |
+
# Group by resolution (H, W)
|
| 1198 |
+
buckets = {} # (H, W) -> {"latents": list, "conditions": list}
|
| 1199 |
+
|
| 1200 |
+
for latent, cond in zip(flat_latents, flat_conditions):
|
| 1201 |
+
# latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
|
| 1202 |
+
h, w = latent.shape[-2], latent.shape[-1]
|
| 1203 |
+
key = (h, w)
|
| 1204 |
+
|
| 1205 |
+
if key not in buckets:
|
| 1206 |
+
buckets[key] = {"latents": [], "conditions": []}
|
| 1207 |
+
|
| 1208 |
+
buckets[key]["latents"].append(latent)
|
| 1209 |
+
buckets[key]["conditions"].append(cond)
|
| 1210 |
+
|
| 1211 |
+
# Convert buckets to output format
|
| 1212 |
+
output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
|
| 1213 |
+
output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
|
| 1214 |
+
|
| 1215 |
+
for (h, w), bucket_data in buckets.items():
|
| 1216 |
+
# Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
|
| 1217 |
+
stacked_latents = torch.stack(bucket_data["latents"], dim=0)
|
| 1218 |
+
output_latents.append({"samples": stacked_latents})
|
| 1219 |
+
|
| 1220 |
+
# Conditions stay as list of condition lists
|
| 1221 |
+
output_conditions.append(bucket_data["conditions"])
|
| 1222 |
+
|
| 1223 |
+
logging.info(
|
| 1224 |
+
f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
|
| 1225 |
+
)
|
| 1226 |
+
|
| 1227 |
+
logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
|
| 1228 |
+
return io.NodeOutput(output_latents, output_conditions)
|
| 1229 |
+
|
| 1230 |
+
|
| 1231 |
+
class MakeTrainingDataset(io.ComfyNode):
|
| 1232 |
+
"""Encode images with VAE and texts with CLIP to create a training dataset."""
|
| 1233 |
+
@classmethod
|
| 1234 |
+
def define_schema(cls):
|
| 1235 |
+
return io.Schema(
|
| 1236 |
+
node_id="MakeTrainingDataset",
|
| 1237 |
+
search_aliases=["encode dataset"],
|
| 1238 |
+
display_name="Make Training Dataset",
|
| 1239 |
+
category="dataset",
|
| 1240 |
+
is_experimental=True,
|
| 1241 |
+
is_input_list=True, # images and texts as lists
|
| 1242 |
+
inputs=[
|
| 1243 |
+
io.Image.Input("images", tooltip="List of images to encode."),
|
| 1244 |
+
io.Vae.Input(
|
| 1245 |
+
"vae", tooltip="VAE model for encoding images to latents."
|
| 1246 |
+
),
|
| 1247 |
+
io.Clip.Input(
|
| 1248 |
+
"clip", tooltip="CLIP model for encoding text to conditioning."
|
| 1249 |
+
),
|
| 1250 |
+
io.String.Input(
|
| 1251 |
+
"texts",
|
| 1252 |
+
optional=True,
|
| 1253 |
+
tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
|
| 1254 |
+
),
|
| 1255 |
+
],
|
| 1256 |
+
outputs=[
|
| 1257 |
+
io.Latent.Output(
|
| 1258 |
+
display_name="latents",
|
| 1259 |
+
is_output_list=True,
|
| 1260 |
+
tooltip="List of latent dicts",
|
| 1261 |
+
),
|
| 1262 |
+
io.Conditioning.Output(
|
| 1263 |
+
display_name="conditioning",
|
| 1264 |
+
is_output_list=True,
|
| 1265 |
+
tooltip="List of conditioning lists",
|
| 1266 |
+
),
|
| 1267 |
+
],
|
| 1268 |
+
)
|
| 1269 |
+
|
| 1270 |
+
@classmethod
|
| 1271 |
+
def execute(cls, images, vae, clip, texts=None):
|
| 1272 |
+
# Extract scalars (vae and clip are single values wrapped in lists)
|
| 1273 |
+
vae = vae[0]
|
| 1274 |
+
clip = clip[0]
|
| 1275 |
+
|
| 1276 |
+
# Handle text list
|
| 1277 |
+
num_images = len(images)
|
| 1278 |
+
|
| 1279 |
+
if texts is None or len(texts) == 0:
|
| 1280 |
+
# Treat as [""] for unconditional training
|
| 1281 |
+
texts = [""]
|
| 1282 |
+
|
| 1283 |
+
if len(texts) == 1 and num_images > 1:
|
| 1284 |
+
# Repeat single text for all images
|
| 1285 |
+
texts = texts * num_images
|
| 1286 |
+
elif len(texts) != num_images:
|
| 1287 |
+
raise ValueError(
|
| 1288 |
+
f"Number of texts ({len(texts)}) does not match number of images ({num_images}). "
|
| 1289 |
+
f"Text list should have length {num_images}, 1, or 0."
|
| 1290 |
+
)
|
| 1291 |
+
|
| 1292 |
+
# Encode images with VAE
|
| 1293 |
+
logging.info(f"Encoding {num_images} images with VAE...")
|
| 1294 |
+
latents_list = [] # list[{"samples": tensor}]
|
| 1295 |
+
for img_tensor in images:
|
| 1296 |
+
# img_tensor is [1, H, W, 3]
|
| 1297 |
+
latent_tensor = vae.encode(img_tensor[:, :, :, :3])
|
| 1298 |
+
latents_list.append({"samples": latent_tensor})
|
| 1299 |
+
|
| 1300 |
+
# Encode texts with CLIP
|
| 1301 |
+
logging.info(f"Encoding {len(texts)} texts with CLIP...")
|
| 1302 |
+
conditioning_list = [] # list[list[cond]]
|
| 1303 |
+
for text in texts:
|
| 1304 |
+
if text == "":
|
| 1305 |
+
cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
| 1306 |
+
else:
|
| 1307 |
+
tokens = clip.tokenize(text)
|
| 1308 |
+
cond = clip.encode_from_tokens_scheduled(tokens)
|
| 1309 |
+
conditioning_list.append(cond)
|
| 1310 |
+
|
| 1311 |
+
logging.info(
|
| 1312 |
+
f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning."
|
| 1313 |
+
)
|
| 1314 |
+
return io.NodeOutput(latents_list, conditioning_list)
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
class SaveTrainingDataset(io.ComfyNode):
|
| 1318 |
+
"""Save encoded training dataset (latents + conditioning) to disk."""
|
| 1319 |
+
@classmethod
|
| 1320 |
+
def define_schema(cls):
|
| 1321 |
+
return io.Schema(
|
| 1322 |
+
node_id="SaveTrainingDataset",
|
| 1323 |
+
search_aliases=["export training data"],
|
| 1324 |
+
display_name="Save Training Dataset",
|
| 1325 |
+
category="dataset",
|
| 1326 |
+
is_experimental=True,
|
| 1327 |
+
is_output_node=True,
|
| 1328 |
+
is_input_list=True, # Receive lists
|
| 1329 |
+
inputs=[
|
| 1330 |
+
io.Latent.Input(
|
| 1331 |
+
"latents",
|
| 1332 |
+
tooltip="List of latent dicts from MakeTrainingDataset.",
|
| 1333 |
+
),
|
| 1334 |
+
io.Conditioning.Input(
|
| 1335 |
+
"conditioning",
|
| 1336 |
+
tooltip="List of conditioning lists from MakeTrainingDataset.",
|
| 1337 |
+
),
|
| 1338 |
+
io.String.Input(
|
| 1339 |
+
"folder_name",
|
| 1340 |
+
default="training_dataset",
|
| 1341 |
+
tooltip="Name of folder to save dataset (inside output directory).",
|
| 1342 |
+
),
|
| 1343 |
+
io.Int.Input(
|
| 1344 |
+
"shard_size",
|
| 1345 |
+
default=1000,
|
| 1346 |
+
min=1,
|
| 1347 |
+
max=100000,
|
| 1348 |
+
tooltip="Number of samples per shard file.",
|
| 1349 |
+
advanced=True,
|
| 1350 |
+
),
|
| 1351 |
+
],
|
| 1352 |
+
outputs=[],
|
| 1353 |
+
)
|
| 1354 |
+
|
| 1355 |
+
@classmethod
|
| 1356 |
+
def execute(cls, latents, conditioning, folder_name, shard_size):
|
| 1357 |
+
# Extract scalars
|
| 1358 |
+
folder_name = folder_name[0]
|
| 1359 |
+
shard_size = shard_size[0]
|
| 1360 |
+
|
| 1361 |
+
# latents: list[{"samples": tensor}]
|
| 1362 |
+
# conditioning: list[list[cond]]
|
| 1363 |
+
|
| 1364 |
+
# Validate lengths match
|
| 1365 |
+
if len(latents) != len(conditioning):
|
| 1366 |
+
raise ValueError(
|
| 1367 |
+
f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). "
|
| 1368 |
+
f"Something went wrong in dataset preparation."
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
# Create output directory
|
| 1372 |
+
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
| 1373 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 1374 |
+
|
| 1375 |
+
# Prepare data pairs
|
| 1376 |
+
num_samples = len(latents)
|
| 1377 |
+
num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
|
| 1378 |
+
|
| 1379 |
+
logging.info(
|
| 1380 |
+
f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
|
| 1381 |
+
)
|
| 1382 |
+
|
| 1383 |
+
# Save data in shards
|
| 1384 |
+
for shard_idx in range(num_shards):
|
| 1385 |
+
start_idx = shard_idx * shard_size
|
| 1386 |
+
end_idx = min(start_idx + shard_size, num_samples)
|
| 1387 |
+
|
| 1388 |
+
# Get shard data (list of latent dicts and conditioning lists)
|
| 1389 |
+
shard_data = {
|
| 1390 |
+
"latents": latents[start_idx:end_idx],
|
| 1391 |
+
"conditioning": conditioning[start_idx:end_idx],
|
| 1392 |
+
}
|
| 1393 |
+
|
| 1394 |
+
# Save shard
|
| 1395 |
+
shard_filename = f"shard_{shard_idx:04d}.pkl"
|
| 1396 |
+
shard_path = os.path.join(output_dir, shard_filename)
|
| 1397 |
+
|
| 1398 |
+
with open(shard_path, "wb") as f:
|
| 1399 |
+
torch.save(shard_data, f)
|
| 1400 |
+
|
| 1401 |
+
logging.info(
|
| 1402 |
+
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
|
| 1403 |
+
)
|
| 1404 |
+
|
| 1405 |
+
# Save metadata
|
| 1406 |
+
metadata = {
|
| 1407 |
+
"num_samples": num_samples,
|
| 1408 |
+
"num_shards": num_shards,
|
| 1409 |
+
"shard_size": shard_size,
|
| 1410 |
+
}
|
| 1411 |
+
metadata_path = os.path.join(output_dir, "metadata.json")
|
| 1412 |
+
with open(metadata_path, "w") as f:
|
| 1413 |
+
json.dump(metadata, f, indent=2)
|
| 1414 |
+
|
| 1415 |
+
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
|
| 1416 |
+
return io.NodeOutput()
|
| 1417 |
+
|
| 1418 |
+
|
| 1419 |
+
class LoadTrainingDataset(io.ComfyNode):
|
| 1420 |
+
"""Load encoded training dataset from disk."""
|
| 1421 |
+
@classmethod
|
| 1422 |
+
def define_schema(cls):
|
| 1423 |
+
return io.Schema(
|
| 1424 |
+
node_id="LoadTrainingDataset",
|
| 1425 |
+
search_aliases=["import dataset", "training data"],
|
| 1426 |
+
display_name="Load Training Dataset",
|
| 1427 |
+
category="dataset",
|
| 1428 |
+
is_experimental=True,
|
| 1429 |
+
inputs=[
|
| 1430 |
+
io.String.Input(
|
| 1431 |
+
"folder_name",
|
| 1432 |
+
default="training_dataset",
|
| 1433 |
+
tooltip="Name of folder containing the saved dataset (inside output directory).",
|
| 1434 |
+
),
|
| 1435 |
+
],
|
| 1436 |
+
outputs=[
|
| 1437 |
+
io.Latent.Output(
|
| 1438 |
+
display_name="latents",
|
| 1439 |
+
is_output_list=True,
|
| 1440 |
+
tooltip="List of latent dicts",
|
| 1441 |
+
),
|
| 1442 |
+
io.Conditioning.Output(
|
| 1443 |
+
display_name="conditioning",
|
| 1444 |
+
is_output_list=True,
|
| 1445 |
+
tooltip="List of conditioning lists",
|
| 1446 |
+
),
|
| 1447 |
+
],
|
| 1448 |
+
)
|
| 1449 |
+
|
| 1450 |
+
@classmethod
|
| 1451 |
+
def execute(cls, folder_name):
|
| 1452 |
+
# Get dataset directory
|
| 1453 |
+
dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
| 1454 |
+
|
| 1455 |
+
if not os.path.exists(dataset_dir):
|
| 1456 |
+
raise ValueError(f"Dataset directory not found: {dataset_dir}")
|
| 1457 |
+
|
| 1458 |
+
# Find all shard files
|
| 1459 |
+
shard_files = sorted(
|
| 1460 |
+
[
|
| 1461 |
+
f
|
| 1462 |
+
for f in os.listdir(dataset_dir)
|
| 1463 |
+
if f.startswith("shard_") and f.endswith(".pkl")
|
| 1464 |
+
]
|
| 1465 |
+
)
|
| 1466 |
+
|
| 1467 |
+
if not shard_files:
|
| 1468 |
+
raise ValueError(f"No shard files found in {dataset_dir}")
|
| 1469 |
+
|
| 1470 |
+
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
|
| 1471 |
+
|
| 1472 |
+
# Load all shards
|
| 1473 |
+
all_latents = [] # list[{"samples": tensor}]
|
| 1474 |
+
all_conditioning = [] # list[list[cond]]
|
| 1475 |
+
|
| 1476 |
+
for shard_file in shard_files:
|
| 1477 |
+
shard_path = os.path.join(dataset_dir, shard_file)
|
| 1478 |
+
|
| 1479 |
+
with open(shard_path, "rb") as f:
|
| 1480 |
+
shard_data = torch.load(f)
|
| 1481 |
+
|
| 1482 |
+
all_latents.extend(shard_data["latents"])
|
| 1483 |
+
all_conditioning.extend(shard_data["conditioning"])
|
| 1484 |
+
|
| 1485 |
+
logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
|
| 1486 |
+
|
| 1487 |
+
logging.info(
|
| 1488 |
+
f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."
|
| 1489 |
+
)
|
| 1490 |
+
return io.NodeOutput(all_latents, all_conditioning)
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
# ========== Extension Setup ==========
|
| 1494 |
+
|
| 1495 |
+
|
| 1496 |
+
class DatasetExtension(ComfyExtension):
|
| 1497 |
+
@override
|
| 1498 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 1499 |
+
return [
|
| 1500 |
+
# Data loading/saving nodes
|
| 1501 |
+
LoadImageDataSetFromFolderNode,
|
| 1502 |
+
LoadImageTextDataSetFromFolderNode,
|
| 1503 |
+
SaveImageDataSetToFolderNode,
|
| 1504 |
+
SaveImageTextDataSetToFolderNode,
|
| 1505 |
+
# Image transform nodes
|
| 1506 |
+
ResizeImagesByShorterEdgeNode,
|
| 1507 |
+
ResizeImagesByLongerEdgeNode,
|
| 1508 |
+
CenterCropImagesNode,
|
| 1509 |
+
RandomCropImagesNode,
|
| 1510 |
+
NormalizeImagesNode,
|
| 1511 |
+
AdjustBrightnessNode,
|
| 1512 |
+
AdjustContrastNode,
|
| 1513 |
+
ShuffleDatasetNode,
|
| 1514 |
+
ShuffleImageTextDatasetNode,
|
| 1515 |
+
# Text transform nodes
|
| 1516 |
+
TextToLowercaseNode,
|
| 1517 |
+
TextToUppercaseNode,
|
| 1518 |
+
TruncateTextNode,
|
| 1519 |
+
AddTextPrefixNode,
|
| 1520 |
+
AddTextSuffixNode,
|
| 1521 |
+
ReplaceTextNode,
|
| 1522 |
+
StripWhitespaceNode,
|
| 1523 |
+
# Group processing examples
|
| 1524 |
+
ImageDeduplicationNode,
|
| 1525 |
+
ImageGridNode,
|
| 1526 |
+
MergeImageListsNode,
|
| 1527 |
+
MergeTextListsNode,
|
| 1528 |
+
# Training dataset nodes
|
| 1529 |
+
MakeTrainingDataset,
|
| 1530 |
+
SaveTrainingDataset,
|
| 1531 |
+
LoadTrainingDataset,
|
| 1532 |
+
ResolutionBucket,
|
| 1533 |
+
]
|
| 1534 |
+
|
| 1535 |
+
|
| 1536 |
+
async def comfy_entrypoint() -> DatasetExtension:
|
| 1537 |
+
return DatasetExtension()
|
ComfyUI/comfy_extras/nodes_differential_diffusion.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# code adapted from https://github.com/exx8/differential-diffusion
|
| 2 |
+
|
| 3 |
+
from typing_extensions import override
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from comfy_api.latest import ComfyExtension, io
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DifferentialDiffusion(io.ComfyNode):
|
| 10 |
+
@classmethod
|
| 11 |
+
def define_schema(cls):
|
| 12 |
+
return io.Schema(
|
| 13 |
+
node_id="DifferentialDiffusion",
|
| 14 |
+
search_aliases=["inpaint gradient", "variable denoise strength"],
|
| 15 |
+
display_name="Differential Diffusion",
|
| 16 |
+
category="_for_testing",
|
| 17 |
+
inputs=[
|
| 18 |
+
io.Model.Input("model"),
|
| 19 |
+
io.Float.Input(
|
| 20 |
+
"strength",
|
| 21 |
+
default=1.0,
|
| 22 |
+
min=0.0,
|
| 23 |
+
max=1.0,
|
| 24 |
+
step=0.01,
|
| 25 |
+
optional=True,
|
| 26 |
+
),
|
| 27 |
+
],
|
| 28 |
+
outputs=[io.Model.Output()],
|
| 29 |
+
is_experimental=True,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def execute(cls, model, strength=1.0) -> io.NodeOutput:
|
| 34 |
+
model = model.clone()
|
| 35 |
+
model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength))
|
| 36 |
+
return io.NodeOutput(model)
|
| 37 |
+
|
| 38 |
+
@classmethod
|
| 39 |
+
def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
|
| 40 |
+
model = extra_options["model"]
|
| 41 |
+
step_sigmas = extra_options["sigmas"]
|
| 42 |
+
sigma_to = model.inner_model.model_sampling.sigma_min
|
| 43 |
+
if step_sigmas[-1] > sigma_to:
|
| 44 |
+
sigma_to = step_sigmas[-1]
|
| 45 |
+
sigma_from = step_sigmas[0]
|
| 46 |
+
|
| 47 |
+
ts_from = model.inner_model.model_sampling.timestep(sigma_from)
|
| 48 |
+
ts_to = model.inner_model.model_sampling.timestep(sigma_to)
|
| 49 |
+
current_ts = model.inner_model.model_sampling.timestep(sigma[0])
|
| 50 |
+
|
| 51 |
+
threshold = (current_ts - ts_to) / (ts_from - ts_to)
|
| 52 |
+
|
| 53 |
+
# Generate the binary mask based on the threshold
|
| 54 |
+
binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype)
|
| 55 |
+
|
| 56 |
+
# Blend binary mask with the original denoise_mask using strength
|
| 57 |
+
if strength and strength < 1:
|
| 58 |
+
blended_mask = strength * binary_mask + (1 - strength) * denoise_mask
|
| 59 |
+
return blended_mask
|
| 60 |
+
else:
|
| 61 |
+
return binary_mask
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DifferentialDiffusionExtension(ComfyExtension):
|
| 65 |
+
@override
|
| 66 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 67 |
+
return [
|
| 68 |
+
DifferentialDiffusion,
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def comfy_entrypoint() -> DifferentialDiffusionExtension:
|
| 73 |
+
return DifferentialDiffusionExtension()
|
ComfyUI/comfy_extras/nodes_easycache.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import TYPE_CHECKING, Union
|
| 3 |
+
from comfy_api.latest import io, ComfyExtension
|
| 4 |
+
import comfy.patcher_extension
|
| 5 |
+
import logging
|
| 6 |
+
import torch
|
| 7 |
+
import comfy.model_patcher
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
from uuid import UUID
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _extract_tensor(data, output_channels):
|
| 13 |
+
"""Extract tensor from data, handling both single tensors and lists."""
|
| 14 |
+
if isinstance(data, list):
|
| 15 |
+
# LTX2 AV tensors: [video, audio]
|
| 16 |
+
return data[0][:, :output_channels], data[1][:, :output_channels]
|
| 17 |
+
return data[:, :output_channels], None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def easycache_forward_wrapper(executor, *args, **kwargs):
|
| 21 |
+
# get values from args
|
| 22 |
+
transformer_options: dict[str] = args[-1]
|
| 23 |
+
if not isinstance(transformer_options, dict):
|
| 24 |
+
transformer_options = kwargs.get("transformer_options")
|
| 25 |
+
if not transformer_options:
|
| 26 |
+
transformer_options = args[-2]
|
| 27 |
+
easycache: EasyCacheHolder = transformer_options["easycache"]
|
| 28 |
+
x, ax = _extract_tensor(args[0], easycache.output_channels)
|
| 29 |
+
sigmas = transformer_options["sigmas"]
|
| 30 |
+
uuids = transformer_options["uuids"]
|
| 31 |
+
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
| 32 |
+
return executor(*args, **kwargs)
|
| 33 |
+
# prepare next x_prev
|
| 34 |
+
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
| 35 |
+
next_x_prev = x
|
| 36 |
+
input_change = None
|
| 37 |
+
do_easycache = easycache.should_do_easycache(sigmas)
|
| 38 |
+
if do_easycache:
|
| 39 |
+
easycache.check_metadata(x)
|
| 40 |
+
# if there isn't a cache diff for current conds, we cannot skip this step
|
| 41 |
+
can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
|
| 42 |
+
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
| 43 |
+
if easycache.skip_current_step and can_apply_cache_diff:
|
| 44 |
+
if easycache.verbose:
|
| 45 |
+
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
| 46 |
+
result = easycache.apply_cache_diff(x, uuids)
|
| 47 |
+
if ax is not None:
|
| 48 |
+
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
| 49 |
+
return [result, result_audio]
|
| 50 |
+
return result
|
| 51 |
+
if easycache.initial_step:
|
| 52 |
+
easycache.first_cond_uuid = uuids[0]
|
| 53 |
+
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
| 54 |
+
easycache.initial_step = False
|
| 55 |
+
if has_first_cond_uuid:
|
| 56 |
+
if easycache.has_x_prev_subsampled():
|
| 57 |
+
input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
| 58 |
+
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
| 59 |
+
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
| 60 |
+
easycache.cumulative_change_rate += approx_output_change_rate
|
| 61 |
+
if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
|
| 62 |
+
if easycache.verbose:
|
| 63 |
+
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
| 64 |
+
# other conds should also skip this step, and instead use their cached values
|
| 65 |
+
easycache.skip_current_step = True
|
| 66 |
+
result = easycache.apply_cache_diff(x, uuids)
|
| 67 |
+
if ax is not None:
|
| 68 |
+
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
| 69 |
+
return [result, result_audio]
|
| 70 |
+
return result
|
| 71 |
+
else:
|
| 72 |
+
if easycache.verbose:
|
| 73 |
+
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
| 74 |
+
easycache.cumulative_change_rate = 0.0
|
| 75 |
+
|
| 76 |
+
full_output: torch.Tensor = executor(*args, **kwargs)
|
| 77 |
+
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
|
| 78 |
+
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
| 79 |
+
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
| 80 |
+
if easycache.verbose:
|
| 81 |
+
output_change_rate = output_change / easycache.output_prev_norm
|
| 82 |
+
easycache.output_change_rates.append(output_change_rate.item())
|
| 83 |
+
if easycache.has_relative_transformation_rate():
|
| 84 |
+
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
| 85 |
+
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
| 86 |
+
if easycache.verbose:
|
| 87 |
+
logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
| 88 |
+
if input_change is not None:
|
| 89 |
+
easycache.relative_transformation_rate = output_change / input_change
|
| 90 |
+
if easycache.verbose:
|
| 91 |
+
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
| 92 |
+
# TODO: allow cache_diff to be offloaded
|
| 93 |
+
easycache.update_cache_diff(output, next_x_prev, uuids)
|
| 94 |
+
if audio_output is not None:
|
| 95 |
+
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
|
| 96 |
+
if has_first_cond_uuid:
|
| 97 |
+
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
| 98 |
+
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
| 99 |
+
easycache.output_prev_norm = output.flatten().abs().mean()
|
| 100 |
+
if easycache.verbose:
|
| 101 |
+
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
| 102 |
+
return full_output
|
| 103 |
+
|
| 104 |
+
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
| 105 |
+
# get values from args
|
| 106 |
+
timestep: float = args[1]
|
| 107 |
+
model_options: dict[str] = args[2]
|
| 108 |
+
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
| 109 |
+
if easycache.is_past_end_timestep(timestep):
|
| 110 |
+
return executor(*args, **kwargs)
|
| 111 |
+
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
| 112 |
+
# prepare next x_prev
|
| 113 |
+
next_x_prev = x
|
| 114 |
+
input_change = None
|
| 115 |
+
do_easycache = easycache.should_do_easycache(timestep)
|
| 116 |
+
if do_easycache:
|
| 117 |
+
easycache.check_metadata(x)
|
| 118 |
+
if easycache.has_x_prev_subsampled():
|
| 119 |
+
if easycache.has_x_prev_subsampled():
|
| 120 |
+
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
| 121 |
+
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
| 122 |
+
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
| 123 |
+
easycache.cumulative_change_rate += approx_output_change_rate
|
| 124 |
+
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
| 125 |
+
if easycache.verbose:
|
| 126 |
+
logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
| 127 |
+
# other conds should also skip this step, and instead use their cached values
|
| 128 |
+
easycache.skip_current_step = True
|
| 129 |
+
return easycache.apply_cache_diff(x)
|
| 130 |
+
else:
|
| 131 |
+
if easycache.verbose:
|
| 132 |
+
logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
| 133 |
+
easycache.cumulative_change_rate = 0.0
|
| 134 |
+
output: torch.Tensor = executor(*args, **kwargs)
|
| 135 |
+
if easycache.has_output_prev_norm():
|
| 136 |
+
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
| 137 |
+
if easycache.verbose:
|
| 138 |
+
output_change_rate = output_change / easycache.output_prev_norm
|
| 139 |
+
easycache.output_change_rates.append(output_change_rate.item())
|
| 140 |
+
if easycache.has_relative_transformation_rate():
|
| 141 |
+
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
| 142 |
+
easycache.approx_output_change_rates.append(approx_output_change_rate.item())
|
| 143 |
+
if easycache.verbose:
|
| 144 |
+
logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
|
| 145 |
+
if input_change is not None:
|
| 146 |
+
easycache.relative_transformation_rate = output_change / input_change
|
| 147 |
+
if easycache.verbose:
|
| 148 |
+
logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
|
| 149 |
+
# TODO: allow cache_diff to be offloaded
|
| 150 |
+
easycache.update_cache_diff(output, next_x_prev)
|
| 151 |
+
easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
|
| 152 |
+
easycache.output_prev_subsampled = easycache.subsample(output)
|
| 153 |
+
easycache.output_prev_norm = output.flatten().abs().mean()
|
| 154 |
+
if easycache.verbose:
|
| 155 |
+
logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
| 156 |
+
return output
|
| 157 |
+
|
| 158 |
+
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
|
| 159 |
+
model_options = args[-1]
|
| 160 |
+
easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"]
|
| 161 |
+
easycache.skip_current_step = False
|
| 162 |
+
# TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset
|
| 163 |
+
return executor(*args, **kwargs)
|
| 164 |
+
|
| 165 |
+
def easycache_sample_wrapper(executor, *args, **kwargs):
|
| 166 |
+
"""
|
| 167 |
+
This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end.
|
| 168 |
+
"""
|
| 169 |
+
try:
|
| 170 |
+
guider = executor.class_obj
|
| 171 |
+
orig_model_options = guider.model_options
|
| 172 |
+
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
|
| 173 |
+
# clone and prepare timesteps
|
| 174 |
+
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
|
| 175 |
+
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
|
| 176 |
+
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
|
| 177 |
+
return executor(*args, **kwargs)
|
| 178 |
+
finally:
|
| 179 |
+
easycache = guider.model_options['transformer_options']['easycache']
|
| 180 |
+
output_change_rates = easycache.output_change_rates
|
| 181 |
+
approx_output_change_rates = easycache.approx_output_change_rates
|
| 182 |
+
if easycache.verbose:
|
| 183 |
+
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
| 184 |
+
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
| 185 |
+
total_steps = len(args[3])-1
|
| 186 |
+
# catch division by zero for log statement; sucks to crash after all sampling is done
|
| 187 |
+
try:
|
| 188 |
+
speedup = total_steps/(total_steps-easycache.total_steps_skipped)
|
| 189 |
+
except ZeroDivisionError:
|
| 190 |
+
speedup = 1.0
|
| 191 |
+
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
|
| 192 |
+
easycache.reset()
|
| 193 |
+
guider.model_options = orig_model_options
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class EasyCacheHolder:
|
| 197 |
+
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
| 198 |
+
self.name = "EasyCache"
|
| 199 |
+
self.reuse_threshold = reuse_threshold
|
| 200 |
+
self.start_percent = start_percent
|
| 201 |
+
self.end_percent = end_percent
|
| 202 |
+
self.subsample_factor = subsample_factor
|
| 203 |
+
self.offload_cache_diff = offload_cache_diff
|
| 204 |
+
self.verbose = verbose
|
| 205 |
+
# timestep values
|
| 206 |
+
self.start_t = 0.0
|
| 207 |
+
self.end_t = 0.0
|
| 208 |
+
# control values
|
| 209 |
+
self.relative_transformation_rate: float = None
|
| 210 |
+
self.cumulative_change_rate = 0.0
|
| 211 |
+
self.initial_step = True
|
| 212 |
+
self.skip_current_step = False
|
| 213 |
+
# cache values
|
| 214 |
+
self.first_cond_uuid = None
|
| 215 |
+
self.x_prev_subsampled: torch.Tensor = None
|
| 216 |
+
self.output_prev_subsampled: torch.Tensor = None
|
| 217 |
+
self.output_prev_norm: torch.Tensor = None
|
| 218 |
+
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
| 219 |
+
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
|
| 220 |
+
self.output_change_rates = []
|
| 221 |
+
self.approx_output_change_rates = []
|
| 222 |
+
self.total_steps_skipped = 0
|
| 223 |
+
# how to deal with mismatched dims
|
| 224 |
+
self.allow_mismatch = True
|
| 225 |
+
self.cut_from_start = True
|
| 226 |
+
self.state_metadata = None
|
| 227 |
+
self.output_channels = output_channels
|
| 228 |
+
|
| 229 |
+
def is_past_end_timestep(self, timestep: float) -> bool:
|
| 230 |
+
return not (timestep[0] > self.end_t).item()
|
| 231 |
+
|
| 232 |
+
def should_do_easycache(self, timestep: float) -> bool:
|
| 233 |
+
return (timestep[0] <= self.start_t).item()
|
| 234 |
+
|
| 235 |
+
def has_x_prev_subsampled(self) -> bool:
|
| 236 |
+
return self.x_prev_subsampled is not None
|
| 237 |
+
|
| 238 |
+
def has_output_prev_subsampled(self) -> bool:
|
| 239 |
+
return self.output_prev_subsampled is not None
|
| 240 |
+
|
| 241 |
+
def has_output_prev_norm(self) -> bool:
|
| 242 |
+
return self.output_prev_norm is not None
|
| 243 |
+
|
| 244 |
+
def has_relative_transformation_rate(self) -> bool:
|
| 245 |
+
return self.relative_transformation_rate is not None
|
| 246 |
+
|
| 247 |
+
def prepare_timesteps(self, model_sampling):
|
| 248 |
+
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
| 249 |
+
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
| 250 |
+
return self
|
| 251 |
+
|
| 252 |
+
def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor:
|
| 253 |
+
batch_offset = x.shape[0] // len(uuids)
|
| 254 |
+
uuid_idx = uuids.index(self.first_cond_uuid)
|
| 255 |
+
if self.subsample_factor > 1:
|
| 256 |
+
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor]
|
| 257 |
+
if clone:
|
| 258 |
+
return to_return.clone()
|
| 259 |
+
return to_return
|
| 260 |
+
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...]
|
| 261 |
+
if clone:
|
| 262 |
+
return to_return.clone()
|
| 263 |
+
return to_return
|
| 264 |
+
|
| 265 |
+
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
| 266 |
+
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
| 267 |
+
|
| 268 |
+
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
| 269 |
+
if self.first_cond_uuid in uuids and not is_audio:
|
| 270 |
+
self.total_steps_skipped += 1
|
| 271 |
+
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
| 272 |
+
batch_offset = x.shape[0] // len(uuids)
|
| 273 |
+
for i, uuid in enumerate(uuids):
|
| 274 |
+
# slice out only what is relevant to this cond
|
| 275 |
+
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
| 276 |
+
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
| 277 |
+
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
|
| 278 |
+
if not self.allow_mismatch:
|
| 279 |
+
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
| 280 |
+
slicing = []
|
| 281 |
+
skip_this_dim = True
|
| 282 |
+
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
|
| 283 |
+
if skip_this_dim:
|
| 284 |
+
skip_this_dim = False
|
| 285 |
+
continue
|
| 286 |
+
if dim_u != dim_x:
|
| 287 |
+
if self.cut_from_start:
|
| 288 |
+
slicing.append(slice(dim_x-dim_u, None))
|
| 289 |
+
else:
|
| 290 |
+
slicing.append(slice(None, dim_u))
|
| 291 |
+
else:
|
| 292 |
+
slicing.append(slice(None))
|
| 293 |
+
batch_slice = batch_slice + slicing
|
| 294 |
+
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
|
| 295 |
+
return x
|
| 296 |
+
|
| 297 |
+
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
| 298 |
+
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
| 299 |
+
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
| 300 |
+
if output.shape[1:] != x.shape[1:]:
|
| 301 |
+
if not self.allow_mismatch:
|
| 302 |
+
raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good")
|
| 303 |
+
slicing = []
|
| 304 |
+
skip_dim = True
|
| 305 |
+
for dim_o, dim_x in zip(output.shape, x.shape):
|
| 306 |
+
if not skip_dim and dim_o != dim_x:
|
| 307 |
+
if self.cut_from_start:
|
| 308 |
+
slicing.append(slice(dim_x-dim_o, None))
|
| 309 |
+
else:
|
| 310 |
+
slicing.append(slice(None, dim_o))
|
| 311 |
+
else:
|
| 312 |
+
slicing.append(slice(None))
|
| 313 |
+
skip_dim = False
|
| 314 |
+
x = x[tuple(slicing)]
|
| 315 |
+
diff = output - x
|
| 316 |
+
batch_offset = diff.shape[0] // len(uuids)
|
| 317 |
+
for i, uuid in enumerate(uuids):
|
| 318 |
+
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
| 319 |
+
|
| 320 |
+
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
| 321 |
+
return self.first_cond_uuid in uuids
|
| 322 |
+
|
| 323 |
+
def check_metadata(self, x: torch.Tensor) -> bool:
|
| 324 |
+
metadata = (x.device, x.dtype, x.shape[1:])
|
| 325 |
+
if self.state_metadata is None:
|
| 326 |
+
self.state_metadata = metadata
|
| 327 |
+
return True
|
| 328 |
+
if metadata == self.state_metadata:
|
| 329 |
+
return True
|
| 330 |
+
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
| 331 |
+
self.reset()
|
| 332 |
+
return False
|
| 333 |
+
|
| 334 |
+
def reset(self):
|
| 335 |
+
self.relative_transformation_rate = 0.0
|
| 336 |
+
self.cumulative_change_rate = 0.0
|
| 337 |
+
self.initial_step = True
|
| 338 |
+
self.skip_current_step = False
|
| 339 |
+
self.output_change_rates = []
|
| 340 |
+
self.first_cond_uuid = None
|
| 341 |
+
del self.x_prev_subsampled
|
| 342 |
+
self.x_prev_subsampled = None
|
| 343 |
+
del self.output_prev_subsampled
|
| 344 |
+
self.output_prev_subsampled = None
|
| 345 |
+
del self.output_prev_norm
|
| 346 |
+
self.output_prev_norm = None
|
| 347 |
+
del self.uuid_cache_diffs
|
| 348 |
+
self.uuid_cache_diffs = {}
|
| 349 |
+
del self.uuid_cache_diffs_audio
|
| 350 |
+
self.uuid_cache_diffs_audio = {}
|
| 351 |
+
self.total_steps_skipped = 0
|
| 352 |
+
self.state_metadata = None
|
| 353 |
+
return self
|
| 354 |
+
|
| 355 |
+
def clone(self):
|
| 356 |
+
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class EasyCacheNode(io.ComfyNode):
|
| 360 |
+
@classmethod
|
| 361 |
+
def define_schema(cls) -> io.Schema:
|
| 362 |
+
return io.Schema(
|
| 363 |
+
node_id="EasyCache",
|
| 364 |
+
display_name="EasyCache",
|
| 365 |
+
description="Native EasyCache implementation.",
|
| 366 |
+
category="advanced/debug/model",
|
| 367 |
+
is_experimental=True,
|
| 368 |
+
inputs=[
|
| 369 |
+
io.Model.Input("model", tooltip="The model to add EasyCache to."),
|
| 370 |
+
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True),
|
| 371 |
+
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache.", advanced=True),
|
| 372 |
+
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache.", advanced=True),
|
| 373 |
+
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True),
|
| 374 |
+
],
|
| 375 |
+
outputs=[
|
| 376 |
+
io.Model.Output(tooltip="The model with EasyCache."),
|
| 377 |
+
],
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
@classmethod
|
| 381 |
+
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
| 382 |
+
model = model.clone()
|
| 383 |
+
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
| 384 |
+
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
| 385 |
+
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
| 386 |
+
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
| 387 |
+
return io.NodeOutput(model)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class LazyCacheHolder:
|
| 391 |
+
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
| 392 |
+
self.name = "LazyCache"
|
| 393 |
+
self.reuse_threshold = reuse_threshold
|
| 394 |
+
self.start_percent = start_percent
|
| 395 |
+
self.end_percent = end_percent
|
| 396 |
+
self.subsample_factor = subsample_factor
|
| 397 |
+
self.offload_cache_diff = offload_cache_diff
|
| 398 |
+
self.verbose = verbose
|
| 399 |
+
# timestep values
|
| 400 |
+
self.start_t = 0.0
|
| 401 |
+
self.end_t = 0.0
|
| 402 |
+
# control values
|
| 403 |
+
self.relative_transformation_rate: float = None
|
| 404 |
+
self.cumulative_change_rate = 0.0
|
| 405 |
+
self.initial_step = True
|
| 406 |
+
# cache values
|
| 407 |
+
self.x_prev_subsampled: torch.Tensor = None
|
| 408 |
+
self.output_prev_subsampled: torch.Tensor = None
|
| 409 |
+
self.output_prev_norm: torch.Tensor = None
|
| 410 |
+
self.cache_diff: torch.Tensor = None
|
| 411 |
+
self.output_change_rates = []
|
| 412 |
+
self.approx_output_change_rates = []
|
| 413 |
+
self.total_steps_skipped = 0
|
| 414 |
+
self.state_metadata = None
|
| 415 |
+
self.output_channels = output_channels
|
| 416 |
+
|
| 417 |
+
def has_cache_diff(self) -> bool:
|
| 418 |
+
return self.cache_diff is not None
|
| 419 |
+
|
| 420 |
+
def is_past_end_timestep(self, timestep: float) -> bool:
|
| 421 |
+
return not (timestep[0] > self.end_t).item()
|
| 422 |
+
|
| 423 |
+
def should_do_easycache(self, timestep: float) -> bool:
|
| 424 |
+
return (timestep[0] <= self.start_t).item()
|
| 425 |
+
|
| 426 |
+
def has_x_prev_subsampled(self) -> bool:
|
| 427 |
+
return self.x_prev_subsampled is not None
|
| 428 |
+
|
| 429 |
+
def has_output_prev_subsampled(self) -> bool:
|
| 430 |
+
return self.output_prev_subsampled is not None
|
| 431 |
+
|
| 432 |
+
def has_output_prev_norm(self) -> bool:
|
| 433 |
+
return self.output_prev_norm is not None
|
| 434 |
+
|
| 435 |
+
def has_relative_transformation_rate(self) -> bool:
|
| 436 |
+
return self.relative_transformation_rate is not None
|
| 437 |
+
|
| 438 |
+
def prepare_timesteps(self, model_sampling):
|
| 439 |
+
self.start_t = model_sampling.percent_to_sigma(self.start_percent)
|
| 440 |
+
self.end_t = model_sampling.percent_to_sigma(self.end_percent)
|
| 441 |
+
return self
|
| 442 |
+
|
| 443 |
+
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
|
| 444 |
+
if self.subsample_factor > 1:
|
| 445 |
+
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
|
| 446 |
+
if clone:
|
| 447 |
+
return to_return.clone()
|
| 448 |
+
return to_return
|
| 449 |
+
if clone:
|
| 450 |
+
return x.clone()
|
| 451 |
+
return x
|
| 452 |
+
|
| 453 |
+
def apply_cache_diff(self, x: torch.Tensor):
|
| 454 |
+
self.total_steps_skipped += 1
|
| 455 |
+
return x + self.cache_diff.to(x.device)
|
| 456 |
+
|
| 457 |
+
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
|
| 458 |
+
self.cache_diff = output - x
|
| 459 |
+
|
| 460 |
+
def check_metadata(self, x: torch.Tensor) -> bool:
|
| 461 |
+
metadata = (x.device, x.dtype, x.shape)
|
| 462 |
+
if self.state_metadata is None:
|
| 463 |
+
self.state_metadata = metadata
|
| 464 |
+
return True
|
| 465 |
+
if metadata == self.state_metadata:
|
| 466 |
+
return True
|
| 467 |
+
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
| 468 |
+
self.reset()
|
| 469 |
+
return False
|
| 470 |
+
|
| 471 |
+
def reset(self):
|
| 472 |
+
self.relative_transformation_rate = 0.0
|
| 473 |
+
self.cumulative_change_rate = 0.0
|
| 474 |
+
self.initial_step = True
|
| 475 |
+
self.output_change_rates = []
|
| 476 |
+
self.approx_output_change_rates = []
|
| 477 |
+
del self.cache_diff
|
| 478 |
+
self.cache_diff = None
|
| 479 |
+
del self.x_prev_subsampled
|
| 480 |
+
self.x_prev_subsampled = None
|
| 481 |
+
del self.output_prev_subsampled
|
| 482 |
+
self.output_prev_subsampled = None
|
| 483 |
+
del self.output_prev_norm
|
| 484 |
+
self.output_prev_norm = None
|
| 485 |
+
self.total_steps_skipped = 0
|
| 486 |
+
self.state_metadata = None
|
| 487 |
+
return self
|
| 488 |
+
|
| 489 |
+
def clone(self):
|
| 490 |
+
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
| 491 |
+
|
| 492 |
+
class LazyCacheNode(io.ComfyNode):
|
| 493 |
+
@classmethod
|
| 494 |
+
def define_schema(cls) -> io.Schema:
|
| 495 |
+
return io.Schema(
|
| 496 |
+
node_id="LazyCache",
|
| 497 |
+
display_name="LazyCache",
|
| 498 |
+
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
|
| 499 |
+
category="advanced/debug/model",
|
| 500 |
+
is_experimental=True,
|
| 501 |
+
inputs=[
|
| 502 |
+
io.Model.Input("model", tooltip="The model to add LazyCache to."),
|
| 503 |
+
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True),
|
| 504 |
+
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache.", advanced=True),
|
| 505 |
+
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache.", advanced=True),
|
| 506 |
+
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True),
|
| 507 |
+
],
|
| 508 |
+
outputs=[
|
| 509 |
+
io.Model.Output(tooltip="The model with LazyCache."),
|
| 510 |
+
],
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
@classmethod
|
| 514 |
+
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
| 515 |
+
model = model.clone()
|
| 516 |
+
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
| 517 |
+
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
| 518 |
+
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
| 519 |
+
return io.NodeOutput(model)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class EasyCacheExtension(ComfyExtension):
|
| 523 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 524 |
+
return [
|
| 525 |
+
EasyCacheNode,
|
| 526 |
+
LazyCacheNode,
|
| 527 |
+
]
|
| 528 |
+
|
| 529 |
+
def comfy_entrypoint():
|
| 530 |
+
return EasyCacheExtension()
|
ComfyUI/comfy_extras/nodes_edit_model.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import node_helpers
|
| 2 |
+
from typing_extensions import override
|
| 3 |
+
from comfy_api.latest import ComfyExtension, io
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ReferenceLatent(io.ComfyNode):
|
| 7 |
+
@classmethod
|
| 8 |
+
def define_schema(cls):
|
| 9 |
+
return io.Schema(
|
| 10 |
+
node_id="ReferenceLatent",
|
| 11 |
+
category="advanced/conditioning/edit_models",
|
| 12 |
+
description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.",
|
| 13 |
+
inputs=[
|
| 14 |
+
io.Conditioning.Input("conditioning"),
|
| 15 |
+
io.Latent.Input("latent", optional=True),
|
| 16 |
+
],
|
| 17 |
+
outputs=[
|
| 18 |
+
io.Conditioning.Output(),
|
| 19 |
+
]
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
|
| 24 |
+
if latent is not None:
|
| 25 |
+
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
|
| 26 |
+
return io.NodeOutput(conditioning)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class EditModelExtension(ComfyExtension):
|
| 30 |
+
@override
|
| 31 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 32 |
+
return [
|
| 33 |
+
ReferenceLatent,
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def comfy_entrypoint() -> EditModelExtension:
|
| 38 |
+
return EditModelExtension()
|
ComfyUI/comfy_extras/nodes_eps.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing_extensions import override
|
| 3 |
+
|
| 4 |
+
from comfy.k_diffusion.sampling import sigma_to_half_log_snr
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class EpsilonScaling(io.ComfyNode):
|
| 9 |
+
"""
|
| 10 |
+
Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
|
| 11 |
+
(https://arxiv.org/abs/2308.15321v6).
|
| 12 |
+
|
| 13 |
+
This method mitigates exposure bias by scaling the predicted noise during sampling,
|
| 14 |
+
which can significantly improve sample quality. This implementation uses the "uniform schedule"
|
| 15 |
+
recommended by the paper for its practicality and effectiveness.
|
| 16 |
+
"""
|
| 17 |
+
@classmethod
|
| 18 |
+
def define_schema(cls):
|
| 19 |
+
return io.Schema(
|
| 20 |
+
node_id="Epsilon Scaling",
|
| 21 |
+
category="model_patches/unet",
|
| 22 |
+
inputs=[
|
| 23 |
+
io.Model.Input("model"),
|
| 24 |
+
io.Float.Input(
|
| 25 |
+
"scaling_factor",
|
| 26 |
+
default=1.005,
|
| 27 |
+
min=0.5,
|
| 28 |
+
max=1.5,
|
| 29 |
+
step=0.001,
|
| 30 |
+
display_mode=io.NumberDisplay.number,
|
| 31 |
+
advanced=True,
|
| 32 |
+
),
|
| 33 |
+
],
|
| 34 |
+
outputs=[
|
| 35 |
+
io.Model.Output(),
|
| 36 |
+
],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def execute(cls, model, scaling_factor) -> io.NodeOutput:
|
| 41 |
+
# Prevent division by zero, though the UI's min value should prevent this.
|
| 42 |
+
if scaling_factor == 0:
|
| 43 |
+
scaling_factor = 1e-9
|
| 44 |
+
|
| 45 |
+
def epsilon_scaling_function(args):
|
| 46 |
+
"""
|
| 47 |
+
This function is applied after the CFG guidance has been calculated.
|
| 48 |
+
It recalculates the denoised latent by scaling the predicted noise.
|
| 49 |
+
"""
|
| 50 |
+
denoised = args["denoised"]
|
| 51 |
+
x = args["input"]
|
| 52 |
+
|
| 53 |
+
noise_pred = x - denoised
|
| 54 |
+
|
| 55 |
+
scaled_noise_pred = noise_pred / scaling_factor
|
| 56 |
+
|
| 57 |
+
new_denoised = x - scaled_noise_pred
|
| 58 |
+
|
| 59 |
+
return new_denoised
|
| 60 |
+
|
| 61 |
+
# Clone the model patcher to avoid modifying the original model in place
|
| 62 |
+
model_clone = model.clone()
|
| 63 |
+
|
| 64 |
+
model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
|
| 65 |
+
|
| 66 |
+
return io.NodeOutput(model_clone)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def compute_tsr_rescaling_factor(
|
| 70 |
+
snr: torch.Tensor, tsr_k: float, tsr_variance: float
|
| 71 |
+
) -> torch.Tensor:
|
| 72 |
+
"""Compute the rescaling score ratio in Temporal Score Rescaling.
|
| 73 |
+
|
| 74 |
+
See equation (6) in https://arxiv.org/pdf/2510.01184v1.
|
| 75 |
+
"""
|
| 76 |
+
posinf_mask = torch.isposinf(snr)
|
| 77 |
+
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
|
| 78 |
+
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class TemporalScoreRescaling(io.ComfyNode):
|
| 82 |
+
@classmethod
|
| 83 |
+
def define_schema(cls):
|
| 84 |
+
return io.Schema(
|
| 85 |
+
node_id="TemporalScoreRescaling",
|
| 86 |
+
display_name="TSR - Temporal Score Rescaling",
|
| 87 |
+
category="model_patches/unet",
|
| 88 |
+
inputs=[
|
| 89 |
+
io.Model.Input("model"),
|
| 90 |
+
io.Float.Input(
|
| 91 |
+
"tsr_k",
|
| 92 |
+
tooltip=(
|
| 93 |
+
"Controls the rescaling strength.\n"
|
| 94 |
+
"Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling."
|
| 95 |
+
),
|
| 96 |
+
default=0.95,
|
| 97 |
+
min=0.01,
|
| 98 |
+
max=100.0,
|
| 99 |
+
step=0.001,
|
| 100 |
+
display_mode=io.NumberDisplay.number,
|
| 101 |
+
advanced=True,
|
| 102 |
+
),
|
| 103 |
+
io.Float.Input(
|
| 104 |
+
"tsr_sigma",
|
| 105 |
+
tooltip=(
|
| 106 |
+
"Controls how early rescaling takes effect.\n"
|
| 107 |
+
"Larger values take effect earlier."
|
| 108 |
+
),
|
| 109 |
+
default=1.0,
|
| 110 |
+
min=0.01,
|
| 111 |
+
max=100.0,
|
| 112 |
+
step=0.001,
|
| 113 |
+
display_mode=io.NumberDisplay.number,
|
| 114 |
+
advanced=True,
|
| 115 |
+
),
|
| 116 |
+
],
|
| 117 |
+
outputs=[
|
| 118 |
+
io.Model.Output(
|
| 119 |
+
display_name="patched_model",
|
| 120 |
+
),
|
| 121 |
+
],
|
| 122 |
+
description=(
|
| 123 |
+
"[Post-CFG Function]\n"
|
| 124 |
+
"TSR - Temporal Score Rescaling (2510.01184)\n\n"
|
| 125 |
+
"Rescaling the model's score or noise to steer the sampling diversity.\n"
|
| 126 |
+
),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
@classmethod
|
| 130 |
+
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
|
| 131 |
+
tsr_variance = tsr_sigma**2
|
| 132 |
+
|
| 133 |
+
def temporal_score_rescaling(args):
|
| 134 |
+
denoised = args["denoised"]
|
| 135 |
+
x = args["input"]
|
| 136 |
+
sigma = args["sigma"]
|
| 137 |
+
curr_model = args["model"]
|
| 138 |
+
|
| 139 |
+
# No rescaling (r = 1) or no noise
|
| 140 |
+
if tsr_k == 1 or sigma == 0:
|
| 141 |
+
return denoised
|
| 142 |
+
|
| 143 |
+
model_sampling = curr_model.current_patcher.get_model_object("model_sampling")
|
| 144 |
+
half_log_snr = sigma_to_half_log_snr(sigma, model_sampling)
|
| 145 |
+
snr = (2 * half_log_snr).exp()
|
| 146 |
+
|
| 147 |
+
# No rescaling needed (r = 1)
|
| 148 |
+
if snr == 0:
|
| 149 |
+
return denoised
|
| 150 |
+
|
| 151 |
+
rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance)
|
| 152 |
+
|
| 153 |
+
# Derived from scaled_denoised = (x - r * sigma * noise) / alpha
|
| 154 |
+
alpha = sigma * half_log_snr.exp()
|
| 155 |
+
return torch.lerp(x / alpha, denoised, rescaling_r)
|
| 156 |
+
|
| 157 |
+
m = model.clone()
|
| 158 |
+
m.set_model_sampler_post_cfg_function(temporal_score_rescaling)
|
| 159 |
+
return io.NodeOutput(m)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class EpsilonScalingExtension(ComfyExtension):
|
| 163 |
+
@override
|
| 164 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 165 |
+
return [
|
| 166 |
+
EpsilonScaling,
|
| 167 |
+
TemporalScoreRescaling,
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
async def comfy_entrypoint() -> EpsilonScalingExtension:
|
| 172 |
+
return EpsilonScalingExtension()
|
ComfyUI/comfy_extras/nodes_flux.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import node_helpers
|
| 2 |
+
import comfy.utils
|
| 3 |
+
from typing_extensions import override
|
| 4 |
+
from comfy_api.latest import ComfyExtension, io
|
| 5 |
+
import comfy.model_management
|
| 6 |
+
import torch
|
| 7 |
+
import math
|
| 8 |
+
import nodes
|
| 9 |
+
import comfy.ldm.flux.math
|
| 10 |
+
|
| 11 |
+
class CLIPTextEncodeFlux(io.ComfyNode):
|
| 12 |
+
@classmethod
|
| 13 |
+
def define_schema(cls):
|
| 14 |
+
return io.Schema(
|
| 15 |
+
node_id="CLIPTextEncodeFlux",
|
| 16 |
+
category="advanced/conditioning/flux",
|
| 17 |
+
inputs=[
|
| 18 |
+
io.Clip.Input("clip"),
|
| 19 |
+
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
| 20 |
+
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
| 21 |
+
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
| 22 |
+
],
|
| 23 |
+
outputs=[
|
| 24 |
+
io.Conditioning.Output(),
|
| 25 |
+
],
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
|
| 30 |
+
tokens = clip.tokenize(clip_l)
|
| 31 |
+
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
| 32 |
+
|
| 33 |
+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
|
| 34 |
+
|
| 35 |
+
encode = execute # TODO: remove
|
| 36 |
+
|
| 37 |
+
class EmptyFlux2LatentImage(io.ComfyNode):
|
| 38 |
+
@classmethod
|
| 39 |
+
def define_schema(cls):
|
| 40 |
+
return io.Schema(
|
| 41 |
+
node_id="EmptyFlux2LatentImage",
|
| 42 |
+
display_name="Empty Flux 2 Latent",
|
| 43 |
+
category="latent",
|
| 44 |
+
inputs=[
|
| 45 |
+
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 46 |
+
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 47 |
+
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
| 48 |
+
],
|
| 49 |
+
outputs=[
|
| 50 |
+
io.Latent.Output(),
|
| 51 |
+
],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
| 56 |
+
latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
| 57 |
+
return io.NodeOutput({"samples": latent})
|
| 58 |
+
|
| 59 |
+
class FluxGuidance(io.ComfyNode):
|
| 60 |
+
@classmethod
|
| 61 |
+
def define_schema(cls):
|
| 62 |
+
return io.Schema(
|
| 63 |
+
node_id="FluxGuidance",
|
| 64 |
+
category="advanced/conditioning/flux",
|
| 65 |
+
inputs=[
|
| 66 |
+
io.Conditioning.Input("conditioning"),
|
| 67 |
+
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
| 68 |
+
],
|
| 69 |
+
outputs=[
|
| 70 |
+
io.Conditioning.Output(),
|
| 71 |
+
],
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def execute(cls, conditioning, guidance) -> io.NodeOutput:
|
| 76 |
+
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
| 77 |
+
return io.NodeOutput(c)
|
| 78 |
+
|
| 79 |
+
append = execute # TODO: remove
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class FluxDisableGuidance(io.ComfyNode):
|
| 83 |
+
@classmethod
|
| 84 |
+
def define_schema(cls):
|
| 85 |
+
return io.Schema(
|
| 86 |
+
node_id="FluxDisableGuidance",
|
| 87 |
+
category="advanced/conditioning/flux",
|
| 88 |
+
description="This node completely disables the guidance embed on Flux and Flux like models",
|
| 89 |
+
inputs=[
|
| 90 |
+
io.Conditioning.Input("conditioning"),
|
| 91 |
+
],
|
| 92 |
+
outputs=[
|
| 93 |
+
io.Conditioning.Output(),
|
| 94 |
+
],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
@classmethod
|
| 98 |
+
def execute(cls, conditioning) -> io.NodeOutput:
|
| 99 |
+
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
| 100 |
+
return io.NodeOutput(c)
|
| 101 |
+
|
| 102 |
+
append = execute # TODO: remove
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
PREFERED_KONTEXT_RESOLUTIONS = [
|
| 106 |
+
(672, 1568),
|
| 107 |
+
(688, 1504),
|
| 108 |
+
(720, 1456),
|
| 109 |
+
(752, 1392),
|
| 110 |
+
(800, 1328),
|
| 111 |
+
(832, 1248),
|
| 112 |
+
(880, 1184),
|
| 113 |
+
(944, 1104),
|
| 114 |
+
(1024, 1024),
|
| 115 |
+
(1104, 944),
|
| 116 |
+
(1184, 880),
|
| 117 |
+
(1248, 832),
|
| 118 |
+
(1328, 800),
|
| 119 |
+
(1392, 752),
|
| 120 |
+
(1456, 720),
|
| 121 |
+
(1504, 688),
|
| 122 |
+
(1568, 672),
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class FluxKontextImageScale(io.ComfyNode):
|
| 127 |
+
@classmethod
|
| 128 |
+
def define_schema(cls):
|
| 129 |
+
return io.Schema(
|
| 130 |
+
node_id="FluxKontextImageScale",
|
| 131 |
+
category="advanced/conditioning/flux",
|
| 132 |
+
description="This node resizes the image to one that is more optimal for flux kontext.",
|
| 133 |
+
inputs=[
|
| 134 |
+
io.Image.Input("image"),
|
| 135 |
+
],
|
| 136 |
+
outputs=[
|
| 137 |
+
io.Image.Output(),
|
| 138 |
+
],
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
@classmethod
|
| 142 |
+
def execute(cls, image) -> io.NodeOutput:
|
| 143 |
+
width = image.shape[2]
|
| 144 |
+
height = image.shape[1]
|
| 145 |
+
aspect_ratio = width / height
|
| 146 |
+
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
| 147 |
+
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
| 148 |
+
return io.NodeOutput(image)
|
| 149 |
+
|
| 150 |
+
scale = execute # TODO: remove
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
| 154 |
+
@classmethod
|
| 155 |
+
def define_schema(cls):
|
| 156 |
+
return io.Schema(
|
| 157 |
+
node_id="FluxKontextMultiReferenceLatentMethod",
|
| 158 |
+
display_name="Edit Model Reference Method",
|
| 159 |
+
category="advanced/conditioning/flux",
|
| 160 |
+
inputs=[
|
| 161 |
+
io.Conditioning.Input("conditioning"),
|
| 162 |
+
io.Combo.Input(
|
| 163 |
+
"reference_latents_method",
|
| 164 |
+
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
|
| 165 |
+
advanced=True,
|
| 166 |
+
),
|
| 167 |
+
],
|
| 168 |
+
outputs=[
|
| 169 |
+
io.Conditioning.Output(),
|
| 170 |
+
],
|
| 171 |
+
is_experimental=True,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
@classmethod
|
| 175 |
+
def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
|
| 176 |
+
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
|
| 177 |
+
reference_latents_method = "uxo"
|
| 178 |
+
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
| 179 |
+
return io.NodeOutput(c)
|
| 180 |
+
|
| 181 |
+
append = execute # TODO: remove
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def generalized_time_snr_shift(t, mu: float, sigma: float):
|
| 185 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
| 189 |
+
a1, b1 = 8.73809524e-05, 1.89833333
|
| 190 |
+
a2, b2 = 0.00016927, 0.45666666
|
| 191 |
+
|
| 192 |
+
if image_seq_len > 4300:
|
| 193 |
+
mu = a2 * image_seq_len + b2
|
| 194 |
+
return float(mu)
|
| 195 |
+
|
| 196 |
+
m_200 = a2 * image_seq_len + b2
|
| 197 |
+
m_10 = a1 * image_seq_len + b1
|
| 198 |
+
|
| 199 |
+
a = (m_200 - m_10) / 190.0
|
| 200 |
+
b = m_200 - 200.0 * a
|
| 201 |
+
mu = a * num_steps + b
|
| 202 |
+
|
| 203 |
+
return float(mu)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
|
| 207 |
+
mu = compute_empirical_mu(image_seq_len, num_steps)
|
| 208 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
| 209 |
+
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
|
| 210 |
+
return timesteps
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class Flux2Scheduler(io.ComfyNode):
|
| 214 |
+
@classmethod
|
| 215 |
+
def define_schema(cls):
|
| 216 |
+
return io.Schema(
|
| 217 |
+
node_id="Flux2Scheduler",
|
| 218 |
+
category="sampling/custom_sampling/schedulers",
|
| 219 |
+
inputs=[
|
| 220 |
+
io.Int.Input("steps", default=20, min=1, max=4096),
|
| 221 |
+
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
|
| 222 |
+
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
|
| 223 |
+
],
|
| 224 |
+
outputs=[
|
| 225 |
+
io.Sigmas.Output(),
|
| 226 |
+
],
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
@classmethod
|
| 230 |
+
def execute(cls, steps, width, height) -> io.NodeOutput:
|
| 231 |
+
seq_len = (width * height / (16 * 16))
|
| 232 |
+
sigmas = get_schedule(steps, round(seq_len))
|
| 233 |
+
return io.NodeOutput(sigmas)
|
| 234 |
+
|
| 235 |
+
class KV_Attn_Input:
|
| 236 |
+
def __init__(self):
|
| 237 |
+
self.cache = {}
|
| 238 |
+
|
| 239 |
+
def __call__(self, q, k, v, extra_options, **kwargs):
|
| 240 |
+
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
|
| 241 |
+
if len(reference_image_num_tokens) == 0:
|
| 242 |
+
return {}
|
| 243 |
+
|
| 244 |
+
ref_toks = sum(reference_image_num_tokens)
|
| 245 |
+
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
| 246 |
+
if cache_key in self.cache:
|
| 247 |
+
kk, vv = self.cache[cache_key]
|
| 248 |
+
self.set_cache = False
|
| 249 |
+
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
| 250 |
+
|
| 251 |
+
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
|
| 252 |
+
self.set_cache = True
|
| 253 |
+
return {"q": q, "k": k, "v": v}
|
| 254 |
+
|
| 255 |
+
def cleanup(self):
|
| 256 |
+
self.cache = {}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class FluxKVCache(io.ComfyNode):
|
| 260 |
+
@classmethod
|
| 261 |
+
def define_schema(cls) -> io.Schema:
|
| 262 |
+
return io.Schema(
|
| 263 |
+
node_id="FluxKVCache",
|
| 264 |
+
display_name="Flux KV Cache",
|
| 265 |
+
description="Enables KV Cache optimization for reference images on Flux family models.",
|
| 266 |
+
category="",
|
| 267 |
+
is_experimental=True,
|
| 268 |
+
inputs=[
|
| 269 |
+
io.Model.Input("model", tooltip="The model to use KV Cache on."),
|
| 270 |
+
],
|
| 271 |
+
outputs=[
|
| 272 |
+
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
|
| 273 |
+
],
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
@classmethod
|
| 277 |
+
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
| 278 |
+
m = model.clone()
|
| 279 |
+
input_patch_obj = KV_Attn_Input()
|
| 280 |
+
|
| 281 |
+
def model_input_patch(inputs):
|
| 282 |
+
if len(input_patch_obj.cache) > 0:
|
| 283 |
+
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
|
| 284 |
+
if ref_image_tokens > 0:
|
| 285 |
+
img = inputs["img"]
|
| 286 |
+
inputs["img"] = img[:, :-ref_image_tokens]
|
| 287 |
+
return inputs
|
| 288 |
+
|
| 289 |
+
m.set_model_attn1_patch(input_patch_obj)
|
| 290 |
+
m.set_model_post_input_patch(model_input_patch)
|
| 291 |
+
if hasattr(model.model.diffusion_model, "params"):
|
| 292 |
+
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
|
| 293 |
+
else:
|
| 294 |
+
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
|
| 295 |
+
|
| 296 |
+
return io.NodeOutput(m)
|
| 297 |
+
|
| 298 |
+
class FluxExtension(ComfyExtension):
|
| 299 |
+
@override
|
| 300 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 301 |
+
return [
|
| 302 |
+
CLIPTextEncodeFlux,
|
| 303 |
+
FluxGuidance,
|
| 304 |
+
FluxDisableGuidance,
|
| 305 |
+
FluxKontextImageScale,
|
| 306 |
+
FluxKontextMultiReferenceLatentMethod,
|
| 307 |
+
EmptyFlux2LatentImage,
|
| 308 |
+
Flux2Scheduler,
|
| 309 |
+
FluxKVCache,
|
| 310 |
+
]
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
async def comfy_entrypoint() -> FluxExtension:
|
| 314 |
+
return FluxExtension()
|
ComfyUI/comfy_extras/nodes_frame_interpolation.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from typing_extensions import override
|
| 4 |
+
|
| 5 |
+
import comfy.model_patcher
|
| 6 |
+
import comfy.utils
|
| 7 |
+
import folder_paths
|
| 8 |
+
from comfy import model_management
|
| 9 |
+
from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config
|
| 10 |
+
from comfy_extras.frame_interpolation_models.film_net import FILMNet
|
| 11 |
+
from comfy_api.latest import ComfyExtension, io
|
| 12 |
+
|
| 13 |
+
FrameInterpolationModel = io.Custom("INTERP_MODEL")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FrameInterpolationModelLoader(io.ComfyNode):
|
| 17 |
+
@classmethod
|
| 18 |
+
def define_schema(cls):
|
| 19 |
+
return io.Schema(
|
| 20 |
+
node_id="FrameInterpolationModelLoader",
|
| 21 |
+
display_name="Load Frame Interpolation Model",
|
| 22 |
+
category="loaders",
|
| 23 |
+
inputs=[
|
| 24 |
+
io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"),
|
| 25 |
+
tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."),
|
| 26 |
+
],
|
| 27 |
+
outputs=[
|
| 28 |
+
FrameInterpolationModel.Output(),
|
| 29 |
+
],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def execute(cls, model_name) -> io.NodeOutput:
|
| 34 |
+
model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name)
|
| 35 |
+
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
| 36 |
+
|
| 37 |
+
model = cls._detect_and_load(sd)
|
| 38 |
+
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
|
| 39 |
+
model.eval().to(dtype)
|
| 40 |
+
patcher = comfy.model_patcher.ModelPatcher(
|
| 41 |
+
model,
|
| 42 |
+
load_device=model_management.get_torch_device(),
|
| 43 |
+
offload_device=model_management.unet_offload_device(),
|
| 44 |
+
)
|
| 45 |
+
return io.NodeOutput(patcher)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def _detect_and_load(cls, sd):
|
| 49 |
+
# Try FILM
|
| 50 |
+
if "extract.extract_sublevels.convs.0.0.conv.weight" in sd:
|
| 51 |
+
model = FILMNet()
|
| 52 |
+
model.load_state_dict(sd)
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
# Try RIFE (needs key remapping for raw checkpoints)
|
| 56 |
+
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""})
|
| 57 |
+
key_map = {}
|
| 58 |
+
for k in sd:
|
| 59 |
+
for i in range(5):
|
| 60 |
+
if k.startswith(f"block{i}."):
|
| 61 |
+
key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}"
|
| 62 |
+
if key_map:
|
| 63 |
+
sd = {key_map.get(k, k): v for k, v in sd.items()}
|
| 64 |
+
sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))}
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
head_ch, channels = detect_rife_config(sd)
|
| 68 |
+
except (KeyError, ValueError):
|
| 69 |
+
raise ValueError("Unrecognized frame interpolation model format")
|
| 70 |
+
model = IFNet(head_ch=head_ch, channels=channels)
|
| 71 |
+
model.load_state_dict(sd)
|
| 72 |
+
return model
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FrameInterpolate(io.ComfyNode):
|
| 76 |
+
@classmethod
|
| 77 |
+
def define_schema(cls):
|
| 78 |
+
return io.Schema(
|
| 79 |
+
node_id="FrameInterpolate",
|
| 80 |
+
display_name="Frame Interpolate",
|
| 81 |
+
category="image/video",
|
| 82 |
+
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
|
| 83 |
+
inputs=[
|
| 84 |
+
FrameInterpolationModel.Input("interp_model"),
|
| 85 |
+
io.Image.Input("images"),
|
| 86 |
+
io.Int.Input("multiplier", default=2, min=2, max=16),
|
| 87 |
+
],
|
| 88 |
+
outputs=[
|
| 89 |
+
io.Image.Output(),
|
| 90 |
+
],
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def execute(cls, interp_model, images, multiplier) -> io.NodeOutput:
|
| 95 |
+
offload_device = model_management.intermediate_device()
|
| 96 |
+
|
| 97 |
+
num_frames = images.shape[0]
|
| 98 |
+
if num_frames < 2 or multiplier < 2:
|
| 99 |
+
return io.NodeOutput(images)
|
| 100 |
+
|
| 101 |
+
model_management.load_model_gpu(interp_model)
|
| 102 |
+
device = interp_model.load_device
|
| 103 |
+
dtype = interp_model.model_dtype()
|
| 104 |
+
inference_model = interp_model.model
|
| 105 |
+
|
| 106 |
+
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
|
| 107 |
+
H, W = images.shape[1], images.shape[2]
|
| 108 |
+
activation_mem = H * W * 3 * images.element_size() * 20
|
| 109 |
+
model_management.free_memory(activation_mem, device)
|
| 110 |
+
align = getattr(inference_model, "pad_align", 1)
|
| 111 |
+
|
| 112 |
+
# Prepare a single padded frame on device for determining output dimensions
|
| 113 |
+
def prepare_frame(idx):
|
| 114 |
+
frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device)
|
| 115 |
+
if align > 1:
|
| 116 |
+
from comfy.ldm.common_dit import pad_to_patch_size
|
| 117 |
+
frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect")
|
| 118 |
+
return frame
|
| 119 |
+
|
| 120 |
+
# Count total interpolation passes for progress bar
|
| 121 |
+
total_pairs = num_frames - 1
|
| 122 |
+
num_interp = multiplier - 1
|
| 123 |
+
total_steps = total_pairs * num_interp
|
| 124 |
+
pbar = comfy.utils.ProgressBar(total_steps)
|
| 125 |
+
tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation")
|
| 126 |
+
|
| 127 |
+
batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit)
|
| 128 |
+
t_values = [t / multiplier for t in range(1, multiplier)]
|
| 129 |
+
|
| 130 |
+
out_dtype = model_management.intermediate_dtype()
|
| 131 |
+
total_out_frames = total_pairs * multiplier + 1
|
| 132 |
+
result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device)
|
| 133 |
+
result[0] = images[0].movedim(-1, 0).to(out_dtype)
|
| 134 |
+
out_idx = 1
|
| 135 |
+
|
| 136 |
+
# Pre-compute timestep tensor on device (padded dimensions needed)
|
| 137 |
+
sample = prepare_frame(0)
|
| 138 |
+
pH, pW = sample.shape[2], sample.shape[3]
|
| 139 |
+
ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1)
|
| 140 |
+
ts_full = ts_full.expand(-1, 1, pH, pW)
|
| 141 |
+
del sample
|
| 142 |
+
|
| 143 |
+
multi_fn = getattr(inference_model, "forward_multi_timestep", None)
|
| 144 |
+
feat_cache = {}
|
| 145 |
+
prev_frame = None
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
for i in range(total_pairs):
|
| 149 |
+
img0_single = prev_frame if prev_frame is not None else prepare_frame(i)
|
| 150 |
+
img1_single = prepare_frame(i + 1)
|
| 151 |
+
prev_frame = img1_single
|
| 152 |
+
|
| 153 |
+
# Cache features: img1 of pair N becomes img0 of pair N+1
|
| 154 |
+
feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single)
|
| 155 |
+
feat_cache["img1"] = inference_model.extract_features(img1_single)
|
| 156 |
+
feat_cache["next"] = feat_cache["img1"]
|
| 157 |
+
|
| 158 |
+
used_multi = False
|
| 159 |
+
if multi_fn is not None:
|
| 160 |
+
# Models with timestep-independent flow can compute it once for all timesteps
|
| 161 |
+
try:
|
| 162 |
+
mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache)
|
| 163 |
+
result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype)
|
| 164 |
+
out_idx += num_interp
|
| 165 |
+
pbar.update(num_interp)
|
| 166 |
+
tqdm_bar.update(num_interp)
|
| 167 |
+
used_multi = True
|
| 168 |
+
except model_management.OOM_EXCEPTION:
|
| 169 |
+
model_management.soft_empty_cache()
|
| 170 |
+
multi_fn = None # fall through to single-timestep path
|
| 171 |
+
|
| 172 |
+
if not used_multi:
|
| 173 |
+
j = 0
|
| 174 |
+
while j < num_interp:
|
| 175 |
+
b = min(batch, num_interp - j)
|
| 176 |
+
try:
|
| 177 |
+
img0 = img0_single.expand(b, -1, -1, -1)
|
| 178 |
+
img1 = img1_single.expand(b, -1, -1, -1)
|
| 179 |
+
mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache)
|
| 180 |
+
result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype)
|
| 181 |
+
out_idx += b
|
| 182 |
+
pbar.update(b)
|
| 183 |
+
tqdm_bar.update(b)
|
| 184 |
+
j += b
|
| 185 |
+
except model_management.OOM_EXCEPTION:
|
| 186 |
+
if batch <= 1:
|
| 187 |
+
raise
|
| 188 |
+
batch = max(1, batch // 2)
|
| 189 |
+
model_management.soft_empty_cache()
|
| 190 |
+
|
| 191 |
+
result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype)
|
| 192 |
+
out_idx += 1
|
| 193 |
+
finally:
|
| 194 |
+
tqdm_bar.close()
|
| 195 |
+
|
| 196 |
+
# BCHW -> BHWC
|
| 197 |
+
result = result.movedim(1, -1).clamp_(0.0, 1.0)
|
| 198 |
+
return io.NodeOutput(result)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class FrameInterpolationExtension(ComfyExtension):
|
| 202 |
+
@override
|
| 203 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 204 |
+
return [
|
| 205 |
+
FrameInterpolationModelLoader,
|
| 206 |
+
FrameInterpolate,
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
async def comfy_entrypoint() -> FrameInterpolationExtension:
|
| 211 |
+
return FrameInterpolationExtension()
|
ComfyUI/comfy_extras/nodes_freelunch.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
from typing_extensions import override
|
| 6 |
+
from comfy_api.latest import ComfyExtension, IO
|
| 7 |
+
|
| 8 |
+
def Fourier_filter(x, threshold, scale):
|
| 9 |
+
# FFT
|
| 10 |
+
x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
|
| 11 |
+
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
|
| 12 |
+
|
| 13 |
+
B, C, H, W = x_freq.shape
|
| 14 |
+
mask = torch.ones((B, C, H, W), device=x.device)
|
| 15 |
+
|
| 16 |
+
crow, ccol = H // 2, W //2
|
| 17 |
+
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
|
| 18 |
+
x_freq = x_freq * mask
|
| 19 |
+
|
| 20 |
+
# IFFT
|
| 21 |
+
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
|
| 22 |
+
x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
|
| 23 |
+
|
| 24 |
+
return x_filtered.to(x.dtype)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class FreeU(IO.ComfyNode):
|
| 28 |
+
@classmethod
|
| 29 |
+
def define_schema(cls):
|
| 30 |
+
return IO.Schema(
|
| 31 |
+
node_id="FreeU",
|
| 32 |
+
category="model_patches/unet",
|
| 33 |
+
inputs=[
|
| 34 |
+
IO.Model.Input("model"),
|
| 35 |
+
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 36 |
+
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 37 |
+
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 38 |
+
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 39 |
+
],
|
| 40 |
+
outputs=[
|
| 41 |
+
IO.Model.Output(),
|
| 42 |
+
],
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
| 47 |
+
model_channels = model.model.model_config.unet_config["model_channels"]
|
| 48 |
+
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
| 49 |
+
on_cpu_devices = {}
|
| 50 |
+
|
| 51 |
+
def output_block_patch(h, hsp, transformer_options):
|
| 52 |
+
scale = scale_dict.get(int(h.shape[1]), None)
|
| 53 |
+
if scale is not None:
|
| 54 |
+
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
|
| 55 |
+
if hsp.device not in on_cpu_devices:
|
| 56 |
+
try:
|
| 57 |
+
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
| 58 |
+
except:
|
| 59 |
+
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
|
| 60 |
+
on_cpu_devices[hsp.device] = True
|
| 61 |
+
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
| 62 |
+
else:
|
| 63 |
+
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
| 64 |
+
|
| 65 |
+
return h, hsp
|
| 66 |
+
|
| 67 |
+
m = model.clone()
|
| 68 |
+
m.set_model_output_block_patch(output_block_patch)
|
| 69 |
+
return IO.NodeOutput(m)
|
| 70 |
+
|
| 71 |
+
patch = execute # TODO: remove
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class FreeU_V2(IO.ComfyNode):
|
| 75 |
+
@classmethod
|
| 76 |
+
def define_schema(cls):
|
| 77 |
+
return IO.Schema(
|
| 78 |
+
node_id="FreeU_V2",
|
| 79 |
+
category="model_patches/unet",
|
| 80 |
+
inputs=[
|
| 81 |
+
IO.Model.Input("model"),
|
| 82 |
+
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 83 |
+
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 84 |
+
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 85 |
+
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True),
|
| 86 |
+
],
|
| 87 |
+
outputs=[
|
| 88 |
+
IO.Model.Output(),
|
| 89 |
+
],
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
@classmethod
|
| 93 |
+
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
| 94 |
+
model_channels = model.model.model_config.unet_config["model_channels"]
|
| 95 |
+
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
| 96 |
+
on_cpu_devices = {}
|
| 97 |
+
|
| 98 |
+
def output_block_patch(h, hsp, transformer_options):
|
| 99 |
+
scale = scale_dict.get(int(h.shape[1]), None)
|
| 100 |
+
if scale is not None:
|
| 101 |
+
hidden_mean = h.mean(1).unsqueeze(1)
|
| 102 |
+
B = hidden_mean.shape[0]
|
| 103 |
+
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 104 |
+
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 105 |
+
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
| 106 |
+
|
| 107 |
+
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
|
| 108 |
+
|
| 109 |
+
if hsp.device not in on_cpu_devices:
|
| 110 |
+
try:
|
| 111 |
+
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
| 112 |
+
except:
|
| 113 |
+
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
|
| 114 |
+
on_cpu_devices[hsp.device] = True
|
| 115 |
+
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
| 116 |
+
else:
|
| 117 |
+
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
| 118 |
+
|
| 119 |
+
return h, hsp
|
| 120 |
+
|
| 121 |
+
m = model.clone()
|
| 122 |
+
m.set_model_output_block_patch(output_block_patch)
|
| 123 |
+
return IO.NodeOutput(m)
|
| 124 |
+
|
| 125 |
+
patch = execute # TODO: remove
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class FreelunchExtension(ComfyExtension):
|
| 129 |
+
@override
|
| 130 |
+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
| 131 |
+
return [
|
| 132 |
+
FreeU,
|
| 133 |
+
FreeU_V2,
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
async def comfy_entrypoint() -> FreelunchExtension:
|
| 138 |
+
return FreelunchExtension()
|
ComfyUI/comfy_extras/nodes_fresca.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code based on https://github.com/WikiChao/FreSca (MIT License)
|
| 2 |
+
import torch
|
| 3 |
+
import torch.fft as fft
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
|
| 9 |
+
"""
|
| 10 |
+
Apply frequency-dependent scaling to an image tensor using Fourier transforms.
|
| 11 |
+
|
| 12 |
+
Parameters:
|
| 13 |
+
x: Input tensor of shape (B, C, H, W)
|
| 14 |
+
scale_low: Scaling factor for low-frequency components (default: 1.0)
|
| 15 |
+
scale_high: Scaling factor for high-frequency components (default: 1.5)
|
| 16 |
+
freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
|
| 20 |
+
"""
|
| 21 |
+
# Preserve input dtype and device
|
| 22 |
+
dtype, device = x.dtype, x.device
|
| 23 |
+
|
| 24 |
+
# Convert to float32 for FFT computations
|
| 25 |
+
x = x.to(torch.float32)
|
| 26 |
+
|
| 27 |
+
# 1) Apply FFT and shift low frequencies to center
|
| 28 |
+
x_freq = fft.fftn(x, dim=(-2, -1))
|
| 29 |
+
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
| 30 |
+
|
| 31 |
+
# Initialize mask with high-frequency scaling factor
|
| 32 |
+
mask = torch.ones(x_freq.shape, device=device) * scale_high
|
| 33 |
+
m = mask
|
| 34 |
+
for d in range(len(x_freq.shape) - 2):
|
| 35 |
+
dim = d + 2
|
| 36 |
+
cc = x_freq.shape[dim] // 2
|
| 37 |
+
f_c = min(freq_cutoff, cc)
|
| 38 |
+
m = m.narrow(dim, cc - f_c, f_c * 2)
|
| 39 |
+
|
| 40 |
+
# Apply low-frequency scaling factor to center region
|
| 41 |
+
m[:] = scale_low
|
| 42 |
+
|
| 43 |
+
# 3) Apply frequency-specific scaling
|
| 44 |
+
x_freq = x_freq * mask
|
| 45 |
+
|
| 46 |
+
# 4) Convert back to spatial domain
|
| 47 |
+
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
| 48 |
+
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
| 49 |
+
|
| 50 |
+
# 5) Restore original dtype
|
| 51 |
+
x_filtered = x_filtered.to(dtype)
|
| 52 |
+
|
| 53 |
+
return x_filtered
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FreSca(io.ComfyNode):
|
| 57 |
+
@classmethod
|
| 58 |
+
def define_schema(cls):
|
| 59 |
+
return io.Schema(
|
| 60 |
+
node_id="FreSca",
|
| 61 |
+
search_aliases=["frequency guidance"],
|
| 62 |
+
display_name="FreSca",
|
| 63 |
+
category="_for_testing",
|
| 64 |
+
description="Applies frequency-dependent scaling to the guidance",
|
| 65 |
+
inputs=[
|
| 66 |
+
io.Model.Input("model"),
|
| 67 |
+
io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
|
| 68 |
+
tooltip="Scaling factor for low-frequency components", advanced=True),
|
| 69 |
+
io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
|
| 70 |
+
tooltip="Scaling factor for high-frequency components", advanced=True),
|
| 71 |
+
io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
|
| 72 |
+
tooltip="Number of frequency indices around center to consider as low-frequency", advanced=True),
|
| 73 |
+
],
|
| 74 |
+
outputs=[
|
| 75 |
+
io.Model.Output(),
|
| 76 |
+
],
|
| 77 |
+
is_experimental=True,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def execute(cls, model, scale_low, scale_high, freq_cutoff):
|
| 82 |
+
def custom_cfg_function(args):
|
| 83 |
+
conds_out = args["conds_out"]
|
| 84 |
+
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
| 85 |
+
return conds_out
|
| 86 |
+
cond = conds_out[0]
|
| 87 |
+
uncond = conds_out[1]
|
| 88 |
+
|
| 89 |
+
guidance = cond - uncond
|
| 90 |
+
filtered_guidance = Fourier_filter(
|
| 91 |
+
guidance,
|
| 92 |
+
scale_low=scale_low,
|
| 93 |
+
scale_high=scale_high,
|
| 94 |
+
freq_cutoff=freq_cutoff,
|
| 95 |
+
)
|
| 96 |
+
filtered_cond = filtered_guidance + uncond
|
| 97 |
+
|
| 98 |
+
return [filtered_cond, uncond] + conds_out[2:]
|
| 99 |
+
|
| 100 |
+
m = model.clone()
|
| 101 |
+
m.set_model_sampler_pre_cfg_function(custom_cfg_function)
|
| 102 |
+
|
| 103 |
+
return io.NodeOutput(m)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class FreScaExtension(ComfyExtension):
|
| 107 |
+
@override
|
| 108 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 109 |
+
return [
|
| 110 |
+
FreSca,
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
async def comfy_entrypoint() -> FreScaExtension:
|
| 115 |
+
return FreScaExtension()
|
ComfyUI/comfy_extras/nodes_gits.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
|
| 7 |
+
def loglinear_interp(t_steps, num_steps):
|
| 8 |
+
"""
|
| 9 |
+
Performs log-linear interpolation of a given array of decreasing numbers.
|
| 10 |
+
"""
|
| 11 |
+
xs = np.linspace(0, 1, len(t_steps))
|
| 12 |
+
ys = np.log(t_steps[::-1])
|
| 13 |
+
|
| 14 |
+
new_xs = np.linspace(0, 1, num_steps)
|
| 15 |
+
new_ys = np.interp(new_xs, xs, ys)
|
| 16 |
+
|
| 17 |
+
interped_ys = np.exp(new_ys)[::-1].copy()
|
| 18 |
+
return interped_ys
|
| 19 |
+
|
| 20 |
+
NOISE_LEVELS = {
|
| 21 |
+
0.80: [
|
| 22 |
+
[14.61464119, 7.49001646, 0.02916753],
|
| 23 |
+
[14.61464119, 11.54541874, 6.77309084, 0.02916753],
|
| 24 |
+
[14.61464119, 11.54541874, 7.49001646, 3.07277966, 0.02916753],
|
| 25 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 2.05039096, 0.02916753],
|
| 26 |
+
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 2.05039096, 0.02916753],
|
| 27 |
+
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
| 28 |
+
[14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
| 29 |
+
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
| 30 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
| 31 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
|
| 32 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
|
| 33 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
|
| 34 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
|
| 35 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
|
| 36 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753],
|
| 37 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753],
|
| 38 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753],
|
| 39 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753],
|
| 40 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.75677586, 2.84484982, 1.78698075, 0.803307, 0.02916753],
|
| 41 |
+
],
|
| 42 |
+
0.85: [
|
| 43 |
+
[14.61464119, 7.49001646, 0.02916753],
|
| 44 |
+
[14.61464119, 7.49001646, 1.84880662, 0.02916753],
|
| 45 |
+
[14.61464119, 11.54541874, 6.77309084, 1.56271636, 0.02916753],
|
| 46 |
+
[14.61464119, 11.54541874, 7.11996698, 3.07277966, 1.24153244, 0.02916753],
|
| 47 |
+
[14.61464119, 11.54541874, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753],
|
| 48 |
+
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753],
|
| 49 |
+
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753],
|
| 50 |
+
[14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753],
|
| 51 |
+
[14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
| 52 |
+
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
| 53 |
+
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
| 54 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
| 55 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
|
| 56 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.60512662, 2.6383388, 1.56271636, 0.72133851, 0.02916753],
|
| 57 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
| 58 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
| 59 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
| 60 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
| 61 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
|
| 62 |
+
],
|
| 63 |
+
0.90: [
|
| 64 |
+
[14.61464119, 6.77309084, 0.02916753],
|
| 65 |
+
[14.61464119, 7.49001646, 1.56271636, 0.02916753],
|
| 66 |
+
[14.61464119, 7.49001646, 3.07277966, 0.95350921, 0.02916753],
|
| 67 |
+
[14.61464119, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753],
|
| 68 |
+
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753],
|
| 69 |
+
[14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.07277966, 1.61558151, 0.69515091, 0.02916753],
|
| 70 |
+
[14.61464119, 12.2308979, 8.75849152, 7.11996698, 4.86714602, 3.07277966, 1.61558151, 0.69515091, 0.02916753],
|
| 71 |
+
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 2.95596409, 1.61558151, 0.69515091, 0.02916753],
|
| 72 |
+
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
|
| 73 |
+
[14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
|
| 74 |
+
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
|
| 75 |
+
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753],
|
| 76 |
+
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753],
|
| 77 |
+
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753],
|
| 78 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753],
|
| 79 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
|
| 80 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
|
| 81 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
|
| 82 |
+
[14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.19988537, 1.51179266, 0.89115214, 0.43325692, 0.02916753],
|
| 83 |
+
],
|
| 84 |
+
0.95: [
|
| 85 |
+
[14.61464119, 6.77309084, 0.02916753],
|
| 86 |
+
[14.61464119, 6.77309084, 1.56271636, 0.02916753],
|
| 87 |
+
[14.61464119, 7.49001646, 2.84484982, 0.89115214, 0.02916753],
|
| 88 |
+
[14.61464119, 7.49001646, 4.86714602, 2.36326075, 0.803307, 0.02916753],
|
| 89 |
+
[14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753],
|
| 90 |
+
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753],
|
| 91 |
+
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
|
| 92 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
|
| 93 |
+
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
|
| 94 |
+
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.41535246, 0.803307, 0.38853383, 0.02916753],
|
| 95 |
+
[14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
|
| 96 |
+
[14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
|
| 97 |
+
[14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
|
| 98 |
+
[14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
|
| 99 |
+
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
|
| 100 |
+
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.78698075, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
|
| 101 |
+
[14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
|
| 102 |
+
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
|
| 103 |
+
[14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.75677586, 3.07277966, 2.45070267, 1.91321158, 1.46270394, 1.05362725, 0.72133851, 0.43325692, 0.19894916, 0.02916753],
|
| 104 |
+
],
|
| 105 |
+
1.00: [
|
| 106 |
+
[14.61464119, 1.56271636, 0.02916753],
|
| 107 |
+
[14.61464119, 6.77309084, 0.95350921, 0.02916753],
|
| 108 |
+
[14.61464119, 6.77309084, 2.36326075, 0.803307, 0.02916753],
|
| 109 |
+
[14.61464119, 7.11996698, 3.07277966, 1.56271636, 0.59516323, 0.02916753],
|
| 110 |
+
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.41535246, 0.57119018, 0.02916753],
|
| 111 |
+
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753],
|
| 112 |
+
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753],
|
| 113 |
+
[14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
|
| 114 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
|
| 115 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.54755926, 0.25053367, 0.02916753],
|
| 116 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
|
| 117 |
+
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
|
| 118 |
+
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.12350607, 1.56271636, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
|
| 119 |
+
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.803307, 0.50118381, 0.27464288, 0.09824532, 0.02916753],
|
| 120 |
+
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.84880662, 1.36964464, 1.01931262, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
| 121 |
+
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.46139455, 2.84484982, 2.19988537, 1.67050016, 1.24153244, 0.92192322, 0.64427125, 0.43325692, 0.25053367, 0.09824532, 0.02916753],
|
| 122 |
+
[14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
| 123 |
+
[14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
| 124 |
+
[14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
| 125 |
+
],
|
| 126 |
+
1.05: [
|
| 127 |
+
[14.61464119, 0.95350921, 0.02916753],
|
| 128 |
+
[14.61464119, 6.77309084, 0.89115214, 0.02916753],
|
| 129 |
+
[14.61464119, 6.77309084, 2.05039096, 0.72133851, 0.02916753],
|
| 130 |
+
[14.61464119, 6.77309084, 2.84484982, 1.28281462, 0.52423614, 0.02916753],
|
| 131 |
+
[14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.803307, 0.34370604, 0.02916753],
|
| 132 |
+
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.56271636, 0.803307, 0.34370604, 0.02916753],
|
| 133 |
+
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.52423614, 0.22545385, 0.02916753],
|
| 134 |
+
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.74807048, 0.41087446, 0.17026083, 0.02916753],
|
| 135 |
+
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.59516323, 0.34370604, 0.13792117, 0.02916753],
|
| 136 |
+
[14.61464119, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
| 137 |
+
[14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
| 138 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
| 139 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.72759056, 1.24153244, 0.86115354, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
| 140 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
|
| 141 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.67050016, 1.28281462, 0.95350921, 0.72133851, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
| 142 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.36326075, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.61951244, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
| 143 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
| 144 |
+
[14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
| 145 |
+
[14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.98035145, 1.61558151, 1.32549286, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
| 146 |
+
],
|
| 147 |
+
1.10: [
|
| 148 |
+
[14.61464119, 0.89115214, 0.02916753],
|
| 149 |
+
[14.61464119, 2.36326075, 0.72133851, 0.02916753],
|
| 150 |
+
[14.61464119, 5.85520077, 1.61558151, 0.57119018, 0.02916753],
|
| 151 |
+
[14.61464119, 6.77309084, 2.45070267, 1.08895338, 0.45573691, 0.02916753],
|
| 152 |
+
[14.61464119, 6.77309084, 2.95596409, 1.56271636, 0.803307, 0.34370604, 0.02916753],
|
| 153 |
+
[14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.89115214, 0.4783645, 0.19894916, 0.02916753],
|
| 154 |
+
[14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.08895338, 0.64427125, 0.34370604, 0.13792117, 0.02916753],
|
| 155 |
+
[14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.54755926, 0.27464288, 0.09824532, 0.02916753],
|
| 156 |
+
[14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.4783645, 0.25053367, 0.09824532, 0.02916753],
|
| 157 |
+
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.41535246, 0.95350921, 0.64427125, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
|
| 158 |
+
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.61558151, 1.12534678, 0.803307, 0.54755926, 0.36617002, 0.22545385, 0.09824532, 0.02916753],
|
| 159 |
+
[14.61464119, 7.49001646, 4.86714602, 3.32507086, 2.45070267, 1.72759056, 1.24153244, 0.89115214, 0.64427125, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
| 160 |
+
[14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.05039096, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
| 161 |
+
[14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.12350607, 1.61558151, 1.24153244, 0.95350921, 0.72133851, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
| 162 |
+
[14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 163 |
+
[14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 164 |
+
[14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 165 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 166 |
+
[14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
| 167 |
+
],
|
| 168 |
+
1.15: [
|
| 169 |
+
[14.61464119, 0.83188516, 0.02916753],
|
| 170 |
+
[14.61464119, 1.84880662, 0.59516323, 0.02916753],
|
| 171 |
+
[14.61464119, 5.85520077, 1.56271636, 0.52423614, 0.02916753],
|
| 172 |
+
[14.61464119, 5.85520077, 1.91321158, 0.83188516, 0.34370604, 0.02916753],
|
| 173 |
+
[14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.59516323, 0.25053367, 0.02916753],
|
| 174 |
+
[14.61464119, 5.85520077, 2.84484982, 1.51179266, 0.803307, 0.41087446, 0.17026083, 0.02916753],
|
| 175 |
+
[14.61464119, 5.85520077, 2.84484982, 1.56271636, 0.89115214, 0.50118381, 0.25053367, 0.09824532, 0.02916753],
|
| 176 |
+
[14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.12534678, 0.72133851, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
|
| 177 |
+
[14.61464119, 6.77309084, 3.07277966, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
| 178 |
+
[14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
| 179 |
+
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.36964464, 0.95350921, 0.69515091, 0.4783645, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
| 180 |
+
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
|
| 181 |
+
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 182 |
+
[14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 183 |
+
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
| 184 |
+
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 185 |
+
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 186 |
+
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 187 |
+
[14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 188 |
+
],
|
| 189 |
+
1.20: [
|
| 190 |
+
[14.61464119, 0.803307, 0.02916753],
|
| 191 |
+
[14.61464119, 1.56271636, 0.52423614, 0.02916753],
|
| 192 |
+
[14.61464119, 2.36326075, 0.92192322, 0.36617002, 0.02916753],
|
| 193 |
+
[14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.25053367, 0.02916753],
|
| 194 |
+
[14.61464119, 5.85520077, 2.05039096, 0.95350921, 0.45573691, 0.17026083, 0.02916753],
|
| 195 |
+
[14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.64427125, 0.29807833, 0.09824532, 0.02916753],
|
| 196 |
+
[14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.803307, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
|
| 197 |
+
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.95350921, 0.59516323, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
|
| 198 |
+
[14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
| 199 |
+
[14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.83188516, 0.59516323, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
| 200 |
+
[14.61464119, 5.85520077, 3.07277966, 1.98035145, 1.36964464, 0.95350921, 0.69515091, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 201 |
+
[14.61464119, 6.77309084, 3.46139455, 2.36326075, 1.56271636, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 202 |
+
[14.61464119, 6.77309084, 3.46139455, 2.45070267, 1.61558151, 1.162866, 0.86115354, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
| 203 |
+
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
| 204 |
+
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 205 |
+
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 206 |
+
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.20157266, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 207 |
+
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 208 |
+
[14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 209 |
+
],
|
| 210 |
+
1.25: [
|
| 211 |
+
[14.61464119, 0.72133851, 0.02916753],
|
| 212 |
+
[14.61464119, 1.56271636, 0.50118381, 0.02916753],
|
| 213 |
+
[14.61464119, 2.05039096, 0.803307, 0.32104823, 0.02916753],
|
| 214 |
+
[14.61464119, 2.36326075, 0.95350921, 0.43325692, 0.17026083, 0.02916753],
|
| 215 |
+
[14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.27464288, 0.09824532, 0.02916753],
|
| 216 |
+
[14.61464119, 3.07277966, 1.51179266, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
|
| 217 |
+
[14.61464119, 5.85520077, 2.36326075, 1.24153244, 0.72133851, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
|
| 218 |
+
[14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
| 219 |
+
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.98595673, 0.64427125, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
| 220 |
+
[14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.52423614, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 221 |
+
[14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 222 |
+
[14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.86115354, 0.64427125, 0.4783645, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 223 |
+
[14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.28281462, 0.92192322, 0.69515091, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 224 |
+
[14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 225 |
+
[14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 226 |
+
[14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 227 |
+
[14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 228 |
+
[14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 229 |
+
[14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.46270394, 1.08895338, 0.83188516, 0.66947293, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 230 |
+
],
|
| 231 |
+
1.30: [
|
| 232 |
+
[14.61464119, 0.72133851, 0.02916753],
|
| 233 |
+
[14.61464119, 1.24153244, 0.43325692, 0.02916753],
|
| 234 |
+
[14.61464119, 1.56271636, 0.59516323, 0.22545385, 0.02916753],
|
| 235 |
+
[14.61464119, 1.84880662, 0.803307, 0.36617002, 0.13792117, 0.02916753],
|
| 236 |
+
[14.61464119, 2.36326075, 1.01931262, 0.52423614, 0.25053367, 0.09824532, 0.02916753],
|
| 237 |
+
[14.61464119, 2.84484982, 1.36964464, 0.74807048, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
|
| 238 |
+
[14.61464119, 3.07277966, 1.56271636, 0.89115214, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
| 239 |
+
[14.61464119, 3.07277966, 1.61558151, 0.95350921, 0.61951244, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
| 240 |
+
[14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 241 |
+
[14.61464119, 5.85520077, 2.45070267, 1.41535246, 0.92192322, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 242 |
+
[14.61464119, 5.85520077, 2.6383388, 1.56271636, 1.01931262, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 243 |
+
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 244 |
+
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.77538133, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 245 |
+
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 246 |
+
[14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 247 |
+
[14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 248 |
+
[14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 249 |
+
[14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 250 |
+
[14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 251 |
+
],
|
| 252 |
+
1.35: [
|
| 253 |
+
[14.61464119, 0.69515091, 0.02916753],
|
| 254 |
+
[14.61464119, 0.95350921, 0.34370604, 0.02916753],
|
| 255 |
+
[14.61464119, 1.56271636, 0.57119018, 0.19894916, 0.02916753],
|
| 256 |
+
[14.61464119, 1.61558151, 0.69515091, 0.29807833, 0.09824532, 0.02916753],
|
| 257 |
+
[14.61464119, 1.84880662, 0.83188516, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
|
| 258 |
+
[14.61464119, 2.45070267, 1.162866, 0.64427125, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
|
| 259 |
+
[14.61464119, 2.84484982, 1.36964464, 0.803307, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
|
| 260 |
+
[14.61464119, 2.84484982, 1.41535246, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 261 |
+
[14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.32104823, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
|
| 262 |
+
[14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 263 |
+
[14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 264 |
+
[14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 265 |
+
[14.61464119, 3.07277966, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 266 |
+
[14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 267 |
+
[14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 268 |
+
[14.61464119, 5.85520077, 2.45070267, 1.51179266, 1.01931262, 0.74807048, 0.57119018, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 269 |
+
[14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 270 |
+
[14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 271 |
+
[14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 272 |
+
],
|
| 273 |
+
1.40: [
|
| 274 |
+
[14.61464119, 0.59516323, 0.02916753],
|
| 275 |
+
[14.61464119, 0.95350921, 0.34370604, 0.02916753],
|
| 276 |
+
[14.61464119, 1.08895338, 0.43325692, 0.13792117, 0.02916753],
|
| 277 |
+
[14.61464119, 1.56271636, 0.64427125, 0.27464288, 0.09824532, 0.02916753],
|
| 278 |
+
[14.61464119, 1.61558151, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
|
| 279 |
+
[14.61464119, 2.05039096, 0.95350921, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
|
| 280 |
+
[14.61464119, 2.45070267, 1.24153244, 0.72133851, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
| 281 |
+
[14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 282 |
+
[14.61464119, 2.45070267, 1.28281462, 0.803307, 0.52423614, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 283 |
+
[14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 284 |
+
[14.61464119, 2.84484982, 1.41535246, 0.86115354, 0.59516323, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 285 |
+
[14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 286 |
+
[14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.4783645, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 287 |
+
[14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 288 |
+
[14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.72133851, 0.54755926, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 289 |
+
[14.61464119, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 290 |
+
[14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 291 |
+
[14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.43325692, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 292 |
+
[14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 293 |
+
],
|
| 294 |
+
1.45: [
|
| 295 |
+
[14.61464119, 0.59516323, 0.02916753],
|
| 296 |
+
[14.61464119, 0.803307, 0.25053367, 0.02916753],
|
| 297 |
+
[14.61464119, 0.95350921, 0.34370604, 0.09824532, 0.02916753],
|
| 298 |
+
[14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753],
|
| 299 |
+
[14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
|
| 300 |
+
[14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
| 301 |
+
[14.61464119, 1.91321158, 0.95350921, 0.57119018, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 302 |
+
[14.61464119, 2.19988537, 1.08895338, 0.64427125, 0.41087446, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 303 |
+
[14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 304 |
+
[14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 305 |
+
[14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 306 |
+
[14.61464119, 2.45070267, 1.28281462, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 307 |
+
[14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 308 |
+
[14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 309 |
+
[14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 310 |
+
[14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 311 |
+
[14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.72133851, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 312 |
+
[14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 313 |
+
[14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.59516323, 0.50118381, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 314 |
+
],
|
| 315 |
+
1.50: [
|
| 316 |
+
[14.61464119, 0.54755926, 0.02916753],
|
| 317 |
+
[14.61464119, 0.803307, 0.25053367, 0.02916753],
|
| 318 |
+
[14.61464119, 0.86115354, 0.32104823, 0.09824532, 0.02916753],
|
| 319 |
+
[14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753],
|
| 320 |
+
[14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
|
| 321 |
+
[14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
|
| 322 |
+
[14.61464119, 1.61558151, 0.83188516, 0.52423614, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
|
| 323 |
+
[14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.38853383, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
|
| 324 |
+
[14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.41087446, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 325 |
+
[14.61464119, 1.84880662, 0.95350921, 0.61951244, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 326 |
+
[14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 327 |
+
[14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 328 |
+
[14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 329 |
+
[14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 330 |
+
[14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 331 |
+
[14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 332 |
+
[14.61464119, 2.45070267, 1.32549286, 0.86115354, 0.64427125, 0.50118381, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 333 |
+
[14.61464119, 2.45070267, 1.36964464, 0.92192322, 0.69515091, 0.54755926, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 334 |
+
[14.61464119, 2.45070267, 1.41535246, 0.95350921, 0.72133851, 0.57119018, 0.4783645, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
|
| 335 |
+
],
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
class GITSScheduler(io.ComfyNode):
|
| 339 |
+
@classmethod
|
| 340 |
+
def define_schema(cls):
|
| 341 |
+
return io.Schema(
|
| 342 |
+
node_id="GITSScheduler",
|
| 343 |
+
category="sampling/custom_sampling/schedulers",
|
| 344 |
+
inputs=[
|
| 345 |
+
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05, advanced=True),
|
| 346 |
+
io.Int.Input("steps", default=10, min=2, max=1000),
|
| 347 |
+
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
| 348 |
+
],
|
| 349 |
+
outputs=[
|
| 350 |
+
io.Sigmas.Output(),
|
| 351 |
+
],
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
@classmethod
|
| 355 |
+
def execute(cls, coeff, steps, denoise):
|
| 356 |
+
total_steps = steps
|
| 357 |
+
if denoise < 1.0:
|
| 358 |
+
if denoise <= 0.0:
|
| 359 |
+
return io.NodeOutput(torch.FloatTensor([]))
|
| 360 |
+
total_steps = round(steps * denoise)
|
| 361 |
+
|
| 362 |
+
if steps <= 20:
|
| 363 |
+
sigmas = NOISE_LEVELS[round(coeff, 2)][steps-2][:]
|
| 364 |
+
else:
|
| 365 |
+
sigmas = NOISE_LEVELS[round(coeff, 2)][-1][:]
|
| 366 |
+
sigmas = loglinear_interp(sigmas, steps + 1)
|
| 367 |
+
|
| 368 |
+
sigmas = sigmas[-(total_steps + 1):]
|
| 369 |
+
sigmas[-1] = 0
|
| 370 |
+
return io.NodeOutput(torch.FloatTensor(sigmas))
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class GITSSchedulerExtension(ComfyExtension):
|
| 374 |
+
@override
|
| 375 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 376 |
+
return [
|
| 377 |
+
GITSScheduler,
|
| 378 |
+
]
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
async def comfy_entrypoint() -> GITSSchedulerExtension:
|
| 382 |
+
return GITSSchedulerExtension()
|
ComfyUI/comfy_extras/nodes_glsl.py
ADDED
|
@@ -0,0 +1,958 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import re
|
| 4 |
+
import logging
|
| 5 |
+
import ctypes.util
|
| 6 |
+
import importlib.util
|
| 7 |
+
from typing import TypedDict
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import nodes
|
| 13 |
+
from comfy_api.latest import ComfyExtension, io, ui
|
| 14 |
+
from typing_extensions import override
|
| 15 |
+
from utils.install_util import get_missing_requirements_message
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _check_opengl_availability():
|
| 21 |
+
"""Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
|
| 22 |
+
logger.debug("_check_opengl_availability: starting")
|
| 23 |
+
missing = []
|
| 24 |
+
|
| 25 |
+
# Check Python packages (using find_spec to avoid importing)
|
| 26 |
+
logger.debug("_check_opengl_availability: checking for glfw package")
|
| 27 |
+
if importlib.util.find_spec("glfw") is None:
|
| 28 |
+
missing.append("glfw")
|
| 29 |
+
|
| 30 |
+
logger.debug("_check_opengl_availability: checking for OpenGL package")
|
| 31 |
+
if importlib.util.find_spec("OpenGL") is None:
|
| 32 |
+
missing.append("PyOpenGL")
|
| 33 |
+
|
| 34 |
+
if missing:
|
| 35 |
+
raise RuntimeError(
|
| 36 |
+
f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# On Linux without display, check if headless backends are available
|
| 40 |
+
logger.debug(f"_check_opengl_availability: platform={sys.platform}")
|
| 41 |
+
if sys.platform.startswith("linux"):
|
| 42 |
+
has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
|
| 43 |
+
logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
|
| 44 |
+
if not has_display:
|
| 45 |
+
# Check for EGL or OSMesa libraries
|
| 46 |
+
logger.debug("_check_opengl_availability: checking for EGL library")
|
| 47 |
+
has_egl = ctypes.util.find_library("EGL")
|
| 48 |
+
logger.debug("_check_opengl_availability: checking for OSMesa library")
|
| 49 |
+
has_osmesa = ctypes.util.find_library("OSMesa")
|
| 50 |
+
|
| 51 |
+
# Error disabled for CI as it fails this check
|
| 52 |
+
# if not has_egl and not has_osmesa:
|
| 53 |
+
# raise RuntimeError(
|
| 54 |
+
# "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
|
| 55 |
+
# "See error below for installation instructions."
|
| 56 |
+
# )
|
| 57 |
+
logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
|
| 58 |
+
|
| 59 |
+
logger.debug("_check_opengl_availability: completed")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Run early check at import time
|
| 63 |
+
logger.debug("nodes_glsl: running _check_opengl_availability at import time")
|
| 64 |
+
_check_opengl_availability()
|
| 65 |
+
|
| 66 |
+
# OpenGL modules - initialized lazily when context is created
|
| 67 |
+
gl = None
|
| 68 |
+
glfw = None
|
| 69 |
+
EGL = None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _import_opengl():
|
| 73 |
+
"""Import OpenGL module. Called after context is created."""
|
| 74 |
+
global gl
|
| 75 |
+
if gl is None:
|
| 76 |
+
logger.debug("_import_opengl: importing OpenGL.GL")
|
| 77 |
+
import OpenGL.GL as _gl
|
| 78 |
+
gl = _gl
|
| 79 |
+
logger.debug("_import_opengl: import completed")
|
| 80 |
+
return gl
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SizeModeInput(TypedDict):
|
| 84 |
+
size_mode: str
|
| 85 |
+
width: int
|
| 86 |
+
height: int
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
MAX_IMAGES = 5 # u_image0-4
|
| 90 |
+
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
|
| 91 |
+
MAX_BOOLS = 10 # u_bool0-9
|
| 92 |
+
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
|
| 93 |
+
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
| 94 |
+
|
| 95 |
+
# Vertex shader using gl_VertexID trick - no VBO needed.
|
| 96 |
+
# Draws a single triangle that covers the entire screen:
|
| 97 |
+
#
|
| 98 |
+
# (-1,3)
|
| 99 |
+
# /|
|
| 100 |
+
# / | <- visible area is the unit square from (-1,-1) to (1,1)
|
| 101 |
+
# / | parts outside get clipped away
|
| 102 |
+
# (-1,-1)---(3,-1)
|
| 103 |
+
#
|
| 104 |
+
# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
|
| 105 |
+
VERTEX_SHADER = """#version 330 core
|
| 106 |
+
out vec2 v_texCoord;
|
| 107 |
+
void main() {
|
| 108 |
+
vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
|
| 109 |
+
v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
|
| 110 |
+
gl_Position = vec4(verts[gl_VertexID], 0, 1);
|
| 111 |
+
}
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
DEFAULT_FRAGMENT_SHADER = """#version 300 es
|
| 115 |
+
precision highp float;
|
| 116 |
+
|
| 117 |
+
uniform sampler2D u_image0;
|
| 118 |
+
uniform vec2 u_resolution;
|
| 119 |
+
|
| 120 |
+
in vec2 v_texCoord;
|
| 121 |
+
layout(location = 0) out vec4 fragColor0;
|
| 122 |
+
|
| 123 |
+
void main() {
|
| 124 |
+
fragColor0 = texture(u_image0, v_texCoord);
|
| 125 |
+
}
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _convert_es_to_desktop(source: str) -> str:
|
| 130 |
+
"""Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
|
| 131 |
+
# Remove any existing #version directive
|
| 132 |
+
source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
|
| 133 |
+
# Remove precision qualifiers (not needed in desktop GLSL)
|
| 134 |
+
source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
|
| 135 |
+
# Prepend desktop GLSL version
|
| 136 |
+
return "#version 330 core\n" + source
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _detect_output_count(source: str) -> int:
|
| 140 |
+
"""Detect how many fragColor outputs are used in the shader.
|
| 141 |
+
|
| 142 |
+
Returns the count of outputs needed (1 to MAX_OUTPUTS).
|
| 143 |
+
"""
|
| 144 |
+
matches = re.findall(r"fragColor(\d+)", source)
|
| 145 |
+
if not matches:
|
| 146 |
+
return 1 # Default to 1 output if none found
|
| 147 |
+
max_index = max(int(m) for m in matches)
|
| 148 |
+
return min(max_index + 1, MAX_OUTPUTS)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _detect_pass_count(source: str) -> int:
|
| 152 |
+
"""Detect multi-pass rendering from #pragma passes N directive.
|
| 153 |
+
|
| 154 |
+
Returns the number of passes (1 if not specified).
|
| 155 |
+
"""
|
| 156 |
+
match = re.search(r'#pragma\s+passes\s+(\d+)', source)
|
| 157 |
+
if match:
|
| 158 |
+
return max(1, int(match.group(1)))
|
| 159 |
+
return 1
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _init_glfw():
|
| 163 |
+
"""Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
|
| 164 |
+
logger.debug("_init_glfw: starting")
|
| 165 |
+
# On macOS, glfw.init() must be called from main thread or it hangs forever
|
| 166 |
+
if sys.platform == "darwin":
|
| 167 |
+
logger.debug("_init_glfw: skipping on macOS")
|
| 168 |
+
raise RuntimeError("GLFW backend not supported on macOS")
|
| 169 |
+
|
| 170 |
+
logger.debug("_init_glfw: importing glfw module")
|
| 171 |
+
import glfw as _glfw
|
| 172 |
+
|
| 173 |
+
logger.debug("_init_glfw: calling glfw.init()")
|
| 174 |
+
if not _glfw.init():
|
| 175 |
+
raise RuntimeError("glfw.init() failed")
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
logger.debug("_init_glfw: setting window hints")
|
| 179 |
+
_glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
|
| 180 |
+
_glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
|
| 181 |
+
_glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
|
| 182 |
+
_glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
|
| 183 |
+
|
| 184 |
+
logger.debug("_init_glfw: calling create_window()")
|
| 185 |
+
window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
|
| 186 |
+
if not window:
|
| 187 |
+
raise RuntimeError("glfw.create_window() failed")
|
| 188 |
+
|
| 189 |
+
logger.debug("_init_glfw: calling make_context_current()")
|
| 190 |
+
_glfw.make_context_current(window)
|
| 191 |
+
logger.debug("_init_glfw: completed successfully")
|
| 192 |
+
return window, _glfw
|
| 193 |
+
except Exception:
|
| 194 |
+
logger.debug("_init_glfw: failed, terminating glfw")
|
| 195 |
+
_glfw.terminate()
|
| 196 |
+
raise
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _init_egl():
|
| 200 |
+
"""Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
|
| 201 |
+
logger.debug("_init_egl: starting")
|
| 202 |
+
from OpenGL import EGL as _EGL
|
| 203 |
+
from OpenGL.EGL import (
|
| 204 |
+
eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
|
| 205 |
+
eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
|
| 206 |
+
eglTerminate, eglDestroyContext, eglDestroySurface,
|
| 207 |
+
EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
|
| 208 |
+
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
| 209 |
+
EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
|
| 210 |
+
EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
|
| 211 |
+
)
|
| 212 |
+
logger.debug("_init_egl: imports completed")
|
| 213 |
+
|
| 214 |
+
display = None
|
| 215 |
+
context = None
|
| 216 |
+
surface = None
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
logger.debug("_init_egl: calling eglGetDisplay()")
|
| 220 |
+
display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
|
| 221 |
+
if display == _EGL.EGL_NO_DISPLAY:
|
| 222 |
+
raise RuntimeError("eglGetDisplay() failed")
|
| 223 |
+
|
| 224 |
+
logger.debug("_init_egl: calling eglInitialize()")
|
| 225 |
+
major, minor = _EGL.EGLint(), _EGL.EGLint()
|
| 226 |
+
if not eglInitialize(display, major, minor):
|
| 227 |
+
display = None # Not initialized, don't terminate
|
| 228 |
+
raise RuntimeError("eglInitialize() failed")
|
| 229 |
+
logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
|
| 230 |
+
|
| 231 |
+
config_attribs = [
|
| 232 |
+
EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
|
| 233 |
+
EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
|
| 234 |
+
EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
|
| 235 |
+
EGL_DEPTH_SIZE, 0, EGL_NONE
|
| 236 |
+
]
|
| 237 |
+
configs = (_EGL.EGLConfig * 1)()
|
| 238 |
+
num_configs = _EGL.EGLint()
|
| 239 |
+
if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
|
| 240 |
+
raise RuntimeError("eglChooseConfig() failed")
|
| 241 |
+
config = configs[0]
|
| 242 |
+
logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
|
| 243 |
+
|
| 244 |
+
if not eglBindAPI(EGL_OPENGL_API):
|
| 245 |
+
raise RuntimeError("eglBindAPI() failed")
|
| 246 |
+
|
| 247 |
+
logger.debug("_init_egl: calling eglCreateContext()")
|
| 248 |
+
context_attribs = [
|
| 249 |
+
_EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
|
| 250 |
+
_EGL.EGL_CONTEXT_MINOR_VERSION, 3,
|
| 251 |
+
_EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
|
| 252 |
+
EGL_NONE
|
| 253 |
+
]
|
| 254 |
+
context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
|
| 255 |
+
if context == EGL_NO_CONTEXT:
|
| 256 |
+
raise RuntimeError("eglCreateContext() failed")
|
| 257 |
+
|
| 258 |
+
logger.debug("_init_egl: calling eglCreatePbufferSurface()")
|
| 259 |
+
pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
|
| 260 |
+
surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
|
| 261 |
+
if surface == _EGL.EGL_NO_SURFACE:
|
| 262 |
+
raise RuntimeError("eglCreatePbufferSurface() failed")
|
| 263 |
+
|
| 264 |
+
logger.debug("_init_egl: calling eglMakeCurrent()")
|
| 265 |
+
if not eglMakeCurrent(display, surface, surface, context):
|
| 266 |
+
raise RuntimeError("eglMakeCurrent() failed")
|
| 267 |
+
|
| 268 |
+
logger.debug("_init_egl: completed successfully")
|
| 269 |
+
return display, context, surface, _EGL
|
| 270 |
+
|
| 271 |
+
except Exception:
|
| 272 |
+
logger.debug("_init_egl: failed, cleaning up")
|
| 273 |
+
# Clean up any resources on failure
|
| 274 |
+
if surface is not None:
|
| 275 |
+
eglDestroySurface(display, surface)
|
| 276 |
+
if context is not None:
|
| 277 |
+
eglDestroyContext(display, context)
|
| 278 |
+
if display is not None:
|
| 279 |
+
eglTerminate(display)
|
| 280 |
+
raise
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _init_osmesa():
|
| 284 |
+
"""Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
|
| 285 |
+
import ctypes
|
| 286 |
+
|
| 287 |
+
logger.debug("_init_osmesa: starting")
|
| 288 |
+
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
| 289 |
+
|
| 290 |
+
logger.debug("_init_osmesa: importing OpenGL.osmesa")
|
| 291 |
+
from OpenGL import GL as _gl
|
| 292 |
+
from OpenGL.osmesa import (
|
| 293 |
+
OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
|
| 294 |
+
OSMESA_RGBA,
|
| 295 |
+
)
|
| 296 |
+
logger.debug("_init_osmesa: imports completed")
|
| 297 |
+
|
| 298 |
+
ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
|
| 299 |
+
if not ctx:
|
| 300 |
+
raise RuntimeError("OSMesaCreateContextExt() failed")
|
| 301 |
+
|
| 302 |
+
width, height = 64, 64
|
| 303 |
+
buffer = (ctypes.c_ubyte * (width * height * 4))()
|
| 304 |
+
|
| 305 |
+
logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
|
| 306 |
+
if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
|
| 307 |
+
OSMesaDestroyContext(ctx)
|
| 308 |
+
raise RuntimeError("OSMesaMakeCurrent() failed")
|
| 309 |
+
|
| 310 |
+
logger.debug("_init_osmesa: completed successfully")
|
| 311 |
+
return ctx, buffer
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class GLContext:
|
| 315 |
+
"""Manages OpenGL context and resources for shader execution.
|
| 316 |
+
|
| 317 |
+
Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
_instance = None
|
| 321 |
+
_initialized = False
|
| 322 |
+
|
| 323 |
+
def __new__(cls):
|
| 324 |
+
if cls._instance is None:
|
| 325 |
+
cls._instance = super().__new__(cls)
|
| 326 |
+
return cls._instance
|
| 327 |
+
|
| 328 |
+
def __init__(self):
|
| 329 |
+
if GLContext._initialized:
|
| 330 |
+
logger.debug("GLContext.__init__: already initialized, skipping")
|
| 331 |
+
return
|
| 332 |
+
|
| 333 |
+
logger.debug("GLContext.__init__: starting initialization")
|
| 334 |
+
|
| 335 |
+
global glfw, EGL
|
| 336 |
+
|
| 337 |
+
import time
|
| 338 |
+
start = time.perf_counter()
|
| 339 |
+
|
| 340 |
+
self._backend = None
|
| 341 |
+
self._window = None
|
| 342 |
+
self._egl_display = None
|
| 343 |
+
self._egl_context = None
|
| 344 |
+
self._egl_surface = None
|
| 345 |
+
self._osmesa_ctx = None
|
| 346 |
+
self._osmesa_buffer = None
|
| 347 |
+
self._vao = None
|
| 348 |
+
|
| 349 |
+
# Try backends in order: GLFW → EGL → OSMesa
|
| 350 |
+
errors = []
|
| 351 |
+
|
| 352 |
+
logger.debug("GLContext.__init__: trying GLFW backend")
|
| 353 |
+
try:
|
| 354 |
+
self._window, glfw = _init_glfw()
|
| 355 |
+
self._backend = "glfw"
|
| 356 |
+
logger.debug("GLContext.__init__: GLFW backend succeeded")
|
| 357 |
+
except Exception as e:
|
| 358 |
+
logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
|
| 359 |
+
errors.append(("GLFW", e))
|
| 360 |
+
|
| 361 |
+
if self._backend is None:
|
| 362 |
+
logger.debug("GLContext.__init__: trying EGL backend")
|
| 363 |
+
try:
|
| 364 |
+
self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
|
| 365 |
+
self._backend = "egl"
|
| 366 |
+
logger.debug("GLContext.__init__: EGL backend succeeded")
|
| 367 |
+
except Exception as e:
|
| 368 |
+
logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
|
| 369 |
+
errors.append(("EGL", e))
|
| 370 |
+
|
| 371 |
+
if self._backend is None:
|
| 372 |
+
logger.debug("GLContext.__init__: trying OSMesa backend")
|
| 373 |
+
try:
|
| 374 |
+
self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
|
| 375 |
+
self._backend = "osmesa"
|
| 376 |
+
logger.debug("GLContext.__init__: OSMesa backend succeeded")
|
| 377 |
+
except Exception as e:
|
| 378 |
+
logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
|
| 379 |
+
errors.append(("OSMesa", e))
|
| 380 |
+
|
| 381 |
+
if self._backend is None:
|
| 382 |
+
if sys.platform == "win32":
|
| 383 |
+
platform_help = (
|
| 384 |
+
"Windows: Ensure GPU drivers are installed and display is available.\n"
|
| 385 |
+
" CPU-only/headless mode is not supported on Windows."
|
| 386 |
+
)
|
| 387 |
+
elif sys.platform == "darwin":
|
| 388 |
+
platform_help = (
|
| 389 |
+
"macOS: GLFW is not supported.\n"
|
| 390 |
+
" Install OSMesa via Homebrew: brew install mesa\n"
|
| 391 |
+
" Then: pip install PyOpenGL PyOpenGL-accelerate"
|
| 392 |
+
)
|
| 393 |
+
else:
|
| 394 |
+
platform_help = (
|
| 395 |
+
"Linux: Install one of these backends:\n"
|
| 396 |
+
" Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
|
| 397 |
+
" Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
|
| 398 |
+
" Headless (CPU): sudo apt install libosmesa6"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
error_details = "\n".join(f" {name}: {err}" for name, err in errors)
|
| 402 |
+
raise RuntimeError(
|
| 403 |
+
f"Failed to create OpenGL context.\n\n"
|
| 404 |
+
f"Backend errors:\n{error_details}\n\n"
|
| 405 |
+
f"{platform_help}"
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Now import OpenGL.GL (after context is current)
|
| 409 |
+
logger.debug("GLContext.__init__: importing OpenGL.GL")
|
| 410 |
+
_import_opengl()
|
| 411 |
+
|
| 412 |
+
# Create VAO (required for core profile, but OSMesa may use compat profile)
|
| 413 |
+
logger.debug("GLContext.__init__: creating VAO")
|
| 414 |
+
try:
|
| 415 |
+
vao = gl.glGenVertexArrays(1)
|
| 416 |
+
gl.glBindVertexArray(vao)
|
| 417 |
+
self._vao = vao # Only store after successful bind
|
| 418 |
+
logger.debug("GLContext.__init__: VAO created successfully")
|
| 419 |
+
except Exception as e:
|
| 420 |
+
logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
|
| 421 |
+
# OSMesa with older Mesa may not support VAOs
|
| 422 |
+
# Clean up if we created but couldn't bind
|
| 423 |
+
if vao:
|
| 424 |
+
try:
|
| 425 |
+
gl.glDeleteVertexArrays(1, [vao])
|
| 426 |
+
except Exception:
|
| 427 |
+
pass
|
| 428 |
+
|
| 429 |
+
elapsed = (time.perf_counter() - start) * 1000
|
| 430 |
+
|
| 431 |
+
# Log device info
|
| 432 |
+
renderer = gl.glGetString(gl.GL_RENDERER)
|
| 433 |
+
vendor = gl.glGetString(gl.GL_VENDOR)
|
| 434 |
+
version = gl.glGetString(gl.GL_VERSION)
|
| 435 |
+
renderer = renderer.decode() if renderer else "Unknown"
|
| 436 |
+
vendor = vendor.decode() if vendor else "Unknown"
|
| 437 |
+
version = version.decode() if version else "Unknown"
|
| 438 |
+
|
| 439 |
+
GLContext._initialized = True
|
| 440 |
+
logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
|
| 441 |
+
|
| 442 |
+
def make_current(self):
|
| 443 |
+
if self._backend == "glfw":
|
| 444 |
+
glfw.make_context_current(self._window)
|
| 445 |
+
elif self._backend == "egl":
|
| 446 |
+
from OpenGL.EGL import eglMakeCurrent
|
| 447 |
+
eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
|
| 448 |
+
elif self._backend == "osmesa":
|
| 449 |
+
from OpenGL.osmesa import OSMesaMakeCurrent
|
| 450 |
+
OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
|
| 451 |
+
|
| 452 |
+
if self._vao is not None:
|
| 453 |
+
gl.glBindVertexArray(self._vao)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def _compile_shader(source: str, shader_type: int) -> int:
|
| 457 |
+
"""Compile a shader and return its ID."""
|
| 458 |
+
shader = gl.glCreateShader(shader_type)
|
| 459 |
+
gl.glShaderSource(shader, source)
|
| 460 |
+
gl.glCompileShader(shader)
|
| 461 |
+
|
| 462 |
+
if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
|
| 463 |
+
error = gl.glGetShaderInfoLog(shader).decode()
|
| 464 |
+
gl.glDeleteShader(shader)
|
| 465 |
+
raise RuntimeError(f"Shader compilation failed:\n{error}")
|
| 466 |
+
|
| 467 |
+
return shader
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def _create_program(vertex_source: str, fragment_source: str) -> int:
|
| 471 |
+
"""Create and link a shader program."""
|
| 472 |
+
vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
|
| 473 |
+
try:
|
| 474 |
+
fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
|
| 475 |
+
except RuntimeError:
|
| 476 |
+
gl.glDeleteShader(vertex_shader)
|
| 477 |
+
raise
|
| 478 |
+
|
| 479 |
+
program = gl.glCreateProgram()
|
| 480 |
+
gl.glAttachShader(program, vertex_shader)
|
| 481 |
+
gl.glAttachShader(program, fragment_shader)
|
| 482 |
+
gl.glLinkProgram(program)
|
| 483 |
+
|
| 484 |
+
gl.glDeleteShader(vertex_shader)
|
| 485 |
+
gl.glDeleteShader(fragment_shader)
|
| 486 |
+
|
| 487 |
+
if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
|
| 488 |
+
error = gl.glGetProgramInfoLog(program).decode()
|
| 489 |
+
gl.glDeleteProgram(program)
|
| 490 |
+
raise RuntimeError(f"Program linking failed:\n{error}")
|
| 491 |
+
|
| 492 |
+
return program
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def _render_shader_batch(
|
| 496 |
+
fragment_code: str,
|
| 497 |
+
width: int,
|
| 498 |
+
height: int,
|
| 499 |
+
image_batches: list[list[np.ndarray]],
|
| 500 |
+
floats: list[float],
|
| 501 |
+
ints: list[int],
|
| 502 |
+
bools: list[bool] | None = None,
|
| 503 |
+
curves: list[np.ndarray] | None = None,
|
| 504 |
+
) -> list[list[np.ndarray]]:
|
| 505 |
+
"""
|
| 506 |
+
Render a fragment shader for multiple batches efficiently.
|
| 507 |
+
|
| 508 |
+
Compiles shader once, reuses framebuffer/textures across batches.
|
| 509 |
+
Supports multi-pass rendering via #pragma passes N directive.
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
fragment_code: User's fragment shader code
|
| 513 |
+
width: Output width
|
| 514 |
+
height: Output height
|
| 515 |
+
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
| 516 |
+
floats: List of float uniforms
|
| 517 |
+
ints: List of int uniforms
|
| 518 |
+
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
|
| 519 |
+
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
| 523 |
+
"""
|
| 524 |
+
import time
|
| 525 |
+
start_time = time.perf_counter()
|
| 526 |
+
|
| 527 |
+
if not image_batches:
|
| 528 |
+
return []
|
| 529 |
+
|
| 530 |
+
ctx = GLContext()
|
| 531 |
+
ctx.make_current()
|
| 532 |
+
|
| 533 |
+
# Convert from GLSL ES to desktop GLSL 330
|
| 534 |
+
fragment_source = _convert_es_to_desktop(fragment_code)
|
| 535 |
+
|
| 536 |
+
# Detect how many outputs the shader actually uses
|
| 537 |
+
num_outputs = _detect_output_count(fragment_code)
|
| 538 |
+
|
| 539 |
+
# Detect multi-pass rendering
|
| 540 |
+
num_passes = _detect_pass_count(fragment_code)
|
| 541 |
+
|
| 542 |
+
if bools is None:
|
| 543 |
+
bools = []
|
| 544 |
+
if curves is None:
|
| 545 |
+
curves = []
|
| 546 |
+
|
| 547 |
+
# Track resources for cleanup
|
| 548 |
+
program = None
|
| 549 |
+
fbo = None
|
| 550 |
+
output_textures = []
|
| 551 |
+
input_textures = []
|
| 552 |
+
curve_textures = []
|
| 553 |
+
ping_pong_textures = []
|
| 554 |
+
ping_pong_fbos = []
|
| 555 |
+
|
| 556 |
+
num_inputs = len(image_batches[0])
|
| 557 |
+
|
| 558 |
+
try:
|
| 559 |
+
# Compile shaders (once for all batches)
|
| 560 |
+
try:
|
| 561 |
+
program = _create_program(VERTEX_SHADER, fragment_source)
|
| 562 |
+
except RuntimeError:
|
| 563 |
+
logger.error(f"Fragment shader:\n{fragment_source}")
|
| 564 |
+
raise
|
| 565 |
+
|
| 566 |
+
gl.glUseProgram(program)
|
| 567 |
+
|
| 568 |
+
# Create framebuffer with only the needed color attachments
|
| 569 |
+
fbo = gl.glGenFramebuffers(1)
|
| 570 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
| 571 |
+
|
| 572 |
+
draw_buffers = []
|
| 573 |
+
for i in range(num_outputs):
|
| 574 |
+
tex = gl.glGenTextures(1)
|
| 575 |
+
output_textures.append(tex)
|
| 576 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
| 577 |
+
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
|
| 578 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
| 579 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
| 580 |
+
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
|
| 581 |
+
draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
|
| 582 |
+
|
| 583 |
+
gl.glDrawBuffers(num_outputs, draw_buffers)
|
| 584 |
+
|
| 585 |
+
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
|
| 586 |
+
raise RuntimeError("Framebuffer is not complete")
|
| 587 |
+
|
| 588 |
+
# Create ping-pong resources for multi-pass rendering
|
| 589 |
+
if num_passes > 1:
|
| 590 |
+
for _ in range(2):
|
| 591 |
+
pp_tex = gl.glGenTextures(1)
|
| 592 |
+
ping_pong_textures.append(pp_tex)
|
| 593 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, pp_tex)
|
| 594 |
+
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
|
| 595 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
| 596 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
| 597 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
| 598 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
| 599 |
+
|
| 600 |
+
pp_fbo = gl.glGenFramebuffers(1)
|
| 601 |
+
ping_pong_fbos.append(pp_fbo)
|
| 602 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, pp_fbo)
|
| 603 |
+
gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, pp_tex, 0)
|
| 604 |
+
gl.glDrawBuffers(1, [gl.GL_COLOR_ATTACHMENT0])
|
| 605 |
+
|
| 606 |
+
if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
|
| 607 |
+
raise RuntimeError("Ping-pong framebuffer is not complete")
|
| 608 |
+
|
| 609 |
+
# Create input textures (reused for all batches)
|
| 610 |
+
for i in range(num_inputs):
|
| 611 |
+
tex = gl.glGenTextures(1)
|
| 612 |
+
input_textures.append(tex)
|
| 613 |
+
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
| 614 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
| 615 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
| 616 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
| 617 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
| 618 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
| 619 |
+
|
| 620 |
+
loc = gl.glGetUniformLocation(program, f"u_image{i}")
|
| 621 |
+
if loc >= 0:
|
| 622 |
+
gl.glUniform1i(loc, i)
|
| 623 |
+
|
| 624 |
+
# Set static uniforms (once for all batches)
|
| 625 |
+
loc = gl.glGetUniformLocation(program, "u_resolution")
|
| 626 |
+
if loc >= 0:
|
| 627 |
+
gl.glUniform2f(loc, float(width), float(height))
|
| 628 |
+
|
| 629 |
+
for i, v in enumerate(floats):
|
| 630 |
+
loc = gl.glGetUniformLocation(program, f"u_float{i}")
|
| 631 |
+
if loc >= 0:
|
| 632 |
+
gl.glUniform1f(loc, v)
|
| 633 |
+
|
| 634 |
+
for i, v in enumerate(ints):
|
| 635 |
+
loc = gl.glGetUniformLocation(program, f"u_int{i}")
|
| 636 |
+
if loc >= 0:
|
| 637 |
+
gl.glUniform1i(loc, v)
|
| 638 |
+
|
| 639 |
+
for i, v in enumerate(bools):
|
| 640 |
+
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
|
| 641 |
+
if loc >= 0:
|
| 642 |
+
gl.glUniform1i(loc, 1 if v else 0)
|
| 643 |
+
|
| 644 |
+
# Create 1D LUT textures for curves (bound after image texture units)
|
| 645 |
+
for i, lut in enumerate(curves):
|
| 646 |
+
tex = gl.glGenTextures(1)
|
| 647 |
+
curve_textures.append(tex)
|
| 648 |
+
unit = MAX_IMAGES + i
|
| 649 |
+
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
|
| 650 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
| 651 |
+
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
|
| 652 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
| 653 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
| 654 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
| 655 |
+
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
| 656 |
+
|
| 657 |
+
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
|
| 658 |
+
if loc >= 0:
|
| 659 |
+
gl.glUniform1i(loc, unit)
|
| 660 |
+
|
| 661 |
+
# Get u_pass uniform location for multi-pass
|
| 662 |
+
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
| 663 |
+
|
| 664 |
+
gl.glViewport(0, 0, width, height)
|
| 665 |
+
gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
|
| 666 |
+
|
| 667 |
+
# Process each batch
|
| 668 |
+
all_batch_outputs = []
|
| 669 |
+
for images in image_batches:
|
| 670 |
+
# Update input textures with this batch's images
|
| 671 |
+
for i, img in enumerate(images):
|
| 672 |
+
gl.glActiveTexture(gl.GL_TEXTURE0 + i)
|
| 673 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
|
| 674 |
+
|
| 675 |
+
# Flip vertically for GL coordinates, ensure RGBA
|
| 676 |
+
h, w, c = img.shape
|
| 677 |
+
if c == 3:
|
| 678 |
+
img_upload = np.empty((h, w, 4), dtype=np.float32)
|
| 679 |
+
img_upload[:, :, :3] = img[::-1, :, :]
|
| 680 |
+
img_upload[:, :, 3] = 1.0
|
| 681 |
+
else:
|
| 682 |
+
img_upload = np.ascontiguousarray(img[::-1, :, :])
|
| 683 |
+
|
| 684 |
+
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload)
|
| 685 |
+
|
| 686 |
+
if num_passes == 1:
|
| 687 |
+
# Single pass - render directly to output FBO
|
| 688 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
| 689 |
+
if pass_loc >= 0:
|
| 690 |
+
gl.glUniform1i(pass_loc, 0)
|
| 691 |
+
gl.glClearColor(0, 0, 0, 0)
|
| 692 |
+
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
| 693 |
+
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
| 694 |
+
else:
|
| 695 |
+
# Multi-pass rendering with ping-pong
|
| 696 |
+
for p in range(num_passes):
|
| 697 |
+
is_last_pass = (p == num_passes - 1)
|
| 698 |
+
|
| 699 |
+
# Set pass uniform
|
| 700 |
+
if pass_loc >= 0:
|
| 701 |
+
gl.glUniform1i(pass_loc, p)
|
| 702 |
+
|
| 703 |
+
if is_last_pass:
|
| 704 |
+
# Last pass renders to the main output FBO
|
| 705 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
|
| 706 |
+
else:
|
| 707 |
+
# Intermediate passes render to ping-pong FBO
|
| 708 |
+
target_fbo = ping_pong_fbos[p % 2]
|
| 709 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, target_fbo)
|
| 710 |
+
|
| 711 |
+
# Set input texture for this pass
|
| 712 |
+
gl.glActiveTexture(gl.GL_TEXTURE0)
|
| 713 |
+
if p == 0:
|
| 714 |
+
# First pass reads from original input
|
| 715 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[0])
|
| 716 |
+
else:
|
| 717 |
+
# Subsequent passes read from previous pass output
|
| 718 |
+
source_tex = ping_pong_textures[(p - 1) % 2]
|
| 719 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, source_tex)
|
| 720 |
+
|
| 721 |
+
gl.glClearColor(0, 0, 0, 0)
|
| 722 |
+
gl.glClear(gl.GL_COLOR_BUFFER_BIT)
|
| 723 |
+
gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
|
| 724 |
+
|
| 725 |
+
# Read back outputs for this batch
|
| 726 |
+
# (glGetTexImage is synchronous, implicitly waits for rendering)
|
| 727 |
+
batch_outputs = []
|
| 728 |
+
for tex in output_textures:
|
| 729 |
+
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
| 730 |
+
data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
|
| 731 |
+
img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
|
| 732 |
+
batch_outputs.append(img[::-1, :, :].copy())
|
| 733 |
+
|
| 734 |
+
# Pad with black images for unused outputs
|
| 735 |
+
black_img = np.zeros((height, width, 4), dtype=np.float32)
|
| 736 |
+
for _ in range(num_outputs, MAX_OUTPUTS):
|
| 737 |
+
batch_outputs.append(black_img)
|
| 738 |
+
|
| 739 |
+
all_batch_outputs.append(batch_outputs)
|
| 740 |
+
|
| 741 |
+
elapsed = (time.perf_counter() - start_time) * 1000
|
| 742 |
+
num_batches = len(image_batches)
|
| 743 |
+
pass_info = f", {num_passes} passes" if num_passes > 1 else ""
|
| 744 |
+
logger.info(f"GLSL shader executed in {elapsed:.1f}ms ({num_batches} batch{'es' if num_batches != 1 else ''}, {width}x{height}{pass_info})")
|
| 745 |
+
|
| 746 |
+
return all_batch_outputs
|
| 747 |
+
|
| 748 |
+
finally:
|
| 749 |
+
# Unbind before deleting
|
| 750 |
+
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
|
| 751 |
+
gl.glUseProgram(0)
|
| 752 |
+
|
| 753 |
+
for tex in input_textures:
|
| 754 |
+
gl.glDeleteTextures(int(tex))
|
| 755 |
+
for tex in curve_textures:
|
| 756 |
+
gl.glDeleteTextures(int(tex))
|
| 757 |
+
for tex in output_textures:
|
| 758 |
+
gl.glDeleteTextures(int(tex))
|
| 759 |
+
for tex in ping_pong_textures:
|
| 760 |
+
gl.glDeleteTextures(int(tex))
|
| 761 |
+
if fbo is not None:
|
| 762 |
+
gl.glDeleteFramebuffers(1, [fbo])
|
| 763 |
+
for pp_fbo in ping_pong_fbos:
|
| 764 |
+
gl.glDeleteFramebuffers(1, [pp_fbo])
|
| 765 |
+
if program is not None:
|
| 766 |
+
gl.glDeleteProgram(program)
|
| 767 |
+
|
| 768 |
+
class GLSLShader(io.ComfyNode):
|
| 769 |
+
|
| 770 |
+
@classmethod
|
| 771 |
+
def define_schema(cls) -> io.Schema:
|
| 772 |
+
image_template = io.Autogrow.TemplatePrefix(
|
| 773 |
+
io.Image.Input("image"),
|
| 774 |
+
prefix="image",
|
| 775 |
+
min=1,
|
| 776 |
+
max=MAX_IMAGES,
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
float_template = io.Autogrow.TemplatePrefix(
|
| 780 |
+
io.Float.Input("float", default=0.0),
|
| 781 |
+
prefix="u_float",
|
| 782 |
+
min=0,
|
| 783 |
+
max=MAX_UNIFORMS,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
int_template = io.Autogrow.TemplatePrefix(
|
| 787 |
+
io.Int.Input("int", default=0),
|
| 788 |
+
prefix="u_int",
|
| 789 |
+
min=0,
|
| 790 |
+
max=MAX_UNIFORMS,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
bool_template = io.Autogrow.TemplatePrefix(
|
| 794 |
+
io.Boolean.Input("bool", default=False),
|
| 795 |
+
prefix="u_bool",
|
| 796 |
+
min=0,
|
| 797 |
+
max=MAX_BOOLS,
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
curve_template = io.Autogrow.TemplatePrefix(
|
| 801 |
+
io.Curve.Input("curve"),
|
| 802 |
+
prefix="u_curve",
|
| 803 |
+
min=0,
|
| 804 |
+
max=MAX_CURVES,
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
return io.Schema(
|
| 808 |
+
node_id="GLSLShader",
|
| 809 |
+
display_name="GLSL Shader",
|
| 810 |
+
category="image/shader",
|
| 811 |
+
description=(
|
| 812 |
+
"Apply GLSL ES fragment shaders to images. "
|
| 813 |
+
"u_resolution (vec2) is always available."
|
| 814 |
+
),
|
| 815 |
+
is_experimental=True,
|
| 816 |
+
has_intermediate_output=True,
|
| 817 |
+
inputs=[
|
| 818 |
+
io.String.Input(
|
| 819 |
+
"fragment_shader",
|
| 820 |
+
default=DEFAULT_FRAGMENT_SHADER,
|
| 821 |
+
multiline=True,
|
| 822 |
+
tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
|
| 823 |
+
),
|
| 824 |
+
io.DynamicCombo.Input(
|
| 825 |
+
"size_mode",
|
| 826 |
+
options=[
|
| 827 |
+
io.DynamicCombo.Option("from_input", []),
|
| 828 |
+
io.DynamicCombo.Option(
|
| 829 |
+
"custom",
|
| 830 |
+
[
|
| 831 |
+
io.Int.Input(
|
| 832 |
+
"width",
|
| 833 |
+
default=512,
|
| 834 |
+
min=1,
|
| 835 |
+
max=nodes.MAX_RESOLUTION,
|
| 836 |
+
),
|
| 837 |
+
io.Int.Input(
|
| 838 |
+
"height",
|
| 839 |
+
default=512,
|
| 840 |
+
min=1,
|
| 841 |
+
max=nodes.MAX_RESOLUTION,
|
| 842 |
+
),
|
| 843 |
+
],
|
| 844 |
+
),
|
| 845 |
+
],
|
| 846 |
+
tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
|
| 847 |
+
),
|
| 848 |
+
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
| 849 |
+
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
| 850 |
+
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
| 851 |
+
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
|
| 852 |
+
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
|
| 853 |
+
],
|
| 854 |
+
outputs=[
|
| 855 |
+
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
| 856 |
+
io.Image.Output(display_name="IMAGE1", tooltip="Available via layout(location = 1) out vec4 fragColor1 in the shader code"),
|
| 857 |
+
io.Image.Output(display_name="IMAGE2", tooltip="Available via layout(location = 2) out vec4 fragColor2 in the shader code"),
|
| 858 |
+
io.Image.Output(display_name="IMAGE3", tooltip="Available via layout(location = 3) out vec4 fragColor3 in the shader code"),
|
| 859 |
+
],
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
@classmethod
|
| 863 |
+
def execute(
|
| 864 |
+
cls,
|
| 865 |
+
fragment_shader: str,
|
| 866 |
+
size_mode: SizeModeInput,
|
| 867 |
+
images: io.Autogrow.Type,
|
| 868 |
+
floats: io.Autogrow.Type = None,
|
| 869 |
+
ints: io.Autogrow.Type = None,
|
| 870 |
+
bools: io.Autogrow.Type = None,
|
| 871 |
+
curves: io.Autogrow.Type = None,
|
| 872 |
+
**kwargs,
|
| 873 |
+
) -> io.NodeOutput:
|
| 874 |
+
|
| 875 |
+
image_list = [v for v in images.values() if v is not None]
|
| 876 |
+
float_list = (
|
| 877 |
+
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
| 878 |
+
)
|
| 879 |
+
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
| 880 |
+
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
|
| 881 |
+
|
| 882 |
+
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
|
| 883 |
+
|
| 884 |
+
if not image_list:
|
| 885 |
+
raise ValueError("At least one input image is required")
|
| 886 |
+
|
| 887 |
+
# Determine output dimensions
|
| 888 |
+
if size_mode["size_mode"] == "custom":
|
| 889 |
+
out_width = size_mode["width"]
|
| 890 |
+
out_height = size_mode["height"]
|
| 891 |
+
else:
|
| 892 |
+
out_height, out_width = image_list[0].shape[1:3]
|
| 893 |
+
|
| 894 |
+
batch_size = image_list[0].shape[0]
|
| 895 |
+
|
| 896 |
+
# Prepare batches
|
| 897 |
+
image_batches = []
|
| 898 |
+
for batch_idx in range(batch_size):
|
| 899 |
+
batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
|
| 900 |
+
image_batches.append(batch_images)
|
| 901 |
+
|
| 902 |
+
all_batch_outputs = _render_shader_batch(
|
| 903 |
+
fragment_shader,
|
| 904 |
+
out_width,
|
| 905 |
+
out_height,
|
| 906 |
+
image_batches,
|
| 907 |
+
float_list,
|
| 908 |
+
int_list,
|
| 909 |
+
bool_list,
|
| 910 |
+
curve_luts,
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
# Collect outputs into tensors
|
| 914 |
+
all_outputs = [[] for _ in range(MAX_OUTPUTS)]
|
| 915 |
+
for batch_outputs in all_batch_outputs:
|
| 916 |
+
for i, out_img in enumerate(batch_outputs):
|
| 917 |
+
all_outputs[i].append(torch.from_numpy(out_img))
|
| 918 |
+
|
| 919 |
+
output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
|
| 920 |
+
return io.NodeOutput(
|
| 921 |
+
*output_tensors,
|
| 922 |
+
ui=cls._build_ui_output(image_list, output_tensors[0]),
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
@classmethod
|
| 926 |
+
def _build_ui_output(
|
| 927 |
+
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
| 928 |
+
) -> dict[str, list]:
|
| 929 |
+
"""Build UI output with input and output images for client-side shader execution."""
|
| 930 |
+
input_images_ui = []
|
| 931 |
+
for img in image_list:
|
| 932 |
+
input_images_ui.extend(ui.ImageSaveHelper.save_images(
|
| 933 |
+
img,
|
| 934 |
+
filename_prefix="GLSLShader_input",
|
| 935 |
+
folder_type=io.FolderType.temp,
|
| 936 |
+
cls=None,
|
| 937 |
+
compress_level=1,
|
| 938 |
+
))
|
| 939 |
+
|
| 940 |
+
output_images_ui = ui.ImageSaveHelper.save_images(
|
| 941 |
+
output_batch,
|
| 942 |
+
filename_prefix="GLSLShader_output",
|
| 943 |
+
folder_type=io.FolderType.temp,
|
| 944 |
+
cls=None,
|
| 945 |
+
compress_level=1,
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
return {"input_images": input_images_ui, "images": output_images_ui}
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
class GLSLExtension(ComfyExtension):
|
| 952 |
+
@override
|
| 953 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 954 |
+
return [GLSLShader]
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
async def comfy_entrypoint() -> GLSLExtension:
|
| 958 |
+
return GLSLExtension()
|
ComfyUI/comfy_extras/nodes_hidream.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
|
| 3 |
+
import folder_paths
|
| 4 |
+
import comfy.sd
|
| 5 |
+
import comfy.model_management
|
| 6 |
+
from comfy_api.latest import ComfyExtension, io
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class QuadrupleCLIPLoader(io.ComfyNode):
|
| 10 |
+
@classmethod
|
| 11 |
+
def define_schema(cls):
|
| 12 |
+
return io.Schema(
|
| 13 |
+
node_id="QuadrupleCLIPLoader",
|
| 14 |
+
category="advanced/loaders",
|
| 15 |
+
description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct",
|
| 16 |
+
inputs=[
|
| 17 |
+
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
|
| 18 |
+
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
|
| 19 |
+
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
|
| 20 |
+
io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")),
|
| 21 |
+
],
|
| 22 |
+
outputs=[
|
| 23 |
+
io.Clip.Output(),
|
| 24 |
+
]
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4):
|
| 29 |
+
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
| 30 |
+
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
| 31 |
+
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
| 32 |
+
clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
|
| 33 |
+
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
| 34 |
+
return io.NodeOutput(clip)
|
| 35 |
+
|
| 36 |
+
class CLIPTextEncodeHiDream(io.ComfyNode):
|
| 37 |
+
@classmethod
|
| 38 |
+
def define_schema(cls):
|
| 39 |
+
return io.Schema(
|
| 40 |
+
node_id="CLIPTextEncodeHiDream",
|
| 41 |
+
search_aliases=["hidream prompt"],
|
| 42 |
+
category="advanced/conditioning",
|
| 43 |
+
inputs=[
|
| 44 |
+
io.Clip.Input("clip"),
|
| 45 |
+
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
| 46 |
+
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
|
| 47 |
+
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
| 48 |
+
io.String.Input("llama", multiline=True, dynamic_prompts=True),
|
| 49 |
+
],
|
| 50 |
+
outputs=[
|
| 51 |
+
io.Conditioning.Output(),
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def execute(cls, clip, clip_l, clip_g, t5xxl, llama):
|
| 57 |
+
tokens = clip.tokenize(clip_g)
|
| 58 |
+
tokens["l"] = clip.tokenize(clip_l)["l"]
|
| 59 |
+
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
| 60 |
+
tokens["llama"] = clip.tokenize(llama)["llama"]
|
| 61 |
+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class HiDreamExtension(ComfyExtension):
|
| 65 |
+
@override
|
| 66 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 67 |
+
return [
|
| 68 |
+
QuadrupleCLIPLoader,
|
| 69 |
+
CLIPTextEncodeHiDream,
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
async def comfy_entrypoint() -> HiDreamExtension:
|
| 74 |
+
return HiDreamExtension()
|
ComfyUI/comfy_extras/nodes_hooks.py
ADDED
|
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import TYPE_CHECKING, Union
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
from collections.abc import Iterable
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from comfy.sd import CLIP
|
| 9 |
+
|
| 10 |
+
import comfy.hooks
|
| 11 |
+
import comfy.sd
|
| 12 |
+
import comfy.utils
|
| 13 |
+
import folder_paths
|
| 14 |
+
|
| 15 |
+
###########################################
|
| 16 |
+
# Mask, Combine, and Hook Conditioning
|
| 17 |
+
#------------------------------------------
|
| 18 |
+
class PairConditioningSetProperties:
|
| 19 |
+
NodeId = 'PairConditioningSetProperties'
|
| 20 |
+
NodeName = 'Cond Pair Set Props'
|
| 21 |
+
@classmethod
|
| 22 |
+
def INPUT_TYPES(s):
|
| 23 |
+
return {
|
| 24 |
+
"required": {
|
| 25 |
+
"positive_NEW": ("CONDITIONING", ),
|
| 26 |
+
"negative_NEW": ("CONDITIONING", ),
|
| 27 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
| 28 |
+
"set_cond_area": (["default", "mask bounds"],),
|
| 29 |
+
},
|
| 30 |
+
"optional": {
|
| 31 |
+
"mask": ("MASK", ),
|
| 32 |
+
"hooks": ("HOOKS",),
|
| 33 |
+
"timesteps": ("TIMESTEPS_RANGE",),
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
EXPERIMENTAL = True
|
| 38 |
+
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
| 39 |
+
RETURN_NAMES = ("positive", "negative")
|
| 40 |
+
CATEGORY = "advanced/hooks/cond pair"
|
| 41 |
+
FUNCTION = "set_properties"
|
| 42 |
+
|
| 43 |
+
def set_properties(self, positive_NEW, negative_NEW,
|
| 44 |
+
strength: float, set_cond_area: str,
|
| 45 |
+
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
| 46 |
+
final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
|
| 47 |
+
strength=strength, set_cond_area=set_cond_area,
|
| 48 |
+
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
| 49 |
+
return (final_positive, final_negative)
|
| 50 |
+
|
| 51 |
+
class PairConditioningSetPropertiesAndCombine:
|
| 52 |
+
NodeId = 'PairConditioningSetPropertiesAndCombine'
|
| 53 |
+
NodeName = 'Cond Pair Set Props Combine'
|
| 54 |
+
@classmethod
|
| 55 |
+
def INPUT_TYPES(s):
|
| 56 |
+
return {
|
| 57 |
+
"required": {
|
| 58 |
+
"positive": ("CONDITIONING", ),
|
| 59 |
+
"negative": ("CONDITIONING", ),
|
| 60 |
+
"positive_NEW": ("CONDITIONING", ),
|
| 61 |
+
"negative_NEW": ("CONDITIONING", ),
|
| 62 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
| 63 |
+
"set_cond_area": (["default", "mask bounds"],),
|
| 64 |
+
},
|
| 65 |
+
"optional": {
|
| 66 |
+
"mask": ("MASK", ),
|
| 67 |
+
"hooks": ("HOOKS",),
|
| 68 |
+
"timesteps": ("TIMESTEPS_RANGE",),
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
EXPERIMENTAL = True
|
| 73 |
+
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
| 74 |
+
RETURN_NAMES = ("positive", "negative")
|
| 75 |
+
CATEGORY = "advanced/hooks/cond pair"
|
| 76 |
+
FUNCTION = "set_properties"
|
| 77 |
+
|
| 78 |
+
def set_properties(self, positive, negative, positive_NEW, negative_NEW,
|
| 79 |
+
strength: float, set_cond_area: str,
|
| 80 |
+
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
| 81 |
+
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
|
| 82 |
+
strength=strength, set_cond_area=set_cond_area,
|
| 83 |
+
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
| 84 |
+
return (final_positive, final_negative)
|
| 85 |
+
|
| 86 |
+
class ConditioningSetProperties:
|
| 87 |
+
NodeId = 'ConditioningSetProperties'
|
| 88 |
+
NodeName = 'Cond Set Props'
|
| 89 |
+
@classmethod
|
| 90 |
+
def INPUT_TYPES(s):
|
| 91 |
+
return {
|
| 92 |
+
"required": {
|
| 93 |
+
"cond_NEW": ("CONDITIONING", ),
|
| 94 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
| 95 |
+
"set_cond_area": (["default", "mask bounds"],),
|
| 96 |
+
},
|
| 97 |
+
"optional": {
|
| 98 |
+
"mask": ("MASK", ),
|
| 99 |
+
"hooks": ("HOOKS",),
|
| 100 |
+
"timesteps": ("TIMESTEPS_RANGE",),
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
EXPERIMENTAL = True
|
| 105 |
+
RETURN_TYPES = ("CONDITIONING",)
|
| 106 |
+
CATEGORY = "advanced/hooks/cond single"
|
| 107 |
+
FUNCTION = "set_properties"
|
| 108 |
+
|
| 109 |
+
def set_properties(self, cond_NEW,
|
| 110 |
+
strength: float, set_cond_area: str,
|
| 111 |
+
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
| 112 |
+
(final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
|
| 113 |
+
strength=strength, set_cond_area=set_cond_area,
|
| 114 |
+
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
| 115 |
+
return (final_cond,)
|
| 116 |
+
|
| 117 |
+
class ConditioningSetPropertiesAndCombine:
|
| 118 |
+
NodeId = 'ConditioningSetPropertiesAndCombine'
|
| 119 |
+
NodeName = 'Cond Set Props Combine'
|
| 120 |
+
@classmethod
|
| 121 |
+
def INPUT_TYPES(s):
|
| 122 |
+
return {
|
| 123 |
+
"required": {
|
| 124 |
+
"cond": ("CONDITIONING", ),
|
| 125 |
+
"cond_NEW": ("CONDITIONING", ),
|
| 126 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
| 127 |
+
"set_cond_area": (["default", "mask bounds"],),
|
| 128 |
+
},
|
| 129 |
+
"optional": {
|
| 130 |
+
"mask": ("MASK", ),
|
| 131 |
+
"hooks": ("HOOKS",),
|
| 132 |
+
"timesteps": ("TIMESTEPS_RANGE",),
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
EXPERIMENTAL = True
|
| 137 |
+
RETURN_TYPES = ("CONDITIONING",)
|
| 138 |
+
CATEGORY = "advanced/hooks/cond single"
|
| 139 |
+
FUNCTION = "set_properties"
|
| 140 |
+
|
| 141 |
+
def set_properties(self, cond, cond_NEW,
|
| 142 |
+
strength: float, set_cond_area: str,
|
| 143 |
+
mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
|
| 144 |
+
(final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
|
| 145 |
+
strength=strength, set_cond_area=set_cond_area,
|
| 146 |
+
mask=mask, hooks=hooks, timesteps_range=timesteps)
|
| 147 |
+
return (final_cond,)
|
| 148 |
+
|
| 149 |
+
class PairConditioningCombine:
|
| 150 |
+
NodeId = 'PairConditioningCombine'
|
| 151 |
+
NodeName = 'Cond Pair Combine'
|
| 152 |
+
@classmethod
|
| 153 |
+
def INPUT_TYPES(s):
|
| 154 |
+
return {
|
| 155 |
+
"required": {
|
| 156 |
+
"positive_A": ("CONDITIONING",),
|
| 157 |
+
"negative_A": ("CONDITIONING",),
|
| 158 |
+
"positive_B": ("CONDITIONING",),
|
| 159 |
+
"negative_B": ("CONDITIONING",),
|
| 160 |
+
},
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
EXPERIMENTAL = True
|
| 164 |
+
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
| 165 |
+
RETURN_NAMES = ("positive", "negative")
|
| 166 |
+
CATEGORY = "advanced/hooks/cond pair"
|
| 167 |
+
FUNCTION = "combine"
|
| 168 |
+
|
| 169 |
+
def combine(self, positive_A, negative_A, positive_B, negative_B):
|
| 170 |
+
final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
|
| 171 |
+
return (final_positive, final_negative,)
|
| 172 |
+
|
| 173 |
+
class PairConditioningSetDefaultAndCombine:
|
| 174 |
+
NodeId = 'PairConditioningSetDefaultCombine'
|
| 175 |
+
NodeName = 'Cond Pair Set Default Combine'
|
| 176 |
+
@classmethod
|
| 177 |
+
def INPUT_TYPES(s):
|
| 178 |
+
return {
|
| 179 |
+
"required": {
|
| 180 |
+
"positive": ("CONDITIONING",),
|
| 181 |
+
"negative": ("CONDITIONING",),
|
| 182 |
+
"positive_DEFAULT": ("CONDITIONING",),
|
| 183 |
+
"negative_DEFAULT": ("CONDITIONING",),
|
| 184 |
+
},
|
| 185 |
+
"optional": {
|
| 186 |
+
"hooks": ("HOOKS",),
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
EXPERIMENTAL = True
|
| 191 |
+
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
| 192 |
+
RETURN_NAMES = ("positive", "negative")
|
| 193 |
+
CATEGORY = "advanced/hooks/cond pair"
|
| 194 |
+
FUNCTION = "set_default_and_combine"
|
| 195 |
+
|
| 196 |
+
def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
|
| 197 |
+
hooks: comfy.hooks.HookGroup=None):
|
| 198 |
+
final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
|
| 199 |
+
hooks=hooks)
|
| 200 |
+
return (final_positive, final_negative)
|
| 201 |
+
|
| 202 |
+
class ConditioningSetDefaultAndCombine:
|
| 203 |
+
NodeId = 'ConditioningSetDefaultCombine'
|
| 204 |
+
NodeName = 'Cond Set Default Combine'
|
| 205 |
+
@classmethod
|
| 206 |
+
def INPUT_TYPES(s):
|
| 207 |
+
return {
|
| 208 |
+
"required": {
|
| 209 |
+
"cond": ("CONDITIONING",),
|
| 210 |
+
"cond_DEFAULT": ("CONDITIONING",),
|
| 211 |
+
},
|
| 212 |
+
"optional": {
|
| 213 |
+
"hooks": ("HOOKS",),
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
EXPERIMENTAL = True
|
| 218 |
+
RETURN_TYPES = ("CONDITIONING",)
|
| 219 |
+
CATEGORY = "advanced/hooks/cond single"
|
| 220 |
+
FUNCTION = "set_default_and_combine"
|
| 221 |
+
|
| 222 |
+
def set_default_and_combine(self, cond, cond_DEFAULT,
|
| 223 |
+
hooks: comfy.hooks.HookGroup=None):
|
| 224 |
+
(final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
|
| 225 |
+
hooks=hooks)
|
| 226 |
+
return (final_conditioning,)
|
| 227 |
+
|
| 228 |
+
class SetClipHooks:
|
| 229 |
+
NodeId = 'SetClipHooks'
|
| 230 |
+
NodeName = 'Set CLIP Hooks'
|
| 231 |
+
@classmethod
|
| 232 |
+
def INPUT_TYPES(s):
|
| 233 |
+
return {
|
| 234 |
+
"required": {
|
| 235 |
+
"clip": ("CLIP",),
|
| 236 |
+
"apply_to_conds": ("BOOLEAN", {"default": True, "advanced": True}),
|
| 237 |
+
"schedule_clip": ("BOOLEAN", {"default": False, "advanced": True})
|
| 238 |
+
},
|
| 239 |
+
"optional": {
|
| 240 |
+
"hooks": ("HOOKS",)
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
EXPERIMENTAL = True
|
| 245 |
+
RETURN_TYPES = ("CLIP",)
|
| 246 |
+
CATEGORY = "advanced/hooks/clip"
|
| 247 |
+
FUNCTION = "apply_hooks"
|
| 248 |
+
|
| 249 |
+
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
| 250 |
+
if hooks is not None:
|
| 251 |
+
clip = clip.clone(disable_dynamic=True)
|
| 252 |
+
if apply_to_conds:
|
| 253 |
+
clip.apply_hooks_to_conds = hooks
|
| 254 |
+
clip.patcher.forced_hooks = hooks.clone()
|
| 255 |
+
clip.use_clip_schedule = schedule_clip
|
| 256 |
+
if not clip.use_clip_schedule:
|
| 257 |
+
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
|
| 258 |
+
clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
|
| 259 |
+
return (clip,)
|
| 260 |
+
|
| 261 |
+
class ConditioningTimestepsRange:
|
| 262 |
+
SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"]
|
| 263 |
+
NodeId = 'ConditioningTimestepsRange'
|
| 264 |
+
NodeName = 'Timesteps Range'
|
| 265 |
+
@classmethod
|
| 266 |
+
def INPUT_TYPES(s):
|
| 267 |
+
return {
|
| 268 |
+
"required": {
|
| 269 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 270 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
| 271 |
+
},
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
EXPERIMENTAL = True
|
| 275 |
+
RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
|
| 276 |
+
RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
|
| 277 |
+
CATEGORY = "advanced/hooks"
|
| 278 |
+
FUNCTION = "create_range"
|
| 279 |
+
|
| 280 |
+
def create_range(self, start_percent: float, end_percent: float):
|
| 281 |
+
return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0))
|
| 282 |
+
#------------------------------------------
|
| 283 |
+
###########################################
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
###########################################
|
| 287 |
+
# Create Hooks
|
| 288 |
+
#------------------------------------------
|
| 289 |
+
class CreateHookLora:
|
| 290 |
+
NodeId = 'CreateHookLora'
|
| 291 |
+
NodeName = 'Create Hook LoRA'
|
| 292 |
+
def __init__(self):
|
| 293 |
+
self.loaded_lora = None
|
| 294 |
+
|
| 295 |
+
@classmethod
|
| 296 |
+
def INPUT_TYPES(s):
|
| 297 |
+
return {
|
| 298 |
+
"required": {
|
| 299 |
+
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
| 300 |
+
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
| 301 |
+
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
| 302 |
+
},
|
| 303 |
+
"optional": {
|
| 304 |
+
"prev_hooks": ("HOOKS",)
|
| 305 |
+
}
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
EXPERIMENTAL = True
|
| 309 |
+
RETURN_TYPES = ("HOOKS",)
|
| 310 |
+
CATEGORY = "advanced/hooks/create"
|
| 311 |
+
FUNCTION = "create_hook"
|
| 312 |
+
|
| 313 |
+
def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None):
|
| 314 |
+
if prev_hooks is None:
|
| 315 |
+
prev_hooks = comfy.hooks.HookGroup()
|
| 316 |
+
prev_hooks.clone()
|
| 317 |
+
|
| 318 |
+
if strength_model == 0 and strength_clip == 0:
|
| 319 |
+
return (prev_hooks,)
|
| 320 |
+
|
| 321 |
+
lora_path = folder_paths.get_full_path("loras", lora_name)
|
| 322 |
+
lora = None
|
| 323 |
+
if self.loaded_lora is not None:
|
| 324 |
+
if self.loaded_lora[0] == lora_path:
|
| 325 |
+
lora = self.loaded_lora[1]
|
| 326 |
+
else:
|
| 327 |
+
temp = self.loaded_lora
|
| 328 |
+
self.loaded_lora = None
|
| 329 |
+
del temp
|
| 330 |
+
|
| 331 |
+
if lora is None:
|
| 332 |
+
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
| 333 |
+
self.loaded_lora = (lora_path, lora)
|
| 334 |
+
|
| 335 |
+
hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip)
|
| 336 |
+
return (prev_hooks.clone_and_combine(hooks),)
|
| 337 |
+
|
| 338 |
+
class CreateHookLoraModelOnly(CreateHookLora):
|
| 339 |
+
NodeId = 'CreateHookLoraModelOnly'
|
| 340 |
+
NodeName = 'Create Hook LoRA (MO)'
|
| 341 |
+
@classmethod
|
| 342 |
+
def INPUT_TYPES(s):
|
| 343 |
+
return {
|
| 344 |
+
"required": {
|
| 345 |
+
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
| 346 |
+
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
| 347 |
+
},
|
| 348 |
+
"optional": {
|
| 349 |
+
"prev_hooks": ("HOOKS",)
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
EXPERIMENTAL = True
|
| 354 |
+
RETURN_TYPES = ("HOOKS",)
|
| 355 |
+
CATEGORY = "advanced/hooks/create"
|
| 356 |
+
FUNCTION = "create_hook_model_only"
|
| 357 |
+
|
| 358 |
+
def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None):
|
| 359 |
+
return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks)
|
| 360 |
+
|
| 361 |
+
class CreateHookModelAsLora:
|
| 362 |
+
NodeId = 'CreateHookModelAsLora'
|
| 363 |
+
NodeName = 'Create Hook Model as LoRA'
|
| 364 |
+
|
| 365 |
+
def __init__(self):
|
| 366 |
+
# when not None, will be in following format:
|
| 367 |
+
# (ckpt_path: str, weights_model: dict, weights_clip: dict)
|
| 368 |
+
self.loaded_weights = None
|
| 369 |
+
|
| 370 |
+
@classmethod
|
| 371 |
+
def INPUT_TYPES(s):
|
| 372 |
+
return {
|
| 373 |
+
"required": {
|
| 374 |
+
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
| 375 |
+
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
| 376 |
+
"strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
| 377 |
+
},
|
| 378 |
+
"optional": {
|
| 379 |
+
"prev_hooks": ("HOOKS",)
|
| 380 |
+
}
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
EXPERIMENTAL = True
|
| 384 |
+
RETURN_TYPES = ("HOOKS",)
|
| 385 |
+
CATEGORY = "advanced/hooks/create"
|
| 386 |
+
FUNCTION = "create_hook"
|
| 387 |
+
|
| 388 |
+
def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float,
|
| 389 |
+
prev_hooks: comfy.hooks.HookGroup=None):
|
| 390 |
+
if prev_hooks is None:
|
| 391 |
+
prev_hooks = comfy.hooks.HookGroup()
|
| 392 |
+
prev_hooks.clone()
|
| 393 |
+
|
| 394 |
+
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
| 395 |
+
weights_model = None
|
| 396 |
+
weights_clip = None
|
| 397 |
+
if self.loaded_weights is not None:
|
| 398 |
+
if self.loaded_weights[0] == ckpt_path:
|
| 399 |
+
weights_model = self.loaded_weights[1]
|
| 400 |
+
weights_clip = self.loaded_weights[2]
|
| 401 |
+
else:
|
| 402 |
+
temp = self.loaded_weights
|
| 403 |
+
self.loaded_weights = None
|
| 404 |
+
del temp
|
| 405 |
+
|
| 406 |
+
if weights_model is None:
|
| 407 |
+
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
| 408 |
+
weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
|
| 409 |
+
weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1])
|
| 410 |
+
self.loaded_weights = (ckpt_path, weights_model, weights_clip)
|
| 411 |
+
|
| 412 |
+
hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip,
|
| 413 |
+
strength_model=strength_model, strength_clip=strength_clip)
|
| 414 |
+
return (prev_hooks.clone_and_combine(hooks),)
|
| 415 |
+
|
| 416 |
+
class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
|
| 417 |
+
NodeId = 'CreateHookModelAsLoraModelOnly'
|
| 418 |
+
NodeName = 'Create Hook Model as LoRA (MO)'
|
| 419 |
+
@classmethod
|
| 420 |
+
def INPUT_TYPES(s):
|
| 421 |
+
return {
|
| 422 |
+
"required": {
|
| 423 |
+
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
| 424 |
+
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
| 425 |
+
},
|
| 426 |
+
"optional": {
|
| 427 |
+
"prev_hooks": ("HOOKS",)
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
EXPERIMENTAL = True
|
| 432 |
+
RETURN_TYPES = ("HOOKS",)
|
| 433 |
+
CATEGORY = "advanced/hooks/create"
|
| 434 |
+
FUNCTION = "create_hook_model_only"
|
| 435 |
+
|
| 436 |
+
def create_hook_model_only(self, ckpt_name: str, strength_model: float,
|
| 437 |
+
prev_hooks: comfy.hooks.HookGroup=None):
|
| 438 |
+
return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks)
|
| 439 |
+
#------------------------------------------
|
| 440 |
+
###########################################
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
###########################################
|
| 444 |
+
# Schedule Hooks
|
| 445 |
+
#------------------------------------------
|
| 446 |
+
class SetHookKeyframes:
|
| 447 |
+
NodeId = 'SetHookKeyframes'
|
| 448 |
+
NodeName = 'Set Hook Keyframes'
|
| 449 |
+
@classmethod
|
| 450 |
+
def INPUT_TYPES(s):
|
| 451 |
+
return {
|
| 452 |
+
"required": {
|
| 453 |
+
"hooks": ("HOOKS",),
|
| 454 |
+
},
|
| 455 |
+
"optional": {
|
| 456 |
+
"hook_kf": ("HOOK_KEYFRAMES",),
|
| 457 |
+
}
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
EXPERIMENTAL = True
|
| 461 |
+
RETURN_TYPES = ("HOOKS",)
|
| 462 |
+
CATEGORY = "advanced/hooks/scheduling"
|
| 463 |
+
FUNCTION = "set_hook_keyframes"
|
| 464 |
+
|
| 465 |
+
def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
| 466 |
+
if hook_kf is not None:
|
| 467 |
+
hooks = hooks.clone()
|
| 468 |
+
hooks.set_keyframes_on_hooks(hook_kf=hook_kf)
|
| 469 |
+
return (hooks,)
|
| 470 |
+
|
| 471 |
+
class CreateHookKeyframe:
|
| 472 |
+
SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"]
|
| 473 |
+
NodeId = 'CreateHookKeyframe'
|
| 474 |
+
NodeName = 'Create Hook Keyframe'
|
| 475 |
+
@classmethod
|
| 476 |
+
def INPUT_TYPES(s):
|
| 477 |
+
return {
|
| 478 |
+
"required": {
|
| 479 |
+
"strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
|
| 480 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 481 |
+
},
|
| 482 |
+
"optional": {
|
| 483 |
+
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
EXPERIMENTAL = True
|
| 488 |
+
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
| 489 |
+
RETURN_NAMES = ("HOOK_KF",)
|
| 490 |
+
CATEGORY = "advanced/hooks/scheduling"
|
| 491 |
+
FUNCTION = "create_hook_keyframe"
|
| 492 |
+
|
| 493 |
+
def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
| 494 |
+
if prev_hook_kf is None:
|
| 495 |
+
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
| 496 |
+
prev_hook_kf = prev_hook_kf.clone()
|
| 497 |
+
keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent)
|
| 498 |
+
prev_hook_kf.add(keyframe)
|
| 499 |
+
return (prev_hook_kf,)
|
| 500 |
+
|
| 501 |
+
class CreateHookKeyframesInterpolated:
|
| 502 |
+
SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"]
|
| 503 |
+
NodeId = 'CreateHookKeyframesInterpolated'
|
| 504 |
+
NodeName = 'Create Hook Keyframes Interp.'
|
| 505 |
+
@classmethod
|
| 506 |
+
def INPUT_TYPES(s):
|
| 507 |
+
return {
|
| 508 |
+
"required": {
|
| 509 |
+
"strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 510 |
+
"strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 511 |
+
"interpolation": (comfy.hooks.InterpolationMethod._LIST, ),
|
| 512 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 513 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 514 |
+
"keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}),
|
| 515 |
+
"print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}),
|
| 516 |
+
},
|
| 517 |
+
"optional": {
|
| 518 |
+
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
| 519 |
+
},
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
EXPERIMENTAL = True
|
| 523 |
+
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
| 524 |
+
RETURN_NAMES = ("HOOK_KF",)
|
| 525 |
+
CATEGORY = "advanced/hooks/scheduling"
|
| 526 |
+
FUNCTION = "create_hook_keyframes"
|
| 527 |
+
|
| 528 |
+
def create_hook_keyframes(self, strength_start: float, strength_end: float, interpolation: str,
|
| 529 |
+
start_percent: float, end_percent: float, keyframes_count: int,
|
| 530 |
+
print_keyframes=False, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
|
| 531 |
+
if prev_hook_kf is None:
|
| 532 |
+
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
| 533 |
+
prev_hook_kf = prev_hook_kf.clone()
|
| 534 |
+
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=keyframes_count,
|
| 535 |
+
method=comfy.hooks.InterpolationMethod.LINEAR)
|
| 536 |
+
strengths = comfy.hooks.InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=keyframes_count, method=interpolation)
|
| 537 |
+
|
| 538 |
+
is_first = True
|
| 539 |
+
for percent, strength in zip(percents, strengths):
|
| 540 |
+
guarantee_steps = 0
|
| 541 |
+
if is_first:
|
| 542 |
+
guarantee_steps = 1
|
| 543 |
+
is_first = False
|
| 544 |
+
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
| 545 |
+
if print_keyframes:
|
| 546 |
+
logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
| 547 |
+
return (prev_hook_kf,)
|
| 548 |
+
|
| 549 |
+
class CreateHookKeyframesFromFloats:
|
| 550 |
+
SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"]
|
| 551 |
+
NodeId = 'CreateHookKeyframesFromFloats'
|
| 552 |
+
NodeName = 'Create Hook Keyframes From Floats'
|
| 553 |
+
@classmethod
|
| 554 |
+
def INPUT_TYPES(s):
|
| 555 |
+
return {
|
| 556 |
+
"required": {
|
| 557 |
+
"floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
|
| 558 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 559 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 560 |
+
"print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}),
|
| 561 |
+
},
|
| 562 |
+
"optional": {
|
| 563 |
+
"prev_hook_kf": ("HOOK_KEYFRAMES",),
|
| 564 |
+
}
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
EXPERIMENTAL = True
|
| 568 |
+
RETURN_TYPES = ("HOOK_KEYFRAMES",)
|
| 569 |
+
RETURN_NAMES = ("HOOK_KF",)
|
| 570 |
+
CATEGORY = "advanced/hooks/scheduling"
|
| 571 |
+
FUNCTION = "create_hook_keyframes"
|
| 572 |
+
|
| 573 |
+
def create_hook_keyframes(self, floats_strength: Union[float, list[float]],
|
| 574 |
+
start_percent: float, end_percent: float,
|
| 575 |
+
prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False):
|
| 576 |
+
if prev_hook_kf is None:
|
| 577 |
+
prev_hook_kf = comfy.hooks.HookKeyframeGroup()
|
| 578 |
+
prev_hook_kf = prev_hook_kf.clone()
|
| 579 |
+
if type(floats_strength) in (float, int):
|
| 580 |
+
floats_strength = [float(floats_strength)]
|
| 581 |
+
elif isinstance(floats_strength, Iterable):
|
| 582 |
+
pass
|
| 583 |
+
else:
|
| 584 |
+
raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
|
| 585 |
+
percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
|
| 586 |
+
method=comfy.hooks.InterpolationMethod.LINEAR)
|
| 587 |
+
|
| 588 |
+
is_first = True
|
| 589 |
+
for percent, strength in zip(percents, floats_strength):
|
| 590 |
+
guarantee_steps = 0
|
| 591 |
+
if is_first:
|
| 592 |
+
guarantee_steps = 1
|
| 593 |
+
is_first = False
|
| 594 |
+
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
| 595 |
+
if print_keyframes:
|
| 596 |
+
logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
| 597 |
+
return (prev_hook_kf,)
|
| 598 |
+
#------------------------------------------
|
| 599 |
+
###########################################
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
class SetModelHooksOnCond:
|
| 603 |
+
@classmethod
|
| 604 |
+
def INPUT_TYPES(s):
|
| 605 |
+
return {
|
| 606 |
+
"required": {
|
| 607 |
+
"conditioning": ("CONDITIONING",),
|
| 608 |
+
"hooks": ("HOOKS",),
|
| 609 |
+
},
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
EXPERIMENTAL = True
|
| 613 |
+
RETURN_TYPES = ("CONDITIONING",)
|
| 614 |
+
CATEGORY = "advanced/hooks/manual"
|
| 615 |
+
FUNCTION = "attach_hook"
|
| 616 |
+
|
| 617 |
+
def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
|
| 618 |
+
return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
###########################################
|
| 622 |
+
# Combine Hooks
|
| 623 |
+
#------------------------------------------
|
| 624 |
+
class CombineHooks:
|
| 625 |
+
SEARCH_ALIASES = ["merge hooks"]
|
| 626 |
+
NodeId = 'CombineHooks2'
|
| 627 |
+
NodeName = 'Combine Hooks [2]'
|
| 628 |
+
@classmethod
|
| 629 |
+
def INPUT_TYPES(s):
|
| 630 |
+
return {
|
| 631 |
+
"required": {
|
| 632 |
+
},
|
| 633 |
+
"optional": {
|
| 634 |
+
"hooks_A": ("HOOKS",),
|
| 635 |
+
"hooks_B": ("HOOKS",),
|
| 636 |
+
}
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
EXPERIMENTAL = True
|
| 640 |
+
RETURN_TYPES = ("HOOKS",)
|
| 641 |
+
CATEGORY = "advanced/hooks/combine"
|
| 642 |
+
FUNCTION = "combine_hooks"
|
| 643 |
+
|
| 644 |
+
def combine_hooks(self,
|
| 645 |
+
hooks_A: comfy.hooks.HookGroup=None,
|
| 646 |
+
hooks_B: comfy.hooks.HookGroup=None):
|
| 647 |
+
candidates = [hooks_A, hooks_B]
|
| 648 |
+
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
| 649 |
+
|
| 650 |
+
class CombineHooksFour:
|
| 651 |
+
NodeId = 'CombineHooks4'
|
| 652 |
+
NodeName = 'Combine Hooks [4]'
|
| 653 |
+
@classmethod
|
| 654 |
+
def INPUT_TYPES(s):
|
| 655 |
+
return {
|
| 656 |
+
"required": {
|
| 657 |
+
},
|
| 658 |
+
"optional": {
|
| 659 |
+
"hooks_A": ("HOOKS",),
|
| 660 |
+
"hooks_B": ("HOOKS",),
|
| 661 |
+
"hooks_C": ("HOOKS",),
|
| 662 |
+
"hooks_D": ("HOOKS",),
|
| 663 |
+
}
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
EXPERIMENTAL = True
|
| 667 |
+
RETURN_TYPES = ("HOOKS",)
|
| 668 |
+
CATEGORY = "advanced/hooks/combine"
|
| 669 |
+
FUNCTION = "combine_hooks"
|
| 670 |
+
|
| 671 |
+
def combine_hooks(self,
|
| 672 |
+
hooks_A: comfy.hooks.HookGroup=None,
|
| 673 |
+
hooks_B: comfy.hooks.HookGroup=None,
|
| 674 |
+
hooks_C: comfy.hooks.HookGroup=None,
|
| 675 |
+
hooks_D: comfy.hooks.HookGroup=None):
|
| 676 |
+
candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
|
| 677 |
+
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
| 678 |
+
|
| 679 |
+
class CombineHooksEight:
|
| 680 |
+
NodeId = 'CombineHooks8'
|
| 681 |
+
NodeName = 'Combine Hooks [8]'
|
| 682 |
+
@classmethod
|
| 683 |
+
def INPUT_TYPES(s):
|
| 684 |
+
return {
|
| 685 |
+
"required": {
|
| 686 |
+
},
|
| 687 |
+
"optional": {
|
| 688 |
+
"hooks_A": ("HOOKS",),
|
| 689 |
+
"hooks_B": ("HOOKS",),
|
| 690 |
+
"hooks_C": ("HOOKS",),
|
| 691 |
+
"hooks_D": ("HOOKS",),
|
| 692 |
+
"hooks_E": ("HOOKS",),
|
| 693 |
+
"hooks_F": ("HOOKS",),
|
| 694 |
+
"hooks_G": ("HOOKS",),
|
| 695 |
+
"hooks_H": ("HOOKS",),
|
| 696 |
+
}
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
EXPERIMENTAL = True
|
| 700 |
+
RETURN_TYPES = ("HOOKS",)
|
| 701 |
+
CATEGORY = "advanced/hooks/combine"
|
| 702 |
+
FUNCTION = "combine_hooks"
|
| 703 |
+
|
| 704 |
+
def combine_hooks(self,
|
| 705 |
+
hooks_A: comfy.hooks.HookGroup=None,
|
| 706 |
+
hooks_B: comfy.hooks.HookGroup=None,
|
| 707 |
+
hooks_C: comfy.hooks.HookGroup=None,
|
| 708 |
+
hooks_D: comfy.hooks.HookGroup=None,
|
| 709 |
+
hooks_E: comfy.hooks.HookGroup=None,
|
| 710 |
+
hooks_F: comfy.hooks.HookGroup=None,
|
| 711 |
+
hooks_G: comfy.hooks.HookGroup=None,
|
| 712 |
+
hooks_H: comfy.hooks.HookGroup=None):
|
| 713 |
+
candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
|
| 714 |
+
return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
|
| 715 |
+
#------------------------------------------
|
| 716 |
+
###########################################
|
| 717 |
+
|
| 718 |
+
node_list = [
|
| 719 |
+
# Create
|
| 720 |
+
CreateHookLora,
|
| 721 |
+
CreateHookLoraModelOnly,
|
| 722 |
+
CreateHookModelAsLora,
|
| 723 |
+
CreateHookModelAsLoraModelOnly,
|
| 724 |
+
# Scheduling
|
| 725 |
+
SetHookKeyframes,
|
| 726 |
+
CreateHookKeyframe,
|
| 727 |
+
CreateHookKeyframesInterpolated,
|
| 728 |
+
CreateHookKeyframesFromFloats,
|
| 729 |
+
# Combine
|
| 730 |
+
CombineHooks,
|
| 731 |
+
CombineHooksFour,
|
| 732 |
+
CombineHooksEight,
|
| 733 |
+
# Attach
|
| 734 |
+
ConditioningSetProperties,
|
| 735 |
+
ConditioningSetPropertiesAndCombine,
|
| 736 |
+
PairConditioningSetProperties,
|
| 737 |
+
PairConditioningSetPropertiesAndCombine,
|
| 738 |
+
ConditioningSetDefaultAndCombine,
|
| 739 |
+
PairConditioningSetDefaultAndCombine,
|
| 740 |
+
PairConditioningCombine,
|
| 741 |
+
SetClipHooks,
|
| 742 |
+
# Other
|
| 743 |
+
ConditioningTimestepsRange,
|
| 744 |
+
]
|
| 745 |
+
NODE_CLASS_MAPPINGS = {}
|
| 746 |
+
NODE_DISPLAY_NAME_MAPPINGS = {}
|
| 747 |
+
|
| 748 |
+
for node in node_list:
|
| 749 |
+
NODE_CLASS_MAPPINGS[node.NodeId] = node
|
| 750 |
+
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
|
ComfyUI/comfy_extras/nodes_hunyuan.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nodes
|
| 2 |
+
import node_helpers
|
| 3 |
+
import torch
|
| 4 |
+
import comfy.model_management
|
| 5 |
+
from typing_extensions import override
|
| 6 |
+
from comfy_api.latest import ComfyExtension, io
|
| 7 |
+
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
| 8 |
+
from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
|
| 9 |
+
import folder_paths
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
| 13 |
+
@classmethod
|
| 14 |
+
def define_schema(cls):
|
| 15 |
+
return io.Schema(
|
| 16 |
+
node_id="CLIPTextEncodeHunyuanDiT",
|
| 17 |
+
category="advanced/conditioning",
|
| 18 |
+
inputs=[
|
| 19 |
+
io.Clip.Input("clip"),
|
| 20 |
+
io.String.Input("bert", multiline=True, dynamic_prompts=True),
|
| 21 |
+
io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
|
| 22 |
+
],
|
| 23 |
+
outputs=[
|
| 24 |
+
io.Conditioning.Output(),
|
| 25 |
+
],
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def execute(cls, clip, bert, mt5xl) -> io.NodeOutput:
|
| 30 |
+
tokens = clip.tokenize(bert)
|
| 31 |
+
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
| 32 |
+
|
| 33 |
+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
| 34 |
+
|
| 35 |
+
encode = execute # TODO: remove
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class EmptyHunyuanLatentVideo(io.ComfyNode):
|
| 39 |
+
@classmethod
|
| 40 |
+
def define_schema(cls):
|
| 41 |
+
return io.Schema(
|
| 42 |
+
node_id="EmptyHunyuanLatentVideo",
|
| 43 |
+
display_name="Empty HunyuanVideo 1.0 Latent",
|
| 44 |
+
category="latent/video",
|
| 45 |
+
inputs=[
|
| 46 |
+
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 47 |
+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 48 |
+
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
| 49 |
+
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
| 50 |
+
],
|
| 51 |
+
outputs=[
|
| 52 |
+
io.Latent.Output(),
|
| 53 |
+
],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
| 58 |
+
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
| 59 |
+
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
|
| 60 |
+
|
| 61 |
+
generate = execute # TODO: remove
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
|
| 65 |
+
@classmethod
|
| 66 |
+
def define_schema(cls):
|
| 67 |
+
schema = super().define_schema()
|
| 68 |
+
schema.node_id = "EmptyHunyuanVideo15Latent"
|
| 69 |
+
schema.display_name = "Empty HunyuanVideo 1.5 Latent"
|
| 70 |
+
return schema
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
| 74 |
+
# Using scale factor of 16 instead of 8
|
| 75 |
+
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
| 76 |
+
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16})
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class HunyuanVideo15ImageToVideo(io.ComfyNode):
|
| 80 |
+
@classmethod
|
| 81 |
+
def define_schema(cls):
|
| 82 |
+
return io.Schema(
|
| 83 |
+
node_id="HunyuanVideo15ImageToVideo",
|
| 84 |
+
category="conditioning/video_models",
|
| 85 |
+
inputs=[
|
| 86 |
+
io.Conditioning.Input("positive"),
|
| 87 |
+
io.Conditioning.Input("negative"),
|
| 88 |
+
io.Vae.Input("vae"),
|
| 89 |
+
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 90 |
+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 91 |
+
io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
| 92 |
+
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
| 93 |
+
io.Image.Input("start_image", optional=True),
|
| 94 |
+
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
| 95 |
+
],
|
| 96 |
+
outputs=[
|
| 97 |
+
io.Conditioning.Output(display_name="positive"),
|
| 98 |
+
io.Conditioning.Output(display_name="negative"),
|
| 99 |
+
io.Latent.Output(display_name="latent"),
|
| 100 |
+
],
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
| 105 |
+
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
| 106 |
+
|
| 107 |
+
if start_image is not None:
|
| 108 |
+
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
| 109 |
+
|
| 110 |
+
encoded = vae.encode(start_image[:, :, :, :3])
|
| 111 |
+
concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device())
|
| 112 |
+
concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded
|
| 113 |
+
|
| 114 |
+
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
| 115 |
+
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
| 116 |
+
|
| 117 |
+
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
| 118 |
+
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
| 119 |
+
|
| 120 |
+
if clip_vision_output is not None:
|
| 121 |
+
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
| 122 |
+
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
| 123 |
+
|
| 124 |
+
out_latent = {}
|
| 125 |
+
out_latent["samples"] = latent
|
| 126 |
+
return io.NodeOutput(positive, negative, out_latent)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class HunyuanVideo15SuperResolution(io.ComfyNode):
|
| 130 |
+
@classmethod
|
| 131 |
+
def define_schema(cls):
|
| 132 |
+
return io.Schema(
|
| 133 |
+
node_id="HunyuanVideo15SuperResolution",
|
| 134 |
+
inputs=[
|
| 135 |
+
io.Conditioning.Input("positive"),
|
| 136 |
+
io.Conditioning.Input("negative"),
|
| 137 |
+
io.Vae.Input("vae", optional=True),
|
| 138 |
+
io.Image.Input("start_image", optional=True),
|
| 139 |
+
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
| 140 |
+
io.Latent.Input("latent"),
|
| 141 |
+
io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01, advanced=True),
|
| 142 |
+
|
| 143 |
+
],
|
| 144 |
+
outputs=[
|
| 145 |
+
io.Conditioning.Output(display_name="positive"),
|
| 146 |
+
io.Conditioning.Output(display_name="negative"),
|
| 147 |
+
io.Latent.Output(display_name="latent"),
|
| 148 |
+
],
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
@classmethod
|
| 152 |
+
def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
| 153 |
+
in_latent = latent["samples"]
|
| 154 |
+
in_channels = in_latent.shape[1]
|
| 155 |
+
cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
| 156 |
+
cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent
|
| 157 |
+
cond_latent[:, 2 * in_channels + 1] = 1
|
| 158 |
+
if start_image is not None:
|
| 159 |
+
start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1)
|
| 160 |
+
encoded = vae.encode(start_image[:, :, :, :3])
|
| 161 |
+
cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded
|
| 162 |
+
cond_latent[:, in_channels + 1, 0] = 1
|
| 163 |
+
|
| 164 |
+
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
|
| 165 |
+
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
|
| 166 |
+
if clip_vision_output is not None:
|
| 167 |
+
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
| 168 |
+
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
| 169 |
+
|
| 170 |
+
return io.NodeOutput(positive, negative, latent)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class LatentUpscaleModelLoader(io.ComfyNode):
|
| 174 |
+
@classmethod
|
| 175 |
+
def define_schema(cls):
|
| 176 |
+
return io.Schema(
|
| 177 |
+
node_id="LatentUpscaleModelLoader",
|
| 178 |
+
display_name="Load Latent Upscale Model",
|
| 179 |
+
category="loaders",
|
| 180 |
+
inputs=[
|
| 181 |
+
io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")),
|
| 182 |
+
],
|
| 183 |
+
outputs=[
|
| 184 |
+
io.LatentUpscaleModel.Output(),
|
| 185 |
+
],
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
@classmethod
|
| 189 |
+
def execute(cls, model_name) -> io.NodeOutput:
|
| 190 |
+
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
| 191 |
+
sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
|
| 192 |
+
|
| 193 |
+
if "blocks.0.block.0.conv.weight" in sd:
|
| 194 |
+
config = {
|
| 195 |
+
"in_channels": sd["in_conv.conv.weight"].shape[1],
|
| 196 |
+
"out_channels": sd["out_conv.conv.weight"].shape[0],
|
| 197 |
+
"hidden_channels": sd["in_conv.conv.weight"].shape[0],
|
| 198 |
+
"num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]),
|
| 199 |
+
"global_residual": False,
|
| 200 |
+
}
|
| 201 |
+
model_type = "720p"
|
| 202 |
+
model = HunyuanVideo15SRModel(model_type, config)
|
| 203 |
+
model.load_sd(sd)
|
| 204 |
+
elif "up.0.block.0.conv1.conv.weight" in sd:
|
| 205 |
+
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
| 206 |
+
config = {
|
| 207 |
+
"z_channels": sd["conv_in.conv.weight"].shape[1],
|
| 208 |
+
"out_channels": sd["conv_out.conv.weight"].shape[0],
|
| 209 |
+
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
| 210 |
+
}
|
| 211 |
+
model_type = "1080p"
|
| 212 |
+
model = HunyuanVideo15SRModel(model_type, config)
|
| 213 |
+
model.load_sd(sd)
|
| 214 |
+
elif "post_upsample_res_blocks.0.conv2.bias" in sd:
|
| 215 |
+
config = json.loads(metadata["config"])
|
| 216 |
+
model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
|
| 217 |
+
model.load_state_dict(sd)
|
| 218 |
+
|
| 219 |
+
return io.NodeOutput(model)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode):
|
| 223 |
+
@classmethod
|
| 224 |
+
def define_schema(cls):
|
| 225 |
+
return io.Schema(
|
| 226 |
+
node_id="HunyuanVideo15LatentUpscaleWithModel",
|
| 227 |
+
display_name="Hunyuan Video 15 Latent Upscale With Model",
|
| 228 |
+
category="latent",
|
| 229 |
+
inputs=[
|
| 230 |
+
io.LatentUpscaleModel.Input("model"),
|
| 231 |
+
io.Latent.Input("samples"),
|
| 232 |
+
io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"),
|
| 233 |
+
io.Int.Input("width", default=1280, min=0, max=16384, step=8),
|
| 234 |
+
io.Int.Input("height", default=720, min=0, max=16384, step=8),
|
| 235 |
+
io.Combo.Input("crop", options=["disabled", "center"]),
|
| 236 |
+
],
|
| 237 |
+
outputs=[
|
| 238 |
+
io.Latent.Output(),
|
| 239 |
+
],
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput:
|
| 244 |
+
if width == 0 and height == 0:
|
| 245 |
+
return io.NodeOutput(samples)
|
| 246 |
+
else:
|
| 247 |
+
if width == 0:
|
| 248 |
+
height = max(64, height)
|
| 249 |
+
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
|
| 250 |
+
elif height == 0:
|
| 251 |
+
width = max(64, width)
|
| 252 |
+
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
|
| 253 |
+
else:
|
| 254 |
+
width = max(64, width)
|
| 255 |
+
height = max(64, height)
|
| 256 |
+
s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop)
|
| 257 |
+
s = model.resample_latent(s)
|
| 258 |
+
return io.NodeOutput({"samples": s.cpu().float()})
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
| 262 |
+
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
| 263 |
+
"1. The main content and theme of the video."
|
| 264 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
| 265 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
| 266 |
+
"4. background environment, light, style and atmosphere."
|
| 267 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
| 268 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
| 269 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
|
| 273 |
+
@classmethod
|
| 274 |
+
def define_schema(cls):
|
| 275 |
+
return io.Schema(
|
| 276 |
+
node_id="TextEncodeHunyuanVideo_ImageToVideo",
|
| 277 |
+
category="advanced/conditioning",
|
| 278 |
+
inputs=[
|
| 279 |
+
io.Clip.Input("clip"),
|
| 280 |
+
io.ClipVisionOutput.Input("clip_vision_output"),
|
| 281 |
+
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
| 282 |
+
io.Int.Input(
|
| 283 |
+
"image_interleave",
|
| 284 |
+
default=2,
|
| 285 |
+
min=1,
|
| 286 |
+
max=512,
|
| 287 |
+
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
|
| 288 |
+
advanced=True,
|
| 289 |
+
),
|
| 290 |
+
],
|
| 291 |
+
outputs=[
|
| 292 |
+
io.Conditioning.Output(),
|
| 293 |
+
],
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
@classmethod
|
| 297 |
+
def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput:
|
| 298 |
+
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
|
| 299 |
+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
| 300 |
+
|
| 301 |
+
encode = execute # TODO: remove
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class HunyuanImageToVideo(io.ComfyNode):
|
| 305 |
+
@classmethod
|
| 306 |
+
def define_schema(cls):
|
| 307 |
+
return io.Schema(
|
| 308 |
+
node_id="HunyuanImageToVideo",
|
| 309 |
+
category="conditioning/video_models",
|
| 310 |
+
inputs=[
|
| 311 |
+
io.Conditioning.Input("positive"),
|
| 312 |
+
io.Vae.Input("vae"),
|
| 313 |
+
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 314 |
+
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 315 |
+
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
| 316 |
+
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
| 317 |
+
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"], advanced=True),
|
| 318 |
+
io.Image.Input("start_image", optional=True),
|
| 319 |
+
],
|
| 320 |
+
outputs=[
|
| 321 |
+
io.Conditioning.Output(display_name="positive"),
|
| 322 |
+
io.Latent.Output(display_name="latent"),
|
| 323 |
+
],
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
@classmethod
|
| 327 |
+
def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput:
|
| 328 |
+
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
| 329 |
+
out_latent = {}
|
| 330 |
+
|
| 331 |
+
if start_image is not None:
|
| 332 |
+
start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
| 333 |
+
|
| 334 |
+
concat_latent_image = vae.encode(start_image)
|
| 335 |
+
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
| 336 |
+
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
| 337 |
+
|
| 338 |
+
if guidance_type == "v1 (concat)":
|
| 339 |
+
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
|
| 340 |
+
elif guidance_type == "v2 (replace)":
|
| 341 |
+
cond = {'guiding_frame_index': 0}
|
| 342 |
+
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
|
| 343 |
+
out_latent["noise_mask"] = mask
|
| 344 |
+
elif guidance_type == "custom":
|
| 345 |
+
cond = {"ref_latent": concat_latent_image}
|
| 346 |
+
|
| 347 |
+
positive = node_helpers.conditioning_set_values(positive, cond)
|
| 348 |
+
|
| 349 |
+
out_latent["samples"] = latent
|
| 350 |
+
return io.NodeOutput(positive, out_latent)
|
| 351 |
+
|
| 352 |
+
encode = execute # TODO: remove
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class EmptyHunyuanImageLatent(io.ComfyNode):
|
| 356 |
+
@classmethod
|
| 357 |
+
def define_schema(cls):
|
| 358 |
+
return io.Schema(
|
| 359 |
+
node_id="EmptyHunyuanImageLatent",
|
| 360 |
+
category="latent",
|
| 361 |
+
inputs=[
|
| 362 |
+
io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
| 363 |
+
io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
| 364 |
+
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
| 365 |
+
],
|
| 366 |
+
outputs=[
|
| 367 |
+
io.Latent.Output(),
|
| 368 |
+
],
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
@classmethod
|
| 372 |
+
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
| 373 |
+
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
| 374 |
+
return io.NodeOutput({"samples":latent})
|
| 375 |
+
|
| 376 |
+
generate = execute # TODO: remove
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class HunyuanRefinerLatent(io.ComfyNode):
|
| 380 |
+
@classmethod
|
| 381 |
+
def define_schema(cls):
|
| 382 |
+
return io.Schema(
|
| 383 |
+
node_id="HunyuanRefinerLatent",
|
| 384 |
+
inputs=[
|
| 385 |
+
io.Conditioning.Input("positive"),
|
| 386 |
+
io.Conditioning.Input("negative"),
|
| 387 |
+
io.Latent.Input("latent"),
|
| 388 |
+
io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01, advanced=True),
|
| 389 |
+
|
| 390 |
+
],
|
| 391 |
+
outputs=[
|
| 392 |
+
io.Conditioning.Output(display_name="positive"),
|
| 393 |
+
io.Conditioning.Output(display_name="negative"),
|
| 394 |
+
io.Latent.Output(display_name="latent"),
|
| 395 |
+
],
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
@classmethod
|
| 399 |
+
def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput:
|
| 400 |
+
latent = latent["samples"]
|
| 401 |
+
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
|
| 402 |
+
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
|
| 403 |
+
out_latent = {}
|
| 404 |
+
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
| 405 |
+
return io.NodeOutput(positive, negative, out_latent)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class HunyuanExtension(ComfyExtension):
|
| 409 |
+
@override
|
| 410 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 411 |
+
return [
|
| 412 |
+
CLIPTextEncodeHunyuanDiT,
|
| 413 |
+
TextEncodeHunyuanVideo_ImageToVideo,
|
| 414 |
+
EmptyHunyuanLatentVideo,
|
| 415 |
+
EmptyHunyuanVideo15Latent,
|
| 416 |
+
HunyuanVideo15ImageToVideo,
|
| 417 |
+
HunyuanVideo15SuperResolution,
|
| 418 |
+
HunyuanVideo15LatentUpscaleWithModel,
|
| 419 |
+
LatentUpscaleModelLoader,
|
| 420 |
+
HunyuanImageToVideo,
|
| 421 |
+
EmptyHunyuanImageLatent,
|
| 422 |
+
HunyuanRefinerLatent,
|
| 423 |
+
]
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
async def comfy_entrypoint() -> HunyuanExtension:
|
| 427 |
+
return HunyuanExtension()
|
ComfyUI/comfy_extras/nodes_hunyuan3d.py
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import struct
|
| 5 |
+
import numpy as np
|
| 6 |
+
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
|
| 7 |
+
import folder_paths
|
| 8 |
+
import comfy.model_management
|
| 9 |
+
from comfy.cli_args import args
|
| 10 |
+
from typing_extensions import override
|
| 11 |
+
from comfy_api.latest import ComfyExtension, IO, Types
|
| 12 |
+
from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
|
| 16 |
+
@classmethod
|
| 17 |
+
def define_schema(cls):
|
| 18 |
+
return IO.Schema(
|
| 19 |
+
node_id="EmptyLatentHunyuan3Dv2",
|
| 20 |
+
category="latent/3d",
|
| 21 |
+
inputs=[
|
| 22 |
+
IO.Int.Input("resolution", default=3072, min=1, max=8192),
|
| 23 |
+
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
| 24 |
+
],
|
| 25 |
+
outputs=[
|
| 26 |
+
IO.Latent.Output(),
|
| 27 |
+
]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def execute(cls, resolution, batch_size) -> IO.NodeOutput:
|
| 32 |
+
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
| 33 |
+
return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"})
|
| 34 |
+
|
| 35 |
+
generate = execute # TODO: remove
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Hunyuan3Dv2Conditioning(IO.ComfyNode):
|
| 39 |
+
@classmethod
|
| 40 |
+
def define_schema(cls):
|
| 41 |
+
return IO.Schema(
|
| 42 |
+
node_id="Hunyuan3Dv2Conditioning",
|
| 43 |
+
category="conditioning/video_models",
|
| 44 |
+
inputs=[
|
| 45 |
+
IO.ClipVisionOutput.Input("clip_vision_output"),
|
| 46 |
+
],
|
| 47 |
+
outputs=[
|
| 48 |
+
IO.Conditioning.Output(display_name="positive"),
|
| 49 |
+
IO.Conditioning.Output(display_name="negative"),
|
| 50 |
+
]
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def execute(cls, clip_vision_output) -> IO.NodeOutput:
|
| 55 |
+
embeds = clip_vision_output.last_hidden_state
|
| 56 |
+
positive = [[embeds, {}]]
|
| 57 |
+
negative = [[torch.zeros_like(embeds), {}]]
|
| 58 |
+
return IO.NodeOutput(positive, negative)
|
| 59 |
+
|
| 60 |
+
encode = execute # TODO: remove
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
|
| 64 |
+
@classmethod
|
| 65 |
+
def define_schema(cls):
|
| 66 |
+
return IO.Schema(
|
| 67 |
+
node_id="Hunyuan3Dv2ConditioningMultiView",
|
| 68 |
+
category="conditioning/video_models",
|
| 69 |
+
inputs=[
|
| 70 |
+
IO.ClipVisionOutput.Input("front", optional=True),
|
| 71 |
+
IO.ClipVisionOutput.Input("left", optional=True),
|
| 72 |
+
IO.ClipVisionOutput.Input("back", optional=True),
|
| 73 |
+
IO.ClipVisionOutput.Input("right", optional=True),
|
| 74 |
+
],
|
| 75 |
+
outputs=[
|
| 76 |
+
IO.Conditioning.Output(display_name="positive"),
|
| 77 |
+
IO.Conditioning.Output(display_name="negative"),
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput:
|
| 83 |
+
all_embeds = [front, left, back, right]
|
| 84 |
+
out = []
|
| 85 |
+
pos_embeds = None
|
| 86 |
+
for i, e in enumerate(all_embeds):
|
| 87 |
+
if e is not None:
|
| 88 |
+
if pos_embeds is None:
|
| 89 |
+
pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
|
| 90 |
+
out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
|
| 91 |
+
|
| 92 |
+
embeds = torch.cat(out, dim=1)
|
| 93 |
+
positive = [[embeds, {}]]
|
| 94 |
+
negative = [[torch.zeros_like(embeds), {}]]
|
| 95 |
+
return IO.NodeOutput(positive, negative)
|
| 96 |
+
|
| 97 |
+
encode = execute # TODO: remove
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class VAEDecodeHunyuan3D(IO.ComfyNode):
|
| 101 |
+
@classmethod
|
| 102 |
+
def define_schema(cls):
|
| 103 |
+
return IO.Schema(
|
| 104 |
+
node_id="VAEDecodeHunyuan3D",
|
| 105 |
+
category="latent/3d",
|
| 106 |
+
inputs=[
|
| 107 |
+
IO.Latent.Input("samples"),
|
| 108 |
+
IO.Vae.Input("vae"),
|
| 109 |
+
IO.Int.Input("num_chunks", default=8000, min=1000, max=500000, advanced=True),
|
| 110 |
+
IO.Int.Input("octree_resolution", default=256, min=16, max=512, advanced=True),
|
| 111 |
+
],
|
| 112 |
+
outputs=[
|
| 113 |
+
IO.Voxel.Output(),
|
| 114 |
+
]
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
@classmethod
|
| 118 |
+
def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput:
|
| 119 |
+
voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
| 120 |
+
return IO.NodeOutput(voxels)
|
| 121 |
+
|
| 122 |
+
decode = execute # TODO: remove
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
| 126 |
+
if device is None:
|
| 127 |
+
device = torch.device("cpu")
|
| 128 |
+
voxels = voxels.to(device)
|
| 129 |
+
|
| 130 |
+
binary = (voxels > threshold).float()
|
| 131 |
+
padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
|
| 132 |
+
|
| 133 |
+
D, H, W = binary.shape
|
| 134 |
+
|
| 135 |
+
neighbors = torch.tensor([
|
| 136 |
+
[0, 0, 1],
|
| 137 |
+
[0, 0, -1],
|
| 138 |
+
[0, 1, 0],
|
| 139 |
+
[0, -1, 0],
|
| 140 |
+
[1, 0, 0],
|
| 141 |
+
[-1, 0, 0]
|
| 142 |
+
], device=device)
|
| 143 |
+
|
| 144 |
+
z, y, x = torch.meshgrid(
|
| 145 |
+
torch.arange(D, device=device),
|
| 146 |
+
torch.arange(H, device=device),
|
| 147 |
+
torch.arange(W, device=device),
|
| 148 |
+
indexing='ij'
|
| 149 |
+
)
|
| 150 |
+
voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
|
| 151 |
+
|
| 152 |
+
solid_mask = binary.flatten() > 0
|
| 153 |
+
solid_indices = voxel_indices[solid_mask]
|
| 154 |
+
|
| 155 |
+
corner_offsets = [
|
| 156 |
+
torch.tensor([
|
| 157 |
+
[0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
|
| 158 |
+
], device=device),
|
| 159 |
+
torch.tensor([
|
| 160 |
+
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
|
| 161 |
+
], device=device),
|
| 162 |
+
torch.tensor([
|
| 163 |
+
[0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
|
| 164 |
+
], device=device),
|
| 165 |
+
torch.tensor([
|
| 166 |
+
[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
|
| 167 |
+
], device=device),
|
| 168 |
+
torch.tensor([
|
| 169 |
+
[1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
|
| 170 |
+
], device=device),
|
| 171 |
+
torch.tensor([
|
| 172 |
+
[0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
|
| 173 |
+
], device=device)
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
all_vertices = []
|
| 177 |
+
all_indices = []
|
| 178 |
+
|
| 179 |
+
vertex_count = 0
|
| 180 |
+
|
| 181 |
+
for face_idx, offset in enumerate(neighbors):
|
| 182 |
+
neighbor_indices = solid_indices + offset
|
| 183 |
+
|
| 184 |
+
padded_indices = neighbor_indices + 1
|
| 185 |
+
|
| 186 |
+
is_exposed = padded[
|
| 187 |
+
padded_indices[:, 0],
|
| 188 |
+
padded_indices[:, 1],
|
| 189 |
+
padded_indices[:, 2]
|
| 190 |
+
] == 0
|
| 191 |
+
|
| 192 |
+
if not is_exposed.any():
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
exposed_indices = solid_indices[is_exposed]
|
| 196 |
+
|
| 197 |
+
corners = corner_offsets[face_idx].unsqueeze(0)
|
| 198 |
+
|
| 199 |
+
face_vertices = exposed_indices.unsqueeze(1) + corners
|
| 200 |
+
|
| 201 |
+
all_vertices.append(face_vertices.reshape(-1, 3))
|
| 202 |
+
|
| 203 |
+
num_faces = exposed_indices.shape[0]
|
| 204 |
+
face_indices = torch.arange(
|
| 205 |
+
vertex_count,
|
| 206 |
+
vertex_count + 4 * num_faces,
|
| 207 |
+
device=device
|
| 208 |
+
).reshape(-1, 4)
|
| 209 |
+
|
| 210 |
+
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1))
|
| 211 |
+
all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1))
|
| 212 |
+
|
| 213 |
+
vertex_count += 4 * num_faces
|
| 214 |
+
|
| 215 |
+
if len(all_vertices) > 0:
|
| 216 |
+
vertices = torch.cat(all_vertices, dim=0)
|
| 217 |
+
faces = torch.cat(all_indices, dim=0)
|
| 218 |
+
else:
|
| 219 |
+
vertices = torch.zeros((1, 3))
|
| 220 |
+
faces = torch.zeros((1, 3))
|
| 221 |
+
|
| 222 |
+
v_min = 0
|
| 223 |
+
v_max = max(voxels.shape)
|
| 224 |
+
|
| 225 |
+
vertices = vertices - (v_min + v_max) / 2
|
| 226 |
+
|
| 227 |
+
scale = (v_max - v_min) / 2
|
| 228 |
+
if scale > 0:
|
| 229 |
+
vertices = vertices / scale
|
| 230 |
+
|
| 231 |
+
vertices = torch.fliplr(vertices)
|
| 232 |
+
return vertices, faces
|
| 233 |
+
|
| 234 |
+
def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
|
| 235 |
+
if device is None:
|
| 236 |
+
device = torch.device("cpu")
|
| 237 |
+
voxels = voxels.to(device)
|
| 238 |
+
|
| 239 |
+
D, H, W = voxels.shape
|
| 240 |
+
|
| 241 |
+
padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
|
| 242 |
+
z, y, x = torch.meshgrid(
|
| 243 |
+
torch.arange(D, device=device),
|
| 244 |
+
torch.arange(H, device=device),
|
| 245 |
+
torch.arange(W, device=device),
|
| 246 |
+
indexing='ij'
|
| 247 |
+
)
|
| 248 |
+
cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
|
| 249 |
+
|
| 250 |
+
corner_offsets = torch.tensor([
|
| 251 |
+
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
|
| 252 |
+
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
|
| 253 |
+
], device=device)
|
| 254 |
+
|
| 255 |
+
pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0)
|
| 256 |
+
z_idx, y_idx, x_idx = pos.unbind(-1)
|
| 257 |
+
corner_values = padded[z_idx, y_idx, x_idx]
|
| 258 |
+
|
| 259 |
+
corner_signs = corner_values > threshold
|
| 260 |
+
has_inside = torch.any(corner_signs, dim=1)
|
| 261 |
+
has_outside = torch.any(~corner_signs, dim=1)
|
| 262 |
+
contains_surface = has_inside & has_outside
|
| 263 |
+
|
| 264 |
+
active_cells = cell_positions[contains_surface]
|
| 265 |
+
active_signs = corner_signs[contains_surface]
|
| 266 |
+
active_values = corner_values[contains_surface]
|
| 267 |
+
|
| 268 |
+
if active_cells.shape[0] == 0:
|
| 269 |
+
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
|
| 270 |
+
|
| 271 |
+
edges = torch.tensor([
|
| 272 |
+
[0, 1], [0, 2], [0, 4], [1, 3],
|
| 273 |
+
[1, 5], [2, 3], [2, 6], [3, 7],
|
| 274 |
+
[4, 5], [4, 6], [5, 7], [6, 7]
|
| 275 |
+
], device=device)
|
| 276 |
+
|
| 277 |
+
cell_vertices = {}
|
| 278 |
+
progress = comfy.utils.ProgressBar(100)
|
| 279 |
+
|
| 280 |
+
for edge_idx, (e1, e2) in enumerate(edges):
|
| 281 |
+
progress.update(1)
|
| 282 |
+
crossing = active_signs[:, e1] != active_signs[:, e2]
|
| 283 |
+
if not crossing.any():
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
|
| 287 |
+
|
| 288 |
+
v1 = active_values[cell_indices, e1]
|
| 289 |
+
v2 = active_values[cell_indices, e2]
|
| 290 |
+
|
| 291 |
+
t = torch.zeros_like(v1, device=device)
|
| 292 |
+
denom = v2 - v1
|
| 293 |
+
valid = denom != 0
|
| 294 |
+
t[valid] = (threshold - v1[valid]) / denom[valid]
|
| 295 |
+
t[~valid] = 0.5
|
| 296 |
+
|
| 297 |
+
p1 = corner_offsets[e1].float()
|
| 298 |
+
p2 = corner_offsets[e2].float()
|
| 299 |
+
|
| 300 |
+
intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
|
| 301 |
+
|
| 302 |
+
for i, point in zip(cell_indices.tolist(), intersection):
|
| 303 |
+
if i not in cell_vertices:
|
| 304 |
+
cell_vertices[i] = []
|
| 305 |
+
cell_vertices[i].append(point)
|
| 306 |
+
|
| 307 |
+
# Calculate the final vertices as the average of intersection points for each cell
|
| 308 |
+
vertices = []
|
| 309 |
+
vertex_lookup = {}
|
| 310 |
+
|
| 311 |
+
vert_progress_mod = round(len(cell_vertices)/50)
|
| 312 |
+
|
| 313 |
+
for i, points in cell_vertices.items():
|
| 314 |
+
if not i % vert_progress_mod:
|
| 315 |
+
progress.update(1)
|
| 316 |
+
|
| 317 |
+
if points:
|
| 318 |
+
vertex = torch.stack(points).mean(dim=0)
|
| 319 |
+
vertex = vertex + active_cells[i].float()
|
| 320 |
+
vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
|
| 321 |
+
vertices.append(vertex)
|
| 322 |
+
|
| 323 |
+
if not vertices:
|
| 324 |
+
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
|
| 325 |
+
|
| 326 |
+
final_vertices = torch.stack(vertices)
|
| 327 |
+
|
| 328 |
+
inside_corners_mask = active_signs
|
| 329 |
+
outside_corners_mask = ~active_signs
|
| 330 |
+
|
| 331 |
+
inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
|
| 332 |
+
outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
|
| 333 |
+
|
| 334 |
+
inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
|
| 335 |
+
outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
|
| 336 |
+
|
| 337 |
+
for i in range(8):
|
| 338 |
+
mask_inside = inside_corners_mask[:, i].unsqueeze(1)
|
| 339 |
+
mask_outside = outside_corners_mask[:, i].unsqueeze(1)
|
| 340 |
+
inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
|
| 341 |
+
outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
|
| 342 |
+
|
| 343 |
+
inside_pos /= inside_counts
|
| 344 |
+
outside_pos /= outside_counts
|
| 345 |
+
gradients = inside_pos - outside_pos
|
| 346 |
+
|
| 347 |
+
pos_dirs = torch.tensor([
|
| 348 |
+
[1, 0, 0],
|
| 349 |
+
[0, 1, 0],
|
| 350 |
+
[0, 0, 1]
|
| 351 |
+
], device=device)
|
| 352 |
+
|
| 353 |
+
cross_products = [
|
| 354 |
+
torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
|
| 355 |
+
for i in range(3) for j in range(i+1, 3)
|
| 356 |
+
]
|
| 357 |
+
|
| 358 |
+
faces = []
|
| 359 |
+
all_keys = set(vertex_lookup.keys())
|
| 360 |
+
|
| 361 |
+
face_progress_mod = round(len(active_cells)/38*3)
|
| 362 |
+
|
| 363 |
+
for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
|
| 364 |
+
dir_i = pos_dirs[i]
|
| 365 |
+
dir_j = pos_dirs[j]
|
| 366 |
+
cross_product = cross_products[pair_idx]
|
| 367 |
+
|
| 368 |
+
ni_positions = active_cells + dir_i
|
| 369 |
+
nj_positions = active_cells + dir_j
|
| 370 |
+
diag_positions = active_cells + dir_i + dir_j
|
| 371 |
+
|
| 372 |
+
alignments = torch.matmul(gradients, cross_product)
|
| 373 |
+
|
| 374 |
+
valid_quads = []
|
| 375 |
+
quad_indices = []
|
| 376 |
+
|
| 377 |
+
for idx, active_cell in enumerate(active_cells):
|
| 378 |
+
if not idx % face_progress_mod:
|
| 379 |
+
progress.update(1)
|
| 380 |
+
cell_key = tuple(active_cell.tolist())
|
| 381 |
+
ni_key = tuple(ni_positions[idx].tolist())
|
| 382 |
+
nj_key = tuple(nj_positions[idx].tolist())
|
| 383 |
+
diag_key = tuple(diag_positions[idx].tolist())
|
| 384 |
+
|
| 385 |
+
if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
|
| 386 |
+
v0 = vertex_lookup[cell_key]
|
| 387 |
+
v1 = vertex_lookup[ni_key]
|
| 388 |
+
v2 = vertex_lookup[nj_key]
|
| 389 |
+
v3 = vertex_lookup[diag_key]
|
| 390 |
+
|
| 391 |
+
valid_quads.append((v0, v1, v2, v3))
|
| 392 |
+
quad_indices.append(idx)
|
| 393 |
+
|
| 394 |
+
for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
|
| 395 |
+
cell_idx = quad_indices[q_idx]
|
| 396 |
+
if alignments[cell_idx] > 0:
|
| 397 |
+
faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
|
| 398 |
+
faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
|
| 399 |
+
else:
|
| 400 |
+
faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
|
| 401 |
+
faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
|
| 402 |
+
|
| 403 |
+
if faces:
|
| 404 |
+
faces = torch.stack(faces)
|
| 405 |
+
else:
|
| 406 |
+
faces = torch.zeros((0, 3), dtype=torch.long, device=device)
|
| 407 |
+
|
| 408 |
+
v_min = 0
|
| 409 |
+
v_max = max(D, H, W)
|
| 410 |
+
|
| 411 |
+
final_vertices = final_vertices - (v_min + v_max) / 2
|
| 412 |
+
|
| 413 |
+
scale = (v_max - v_min) / 2
|
| 414 |
+
if scale > 0:
|
| 415 |
+
final_vertices = final_vertices / scale
|
| 416 |
+
|
| 417 |
+
final_vertices = torch.fliplr(final_vertices)
|
| 418 |
+
|
| 419 |
+
return final_vertices, faces
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
class VoxelToMeshBasic(IO.ComfyNode):
|
| 423 |
+
@classmethod
|
| 424 |
+
def define_schema(cls):
|
| 425 |
+
return IO.Schema(
|
| 426 |
+
node_id="VoxelToMeshBasic",
|
| 427 |
+
category="3d",
|
| 428 |
+
inputs=[
|
| 429 |
+
IO.Voxel.Input("voxel"),
|
| 430 |
+
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
| 431 |
+
],
|
| 432 |
+
outputs=[
|
| 433 |
+
IO.Mesh.Output(),
|
| 434 |
+
]
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
@classmethod
|
| 438 |
+
def execute(cls, voxel, threshold) -> IO.NodeOutput:
|
| 439 |
+
vertices = []
|
| 440 |
+
faces = []
|
| 441 |
+
for x in voxel.data:
|
| 442 |
+
v, f = voxel_to_mesh(x, threshold=threshold, device=None)
|
| 443 |
+
vertices.append(v)
|
| 444 |
+
faces.append(f)
|
| 445 |
+
|
| 446 |
+
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
| 447 |
+
|
| 448 |
+
decode = execute # TODO: remove
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class VoxelToMesh(IO.ComfyNode):
|
| 452 |
+
@classmethod
|
| 453 |
+
def define_schema(cls):
|
| 454 |
+
return IO.Schema(
|
| 455 |
+
node_id="VoxelToMesh",
|
| 456 |
+
category="3d",
|
| 457 |
+
inputs=[
|
| 458 |
+
IO.Voxel.Input("voxel"),
|
| 459 |
+
IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
|
| 460 |
+
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
| 461 |
+
],
|
| 462 |
+
outputs=[
|
| 463 |
+
IO.Mesh.Output(),
|
| 464 |
+
]
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
@classmethod
|
| 468 |
+
def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput:
|
| 469 |
+
vertices = []
|
| 470 |
+
faces = []
|
| 471 |
+
|
| 472 |
+
if algorithm == "basic":
|
| 473 |
+
mesh_function = voxel_to_mesh
|
| 474 |
+
elif algorithm == "surface net":
|
| 475 |
+
mesh_function = voxel_to_mesh_surfnet
|
| 476 |
+
|
| 477 |
+
for x in voxel.data:
|
| 478 |
+
v, f = mesh_function(x, threshold=threshold, device=None)
|
| 479 |
+
vertices.append(v)
|
| 480 |
+
faces.append(f)
|
| 481 |
+
|
| 482 |
+
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
| 483 |
+
|
| 484 |
+
decode = execute # TODO: remove
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def save_glb(vertices, faces, filepath, metadata=None):
|
| 488 |
+
"""
|
| 489 |
+
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
| 490 |
+
|
| 491 |
+
Parameters:
|
| 492 |
+
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
|
| 493 |
+
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
|
| 494 |
+
filepath: str - Output filepath (should end with .glb)
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
# Convert tensors to numpy arrays
|
| 498 |
+
vertices_np = vertices.cpu().numpy().astype(np.float32)
|
| 499 |
+
faces_np = faces.cpu().numpy().astype(np.uint32)
|
| 500 |
+
|
| 501 |
+
vertices_buffer = vertices_np.tobytes()
|
| 502 |
+
indices_buffer = faces_np.tobytes()
|
| 503 |
+
|
| 504 |
+
def pad_to_4_bytes(buffer):
|
| 505 |
+
padding_length = (4 - (len(buffer) % 4)) % 4
|
| 506 |
+
return buffer + b'\x00' * padding_length
|
| 507 |
+
|
| 508 |
+
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
|
| 509 |
+
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
|
| 510 |
+
|
| 511 |
+
buffer_data = vertices_buffer_padded + indices_buffer_padded
|
| 512 |
+
|
| 513 |
+
vertices_byte_length = len(vertices_buffer)
|
| 514 |
+
vertices_byte_offset = 0
|
| 515 |
+
indices_byte_length = len(indices_buffer)
|
| 516 |
+
indices_byte_offset = len(vertices_buffer_padded)
|
| 517 |
+
|
| 518 |
+
gltf = {
|
| 519 |
+
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
| 520 |
+
"buffers": [
|
| 521 |
+
{
|
| 522 |
+
"byteLength": len(buffer_data)
|
| 523 |
+
}
|
| 524 |
+
],
|
| 525 |
+
"bufferViews": [
|
| 526 |
+
{
|
| 527 |
+
"buffer": 0,
|
| 528 |
+
"byteOffset": vertices_byte_offset,
|
| 529 |
+
"byteLength": vertices_byte_length,
|
| 530 |
+
"target": 34962 # ARRAY_BUFFER
|
| 531 |
+
},
|
| 532 |
+
{
|
| 533 |
+
"buffer": 0,
|
| 534 |
+
"byteOffset": indices_byte_offset,
|
| 535 |
+
"byteLength": indices_byte_length,
|
| 536 |
+
"target": 34963 # ELEMENT_ARRAY_BUFFER
|
| 537 |
+
}
|
| 538 |
+
],
|
| 539 |
+
"accessors": [
|
| 540 |
+
{
|
| 541 |
+
"bufferView": 0,
|
| 542 |
+
"byteOffset": 0,
|
| 543 |
+
"componentType": 5126, # FLOAT
|
| 544 |
+
"count": len(vertices_np),
|
| 545 |
+
"type": "VEC3",
|
| 546 |
+
"max": vertices_np.max(axis=0).tolist(),
|
| 547 |
+
"min": vertices_np.min(axis=0).tolist()
|
| 548 |
+
},
|
| 549 |
+
{
|
| 550 |
+
"bufferView": 1,
|
| 551 |
+
"byteOffset": 0,
|
| 552 |
+
"componentType": 5125, # UNSIGNED_INT
|
| 553 |
+
"count": faces_np.size,
|
| 554 |
+
"type": "SCALAR"
|
| 555 |
+
}
|
| 556 |
+
],
|
| 557 |
+
"meshes": [
|
| 558 |
+
{
|
| 559 |
+
"primitives": [
|
| 560 |
+
{
|
| 561 |
+
"attributes": {
|
| 562 |
+
"POSITION": 0
|
| 563 |
+
},
|
| 564 |
+
"indices": 1,
|
| 565 |
+
"mode": 4 # TRIANGLES
|
| 566 |
+
}
|
| 567 |
+
]
|
| 568 |
+
}
|
| 569 |
+
],
|
| 570 |
+
"nodes": [
|
| 571 |
+
{
|
| 572 |
+
"mesh": 0
|
| 573 |
+
}
|
| 574 |
+
],
|
| 575 |
+
"scenes": [
|
| 576 |
+
{
|
| 577 |
+
"nodes": [0]
|
| 578 |
+
}
|
| 579 |
+
],
|
| 580 |
+
"scene": 0
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
if metadata is not None:
|
| 584 |
+
gltf["asset"]["extras"] = metadata
|
| 585 |
+
|
| 586 |
+
# Convert the JSON to bytes
|
| 587 |
+
gltf_json = json.dumps(gltf).encode('utf8')
|
| 588 |
+
|
| 589 |
+
def pad_json_to_4_bytes(buffer):
|
| 590 |
+
padding_length = (4 - (len(buffer) % 4)) % 4
|
| 591 |
+
return buffer + b' ' * padding_length
|
| 592 |
+
|
| 593 |
+
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
|
| 594 |
+
|
| 595 |
+
# Create the GLB header
|
| 596 |
+
# Magic glTF
|
| 597 |
+
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
|
| 598 |
+
|
| 599 |
+
# Create JSON chunk header (chunk type 0)
|
| 600 |
+
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
|
| 601 |
+
|
| 602 |
+
# Create BIN chunk header (chunk type 1)
|
| 603 |
+
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
|
| 604 |
+
|
| 605 |
+
# Write the GLB file
|
| 606 |
+
with open(filepath, 'wb') as f:
|
| 607 |
+
f.write(glb_header)
|
| 608 |
+
f.write(json_chunk_header)
|
| 609 |
+
f.write(gltf_json_padded)
|
| 610 |
+
f.write(bin_chunk_header)
|
| 611 |
+
f.write(buffer_data)
|
| 612 |
+
|
| 613 |
+
return filepath
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class SaveGLB(IO.ComfyNode):
|
| 617 |
+
@classmethod
|
| 618 |
+
def define_schema(cls):
|
| 619 |
+
return IO.Schema(
|
| 620 |
+
node_id="SaveGLB",
|
| 621 |
+
display_name="Save 3D Model",
|
| 622 |
+
search_aliases=["export 3d model", "save mesh"],
|
| 623 |
+
category="3d",
|
| 624 |
+
essentials_category="Basics",
|
| 625 |
+
is_output_node=True,
|
| 626 |
+
inputs=[
|
| 627 |
+
IO.MultiType.Input(
|
| 628 |
+
IO.Mesh.Input("mesh"),
|
| 629 |
+
types=[
|
| 630 |
+
IO.File3DGLB,
|
| 631 |
+
IO.File3DGLTF,
|
| 632 |
+
IO.File3DOBJ,
|
| 633 |
+
IO.File3DFBX,
|
| 634 |
+
IO.File3DSTL,
|
| 635 |
+
IO.File3DUSDZ,
|
| 636 |
+
IO.File3DAny,
|
| 637 |
+
],
|
| 638 |
+
tooltip="Mesh or 3D file to save",
|
| 639 |
+
),
|
| 640 |
+
IO.String.Input("filename_prefix", default="3d/ComfyUI"),
|
| 641 |
+
],
|
| 642 |
+
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
@classmethod
|
| 646 |
+
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
|
| 647 |
+
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
| 648 |
+
results = []
|
| 649 |
+
|
| 650 |
+
metadata = {}
|
| 651 |
+
if not args.disable_metadata:
|
| 652 |
+
if cls.hidden.prompt is not None:
|
| 653 |
+
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
| 654 |
+
if cls.hidden.extra_pnginfo is not None:
|
| 655 |
+
for x in cls.hidden.extra_pnginfo:
|
| 656 |
+
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
| 657 |
+
|
| 658 |
+
if isinstance(mesh, Types.File3D):
|
| 659 |
+
# Handle File3D input - save BytesIO data to output folder
|
| 660 |
+
ext = mesh.format or "glb"
|
| 661 |
+
f = f"{filename}_{counter:05}_.{ext}"
|
| 662 |
+
mesh.save_to(os.path.join(full_output_folder, f))
|
| 663 |
+
results.append({
|
| 664 |
+
"filename": f,
|
| 665 |
+
"subfolder": subfolder,
|
| 666 |
+
"type": "output"
|
| 667 |
+
})
|
| 668 |
+
else:
|
| 669 |
+
# Handle Mesh input - save vertices and faces as GLB
|
| 670 |
+
for i in range(mesh.vertices.shape[0]):
|
| 671 |
+
f = f"{filename}_{counter:05}_.glb"
|
| 672 |
+
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
| 673 |
+
results.append({
|
| 674 |
+
"filename": f,
|
| 675 |
+
"subfolder": subfolder,
|
| 676 |
+
"type": "output"
|
| 677 |
+
})
|
| 678 |
+
counter += 1
|
| 679 |
+
return IO.NodeOutput(ui={"3d": results})
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
class Hunyuan3dExtension(ComfyExtension):
|
| 683 |
+
@override
|
| 684 |
+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
| 685 |
+
return [
|
| 686 |
+
EmptyLatentHunyuan3Dv2,
|
| 687 |
+
Hunyuan3Dv2Conditioning,
|
| 688 |
+
Hunyuan3Dv2ConditioningMultiView,
|
| 689 |
+
VAEDecodeHunyuan3D,
|
| 690 |
+
VoxelToMeshBasic,
|
| 691 |
+
VoxelToMesh,
|
| 692 |
+
SaveGLB,
|
| 693 |
+
]
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
async def comfy_entrypoint() -> Hunyuan3dExtension:
|
| 697 |
+
return Hunyuan3dExtension()
|
ComfyUI/comfy_extras/nodes_hypernetwork.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import comfy.utils
|
| 2 |
+
import folder_paths
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
from comfy_api.latest import IO, ComfyExtension
|
| 6 |
+
from typing_extensions import override
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_hypernetwork_patch(path, strength):
|
| 10 |
+
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
| 11 |
+
activation_func = sd.get('activation_func', 'linear')
|
| 12 |
+
is_layer_norm = sd.get('is_layer_norm', False)
|
| 13 |
+
use_dropout = sd.get('use_dropout', False)
|
| 14 |
+
activate_output = sd.get('activate_output', False)
|
| 15 |
+
last_layer_dropout = sd.get('last_layer_dropout', False)
|
| 16 |
+
|
| 17 |
+
valid_activation = {
|
| 18 |
+
"linear": torch.nn.Identity,
|
| 19 |
+
"relu": torch.nn.ReLU,
|
| 20 |
+
"leakyrelu": torch.nn.LeakyReLU,
|
| 21 |
+
"elu": torch.nn.ELU,
|
| 22 |
+
"swish": torch.nn.Hardswish,
|
| 23 |
+
"tanh": torch.nn.Tanh,
|
| 24 |
+
"sigmoid": torch.nn.Sigmoid,
|
| 25 |
+
"softsign": torch.nn.Softsign,
|
| 26 |
+
"mish": torch.nn.Mish,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
if activation_func not in valid_activation:
|
| 30 |
+
logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
out = {}
|
| 34 |
+
|
| 35 |
+
for d in sd:
|
| 36 |
+
try:
|
| 37 |
+
dim = int(d)
|
| 38 |
+
except:
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
output = []
|
| 42 |
+
for index in [0, 1]:
|
| 43 |
+
attn_weights = sd[dim][index]
|
| 44 |
+
keys = attn_weights.keys()
|
| 45 |
+
|
| 46 |
+
linears = filter(lambda a: a.endswith(".weight"), keys)
|
| 47 |
+
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
| 48 |
+
layers = []
|
| 49 |
+
|
| 50 |
+
i = 0
|
| 51 |
+
while i < len(linears):
|
| 52 |
+
lin_name = linears[i]
|
| 53 |
+
last_layer = (i == (len(linears) - 1))
|
| 54 |
+
penultimate_layer = (i == (len(linears) - 2))
|
| 55 |
+
|
| 56 |
+
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
| 57 |
+
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
| 58 |
+
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
| 59 |
+
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
| 60 |
+
layers.append(layer)
|
| 61 |
+
if activation_func != "linear":
|
| 62 |
+
if (not last_layer) or (activate_output):
|
| 63 |
+
layers.append(valid_activation[activation_func]())
|
| 64 |
+
if is_layer_norm:
|
| 65 |
+
i += 1
|
| 66 |
+
ln_name = linears[i]
|
| 67 |
+
ln_weight = attn_weights['{}.weight'.format(ln_name)]
|
| 68 |
+
ln_bias = attn_weights['{}.bias'.format(ln_name)]
|
| 69 |
+
ln = torch.nn.LayerNorm(ln_weight.shape[0])
|
| 70 |
+
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
|
| 71 |
+
layers.append(ln)
|
| 72 |
+
if use_dropout:
|
| 73 |
+
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
| 74 |
+
layers.append(torch.nn.Dropout(p=0.3))
|
| 75 |
+
i += 1
|
| 76 |
+
|
| 77 |
+
output.append(torch.nn.Sequential(*layers))
|
| 78 |
+
out[dim] = torch.nn.ModuleList(output)
|
| 79 |
+
|
| 80 |
+
class hypernetwork_patch:
|
| 81 |
+
def __init__(self, hypernet, strength):
|
| 82 |
+
self.hypernet = hypernet
|
| 83 |
+
self.strength = strength
|
| 84 |
+
def __call__(self, q, k, v, extra_options):
|
| 85 |
+
dim = k.shape[-1]
|
| 86 |
+
if dim in self.hypernet:
|
| 87 |
+
hn = self.hypernet[dim]
|
| 88 |
+
k = k + hn[0](k) * self.strength
|
| 89 |
+
v = v + hn[1](v) * self.strength
|
| 90 |
+
|
| 91 |
+
return q, k, v
|
| 92 |
+
|
| 93 |
+
def to(self, device):
|
| 94 |
+
for d in self.hypernet.keys():
|
| 95 |
+
self.hypernet[d] = self.hypernet[d].to(device)
|
| 96 |
+
return self
|
| 97 |
+
|
| 98 |
+
return hypernetwork_patch(out, strength)
|
| 99 |
+
|
| 100 |
+
class HypernetworkLoader(IO.ComfyNode):
|
| 101 |
+
@classmethod
|
| 102 |
+
def define_schema(cls):
|
| 103 |
+
return IO.Schema(
|
| 104 |
+
node_id="HypernetworkLoader",
|
| 105 |
+
category="loaders",
|
| 106 |
+
inputs=[
|
| 107 |
+
IO.Model.Input("model"),
|
| 108 |
+
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
| 109 |
+
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
| 110 |
+
],
|
| 111 |
+
outputs=[
|
| 112 |
+
IO.Model.Output(),
|
| 113 |
+
],
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
|
| 118 |
+
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
| 119 |
+
model_hypernetwork = model.clone()
|
| 120 |
+
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
| 121 |
+
if patch is not None:
|
| 122 |
+
model_hypernetwork.set_model_attn1_patch(patch)
|
| 123 |
+
model_hypernetwork.set_model_attn2_patch(patch)
|
| 124 |
+
return IO.NodeOutput(model_hypernetwork)
|
| 125 |
+
|
| 126 |
+
load_hypernetwork = execute # TODO: remove
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class HyperNetworkExtension(ComfyExtension):
|
| 130 |
+
@override
|
| 131 |
+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
| 132 |
+
return [
|
| 133 |
+
HypernetworkLoader,
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
async def comfy_entrypoint() -> HyperNetworkExtension:
|
| 138 |
+
return HyperNetworkExtension()
|
ComfyUI/comfy_extras/nodes_hypertile.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Taken from: https://github.com/tfernd/HyperTile/
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing_extensions import override
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
# Use torch rng for consistency across generations
|
| 7 |
+
from torch import randint
|
| 8 |
+
from comfy_api.latest import ComfyExtension, io
|
| 9 |
+
|
| 10 |
+
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
| 11 |
+
min_value = min(min_value, value)
|
| 12 |
+
|
| 13 |
+
# All big divisors of value (inclusive)
|
| 14 |
+
divisors = [i for i in range(min_value, value + 1) if value % i == 0]
|
| 15 |
+
|
| 16 |
+
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
|
| 17 |
+
|
| 18 |
+
if len(ns) - 1 > 0:
|
| 19 |
+
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
|
| 20 |
+
else:
|
| 21 |
+
idx = 0
|
| 22 |
+
|
| 23 |
+
return ns[idx]
|
| 24 |
+
|
| 25 |
+
class HyperTile(io.ComfyNode):
|
| 26 |
+
@classmethod
|
| 27 |
+
def define_schema(cls):
|
| 28 |
+
return io.Schema(
|
| 29 |
+
node_id="HyperTile",
|
| 30 |
+
category="model_patches/unet",
|
| 31 |
+
inputs=[
|
| 32 |
+
io.Model.Input("model"),
|
| 33 |
+
io.Int.Input("tile_size", default=256, min=1, max=2048, advanced=True),
|
| 34 |
+
io.Int.Input("swap_size", default=2, min=1, max=128, advanced=True),
|
| 35 |
+
io.Int.Input("max_depth", default=0, min=0, max=10, advanced=True),
|
| 36 |
+
io.Boolean.Input("scale_depth", default=False, advanced=True),
|
| 37 |
+
],
|
| 38 |
+
outputs=[
|
| 39 |
+
io.Model.Output(),
|
| 40 |
+
],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
|
| 45 |
+
latent_tile_size = max(32, tile_size) // 8
|
| 46 |
+
temp = None
|
| 47 |
+
|
| 48 |
+
def hypertile_in(q, k, v, extra_options):
|
| 49 |
+
nonlocal temp
|
| 50 |
+
model_chans = q.shape[-2]
|
| 51 |
+
orig_shape = extra_options['original_shape']
|
| 52 |
+
apply_to = []
|
| 53 |
+
for i in range(max_depth + 1):
|
| 54 |
+
apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
|
| 55 |
+
|
| 56 |
+
if model_chans in apply_to:
|
| 57 |
+
shape = extra_options["original_shape"]
|
| 58 |
+
aspect_ratio = shape[-1] / shape[-2]
|
| 59 |
+
|
| 60 |
+
hw = q.size(1)
|
| 61 |
+
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
| 62 |
+
|
| 63 |
+
factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
|
| 64 |
+
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
| 65 |
+
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
| 66 |
+
|
| 67 |
+
if nh * nw > 1:
|
| 68 |
+
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
| 69 |
+
temp = (nh, nw, h, w)
|
| 70 |
+
return q, k, v
|
| 71 |
+
|
| 72 |
+
return q, k, v
|
| 73 |
+
def hypertile_out(out, extra_options):
|
| 74 |
+
nonlocal temp
|
| 75 |
+
if temp is not None:
|
| 76 |
+
nh, nw, h, w = temp
|
| 77 |
+
temp = None
|
| 78 |
+
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
| 79 |
+
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
| 80 |
+
return out
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
m = model.clone()
|
| 84 |
+
m.set_model_attn1_patch(hypertile_in)
|
| 85 |
+
m.set_model_attn1_output_patch(hypertile_out)
|
| 86 |
+
return (m, )
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class HyperTileExtension(ComfyExtension):
|
| 90 |
+
@override
|
| 91 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 92 |
+
return [
|
| 93 |
+
HyperTile,
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
async def comfy_entrypoint() -> HyperTileExtension:
|
| 98 |
+
return HyperTileExtension()
|
ComfyUI/comfy_extras/nodes_image_compare.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nodes
|
| 2 |
+
|
| 3 |
+
from typing_extensions import override
|
| 4 |
+
from comfy_api.latest import IO, ComfyExtension
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ImageCompare(IO.ComfyNode):
|
| 8 |
+
"""Compares two images with a slider interface."""
|
| 9 |
+
|
| 10 |
+
@classmethod
|
| 11 |
+
def define_schema(cls):
|
| 12 |
+
return IO.Schema(
|
| 13 |
+
node_id="ImageCompare",
|
| 14 |
+
display_name="Image Compare",
|
| 15 |
+
description="Compares two images side by side with a slider.",
|
| 16 |
+
category="image",
|
| 17 |
+
essentials_category="Image Tools",
|
| 18 |
+
is_experimental=True,
|
| 19 |
+
is_output_node=True,
|
| 20 |
+
inputs=[
|
| 21 |
+
IO.Image.Input("image_a", optional=True),
|
| 22 |
+
IO.Image.Input("image_b", optional=True),
|
| 23 |
+
IO.ImageCompare.Input("compare_view"),
|
| 24 |
+
],
|
| 25 |
+
outputs=[],
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput:
|
| 30 |
+
result = {"a_images": [], "b_images": []}
|
| 31 |
+
|
| 32 |
+
preview_node = nodes.PreviewImage()
|
| 33 |
+
|
| 34 |
+
if image_a is not None and len(image_a) > 0:
|
| 35 |
+
saved = preview_node.save_images(image_a, "comfy.compare.a")
|
| 36 |
+
result["a_images"] = saved["ui"]["images"]
|
| 37 |
+
|
| 38 |
+
if image_b is not None and len(image_b) > 0:
|
| 39 |
+
saved = preview_node.save_images(image_b, "comfy.compare.b")
|
| 40 |
+
result["b_images"] = saved["ui"]["images"]
|
| 41 |
+
|
| 42 |
+
return IO.NodeOutput(ui=result)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ImageCompareExtension(ComfyExtension):
|
| 46 |
+
@override
|
| 47 |
+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
| 48 |
+
return [
|
| 49 |
+
ImageCompare,
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
async def comfy_entrypoint() -> ImageCompareExtension:
|
| 54 |
+
return ImageCompareExtension()
|
ComfyUI/comfy_extras/nodes_images.py
ADDED
|
@@ -0,0 +1,851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import nodes
|
| 4 |
+
import folder_paths
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import math
|
| 10 |
+
import torch
|
| 11 |
+
import comfy.utils
|
| 12 |
+
|
| 13 |
+
from server import PromptServer
|
| 14 |
+
from comfy_api.latest import ComfyExtension, IO, UI
|
| 15 |
+
from typing_extensions import override
|
| 16 |
+
|
| 17 |
+
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
|
| 18 |
+
|
| 19 |
+
MAX_RESOLUTION = nodes.MAX_RESOLUTION
|
| 20 |
+
|
| 21 |
+
class ImageCrop(IO.ComfyNode):
|
| 22 |
+
@classmethod
|
| 23 |
+
def define_schema(cls):
|
| 24 |
+
return IO.Schema(
|
| 25 |
+
node_id="ImageCrop",
|
| 26 |
+
search_aliases=["trim"],
|
| 27 |
+
display_name="Image Crop (Deprecated)",
|
| 28 |
+
category="image/transform",
|
| 29 |
+
is_deprecated=True,
|
| 30 |
+
essentials_category="Image Tools",
|
| 31 |
+
inputs=[
|
| 32 |
+
IO.Image.Input("image"),
|
| 33 |
+
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
| 34 |
+
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
| 35 |
+
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
| 36 |
+
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
| 37 |
+
],
|
| 38 |
+
outputs=[IO.Image.Output()],
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def execute(cls, image, width, height, x, y) -> IO.NodeOutput:
|
| 43 |
+
x = min(x, image.shape[2] - 1)
|
| 44 |
+
y = min(y, image.shape[1] - 1)
|
| 45 |
+
to_x = width + x
|
| 46 |
+
to_y = height + y
|
| 47 |
+
img = image[:,y:to_y, x:to_x, :]
|
| 48 |
+
return IO.NodeOutput(img)
|
| 49 |
+
|
| 50 |
+
crop = execute # TODO: remove
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ImageCropV2(IO.ComfyNode):
|
| 54 |
+
@classmethod
|
| 55 |
+
def define_schema(cls):
|
| 56 |
+
return IO.Schema(
|
| 57 |
+
node_id="ImageCropV2",
|
| 58 |
+
search_aliases=["trim"],
|
| 59 |
+
display_name="Image Crop",
|
| 60 |
+
category="image/transform",
|
| 61 |
+
essentials_category="Image Tools",
|
| 62 |
+
has_intermediate_output=True,
|
| 63 |
+
inputs=[
|
| 64 |
+
IO.Image.Input("image"),
|
| 65 |
+
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
| 66 |
+
],
|
| 67 |
+
outputs=[IO.Image.Output()],
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def execute(cls, image, crop_region) -> IO.NodeOutput:
|
| 72 |
+
x = crop_region.get("x", 0)
|
| 73 |
+
y = crop_region.get("y", 0)
|
| 74 |
+
width = crop_region.get("width", 512)
|
| 75 |
+
height = crop_region.get("height", 512)
|
| 76 |
+
|
| 77 |
+
x = min(x, image.shape[2] - 1)
|
| 78 |
+
y = min(y, image.shape[1] - 1)
|
| 79 |
+
to_x = width + x
|
| 80 |
+
to_y = height + y
|
| 81 |
+
img = image[:,y:to_y, x:to_x, :]
|
| 82 |
+
return IO.NodeOutput(img, ui=UI.PreviewImage(img))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class BoundingBox(IO.ComfyNode):
|
| 86 |
+
@classmethod
|
| 87 |
+
def define_schema(cls):
|
| 88 |
+
return IO.Schema(
|
| 89 |
+
node_id="PrimitiveBoundingBox",
|
| 90 |
+
display_name="Bounding Box",
|
| 91 |
+
category="utils/primitive",
|
| 92 |
+
inputs=[
|
| 93 |
+
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
|
| 94 |
+
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
|
| 95 |
+
IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
|
| 96 |
+
IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
|
| 97 |
+
],
|
| 98 |
+
outputs=[IO.BoundingBox.Output()],
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
@classmethod
|
| 102 |
+
def execute(cls, x, y, width, height) -> IO.NodeOutput:
|
| 103 |
+
return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class RepeatImageBatch(IO.ComfyNode):
|
| 107 |
+
@classmethod
|
| 108 |
+
def define_schema(cls):
|
| 109 |
+
return IO.Schema(
|
| 110 |
+
node_id="RepeatImageBatch",
|
| 111 |
+
search_aliases=["duplicate image", "clone image"],
|
| 112 |
+
category="image/batch",
|
| 113 |
+
inputs=[
|
| 114 |
+
IO.Image.Input("image"),
|
| 115 |
+
IO.Int.Input("amount", default=1, min=1, max=4096),
|
| 116 |
+
],
|
| 117 |
+
outputs=[IO.Image.Output()],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
@classmethod
|
| 121 |
+
def execute(cls, image, amount) -> IO.NodeOutput:
|
| 122 |
+
s = image.repeat((amount, 1,1,1))
|
| 123 |
+
return IO.NodeOutput(s)
|
| 124 |
+
|
| 125 |
+
repeat = execute # TODO: remove
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class ImageFromBatch(IO.ComfyNode):
|
| 129 |
+
@classmethod
|
| 130 |
+
def define_schema(cls):
|
| 131 |
+
return IO.Schema(
|
| 132 |
+
node_id="ImageFromBatch",
|
| 133 |
+
search_aliases=["select image", "pick from batch", "extract image"],
|
| 134 |
+
category="image/batch",
|
| 135 |
+
inputs=[
|
| 136 |
+
IO.Image.Input("image"),
|
| 137 |
+
IO.Int.Input("batch_index", default=0, min=0, max=4095),
|
| 138 |
+
IO.Int.Input("length", default=1, min=1, max=4096),
|
| 139 |
+
],
|
| 140 |
+
outputs=[IO.Image.Output()],
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
@classmethod
|
| 144 |
+
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
|
| 145 |
+
s_in = image
|
| 146 |
+
batch_index = min(s_in.shape[0] - 1, batch_index)
|
| 147 |
+
length = min(s_in.shape[0] - batch_index, length)
|
| 148 |
+
s = s_in[batch_index:batch_index + length].clone()
|
| 149 |
+
return IO.NodeOutput(s)
|
| 150 |
+
|
| 151 |
+
frombatch = execute # TODO: remove
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class ImageAddNoise(IO.ComfyNode):
|
| 155 |
+
@classmethod
|
| 156 |
+
def define_schema(cls):
|
| 157 |
+
return IO.Schema(
|
| 158 |
+
node_id="ImageAddNoise",
|
| 159 |
+
search_aliases=["film grain"],
|
| 160 |
+
category="image",
|
| 161 |
+
inputs=[
|
| 162 |
+
IO.Image.Input("image"),
|
| 163 |
+
IO.Int.Input(
|
| 164 |
+
"seed",
|
| 165 |
+
default=0,
|
| 166 |
+
min=0,
|
| 167 |
+
max=0xFFFFFFFFFFFFFFFF,
|
| 168 |
+
control_after_generate=True,
|
| 169 |
+
tooltip="The random seed used for creating the noise.",
|
| 170 |
+
),
|
| 171 |
+
IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01),
|
| 172 |
+
],
|
| 173 |
+
outputs=[IO.Image.Output()],
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
@classmethod
|
| 177 |
+
def execute(cls, image, seed, strength) -> IO.NodeOutput:
|
| 178 |
+
generator = torch.manual_seed(seed)
|
| 179 |
+
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
|
| 180 |
+
return IO.NodeOutput(s)
|
| 181 |
+
|
| 182 |
+
repeat = execute # TODO: remove
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class SaveAnimatedWEBP(IO.ComfyNode):
|
| 186 |
+
COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
|
| 187 |
+
|
| 188 |
+
@classmethod
|
| 189 |
+
def define_schema(cls):
|
| 190 |
+
return IO.Schema(
|
| 191 |
+
node_id="SaveAnimatedWEBP",
|
| 192 |
+
category="image/animation",
|
| 193 |
+
inputs=[
|
| 194 |
+
IO.Image.Input("images"),
|
| 195 |
+
IO.String.Input("filename_prefix", default="ComfyUI"),
|
| 196 |
+
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
|
| 197 |
+
IO.Boolean.Input("lossless", default=True),
|
| 198 |
+
IO.Int.Input("quality", default=80, min=0, max=100),
|
| 199 |
+
IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())),
|
| 200 |
+
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
|
| 201 |
+
],
|
| 202 |
+
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
| 203 |
+
is_output_node=True,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
@classmethod
|
| 207 |
+
def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
|
| 208 |
+
return IO.NodeOutput(
|
| 209 |
+
ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
|
| 210 |
+
images=images,
|
| 211 |
+
filename_prefix=filename_prefix,
|
| 212 |
+
cls=cls,
|
| 213 |
+
fps=fps,
|
| 214 |
+
lossless=lossless,
|
| 215 |
+
quality=quality,
|
| 216 |
+
method=cls.COMPRESS_METHODS.get(method)
|
| 217 |
+
)
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
save_images = execute # TODO: remove
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class SaveAnimatedPNG(IO.ComfyNode):
|
| 224 |
+
|
| 225 |
+
@classmethod
|
| 226 |
+
def define_schema(cls):
|
| 227 |
+
return IO.Schema(
|
| 228 |
+
node_id="SaveAnimatedPNG",
|
| 229 |
+
category="image/animation",
|
| 230 |
+
inputs=[
|
| 231 |
+
IO.Image.Input("images"),
|
| 232 |
+
IO.String.Input("filename_prefix", default="ComfyUI"),
|
| 233 |
+
IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
|
| 234 |
+
IO.Int.Input("compress_level", default=4, min=0, max=9, advanced=True),
|
| 235 |
+
],
|
| 236 |
+
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
| 237 |
+
is_output_node=True,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
|
| 242 |
+
return IO.NodeOutput(
|
| 243 |
+
ui=UI.ImageSaveHelper.get_save_animated_png_ui(
|
| 244 |
+
images=images,
|
| 245 |
+
filename_prefix=filename_prefix,
|
| 246 |
+
cls=cls,
|
| 247 |
+
fps=fps,
|
| 248 |
+
compress_level=compress_level,
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
save_images = execute # TODO: remove
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class ImageStitch(IO.ComfyNode):
|
| 256 |
+
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
|
| 257 |
+
@classmethod
|
| 258 |
+
def define_schema(cls):
|
| 259 |
+
return IO.Schema(
|
| 260 |
+
node_id="ImageStitch",
|
| 261 |
+
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
|
| 262 |
+
display_name="Image Stitch",
|
| 263 |
+
description="Stitches image2 to image1 in the specified direction.\n"
|
| 264 |
+
"If image2 is not provided, returns image1 unchanged.\n"
|
| 265 |
+
"Optional spacing can be added between images.",
|
| 266 |
+
category="image/transform",
|
| 267 |
+
inputs=[
|
| 268 |
+
IO.Image.Input("image1"),
|
| 269 |
+
IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
|
| 270 |
+
IO.Boolean.Input("match_image_size", default=True),
|
| 271 |
+
IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2, advanced=True),
|
| 272 |
+
IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white", advanced=True),
|
| 273 |
+
IO.Image.Input("image2", optional=True),
|
| 274 |
+
],
|
| 275 |
+
outputs=[IO.Image.Output()],
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
@classmethod
|
| 279 |
+
def execute(
|
| 280 |
+
cls,
|
| 281 |
+
image1,
|
| 282 |
+
direction,
|
| 283 |
+
match_image_size,
|
| 284 |
+
spacing_width,
|
| 285 |
+
spacing_color,
|
| 286 |
+
image2=None,
|
| 287 |
+
) -> IO.NodeOutput:
|
| 288 |
+
if image2 is None:
|
| 289 |
+
return IO.NodeOutput(image1)
|
| 290 |
+
|
| 291 |
+
# Handle batch size differences
|
| 292 |
+
if image1.shape[0] != image2.shape[0]:
|
| 293 |
+
max_batch = max(image1.shape[0], image2.shape[0])
|
| 294 |
+
if image1.shape[0] < max_batch:
|
| 295 |
+
image1 = torch.cat(
|
| 296 |
+
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
|
| 297 |
+
)
|
| 298 |
+
if image2.shape[0] < max_batch:
|
| 299 |
+
image2 = torch.cat(
|
| 300 |
+
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Match image sizes if requested
|
| 304 |
+
if match_image_size:
|
| 305 |
+
h1, w1 = image1.shape[1:3]
|
| 306 |
+
h2, w2 = image2.shape[1:3]
|
| 307 |
+
aspect_ratio = w2 / h2
|
| 308 |
+
|
| 309 |
+
if direction in ["left", "right"]:
|
| 310 |
+
target_h, target_w = h1, int(h1 * aspect_ratio)
|
| 311 |
+
else: # up, down
|
| 312 |
+
target_w, target_h = w1, int(w1 / aspect_ratio)
|
| 313 |
+
|
| 314 |
+
image2 = comfy.utils.common_upscale(
|
| 315 |
+
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
|
| 316 |
+
).movedim(1, -1)
|
| 317 |
+
|
| 318 |
+
color_map = {
|
| 319 |
+
"white": 1.0,
|
| 320 |
+
"black": 0.0,
|
| 321 |
+
"red": (1.0, 0.0, 0.0),
|
| 322 |
+
"green": (0.0, 1.0, 0.0),
|
| 323 |
+
"blue": (0.0, 0.0, 1.0),
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
color_val = color_map[spacing_color]
|
| 327 |
+
|
| 328 |
+
# When not matching sizes, pad to align non-concat dimensions
|
| 329 |
+
if not match_image_size:
|
| 330 |
+
h1, w1 = image1.shape[1:3]
|
| 331 |
+
h2, w2 = image2.shape[1:3]
|
| 332 |
+
pad_value = 0.0
|
| 333 |
+
if not isinstance(color_val, tuple):
|
| 334 |
+
pad_value = color_val
|
| 335 |
+
|
| 336 |
+
if direction in ["left", "right"]:
|
| 337 |
+
# For horizontal concat, pad heights to match
|
| 338 |
+
if h1 != h2:
|
| 339 |
+
target_h = max(h1, h2)
|
| 340 |
+
if h1 < target_h:
|
| 341 |
+
pad_h = target_h - h1
|
| 342 |
+
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
| 343 |
+
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
|
| 344 |
+
if h2 < target_h:
|
| 345 |
+
pad_h = target_h - h2
|
| 346 |
+
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
|
| 347 |
+
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
|
| 348 |
+
else: # up, down
|
| 349 |
+
# For vertical concat, pad widths to match
|
| 350 |
+
if w1 != w2:
|
| 351 |
+
target_w = max(w1, w2)
|
| 352 |
+
if w1 < target_w:
|
| 353 |
+
pad_w = target_w - w1
|
| 354 |
+
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
| 355 |
+
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
|
| 356 |
+
if w2 < target_w:
|
| 357 |
+
pad_w = target_w - w2
|
| 358 |
+
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
|
| 359 |
+
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
|
| 360 |
+
|
| 361 |
+
# Ensure same number of channels
|
| 362 |
+
if image1.shape[-1] != image2.shape[-1]:
|
| 363 |
+
max_channels = max(image1.shape[-1], image2.shape[-1])
|
| 364 |
+
if image1.shape[-1] < max_channels:
|
| 365 |
+
image1 = torch.cat(
|
| 366 |
+
[
|
| 367 |
+
image1,
|
| 368 |
+
torch.ones(
|
| 369 |
+
*image1.shape[:-1],
|
| 370 |
+
max_channels - image1.shape[-1],
|
| 371 |
+
device=image1.device,
|
| 372 |
+
),
|
| 373 |
+
],
|
| 374 |
+
dim=-1,
|
| 375 |
+
)
|
| 376 |
+
if image2.shape[-1] < max_channels:
|
| 377 |
+
image2 = torch.cat(
|
| 378 |
+
[
|
| 379 |
+
image2,
|
| 380 |
+
torch.ones(
|
| 381 |
+
*image2.shape[:-1],
|
| 382 |
+
max_channels - image2.shape[-1],
|
| 383 |
+
device=image2.device,
|
| 384 |
+
),
|
| 385 |
+
],
|
| 386 |
+
dim=-1,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Add spacing if specified
|
| 390 |
+
if spacing_width > 0:
|
| 391 |
+
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
|
| 392 |
+
|
| 393 |
+
if direction in ["left", "right"]:
|
| 394 |
+
spacing_shape = (
|
| 395 |
+
image1.shape[0],
|
| 396 |
+
max(image1.shape[1], image2.shape[1]),
|
| 397 |
+
spacing_width,
|
| 398 |
+
image1.shape[-1],
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
spacing_shape = (
|
| 402 |
+
image1.shape[0],
|
| 403 |
+
spacing_width,
|
| 404 |
+
max(image1.shape[2], image2.shape[2]),
|
| 405 |
+
image1.shape[-1],
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
|
| 409 |
+
if isinstance(color_val, tuple):
|
| 410 |
+
for i, c in enumerate(color_val):
|
| 411 |
+
if i < spacing.shape[-1]:
|
| 412 |
+
spacing[..., i] = c
|
| 413 |
+
if spacing.shape[-1] == 4: # Add alpha
|
| 414 |
+
spacing[..., 3] = 1.0
|
| 415 |
+
else:
|
| 416 |
+
spacing[..., : min(3, spacing.shape[-1])] = color_val
|
| 417 |
+
if spacing.shape[-1] == 4:
|
| 418 |
+
spacing[..., 3] = 1.0
|
| 419 |
+
|
| 420 |
+
# Concatenate images
|
| 421 |
+
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
|
| 422 |
+
if spacing_width > 0:
|
| 423 |
+
images.insert(1, spacing)
|
| 424 |
+
|
| 425 |
+
concat_dim = 2 if direction in ["left", "right"] else 1
|
| 426 |
+
return IO.NodeOutput(torch.cat(images, dim=concat_dim))
|
| 427 |
+
|
| 428 |
+
stitch = execute # TODO: remove
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class ResizeAndPadImage(IO.ComfyNode):
|
| 432 |
+
@classmethod
|
| 433 |
+
def define_schema(cls):
|
| 434 |
+
return IO.Schema(
|
| 435 |
+
node_id="ResizeAndPadImage",
|
| 436 |
+
search_aliases=["fit to size"],
|
| 437 |
+
category="image/transform",
|
| 438 |
+
inputs=[
|
| 439 |
+
IO.Image.Input("image"),
|
| 440 |
+
IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
| 441 |
+
IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
| 442 |
+
IO.Combo.Input("padding_color", options=["white", "black"], advanced=True),
|
| 443 |
+
IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"], advanced=True),
|
| 444 |
+
],
|
| 445 |
+
outputs=[IO.Image.Output()],
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
@classmethod
|
| 449 |
+
def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput:
|
| 450 |
+
batch_size, orig_height, orig_width, channels = image.shape
|
| 451 |
+
|
| 452 |
+
scale_w = target_width / orig_width
|
| 453 |
+
scale_h = target_height / orig_height
|
| 454 |
+
scale = min(scale_w, scale_h)
|
| 455 |
+
|
| 456 |
+
new_width = int(orig_width * scale)
|
| 457 |
+
new_height = int(orig_height * scale)
|
| 458 |
+
|
| 459 |
+
image_permuted = image.permute(0, 3, 1, 2)
|
| 460 |
+
|
| 461 |
+
resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled")
|
| 462 |
+
|
| 463 |
+
pad_value = 0.0 if padding_color == "black" else 1.0
|
| 464 |
+
padded = torch.full(
|
| 465 |
+
(batch_size, channels, target_height, target_width),
|
| 466 |
+
pad_value,
|
| 467 |
+
dtype=image.dtype,
|
| 468 |
+
device=image.device
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
y_offset = (target_height - new_height) // 2
|
| 472 |
+
x_offset = (target_width - new_width) // 2
|
| 473 |
+
|
| 474 |
+
padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
|
| 475 |
+
|
| 476 |
+
output = padded.permute(0, 2, 3, 1)
|
| 477 |
+
return IO.NodeOutput(output)
|
| 478 |
+
|
| 479 |
+
resize_and_pad = execute # TODO: remove
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
class SaveSVGNode(IO.ComfyNode):
|
| 483 |
+
@classmethod
|
| 484 |
+
def define_schema(cls):
|
| 485 |
+
return IO.Schema(
|
| 486 |
+
node_id="SaveSVGNode",
|
| 487 |
+
search_aliases=["export vector", "save vector graphics"],
|
| 488 |
+
description="Save SVG files on disk.",
|
| 489 |
+
category="image/save",
|
| 490 |
+
inputs=[
|
| 491 |
+
IO.SVG.Input("svg"),
|
| 492 |
+
IO.String.Input(
|
| 493 |
+
"filename_prefix",
|
| 494 |
+
default="svg/ComfyUI",
|
| 495 |
+
tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
|
| 496 |
+
),
|
| 497 |
+
],
|
| 498 |
+
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
| 499 |
+
is_output_node=True,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
@classmethod
|
| 503 |
+
def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput:
|
| 504 |
+
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
| 505 |
+
results: list[UI.SavedResult] = []
|
| 506 |
+
|
| 507 |
+
# Prepare metadata JSON
|
| 508 |
+
metadata_dict = {}
|
| 509 |
+
if cls.hidden.prompt is not None:
|
| 510 |
+
metadata_dict["prompt"] = cls.hidden.prompt
|
| 511 |
+
if cls.hidden.extra_pnginfo is not None:
|
| 512 |
+
metadata_dict.update(cls.hidden.extra_pnginfo)
|
| 513 |
+
|
| 514 |
+
# Convert metadata to JSON string
|
| 515 |
+
metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
for batch_number, svg_bytes in enumerate(svg.data):
|
| 519 |
+
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
| 520 |
+
file = f"{filename_with_batch_num}_{counter:05}_.svg"
|
| 521 |
+
|
| 522 |
+
# Read SVG content
|
| 523 |
+
svg_bytes.seek(0)
|
| 524 |
+
svg_content = svg_bytes.read().decode('utf-8')
|
| 525 |
+
|
| 526 |
+
# Inject metadata if available
|
| 527 |
+
if metadata_json:
|
| 528 |
+
# Create metadata element with CDATA section
|
| 529 |
+
metadata_element = f""" <metadata>
|
| 530 |
+
<![CDATA[
|
| 531 |
+
{metadata_json}
|
| 532 |
+
]]>
|
| 533 |
+
</metadata>
|
| 534 |
+
"""
|
| 535 |
+
# Insert metadata after opening svg tag using regex with a replacement function
|
| 536 |
+
def replacement(match):
|
| 537 |
+
# match.group(1) contains the captured <svg> tag
|
| 538 |
+
return match.group(1) + '\n' + metadata_element
|
| 539 |
+
|
| 540 |
+
# Apply the substitution
|
| 541 |
+
svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
|
| 542 |
+
|
| 543 |
+
# Write the modified SVG to file
|
| 544 |
+
with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
|
| 545 |
+
svg_file.write(svg_content.encode('utf-8'))
|
| 546 |
+
|
| 547 |
+
results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
|
| 548 |
+
counter += 1
|
| 549 |
+
return IO.NodeOutput(ui={"images": results})
|
| 550 |
+
|
| 551 |
+
save_svg = execute # TODO: remove
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class GetImageSize(IO.ComfyNode):
|
| 555 |
+
@classmethod
|
| 556 |
+
def define_schema(cls):
|
| 557 |
+
return IO.Schema(
|
| 558 |
+
node_id="GetImageSize",
|
| 559 |
+
search_aliases=["dimensions", "resolution", "image info"],
|
| 560 |
+
display_name="Get Image Size",
|
| 561 |
+
description="Returns width and height of the image, and passes it through unchanged.",
|
| 562 |
+
category="image",
|
| 563 |
+
inputs=[
|
| 564 |
+
IO.Image.Input("image"),
|
| 565 |
+
],
|
| 566 |
+
outputs=[
|
| 567 |
+
IO.Int.Output(display_name="width"),
|
| 568 |
+
IO.Int.Output(display_name="height"),
|
| 569 |
+
IO.Int.Output(display_name="batch_size"),
|
| 570 |
+
],
|
| 571 |
+
hidden=[IO.Hidden.unique_id],
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
@classmethod
|
| 575 |
+
def execute(cls, image) -> IO.NodeOutput:
|
| 576 |
+
height = image.shape[1]
|
| 577 |
+
width = image.shape[2]
|
| 578 |
+
batch_size = image.shape[0]
|
| 579 |
+
|
| 580 |
+
# Send progress text to display size on the node
|
| 581 |
+
if cls.hidden.unique_id:
|
| 582 |
+
PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
|
| 583 |
+
|
| 584 |
+
return IO.NodeOutput(width, height, batch_size)
|
| 585 |
+
|
| 586 |
+
get_size = execute # TODO: remove
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
class ImageRotate(IO.ComfyNode):
|
| 590 |
+
@classmethod
|
| 591 |
+
def define_schema(cls):
|
| 592 |
+
return IO.Schema(
|
| 593 |
+
node_id="ImageRotate",
|
| 594 |
+
display_name="Image Rotate",
|
| 595 |
+
search_aliases=["turn", "flip orientation"],
|
| 596 |
+
category="image/transform",
|
| 597 |
+
essentials_category="Image Tools",
|
| 598 |
+
inputs=[
|
| 599 |
+
IO.Image.Input("image"),
|
| 600 |
+
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
|
| 601 |
+
],
|
| 602 |
+
outputs=[IO.Image.Output()],
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
@classmethod
|
| 606 |
+
def execute(cls, image, rotation) -> IO.NodeOutput:
|
| 607 |
+
rotate_by = 0
|
| 608 |
+
if rotation.startswith("90"):
|
| 609 |
+
rotate_by = 1
|
| 610 |
+
elif rotation.startswith("180"):
|
| 611 |
+
rotate_by = 2
|
| 612 |
+
elif rotation.startswith("270"):
|
| 613 |
+
rotate_by = 3
|
| 614 |
+
|
| 615 |
+
image = torch.rot90(image, k=rotate_by, dims=[2, 1])
|
| 616 |
+
return IO.NodeOutput(image)
|
| 617 |
+
|
| 618 |
+
rotate = execute # TODO: remove
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class ImageFlip(IO.ComfyNode):
|
| 622 |
+
@classmethod
|
| 623 |
+
def define_schema(cls):
|
| 624 |
+
return IO.Schema(
|
| 625 |
+
node_id="ImageFlip",
|
| 626 |
+
search_aliases=["mirror", "reflect"],
|
| 627 |
+
category="image/transform",
|
| 628 |
+
inputs=[
|
| 629 |
+
IO.Image.Input("image"),
|
| 630 |
+
IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]),
|
| 631 |
+
],
|
| 632 |
+
outputs=[IO.Image.Output()],
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
@classmethod
|
| 636 |
+
def execute(cls, image, flip_method) -> IO.NodeOutput:
|
| 637 |
+
if flip_method.startswith("x"):
|
| 638 |
+
image = torch.flip(image, dims=[1])
|
| 639 |
+
elif flip_method.startswith("y"):
|
| 640 |
+
image = torch.flip(image, dims=[2])
|
| 641 |
+
|
| 642 |
+
return IO.NodeOutput(image)
|
| 643 |
+
|
| 644 |
+
flip = execute # TODO: remove
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
class ImageScaleToMaxDimension(IO.ComfyNode):
|
| 648 |
+
|
| 649 |
+
@classmethod
|
| 650 |
+
def define_schema(cls):
|
| 651 |
+
return IO.Schema(
|
| 652 |
+
node_id="ImageScaleToMaxDimension",
|
| 653 |
+
category="image/upscaling",
|
| 654 |
+
inputs=[
|
| 655 |
+
IO.Image.Input("image"),
|
| 656 |
+
IO.Combo.Input(
|
| 657 |
+
"upscale_method",
|
| 658 |
+
options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"],
|
| 659 |
+
),
|
| 660 |
+
IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
|
| 661 |
+
],
|
| 662 |
+
outputs=[IO.Image.Output()],
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
@classmethod
|
| 666 |
+
def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
|
| 667 |
+
height = image.shape[1]
|
| 668 |
+
width = image.shape[2]
|
| 669 |
+
|
| 670 |
+
if height > width:
|
| 671 |
+
width = round((width / height) * largest_size)
|
| 672 |
+
height = largest_size
|
| 673 |
+
elif width > height:
|
| 674 |
+
height = round((height / width) * largest_size)
|
| 675 |
+
width = largest_size
|
| 676 |
+
else:
|
| 677 |
+
height = largest_size
|
| 678 |
+
width = largest_size
|
| 679 |
+
|
| 680 |
+
samples = image.movedim(-1, 1)
|
| 681 |
+
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
|
| 682 |
+
s = s.movedim(1, -1)
|
| 683 |
+
return IO.NodeOutput(s)
|
| 684 |
+
|
| 685 |
+
upscale = execute # TODO: remove
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
class SplitImageToTileList(IO.ComfyNode):
|
| 689 |
+
@classmethod
|
| 690 |
+
def define_schema(cls):
|
| 691 |
+
return IO.Schema(
|
| 692 |
+
node_id="SplitImageToTileList",
|
| 693 |
+
category="image/batch",
|
| 694 |
+
search_aliases=["split image", "tile image", "slice image"],
|
| 695 |
+
display_name="Split Image into List of Tiles",
|
| 696 |
+
description="Splits an image into a batched list of tiles with a specified overlap.",
|
| 697 |
+
inputs=[
|
| 698 |
+
IO.Image.Input("image"),
|
| 699 |
+
IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION),
|
| 700 |
+
IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION),
|
| 701 |
+
IO.Int.Input("overlap", default=128, min=0, max=4096),
|
| 702 |
+
],
|
| 703 |
+
outputs=[
|
| 704 |
+
IO.Image.Output(is_output_list=True),
|
| 705 |
+
],
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
@staticmethod
|
| 709 |
+
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
| 710 |
+
coords = []
|
| 711 |
+
stride_x = round(max(tile_width * 0.25, tile_width - overlap))
|
| 712 |
+
stride_y = round(max(tile_width * 0.25, tile_height - overlap))
|
| 713 |
+
|
| 714 |
+
y = 0
|
| 715 |
+
while y < height:
|
| 716 |
+
x = 0
|
| 717 |
+
y_end = min(y + tile_height, height)
|
| 718 |
+
y_start = max(0, y_end - tile_height)
|
| 719 |
+
|
| 720 |
+
while x < width:
|
| 721 |
+
x_end = min(x + tile_width, width)
|
| 722 |
+
x_start = max(0, x_end - tile_width)
|
| 723 |
+
|
| 724 |
+
coords.append((x_start, y_start, x_end, y_end))
|
| 725 |
+
|
| 726 |
+
if x_end >= width:
|
| 727 |
+
break
|
| 728 |
+
x += stride_x
|
| 729 |
+
|
| 730 |
+
if y_end >= height:
|
| 731 |
+
break
|
| 732 |
+
y += stride_y
|
| 733 |
+
|
| 734 |
+
return coords
|
| 735 |
+
|
| 736 |
+
@classmethod
|
| 737 |
+
def execute(cls, image, tile_width, tile_height, overlap):
|
| 738 |
+
b, h, w, c = image.shape
|
| 739 |
+
coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap)
|
| 740 |
+
|
| 741 |
+
output_list = []
|
| 742 |
+
for (x_start, y_start, x_end, y_end) in coords:
|
| 743 |
+
tile = image[:, y_start:y_end, x_start:x_end, :]
|
| 744 |
+
output_list.append(tile)
|
| 745 |
+
|
| 746 |
+
return IO.NodeOutput(output_list)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
class ImageMergeTileList(IO.ComfyNode):
|
| 750 |
+
@classmethod
|
| 751 |
+
def define_schema(cls):
|
| 752 |
+
return IO.Schema(
|
| 753 |
+
node_id="ImageMergeTileList",
|
| 754 |
+
display_name="Merge List of Tiles to Image",
|
| 755 |
+
category="image/batch",
|
| 756 |
+
search_aliases=["split image", "tile image", "slice image"],
|
| 757 |
+
is_input_list=True,
|
| 758 |
+
inputs=[
|
| 759 |
+
IO.Image.Input("image_list"),
|
| 760 |
+
IO.Int.Input("final_width", default=1024, min=64, max=32768),
|
| 761 |
+
IO.Int.Input("final_height", default=1024, min=64, max=32768),
|
| 762 |
+
IO.Int.Input("overlap", default=128, min=0, max=4096),
|
| 763 |
+
],
|
| 764 |
+
outputs=[
|
| 765 |
+
IO.Image.Output(is_output_list=False),
|
| 766 |
+
],
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
@classmethod
|
| 770 |
+
def execute(cls, image_list, final_width, final_height, overlap):
|
| 771 |
+
w = final_width[0]
|
| 772 |
+
h = final_height[0]
|
| 773 |
+
ovlp = overlap[0]
|
| 774 |
+
feather_str = 1.0
|
| 775 |
+
|
| 776 |
+
first_tile = image_list[0]
|
| 777 |
+
b, t_h, t_w, c = first_tile.shape
|
| 778 |
+
device = first_tile.device
|
| 779 |
+
dtype = first_tile.dtype
|
| 780 |
+
|
| 781 |
+
coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp)
|
| 782 |
+
|
| 783 |
+
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
|
| 784 |
+
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
|
| 785 |
+
|
| 786 |
+
if ovlp > 0:
|
| 787 |
+
y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype))
|
| 788 |
+
x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype))
|
| 789 |
+
y_w = torch.clamp(y_w, min=1e-5)
|
| 790 |
+
x_w = torch.clamp(x_w, min=1e-5)
|
| 791 |
+
|
| 792 |
+
sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1)
|
| 793 |
+
flat_mask = torch.ones_like(sine_mask)
|
| 794 |
+
|
| 795 |
+
weight_mask = torch.lerp(flat_mask, sine_mask, feather_str)
|
| 796 |
+
else:
|
| 797 |
+
weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype)
|
| 798 |
+
|
| 799 |
+
for i, (x_start, y_start, x_end, y_end) in enumerate(coords):
|
| 800 |
+
if i >= len(image_list):
|
| 801 |
+
break
|
| 802 |
+
|
| 803 |
+
tile = image_list[i]
|
| 804 |
+
|
| 805 |
+
region_h = y_end - y_start
|
| 806 |
+
region_w = x_end - x_start
|
| 807 |
+
|
| 808 |
+
real_h = min(region_h, tile.shape[1])
|
| 809 |
+
real_w = min(region_w, tile.shape[2])
|
| 810 |
+
|
| 811 |
+
y_end_actual = y_start + real_h
|
| 812 |
+
x_end_actual = x_start + real_w
|
| 813 |
+
|
| 814 |
+
tile_crop = tile[:, :real_h, :real_w, :]
|
| 815 |
+
mask_crop = weight_mask[:, :real_h, :real_w, :]
|
| 816 |
+
|
| 817 |
+
canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop
|
| 818 |
+
weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop
|
| 819 |
+
|
| 820 |
+
weights[weights == 0] = 1.0
|
| 821 |
+
merged_image = canvas / weights
|
| 822 |
+
|
| 823 |
+
return IO.NodeOutput(merged_image)
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
class ImagesExtension(ComfyExtension):
|
| 827 |
+
@override
|
| 828 |
+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
| 829 |
+
return [
|
| 830 |
+
ImageCrop,
|
| 831 |
+
ImageCropV2,
|
| 832 |
+
BoundingBox,
|
| 833 |
+
RepeatImageBatch,
|
| 834 |
+
ImageFromBatch,
|
| 835 |
+
ImageAddNoise,
|
| 836 |
+
SaveAnimatedWEBP,
|
| 837 |
+
SaveAnimatedPNG,
|
| 838 |
+
SaveSVGNode,
|
| 839 |
+
ImageStitch,
|
| 840 |
+
ResizeAndPadImage,
|
| 841 |
+
GetImageSize,
|
| 842 |
+
ImageRotate,
|
| 843 |
+
ImageFlip,
|
| 844 |
+
ImageScaleToMaxDimension,
|
| 845 |
+
SplitImageToTileList,
|
| 846 |
+
ImageMergeTileList,
|
| 847 |
+
]
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
async def comfy_entrypoint() -> ImagesExtension:
|
| 851 |
+
return ImagesExtension()
|
ComfyUI/comfy_extras/nodes_ip2p.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from typing_extensions import override
|
| 4 |
+
from comfy_api.latest import ComfyExtension, io
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class InstructPixToPixConditioning(io.ComfyNode):
|
| 8 |
+
@classmethod
|
| 9 |
+
def define_schema(cls):
|
| 10 |
+
return io.Schema(
|
| 11 |
+
node_id="InstructPixToPixConditioning",
|
| 12 |
+
category="conditioning/instructpix2pix",
|
| 13 |
+
inputs=[
|
| 14 |
+
io.Conditioning.Input("positive"),
|
| 15 |
+
io.Conditioning.Input("negative"),
|
| 16 |
+
io.Vae.Input("vae"),
|
| 17 |
+
io.Image.Input("pixels"),
|
| 18 |
+
],
|
| 19 |
+
outputs=[
|
| 20 |
+
io.Conditioning.Output(display_name="positive"),
|
| 21 |
+
io.Conditioning.Output(display_name="negative"),
|
| 22 |
+
io.Latent.Output(display_name="latent"),
|
| 23 |
+
],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput:
|
| 28 |
+
x = (pixels.shape[1] // 8) * 8
|
| 29 |
+
y = (pixels.shape[2] // 8) * 8
|
| 30 |
+
|
| 31 |
+
if pixels.shape[1] != x or pixels.shape[2] != y:
|
| 32 |
+
x_offset = (pixels.shape[1] % 8) // 2
|
| 33 |
+
y_offset = (pixels.shape[2] % 8) // 2
|
| 34 |
+
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
|
| 35 |
+
|
| 36 |
+
concat_latent = vae.encode(pixels)
|
| 37 |
+
|
| 38 |
+
out_latent = {}
|
| 39 |
+
out_latent["samples"] = torch.zeros_like(concat_latent)
|
| 40 |
+
|
| 41 |
+
out = []
|
| 42 |
+
for conditioning in [positive, negative]:
|
| 43 |
+
c = []
|
| 44 |
+
for t in conditioning:
|
| 45 |
+
d = t[1].copy()
|
| 46 |
+
d["concat_latent_image"] = concat_latent
|
| 47 |
+
n = [t[0], d]
|
| 48 |
+
c.append(n)
|
| 49 |
+
out.append(c)
|
| 50 |
+
return io.NodeOutput(out[0], out[1], out_latent)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class InstructPix2PixExtension(ComfyExtension):
|
| 54 |
+
@override
|
| 55 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 56 |
+
return [
|
| 57 |
+
InstructPixToPixConditioning,
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
async def comfy_entrypoint() -> InstructPix2PixExtension:
|
| 62 |
+
return InstructPix2PixExtension()
|
| 63 |
+
|
ComfyUI/comfy_extras/nodes_kandinsky5.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nodes
|
| 2 |
+
import node_helpers
|
| 3 |
+
import torch
|
| 4 |
+
import comfy.model_management
|
| 5 |
+
import comfy.utils
|
| 6 |
+
|
| 7 |
+
from typing_extensions import override
|
| 8 |
+
from comfy_api.latest import ComfyExtension, io
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Kandinsky5ImageToVideo(io.ComfyNode):
|
| 12 |
+
@classmethod
|
| 13 |
+
def define_schema(cls):
|
| 14 |
+
return io.Schema(
|
| 15 |
+
node_id="Kandinsky5ImageToVideo",
|
| 16 |
+
category="conditioning/video_models",
|
| 17 |
+
inputs=[
|
| 18 |
+
io.Conditioning.Input("positive"),
|
| 19 |
+
io.Conditioning.Input("negative"),
|
| 20 |
+
io.Vae.Input("vae"),
|
| 21 |
+
io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 22 |
+
io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
| 23 |
+
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
| 24 |
+
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
| 25 |
+
io.Image.Input("start_image", optional=True),
|
| 26 |
+
],
|
| 27 |
+
outputs=[
|
| 28 |
+
io.Conditioning.Output(display_name="positive"),
|
| 29 |
+
io.Conditioning.Output(display_name="negative"),
|
| 30 |
+
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
|
| 31 |
+
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
|
| 32 |
+
],
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
| 37 |
+
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
| 38 |
+
cond_latent_out = {}
|
| 39 |
+
if start_image is not None:
|
| 40 |
+
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
| 41 |
+
encoded = vae.encode(start_image[:, :, :, :3])
|
| 42 |
+
cond_latent_out["samples"] = encoded
|
| 43 |
+
|
| 44 |
+
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
| 45 |
+
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
| 46 |
+
|
| 47 |
+
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
|
| 48 |
+
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
|
| 49 |
+
|
| 50 |
+
out_latent = {}
|
| 51 |
+
out_latent["samples"] = latent
|
| 52 |
+
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
|
| 56 |
+
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
|
| 57 |
+
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
|
| 58 |
+
|
| 59 |
+
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
|
| 60 |
+
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
|
| 61 |
+
|
| 62 |
+
# normalization
|
| 63 |
+
normalized = (source - source_mean) / (source_std + 1e-8)
|
| 64 |
+
normalized = normalized * reference_std + reference_mean
|
| 65 |
+
|
| 66 |
+
return normalized
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class NormalizeVideoLatentStart(io.ComfyNode):
|
| 70 |
+
@classmethod
|
| 71 |
+
def define_schema(cls):
|
| 72 |
+
return io.Schema(
|
| 73 |
+
node_id="NormalizeVideoLatentStart",
|
| 74 |
+
category="conditioning/video_models",
|
| 75 |
+
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
|
| 76 |
+
inputs=[
|
| 77 |
+
io.Latent.Input("latent"),
|
| 78 |
+
io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
|
| 79 |
+
io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
|
| 80 |
+
],
|
| 81 |
+
outputs=[
|
| 82 |
+
io.Latent.Output(display_name="latent"),
|
| 83 |
+
],
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
@classmethod
|
| 87 |
+
def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
|
| 88 |
+
if latent["samples"].shape[2] <= 1:
|
| 89 |
+
return io.NodeOutput(latent)
|
| 90 |
+
s = latent.copy()
|
| 91 |
+
samples = latent["samples"].clone()
|
| 92 |
+
|
| 93 |
+
first_frames = samples[:, :, :start_frame_count]
|
| 94 |
+
reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
|
| 95 |
+
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
| 96 |
+
|
| 97 |
+
samples[:, :, :start_frame_count] = normalized_first_frames
|
| 98 |
+
s["samples"] = samples
|
| 99 |
+
return io.NodeOutput(s)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class CLIPTextEncodeKandinsky5(io.ComfyNode):
|
| 103 |
+
@classmethod
|
| 104 |
+
def define_schema(cls):
|
| 105 |
+
return io.Schema(
|
| 106 |
+
node_id="CLIPTextEncodeKandinsky5",
|
| 107 |
+
search_aliases=["kandinsky prompt"],
|
| 108 |
+
category="advanced/conditioning/kandinsky5",
|
| 109 |
+
inputs=[
|
| 110 |
+
io.Clip.Input("clip"),
|
| 111 |
+
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
| 112 |
+
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
|
| 113 |
+
],
|
| 114 |
+
outputs=[
|
| 115 |
+
io.Conditioning.Output(),
|
| 116 |
+
],
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
@classmethod
|
| 120 |
+
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
|
| 121 |
+
tokens = clip.tokenize(clip_l)
|
| 122 |
+
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
|
| 123 |
+
|
| 124 |
+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class Kandinsky5Extension(ComfyExtension):
|
| 128 |
+
@override
|
| 129 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 130 |
+
return [
|
| 131 |
+
Kandinsky5ImageToVideo,
|
| 132 |
+
NormalizeVideoLatentStart,
|
| 133 |
+
CLIPTextEncodeKandinsky5,
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
async def comfy_entrypoint() -> Kandinsky5Extension:
|
| 137 |
+
return Kandinsky5Extension()
|
ComfyUI/comfy_extras/nodes_latent.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import comfy.utils
|
| 2 |
+
import comfy_extras.nodes_post_processing
|
| 3 |
+
import torch
|
| 4 |
+
import nodes
|
| 5 |
+
from typing_extensions import override
|
| 6 |
+
from comfy_api.latest import ComfyExtension, io
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
| 11 |
+
if latent.shape[1:] != target_shape[1:]:
|
| 12 |
+
latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center")
|
| 13 |
+
if repeat_batch:
|
| 14 |
+
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
| 15 |
+
else:
|
| 16 |
+
return latent
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LatentAdd(io.ComfyNode):
|
| 20 |
+
@classmethod
|
| 21 |
+
def define_schema(cls):
|
| 22 |
+
return io.Schema(
|
| 23 |
+
node_id="LatentAdd",
|
| 24 |
+
search_aliases=["combine latents", "sum latents"],
|
| 25 |
+
category="latent/advanced",
|
| 26 |
+
inputs=[
|
| 27 |
+
io.Latent.Input("samples1"),
|
| 28 |
+
io.Latent.Input("samples2"),
|
| 29 |
+
],
|
| 30 |
+
outputs=[
|
| 31 |
+
io.Latent.Output(),
|
| 32 |
+
],
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
| 37 |
+
samples_out = samples1.copy()
|
| 38 |
+
|
| 39 |
+
s1 = samples1["samples"]
|
| 40 |
+
s2 = samples2["samples"]
|
| 41 |
+
|
| 42 |
+
s2 = reshape_latent_to(s1.shape, s2)
|
| 43 |
+
samples_out["samples"] = s1 + s2
|
| 44 |
+
return io.NodeOutput(samples_out)
|
| 45 |
+
|
| 46 |
+
class LatentSubtract(io.ComfyNode):
|
| 47 |
+
@classmethod
|
| 48 |
+
def define_schema(cls):
|
| 49 |
+
return io.Schema(
|
| 50 |
+
node_id="LatentSubtract",
|
| 51 |
+
search_aliases=["difference latent", "remove features"],
|
| 52 |
+
category="latent/advanced",
|
| 53 |
+
inputs=[
|
| 54 |
+
io.Latent.Input("samples1"),
|
| 55 |
+
io.Latent.Input("samples2"),
|
| 56 |
+
],
|
| 57 |
+
outputs=[
|
| 58 |
+
io.Latent.Output(),
|
| 59 |
+
],
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
| 64 |
+
samples_out = samples1.copy()
|
| 65 |
+
|
| 66 |
+
s1 = samples1["samples"]
|
| 67 |
+
s2 = samples2["samples"]
|
| 68 |
+
|
| 69 |
+
s2 = reshape_latent_to(s1.shape, s2)
|
| 70 |
+
samples_out["samples"] = s1 - s2
|
| 71 |
+
return io.NodeOutput(samples_out)
|
| 72 |
+
|
| 73 |
+
class LatentMultiply(io.ComfyNode):
|
| 74 |
+
@classmethod
|
| 75 |
+
def define_schema(cls):
|
| 76 |
+
return io.Schema(
|
| 77 |
+
node_id="LatentMultiply",
|
| 78 |
+
search_aliases=["scale latent", "amplify latent", "latent gain"],
|
| 79 |
+
category="latent/advanced",
|
| 80 |
+
inputs=[
|
| 81 |
+
io.Latent.Input("samples"),
|
| 82 |
+
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
| 83 |
+
],
|
| 84 |
+
outputs=[
|
| 85 |
+
io.Latent.Output(),
|
| 86 |
+
],
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
@classmethod
|
| 90 |
+
def execute(cls, samples, multiplier) -> io.NodeOutput:
|
| 91 |
+
samples_out = samples.copy()
|
| 92 |
+
|
| 93 |
+
s1 = samples["samples"]
|
| 94 |
+
samples_out["samples"] = s1 * multiplier
|
| 95 |
+
return io.NodeOutput(samples_out)
|
| 96 |
+
|
| 97 |
+
class LatentInterpolate(io.ComfyNode):
|
| 98 |
+
@classmethod
|
| 99 |
+
def define_schema(cls):
|
| 100 |
+
return io.Schema(
|
| 101 |
+
node_id="LatentInterpolate",
|
| 102 |
+
search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
|
| 103 |
+
category="latent/advanced",
|
| 104 |
+
inputs=[
|
| 105 |
+
io.Latent.Input("samples1"),
|
| 106 |
+
io.Latent.Input("samples2"),
|
| 107 |
+
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
| 108 |
+
],
|
| 109 |
+
outputs=[
|
| 110 |
+
io.Latent.Output(),
|
| 111 |
+
],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
|
| 116 |
+
samples_out = samples1.copy()
|
| 117 |
+
|
| 118 |
+
s1 = samples1["samples"]
|
| 119 |
+
s2 = samples2["samples"]
|
| 120 |
+
|
| 121 |
+
s2 = reshape_latent_to(s1.shape, s2)
|
| 122 |
+
|
| 123 |
+
m1 = torch.linalg.vector_norm(s1, dim=(1))
|
| 124 |
+
m2 = torch.linalg.vector_norm(s2, dim=(1))
|
| 125 |
+
|
| 126 |
+
s1 = torch.nan_to_num(s1 / m1)
|
| 127 |
+
s2 = torch.nan_to_num(s2 / m2)
|
| 128 |
+
|
| 129 |
+
t = (s1 * ratio + s2 * (1.0 - ratio))
|
| 130 |
+
mt = torch.linalg.vector_norm(t, dim=(1))
|
| 131 |
+
st = torch.nan_to_num(t / mt)
|
| 132 |
+
|
| 133 |
+
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
| 134 |
+
return io.NodeOutput(samples_out)
|
| 135 |
+
|
| 136 |
+
class LatentConcat(io.ComfyNode):
|
| 137 |
+
@classmethod
|
| 138 |
+
def define_schema(cls):
|
| 139 |
+
return io.Schema(
|
| 140 |
+
node_id="LatentConcat",
|
| 141 |
+
search_aliases=["join latents", "stitch latents"],
|
| 142 |
+
category="latent/advanced",
|
| 143 |
+
inputs=[
|
| 144 |
+
io.Latent.Input("samples1"),
|
| 145 |
+
io.Latent.Input("samples2"),
|
| 146 |
+
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
|
| 147 |
+
],
|
| 148 |
+
outputs=[
|
| 149 |
+
io.Latent.Output(),
|
| 150 |
+
],
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
|
| 155 |
+
samples_out = samples1.copy()
|
| 156 |
+
|
| 157 |
+
s1 = samples1["samples"]
|
| 158 |
+
s2 = samples2["samples"]
|
| 159 |
+
s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0])
|
| 160 |
+
|
| 161 |
+
if "-" in dim:
|
| 162 |
+
c = (s2, s1)
|
| 163 |
+
else:
|
| 164 |
+
c = (s1, s2)
|
| 165 |
+
|
| 166 |
+
if "x" in dim:
|
| 167 |
+
dim = -1
|
| 168 |
+
elif "y" in dim:
|
| 169 |
+
dim = -2
|
| 170 |
+
elif "t" in dim:
|
| 171 |
+
dim = -3
|
| 172 |
+
|
| 173 |
+
samples_out["samples"] = torch.cat(c, dim=dim)
|
| 174 |
+
return io.NodeOutput(samples_out)
|
| 175 |
+
|
| 176 |
+
class LatentCut(io.ComfyNode):
|
| 177 |
+
@classmethod
|
| 178 |
+
def define_schema(cls):
|
| 179 |
+
return io.Schema(
|
| 180 |
+
node_id="LatentCut",
|
| 181 |
+
search_aliases=["crop latent", "slice latent", "extract region"],
|
| 182 |
+
category="latent/advanced",
|
| 183 |
+
inputs=[
|
| 184 |
+
io.Latent.Input("samples"),
|
| 185 |
+
io.Combo.Input("dim", options=["x", "y", "t"]),
|
| 186 |
+
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
| 187 |
+
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
| 188 |
+
],
|
| 189 |
+
outputs=[
|
| 190 |
+
io.Latent.Output(),
|
| 191 |
+
],
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
@classmethod
|
| 195 |
+
def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
|
| 196 |
+
samples_out = samples.copy()
|
| 197 |
+
|
| 198 |
+
s1 = samples["samples"]
|
| 199 |
+
|
| 200 |
+
if "x" in dim:
|
| 201 |
+
dim = s1.ndim - 1
|
| 202 |
+
elif "y" in dim:
|
| 203 |
+
dim = s1.ndim - 2
|
| 204 |
+
elif "t" in dim:
|
| 205 |
+
dim = s1.ndim - 3
|
| 206 |
+
|
| 207 |
+
if index >= 0:
|
| 208 |
+
index = min(index, s1.shape[dim] - 1)
|
| 209 |
+
amount = min(s1.shape[dim] - index, amount)
|
| 210 |
+
else:
|
| 211 |
+
index = max(index, -s1.shape[dim])
|
| 212 |
+
amount = min(-index, amount)
|
| 213 |
+
|
| 214 |
+
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
| 215 |
+
return io.NodeOutput(samples_out)
|
| 216 |
+
|
| 217 |
+
class LatentCutToBatch(io.ComfyNode):
|
| 218 |
+
@classmethod
|
| 219 |
+
def define_schema(cls):
|
| 220 |
+
return io.Schema(
|
| 221 |
+
node_id="LatentCutToBatch",
|
| 222 |
+
search_aliases=["slice to batch", "split latent", "tile latent"],
|
| 223 |
+
category="latent/advanced",
|
| 224 |
+
inputs=[
|
| 225 |
+
io.Latent.Input("samples"),
|
| 226 |
+
io.Combo.Input("dim", options=["t", "x", "y"]),
|
| 227 |
+
io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
| 228 |
+
],
|
| 229 |
+
outputs=[
|
| 230 |
+
io.Latent.Output(),
|
| 231 |
+
],
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
@classmethod
|
| 235 |
+
def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
|
| 236 |
+
samples_out = samples.copy()
|
| 237 |
+
|
| 238 |
+
s1 = samples["samples"]
|
| 239 |
+
|
| 240 |
+
if "x" in dim:
|
| 241 |
+
dim = s1.ndim - 1
|
| 242 |
+
elif "y" in dim:
|
| 243 |
+
dim = s1.ndim - 2
|
| 244 |
+
elif "t" in dim:
|
| 245 |
+
dim = s1.ndim - 3
|
| 246 |
+
|
| 247 |
+
if dim < 2:
|
| 248 |
+
return io.NodeOutput(samples)
|
| 249 |
+
|
| 250 |
+
s = s1.movedim(dim, 1)
|
| 251 |
+
if s.shape[1] < slice_size:
|
| 252 |
+
slice_size = s.shape[1]
|
| 253 |
+
elif s.shape[1] % slice_size != 0:
|
| 254 |
+
s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
|
| 255 |
+
new_shape = [-1, slice_size] + list(s.shape[2:])
|
| 256 |
+
samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
|
| 257 |
+
return io.NodeOutput(samples_out)
|
| 258 |
+
|
| 259 |
+
class LatentBatch(io.ComfyNode):
|
| 260 |
+
@classmethod
|
| 261 |
+
def define_schema(cls):
|
| 262 |
+
return io.Schema(
|
| 263 |
+
node_id="LatentBatch",
|
| 264 |
+
search_aliases=["combine latents", "merge latents", "join latents"],
|
| 265 |
+
category="latent/batch",
|
| 266 |
+
is_deprecated=True,
|
| 267 |
+
inputs=[
|
| 268 |
+
io.Latent.Input("samples1"),
|
| 269 |
+
io.Latent.Input("samples2"),
|
| 270 |
+
],
|
| 271 |
+
outputs=[
|
| 272 |
+
io.Latent.Output(),
|
| 273 |
+
],
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
@classmethod
|
| 277 |
+
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
| 278 |
+
samples_out = samples1.copy()
|
| 279 |
+
s1 = samples1["samples"]
|
| 280 |
+
s2 = samples2["samples"]
|
| 281 |
+
|
| 282 |
+
s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False)
|
| 283 |
+
s = torch.cat((s1, s2), dim=0)
|
| 284 |
+
samples_out["samples"] = s
|
| 285 |
+
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
| 286 |
+
return io.NodeOutput(samples_out)
|
| 287 |
+
|
| 288 |
+
class LatentBatchSeedBehavior(io.ComfyNode):
|
| 289 |
+
@classmethod
|
| 290 |
+
def define_schema(cls):
|
| 291 |
+
return io.Schema(
|
| 292 |
+
node_id="LatentBatchSeedBehavior",
|
| 293 |
+
category="latent/advanced",
|
| 294 |
+
inputs=[
|
| 295 |
+
io.Latent.Input("samples"),
|
| 296 |
+
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
|
| 297 |
+
],
|
| 298 |
+
outputs=[
|
| 299 |
+
io.Latent.Output(),
|
| 300 |
+
],
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
@classmethod
|
| 304 |
+
def execute(cls, samples, seed_behavior) -> io.NodeOutput:
|
| 305 |
+
samples_out = samples.copy()
|
| 306 |
+
latent = samples["samples"]
|
| 307 |
+
if seed_behavior == "random":
|
| 308 |
+
if 'batch_index' in samples_out:
|
| 309 |
+
samples_out.pop('batch_index')
|
| 310 |
+
elif seed_behavior == "fixed":
|
| 311 |
+
batch_number = samples_out.get("batch_index", [0])[0]
|
| 312 |
+
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
| 313 |
+
|
| 314 |
+
return io.NodeOutput(samples_out)
|
| 315 |
+
|
| 316 |
+
class LatentApplyOperation(io.ComfyNode):
|
| 317 |
+
@classmethod
|
| 318 |
+
def define_schema(cls):
|
| 319 |
+
return io.Schema(
|
| 320 |
+
node_id="LatentApplyOperation",
|
| 321 |
+
search_aliases=["transform latent"],
|
| 322 |
+
category="latent/advanced/operations",
|
| 323 |
+
is_experimental=True,
|
| 324 |
+
inputs=[
|
| 325 |
+
io.Latent.Input("samples"),
|
| 326 |
+
io.LatentOperation.Input("operation"),
|
| 327 |
+
],
|
| 328 |
+
outputs=[
|
| 329 |
+
io.Latent.Output(),
|
| 330 |
+
],
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
@classmethod
|
| 334 |
+
def execute(cls, samples, operation) -> io.NodeOutput:
|
| 335 |
+
samples_out = samples.copy()
|
| 336 |
+
|
| 337 |
+
s1 = samples["samples"]
|
| 338 |
+
samples_out["samples"] = operation(latent=s1)
|
| 339 |
+
return io.NodeOutput(samples_out)
|
| 340 |
+
|
| 341 |
+
class LatentApplyOperationCFG(io.ComfyNode):
|
| 342 |
+
@classmethod
|
| 343 |
+
def define_schema(cls):
|
| 344 |
+
return io.Schema(
|
| 345 |
+
node_id="LatentApplyOperationCFG",
|
| 346 |
+
category="latent/advanced/operations",
|
| 347 |
+
is_experimental=True,
|
| 348 |
+
inputs=[
|
| 349 |
+
io.Model.Input("model"),
|
| 350 |
+
io.LatentOperation.Input("operation"),
|
| 351 |
+
],
|
| 352 |
+
outputs=[
|
| 353 |
+
io.Model.Output(),
|
| 354 |
+
],
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
@classmethod
|
| 358 |
+
def execute(cls, model, operation) -> io.NodeOutput:
|
| 359 |
+
m = model.clone()
|
| 360 |
+
|
| 361 |
+
def pre_cfg_function(args):
|
| 362 |
+
conds_out = args["conds_out"]
|
| 363 |
+
if len(conds_out) == 2:
|
| 364 |
+
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
|
| 365 |
+
else:
|
| 366 |
+
conds_out[0] = operation(latent=conds_out[0])
|
| 367 |
+
return conds_out
|
| 368 |
+
|
| 369 |
+
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
| 370 |
+
return io.NodeOutput(m)
|
| 371 |
+
|
| 372 |
+
class LatentOperationTonemapReinhard(io.ComfyNode):
|
| 373 |
+
@classmethod
|
| 374 |
+
def define_schema(cls):
|
| 375 |
+
return io.Schema(
|
| 376 |
+
node_id="LatentOperationTonemapReinhard",
|
| 377 |
+
search_aliases=["hdr latent"],
|
| 378 |
+
category="latent/advanced/operations",
|
| 379 |
+
is_experimental=True,
|
| 380 |
+
inputs=[
|
| 381 |
+
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
| 382 |
+
],
|
| 383 |
+
outputs=[
|
| 384 |
+
io.LatentOperation.Output(),
|
| 385 |
+
],
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
@classmethod
|
| 389 |
+
def execute(cls, multiplier) -> io.NodeOutput:
|
| 390 |
+
def tonemap_reinhard(latent, **kwargs):
|
| 391 |
+
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
| 392 |
+
normalized_latent = latent / latent_vector_magnitude
|
| 393 |
+
|
| 394 |
+
dims = list(range(1, latent_vector_magnitude.ndim))
|
| 395 |
+
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
|
| 396 |
+
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
|
| 397 |
+
|
| 398 |
+
top = (std * 5 + mean) * multiplier
|
| 399 |
+
|
| 400 |
+
#reinhard
|
| 401 |
+
latent_vector_magnitude *= (1.0 / top)
|
| 402 |
+
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
|
| 403 |
+
new_magnitude *= top
|
| 404 |
+
|
| 405 |
+
return normalized_latent * new_magnitude
|
| 406 |
+
return io.NodeOutput(tonemap_reinhard)
|
| 407 |
+
|
| 408 |
+
class LatentOperationSharpen(io.ComfyNode):
|
| 409 |
+
@classmethod
|
| 410 |
+
def define_schema(cls):
|
| 411 |
+
return io.Schema(
|
| 412 |
+
node_id="LatentOperationSharpen",
|
| 413 |
+
category="latent/advanced/operations",
|
| 414 |
+
is_experimental=True,
|
| 415 |
+
inputs=[
|
| 416 |
+
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1, advanced=True),
|
| 417 |
+
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1, advanced=True),
|
| 418 |
+
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01, advanced=True),
|
| 419 |
+
],
|
| 420 |
+
outputs=[
|
| 421 |
+
io.LatentOperation.Output(),
|
| 422 |
+
],
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
@classmethod
|
| 426 |
+
def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
|
| 427 |
+
def sharpen(latent, **kwargs):
|
| 428 |
+
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
| 429 |
+
normalized_latent = latent / luminance
|
| 430 |
+
channels = latent.shape[1]
|
| 431 |
+
|
| 432 |
+
kernel_size = sharpen_radius * 2 + 1
|
| 433 |
+
kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
|
| 434 |
+
center = kernel_size // 2
|
| 435 |
+
|
| 436 |
+
kernel *= alpha * -10
|
| 437 |
+
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
| 438 |
+
|
| 439 |
+
padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
|
| 440 |
+
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
| 441 |
+
|
| 442 |
+
return luminance * sharpened
|
| 443 |
+
return io.NodeOutput(sharpen)
|
| 444 |
+
|
| 445 |
+
class ReplaceVideoLatentFrames(io.ComfyNode):
|
| 446 |
+
@classmethod
|
| 447 |
+
def define_schema(cls):
|
| 448 |
+
return io.Schema(
|
| 449 |
+
node_id="ReplaceVideoLatentFrames",
|
| 450 |
+
category="latent/batch",
|
| 451 |
+
inputs=[
|
| 452 |
+
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
|
| 453 |
+
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
|
| 454 |
+
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
|
| 455 |
+
],
|
| 456 |
+
outputs=[
|
| 457 |
+
io.Latent.Output(),
|
| 458 |
+
],
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
@classmethod
|
| 462 |
+
def execute(cls, destination, index, source=None) -> io.NodeOutput:
|
| 463 |
+
if source is None:
|
| 464 |
+
return io.NodeOutput(destination)
|
| 465 |
+
dest_frames = destination["samples"].shape[2]
|
| 466 |
+
source_frames = source["samples"].shape[2]
|
| 467 |
+
if index < 0:
|
| 468 |
+
index = dest_frames + index
|
| 469 |
+
if index > dest_frames:
|
| 470 |
+
logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
|
| 471 |
+
return io.NodeOutput(destination)
|
| 472 |
+
if index + source_frames > dest_frames:
|
| 473 |
+
logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
|
| 474 |
+
return io.NodeOutput(destination)
|
| 475 |
+
s = source.copy()
|
| 476 |
+
s_source = source["samples"]
|
| 477 |
+
s_destination = destination["samples"].clone()
|
| 478 |
+
s_destination[:, :, index:index + s_source.shape[2]] = s_source
|
| 479 |
+
s["samples"] = s_destination
|
| 480 |
+
return io.NodeOutput(s)
|
| 481 |
+
|
| 482 |
+
class LatentExtension(ComfyExtension):
|
| 483 |
+
@override
|
| 484 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 485 |
+
return [
|
| 486 |
+
LatentAdd,
|
| 487 |
+
LatentSubtract,
|
| 488 |
+
LatentMultiply,
|
| 489 |
+
LatentInterpolate,
|
| 490 |
+
LatentConcat,
|
| 491 |
+
LatentCut,
|
| 492 |
+
LatentCutToBatch,
|
| 493 |
+
LatentBatch,
|
| 494 |
+
LatentBatchSeedBehavior,
|
| 495 |
+
LatentApplyOperation,
|
| 496 |
+
LatentApplyOperationCFG,
|
| 497 |
+
LatentOperationTonemapReinhard,
|
| 498 |
+
LatentOperationSharpen,
|
| 499 |
+
ReplaceVideoLatentFrames
|
| 500 |
+
]
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
async def comfy_entrypoint() -> LatentExtension:
|
| 504 |
+
return LatentExtension()
|
ComfyUI/comfy_extras/nodes_load_3d.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nodes
|
| 2 |
+
import folder_paths
|
| 3 |
+
import os
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
from typing_extensions import override
|
| 7 |
+
from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def normalize_path(path):
|
| 13 |
+
return path.replace('\\', '/')
|
| 14 |
+
|
| 15 |
+
class Load3D(IO.ComfyNode):
|
| 16 |
+
@classmethod
|
| 17 |
+
def define_schema(cls):
|
| 18 |
+
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
|
| 19 |
+
|
| 20 |
+
os.makedirs(input_dir, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
input_path = Path(input_dir)
|
| 23 |
+
base_path = Path(folder_paths.get_input_directory())
|
| 24 |
+
|
| 25 |
+
files = [
|
| 26 |
+
normalize_path(str(file_path.relative_to(base_path)))
|
| 27 |
+
for file_path in input_path.rglob("*")
|
| 28 |
+
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'}
|
| 29 |
+
]
|
| 30 |
+
return IO.Schema(
|
| 31 |
+
node_id="Load3D",
|
| 32 |
+
display_name="Load 3D & Animation",
|
| 33 |
+
category="3d",
|
| 34 |
+
essentials_category="Basics",
|
| 35 |
+
is_experimental=True,
|
| 36 |
+
inputs=[
|
| 37 |
+
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
|
| 38 |
+
IO.Load3D.Input("image"),
|
| 39 |
+
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
| 40 |
+
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
| 41 |
+
],
|
| 42 |
+
outputs=[
|
| 43 |
+
IO.Image.Output(display_name="image"),
|
| 44 |
+
IO.Mask.Output(display_name="mask"),
|
| 45 |
+
IO.String.Output(display_name="mesh_path"),
|
| 46 |
+
IO.Image.Output(display_name="normal"),
|
| 47 |
+
IO.Load3DCamera.Output(display_name="camera_info"),
|
| 48 |
+
IO.Video.Output(display_name="recording_video"),
|
| 49 |
+
IO.File3DAny.Output(display_name="model_3d"),
|
| 50 |
+
],
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
|
| 55 |
+
image_path = folder_paths.get_annotated_filepath(image['image'])
|
| 56 |
+
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
| 57 |
+
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
| 58 |
+
|
| 59 |
+
load_image_node = nodes.LoadImage()
|
| 60 |
+
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
| 61 |
+
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
| 62 |
+
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
| 63 |
+
|
| 64 |
+
video = None
|
| 65 |
+
|
| 66 |
+
if image['recording'] != "":
|
| 67 |
+
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
|
| 68 |
+
|
| 69 |
+
video = InputImpl.VideoFromFile(recording_video_path)
|
| 70 |
+
|
| 71 |
+
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
|
| 72 |
+
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
|
| 73 |
+
|
| 74 |
+
process = execute # TODO: remove
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Preview3D(IO.ComfyNode):
|
| 78 |
+
@classmethod
|
| 79 |
+
def define_schema(cls):
|
| 80 |
+
return IO.Schema(
|
| 81 |
+
node_id="Preview3D",
|
| 82 |
+
search_aliases=["view mesh", "3d viewer"],
|
| 83 |
+
display_name="Preview 3D & Animation",
|
| 84 |
+
category="3d",
|
| 85 |
+
is_experimental=True,
|
| 86 |
+
is_output_node=True,
|
| 87 |
+
inputs=[
|
| 88 |
+
IO.MultiType.Input(
|
| 89 |
+
IO.String.Input("model_file", default="", multiline=False),
|
| 90 |
+
types=[
|
| 91 |
+
IO.File3DGLB,
|
| 92 |
+
IO.File3DGLTF,
|
| 93 |
+
IO.File3DFBX,
|
| 94 |
+
IO.File3DOBJ,
|
| 95 |
+
IO.File3DSTL,
|
| 96 |
+
IO.File3DUSDZ,
|
| 97 |
+
IO.File3DAny,
|
| 98 |
+
],
|
| 99 |
+
tooltip="3D model file or path string",
|
| 100 |
+
),
|
| 101 |
+
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
| 102 |
+
IO.Image.Input("bg_image", optional=True, advanced=True),
|
| 103 |
+
],
|
| 104 |
+
outputs=[],
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
@classmethod
|
| 108 |
+
def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput:
|
| 109 |
+
if isinstance(model_file, Types.File3D):
|
| 110 |
+
filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
|
| 111 |
+
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
|
| 112 |
+
else:
|
| 113 |
+
filename = model_file
|
| 114 |
+
camera_info = kwargs.get("camera_info", None)
|
| 115 |
+
bg_image = kwargs.get("bg_image", None)
|
| 116 |
+
return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image))
|
| 117 |
+
|
| 118 |
+
process = execute # TODO: remove
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class Load3DExtension(ComfyExtension):
|
| 122 |
+
@override
|
| 123 |
+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
| 124 |
+
return [
|
| 125 |
+
Load3D,
|
| 126 |
+
Preview3D,
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
async def comfy_entrypoint() -> Load3DExtension:
|
| 131 |
+
return Load3DExtension()
|
ComfyUI/comfy_extras/nodes_logic.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import TypedDict
|
| 3 |
+
from typing_extensions import override
|
| 4 |
+
from comfy_api.latest import ComfyExtension, io
|
| 5 |
+
from comfy_api.latest import _io
|
| 6 |
+
|
| 7 |
+
# sentinel for missing inputs
|
| 8 |
+
MISSING = object()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SwitchNode(io.ComfyNode):
|
| 12 |
+
@classmethod
|
| 13 |
+
def define_schema(cls):
|
| 14 |
+
template = io.MatchType.Template("switch")
|
| 15 |
+
return io.Schema(
|
| 16 |
+
node_id="ComfySwitchNode",
|
| 17 |
+
display_name="Switch",
|
| 18 |
+
category="logic",
|
| 19 |
+
is_experimental=True,
|
| 20 |
+
inputs=[
|
| 21 |
+
io.Boolean.Input("switch"),
|
| 22 |
+
io.MatchType.Input("on_false", template=template, lazy=True),
|
| 23 |
+
io.MatchType.Input("on_true", template=template, lazy=True),
|
| 24 |
+
],
|
| 25 |
+
outputs=[
|
| 26 |
+
io.MatchType.Output(template=template, display_name="output"),
|
| 27 |
+
],
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def check_lazy_status(cls, switch, on_false=None, on_true=None):
|
| 32 |
+
if switch and on_true is None:
|
| 33 |
+
return ["on_true"]
|
| 34 |
+
if not switch and on_false is None:
|
| 35 |
+
return ["on_false"]
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
|
| 39 |
+
return io.NodeOutput(on_true if switch else on_false)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SoftSwitchNode(io.ComfyNode):
|
| 43 |
+
@classmethod
|
| 44 |
+
def define_schema(cls):
|
| 45 |
+
template = io.MatchType.Template("switch")
|
| 46 |
+
return io.Schema(
|
| 47 |
+
node_id="ComfySoftSwitchNode",
|
| 48 |
+
display_name="Soft Switch",
|
| 49 |
+
category="logic",
|
| 50 |
+
is_experimental=True,
|
| 51 |
+
inputs=[
|
| 52 |
+
io.Boolean.Input("switch"),
|
| 53 |
+
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
|
| 54 |
+
io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
|
| 55 |
+
],
|
| 56 |
+
outputs=[
|
| 57 |
+
io.MatchType.Output(template=template, display_name="output"),
|
| 58 |
+
],
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
|
| 63 |
+
# We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
|
| 64 |
+
# This trick allows us to ignore the value of the switch and still be able to run execute().
|
| 65 |
+
|
| 66 |
+
# One of the inputs may be missing, in which case we need to evaluate the other input
|
| 67 |
+
if on_false is MISSING:
|
| 68 |
+
return ["on_true"]
|
| 69 |
+
if on_true is MISSING:
|
| 70 |
+
return ["on_false"]
|
| 71 |
+
# Normal lazy switch operation
|
| 72 |
+
if switch and on_true is None:
|
| 73 |
+
return ["on_true"]
|
| 74 |
+
if not switch and on_false is None:
|
| 75 |
+
return ["on_false"]
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
|
| 79 |
+
# This check happens before check_lazy_status(), so we can eliminate the case where
|
| 80 |
+
# both inputs are missing.
|
| 81 |
+
if on_false is MISSING and on_true is MISSING:
|
| 82 |
+
return "At least one of on_false or on_true must be connected to Switch node"
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
|
| 87 |
+
if on_true is MISSING:
|
| 88 |
+
return io.NodeOutput(on_false)
|
| 89 |
+
if on_false is MISSING:
|
| 90 |
+
return io.NodeOutput(on_true)
|
| 91 |
+
return io.NodeOutput(on_true if switch else on_false)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class CustomComboNode(io.ComfyNode):
|
| 95 |
+
"""
|
| 96 |
+
Frontend node that allows user to write their own options for a combo.
|
| 97 |
+
This is here to make sure the node has a backend-representation to avoid some annoyances.
|
| 98 |
+
"""
|
| 99 |
+
@classmethod
|
| 100 |
+
def define_schema(cls):
|
| 101 |
+
return io.Schema(
|
| 102 |
+
node_id="CustomCombo",
|
| 103 |
+
display_name="Custom Combo",
|
| 104 |
+
category="utils",
|
| 105 |
+
is_experimental=True,
|
| 106 |
+
inputs=[io.Combo.Input("choice", options=[])],
|
| 107 |
+
outputs=[
|
| 108 |
+
io.String.Output(display_name="STRING"),
|
| 109 |
+
io.Int.Output(display_name="INDEX"),
|
| 110 |
+
],
|
| 111 |
+
accept_all_inputs=True,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool:
|
| 116 |
+
# NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
|
| 117 |
+
# I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
|
| 118 |
+
# I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
|
| 119 |
+
return True
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput:
|
| 123 |
+
return io.NodeOutput(choice, index)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class DCTestNode(io.ComfyNode):
|
| 127 |
+
class DCValues(TypedDict):
|
| 128 |
+
combo: str
|
| 129 |
+
string: str
|
| 130 |
+
integer: int
|
| 131 |
+
image: io.Image.Type
|
| 132 |
+
subcombo: dict[str]
|
| 133 |
+
|
| 134 |
+
@classmethod
|
| 135 |
+
def define_schema(cls):
|
| 136 |
+
return io.Schema(
|
| 137 |
+
node_id="DCTestNode",
|
| 138 |
+
display_name="DCTest",
|
| 139 |
+
category="logic",
|
| 140 |
+
is_output_node=True,
|
| 141 |
+
inputs=[io.DynamicCombo.Input("combo", options=[
|
| 142 |
+
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
| 143 |
+
io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
|
| 144 |
+
io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
|
| 145 |
+
io.DynamicCombo.Option("option4", [
|
| 146 |
+
io.DynamicCombo.Input("subcombo", options=[
|
| 147 |
+
io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
|
| 148 |
+
io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
|
| 149 |
+
])
|
| 150 |
+
])]
|
| 151 |
+
)],
|
| 152 |
+
outputs=[io.AnyType.Output()],
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
@classmethod
|
| 156 |
+
def execute(cls, combo: DCValues) -> io.NodeOutput:
|
| 157 |
+
combo_val = combo["combo"]
|
| 158 |
+
if combo_val == "option1":
|
| 159 |
+
return io.NodeOutput(combo["string"])
|
| 160 |
+
elif combo_val == "option2":
|
| 161 |
+
return io.NodeOutput(combo["integer"])
|
| 162 |
+
elif combo_val == "option3":
|
| 163 |
+
return io.NodeOutput(combo["image"])
|
| 164 |
+
elif combo_val == "option4":
|
| 165 |
+
return io.NodeOutput(f"{combo['subcombo']}")
|
| 166 |
+
else:
|
| 167 |
+
raise ValueError(f"Invalid combo: {combo_val}")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class AutogrowNamesTestNode(io.ComfyNode):
|
| 171 |
+
@classmethod
|
| 172 |
+
def define_schema(cls):
|
| 173 |
+
template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
|
| 174 |
+
return io.Schema(
|
| 175 |
+
node_id="AutogrowNamesTestNode",
|
| 176 |
+
display_name="AutogrowNamesTest",
|
| 177 |
+
category="logic",
|
| 178 |
+
inputs=[
|
| 179 |
+
_io.Autogrow.Input("autogrow", template=template)
|
| 180 |
+
],
|
| 181 |
+
outputs=[io.String.Output()],
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
@classmethod
|
| 185 |
+
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
|
| 186 |
+
vals = list(autogrow.values())
|
| 187 |
+
combined = ",".join([str(x) for x in vals])
|
| 188 |
+
return io.NodeOutput(combined)
|
| 189 |
+
|
| 190 |
+
class AutogrowPrefixTestNode(io.ComfyNode):
|
| 191 |
+
@classmethod
|
| 192 |
+
def define_schema(cls):
|
| 193 |
+
template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
|
| 194 |
+
return io.Schema(
|
| 195 |
+
node_id="AutogrowPrefixTestNode",
|
| 196 |
+
display_name="AutogrowPrefixTest",
|
| 197 |
+
category="logic",
|
| 198 |
+
inputs=[
|
| 199 |
+
_io.Autogrow.Input("autogrow", template=template)
|
| 200 |
+
],
|
| 201 |
+
outputs=[io.String.Output()],
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
@classmethod
|
| 205 |
+
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
|
| 206 |
+
vals = list(autogrow.values())
|
| 207 |
+
combined = ",".join([str(x) for x in vals])
|
| 208 |
+
return io.NodeOutput(combined)
|
| 209 |
+
|
| 210 |
+
class ComboOutputTestNode(io.ComfyNode):
|
| 211 |
+
@classmethod
|
| 212 |
+
def define_schema(cls):
|
| 213 |
+
return io.Schema(
|
| 214 |
+
node_id="ComboOptionTestNode",
|
| 215 |
+
display_name="ComboOptionTest",
|
| 216 |
+
category="logic",
|
| 217 |
+
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
| 218 |
+
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
| 219 |
+
outputs=[io.Combo.Output(), io.Combo.Output()],
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
|
| 224 |
+
return io.NodeOutput(combo, combo2)
|
| 225 |
+
|
| 226 |
+
class ConvertStringToComboNode(io.ComfyNode):
|
| 227 |
+
@classmethod
|
| 228 |
+
def define_schema(cls):
|
| 229 |
+
return io.Schema(
|
| 230 |
+
node_id="ConvertStringToComboNode",
|
| 231 |
+
search_aliases=["string to dropdown", "text to combo"],
|
| 232 |
+
display_name="Convert String to Combo",
|
| 233 |
+
category="logic",
|
| 234 |
+
inputs=[io.String.Input("string")],
|
| 235 |
+
outputs=[io.Combo.Output()],
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
@classmethod
|
| 239 |
+
def execute(cls, string: str) -> io.NodeOutput:
|
| 240 |
+
return io.NodeOutput(string)
|
| 241 |
+
|
| 242 |
+
class InvertBooleanNode(io.ComfyNode):
|
| 243 |
+
@classmethod
|
| 244 |
+
def define_schema(cls):
|
| 245 |
+
return io.Schema(
|
| 246 |
+
node_id="InvertBooleanNode",
|
| 247 |
+
search_aliases=["not", "toggle", "negate", "flip boolean"],
|
| 248 |
+
display_name="Invert Boolean",
|
| 249 |
+
category="logic",
|
| 250 |
+
inputs=[io.Boolean.Input("boolean")],
|
| 251 |
+
outputs=[io.Boolean.Output()],
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
@classmethod
|
| 255 |
+
def execute(cls, boolean: bool) -> io.NodeOutput:
|
| 256 |
+
return io.NodeOutput(not boolean)
|
| 257 |
+
|
| 258 |
+
class LogicExtension(ComfyExtension):
|
| 259 |
+
@override
|
| 260 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 261 |
+
return [
|
| 262 |
+
SwitchNode,
|
| 263 |
+
CustomComboNode,
|
| 264 |
+
# SoftSwitchNode,
|
| 265 |
+
# ConvertStringToComboNode,
|
| 266 |
+
# DCTestNode,
|
| 267 |
+
# AutogrowNamesTestNode,
|
| 268 |
+
# AutogrowPrefixTestNode,
|
| 269 |
+
# ComboOutputTestNode,
|
| 270 |
+
# InvertBooleanNode,
|
| 271 |
+
]
|
| 272 |
+
|
| 273 |
+
async def comfy_entrypoint() -> LogicExtension:
|
| 274 |
+
return LogicExtension()
|
ComfyUI/comfy_extras/nodes_lora_debug.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import folder_paths
|
| 2 |
+
import comfy.utils
|
| 3 |
+
import comfy.sd
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LoraLoaderBypass:
|
| 7 |
+
"""
|
| 8 |
+
Apply LoRA in bypass mode without modifying base model weights.
|
| 9 |
+
|
| 10 |
+
Bypass mode computes: output = base_forward(x) + lora_path(x)
|
| 11 |
+
This is useful for training and when model weights are offloaded.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.loaded_lora = None
|
| 16 |
+
|
| 17 |
+
@classmethod
|
| 18 |
+
def INPUT_TYPES(s):
|
| 19 |
+
return {
|
| 20 |
+
"required": {
|
| 21 |
+
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
|
| 22 |
+
"clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
|
| 23 |
+
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
| 24 |
+
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
|
| 25 |
+
"strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
RETURN_TYPES = ("MODEL", "CLIP")
|
| 30 |
+
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
|
| 31 |
+
FUNCTION = "load_lora"
|
| 32 |
+
|
| 33 |
+
CATEGORY = "loaders"
|
| 34 |
+
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
|
| 35 |
+
EXPERIMENTAL = True
|
| 36 |
+
|
| 37 |
+
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
| 38 |
+
if strength_model == 0 and strength_clip == 0:
|
| 39 |
+
return (model, clip)
|
| 40 |
+
|
| 41 |
+
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
|
| 42 |
+
lora = None
|
| 43 |
+
if self.loaded_lora is not None:
|
| 44 |
+
if self.loaded_lora[0] == lora_path:
|
| 45 |
+
lora = self.loaded_lora[1]
|
| 46 |
+
else:
|
| 47 |
+
self.loaded_lora = None
|
| 48 |
+
|
| 49 |
+
if lora is None:
|
| 50 |
+
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
| 51 |
+
self.loaded_lora = (lora_path, lora)
|
| 52 |
+
|
| 53 |
+
model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
| 54 |
+
return (model_lora, clip_lora)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class LoraLoaderBypassModelOnly(LoraLoaderBypass):
|
| 58 |
+
@classmethod
|
| 59 |
+
def INPUT_TYPES(s):
|
| 60 |
+
return {"required": { "model": ("MODEL",),
|
| 61 |
+
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
| 62 |
+
"strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
|
| 63 |
+
}}
|
| 64 |
+
RETURN_TYPES = ("MODEL",)
|
| 65 |
+
FUNCTION = "load_lora_model_only"
|
| 66 |
+
|
| 67 |
+
def load_lora_model_only(self, model, lora_name, strength_model):
|
| 68 |
+
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
NODE_CLASS_MAPPINGS = {
|
| 72 |
+
"LoraLoaderBypass": LoraLoaderBypass,
|
| 73 |
+
"LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 77 |
+
"LoraLoaderBypass": "Load LoRA (Bypass) (For debugging)",
|
| 78 |
+
"LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only) (for debugging)",
|
| 79 |
+
}
|
ComfyUI/comfy_extras/nodes_lora_extract.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import comfy.model_management
|
| 3 |
+
import comfy.utils
|
| 4 |
+
import folder_paths
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing_extensions import override
|
| 9 |
+
from comfy_api.latest import ComfyExtension, io
|
| 10 |
+
from tqdm.auto import trange
|
| 11 |
+
|
| 12 |
+
CLAMP_QUANTILE = 0.99
|
| 13 |
+
|
| 14 |
+
def extract_lora(diff, rank):
|
| 15 |
+
conv2d = (len(diff.shape) == 4)
|
| 16 |
+
kernel_size = None if not conv2d else diff.size()[2:4]
|
| 17 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
| 18 |
+
out_dim, in_dim = diff.size()[0:2]
|
| 19 |
+
rank = min(rank, in_dim, out_dim)
|
| 20 |
+
|
| 21 |
+
if conv2d:
|
| 22 |
+
if conv2d_3x3:
|
| 23 |
+
diff = diff.flatten(start_dim=1)
|
| 24 |
+
else:
|
| 25 |
+
diff = diff.squeeze()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
U, S, Vh = torch.linalg.svd(diff.float())
|
| 29 |
+
U = U[:, :rank]
|
| 30 |
+
S = S[:rank]
|
| 31 |
+
U = U @ torch.diag(S)
|
| 32 |
+
Vh = Vh[:rank, :]
|
| 33 |
+
|
| 34 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 35 |
+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
| 36 |
+
low_val = -hi_val
|
| 37 |
+
|
| 38 |
+
U = U.clamp(low_val, hi_val)
|
| 39 |
+
Vh = Vh.clamp(low_val, hi_val)
|
| 40 |
+
if conv2d:
|
| 41 |
+
U = U.reshape(out_dim, rank, 1, 1)
|
| 42 |
+
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
| 43 |
+
return (U, Vh)
|
| 44 |
+
|
| 45 |
+
class LORAType(Enum):
|
| 46 |
+
STANDARD = 0
|
| 47 |
+
FULL_DIFF = 1
|
| 48 |
+
|
| 49 |
+
LORA_TYPES = {"standard": LORAType.STANDARD,
|
| 50 |
+
"full_diff": LORAType.FULL_DIFF}
|
| 51 |
+
|
| 52 |
+
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
| 53 |
+
comfy.model_management.load_models_gpu([model_diff])
|
| 54 |
+
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
| 55 |
+
|
| 56 |
+
sd_keys = list(sd.keys())
|
| 57 |
+
for index in trange(len(sd_keys), unit="weight"):
|
| 58 |
+
k = sd_keys[index]
|
| 59 |
+
op_keys = sd_keys[index].rsplit('.', 1)
|
| 60 |
+
if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
|
| 61 |
+
continue
|
| 62 |
+
op = comfy.utils.get_attr(model_diff.model, op_keys[0])
|
| 63 |
+
if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
|
| 64 |
+
weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
|
| 65 |
+
else:
|
| 66 |
+
weight_diff = sd[k]
|
| 67 |
+
|
| 68 |
+
if op_keys[1] == "weight":
|
| 69 |
+
if lora_type == LORAType.STANDARD:
|
| 70 |
+
if weight_diff.ndim < 2:
|
| 71 |
+
if bias_diff:
|
| 72 |
+
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
| 73 |
+
continue
|
| 74 |
+
try:
|
| 75 |
+
out = extract_lora(weight_diff, rank)
|
| 76 |
+
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
|
| 77 |
+
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
|
| 78 |
+
except:
|
| 79 |
+
logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
|
| 80 |
+
elif lora_type == LORAType.FULL_DIFF:
|
| 81 |
+
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
| 82 |
+
|
| 83 |
+
elif bias_diff and op_keys[1] == "bias":
|
| 84 |
+
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
|
| 85 |
+
return output_sd
|
| 86 |
+
|
| 87 |
+
class LoraSave(io.ComfyNode):
|
| 88 |
+
@classmethod
|
| 89 |
+
def define_schema(cls):
|
| 90 |
+
return io.Schema(
|
| 91 |
+
node_id="LoraSave",
|
| 92 |
+
search_aliases=["export lora"],
|
| 93 |
+
display_name="Extract and Save Lora",
|
| 94 |
+
category="_for_testing",
|
| 95 |
+
inputs=[
|
| 96 |
+
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
| 97 |
+
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),
|
| 98 |
+
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys()), advanced=True),
|
| 99 |
+
io.Boolean.Input("bias_diff", default=True, advanced=True),
|
| 100 |
+
io.Model.Input(
|
| 101 |
+
"model_diff",
|
| 102 |
+
tooltip="The ModelSubtract output to be converted to a lora.",
|
| 103 |
+
optional=True,
|
| 104 |
+
),
|
| 105 |
+
io.Clip.Input(
|
| 106 |
+
"text_encoder_diff",
|
| 107 |
+
tooltip="The CLIPSubtract output to be converted to a lora.",
|
| 108 |
+
optional=True,
|
| 109 |
+
),
|
| 110 |
+
],
|
| 111 |
+
is_experimental=True,
|
| 112 |
+
is_output_node=True,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
@classmethod
|
| 116 |
+
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
|
| 117 |
+
if model_diff is None and text_encoder_diff is None:
|
| 118 |
+
return io.NodeOutput()
|
| 119 |
+
|
| 120 |
+
lora_type = LORA_TYPES.get(lora_type)
|
| 121 |
+
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
| 122 |
+
|
| 123 |
+
output_sd = {}
|
| 124 |
+
if model_diff is not None:
|
| 125 |
+
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
|
| 126 |
+
if text_encoder_diff is not None:
|
| 127 |
+
output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
|
| 128 |
+
|
| 129 |
+
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
| 130 |
+
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
| 131 |
+
|
| 132 |
+
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
| 133 |
+
return io.NodeOutput()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class LoraSaveExtension(ComfyExtension):
|
| 137 |
+
@override
|
| 138 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 139 |
+
return [
|
| 140 |
+
LoraSave,
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
async def comfy_entrypoint() -> LoraSaveExtension:
|
| 145 |
+
return LoraSaveExtension()
|
ComfyUI/comfy_extras/nodes_lotus.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import override
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import comfy.model_management as mm
|
| 5 |
+
from comfy_api.latest import ComfyExtension, io
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LotusConditioning(io.ComfyNode):
|
| 9 |
+
@classmethod
|
| 10 |
+
def define_schema(cls):
|
| 11 |
+
return io.Schema(
|
| 12 |
+
node_id="LotusConditioning",
|
| 13 |
+
category="conditioning/lotus",
|
| 14 |
+
inputs=[],
|
| 15 |
+
outputs=[io.Conditioning.Output(display_name="conditioning")],
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
@classmethod
|
| 19 |
+
def execute(cls) -> io.NodeOutput:
|
| 20 |
+
device = mm.get_torch_device()
|
| 21 |
+
#lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change
|
| 22 |
+
#and getting parity with the reference implementation would otherwise require inference and 800mb of tensors
|
| 23 |
+
prompt_embeds = torch.tensor([[[-0.3134765625, -0.447509765625, -0.00823974609375, -0.22802734375, 0.1785888671875, -0.2342529296875, -0.2188720703125, -0.0089111328125, -0.31396484375, 0.196533203125, -0.055877685546875, -0.3828125, -0.0965576171875, 0.0073394775390625, -0.284423828125, 0.07470703125, -0.086181640625, -0.211181640625, 0.0599365234375, 0.10693359375, 0.0007929801940917969, -0.78076171875, -0.382568359375, -0.1851806640625, -0.140625, -0.0936279296875, -0.1229248046875, -0.152099609375, -0.203857421875, -0.2349853515625, -0.2437744140625, -0.10858154296875, -0.08990478515625, 0.08892822265625, -0.2391357421875, -0.1611328125, -0.427978515625, -0.1336669921875, -0.27685546875, -0.1781005859375, -0.3857421875, 0.251953125, -0.055999755859375, -0.0712890625, -0.00130462646484375, 0.033477783203125, -0.26416015625, 0.07171630859375, -0.0090789794921875, -0.2025146484375, -0.2763671875, -0.09869384765625, -0.45751953125, -0.23095703125, 0.004528045654296875, -0.369140625, -0.366943359375, -0.205322265625, -0.1505126953125, -0.45166015625, -0.2059326171875, 0.0168609619140625, -0.305419921875, -0.150634765625, 0.02685546875, -0.609375, -0.019012451171875, 0.050445556640625, -0.0084381103515625, -0.31005859375, -0.184326171875, -0.15185546875, 0.06732177734375, 0.150390625, -0.10919189453125, -0.08837890625, -0.50537109375, -0.389892578125, -0.0294342041015625, -0.10491943359375, -0.187255859375, -0.43212890625, -0.328125, -1.060546875, 0.011871337890625, 0.04730224609375, -0.09521484375, -0.07452392578125, -0.29296875, -0.109130859375, -0.250244140625, -0.3828125, -0.171875, -0.03399658203125, -0.15478515625, -0.1861572265625, -0.2398681640625, 0.1053466796875, -0.22314453125, -0.1932373046875, -0.18798828125, -0.430419921875, -0.05364990234375, -0.474609375, -0.261474609375, -0.1077880859375, -0.439208984375, 0.08966064453125, -0.185302734375, -0.338134765625, -0.297119140625, -0.298583984375, -0.175537109375, -0.373291015625, -0.1397705078125, -0.260498046875, -0.383544921875, -0.09979248046875, -0.319580078125, -0.06884765625, -0.4365234375, -0.183837890625, -0.393310546875, -0.002277374267578125, 0.11236572265625, -0.260498046875, -0.2242431640625, -0.19384765625, -0.51123046875, 0.03216552734375, -0.048004150390625, -0.279052734375, -0.2978515625, -0.255615234375, 0.115478515625, -4.08984375, -0.1668701171875, -0.278076171875, -0.5712890625, -0.1385498046875, -0.244384765625, -0.41455078125, -0.244140625, -0.0677490234375, -0.141357421875, -0.11590576171875, -0.1439208984375, -0.0185394287109375, -2.490234375, -0.1549072265625, -0.2305908203125, -0.3828125, -0.1173095703125, -0.08258056640625, -0.1719970703125, -0.325439453125, -0.292724609375, -0.08154296875, -0.412353515625, -0.3115234375, -0.00832366943359375, 0.00489044189453125, -0.2236328125, -0.151123046875, -0.457275390625, -0.135009765625, -0.163330078125, -0.0819091796875, 0.06689453125, 0.0209197998046875, -0.11907958984375, -0.10369873046875, -0.2998046875, -0.478759765625, -0.07940673828125, -0.01517486572265625, -0.3017578125, -0.343994140625, -0.258544921875, -0.44775390625, -0.392822265625, -0.0255584716796875, -0.2998046875, 0.10833740234375, -0.271728515625, -0.36181640625, -0.255859375, -0.2056884765625, -0.055450439453125, 0.060516357421875, -0.45751953125, -0.2322998046875, -0.1737060546875, -0.40576171875, -0.2286376953125, -0.053070068359375, -0.0283660888671875, -0.1898193359375, -4.291534423828125e-05, -0.6591796875, -0.1717529296875, -0.479736328125, -0.1400146484375, -0.40771484375, 0.154296875, 0.003101348876953125, 0.00661468505859375, -0.2073974609375, -0.493408203125, 2.171875, -0.45361328125, -0.283935546875, -0.302001953125, -0.25146484375, -0.207275390625, -0.1524658203125, -0.72998046875, -0.08203125, 0.053192138671875, -0.2685546875, 0.1834716796875, -0.270263671875, -0.091552734375, -0.08319091796875, -0.1297607421875, -0.453857421875, 0.0687255859375, 0.0268096923828125, -0.16552734375, -0.4208984375, -0.1552734375, -0.057373046875, -0.300537109375, -0.04541015625, -0.486083984375, -0.2205810546875, -0.39013671875, 0.007488250732421875, -0.005329132080078125, -0.09759521484375, -0.1448974609375, -0.21923828125, -0.429443359375, -0.40087890625, -0.19384765625, -0.064453125, -0.0306243896484375, -0.045806884765625, -0.056793212890625, 0.119384765625, -0.2073974609375, -0.356201171875, -0.168212890625, -0.291748046875, -0.289794921875, -0.205322265625, -0.419677734375, -0.478271484375, -0.2037353515625, -0.368408203125, -0.186279296875, -0.427734375, -0.1756591796875, 0.07501220703125, -0.2457275390625, -0.03692626953125, 0.003997802734375, -5.7578125, -0.01052093505859375, -0.2305908203125, -0.2252197265625, -0.197509765625, -0.1566162109375, -0.1668701171875, -0.383056640625, -0.05413818359375, 0.12188720703125, -0.369873046875, -0.0184478759765625, -0.150146484375, -0.51123046875, -0.45947265625, -0.1561279296875, 0.060455322265625, 0.043487548828125, -0.1370849609375, -0.069091796875, -0.285888671875, -0.44482421875, -0.2374267578125, -0.2191162109375, -0.434814453125, -0.0360107421875, 0.1298828125, 0.0217742919921875, -0.51220703125, -0.13525390625, -0.09381103515625, -0.276611328125, -0.171875, -0.17138671875, -0.4443359375, -0.2178955078125, -0.269775390625, -0.38623046875, -0.31591796875, -0.42333984375, -0.280029296875, -0.255615234375, -0.17041015625, 0.06268310546875, -0.1878662109375, -0.00677490234375, -0.23583984375, -0.08795166015625, -0.2232666015625, -0.1719970703125, -0.484130859375, -0.328857421875, 0.04669189453125, -0.0419921875, -0.11114501953125, 0.02313232421875, -0.0033130645751953125, -0.6005859375, 0.09051513671875, -0.1884765625, -0.262939453125, -0.375732421875, -0.525390625, -0.1170654296875, -0.3779296875, -0.242919921875, -0.419921875, 0.0665283203125, -0.343017578125, 0.06658935546875, -0.346435546875, -0.1363525390625, -0.2000732421875, -0.3837890625, 0.028167724609375, 0.043853759765625, -0.0171051025390625, -0.477294921875, -0.107421875, -0.129150390625, -0.319580078125, -0.32177734375, -0.4951171875, -0.010589599609375, -0.1778564453125, -0.40234375, -0.0810546875, 0.03314208984375, -0.13720703125, -0.31591796875, -0.048248291015625, -0.274658203125, -0.0689697265625, -0.027130126953125, -0.0953369140625, 0.146728515625, -0.38671875, -0.025390625, -0.42333984375, -0.41748046875, -0.379638671875, -0.1978759765625, -0.533203125, -0.33544921875, 0.0694580078125, -0.322998046875, -0.1876220703125, 0.0094451904296875, 0.1839599609375, -0.254150390625, -0.30078125, -0.09228515625, -0.0885009765625, 0.12371826171875, 0.1500244140625, -0.12152099609375, -0.29833984375, 0.03924560546875, -0.1470947265625, -0.1610107421875, -0.2049560546875, -0.01708984375, -0.2470703125, -0.1522216796875, -0.25830078125, 0.10870361328125, -0.302490234375, -0.2376708984375, -0.360107421875, -0.443359375, -0.0784912109375, -0.63623046875, -0.0980224609375, -0.332275390625, -0.1749267578125, -0.30859375, -0.1968994140625, -0.250244140625, -0.447021484375, -0.18408203125, -0.006908416748046875, -0.2044677734375, -0.2548828125, -0.369140625, -0.11328125, -0.1103515625, -0.27783203125, -0.325439453125, 0.01381683349609375, 0.036773681640625, -0.1458740234375, -0.34619140625, -0.232177734375, -0.0562744140625, -0.4482421875, -0.21875, -0.0855712890625, -0.276123046875, -0.1544189453125, -0.223388671875, -0.259521484375, 0.0865478515625, -0.0038013458251953125, -0.340087890625, -0.076171875, -0.25341796875, -0.0007548332214355469, -0.060455322265625, -0.352294921875, 0.035736083984375, -0.2181396484375, -0.2318115234375, -0.1707763671875, 0.018646240234375, 0.093505859375, -0.197021484375, 0.033477783203125, -0.035247802734375, 0.0440673828125, -0.2056884765625, -0.040924072265625, -0.05865478515625, 0.056884765625, -0.08807373046875, -0.10845947265625, 0.09564208984375, -0.10888671875, -0.332275390625, -0.1119384765625, -0.115478515625, 13.0234375, 0.0030040740966796875, -0.53662109375, -0.1856689453125, -0.068115234375, -0.143798828125, -0.177978515625, -0.32666015625, -0.353515625, -0.1563720703125, -0.3203125, 0.0085906982421875, -0.1043701171875, -0.365478515625, -0.303466796875, -0.34326171875, -0.410888671875, -0.03790283203125, -0.11419677734375, -0.2939453125, 0.074462890625, -0.21826171875, 0.0242767333984375, -0.226318359375, -0.353515625, -0.177734375, -0.169189453125, -0.2423095703125, -0.12115478515625, -0.07843017578125, -0.341064453125, -0.2117919921875, -0.505859375, -0.544921875, -0.3935546875, -0.10772705078125, -0.2054443359375, -0.136474609375, -0.1796875, -0.396240234375, -0.1971435546875, -0.68408203125, -0.032684326171875, -0.03863525390625, -0.0709228515625, -0.1005859375, -0.156005859375, -0.3837890625, -0.319580078125, 0.11102294921875, -0.394287109375, 0.0799560546875, -0.50341796875, -0.1572265625, 0.004131317138671875, -0.12286376953125, -0.2347412109375, -0.29150390625, -0.10321044921875, -0.286376953125, 0.018798828125, -0.152099609375, -0.321044921875, 0.0191650390625, -0.11376953125, -0.54736328125, 0.15869140625, -0.257568359375, -0.2490234375, -0.3115234375, -0.09765625, -0.350830078125, -0.36376953125, -0.0771484375, -0.2298583984375, -0.30615234375, -0.052154541015625, -0.12091064453125, -0.40283203125, -0.1649169921875, 0.0206451416015625, -0.312744140625, -0.10308837890625, -0.50341796875, -0.1754150390625, -0.2003173828125, -0.173583984375, -0.204833984375, -0.1876220703125, -0.12176513671875, -0.06201171875, -0.03485107421875, -0.20068359375, -0.21484375, -0.246337890625, -0.006587982177734375, -0.09674072265625, -0.4658203125, -0.3994140625, -0.2210693359375, -0.09588623046875, -0.126220703125, -0.09222412109375, -0.145751953125, -0.217529296875, -0.289306640625, -0.28271484375, -0.1787109375, -0.169189453125, -0.359375, -0.21826171875, -0.043792724609375, -0.205322265625, -0.2900390625, -0.055419921875, -0.1490478515625, -0.340576171875, -0.045928955078125, -0.30517578125, -0.51123046875, -0.1046142578125, -0.349853515625, -0.10882568359375, -0.16748046875, -0.267333984375, -0.122314453125, -0.0985107421875, -0.3076171875, -0.1766357421875, -0.251708984375, 0.1964111328125, -0.2220458984375, -0.2349853515625, -0.035980224609375, -0.1749267578125, -0.237060546875, -0.480224609375, -0.240234375, -0.09539794921875, -0.2481689453125, -0.389404296875, -0.1748046875, -0.370849609375, -0.010650634765625, -0.147705078125, -0.0035457611083984375, -0.32568359375, -0.29931640625, -0.1395263671875, -0.28173828125, -0.09820556640625, -0.0176239013671875, -0.05926513671875, -0.0755615234375, -0.1746826171875, -0.283203125, -0.1617431640625, -0.4404296875, 0.046234130859375, -0.183837890625, -0.052032470703125, -0.24658203125, -0.11224365234375, -0.100830078125, -0.162841796875, -0.29736328125, -0.396484375, 0.11798095703125, -0.006496429443359375, -0.32568359375, -0.347900390625, -0.04595947265625, -0.09637451171875, -0.344970703125, -0.01166534423828125, -0.346435546875, -0.2861328125, -0.1845703125, -0.276611328125, -0.01312255859375, -0.395263671875, -0.50927734375, -0.1114501953125, -0.1861572265625, -0.2158203125, -0.1812744140625, 0.055419921875, -0.294189453125, 0.06500244140625, -0.1444091796875, -0.06365966796875, -0.18408203125, -0.0091705322265625, -0.1640625, -0.1856689453125, 0.090087890625, 0.024566650390625, -0.0195159912109375, -0.5546875, -0.301025390625, -0.438232421875, -0.072021484375, 0.030517578125, -0.1490478515625, 0.04888916015625, -0.23681640625, -0.1553955078125, -0.018096923828125, -0.229736328125, -0.2919921875, -0.355712890625, -0.285400390625, -0.1756591796875, -0.08355712890625, -0.416259765625, 0.022674560546875, -0.417236328125, 0.410400390625, -0.249755859375, 0.015625, -0.033599853515625, -0.040313720703125, -0.51708984375, -0.0518798828125, -0.08843994140625, -0.2022705078125, -0.3740234375, -0.285888671875, -0.176025390625, -0.292724609375, -0.369140625, -0.08367919921875, -0.356689453125, -0.38623046875, 0.06549072265625, 0.1669921875, -0.2099609375, -0.007434844970703125, 0.12890625, -0.0040740966796875, -0.2174072265625, -0.025115966796875, -0.2364501953125, -0.1695556640625, -0.0469970703125, -0.03924560546875, -0.36181640625, -0.047515869140625, -0.3154296875, -0.275634765625, -0.25634765625, -0.061920166015625, -0.12164306640625, -0.47314453125, -0.10784912109375, -0.74755859375, -0.13232421875, -0.32421875, -0.04998779296875, -0.286376953125, 0.10345458984375, -0.1710205078125, -0.388916015625, 0.12744140625, -0.3359375, -0.302490234375, -0.238525390625, -0.1455078125, -0.15869140625, -0.2427978515625, -0.0355224609375, -0.11944580078125, -0.31298828125, 0.11456298828125, -0.287841796875, -0.5439453125, -0.3076171875, -0.08642578125, -0.2408447265625, -0.283447265625, -0.428466796875, -0.085693359375, -0.1683349609375, 0.255126953125, 0.07635498046875, -0.38623046875, -0.2025146484375, -0.1331787109375, -0.10821533203125, -0.49951171875, 0.09130859375, -0.19677734375, -0.01904296875, -0.151123046875, -0.344482421875, -0.316650390625, -0.03900146484375, 0.1397705078125, 0.1334228515625, -0.037200927734375, -0.01861572265625, -0.1351318359375, -0.07037353515625, -0.380615234375, -0.34033203125, -0.06903076171875, 0.219970703125, 0.0132598876953125, -0.15869140625, -0.6376953125, 0.158935546875, -0.5283203125, -0.2320556640625, -0.185791015625, -0.2132568359375, -0.436767578125, -0.430908203125, -0.1763916015625, -0.0007672309875488281, -0.424072265625, -0.06719970703125, -0.347900390625, -0.14453125, -0.3056640625, -0.36474609375, -0.35986328125, -0.46240234375, -0.446044921875, -0.1905517578125, -0.1114501953125, -0.42919921875, -0.0643310546875, -0.3662109375, -0.4296875, -0.10968017578125, -0.2998046875, -0.1756591796875, -0.4052734375, -0.0841064453125, -0.252197265625, -0.047393798828125, 0.00434112548828125, -0.10040283203125, -0.271484375, -0.185302734375, -0.1910400390625, 0.10260009765625, 0.01393890380859375, -0.03350830078125, -0.33935546875, -0.329345703125, 0.0574951171875, -0.18896484375, -0.17724609375, -0.42919921875, -0.26708984375, -0.4189453125, -0.149169921875, -0.265625, -0.198974609375, -0.1722412109375, 0.1563720703125, -0.20947265625, -0.267822265625, -0.06353759765625, -0.365478515625, -0.340087890625, -0.3095703125, -0.320068359375, -0.0880126953125, -0.353759765625, -0.0005812644958496094, -0.1617431640625, -0.1866455078125, -0.201416015625, -0.181396484375, -0.2349853515625, -0.384765625, -0.5244140625, 0.01227569580078125, -0.21337890625, -0.30810546875, -0.17578125, -0.3037109375, -0.52978515625, -0.1561279296875, -0.296142578125, 0.057342529296875, -0.369384765625, -0.107666015625, -0.338623046875, -0.2060546875, -0.0213775634765625, -0.394775390625, -0.219482421875, -0.125732421875, -0.03997802734375, -0.42431640625, -0.134521484375, -0.2418212890625, -0.10504150390625, 0.1552734375, 0.1126708984375, -0.1427001953125, -0.133544921875, -0.111083984375, -0.375732421875, -0.2783203125, -0.036834716796875, -0.11053466796875, 0.2471923828125, -0.2529296875, -0.56494140625, -0.374755859375, -0.326416015625, 0.2137451171875, -0.09454345703125, -0.337158203125, -0.3359375, -0.34375, -0.0999755859375, -0.388671875, 0.0103302001953125, 0.14990234375, -0.2041015625, -0.39501953125, -0.39013671875, -0.1258544921875, 0.1453857421875, -0.250732421875, -0.06732177734375, -0.10638427734375, -0.032379150390625, -0.35888671875, -0.098876953125, -0.172607421875, 0.05126953125, -0.1956787109375, -0.183837890625, -0.37060546875, 0.1556396484375, -0.34375, -0.28662109375, -0.06982421875, -0.302490234375, -0.281005859375, -0.1640625, -0.5302734375, -0.1368408203125, -0.1268310546875, -0.35302734375, -0.1473388671875, -0.45556640625, -0.35986328125, -0.273681640625, -0.2249755859375, -0.1893310546875, 0.09356689453125, -0.248291015625, -0.197998046875, -0.3525390625, -0.30126953125, -0.228271484375, -0.2421875, -0.0906982421875, 0.227783203125, -0.296875, -0.009796142578125, -0.2939453125, -0.1021728515625, -0.215576171875, -0.267822265625, -0.052642822265625, 0.203369140625, -0.1417236328125, 0.18505859375, 0.12347412109375, -0.0972900390625, -0.54052734375, -0.430419921875, -0.0906982421875, -0.5419921875, -0.22900390625, -0.0625, -0.12152099609375, -0.495849609375, -0.206787109375, -0.025848388671875, 0.039031982421875, -0.453857421875, -0.318359375, -0.426025390625, -0.3701171875, -0.2169189453125, 0.0845947265625, -0.045654296875, 0.11090087890625, 0.0012454986572265625, 0.2066650390625, -0.046356201171875, -0.2337646484375, -0.295654296875, 0.057891845703125, -0.1639404296875, -0.0535888671875, -0.2607421875, -0.1488037109375, -0.16015625, -0.54345703125, -0.2305908203125, -0.55029296875, -0.178955078125, -0.222412109375, -0.0711669921875, -0.12298583984375, -0.119140625, -0.253662109375, -0.33984375, -0.11322021484375, -0.10723876953125, -0.205078125, -0.360595703125, 0.085205078125, -0.252197265625, -0.365966796875, -0.26953125, 0.2000732421875, -0.50634765625, 0.05706787109375, -0.3115234375, 0.0242919921875, -0.1689453125, -0.2401123046875, -0.3759765625, -0.2125244140625, 0.076416015625, -0.489013671875, -0.11749267578125, -0.55908203125, -0.313232421875, -0.572265625, -0.1387939453125, -0.037078857421875, -0.385498046875, 0.0323486328125, -0.39404296875, -0.05072021484375, -0.10430908203125, -0.10919189453125, -0.28759765625, -0.37451171875, -0.016937255859375, -0.2200927734375, -0.296875, -0.0286712646484375, -0.213134765625, 0.052001953125, -0.052337646484375, -0.253662109375, 0.07269287109375, -0.2498779296875, -0.150146484375, -0.09930419921875, -0.343505859375, 0.254150390625, -0.032440185546875, -0.296142578125], [1.4111328125, 0.00757598876953125, -0.428955078125, 0.089599609375, 0.0227813720703125, -0.0350341796875, -1.0986328125, 0.194091796875, 2.115234375, -0.75439453125, 0.269287109375, -0.73486328125, -1.1025390625, -0.050262451171875, -0.5830078125, 0.0268707275390625, -0.603515625, -0.6025390625, -1.1689453125, 0.25048828125, -0.4189453125, -0.5517578125, -0.30322265625, 0.7724609375, 0.931640625, -0.1422119140625, 2.27734375, -0.56591796875, 1.013671875, -0.9638671875, -0.66796875, -0.8125, 1.3740234375, -1.060546875, -1.029296875, -1.6796875, 0.62890625, 0.49365234375, 0.671875, 0.99755859375, -1.0185546875, -0.047027587890625, -0.374267578125, 0.2354736328125, 1.4970703125, -1.5673828125, 0.448974609375, 0.2078857421875, -1.060546875, -0.171875, -0.6201171875, -0.1607666015625, 0.7548828125, -0.58935546875, -0.2052001953125, 0.060791015625, 0.200439453125, 3.154296875, -3.87890625, 2.03515625, 1.126953125, 0.1640625, -1.8447265625, 0.002620697021484375, 0.7998046875, -0.337158203125, 0.47216796875, -0.5849609375, 0.9970703125, 0.3935546875, 1.22265625, -1.5048828125, -0.65673828125, 1.1474609375, -1.73046875, -1.8701171875, 1.529296875, -0.6787109375, -1.4453125, 1.556640625, -0.327392578125, 2.986328125, -0.146240234375, -2.83984375, 0.303466796875, -0.71728515625, -0.09698486328125, -0.2423095703125, 0.6767578125, -2.197265625, -0.86279296875, -0.53857421875, -1.2236328125, 1.669921875, -1.1689453125, -0.291259765625, -0.54736328125, -0.036346435546875, 1.041015625, -1.7265625, -0.6064453125, -0.1634521484375, 0.2381591796875, 0.65087890625, -1.169921875, 1.9208984375, 0.5634765625, 0.37841796875, 0.798828125, -1.021484375, -0.4091796875, 2.275390625, -0.302734375, -1.7783203125, 1.0458984375, 1.478515625, 0.708984375, -1.541015625, -0.0006041526794433594, 1.1884765625, 2.041015625, 0.560546875, -0.1131591796875, 1.0341796875, 0.06121826171875, 2.6796875, -0.53369140625, -1.2490234375, -0.7333984375, -1.017578125, -1.0078125, 1.3212890625, -0.47607421875, -1.4189453125, 0.54052734375, -0.796875, -0.73095703125, -1.412109375, -0.94873046875, -2.2734375, -1.1220703125, -1.3837890625, -0.5087890625, -1.0380859375, -0.93603515625, -0.58349609375, -1.0703125, -1.10546875, -2.60546875, 0.062225341796875, 0.38232421875, -0.411376953125, -0.369140625, -0.9833984375, -0.7294921875, -0.181396484375, -0.47216796875, -0.56884765625, -0.11041259765625, -2.673828125, 0.27783203125, -0.857421875, 0.9296875, 1.9580078125, 0.1385498046875, -1.91796875, -1.529296875, 0.53857421875, 0.509765625, -0.90380859375, -0.0947265625, -2.083984375, 0.9228515625, -0.28564453125, -0.80859375, -0.093505859375, -0.6015625, -1.255859375, 0.6533203125, 0.327880859375, -0.07598876953125, -0.22705078125, -0.30078125, -0.5185546875, -1.6044921875, 1.5927734375, 1.416015625, -0.91796875, -0.276611328125, -0.75830078125, -1.1689453125, -1.7421875, 1.0546875, -0.26513671875, -0.03314208984375, 0.278076171875, -1.337890625, 0.055023193359375, 0.10546875, -1.064453125, 1.048828125, -1.4052734375, -1.1240234375, -0.51416015625, -1.05859375, -1.7265625, -1.1328125, 0.43310546875, -2.576171875, -2.140625, -0.79345703125, 0.50146484375, 1.96484375, 0.98583984375, 0.337646484375, -0.77978515625, 0.85498046875, -0.65185546875, -0.484375, 2.708984375, 0.55810546875, -0.147216796875, -0.5537109375, -0.75439453125, -1.736328125, 1.1259765625, -1.095703125, -0.2587890625, 2.978515625, 0.335205078125, 0.357666015625, -0.09356689453125, 0.295654296875, -0.23779296875, 1.5751953125, 0.10400390625, 1.7001953125, -0.72900390625, -1.466796875, -0.2012939453125, 0.634765625, -0.1556396484375, -2.01171875, 0.32666015625, 0.047454833984375, -0.1671142578125, -0.78369140625, -0.994140625, 0.7802734375, -0.1429443359375, -0.115234375, 0.53271484375, -0.96142578125, -0.064208984375, 1.396484375, 1.654296875, -1.6015625, -0.77392578125, 0.276123046875, -0.42236328125, 0.8642578125, 0.533203125, 0.397216796875, -1.21484375, 0.392578125, -0.501953125, -0.231689453125, 1.474609375, 1.6669921875, 1.8662109375, -1.2998046875, 0.223876953125, -0.51318359375, -0.437744140625, -1.16796875, -0.7724609375, 1.6826171875, 0.62255859375, 2.189453125, -0.599609375, -0.65576171875, -1.1005859375, -0.45263671875, -0.292236328125, 2.58203125, -1.3779296875, 0.23486328125, -1.708984375, -1.4111328125, -0.5078125, -0.8525390625, -0.90771484375, 0.861328125, -2.22265625, -1.380859375, 0.7275390625, 0.85595703125, -0.77978515625, 2.044921875, -0.430908203125, 0.78857421875, -1.21484375, -0.09130859375, 0.5146484375, -1.92578125, -0.1396484375, 0.289306640625, 0.60498046875, 0.93896484375, -0.09295654296875, -0.45751953125, -0.986328125, -0.66259765625, 1.48046875, 0.274169921875, -0.267333984375, -1.3017578125, -1.3623046875, -1.982421875, -0.86083984375, -0.41259765625, -0.2939453125, -1.91015625, 1.6826171875, 0.437255859375, 1.0029296875, 0.376220703125, -0.010467529296875, -0.82861328125, -0.513671875, -3.134765625, 1.0205078125, -1.26171875, -1.009765625, 1.0869140625, -0.95703125, 0.0103759765625, 1.642578125, 0.78564453125, 1.029296875, 0.496826171875, 1.2880859375, 0.5234375, 0.05322265625, -0.206787109375, -0.79443359375, -1.1669921875, 0.049530029296875, -0.27978515625, 0.0237884521484375, -0.74169921875, -1.068359375, 0.86083984375, 1.1787109375, 0.91064453125, -0.453857421875, -1.822265625, -0.9228515625, -0.50048828125, 0.359130859375, 0.802734375, -1.3564453125, -0.322509765625, -1.1123046875, -1.0390625, -0.52685546875, -1.291015625, -0.343017578125, -1.2109375, -0.19091796875, 2.146484375, -0.04315185546875, -0.3701171875, -2.044921875, -0.429931640625, -0.56103515625, -0.166015625, -0.4658203125, -2.29296875, -1.078125, -1.0927734375, -0.1033935546875, -0.56103515625, -0.05743408203125, -1.986328125, -0.513671875, 0.70361328125, -2.484375, -1.3037109375, -1.6650390625, 0.4814453125, -0.84912109375, -2.697265625, -0.197998046875, 0.0869140625, -0.172607421875, -1.326171875, -1.197265625, 1.23828125, -0.38720703125, -0.075927734375, 0.02569580078125, -1.2119140625, 0.09027099609375, -2.12890625, -1.640625, -0.1524658203125, 0.2373046875, 1.37109375, 2.248046875, 1.4619140625, 0.3134765625, 0.50244140625, -0.1383056640625, -1.2705078125, 0.7353515625, 0.65771484375, -0.431396484375, -1.341796875, 0.10089111328125, 0.208984375, -0.0099945068359375, 0.83203125, 1.314453125, -0.422607421875, -1.58984375, -0.6044921875, 0.23681640625, -1.60546875, -0.61083984375, -1.5615234375, 1.62890625, -0.6728515625, -0.68212890625, -0.5224609375, -0.9150390625, -0.468994140625, 0.268310546875, 0.287353515625, -0.025543212890625, 0.443603515625, 1.62109375, -1.08984375, -0.5556640625, 1.03515625, -0.31298828125, -0.041778564453125, 0.260986328125, 0.34716796875, -2.326171875, 0.228271484375, -0.85107421875, -2.255859375, 0.3486328125, -0.25830078125, -0.3671875, -0.796875, -1.115234375, 1.8369140625, -0.19775390625, -1.236328125, -0.0447998046875, 0.69921875, 1.37890625, 1.11328125, 0.0928955078125, 0.6318359375, -0.62353515625, 0.55859375, -0.286865234375, 1.5361328125, -0.391357421875, -0.052215576171875, -1.12890625, 0.55517578125, -0.28515625, -0.3603515625, 0.68896484375, 0.67626953125, 0.003070831298828125, 1.2236328125, 0.1597900390625, -1.3076171875, 0.99951171875, -2.5078125, -1.2119140625, 0.1749267578125, -1.1865234375, -1.234375, -0.1180419921875, -1.751953125, 0.033050537109375, 0.234130859375, -3.107421875, -1.0380859375, 0.61181640625, -0.87548828125, 0.3154296875, -1.103515625, 0.261474609375, -1.130859375, -0.7470703125, -0.43408203125, 1.3828125, -0.41259765625, -1.7587890625, 0.765625, 0.004852294921875, 0.135498046875, -0.76953125, -0.1314697265625, 0.400390625, 1.43359375, 0.07135009765625, 0.0645751953125, -0.5869140625, -0.5810546875, -0.2900390625, -1.3037109375, 0.1287841796875, -0.27490234375, 0.59228515625, 2.333984375, -0.54541015625, -0.556640625, 0.447265625, -0.806640625, 0.09149169921875, -0.70654296875, -0.357177734375, -1.099609375, -0.5576171875, -0.44189453125, 0.400390625, -0.666015625, -1.4619140625, 0.728515625, -1.5986328125, 0.153076171875, -0.126708984375, -2.83984375, -1.84375, -0.2469482421875, 0.677734375, 0.43701171875, 3.298828125, 1.1591796875, -0.7158203125, -0.8251953125, 0.451171875, -2.376953125, -0.58642578125, -0.86767578125, 0.0789794921875, 0.1351318359375, -0.325439453125, 0.484375, 1.166015625, -0.1610107421875, -0.15234375, -0.54638671875, -0.806640625, 0.285400390625, 0.1661376953125, -0.50146484375, -1.0478515625, 1.5751953125, 0.0313720703125, 0.2396240234375, -0.6572265625, -0.1258544921875, -1.060546875, 1.3076171875, -0.301513671875, -1.2412109375, 0.6376953125, -1.5693359375, 0.354248046875, 0.2427978515625, -0.392333984375, 0.61962890625, -0.58837890625, -1.71484375, -0.2098388671875, -0.828125, 0.330810546875, 0.16357421875, -0.2259521484375, 0.0972900390625, -0.451416015625, 1.79296875, -1.673828125, -1.58203125, -2.099609375, -0.487548828125, -0.87060546875, 0.62646484375, -1.470703125, -0.1558837890625, 0.4609375, 1.3369140625, 0.2322998046875, 0.1632080078125, 0.65966796875, 1.0810546875, 0.1041259765625, 0.63232421875, -0.32421875, -1.04296875, -1.046875, -1.3720703125, -0.8486328125, 0.1290283203125, 0.137939453125, 0.1549072265625, -1.0908203125, 0.0167694091796875, -0.31689453125, 1.390625, 0.07269287109375, 1.0390625, 1.1162109375, -0.455810546875, -0.06689453125, -0.053741455078125, 0.5048828125, -0.8408203125, -1.19921875, 0.87841796875, 0.7421875, 0.2030029296875, 0.109619140625, -0.59912109375, -1.337890625, -0.74169921875, -0.64453125, -1.326171875, 0.21044921875, -1.3583984375, -1.685546875, -0.472900390625, -0.270263671875, 0.99365234375, -0.96240234375, 1.1279296875, -0.45947265625, -0.45654296875, -0.99169921875, -3.515625, -1.9853515625, 0.73681640625, 0.92333984375, -0.56201171875, -1.4453125, -2.078125, 0.94189453125, -1.333984375, 0.0982666015625, 0.60693359375, 0.367431640625, 3.015625, -1.1357421875, -1.5634765625, 0.90234375, -0.1783447265625, 0.1802978515625, -0.317138671875, -0.513671875, 1.2353515625, -0.033203125, 1.4482421875, 1.0087890625, 0.9248046875, 0.10418701171875, 0.7626953125, -1.3798828125, 0.276123046875, 0.55224609375, 1.1005859375, -0.62158203125, -0.806640625, 0.65087890625, 0.270263671875, -0.339111328125, -0.9384765625, -0.09381103515625, -0.7216796875, 1.37890625, -0.398193359375, -0.3095703125, -1.4912109375, 0.96630859375, 0.43798828125, 0.62255859375, 0.0213470458984375, 0.235595703125, -1.2958984375, 0.0157318115234375, -0.810546875, 1.9736328125, -0.2462158203125, 0.720703125, 0.822265625, -0.755859375, -0.658203125, 0.344482421875, -2.892578125, -0.282470703125, 1.2529296875, -0.294189453125, 0.6748046875, -0.80859375, 0.9287109375, 1.27734375, -1.71875, -0.166015625, 0.47412109375, -0.41259765625, -1.3681640625, -0.978515625, -0.77978515625, -1.044921875, -0.90380859375, -0.08184814453125, -0.86181640625, -0.10772705078125, -0.299560546875, -0.4306640625, -0.47119140625, 0.95703125, 1.107421875, 0.91796875, 0.76025390625, 0.7392578125, -0.09161376953125, -0.7392578125, 0.9716796875, -0.395751953125, -0.75390625, -0.164306640625, -0.087646484375, 0.028564453125, -0.91943359375, -0.66796875, 2.486328125, 0.427734375, 0.626953125, 0.474853515625, 0.0926513671875, 0.830078125, -0.6923828125, 0.7841796875, -0.89208984375, -2.482421875, 0.034912109375, -1.3447265625, -0.475341796875, -0.286376953125, -0.732421875, 0.190673828125, -0.491455078125, -3.091796875, -1.2783203125, -0.66015625, -0.1507568359375, 0.042236328125, -1.025390625, 0.12744140625, -1.984375, -0.393798828125, -1.25, -1.140625, 1.77734375, 0.2457275390625, -0.8017578125, 0.7763671875, -0.387939453125, -0.3662109375, 1.1572265625, 0.123291015625, -0.07135009765625, 1.412109375, -0.685546875, -3.078125, 0.031524658203125, -0.70458984375, 0.78759765625, 0.433837890625, -1.861328125, -1.33203125, 2.119140625, -1.3544921875, -0.6591796875, -1.4970703125, 0.40625, -2.078125, -1.30859375, 0.050262451171875, -0.60107421875, 1.0078125, 0.05657958984375, -0.96826171875, 0.0264892578125, 0.159912109375, 0.84033203125, -1.1494140625, -0.0433349609375, -0.2034912109375, 1.09765625, -1.142578125, -0.283203125, -0.427978515625, 1.0927734375, -0.67529296875, -0.61572265625, 2.517578125, 0.84130859375, 1.8662109375, 0.1748046875, -0.407958984375, -0.029449462890625, -0.27587890625, -0.958984375, -0.10028076171875, 1.248046875, -0.0792236328125, -0.45556640625, 0.7685546875, 1.5556640625, -1.8759765625, -0.131591796875, -1.3583984375, 0.7890625, 0.80810546875, -1.0322265625, -0.53076171875, -0.1484375, -1.7841796875, -1.2470703125, 0.17138671875, -0.04864501953125, -0.80322265625, -0.0933837890625, 0.984375, 0.7001953125, 0.5380859375, 0.2022705078125, -1.1865234375, 0.5439453125, 1.1318359375, 0.79931640625, 0.32666015625, -1.26171875, 0.457763671875, 1.1591796875, -0.34423828125, 0.65771484375, 0.216552734375, 1.19140625, -0.2744140625, -0.020416259765625, -0.86376953125, 0.93017578125, 1.0556640625, 0.69873046875, -0.15087890625, -0.33056640625, 0.8505859375, 0.06890869140625, 0.359375, -0.262939453125, 0.12493896484375, 0.017059326171875, -0.98974609375, 0.5107421875, 0.2408447265625, 0.615234375, -0.62890625, 0.86962890625, -0.07427978515625, 0.85595703125, 0.300537109375, -1.072265625, -1.6064453125, -0.353515625, -0.484130859375, -0.6044921875, -0.455810546875, 0.95849609375, 1.3671875, 0.544921875, 0.560546875, 0.34521484375, -0.6513671875, -0.410400390625, -0.2021484375, -0.1656494140625, 0.073486328125, 0.84716796875, -1.7998046875, -1.0126953125, -0.1324462890625, 0.95849609375, -0.669921875, -0.79052734375, -2.193359375, -0.42529296875, -1.7275390625, -1.04296875, 0.716796875, -0.4423828125, -1.193359375, 0.61572265625, -1.5224609375, 0.62890625, -0.705078125, 0.677734375, -0.213134765625, -1.6748046875, -1.087890625, -0.65185546875, -1.1337890625, 2.314453125, -0.352783203125, -0.27001953125, -2.01953125, -1.2685546875, 0.308837890625, -0.280517578125, -1.3798828125, -1.595703125, 0.642578125, 1.693359375, -0.82470703125, -1.255859375, 0.57373046875, 1.5859375, 1.068359375, -0.876953125, 0.370849609375, 1.220703125, 0.59765625, 0.007602691650390625, 0.09326171875, -0.9521484375, -0.024932861328125, -0.94775390625, -0.299560546875, -0.002536773681640625, 1.41796875, -0.06903076171875, -1.5927734375, 0.353515625, 3.63671875, -0.765625, -1.1142578125, 0.4287109375, -0.86865234375, -0.9267578125, -0.21826171875, -1.10546875, 0.29296875, -0.225830078125, 0.5400390625, -0.45556640625, -0.68701171875, -0.79150390625, -1.0810546875, 0.25439453125, -1.2998046875, -0.494140625, -0.1510009765625, 1.5615234375, -0.4248046875, -0.486572265625, 0.45458984375, 0.047637939453125, -0.11639404296875, 0.057403564453125, 0.130126953125, -0.10125732421875, -0.56201171875, 1.4765625, -1.7451171875, 1.34765625, -0.45703125, 0.873046875, -0.056121826171875, -0.8876953125, -0.986328125, 1.5654296875, 0.49853515625, 0.55859375, -0.2198486328125, 0.62548828125, 0.2734375, -0.63671875, -0.41259765625, -1.2705078125, 0.0665283203125, 1.3369140625, 0.90283203125, -0.77685546875, -1.5, -1.8525390625, -1.314453125, -0.86767578125, -0.331787109375, 0.1590576171875, 0.94775390625, -0.1771240234375, 1.638671875, -2.17578125, 0.58740234375, 0.424560546875, -0.3466796875, 0.642578125, 0.473388671875, 0.96435546875, 1.38671875, -0.91357421875, 1.0361328125, -0.67333984375, 1.5009765625]]]).to(device)
|
| 24 |
+
|
| 25 |
+
cond = [[prompt_embeds, {}]]
|
| 26 |
+
|
| 27 |
+
return io.NodeOutput(cond)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LotusExtension(ComfyExtension):
|
| 31 |
+
@override
|
| 32 |
+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
| 33 |
+
return [
|
| 34 |
+
LotusConditioning,
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
async def comfy_entrypoint() -> LotusExtension:
|
| 39 |
+
return LotusExtension()
|