zk-Armor commited on
Commit
ea5011c
·
verified ·
1 Parent(s): 692cc0b

Upload ComfyUI/comfy_extras

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ComfyUI/comfy_extras/chainner_models/model_loading.py +6 -0
  2. ComfyUI/comfy_extras/frame_interpolation_models/film_net.py +258 -0
  3. ComfyUI/comfy_extras/frame_interpolation_models/ifnet.py +128 -0
  4. ComfyUI/comfy_extras/nodes_ace.py +145 -0
  5. ComfyUI/comfy_extras/nodes_advanced_samplers.py +121 -0
  6. ComfyUI/comfy_extras/nodes_align_your_steps.py +70 -0
  7. ComfyUI/comfy_extras/nodes_apg.py +110 -0
  8. ComfyUI/comfy_extras/nodes_attention_multiply.py +151 -0
  9. ComfyUI/comfy_extras/nodes_audio.py +794 -0
  10. ComfyUI/comfy_extras/nodes_audio_encoder.py +62 -0
  11. ComfyUI/comfy_extras/nodes_camera_trajectory.py +239 -0
  12. ComfyUI/comfy_extras/nodes_canny.py +45 -0
  13. ComfyUI/comfy_extras/nodes_cfg.py +91 -0
  14. ComfyUI/comfy_extras/nodes_chroma_radiance.py +117 -0
  15. ComfyUI/comfy_extras/nodes_clip_sdxl.py +71 -0
  16. ComfyUI/comfy_extras/nodes_color.py +42 -0
  17. ComfyUI/comfy_extras/nodes_compositing.py +226 -0
  18. ComfyUI/comfy_extras/nodes_cond.py +68 -0
  19. ComfyUI/comfy_extras/nodes_context_windows.py +103 -0
  20. ComfyUI/comfy_extras/nodes_controlnet.py +85 -0
  21. ComfyUI/comfy_extras/nodes_cosmos.py +143 -0
  22. ComfyUI/comfy_extras/nodes_curve.py +92 -0
  23. ComfyUI/comfy_extras/nodes_custom_sampler.py +1095 -0
  24. ComfyUI/comfy_extras/nodes_dataset.py +1537 -0
  25. ComfyUI/comfy_extras/nodes_differential_diffusion.py +73 -0
  26. ComfyUI/comfy_extras/nodes_easycache.py +530 -0
  27. ComfyUI/comfy_extras/nodes_edit_model.py +38 -0
  28. ComfyUI/comfy_extras/nodes_eps.py +172 -0
  29. ComfyUI/comfy_extras/nodes_flux.py +314 -0
  30. ComfyUI/comfy_extras/nodes_frame_interpolation.py +211 -0
  31. ComfyUI/comfy_extras/nodes_freelunch.py +138 -0
  32. ComfyUI/comfy_extras/nodes_fresca.py +115 -0
  33. ComfyUI/comfy_extras/nodes_gits.py +382 -0
  34. ComfyUI/comfy_extras/nodes_glsl.py +958 -0
  35. ComfyUI/comfy_extras/nodes_hidream.py +74 -0
  36. ComfyUI/comfy_extras/nodes_hooks.py +750 -0
  37. ComfyUI/comfy_extras/nodes_hunyuan.py +427 -0
  38. ComfyUI/comfy_extras/nodes_hunyuan3d.py +697 -0
  39. ComfyUI/comfy_extras/nodes_hypernetwork.py +138 -0
  40. ComfyUI/comfy_extras/nodes_hypertile.py +98 -0
  41. ComfyUI/comfy_extras/nodes_image_compare.py +54 -0
  42. ComfyUI/comfy_extras/nodes_images.py +851 -0
  43. ComfyUI/comfy_extras/nodes_ip2p.py +63 -0
  44. ComfyUI/comfy_extras/nodes_kandinsky5.py +137 -0
  45. ComfyUI/comfy_extras/nodes_latent.py +504 -0
  46. ComfyUI/comfy_extras/nodes_load_3d.py +131 -0
  47. ComfyUI/comfy_extras/nodes_logic.py +274 -0
  48. ComfyUI/comfy_extras/nodes_lora_debug.py +79 -0
  49. ComfyUI/comfy_extras/nodes_lora_extract.py +145 -0
  50. ComfyUI/comfy_extras/nodes_lotus.py +39 -0
ComfyUI/comfy_extras/chainner_models/model_loading.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import logging
2
+ from spandrel import ModelLoader
3
+
4
+ def load_state_dict(state_dict):
5
+ logging.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
6
+ return ModelLoader().load_from_state_dict(state_dict).eval()
ComfyUI/comfy_extras/frame_interpolation_models/film_net.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FILM: Frame Interpolation for Large Motion (ECCV 2022)."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import comfy.ops
8
+
9
+ ops = comfy.ops.disable_weight_init
10
+
11
+
12
+ class FilmConv2d(nn.Module):
13
+ """Conv2d with optional LeakyReLU and FILM-style padding."""
14
+
15
+ def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops):
16
+ super().__init__()
17
+ self.even_pad = not size % 2
18
+ self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype)
19
+ self.activation = nn.LeakyReLU(0.2) if activation else None
20
+
21
+ def forward(self, x):
22
+ if self.even_pad:
23
+ x = F.pad(x, (0, 1, 0, 1))
24
+ x = self.conv(x)
25
+ if self.activation is not None:
26
+ x = self.activation(x)
27
+ return x
28
+
29
+
30
+ def _warp_core(image, flow, grid_x, grid_y):
31
+ dtype = image.dtype
32
+ H, W = flow.shape[2], flow.shape[3]
33
+ dx = flow[:, 0].float() / (W * 0.5)
34
+ dy = flow[:, 1].float() / (H * 0.5)
35
+ grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3)
36
+ return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype)
37
+
38
+
39
+ def build_image_pyramid(image, pyramid_levels):
40
+ pyramid = [image]
41
+ for _ in range(1, pyramid_levels):
42
+ image = F.avg_pool2d(image, 2, 2)
43
+ pyramid.append(image)
44
+ return pyramid
45
+
46
+
47
+ def flow_pyramid_synthesis(residual_pyramid):
48
+ flow = residual_pyramid[-1]
49
+ flow_pyramid = [flow]
50
+ for residual_flow in residual_pyramid[:-1][::-1]:
51
+ flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow)
52
+ flow_pyramid.append(flow)
53
+ flow_pyramid.reverse()
54
+ return flow_pyramid
55
+
56
+
57
+ def multiply_pyramid(pyramid, scalar):
58
+ return [image * scalar[:, None, None, None] for image in pyramid]
59
+
60
+
61
+ def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn):
62
+ return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)]
63
+
64
+
65
+ def concatenate_pyramids(pyramid1, pyramid2):
66
+ return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)]
67
+
68
+
69
+ class SubTreeExtractor(nn.Module):
70
+ def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops):
71
+ super().__init__()
72
+ convs = []
73
+ for i in range(n_layers):
74
+ out_ch = channels << i
75
+ convs.append(nn.Sequential(
76
+ FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations),
77
+ FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations)))
78
+ in_channels = out_ch
79
+ self.convs = nn.ModuleList(convs)
80
+
81
+ def forward(self, image, n):
82
+ head = image
83
+ pyramid = []
84
+ for i, layer in enumerate(self.convs):
85
+ head = layer(head)
86
+ pyramid.append(head)
87
+ if i < n - 1:
88
+ head = F.avg_pool2d(head, 2, 2)
89
+ return pyramid
90
+
91
+
92
+ class FeatureExtractor(nn.Module):
93
+ def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops):
94
+ super().__init__()
95
+ self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations)
96
+ self.sub_levels = sub_levels
97
+
98
+ def forward(self, image_pyramid):
99
+ sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels))
100
+ for i in range(len(image_pyramid))]
101
+ feature_pyramid = []
102
+ for i in range(len(image_pyramid)):
103
+ features = sub_pyramids[i][0]
104
+ for j in range(1, self.sub_levels):
105
+ if j <= i:
106
+ features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
107
+ feature_pyramid.append(features)
108
+ # Free sub-pyramids no longer needed by future levels
109
+ if i >= self.sub_levels - 1:
110
+ sub_pyramids[i - self.sub_levels + 1] = None
111
+ return feature_pyramid
112
+
113
+
114
+ class FlowEstimator(nn.Module):
115
+ def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops):
116
+ super().__init__()
117
+ self._convs = nn.ModuleList()
118
+ for _ in range(num_convs):
119
+ self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations))
120
+ in_channels = num_filters
121
+ self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations))
122
+ self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations))
123
+
124
+ def forward(self, features_a, features_b):
125
+ net = torch.cat([features_a, features_b], dim=1)
126
+ for conv in self._convs:
127
+ net = conv(net)
128
+ return net
129
+
130
+
131
+ class PyramidFlowEstimator(nn.Module):
132
+ def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
133
+ super().__init__()
134
+ in_channels = filters << 1
135
+ predictors = []
136
+ for i in range(len(flow_convs)):
137
+ predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations))
138
+ in_channels += filters << (i + 2)
139
+ self._predictor = predictors[-1]
140
+ self._predictors = nn.ModuleList(predictors[:-1][::-1])
141
+
142
+ def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn):
143
+ levels = len(feature_pyramid_a)
144
+ v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
145
+ residuals = [v]
146
+ # Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels
147
+ steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)]
148
+ steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)]
149
+ for i, predictor in steps:
150
+ v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2)
151
+ v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v))
152
+ residuals.append(v_residual)
153
+ v = v.add_(v_residual)
154
+ residuals.reverse()
155
+ return residuals
156
+
157
+
158
+ def _get_fusion_channels(level, filters):
159
+ # Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions
160
+ return (sum(filters << i for i in range(level)) + 3 + 2) * 2
161
+
162
+
163
+ class Fusion(nn.Module):
164
+ def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops):
165
+ super().__init__()
166
+ self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype)
167
+ self.convs = nn.ModuleList()
168
+ in_channels = _get_fusion_channels(n_layers, filters)
169
+ increase = 0
170
+ for i in range(n_layers)[::-1]:
171
+ num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
172
+ self.convs.append(nn.ModuleList([
173
+ FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations),
174
+ FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations),
175
+ FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)]))
176
+ in_channels = num_filters
177
+ increase = _get_fusion_channels(i, filters) - num_filters // 2
178
+
179
+ def forward(self, pyramid):
180
+ net = pyramid[-1]
181
+ for k, layers in enumerate(self.convs):
182
+ i = len(self.convs) - 1 - k
183
+ net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest"))
184
+ net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1)))
185
+ return self.output_conv(net)
186
+
187
+
188
+ class FILMNet(nn.Module):
189
+ def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4,
190
+ filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
191
+ super().__init__()
192
+ self.pyramid_levels = pyramid_levels
193
+ self.fusion_pyramid_levels = fusion_pyramid_levels
194
+ self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations)
195
+ self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations)
196
+ self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations)
197
+ self._warp_grids = {}
198
+
199
+ def get_dtype(self):
200
+ return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
201
+
202
+ def _build_warp_grids(self, H, W, device):
203
+ """Pre-compute warp grids for all pyramid levels."""
204
+ if (H, W) in self._warp_grids:
205
+ return
206
+ self._warp_grids = {} # clear old resolution grids to prevent memory leaks
207
+ for _ in range(self.pyramid_levels):
208
+ self._warp_grids[(H, W)] = (
209
+ torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device),
210
+ torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device),
211
+ )
212
+ H, W = H // 2, W // 2
213
+
214
+ def warp(self, image, flow):
215
+ grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])]
216
+ return _warp_core(image, flow, grid_x, grid_y)
217
+
218
+ def extract_features(self, img):
219
+ """Extract image and feature pyramids for a single frame. Can be cached across pairs."""
220
+ image_pyramid = build_image_pyramid(img, self.pyramid_levels)
221
+ feature_pyramid = self.extract(image_pyramid)
222
+ return image_pyramid, feature_pyramid
223
+
224
+ def forward(self, img0, img1, timestep=0.5, cache=None):
225
+ # FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported)
226
+ t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep
227
+ return self.forward_multi_timestep(img0, img1, [t], cache=cache)
228
+
229
+ def forward_multi_timestep(self, img0, img1, timesteps, cache=None):
230
+ """Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs."""
231
+ self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
232
+
233
+ image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0)
234
+ image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1)
235
+
236
+ fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels]
237
+ bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels]
238
+
239
+ # Build warp targets and free full pyramids (only first fpl levels needed from here)
240
+ fpl = self.fusion_pyramid_levels
241
+ p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]),
242
+ concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])]
243
+ del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1
244
+
245
+ results = []
246
+ dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype)
247
+ for idx in range(len(timesteps)):
248
+ batch_dt = dt_tensors[idx:idx + 1]
249
+ bwd_scaled = multiply_pyramid(bwd_flow, batch_dt)
250
+ fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt)
251
+ fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp)
252
+ bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp)
253
+ aligned = [torch.cat([fw, bw, bf, ff], dim=1)
254
+ for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)]
255
+ del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled
256
+ results.append(self.fuse(aligned))
257
+ del aligned
258
+ return torch.cat(results, dim=0)
ComfyUI/comfy_extras/frame_interpolation_models/ifnet.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import comfy.ops
6
+
7
+ ops = comfy.ops.disable_weight_init
8
+
9
+
10
+ def _warp(img, flow, warp_grids):
11
+ B, _, H, W = img.shape
12
+ base_grid, flow_div = warp_grids[(H, W)]
13
+ flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float()
14
+ grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1)
15
+ return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype)
16
+
17
+
18
+ class Head(nn.Module):
19
+ def __init__(self, out_ch=4, device=None, dtype=None, operations=ops):
20
+ super().__init__()
21
+ self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype)
22
+ self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
23
+ self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
24
+ self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype)
25
+ self.relu = nn.LeakyReLU(0.2, True)
26
+
27
+ def forward(self, x):
28
+ x = self.relu(self.cnn0(x))
29
+ x = self.relu(self.cnn1(x))
30
+ x = self.relu(self.cnn2(x))
31
+ return self.cnn3(x)
32
+
33
+
34
+ class ResConv(nn.Module):
35
+ def __init__(self, c, device=None, dtype=None, operations=ops):
36
+ super().__init__()
37
+ self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype)
38
+ self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype))
39
+ self.relu = nn.LeakyReLU(0.2, True)
40
+
41
+ def forward(self, x):
42
+ return self.relu(torch.addcmul(x, self.conv(x), self.beta))
43
+
44
+
45
+ class IFBlock(nn.Module):
46
+ def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops):
47
+ super().__init__()
48
+ self.conv0 = nn.Sequential(
49
+ nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)),
50
+ nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)))
51
+ self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)))
52
+ self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2))
53
+
54
+ def forward(self, x, flow=None, scale=1):
55
+ x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear")
56
+ if flow is not None:
57
+ flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale)
58
+ x = torch.cat((x, flow), 1)
59
+ feat = self.convblock(self.conv0(x))
60
+ tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear")
61
+ return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:]
62
+
63
+
64
+ class IFNet(nn.Module):
65
+ def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops):
66
+ super().__init__()
67
+ self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations)
68
+ block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4
69
+ self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)])
70
+ self.scale_list = [16, 8, 4, 2, 1]
71
+ self.pad_align = 64
72
+ self._warp_grids = {}
73
+
74
+ def get_dtype(self):
75
+ return self.encode.cnn0.weight.dtype
76
+
77
+ def _build_warp_grids(self, H, W, device):
78
+ if (H, W) in self._warp_grids:
79
+ return
80
+ self._warp_grids = {} # clear old resolution grids to prevent memory leaks
81
+ grid_y, grid_x = torch.meshgrid(
82
+ torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32),
83
+ torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij")
84
+ self._warp_grids[(H, W)] = (
85
+ torch.stack((grid_x, grid_y), dim=0).unsqueeze(0),
86
+ torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device))
87
+
88
+ def warp(self, img, flow):
89
+ return _warp(img, flow, self._warp_grids)
90
+
91
+ def extract_features(self, img):
92
+ """Extract head features for a single frame. Can be cached across pairs."""
93
+ return self.encode(img)
94
+
95
+ def forward(self, img0, img1, timestep=0.5, cache=None):
96
+ if not isinstance(timestep, torch.Tensor):
97
+ timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype)
98
+
99
+ self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
100
+
101
+ B = img0.shape[0]
102
+ f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0)
103
+ f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1)
104
+ flow = mask = feat = None
105
+ warped_img0, warped_img1 = img0, img1
106
+ for i, block in enumerate(self.blocks):
107
+ if flow is None:
108
+ flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i])
109
+ else:
110
+ fd, mask, feat = block(
111
+ torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1),
112
+ flow, scale=self.scale_list[i])
113
+ flow = flow.add_(fd)
114
+ warped_img0 = self.warp(img0, flow[:, :2])
115
+ warped_img1 = self.warp(img1, flow[:, 2:4])
116
+ return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask))
117
+
118
+
119
+ def detect_rife_config(state_dict):
120
+ head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW)
121
+ channels = []
122
+ for i in range(5):
123
+ key = f"blocks.{i}.conv0.1.0.weight"
124
+ if key in state_dict:
125
+ channels.append(state_dict[key].shape[0])
126
+ if len(channels) != 5:
127
+ raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}")
128
+ return head_ch, channels
ComfyUI/comfy_extras/nodes_ace.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing_extensions import override
3
+
4
+ import comfy.model_management
5
+ import node_helpers
6
+ from comfy_api.latest import ComfyExtension, IO
7
+
8
+
9
+ class TextEncodeAceStepAudio(IO.ComfyNode):
10
+ @classmethod
11
+ def define_schema(cls):
12
+ return IO.Schema(
13
+ node_id="TextEncodeAceStepAudio",
14
+ category="conditioning",
15
+ inputs=[
16
+ IO.Clip.Input("clip"),
17
+ IO.String.Input("tags", multiline=True, dynamic_prompts=True),
18
+ IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
19
+ IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
20
+ ],
21
+ outputs=[IO.Conditioning.Output()],
22
+ )
23
+
24
+ @classmethod
25
+ def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput:
26
+ tokens = clip.tokenize(tags, lyrics=lyrics)
27
+ conditioning = clip.encode_from_tokens_scheduled(tokens)
28
+ conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
29
+ return IO.NodeOutput(conditioning)
30
+
31
+ class TextEncodeAceStepAudio15(IO.ComfyNode):
32
+ @classmethod
33
+ def define_schema(cls):
34
+ return IO.Schema(
35
+ node_id="TextEncodeAceStepAudio1.5",
36
+ category="conditioning",
37
+ inputs=[
38
+ IO.Clip.Input("clip"),
39
+ IO.String.Input("tags", multiline=True, dynamic_prompts=True),
40
+ IO.String.Input("lyrics", multiline=True, dynamic_prompts=True),
41
+ IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
42
+ IO.Int.Input("bpm", default=120, min=10, max=300),
43
+ IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
44
+ IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
45
+ IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
46
+ IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
47
+ IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
48
+ IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
49
+ IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
50
+ IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
51
+ IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
52
+ IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
53
+ ],
54
+ outputs=[IO.Conditioning.Output()],
55
+ )
56
+
57
+ @classmethod
58
+ def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput:
59
+ tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
60
+ conditioning = clip.encode_from_tokens_scheduled(tokens)
61
+ return IO.NodeOutput(conditioning)
62
+
63
+
64
+ class EmptyAceStepLatentAudio(IO.ComfyNode):
65
+ @classmethod
66
+ def define_schema(cls):
67
+ return IO.Schema(
68
+ node_id="EmptyAceStepLatentAudio",
69
+ display_name="Empty Ace Step 1.0 Latent Audio",
70
+ category="latent/audio",
71
+ inputs=[
72
+ IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
73
+ IO.Int.Input(
74
+ "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
75
+ ),
76
+ ],
77
+ outputs=[IO.Latent.Output()],
78
+ )
79
+
80
+ @classmethod
81
+ def execute(cls, seconds, batch_size) -> IO.NodeOutput:
82
+ length = int(seconds * 44100 / 512 / 8)
83
+ latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
84
+ return IO.NodeOutput({"samples": latent, "type": "audio"})
85
+
86
+
87
+ class EmptyAceStep15LatentAudio(IO.ComfyNode):
88
+ @classmethod
89
+ def define_schema(cls):
90
+ return IO.Schema(
91
+ node_id="EmptyAceStep1.5LatentAudio",
92
+ display_name="Empty Ace Step 1.5 Latent Audio",
93
+ category="latent/audio",
94
+ inputs=[
95
+ IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
96
+ IO.Int.Input(
97
+ "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
98
+ ),
99
+ ],
100
+ outputs=[IO.Latent.Output()],
101
+ )
102
+
103
+ @classmethod
104
+ def execute(cls, seconds, batch_size) -> IO.NodeOutput:
105
+ length = round((seconds * 48000 / 1920))
106
+ latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
107
+ return IO.NodeOutput({"samples": latent, "type": "audio"})
108
+
109
+ class ReferenceAudio(IO.ComfyNode):
110
+ @classmethod
111
+ def define_schema(cls):
112
+ return IO.Schema(
113
+ node_id="ReferenceTimbreAudio",
114
+ display_name="Reference Audio",
115
+ category="advanced/conditioning/audio",
116
+ is_experimental=True,
117
+ description="This node sets the reference audio for ace step 1.5",
118
+ inputs=[
119
+ IO.Conditioning.Input("conditioning"),
120
+ IO.Latent.Input("latent", optional=True),
121
+ ],
122
+ outputs=[
123
+ IO.Conditioning.Output(),
124
+ ]
125
+ )
126
+
127
+ @classmethod
128
+ def execute(cls, conditioning, latent=None) -> IO.NodeOutput:
129
+ if latent is not None:
130
+ conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
131
+ return IO.NodeOutput(conditioning)
132
+
133
+ class AceExtension(ComfyExtension):
134
+ @override
135
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
136
+ return [
137
+ TextEncodeAceStepAudio,
138
+ EmptyAceStepLatentAudio,
139
+ TextEncodeAceStepAudio15,
140
+ EmptyAceStep15LatentAudio,
141
+ ReferenceAudio,
142
+ ]
143
+
144
+ async def comfy_entrypoint() -> AceExtension:
145
+ return AceExtension()
ComfyUI/comfy_extras/nodes_advanced_samplers.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm.auto import trange
4
+ from typing_extensions import override
5
+
6
+ import comfy.model_patcher
7
+ import comfy.samplers
8
+ import comfy.utils
9
+ from comfy.k_diffusion.sampling import to_d
10
+ from comfy_api.latest import ComfyExtension, io
11
+
12
+
13
+ @torch.no_grad()
14
+ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable=None, total_upscale=2.0, upscale_method="bislerp", upscale_steps=None):
15
+ extra_args = {} if extra_args is None else extra_args
16
+
17
+ if upscale_steps is None:
18
+ upscale_steps = max(len(sigmas) // 2 + 1, 2)
19
+ else:
20
+ upscale_steps += 1
21
+ upscale_steps = min(upscale_steps, len(sigmas) + 1)
22
+
23
+ upscales = np.linspace(1.0, total_upscale, upscale_steps)[1:]
24
+
25
+ orig_shape = x.size()
26
+ s_in = x.new_ones([x.shape[0]])
27
+ for i in trange(len(sigmas) - 1, disable=disable):
28
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
29
+ if callback is not None:
30
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
31
+
32
+ x = denoised
33
+ if i < len(upscales):
34
+ x = comfy.utils.common_upscale(x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled")
35
+
36
+ if sigmas[i + 1] > 0:
37
+ x += sigmas[i + 1] * torch.randn_like(x)
38
+ return x
39
+
40
+
41
+ class SamplerLCMUpscale(io.ComfyNode):
42
+ UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
43
+
44
+ @classmethod
45
+ def define_schema(cls) -> io.Schema:
46
+ return io.Schema(
47
+ node_id="SamplerLCMUpscale",
48
+ category="sampling/custom_sampling/samplers",
49
+ inputs=[
50
+ io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True),
51
+ io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True),
52
+ io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
53
+ ],
54
+ outputs=[io.Sampler.Output()],
55
+ )
56
+
57
+ @classmethod
58
+ def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
59
+ if scale_steps < 0:
60
+ scale_steps = None
61
+ sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
62
+ return io.NodeOutput(sampler)
63
+
64
+
65
+ @torch.no_grad()
66
+ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
67
+ extra_args = {} if extra_args is None else extra_args
68
+
69
+ temp = [0]
70
+ def post_cfg_function(args):
71
+ temp[0] = args["uncond_denoised"]
72
+ return args["denoised"]
73
+
74
+ model_options = extra_args.get("model_options", {}).copy()
75
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
76
+
77
+ s_in = x.new_ones([x.shape[0]])
78
+ for i in trange(len(sigmas) - 1, disable=disable):
79
+ sigma_hat = sigmas[i]
80
+ denoised = model(x, sigma_hat * s_in, **extra_args)
81
+ d = to_d(x - denoised + temp[0], sigmas[i], denoised)
82
+ if callback is not None:
83
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
84
+ dt = sigmas[i + 1] - sigma_hat
85
+ x = x + d * dt
86
+ return x
87
+
88
+
89
+ class SamplerEulerCFGpp(io.ComfyNode):
90
+ @classmethod
91
+ def define_schema(cls) -> io.Schema:
92
+ return io.Schema(
93
+ node_id="SamplerEulerCFGpp",
94
+ display_name="SamplerEulerCFG++",
95
+ category="_for_testing", # "sampling/custom_sampling/samplers"
96
+ inputs=[
97
+ io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
98
+ ],
99
+ outputs=[io.Sampler.Output()],
100
+ is_experimental=True,
101
+ )
102
+
103
+ @classmethod
104
+ def execute(cls, version) -> io.NodeOutput:
105
+ if version == "alternative":
106
+ sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
107
+ else:
108
+ sampler = comfy.samplers.ksampler("euler_cfg_pp")
109
+ return io.NodeOutput(sampler)
110
+
111
+
112
+ class AdvancedSamplersExtension(ComfyExtension):
113
+ @override
114
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
115
+ return [
116
+ SamplerLCMUpscale,
117
+ SamplerEulerCFGpp,
118
+ ]
119
+
120
+ async def comfy_entrypoint() -> AdvancedSamplersExtension:
121
+ return AdvancedSamplersExtension()
ComfyUI/comfy_extras/nodes_align_your_steps.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
2
+ import numpy as np
3
+ import torch
4
+ from typing_extensions import override
5
+
6
+ from comfy_api.latest import ComfyExtension, io
7
+
8
+
9
+ def loglinear_interp(t_steps, num_steps):
10
+ """
11
+ Performs log-linear interpolation of a given array of decreasing numbers.
12
+ """
13
+ xs = np.linspace(0, 1, len(t_steps))
14
+ ys = np.log(t_steps[::-1])
15
+
16
+ new_xs = np.linspace(0, 1, num_steps)
17
+ new_ys = np.interp(new_xs, xs, ys)
18
+
19
+ interped_ys = np.exp(new_ys)[::-1].copy()
20
+ return interped_ys
21
+
22
+ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
23
+ "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
24
+ "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
25
+
26
+ class AlignYourStepsScheduler(io.ComfyNode):
27
+ @classmethod
28
+ def define_schema(cls) -> io.Schema:
29
+ return io.Schema(
30
+ node_id="AlignYourStepsScheduler",
31
+ search_aliases=["AYS scheduler"],
32
+ category="sampling/custom_sampling/schedulers",
33
+ inputs=[
34
+ io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
35
+ io.Int.Input("steps", default=10, min=1, max=10000),
36
+ io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
37
+ ],
38
+ outputs=[io.Sigmas.Output()],
39
+ )
40
+
41
+ def get_sigmas(self, model_type, steps, denoise):
42
+ # Deprecated: use the V3 schema's `execute` method instead of this.
43
+ return AlignYourStepsScheduler().execute(model_type, steps, denoise).result
44
+
45
+ @classmethod
46
+ def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
47
+ total_steps = steps
48
+ if denoise < 1.0:
49
+ if denoise <= 0.0:
50
+ return io.NodeOutput(torch.FloatTensor([]))
51
+ total_steps = round(steps * denoise)
52
+
53
+ sigmas = NOISE_LEVELS[model_type][:]
54
+ if (steps + 1) != len(sigmas):
55
+ sigmas = loglinear_interp(sigmas, steps + 1)
56
+
57
+ sigmas = sigmas[-(total_steps + 1):]
58
+ sigmas[-1] = 0
59
+ return io.NodeOutput(torch.FloatTensor(sigmas))
60
+
61
+
62
+ class AlignYourStepsExtension(ComfyExtension):
63
+ @override
64
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
65
+ return [
66
+ AlignYourStepsScheduler,
67
+ ]
68
+
69
+ async def comfy_entrypoint() -> AlignYourStepsExtension:
70
+ return AlignYourStepsExtension()
ComfyUI/comfy_extras/nodes_apg.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing_extensions import override
3
+
4
+ from comfy_api.latest import ComfyExtension, io
5
+
6
+
7
+ def project(v0, v1):
8
+ v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
9
+ v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
10
+ v0_orthogonal = v0 - v0_parallel
11
+ return v0_parallel, v0_orthogonal
12
+
13
+ class APG(io.ComfyNode):
14
+ @classmethod
15
+ def define_schema(cls) -> io.Schema:
16
+ return io.Schema(
17
+ node_id="APG",
18
+ display_name="Adaptive Projected Guidance",
19
+ category="sampling/custom_sampling",
20
+ inputs=[
21
+ io.Model.Input("model"),
22
+ io.Float.Input(
23
+ "eta",
24
+ default=1.0,
25
+ min=-10.0,
26
+ max=10.0,
27
+ step=0.01,
28
+ tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
29
+ advanced=True,
30
+ ),
31
+ io.Float.Input(
32
+ "norm_threshold",
33
+ default=5.0,
34
+ min=0.0,
35
+ max=50.0,
36
+ step=0.1,
37
+ tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
38
+ advanced=True,
39
+ ),
40
+ io.Float.Input(
41
+ "momentum",
42
+ default=0.0,
43
+ min=-5.0,
44
+ max=1.0,
45
+ step=0.01,
46
+ tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
47
+ advanced=True,
48
+ ),
49
+ ],
50
+ outputs=[io.Model.Output()],
51
+ )
52
+
53
+ @classmethod
54
+ def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
55
+ running_avg = 0
56
+ prev_sigma = None
57
+
58
+ def pre_cfg_function(args):
59
+ nonlocal running_avg, prev_sigma
60
+
61
+ if len(args["conds_out"]) == 1:
62
+ return args["conds_out"]
63
+
64
+ cond = args["conds_out"][0]
65
+ uncond = args["conds_out"][1]
66
+ sigma = args["sigma"][0]
67
+ cond_scale = args["cond_scale"]
68
+
69
+ if prev_sigma is not None and sigma > prev_sigma:
70
+ running_avg = 0
71
+ prev_sigma = sigma
72
+
73
+ guidance = cond - uncond
74
+
75
+ if momentum != 0:
76
+ if not torch.is_tensor(running_avg):
77
+ running_avg = guidance
78
+ else:
79
+ running_avg = momentum * running_avg + guidance
80
+ guidance = running_avg
81
+
82
+ if norm_threshold > 0:
83
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
84
+ scale = torch.minimum(
85
+ torch.ones_like(guidance_norm),
86
+ norm_threshold / guidance_norm
87
+ )
88
+ guidance = guidance * scale
89
+
90
+ guidance_parallel, guidance_orthogonal = project(guidance, cond)
91
+ modified_guidance = guidance_orthogonal + eta * guidance_parallel
92
+
93
+ modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
94
+
95
+ return [modified_cond, uncond] + args["conds_out"][2:]
96
+
97
+ m = model.clone()
98
+ m.set_model_sampler_pre_cfg_function(pre_cfg_function)
99
+ return io.NodeOutput(m)
100
+
101
+
102
+ class ApgExtension(ComfyExtension):
103
+ @override
104
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
105
+ return [
106
+ APG,
107
+ ]
108
+
109
+ async def comfy_entrypoint() -> ApgExtension:
110
+ return ApgExtension()
ComfyUI/comfy_extras/nodes_attention_multiply.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+
3
+ from comfy_api.latest import ComfyExtension, io
4
+
5
+
6
+ def attention_multiply(attn, model, q, k, v, out):
7
+ m = model.clone()
8
+ sd = model.model_state_dict()
9
+
10
+ for key in sd:
11
+ if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)):
12
+ m.add_patches({key: (None,)}, 0.0, q)
13
+ if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)):
14
+ m.add_patches({key: (None,)}, 0.0, k)
15
+ if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)):
16
+ m.add_patches({key: (None,)}, 0.0, v)
17
+ if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)):
18
+ m.add_patches({key: (None,)}, 0.0, out)
19
+
20
+ return m
21
+
22
+
23
+ class UNetSelfAttentionMultiply(io.ComfyNode):
24
+ @classmethod
25
+ def define_schema(cls) -> io.Schema:
26
+ return io.Schema(
27
+ node_id="UNetSelfAttentionMultiply",
28
+ category="_for_testing/attention_experiments",
29
+ inputs=[
30
+ io.Model.Input("model"),
31
+ io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
32
+ io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
33
+ io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
34
+ io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
35
+ ],
36
+ outputs=[io.Model.Output()],
37
+ is_experimental=True,
38
+ )
39
+
40
+ @classmethod
41
+ def execute(cls, model, q, k, v, out) -> io.NodeOutput:
42
+ m = attention_multiply("attn1", model, q, k, v, out)
43
+ return io.NodeOutput(m)
44
+
45
+
46
+ class UNetCrossAttentionMultiply(io.ComfyNode):
47
+ @classmethod
48
+ def define_schema(cls) -> io.Schema:
49
+ return io.Schema(
50
+ node_id="UNetCrossAttentionMultiply",
51
+ category="_for_testing/attention_experiments",
52
+ inputs=[
53
+ io.Model.Input("model"),
54
+ io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
55
+ io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
56
+ io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
57
+ io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
58
+ ],
59
+ outputs=[io.Model.Output()],
60
+ is_experimental=True,
61
+ )
62
+
63
+ @classmethod
64
+ def execute(cls, model, q, k, v, out) -> io.NodeOutput:
65
+ m = attention_multiply("attn2", model, q, k, v, out)
66
+ return io.NodeOutput(m)
67
+
68
+
69
+ class CLIPAttentionMultiply(io.ComfyNode):
70
+ @classmethod
71
+ def define_schema(cls) -> io.Schema:
72
+ return io.Schema(
73
+ node_id="CLIPAttentionMultiply",
74
+ search_aliases=["clip attention scale", "text encoder attention"],
75
+ category="_for_testing/attention_experiments",
76
+ inputs=[
77
+ io.Clip.Input("clip"),
78
+ io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
79
+ io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
80
+ io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
81
+ io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
82
+ ],
83
+ outputs=[io.Clip.Output()],
84
+ is_experimental=True,
85
+ )
86
+
87
+ @classmethod
88
+ def execute(cls, clip, q, k, v, out) -> io.NodeOutput:
89
+ m = clip.clone()
90
+ sd = m.patcher.model_state_dict()
91
+
92
+ for key in sd:
93
+ if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"):
94
+ m.add_patches({key: (None,)}, 0.0, q)
95
+ if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"):
96
+ m.add_patches({key: (None,)}, 0.0, k)
97
+ if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"):
98
+ m.add_patches({key: (None,)}, 0.0, v)
99
+ if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
100
+ m.add_patches({key: (None,)}, 0.0, out)
101
+ return io.NodeOutput(m)
102
+
103
+
104
+ class UNetTemporalAttentionMultiply(io.ComfyNode):
105
+ @classmethod
106
+ def define_schema(cls) -> io.Schema:
107
+ return io.Schema(
108
+ node_id="UNetTemporalAttentionMultiply",
109
+ category="_for_testing/attention_experiments",
110
+ inputs=[
111
+ io.Model.Input("model"),
112
+ io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
113
+ io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
114
+ io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
115
+ io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
116
+ ],
117
+ outputs=[io.Model.Output()],
118
+ is_experimental=True,
119
+ )
120
+
121
+ @classmethod
122
+ def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput:
123
+ m = model.clone()
124
+ sd = model.model_state_dict()
125
+
126
+ for k in sd:
127
+ if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")):
128
+ if '.time_stack.' in k:
129
+ m.add_patches({k: (None,)}, 0.0, self_temporal)
130
+ else:
131
+ m.add_patches({k: (None,)}, 0.0, self_structural)
132
+ elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")):
133
+ if '.time_stack.' in k:
134
+ m.add_patches({k: (None,)}, 0.0, cross_temporal)
135
+ else:
136
+ m.add_patches({k: (None,)}, 0.0, cross_structural)
137
+ return io.NodeOutput(m)
138
+
139
+
140
+ class AttentionMultiplyExtension(ComfyExtension):
141
+ @override
142
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
143
+ return [
144
+ UNetSelfAttentionMultiply,
145
+ UNetCrossAttentionMultiply,
146
+ CLIPAttentionMultiply,
147
+ UNetTemporalAttentionMultiply,
148
+ ]
149
+
150
+ async def comfy_entrypoint() -> AttentionMultiplyExtension:
151
+ return AttentionMultiplyExtension()
ComfyUI/comfy_extras/nodes_audio.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import av
4
+ import torchaudio
5
+ import torch
6
+ import comfy.model_management
7
+ import folder_paths
8
+ import os
9
+ import hashlib
10
+ import node_helpers
11
+ import logging
12
+ from typing_extensions import override
13
+ from comfy_api.latest import ComfyExtension, IO, UI
14
+
15
+ class EmptyLatentAudio(IO.ComfyNode):
16
+ @classmethod
17
+ def define_schema(cls):
18
+ return IO.Schema(
19
+ node_id="EmptyLatentAudio",
20
+ display_name="Empty Latent Audio",
21
+ category="latent/audio",
22
+ essentials_category="Audio",
23
+ inputs=[
24
+ IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
25
+ IO.Int.Input(
26
+ "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch.",
27
+ ),
28
+ ],
29
+ outputs=[IO.Latent.Output()],
30
+ )
31
+
32
+ @classmethod
33
+ def execute(cls, seconds, batch_size) -> IO.NodeOutput:
34
+ length = round((seconds * 44100 / 2048) / 2) * 2
35
+ latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
36
+ return IO.NodeOutput({"samples":latent, "type": "audio"})
37
+
38
+ generate = execute # TODO: remove
39
+
40
+
41
+ class ConditioningStableAudio(IO.ComfyNode):
42
+ @classmethod
43
+ def define_schema(cls):
44
+ return IO.Schema(
45
+ node_id="ConditioningStableAudio",
46
+ category="conditioning",
47
+ inputs=[
48
+ IO.Conditioning.Input("positive"),
49
+ IO.Conditioning.Input("negative"),
50
+ IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
51
+ IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
52
+ ],
53
+ outputs=[
54
+ IO.Conditioning.Output(display_name="positive"),
55
+ IO.Conditioning.Output(display_name="negative"),
56
+ ],
57
+ )
58
+
59
+ @classmethod
60
+ def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
61
+ positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
62
+ negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
63
+ return IO.NodeOutput(positive, negative)
64
+
65
+ append = execute # TODO: remove
66
+
67
+
68
+ class VAEEncodeAudio(IO.ComfyNode):
69
+ @classmethod
70
+ def define_schema(cls):
71
+ return IO.Schema(
72
+ node_id="VAEEncodeAudio",
73
+ search_aliases=["audio to latent"],
74
+ display_name="VAE Encode Audio",
75
+ category="latent/audio",
76
+ inputs=[
77
+ IO.Audio.Input("audio"),
78
+ IO.Vae.Input("vae"),
79
+ ],
80
+ outputs=[IO.Latent.Output()],
81
+ )
82
+
83
+ @classmethod
84
+ def execute(cls, vae, audio) -> IO.NodeOutput:
85
+ sample_rate = audio["sample_rate"]
86
+ vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
87
+ if vae_sample_rate != sample_rate:
88
+ waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate)
89
+ else:
90
+ waveform = audio["waveform"]
91
+
92
+ t = vae.encode(waveform.movedim(1, -1))
93
+ return IO.NodeOutput({"samples": t})
94
+
95
+ encode = execute # TODO: remove
96
+
97
+
98
+ def vae_decode_audio(vae, samples, tile=None, overlap=None):
99
+ if tile is not None:
100
+ audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1)
101
+ else:
102
+ audio = vae.decode(samples["samples"]).movedim(-1, 1)
103
+
104
+ std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
105
+ std[std < 1.0] = 1.0
106
+ audio /= std
107
+ vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
108
+ return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
109
+
110
+
111
+ class VAEDecodeAudio(IO.ComfyNode):
112
+ @classmethod
113
+ def define_schema(cls):
114
+ return IO.Schema(
115
+ node_id="VAEDecodeAudio",
116
+ search_aliases=["latent to audio"],
117
+ display_name="VAE Decode Audio",
118
+ category="latent/audio",
119
+ inputs=[
120
+ IO.Latent.Input("samples"),
121
+ IO.Vae.Input("vae"),
122
+ ],
123
+ outputs=[IO.Audio.Output()],
124
+ )
125
+
126
+ @classmethod
127
+ def execute(cls, vae, samples) -> IO.NodeOutput:
128
+ return IO.NodeOutput(vae_decode_audio(vae, samples))
129
+
130
+ decode = execute # TODO: remove
131
+
132
+
133
+ class VAEDecodeAudioTiled(IO.ComfyNode):
134
+ @classmethod
135
+ def define_schema(cls):
136
+ return IO.Schema(
137
+ node_id="VAEDecodeAudioTiled",
138
+ search_aliases=["latent to audio"],
139
+ display_name="VAE Decode Audio (Tiled)",
140
+ category="latent/audio",
141
+ inputs=[
142
+ IO.Latent.Input("samples"),
143
+ IO.Vae.Input("vae"),
144
+ IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
145
+ IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
146
+ ],
147
+ outputs=[IO.Audio.Output()],
148
+ )
149
+
150
+ @classmethod
151
+ def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
152
+ return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
153
+
154
+
155
+ class SaveAudio(IO.ComfyNode):
156
+ @classmethod
157
+ def define_schema(cls):
158
+ return IO.Schema(
159
+ node_id="SaveAudio",
160
+ search_aliases=["export flac"],
161
+ display_name="Save Audio (FLAC)",
162
+ category="audio",
163
+ essentials_category="Audio",
164
+ inputs=[
165
+ IO.Audio.Input("audio"),
166
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
167
+ ],
168
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
169
+ is_output_node=True,
170
+ )
171
+
172
+ @classmethod
173
+ def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
174
+ return IO.NodeOutput(
175
+ ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
176
+ )
177
+
178
+ save_flac = execute # TODO: remove
179
+
180
+
181
+ class SaveAudioMP3(IO.ComfyNode):
182
+ @classmethod
183
+ def define_schema(cls):
184
+ return IO.Schema(
185
+ node_id="SaveAudioMP3",
186
+ search_aliases=["export mp3"],
187
+ display_name="Save Audio (MP3)",
188
+ category="audio",
189
+ essentials_category="Audio",
190
+ inputs=[
191
+ IO.Audio.Input("audio"),
192
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
193
+ IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
194
+ ],
195
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
196
+ is_output_node=True,
197
+ )
198
+
199
+ @classmethod
200
+ def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
201
+ return IO.NodeOutput(
202
+ ui=UI.AudioSaveHelper.get_save_audio_ui(
203
+ audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
204
+ )
205
+ )
206
+
207
+ save_mp3 = execute # TODO: remove
208
+
209
+
210
+ class SaveAudioOpus(IO.ComfyNode):
211
+ @classmethod
212
+ def define_schema(cls):
213
+ return IO.Schema(
214
+ node_id="SaveAudioOpus",
215
+ search_aliases=["export opus"],
216
+ display_name="Save Audio (Opus)",
217
+ category="audio",
218
+ inputs=[
219
+ IO.Audio.Input("audio"),
220
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
221
+ IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
222
+ ],
223
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
224
+ is_output_node=True,
225
+ )
226
+
227
+ @classmethod
228
+ def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
229
+ return IO.NodeOutput(
230
+ ui=UI.AudioSaveHelper.get_save_audio_ui(
231
+ audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
232
+ )
233
+ )
234
+
235
+ save_opus = execute # TODO: remove
236
+
237
+
238
+ class PreviewAudio(IO.ComfyNode):
239
+ @classmethod
240
+ def define_schema(cls):
241
+ return IO.Schema(
242
+ node_id="PreviewAudio",
243
+ search_aliases=["play audio"],
244
+ display_name="Preview Audio",
245
+ category="audio",
246
+ inputs=[
247
+ IO.Audio.Input("audio"),
248
+ ],
249
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
250
+ is_output_node=True,
251
+ )
252
+
253
+ @classmethod
254
+ def execute(cls, audio) -> IO.NodeOutput:
255
+ return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
256
+
257
+ save_flac = execute # TODO: remove
258
+
259
+
260
+ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
261
+ """Convert audio to float 32 bits PCM format."""
262
+ if wav.dtype.is_floating_point:
263
+ return wav
264
+ elif wav.dtype == torch.int16:
265
+ return wav.float() / (2 ** 15)
266
+ elif wav.dtype == torch.int32:
267
+ return wav.float() / (2 ** 31)
268
+ raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
269
+
270
+ def load(filepath: str) -> tuple[torch.Tensor, int]:
271
+ with av.open(filepath) as af:
272
+ if not af.streams.audio:
273
+ raise ValueError("No audio stream found in the file.")
274
+
275
+ stream = af.streams.audio[0]
276
+ sr = stream.codec_context.sample_rate
277
+ n_channels = stream.channels
278
+
279
+ frames = []
280
+ length = 0
281
+ for frame in af.decode(streams=stream.index):
282
+ buf = torch.from_numpy(frame.to_ndarray())
283
+ if buf.shape[0] != n_channels:
284
+ buf = buf.view(-1, n_channels).t()
285
+
286
+ frames.append(buf)
287
+ length += buf.shape[1]
288
+
289
+ if not frames:
290
+ raise ValueError("No audio frames decoded.")
291
+
292
+ wav = torch.cat(frames, dim=1)
293
+ wav = f32_pcm(wav)
294
+ return wav, sr
295
+
296
+ class LoadAudio(IO.ComfyNode):
297
+ @classmethod
298
+ def define_schema(cls):
299
+ input_dir = folder_paths.get_input_directory()
300
+ files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
301
+ return IO.Schema(
302
+ node_id="LoadAudio",
303
+ search_aliases=["import audio", "open audio", "audio file"],
304
+ display_name="Load Audio",
305
+ category="audio",
306
+ essentials_category="Audio",
307
+ inputs=[
308
+ IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
309
+ ],
310
+ outputs=[IO.Audio.Output()],
311
+ )
312
+
313
+ @classmethod
314
+ def execute(cls, audio) -> IO.NodeOutput:
315
+ audio_path = folder_paths.get_annotated_filepath(audio)
316
+ waveform, sample_rate = load(audio_path)
317
+ audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
318
+ return IO.NodeOutput(audio)
319
+
320
+ @classmethod
321
+ def fingerprint_inputs(cls, audio):
322
+ image_path = folder_paths.get_annotated_filepath(audio)
323
+ m = hashlib.sha256()
324
+ with open(image_path, 'rb') as f:
325
+ m.update(f.read())
326
+ return m.digest().hex()
327
+
328
+ @classmethod
329
+ def validate_inputs(cls, audio):
330
+ if not folder_paths.exists_annotated_filepath(audio):
331
+ return "Invalid audio file: {}".format(audio)
332
+ return True
333
+
334
+ load = execute # TODO: remove
335
+
336
+
337
+ class RecordAudio(IO.ComfyNode):
338
+ @classmethod
339
+ def define_schema(cls):
340
+ return IO.Schema(
341
+ node_id="RecordAudio",
342
+ search_aliases=["microphone input", "audio capture", "voice input"],
343
+ display_name="Record Audio",
344
+ category="audio",
345
+ inputs=[
346
+ IO.Custom("AUDIO_RECORD").Input("audio"),
347
+ ],
348
+ outputs=[IO.Audio.Output()],
349
+ )
350
+
351
+ @classmethod
352
+ def execute(cls, audio) -> IO.NodeOutput:
353
+ audio_path = folder_paths.get_annotated_filepath(audio)
354
+
355
+ waveform, sample_rate = load(audio_path)
356
+ audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
357
+ return IO.NodeOutput(audio)
358
+
359
+ load = execute # TODO: remove
360
+
361
+
362
+ class TrimAudioDuration(IO.ComfyNode):
363
+ @classmethod
364
+ def define_schema(cls):
365
+ return IO.Schema(
366
+ node_id="TrimAudioDuration",
367
+ search_aliases=["cut audio", "audio clip", "shorten audio"],
368
+ display_name="Trim Audio Duration",
369
+ description="Trim audio tensor into chosen time range.",
370
+ category="audio",
371
+ inputs=[
372
+ IO.Audio.Input("audio"),
373
+ IO.Float.Input(
374
+ "start_index",
375
+ default=0.0,
376
+ min=-0xffffffffffffffff,
377
+ max=0xffffffffffffffff,
378
+ step=0.01,
379
+ tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
380
+ ),
381
+ IO.Float.Input(
382
+ "duration",
383
+ default=60.0,
384
+ min=0.0,
385
+ step=0.01,
386
+ tooltip="Duration in seconds",
387
+ ),
388
+ ],
389
+ outputs=[IO.Audio.Output()],
390
+ )
391
+
392
+ @classmethod
393
+ def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
394
+ waveform = audio["waveform"]
395
+ sample_rate = audio["sample_rate"]
396
+ audio_length = waveform.shape[-1]
397
+
398
+ if start_index < 0:
399
+ start_frame = audio_length + int(round(start_index * sample_rate))
400
+ else:
401
+ start_frame = int(round(start_index * sample_rate))
402
+ start_frame = max(0, min(start_frame, audio_length - 1))
403
+
404
+ end_frame = start_frame + int(round(duration * sample_rate))
405
+ end_frame = max(0, min(end_frame, audio_length))
406
+
407
+ if start_frame >= end_frame:
408
+ raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
409
+
410
+ return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
411
+
412
+ trim = execute # TODO: remove
413
+
414
+
415
+ class SplitAudioChannels(IO.ComfyNode):
416
+ @classmethod
417
+ def define_schema(cls):
418
+ return IO.Schema(
419
+ node_id="SplitAudioChannels",
420
+ search_aliases=["stereo to mono"],
421
+ display_name="Split Audio Channels",
422
+ description="Separates the audio into left and right channels.",
423
+ category="audio",
424
+ inputs=[
425
+ IO.Audio.Input("audio"),
426
+ ],
427
+ outputs=[
428
+ IO.Audio.Output(display_name="left"),
429
+ IO.Audio.Output(display_name="right"),
430
+ ],
431
+ )
432
+
433
+ @classmethod
434
+ def execute(cls, audio) -> IO.NodeOutput:
435
+ waveform = audio["waveform"]
436
+ sample_rate = audio["sample_rate"]
437
+
438
+ if waveform.shape[1] != 2:
439
+ raise ValueError("AudioSplit: Input audio has only one channel.")
440
+
441
+ left_channel = waveform[..., 0:1, :]
442
+ right_channel = waveform[..., 1:2, :]
443
+
444
+ return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
445
+
446
+ separate = execute # TODO: remove
447
+
448
+ class JoinAudioChannels(IO.ComfyNode):
449
+ @classmethod
450
+ def define_schema(cls):
451
+ return IO.Schema(
452
+ node_id="JoinAudioChannels",
453
+ display_name="Join Audio Channels",
454
+ description="Joins left and right mono audio channels into a stereo audio.",
455
+ category="audio",
456
+ inputs=[
457
+ IO.Audio.Input("audio_left"),
458
+ IO.Audio.Input("audio_right"),
459
+ ],
460
+ outputs=[
461
+ IO.Audio.Output(display_name="audio"),
462
+ ],
463
+ )
464
+
465
+ @classmethod
466
+ def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
467
+ waveform_left = audio_left["waveform"]
468
+ sample_rate_left = audio_left["sample_rate"]
469
+ waveform_right = audio_right["waveform"]
470
+ sample_rate_right = audio_right["sample_rate"]
471
+
472
+ if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1:
473
+ raise ValueError("AudioJoin: Both input audios must be mono.")
474
+
475
+ # Handle different sample rates by resampling to the higher rate
476
+ waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates(
477
+ waveform_left, sample_rate_left, waveform_right, sample_rate_right
478
+ )
479
+
480
+ # Handle different lengths by trimming to the shorter length
481
+ length_left = waveform_left.shape[-1]
482
+ length_right = waveform_right.shape[-1]
483
+
484
+ if length_left != length_right:
485
+ min_length = min(length_left, length_right)
486
+ if length_left > min_length:
487
+ logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.")
488
+ waveform_left = waveform_left[..., :min_length]
489
+ if length_right > min_length:
490
+ logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.")
491
+ waveform_right = waveform_right[..., :min_length]
492
+
493
+ # Join the channels into stereo
494
+ left_channel = waveform_left[..., 0:1, :]
495
+ right_channel = waveform_right[..., 0:1, :]
496
+ stereo_waveform = torch.cat([left_channel, right_channel], dim=1)
497
+
498
+ return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate})
499
+
500
+
501
+ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
502
+ if sample_rate_1 != sample_rate_2:
503
+ if sample_rate_1 > sample_rate_2:
504
+ waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1)
505
+ output_sample_rate = sample_rate_1
506
+ logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.")
507
+ else:
508
+ waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2)
509
+ output_sample_rate = sample_rate_2
510
+ logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.")
511
+ else:
512
+ output_sample_rate = sample_rate_1
513
+ return waveform_1, waveform_2, output_sample_rate
514
+
515
+
516
+ class AudioConcat(IO.ComfyNode):
517
+ @classmethod
518
+ def define_schema(cls):
519
+ return IO.Schema(
520
+ node_id="AudioConcat",
521
+ search_aliases=["join audio", "combine audio", "append audio"],
522
+ display_name="Audio Concat",
523
+ description="Concatenates the audio1 to audio2 in the specified direction.",
524
+ category="audio",
525
+ inputs=[
526
+ IO.Audio.Input("audio1"),
527
+ IO.Audio.Input("audio2"),
528
+ IO.Combo.Input(
529
+ "direction",
530
+ options=['after', 'before'],
531
+ default="after",
532
+ tooltip="Whether to append audio2 after or before audio1.",
533
+ )
534
+ ],
535
+ outputs=[IO.Audio.Output()],
536
+ )
537
+
538
+ @classmethod
539
+ def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
540
+ waveform_1 = audio1["waveform"]
541
+ waveform_2 = audio2["waveform"]
542
+ sample_rate_1 = audio1["sample_rate"]
543
+ sample_rate_2 = audio2["sample_rate"]
544
+
545
+ if waveform_1.shape[1] == 1:
546
+ waveform_1 = waveform_1.repeat(1, 2, 1)
547
+ logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.")
548
+ if waveform_2.shape[1] == 1:
549
+ waveform_2 = waveform_2.repeat(1, 2, 1)
550
+ logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.")
551
+
552
+ waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
553
+
554
+ if direction == 'after':
555
+ concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)
556
+ elif direction == 'before':
557
+ concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
558
+
559
+ return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
560
+
561
+ concat = execute # TODO: remove
562
+
563
+
564
+ class AudioMerge(IO.ComfyNode):
565
+ @classmethod
566
+ def define_schema(cls):
567
+ return IO.Schema(
568
+ node_id="AudioMerge",
569
+ search_aliases=["mix audio", "overlay audio", "layer audio"],
570
+ display_name="Audio Merge",
571
+ description="Combine two audio tracks by overlaying their waveforms.",
572
+ category="audio",
573
+ inputs=[
574
+ IO.Audio.Input("audio1"),
575
+ IO.Audio.Input("audio2"),
576
+ IO.Combo.Input(
577
+ "merge_method",
578
+ options=["add", "mean", "subtract", "multiply"],
579
+ tooltip="The method used to combine the audio waveforms.",
580
+ )
581
+ ],
582
+ outputs=[IO.Audio.Output()],
583
+ )
584
+
585
+ @classmethod
586
+ def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
587
+ waveform_1 = audio1["waveform"]
588
+ waveform_2 = audio2["waveform"]
589
+ sample_rate_1 = audio1["sample_rate"]
590
+ sample_rate_2 = audio2["sample_rate"]
591
+
592
+ waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
593
+
594
+ length_1 = waveform_1.shape[-1]
595
+ length_2 = waveform_2.shape[-1]
596
+
597
+ if length_2 > length_1:
598
+ logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
599
+ waveform_2 = waveform_2[..., :length_1]
600
+ elif length_2 < length_1:
601
+ logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.")
602
+ pad_shape = list(waveform_2.shape)
603
+ pad_shape[-1] = length_1 - length_2
604
+ pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device)
605
+ waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1)
606
+
607
+ if merge_method == "add":
608
+ waveform = waveform_1 + waveform_2
609
+ elif merge_method == "subtract":
610
+ waveform = waveform_1 - waveform_2
611
+ elif merge_method == "multiply":
612
+ waveform = waveform_1 * waveform_2
613
+ elif merge_method == "mean":
614
+ waveform = (waveform_1 + waveform_2) / 2
615
+
616
+ max_val = waveform.abs().max()
617
+ if max_val > 1.0:
618
+ waveform = waveform / max_val
619
+
620
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
621
+
622
+ merge = execute # TODO: remove
623
+
624
+
625
+ class AudioAdjustVolume(IO.ComfyNode):
626
+ @classmethod
627
+ def define_schema(cls):
628
+ return IO.Schema(
629
+ node_id="AudioAdjustVolume",
630
+ search_aliases=["audio gain", "loudness", "audio level"],
631
+ display_name="Audio Adjust Volume",
632
+ category="audio",
633
+ inputs=[
634
+ IO.Audio.Input("audio"),
635
+ IO.Int.Input(
636
+ "volume",
637
+ default=1,
638
+ min=-100,
639
+ max=100,
640
+ tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
641
+ )
642
+ ],
643
+ outputs=[IO.Audio.Output()],
644
+ )
645
+
646
+ @classmethod
647
+ def execute(cls, audio, volume) -> IO.NodeOutput:
648
+ if volume == 0:
649
+ return IO.NodeOutput(audio)
650
+ waveform = audio["waveform"]
651
+ sample_rate = audio["sample_rate"]
652
+
653
+ gain = 10 ** (volume / 20)
654
+ waveform = waveform * gain
655
+
656
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
657
+
658
+ adjust_volume = execute # TODO: remove
659
+
660
+
661
+ class EmptyAudio(IO.ComfyNode):
662
+ @classmethod
663
+ def define_schema(cls):
664
+ return IO.Schema(
665
+ node_id="EmptyAudio",
666
+ search_aliases=["blank audio"],
667
+ display_name="Empty Audio",
668
+ category="audio",
669
+ inputs=[
670
+ IO.Float.Input(
671
+ "duration",
672
+ default=60.0,
673
+ min=0.0,
674
+ max=0xffffffffffffffff,
675
+ step=0.01,
676
+ tooltip="Duration of the empty audio clip in seconds",
677
+ ),
678
+ IO.Int.Input(
679
+ "sample_rate",
680
+ default=44100,
681
+ tooltip="Sample rate of the empty audio clip.",
682
+ min=1,
683
+ max=192000,
684
+ advanced=True,
685
+ ),
686
+ IO.Int.Input(
687
+ "channels",
688
+ default=2,
689
+ min=1,
690
+ max=2,
691
+ tooltip="Number of audio channels (1 for mono, 2 for stereo).",
692
+ advanced=True,
693
+ ),
694
+ ],
695
+ outputs=[IO.Audio.Output()],
696
+ )
697
+
698
+ @classmethod
699
+ def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
700
+ num_samples = int(round(duration * sample_rate))
701
+ waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
702
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
703
+
704
+ create_empty_audio = execute # TODO: remove
705
+
706
+
707
+ class AudioEqualizer3Band(IO.ComfyNode):
708
+ @classmethod
709
+ def define_schema(cls):
710
+ return IO.Schema(
711
+ node_id="AudioEqualizer3Band",
712
+ search_aliases=["eq", "bass boost", "treble boost", "equalizer"],
713
+ display_name="Audio Equalizer (3-Band)",
714
+ category="audio",
715
+ is_experimental=True,
716
+ inputs=[
717
+ IO.Audio.Input("audio"),
718
+ IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"),
719
+ IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"),
720
+ IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"),
721
+ IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"),
722
+ IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"),
723
+ IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"),
724
+ IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"),
725
+ ],
726
+ outputs=[IO.Audio.Output()],
727
+ )
728
+
729
+ @classmethod
730
+ def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
731
+ waveform = audio["waveform"]
732
+ sample_rate = audio["sample_rate"]
733
+ eq_waveform = waveform.clone()
734
+
735
+ # 1. Apply Low Shelf (Bass)
736
+ if low_gain_dB != 0:
737
+ eq_waveform = torchaudio.functional.bass_biquad(
738
+ eq_waveform,
739
+ sample_rate,
740
+ gain=low_gain_dB,
741
+ central_freq=float(low_freq),
742
+ Q=0.707
743
+ )
744
+
745
+ # 2. Apply Peaking EQ (Mids)
746
+ if mid_gain_dB != 0:
747
+ eq_waveform = torchaudio.functional.equalizer_biquad(
748
+ eq_waveform,
749
+ sample_rate,
750
+ center_freq=float(mid_freq),
751
+ gain=mid_gain_dB,
752
+ Q=mid_q
753
+ )
754
+
755
+ # 3. Apply High Shelf (Treble)
756
+ if high_gain_dB != 0:
757
+ eq_waveform = torchaudio.functional.treble_biquad(
758
+ eq_waveform,
759
+ sample_rate,
760
+ gain=high_gain_dB,
761
+ central_freq=float(high_freq),
762
+ Q=0.707
763
+ )
764
+
765
+ return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate})
766
+
767
+
768
+ class AudioExtension(ComfyExtension):
769
+ @override
770
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
771
+ return [
772
+ EmptyLatentAudio,
773
+ VAEEncodeAudio,
774
+ VAEDecodeAudio,
775
+ VAEDecodeAudioTiled,
776
+ SaveAudio,
777
+ SaveAudioMP3,
778
+ SaveAudioOpus,
779
+ LoadAudio,
780
+ PreviewAudio,
781
+ ConditioningStableAudio,
782
+ RecordAudio,
783
+ TrimAudioDuration,
784
+ SplitAudioChannels,
785
+ JoinAudioChannels,
786
+ AudioConcat,
787
+ AudioMerge,
788
+ AudioAdjustVolume,
789
+ EmptyAudio,
790
+ AudioEqualizer3Band,
791
+ ]
792
+
793
+ async def comfy_entrypoint() -> AudioExtension:
794
+ return AudioExtension()
ComfyUI/comfy_extras/nodes_audio_encoder.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import folder_paths
2
+ import comfy.audio_encoders.audio_encoders
3
+ import comfy.utils
4
+ from typing_extensions import override
5
+ from comfy_api.latest import ComfyExtension, io
6
+
7
+
8
+ class AudioEncoderLoader(io.ComfyNode):
9
+ @classmethod
10
+ def define_schema(cls) -> io.Schema:
11
+ return io.Schema(
12
+ node_id="AudioEncoderLoader",
13
+ category="loaders",
14
+ inputs=[
15
+ io.Combo.Input(
16
+ "audio_encoder_name",
17
+ options=folder_paths.get_filename_list("audio_encoders"),
18
+ ),
19
+ ],
20
+ outputs=[io.AudioEncoder.Output()],
21
+ )
22
+
23
+ @classmethod
24
+ def execute(cls, audio_encoder_name) -> io.NodeOutput:
25
+ audio_encoder_name = folder_paths.get_full_path_or_raise("audio_encoders", audio_encoder_name)
26
+ sd = comfy.utils.load_torch_file(audio_encoder_name, safe_load=True)
27
+ audio_encoder = comfy.audio_encoders.audio_encoders.load_audio_encoder_from_sd(sd)
28
+ if audio_encoder is None:
29
+ raise RuntimeError("ERROR: audio encoder file is invalid and does not contain a valid model.")
30
+ return io.NodeOutput(audio_encoder)
31
+
32
+
33
+ class AudioEncoderEncode(io.ComfyNode):
34
+ @classmethod
35
+ def define_schema(cls) -> io.Schema:
36
+ return io.Schema(
37
+ node_id="AudioEncoderEncode",
38
+ category="conditioning",
39
+ inputs=[
40
+ io.AudioEncoder.Input("audio_encoder"),
41
+ io.Audio.Input("audio"),
42
+ ],
43
+ outputs=[io.AudioEncoderOutput.Output()],
44
+ )
45
+
46
+ @classmethod
47
+ def execute(cls, audio_encoder, audio) -> io.NodeOutput:
48
+ output = audio_encoder.encode_audio(audio["waveform"], audio["sample_rate"])
49
+ return io.NodeOutput(output)
50
+
51
+
52
+ class AudioEncoder(ComfyExtension):
53
+ @override
54
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
55
+ return [
56
+ AudioEncoderLoader,
57
+ AudioEncoderEncode,
58
+ ]
59
+
60
+
61
+ async def comfy_entrypoint() -> AudioEncoder:
62
+ return AudioEncoder()
ComfyUI/comfy_extras/nodes_camera_trajectory.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import torch
3
+ import numpy as np
4
+ from einops import rearrange
5
+ from typing_extensions import override
6
+ import comfy.model_management
7
+
8
+ from comfy_api.latest import ComfyExtension, io
9
+
10
+
11
+ CAMERA_DICT = {
12
+ "base_T_norm": 1.5,
13
+ "base_angle": np.pi/3,
14
+ "Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]},
15
+ "Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]},
16
+ "Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]},
17
+ "Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]},
18
+ "Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]},
19
+ "Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]},
20
+ "Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]},
21
+ "Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]},
22
+ "ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]},
23
+ }
24
+
25
+
26
+ def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
27
+
28
+ def get_relative_pose(cam_params):
29
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
30
+ """
31
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
32
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
33
+ cam_to_origin = 0
34
+ target_cam_c2w = np.array([
35
+ [1, 0, 0, 0],
36
+ [0, 1, 0, -cam_to_origin],
37
+ [0, 0, 1, 0],
38
+ [0, 0, 0, 1]
39
+ ])
40
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
41
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
42
+ ret_poses = np.array(ret_poses, dtype=np.float32)
43
+ return ret_poses
44
+
45
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
46
+ """
47
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
48
+
49
+ sample_wh_ratio = width / height
50
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
51
+
52
+ if pose_wh_ratio > sample_wh_ratio:
53
+ resized_ori_w = height * pose_wh_ratio
54
+ for cam_param in cam_params:
55
+ cam_param.fx = resized_ori_w * cam_param.fx / width
56
+ else:
57
+ resized_ori_h = width / pose_wh_ratio
58
+ for cam_param in cam_params:
59
+ cam_param.fy = resized_ori_h * cam_param.fy / height
60
+
61
+ intrinsic = np.asarray([[cam_param.fx * width,
62
+ cam_param.fy * height,
63
+ cam_param.cx * width,
64
+ cam_param.cy * height]
65
+ for cam_param in cam_params], dtype=np.float32)
66
+
67
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
68
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
69
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
70
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
71
+ plucker_embedding = plucker_embedding[None]
72
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
73
+ return plucker_embedding
74
+
75
+ class Camera(object):
76
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
77
+ """
78
+ def __init__(self, entry):
79
+ fx, fy, cx, cy = entry[1:5]
80
+ self.fx = fx
81
+ self.fy = fy
82
+ self.cx = cx
83
+ self.cy = cy
84
+ c2w_mat = np.array(entry[7:]).reshape(4, 4)
85
+ self.c2w_mat = c2w_mat
86
+ self.w2c_mat = np.linalg.inv(c2w_mat)
87
+
88
+ def ray_condition(K, c2w, H, W, device):
89
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
90
+ """
91
+ # c2w: B, V, 4, 4
92
+ # K: B, V, 4
93
+
94
+ B = K.shape[0]
95
+
96
+ j, i = torch.meshgrid(
97
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
98
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
99
+ indexing='ij'
100
+ )
101
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
102
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
103
+
104
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
105
+
106
+ zs = torch.ones_like(i) # [B, HxW]
107
+ xs = (i - cx) / fx * zs
108
+ ys = (j - cy) / fy * zs
109
+ zs = zs.expand_as(ys)
110
+
111
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
112
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
113
+
114
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
115
+ rays_o = c2w[..., :3, 3] # B, V, 3
116
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
117
+ # c2w @ dirctions
118
+ rays_dxo = torch.cross(rays_o, rays_d)
119
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
120
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
121
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
122
+ return plucker
123
+
124
+ def get_camera_motion(angle, T, speed, n=81):
125
+ def compute_R_form_rad_angle(angles):
126
+ theta_x, theta_y, theta_z = angles
127
+ Rx = np.array([[1, 0, 0],
128
+ [0, np.cos(theta_x), -np.sin(theta_x)],
129
+ [0, np.sin(theta_x), np.cos(theta_x)]])
130
+
131
+ Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)],
132
+ [0, 1, 0],
133
+ [-np.sin(theta_y), 0, np.cos(theta_y)]])
134
+
135
+ Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
136
+ [np.sin(theta_z), np.cos(theta_z), 0],
137
+ [0, 0, 1]])
138
+
139
+ R = np.dot(Rz, np.dot(Ry, Rx))
140
+ return R
141
+ RT = []
142
+ for i in range(n):
143
+ _angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle
144
+ R = compute_R_form_rad_angle(_angle)
145
+ _T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1))
146
+ _RT = np.concatenate([R,_T], axis=1)
147
+ RT.append(_RT)
148
+ RT = np.stack(RT)
149
+ return RT
150
+
151
+ class WanCameraEmbedding(io.ComfyNode):
152
+ @classmethod
153
+ def define_schema(cls):
154
+ return io.Schema(
155
+ node_id="WanCameraEmbedding",
156
+ category="camera",
157
+ inputs=[
158
+ io.Combo.Input(
159
+ "camera_pose",
160
+ options=[
161
+ "Static",
162
+ "Pan Up",
163
+ "Pan Down",
164
+ "Pan Left",
165
+ "Pan Right",
166
+ "Zoom In",
167
+ "Zoom Out",
168
+ "Anti Clockwise (ACW)",
169
+ "ClockWise (CW)",
170
+ ],
171
+ default="Static",
172
+ ),
173
+ io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
174
+ io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
175
+ io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
176
+ io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True),
177
+ io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True),
178
+ io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True),
179
+ io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True),
180
+ io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True),
181
+ ],
182
+ outputs=[
183
+ io.WanCameraEmbedding.Output(display_name="camera_embedding"),
184
+ io.Int.Output(display_name="width"),
185
+ io.Int.Output(display_name="height"),
186
+ io.Int.Output(display_name="length"),
187
+ ],
188
+ )
189
+
190
+ @classmethod
191
+ def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput:
192
+ """
193
+ Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
194
+ Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
195
+ """
196
+ motion_list = [camera_pose]
197
+ speed = speed
198
+ angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
199
+ T = np.array(CAMERA_DICT[motion_list[0]]["T"])
200
+ RT = get_camera_motion(angle, T, speed, length)
201
+
202
+ trajs=[]
203
+ for cp in RT.tolist():
204
+ traj=[fx,fy,cx,cy,0,0]
205
+ traj.extend(cp[0])
206
+ traj.extend(cp[1])
207
+ traj.extend(cp[2])
208
+ traj.extend([0,0,0,1])
209
+ trajs.append(traj)
210
+
211
+ cam_params = np.array([[float(x) for x in pose] for pose in trajs])
212
+ cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
213
+ control_camera_video = process_pose_params(cam_params, width=width, height=height)
214
+ control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
215
+
216
+ control_camera_video = torch.concat(
217
+ [
218
+ torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
219
+ control_camera_video[:, :, 1:]
220
+ ], dim=2
221
+ ).transpose(1, 2)
222
+
223
+ # Reshape, transpose, and view into desired shape
224
+ b, f, c, h, w = control_camera_video.shape
225
+ control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
226
+ control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
227
+
228
+ return io.NodeOutput(control_camera_video, width, height, length)
229
+
230
+
231
+ class CameraTrajectoryExtension(ComfyExtension):
232
+ @override
233
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
234
+ return [
235
+ WanCameraEmbedding,
236
+ ]
237
+
238
+ async def comfy_entrypoint() -> CameraTrajectoryExtension:
239
+ return CameraTrajectoryExtension()
ComfyUI/comfy_extras/nodes_canny.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from kornia.filters import canny
2
+ from typing_extensions import override
3
+
4
+ import comfy.model_management
5
+ from comfy_api.latest import ComfyExtension, io
6
+ import torch
7
+
8
+
9
+ class Canny(io.ComfyNode):
10
+ @classmethod
11
+ def define_schema(cls):
12
+ return io.Schema(
13
+ node_id="Canny",
14
+ display_name="Canny",
15
+ search_aliases=["edge detection", "outline", "contour detection", "line art"],
16
+ category="image/preprocessors",
17
+ essentials_category="Image Tools",
18
+ inputs=[
19
+ io.Image.Input("image"),
20
+ io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01),
21
+ io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01),
22
+ ],
23
+ outputs=[io.Image.Output()],
24
+ )
25
+
26
+ @classmethod
27
+ def detect_edge(cls, image, low_threshold, high_threshold):
28
+ # Deprecated: use the V3 schema's `execute` method instead of this.
29
+ return cls.execute(image, low_threshold, high_threshold)
30
+
31
+ @classmethod
32
+ def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
33
+ output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
34
+ img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
35
+ return io.NodeOutput(img_out)
36
+
37
+
38
+ class CannyExtension(ComfyExtension):
39
+ @override
40
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
41
+ return [Canny]
42
+
43
+
44
+ async def comfy_entrypoint() -> CannyExtension:
45
+ return CannyExtension()
ComfyUI/comfy_extras/nodes_cfg.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+
3
+ import torch
4
+
5
+ from comfy_api.latest import ComfyExtension, io
6
+
7
+
8
+ # https://github.com/WeichenFan/CFG-Zero-star
9
+ def optimized_scale(positive, negative):
10
+ positive_flat = positive.reshape(positive.shape[0], -1)
11
+ negative_flat = negative.reshape(negative.shape[0], -1)
12
+
13
+ # Calculate dot production
14
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
15
+
16
+ # Squared norm of uncondition
17
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
18
+
19
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
20
+ st_star = dot_product / squared_norm
21
+
22
+ return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
23
+
24
+ class CFGZeroStar(io.ComfyNode):
25
+ @classmethod
26
+ def define_schema(cls) -> io.Schema:
27
+ return io.Schema(
28
+ node_id="CFGZeroStar",
29
+ category="advanced/guidance",
30
+ inputs=[
31
+ io.Model.Input("model"),
32
+ ],
33
+ outputs=[io.Model.Output(display_name="patched_model")],
34
+ )
35
+
36
+ @classmethod
37
+ def execute(cls, model) -> io.NodeOutput:
38
+ m = model.clone()
39
+ def cfg_zero_star(args):
40
+ guidance_scale = args['cond_scale']
41
+ x = args['input']
42
+ cond_p = args['cond_denoised']
43
+ uncond_p = args['uncond_denoised']
44
+ out = args["denoised"]
45
+ alpha = optimized_scale(x - cond_p, x - uncond_p)
46
+
47
+ return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
48
+ m.set_model_sampler_post_cfg_function(cfg_zero_star)
49
+ return io.NodeOutput(m)
50
+
51
+ class CFGNorm(io.ComfyNode):
52
+ @classmethod
53
+ def define_schema(cls) -> io.Schema:
54
+ return io.Schema(
55
+ node_id="CFGNorm",
56
+ category="advanced/guidance",
57
+ inputs=[
58
+ io.Model.Input("model"),
59
+ io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
60
+ ],
61
+ outputs=[io.Model.Output(display_name="patched_model")],
62
+ is_experimental=True,
63
+ )
64
+
65
+ @classmethod
66
+ def execute(cls, model, strength) -> io.NodeOutput:
67
+ m = model.clone()
68
+ def cfg_norm(args):
69
+ cond_p = args['cond_denoised']
70
+ pred_text_ = args["denoised"]
71
+
72
+ norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
73
+ norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
74
+ scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
75
+ return pred_text_ * scale * strength
76
+
77
+ m.set_model_sampler_post_cfg_function(cfg_norm)
78
+ return io.NodeOutput(m)
79
+
80
+
81
+ class CfgExtension(ComfyExtension):
82
+ @override
83
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
84
+ return [
85
+ CFGZeroStar,
86
+ CFGNorm,
87
+ ]
88
+
89
+
90
+ async def comfy_entrypoint() -> CfgExtension:
91
+ return CfgExtension()
ComfyUI/comfy_extras/nodes_chroma_radiance.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+ from typing import Callable
3
+
4
+ import torch
5
+
6
+ import comfy.model_management
7
+ from comfy_api.latest import ComfyExtension, io
8
+
9
+ import nodes
10
+
11
+ class EmptyChromaRadianceLatentImage(io.ComfyNode):
12
+ @classmethod
13
+ def define_schema(cls) -> io.Schema:
14
+ return io.Schema(
15
+ node_id="EmptyChromaRadianceLatentImage",
16
+ category="latent/chroma_radiance",
17
+ inputs=[
18
+ io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
19
+ io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
20
+ io.Int.Input(id="batch_size", default=1, min=1, max=4096),
21
+ ],
22
+ outputs=[io.Latent().Output()],
23
+ )
24
+
25
+ @classmethod
26
+ def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput:
27
+ latent = torch.zeros((batch_size, 3, height, width), device=comfy.model_management.intermediate_device())
28
+ return io.NodeOutput({"samples":latent})
29
+
30
+
31
+ class ChromaRadianceOptions(io.ComfyNode):
32
+ @classmethod
33
+ def define_schema(cls) -> io.Schema:
34
+ return io.Schema(
35
+ node_id="ChromaRadianceOptions",
36
+ category="model_patches/chroma_radiance",
37
+ description="Allows setting advanced options for the Chroma Radiance model.",
38
+ inputs=[
39
+ io.Model.Input(id="model"),
40
+ io.Boolean.Input(
41
+ id="preserve_wrapper",
42
+ default=True,
43
+ tooltip="When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled.",
44
+ ),
45
+ io.Float.Input(
46
+ id="start_sigma",
47
+ default=1.0,
48
+ min=0.0,
49
+ max=1.0,
50
+ tooltip="First sigma that these options will be in effect.",
51
+ advanced=True,
52
+ ),
53
+ io.Float.Input(
54
+ id="end_sigma",
55
+ default=0.0,
56
+ min=0.0,
57
+ max=1.0,
58
+ tooltip="Last sigma that these options will be in effect.",
59
+ advanced=True,
60
+ ),
61
+ io.Int.Input(
62
+ id="nerf_tile_size",
63
+ default=-1,
64
+ min=-1,
65
+ tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
66
+ advanced=True,
67
+ ),
68
+ ],
69
+ outputs=[io.Model.Output()],
70
+ )
71
+
72
+ @classmethod
73
+ def execute(
74
+ cls,
75
+ *,
76
+ model: io.Model.Type,
77
+ preserve_wrapper: bool,
78
+ start_sigma: float,
79
+ end_sigma: float,
80
+ nerf_tile_size: int,
81
+ ) -> io.NodeOutput:
82
+ radiance_options = {}
83
+ if nerf_tile_size >= 0:
84
+ radiance_options["nerf_tile_size"] = nerf_tile_size
85
+
86
+ if not radiance_options:
87
+ return io.NodeOutput(model)
88
+
89
+ old_wrapper = model.model_options.get("model_function_wrapper")
90
+
91
+ def model_function_wrapper(apply_model: Callable, args: dict) -> torch.Tensor:
92
+ c = args["c"].copy()
93
+ sigma = args["timestep"].max().detach().cpu().item()
94
+ if end_sigma <= sigma <= start_sigma:
95
+ transformer_options = c.get("transformer_options", {}).copy()
96
+ transformer_options["chroma_radiance_options"] = radiance_options.copy()
97
+ c["transformer_options"] = transformer_options
98
+ if not (preserve_wrapper and old_wrapper):
99
+ return apply_model(args["input"], args["timestep"], **c)
100
+ return old_wrapper(apply_model, args | {"c": c})
101
+
102
+ model = model.clone()
103
+ model.set_model_unet_function_wrapper(model_function_wrapper)
104
+ return io.NodeOutput(model)
105
+
106
+
107
+ class ChromaRadianceExtension(ComfyExtension):
108
+ @override
109
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
110
+ return [
111
+ EmptyChromaRadianceLatentImage,
112
+ ChromaRadianceOptions,
113
+ ]
114
+
115
+
116
+ async def comfy_entrypoint() -> ChromaRadianceExtension:
117
+ return ChromaRadianceExtension()
ComfyUI/comfy_extras/nodes_clip_sdxl.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+
3
+ import nodes
4
+ from comfy_api.latest import ComfyExtension, io
5
+
6
+
7
+ class CLIPTextEncodeSDXLRefiner(io.ComfyNode):
8
+ @classmethod
9
+ def define_schema(cls):
10
+ return io.Schema(
11
+ node_id="CLIPTextEncodeSDXLRefiner",
12
+ category="advanced/conditioning",
13
+ inputs=[
14
+ io.Float.Input("ascore", default=6.0, min=0.0, max=1000.0, step=0.01),
15
+ io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
16
+ io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
17
+ io.String.Input("text", multiline=True, dynamic_prompts=True),
18
+ io.Clip.Input("clip"),
19
+ ],
20
+ outputs=[io.Conditioning.Output()],
21
+ )
22
+
23
+ @classmethod
24
+ def execute(cls, clip, ascore, width, height, text) -> io.NodeOutput:
25
+ tokens = clip.tokenize(text)
26
+ return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}))
27
+
28
+ class CLIPTextEncodeSDXL(io.ComfyNode):
29
+ @classmethod
30
+ def define_schema(cls):
31
+ return io.Schema(
32
+ node_id="CLIPTextEncodeSDXL",
33
+ category="advanced/conditioning",
34
+ inputs=[
35
+ io.Clip.Input("clip"),
36
+ io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
37
+ io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
38
+ io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True),
39
+ io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True),
40
+ io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION),
41
+ io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION),
42
+ io.String.Input("text_g", multiline=True, dynamic_prompts=True),
43
+ io.String.Input("text_l", multiline=True, dynamic_prompts=True),
44
+ ],
45
+ outputs=[io.Conditioning.Output()],
46
+ )
47
+
48
+ @classmethod
49
+ def execute(cls, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l) -> io.NodeOutput:
50
+ tokens = clip.tokenize(text_g)
51
+ tokens["l"] = clip.tokenize(text_l)["l"]
52
+ if len(tokens["l"]) != len(tokens["g"]):
53
+ empty = clip.tokenize("")
54
+ while len(tokens["l"]) < len(tokens["g"]):
55
+ tokens["l"] += empty["l"]
56
+ while len(tokens["l"]) > len(tokens["g"]):
57
+ tokens["g"] += empty["g"]
58
+ return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}))
59
+
60
+
61
+ class ClipSdxlExtension(ComfyExtension):
62
+ @override
63
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
64
+ return [
65
+ CLIPTextEncodeSDXLRefiner,
66
+ CLIPTextEncodeSDXL,
67
+ ]
68
+
69
+
70
+ async def comfy_entrypoint() -> ClipSdxlExtension:
71
+ return ClipSdxlExtension()
ComfyUI/comfy_extras/nodes_color.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+ from comfy_api.latest import ComfyExtension, io
3
+
4
+
5
+ class ColorToRGBInt(io.ComfyNode):
6
+ @classmethod
7
+ def define_schema(cls) -> io.Schema:
8
+ return io.Schema(
9
+ node_id="ColorToRGBInt",
10
+ display_name="Color to RGB Int",
11
+ category="utils",
12
+ description="Convert a color to a RGB integer value.",
13
+ inputs=[
14
+ io.Color.Input("color"),
15
+ ],
16
+ outputs=[
17
+ io.Int.Output(display_name="rgb_int"),
18
+ ],
19
+ )
20
+
21
+ @classmethod
22
+ def execute(
23
+ cls,
24
+ color: str,
25
+ ) -> io.NodeOutput:
26
+ # expect format #RRGGBB
27
+ if len(color) != 7 or color[0] != "#":
28
+ raise ValueError("Color must be in format #RRGGBB")
29
+ r = int(color[1:3], 16)
30
+ g = int(color[3:5], 16)
31
+ b = int(color[5:7], 16)
32
+ return io.NodeOutput(r * 256 * 256 + g * 256 + b)
33
+
34
+
35
+ class ColorExtension(ComfyExtension):
36
+ @override
37
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
38
+ return [ColorToRGBInt]
39
+
40
+
41
+ async def comfy_entrypoint() -> ColorExtension:
42
+ return ColorExtension()
ComfyUI/comfy_extras/nodes_compositing.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.utils
3
+ from enum import Enum
4
+ from typing_extensions import override
5
+ from comfy_api.latest import ComfyExtension, io
6
+
7
+
8
+ def resize_mask(mask, shape):
9
+ return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
10
+
11
+ class PorterDuffMode(Enum):
12
+ ADD = 0
13
+ CLEAR = 1
14
+ DARKEN = 2
15
+ DST = 3
16
+ DST_ATOP = 4
17
+ DST_IN = 5
18
+ DST_OUT = 6
19
+ DST_OVER = 7
20
+ LIGHTEN = 8
21
+ MULTIPLY = 9
22
+ OVERLAY = 10
23
+ SCREEN = 11
24
+ SRC = 12
25
+ SRC_ATOP = 13
26
+ SRC_IN = 14
27
+ SRC_OUT = 15
28
+ SRC_OVER = 16
29
+ XOR = 17
30
+
31
+
32
+ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
33
+ # convert mask to alpha
34
+ src_alpha = 1 - src_alpha
35
+ dst_alpha = 1 - dst_alpha
36
+ # premultiply alpha
37
+ src_image = src_image * src_alpha
38
+ dst_image = dst_image * dst_alpha
39
+
40
+ # composite ops below assume alpha-premultiplied images
41
+ if mode == PorterDuffMode.ADD:
42
+ out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
43
+ out_image = torch.clamp(src_image + dst_image, 0, 1)
44
+ elif mode == PorterDuffMode.CLEAR:
45
+ out_alpha = torch.zeros_like(dst_alpha)
46
+ out_image = torch.zeros_like(dst_image)
47
+ elif mode == PorterDuffMode.DARKEN:
48
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
49
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
50
+ elif mode == PorterDuffMode.DST:
51
+ out_alpha = dst_alpha
52
+ out_image = dst_image
53
+ elif mode == PorterDuffMode.DST_ATOP:
54
+ out_alpha = src_alpha
55
+ out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
56
+ elif mode == PorterDuffMode.DST_IN:
57
+ out_alpha = src_alpha * dst_alpha
58
+ out_image = dst_image * src_alpha
59
+ elif mode == PorterDuffMode.DST_OUT:
60
+ out_alpha = (1 - src_alpha) * dst_alpha
61
+ out_image = (1 - src_alpha) * dst_image
62
+ elif mode == PorterDuffMode.DST_OVER:
63
+ out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
64
+ out_image = dst_image + (1 - dst_alpha) * src_image
65
+ elif mode == PorterDuffMode.LIGHTEN:
66
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
67
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
68
+ elif mode == PorterDuffMode.MULTIPLY:
69
+ out_alpha = src_alpha * dst_alpha
70
+ out_image = src_image * dst_image
71
+ elif mode == PorterDuffMode.OVERLAY:
72
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
73
+ out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
74
+ src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
75
+ elif mode == PorterDuffMode.SCREEN:
76
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
77
+ out_image = src_image + dst_image - src_image * dst_image
78
+ elif mode == PorterDuffMode.SRC:
79
+ out_alpha = src_alpha
80
+ out_image = src_image
81
+ elif mode == PorterDuffMode.SRC_ATOP:
82
+ out_alpha = dst_alpha
83
+ out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
84
+ elif mode == PorterDuffMode.SRC_IN:
85
+ out_alpha = src_alpha * dst_alpha
86
+ out_image = src_image * dst_alpha
87
+ elif mode == PorterDuffMode.SRC_OUT:
88
+ out_alpha = (1 - dst_alpha) * src_alpha
89
+ out_image = (1 - dst_alpha) * src_image
90
+ elif mode == PorterDuffMode.SRC_OVER:
91
+ out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
92
+ out_image = src_image + (1 - src_alpha) * dst_image
93
+ elif mode == PorterDuffMode.XOR:
94
+ out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
95
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
96
+ else:
97
+ return None, None
98
+
99
+ # back to non-premultiplied alpha
100
+ out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image))
101
+ out_image = torch.clamp(out_image, 0, 1)
102
+ # convert alpha to mask
103
+ out_alpha = 1 - out_alpha
104
+ return out_image, out_alpha
105
+
106
+
107
+ class PorterDuffImageComposite(io.ComfyNode):
108
+ @classmethod
109
+ def define_schema(cls):
110
+ return io.Schema(
111
+ node_id="PorterDuffImageComposite",
112
+ search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"],
113
+ display_name="Porter-Duff Image Composite",
114
+ category="mask/compositing",
115
+ inputs=[
116
+ io.Image.Input("source"),
117
+ io.Mask.Input("source_alpha"),
118
+ io.Image.Input("destination"),
119
+ io.Mask.Input("destination_alpha"),
120
+ io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
121
+ ],
122
+ outputs=[
123
+ io.Image.Output(),
124
+ io.Mask.Output(),
125
+ ],
126
+ )
127
+
128
+ @classmethod
129
+ def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
130
+ batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
131
+ out_images = []
132
+ out_alphas = []
133
+
134
+ for i in range(batch_size):
135
+ src_image = source[i]
136
+ dst_image = destination[i]
137
+
138
+ assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
139
+
140
+ src_alpha = source_alpha[i].unsqueeze(2)
141
+ dst_alpha = destination_alpha[i].unsqueeze(2)
142
+
143
+ if dst_alpha.shape[:2] != dst_image.shape[:2]:
144
+ upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
145
+ upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
146
+ dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
147
+ if src_image.shape != dst_image.shape:
148
+ upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
149
+ upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
150
+ src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
151
+ if src_alpha.shape != dst_alpha.shape:
152
+ upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
153
+ upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
154
+ src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
155
+
156
+ out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
157
+
158
+ out_images.append(out_image)
159
+ out_alphas.append(out_alpha.squeeze(2))
160
+
161
+ return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
162
+
163
+
164
+ class SplitImageWithAlpha(io.ComfyNode):
165
+ @classmethod
166
+ def define_schema(cls):
167
+ return io.Schema(
168
+ node_id="SplitImageWithAlpha",
169
+ search_aliases=["extract alpha", "separate transparency", "remove alpha"],
170
+ display_name="Split Image with Alpha",
171
+ category="mask/compositing",
172
+ inputs=[
173
+ io.Image.Input("image"),
174
+ ],
175
+ outputs=[
176
+ io.Image.Output(),
177
+ io.Mask.Output(),
178
+ ],
179
+ )
180
+
181
+ @classmethod
182
+ def execute(cls, image: torch.Tensor) -> io.NodeOutput:
183
+ out_images = [i[:,:,:3] for i in image]
184
+ out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
185
+ return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
186
+
187
+
188
+ class JoinImageWithAlpha(io.ComfyNode):
189
+ @classmethod
190
+ def define_schema(cls):
191
+ return io.Schema(
192
+ node_id="JoinImageWithAlpha",
193
+ search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"],
194
+ display_name="Join Image with Alpha",
195
+ category="mask/compositing",
196
+ inputs=[
197
+ io.Image.Input("image"),
198
+ io.Mask.Input("alpha"),
199
+ ],
200
+ outputs=[io.Image.Output()],
201
+ )
202
+
203
+ @classmethod
204
+ def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
205
+ batch_size = min(len(image), len(alpha))
206
+ out_images = []
207
+
208
+ alpha = 1.0 - resize_mask(alpha, image.shape[1:])
209
+ for i in range(batch_size):
210
+ out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
211
+
212
+ return io.NodeOutput(torch.stack(out_images))
213
+
214
+
215
+ class CompositingExtension(ComfyExtension):
216
+ @override
217
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
218
+ return [
219
+ PorterDuffImageComposite,
220
+ SplitImageWithAlpha,
221
+ JoinImageWithAlpha,
222
+ ]
223
+
224
+
225
+ async def comfy_entrypoint() -> CompositingExtension:
226
+ return CompositingExtension()
ComfyUI/comfy_extras/nodes_cond.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+
3
+ from comfy_api.latest import ComfyExtension, io
4
+
5
+
6
+ class CLIPTextEncodeControlnet(io.ComfyNode):
7
+ @classmethod
8
+ def define_schema(cls) -> io.Schema:
9
+ return io.Schema(
10
+ node_id="CLIPTextEncodeControlnet",
11
+ category="_for_testing/conditioning",
12
+ inputs=[
13
+ io.Clip.Input("clip"),
14
+ io.Conditioning.Input("conditioning"),
15
+ io.String.Input("text", multiline=True, dynamic_prompts=True),
16
+ ],
17
+ outputs=[io.Conditioning.Output()],
18
+ is_experimental=True,
19
+ )
20
+
21
+ @classmethod
22
+ def execute(cls, clip, conditioning, text) -> io.NodeOutput:
23
+ tokens = clip.tokenize(text)
24
+ cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
25
+ c = []
26
+ for t in conditioning:
27
+ n = [t[0], t[1].copy()]
28
+ n[1]['cross_attn_controlnet'] = cond
29
+ n[1]['pooled_output_controlnet'] = pooled
30
+ c.append(n)
31
+ return io.NodeOutput(c)
32
+
33
+ class T5TokenizerOptions(io.ComfyNode):
34
+ @classmethod
35
+ def define_schema(cls) -> io.Schema:
36
+ return io.Schema(
37
+ node_id="T5TokenizerOptions",
38
+ category="_for_testing/conditioning",
39
+ inputs=[
40
+ io.Clip.Input("clip"),
41
+ io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
42
+ io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True),
43
+ ],
44
+ outputs=[io.Clip.Output()],
45
+ is_experimental=True,
46
+ )
47
+
48
+ @classmethod
49
+ def execute(cls, clip, min_padding, min_length) -> io.NodeOutput:
50
+ clip = clip.clone()
51
+ for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
52
+ clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
53
+ clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
54
+
55
+ return io.NodeOutput(clip)
56
+
57
+
58
+ class CondExtension(ComfyExtension):
59
+ @override
60
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
61
+ return [
62
+ CLIPTextEncodeControlnet,
63
+ T5TokenizerOptions,
64
+ ]
65
+
66
+
67
+ async def comfy_entrypoint() -> CondExtension:
68
+ return CondExtension()
ComfyUI/comfy_extras/nodes_context_windows.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from comfy_api.latest import ComfyExtension, io
3
+ import comfy.context_windows
4
+ import nodes
5
+
6
+
7
+ class ContextWindowsManualNode(io.ComfyNode):
8
+ @classmethod
9
+ def define_schema(cls) -> io.Schema:
10
+ return io.Schema(
11
+ node_id="ContextWindowsManual",
12
+ display_name="Context Windows (Manual)",
13
+ category="context",
14
+ description="Manually set context windows.",
15
+ inputs=[
16
+ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
17
+ io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True),
18
+ io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True),
19
+ io.Combo.Input("context_schedule", options=[
20
+ comfy.context_windows.ContextSchedules.STATIC_STANDARD,
21
+ comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
22
+ comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
23
+ comfy.context_windows.ContextSchedules.BATCHED,
24
+ ], tooltip="The stride of the context window."),
25
+ io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
26
+ io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
27
+ io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
28
+ io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
29
+ io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
30
+ io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
31
+ io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
32
+ ],
33
+ outputs=[
34
+ io.Model.Output(tooltip="The model with context windows applied during sampling."),
35
+ ],
36
+ is_experimental=True,
37
+ )
38
+
39
+ @classmethod
40
+ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
41
+ cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
42
+ model = model.clone()
43
+ model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
44
+ context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
45
+ fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method),
46
+ context_length=context_length,
47
+ context_overlap=context_overlap,
48
+ context_stride=context_stride,
49
+ closed_loop=closed_loop,
50
+ dim=dim,
51
+ freenoise=freenoise,
52
+ cond_retain_index_list=cond_retain_index_list,
53
+ split_conds_to_windows=split_conds_to_windows
54
+ )
55
+ # make memory usage calculation only take into account the context window latents
56
+ comfy.context_windows.create_prepare_sampling_wrapper(model)
57
+ if freenoise: # no other use for this wrapper at this time
58
+ comfy.context_windows.create_sampler_sample_wrapper(model)
59
+ return io.NodeOutput(model)
60
+
61
+ class WanContextWindowsManualNode(ContextWindowsManualNode):
62
+ @classmethod
63
+ def define_schema(cls) -> io.Schema:
64
+ schema = super().define_schema()
65
+ schema.node_id = "WanContextWindowsManual"
66
+ schema.display_name = "WAN Context Windows (Manual)"
67
+ schema.description = "Manually set context windows for WAN-like models (dim=2)."
68
+ schema.inputs = [
69
+ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
70
+ io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True),
71
+ io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True),
72
+ io.Combo.Input("context_schedule", options=[
73
+ comfy.context_windows.ContextSchedules.STATIC_STANDARD,
74
+ comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
75
+ comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
76
+ comfy.context_windows.ContextSchedules.BATCHED,
77
+ ], tooltip="The stride of the context window."),
78
+ io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
79
+ io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
80
+ io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
81
+ io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
82
+ #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
83
+ #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
84
+ ]
85
+ return schema
86
+
87
+ @classmethod
88
+ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
89
+ cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
90
+ context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
91
+ context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
92
+ return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
93
+
94
+
95
+ class ContextWindowsExtension(ComfyExtension):
96
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
97
+ return [
98
+ ContextWindowsManualNode,
99
+ WanContextWindowsManualNode,
100
+ ]
101
+
102
+ def comfy_entrypoint():
103
+ return ContextWindowsExtension()
ComfyUI/comfy_extras/nodes_controlnet.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
2
+ import nodes
3
+ import comfy.utils
4
+ from typing_extensions import override
5
+ from comfy_api.latest import ComfyExtension, io
6
+
7
+ class SetUnionControlNetType(io.ComfyNode):
8
+ @classmethod
9
+ def define_schema(cls):
10
+ return io.Schema(
11
+ node_id="SetUnionControlNetType",
12
+ category="conditioning/controlnet",
13
+ inputs=[
14
+ io.ControlNet.Input("control_net"),
15
+ io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
16
+ ],
17
+ outputs=[
18
+ io.ControlNet.Output(),
19
+ ],
20
+ )
21
+
22
+ @classmethod
23
+ def execute(cls, control_net, type) -> io.NodeOutput:
24
+ control_net = control_net.copy()
25
+ type_number = UNION_CONTROLNET_TYPES.get(type, -1)
26
+ if type_number >= 0:
27
+ control_net.set_extra_arg("control_type", [type_number])
28
+ else:
29
+ control_net.set_extra_arg("control_type", [])
30
+
31
+ return io.NodeOutput(control_net)
32
+
33
+ set_controlnet_type = execute # TODO: remove
34
+
35
+
36
+ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
37
+ @classmethod
38
+ def define_schema(cls):
39
+ return io.Schema(
40
+ node_id="ControlNetInpaintingAliMamaApply",
41
+ search_aliases=["masked controlnet"],
42
+ category="conditioning/controlnet",
43
+ inputs=[
44
+ io.Conditioning.Input("positive"),
45
+ io.Conditioning.Input("negative"),
46
+ io.ControlNet.Input("control_net"),
47
+ io.Vae.Input("vae"),
48
+ io.Image.Input("image"),
49
+ io.Mask.Input("mask"),
50
+ io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
51
+ io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True),
52
+ io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True),
53
+ ],
54
+ outputs=[
55
+ io.Conditioning.Output(display_name="positive"),
56
+ io.Conditioning.Output(display_name="negative"),
57
+ ],
58
+ )
59
+
60
+ @classmethod
61
+ def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
62
+ extra_concat = []
63
+ if control_net.concat_mask:
64
+ mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
65
+ mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
66
+ image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
67
+ extra_concat = [mask]
68
+
69
+ result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
70
+ return io.NodeOutput(result[0], result[1])
71
+
72
+ apply_inpaint_controlnet = execute # TODO: remove
73
+
74
+
75
+ class ControlNetExtension(ComfyExtension):
76
+ @override
77
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
78
+ return [
79
+ SetUnionControlNetType,
80
+ ControlNetInpaintingAliMamaApply,
81
+ ]
82
+
83
+
84
+ async def comfy_entrypoint() -> ControlNetExtension:
85
+ return ControlNetExtension()
ComfyUI/comfy_extras/nodes_cosmos.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+ import nodes
3
+ import torch
4
+ import comfy.model_management
5
+ import comfy.utils
6
+ import comfy.latent_formats
7
+
8
+ from comfy_api.latest import ComfyExtension, io
9
+
10
+
11
+ class EmptyCosmosLatentVideo(io.ComfyNode):
12
+ @classmethod
13
+ def define_schema(cls) -> io.Schema:
14
+ return io.Schema(
15
+ node_id="EmptyCosmosLatentVideo",
16
+ category="latent/video",
17
+ inputs=[
18
+ io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
19
+ io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
20
+ io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
21
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
22
+ ],
23
+ outputs=[io.Latent.Output()],
24
+ )
25
+
26
+ @classmethod
27
+ def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
28
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
29
+ return io.NodeOutput({"samples": latent})
30
+
31
+
32
+ def vae_encode_with_padding(vae, image, width, height, length, padding=0):
33
+ pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
34
+ pixel_len = min(pixels.shape[0], length)
35
+ padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7)
36
+ padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
37
+ padded_pixels[:pixel_len] = pixels[:pixel_len]
38
+ latent_len = ((pixel_len - 1) // 8) + 1
39
+ latent_temp = vae.encode(padded_pixels)
40
+ return latent_temp[:, :, :latent_len]
41
+
42
+
43
+ class CosmosImageToVideoLatent(io.ComfyNode):
44
+ @classmethod
45
+ def define_schema(cls) -> io.Schema:
46
+ return io.Schema(
47
+ node_id="CosmosImageToVideoLatent",
48
+ category="conditioning/inpaint",
49
+ inputs=[
50
+ io.Vae.Input("vae"),
51
+ io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
52
+ io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
53
+ io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
54
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
55
+ io.Image.Input("start_image", optional=True),
56
+ io.Image.Input("end_image", optional=True),
57
+ ],
58
+ outputs=[io.Latent.Output()],
59
+ )
60
+
61
+ @classmethod
62
+ def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
63
+ latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
64
+ if start_image is None and end_image is None:
65
+ out_latent = {}
66
+ out_latent["samples"] = latent
67
+ return io.NodeOutput(out_latent)
68
+
69
+ mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
70
+
71
+ if start_image is not None:
72
+ latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
73
+ latent[:, :, :latent_temp.shape[-3]] = latent_temp
74
+ mask[:, :, :latent_temp.shape[-3]] *= 0.0
75
+
76
+ if end_image is not None:
77
+ latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
78
+ latent[:, :, -latent_temp.shape[-3]:] = latent_temp
79
+ mask[:, :, -latent_temp.shape[-3]:] *= 0.0
80
+
81
+ out_latent = {}
82
+ out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
83
+ out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
84
+ return io.NodeOutput(out_latent)
85
+
86
+ class CosmosPredict2ImageToVideoLatent(io.ComfyNode):
87
+ @classmethod
88
+ def define_schema(cls) -> io.Schema:
89
+ return io.Schema(
90
+ node_id="CosmosPredict2ImageToVideoLatent",
91
+ category="conditioning/inpaint",
92
+ inputs=[
93
+ io.Vae.Input("vae"),
94
+ io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
95
+ io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
96
+ io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4),
97
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
98
+ io.Image.Input("start_image", optional=True),
99
+ io.Image.Input("end_image", optional=True),
100
+ ],
101
+ outputs=[io.Latent.Output()],
102
+ )
103
+
104
+ @classmethod
105
+ def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
106
+ latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
107
+ if start_image is None and end_image is None:
108
+ out_latent = {}
109
+ out_latent["samples"] = latent
110
+ return io.NodeOutput(out_latent)
111
+
112
+ mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
113
+
114
+ if start_image is not None:
115
+ latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
116
+ latent[:, :, :latent_temp.shape[-3]] = latent_temp
117
+ mask[:, :, :latent_temp.shape[-3]] *= 0.0
118
+
119
+ if end_image is not None:
120
+ latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
121
+ latent[:, :, -latent_temp.shape[-3]:] = latent_temp
122
+ mask[:, :, -latent_temp.shape[-3]:] *= 0.0
123
+
124
+ out_latent = {}
125
+ latent_format = comfy.latent_formats.Wan21()
126
+ latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
127
+ out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
128
+ out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
129
+ return io.NodeOutput(out_latent)
130
+
131
+
132
+ class CosmosExtension(ComfyExtension):
133
+ @override
134
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
135
+ return [
136
+ EmptyCosmosLatentVideo,
137
+ CosmosImageToVideoLatent,
138
+ CosmosPredict2ImageToVideoLatent,
139
+ ]
140
+
141
+
142
+ async def comfy_entrypoint() -> CosmosExtension:
143
+ return CosmosExtension()
ComfyUI/comfy_extras/nodes_curve.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from comfy_api.latest import ComfyExtension, io
6
+ from comfy_api.input import CurveInput
7
+ from typing_extensions import override
8
+
9
+
10
+ class CurveEditor(io.ComfyNode):
11
+ @classmethod
12
+ def define_schema(cls):
13
+ return io.Schema(
14
+ node_id="CurveEditor",
15
+ display_name="Curve Editor",
16
+ category="utils",
17
+ inputs=[
18
+ io.Curve.Input("curve"),
19
+ io.Histogram.Input("histogram", optional=True),
20
+ ],
21
+ outputs=[
22
+ io.Curve.Output("curve"),
23
+ ],
24
+ )
25
+
26
+ @classmethod
27
+ def execute(cls, curve, histogram=None) -> io.NodeOutput:
28
+ result = CurveInput.from_raw(curve)
29
+
30
+ ui = {}
31
+ if histogram is not None:
32
+ ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
33
+
34
+ return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
35
+
36
+
37
+ class ImageHistogram(io.ComfyNode):
38
+ @classmethod
39
+ def define_schema(cls):
40
+ return io.Schema(
41
+ node_id="ImageHistogram",
42
+ display_name="Image Histogram",
43
+ category="utils",
44
+ inputs=[
45
+ io.Image.Input("image"),
46
+ ],
47
+ outputs=[
48
+ io.Histogram.Output("rgb"),
49
+ io.Histogram.Output("luminance"),
50
+ io.Histogram.Output("red"),
51
+ io.Histogram.Output("green"),
52
+ io.Histogram.Output("blue"),
53
+ ],
54
+ )
55
+
56
+ @classmethod
57
+ def execute(cls, image) -> io.NodeOutput:
58
+ img = image[0].cpu().numpy()
59
+ img_uint8 = np.clip(img * 255, 0, 255).astype(np.uint8)
60
+
61
+ def bincount(data):
62
+ return np.bincount(data.ravel(), minlength=256)[:256]
63
+
64
+ hist_r = bincount(img_uint8[:, :, 0])
65
+ hist_g = bincount(img_uint8[:, :, 1])
66
+ hist_b = bincount(img_uint8[:, :, 2])
67
+
68
+ # Average of R, G, B histograms (same as Photoshop's RGB composite)
69
+ rgb = ((hist_r + hist_g + hist_b) // 3).tolist()
70
+
71
+ # ITU-R BT.709-6, Item 3.2 (p.6) — Derivation of luminance signal
72
+ # https://www.itu.int/rec/R-REC-BT.709-6-201506-I/en
73
+ lum = 0.2126 * img[:, :, 0] + 0.7152 * img[:, :, 1] + 0.0722 * img[:, :, 2]
74
+ luminance = bincount(np.clip(lum * 255, 0, 255).astype(np.uint8)).tolist()
75
+
76
+ return io.NodeOutput(
77
+ rgb,
78
+ luminance,
79
+ hist_r.tolist(),
80
+ hist_g.tolist(),
81
+ hist_b.tolist(),
82
+ )
83
+
84
+
85
+ class CurveExtension(ComfyExtension):
86
+ @override
87
+ async def get_node_list(self):
88
+ return [CurveEditor, ImageHistogram]
89
+
90
+
91
+ async def comfy_entrypoint():
92
+ return CurveExtension()
ComfyUI/comfy_extras/nodes_custom_sampler.py ADDED
@@ -0,0 +1,1095 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import comfy.samplers
3
+ import comfy.sample
4
+ from comfy.k_diffusion import sampling as k_diffusion_sampling
5
+ from comfy.k_diffusion import sa_solver
6
+ import latent_preview
7
+ import torch
8
+ import comfy.utils
9
+ import node_helpers
10
+ from typing_extensions import override
11
+ from comfy_api.latest import ComfyExtension, io
12
+ import re
13
+
14
+
15
+ class BasicScheduler(io.ComfyNode):
16
+ @classmethod
17
+ def define_schema(cls):
18
+ return io.Schema(
19
+ node_id="BasicScheduler",
20
+ category="sampling/custom_sampling/schedulers",
21
+ inputs=[
22
+ io.Model.Input("model"),
23
+ io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES),
24
+ io.Int.Input("steps", default=20, min=1, max=10000),
25
+ io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
26
+ ],
27
+ outputs=[io.Sigmas.Output()]
28
+ )
29
+
30
+ @classmethod
31
+ def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput:
32
+ total_steps = steps
33
+ if denoise < 1.0:
34
+ if denoise <= 0.0:
35
+ return io.NodeOutput(torch.FloatTensor([]))
36
+ total_steps = int(steps/denoise)
37
+
38
+ sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu()
39
+ sigmas = sigmas[-(steps + 1):]
40
+ return io.NodeOutput(sigmas)
41
+
42
+ get_sigmas = execute
43
+
44
+
45
+ class KarrasScheduler(io.ComfyNode):
46
+ @classmethod
47
+ def define_schema(cls):
48
+ return io.Schema(
49
+ node_id="KarrasScheduler",
50
+ category="sampling/custom_sampling/schedulers",
51
+ inputs=[
52
+ io.Int.Input("steps", default=20, min=1, max=10000),
53
+ io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
54
+ io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
55
+ io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
56
+ ],
57
+ outputs=[io.Sigmas.Output()]
58
+ )
59
+
60
+ @classmethod
61
+ def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput:
62
+ sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
63
+ return io.NodeOutput(sigmas)
64
+
65
+ get_sigmas = execute
66
+
67
+ class ExponentialScheduler(io.ComfyNode):
68
+ @classmethod
69
+ def define_schema(cls):
70
+ return io.Schema(
71
+ node_id="ExponentialScheduler",
72
+ category="sampling/custom_sampling/schedulers",
73
+ inputs=[
74
+ io.Int.Input("steps", default=20, min=1, max=10000),
75
+ io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
76
+ io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
77
+ ],
78
+ outputs=[io.Sigmas.Output()]
79
+ )
80
+
81
+ @classmethod
82
+ def execute(cls, steps, sigma_max, sigma_min) -> io.NodeOutput:
83
+ sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max)
84
+ return io.NodeOutput(sigmas)
85
+
86
+ get_sigmas = execute
87
+
88
+ class PolyexponentialScheduler(io.ComfyNode):
89
+ @classmethod
90
+ def define_schema(cls):
91
+ return io.Schema(
92
+ node_id="PolyexponentialScheduler",
93
+ category="sampling/custom_sampling/schedulers",
94
+ inputs=[
95
+ io.Int.Input("steps", default=20, min=1, max=10000),
96
+ io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
97
+ io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
98
+ io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
99
+ ],
100
+ outputs=[io.Sigmas.Output()]
101
+ )
102
+
103
+ @classmethod
104
+ def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput:
105
+ sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
106
+ return io.NodeOutput(sigmas)
107
+
108
+ get_sigmas = execute
109
+
110
+ class LaplaceScheduler(io.ComfyNode):
111
+ @classmethod
112
+ def define_schema(cls):
113
+ return io.Schema(
114
+ node_id="LaplaceScheduler",
115
+ category="sampling/custom_sampling/schedulers",
116
+ inputs=[
117
+ io.Int.Input("steps", default=20, min=1, max=10000),
118
+ io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
119
+ io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
120
+ io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False, advanced=True),
121
+ io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False, advanced=True),
122
+ ],
123
+ outputs=[io.Sigmas.Output()]
124
+ )
125
+
126
+ @classmethod
127
+ def execute(cls, steps, sigma_max, sigma_min, mu, beta) -> io.NodeOutput:
128
+ sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
129
+ return io.NodeOutput(sigmas)
130
+
131
+ get_sigmas = execute
132
+
133
+
134
+ class SDTurboScheduler(io.ComfyNode):
135
+ @classmethod
136
+ def define_schema(cls):
137
+ return io.Schema(
138
+ node_id="SDTurboScheduler",
139
+ category="sampling/custom_sampling/schedulers",
140
+ inputs=[
141
+ io.Model.Input("model"),
142
+ io.Int.Input("steps", default=1, min=1, max=10),
143
+ io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01),
144
+ ],
145
+ outputs=[io.Sigmas.Output()]
146
+ )
147
+
148
+ @classmethod
149
+ def execute(cls, model, steps, denoise) -> io.NodeOutput:
150
+ start_step = 10 - int(10 * denoise)
151
+ timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
152
+ sigmas = model.get_model_object("model_sampling").sigma(timesteps)
153
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
154
+ return io.NodeOutput(sigmas)
155
+
156
+ get_sigmas = execute
157
+
158
+ class BetaSamplingScheduler(io.ComfyNode):
159
+ @classmethod
160
+ def define_schema(cls):
161
+ return io.Schema(
162
+ node_id="BetaSamplingScheduler",
163
+ category="sampling/custom_sampling/schedulers",
164
+ inputs=[
165
+ io.Model.Input("model"),
166
+ io.Int.Input("steps", default=20, min=1, max=10000),
167
+ io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True),
168
+ io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True),
169
+ ],
170
+ outputs=[io.Sigmas.Output()]
171
+ )
172
+
173
+ @classmethod
174
+ def execute(cls, model, steps, alpha, beta) -> io.NodeOutput:
175
+ sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta)
176
+ return io.NodeOutput(sigmas)
177
+
178
+ get_sigmas = execute
179
+
180
+ class VPScheduler(io.ComfyNode):
181
+ @classmethod
182
+ def define_schema(cls):
183
+ return io.Schema(
184
+ node_id="VPScheduler",
185
+ category="sampling/custom_sampling/schedulers",
186
+ inputs=[
187
+ io.Int.Input("steps", default=20, min=1, max=10000),
188
+ io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), #TODO: fix default values
189
+ io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
190
+ io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False, advanced=True),
191
+ ],
192
+ outputs=[io.Sigmas.Output()]
193
+ )
194
+
195
+ @classmethod
196
+ def execute(cls, steps, beta_d, beta_min, eps_s) -> io.NodeOutput:
197
+ sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
198
+ return io.NodeOutput(sigmas)
199
+
200
+ get_sigmas = execute
201
+
202
+ class SplitSigmas(io.ComfyNode):
203
+ @classmethod
204
+ def define_schema(cls):
205
+ return io.Schema(
206
+ node_id="SplitSigmas",
207
+ category="sampling/custom_sampling/sigmas",
208
+ inputs=[
209
+ io.Sigmas.Input("sigmas"),
210
+ io.Int.Input("step", default=0, min=0, max=10000),
211
+ ],
212
+ outputs=[
213
+ io.Sigmas.Output(display_name="high_sigmas"),
214
+ io.Sigmas.Output(display_name="low_sigmas"),
215
+ ]
216
+ )
217
+
218
+ @classmethod
219
+ def execute(cls, sigmas, step) -> io.NodeOutput:
220
+ sigmas1 = sigmas[:step + 1]
221
+ sigmas2 = sigmas[step:]
222
+ return io.NodeOutput(sigmas1, sigmas2)
223
+
224
+ get_sigmas = execute
225
+
226
+ class SplitSigmasDenoise(io.ComfyNode):
227
+ @classmethod
228
+ def define_schema(cls):
229
+ return io.Schema(
230
+ node_id="SplitSigmasDenoise",
231
+ category="sampling/custom_sampling/sigmas",
232
+ inputs=[
233
+ io.Sigmas.Input("sigmas"),
234
+ io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
235
+ ],
236
+ outputs=[
237
+ io.Sigmas.Output(display_name="high_sigmas"),
238
+ io.Sigmas.Output(display_name="low_sigmas"),
239
+ ]
240
+ )
241
+
242
+ @classmethod
243
+ def execute(cls, sigmas, denoise) -> io.NodeOutput:
244
+ steps = max(sigmas.shape[-1] - 1, 0)
245
+ total_steps = round(steps * denoise)
246
+ sigmas1 = sigmas[:-(total_steps)]
247
+ sigmas2 = sigmas[-(total_steps + 1):]
248
+ return io.NodeOutput(sigmas1, sigmas2)
249
+
250
+ get_sigmas = execute
251
+
252
+ class FlipSigmas(io.ComfyNode):
253
+ @classmethod
254
+ def define_schema(cls):
255
+ return io.Schema(
256
+ node_id="FlipSigmas",
257
+ category="sampling/custom_sampling/sigmas",
258
+ inputs=[io.Sigmas.Input("sigmas")],
259
+ outputs=[io.Sigmas.Output()]
260
+ )
261
+
262
+ @classmethod
263
+ def execute(cls, sigmas) -> io.NodeOutput:
264
+ if len(sigmas) == 0:
265
+ return io.NodeOutput(sigmas)
266
+
267
+ sigmas = sigmas.flip(0)
268
+ if sigmas[0] == 0:
269
+ sigmas[0] = 0.0001
270
+ return io.NodeOutput(sigmas)
271
+
272
+ get_sigmas = execute
273
+
274
+ class SetFirstSigma(io.ComfyNode):
275
+ @classmethod
276
+ def define_schema(cls):
277
+ return io.Schema(
278
+ node_id="SetFirstSigma",
279
+ category="sampling/custom_sampling/sigmas",
280
+ inputs=[
281
+ io.Sigmas.Input("sigmas"),
282
+ io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False),
283
+ ],
284
+ outputs=[io.Sigmas.Output()]
285
+ )
286
+
287
+ @classmethod
288
+ def execute(cls, sigmas, sigma) -> io.NodeOutput:
289
+ sigmas = sigmas.clone()
290
+ sigmas[0] = sigma
291
+ return io.NodeOutput(sigmas)
292
+
293
+ set_first_sigma = execute
294
+
295
+ class ExtendIntermediateSigmas(io.ComfyNode):
296
+ @classmethod
297
+ def define_schema(cls):
298
+ return io.Schema(
299
+ node_id="ExtendIntermediateSigmas",
300
+ search_aliases=["interpolate sigmas"],
301
+ category="sampling/custom_sampling/sigmas",
302
+ inputs=[
303
+ io.Sigmas.Input("sigmas"),
304
+ io.Int.Input("steps", default=2, min=1, max=100),
305
+ io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False),
306
+ io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False),
307
+ io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']),
308
+ ],
309
+ outputs=[io.Sigmas.Output()]
310
+ )
311
+
312
+ @classmethod
313
+ def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str) -> io.NodeOutput:
314
+ if start_at_sigma < 0:
315
+ start_at_sigma = float("inf")
316
+
317
+ interpolator = {
318
+ 'linear': lambda x: x,
319
+ 'cosine': lambda x: torch.sin(x*math.pi/2),
320
+ 'sine': lambda x: 1 - torch.cos(x*math.pi/2)
321
+ }[spacing]
322
+
323
+ # linear space for our interpolation function
324
+ x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1]
325
+ computed_spacing = interpolator(x)
326
+
327
+ extended_sigmas = []
328
+ for i in range(len(sigmas) - 1):
329
+ sigma_current = sigmas[i]
330
+ sigma_next = sigmas[i+1]
331
+
332
+ extended_sigmas.append(sigma_current)
333
+
334
+ if end_at_sigma <= sigma_current <= start_at_sigma:
335
+ interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current
336
+ extended_sigmas.extend(interpolated_steps.tolist())
337
+
338
+ # Add the last sigma value
339
+ if len(sigmas) > 0:
340
+ extended_sigmas.append(sigmas[-1])
341
+
342
+ extended_sigmas = torch.FloatTensor(extended_sigmas)
343
+
344
+ return io.NodeOutput(extended_sigmas)
345
+
346
+ extend = execute
347
+
348
+
349
+ class SamplingPercentToSigma(io.ComfyNode):
350
+ @classmethod
351
+ def define_schema(cls):
352
+ return io.Schema(
353
+ node_id="SamplingPercentToSigma",
354
+ category="sampling/custom_sampling/sigmas",
355
+ inputs=[
356
+ io.Model.Input("model"),
357
+ io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001),
358
+ io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."),
359
+ ],
360
+ outputs=[io.Float.Output(display_name="sigma_value")]
361
+ )
362
+
363
+ @classmethod
364
+ def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput:
365
+ model_sampling = model.get_model_object("model_sampling")
366
+ sigma_val = model_sampling.percent_to_sigma(sampling_percent)
367
+ if return_actual_sigma:
368
+ if sampling_percent == 0.0:
369
+ sigma_val = model_sampling.sigma_max.item()
370
+ elif sampling_percent == 1.0:
371
+ sigma_val = model_sampling.sigma_min.item()
372
+ return io.NodeOutput(sigma_val)
373
+
374
+ get_sigma = execute
375
+
376
+
377
+ class KSamplerSelect(io.ComfyNode):
378
+ @classmethod
379
+ def define_schema(cls):
380
+ return io.Schema(
381
+ node_id="KSamplerSelect",
382
+ category="sampling/custom_sampling/samplers",
383
+ inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)],
384
+ outputs=[io.Sampler.Output()]
385
+ )
386
+
387
+ @classmethod
388
+ def execute(cls, sampler_name) -> io.NodeOutput:
389
+ sampler = comfy.samplers.sampler_object(sampler_name)
390
+ return io.NodeOutput(sampler)
391
+
392
+ get_sampler = execute
393
+
394
+ class SamplerDPMPP_3M_SDE(io.ComfyNode):
395
+ @classmethod
396
+ def define_schema(cls):
397
+ return io.Schema(
398
+ node_id="SamplerDPMPP_3M_SDE",
399
+ category="sampling/custom_sampling/samplers",
400
+ inputs=[
401
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
402
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
403
+ io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True),
404
+ ],
405
+ outputs=[io.Sampler.Output()]
406
+ )
407
+
408
+ @classmethod
409
+ def execute(cls, eta, s_noise, noise_device) -> io.NodeOutput:
410
+ if noise_device == 'cpu':
411
+ sampler_name = "dpmpp_3m_sde"
412
+ else:
413
+ sampler_name = "dpmpp_3m_sde_gpu"
414
+ sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise})
415
+ return io.NodeOutput(sampler)
416
+
417
+ get_sampler = execute
418
+
419
+ class SamplerDPMPP_2M_SDE(io.ComfyNode):
420
+ @classmethod
421
+ def define_schema(cls):
422
+ return io.Schema(
423
+ node_id="SamplerDPMPP_2M_SDE",
424
+ category="sampling/custom_sampling/samplers",
425
+ inputs=[
426
+ io.Combo.Input("solver_type", options=['midpoint', 'heun']),
427
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
428
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
429
+ io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True),
430
+ ],
431
+ outputs=[io.Sampler.Output()]
432
+ )
433
+
434
+ @classmethod
435
+ def execute(cls, solver_type, eta, s_noise, noise_device) -> io.NodeOutput:
436
+ if noise_device == 'cpu':
437
+ sampler_name = "dpmpp_2m_sde"
438
+ else:
439
+ sampler_name = "dpmpp_2m_sde_gpu"
440
+ sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
441
+ return io.NodeOutput(sampler)
442
+
443
+ get_sampler = execute
444
+
445
+
446
+ class SamplerDPMPP_SDE(io.ComfyNode):
447
+ @classmethod
448
+ def define_schema(cls):
449
+ return io.Schema(
450
+ node_id="SamplerDPMPP_SDE",
451
+ category="sampling/custom_sampling/samplers",
452
+ inputs=[
453
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
454
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
455
+ io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
456
+ io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True),
457
+ ],
458
+ outputs=[io.Sampler.Output()]
459
+ )
460
+
461
+ @classmethod
462
+ def execute(cls, eta, s_noise, r, noise_device) -> io.NodeOutput:
463
+ if noise_device == 'cpu':
464
+ sampler_name = "dpmpp_sde"
465
+ else:
466
+ sampler_name = "dpmpp_sde_gpu"
467
+ sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
468
+ return io.NodeOutput(sampler)
469
+
470
+ get_sampler = execute
471
+
472
+ class SamplerDPMPP_2S_Ancestral(io.ComfyNode):
473
+ @classmethod
474
+ def define_schema(cls):
475
+ return io.Schema(
476
+ node_id="SamplerDPMPP_2S_Ancestral",
477
+ category="sampling/custom_sampling/samplers",
478
+ inputs=[
479
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
480
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
481
+ ],
482
+ outputs=[io.Sampler.Output()]
483
+ )
484
+
485
+ @classmethod
486
+ def execute(cls, eta, s_noise) -> io.NodeOutput:
487
+ sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise})
488
+ return io.NodeOutput(sampler)
489
+
490
+ get_sampler = execute
491
+
492
+ class SamplerEulerAncestral(io.ComfyNode):
493
+ @classmethod
494
+ def define_schema(cls):
495
+ return io.Schema(
496
+ node_id="SamplerEulerAncestral",
497
+ category="sampling/custom_sampling/samplers",
498
+ inputs=[
499
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
500
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
501
+ ],
502
+ outputs=[io.Sampler.Output()]
503
+ )
504
+
505
+ @classmethod
506
+ def execute(cls, eta, s_noise) -> io.NodeOutput:
507
+ sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise})
508
+ return io.NodeOutput(sampler)
509
+
510
+ get_sampler = execute
511
+
512
+ class SamplerEulerAncestralCFGPP(io.ComfyNode):
513
+ @classmethod
514
+ def define_schema(cls):
515
+ return io.Schema(
516
+ node_id="SamplerEulerAncestralCFGPP",
517
+ display_name="SamplerEulerAncestralCFG++",
518
+ category="sampling/custom_sampling/samplers",
519
+ inputs=[
520
+ io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False),
521
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
522
+ ],
523
+ outputs=[io.Sampler.Output()]
524
+ )
525
+
526
+ @classmethod
527
+ def execute(cls, eta, s_noise) -> io.NodeOutput:
528
+ sampler = comfy.samplers.ksampler(
529
+ "euler_ancestral_cfg_pp",
530
+ {"eta": eta, "s_noise": s_noise})
531
+ return io.NodeOutput(sampler)
532
+
533
+ get_sampler = execute
534
+
535
+ class SamplerLMS(io.ComfyNode):
536
+ @classmethod
537
+ def define_schema(cls):
538
+ return io.Schema(
539
+ node_id="SamplerLMS",
540
+ category="sampling/custom_sampling/samplers",
541
+ inputs=[io.Int.Input("order", default=4, min=1, max=100, advanced=True)],
542
+ outputs=[io.Sampler.Output()]
543
+ )
544
+
545
+ @classmethod
546
+ def execute(cls, order) -> io.NodeOutput:
547
+ sampler = comfy.samplers.ksampler("lms", {"order": order})
548
+ return io.NodeOutput(sampler)
549
+
550
+ get_sampler = execute
551
+
552
+ class SamplerDPMAdaptative(io.ComfyNode):
553
+ @classmethod
554
+ def define_schema(cls):
555
+ return io.Schema(
556
+ node_id="SamplerDPMAdaptative",
557
+ category="sampling/custom_sampling/samplers",
558
+ inputs=[
559
+ io.Int.Input("order", default=3, min=2, max=3, advanced=True),
560
+ io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
561
+ io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
562
+ io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
563
+ io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
564
+ io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
565
+ io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
566
+ io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
567
+ io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
568
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
569
+ ],
570
+ outputs=[io.Sampler.Output()]
571
+ )
572
+
573
+ @classmethod
574
+ def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) -> io.NodeOutput:
575
+ sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff,
576
+ "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta,
577
+ "s_noise":s_noise })
578
+ return io.NodeOutput(sampler)
579
+
580
+ get_sampler = execute
581
+
582
+
583
+ class SamplerER_SDE(io.ComfyNode):
584
+ @classmethod
585
+ def define_schema(cls):
586
+ return io.Schema(
587
+ node_id="SamplerER_SDE",
588
+ category="sampling/custom_sampling/samplers",
589
+ inputs=[
590
+ io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
591
+ io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True),
592
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type.", advanced=True),
593
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
594
+ ],
595
+ outputs=[io.Sampler.Output()]
596
+ )
597
+
598
+ @classmethod
599
+ def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput:
600
+ if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0):
601
+ eta = 0
602
+ s_noise = 0
603
+
604
+ def reverse_time_sde_noise_scaler(x):
605
+ return x ** (eta + 1)
606
+
607
+ if solver_type == "ER-SDE":
608
+ # Use the default one in sample_er_sde()
609
+ noise_scaler = None
610
+ else:
611
+ noise_scaler = reverse_time_sde_noise_scaler
612
+
613
+ sampler_name = "er_sde"
614
+ sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage})
615
+ return io.NodeOutput(sampler)
616
+
617
+ get_sampler = execute
618
+
619
+
620
+ class SamplerSASolver(io.ComfyNode):
621
+ @classmethod
622
+ def define_schema(cls):
623
+ return io.Schema(
624
+ node_id="SamplerSASolver",
625
+ search_aliases=["sde"],
626
+ category="sampling/custom_sampling/samplers",
627
+ inputs=[
628
+ io.Model.Input("model"),
629
+ io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False, advanced=True),
630
+ io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001, advanced=True),
631
+ io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001, advanced=True),
632
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
633
+ io.Int.Input("predictor_order", default=3, min=1, max=6, advanced=True),
634
+ io.Int.Input("corrector_order", default=4, min=0, max=6, advanced=True),
635
+ io.Boolean.Input("use_pece", advanced=True),
636
+ io.Boolean.Input("simple_order_2", advanced=True),
637
+ ],
638
+ outputs=[io.Sampler.Output()]
639
+ )
640
+
641
+ @classmethod
642
+ def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput:
643
+ model_sampling = model.get_model_object("model_sampling")
644
+ start_sigma = model_sampling.percent_to_sigma(sde_start_percent)
645
+ end_sigma = model_sampling.percent_to_sigma(sde_end_percent)
646
+ tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta)
647
+
648
+ sampler_name = "sa_solver"
649
+ sampler = comfy.samplers.ksampler(
650
+ sampler_name,
651
+ {
652
+ "tau_func": tau_func,
653
+ "s_noise": s_noise,
654
+ "predictor_order": predictor_order,
655
+ "corrector_order": corrector_order,
656
+ "use_pece": use_pece,
657
+ "simple_order_2": simple_order_2,
658
+ },
659
+ )
660
+ return io.NodeOutput(sampler)
661
+
662
+ get_sampler = execute
663
+
664
+
665
+ class SamplerSEEDS2(io.ComfyNode):
666
+ @classmethod
667
+ def define_schema(cls):
668
+ return io.Schema(
669
+ node_id="SamplerSEEDS2",
670
+ search_aliases=["sde", "exp heun"],
671
+ category="sampling/custom_sampling/samplers",
672
+ inputs=[
673
+ io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
674
+ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength", advanced=True),
675
+ io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier", advanced=True),
676
+ io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)", advanced=True),
677
+ ],
678
+ outputs=[io.Sampler.Output()],
679
+ description=(
680
+ "This sampler node can represent multiple samplers:\n\n"
681
+ "seeds_2\n"
682
+ "- default setting\n\n"
683
+ "exp_heun_2_x0\n"
684
+ "- solver_type=phi_2, r=1.0, eta=0.0\n\n"
685
+ "exp_heun_2_x0_sde\n"
686
+ "- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0"
687
+ )
688
+ )
689
+
690
+ @classmethod
691
+ def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
692
+ sampler_name = "seeds_2"
693
+ sampler = comfy.samplers.ksampler(
694
+ sampler_name,
695
+ {"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
696
+ )
697
+ return io.NodeOutput(sampler)
698
+
699
+
700
+ class Noise_EmptyNoise:
701
+ def __init__(self):
702
+ self.seed = 0
703
+
704
+ def generate_noise(self, input_latent):
705
+ latent_image = input_latent["samples"]
706
+ if latent_image.is_nested:
707
+ tensors = latent_image.unbind()
708
+ zeros = []
709
+ for t in tensors:
710
+ zeros.append(torch.zeros(t.shape, dtype=t.dtype, layout=t.layout, device="cpu"))
711
+ return comfy.nested_tensor.NestedTensor(zeros)
712
+ else:
713
+ return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
714
+
715
+
716
+ class Noise_RandomNoise:
717
+ def __init__(self, seed):
718
+ self.seed = seed
719
+
720
+ def generate_noise(self, input_latent):
721
+ latent_image = input_latent["samples"]
722
+ batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None
723
+ return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds)
724
+
725
+ class SamplerCustom(io.ComfyNode):
726
+ @classmethod
727
+ def define_schema(cls):
728
+ return io.Schema(
729
+ node_id="SamplerCustom",
730
+ category="sampling/custom_sampling",
731
+ inputs=[
732
+ io.Model.Input("model"),
733
+ io.Boolean.Input("add_noise", default=True, advanced=True),
734
+ io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
735
+ io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
736
+ io.Conditioning.Input("positive"),
737
+ io.Conditioning.Input("negative"),
738
+ io.Sampler.Input("sampler"),
739
+ io.Sigmas.Input("sigmas"),
740
+ io.Latent.Input("latent_image"),
741
+ ],
742
+ outputs=[
743
+ io.Latent.Output(display_name="output"),
744
+ io.Latent.Output(display_name="denoised_output"),
745
+ ]
746
+ )
747
+
748
+ @classmethod
749
+ def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput:
750
+ latent = latent_image
751
+ latent_image = latent["samples"]
752
+ latent = latent.copy()
753
+ latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
754
+ latent["samples"] = latent_image
755
+
756
+ if not add_noise:
757
+ noise = Noise_EmptyNoise().generate_noise(latent)
758
+ else:
759
+ noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
760
+
761
+ noise_mask = None
762
+ if "noise_mask" in latent:
763
+ noise_mask = latent["noise_mask"]
764
+
765
+ x0_output = {}
766
+ callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
767
+
768
+ disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
769
+ samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
770
+
771
+ out = latent.copy()
772
+ out.pop("downscale_ratio_spacial", None)
773
+ out["samples"] = samples
774
+ if "x0" in x0_output:
775
+ x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
776
+ if samples.is_nested:
777
+ latent_shapes = [x.shape for x in samples.unbind()]
778
+ x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
779
+ out_denoised = latent.copy()
780
+ out_denoised["samples"] = x0_out
781
+ else:
782
+ out_denoised = out
783
+ return io.NodeOutput(out, out_denoised)
784
+
785
+ sample = execute
786
+
787
+ class Guider_Basic(comfy.samplers.CFGGuider):
788
+ def set_conds(self, positive):
789
+ self.inner_set_conds({"positive": positive})
790
+
791
+ class BasicGuider(io.ComfyNode):
792
+ @classmethod
793
+ def define_schema(cls):
794
+ return io.Schema(
795
+ node_id="BasicGuider",
796
+ category="sampling/custom_sampling/guiders",
797
+ inputs=[
798
+ io.Model.Input("model"),
799
+ io.Conditioning.Input("conditioning"),
800
+ ],
801
+ outputs=[io.Guider.Output()]
802
+ )
803
+
804
+ @classmethod
805
+ def execute(cls, model, conditioning) -> io.NodeOutput:
806
+ guider = Guider_Basic(model)
807
+ guider.set_conds(conditioning)
808
+ return io.NodeOutput(guider)
809
+
810
+ get_guider = execute
811
+
812
+ class CFGGuider(io.ComfyNode):
813
+ @classmethod
814
+ def define_schema(cls):
815
+ return io.Schema(
816
+ node_id="CFGGuider",
817
+ category="sampling/custom_sampling/guiders",
818
+ inputs=[
819
+ io.Model.Input("model"),
820
+ io.Conditioning.Input("positive"),
821
+ io.Conditioning.Input("negative"),
822
+ io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
823
+ ],
824
+ outputs=[io.Guider.Output()]
825
+ )
826
+
827
+ @classmethod
828
+ def execute(cls, model, positive, negative, cfg) -> io.NodeOutput:
829
+ guider = comfy.samplers.CFGGuider(model)
830
+ guider.set_conds(positive, negative)
831
+ guider.set_cfg(cfg)
832
+ return io.NodeOutput(guider)
833
+
834
+ get_guider = execute
835
+
836
+ class Guider_DualCFG(comfy.samplers.CFGGuider):
837
+ def set_cfg(self, cfg1, cfg2, nested=False):
838
+ self.cfg1 = cfg1
839
+ self.cfg2 = cfg2
840
+ self.nested = nested
841
+
842
+ def set_conds(self, positive, middle, negative):
843
+ middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"})
844
+ self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative})
845
+
846
+ def predict_noise(self, x, timestep, model_options={}, seed=None):
847
+ negative_cond = self.conds.get("negative", None)
848
+ middle_cond = self.conds.get("middle", None)
849
+ positive_cond = self.conds.get("positive", None)
850
+
851
+ if self.nested:
852
+ out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
853
+ pred_text = comfy.samplers.cfg_function(self.inner_model, out[2], out[1], self.cfg1, x, timestep, model_options=model_options, cond=positive_cond, uncond=middle_cond)
854
+ return out[0] + self.cfg2 * (pred_text - out[0])
855
+ else:
856
+ if model_options.get("disable_cfg1_optimization", False) == False:
857
+ if math.isclose(self.cfg2, 1.0):
858
+ negative_cond = None
859
+ if math.isclose(self.cfg1, 1.0):
860
+ middle_cond = None
861
+
862
+ out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options)
863
+ return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1
864
+
865
+ class DualCFGGuider(io.ComfyNode):
866
+ @classmethod
867
+ def define_schema(cls):
868
+ return io.Schema(
869
+ node_id="DualCFGGuider",
870
+ search_aliases=["dual prompt guidance"],
871
+ category="sampling/custom_sampling/guiders",
872
+ inputs=[
873
+ io.Model.Input("model"),
874
+ io.Conditioning.Input("cond1"),
875
+ io.Conditioning.Input("cond2"),
876
+ io.Conditioning.Input("negative"),
877
+ io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
878
+ io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
879
+ io.Combo.Input("style", options=["regular", "nested"]),
880
+ ],
881
+ outputs=[io.Guider.Output()]
882
+ )
883
+
884
+ @classmethod
885
+ def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput:
886
+ guider = Guider_DualCFG(model)
887
+ guider.set_conds(cond1, cond2, negative)
888
+ guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested"))
889
+ return io.NodeOutput(guider)
890
+
891
+ get_guider = execute
892
+
893
+ class DisableNoise(io.ComfyNode):
894
+ @classmethod
895
+ def define_schema(cls):
896
+ return io.Schema(
897
+ node_id="DisableNoise",
898
+ search_aliases=["zero noise"],
899
+ category="sampling/custom_sampling/noise",
900
+ inputs=[],
901
+ outputs=[io.Noise.Output()]
902
+ )
903
+
904
+ @classmethod
905
+ def execute(cls) -> io.NodeOutput:
906
+ return io.NodeOutput(Noise_EmptyNoise())
907
+
908
+ get_noise = execute
909
+
910
+
911
+ class RandomNoise(io.ComfyNode):
912
+ @classmethod
913
+ def define_schema(cls):
914
+ return io.Schema(
915
+ node_id="RandomNoise",
916
+ category="sampling/custom_sampling/noise",
917
+ inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)],
918
+ outputs=[io.Noise.Output()]
919
+ )
920
+
921
+ @classmethod
922
+ def execute(cls, noise_seed) -> io.NodeOutput:
923
+ return io.NodeOutput(Noise_RandomNoise(noise_seed))
924
+
925
+ get_noise = execute
926
+
927
+
928
+ class SamplerCustomAdvanced(io.ComfyNode):
929
+ @classmethod
930
+ def define_schema(cls):
931
+ return io.Schema(
932
+ node_id="SamplerCustomAdvanced",
933
+ category="sampling/custom_sampling",
934
+ inputs=[
935
+ io.Noise.Input("noise"),
936
+ io.Guider.Input("guider"),
937
+ io.Sampler.Input("sampler"),
938
+ io.Sigmas.Input("sigmas"),
939
+ io.Latent.Input("latent_image"),
940
+ ],
941
+ outputs=[
942
+ io.Latent.Output(display_name="output"),
943
+ io.Latent.Output(display_name="denoised_output"),
944
+ ]
945
+ )
946
+
947
+ @classmethod
948
+ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
949
+ latent = latent_image
950
+ latent_image = latent["samples"]
951
+ latent = latent.copy()
952
+ latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None))
953
+ latent["samples"] = latent_image
954
+
955
+ noise_mask = None
956
+ if "noise_mask" in latent:
957
+ noise_mask = latent["noise_mask"]
958
+
959
+ x0_output = {}
960
+ callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
961
+
962
+ disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
963
+ samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
964
+ samples = samples.to(comfy.model_management.intermediate_device())
965
+
966
+ out = latent.copy()
967
+ out.pop("downscale_ratio_spacial", None)
968
+ out["samples"] = samples
969
+ if "x0" in x0_output:
970
+ x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())
971
+ if samples.is_nested:
972
+ latent_shapes = [x.shape for x in samples.unbind()]
973
+ x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes))
974
+ out_denoised = latent.copy()
975
+ out_denoised["samples"] = x0_out
976
+ else:
977
+ out_denoised = out
978
+ return io.NodeOutput(out, out_denoised)
979
+
980
+ sample = execute
981
+
982
+ class AddNoise(io.ComfyNode):
983
+ @classmethod
984
+ def define_schema(cls):
985
+ return io.Schema(
986
+ node_id="AddNoise",
987
+ category="_for_testing/custom_sampling/noise",
988
+ is_experimental=True,
989
+ inputs=[
990
+ io.Model.Input("model"),
991
+ io.Noise.Input("noise"),
992
+ io.Sigmas.Input("sigmas"),
993
+ io.Latent.Input("latent_image"),
994
+ ],
995
+ outputs=[
996
+ io.Latent.Output(),
997
+ ]
998
+ )
999
+
1000
+ @classmethod
1001
+ def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput:
1002
+ if len(sigmas) == 0:
1003
+ return io.NodeOutput(latent_image)
1004
+
1005
+ latent = latent_image
1006
+ latent_image = latent["samples"]
1007
+
1008
+ noisy = noise.generate_noise(latent)
1009
+
1010
+ model_sampling = model.get_model_object("model_sampling")
1011
+ process_latent_out = model.get_model_object("process_latent_out")
1012
+ process_latent_in = model.get_model_object("process_latent_in")
1013
+
1014
+ if len(sigmas) > 1:
1015
+ scale = torch.abs(sigmas[0] - sigmas[-1])
1016
+ else:
1017
+ scale = sigmas[0]
1018
+
1019
+ if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
1020
+ latent_image = process_latent_in(latent_image)
1021
+ noisy = model_sampling.noise_scaling(scale, noisy, latent_image)
1022
+ noisy = process_latent_out(noisy)
1023
+ noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0)
1024
+
1025
+ out = latent.copy()
1026
+ out["samples"] = noisy
1027
+ return io.NodeOutput(out)
1028
+
1029
+ add_noise = execute
1030
+
1031
+ class ManualSigmas(io.ComfyNode):
1032
+ @classmethod
1033
+ def define_schema(cls):
1034
+ return io.Schema(
1035
+ node_id="ManualSigmas",
1036
+ search_aliases=["custom noise schedule", "define sigmas"],
1037
+ category="_for_testing/custom_sampling",
1038
+ is_experimental=True,
1039
+ inputs=[
1040
+ io.String.Input("sigmas", default="1, 0.5", multiline=False)
1041
+ ],
1042
+ outputs=[io.Sigmas.Output()]
1043
+ )
1044
+
1045
+ @classmethod
1046
+ def execute(cls, sigmas) -> io.NodeOutput:
1047
+ sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas)
1048
+ sigmas = [float(i) for i in sigmas]
1049
+ sigmas = torch.FloatTensor(sigmas)
1050
+ return io.NodeOutput(sigmas)
1051
+
1052
+ class CustomSamplersExtension(ComfyExtension):
1053
+ @override
1054
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
1055
+ return [
1056
+ SamplerCustom,
1057
+ BasicScheduler,
1058
+ KarrasScheduler,
1059
+ ExponentialScheduler,
1060
+ PolyexponentialScheduler,
1061
+ LaplaceScheduler,
1062
+ VPScheduler,
1063
+ BetaSamplingScheduler,
1064
+ SDTurboScheduler,
1065
+ KSamplerSelect,
1066
+ SamplerEulerAncestral,
1067
+ SamplerEulerAncestralCFGPP,
1068
+ SamplerLMS,
1069
+ SamplerDPMPP_3M_SDE,
1070
+ SamplerDPMPP_2M_SDE,
1071
+ SamplerDPMPP_SDE,
1072
+ SamplerDPMPP_2S_Ancestral,
1073
+ SamplerDPMAdaptative,
1074
+ SamplerER_SDE,
1075
+ SamplerSASolver,
1076
+ SamplerSEEDS2,
1077
+ SplitSigmas,
1078
+ SplitSigmasDenoise,
1079
+ FlipSigmas,
1080
+ SetFirstSigma,
1081
+ ExtendIntermediateSigmas,
1082
+ SamplingPercentToSigma,
1083
+ CFGGuider,
1084
+ DualCFGGuider,
1085
+ BasicGuider,
1086
+ RandomNoise,
1087
+ DisableNoise,
1088
+ AddNoise,
1089
+ SamplerCustomAdvanced,
1090
+ ManualSigmas,
1091
+ ]
1092
+
1093
+
1094
+ async def comfy_entrypoint() -> CustomSamplersExtension:
1095
+ return CustomSamplersExtension()
ComfyUI/comfy_extras/nodes_dataset.py ADDED
@@ -0,0 +1,1537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import json
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from typing_extensions import override
9
+
10
+ import folder_paths
11
+ import node_helpers
12
+ from comfy_api.latest import ComfyExtension, io
13
+
14
+
15
+ def load_and_process_images(image_files, input_dir):
16
+ """Utility function to load and process a list of images.
17
+
18
+ Args:
19
+ image_files: List of image filenames
20
+ input_dir: Base directory containing the images
21
+ resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
22
+
23
+ Returns:
24
+ torch.Tensor: Batch of processed images
25
+ """
26
+ if not image_files:
27
+ raise ValueError("No valid images found in input")
28
+
29
+ output_images = []
30
+
31
+ for file in image_files:
32
+ image_path = os.path.join(input_dir, file)
33
+ img = node_helpers.pillow(Image.open, image_path)
34
+
35
+ if img.mode == "I":
36
+ img = img.point(lambda i: i * (1 / 255))
37
+ img = img.convert("RGB")
38
+ img_array = np.array(img).astype(np.float32) / 255.0
39
+ img_tensor = torch.from_numpy(img_array)[None,]
40
+ output_images.append(img_tensor)
41
+
42
+ return output_images
43
+
44
+
45
+ class LoadImageDataSetFromFolderNode(io.ComfyNode):
46
+ @classmethod
47
+ def define_schema(cls):
48
+ return io.Schema(
49
+ node_id="LoadImageDataSetFromFolder",
50
+ display_name="Load Image Dataset from Folder",
51
+ category="dataset",
52
+ is_experimental=True,
53
+ inputs=[
54
+ io.Combo.Input(
55
+ "folder",
56
+ options=folder_paths.get_input_subfolders(),
57
+ tooltip="The folder to load images from.",
58
+ )
59
+ ],
60
+ outputs=[
61
+ io.Image.Output(
62
+ display_name="images",
63
+ is_output_list=True,
64
+ tooltip="List of loaded images",
65
+ )
66
+ ],
67
+ )
68
+
69
+ @classmethod
70
+ def execute(cls, folder):
71
+ sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
72
+ valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
73
+ image_files = [
74
+ f
75
+ for f in os.listdir(sub_input_dir)
76
+ if any(f.lower().endswith(ext) for ext in valid_extensions)
77
+ ]
78
+ output_tensor = load_and_process_images(image_files, sub_input_dir)
79
+ return io.NodeOutput(output_tensor)
80
+
81
+
82
+ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
83
+ @classmethod
84
+ def define_schema(cls):
85
+ return io.Schema(
86
+ node_id="LoadImageTextDataSetFromFolder",
87
+ display_name="Load Image and Text Dataset from Folder",
88
+ category="dataset",
89
+ is_experimental=True,
90
+ inputs=[
91
+ io.Combo.Input(
92
+ "folder",
93
+ options=folder_paths.get_input_subfolders(),
94
+ tooltip="The folder to load images from.",
95
+ )
96
+ ],
97
+ outputs=[
98
+ io.Image.Output(
99
+ display_name="images",
100
+ is_output_list=True,
101
+ tooltip="List of loaded images",
102
+ ),
103
+ io.String.Output(
104
+ display_name="texts",
105
+ is_output_list=True,
106
+ tooltip="List of text captions",
107
+ ),
108
+ ],
109
+ )
110
+
111
+ @classmethod
112
+ def execute(cls, folder):
113
+ logging.info(f"Loading images from folder: {folder}")
114
+
115
+ sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
116
+ valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
117
+
118
+ image_files = []
119
+ for item in os.listdir(sub_input_dir):
120
+ path = os.path.join(sub_input_dir, item)
121
+ if any(item.lower().endswith(ext) for ext in valid_extensions):
122
+ image_files.append(path)
123
+ elif os.path.isdir(path):
124
+ # Support kohya-ss/sd-scripts folder structure
125
+ repeat = 1
126
+ if item.split("_")[0].isdigit():
127
+ repeat = int(item.split("_")[0])
128
+ image_files.extend(
129
+ [
130
+ os.path.join(path, f)
131
+ for f in os.listdir(path)
132
+ if any(f.lower().endswith(ext) for ext in valid_extensions)
133
+ ]
134
+ * repeat
135
+ )
136
+
137
+ caption_file_path = [
138
+ f.replace(os.path.splitext(f)[1], ".txt") for f in image_files
139
+ ]
140
+ captions = []
141
+ for caption_file in caption_file_path:
142
+ caption_path = os.path.join(sub_input_dir, caption_file)
143
+ if os.path.exists(caption_path):
144
+ with open(caption_path, "r", encoding="utf-8") as f:
145
+ caption = f.read().strip()
146
+ captions.append(caption)
147
+ else:
148
+ captions.append("")
149
+
150
+ output_tensor = load_and_process_images(image_files, sub_input_dir)
151
+
152
+ logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
153
+ return io.NodeOutput(output_tensor, captions)
154
+
155
+
156
+ def save_images_to_folder(image_list, output_dir, prefix="image"):
157
+ """Utility function to save a list of image tensors to disk.
158
+
159
+ Args:
160
+ image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W])
161
+ output_dir: Directory to save images to
162
+ prefix: Filename prefix
163
+
164
+ Returns:
165
+ List of saved filenames
166
+ """
167
+ os.makedirs(output_dir, exist_ok=True)
168
+ saved_files = []
169
+
170
+ for idx, img_tensor in enumerate(image_list):
171
+ # Handle different tensor shapes
172
+ if isinstance(img_tensor, torch.Tensor):
173
+ # Remove batch dimension if present [1, H, W, C] -> [H, W, C]
174
+ if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
175
+ img_tensor = img_tensor.squeeze(0)
176
+
177
+ # If tensor is [C, H, W], permute to [H, W, C]
178
+ if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]:
179
+ if (
180
+ img_tensor.shape[0] <= 4
181
+ and img_tensor.shape[1] > 4
182
+ and img_tensor.shape[2] > 4
183
+ ):
184
+ img_tensor = img_tensor.permute(1, 2, 0)
185
+
186
+ # Convert to numpy and scale to 0-255
187
+ img_array = img_tensor.cpu().numpy()
188
+ img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8)
189
+
190
+ # Convert to PIL Image
191
+ img = Image.fromarray(img_array)
192
+ else:
193
+ raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}")
194
+
195
+ # Save image
196
+ filename = f"{prefix}_{idx:05d}.png"
197
+ filepath = os.path.join(output_dir, filename)
198
+ img.save(filepath)
199
+ saved_files.append(filename)
200
+
201
+ return saved_files
202
+
203
+
204
+ class SaveImageDataSetToFolderNode(io.ComfyNode):
205
+ @classmethod
206
+ def define_schema(cls):
207
+ return io.Schema(
208
+ node_id="SaveImageDataSetToFolder",
209
+ display_name="Save Image Dataset to Folder",
210
+ category="dataset",
211
+ is_experimental=True,
212
+ is_output_node=True,
213
+ is_input_list=True, # Receive images as list
214
+ inputs=[
215
+ io.Image.Input("images", tooltip="List of images to save."),
216
+ io.String.Input(
217
+ "folder_name",
218
+ default="dataset",
219
+ tooltip="Name of the folder to save images to (inside output directory).",
220
+ ),
221
+ io.String.Input(
222
+ "filename_prefix",
223
+ default="image",
224
+ tooltip="Prefix for saved image filenames.",
225
+ advanced=True,
226
+ ),
227
+ ],
228
+ outputs=[],
229
+ )
230
+
231
+ @classmethod
232
+ def execute(cls, images, folder_name, filename_prefix):
233
+ # Extract scalar values
234
+ folder_name = folder_name[0]
235
+ filename_prefix = filename_prefix[0]
236
+
237
+ output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
238
+ saved_files = save_images_to_folder(images, output_dir, filename_prefix)
239
+
240
+ logging.info(f"Saved {len(saved_files)} images to {output_dir}.")
241
+ return io.NodeOutput()
242
+
243
+
244
+ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
245
+ @classmethod
246
+ def define_schema(cls):
247
+ return io.Schema(
248
+ node_id="SaveImageTextDataSetToFolder",
249
+ display_name="Save Image and Text Dataset to Folder",
250
+ category="dataset",
251
+ is_experimental=True,
252
+ is_output_node=True,
253
+ is_input_list=True, # Receive both images and texts as lists
254
+ inputs=[
255
+ io.Image.Input("images", tooltip="List of images to save."),
256
+ io.String.Input("texts", tooltip="List of text captions to save."),
257
+ io.String.Input(
258
+ "folder_name",
259
+ default="dataset",
260
+ tooltip="Name of the folder to save images to (inside output directory).",
261
+ ),
262
+ io.String.Input(
263
+ "filename_prefix",
264
+ default="image",
265
+ tooltip="Prefix for saved image filenames.",
266
+ advanced=True,
267
+ ),
268
+ ],
269
+ outputs=[],
270
+ )
271
+
272
+ @classmethod
273
+ def execute(cls, images, texts, folder_name, filename_prefix):
274
+ # Extract scalar values
275
+ folder_name = folder_name[0]
276
+ filename_prefix = filename_prefix[0]
277
+
278
+ output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
279
+ saved_files = save_images_to_folder(images, output_dir, filename_prefix)
280
+
281
+ # Save captions
282
+ for idx, (filename, caption) in enumerate(zip(saved_files, texts)):
283
+ caption_filename = filename.replace(".png", ".txt")
284
+ caption_path = os.path.join(output_dir, caption_filename)
285
+ with open(caption_path, "w", encoding="utf-8") as f:
286
+ f.write(caption)
287
+
288
+ logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
289
+ return io.NodeOutput()
290
+
291
+
292
+ # ========== Helper Functions for Transform Nodes ==========
293
+
294
+
295
+ def tensor_to_pil(img_tensor):
296
+ """Convert tensor to PIL Image."""
297
+ if img_tensor.dim() == 4 and img_tensor.shape[0] == 1:
298
+ img_tensor = img_tensor.squeeze(0)
299
+ img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
300
+ return Image.fromarray(img_array)
301
+
302
+
303
+ def pil_to_tensor(img):
304
+ """Convert PIL Image to tensor."""
305
+ img_array = np.array(img).astype(np.float32) / 255.0
306
+ return torch.from_numpy(img_array)[None,]
307
+
308
+
309
+ # ========== Base Classes for Transform Nodes ==========
310
+
311
+
312
+ class ImageProcessingNode(io.ComfyNode):
313
+ """Base class for image processing nodes that operate on images.
314
+
315
+ Child classes should set:
316
+ node_id: Unique node identifier (required)
317
+ display_name: Display name (optional, defaults to node_id)
318
+ description: Node description (optional)
319
+ extra_inputs: List of additional io.Input objects beyond "images" (optional)
320
+ is_group_process: None (auto-detect), True (group), or False (individual) (optional)
321
+ is_output_list: True (list output) or False (single output) (optional, default True)
322
+
323
+ Child classes must implement ONE of:
324
+ _process(cls, image, **kwargs) -> tensor (for single-item processing)
325
+ _group_process(cls, images, **kwargs) -> list[tensor] (for group processing)
326
+ """
327
+
328
+ node_id = None
329
+ display_name = None
330
+ description = None
331
+ extra_inputs = []
332
+ is_group_process = None # None = auto-detect, True/False = explicit
333
+ is_output_list = None # None = auto-detect based on processing mode
334
+
335
+ @classmethod
336
+ def _detect_processing_mode(cls):
337
+ """Detect whether this node uses group or individual processing.
338
+
339
+ Returns:
340
+ bool: True if group processing, False if individual processing
341
+ """
342
+ # Explicit setting takes precedence
343
+ if cls.is_group_process is not None:
344
+ return cls.is_group_process
345
+
346
+ # Check which method is overridden by looking at the defining class in MRO
347
+ base_class = ImageProcessingNode
348
+
349
+ # Find which class in MRO defines _process
350
+ process_definer = None
351
+ for klass in cls.__mro__:
352
+ if "_process" in klass.__dict__:
353
+ process_definer = klass
354
+ break
355
+
356
+ # Find which class in MRO defines _group_process
357
+ group_definer = None
358
+ for klass in cls.__mro__:
359
+ if "_group_process" in klass.__dict__:
360
+ group_definer = klass
361
+ break
362
+
363
+ # Check what was overridden (not defined in base class)
364
+ has_process = process_definer is not None and process_definer is not base_class
365
+ has_group = group_definer is not None and group_definer is not base_class
366
+
367
+ if has_process and has_group:
368
+ raise ValueError(
369
+ f"{cls.__name__}: Cannot override both _process and _group_process. "
370
+ "Override only one, or set is_group_process explicitly."
371
+ )
372
+ if not has_process and not has_group:
373
+ raise ValueError(
374
+ f"{cls.__name__}: Must override either _process or _group_process"
375
+ )
376
+
377
+ return has_group
378
+
379
+ @classmethod
380
+ def define_schema(cls):
381
+ if cls.node_id is None:
382
+ raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
383
+
384
+ is_group = cls._detect_processing_mode()
385
+
386
+ # Auto-detect is_output_list if not explicitly set
387
+ # Single processing: False (backend collects results into list)
388
+ # Group processing: True by default (can be False for single-output nodes)
389
+ output_is_list = (
390
+ cls.is_output_list if cls.is_output_list is not None else is_group
391
+ )
392
+
393
+ inputs = [
394
+ io.Image.Input(
395
+ "images",
396
+ tooltip=(
397
+ "List of images to process." if is_group else "Image to process."
398
+ ),
399
+ )
400
+ ]
401
+ inputs.extend(cls.extra_inputs)
402
+
403
+ return io.Schema(
404
+ node_id=cls.node_id,
405
+ display_name=cls.display_name or cls.node_id,
406
+ category="dataset/image",
407
+ is_experimental=True,
408
+ is_input_list=is_group, # True for group, False for individual
409
+ inputs=inputs,
410
+ outputs=[
411
+ io.Image.Output(
412
+ display_name="images",
413
+ is_output_list=output_is_list,
414
+ tooltip="Processed images",
415
+ )
416
+ ],
417
+ )
418
+
419
+ @classmethod
420
+ def execute(cls, images, **kwargs):
421
+ """Execute the node. Routes to _process or _group_process based on mode."""
422
+ is_group = cls._detect_processing_mode()
423
+
424
+ # Extract scalar values from lists for parameters
425
+ params = {}
426
+ for k, v in kwargs.items():
427
+ if isinstance(v, list) and len(v) == 1:
428
+ params[k] = v[0]
429
+ else:
430
+ params[k] = v
431
+
432
+ if is_group:
433
+ # Group processing: images is list, call _group_process
434
+ result = cls._group_process(images, **params)
435
+ else:
436
+ # Individual processing: images is single item, call _process
437
+ result = cls._process(images, **params)
438
+
439
+ return io.NodeOutput(result)
440
+
441
+ @classmethod
442
+ def _process(cls, image, **kwargs):
443
+ """Override this method for single-item processing.
444
+
445
+ Args:
446
+ image: tensor - Single image tensor
447
+ **kwargs: Additional parameters (already extracted from lists)
448
+
449
+ Returns:
450
+ tensor - Processed image
451
+ """
452
+ raise NotImplementedError(f"{cls.__name__} must implement _process method")
453
+
454
+ @classmethod
455
+ def _group_process(cls, images, **kwargs):
456
+ """Override this method for group processing.
457
+
458
+ Args:
459
+ images: list[tensor] - List of image tensors
460
+ **kwargs: Additional parameters (already extracted from lists)
461
+
462
+ Returns:
463
+ list[tensor] - Processed images
464
+ """
465
+ raise NotImplementedError(
466
+ f"{cls.__name__} must implement _group_process method"
467
+ )
468
+
469
+
470
+ class TextProcessingNode(io.ComfyNode):
471
+ """Base class for text processing nodes that operate on texts.
472
+
473
+ Child classes should set:
474
+ node_id: Unique node identifier (required)
475
+ display_name: Display name (optional, defaults to node_id)
476
+ description: Node description (optional)
477
+ extra_inputs: List of additional io.Input objects beyond "texts" (optional)
478
+ is_group_process: None (auto-detect), True (group), or False (individual) (optional)
479
+ is_output_list: True (list output) or False (single output) (optional, default True)
480
+
481
+ Child classes must implement ONE of:
482
+ _process(cls, text, **kwargs) -> str (for single-item processing)
483
+ _group_process(cls, texts, **kwargs) -> list[str] (for group processing)
484
+ """
485
+
486
+ node_id = None
487
+ display_name = None
488
+ description = None
489
+ extra_inputs = []
490
+ is_group_process = None # None = auto-detect, True/False = explicit
491
+ is_output_list = None # None = auto-detect based on processing mode
492
+
493
+ @classmethod
494
+ def _detect_processing_mode(cls):
495
+ """Detect whether this node uses group or individual processing.
496
+
497
+ Returns:
498
+ bool: True if group processing, False if individual processing
499
+ """
500
+ # Explicit setting takes precedence
501
+ if cls.is_group_process is not None:
502
+ return cls.is_group_process
503
+
504
+ # Check which method is overridden by looking at the defining class in MRO
505
+ base_class = TextProcessingNode
506
+
507
+ # Find which class in MRO defines _process
508
+ process_definer = None
509
+ for klass in cls.__mro__:
510
+ if "_process" in klass.__dict__:
511
+ process_definer = klass
512
+ break
513
+
514
+ # Find which class in MRO defines _group_process
515
+ group_definer = None
516
+ for klass in cls.__mro__:
517
+ if "_group_process" in klass.__dict__:
518
+ group_definer = klass
519
+ break
520
+
521
+ # Check what was overridden (not defined in base class)
522
+ has_process = process_definer is not None and process_definer is not base_class
523
+ has_group = group_definer is not None and group_definer is not base_class
524
+
525
+ if has_process and has_group:
526
+ raise ValueError(
527
+ f"{cls.__name__}: Cannot override both _process and _group_process. "
528
+ "Override only one, or set is_group_process explicitly."
529
+ )
530
+ if not has_process and not has_group:
531
+ raise ValueError(
532
+ f"{cls.__name__}: Must override either _process or _group_process"
533
+ )
534
+
535
+ return has_group
536
+
537
+ @classmethod
538
+ def define_schema(cls):
539
+ if cls.node_id is None:
540
+ raise NotImplementedError(f"{cls.__name__} must set node_id class variable")
541
+
542
+ is_group = cls._detect_processing_mode()
543
+
544
+ inputs = [
545
+ io.String.Input(
546
+ "texts",
547
+ tooltip="List of texts to process." if is_group else "Text to process.",
548
+ )
549
+ ]
550
+ inputs.extend(cls.extra_inputs)
551
+
552
+ return io.Schema(
553
+ node_id=cls.node_id,
554
+ display_name=cls.display_name or cls.node_id,
555
+ category="dataset/text",
556
+ is_experimental=True,
557
+ is_input_list=is_group, # True for group, False for individual
558
+ inputs=inputs,
559
+ outputs=[
560
+ io.String.Output(
561
+ display_name="texts",
562
+ is_output_list=cls.is_output_list,
563
+ tooltip="Processed texts",
564
+ )
565
+ ],
566
+ )
567
+
568
+ @classmethod
569
+ def execute(cls, texts, **kwargs):
570
+ """Execute the node. Routes to _process or _group_process based on mode."""
571
+ is_group = cls._detect_processing_mode()
572
+
573
+ # Extract scalar values from lists for parameters
574
+ params = {}
575
+ for k, v in kwargs.items():
576
+ if isinstance(v, list) and len(v) == 1:
577
+ params[k] = v[0]
578
+ else:
579
+ params[k] = v
580
+
581
+ if is_group:
582
+ # Group processing: texts is list, call _group_process
583
+ result = cls._group_process(texts, **params)
584
+ else:
585
+ # Individual processing: texts is single item, call _process
586
+ result = cls._process(texts, **params)
587
+
588
+ # Wrap result based on is_output_list
589
+ if cls.is_output_list:
590
+ # Result should already be a list (or will be for individual)
591
+ return io.NodeOutput(result if is_group else [result])
592
+ else:
593
+ # Single output - wrap in list for NodeOutput
594
+ return io.NodeOutput([result])
595
+
596
+ @classmethod
597
+ def _process(cls, text, **kwargs):
598
+ """Override this method for single-item processing.
599
+
600
+ Args:
601
+ text: str - Single text string
602
+ **kwargs: Additional parameters (already extracted from lists)
603
+
604
+ Returns:
605
+ str - Processed text
606
+ """
607
+ raise NotImplementedError(f"{cls.__name__} must implement _process method")
608
+
609
+ @classmethod
610
+ def _group_process(cls, texts, **kwargs):
611
+ """Override this method for group processing.
612
+
613
+ Args:
614
+ texts: list[str] - List of text strings
615
+ **kwargs: Additional parameters (already extracted from lists)
616
+
617
+ Returns:
618
+ list[str] - Processed texts
619
+ """
620
+ raise NotImplementedError(
621
+ f"{cls.__name__} must implement _group_process method"
622
+ )
623
+
624
+
625
+ # ========== Image Transform Nodes ==========
626
+
627
+
628
+ class ResizeImagesByShorterEdgeNode(ImageProcessingNode):
629
+ node_id = "ResizeImagesByShorterEdge"
630
+ display_name = "Resize Images by Shorter Edge"
631
+ description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio."
632
+ extra_inputs = [
633
+ io.Int.Input(
634
+ "shorter_edge",
635
+ default=512,
636
+ min=1,
637
+ max=8192,
638
+ tooltip="Target length for the shorter edge.",
639
+ ),
640
+ ]
641
+
642
+ @classmethod
643
+ def _process(cls, image, shorter_edge):
644
+ img = tensor_to_pil(image)
645
+ w, h = img.size
646
+ if w < h:
647
+ new_w = shorter_edge
648
+ new_h = int(h * (shorter_edge / w))
649
+ else:
650
+ new_h = shorter_edge
651
+ new_w = int(w * (shorter_edge / h))
652
+ img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
653
+ return pil_to_tensor(img)
654
+
655
+
656
+ class ResizeImagesByLongerEdgeNode(ImageProcessingNode):
657
+ node_id = "ResizeImagesByLongerEdge"
658
+ display_name = "Resize Images by Longer Edge"
659
+ description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio."
660
+ extra_inputs = [
661
+ io.Int.Input(
662
+ "longer_edge",
663
+ default=1024,
664
+ min=1,
665
+ max=8192,
666
+ tooltip="Target length for the longer edge.",
667
+ ),
668
+ ]
669
+
670
+ @classmethod
671
+ def _process(cls, image, longer_edge):
672
+ resized_images = []
673
+ for image_i in image:
674
+ img = tensor_to_pil(image_i)
675
+ w, h = img.size
676
+ if w > h:
677
+ new_w = longer_edge
678
+ new_h = int(h * (longer_edge / w))
679
+ else:
680
+ new_h = longer_edge
681
+ new_w = int(w * (longer_edge / h))
682
+ img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
683
+ resized_images.append(pil_to_tensor(img))
684
+ return torch.cat(resized_images, dim=0)
685
+
686
+
687
+ class CenterCropImagesNode(ImageProcessingNode):
688
+ node_id = "CenterCropImages"
689
+ display_name = "Center Crop Images"
690
+ description = "Center crop all images to the specified dimensions."
691
+ extra_inputs = [
692
+ io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
693
+ io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
694
+ ]
695
+
696
+ @classmethod
697
+ def _process(cls, image, width, height):
698
+ img = tensor_to_pil(image)
699
+ left = max(0, (img.width - width) // 2)
700
+ top = max(0, (img.height - height) // 2)
701
+ right = min(img.width, left + width)
702
+ bottom = min(img.height, top + height)
703
+ img = img.crop((left, top, right, bottom))
704
+ return pil_to_tensor(img)
705
+
706
+
707
+ class RandomCropImagesNode(ImageProcessingNode):
708
+ node_id = "RandomCropImages"
709
+ display_name = "Random Crop Images"
710
+ description = (
711
+ "Randomly crop all images to the specified dimensions (for data augmentation)."
712
+ )
713
+ extra_inputs = [
714
+ io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."),
715
+ io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."),
716
+ io.Int.Input(
717
+ "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
718
+ ),
719
+ ]
720
+
721
+ @classmethod
722
+ def _process(cls, image, width, height, seed):
723
+ np.random.seed(seed % (2**32 - 1))
724
+ img = tensor_to_pil(image)
725
+ max_left = max(0, img.width - width)
726
+ max_top = max(0, img.height - height)
727
+ left = np.random.randint(0, max_left + 1) if max_left > 0 else 0
728
+ top = np.random.randint(0, max_top + 1) if max_top > 0 else 0
729
+ right = min(img.width, left + width)
730
+ bottom = min(img.height, top + height)
731
+ img = img.crop((left, top, right, bottom))
732
+ return pil_to_tensor(img)
733
+
734
+
735
+ class NormalizeImagesNode(ImageProcessingNode):
736
+ node_id = "NormalizeImages"
737
+ display_name = "Normalize Images"
738
+ description = "Normalize images using mean and standard deviation."
739
+ extra_inputs = [
740
+ io.Float.Input(
741
+ "mean",
742
+ default=0.5,
743
+ min=0.0,
744
+ max=1.0,
745
+ tooltip="Mean value for normalization.",
746
+ advanced=True,
747
+ ),
748
+ io.Float.Input(
749
+ "std",
750
+ default=0.5,
751
+ min=0.001,
752
+ max=1.0,
753
+ tooltip="Standard deviation for normalization.",
754
+ advanced=True,
755
+ ),
756
+ ]
757
+
758
+ @classmethod
759
+ def _process(cls, image, mean, std):
760
+ return (image - mean) / std
761
+
762
+
763
+ class AdjustBrightnessNode(ImageProcessingNode):
764
+ node_id = "AdjustBrightness"
765
+ display_name = "Adjust Brightness"
766
+ description = "Adjust brightness of all images."
767
+ extra_inputs = [
768
+ io.Float.Input(
769
+ "factor",
770
+ default=1.0,
771
+ min=0.0,
772
+ max=2.0,
773
+ tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.",
774
+ ),
775
+ ]
776
+
777
+ @classmethod
778
+ def _process(cls, image, factor):
779
+ return (image * factor).clamp(0.0, 1.0)
780
+
781
+
782
+ class AdjustContrastNode(ImageProcessingNode):
783
+ node_id = "AdjustContrast"
784
+ display_name = "Adjust Contrast"
785
+ description = "Adjust contrast of all images."
786
+ extra_inputs = [
787
+ io.Float.Input(
788
+ "factor",
789
+ default=1.0,
790
+ min=0.0,
791
+ max=2.0,
792
+ tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.",
793
+ ),
794
+ ]
795
+
796
+ @classmethod
797
+ def _process(cls, image, factor):
798
+ return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0)
799
+
800
+
801
+ class ShuffleDatasetNode(ImageProcessingNode):
802
+ node_id = "ShuffleDataset"
803
+ display_name = "Shuffle Image Dataset"
804
+ description = "Randomly shuffle the order of images in the dataset."
805
+ is_group_process = True # Requires full list to shuffle
806
+ extra_inputs = [
807
+ io.Int.Input(
808
+ "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed."
809
+ ),
810
+ ]
811
+
812
+ @classmethod
813
+ def _group_process(cls, images, seed):
814
+ np.random.seed(seed % (2**32 - 1))
815
+ indices = np.random.permutation(len(images))
816
+ return [images[i] for i in indices]
817
+
818
+
819
+ class ShuffleImageTextDatasetNode(io.ComfyNode):
820
+ """Special node that shuffles both images and texts together."""
821
+
822
+ @classmethod
823
+ def define_schema(cls):
824
+ return io.Schema(
825
+ node_id="ShuffleImageTextDataset",
826
+ display_name="Shuffle Image-Text Dataset",
827
+ category="dataset/image",
828
+ is_experimental=True,
829
+ is_input_list=True,
830
+ inputs=[
831
+ io.Image.Input("images", tooltip="List of images to shuffle."),
832
+ io.String.Input("texts", tooltip="List of texts to shuffle."),
833
+ io.Int.Input(
834
+ "seed",
835
+ default=0,
836
+ min=0,
837
+ max=0xFFFFFFFFFFFFFFFF,
838
+ tooltip="Random seed.",
839
+ ),
840
+ ],
841
+ outputs=[
842
+ io.Image.Output(
843
+ display_name="images",
844
+ is_output_list=True,
845
+ tooltip="Shuffled images",
846
+ ),
847
+ io.String.Output(
848
+ display_name="texts", is_output_list=True, tooltip="Shuffled texts"
849
+ ),
850
+ ],
851
+ )
852
+
853
+ @classmethod
854
+ def execute(cls, images, texts, seed):
855
+ seed = seed[0] # Extract scalar
856
+ np.random.seed(seed % (2**32 - 1))
857
+ indices = np.random.permutation(len(images))
858
+ shuffled_images = [images[i] for i in indices]
859
+ shuffled_texts = [texts[i] for i in indices]
860
+ return io.NodeOutput(shuffled_images, shuffled_texts)
861
+
862
+
863
+ # ========== Text Transform Nodes ==========
864
+
865
+
866
+ class TextToLowercaseNode(TextProcessingNode):
867
+ node_id = "TextToLowercase"
868
+ display_name = "Text to Lowercase"
869
+ description = "Convert all texts to lowercase."
870
+
871
+ @classmethod
872
+ def _process(cls, text):
873
+ return text.lower()
874
+
875
+
876
+ class TextToUppercaseNode(TextProcessingNode):
877
+ node_id = "TextToUppercase"
878
+ display_name = "Text to Uppercase"
879
+ description = "Convert all texts to uppercase."
880
+
881
+ @classmethod
882
+ def _process(cls, text):
883
+ return text.upper()
884
+
885
+
886
+ class TruncateTextNode(TextProcessingNode):
887
+ node_id = "TruncateText"
888
+ display_name = "Truncate Text"
889
+ description = "Truncate all texts to a maximum length."
890
+ extra_inputs = [
891
+ io.Int.Input(
892
+ "max_length", default=77, min=1, max=10000, tooltip="Maximum text length."
893
+ ),
894
+ ]
895
+
896
+ @classmethod
897
+ def _process(cls, text, max_length):
898
+ return text[:max_length]
899
+
900
+
901
+ class AddTextPrefixNode(TextProcessingNode):
902
+ node_id = "AddTextPrefix"
903
+ display_name = "Add Text Prefix"
904
+ description = "Add a prefix to all texts."
905
+ extra_inputs = [
906
+ io.String.Input("prefix", default="", tooltip="Prefix to add."),
907
+ ]
908
+
909
+ @classmethod
910
+ def _process(cls, text, prefix):
911
+ return prefix + text
912
+
913
+
914
+ class AddTextSuffixNode(TextProcessingNode):
915
+ node_id = "AddTextSuffix"
916
+ display_name = "Add Text Suffix"
917
+ description = "Add a suffix to all texts."
918
+ extra_inputs = [
919
+ io.String.Input("suffix", default="", tooltip="Suffix to add."),
920
+ ]
921
+
922
+ @classmethod
923
+ def _process(cls, text, suffix):
924
+ return text + suffix
925
+
926
+
927
+ class ReplaceTextNode(TextProcessingNode):
928
+ node_id = "ReplaceText"
929
+ display_name = "Replace Text"
930
+ description = "Replace text in all texts."
931
+ extra_inputs = [
932
+ io.String.Input("find", default="", tooltip="Text to find."),
933
+ io.String.Input("replace", default="", tooltip="Text to replace with."),
934
+ ]
935
+
936
+ @classmethod
937
+ def _process(cls, text, find, replace):
938
+ return text.replace(find, replace)
939
+
940
+
941
+ class StripWhitespaceNode(TextProcessingNode):
942
+ node_id = "StripWhitespace"
943
+ display_name = "Strip Whitespace"
944
+ description = "Strip leading and trailing whitespace from all texts."
945
+
946
+ @classmethod
947
+ def _process(cls, text):
948
+ return text.strip()
949
+
950
+
951
+ # ========== Group Processing Example Nodes ==========
952
+
953
+
954
+ class ImageDeduplicationNode(ImageProcessingNode):
955
+ """Remove duplicate or very similar images from the dataset using perceptual hashing."""
956
+
957
+ node_id = "ImageDeduplication"
958
+ display_name = "Image Deduplication"
959
+ description = "Remove duplicate or very similar images from the dataset."
960
+ is_group_process = True # Requires full list to compare images
961
+ extra_inputs = [
962
+ io.Float.Input(
963
+ "similarity_threshold",
964
+ default=0.95,
965
+ min=0.0,
966
+ max=1.0,
967
+ tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.",
968
+ advanced=True,
969
+ ),
970
+ ]
971
+
972
+ @classmethod
973
+ def _group_process(cls, images, similarity_threshold):
974
+ """Remove duplicate images using perceptual hashing."""
975
+ if len(images) == 0:
976
+ return []
977
+
978
+ # Compute simple perceptual hash for each image
979
+ def compute_hash(img_tensor):
980
+ """Compute a simple perceptual hash by resizing to 8x8 and comparing to average."""
981
+ img = tensor_to_pil(img_tensor)
982
+ # Resize to 8x8
983
+ img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L")
984
+ # Get pixels
985
+ pixels = list(img_small.getdata())
986
+ # Compute average
987
+ avg = sum(pixels) / len(pixels)
988
+ # Create hash (1 if above average, 0 otherwise)
989
+ hash_bits = "".join("1" if p > avg else "0" for p in pixels)
990
+ return hash_bits
991
+
992
+ def hamming_distance(hash1, hash2):
993
+ """Compute Hamming distance between two hash strings."""
994
+ return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
995
+
996
+ # Compute hashes for all images
997
+ hashes = [compute_hash(img) for img in images]
998
+
999
+ # Find duplicates
1000
+ keep_indices = []
1001
+ for i in range(len(images)):
1002
+ is_duplicate = False
1003
+ for j in keep_indices:
1004
+ # Compare hashes
1005
+ distance = hamming_distance(hashes[i], hashes[j])
1006
+ similarity = 1.0 - (distance / 64.0) # 64 bits total
1007
+ if similarity >= similarity_threshold:
1008
+ is_duplicate = True
1009
+ logging.info(
1010
+ f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping"
1011
+ )
1012
+ break
1013
+
1014
+ if not is_duplicate:
1015
+ keep_indices.append(i)
1016
+
1017
+ # Return only unique images
1018
+ unique_images = [images[i] for i in keep_indices]
1019
+ logging.info(
1020
+ f"Deduplication: kept {len(unique_images)} out of {len(images)} images"
1021
+ )
1022
+ return unique_images
1023
+
1024
+
1025
+ class ImageGridNode(ImageProcessingNode):
1026
+ """Combine multiple images into a single grid/collage."""
1027
+
1028
+ node_id = "ImageGrid"
1029
+ display_name = "Image Grid"
1030
+ description = "Arrange multiple images into a grid layout."
1031
+ is_group_process = True # Requires full list to create grid
1032
+ is_output_list = False # Outputs single grid image
1033
+ extra_inputs = [
1034
+ io.Int.Input(
1035
+ "columns",
1036
+ default=4,
1037
+ min=1,
1038
+ max=20,
1039
+ tooltip="Number of columns in the grid.",
1040
+ ),
1041
+ io.Int.Input(
1042
+ "cell_width",
1043
+ default=256,
1044
+ min=32,
1045
+ max=2048,
1046
+ tooltip="Width of each cell in the grid.",
1047
+ advanced=True,
1048
+ ),
1049
+ io.Int.Input(
1050
+ "cell_height",
1051
+ default=256,
1052
+ min=32,
1053
+ max=2048,
1054
+ tooltip="Height of each cell in the grid.",
1055
+ advanced=True,
1056
+ ),
1057
+ io.Int.Input(
1058
+ "padding", default=4, min=0, max=50, tooltip="Padding between images.", advanced=True
1059
+ ),
1060
+ ]
1061
+
1062
+ @classmethod
1063
+ def _group_process(cls, images, columns, cell_width, cell_height, padding):
1064
+ """Arrange images into a grid."""
1065
+ if len(images) == 0:
1066
+ raise ValueError("Cannot create grid from empty image list")
1067
+
1068
+ # Calculate grid dimensions
1069
+ num_images = len(images)
1070
+ rows = (num_images + columns - 1) // columns # Ceiling division
1071
+
1072
+ # Calculate total grid size
1073
+ grid_width = columns * cell_width + (columns - 1) * padding
1074
+ grid_height = rows * cell_height + (rows - 1) * padding
1075
+
1076
+ # Create blank grid
1077
+ grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0))
1078
+
1079
+ # Place images
1080
+ for idx, img_tensor in enumerate(images):
1081
+ row = idx // columns
1082
+ col = idx % columns
1083
+
1084
+ # Convert to PIL and resize to cell size
1085
+ img = tensor_to_pil(img_tensor)
1086
+ img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS)
1087
+
1088
+ # Calculate position
1089
+ x = col * (cell_width + padding)
1090
+ y = row * (cell_height + padding)
1091
+
1092
+ # Paste into grid
1093
+ grid.paste(img, (x, y))
1094
+
1095
+ logging.info(
1096
+ f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})"
1097
+ )
1098
+ return pil_to_tensor(grid)
1099
+
1100
+
1101
+ class MergeImageListsNode(ImageProcessingNode):
1102
+ """Merge multiple image lists into a single list."""
1103
+
1104
+ node_id = "MergeImageLists"
1105
+ display_name = "Merge Image Lists"
1106
+ description = "Concatenate multiple image lists into one."
1107
+ is_group_process = True # Receives images as list
1108
+
1109
+ @classmethod
1110
+ def _group_process(cls, images):
1111
+ """Simply return the images list (already merged by input handling)."""
1112
+ # When multiple list inputs are connected, they're concatenated
1113
+ # For now, this is a simple pass-through
1114
+ logging.info(f"Merged image list contains {len(images)} images")
1115
+ return images
1116
+
1117
+
1118
+ class MergeTextListsNode(TextProcessingNode):
1119
+ """Merge multiple text lists into a single list."""
1120
+
1121
+ node_id = "MergeTextLists"
1122
+ display_name = "Merge Text Lists"
1123
+ description = "Concatenate multiple text lists into one."
1124
+ is_group_process = True # Receives texts as list
1125
+
1126
+ @classmethod
1127
+ def _group_process(cls, texts):
1128
+ """Simply return the texts list (already merged by input handling)."""
1129
+ # When multiple list inputs are connected, they're concatenated
1130
+ # For now, this is a simple pass-through
1131
+ logging.info(f"Merged text list contains {len(texts)} texts")
1132
+ return texts
1133
+
1134
+
1135
+ # ========== Training Dataset Nodes ==========
1136
+
1137
+
1138
+ class ResolutionBucket(io.ComfyNode):
1139
+ """Bucket latents and conditions by resolution for efficient batch training."""
1140
+
1141
+ @classmethod
1142
+ def define_schema(cls):
1143
+ return io.Schema(
1144
+ node_id="ResolutionBucket",
1145
+ display_name="Resolution Bucket",
1146
+ category="dataset",
1147
+ is_experimental=True,
1148
+ is_input_list=True,
1149
+ inputs=[
1150
+ io.Latent.Input(
1151
+ "latents",
1152
+ tooltip="List of latent dicts to bucket by resolution.",
1153
+ ),
1154
+ io.Conditioning.Input(
1155
+ "conditioning",
1156
+ tooltip="List of conditioning lists (must match latents length).",
1157
+ ),
1158
+ ],
1159
+ outputs=[
1160
+ io.Latent.Output(
1161
+ display_name="latents",
1162
+ is_output_list=True,
1163
+ tooltip="List of batched latent dicts, one per resolution bucket.",
1164
+ ),
1165
+ io.Conditioning.Output(
1166
+ display_name="conditioning",
1167
+ is_output_list=True,
1168
+ tooltip="List of condition lists, one per resolution bucket.",
1169
+ ),
1170
+ ],
1171
+ )
1172
+
1173
+ @classmethod
1174
+ def execute(cls, latents, conditioning):
1175
+ # latents: list[{"samples": tensor}] where tensor is (B, C, H, W), typically B=1
1176
+ # conditioning: list[list[cond]]
1177
+
1178
+ # Validate lengths match
1179
+ if len(latents) != len(conditioning):
1180
+ raise ValueError(
1181
+ f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)})."
1182
+ )
1183
+
1184
+ # Flatten latents and conditions to individual samples
1185
+ flat_latents = [] # list of (C, H, W) tensors
1186
+ flat_conditions = [] # list of condition lists
1187
+
1188
+ for latent_dict, cond in zip(latents, conditioning):
1189
+ samples = latent_dict["samples"] # (B, C, H, W)
1190
+ batch_size = samples.shape[0]
1191
+
1192
+ # cond is a list of conditions with length == batch_size
1193
+ for i in range(batch_size):
1194
+ flat_latents.append(samples[i]) # (C, H, W)
1195
+ flat_conditions.append(cond[i]) # single condition
1196
+
1197
+ # Group by resolution (H, W)
1198
+ buckets = {} # (H, W) -> {"latents": list, "conditions": list}
1199
+
1200
+ for latent, cond in zip(flat_latents, flat_conditions):
1201
+ # latent shape is (..., H, W) (B, C, H, W) or (B, T, C, H ,W)
1202
+ h, w = latent.shape[-2], latent.shape[-1]
1203
+ key = (h, w)
1204
+
1205
+ if key not in buckets:
1206
+ buckets[key] = {"latents": [], "conditions": []}
1207
+
1208
+ buckets[key]["latents"].append(latent)
1209
+ buckets[key]["conditions"].append(cond)
1210
+
1211
+ # Convert buckets to output format
1212
+ output_latents = [] # list[{"samples": tensor}] where tensor is (Bi, ..., H, W)
1213
+ output_conditions = [] # list[list[cond]] where each inner list has Bi conditions
1214
+
1215
+ for (h, w), bucket_data in buckets.items():
1216
+ # Stack latents into batch: list of (..., H, W) -> (Bi, ..., H, W)
1217
+ stacked_latents = torch.stack(bucket_data["latents"], dim=0)
1218
+ output_latents.append({"samples": stacked_latents})
1219
+
1220
+ # Conditions stay as list of condition lists
1221
+ output_conditions.append(bucket_data["conditions"])
1222
+
1223
+ logging.info(
1224
+ f"Resolution bucket ({h}x{w}): {len(bucket_data['latents'])} samples"
1225
+ )
1226
+
1227
+ logging.info(f"Created {len(buckets)} resolution buckets from {len(flat_latents)} samples")
1228
+ return io.NodeOutput(output_latents, output_conditions)
1229
+
1230
+
1231
+ class MakeTrainingDataset(io.ComfyNode):
1232
+ """Encode images with VAE and texts with CLIP to create a training dataset."""
1233
+ @classmethod
1234
+ def define_schema(cls):
1235
+ return io.Schema(
1236
+ node_id="MakeTrainingDataset",
1237
+ search_aliases=["encode dataset"],
1238
+ display_name="Make Training Dataset",
1239
+ category="dataset",
1240
+ is_experimental=True,
1241
+ is_input_list=True, # images and texts as lists
1242
+ inputs=[
1243
+ io.Image.Input("images", tooltip="List of images to encode."),
1244
+ io.Vae.Input(
1245
+ "vae", tooltip="VAE model for encoding images to latents."
1246
+ ),
1247
+ io.Clip.Input(
1248
+ "clip", tooltip="CLIP model for encoding text to conditioning."
1249
+ ),
1250
+ io.String.Input(
1251
+ "texts",
1252
+ optional=True,
1253
+ tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).",
1254
+ ),
1255
+ ],
1256
+ outputs=[
1257
+ io.Latent.Output(
1258
+ display_name="latents",
1259
+ is_output_list=True,
1260
+ tooltip="List of latent dicts",
1261
+ ),
1262
+ io.Conditioning.Output(
1263
+ display_name="conditioning",
1264
+ is_output_list=True,
1265
+ tooltip="List of conditioning lists",
1266
+ ),
1267
+ ],
1268
+ )
1269
+
1270
+ @classmethod
1271
+ def execute(cls, images, vae, clip, texts=None):
1272
+ # Extract scalars (vae and clip are single values wrapped in lists)
1273
+ vae = vae[0]
1274
+ clip = clip[0]
1275
+
1276
+ # Handle text list
1277
+ num_images = len(images)
1278
+
1279
+ if texts is None or len(texts) == 0:
1280
+ # Treat as [""] for unconditional training
1281
+ texts = [""]
1282
+
1283
+ if len(texts) == 1 and num_images > 1:
1284
+ # Repeat single text for all images
1285
+ texts = texts * num_images
1286
+ elif len(texts) != num_images:
1287
+ raise ValueError(
1288
+ f"Number of texts ({len(texts)}) does not match number of images ({num_images}). "
1289
+ f"Text list should have length {num_images}, 1, or 0."
1290
+ )
1291
+
1292
+ # Encode images with VAE
1293
+ logging.info(f"Encoding {num_images} images with VAE...")
1294
+ latents_list = [] # list[{"samples": tensor}]
1295
+ for img_tensor in images:
1296
+ # img_tensor is [1, H, W, 3]
1297
+ latent_tensor = vae.encode(img_tensor[:, :, :, :3])
1298
+ latents_list.append({"samples": latent_tensor})
1299
+
1300
+ # Encode texts with CLIP
1301
+ logging.info(f"Encoding {len(texts)} texts with CLIP...")
1302
+ conditioning_list = [] # list[list[cond]]
1303
+ for text in texts:
1304
+ if text == "":
1305
+ cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
1306
+ else:
1307
+ tokens = clip.tokenize(text)
1308
+ cond = clip.encode_from_tokens_scheduled(tokens)
1309
+ conditioning_list.append(cond)
1310
+
1311
+ logging.info(
1312
+ f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning."
1313
+ )
1314
+ return io.NodeOutput(latents_list, conditioning_list)
1315
+
1316
+
1317
+ class SaveTrainingDataset(io.ComfyNode):
1318
+ """Save encoded training dataset (latents + conditioning) to disk."""
1319
+ @classmethod
1320
+ def define_schema(cls):
1321
+ return io.Schema(
1322
+ node_id="SaveTrainingDataset",
1323
+ search_aliases=["export training data"],
1324
+ display_name="Save Training Dataset",
1325
+ category="dataset",
1326
+ is_experimental=True,
1327
+ is_output_node=True,
1328
+ is_input_list=True, # Receive lists
1329
+ inputs=[
1330
+ io.Latent.Input(
1331
+ "latents",
1332
+ tooltip="List of latent dicts from MakeTrainingDataset.",
1333
+ ),
1334
+ io.Conditioning.Input(
1335
+ "conditioning",
1336
+ tooltip="List of conditioning lists from MakeTrainingDataset.",
1337
+ ),
1338
+ io.String.Input(
1339
+ "folder_name",
1340
+ default="training_dataset",
1341
+ tooltip="Name of folder to save dataset (inside output directory).",
1342
+ ),
1343
+ io.Int.Input(
1344
+ "shard_size",
1345
+ default=1000,
1346
+ min=1,
1347
+ max=100000,
1348
+ tooltip="Number of samples per shard file.",
1349
+ advanced=True,
1350
+ ),
1351
+ ],
1352
+ outputs=[],
1353
+ )
1354
+
1355
+ @classmethod
1356
+ def execute(cls, latents, conditioning, folder_name, shard_size):
1357
+ # Extract scalars
1358
+ folder_name = folder_name[0]
1359
+ shard_size = shard_size[0]
1360
+
1361
+ # latents: list[{"samples": tensor}]
1362
+ # conditioning: list[list[cond]]
1363
+
1364
+ # Validate lengths match
1365
+ if len(latents) != len(conditioning):
1366
+ raise ValueError(
1367
+ f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). "
1368
+ f"Something went wrong in dataset preparation."
1369
+ )
1370
+
1371
+ # Create output directory
1372
+ output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
1373
+ os.makedirs(output_dir, exist_ok=True)
1374
+
1375
+ # Prepare data pairs
1376
+ num_samples = len(latents)
1377
+ num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
1378
+
1379
+ logging.info(
1380
+ f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
1381
+ )
1382
+
1383
+ # Save data in shards
1384
+ for shard_idx in range(num_shards):
1385
+ start_idx = shard_idx * shard_size
1386
+ end_idx = min(start_idx + shard_size, num_samples)
1387
+
1388
+ # Get shard data (list of latent dicts and conditioning lists)
1389
+ shard_data = {
1390
+ "latents": latents[start_idx:end_idx],
1391
+ "conditioning": conditioning[start_idx:end_idx],
1392
+ }
1393
+
1394
+ # Save shard
1395
+ shard_filename = f"shard_{shard_idx:04d}.pkl"
1396
+ shard_path = os.path.join(output_dir, shard_filename)
1397
+
1398
+ with open(shard_path, "wb") as f:
1399
+ torch.save(shard_data, f)
1400
+
1401
+ logging.info(
1402
+ f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
1403
+ )
1404
+
1405
+ # Save metadata
1406
+ metadata = {
1407
+ "num_samples": num_samples,
1408
+ "num_shards": num_shards,
1409
+ "shard_size": shard_size,
1410
+ }
1411
+ metadata_path = os.path.join(output_dir, "metadata.json")
1412
+ with open(metadata_path, "w") as f:
1413
+ json.dump(metadata, f, indent=2)
1414
+
1415
+ logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
1416
+ return io.NodeOutput()
1417
+
1418
+
1419
+ class LoadTrainingDataset(io.ComfyNode):
1420
+ """Load encoded training dataset from disk."""
1421
+ @classmethod
1422
+ def define_schema(cls):
1423
+ return io.Schema(
1424
+ node_id="LoadTrainingDataset",
1425
+ search_aliases=["import dataset", "training data"],
1426
+ display_name="Load Training Dataset",
1427
+ category="dataset",
1428
+ is_experimental=True,
1429
+ inputs=[
1430
+ io.String.Input(
1431
+ "folder_name",
1432
+ default="training_dataset",
1433
+ tooltip="Name of folder containing the saved dataset (inside output directory).",
1434
+ ),
1435
+ ],
1436
+ outputs=[
1437
+ io.Latent.Output(
1438
+ display_name="latents",
1439
+ is_output_list=True,
1440
+ tooltip="List of latent dicts",
1441
+ ),
1442
+ io.Conditioning.Output(
1443
+ display_name="conditioning",
1444
+ is_output_list=True,
1445
+ tooltip="List of conditioning lists",
1446
+ ),
1447
+ ],
1448
+ )
1449
+
1450
+ @classmethod
1451
+ def execute(cls, folder_name):
1452
+ # Get dataset directory
1453
+ dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
1454
+
1455
+ if not os.path.exists(dataset_dir):
1456
+ raise ValueError(f"Dataset directory not found: {dataset_dir}")
1457
+
1458
+ # Find all shard files
1459
+ shard_files = sorted(
1460
+ [
1461
+ f
1462
+ for f in os.listdir(dataset_dir)
1463
+ if f.startswith("shard_") and f.endswith(".pkl")
1464
+ ]
1465
+ )
1466
+
1467
+ if not shard_files:
1468
+ raise ValueError(f"No shard files found in {dataset_dir}")
1469
+
1470
+ logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
1471
+
1472
+ # Load all shards
1473
+ all_latents = [] # list[{"samples": tensor}]
1474
+ all_conditioning = [] # list[list[cond]]
1475
+
1476
+ for shard_file in shard_files:
1477
+ shard_path = os.path.join(dataset_dir, shard_file)
1478
+
1479
+ with open(shard_path, "rb") as f:
1480
+ shard_data = torch.load(f)
1481
+
1482
+ all_latents.extend(shard_data["latents"])
1483
+ all_conditioning.extend(shard_data["conditioning"])
1484
+
1485
+ logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
1486
+
1487
+ logging.info(
1488
+ f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."
1489
+ )
1490
+ return io.NodeOutput(all_latents, all_conditioning)
1491
+
1492
+
1493
+ # ========== Extension Setup ==========
1494
+
1495
+
1496
+ class DatasetExtension(ComfyExtension):
1497
+ @override
1498
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
1499
+ return [
1500
+ # Data loading/saving nodes
1501
+ LoadImageDataSetFromFolderNode,
1502
+ LoadImageTextDataSetFromFolderNode,
1503
+ SaveImageDataSetToFolderNode,
1504
+ SaveImageTextDataSetToFolderNode,
1505
+ # Image transform nodes
1506
+ ResizeImagesByShorterEdgeNode,
1507
+ ResizeImagesByLongerEdgeNode,
1508
+ CenterCropImagesNode,
1509
+ RandomCropImagesNode,
1510
+ NormalizeImagesNode,
1511
+ AdjustBrightnessNode,
1512
+ AdjustContrastNode,
1513
+ ShuffleDatasetNode,
1514
+ ShuffleImageTextDatasetNode,
1515
+ # Text transform nodes
1516
+ TextToLowercaseNode,
1517
+ TextToUppercaseNode,
1518
+ TruncateTextNode,
1519
+ AddTextPrefixNode,
1520
+ AddTextSuffixNode,
1521
+ ReplaceTextNode,
1522
+ StripWhitespaceNode,
1523
+ # Group processing examples
1524
+ ImageDeduplicationNode,
1525
+ ImageGridNode,
1526
+ MergeImageListsNode,
1527
+ MergeTextListsNode,
1528
+ # Training dataset nodes
1529
+ MakeTrainingDataset,
1530
+ SaveTrainingDataset,
1531
+ LoadTrainingDataset,
1532
+ ResolutionBucket,
1533
+ ]
1534
+
1535
+
1536
+ async def comfy_entrypoint() -> DatasetExtension:
1537
+ return DatasetExtension()
ComfyUI/comfy_extras/nodes_differential_diffusion.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from https://github.com/exx8/differential-diffusion
2
+
3
+ from typing_extensions import override
4
+
5
+ import torch
6
+ from comfy_api.latest import ComfyExtension, io
7
+
8
+
9
+ class DifferentialDiffusion(io.ComfyNode):
10
+ @classmethod
11
+ def define_schema(cls):
12
+ return io.Schema(
13
+ node_id="DifferentialDiffusion",
14
+ search_aliases=["inpaint gradient", "variable denoise strength"],
15
+ display_name="Differential Diffusion",
16
+ category="_for_testing",
17
+ inputs=[
18
+ io.Model.Input("model"),
19
+ io.Float.Input(
20
+ "strength",
21
+ default=1.0,
22
+ min=0.0,
23
+ max=1.0,
24
+ step=0.01,
25
+ optional=True,
26
+ ),
27
+ ],
28
+ outputs=[io.Model.Output()],
29
+ is_experimental=True,
30
+ )
31
+
32
+ @classmethod
33
+ def execute(cls, model, strength=1.0) -> io.NodeOutput:
34
+ model = model.clone()
35
+ model.set_model_denoise_mask_function(lambda *args, **kwargs: cls.forward(*args, **kwargs, strength=strength))
36
+ return io.NodeOutput(model)
37
+
38
+ @classmethod
39
+ def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict, strength: float):
40
+ model = extra_options["model"]
41
+ step_sigmas = extra_options["sigmas"]
42
+ sigma_to = model.inner_model.model_sampling.sigma_min
43
+ if step_sigmas[-1] > sigma_to:
44
+ sigma_to = step_sigmas[-1]
45
+ sigma_from = step_sigmas[0]
46
+
47
+ ts_from = model.inner_model.model_sampling.timestep(sigma_from)
48
+ ts_to = model.inner_model.model_sampling.timestep(sigma_to)
49
+ current_ts = model.inner_model.model_sampling.timestep(sigma[0])
50
+
51
+ threshold = (current_ts - ts_to) / (ts_from - ts_to)
52
+
53
+ # Generate the binary mask based on the threshold
54
+ binary_mask = (denoise_mask >= threshold).to(denoise_mask.dtype)
55
+
56
+ # Blend binary mask with the original denoise_mask using strength
57
+ if strength and strength < 1:
58
+ blended_mask = strength * binary_mask + (1 - strength) * denoise_mask
59
+ return blended_mask
60
+ else:
61
+ return binary_mask
62
+
63
+
64
+ class DifferentialDiffusionExtension(ComfyExtension):
65
+ @override
66
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
67
+ return [
68
+ DifferentialDiffusion,
69
+ ]
70
+
71
+
72
+ async def comfy_entrypoint() -> DifferentialDiffusionExtension:
73
+ return DifferentialDiffusionExtension()
ComfyUI/comfy_extras/nodes_easycache.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Union
3
+ from comfy_api.latest import io, ComfyExtension
4
+ import comfy.patcher_extension
5
+ import logging
6
+ import torch
7
+ import comfy.model_patcher
8
+ if TYPE_CHECKING:
9
+ from uuid import UUID
10
+
11
+
12
+ def _extract_tensor(data, output_channels):
13
+ """Extract tensor from data, handling both single tensors and lists."""
14
+ if isinstance(data, list):
15
+ # LTX2 AV tensors: [video, audio]
16
+ return data[0][:, :output_channels], data[1][:, :output_channels]
17
+ return data[:, :output_channels], None
18
+
19
+
20
+ def easycache_forward_wrapper(executor, *args, **kwargs):
21
+ # get values from args
22
+ transformer_options: dict[str] = args[-1]
23
+ if not isinstance(transformer_options, dict):
24
+ transformer_options = kwargs.get("transformer_options")
25
+ if not transformer_options:
26
+ transformer_options = args[-2]
27
+ easycache: EasyCacheHolder = transformer_options["easycache"]
28
+ x, ax = _extract_tensor(args[0], easycache.output_channels)
29
+ sigmas = transformer_options["sigmas"]
30
+ uuids = transformer_options["uuids"]
31
+ if sigmas is not None and easycache.is_past_end_timestep(sigmas):
32
+ return executor(*args, **kwargs)
33
+ # prepare next x_prev
34
+ has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
35
+ next_x_prev = x
36
+ input_change = None
37
+ do_easycache = easycache.should_do_easycache(sigmas)
38
+ if do_easycache:
39
+ easycache.check_metadata(x)
40
+ # if there isn't a cache diff for current conds, we cannot skip this step
41
+ can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
42
+ # if first cond marked this step for skipping, skip it and use appropriate cached values
43
+ if easycache.skip_current_step and can_apply_cache_diff:
44
+ if easycache.verbose:
45
+ logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
46
+ result = easycache.apply_cache_diff(x, uuids)
47
+ if ax is not None:
48
+ result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
49
+ return [result, result_audio]
50
+ return result
51
+ if easycache.initial_step:
52
+ easycache.first_cond_uuid = uuids[0]
53
+ has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
54
+ easycache.initial_step = False
55
+ if has_first_cond_uuid:
56
+ if easycache.has_x_prev_subsampled():
57
+ input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
58
+ if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
59
+ approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
60
+ easycache.cumulative_change_rate += approx_output_change_rate
61
+ if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
62
+ if easycache.verbose:
63
+ logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
64
+ # other conds should also skip this step, and instead use their cached values
65
+ easycache.skip_current_step = True
66
+ result = easycache.apply_cache_diff(x, uuids)
67
+ if ax is not None:
68
+ result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
69
+ return [result, result_audio]
70
+ return result
71
+ else:
72
+ if easycache.verbose:
73
+ logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
74
+ easycache.cumulative_change_rate = 0.0
75
+
76
+ full_output: torch.Tensor = executor(*args, **kwargs)
77
+ output, audio_output = _extract_tensor(full_output, easycache.output_channels)
78
+ if has_first_cond_uuid and easycache.has_output_prev_norm():
79
+ output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
80
+ if easycache.verbose:
81
+ output_change_rate = output_change / easycache.output_prev_norm
82
+ easycache.output_change_rates.append(output_change_rate.item())
83
+ if easycache.has_relative_transformation_rate():
84
+ approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
85
+ easycache.approx_output_change_rates.append(approx_output_change_rate.item())
86
+ if easycache.verbose:
87
+ logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
88
+ if input_change is not None:
89
+ easycache.relative_transformation_rate = output_change / input_change
90
+ if easycache.verbose:
91
+ logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
92
+ # TODO: allow cache_diff to be offloaded
93
+ easycache.update_cache_diff(output, next_x_prev, uuids)
94
+ if audio_output is not None:
95
+ easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
96
+ if has_first_cond_uuid:
97
+ easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
98
+ easycache.output_prev_subsampled = easycache.subsample(output, uuids)
99
+ easycache.output_prev_norm = output.flatten().abs().mean()
100
+ if easycache.verbose:
101
+ logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
102
+ return full_output
103
+
104
+ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
105
+ # get values from args
106
+ timestep: float = args[1]
107
+ model_options: dict[str] = args[2]
108
+ easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
109
+ if easycache.is_past_end_timestep(timestep):
110
+ return executor(*args, **kwargs)
111
+ x: torch.Tensor = args[0][:, :easycache.output_channels]
112
+ # prepare next x_prev
113
+ next_x_prev = x
114
+ input_change = None
115
+ do_easycache = easycache.should_do_easycache(timestep)
116
+ if do_easycache:
117
+ easycache.check_metadata(x)
118
+ if easycache.has_x_prev_subsampled():
119
+ if easycache.has_x_prev_subsampled():
120
+ input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
121
+ if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
122
+ approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
123
+ easycache.cumulative_change_rate += approx_output_change_rate
124
+ if easycache.cumulative_change_rate < easycache.reuse_threshold:
125
+ if easycache.verbose:
126
+ logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
127
+ # other conds should also skip this step, and instead use their cached values
128
+ easycache.skip_current_step = True
129
+ return easycache.apply_cache_diff(x)
130
+ else:
131
+ if easycache.verbose:
132
+ logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
133
+ easycache.cumulative_change_rate = 0.0
134
+ output: torch.Tensor = executor(*args, **kwargs)
135
+ if easycache.has_output_prev_norm():
136
+ output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
137
+ if easycache.verbose:
138
+ output_change_rate = output_change / easycache.output_prev_norm
139
+ easycache.output_change_rates.append(output_change_rate.item())
140
+ if easycache.has_relative_transformation_rate():
141
+ approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
142
+ easycache.approx_output_change_rates.append(approx_output_change_rate.item())
143
+ if easycache.verbose:
144
+ logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}")
145
+ if input_change is not None:
146
+ easycache.relative_transformation_rate = output_change / input_change
147
+ if easycache.verbose:
148
+ logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}")
149
+ # TODO: allow cache_diff to be offloaded
150
+ easycache.update_cache_diff(output, next_x_prev)
151
+ easycache.x_prev_subsampled = easycache.subsample(next_x_prev)
152
+ easycache.output_prev_subsampled = easycache.subsample(output)
153
+ easycache.output_prev_norm = output.flatten().abs().mean()
154
+ if easycache.verbose:
155
+ logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
156
+ return output
157
+
158
+ def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs):
159
+ model_options = args[-1]
160
+ easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"]
161
+ easycache.skip_current_step = False
162
+ # TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset
163
+ return executor(*args, **kwargs)
164
+
165
+ def easycache_sample_wrapper(executor, *args, **kwargs):
166
+ """
167
+ This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end.
168
+ """
169
+ try:
170
+ guider = executor.class_obj
171
+ orig_model_options = guider.model_options
172
+ guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options)
173
+ # clone and prepare timesteps
174
+ guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling)
175
+ easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache']
176
+ logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}")
177
+ return executor(*args, **kwargs)
178
+ finally:
179
+ easycache = guider.model_options['transformer_options']['easycache']
180
+ output_change_rates = easycache.output_change_rates
181
+ approx_output_change_rates = easycache.approx_output_change_rates
182
+ if easycache.verbose:
183
+ logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
184
+ logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
185
+ total_steps = len(args[3])-1
186
+ # catch division by zero for log statement; sucks to crash after all sampling is done
187
+ try:
188
+ speedup = total_steps/(total_steps-easycache.total_steps_skipped)
189
+ except ZeroDivisionError:
190
+ speedup = 1.0
191
+ logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
192
+ easycache.reset()
193
+ guider.model_options = orig_model_options
194
+
195
+
196
+ class EasyCacheHolder:
197
+ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
198
+ self.name = "EasyCache"
199
+ self.reuse_threshold = reuse_threshold
200
+ self.start_percent = start_percent
201
+ self.end_percent = end_percent
202
+ self.subsample_factor = subsample_factor
203
+ self.offload_cache_diff = offload_cache_diff
204
+ self.verbose = verbose
205
+ # timestep values
206
+ self.start_t = 0.0
207
+ self.end_t = 0.0
208
+ # control values
209
+ self.relative_transformation_rate: float = None
210
+ self.cumulative_change_rate = 0.0
211
+ self.initial_step = True
212
+ self.skip_current_step = False
213
+ # cache values
214
+ self.first_cond_uuid = None
215
+ self.x_prev_subsampled: torch.Tensor = None
216
+ self.output_prev_subsampled: torch.Tensor = None
217
+ self.output_prev_norm: torch.Tensor = None
218
+ self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
219
+ self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
220
+ self.output_change_rates = []
221
+ self.approx_output_change_rates = []
222
+ self.total_steps_skipped = 0
223
+ # how to deal with mismatched dims
224
+ self.allow_mismatch = True
225
+ self.cut_from_start = True
226
+ self.state_metadata = None
227
+ self.output_channels = output_channels
228
+
229
+ def is_past_end_timestep(self, timestep: float) -> bool:
230
+ return not (timestep[0] > self.end_t).item()
231
+
232
+ def should_do_easycache(self, timestep: float) -> bool:
233
+ return (timestep[0] <= self.start_t).item()
234
+
235
+ def has_x_prev_subsampled(self) -> bool:
236
+ return self.x_prev_subsampled is not None
237
+
238
+ def has_output_prev_subsampled(self) -> bool:
239
+ return self.output_prev_subsampled is not None
240
+
241
+ def has_output_prev_norm(self) -> bool:
242
+ return self.output_prev_norm is not None
243
+
244
+ def has_relative_transformation_rate(self) -> bool:
245
+ return self.relative_transformation_rate is not None
246
+
247
+ def prepare_timesteps(self, model_sampling):
248
+ self.start_t = model_sampling.percent_to_sigma(self.start_percent)
249
+ self.end_t = model_sampling.percent_to_sigma(self.end_percent)
250
+ return self
251
+
252
+ def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor:
253
+ batch_offset = x.shape[0] // len(uuids)
254
+ uuid_idx = uuids.index(self.first_cond_uuid)
255
+ if self.subsample_factor > 1:
256
+ to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor]
257
+ if clone:
258
+ return to_return.clone()
259
+ return to_return
260
+ to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...]
261
+ if clone:
262
+ return to_return.clone()
263
+ return to_return
264
+
265
+ def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
266
+ return all(uuid in self.uuid_cache_diffs for uuid in uuids)
267
+
268
+ def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
269
+ if self.first_cond_uuid in uuids and not is_audio:
270
+ self.total_steps_skipped += 1
271
+ cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
272
+ batch_offset = x.shape[0] // len(uuids)
273
+ for i, uuid in enumerate(uuids):
274
+ # slice out only what is relevant to this cond
275
+ batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
276
+ # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
277
+ if x.shape[1:] != cache_diffs[uuid].shape[1:]:
278
+ if not self.allow_mismatch:
279
+ raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
280
+ slicing = []
281
+ skip_this_dim = True
282
+ for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
283
+ if skip_this_dim:
284
+ skip_this_dim = False
285
+ continue
286
+ if dim_u != dim_x:
287
+ if self.cut_from_start:
288
+ slicing.append(slice(dim_x-dim_u, None))
289
+ else:
290
+ slicing.append(slice(None, dim_u))
291
+ else:
292
+ slicing.append(slice(None))
293
+ batch_slice = batch_slice + slicing
294
+ x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
295
+ return x
296
+
297
+ def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
298
+ cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
299
+ # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
300
+ if output.shape[1:] != x.shape[1:]:
301
+ if not self.allow_mismatch:
302
+ raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good")
303
+ slicing = []
304
+ skip_dim = True
305
+ for dim_o, dim_x in zip(output.shape, x.shape):
306
+ if not skip_dim and dim_o != dim_x:
307
+ if self.cut_from_start:
308
+ slicing.append(slice(dim_x-dim_o, None))
309
+ else:
310
+ slicing.append(slice(None, dim_o))
311
+ else:
312
+ slicing.append(slice(None))
313
+ skip_dim = False
314
+ x = x[tuple(slicing)]
315
+ diff = output - x
316
+ batch_offset = diff.shape[0] // len(uuids)
317
+ for i, uuid in enumerate(uuids):
318
+ cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
319
+
320
+ def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
321
+ return self.first_cond_uuid in uuids
322
+
323
+ def check_metadata(self, x: torch.Tensor) -> bool:
324
+ metadata = (x.device, x.dtype, x.shape[1:])
325
+ if self.state_metadata is None:
326
+ self.state_metadata = metadata
327
+ return True
328
+ if metadata == self.state_metadata:
329
+ return True
330
+ logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
331
+ self.reset()
332
+ return False
333
+
334
+ def reset(self):
335
+ self.relative_transformation_rate = 0.0
336
+ self.cumulative_change_rate = 0.0
337
+ self.initial_step = True
338
+ self.skip_current_step = False
339
+ self.output_change_rates = []
340
+ self.first_cond_uuid = None
341
+ del self.x_prev_subsampled
342
+ self.x_prev_subsampled = None
343
+ del self.output_prev_subsampled
344
+ self.output_prev_subsampled = None
345
+ del self.output_prev_norm
346
+ self.output_prev_norm = None
347
+ del self.uuid_cache_diffs
348
+ self.uuid_cache_diffs = {}
349
+ del self.uuid_cache_diffs_audio
350
+ self.uuid_cache_diffs_audio = {}
351
+ self.total_steps_skipped = 0
352
+ self.state_metadata = None
353
+ return self
354
+
355
+ def clone(self):
356
+ return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
357
+
358
+
359
+ class EasyCacheNode(io.ComfyNode):
360
+ @classmethod
361
+ def define_schema(cls) -> io.Schema:
362
+ return io.Schema(
363
+ node_id="EasyCache",
364
+ display_name="EasyCache",
365
+ description="Native EasyCache implementation.",
366
+ category="advanced/debug/model",
367
+ is_experimental=True,
368
+ inputs=[
369
+ io.Model.Input("model", tooltip="The model to add EasyCache to."),
370
+ io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True),
371
+ io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache.", advanced=True),
372
+ io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache.", advanced=True),
373
+ io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True),
374
+ ],
375
+ outputs=[
376
+ io.Model.Output(tooltip="The model with EasyCache."),
377
+ ],
378
+ )
379
+
380
+ @classmethod
381
+ def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
382
+ model = model.clone()
383
+ model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
384
+ model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
385
+ model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
386
+ model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
387
+ return io.NodeOutput(model)
388
+
389
+
390
+ class LazyCacheHolder:
391
+ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
392
+ self.name = "LazyCache"
393
+ self.reuse_threshold = reuse_threshold
394
+ self.start_percent = start_percent
395
+ self.end_percent = end_percent
396
+ self.subsample_factor = subsample_factor
397
+ self.offload_cache_diff = offload_cache_diff
398
+ self.verbose = verbose
399
+ # timestep values
400
+ self.start_t = 0.0
401
+ self.end_t = 0.0
402
+ # control values
403
+ self.relative_transformation_rate: float = None
404
+ self.cumulative_change_rate = 0.0
405
+ self.initial_step = True
406
+ # cache values
407
+ self.x_prev_subsampled: torch.Tensor = None
408
+ self.output_prev_subsampled: torch.Tensor = None
409
+ self.output_prev_norm: torch.Tensor = None
410
+ self.cache_diff: torch.Tensor = None
411
+ self.output_change_rates = []
412
+ self.approx_output_change_rates = []
413
+ self.total_steps_skipped = 0
414
+ self.state_metadata = None
415
+ self.output_channels = output_channels
416
+
417
+ def has_cache_diff(self) -> bool:
418
+ return self.cache_diff is not None
419
+
420
+ def is_past_end_timestep(self, timestep: float) -> bool:
421
+ return not (timestep[0] > self.end_t).item()
422
+
423
+ def should_do_easycache(self, timestep: float) -> bool:
424
+ return (timestep[0] <= self.start_t).item()
425
+
426
+ def has_x_prev_subsampled(self) -> bool:
427
+ return self.x_prev_subsampled is not None
428
+
429
+ def has_output_prev_subsampled(self) -> bool:
430
+ return self.output_prev_subsampled is not None
431
+
432
+ def has_output_prev_norm(self) -> bool:
433
+ return self.output_prev_norm is not None
434
+
435
+ def has_relative_transformation_rate(self) -> bool:
436
+ return self.relative_transformation_rate is not None
437
+
438
+ def prepare_timesteps(self, model_sampling):
439
+ self.start_t = model_sampling.percent_to_sigma(self.start_percent)
440
+ self.end_t = model_sampling.percent_to_sigma(self.end_percent)
441
+ return self
442
+
443
+ def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor:
444
+ if self.subsample_factor > 1:
445
+ to_return = x[..., ::self.subsample_factor, ::self.subsample_factor]
446
+ if clone:
447
+ return to_return.clone()
448
+ return to_return
449
+ if clone:
450
+ return x.clone()
451
+ return x
452
+
453
+ def apply_cache_diff(self, x: torch.Tensor):
454
+ self.total_steps_skipped += 1
455
+ return x + self.cache_diff.to(x.device)
456
+
457
+ def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
458
+ self.cache_diff = output - x
459
+
460
+ def check_metadata(self, x: torch.Tensor) -> bool:
461
+ metadata = (x.device, x.dtype, x.shape)
462
+ if self.state_metadata is None:
463
+ self.state_metadata = metadata
464
+ return True
465
+ if metadata == self.state_metadata:
466
+ return True
467
+ logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
468
+ self.reset()
469
+ return False
470
+
471
+ def reset(self):
472
+ self.relative_transformation_rate = 0.0
473
+ self.cumulative_change_rate = 0.0
474
+ self.initial_step = True
475
+ self.output_change_rates = []
476
+ self.approx_output_change_rates = []
477
+ del self.cache_diff
478
+ self.cache_diff = None
479
+ del self.x_prev_subsampled
480
+ self.x_prev_subsampled = None
481
+ del self.output_prev_subsampled
482
+ self.output_prev_subsampled = None
483
+ del self.output_prev_norm
484
+ self.output_prev_norm = None
485
+ self.total_steps_skipped = 0
486
+ self.state_metadata = None
487
+ return self
488
+
489
+ def clone(self):
490
+ return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
491
+
492
+ class LazyCacheNode(io.ComfyNode):
493
+ @classmethod
494
+ def define_schema(cls) -> io.Schema:
495
+ return io.Schema(
496
+ node_id="LazyCache",
497
+ display_name="LazyCache",
498
+ description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.",
499
+ category="advanced/debug/model",
500
+ is_experimental=True,
501
+ inputs=[
502
+ io.Model.Input("model", tooltip="The model to add LazyCache to."),
503
+ io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True),
504
+ io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache.", advanced=True),
505
+ io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache.", advanced=True),
506
+ io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True),
507
+ ],
508
+ outputs=[
509
+ io.Model.Output(tooltip="The model with LazyCache."),
510
+ ],
511
+ )
512
+
513
+ @classmethod
514
+ def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
515
+ model = model.clone()
516
+ model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
517
+ model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
518
+ model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
519
+ return io.NodeOutput(model)
520
+
521
+
522
+ class EasyCacheExtension(ComfyExtension):
523
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
524
+ return [
525
+ EasyCacheNode,
526
+ LazyCacheNode,
527
+ ]
528
+
529
+ def comfy_entrypoint():
530
+ return EasyCacheExtension()
ComfyUI/comfy_extras/nodes_edit_model.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import node_helpers
2
+ from typing_extensions import override
3
+ from comfy_api.latest import ComfyExtension, io
4
+
5
+
6
+ class ReferenceLatent(io.ComfyNode):
7
+ @classmethod
8
+ def define_schema(cls):
9
+ return io.Schema(
10
+ node_id="ReferenceLatent",
11
+ category="advanced/conditioning/edit_models",
12
+ description="This node sets the guiding latent for an edit model. If the model supports it you can chain multiple to set multiple reference images.",
13
+ inputs=[
14
+ io.Conditioning.Input("conditioning"),
15
+ io.Latent.Input("latent", optional=True),
16
+ ],
17
+ outputs=[
18
+ io.Conditioning.Output(),
19
+ ]
20
+ )
21
+
22
+ @classmethod
23
+ def execute(cls, conditioning, latent=None) -> io.NodeOutput:
24
+ if latent is not None:
25
+ conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [latent["samples"]]}, append=True)
26
+ return io.NodeOutput(conditioning)
27
+
28
+
29
+ class EditModelExtension(ComfyExtension):
30
+ @override
31
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
32
+ return [
33
+ ReferenceLatent,
34
+ ]
35
+
36
+
37
+ def comfy_entrypoint() -> EditModelExtension:
38
+ return EditModelExtension()
ComfyUI/comfy_extras/nodes_eps.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing_extensions import override
3
+
4
+ from comfy.k_diffusion.sampling import sigma_to_half_log_snr
5
+ from comfy_api.latest import ComfyExtension, io
6
+
7
+
8
+ class EpsilonScaling(io.ComfyNode):
9
+ """
10
+ Implements the Epsilon Scaling method from 'Elucidating the Exposure Bias in Diffusion Models'
11
+ (https://arxiv.org/abs/2308.15321v6).
12
+
13
+ This method mitigates exposure bias by scaling the predicted noise during sampling,
14
+ which can significantly improve sample quality. This implementation uses the "uniform schedule"
15
+ recommended by the paper for its practicality and effectiveness.
16
+ """
17
+ @classmethod
18
+ def define_schema(cls):
19
+ return io.Schema(
20
+ node_id="Epsilon Scaling",
21
+ category="model_patches/unet",
22
+ inputs=[
23
+ io.Model.Input("model"),
24
+ io.Float.Input(
25
+ "scaling_factor",
26
+ default=1.005,
27
+ min=0.5,
28
+ max=1.5,
29
+ step=0.001,
30
+ display_mode=io.NumberDisplay.number,
31
+ advanced=True,
32
+ ),
33
+ ],
34
+ outputs=[
35
+ io.Model.Output(),
36
+ ],
37
+ )
38
+
39
+ @classmethod
40
+ def execute(cls, model, scaling_factor) -> io.NodeOutput:
41
+ # Prevent division by zero, though the UI's min value should prevent this.
42
+ if scaling_factor == 0:
43
+ scaling_factor = 1e-9
44
+
45
+ def epsilon_scaling_function(args):
46
+ """
47
+ This function is applied after the CFG guidance has been calculated.
48
+ It recalculates the denoised latent by scaling the predicted noise.
49
+ """
50
+ denoised = args["denoised"]
51
+ x = args["input"]
52
+
53
+ noise_pred = x - denoised
54
+
55
+ scaled_noise_pred = noise_pred / scaling_factor
56
+
57
+ new_denoised = x - scaled_noise_pred
58
+
59
+ return new_denoised
60
+
61
+ # Clone the model patcher to avoid modifying the original model in place
62
+ model_clone = model.clone()
63
+
64
+ model_clone.set_model_sampler_post_cfg_function(epsilon_scaling_function)
65
+
66
+ return io.NodeOutput(model_clone)
67
+
68
+
69
+ def compute_tsr_rescaling_factor(
70
+ snr: torch.Tensor, tsr_k: float, tsr_variance: float
71
+ ) -> torch.Tensor:
72
+ """Compute the rescaling score ratio in Temporal Score Rescaling.
73
+
74
+ See equation (6) in https://arxiv.org/pdf/2510.01184v1.
75
+ """
76
+ posinf_mask = torch.isposinf(snr)
77
+ rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
78
+ return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
79
+
80
+
81
+ class TemporalScoreRescaling(io.ComfyNode):
82
+ @classmethod
83
+ def define_schema(cls):
84
+ return io.Schema(
85
+ node_id="TemporalScoreRescaling",
86
+ display_name="TSR - Temporal Score Rescaling",
87
+ category="model_patches/unet",
88
+ inputs=[
89
+ io.Model.Input("model"),
90
+ io.Float.Input(
91
+ "tsr_k",
92
+ tooltip=(
93
+ "Controls the rescaling strength.\n"
94
+ "Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling."
95
+ ),
96
+ default=0.95,
97
+ min=0.01,
98
+ max=100.0,
99
+ step=0.001,
100
+ display_mode=io.NumberDisplay.number,
101
+ advanced=True,
102
+ ),
103
+ io.Float.Input(
104
+ "tsr_sigma",
105
+ tooltip=(
106
+ "Controls how early rescaling takes effect.\n"
107
+ "Larger values take effect earlier."
108
+ ),
109
+ default=1.0,
110
+ min=0.01,
111
+ max=100.0,
112
+ step=0.001,
113
+ display_mode=io.NumberDisplay.number,
114
+ advanced=True,
115
+ ),
116
+ ],
117
+ outputs=[
118
+ io.Model.Output(
119
+ display_name="patched_model",
120
+ ),
121
+ ],
122
+ description=(
123
+ "[Post-CFG Function]\n"
124
+ "TSR - Temporal Score Rescaling (2510.01184)\n\n"
125
+ "Rescaling the model's score or noise to steer the sampling diversity.\n"
126
+ ),
127
+ )
128
+
129
+ @classmethod
130
+ def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
131
+ tsr_variance = tsr_sigma**2
132
+
133
+ def temporal_score_rescaling(args):
134
+ denoised = args["denoised"]
135
+ x = args["input"]
136
+ sigma = args["sigma"]
137
+ curr_model = args["model"]
138
+
139
+ # No rescaling (r = 1) or no noise
140
+ if tsr_k == 1 or sigma == 0:
141
+ return denoised
142
+
143
+ model_sampling = curr_model.current_patcher.get_model_object("model_sampling")
144
+ half_log_snr = sigma_to_half_log_snr(sigma, model_sampling)
145
+ snr = (2 * half_log_snr).exp()
146
+
147
+ # No rescaling needed (r = 1)
148
+ if snr == 0:
149
+ return denoised
150
+
151
+ rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance)
152
+
153
+ # Derived from scaled_denoised = (x - r * sigma * noise) / alpha
154
+ alpha = sigma * half_log_snr.exp()
155
+ return torch.lerp(x / alpha, denoised, rescaling_r)
156
+
157
+ m = model.clone()
158
+ m.set_model_sampler_post_cfg_function(temporal_score_rescaling)
159
+ return io.NodeOutput(m)
160
+
161
+
162
+ class EpsilonScalingExtension(ComfyExtension):
163
+ @override
164
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
165
+ return [
166
+ EpsilonScaling,
167
+ TemporalScoreRescaling,
168
+ ]
169
+
170
+
171
+ async def comfy_entrypoint() -> EpsilonScalingExtension:
172
+ return EpsilonScalingExtension()
ComfyUI/comfy_extras/nodes_flux.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import node_helpers
2
+ import comfy.utils
3
+ from typing_extensions import override
4
+ from comfy_api.latest import ComfyExtension, io
5
+ import comfy.model_management
6
+ import torch
7
+ import math
8
+ import nodes
9
+ import comfy.ldm.flux.math
10
+
11
+ class CLIPTextEncodeFlux(io.ComfyNode):
12
+ @classmethod
13
+ def define_schema(cls):
14
+ return io.Schema(
15
+ node_id="CLIPTextEncodeFlux",
16
+ category="advanced/conditioning/flux",
17
+ inputs=[
18
+ io.Clip.Input("clip"),
19
+ io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
20
+ io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
21
+ io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
22
+ ],
23
+ outputs=[
24
+ io.Conditioning.Output(),
25
+ ],
26
+ )
27
+
28
+ @classmethod
29
+ def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
30
+ tokens = clip.tokenize(clip_l)
31
+ tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
32
+
33
+ return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
34
+
35
+ encode = execute # TODO: remove
36
+
37
+ class EmptyFlux2LatentImage(io.ComfyNode):
38
+ @classmethod
39
+ def define_schema(cls):
40
+ return io.Schema(
41
+ node_id="EmptyFlux2LatentImage",
42
+ display_name="Empty Flux 2 Latent",
43
+ category="latent",
44
+ inputs=[
45
+ io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
46
+ io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
47
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
48
+ ],
49
+ outputs=[
50
+ io.Latent.Output(),
51
+ ],
52
+ )
53
+
54
+ @classmethod
55
+ def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
56
+ latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device())
57
+ return io.NodeOutput({"samples": latent})
58
+
59
+ class FluxGuidance(io.ComfyNode):
60
+ @classmethod
61
+ def define_schema(cls):
62
+ return io.Schema(
63
+ node_id="FluxGuidance",
64
+ category="advanced/conditioning/flux",
65
+ inputs=[
66
+ io.Conditioning.Input("conditioning"),
67
+ io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
68
+ ],
69
+ outputs=[
70
+ io.Conditioning.Output(),
71
+ ],
72
+ )
73
+
74
+ @classmethod
75
+ def execute(cls, conditioning, guidance) -> io.NodeOutput:
76
+ c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
77
+ return io.NodeOutput(c)
78
+
79
+ append = execute # TODO: remove
80
+
81
+
82
+ class FluxDisableGuidance(io.ComfyNode):
83
+ @classmethod
84
+ def define_schema(cls):
85
+ return io.Schema(
86
+ node_id="FluxDisableGuidance",
87
+ category="advanced/conditioning/flux",
88
+ description="This node completely disables the guidance embed on Flux and Flux like models",
89
+ inputs=[
90
+ io.Conditioning.Input("conditioning"),
91
+ ],
92
+ outputs=[
93
+ io.Conditioning.Output(),
94
+ ],
95
+ )
96
+
97
+ @classmethod
98
+ def execute(cls, conditioning) -> io.NodeOutput:
99
+ c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
100
+ return io.NodeOutput(c)
101
+
102
+ append = execute # TODO: remove
103
+
104
+
105
+ PREFERED_KONTEXT_RESOLUTIONS = [
106
+ (672, 1568),
107
+ (688, 1504),
108
+ (720, 1456),
109
+ (752, 1392),
110
+ (800, 1328),
111
+ (832, 1248),
112
+ (880, 1184),
113
+ (944, 1104),
114
+ (1024, 1024),
115
+ (1104, 944),
116
+ (1184, 880),
117
+ (1248, 832),
118
+ (1328, 800),
119
+ (1392, 752),
120
+ (1456, 720),
121
+ (1504, 688),
122
+ (1568, 672),
123
+ ]
124
+
125
+
126
+ class FluxKontextImageScale(io.ComfyNode):
127
+ @classmethod
128
+ def define_schema(cls):
129
+ return io.Schema(
130
+ node_id="FluxKontextImageScale",
131
+ category="advanced/conditioning/flux",
132
+ description="This node resizes the image to one that is more optimal for flux kontext.",
133
+ inputs=[
134
+ io.Image.Input("image"),
135
+ ],
136
+ outputs=[
137
+ io.Image.Output(),
138
+ ],
139
+ )
140
+
141
+ @classmethod
142
+ def execute(cls, image) -> io.NodeOutput:
143
+ width = image.shape[2]
144
+ height = image.shape[1]
145
+ aspect_ratio = width / height
146
+ _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
147
+ image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
148
+ return io.NodeOutput(image)
149
+
150
+ scale = execute # TODO: remove
151
+
152
+
153
+ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
154
+ @classmethod
155
+ def define_schema(cls):
156
+ return io.Schema(
157
+ node_id="FluxKontextMultiReferenceLatentMethod",
158
+ display_name="Edit Model Reference Method",
159
+ category="advanced/conditioning/flux",
160
+ inputs=[
161
+ io.Conditioning.Input("conditioning"),
162
+ io.Combo.Input(
163
+ "reference_latents_method",
164
+ options=["offset", "index", "uxo/uno", "index_timestep_zero"],
165
+ advanced=True,
166
+ ),
167
+ ],
168
+ outputs=[
169
+ io.Conditioning.Output(),
170
+ ],
171
+ is_experimental=True,
172
+ )
173
+
174
+ @classmethod
175
+ def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
176
+ if "uxo" in reference_latents_method or "uso" in reference_latents_method:
177
+ reference_latents_method = "uxo"
178
+ c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
179
+ return io.NodeOutput(c)
180
+
181
+ append = execute # TODO: remove
182
+
183
+
184
+ def generalized_time_snr_shift(t, mu: float, sigma: float):
185
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
186
+
187
+
188
+ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
189
+ a1, b1 = 8.73809524e-05, 1.89833333
190
+ a2, b2 = 0.00016927, 0.45666666
191
+
192
+ if image_seq_len > 4300:
193
+ mu = a2 * image_seq_len + b2
194
+ return float(mu)
195
+
196
+ m_200 = a2 * image_seq_len + b2
197
+ m_10 = a1 * image_seq_len + b1
198
+
199
+ a = (m_200 - m_10) / 190.0
200
+ b = m_200 - 200.0 * a
201
+ mu = a * num_steps + b
202
+
203
+ return float(mu)
204
+
205
+
206
+ def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
207
+ mu = compute_empirical_mu(image_seq_len, num_steps)
208
+ timesteps = torch.linspace(1, 0, num_steps + 1)
209
+ timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
210
+ return timesteps
211
+
212
+
213
+ class Flux2Scheduler(io.ComfyNode):
214
+ @classmethod
215
+ def define_schema(cls):
216
+ return io.Schema(
217
+ node_id="Flux2Scheduler",
218
+ category="sampling/custom_sampling/schedulers",
219
+ inputs=[
220
+ io.Int.Input("steps", default=20, min=1, max=4096),
221
+ io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
222
+ io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
223
+ ],
224
+ outputs=[
225
+ io.Sigmas.Output(),
226
+ ],
227
+ )
228
+
229
+ @classmethod
230
+ def execute(cls, steps, width, height) -> io.NodeOutput:
231
+ seq_len = (width * height / (16 * 16))
232
+ sigmas = get_schedule(steps, round(seq_len))
233
+ return io.NodeOutput(sigmas)
234
+
235
+ class KV_Attn_Input:
236
+ def __init__(self):
237
+ self.cache = {}
238
+
239
+ def __call__(self, q, k, v, extra_options, **kwargs):
240
+ reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
241
+ if len(reference_image_num_tokens) == 0:
242
+ return {}
243
+
244
+ ref_toks = sum(reference_image_num_tokens)
245
+ cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
246
+ if cache_key in self.cache:
247
+ kk, vv = self.cache[cache_key]
248
+ self.set_cache = False
249
+ return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
250
+
251
+ self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
252
+ self.set_cache = True
253
+ return {"q": q, "k": k, "v": v}
254
+
255
+ def cleanup(self):
256
+ self.cache = {}
257
+
258
+
259
+ class FluxKVCache(io.ComfyNode):
260
+ @classmethod
261
+ def define_schema(cls) -> io.Schema:
262
+ return io.Schema(
263
+ node_id="FluxKVCache",
264
+ display_name="Flux KV Cache",
265
+ description="Enables KV Cache optimization for reference images on Flux family models.",
266
+ category="",
267
+ is_experimental=True,
268
+ inputs=[
269
+ io.Model.Input("model", tooltip="The model to use KV Cache on."),
270
+ ],
271
+ outputs=[
272
+ io.Model.Output(tooltip="The patched model with KV Cache enabled."),
273
+ ],
274
+ )
275
+
276
+ @classmethod
277
+ def execute(cls, model: io.Model.Type) -> io.NodeOutput:
278
+ m = model.clone()
279
+ input_patch_obj = KV_Attn_Input()
280
+
281
+ def model_input_patch(inputs):
282
+ if len(input_patch_obj.cache) > 0:
283
+ ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
284
+ if ref_image_tokens > 0:
285
+ img = inputs["img"]
286
+ inputs["img"] = img[:, :-ref_image_tokens]
287
+ return inputs
288
+
289
+ m.set_model_attn1_patch(input_patch_obj)
290
+ m.set_model_post_input_patch(model_input_patch)
291
+ if hasattr(model.model.diffusion_model, "params"):
292
+ m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
293
+ else:
294
+ m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
295
+
296
+ return io.NodeOutput(m)
297
+
298
+ class FluxExtension(ComfyExtension):
299
+ @override
300
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
301
+ return [
302
+ CLIPTextEncodeFlux,
303
+ FluxGuidance,
304
+ FluxDisableGuidance,
305
+ FluxKontextImageScale,
306
+ FluxKontextMultiReferenceLatentMethod,
307
+ EmptyFlux2LatentImage,
308
+ Flux2Scheduler,
309
+ FluxKVCache,
310
+ ]
311
+
312
+
313
+ async def comfy_entrypoint() -> FluxExtension:
314
+ return FluxExtension()
ComfyUI/comfy_extras/nodes_frame_interpolation.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from typing_extensions import override
4
+
5
+ import comfy.model_patcher
6
+ import comfy.utils
7
+ import folder_paths
8
+ from comfy import model_management
9
+ from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config
10
+ from comfy_extras.frame_interpolation_models.film_net import FILMNet
11
+ from comfy_api.latest import ComfyExtension, io
12
+
13
+ FrameInterpolationModel = io.Custom("INTERP_MODEL")
14
+
15
+
16
+ class FrameInterpolationModelLoader(io.ComfyNode):
17
+ @classmethod
18
+ def define_schema(cls):
19
+ return io.Schema(
20
+ node_id="FrameInterpolationModelLoader",
21
+ display_name="Load Frame Interpolation Model",
22
+ category="loaders",
23
+ inputs=[
24
+ io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"),
25
+ tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."),
26
+ ],
27
+ outputs=[
28
+ FrameInterpolationModel.Output(),
29
+ ],
30
+ )
31
+
32
+ @classmethod
33
+ def execute(cls, model_name) -> io.NodeOutput:
34
+ model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name)
35
+ sd = comfy.utils.load_torch_file(model_path, safe_load=True)
36
+
37
+ model = cls._detect_and_load(sd)
38
+ dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
39
+ model.eval().to(dtype)
40
+ patcher = comfy.model_patcher.ModelPatcher(
41
+ model,
42
+ load_device=model_management.get_torch_device(),
43
+ offload_device=model_management.unet_offload_device(),
44
+ )
45
+ return io.NodeOutput(patcher)
46
+
47
+ @classmethod
48
+ def _detect_and_load(cls, sd):
49
+ # Try FILM
50
+ if "extract.extract_sublevels.convs.0.0.conv.weight" in sd:
51
+ model = FILMNet()
52
+ model.load_state_dict(sd)
53
+ return model
54
+
55
+ # Try RIFE (needs key remapping for raw checkpoints)
56
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""})
57
+ key_map = {}
58
+ for k in sd:
59
+ for i in range(5):
60
+ if k.startswith(f"block{i}."):
61
+ key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}"
62
+ if key_map:
63
+ sd = {key_map.get(k, k): v for k, v in sd.items()}
64
+ sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))}
65
+
66
+ try:
67
+ head_ch, channels = detect_rife_config(sd)
68
+ except (KeyError, ValueError):
69
+ raise ValueError("Unrecognized frame interpolation model format")
70
+ model = IFNet(head_ch=head_ch, channels=channels)
71
+ model.load_state_dict(sd)
72
+ return model
73
+
74
+
75
+ class FrameInterpolate(io.ComfyNode):
76
+ @classmethod
77
+ def define_schema(cls):
78
+ return io.Schema(
79
+ node_id="FrameInterpolate",
80
+ display_name="Frame Interpolate",
81
+ category="image/video",
82
+ search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
83
+ inputs=[
84
+ FrameInterpolationModel.Input("interp_model"),
85
+ io.Image.Input("images"),
86
+ io.Int.Input("multiplier", default=2, min=2, max=16),
87
+ ],
88
+ outputs=[
89
+ io.Image.Output(),
90
+ ],
91
+ )
92
+
93
+ @classmethod
94
+ def execute(cls, interp_model, images, multiplier) -> io.NodeOutput:
95
+ offload_device = model_management.intermediate_device()
96
+
97
+ num_frames = images.shape[0]
98
+ if num_frames < 2 or multiplier < 2:
99
+ return io.NodeOutput(images)
100
+
101
+ model_management.load_model_gpu(interp_model)
102
+ device = interp_model.load_device
103
+ dtype = interp_model.model_dtype()
104
+ inference_model = interp_model.model
105
+
106
+ # Free VRAM for inference activations (model weights + ~20x a single frame's worth)
107
+ H, W = images.shape[1], images.shape[2]
108
+ activation_mem = H * W * 3 * images.element_size() * 20
109
+ model_management.free_memory(activation_mem, device)
110
+ align = getattr(inference_model, "pad_align", 1)
111
+
112
+ # Prepare a single padded frame on device for determining output dimensions
113
+ def prepare_frame(idx):
114
+ frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device)
115
+ if align > 1:
116
+ from comfy.ldm.common_dit import pad_to_patch_size
117
+ frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect")
118
+ return frame
119
+
120
+ # Count total interpolation passes for progress bar
121
+ total_pairs = num_frames - 1
122
+ num_interp = multiplier - 1
123
+ total_steps = total_pairs * num_interp
124
+ pbar = comfy.utils.ProgressBar(total_steps)
125
+ tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation")
126
+
127
+ batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit)
128
+ t_values = [t / multiplier for t in range(1, multiplier)]
129
+
130
+ out_dtype = model_management.intermediate_dtype()
131
+ total_out_frames = total_pairs * multiplier + 1
132
+ result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device)
133
+ result[0] = images[0].movedim(-1, 0).to(out_dtype)
134
+ out_idx = 1
135
+
136
+ # Pre-compute timestep tensor on device (padded dimensions needed)
137
+ sample = prepare_frame(0)
138
+ pH, pW = sample.shape[2], sample.shape[3]
139
+ ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1)
140
+ ts_full = ts_full.expand(-1, 1, pH, pW)
141
+ del sample
142
+
143
+ multi_fn = getattr(inference_model, "forward_multi_timestep", None)
144
+ feat_cache = {}
145
+ prev_frame = None
146
+
147
+ try:
148
+ for i in range(total_pairs):
149
+ img0_single = prev_frame if prev_frame is not None else prepare_frame(i)
150
+ img1_single = prepare_frame(i + 1)
151
+ prev_frame = img1_single
152
+
153
+ # Cache features: img1 of pair N becomes img0 of pair N+1
154
+ feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single)
155
+ feat_cache["img1"] = inference_model.extract_features(img1_single)
156
+ feat_cache["next"] = feat_cache["img1"]
157
+
158
+ used_multi = False
159
+ if multi_fn is not None:
160
+ # Models with timestep-independent flow can compute it once for all timesteps
161
+ try:
162
+ mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache)
163
+ result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype)
164
+ out_idx += num_interp
165
+ pbar.update(num_interp)
166
+ tqdm_bar.update(num_interp)
167
+ used_multi = True
168
+ except model_management.OOM_EXCEPTION:
169
+ model_management.soft_empty_cache()
170
+ multi_fn = None # fall through to single-timestep path
171
+
172
+ if not used_multi:
173
+ j = 0
174
+ while j < num_interp:
175
+ b = min(batch, num_interp - j)
176
+ try:
177
+ img0 = img0_single.expand(b, -1, -1, -1)
178
+ img1 = img1_single.expand(b, -1, -1, -1)
179
+ mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache)
180
+ result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype)
181
+ out_idx += b
182
+ pbar.update(b)
183
+ tqdm_bar.update(b)
184
+ j += b
185
+ except model_management.OOM_EXCEPTION:
186
+ if batch <= 1:
187
+ raise
188
+ batch = max(1, batch // 2)
189
+ model_management.soft_empty_cache()
190
+
191
+ result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype)
192
+ out_idx += 1
193
+ finally:
194
+ tqdm_bar.close()
195
+
196
+ # BCHW -> BHWC
197
+ result = result.movedim(1, -1).clamp_(0.0, 1.0)
198
+ return io.NodeOutput(result)
199
+
200
+
201
+ class FrameInterpolationExtension(ComfyExtension):
202
+ @override
203
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
204
+ return [
205
+ FrameInterpolationModelLoader,
206
+ FrameInterpolate,
207
+ ]
208
+
209
+
210
+ async def comfy_entrypoint() -> FrameInterpolationExtension:
211
+ return FrameInterpolationExtension()
ComfyUI/comfy_extras/nodes_freelunch.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
2
+
3
+ import torch
4
+ import logging
5
+ from typing_extensions import override
6
+ from comfy_api.latest import ComfyExtension, IO
7
+
8
+ def Fourier_filter(x, threshold, scale):
9
+ # FFT
10
+ x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
11
+ x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
12
+
13
+ B, C, H, W = x_freq.shape
14
+ mask = torch.ones((B, C, H, W), device=x.device)
15
+
16
+ crow, ccol = H // 2, W //2
17
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
18
+ x_freq = x_freq * mask
19
+
20
+ # IFFT
21
+ x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
22
+ x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
23
+
24
+ return x_filtered.to(x.dtype)
25
+
26
+
27
+ class FreeU(IO.ComfyNode):
28
+ @classmethod
29
+ def define_schema(cls):
30
+ return IO.Schema(
31
+ node_id="FreeU",
32
+ category="model_patches/unet",
33
+ inputs=[
34
+ IO.Model.Input("model"),
35
+ IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01, advanced=True),
36
+ IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01, advanced=True),
37
+ IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True),
38
+ IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True),
39
+ ],
40
+ outputs=[
41
+ IO.Model.Output(),
42
+ ],
43
+ )
44
+
45
+ @classmethod
46
+ def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
47
+ model_channels = model.model.model_config.unet_config["model_channels"]
48
+ scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
49
+ on_cpu_devices = {}
50
+
51
+ def output_block_patch(h, hsp, transformer_options):
52
+ scale = scale_dict.get(int(h.shape[1]), None)
53
+ if scale is not None:
54
+ h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
55
+ if hsp.device not in on_cpu_devices:
56
+ try:
57
+ hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
58
+ except:
59
+ logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
60
+ on_cpu_devices[hsp.device] = True
61
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
62
+ else:
63
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
64
+
65
+ return h, hsp
66
+
67
+ m = model.clone()
68
+ m.set_model_output_block_patch(output_block_patch)
69
+ return IO.NodeOutput(m)
70
+
71
+ patch = execute # TODO: remove
72
+
73
+
74
+ class FreeU_V2(IO.ComfyNode):
75
+ @classmethod
76
+ def define_schema(cls):
77
+ return IO.Schema(
78
+ node_id="FreeU_V2",
79
+ category="model_patches/unet",
80
+ inputs=[
81
+ IO.Model.Input("model"),
82
+ IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01, advanced=True),
83
+ IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01, advanced=True),
84
+ IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True),
85
+ IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True),
86
+ ],
87
+ outputs=[
88
+ IO.Model.Output(),
89
+ ],
90
+ )
91
+
92
+ @classmethod
93
+ def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
94
+ model_channels = model.model.model_config.unet_config["model_channels"]
95
+ scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
96
+ on_cpu_devices = {}
97
+
98
+ def output_block_patch(h, hsp, transformer_options):
99
+ scale = scale_dict.get(int(h.shape[1]), None)
100
+ if scale is not None:
101
+ hidden_mean = h.mean(1).unsqueeze(1)
102
+ B = hidden_mean.shape[0]
103
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
104
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
105
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
106
+
107
+ h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
108
+
109
+ if hsp.device not in on_cpu_devices:
110
+ try:
111
+ hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
112
+ except:
113
+ logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
114
+ on_cpu_devices[hsp.device] = True
115
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
116
+ else:
117
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
118
+
119
+ return h, hsp
120
+
121
+ m = model.clone()
122
+ m.set_model_output_block_patch(output_block_patch)
123
+ return IO.NodeOutput(m)
124
+
125
+ patch = execute # TODO: remove
126
+
127
+
128
+ class FreelunchExtension(ComfyExtension):
129
+ @override
130
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
131
+ return [
132
+ FreeU,
133
+ FreeU_V2,
134
+ ]
135
+
136
+
137
+ async def comfy_entrypoint() -> FreelunchExtension:
138
+ return FreelunchExtension()
ComfyUI/comfy_extras/nodes_fresca.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code based on https://github.com/WikiChao/FreSca (MIT License)
2
+ import torch
3
+ import torch.fft as fft
4
+ from typing_extensions import override
5
+ from comfy_api.latest import ComfyExtension, io
6
+
7
+
8
+ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
9
+ """
10
+ Apply frequency-dependent scaling to an image tensor using Fourier transforms.
11
+
12
+ Parameters:
13
+ x: Input tensor of shape (B, C, H, W)
14
+ scale_low: Scaling factor for low-frequency components (default: 1.0)
15
+ scale_high: Scaling factor for high-frequency components (default: 1.5)
16
+ freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
17
+
18
+ Returns:
19
+ x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
20
+ """
21
+ # Preserve input dtype and device
22
+ dtype, device = x.dtype, x.device
23
+
24
+ # Convert to float32 for FFT computations
25
+ x = x.to(torch.float32)
26
+
27
+ # 1) Apply FFT and shift low frequencies to center
28
+ x_freq = fft.fftn(x, dim=(-2, -1))
29
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
30
+
31
+ # Initialize mask with high-frequency scaling factor
32
+ mask = torch.ones(x_freq.shape, device=device) * scale_high
33
+ m = mask
34
+ for d in range(len(x_freq.shape) - 2):
35
+ dim = d + 2
36
+ cc = x_freq.shape[dim] // 2
37
+ f_c = min(freq_cutoff, cc)
38
+ m = m.narrow(dim, cc - f_c, f_c * 2)
39
+
40
+ # Apply low-frequency scaling factor to center region
41
+ m[:] = scale_low
42
+
43
+ # 3) Apply frequency-specific scaling
44
+ x_freq = x_freq * mask
45
+
46
+ # 4) Convert back to spatial domain
47
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
48
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
49
+
50
+ # 5) Restore original dtype
51
+ x_filtered = x_filtered.to(dtype)
52
+
53
+ return x_filtered
54
+
55
+
56
+ class FreSca(io.ComfyNode):
57
+ @classmethod
58
+ def define_schema(cls):
59
+ return io.Schema(
60
+ node_id="FreSca",
61
+ search_aliases=["frequency guidance"],
62
+ display_name="FreSca",
63
+ category="_for_testing",
64
+ description="Applies frequency-dependent scaling to the guidance",
65
+ inputs=[
66
+ io.Model.Input("model"),
67
+ io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01,
68
+ tooltip="Scaling factor for low-frequency components", advanced=True),
69
+ io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01,
70
+ tooltip="Scaling factor for high-frequency components", advanced=True),
71
+ io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1,
72
+ tooltip="Number of frequency indices around center to consider as low-frequency", advanced=True),
73
+ ],
74
+ outputs=[
75
+ io.Model.Output(),
76
+ ],
77
+ is_experimental=True,
78
+ )
79
+
80
+ @classmethod
81
+ def execute(cls, model, scale_low, scale_high, freq_cutoff):
82
+ def custom_cfg_function(args):
83
+ conds_out = args["conds_out"]
84
+ if len(conds_out) <= 1 or None in args["conds"][:2]:
85
+ return conds_out
86
+ cond = conds_out[0]
87
+ uncond = conds_out[1]
88
+
89
+ guidance = cond - uncond
90
+ filtered_guidance = Fourier_filter(
91
+ guidance,
92
+ scale_low=scale_low,
93
+ scale_high=scale_high,
94
+ freq_cutoff=freq_cutoff,
95
+ )
96
+ filtered_cond = filtered_guidance + uncond
97
+
98
+ return [filtered_cond, uncond] + conds_out[2:]
99
+
100
+ m = model.clone()
101
+ m.set_model_sampler_pre_cfg_function(custom_cfg_function)
102
+
103
+ return io.NodeOutput(m)
104
+
105
+
106
+ class FreScaExtension(ComfyExtension):
107
+ @override
108
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
109
+ return [
110
+ FreSca,
111
+ ]
112
+
113
+
114
+ async def comfy_entrypoint() -> FreScaExtension:
115
+ return FreScaExtension()
ComfyUI/comfy_extras/nodes_gits.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
2
+ import numpy as np
3
+ import torch
4
+ from typing_extensions import override
5
+ from comfy_api.latest import ComfyExtension, io
6
+
7
+ def loglinear_interp(t_steps, num_steps):
8
+ """
9
+ Performs log-linear interpolation of a given array of decreasing numbers.
10
+ """
11
+ xs = np.linspace(0, 1, len(t_steps))
12
+ ys = np.log(t_steps[::-1])
13
+
14
+ new_xs = np.linspace(0, 1, num_steps)
15
+ new_ys = np.interp(new_xs, xs, ys)
16
+
17
+ interped_ys = np.exp(new_ys)[::-1].copy()
18
+ return interped_ys
19
+
20
+ NOISE_LEVELS = {
21
+ 0.80: [
22
+ [14.61464119, 7.49001646, 0.02916753],
23
+ [14.61464119, 11.54541874, 6.77309084, 0.02916753],
24
+ [14.61464119, 11.54541874, 7.49001646, 3.07277966, 0.02916753],
25
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 2.05039096, 0.02916753],
26
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 2.05039096, 0.02916753],
27
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
28
+ [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
29
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
30
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
31
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753],
32
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
33
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
34
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
35
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753],
36
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753],
37
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753],
38
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753],
39
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753],
40
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.75677586, 2.84484982, 1.78698075, 0.803307, 0.02916753],
41
+ ],
42
+ 0.85: [
43
+ [14.61464119, 7.49001646, 0.02916753],
44
+ [14.61464119, 7.49001646, 1.84880662, 0.02916753],
45
+ [14.61464119, 11.54541874, 6.77309084, 1.56271636, 0.02916753],
46
+ [14.61464119, 11.54541874, 7.11996698, 3.07277966, 1.24153244, 0.02916753],
47
+ [14.61464119, 11.54541874, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753],
48
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753],
49
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753],
50
+ [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753],
51
+ [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
52
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
53
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
54
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
55
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753],
56
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.60512662, 2.6383388, 1.56271636, 0.72133851, 0.02916753],
57
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
58
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
59
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
60
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
61
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753],
62
+ ],
63
+ 0.90: [
64
+ [14.61464119, 6.77309084, 0.02916753],
65
+ [14.61464119, 7.49001646, 1.56271636, 0.02916753],
66
+ [14.61464119, 7.49001646, 3.07277966, 0.95350921, 0.02916753],
67
+ [14.61464119, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753],
68
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753],
69
+ [14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.07277966, 1.61558151, 0.69515091, 0.02916753],
70
+ [14.61464119, 12.2308979, 8.75849152, 7.11996698, 4.86714602, 3.07277966, 1.61558151, 0.69515091, 0.02916753],
71
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 2.95596409, 1.61558151, 0.69515091, 0.02916753],
72
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
73
+ [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
74
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753],
75
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753],
76
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753],
77
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753],
78
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753],
79
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
80
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
81
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753],
82
+ [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.19988537, 1.51179266, 0.89115214, 0.43325692, 0.02916753],
83
+ ],
84
+ 0.95: [
85
+ [14.61464119, 6.77309084, 0.02916753],
86
+ [14.61464119, 6.77309084, 1.56271636, 0.02916753],
87
+ [14.61464119, 7.49001646, 2.84484982, 0.89115214, 0.02916753],
88
+ [14.61464119, 7.49001646, 4.86714602, 2.36326075, 0.803307, 0.02916753],
89
+ [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753],
90
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753],
91
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
92
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
93
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753],
94
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.41535246, 0.803307, 0.38853383, 0.02916753],
95
+ [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
96
+ [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
97
+ [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
98
+ [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
99
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753],
100
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.78698075, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
101
+ [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
102
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753],
103
+ [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.75677586, 3.07277966, 2.45070267, 1.91321158, 1.46270394, 1.05362725, 0.72133851, 0.43325692, 0.19894916, 0.02916753],
104
+ ],
105
+ 1.00: [
106
+ [14.61464119, 1.56271636, 0.02916753],
107
+ [14.61464119, 6.77309084, 0.95350921, 0.02916753],
108
+ [14.61464119, 6.77309084, 2.36326075, 0.803307, 0.02916753],
109
+ [14.61464119, 7.11996698, 3.07277966, 1.56271636, 0.59516323, 0.02916753],
110
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.41535246, 0.57119018, 0.02916753],
111
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753],
112
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753],
113
+ [14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
114
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753],
115
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.54755926, 0.25053367, 0.02916753],
116
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
117
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
118
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.12350607, 1.56271636, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753],
119
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.803307, 0.50118381, 0.27464288, 0.09824532, 0.02916753],
120
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.84880662, 1.36964464, 1.01931262, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
121
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.46139455, 2.84484982, 2.19988537, 1.67050016, 1.24153244, 0.92192322, 0.64427125, 0.43325692, 0.25053367, 0.09824532, 0.02916753],
122
+ [14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
123
+ [14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
124
+ [14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
125
+ ],
126
+ 1.05: [
127
+ [14.61464119, 0.95350921, 0.02916753],
128
+ [14.61464119, 6.77309084, 0.89115214, 0.02916753],
129
+ [14.61464119, 6.77309084, 2.05039096, 0.72133851, 0.02916753],
130
+ [14.61464119, 6.77309084, 2.84484982, 1.28281462, 0.52423614, 0.02916753],
131
+ [14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.803307, 0.34370604, 0.02916753],
132
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.56271636, 0.803307, 0.34370604, 0.02916753],
133
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.52423614, 0.22545385, 0.02916753],
134
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.74807048, 0.41087446, 0.17026083, 0.02916753],
135
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.59516323, 0.34370604, 0.13792117, 0.02916753],
136
+ [14.61464119, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
137
+ [14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
138
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
139
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.72759056, 1.24153244, 0.86115354, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
140
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753],
141
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.67050016, 1.28281462, 0.95350921, 0.72133851, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
142
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.36326075, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.61951244, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
143
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
144
+ [14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
145
+ [14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.98035145, 1.61558151, 1.32549286, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
146
+ ],
147
+ 1.10: [
148
+ [14.61464119, 0.89115214, 0.02916753],
149
+ [14.61464119, 2.36326075, 0.72133851, 0.02916753],
150
+ [14.61464119, 5.85520077, 1.61558151, 0.57119018, 0.02916753],
151
+ [14.61464119, 6.77309084, 2.45070267, 1.08895338, 0.45573691, 0.02916753],
152
+ [14.61464119, 6.77309084, 2.95596409, 1.56271636, 0.803307, 0.34370604, 0.02916753],
153
+ [14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.89115214, 0.4783645, 0.19894916, 0.02916753],
154
+ [14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.08895338, 0.64427125, 0.34370604, 0.13792117, 0.02916753],
155
+ [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.54755926, 0.27464288, 0.09824532, 0.02916753],
156
+ [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.4783645, 0.25053367, 0.09824532, 0.02916753],
157
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.41535246, 0.95350921, 0.64427125, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
158
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.61558151, 1.12534678, 0.803307, 0.54755926, 0.36617002, 0.22545385, 0.09824532, 0.02916753],
159
+ [14.61464119, 7.49001646, 4.86714602, 3.32507086, 2.45070267, 1.72759056, 1.24153244, 0.89115214, 0.64427125, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
160
+ [14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.05039096, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
161
+ [14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.12350607, 1.61558151, 1.24153244, 0.95350921, 0.72133851, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
162
+ [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
163
+ [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
164
+ [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
165
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
166
+ [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
167
+ ],
168
+ 1.15: [
169
+ [14.61464119, 0.83188516, 0.02916753],
170
+ [14.61464119, 1.84880662, 0.59516323, 0.02916753],
171
+ [14.61464119, 5.85520077, 1.56271636, 0.52423614, 0.02916753],
172
+ [14.61464119, 5.85520077, 1.91321158, 0.83188516, 0.34370604, 0.02916753],
173
+ [14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.59516323, 0.25053367, 0.02916753],
174
+ [14.61464119, 5.85520077, 2.84484982, 1.51179266, 0.803307, 0.41087446, 0.17026083, 0.02916753],
175
+ [14.61464119, 5.85520077, 2.84484982, 1.56271636, 0.89115214, 0.50118381, 0.25053367, 0.09824532, 0.02916753],
176
+ [14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.12534678, 0.72133851, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
177
+ [14.61464119, 6.77309084, 3.07277966, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
178
+ [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
179
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.36964464, 0.95350921, 0.69515091, 0.4783645, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
180
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753],
181
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
182
+ [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
183
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
184
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
185
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
186
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
187
+ [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
188
+ ],
189
+ 1.20: [
190
+ [14.61464119, 0.803307, 0.02916753],
191
+ [14.61464119, 1.56271636, 0.52423614, 0.02916753],
192
+ [14.61464119, 2.36326075, 0.92192322, 0.36617002, 0.02916753],
193
+ [14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.25053367, 0.02916753],
194
+ [14.61464119, 5.85520077, 2.05039096, 0.95350921, 0.45573691, 0.17026083, 0.02916753],
195
+ [14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.64427125, 0.29807833, 0.09824532, 0.02916753],
196
+ [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.803307, 0.45573691, 0.25053367, 0.09824532, 0.02916753],
197
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.95350921, 0.59516323, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
198
+ [14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
199
+ [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.83188516, 0.59516323, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
200
+ [14.61464119, 5.85520077, 3.07277966, 1.98035145, 1.36964464, 0.95350921, 0.69515091, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
201
+ [14.61464119, 6.77309084, 3.46139455, 2.36326075, 1.56271636, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
202
+ [14.61464119, 6.77309084, 3.46139455, 2.45070267, 1.61558151, 1.162866, 0.86115354, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
203
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
204
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
205
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
206
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.20157266, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
207
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
208
+ [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
209
+ ],
210
+ 1.25: [
211
+ [14.61464119, 0.72133851, 0.02916753],
212
+ [14.61464119, 1.56271636, 0.50118381, 0.02916753],
213
+ [14.61464119, 2.05039096, 0.803307, 0.32104823, 0.02916753],
214
+ [14.61464119, 2.36326075, 0.95350921, 0.43325692, 0.17026083, 0.02916753],
215
+ [14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.27464288, 0.09824532, 0.02916753],
216
+ [14.61464119, 3.07277966, 1.51179266, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
217
+ [14.61464119, 5.85520077, 2.36326075, 1.24153244, 0.72133851, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
218
+ [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
219
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.98595673, 0.64427125, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
220
+ [14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.52423614, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
221
+ [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
222
+ [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.86115354, 0.64427125, 0.4783645, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
223
+ [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.28281462, 0.92192322, 0.69515091, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
224
+ [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
225
+ [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
226
+ [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
227
+ [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
228
+ [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
229
+ [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.46270394, 1.08895338, 0.83188516, 0.66947293, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
230
+ ],
231
+ 1.30: [
232
+ [14.61464119, 0.72133851, 0.02916753],
233
+ [14.61464119, 1.24153244, 0.43325692, 0.02916753],
234
+ [14.61464119, 1.56271636, 0.59516323, 0.22545385, 0.02916753],
235
+ [14.61464119, 1.84880662, 0.803307, 0.36617002, 0.13792117, 0.02916753],
236
+ [14.61464119, 2.36326075, 1.01931262, 0.52423614, 0.25053367, 0.09824532, 0.02916753],
237
+ [14.61464119, 2.84484982, 1.36964464, 0.74807048, 0.41087446, 0.22545385, 0.09824532, 0.02916753],
238
+ [14.61464119, 3.07277966, 1.56271636, 0.89115214, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
239
+ [14.61464119, 3.07277966, 1.61558151, 0.95350921, 0.61951244, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
240
+ [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
241
+ [14.61464119, 5.85520077, 2.45070267, 1.41535246, 0.92192322, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
242
+ [14.61464119, 5.85520077, 2.6383388, 1.56271636, 1.01931262, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
243
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
244
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.77538133, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
245
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
246
+ [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
247
+ [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
248
+ [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
249
+ [14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
250
+ [14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
251
+ ],
252
+ 1.35: [
253
+ [14.61464119, 0.69515091, 0.02916753],
254
+ [14.61464119, 0.95350921, 0.34370604, 0.02916753],
255
+ [14.61464119, 1.56271636, 0.57119018, 0.19894916, 0.02916753],
256
+ [14.61464119, 1.61558151, 0.69515091, 0.29807833, 0.09824532, 0.02916753],
257
+ [14.61464119, 1.84880662, 0.83188516, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
258
+ [14.61464119, 2.45070267, 1.162866, 0.64427125, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
259
+ [14.61464119, 2.84484982, 1.36964464, 0.803307, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753],
260
+ [14.61464119, 2.84484982, 1.41535246, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
261
+ [14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.32104823, 0.22545385, 0.17026083, 0.09824532, 0.02916753],
262
+ [14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
263
+ [14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
264
+ [14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
265
+ [14.61464119, 3.07277966, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
266
+ [14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
267
+ [14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
268
+ [14.61464119, 5.85520077, 2.45070267, 1.51179266, 1.01931262, 0.74807048, 0.57119018, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
269
+ [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
270
+ [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
271
+ [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
272
+ ],
273
+ 1.40: [
274
+ [14.61464119, 0.59516323, 0.02916753],
275
+ [14.61464119, 0.95350921, 0.34370604, 0.02916753],
276
+ [14.61464119, 1.08895338, 0.43325692, 0.13792117, 0.02916753],
277
+ [14.61464119, 1.56271636, 0.64427125, 0.27464288, 0.09824532, 0.02916753],
278
+ [14.61464119, 1.61558151, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753],
279
+ [14.61464119, 2.05039096, 0.95350921, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753],
280
+ [14.61464119, 2.45070267, 1.24153244, 0.72133851, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
281
+ [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
282
+ [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.52423614, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
283
+ [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
284
+ [14.61464119, 2.84484982, 1.41535246, 0.86115354, 0.59516323, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
285
+ [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
286
+ [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.4783645, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
287
+ [14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
288
+ [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.72133851, 0.54755926, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
289
+ [14.61464119, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
290
+ [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
291
+ [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.43325692, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
292
+ [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
293
+ ],
294
+ 1.45: [
295
+ [14.61464119, 0.59516323, 0.02916753],
296
+ [14.61464119, 0.803307, 0.25053367, 0.02916753],
297
+ [14.61464119, 0.95350921, 0.34370604, 0.09824532, 0.02916753],
298
+ [14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753],
299
+ [14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
300
+ [14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
301
+ [14.61464119, 1.91321158, 0.95350921, 0.57119018, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
302
+ [14.61464119, 2.19988537, 1.08895338, 0.64427125, 0.41087446, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
303
+ [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
304
+ [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
305
+ [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
306
+ [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
307
+ [14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
308
+ [14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
309
+ [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
310
+ [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
311
+ [14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.72133851, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
312
+ [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
313
+ [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.59516323, 0.50118381, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
314
+ ],
315
+ 1.50: [
316
+ [14.61464119, 0.54755926, 0.02916753],
317
+ [14.61464119, 0.803307, 0.25053367, 0.02916753],
318
+ [14.61464119, 0.86115354, 0.32104823, 0.09824532, 0.02916753],
319
+ [14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753],
320
+ [14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753],
321
+ [14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753],
322
+ [14.61464119, 1.61558151, 0.83188516, 0.52423614, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753],
323
+ [14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.38853383, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753],
324
+ [14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.41087446, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
325
+ [14.61464119, 1.84880662, 0.95350921, 0.61951244, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
326
+ [14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
327
+ [14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
328
+ [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
329
+ [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
330
+ [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
331
+ [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
332
+ [14.61464119, 2.45070267, 1.32549286, 0.86115354, 0.64427125, 0.50118381, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
333
+ [14.61464119, 2.45070267, 1.36964464, 0.92192322, 0.69515091, 0.54755926, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
334
+ [14.61464119, 2.45070267, 1.41535246, 0.95350921, 0.72133851, 0.57119018, 0.4783645, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753],
335
+ ],
336
+ }
337
+
338
+ class GITSScheduler(io.ComfyNode):
339
+ @classmethod
340
+ def define_schema(cls):
341
+ return io.Schema(
342
+ node_id="GITSScheduler",
343
+ category="sampling/custom_sampling/schedulers",
344
+ inputs=[
345
+ io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05, advanced=True),
346
+ io.Int.Input("steps", default=10, min=2, max=1000),
347
+ io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
348
+ ],
349
+ outputs=[
350
+ io.Sigmas.Output(),
351
+ ],
352
+ )
353
+
354
+ @classmethod
355
+ def execute(cls, coeff, steps, denoise):
356
+ total_steps = steps
357
+ if denoise < 1.0:
358
+ if denoise <= 0.0:
359
+ return io.NodeOutput(torch.FloatTensor([]))
360
+ total_steps = round(steps * denoise)
361
+
362
+ if steps <= 20:
363
+ sigmas = NOISE_LEVELS[round(coeff, 2)][steps-2][:]
364
+ else:
365
+ sigmas = NOISE_LEVELS[round(coeff, 2)][-1][:]
366
+ sigmas = loglinear_interp(sigmas, steps + 1)
367
+
368
+ sigmas = sigmas[-(total_steps + 1):]
369
+ sigmas[-1] = 0
370
+ return io.NodeOutput(torch.FloatTensor(sigmas))
371
+
372
+
373
+ class GITSSchedulerExtension(ComfyExtension):
374
+ @override
375
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
376
+ return [
377
+ GITSScheduler,
378
+ ]
379
+
380
+
381
+ async def comfy_entrypoint() -> GITSSchedulerExtension:
382
+ return GITSSchedulerExtension()
ComfyUI/comfy_extras/nodes_glsl.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ import logging
5
+ import ctypes.util
6
+ import importlib.util
7
+ from typing import TypedDict
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ import nodes
13
+ from comfy_api.latest import ComfyExtension, io, ui
14
+ from typing_extensions import override
15
+ from utils.install_util import get_missing_requirements_message
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _check_opengl_availability():
21
+ """Early check for OpenGL availability. Raises RuntimeError if unlikely to work."""
22
+ logger.debug("_check_opengl_availability: starting")
23
+ missing = []
24
+
25
+ # Check Python packages (using find_spec to avoid importing)
26
+ logger.debug("_check_opengl_availability: checking for glfw package")
27
+ if importlib.util.find_spec("glfw") is None:
28
+ missing.append("glfw")
29
+
30
+ logger.debug("_check_opengl_availability: checking for OpenGL package")
31
+ if importlib.util.find_spec("OpenGL") is None:
32
+ missing.append("PyOpenGL")
33
+
34
+ if missing:
35
+ raise RuntimeError(
36
+ f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n"
37
+ )
38
+
39
+ # On Linux without display, check if headless backends are available
40
+ logger.debug(f"_check_opengl_availability: platform={sys.platform}")
41
+ if sys.platform.startswith("linux"):
42
+ has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY")
43
+ logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}")
44
+ if not has_display:
45
+ # Check for EGL or OSMesa libraries
46
+ logger.debug("_check_opengl_availability: checking for EGL library")
47
+ has_egl = ctypes.util.find_library("EGL")
48
+ logger.debug("_check_opengl_availability: checking for OSMesa library")
49
+ has_osmesa = ctypes.util.find_library("OSMesa")
50
+
51
+ # Error disabled for CI as it fails this check
52
+ # if not has_egl and not has_osmesa:
53
+ # raise RuntimeError(
54
+ # "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n"
55
+ # "See error below for installation instructions."
56
+ # )
57
+ logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}")
58
+
59
+ logger.debug("_check_opengl_availability: completed")
60
+
61
+
62
+ # Run early check at import time
63
+ logger.debug("nodes_glsl: running _check_opengl_availability at import time")
64
+ _check_opengl_availability()
65
+
66
+ # OpenGL modules - initialized lazily when context is created
67
+ gl = None
68
+ glfw = None
69
+ EGL = None
70
+
71
+
72
+ def _import_opengl():
73
+ """Import OpenGL module. Called after context is created."""
74
+ global gl
75
+ if gl is None:
76
+ logger.debug("_import_opengl: importing OpenGL.GL")
77
+ import OpenGL.GL as _gl
78
+ gl = _gl
79
+ logger.debug("_import_opengl: import completed")
80
+ return gl
81
+
82
+
83
+ class SizeModeInput(TypedDict):
84
+ size_mode: str
85
+ width: int
86
+ height: int
87
+
88
+
89
+ MAX_IMAGES = 5 # u_image0-4
90
+ MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
91
+ MAX_BOOLS = 10 # u_bool0-9
92
+ MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
93
+ MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
94
+
95
+ # Vertex shader using gl_VertexID trick - no VBO needed.
96
+ # Draws a single triangle that covers the entire screen:
97
+ #
98
+ # (-1,3)
99
+ # /|
100
+ # / | <- visible area is the unit square from (-1,-1) to (1,1)
101
+ # / | parts outside get clipped away
102
+ # (-1,-1)---(3,-1)
103
+ #
104
+ # v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1)
105
+ VERTEX_SHADER = """#version 330 core
106
+ out vec2 v_texCoord;
107
+ void main() {
108
+ vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3));
109
+ v_texCoord = verts[gl_VertexID] * 0.5 + 0.5;
110
+ gl_Position = vec4(verts[gl_VertexID], 0, 1);
111
+ }
112
+ """
113
+
114
+ DEFAULT_FRAGMENT_SHADER = """#version 300 es
115
+ precision highp float;
116
+
117
+ uniform sampler2D u_image0;
118
+ uniform vec2 u_resolution;
119
+
120
+ in vec2 v_texCoord;
121
+ layout(location = 0) out vec4 fragColor0;
122
+
123
+ void main() {
124
+ fragColor0 = texture(u_image0, v_texCoord);
125
+ }
126
+ """
127
+
128
+
129
+ def _convert_es_to_desktop(source: str) -> str:
130
+ """Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core."""
131
+ # Remove any existing #version directive
132
+ source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE)
133
+ # Remove precision qualifiers (not needed in desktop GLSL)
134
+ source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source)
135
+ # Prepend desktop GLSL version
136
+ return "#version 330 core\n" + source
137
+
138
+
139
+ def _detect_output_count(source: str) -> int:
140
+ """Detect how many fragColor outputs are used in the shader.
141
+
142
+ Returns the count of outputs needed (1 to MAX_OUTPUTS).
143
+ """
144
+ matches = re.findall(r"fragColor(\d+)", source)
145
+ if not matches:
146
+ return 1 # Default to 1 output if none found
147
+ max_index = max(int(m) for m in matches)
148
+ return min(max_index + 1, MAX_OUTPUTS)
149
+
150
+
151
+ def _detect_pass_count(source: str) -> int:
152
+ """Detect multi-pass rendering from #pragma passes N directive.
153
+
154
+ Returns the number of passes (1 if not specified).
155
+ """
156
+ match = re.search(r'#pragma\s+passes\s+(\d+)', source)
157
+ if match:
158
+ return max(1, int(match.group(1)))
159
+ return 1
160
+
161
+
162
+ def _init_glfw():
163
+ """Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure."""
164
+ logger.debug("_init_glfw: starting")
165
+ # On macOS, glfw.init() must be called from main thread or it hangs forever
166
+ if sys.platform == "darwin":
167
+ logger.debug("_init_glfw: skipping on macOS")
168
+ raise RuntimeError("GLFW backend not supported on macOS")
169
+
170
+ logger.debug("_init_glfw: importing glfw module")
171
+ import glfw as _glfw
172
+
173
+ logger.debug("_init_glfw: calling glfw.init()")
174
+ if not _glfw.init():
175
+ raise RuntimeError("glfw.init() failed")
176
+
177
+ try:
178
+ logger.debug("_init_glfw: setting window hints")
179
+ _glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE)
180
+ _glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3)
181
+ _glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3)
182
+ _glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE)
183
+
184
+ logger.debug("_init_glfw: calling create_window()")
185
+ window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None)
186
+ if not window:
187
+ raise RuntimeError("glfw.create_window() failed")
188
+
189
+ logger.debug("_init_glfw: calling make_context_current()")
190
+ _glfw.make_context_current(window)
191
+ logger.debug("_init_glfw: completed successfully")
192
+ return window, _glfw
193
+ except Exception:
194
+ logger.debug("_init_glfw: failed, terminating glfw")
195
+ _glfw.terminate()
196
+ raise
197
+
198
+
199
+ def _init_egl():
200
+ """Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure."""
201
+ logger.debug("_init_egl: starting")
202
+ from OpenGL import EGL as _EGL
203
+ from OpenGL.EGL import (
204
+ eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext,
205
+ eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI,
206
+ eglTerminate, eglDestroyContext, eglDestroySurface,
207
+ EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE,
208
+ EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
209
+ EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE,
210
+ EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API,
211
+ )
212
+ logger.debug("_init_egl: imports completed")
213
+
214
+ display = None
215
+ context = None
216
+ surface = None
217
+
218
+ try:
219
+ logger.debug("_init_egl: calling eglGetDisplay()")
220
+ display = eglGetDisplay(EGL_DEFAULT_DISPLAY)
221
+ if display == _EGL.EGL_NO_DISPLAY:
222
+ raise RuntimeError("eglGetDisplay() failed")
223
+
224
+ logger.debug("_init_egl: calling eglInitialize()")
225
+ major, minor = _EGL.EGLint(), _EGL.EGLint()
226
+ if not eglInitialize(display, major, minor):
227
+ display = None # Not initialized, don't terminate
228
+ raise RuntimeError("eglInitialize() failed")
229
+ logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}")
230
+
231
+ config_attribs = [
232
+ EGL_SURFACE_TYPE, EGL_PBUFFER_BIT,
233
+ EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT,
234
+ EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8,
235
+ EGL_DEPTH_SIZE, 0, EGL_NONE
236
+ ]
237
+ configs = (_EGL.EGLConfig * 1)()
238
+ num_configs = _EGL.EGLint()
239
+ if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0:
240
+ raise RuntimeError("eglChooseConfig() failed")
241
+ config = configs[0]
242
+ logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}")
243
+
244
+ if not eglBindAPI(EGL_OPENGL_API):
245
+ raise RuntimeError("eglBindAPI() failed")
246
+
247
+ logger.debug("_init_egl: calling eglCreateContext()")
248
+ context_attribs = [
249
+ _EGL.EGL_CONTEXT_MAJOR_VERSION, 3,
250
+ _EGL.EGL_CONTEXT_MINOR_VERSION, 3,
251
+ _EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT,
252
+ EGL_NONE
253
+ ]
254
+ context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs)
255
+ if context == EGL_NO_CONTEXT:
256
+ raise RuntimeError("eglCreateContext() failed")
257
+
258
+ logger.debug("_init_egl: calling eglCreatePbufferSurface()")
259
+ pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE]
260
+ surface = eglCreatePbufferSurface(display, config, pbuffer_attribs)
261
+ if surface == _EGL.EGL_NO_SURFACE:
262
+ raise RuntimeError("eglCreatePbufferSurface() failed")
263
+
264
+ logger.debug("_init_egl: calling eglMakeCurrent()")
265
+ if not eglMakeCurrent(display, surface, surface, context):
266
+ raise RuntimeError("eglMakeCurrent() failed")
267
+
268
+ logger.debug("_init_egl: completed successfully")
269
+ return display, context, surface, _EGL
270
+
271
+ except Exception:
272
+ logger.debug("_init_egl: failed, cleaning up")
273
+ # Clean up any resources on failure
274
+ if surface is not None:
275
+ eglDestroySurface(display, surface)
276
+ if context is not None:
277
+ eglDestroyContext(display, context)
278
+ if display is not None:
279
+ eglTerminate(display)
280
+ raise
281
+
282
+
283
+ def _init_osmesa():
284
+ """Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure."""
285
+ import ctypes
286
+
287
+ logger.debug("_init_osmesa: starting")
288
+ os.environ["PYOPENGL_PLATFORM"] = "osmesa"
289
+
290
+ logger.debug("_init_osmesa: importing OpenGL.osmesa")
291
+ from OpenGL import GL as _gl
292
+ from OpenGL.osmesa import (
293
+ OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext,
294
+ OSMESA_RGBA,
295
+ )
296
+ logger.debug("_init_osmesa: imports completed")
297
+
298
+ ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None)
299
+ if not ctx:
300
+ raise RuntimeError("OSMesaCreateContextExt() failed")
301
+
302
+ width, height = 64, 64
303
+ buffer = (ctypes.c_ubyte * (width * height * 4))()
304
+
305
+ logger.debug("_init_osmesa: calling OSMesaMakeCurrent()")
306
+ if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height):
307
+ OSMesaDestroyContext(ctx)
308
+ raise RuntimeError("OSMesaMakeCurrent() failed")
309
+
310
+ logger.debug("_init_osmesa: completed successfully")
311
+ return ctx, buffer
312
+
313
+
314
+ class GLContext:
315
+ """Manages OpenGL context and resources for shader execution.
316
+
317
+ Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software).
318
+ """
319
+
320
+ _instance = None
321
+ _initialized = False
322
+
323
+ def __new__(cls):
324
+ if cls._instance is None:
325
+ cls._instance = super().__new__(cls)
326
+ return cls._instance
327
+
328
+ def __init__(self):
329
+ if GLContext._initialized:
330
+ logger.debug("GLContext.__init__: already initialized, skipping")
331
+ return
332
+
333
+ logger.debug("GLContext.__init__: starting initialization")
334
+
335
+ global glfw, EGL
336
+
337
+ import time
338
+ start = time.perf_counter()
339
+
340
+ self._backend = None
341
+ self._window = None
342
+ self._egl_display = None
343
+ self._egl_context = None
344
+ self._egl_surface = None
345
+ self._osmesa_ctx = None
346
+ self._osmesa_buffer = None
347
+ self._vao = None
348
+
349
+ # Try backends in order: GLFW → EGL → OSMesa
350
+ errors = []
351
+
352
+ logger.debug("GLContext.__init__: trying GLFW backend")
353
+ try:
354
+ self._window, glfw = _init_glfw()
355
+ self._backend = "glfw"
356
+ logger.debug("GLContext.__init__: GLFW backend succeeded")
357
+ except Exception as e:
358
+ logger.debug(f"GLContext.__init__: GLFW backend failed: {e}")
359
+ errors.append(("GLFW", e))
360
+
361
+ if self._backend is None:
362
+ logger.debug("GLContext.__init__: trying EGL backend")
363
+ try:
364
+ self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl()
365
+ self._backend = "egl"
366
+ logger.debug("GLContext.__init__: EGL backend succeeded")
367
+ except Exception as e:
368
+ logger.debug(f"GLContext.__init__: EGL backend failed: {e}")
369
+ errors.append(("EGL", e))
370
+
371
+ if self._backend is None:
372
+ logger.debug("GLContext.__init__: trying OSMesa backend")
373
+ try:
374
+ self._osmesa_ctx, self._osmesa_buffer = _init_osmesa()
375
+ self._backend = "osmesa"
376
+ logger.debug("GLContext.__init__: OSMesa backend succeeded")
377
+ except Exception as e:
378
+ logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}")
379
+ errors.append(("OSMesa", e))
380
+
381
+ if self._backend is None:
382
+ if sys.platform == "win32":
383
+ platform_help = (
384
+ "Windows: Ensure GPU drivers are installed and display is available.\n"
385
+ " CPU-only/headless mode is not supported on Windows."
386
+ )
387
+ elif sys.platform == "darwin":
388
+ platform_help = (
389
+ "macOS: GLFW is not supported.\n"
390
+ " Install OSMesa via Homebrew: brew install mesa\n"
391
+ " Then: pip install PyOpenGL PyOpenGL-accelerate"
392
+ )
393
+ else:
394
+ platform_help = (
395
+ "Linux: Install one of these backends:\n"
396
+ " Desktop: sudo apt install libgl1-mesa-glx libglfw3\n"
397
+ " Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n"
398
+ " Headless (CPU): sudo apt install libosmesa6"
399
+ )
400
+
401
+ error_details = "\n".join(f" {name}: {err}" for name, err in errors)
402
+ raise RuntimeError(
403
+ f"Failed to create OpenGL context.\n\n"
404
+ f"Backend errors:\n{error_details}\n\n"
405
+ f"{platform_help}"
406
+ )
407
+
408
+ # Now import OpenGL.GL (after context is current)
409
+ logger.debug("GLContext.__init__: importing OpenGL.GL")
410
+ _import_opengl()
411
+
412
+ # Create VAO (required for core profile, but OSMesa may use compat profile)
413
+ logger.debug("GLContext.__init__: creating VAO")
414
+ try:
415
+ vao = gl.glGenVertexArrays(1)
416
+ gl.glBindVertexArray(vao)
417
+ self._vao = vao # Only store after successful bind
418
+ logger.debug("GLContext.__init__: VAO created successfully")
419
+ except Exception as e:
420
+ logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}")
421
+ # OSMesa with older Mesa may not support VAOs
422
+ # Clean up if we created but couldn't bind
423
+ if vao:
424
+ try:
425
+ gl.glDeleteVertexArrays(1, [vao])
426
+ except Exception:
427
+ pass
428
+
429
+ elapsed = (time.perf_counter() - start) * 1000
430
+
431
+ # Log device info
432
+ renderer = gl.glGetString(gl.GL_RENDERER)
433
+ vendor = gl.glGetString(gl.GL_VENDOR)
434
+ version = gl.glGetString(gl.GL_VERSION)
435
+ renderer = renderer.decode() if renderer else "Unknown"
436
+ vendor = vendor.decode() if vendor else "Unknown"
437
+ version = version.decode() if version else "Unknown"
438
+
439
+ GLContext._initialized = True
440
+ logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}")
441
+
442
+ def make_current(self):
443
+ if self._backend == "glfw":
444
+ glfw.make_context_current(self._window)
445
+ elif self._backend == "egl":
446
+ from OpenGL.EGL import eglMakeCurrent
447
+ eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context)
448
+ elif self._backend == "osmesa":
449
+ from OpenGL.osmesa import OSMesaMakeCurrent
450
+ OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64)
451
+
452
+ if self._vao is not None:
453
+ gl.glBindVertexArray(self._vao)
454
+
455
+
456
+ def _compile_shader(source: str, shader_type: int) -> int:
457
+ """Compile a shader and return its ID."""
458
+ shader = gl.glCreateShader(shader_type)
459
+ gl.glShaderSource(shader, source)
460
+ gl.glCompileShader(shader)
461
+
462
+ if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE:
463
+ error = gl.glGetShaderInfoLog(shader).decode()
464
+ gl.glDeleteShader(shader)
465
+ raise RuntimeError(f"Shader compilation failed:\n{error}")
466
+
467
+ return shader
468
+
469
+
470
+ def _create_program(vertex_source: str, fragment_source: str) -> int:
471
+ """Create and link a shader program."""
472
+ vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER)
473
+ try:
474
+ fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER)
475
+ except RuntimeError:
476
+ gl.glDeleteShader(vertex_shader)
477
+ raise
478
+
479
+ program = gl.glCreateProgram()
480
+ gl.glAttachShader(program, vertex_shader)
481
+ gl.glAttachShader(program, fragment_shader)
482
+ gl.glLinkProgram(program)
483
+
484
+ gl.glDeleteShader(vertex_shader)
485
+ gl.glDeleteShader(fragment_shader)
486
+
487
+ if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE:
488
+ error = gl.glGetProgramInfoLog(program).decode()
489
+ gl.glDeleteProgram(program)
490
+ raise RuntimeError(f"Program linking failed:\n{error}")
491
+
492
+ return program
493
+
494
+
495
+ def _render_shader_batch(
496
+ fragment_code: str,
497
+ width: int,
498
+ height: int,
499
+ image_batches: list[list[np.ndarray]],
500
+ floats: list[float],
501
+ ints: list[int],
502
+ bools: list[bool] | None = None,
503
+ curves: list[np.ndarray] | None = None,
504
+ ) -> list[list[np.ndarray]]:
505
+ """
506
+ Render a fragment shader for multiple batches efficiently.
507
+
508
+ Compiles shader once, reuses framebuffer/textures across batches.
509
+ Supports multi-pass rendering via #pragma passes N directive.
510
+
511
+ Args:
512
+ fragment_code: User's fragment shader code
513
+ width: Output width
514
+ height: Output height
515
+ image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
516
+ floats: List of float uniforms
517
+ ints: List of int uniforms
518
+ bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
519
+ curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
520
+
521
+ Returns:
522
+ List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
523
+ """
524
+ import time
525
+ start_time = time.perf_counter()
526
+
527
+ if not image_batches:
528
+ return []
529
+
530
+ ctx = GLContext()
531
+ ctx.make_current()
532
+
533
+ # Convert from GLSL ES to desktop GLSL 330
534
+ fragment_source = _convert_es_to_desktop(fragment_code)
535
+
536
+ # Detect how many outputs the shader actually uses
537
+ num_outputs = _detect_output_count(fragment_code)
538
+
539
+ # Detect multi-pass rendering
540
+ num_passes = _detect_pass_count(fragment_code)
541
+
542
+ if bools is None:
543
+ bools = []
544
+ if curves is None:
545
+ curves = []
546
+
547
+ # Track resources for cleanup
548
+ program = None
549
+ fbo = None
550
+ output_textures = []
551
+ input_textures = []
552
+ curve_textures = []
553
+ ping_pong_textures = []
554
+ ping_pong_fbos = []
555
+
556
+ num_inputs = len(image_batches[0])
557
+
558
+ try:
559
+ # Compile shaders (once for all batches)
560
+ try:
561
+ program = _create_program(VERTEX_SHADER, fragment_source)
562
+ except RuntimeError:
563
+ logger.error(f"Fragment shader:\n{fragment_source}")
564
+ raise
565
+
566
+ gl.glUseProgram(program)
567
+
568
+ # Create framebuffer with only the needed color attachments
569
+ fbo = gl.glGenFramebuffers(1)
570
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
571
+
572
+ draw_buffers = []
573
+ for i in range(num_outputs):
574
+ tex = gl.glGenTextures(1)
575
+ output_textures.append(tex)
576
+ gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
577
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
578
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
579
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
580
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0)
581
+ draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i)
582
+
583
+ gl.glDrawBuffers(num_outputs, draw_buffers)
584
+
585
+ if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
586
+ raise RuntimeError("Framebuffer is not complete")
587
+
588
+ # Create ping-pong resources for multi-pass rendering
589
+ if num_passes > 1:
590
+ for _ in range(2):
591
+ pp_tex = gl.glGenTextures(1)
592
+ ping_pong_textures.append(pp_tex)
593
+ gl.glBindTexture(gl.GL_TEXTURE_2D, pp_tex)
594
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None)
595
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
596
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
597
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
598
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
599
+
600
+ pp_fbo = gl.glGenFramebuffers(1)
601
+ ping_pong_fbos.append(pp_fbo)
602
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, pp_fbo)
603
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, pp_tex, 0)
604
+ gl.glDrawBuffers(1, [gl.GL_COLOR_ATTACHMENT0])
605
+
606
+ if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE:
607
+ raise RuntimeError("Ping-pong framebuffer is not complete")
608
+
609
+ # Create input textures (reused for all batches)
610
+ for i in range(num_inputs):
611
+ tex = gl.glGenTextures(1)
612
+ input_textures.append(tex)
613
+ gl.glActiveTexture(gl.GL_TEXTURE0 + i)
614
+ gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
615
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
616
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
617
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
618
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
619
+
620
+ loc = gl.glGetUniformLocation(program, f"u_image{i}")
621
+ if loc >= 0:
622
+ gl.glUniform1i(loc, i)
623
+
624
+ # Set static uniforms (once for all batches)
625
+ loc = gl.glGetUniformLocation(program, "u_resolution")
626
+ if loc >= 0:
627
+ gl.glUniform2f(loc, float(width), float(height))
628
+
629
+ for i, v in enumerate(floats):
630
+ loc = gl.glGetUniformLocation(program, f"u_float{i}")
631
+ if loc >= 0:
632
+ gl.glUniform1f(loc, v)
633
+
634
+ for i, v in enumerate(ints):
635
+ loc = gl.glGetUniformLocation(program, f"u_int{i}")
636
+ if loc >= 0:
637
+ gl.glUniform1i(loc, v)
638
+
639
+ for i, v in enumerate(bools):
640
+ loc = gl.glGetUniformLocation(program, f"u_bool{i}")
641
+ if loc >= 0:
642
+ gl.glUniform1i(loc, 1 if v else 0)
643
+
644
+ # Create 1D LUT textures for curves (bound after image texture units)
645
+ for i, lut in enumerate(curves):
646
+ tex = gl.glGenTextures(1)
647
+ curve_textures.append(tex)
648
+ unit = MAX_IMAGES + i
649
+ gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
650
+ gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
651
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
652
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
653
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
654
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
655
+ gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
656
+
657
+ loc = gl.glGetUniformLocation(program, f"u_curve{i}")
658
+ if loc >= 0:
659
+ gl.glUniform1i(loc, unit)
660
+
661
+ # Get u_pass uniform location for multi-pass
662
+ pass_loc = gl.glGetUniformLocation(program, "u_pass")
663
+
664
+ gl.glViewport(0, 0, width, height)
665
+ gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly
666
+
667
+ # Process each batch
668
+ all_batch_outputs = []
669
+ for images in image_batches:
670
+ # Update input textures with this batch's images
671
+ for i, img in enumerate(images):
672
+ gl.glActiveTexture(gl.GL_TEXTURE0 + i)
673
+ gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i])
674
+
675
+ # Flip vertically for GL coordinates, ensure RGBA
676
+ h, w, c = img.shape
677
+ if c == 3:
678
+ img_upload = np.empty((h, w, 4), dtype=np.float32)
679
+ img_upload[:, :, :3] = img[::-1, :, :]
680
+ img_upload[:, :, 3] = 1.0
681
+ else:
682
+ img_upload = np.ascontiguousarray(img[::-1, :, :])
683
+
684
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload)
685
+
686
+ if num_passes == 1:
687
+ # Single pass - render directly to output FBO
688
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
689
+ if pass_loc >= 0:
690
+ gl.glUniform1i(pass_loc, 0)
691
+ gl.glClearColor(0, 0, 0, 0)
692
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
693
+ gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
694
+ else:
695
+ # Multi-pass rendering with ping-pong
696
+ for p in range(num_passes):
697
+ is_last_pass = (p == num_passes - 1)
698
+
699
+ # Set pass uniform
700
+ if pass_loc >= 0:
701
+ gl.glUniform1i(pass_loc, p)
702
+
703
+ if is_last_pass:
704
+ # Last pass renders to the main output FBO
705
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
706
+ else:
707
+ # Intermediate passes render to ping-pong FBO
708
+ target_fbo = ping_pong_fbos[p % 2]
709
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, target_fbo)
710
+
711
+ # Set input texture for this pass
712
+ gl.glActiveTexture(gl.GL_TEXTURE0)
713
+ if p == 0:
714
+ # First pass reads from original input
715
+ gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[0])
716
+ else:
717
+ # Subsequent passes read from previous pass output
718
+ source_tex = ping_pong_textures[(p - 1) % 2]
719
+ gl.glBindTexture(gl.GL_TEXTURE_2D, source_tex)
720
+
721
+ gl.glClearColor(0, 0, 0, 0)
722
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT)
723
+ gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3)
724
+
725
+ # Read back outputs for this batch
726
+ # (glGetTexImage is synchronous, implicitly waits for rendering)
727
+ batch_outputs = []
728
+ for tex in output_textures:
729
+ gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
730
+ data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT)
731
+ img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4)
732
+ batch_outputs.append(img[::-1, :, :].copy())
733
+
734
+ # Pad with black images for unused outputs
735
+ black_img = np.zeros((height, width, 4), dtype=np.float32)
736
+ for _ in range(num_outputs, MAX_OUTPUTS):
737
+ batch_outputs.append(black_img)
738
+
739
+ all_batch_outputs.append(batch_outputs)
740
+
741
+ elapsed = (time.perf_counter() - start_time) * 1000
742
+ num_batches = len(image_batches)
743
+ pass_info = f", {num_passes} passes" if num_passes > 1 else ""
744
+ logger.info(f"GLSL shader executed in {elapsed:.1f}ms ({num_batches} batch{'es' if num_batches != 1 else ''}, {width}x{height}{pass_info})")
745
+
746
+ return all_batch_outputs
747
+
748
+ finally:
749
+ # Unbind before deleting
750
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
751
+ gl.glUseProgram(0)
752
+
753
+ for tex in input_textures:
754
+ gl.glDeleteTextures(int(tex))
755
+ for tex in curve_textures:
756
+ gl.glDeleteTextures(int(tex))
757
+ for tex in output_textures:
758
+ gl.glDeleteTextures(int(tex))
759
+ for tex in ping_pong_textures:
760
+ gl.glDeleteTextures(int(tex))
761
+ if fbo is not None:
762
+ gl.glDeleteFramebuffers(1, [fbo])
763
+ for pp_fbo in ping_pong_fbos:
764
+ gl.glDeleteFramebuffers(1, [pp_fbo])
765
+ if program is not None:
766
+ gl.glDeleteProgram(program)
767
+
768
+ class GLSLShader(io.ComfyNode):
769
+
770
+ @classmethod
771
+ def define_schema(cls) -> io.Schema:
772
+ image_template = io.Autogrow.TemplatePrefix(
773
+ io.Image.Input("image"),
774
+ prefix="image",
775
+ min=1,
776
+ max=MAX_IMAGES,
777
+ )
778
+
779
+ float_template = io.Autogrow.TemplatePrefix(
780
+ io.Float.Input("float", default=0.0),
781
+ prefix="u_float",
782
+ min=0,
783
+ max=MAX_UNIFORMS,
784
+ )
785
+
786
+ int_template = io.Autogrow.TemplatePrefix(
787
+ io.Int.Input("int", default=0),
788
+ prefix="u_int",
789
+ min=0,
790
+ max=MAX_UNIFORMS,
791
+ )
792
+
793
+ bool_template = io.Autogrow.TemplatePrefix(
794
+ io.Boolean.Input("bool", default=False),
795
+ prefix="u_bool",
796
+ min=0,
797
+ max=MAX_BOOLS,
798
+ )
799
+
800
+ curve_template = io.Autogrow.TemplatePrefix(
801
+ io.Curve.Input("curve"),
802
+ prefix="u_curve",
803
+ min=0,
804
+ max=MAX_CURVES,
805
+ )
806
+
807
+ return io.Schema(
808
+ node_id="GLSLShader",
809
+ display_name="GLSL Shader",
810
+ category="image/shader",
811
+ description=(
812
+ "Apply GLSL ES fragment shaders to images. "
813
+ "u_resolution (vec2) is always available."
814
+ ),
815
+ is_experimental=True,
816
+ has_intermediate_output=True,
817
+ inputs=[
818
+ io.String.Input(
819
+ "fragment_shader",
820
+ default=DEFAULT_FRAGMENT_SHADER,
821
+ multiline=True,
822
+ tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)",
823
+ ),
824
+ io.DynamicCombo.Input(
825
+ "size_mode",
826
+ options=[
827
+ io.DynamicCombo.Option("from_input", []),
828
+ io.DynamicCombo.Option(
829
+ "custom",
830
+ [
831
+ io.Int.Input(
832
+ "width",
833
+ default=512,
834
+ min=1,
835
+ max=nodes.MAX_RESOLUTION,
836
+ ),
837
+ io.Int.Input(
838
+ "height",
839
+ default=512,
840
+ min=1,
841
+ max=nodes.MAX_RESOLUTION,
842
+ ),
843
+ ],
844
+ ),
845
+ ],
846
+ tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size",
847
+ ),
848
+ io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
849
+ io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
850
+ io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
851
+ io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
852
+ io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
853
+ ],
854
+ outputs=[
855
+ io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
856
+ io.Image.Output(display_name="IMAGE1", tooltip="Available via layout(location = 1) out vec4 fragColor1 in the shader code"),
857
+ io.Image.Output(display_name="IMAGE2", tooltip="Available via layout(location = 2) out vec4 fragColor2 in the shader code"),
858
+ io.Image.Output(display_name="IMAGE3", tooltip="Available via layout(location = 3) out vec4 fragColor3 in the shader code"),
859
+ ],
860
+ )
861
+
862
+ @classmethod
863
+ def execute(
864
+ cls,
865
+ fragment_shader: str,
866
+ size_mode: SizeModeInput,
867
+ images: io.Autogrow.Type,
868
+ floats: io.Autogrow.Type = None,
869
+ ints: io.Autogrow.Type = None,
870
+ bools: io.Autogrow.Type = None,
871
+ curves: io.Autogrow.Type = None,
872
+ **kwargs,
873
+ ) -> io.NodeOutput:
874
+
875
+ image_list = [v for v in images.values() if v is not None]
876
+ float_list = (
877
+ [v if v is not None else 0.0 for v in floats.values()] if floats else []
878
+ )
879
+ int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
880
+ bool_list = [v if v is not None else False for v in bools.values()] if bools else []
881
+
882
+ curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
883
+
884
+ if not image_list:
885
+ raise ValueError("At least one input image is required")
886
+
887
+ # Determine output dimensions
888
+ if size_mode["size_mode"] == "custom":
889
+ out_width = size_mode["width"]
890
+ out_height = size_mode["height"]
891
+ else:
892
+ out_height, out_width = image_list[0].shape[1:3]
893
+
894
+ batch_size = image_list[0].shape[0]
895
+
896
+ # Prepare batches
897
+ image_batches = []
898
+ for batch_idx in range(batch_size):
899
+ batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list]
900
+ image_batches.append(batch_images)
901
+
902
+ all_batch_outputs = _render_shader_batch(
903
+ fragment_shader,
904
+ out_width,
905
+ out_height,
906
+ image_batches,
907
+ float_list,
908
+ int_list,
909
+ bool_list,
910
+ curve_luts,
911
+ )
912
+
913
+ # Collect outputs into tensors
914
+ all_outputs = [[] for _ in range(MAX_OUTPUTS)]
915
+ for batch_outputs in all_batch_outputs:
916
+ for i, out_img in enumerate(batch_outputs):
917
+ all_outputs[i].append(torch.from_numpy(out_img))
918
+
919
+ output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)]
920
+ return io.NodeOutput(
921
+ *output_tensors,
922
+ ui=cls._build_ui_output(image_list, output_tensors[0]),
923
+ )
924
+
925
+ @classmethod
926
+ def _build_ui_output(
927
+ cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
928
+ ) -> dict[str, list]:
929
+ """Build UI output with input and output images for client-side shader execution."""
930
+ input_images_ui = []
931
+ for img in image_list:
932
+ input_images_ui.extend(ui.ImageSaveHelper.save_images(
933
+ img,
934
+ filename_prefix="GLSLShader_input",
935
+ folder_type=io.FolderType.temp,
936
+ cls=None,
937
+ compress_level=1,
938
+ ))
939
+
940
+ output_images_ui = ui.ImageSaveHelper.save_images(
941
+ output_batch,
942
+ filename_prefix="GLSLShader_output",
943
+ folder_type=io.FolderType.temp,
944
+ cls=None,
945
+ compress_level=1,
946
+ )
947
+
948
+ return {"input_images": input_images_ui, "images": output_images_ui}
949
+
950
+
951
+ class GLSLExtension(ComfyExtension):
952
+ @override
953
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
954
+ return [GLSLShader]
955
+
956
+
957
+ async def comfy_entrypoint() -> GLSLExtension:
958
+ return GLSLExtension()
ComfyUI/comfy_extras/nodes_hidream.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+
3
+ import folder_paths
4
+ import comfy.sd
5
+ import comfy.model_management
6
+ from comfy_api.latest import ComfyExtension, io
7
+
8
+
9
+ class QuadrupleCLIPLoader(io.ComfyNode):
10
+ @classmethod
11
+ def define_schema(cls):
12
+ return io.Schema(
13
+ node_id="QuadrupleCLIPLoader",
14
+ category="advanced/loaders",
15
+ description="[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct",
16
+ inputs=[
17
+ io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
18
+ io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
19
+ io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
20
+ io.Combo.Input("clip_name4", options=folder_paths.get_filename_list("text_encoders")),
21
+ ],
22
+ outputs=[
23
+ io.Clip.Output(),
24
+ ]
25
+ )
26
+
27
+ @classmethod
28
+ def execute(cls, clip_name1, clip_name2, clip_name3, clip_name4):
29
+ clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
30
+ clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
31
+ clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
32
+ clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
33
+ clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
34
+ return io.NodeOutput(clip)
35
+
36
+ class CLIPTextEncodeHiDream(io.ComfyNode):
37
+ @classmethod
38
+ def define_schema(cls):
39
+ return io.Schema(
40
+ node_id="CLIPTextEncodeHiDream",
41
+ search_aliases=["hidream prompt"],
42
+ category="advanced/conditioning",
43
+ inputs=[
44
+ io.Clip.Input("clip"),
45
+ io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
46
+ io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
47
+ io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
48
+ io.String.Input("llama", multiline=True, dynamic_prompts=True),
49
+ ],
50
+ outputs=[
51
+ io.Conditioning.Output(),
52
+ ]
53
+ )
54
+
55
+ @classmethod
56
+ def execute(cls, clip, clip_l, clip_g, t5xxl, llama):
57
+ tokens = clip.tokenize(clip_g)
58
+ tokens["l"] = clip.tokenize(clip_l)["l"]
59
+ tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
60
+ tokens["llama"] = clip.tokenize(llama)["llama"]
61
+ return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
62
+
63
+
64
+ class HiDreamExtension(ComfyExtension):
65
+ @override
66
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
67
+ return [
68
+ QuadrupleCLIPLoader,
69
+ CLIPTextEncodeHiDream,
70
+ ]
71
+
72
+
73
+ async def comfy_entrypoint() -> HiDreamExtension:
74
+ return HiDreamExtension()
ComfyUI/comfy_extras/nodes_hooks.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Union
3
+ import logging
4
+ import torch
5
+ from collections.abc import Iterable
6
+
7
+ if TYPE_CHECKING:
8
+ from comfy.sd import CLIP
9
+
10
+ import comfy.hooks
11
+ import comfy.sd
12
+ import comfy.utils
13
+ import folder_paths
14
+
15
+ ###########################################
16
+ # Mask, Combine, and Hook Conditioning
17
+ #------------------------------------------
18
+ class PairConditioningSetProperties:
19
+ NodeId = 'PairConditioningSetProperties'
20
+ NodeName = 'Cond Pair Set Props'
21
+ @classmethod
22
+ def INPUT_TYPES(s):
23
+ return {
24
+ "required": {
25
+ "positive_NEW": ("CONDITIONING", ),
26
+ "negative_NEW": ("CONDITIONING", ),
27
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
28
+ "set_cond_area": (["default", "mask bounds"],),
29
+ },
30
+ "optional": {
31
+ "mask": ("MASK", ),
32
+ "hooks": ("HOOKS",),
33
+ "timesteps": ("TIMESTEPS_RANGE",),
34
+ }
35
+ }
36
+
37
+ EXPERIMENTAL = True
38
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
39
+ RETURN_NAMES = ("positive", "negative")
40
+ CATEGORY = "advanced/hooks/cond pair"
41
+ FUNCTION = "set_properties"
42
+
43
+ def set_properties(self, positive_NEW, negative_NEW,
44
+ strength: float, set_cond_area: str,
45
+ mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
46
+ final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW],
47
+ strength=strength, set_cond_area=set_cond_area,
48
+ mask=mask, hooks=hooks, timesteps_range=timesteps)
49
+ return (final_positive, final_negative)
50
+
51
+ class PairConditioningSetPropertiesAndCombine:
52
+ NodeId = 'PairConditioningSetPropertiesAndCombine'
53
+ NodeName = 'Cond Pair Set Props Combine'
54
+ @classmethod
55
+ def INPUT_TYPES(s):
56
+ return {
57
+ "required": {
58
+ "positive": ("CONDITIONING", ),
59
+ "negative": ("CONDITIONING", ),
60
+ "positive_NEW": ("CONDITIONING", ),
61
+ "negative_NEW": ("CONDITIONING", ),
62
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
63
+ "set_cond_area": (["default", "mask bounds"],),
64
+ },
65
+ "optional": {
66
+ "mask": ("MASK", ),
67
+ "hooks": ("HOOKS",),
68
+ "timesteps": ("TIMESTEPS_RANGE",),
69
+ }
70
+ }
71
+
72
+ EXPERIMENTAL = True
73
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
74
+ RETURN_NAMES = ("positive", "negative")
75
+ CATEGORY = "advanced/hooks/cond pair"
76
+ FUNCTION = "set_properties"
77
+
78
+ def set_properties(self, positive, negative, positive_NEW, negative_NEW,
79
+ strength: float, set_cond_area: str,
80
+ mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
81
+ final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW],
82
+ strength=strength, set_cond_area=set_cond_area,
83
+ mask=mask, hooks=hooks, timesteps_range=timesteps)
84
+ return (final_positive, final_negative)
85
+
86
+ class ConditioningSetProperties:
87
+ NodeId = 'ConditioningSetProperties'
88
+ NodeName = 'Cond Set Props'
89
+ @classmethod
90
+ def INPUT_TYPES(s):
91
+ return {
92
+ "required": {
93
+ "cond_NEW": ("CONDITIONING", ),
94
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
95
+ "set_cond_area": (["default", "mask bounds"],),
96
+ },
97
+ "optional": {
98
+ "mask": ("MASK", ),
99
+ "hooks": ("HOOKS",),
100
+ "timesteps": ("TIMESTEPS_RANGE",),
101
+ }
102
+ }
103
+
104
+ EXPERIMENTAL = True
105
+ RETURN_TYPES = ("CONDITIONING",)
106
+ CATEGORY = "advanced/hooks/cond single"
107
+ FUNCTION = "set_properties"
108
+
109
+ def set_properties(self, cond_NEW,
110
+ strength: float, set_cond_area: str,
111
+ mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
112
+ (final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW],
113
+ strength=strength, set_cond_area=set_cond_area,
114
+ mask=mask, hooks=hooks, timesteps_range=timesteps)
115
+ return (final_cond,)
116
+
117
+ class ConditioningSetPropertiesAndCombine:
118
+ NodeId = 'ConditioningSetPropertiesAndCombine'
119
+ NodeName = 'Cond Set Props Combine'
120
+ @classmethod
121
+ def INPUT_TYPES(s):
122
+ return {
123
+ "required": {
124
+ "cond": ("CONDITIONING", ),
125
+ "cond_NEW": ("CONDITIONING", ),
126
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
127
+ "set_cond_area": (["default", "mask bounds"],),
128
+ },
129
+ "optional": {
130
+ "mask": ("MASK", ),
131
+ "hooks": ("HOOKS",),
132
+ "timesteps": ("TIMESTEPS_RANGE",),
133
+ }
134
+ }
135
+
136
+ EXPERIMENTAL = True
137
+ RETURN_TYPES = ("CONDITIONING",)
138
+ CATEGORY = "advanced/hooks/cond single"
139
+ FUNCTION = "set_properties"
140
+
141
+ def set_properties(self, cond, cond_NEW,
142
+ strength: float, set_cond_area: str,
143
+ mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None):
144
+ (final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW],
145
+ strength=strength, set_cond_area=set_cond_area,
146
+ mask=mask, hooks=hooks, timesteps_range=timesteps)
147
+ return (final_cond,)
148
+
149
+ class PairConditioningCombine:
150
+ NodeId = 'PairConditioningCombine'
151
+ NodeName = 'Cond Pair Combine'
152
+ @classmethod
153
+ def INPUT_TYPES(s):
154
+ return {
155
+ "required": {
156
+ "positive_A": ("CONDITIONING",),
157
+ "negative_A": ("CONDITIONING",),
158
+ "positive_B": ("CONDITIONING",),
159
+ "negative_B": ("CONDITIONING",),
160
+ },
161
+ }
162
+
163
+ EXPERIMENTAL = True
164
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
165
+ RETURN_NAMES = ("positive", "negative")
166
+ CATEGORY = "advanced/hooks/cond pair"
167
+ FUNCTION = "combine"
168
+
169
+ def combine(self, positive_A, negative_A, positive_B, negative_B):
170
+ final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],)
171
+ return (final_positive, final_negative,)
172
+
173
+ class PairConditioningSetDefaultAndCombine:
174
+ NodeId = 'PairConditioningSetDefaultCombine'
175
+ NodeName = 'Cond Pair Set Default Combine'
176
+ @classmethod
177
+ def INPUT_TYPES(s):
178
+ return {
179
+ "required": {
180
+ "positive": ("CONDITIONING",),
181
+ "negative": ("CONDITIONING",),
182
+ "positive_DEFAULT": ("CONDITIONING",),
183
+ "negative_DEFAULT": ("CONDITIONING",),
184
+ },
185
+ "optional": {
186
+ "hooks": ("HOOKS",),
187
+ }
188
+ }
189
+
190
+ EXPERIMENTAL = True
191
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
192
+ RETURN_NAMES = ("positive", "negative")
193
+ CATEGORY = "advanced/hooks/cond pair"
194
+ FUNCTION = "set_default_and_combine"
195
+
196
+ def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT,
197
+ hooks: comfy.hooks.HookGroup=None):
198
+ final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT],
199
+ hooks=hooks)
200
+ return (final_positive, final_negative)
201
+
202
+ class ConditioningSetDefaultAndCombine:
203
+ NodeId = 'ConditioningSetDefaultCombine'
204
+ NodeName = 'Cond Set Default Combine'
205
+ @classmethod
206
+ def INPUT_TYPES(s):
207
+ return {
208
+ "required": {
209
+ "cond": ("CONDITIONING",),
210
+ "cond_DEFAULT": ("CONDITIONING",),
211
+ },
212
+ "optional": {
213
+ "hooks": ("HOOKS",),
214
+ }
215
+ }
216
+
217
+ EXPERIMENTAL = True
218
+ RETURN_TYPES = ("CONDITIONING",)
219
+ CATEGORY = "advanced/hooks/cond single"
220
+ FUNCTION = "set_default_and_combine"
221
+
222
+ def set_default_and_combine(self, cond, cond_DEFAULT,
223
+ hooks: comfy.hooks.HookGroup=None):
224
+ (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT],
225
+ hooks=hooks)
226
+ return (final_conditioning,)
227
+
228
+ class SetClipHooks:
229
+ NodeId = 'SetClipHooks'
230
+ NodeName = 'Set CLIP Hooks'
231
+ @classmethod
232
+ def INPUT_TYPES(s):
233
+ return {
234
+ "required": {
235
+ "clip": ("CLIP",),
236
+ "apply_to_conds": ("BOOLEAN", {"default": True, "advanced": True}),
237
+ "schedule_clip": ("BOOLEAN", {"default": False, "advanced": True})
238
+ },
239
+ "optional": {
240
+ "hooks": ("HOOKS",)
241
+ }
242
+ }
243
+
244
+ EXPERIMENTAL = True
245
+ RETURN_TYPES = ("CLIP",)
246
+ CATEGORY = "advanced/hooks/clip"
247
+ FUNCTION = "apply_hooks"
248
+
249
+ def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
250
+ if hooks is not None:
251
+ clip = clip.clone(disable_dynamic=True)
252
+ if apply_to_conds:
253
+ clip.apply_hooks_to_conds = hooks
254
+ clip.patcher.forced_hooks = hooks.clone()
255
+ clip.use_clip_schedule = schedule_clip
256
+ if not clip.use_clip_schedule:
257
+ clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
258
+ clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
259
+ return (clip,)
260
+
261
+ class ConditioningTimestepsRange:
262
+ SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"]
263
+ NodeId = 'ConditioningTimestepsRange'
264
+ NodeName = 'Timesteps Range'
265
+ @classmethod
266
+ def INPUT_TYPES(s):
267
+ return {
268
+ "required": {
269
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
270
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
271
+ },
272
+ }
273
+
274
+ EXPERIMENTAL = True
275
+ RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE")
276
+ RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE")
277
+ CATEGORY = "advanced/hooks"
278
+ FUNCTION = "create_range"
279
+
280
+ def create_range(self, start_percent: float, end_percent: float):
281
+ return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0))
282
+ #------------------------------------------
283
+ ###########################################
284
+
285
+
286
+ ###########################################
287
+ # Create Hooks
288
+ #------------------------------------------
289
+ class CreateHookLora:
290
+ NodeId = 'CreateHookLora'
291
+ NodeName = 'Create Hook LoRA'
292
+ def __init__(self):
293
+ self.loaded_lora = None
294
+
295
+ @classmethod
296
+ def INPUT_TYPES(s):
297
+ return {
298
+ "required": {
299
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
300
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
301
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
302
+ },
303
+ "optional": {
304
+ "prev_hooks": ("HOOKS",)
305
+ }
306
+ }
307
+
308
+ EXPERIMENTAL = True
309
+ RETURN_TYPES = ("HOOKS",)
310
+ CATEGORY = "advanced/hooks/create"
311
+ FUNCTION = "create_hook"
312
+
313
+ def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None):
314
+ if prev_hooks is None:
315
+ prev_hooks = comfy.hooks.HookGroup()
316
+ prev_hooks.clone()
317
+
318
+ if strength_model == 0 and strength_clip == 0:
319
+ return (prev_hooks,)
320
+
321
+ lora_path = folder_paths.get_full_path("loras", lora_name)
322
+ lora = None
323
+ if self.loaded_lora is not None:
324
+ if self.loaded_lora[0] == lora_path:
325
+ lora = self.loaded_lora[1]
326
+ else:
327
+ temp = self.loaded_lora
328
+ self.loaded_lora = None
329
+ del temp
330
+
331
+ if lora is None:
332
+ lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
333
+ self.loaded_lora = (lora_path, lora)
334
+
335
+ hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip)
336
+ return (prev_hooks.clone_and_combine(hooks),)
337
+
338
+ class CreateHookLoraModelOnly(CreateHookLora):
339
+ NodeId = 'CreateHookLoraModelOnly'
340
+ NodeName = 'Create Hook LoRA (MO)'
341
+ @classmethod
342
+ def INPUT_TYPES(s):
343
+ return {
344
+ "required": {
345
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
346
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
347
+ },
348
+ "optional": {
349
+ "prev_hooks": ("HOOKS",)
350
+ }
351
+ }
352
+
353
+ EXPERIMENTAL = True
354
+ RETURN_TYPES = ("HOOKS",)
355
+ CATEGORY = "advanced/hooks/create"
356
+ FUNCTION = "create_hook_model_only"
357
+
358
+ def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None):
359
+ return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks)
360
+
361
+ class CreateHookModelAsLora:
362
+ NodeId = 'CreateHookModelAsLora'
363
+ NodeName = 'Create Hook Model as LoRA'
364
+
365
+ def __init__(self):
366
+ # when not None, will be in following format:
367
+ # (ckpt_path: str, weights_model: dict, weights_clip: dict)
368
+ self.loaded_weights = None
369
+
370
+ @classmethod
371
+ def INPUT_TYPES(s):
372
+ return {
373
+ "required": {
374
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
375
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
376
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
377
+ },
378
+ "optional": {
379
+ "prev_hooks": ("HOOKS",)
380
+ }
381
+ }
382
+
383
+ EXPERIMENTAL = True
384
+ RETURN_TYPES = ("HOOKS",)
385
+ CATEGORY = "advanced/hooks/create"
386
+ FUNCTION = "create_hook"
387
+
388
+ def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float,
389
+ prev_hooks: comfy.hooks.HookGroup=None):
390
+ if prev_hooks is None:
391
+ prev_hooks = comfy.hooks.HookGroup()
392
+ prev_hooks.clone()
393
+
394
+ ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
395
+ weights_model = None
396
+ weights_clip = None
397
+ if self.loaded_weights is not None:
398
+ if self.loaded_weights[0] == ckpt_path:
399
+ weights_model = self.loaded_weights[1]
400
+ weights_clip = self.loaded_weights[2]
401
+ else:
402
+ temp = self.loaded_weights
403
+ self.loaded_weights = None
404
+ del temp
405
+
406
+ if weights_model is None:
407
+ out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
408
+ weights_model = comfy.hooks.get_patch_weights_from_model(out[0])
409
+ weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1])
410
+ self.loaded_weights = (ckpt_path, weights_model, weights_clip)
411
+
412
+ hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip,
413
+ strength_model=strength_model, strength_clip=strength_clip)
414
+ return (prev_hooks.clone_and_combine(hooks),)
415
+
416
+ class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora):
417
+ NodeId = 'CreateHookModelAsLoraModelOnly'
418
+ NodeName = 'Create Hook Model as LoRA (MO)'
419
+ @classmethod
420
+ def INPUT_TYPES(s):
421
+ return {
422
+ "required": {
423
+ "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
424
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
425
+ },
426
+ "optional": {
427
+ "prev_hooks": ("HOOKS",)
428
+ }
429
+ }
430
+
431
+ EXPERIMENTAL = True
432
+ RETURN_TYPES = ("HOOKS",)
433
+ CATEGORY = "advanced/hooks/create"
434
+ FUNCTION = "create_hook_model_only"
435
+
436
+ def create_hook_model_only(self, ckpt_name: str, strength_model: float,
437
+ prev_hooks: comfy.hooks.HookGroup=None):
438
+ return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks)
439
+ #------------------------------------------
440
+ ###########################################
441
+
442
+
443
+ ###########################################
444
+ # Schedule Hooks
445
+ #------------------------------------------
446
+ class SetHookKeyframes:
447
+ NodeId = 'SetHookKeyframes'
448
+ NodeName = 'Set Hook Keyframes'
449
+ @classmethod
450
+ def INPUT_TYPES(s):
451
+ return {
452
+ "required": {
453
+ "hooks": ("HOOKS",),
454
+ },
455
+ "optional": {
456
+ "hook_kf": ("HOOK_KEYFRAMES",),
457
+ }
458
+ }
459
+
460
+ EXPERIMENTAL = True
461
+ RETURN_TYPES = ("HOOKS",)
462
+ CATEGORY = "advanced/hooks/scheduling"
463
+ FUNCTION = "set_hook_keyframes"
464
+
465
+ def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None):
466
+ if hook_kf is not None:
467
+ hooks = hooks.clone()
468
+ hooks.set_keyframes_on_hooks(hook_kf=hook_kf)
469
+ return (hooks,)
470
+
471
+ class CreateHookKeyframe:
472
+ SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"]
473
+ NodeId = 'CreateHookKeyframe'
474
+ NodeName = 'Create Hook Keyframe'
475
+ @classmethod
476
+ def INPUT_TYPES(s):
477
+ return {
478
+ "required": {
479
+ "strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
480
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
481
+ },
482
+ "optional": {
483
+ "prev_hook_kf": ("HOOK_KEYFRAMES",),
484
+ }
485
+ }
486
+
487
+ EXPERIMENTAL = True
488
+ RETURN_TYPES = ("HOOK_KEYFRAMES",)
489
+ RETURN_NAMES = ("HOOK_KF",)
490
+ CATEGORY = "advanced/hooks/scheduling"
491
+ FUNCTION = "create_hook_keyframe"
492
+
493
+ def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
494
+ if prev_hook_kf is None:
495
+ prev_hook_kf = comfy.hooks.HookKeyframeGroup()
496
+ prev_hook_kf = prev_hook_kf.clone()
497
+ keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent)
498
+ prev_hook_kf.add(keyframe)
499
+ return (prev_hook_kf,)
500
+
501
+ class CreateHookKeyframesInterpolated:
502
+ SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"]
503
+ NodeId = 'CreateHookKeyframesInterpolated'
504
+ NodeName = 'Create Hook Keyframes Interp.'
505
+ @classmethod
506
+ def INPUT_TYPES(s):
507
+ return {
508
+ "required": {
509
+ "strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
510
+ "strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
511
+ "interpolation": (comfy.hooks.InterpolationMethod._LIST, ),
512
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
513
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
514
+ "keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}),
515
+ "print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}),
516
+ },
517
+ "optional": {
518
+ "prev_hook_kf": ("HOOK_KEYFRAMES",),
519
+ },
520
+ }
521
+
522
+ EXPERIMENTAL = True
523
+ RETURN_TYPES = ("HOOK_KEYFRAMES",)
524
+ RETURN_NAMES = ("HOOK_KF",)
525
+ CATEGORY = "advanced/hooks/scheduling"
526
+ FUNCTION = "create_hook_keyframes"
527
+
528
+ def create_hook_keyframes(self, strength_start: float, strength_end: float, interpolation: str,
529
+ start_percent: float, end_percent: float, keyframes_count: int,
530
+ print_keyframes=False, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None):
531
+ if prev_hook_kf is None:
532
+ prev_hook_kf = comfy.hooks.HookKeyframeGroup()
533
+ prev_hook_kf = prev_hook_kf.clone()
534
+ percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=keyframes_count,
535
+ method=comfy.hooks.InterpolationMethod.LINEAR)
536
+ strengths = comfy.hooks.InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=keyframes_count, method=interpolation)
537
+
538
+ is_first = True
539
+ for percent, strength in zip(percents, strengths):
540
+ guarantee_steps = 0
541
+ if is_first:
542
+ guarantee_steps = 1
543
+ is_first = False
544
+ prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
545
+ if print_keyframes:
546
+ logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
547
+ return (prev_hook_kf,)
548
+
549
+ class CreateHookKeyframesFromFloats:
550
+ SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"]
551
+ NodeId = 'CreateHookKeyframesFromFloats'
552
+ NodeName = 'Create Hook Keyframes From Floats'
553
+ @classmethod
554
+ def INPUT_TYPES(s):
555
+ return {
556
+ "required": {
557
+ "floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
558
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
559
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
560
+ "print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}),
561
+ },
562
+ "optional": {
563
+ "prev_hook_kf": ("HOOK_KEYFRAMES",),
564
+ }
565
+ }
566
+
567
+ EXPERIMENTAL = True
568
+ RETURN_TYPES = ("HOOK_KEYFRAMES",)
569
+ RETURN_NAMES = ("HOOK_KF",)
570
+ CATEGORY = "advanced/hooks/scheduling"
571
+ FUNCTION = "create_hook_keyframes"
572
+
573
+ def create_hook_keyframes(self, floats_strength: Union[float, list[float]],
574
+ start_percent: float, end_percent: float,
575
+ prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False):
576
+ if prev_hook_kf is None:
577
+ prev_hook_kf = comfy.hooks.HookKeyframeGroup()
578
+ prev_hook_kf = prev_hook_kf.clone()
579
+ if type(floats_strength) in (float, int):
580
+ floats_strength = [float(floats_strength)]
581
+ elif isinstance(floats_strength, Iterable):
582
+ pass
583
+ else:
584
+ raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.")
585
+ percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength),
586
+ method=comfy.hooks.InterpolationMethod.LINEAR)
587
+
588
+ is_first = True
589
+ for percent, strength in zip(percents, floats_strength):
590
+ guarantee_steps = 0
591
+ if is_first:
592
+ guarantee_steps = 1
593
+ is_first = False
594
+ prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
595
+ if print_keyframes:
596
+ logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
597
+ return (prev_hook_kf,)
598
+ #------------------------------------------
599
+ ###########################################
600
+
601
+
602
+ class SetModelHooksOnCond:
603
+ @classmethod
604
+ def INPUT_TYPES(s):
605
+ return {
606
+ "required": {
607
+ "conditioning": ("CONDITIONING",),
608
+ "hooks": ("HOOKS",),
609
+ },
610
+ }
611
+
612
+ EXPERIMENTAL = True
613
+ RETURN_TYPES = ("CONDITIONING",)
614
+ CATEGORY = "advanced/hooks/manual"
615
+ FUNCTION = "attach_hook"
616
+
617
+ def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup):
618
+ return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),)
619
+
620
+
621
+ ###########################################
622
+ # Combine Hooks
623
+ #------------------------------------------
624
+ class CombineHooks:
625
+ SEARCH_ALIASES = ["merge hooks"]
626
+ NodeId = 'CombineHooks2'
627
+ NodeName = 'Combine Hooks [2]'
628
+ @classmethod
629
+ def INPUT_TYPES(s):
630
+ return {
631
+ "required": {
632
+ },
633
+ "optional": {
634
+ "hooks_A": ("HOOKS",),
635
+ "hooks_B": ("HOOKS",),
636
+ }
637
+ }
638
+
639
+ EXPERIMENTAL = True
640
+ RETURN_TYPES = ("HOOKS",)
641
+ CATEGORY = "advanced/hooks/combine"
642
+ FUNCTION = "combine_hooks"
643
+
644
+ def combine_hooks(self,
645
+ hooks_A: comfy.hooks.HookGroup=None,
646
+ hooks_B: comfy.hooks.HookGroup=None):
647
+ candidates = [hooks_A, hooks_B]
648
+ return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
649
+
650
+ class CombineHooksFour:
651
+ NodeId = 'CombineHooks4'
652
+ NodeName = 'Combine Hooks [4]'
653
+ @classmethod
654
+ def INPUT_TYPES(s):
655
+ return {
656
+ "required": {
657
+ },
658
+ "optional": {
659
+ "hooks_A": ("HOOKS",),
660
+ "hooks_B": ("HOOKS",),
661
+ "hooks_C": ("HOOKS",),
662
+ "hooks_D": ("HOOKS",),
663
+ }
664
+ }
665
+
666
+ EXPERIMENTAL = True
667
+ RETURN_TYPES = ("HOOKS",)
668
+ CATEGORY = "advanced/hooks/combine"
669
+ FUNCTION = "combine_hooks"
670
+
671
+ def combine_hooks(self,
672
+ hooks_A: comfy.hooks.HookGroup=None,
673
+ hooks_B: comfy.hooks.HookGroup=None,
674
+ hooks_C: comfy.hooks.HookGroup=None,
675
+ hooks_D: comfy.hooks.HookGroup=None):
676
+ candidates = [hooks_A, hooks_B, hooks_C, hooks_D]
677
+ return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
678
+
679
+ class CombineHooksEight:
680
+ NodeId = 'CombineHooks8'
681
+ NodeName = 'Combine Hooks [8]'
682
+ @classmethod
683
+ def INPUT_TYPES(s):
684
+ return {
685
+ "required": {
686
+ },
687
+ "optional": {
688
+ "hooks_A": ("HOOKS",),
689
+ "hooks_B": ("HOOKS",),
690
+ "hooks_C": ("HOOKS",),
691
+ "hooks_D": ("HOOKS",),
692
+ "hooks_E": ("HOOKS",),
693
+ "hooks_F": ("HOOKS",),
694
+ "hooks_G": ("HOOKS",),
695
+ "hooks_H": ("HOOKS",),
696
+ }
697
+ }
698
+
699
+ EXPERIMENTAL = True
700
+ RETURN_TYPES = ("HOOKS",)
701
+ CATEGORY = "advanced/hooks/combine"
702
+ FUNCTION = "combine_hooks"
703
+
704
+ def combine_hooks(self,
705
+ hooks_A: comfy.hooks.HookGroup=None,
706
+ hooks_B: comfy.hooks.HookGroup=None,
707
+ hooks_C: comfy.hooks.HookGroup=None,
708
+ hooks_D: comfy.hooks.HookGroup=None,
709
+ hooks_E: comfy.hooks.HookGroup=None,
710
+ hooks_F: comfy.hooks.HookGroup=None,
711
+ hooks_G: comfy.hooks.HookGroup=None,
712
+ hooks_H: comfy.hooks.HookGroup=None):
713
+ candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H]
714
+ return (comfy.hooks.HookGroup.combine_all_hooks(candidates),)
715
+ #------------------------------------------
716
+ ###########################################
717
+
718
+ node_list = [
719
+ # Create
720
+ CreateHookLora,
721
+ CreateHookLoraModelOnly,
722
+ CreateHookModelAsLora,
723
+ CreateHookModelAsLoraModelOnly,
724
+ # Scheduling
725
+ SetHookKeyframes,
726
+ CreateHookKeyframe,
727
+ CreateHookKeyframesInterpolated,
728
+ CreateHookKeyframesFromFloats,
729
+ # Combine
730
+ CombineHooks,
731
+ CombineHooksFour,
732
+ CombineHooksEight,
733
+ # Attach
734
+ ConditioningSetProperties,
735
+ ConditioningSetPropertiesAndCombine,
736
+ PairConditioningSetProperties,
737
+ PairConditioningSetPropertiesAndCombine,
738
+ ConditioningSetDefaultAndCombine,
739
+ PairConditioningSetDefaultAndCombine,
740
+ PairConditioningCombine,
741
+ SetClipHooks,
742
+ # Other
743
+ ConditioningTimestepsRange,
744
+ ]
745
+ NODE_CLASS_MAPPINGS = {}
746
+ NODE_DISPLAY_NAME_MAPPINGS = {}
747
+
748
+ for node in node_list:
749
+ NODE_CLASS_MAPPINGS[node.NodeId] = node
750
+ NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
ComfyUI/comfy_extras/nodes_hunyuan.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import node_helpers
3
+ import torch
4
+ import comfy.model_management
5
+ from typing_extensions import override
6
+ from comfy_api.latest import ComfyExtension, io
7
+ from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
8
+ from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
9
+ import folder_paths
10
+ import json
11
+
12
+ class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
13
+ @classmethod
14
+ def define_schema(cls):
15
+ return io.Schema(
16
+ node_id="CLIPTextEncodeHunyuanDiT",
17
+ category="advanced/conditioning",
18
+ inputs=[
19
+ io.Clip.Input("clip"),
20
+ io.String.Input("bert", multiline=True, dynamic_prompts=True),
21
+ io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
22
+ ],
23
+ outputs=[
24
+ io.Conditioning.Output(),
25
+ ],
26
+ )
27
+
28
+ @classmethod
29
+ def execute(cls, clip, bert, mt5xl) -> io.NodeOutput:
30
+ tokens = clip.tokenize(bert)
31
+ tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
32
+
33
+ return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
34
+
35
+ encode = execute # TODO: remove
36
+
37
+
38
+ class EmptyHunyuanLatentVideo(io.ComfyNode):
39
+ @classmethod
40
+ def define_schema(cls):
41
+ return io.Schema(
42
+ node_id="EmptyHunyuanLatentVideo",
43
+ display_name="Empty HunyuanVideo 1.0 Latent",
44
+ category="latent/video",
45
+ inputs=[
46
+ io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
47
+ io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
48
+ io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
49
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
50
+ ],
51
+ outputs=[
52
+ io.Latent.Output(),
53
+ ],
54
+ )
55
+
56
+ @classmethod
57
+ def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
58
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
59
+ return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
60
+
61
+ generate = execute # TODO: remove
62
+
63
+
64
+ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
65
+ @classmethod
66
+ def define_schema(cls):
67
+ schema = super().define_schema()
68
+ schema.node_id = "EmptyHunyuanVideo15Latent"
69
+ schema.display_name = "Empty HunyuanVideo 1.5 Latent"
70
+ return schema
71
+
72
+ @classmethod
73
+ def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
74
+ # Using scale factor of 16 instead of 8
75
+ latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
76
+ return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16})
77
+
78
+
79
+ class HunyuanVideo15ImageToVideo(io.ComfyNode):
80
+ @classmethod
81
+ def define_schema(cls):
82
+ return io.Schema(
83
+ node_id="HunyuanVideo15ImageToVideo",
84
+ category="conditioning/video_models",
85
+ inputs=[
86
+ io.Conditioning.Input("positive"),
87
+ io.Conditioning.Input("negative"),
88
+ io.Vae.Input("vae"),
89
+ io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
90
+ io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
91
+ io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4),
92
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
93
+ io.Image.Input("start_image", optional=True),
94
+ io.ClipVisionOutput.Input("clip_vision_output", optional=True),
95
+ ],
96
+ outputs=[
97
+ io.Conditioning.Output(display_name="positive"),
98
+ io.Conditioning.Output(display_name="negative"),
99
+ io.Latent.Output(display_name="latent"),
100
+ ],
101
+ )
102
+
103
+ @classmethod
104
+ def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
105
+ latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
106
+
107
+ if start_image is not None:
108
+ start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
109
+
110
+ encoded = vae.encode(start_image[:, :, :, :3])
111
+ concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device())
112
+ concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded
113
+
114
+ mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
115
+ mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
116
+
117
+ positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
118
+ negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
119
+
120
+ if clip_vision_output is not None:
121
+ positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
122
+ negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
123
+
124
+ out_latent = {}
125
+ out_latent["samples"] = latent
126
+ return io.NodeOutput(positive, negative, out_latent)
127
+
128
+
129
+ class HunyuanVideo15SuperResolution(io.ComfyNode):
130
+ @classmethod
131
+ def define_schema(cls):
132
+ return io.Schema(
133
+ node_id="HunyuanVideo15SuperResolution",
134
+ inputs=[
135
+ io.Conditioning.Input("positive"),
136
+ io.Conditioning.Input("negative"),
137
+ io.Vae.Input("vae", optional=True),
138
+ io.Image.Input("start_image", optional=True),
139
+ io.ClipVisionOutput.Input("clip_vision_output", optional=True),
140
+ io.Latent.Input("latent"),
141
+ io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01, advanced=True),
142
+
143
+ ],
144
+ outputs=[
145
+ io.Conditioning.Output(display_name="positive"),
146
+ io.Conditioning.Output(display_name="negative"),
147
+ io.Latent.Output(display_name="latent"),
148
+ ],
149
+ )
150
+
151
+ @classmethod
152
+ def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
153
+ in_latent = latent["samples"]
154
+ in_channels = in_latent.shape[1]
155
+ cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device())
156
+ cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent
157
+ cond_latent[:, 2 * in_channels + 1] = 1
158
+ if start_image is not None:
159
+ start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1)
160
+ encoded = vae.encode(start_image[:, :, :, :3])
161
+ cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded
162
+ cond_latent[:, in_channels + 1, 0] = 1
163
+
164
+ positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
165
+ negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
166
+ if clip_vision_output is not None:
167
+ positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
168
+ negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
169
+
170
+ return io.NodeOutput(positive, negative, latent)
171
+
172
+
173
+ class LatentUpscaleModelLoader(io.ComfyNode):
174
+ @classmethod
175
+ def define_schema(cls):
176
+ return io.Schema(
177
+ node_id="LatentUpscaleModelLoader",
178
+ display_name="Load Latent Upscale Model",
179
+ category="loaders",
180
+ inputs=[
181
+ io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")),
182
+ ],
183
+ outputs=[
184
+ io.LatentUpscaleModel.Output(),
185
+ ],
186
+ )
187
+
188
+ @classmethod
189
+ def execute(cls, model_name) -> io.NodeOutput:
190
+ model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
191
+ sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
192
+
193
+ if "blocks.0.block.0.conv.weight" in sd:
194
+ config = {
195
+ "in_channels": sd["in_conv.conv.weight"].shape[1],
196
+ "out_channels": sd["out_conv.conv.weight"].shape[0],
197
+ "hidden_channels": sd["in_conv.conv.weight"].shape[0],
198
+ "num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]),
199
+ "global_residual": False,
200
+ }
201
+ model_type = "720p"
202
+ model = HunyuanVideo15SRModel(model_type, config)
203
+ model.load_sd(sd)
204
+ elif "up.0.block.0.conv1.conv.weight" in sd:
205
+ sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
206
+ config = {
207
+ "z_channels": sd["conv_in.conv.weight"].shape[1],
208
+ "out_channels": sd["conv_out.conv.weight"].shape[0],
209
+ "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
210
+ }
211
+ model_type = "1080p"
212
+ model = HunyuanVideo15SRModel(model_type, config)
213
+ model.load_sd(sd)
214
+ elif "post_upsample_res_blocks.0.conv2.bias" in sd:
215
+ config = json.loads(metadata["config"])
216
+ model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
217
+ model.load_state_dict(sd)
218
+
219
+ return io.NodeOutput(model)
220
+
221
+
222
+ class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode):
223
+ @classmethod
224
+ def define_schema(cls):
225
+ return io.Schema(
226
+ node_id="HunyuanVideo15LatentUpscaleWithModel",
227
+ display_name="Hunyuan Video 15 Latent Upscale With Model",
228
+ category="latent",
229
+ inputs=[
230
+ io.LatentUpscaleModel.Input("model"),
231
+ io.Latent.Input("samples"),
232
+ io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"),
233
+ io.Int.Input("width", default=1280, min=0, max=16384, step=8),
234
+ io.Int.Input("height", default=720, min=0, max=16384, step=8),
235
+ io.Combo.Input("crop", options=["disabled", "center"]),
236
+ ],
237
+ outputs=[
238
+ io.Latent.Output(),
239
+ ],
240
+ )
241
+
242
+ @classmethod
243
+ def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput:
244
+ if width == 0 and height == 0:
245
+ return io.NodeOutput(samples)
246
+ else:
247
+ if width == 0:
248
+ height = max(64, height)
249
+ width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
250
+ elif height == 0:
251
+ width = max(64, width)
252
+ height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
253
+ else:
254
+ width = max(64, width)
255
+ height = max(64, height)
256
+ s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop)
257
+ s = model.resample_latent(s)
258
+ return io.NodeOutput({"samples": s.cpu().float()})
259
+
260
+
261
+ PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
262
+ "<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
263
+ "1. The main content and theme of the video."
264
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
265
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
266
+ "4. background environment, light, style and atmosphere."
267
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
268
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
269
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
270
+ )
271
+
272
+ class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
273
+ @classmethod
274
+ def define_schema(cls):
275
+ return io.Schema(
276
+ node_id="TextEncodeHunyuanVideo_ImageToVideo",
277
+ category="advanced/conditioning",
278
+ inputs=[
279
+ io.Clip.Input("clip"),
280
+ io.ClipVisionOutput.Input("clip_vision_output"),
281
+ io.String.Input("prompt", multiline=True, dynamic_prompts=True),
282
+ io.Int.Input(
283
+ "image_interleave",
284
+ default=2,
285
+ min=1,
286
+ max=512,
287
+ tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
288
+ advanced=True,
289
+ ),
290
+ ],
291
+ outputs=[
292
+ io.Conditioning.Output(),
293
+ ],
294
+ )
295
+
296
+ @classmethod
297
+ def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput:
298
+ tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
299
+ return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
300
+
301
+ encode = execute # TODO: remove
302
+
303
+
304
+ class HunyuanImageToVideo(io.ComfyNode):
305
+ @classmethod
306
+ def define_schema(cls):
307
+ return io.Schema(
308
+ node_id="HunyuanImageToVideo",
309
+ category="conditioning/video_models",
310
+ inputs=[
311
+ io.Conditioning.Input("positive"),
312
+ io.Vae.Input("vae"),
313
+ io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
314
+ io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
315
+ io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
316
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
317
+ io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"], advanced=True),
318
+ io.Image.Input("start_image", optional=True),
319
+ ],
320
+ outputs=[
321
+ io.Conditioning.Output(display_name="positive"),
322
+ io.Latent.Output(display_name="latent"),
323
+ ],
324
+ )
325
+
326
+ @classmethod
327
+ def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput:
328
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
329
+ out_latent = {}
330
+
331
+ if start_image is not None:
332
+ start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
333
+
334
+ concat_latent_image = vae.encode(start_image)
335
+ mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
336
+ mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
337
+
338
+ if guidance_type == "v1 (concat)":
339
+ cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
340
+ elif guidance_type == "v2 (replace)":
341
+ cond = {'guiding_frame_index': 0}
342
+ latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
343
+ out_latent["noise_mask"] = mask
344
+ elif guidance_type == "custom":
345
+ cond = {"ref_latent": concat_latent_image}
346
+
347
+ positive = node_helpers.conditioning_set_values(positive, cond)
348
+
349
+ out_latent["samples"] = latent
350
+ return io.NodeOutput(positive, out_latent)
351
+
352
+ encode = execute # TODO: remove
353
+
354
+
355
+ class EmptyHunyuanImageLatent(io.ComfyNode):
356
+ @classmethod
357
+ def define_schema(cls):
358
+ return io.Schema(
359
+ node_id="EmptyHunyuanImageLatent",
360
+ category="latent",
361
+ inputs=[
362
+ io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
363
+ io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
364
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
365
+ ],
366
+ outputs=[
367
+ io.Latent.Output(),
368
+ ],
369
+ )
370
+
371
+ @classmethod
372
+ def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
373
+ latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
374
+ return io.NodeOutput({"samples":latent})
375
+
376
+ generate = execute # TODO: remove
377
+
378
+
379
+ class HunyuanRefinerLatent(io.ComfyNode):
380
+ @classmethod
381
+ def define_schema(cls):
382
+ return io.Schema(
383
+ node_id="HunyuanRefinerLatent",
384
+ inputs=[
385
+ io.Conditioning.Input("positive"),
386
+ io.Conditioning.Input("negative"),
387
+ io.Latent.Input("latent"),
388
+ io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01, advanced=True),
389
+
390
+ ],
391
+ outputs=[
392
+ io.Conditioning.Output(display_name="positive"),
393
+ io.Conditioning.Output(display_name="negative"),
394
+ io.Latent.Output(display_name="latent"),
395
+ ],
396
+ )
397
+
398
+ @classmethod
399
+ def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput:
400
+ latent = latent["samples"]
401
+ positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
402
+ negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
403
+ out_latent = {}
404
+ out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
405
+ return io.NodeOutput(positive, negative, out_latent)
406
+
407
+
408
+ class HunyuanExtension(ComfyExtension):
409
+ @override
410
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
411
+ return [
412
+ CLIPTextEncodeHunyuanDiT,
413
+ TextEncodeHunyuanVideo_ImageToVideo,
414
+ EmptyHunyuanLatentVideo,
415
+ EmptyHunyuanVideo15Latent,
416
+ HunyuanVideo15ImageToVideo,
417
+ HunyuanVideo15SuperResolution,
418
+ HunyuanVideo15LatentUpscaleWithModel,
419
+ LatentUpscaleModelLoader,
420
+ HunyuanImageToVideo,
421
+ EmptyHunyuanImageLatent,
422
+ HunyuanRefinerLatent,
423
+ ]
424
+
425
+
426
+ async def comfy_entrypoint() -> HunyuanExtension:
427
+ return HunyuanExtension()
ComfyUI/comfy_extras/nodes_hunyuan3d.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import json
4
+ import struct
5
+ import numpy as np
6
+ from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
7
+ import folder_paths
8
+ import comfy.model_management
9
+ from comfy.cli_args import args
10
+ from typing_extensions import override
11
+ from comfy_api.latest import ComfyExtension, IO, Types
12
+ from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
13
+
14
+
15
+ class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
16
+ @classmethod
17
+ def define_schema(cls):
18
+ return IO.Schema(
19
+ node_id="EmptyLatentHunyuan3Dv2",
20
+ category="latent/3d",
21
+ inputs=[
22
+ IO.Int.Input("resolution", default=3072, min=1, max=8192),
23
+ IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
24
+ ],
25
+ outputs=[
26
+ IO.Latent.Output(),
27
+ ]
28
+ )
29
+
30
+ @classmethod
31
+ def execute(cls, resolution, batch_size) -> IO.NodeOutput:
32
+ latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
33
+ return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"})
34
+
35
+ generate = execute # TODO: remove
36
+
37
+
38
+ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
39
+ @classmethod
40
+ def define_schema(cls):
41
+ return IO.Schema(
42
+ node_id="Hunyuan3Dv2Conditioning",
43
+ category="conditioning/video_models",
44
+ inputs=[
45
+ IO.ClipVisionOutput.Input("clip_vision_output"),
46
+ ],
47
+ outputs=[
48
+ IO.Conditioning.Output(display_name="positive"),
49
+ IO.Conditioning.Output(display_name="negative"),
50
+ ]
51
+ )
52
+
53
+ @classmethod
54
+ def execute(cls, clip_vision_output) -> IO.NodeOutput:
55
+ embeds = clip_vision_output.last_hidden_state
56
+ positive = [[embeds, {}]]
57
+ negative = [[torch.zeros_like(embeds), {}]]
58
+ return IO.NodeOutput(positive, negative)
59
+
60
+ encode = execute # TODO: remove
61
+
62
+
63
+ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
64
+ @classmethod
65
+ def define_schema(cls):
66
+ return IO.Schema(
67
+ node_id="Hunyuan3Dv2ConditioningMultiView",
68
+ category="conditioning/video_models",
69
+ inputs=[
70
+ IO.ClipVisionOutput.Input("front", optional=True),
71
+ IO.ClipVisionOutput.Input("left", optional=True),
72
+ IO.ClipVisionOutput.Input("back", optional=True),
73
+ IO.ClipVisionOutput.Input("right", optional=True),
74
+ ],
75
+ outputs=[
76
+ IO.Conditioning.Output(display_name="positive"),
77
+ IO.Conditioning.Output(display_name="negative"),
78
+ ]
79
+ )
80
+
81
+ @classmethod
82
+ def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput:
83
+ all_embeds = [front, left, back, right]
84
+ out = []
85
+ pos_embeds = None
86
+ for i, e in enumerate(all_embeds):
87
+ if e is not None:
88
+ if pos_embeds is None:
89
+ pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
90
+ out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
91
+
92
+ embeds = torch.cat(out, dim=1)
93
+ positive = [[embeds, {}]]
94
+ negative = [[torch.zeros_like(embeds), {}]]
95
+ return IO.NodeOutput(positive, negative)
96
+
97
+ encode = execute # TODO: remove
98
+
99
+
100
+ class VAEDecodeHunyuan3D(IO.ComfyNode):
101
+ @classmethod
102
+ def define_schema(cls):
103
+ return IO.Schema(
104
+ node_id="VAEDecodeHunyuan3D",
105
+ category="latent/3d",
106
+ inputs=[
107
+ IO.Latent.Input("samples"),
108
+ IO.Vae.Input("vae"),
109
+ IO.Int.Input("num_chunks", default=8000, min=1000, max=500000, advanced=True),
110
+ IO.Int.Input("octree_resolution", default=256, min=16, max=512, advanced=True),
111
+ ],
112
+ outputs=[
113
+ IO.Voxel.Output(),
114
+ ]
115
+ )
116
+
117
+ @classmethod
118
+ def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput:
119
+ voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
120
+ return IO.NodeOutput(voxels)
121
+
122
+ decode = execute # TODO: remove
123
+
124
+
125
+ def voxel_to_mesh(voxels, threshold=0.5, device=None):
126
+ if device is None:
127
+ device = torch.device("cpu")
128
+ voxels = voxels.to(device)
129
+
130
+ binary = (voxels > threshold).float()
131
+ padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
132
+
133
+ D, H, W = binary.shape
134
+
135
+ neighbors = torch.tensor([
136
+ [0, 0, 1],
137
+ [0, 0, -1],
138
+ [0, 1, 0],
139
+ [0, -1, 0],
140
+ [1, 0, 0],
141
+ [-1, 0, 0]
142
+ ], device=device)
143
+
144
+ z, y, x = torch.meshgrid(
145
+ torch.arange(D, device=device),
146
+ torch.arange(H, device=device),
147
+ torch.arange(W, device=device),
148
+ indexing='ij'
149
+ )
150
+ voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
151
+
152
+ solid_mask = binary.flatten() > 0
153
+ solid_indices = voxel_indices[solid_mask]
154
+
155
+ corner_offsets = [
156
+ torch.tensor([
157
+ [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
158
+ ], device=device),
159
+ torch.tensor([
160
+ [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
161
+ ], device=device),
162
+ torch.tensor([
163
+ [0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
164
+ ], device=device),
165
+ torch.tensor([
166
+ [0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
167
+ ], device=device),
168
+ torch.tensor([
169
+ [1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
170
+ ], device=device),
171
+ torch.tensor([
172
+ [0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
173
+ ], device=device)
174
+ ]
175
+
176
+ all_vertices = []
177
+ all_indices = []
178
+
179
+ vertex_count = 0
180
+
181
+ for face_idx, offset in enumerate(neighbors):
182
+ neighbor_indices = solid_indices + offset
183
+
184
+ padded_indices = neighbor_indices + 1
185
+
186
+ is_exposed = padded[
187
+ padded_indices[:, 0],
188
+ padded_indices[:, 1],
189
+ padded_indices[:, 2]
190
+ ] == 0
191
+
192
+ if not is_exposed.any():
193
+ continue
194
+
195
+ exposed_indices = solid_indices[is_exposed]
196
+
197
+ corners = corner_offsets[face_idx].unsqueeze(0)
198
+
199
+ face_vertices = exposed_indices.unsqueeze(1) + corners
200
+
201
+ all_vertices.append(face_vertices.reshape(-1, 3))
202
+
203
+ num_faces = exposed_indices.shape[0]
204
+ face_indices = torch.arange(
205
+ vertex_count,
206
+ vertex_count + 4 * num_faces,
207
+ device=device
208
+ ).reshape(-1, 4)
209
+
210
+ all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1))
211
+ all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1))
212
+
213
+ vertex_count += 4 * num_faces
214
+
215
+ if len(all_vertices) > 0:
216
+ vertices = torch.cat(all_vertices, dim=0)
217
+ faces = torch.cat(all_indices, dim=0)
218
+ else:
219
+ vertices = torch.zeros((1, 3))
220
+ faces = torch.zeros((1, 3))
221
+
222
+ v_min = 0
223
+ v_max = max(voxels.shape)
224
+
225
+ vertices = vertices - (v_min + v_max) / 2
226
+
227
+ scale = (v_max - v_min) / 2
228
+ if scale > 0:
229
+ vertices = vertices / scale
230
+
231
+ vertices = torch.fliplr(vertices)
232
+ return vertices, faces
233
+
234
+ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
235
+ if device is None:
236
+ device = torch.device("cpu")
237
+ voxels = voxels.to(device)
238
+
239
+ D, H, W = voxels.shape
240
+
241
+ padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
242
+ z, y, x = torch.meshgrid(
243
+ torch.arange(D, device=device),
244
+ torch.arange(H, device=device),
245
+ torch.arange(W, device=device),
246
+ indexing='ij'
247
+ )
248
+ cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
249
+
250
+ corner_offsets = torch.tensor([
251
+ [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
252
+ [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
253
+ ], device=device)
254
+
255
+ pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0)
256
+ z_idx, y_idx, x_idx = pos.unbind(-1)
257
+ corner_values = padded[z_idx, y_idx, x_idx]
258
+
259
+ corner_signs = corner_values > threshold
260
+ has_inside = torch.any(corner_signs, dim=1)
261
+ has_outside = torch.any(~corner_signs, dim=1)
262
+ contains_surface = has_inside & has_outside
263
+
264
+ active_cells = cell_positions[contains_surface]
265
+ active_signs = corner_signs[contains_surface]
266
+ active_values = corner_values[contains_surface]
267
+
268
+ if active_cells.shape[0] == 0:
269
+ return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
270
+
271
+ edges = torch.tensor([
272
+ [0, 1], [0, 2], [0, 4], [1, 3],
273
+ [1, 5], [2, 3], [2, 6], [3, 7],
274
+ [4, 5], [4, 6], [5, 7], [6, 7]
275
+ ], device=device)
276
+
277
+ cell_vertices = {}
278
+ progress = comfy.utils.ProgressBar(100)
279
+
280
+ for edge_idx, (e1, e2) in enumerate(edges):
281
+ progress.update(1)
282
+ crossing = active_signs[:, e1] != active_signs[:, e2]
283
+ if not crossing.any():
284
+ continue
285
+
286
+ cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
287
+
288
+ v1 = active_values[cell_indices, e1]
289
+ v2 = active_values[cell_indices, e2]
290
+
291
+ t = torch.zeros_like(v1, device=device)
292
+ denom = v2 - v1
293
+ valid = denom != 0
294
+ t[valid] = (threshold - v1[valid]) / denom[valid]
295
+ t[~valid] = 0.5
296
+
297
+ p1 = corner_offsets[e1].float()
298
+ p2 = corner_offsets[e2].float()
299
+
300
+ intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
301
+
302
+ for i, point in zip(cell_indices.tolist(), intersection):
303
+ if i not in cell_vertices:
304
+ cell_vertices[i] = []
305
+ cell_vertices[i].append(point)
306
+
307
+ # Calculate the final vertices as the average of intersection points for each cell
308
+ vertices = []
309
+ vertex_lookup = {}
310
+
311
+ vert_progress_mod = round(len(cell_vertices)/50)
312
+
313
+ for i, points in cell_vertices.items():
314
+ if not i % vert_progress_mod:
315
+ progress.update(1)
316
+
317
+ if points:
318
+ vertex = torch.stack(points).mean(dim=0)
319
+ vertex = vertex + active_cells[i].float()
320
+ vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
321
+ vertices.append(vertex)
322
+
323
+ if not vertices:
324
+ return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
325
+
326
+ final_vertices = torch.stack(vertices)
327
+
328
+ inside_corners_mask = active_signs
329
+ outside_corners_mask = ~active_signs
330
+
331
+ inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
332
+ outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
333
+
334
+ inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
335
+ outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
336
+
337
+ for i in range(8):
338
+ mask_inside = inside_corners_mask[:, i].unsqueeze(1)
339
+ mask_outside = outside_corners_mask[:, i].unsqueeze(1)
340
+ inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
341
+ outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
342
+
343
+ inside_pos /= inside_counts
344
+ outside_pos /= outside_counts
345
+ gradients = inside_pos - outside_pos
346
+
347
+ pos_dirs = torch.tensor([
348
+ [1, 0, 0],
349
+ [0, 1, 0],
350
+ [0, 0, 1]
351
+ ], device=device)
352
+
353
+ cross_products = [
354
+ torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
355
+ for i in range(3) for j in range(i+1, 3)
356
+ ]
357
+
358
+ faces = []
359
+ all_keys = set(vertex_lookup.keys())
360
+
361
+ face_progress_mod = round(len(active_cells)/38*3)
362
+
363
+ for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
364
+ dir_i = pos_dirs[i]
365
+ dir_j = pos_dirs[j]
366
+ cross_product = cross_products[pair_idx]
367
+
368
+ ni_positions = active_cells + dir_i
369
+ nj_positions = active_cells + dir_j
370
+ diag_positions = active_cells + dir_i + dir_j
371
+
372
+ alignments = torch.matmul(gradients, cross_product)
373
+
374
+ valid_quads = []
375
+ quad_indices = []
376
+
377
+ for idx, active_cell in enumerate(active_cells):
378
+ if not idx % face_progress_mod:
379
+ progress.update(1)
380
+ cell_key = tuple(active_cell.tolist())
381
+ ni_key = tuple(ni_positions[idx].tolist())
382
+ nj_key = tuple(nj_positions[idx].tolist())
383
+ diag_key = tuple(diag_positions[idx].tolist())
384
+
385
+ if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
386
+ v0 = vertex_lookup[cell_key]
387
+ v1 = vertex_lookup[ni_key]
388
+ v2 = vertex_lookup[nj_key]
389
+ v3 = vertex_lookup[diag_key]
390
+
391
+ valid_quads.append((v0, v1, v2, v3))
392
+ quad_indices.append(idx)
393
+
394
+ for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
395
+ cell_idx = quad_indices[q_idx]
396
+ if alignments[cell_idx] > 0:
397
+ faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
398
+ faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
399
+ else:
400
+ faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
401
+ faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
402
+
403
+ if faces:
404
+ faces = torch.stack(faces)
405
+ else:
406
+ faces = torch.zeros((0, 3), dtype=torch.long, device=device)
407
+
408
+ v_min = 0
409
+ v_max = max(D, H, W)
410
+
411
+ final_vertices = final_vertices - (v_min + v_max) / 2
412
+
413
+ scale = (v_max - v_min) / 2
414
+ if scale > 0:
415
+ final_vertices = final_vertices / scale
416
+
417
+ final_vertices = torch.fliplr(final_vertices)
418
+
419
+ return final_vertices, faces
420
+
421
+
422
+ class VoxelToMeshBasic(IO.ComfyNode):
423
+ @classmethod
424
+ def define_schema(cls):
425
+ return IO.Schema(
426
+ node_id="VoxelToMeshBasic",
427
+ category="3d",
428
+ inputs=[
429
+ IO.Voxel.Input("voxel"),
430
+ IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
431
+ ],
432
+ outputs=[
433
+ IO.Mesh.Output(),
434
+ ]
435
+ )
436
+
437
+ @classmethod
438
+ def execute(cls, voxel, threshold) -> IO.NodeOutput:
439
+ vertices = []
440
+ faces = []
441
+ for x in voxel.data:
442
+ v, f = voxel_to_mesh(x, threshold=threshold, device=None)
443
+ vertices.append(v)
444
+ faces.append(f)
445
+
446
+ return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
447
+
448
+ decode = execute # TODO: remove
449
+
450
+
451
+ class VoxelToMesh(IO.ComfyNode):
452
+ @classmethod
453
+ def define_schema(cls):
454
+ return IO.Schema(
455
+ node_id="VoxelToMesh",
456
+ category="3d",
457
+ inputs=[
458
+ IO.Voxel.Input("voxel"),
459
+ IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True),
460
+ IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
461
+ ],
462
+ outputs=[
463
+ IO.Mesh.Output(),
464
+ ]
465
+ )
466
+
467
+ @classmethod
468
+ def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput:
469
+ vertices = []
470
+ faces = []
471
+
472
+ if algorithm == "basic":
473
+ mesh_function = voxel_to_mesh
474
+ elif algorithm == "surface net":
475
+ mesh_function = voxel_to_mesh_surfnet
476
+
477
+ for x in voxel.data:
478
+ v, f = mesh_function(x, threshold=threshold, device=None)
479
+ vertices.append(v)
480
+ faces.append(f)
481
+
482
+ return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
483
+
484
+ decode = execute # TODO: remove
485
+
486
+
487
+ def save_glb(vertices, faces, filepath, metadata=None):
488
+ """
489
+ Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
490
+
491
+ Parameters:
492
+ vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
493
+ faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
494
+ filepath: str - Output filepath (should end with .glb)
495
+ """
496
+
497
+ # Convert tensors to numpy arrays
498
+ vertices_np = vertices.cpu().numpy().astype(np.float32)
499
+ faces_np = faces.cpu().numpy().astype(np.uint32)
500
+
501
+ vertices_buffer = vertices_np.tobytes()
502
+ indices_buffer = faces_np.tobytes()
503
+
504
+ def pad_to_4_bytes(buffer):
505
+ padding_length = (4 - (len(buffer) % 4)) % 4
506
+ return buffer + b'\x00' * padding_length
507
+
508
+ vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
509
+ indices_buffer_padded = pad_to_4_bytes(indices_buffer)
510
+
511
+ buffer_data = vertices_buffer_padded + indices_buffer_padded
512
+
513
+ vertices_byte_length = len(vertices_buffer)
514
+ vertices_byte_offset = 0
515
+ indices_byte_length = len(indices_buffer)
516
+ indices_byte_offset = len(vertices_buffer_padded)
517
+
518
+ gltf = {
519
+ "asset": {"version": "2.0", "generator": "ComfyUI"},
520
+ "buffers": [
521
+ {
522
+ "byteLength": len(buffer_data)
523
+ }
524
+ ],
525
+ "bufferViews": [
526
+ {
527
+ "buffer": 0,
528
+ "byteOffset": vertices_byte_offset,
529
+ "byteLength": vertices_byte_length,
530
+ "target": 34962 # ARRAY_BUFFER
531
+ },
532
+ {
533
+ "buffer": 0,
534
+ "byteOffset": indices_byte_offset,
535
+ "byteLength": indices_byte_length,
536
+ "target": 34963 # ELEMENT_ARRAY_BUFFER
537
+ }
538
+ ],
539
+ "accessors": [
540
+ {
541
+ "bufferView": 0,
542
+ "byteOffset": 0,
543
+ "componentType": 5126, # FLOAT
544
+ "count": len(vertices_np),
545
+ "type": "VEC3",
546
+ "max": vertices_np.max(axis=0).tolist(),
547
+ "min": vertices_np.min(axis=0).tolist()
548
+ },
549
+ {
550
+ "bufferView": 1,
551
+ "byteOffset": 0,
552
+ "componentType": 5125, # UNSIGNED_INT
553
+ "count": faces_np.size,
554
+ "type": "SCALAR"
555
+ }
556
+ ],
557
+ "meshes": [
558
+ {
559
+ "primitives": [
560
+ {
561
+ "attributes": {
562
+ "POSITION": 0
563
+ },
564
+ "indices": 1,
565
+ "mode": 4 # TRIANGLES
566
+ }
567
+ ]
568
+ }
569
+ ],
570
+ "nodes": [
571
+ {
572
+ "mesh": 0
573
+ }
574
+ ],
575
+ "scenes": [
576
+ {
577
+ "nodes": [0]
578
+ }
579
+ ],
580
+ "scene": 0
581
+ }
582
+
583
+ if metadata is not None:
584
+ gltf["asset"]["extras"] = metadata
585
+
586
+ # Convert the JSON to bytes
587
+ gltf_json = json.dumps(gltf).encode('utf8')
588
+
589
+ def pad_json_to_4_bytes(buffer):
590
+ padding_length = (4 - (len(buffer) % 4)) % 4
591
+ return buffer + b' ' * padding_length
592
+
593
+ gltf_json_padded = pad_json_to_4_bytes(gltf_json)
594
+
595
+ # Create the GLB header
596
+ # Magic glTF
597
+ glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
598
+
599
+ # Create JSON chunk header (chunk type 0)
600
+ json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
601
+
602
+ # Create BIN chunk header (chunk type 1)
603
+ bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
604
+
605
+ # Write the GLB file
606
+ with open(filepath, 'wb') as f:
607
+ f.write(glb_header)
608
+ f.write(json_chunk_header)
609
+ f.write(gltf_json_padded)
610
+ f.write(bin_chunk_header)
611
+ f.write(buffer_data)
612
+
613
+ return filepath
614
+
615
+
616
+ class SaveGLB(IO.ComfyNode):
617
+ @classmethod
618
+ def define_schema(cls):
619
+ return IO.Schema(
620
+ node_id="SaveGLB",
621
+ display_name="Save 3D Model",
622
+ search_aliases=["export 3d model", "save mesh"],
623
+ category="3d",
624
+ essentials_category="Basics",
625
+ is_output_node=True,
626
+ inputs=[
627
+ IO.MultiType.Input(
628
+ IO.Mesh.Input("mesh"),
629
+ types=[
630
+ IO.File3DGLB,
631
+ IO.File3DGLTF,
632
+ IO.File3DOBJ,
633
+ IO.File3DFBX,
634
+ IO.File3DSTL,
635
+ IO.File3DUSDZ,
636
+ IO.File3DAny,
637
+ ],
638
+ tooltip="Mesh or 3D file to save",
639
+ ),
640
+ IO.String.Input("filename_prefix", default="3d/ComfyUI"),
641
+ ],
642
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
643
+ )
644
+
645
+ @classmethod
646
+ def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
647
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
648
+ results = []
649
+
650
+ metadata = {}
651
+ if not args.disable_metadata:
652
+ if cls.hidden.prompt is not None:
653
+ metadata["prompt"] = json.dumps(cls.hidden.prompt)
654
+ if cls.hidden.extra_pnginfo is not None:
655
+ for x in cls.hidden.extra_pnginfo:
656
+ metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
657
+
658
+ if isinstance(mesh, Types.File3D):
659
+ # Handle File3D input - save BytesIO data to output folder
660
+ ext = mesh.format or "glb"
661
+ f = f"{filename}_{counter:05}_.{ext}"
662
+ mesh.save_to(os.path.join(full_output_folder, f))
663
+ results.append({
664
+ "filename": f,
665
+ "subfolder": subfolder,
666
+ "type": "output"
667
+ })
668
+ else:
669
+ # Handle Mesh input - save vertices and faces as GLB
670
+ for i in range(mesh.vertices.shape[0]):
671
+ f = f"{filename}_{counter:05}_.glb"
672
+ save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
673
+ results.append({
674
+ "filename": f,
675
+ "subfolder": subfolder,
676
+ "type": "output"
677
+ })
678
+ counter += 1
679
+ return IO.NodeOutput(ui={"3d": results})
680
+
681
+
682
+ class Hunyuan3dExtension(ComfyExtension):
683
+ @override
684
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
685
+ return [
686
+ EmptyLatentHunyuan3Dv2,
687
+ Hunyuan3Dv2Conditioning,
688
+ Hunyuan3Dv2ConditioningMultiView,
689
+ VAEDecodeHunyuan3D,
690
+ VoxelToMeshBasic,
691
+ VoxelToMesh,
692
+ SaveGLB,
693
+ ]
694
+
695
+
696
+ async def comfy_entrypoint() -> Hunyuan3dExtension:
697
+ return Hunyuan3dExtension()
ComfyUI/comfy_extras/nodes_hypernetwork.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.utils
2
+ import folder_paths
3
+ import torch
4
+ import logging
5
+ from comfy_api.latest import IO, ComfyExtension
6
+ from typing_extensions import override
7
+
8
+
9
+ def load_hypernetwork_patch(path, strength):
10
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
11
+ activation_func = sd.get('activation_func', 'linear')
12
+ is_layer_norm = sd.get('is_layer_norm', False)
13
+ use_dropout = sd.get('use_dropout', False)
14
+ activate_output = sd.get('activate_output', False)
15
+ last_layer_dropout = sd.get('last_layer_dropout', False)
16
+
17
+ valid_activation = {
18
+ "linear": torch.nn.Identity,
19
+ "relu": torch.nn.ReLU,
20
+ "leakyrelu": torch.nn.LeakyReLU,
21
+ "elu": torch.nn.ELU,
22
+ "swish": torch.nn.Hardswish,
23
+ "tanh": torch.nn.Tanh,
24
+ "sigmoid": torch.nn.Sigmoid,
25
+ "softsign": torch.nn.Softsign,
26
+ "mish": torch.nn.Mish,
27
+ }
28
+
29
+ if activation_func not in valid_activation:
30
+ logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
31
+ return None
32
+
33
+ out = {}
34
+
35
+ for d in sd:
36
+ try:
37
+ dim = int(d)
38
+ except:
39
+ continue
40
+
41
+ output = []
42
+ for index in [0, 1]:
43
+ attn_weights = sd[dim][index]
44
+ keys = attn_weights.keys()
45
+
46
+ linears = filter(lambda a: a.endswith(".weight"), keys)
47
+ linears = list(map(lambda a: a[:-len(".weight")], linears))
48
+ layers = []
49
+
50
+ i = 0
51
+ while i < len(linears):
52
+ lin_name = linears[i]
53
+ last_layer = (i == (len(linears) - 1))
54
+ penultimate_layer = (i == (len(linears) - 2))
55
+
56
+ lin_weight = attn_weights['{}.weight'.format(lin_name)]
57
+ lin_bias = attn_weights['{}.bias'.format(lin_name)]
58
+ layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
59
+ layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
60
+ layers.append(layer)
61
+ if activation_func != "linear":
62
+ if (not last_layer) or (activate_output):
63
+ layers.append(valid_activation[activation_func]())
64
+ if is_layer_norm:
65
+ i += 1
66
+ ln_name = linears[i]
67
+ ln_weight = attn_weights['{}.weight'.format(ln_name)]
68
+ ln_bias = attn_weights['{}.bias'.format(ln_name)]
69
+ ln = torch.nn.LayerNorm(ln_weight.shape[0])
70
+ ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
71
+ layers.append(ln)
72
+ if use_dropout:
73
+ if (not last_layer) and (not penultimate_layer or last_layer_dropout):
74
+ layers.append(torch.nn.Dropout(p=0.3))
75
+ i += 1
76
+
77
+ output.append(torch.nn.Sequential(*layers))
78
+ out[dim] = torch.nn.ModuleList(output)
79
+
80
+ class hypernetwork_patch:
81
+ def __init__(self, hypernet, strength):
82
+ self.hypernet = hypernet
83
+ self.strength = strength
84
+ def __call__(self, q, k, v, extra_options):
85
+ dim = k.shape[-1]
86
+ if dim in self.hypernet:
87
+ hn = self.hypernet[dim]
88
+ k = k + hn[0](k) * self.strength
89
+ v = v + hn[1](v) * self.strength
90
+
91
+ return q, k, v
92
+
93
+ def to(self, device):
94
+ for d in self.hypernet.keys():
95
+ self.hypernet[d] = self.hypernet[d].to(device)
96
+ return self
97
+
98
+ return hypernetwork_patch(out, strength)
99
+
100
+ class HypernetworkLoader(IO.ComfyNode):
101
+ @classmethod
102
+ def define_schema(cls):
103
+ return IO.Schema(
104
+ node_id="HypernetworkLoader",
105
+ category="loaders",
106
+ inputs=[
107
+ IO.Model.Input("model"),
108
+ IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
109
+ IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
110
+ ],
111
+ outputs=[
112
+ IO.Model.Output(),
113
+ ],
114
+ )
115
+
116
+ @classmethod
117
+ def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
118
+ hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
119
+ model_hypernetwork = model.clone()
120
+ patch = load_hypernetwork_patch(hypernetwork_path, strength)
121
+ if patch is not None:
122
+ model_hypernetwork.set_model_attn1_patch(patch)
123
+ model_hypernetwork.set_model_attn2_patch(patch)
124
+ return IO.NodeOutput(model_hypernetwork)
125
+
126
+ load_hypernetwork = execute # TODO: remove
127
+
128
+
129
+ class HyperNetworkExtension(ComfyExtension):
130
+ @override
131
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
132
+ return [
133
+ HypernetworkLoader,
134
+ ]
135
+
136
+
137
+ async def comfy_entrypoint() -> HyperNetworkExtension:
138
+ return HyperNetworkExtension()
ComfyUI/comfy_extras/nodes_hypertile.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Taken from: https://github.com/tfernd/HyperTile/
2
+
3
+ import math
4
+ from typing_extensions import override
5
+ from einops import rearrange
6
+ # Use torch rng for consistency across generations
7
+ from torch import randint
8
+ from comfy_api.latest import ComfyExtension, io
9
+
10
+ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
11
+ min_value = min(min_value, value)
12
+
13
+ # All big divisors of value (inclusive)
14
+ divisors = [i for i in range(min_value, value + 1) if value % i == 0]
15
+
16
+ ns = [value // i for i in divisors[:max_options]] # has at least 1 element
17
+
18
+ if len(ns) - 1 > 0:
19
+ idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
20
+ else:
21
+ idx = 0
22
+
23
+ return ns[idx]
24
+
25
+ class HyperTile(io.ComfyNode):
26
+ @classmethod
27
+ def define_schema(cls):
28
+ return io.Schema(
29
+ node_id="HyperTile",
30
+ category="model_patches/unet",
31
+ inputs=[
32
+ io.Model.Input("model"),
33
+ io.Int.Input("tile_size", default=256, min=1, max=2048, advanced=True),
34
+ io.Int.Input("swap_size", default=2, min=1, max=128, advanced=True),
35
+ io.Int.Input("max_depth", default=0, min=0, max=10, advanced=True),
36
+ io.Boolean.Input("scale_depth", default=False, advanced=True),
37
+ ],
38
+ outputs=[
39
+ io.Model.Output(),
40
+ ],
41
+ )
42
+
43
+ @classmethod
44
+ def execute(cls, model, tile_size, swap_size, max_depth, scale_depth) -> io.NodeOutput:
45
+ latent_tile_size = max(32, tile_size) // 8
46
+ temp = None
47
+
48
+ def hypertile_in(q, k, v, extra_options):
49
+ nonlocal temp
50
+ model_chans = q.shape[-2]
51
+ orig_shape = extra_options['original_shape']
52
+ apply_to = []
53
+ for i in range(max_depth + 1):
54
+ apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
55
+
56
+ if model_chans in apply_to:
57
+ shape = extra_options["original_shape"]
58
+ aspect_ratio = shape[-1] / shape[-2]
59
+
60
+ hw = q.size(1)
61
+ h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
62
+
63
+ factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
64
+ nh = random_divisor(h, latent_tile_size * factor, swap_size)
65
+ nw = random_divisor(w, latent_tile_size * factor, swap_size)
66
+
67
+ if nh * nw > 1:
68
+ q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
69
+ temp = (nh, nw, h, w)
70
+ return q, k, v
71
+
72
+ return q, k, v
73
+ def hypertile_out(out, extra_options):
74
+ nonlocal temp
75
+ if temp is not None:
76
+ nh, nw, h, w = temp
77
+ temp = None
78
+ out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
79
+ out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
80
+ return out
81
+
82
+
83
+ m = model.clone()
84
+ m.set_model_attn1_patch(hypertile_in)
85
+ m.set_model_attn1_output_patch(hypertile_out)
86
+ return (m, )
87
+
88
+
89
+ class HyperTileExtension(ComfyExtension):
90
+ @override
91
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
92
+ return [
93
+ HyperTile,
94
+ ]
95
+
96
+
97
+ async def comfy_entrypoint() -> HyperTileExtension:
98
+ return HyperTileExtension()
ComfyUI/comfy_extras/nodes_image_compare.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+
3
+ from typing_extensions import override
4
+ from comfy_api.latest import IO, ComfyExtension
5
+
6
+
7
+ class ImageCompare(IO.ComfyNode):
8
+ """Compares two images with a slider interface."""
9
+
10
+ @classmethod
11
+ def define_schema(cls):
12
+ return IO.Schema(
13
+ node_id="ImageCompare",
14
+ display_name="Image Compare",
15
+ description="Compares two images side by side with a slider.",
16
+ category="image",
17
+ essentials_category="Image Tools",
18
+ is_experimental=True,
19
+ is_output_node=True,
20
+ inputs=[
21
+ IO.Image.Input("image_a", optional=True),
22
+ IO.Image.Input("image_b", optional=True),
23
+ IO.ImageCompare.Input("compare_view"),
24
+ ],
25
+ outputs=[],
26
+ )
27
+
28
+ @classmethod
29
+ def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput:
30
+ result = {"a_images": [], "b_images": []}
31
+
32
+ preview_node = nodes.PreviewImage()
33
+
34
+ if image_a is not None and len(image_a) > 0:
35
+ saved = preview_node.save_images(image_a, "comfy.compare.a")
36
+ result["a_images"] = saved["ui"]["images"]
37
+
38
+ if image_b is not None and len(image_b) > 0:
39
+ saved = preview_node.save_images(image_b, "comfy.compare.b")
40
+ result["b_images"] = saved["ui"]["images"]
41
+
42
+ return IO.NodeOutput(ui=result)
43
+
44
+
45
+ class ImageCompareExtension(ComfyExtension):
46
+ @override
47
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
48
+ return [
49
+ ImageCompare,
50
+ ]
51
+
52
+
53
+ async def comfy_entrypoint() -> ImageCompareExtension:
54
+ return ImageCompareExtension()
ComfyUI/comfy_extras/nodes_images.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import nodes
4
+ import folder_paths
5
+
6
+ import json
7
+ import os
8
+ import re
9
+ import math
10
+ import torch
11
+ import comfy.utils
12
+
13
+ from server import PromptServer
14
+ from comfy_api.latest import ComfyExtension, IO, UI
15
+ from typing_extensions import override
16
+
17
+ SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
18
+
19
+ MAX_RESOLUTION = nodes.MAX_RESOLUTION
20
+
21
+ class ImageCrop(IO.ComfyNode):
22
+ @classmethod
23
+ def define_schema(cls):
24
+ return IO.Schema(
25
+ node_id="ImageCrop",
26
+ search_aliases=["trim"],
27
+ display_name="Image Crop (Deprecated)",
28
+ category="image/transform",
29
+ is_deprecated=True,
30
+ essentials_category="Image Tools",
31
+ inputs=[
32
+ IO.Image.Input("image"),
33
+ IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
34
+ IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
35
+ IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
36
+ IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
37
+ ],
38
+ outputs=[IO.Image.Output()],
39
+ )
40
+
41
+ @classmethod
42
+ def execute(cls, image, width, height, x, y) -> IO.NodeOutput:
43
+ x = min(x, image.shape[2] - 1)
44
+ y = min(y, image.shape[1] - 1)
45
+ to_x = width + x
46
+ to_y = height + y
47
+ img = image[:,y:to_y, x:to_x, :]
48
+ return IO.NodeOutput(img)
49
+
50
+ crop = execute # TODO: remove
51
+
52
+
53
+ class ImageCropV2(IO.ComfyNode):
54
+ @classmethod
55
+ def define_schema(cls):
56
+ return IO.Schema(
57
+ node_id="ImageCropV2",
58
+ search_aliases=["trim"],
59
+ display_name="Image Crop",
60
+ category="image/transform",
61
+ essentials_category="Image Tools",
62
+ has_intermediate_output=True,
63
+ inputs=[
64
+ IO.Image.Input("image"),
65
+ IO.BoundingBox.Input("crop_region", component="ImageCrop"),
66
+ ],
67
+ outputs=[IO.Image.Output()],
68
+ )
69
+
70
+ @classmethod
71
+ def execute(cls, image, crop_region) -> IO.NodeOutput:
72
+ x = crop_region.get("x", 0)
73
+ y = crop_region.get("y", 0)
74
+ width = crop_region.get("width", 512)
75
+ height = crop_region.get("height", 512)
76
+
77
+ x = min(x, image.shape[2] - 1)
78
+ y = min(y, image.shape[1] - 1)
79
+ to_x = width + x
80
+ to_y = height + y
81
+ img = image[:,y:to_y, x:to_x, :]
82
+ return IO.NodeOutput(img, ui=UI.PreviewImage(img))
83
+
84
+
85
+ class BoundingBox(IO.ComfyNode):
86
+ @classmethod
87
+ def define_schema(cls):
88
+ return IO.Schema(
89
+ node_id="PrimitiveBoundingBox",
90
+ display_name="Bounding Box",
91
+ category="utils/primitive",
92
+ inputs=[
93
+ IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
94
+ IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
95
+ IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
96
+ IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
97
+ ],
98
+ outputs=[IO.BoundingBox.Output()],
99
+ )
100
+
101
+ @classmethod
102
+ def execute(cls, x, y, width, height) -> IO.NodeOutput:
103
+ return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
104
+
105
+
106
+ class RepeatImageBatch(IO.ComfyNode):
107
+ @classmethod
108
+ def define_schema(cls):
109
+ return IO.Schema(
110
+ node_id="RepeatImageBatch",
111
+ search_aliases=["duplicate image", "clone image"],
112
+ category="image/batch",
113
+ inputs=[
114
+ IO.Image.Input("image"),
115
+ IO.Int.Input("amount", default=1, min=1, max=4096),
116
+ ],
117
+ outputs=[IO.Image.Output()],
118
+ )
119
+
120
+ @classmethod
121
+ def execute(cls, image, amount) -> IO.NodeOutput:
122
+ s = image.repeat((amount, 1,1,1))
123
+ return IO.NodeOutput(s)
124
+
125
+ repeat = execute # TODO: remove
126
+
127
+
128
+ class ImageFromBatch(IO.ComfyNode):
129
+ @classmethod
130
+ def define_schema(cls):
131
+ return IO.Schema(
132
+ node_id="ImageFromBatch",
133
+ search_aliases=["select image", "pick from batch", "extract image"],
134
+ category="image/batch",
135
+ inputs=[
136
+ IO.Image.Input("image"),
137
+ IO.Int.Input("batch_index", default=0, min=0, max=4095),
138
+ IO.Int.Input("length", default=1, min=1, max=4096),
139
+ ],
140
+ outputs=[IO.Image.Output()],
141
+ )
142
+
143
+ @classmethod
144
+ def execute(cls, image, batch_index, length) -> IO.NodeOutput:
145
+ s_in = image
146
+ batch_index = min(s_in.shape[0] - 1, batch_index)
147
+ length = min(s_in.shape[0] - batch_index, length)
148
+ s = s_in[batch_index:batch_index + length].clone()
149
+ return IO.NodeOutput(s)
150
+
151
+ frombatch = execute # TODO: remove
152
+
153
+
154
+ class ImageAddNoise(IO.ComfyNode):
155
+ @classmethod
156
+ def define_schema(cls):
157
+ return IO.Schema(
158
+ node_id="ImageAddNoise",
159
+ search_aliases=["film grain"],
160
+ category="image",
161
+ inputs=[
162
+ IO.Image.Input("image"),
163
+ IO.Int.Input(
164
+ "seed",
165
+ default=0,
166
+ min=0,
167
+ max=0xFFFFFFFFFFFFFFFF,
168
+ control_after_generate=True,
169
+ tooltip="The random seed used for creating the noise.",
170
+ ),
171
+ IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01),
172
+ ],
173
+ outputs=[IO.Image.Output()],
174
+ )
175
+
176
+ @classmethod
177
+ def execute(cls, image, seed, strength) -> IO.NodeOutput:
178
+ generator = torch.manual_seed(seed)
179
+ s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
180
+ return IO.NodeOutput(s)
181
+
182
+ repeat = execute # TODO: remove
183
+
184
+
185
+ class SaveAnimatedWEBP(IO.ComfyNode):
186
+ COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6}
187
+
188
+ @classmethod
189
+ def define_schema(cls):
190
+ return IO.Schema(
191
+ node_id="SaveAnimatedWEBP",
192
+ category="image/animation",
193
+ inputs=[
194
+ IO.Image.Input("images"),
195
+ IO.String.Input("filename_prefix", default="ComfyUI"),
196
+ IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
197
+ IO.Boolean.Input("lossless", default=True),
198
+ IO.Int.Input("quality", default=80, min=0, max=100),
199
+ IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())),
200
+ # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
201
+ ],
202
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
203
+ is_output_node=True,
204
+ )
205
+
206
+ @classmethod
207
+ def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput:
208
+ return IO.NodeOutput(
209
+ ui=UI.ImageSaveHelper.get_save_animated_webp_ui(
210
+ images=images,
211
+ filename_prefix=filename_prefix,
212
+ cls=cls,
213
+ fps=fps,
214
+ lossless=lossless,
215
+ quality=quality,
216
+ method=cls.COMPRESS_METHODS.get(method)
217
+ )
218
+ )
219
+
220
+ save_images = execute # TODO: remove
221
+
222
+
223
+ class SaveAnimatedPNG(IO.ComfyNode):
224
+
225
+ @classmethod
226
+ def define_schema(cls):
227
+ return IO.Schema(
228
+ node_id="SaveAnimatedPNG",
229
+ category="image/animation",
230
+ inputs=[
231
+ IO.Image.Input("images"),
232
+ IO.String.Input("filename_prefix", default="ComfyUI"),
233
+ IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01),
234
+ IO.Int.Input("compress_level", default=4, min=0, max=9, advanced=True),
235
+ ],
236
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
237
+ is_output_node=True,
238
+ )
239
+
240
+ @classmethod
241
+ def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput:
242
+ return IO.NodeOutput(
243
+ ui=UI.ImageSaveHelper.get_save_animated_png_ui(
244
+ images=images,
245
+ filename_prefix=filename_prefix,
246
+ cls=cls,
247
+ fps=fps,
248
+ compress_level=compress_level,
249
+ )
250
+ )
251
+
252
+ save_images = execute # TODO: remove
253
+
254
+
255
+ class ImageStitch(IO.ComfyNode):
256
+ """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
257
+ @classmethod
258
+ def define_schema(cls):
259
+ return IO.Schema(
260
+ node_id="ImageStitch",
261
+ search_aliases=["combine images", "join images", "concatenate images", "side by side"],
262
+ display_name="Image Stitch",
263
+ description="Stitches image2 to image1 in the specified direction.\n"
264
+ "If image2 is not provided, returns image1 unchanged.\n"
265
+ "Optional spacing can be added between images.",
266
+ category="image/transform",
267
+ inputs=[
268
+ IO.Image.Input("image1"),
269
+ IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"),
270
+ IO.Boolean.Input("match_image_size", default=True),
271
+ IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2, advanced=True),
272
+ IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white", advanced=True),
273
+ IO.Image.Input("image2", optional=True),
274
+ ],
275
+ outputs=[IO.Image.Output()],
276
+ )
277
+
278
+ @classmethod
279
+ def execute(
280
+ cls,
281
+ image1,
282
+ direction,
283
+ match_image_size,
284
+ spacing_width,
285
+ spacing_color,
286
+ image2=None,
287
+ ) -> IO.NodeOutput:
288
+ if image2 is None:
289
+ return IO.NodeOutput(image1)
290
+
291
+ # Handle batch size differences
292
+ if image1.shape[0] != image2.shape[0]:
293
+ max_batch = max(image1.shape[0], image2.shape[0])
294
+ if image1.shape[0] < max_batch:
295
+ image1 = torch.cat(
296
+ [image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
297
+ )
298
+ if image2.shape[0] < max_batch:
299
+ image2 = torch.cat(
300
+ [image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
301
+ )
302
+
303
+ # Match image sizes if requested
304
+ if match_image_size:
305
+ h1, w1 = image1.shape[1:3]
306
+ h2, w2 = image2.shape[1:3]
307
+ aspect_ratio = w2 / h2
308
+
309
+ if direction in ["left", "right"]:
310
+ target_h, target_w = h1, int(h1 * aspect_ratio)
311
+ else: # up, down
312
+ target_w, target_h = w1, int(w1 / aspect_ratio)
313
+
314
+ image2 = comfy.utils.common_upscale(
315
+ image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
316
+ ).movedim(1, -1)
317
+
318
+ color_map = {
319
+ "white": 1.0,
320
+ "black": 0.0,
321
+ "red": (1.0, 0.0, 0.0),
322
+ "green": (0.0, 1.0, 0.0),
323
+ "blue": (0.0, 0.0, 1.0),
324
+ }
325
+
326
+ color_val = color_map[spacing_color]
327
+
328
+ # When not matching sizes, pad to align non-concat dimensions
329
+ if not match_image_size:
330
+ h1, w1 = image1.shape[1:3]
331
+ h2, w2 = image2.shape[1:3]
332
+ pad_value = 0.0
333
+ if not isinstance(color_val, tuple):
334
+ pad_value = color_val
335
+
336
+ if direction in ["left", "right"]:
337
+ # For horizontal concat, pad heights to match
338
+ if h1 != h2:
339
+ target_h = max(h1, h2)
340
+ if h1 < target_h:
341
+ pad_h = target_h - h1
342
+ pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
343
+ image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
344
+ if h2 < target_h:
345
+ pad_h = target_h - h2
346
+ pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
347
+ image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
348
+ else: # up, down
349
+ # For vertical concat, pad widths to match
350
+ if w1 != w2:
351
+ target_w = max(w1, w2)
352
+ if w1 < target_w:
353
+ pad_w = target_w - w1
354
+ pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
355
+ image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
356
+ if w2 < target_w:
357
+ pad_w = target_w - w2
358
+ pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
359
+ image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
360
+
361
+ # Ensure same number of channels
362
+ if image1.shape[-1] != image2.shape[-1]:
363
+ max_channels = max(image1.shape[-1], image2.shape[-1])
364
+ if image1.shape[-1] < max_channels:
365
+ image1 = torch.cat(
366
+ [
367
+ image1,
368
+ torch.ones(
369
+ *image1.shape[:-1],
370
+ max_channels - image1.shape[-1],
371
+ device=image1.device,
372
+ ),
373
+ ],
374
+ dim=-1,
375
+ )
376
+ if image2.shape[-1] < max_channels:
377
+ image2 = torch.cat(
378
+ [
379
+ image2,
380
+ torch.ones(
381
+ *image2.shape[:-1],
382
+ max_channels - image2.shape[-1],
383
+ device=image2.device,
384
+ ),
385
+ ],
386
+ dim=-1,
387
+ )
388
+
389
+ # Add spacing if specified
390
+ if spacing_width > 0:
391
+ spacing_width = spacing_width + (spacing_width % 2) # Ensure even
392
+
393
+ if direction in ["left", "right"]:
394
+ spacing_shape = (
395
+ image1.shape[0],
396
+ max(image1.shape[1], image2.shape[1]),
397
+ spacing_width,
398
+ image1.shape[-1],
399
+ )
400
+ else:
401
+ spacing_shape = (
402
+ image1.shape[0],
403
+ spacing_width,
404
+ max(image1.shape[2], image2.shape[2]),
405
+ image1.shape[-1],
406
+ )
407
+
408
+ spacing = torch.full(spacing_shape, 0.0, device=image1.device)
409
+ if isinstance(color_val, tuple):
410
+ for i, c in enumerate(color_val):
411
+ if i < spacing.shape[-1]:
412
+ spacing[..., i] = c
413
+ if spacing.shape[-1] == 4: # Add alpha
414
+ spacing[..., 3] = 1.0
415
+ else:
416
+ spacing[..., : min(3, spacing.shape[-1])] = color_val
417
+ if spacing.shape[-1] == 4:
418
+ spacing[..., 3] = 1.0
419
+
420
+ # Concatenate images
421
+ images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
422
+ if spacing_width > 0:
423
+ images.insert(1, spacing)
424
+
425
+ concat_dim = 2 if direction in ["left", "right"] else 1
426
+ return IO.NodeOutput(torch.cat(images, dim=concat_dim))
427
+
428
+ stitch = execute # TODO: remove
429
+
430
+
431
+ class ResizeAndPadImage(IO.ComfyNode):
432
+ @classmethod
433
+ def define_schema(cls):
434
+ return IO.Schema(
435
+ node_id="ResizeAndPadImage",
436
+ search_aliases=["fit to size"],
437
+ category="image/transform",
438
+ inputs=[
439
+ IO.Image.Input("image"),
440
+ IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
441
+ IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
442
+ IO.Combo.Input("padding_color", options=["white", "black"], advanced=True),
443
+ IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"], advanced=True),
444
+ ],
445
+ outputs=[IO.Image.Output()],
446
+ )
447
+
448
+ @classmethod
449
+ def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput:
450
+ batch_size, orig_height, orig_width, channels = image.shape
451
+
452
+ scale_w = target_width / orig_width
453
+ scale_h = target_height / orig_height
454
+ scale = min(scale_w, scale_h)
455
+
456
+ new_width = int(orig_width * scale)
457
+ new_height = int(orig_height * scale)
458
+
459
+ image_permuted = image.permute(0, 3, 1, 2)
460
+
461
+ resized = comfy.utils.common_upscale(image_permuted, new_width, new_height, interpolation, "disabled")
462
+
463
+ pad_value = 0.0 if padding_color == "black" else 1.0
464
+ padded = torch.full(
465
+ (batch_size, channels, target_height, target_width),
466
+ pad_value,
467
+ dtype=image.dtype,
468
+ device=image.device
469
+ )
470
+
471
+ y_offset = (target_height - new_height) // 2
472
+ x_offset = (target_width - new_width) // 2
473
+
474
+ padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized
475
+
476
+ output = padded.permute(0, 2, 3, 1)
477
+ return IO.NodeOutput(output)
478
+
479
+ resize_and_pad = execute # TODO: remove
480
+
481
+
482
+ class SaveSVGNode(IO.ComfyNode):
483
+ @classmethod
484
+ def define_schema(cls):
485
+ return IO.Schema(
486
+ node_id="SaveSVGNode",
487
+ search_aliases=["export vector", "save vector graphics"],
488
+ description="Save SVG files on disk.",
489
+ category="image/save",
490
+ inputs=[
491
+ IO.SVG.Input("svg"),
492
+ IO.String.Input(
493
+ "filename_prefix",
494
+ default="svg/ComfyUI",
495
+ tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
496
+ ),
497
+ ],
498
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
499
+ is_output_node=True,
500
+ )
501
+
502
+ @classmethod
503
+ def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput:
504
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
505
+ results: list[UI.SavedResult] = []
506
+
507
+ # Prepare metadata JSON
508
+ metadata_dict = {}
509
+ if cls.hidden.prompt is not None:
510
+ metadata_dict["prompt"] = cls.hidden.prompt
511
+ if cls.hidden.extra_pnginfo is not None:
512
+ metadata_dict.update(cls.hidden.extra_pnginfo)
513
+
514
+ # Convert metadata to JSON string
515
+ metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
516
+
517
+
518
+ for batch_number, svg_bytes in enumerate(svg.data):
519
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
520
+ file = f"{filename_with_batch_num}_{counter:05}_.svg"
521
+
522
+ # Read SVG content
523
+ svg_bytes.seek(0)
524
+ svg_content = svg_bytes.read().decode('utf-8')
525
+
526
+ # Inject metadata if available
527
+ if metadata_json:
528
+ # Create metadata element with CDATA section
529
+ metadata_element = f""" <metadata>
530
+ <![CDATA[
531
+ {metadata_json}
532
+ ]]>
533
+ </metadata>
534
+ """
535
+ # Insert metadata after opening svg tag using regex with a replacement function
536
+ def replacement(match):
537
+ # match.group(1) contains the captured <svg> tag
538
+ return match.group(1) + '\n' + metadata_element
539
+
540
+ # Apply the substitution
541
+ svg_content = re.sub(r'(<svg[^>]*>)', replacement, svg_content, flags=re.UNICODE)
542
+
543
+ # Write the modified SVG to file
544
+ with open(os.path.join(full_output_folder, file), 'wb') as svg_file:
545
+ svg_file.write(svg_content.encode('utf-8'))
546
+
547
+ results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output))
548
+ counter += 1
549
+ return IO.NodeOutput(ui={"images": results})
550
+
551
+ save_svg = execute # TODO: remove
552
+
553
+
554
+ class GetImageSize(IO.ComfyNode):
555
+ @classmethod
556
+ def define_schema(cls):
557
+ return IO.Schema(
558
+ node_id="GetImageSize",
559
+ search_aliases=["dimensions", "resolution", "image info"],
560
+ display_name="Get Image Size",
561
+ description="Returns width and height of the image, and passes it through unchanged.",
562
+ category="image",
563
+ inputs=[
564
+ IO.Image.Input("image"),
565
+ ],
566
+ outputs=[
567
+ IO.Int.Output(display_name="width"),
568
+ IO.Int.Output(display_name="height"),
569
+ IO.Int.Output(display_name="batch_size"),
570
+ ],
571
+ hidden=[IO.Hidden.unique_id],
572
+ )
573
+
574
+ @classmethod
575
+ def execute(cls, image) -> IO.NodeOutput:
576
+ height = image.shape[1]
577
+ width = image.shape[2]
578
+ batch_size = image.shape[0]
579
+
580
+ # Send progress text to display size on the node
581
+ if cls.hidden.unique_id:
582
+ PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id)
583
+
584
+ return IO.NodeOutput(width, height, batch_size)
585
+
586
+ get_size = execute # TODO: remove
587
+
588
+
589
+ class ImageRotate(IO.ComfyNode):
590
+ @classmethod
591
+ def define_schema(cls):
592
+ return IO.Schema(
593
+ node_id="ImageRotate",
594
+ display_name="Image Rotate",
595
+ search_aliases=["turn", "flip orientation"],
596
+ category="image/transform",
597
+ essentials_category="Image Tools",
598
+ inputs=[
599
+ IO.Image.Input("image"),
600
+ IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
601
+ ],
602
+ outputs=[IO.Image.Output()],
603
+ )
604
+
605
+ @classmethod
606
+ def execute(cls, image, rotation) -> IO.NodeOutput:
607
+ rotate_by = 0
608
+ if rotation.startswith("90"):
609
+ rotate_by = 1
610
+ elif rotation.startswith("180"):
611
+ rotate_by = 2
612
+ elif rotation.startswith("270"):
613
+ rotate_by = 3
614
+
615
+ image = torch.rot90(image, k=rotate_by, dims=[2, 1])
616
+ return IO.NodeOutput(image)
617
+
618
+ rotate = execute # TODO: remove
619
+
620
+
621
+ class ImageFlip(IO.ComfyNode):
622
+ @classmethod
623
+ def define_schema(cls):
624
+ return IO.Schema(
625
+ node_id="ImageFlip",
626
+ search_aliases=["mirror", "reflect"],
627
+ category="image/transform",
628
+ inputs=[
629
+ IO.Image.Input("image"),
630
+ IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]),
631
+ ],
632
+ outputs=[IO.Image.Output()],
633
+ )
634
+
635
+ @classmethod
636
+ def execute(cls, image, flip_method) -> IO.NodeOutput:
637
+ if flip_method.startswith("x"):
638
+ image = torch.flip(image, dims=[1])
639
+ elif flip_method.startswith("y"):
640
+ image = torch.flip(image, dims=[2])
641
+
642
+ return IO.NodeOutput(image)
643
+
644
+ flip = execute # TODO: remove
645
+
646
+
647
+ class ImageScaleToMaxDimension(IO.ComfyNode):
648
+
649
+ @classmethod
650
+ def define_schema(cls):
651
+ return IO.Schema(
652
+ node_id="ImageScaleToMaxDimension",
653
+ category="image/upscaling",
654
+ inputs=[
655
+ IO.Image.Input("image"),
656
+ IO.Combo.Input(
657
+ "upscale_method",
658
+ options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"],
659
+ ),
660
+ IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1),
661
+ ],
662
+ outputs=[IO.Image.Output()],
663
+ )
664
+
665
+ @classmethod
666
+ def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput:
667
+ height = image.shape[1]
668
+ width = image.shape[2]
669
+
670
+ if height > width:
671
+ width = round((width / height) * largest_size)
672
+ height = largest_size
673
+ elif width > height:
674
+ height = round((height / width) * largest_size)
675
+ width = largest_size
676
+ else:
677
+ height = largest_size
678
+ width = largest_size
679
+
680
+ samples = image.movedim(-1, 1)
681
+ s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
682
+ s = s.movedim(1, -1)
683
+ return IO.NodeOutput(s)
684
+
685
+ upscale = execute # TODO: remove
686
+
687
+
688
+ class SplitImageToTileList(IO.ComfyNode):
689
+ @classmethod
690
+ def define_schema(cls):
691
+ return IO.Schema(
692
+ node_id="SplitImageToTileList",
693
+ category="image/batch",
694
+ search_aliases=["split image", "tile image", "slice image"],
695
+ display_name="Split Image into List of Tiles",
696
+ description="Splits an image into a batched list of tiles with a specified overlap.",
697
+ inputs=[
698
+ IO.Image.Input("image"),
699
+ IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION),
700
+ IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION),
701
+ IO.Int.Input("overlap", default=128, min=0, max=4096),
702
+ ],
703
+ outputs=[
704
+ IO.Image.Output(is_output_list=True),
705
+ ],
706
+ )
707
+
708
+ @staticmethod
709
+ def get_grid_coords(width, height, tile_width, tile_height, overlap):
710
+ coords = []
711
+ stride_x = round(max(tile_width * 0.25, tile_width - overlap))
712
+ stride_y = round(max(tile_width * 0.25, tile_height - overlap))
713
+
714
+ y = 0
715
+ while y < height:
716
+ x = 0
717
+ y_end = min(y + tile_height, height)
718
+ y_start = max(0, y_end - tile_height)
719
+
720
+ while x < width:
721
+ x_end = min(x + tile_width, width)
722
+ x_start = max(0, x_end - tile_width)
723
+
724
+ coords.append((x_start, y_start, x_end, y_end))
725
+
726
+ if x_end >= width:
727
+ break
728
+ x += stride_x
729
+
730
+ if y_end >= height:
731
+ break
732
+ y += stride_y
733
+
734
+ return coords
735
+
736
+ @classmethod
737
+ def execute(cls, image, tile_width, tile_height, overlap):
738
+ b, h, w, c = image.shape
739
+ coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap)
740
+
741
+ output_list = []
742
+ for (x_start, y_start, x_end, y_end) in coords:
743
+ tile = image[:, y_start:y_end, x_start:x_end, :]
744
+ output_list.append(tile)
745
+
746
+ return IO.NodeOutput(output_list)
747
+
748
+
749
+ class ImageMergeTileList(IO.ComfyNode):
750
+ @classmethod
751
+ def define_schema(cls):
752
+ return IO.Schema(
753
+ node_id="ImageMergeTileList",
754
+ display_name="Merge List of Tiles to Image",
755
+ category="image/batch",
756
+ search_aliases=["split image", "tile image", "slice image"],
757
+ is_input_list=True,
758
+ inputs=[
759
+ IO.Image.Input("image_list"),
760
+ IO.Int.Input("final_width", default=1024, min=64, max=32768),
761
+ IO.Int.Input("final_height", default=1024, min=64, max=32768),
762
+ IO.Int.Input("overlap", default=128, min=0, max=4096),
763
+ ],
764
+ outputs=[
765
+ IO.Image.Output(is_output_list=False),
766
+ ],
767
+ )
768
+
769
+ @classmethod
770
+ def execute(cls, image_list, final_width, final_height, overlap):
771
+ w = final_width[0]
772
+ h = final_height[0]
773
+ ovlp = overlap[0]
774
+ feather_str = 1.0
775
+
776
+ first_tile = image_list[0]
777
+ b, t_h, t_w, c = first_tile.shape
778
+ device = first_tile.device
779
+ dtype = first_tile.dtype
780
+
781
+ coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp)
782
+
783
+ canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
784
+ weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
785
+
786
+ if ovlp > 0:
787
+ y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype))
788
+ x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype))
789
+ y_w = torch.clamp(y_w, min=1e-5)
790
+ x_w = torch.clamp(x_w, min=1e-5)
791
+
792
+ sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1)
793
+ flat_mask = torch.ones_like(sine_mask)
794
+
795
+ weight_mask = torch.lerp(flat_mask, sine_mask, feather_str)
796
+ else:
797
+ weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype)
798
+
799
+ for i, (x_start, y_start, x_end, y_end) in enumerate(coords):
800
+ if i >= len(image_list):
801
+ break
802
+
803
+ tile = image_list[i]
804
+
805
+ region_h = y_end - y_start
806
+ region_w = x_end - x_start
807
+
808
+ real_h = min(region_h, tile.shape[1])
809
+ real_w = min(region_w, tile.shape[2])
810
+
811
+ y_end_actual = y_start + real_h
812
+ x_end_actual = x_start + real_w
813
+
814
+ tile_crop = tile[:, :real_h, :real_w, :]
815
+ mask_crop = weight_mask[:, :real_h, :real_w, :]
816
+
817
+ canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop
818
+ weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop
819
+
820
+ weights[weights == 0] = 1.0
821
+ merged_image = canvas / weights
822
+
823
+ return IO.NodeOutput(merged_image)
824
+
825
+
826
+ class ImagesExtension(ComfyExtension):
827
+ @override
828
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
829
+ return [
830
+ ImageCrop,
831
+ ImageCropV2,
832
+ BoundingBox,
833
+ RepeatImageBatch,
834
+ ImageFromBatch,
835
+ ImageAddNoise,
836
+ SaveAnimatedWEBP,
837
+ SaveAnimatedPNG,
838
+ SaveSVGNode,
839
+ ImageStitch,
840
+ ResizeAndPadImage,
841
+ GetImageSize,
842
+ ImageRotate,
843
+ ImageFlip,
844
+ ImageScaleToMaxDimension,
845
+ SplitImageToTileList,
846
+ ImageMergeTileList,
847
+ ]
848
+
849
+
850
+ async def comfy_entrypoint() -> ImagesExtension:
851
+ return ImagesExtension()
ComfyUI/comfy_extras/nodes_ip2p.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing_extensions import override
4
+ from comfy_api.latest import ComfyExtension, io
5
+
6
+
7
+ class InstructPixToPixConditioning(io.ComfyNode):
8
+ @classmethod
9
+ def define_schema(cls):
10
+ return io.Schema(
11
+ node_id="InstructPixToPixConditioning",
12
+ category="conditioning/instructpix2pix",
13
+ inputs=[
14
+ io.Conditioning.Input("positive"),
15
+ io.Conditioning.Input("negative"),
16
+ io.Vae.Input("vae"),
17
+ io.Image.Input("pixels"),
18
+ ],
19
+ outputs=[
20
+ io.Conditioning.Output(display_name="positive"),
21
+ io.Conditioning.Output(display_name="negative"),
22
+ io.Latent.Output(display_name="latent"),
23
+ ],
24
+ )
25
+
26
+ @classmethod
27
+ def execute(cls, positive, negative, pixels, vae) -> io.NodeOutput:
28
+ x = (pixels.shape[1] // 8) * 8
29
+ y = (pixels.shape[2] // 8) * 8
30
+
31
+ if pixels.shape[1] != x or pixels.shape[2] != y:
32
+ x_offset = (pixels.shape[1] % 8) // 2
33
+ y_offset = (pixels.shape[2] % 8) // 2
34
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
35
+
36
+ concat_latent = vae.encode(pixels)
37
+
38
+ out_latent = {}
39
+ out_latent["samples"] = torch.zeros_like(concat_latent)
40
+
41
+ out = []
42
+ for conditioning in [positive, negative]:
43
+ c = []
44
+ for t in conditioning:
45
+ d = t[1].copy()
46
+ d["concat_latent_image"] = concat_latent
47
+ n = [t[0], d]
48
+ c.append(n)
49
+ out.append(c)
50
+ return io.NodeOutput(out[0], out[1], out_latent)
51
+
52
+
53
+ class InstructPix2PixExtension(ComfyExtension):
54
+ @override
55
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
56
+ return [
57
+ InstructPixToPixConditioning,
58
+ ]
59
+
60
+
61
+ async def comfy_entrypoint() -> InstructPix2PixExtension:
62
+ return InstructPix2PixExtension()
63
+
ComfyUI/comfy_extras/nodes_kandinsky5.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import node_helpers
3
+ import torch
4
+ import comfy.model_management
5
+ import comfy.utils
6
+
7
+ from typing_extensions import override
8
+ from comfy_api.latest import ComfyExtension, io
9
+
10
+
11
+ class Kandinsky5ImageToVideo(io.ComfyNode):
12
+ @classmethod
13
+ def define_schema(cls):
14
+ return io.Schema(
15
+ node_id="Kandinsky5ImageToVideo",
16
+ category="conditioning/video_models",
17
+ inputs=[
18
+ io.Conditioning.Input("positive"),
19
+ io.Conditioning.Input("negative"),
20
+ io.Vae.Input("vae"),
21
+ io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
22
+ io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
23
+ io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
24
+ io.Int.Input("batch_size", default=1, min=1, max=4096),
25
+ io.Image.Input("start_image", optional=True),
26
+ ],
27
+ outputs=[
28
+ io.Conditioning.Output(display_name="positive"),
29
+ io.Conditioning.Output(display_name="negative"),
30
+ io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
31
+ io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
32
+ ],
33
+ )
34
+
35
+ @classmethod
36
+ def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
37
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
38
+ cond_latent_out = {}
39
+ if start_image is not None:
40
+ start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
41
+ encoded = vae.encode(start_image[:, :, :, :3])
42
+ cond_latent_out["samples"] = encoded
43
+
44
+ mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
45
+ mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
46
+
47
+ positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
48
+ negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
49
+
50
+ out_latent = {}
51
+ out_latent["samples"] = latent
52
+ return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
53
+
54
+
55
+ def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
56
+ source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
57
+ source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
58
+
59
+ reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
60
+ reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
61
+
62
+ # normalization
63
+ normalized = (source - source_mean) / (source_std + 1e-8)
64
+ normalized = normalized * reference_std + reference_mean
65
+
66
+ return normalized
67
+
68
+
69
+ class NormalizeVideoLatentStart(io.ComfyNode):
70
+ @classmethod
71
+ def define_schema(cls):
72
+ return io.Schema(
73
+ node_id="NormalizeVideoLatentStart",
74
+ category="conditioning/video_models",
75
+ description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
76
+ inputs=[
77
+ io.Latent.Input("latent"),
78
+ io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
79
+ io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
80
+ ],
81
+ outputs=[
82
+ io.Latent.Output(display_name="latent"),
83
+ ],
84
+ )
85
+
86
+ @classmethod
87
+ def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
88
+ if latent["samples"].shape[2] <= 1:
89
+ return io.NodeOutput(latent)
90
+ s = latent.copy()
91
+ samples = latent["samples"].clone()
92
+
93
+ first_frames = samples[:, :, :start_frame_count]
94
+ reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
95
+ normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
96
+
97
+ samples[:, :, :start_frame_count] = normalized_first_frames
98
+ s["samples"] = samples
99
+ return io.NodeOutput(s)
100
+
101
+
102
+ class CLIPTextEncodeKandinsky5(io.ComfyNode):
103
+ @classmethod
104
+ def define_schema(cls):
105
+ return io.Schema(
106
+ node_id="CLIPTextEncodeKandinsky5",
107
+ search_aliases=["kandinsky prompt"],
108
+ category="advanced/conditioning/kandinsky5",
109
+ inputs=[
110
+ io.Clip.Input("clip"),
111
+ io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
112
+ io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
113
+ ],
114
+ outputs=[
115
+ io.Conditioning.Output(),
116
+ ],
117
+ )
118
+
119
+ @classmethod
120
+ def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
121
+ tokens = clip.tokenize(clip_l)
122
+ tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
123
+
124
+ return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
125
+
126
+
127
+ class Kandinsky5Extension(ComfyExtension):
128
+ @override
129
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
130
+ return [
131
+ Kandinsky5ImageToVideo,
132
+ NormalizeVideoLatentStart,
133
+ CLIPTextEncodeKandinsky5,
134
+ ]
135
+
136
+ async def comfy_entrypoint() -> Kandinsky5Extension:
137
+ return Kandinsky5Extension()
ComfyUI/comfy_extras/nodes_latent.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import comfy.utils
2
+ import comfy_extras.nodes_post_processing
3
+ import torch
4
+ import nodes
5
+ from typing_extensions import override
6
+ from comfy_api.latest import ComfyExtension, io
7
+ import logging
8
+ import math
9
+
10
+ def reshape_latent_to(target_shape, latent, repeat_batch=True):
11
+ if latent.shape[1:] != target_shape[1:]:
12
+ latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center")
13
+ if repeat_batch:
14
+ return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
15
+ else:
16
+ return latent
17
+
18
+
19
+ class LatentAdd(io.ComfyNode):
20
+ @classmethod
21
+ def define_schema(cls):
22
+ return io.Schema(
23
+ node_id="LatentAdd",
24
+ search_aliases=["combine latents", "sum latents"],
25
+ category="latent/advanced",
26
+ inputs=[
27
+ io.Latent.Input("samples1"),
28
+ io.Latent.Input("samples2"),
29
+ ],
30
+ outputs=[
31
+ io.Latent.Output(),
32
+ ],
33
+ )
34
+
35
+ @classmethod
36
+ def execute(cls, samples1, samples2) -> io.NodeOutput:
37
+ samples_out = samples1.copy()
38
+
39
+ s1 = samples1["samples"]
40
+ s2 = samples2["samples"]
41
+
42
+ s2 = reshape_latent_to(s1.shape, s2)
43
+ samples_out["samples"] = s1 + s2
44
+ return io.NodeOutput(samples_out)
45
+
46
+ class LatentSubtract(io.ComfyNode):
47
+ @classmethod
48
+ def define_schema(cls):
49
+ return io.Schema(
50
+ node_id="LatentSubtract",
51
+ search_aliases=["difference latent", "remove features"],
52
+ category="latent/advanced",
53
+ inputs=[
54
+ io.Latent.Input("samples1"),
55
+ io.Latent.Input("samples2"),
56
+ ],
57
+ outputs=[
58
+ io.Latent.Output(),
59
+ ],
60
+ )
61
+
62
+ @classmethod
63
+ def execute(cls, samples1, samples2) -> io.NodeOutput:
64
+ samples_out = samples1.copy()
65
+
66
+ s1 = samples1["samples"]
67
+ s2 = samples2["samples"]
68
+
69
+ s2 = reshape_latent_to(s1.shape, s2)
70
+ samples_out["samples"] = s1 - s2
71
+ return io.NodeOutput(samples_out)
72
+
73
+ class LatentMultiply(io.ComfyNode):
74
+ @classmethod
75
+ def define_schema(cls):
76
+ return io.Schema(
77
+ node_id="LatentMultiply",
78
+ search_aliases=["scale latent", "amplify latent", "latent gain"],
79
+ category="latent/advanced",
80
+ inputs=[
81
+ io.Latent.Input("samples"),
82
+ io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
83
+ ],
84
+ outputs=[
85
+ io.Latent.Output(),
86
+ ],
87
+ )
88
+
89
+ @classmethod
90
+ def execute(cls, samples, multiplier) -> io.NodeOutput:
91
+ samples_out = samples.copy()
92
+
93
+ s1 = samples["samples"]
94
+ samples_out["samples"] = s1 * multiplier
95
+ return io.NodeOutput(samples_out)
96
+
97
+ class LatentInterpolate(io.ComfyNode):
98
+ @classmethod
99
+ def define_schema(cls):
100
+ return io.Schema(
101
+ node_id="LatentInterpolate",
102
+ search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
103
+ category="latent/advanced",
104
+ inputs=[
105
+ io.Latent.Input("samples1"),
106
+ io.Latent.Input("samples2"),
107
+ io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
108
+ ],
109
+ outputs=[
110
+ io.Latent.Output(),
111
+ ],
112
+ )
113
+
114
+ @classmethod
115
+ def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
116
+ samples_out = samples1.copy()
117
+
118
+ s1 = samples1["samples"]
119
+ s2 = samples2["samples"]
120
+
121
+ s2 = reshape_latent_to(s1.shape, s2)
122
+
123
+ m1 = torch.linalg.vector_norm(s1, dim=(1))
124
+ m2 = torch.linalg.vector_norm(s2, dim=(1))
125
+
126
+ s1 = torch.nan_to_num(s1 / m1)
127
+ s2 = torch.nan_to_num(s2 / m2)
128
+
129
+ t = (s1 * ratio + s2 * (1.0 - ratio))
130
+ mt = torch.linalg.vector_norm(t, dim=(1))
131
+ st = torch.nan_to_num(t / mt)
132
+
133
+ samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
134
+ return io.NodeOutput(samples_out)
135
+
136
+ class LatentConcat(io.ComfyNode):
137
+ @classmethod
138
+ def define_schema(cls):
139
+ return io.Schema(
140
+ node_id="LatentConcat",
141
+ search_aliases=["join latents", "stitch latents"],
142
+ category="latent/advanced",
143
+ inputs=[
144
+ io.Latent.Input("samples1"),
145
+ io.Latent.Input("samples2"),
146
+ io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
147
+ ],
148
+ outputs=[
149
+ io.Latent.Output(),
150
+ ],
151
+ )
152
+
153
+ @classmethod
154
+ def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
155
+ samples_out = samples1.copy()
156
+
157
+ s1 = samples1["samples"]
158
+ s2 = samples2["samples"]
159
+ s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0])
160
+
161
+ if "-" in dim:
162
+ c = (s2, s1)
163
+ else:
164
+ c = (s1, s2)
165
+
166
+ if "x" in dim:
167
+ dim = -1
168
+ elif "y" in dim:
169
+ dim = -2
170
+ elif "t" in dim:
171
+ dim = -3
172
+
173
+ samples_out["samples"] = torch.cat(c, dim=dim)
174
+ return io.NodeOutput(samples_out)
175
+
176
+ class LatentCut(io.ComfyNode):
177
+ @classmethod
178
+ def define_schema(cls):
179
+ return io.Schema(
180
+ node_id="LatentCut",
181
+ search_aliases=["crop latent", "slice latent", "extract region"],
182
+ category="latent/advanced",
183
+ inputs=[
184
+ io.Latent.Input("samples"),
185
+ io.Combo.Input("dim", options=["x", "y", "t"]),
186
+ io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
187
+ io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
188
+ ],
189
+ outputs=[
190
+ io.Latent.Output(),
191
+ ],
192
+ )
193
+
194
+ @classmethod
195
+ def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
196
+ samples_out = samples.copy()
197
+
198
+ s1 = samples["samples"]
199
+
200
+ if "x" in dim:
201
+ dim = s1.ndim - 1
202
+ elif "y" in dim:
203
+ dim = s1.ndim - 2
204
+ elif "t" in dim:
205
+ dim = s1.ndim - 3
206
+
207
+ if index >= 0:
208
+ index = min(index, s1.shape[dim] - 1)
209
+ amount = min(s1.shape[dim] - index, amount)
210
+ else:
211
+ index = max(index, -s1.shape[dim])
212
+ amount = min(-index, amount)
213
+
214
+ samples_out["samples"] = torch.narrow(s1, dim, index, amount)
215
+ return io.NodeOutput(samples_out)
216
+
217
+ class LatentCutToBatch(io.ComfyNode):
218
+ @classmethod
219
+ def define_schema(cls):
220
+ return io.Schema(
221
+ node_id="LatentCutToBatch",
222
+ search_aliases=["slice to batch", "split latent", "tile latent"],
223
+ category="latent/advanced",
224
+ inputs=[
225
+ io.Latent.Input("samples"),
226
+ io.Combo.Input("dim", options=["t", "x", "y"]),
227
+ io.Int.Input("slice_size", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
228
+ ],
229
+ outputs=[
230
+ io.Latent.Output(),
231
+ ],
232
+ )
233
+
234
+ @classmethod
235
+ def execute(cls, samples, dim, slice_size) -> io.NodeOutput:
236
+ samples_out = samples.copy()
237
+
238
+ s1 = samples["samples"]
239
+
240
+ if "x" in dim:
241
+ dim = s1.ndim - 1
242
+ elif "y" in dim:
243
+ dim = s1.ndim - 2
244
+ elif "t" in dim:
245
+ dim = s1.ndim - 3
246
+
247
+ if dim < 2:
248
+ return io.NodeOutput(samples)
249
+
250
+ s = s1.movedim(dim, 1)
251
+ if s.shape[1] < slice_size:
252
+ slice_size = s.shape[1]
253
+ elif s.shape[1] % slice_size != 0:
254
+ s = s[:, :math.floor(s.shape[1] / slice_size) * slice_size]
255
+ new_shape = [-1, slice_size] + list(s.shape[2:])
256
+ samples_out["samples"] = s.reshape(new_shape).movedim(1, dim)
257
+ return io.NodeOutput(samples_out)
258
+
259
+ class LatentBatch(io.ComfyNode):
260
+ @classmethod
261
+ def define_schema(cls):
262
+ return io.Schema(
263
+ node_id="LatentBatch",
264
+ search_aliases=["combine latents", "merge latents", "join latents"],
265
+ category="latent/batch",
266
+ is_deprecated=True,
267
+ inputs=[
268
+ io.Latent.Input("samples1"),
269
+ io.Latent.Input("samples2"),
270
+ ],
271
+ outputs=[
272
+ io.Latent.Output(),
273
+ ],
274
+ )
275
+
276
+ @classmethod
277
+ def execute(cls, samples1, samples2) -> io.NodeOutput:
278
+ samples_out = samples1.copy()
279
+ s1 = samples1["samples"]
280
+ s2 = samples2["samples"]
281
+
282
+ s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False)
283
+ s = torch.cat((s1, s2), dim=0)
284
+ samples_out["samples"] = s
285
+ samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
286
+ return io.NodeOutput(samples_out)
287
+
288
+ class LatentBatchSeedBehavior(io.ComfyNode):
289
+ @classmethod
290
+ def define_schema(cls):
291
+ return io.Schema(
292
+ node_id="LatentBatchSeedBehavior",
293
+ category="latent/advanced",
294
+ inputs=[
295
+ io.Latent.Input("samples"),
296
+ io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
297
+ ],
298
+ outputs=[
299
+ io.Latent.Output(),
300
+ ],
301
+ )
302
+
303
+ @classmethod
304
+ def execute(cls, samples, seed_behavior) -> io.NodeOutput:
305
+ samples_out = samples.copy()
306
+ latent = samples["samples"]
307
+ if seed_behavior == "random":
308
+ if 'batch_index' in samples_out:
309
+ samples_out.pop('batch_index')
310
+ elif seed_behavior == "fixed":
311
+ batch_number = samples_out.get("batch_index", [0])[0]
312
+ samples_out["batch_index"] = [batch_number] * latent.shape[0]
313
+
314
+ return io.NodeOutput(samples_out)
315
+
316
+ class LatentApplyOperation(io.ComfyNode):
317
+ @classmethod
318
+ def define_schema(cls):
319
+ return io.Schema(
320
+ node_id="LatentApplyOperation",
321
+ search_aliases=["transform latent"],
322
+ category="latent/advanced/operations",
323
+ is_experimental=True,
324
+ inputs=[
325
+ io.Latent.Input("samples"),
326
+ io.LatentOperation.Input("operation"),
327
+ ],
328
+ outputs=[
329
+ io.Latent.Output(),
330
+ ],
331
+ )
332
+
333
+ @classmethod
334
+ def execute(cls, samples, operation) -> io.NodeOutput:
335
+ samples_out = samples.copy()
336
+
337
+ s1 = samples["samples"]
338
+ samples_out["samples"] = operation(latent=s1)
339
+ return io.NodeOutput(samples_out)
340
+
341
+ class LatentApplyOperationCFG(io.ComfyNode):
342
+ @classmethod
343
+ def define_schema(cls):
344
+ return io.Schema(
345
+ node_id="LatentApplyOperationCFG",
346
+ category="latent/advanced/operations",
347
+ is_experimental=True,
348
+ inputs=[
349
+ io.Model.Input("model"),
350
+ io.LatentOperation.Input("operation"),
351
+ ],
352
+ outputs=[
353
+ io.Model.Output(),
354
+ ],
355
+ )
356
+
357
+ @classmethod
358
+ def execute(cls, model, operation) -> io.NodeOutput:
359
+ m = model.clone()
360
+
361
+ def pre_cfg_function(args):
362
+ conds_out = args["conds_out"]
363
+ if len(conds_out) == 2:
364
+ conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
365
+ else:
366
+ conds_out[0] = operation(latent=conds_out[0])
367
+ return conds_out
368
+
369
+ m.set_model_sampler_pre_cfg_function(pre_cfg_function)
370
+ return io.NodeOutput(m)
371
+
372
+ class LatentOperationTonemapReinhard(io.ComfyNode):
373
+ @classmethod
374
+ def define_schema(cls):
375
+ return io.Schema(
376
+ node_id="LatentOperationTonemapReinhard",
377
+ search_aliases=["hdr latent"],
378
+ category="latent/advanced/operations",
379
+ is_experimental=True,
380
+ inputs=[
381
+ io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
382
+ ],
383
+ outputs=[
384
+ io.LatentOperation.Output(),
385
+ ],
386
+ )
387
+
388
+ @classmethod
389
+ def execute(cls, multiplier) -> io.NodeOutput:
390
+ def tonemap_reinhard(latent, **kwargs):
391
+ latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
392
+ normalized_latent = latent / latent_vector_magnitude
393
+
394
+ dims = list(range(1, latent_vector_magnitude.ndim))
395
+ mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
396
+ std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
397
+
398
+ top = (std * 5 + mean) * multiplier
399
+
400
+ #reinhard
401
+ latent_vector_magnitude *= (1.0 / top)
402
+ new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
403
+ new_magnitude *= top
404
+
405
+ return normalized_latent * new_magnitude
406
+ return io.NodeOutput(tonemap_reinhard)
407
+
408
+ class LatentOperationSharpen(io.ComfyNode):
409
+ @classmethod
410
+ def define_schema(cls):
411
+ return io.Schema(
412
+ node_id="LatentOperationSharpen",
413
+ category="latent/advanced/operations",
414
+ is_experimental=True,
415
+ inputs=[
416
+ io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1, advanced=True),
417
+ io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1, advanced=True),
418
+ io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01, advanced=True),
419
+ ],
420
+ outputs=[
421
+ io.LatentOperation.Output(),
422
+ ],
423
+ )
424
+
425
+ @classmethod
426
+ def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
427
+ def sharpen(latent, **kwargs):
428
+ luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
429
+ normalized_latent = latent / luminance
430
+ channels = latent.shape[1]
431
+
432
+ kernel_size = sharpen_radius * 2 + 1
433
+ kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
434
+ center = kernel_size // 2
435
+
436
+ kernel *= alpha * -10
437
+ kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
438
+
439
+ padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
440
+ sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
441
+
442
+ return luminance * sharpened
443
+ return io.NodeOutput(sharpen)
444
+
445
+ class ReplaceVideoLatentFrames(io.ComfyNode):
446
+ @classmethod
447
+ def define_schema(cls):
448
+ return io.Schema(
449
+ node_id="ReplaceVideoLatentFrames",
450
+ category="latent/batch",
451
+ inputs=[
452
+ io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
453
+ io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
454
+ io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
455
+ ],
456
+ outputs=[
457
+ io.Latent.Output(),
458
+ ],
459
+ )
460
+
461
+ @classmethod
462
+ def execute(cls, destination, index, source=None) -> io.NodeOutput:
463
+ if source is None:
464
+ return io.NodeOutput(destination)
465
+ dest_frames = destination["samples"].shape[2]
466
+ source_frames = source["samples"].shape[2]
467
+ if index < 0:
468
+ index = dest_frames + index
469
+ if index > dest_frames:
470
+ logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
471
+ return io.NodeOutput(destination)
472
+ if index + source_frames > dest_frames:
473
+ logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
474
+ return io.NodeOutput(destination)
475
+ s = source.copy()
476
+ s_source = source["samples"]
477
+ s_destination = destination["samples"].clone()
478
+ s_destination[:, :, index:index + s_source.shape[2]] = s_source
479
+ s["samples"] = s_destination
480
+ return io.NodeOutput(s)
481
+
482
+ class LatentExtension(ComfyExtension):
483
+ @override
484
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
485
+ return [
486
+ LatentAdd,
487
+ LatentSubtract,
488
+ LatentMultiply,
489
+ LatentInterpolate,
490
+ LatentConcat,
491
+ LatentCut,
492
+ LatentCutToBatch,
493
+ LatentBatch,
494
+ LatentBatchSeedBehavior,
495
+ LatentApplyOperation,
496
+ LatentApplyOperationCFG,
497
+ LatentOperationTonemapReinhard,
498
+ LatentOperationSharpen,
499
+ ReplaceVideoLatentFrames
500
+ ]
501
+
502
+
503
+ async def comfy_entrypoint() -> LatentExtension:
504
+ return LatentExtension()
ComfyUI/comfy_extras/nodes_load_3d.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nodes
2
+ import folder_paths
3
+ import os
4
+ import uuid
5
+
6
+ from typing_extensions import override
7
+ from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types
8
+
9
+ from pathlib import Path
10
+
11
+
12
+ def normalize_path(path):
13
+ return path.replace('\\', '/')
14
+
15
+ class Load3D(IO.ComfyNode):
16
+ @classmethod
17
+ def define_schema(cls):
18
+ input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
19
+
20
+ os.makedirs(input_dir, exist_ok=True)
21
+
22
+ input_path = Path(input_dir)
23
+ base_path = Path(folder_paths.get_input_directory())
24
+
25
+ files = [
26
+ normalize_path(str(file_path.relative_to(base_path)))
27
+ for file_path in input_path.rglob("*")
28
+ if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'}
29
+ ]
30
+ return IO.Schema(
31
+ node_id="Load3D",
32
+ display_name="Load 3D & Animation",
33
+ category="3d",
34
+ essentials_category="Basics",
35
+ is_experimental=True,
36
+ inputs=[
37
+ IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
38
+ IO.Load3D.Input("image"),
39
+ IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
40
+ IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
41
+ ],
42
+ outputs=[
43
+ IO.Image.Output(display_name="image"),
44
+ IO.Mask.Output(display_name="mask"),
45
+ IO.String.Output(display_name="mesh_path"),
46
+ IO.Image.Output(display_name="normal"),
47
+ IO.Load3DCamera.Output(display_name="camera_info"),
48
+ IO.Video.Output(display_name="recording_video"),
49
+ IO.File3DAny.Output(display_name="model_3d"),
50
+ ],
51
+ )
52
+
53
+ @classmethod
54
+ def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
55
+ image_path = folder_paths.get_annotated_filepath(image['image'])
56
+ mask_path = folder_paths.get_annotated_filepath(image['mask'])
57
+ normal_path = folder_paths.get_annotated_filepath(image['normal'])
58
+
59
+ load_image_node = nodes.LoadImage()
60
+ output_image, ignore_mask = load_image_node.load_image(image=image_path)
61
+ ignore_image, output_mask = load_image_node.load_image(image=mask_path)
62
+ normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
63
+
64
+ video = None
65
+
66
+ if image['recording'] != "":
67
+ recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
68
+
69
+ video = InputImpl.VideoFromFile(recording_video_path)
70
+
71
+ file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
72
+ return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
73
+
74
+ process = execute # TODO: remove
75
+
76
+
77
+ class Preview3D(IO.ComfyNode):
78
+ @classmethod
79
+ def define_schema(cls):
80
+ return IO.Schema(
81
+ node_id="Preview3D",
82
+ search_aliases=["view mesh", "3d viewer"],
83
+ display_name="Preview 3D & Animation",
84
+ category="3d",
85
+ is_experimental=True,
86
+ is_output_node=True,
87
+ inputs=[
88
+ IO.MultiType.Input(
89
+ IO.String.Input("model_file", default="", multiline=False),
90
+ types=[
91
+ IO.File3DGLB,
92
+ IO.File3DGLTF,
93
+ IO.File3DFBX,
94
+ IO.File3DOBJ,
95
+ IO.File3DSTL,
96
+ IO.File3DUSDZ,
97
+ IO.File3DAny,
98
+ ],
99
+ tooltip="3D model file or path string",
100
+ ),
101
+ IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
102
+ IO.Image.Input("bg_image", optional=True, advanced=True),
103
+ ],
104
+ outputs=[],
105
+ )
106
+
107
+ @classmethod
108
+ def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput:
109
+ if isinstance(model_file, Types.File3D):
110
+ filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
111
+ model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
112
+ else:
113
+ filename = model_file
114
+ camera_info = kwargs.get("camera_info", None)
115
+ bg_image = kwargs.get("bg_image", None)
116
+ return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image))
117
+
118
+ process = execute # TODO: remove
119
+
120
+
121
+ class Load3DExtension(ComfyExtension):
122
+ @override
123
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
124
+ return [
125
+ Load3D,
126
+ Preview3D,
127
+ ]
128
+
129
+
130
+ async def comfy_entrypoint() -> Load3DExtension:
131
+ return Load3DExtension()
ComfyUI/comfy_extras/nodes_logic.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import TypedDict
3
+ from typing_extensions import override
4
+ from comfy_api.latest import ComfyExtension, io
5
+ from comfy_api.latest import _io
6
+
7
+ # sentinel for missing inputs
8
+ MISSING = object()
9
+
10
+
11
+ class SwitchNode(io.ComfyNode):
12
+ @classmethod
13
+ def define_schema(cls):
14
+ template = io.MatchType.Template("switch")
15
+ return io.Schema(
16
+ node_id="ComfySwitchNode",
17
+ display_name="Switch",
18
+ category="logic",
19
+ is_experimental=True,
20
+ inputs=[
21
+ io.Boolean.Input("switch"),
22
+ io.MatchType.Input("on_false", template=template, lazy=True),
23
+ io.MatchType.Input("on_true", template=template, lazy=True),
24
+ ],
25
+ outputs=[
26
+ io.MatchType.Output(template=template, display_name="output"),
27
+ ],
28
+ )
29
+
30
+ @classmethod
31
+ def check_lazy_status(cls, switch, on_false=None, on_true=None):
32
+ if switch and on_true is None:
33
+ return ["on_true"]
34
+ if not switch and on_false is None:
35
+ return ["on_false"]
36
+
37
+ @classmethod
38
+ def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
39
+ return io.NodeOutput(on_true if switch else on_false)
40
+
41
+
42
+ class SoftSwitchNode(io.ComfyNode):
43
+ @classmethod
44
+ def define_schema(cls):
45
+ template = io.MatchType.Template("switch")
46
+ return io.Schema(
47
+ node_id="ComfySoftSwitchNode",
48
+ display_name="Soft Switch",
49
+ category="logic",
50
+ is_experimental=True,
51
+ inputs=[
52
+ io.Boolean.Input("switch"),
53
+ io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
54
+ io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
55
+ ],
56
+ outputs=[
57
+ io.MatchType.Output(template=template, display_name="output"),
58
+ ],
59
+ )
60
+
61
+ @classmethod
62
+ def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING):
63
+ # We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs.
64
+ # This trick allows us to ignore the value of the switch and still be able to run execute().
65
+
66
+ # One of the inputs may be missing, in which case we need to evaluate the other input
67
+ if on_false is MISSING:
68
+ return ["on_true"]
69
+ if on_true is MISSING:
70
+ return ["on_false"]
71
+ # Normal lazy switch operation
72
+ if switch and on_true is None:
73
+ return ["on_true"]
74
+ if not switch and on_false is None:
75
+ return ["on_false"]
76
+
77
+ @classmethod
78
+ def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING):
79
+ # This check happens before check_lazy_status(), so we can eliminate the case where
80
+ # both inputs are missing.
81
+ if on_false is MISSING and on_true is MISSING:
82
+ return "At least one of on_false or on_true must be connected to Switch node"
83
+ return True
84
+
85
+ @classmethod
86
+ def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput:
87
+ if on_true is MISSING:
88
+ return io.NodeOutput(on_false)
89
+ if on_false is MISSING:
90
+ return io.NodeOutput(on_true)
91
+ return io.NodeOutput(on_true if switch else on_false)
92
+
93
+
94
+ class CustomComboNode(io.ComfyNode):
95
+ """
96
+ Frontend node that allows user to write their own options for a combo.
97
+ This is here to make sure the node has a backend-representation to avoid some annoyances.
98
+ """
99
+ @classmethod
100
+ def define_schema(cls):
101
+ return io.Schema(
102
+ node_id="CustomCombo",
103
+ display_name="Custom Combo",
104
+ category="utils",
105
+ is_experimental=True,
106
+ inputs=[io.Combo.Input("choice", options=[])],
107
+ outputs=[
108
+ io.String.Output(display_name="STRING"),
109
+ io.Int.Output(display_name="INDEX"),
110
+ ],
111
+ accept_all_inputs=True,
112
+ )
113
+
114
+ @classmethod
115
+ def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool:
116
+ # NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs.
117
+ # I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined.
118
+ # I need to skip checking that the chosen combo option is in the options list, since those are defined by the user.
119
+ return True
120
+
121
+ @classmethod
122
+ def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput:
123
+ return io.NodeOutput(choice, index)
124
+
125
+
126
+ class DCTestNode(io.ComfyNode):
127
+ class DCValues(TypedDict):
128
+ combo: str
129
+ string: str
130
+ integer: int
131
+ image: io.Image.Type
132
+ subcombo: dict[str]
133
+
134
+ @classmethod
135
+ def define_schema(cls):
136
+ return io.Schema(
137
+ node_id="DCTestNode",
138
+ display_name="DCTest",
139
+ category="logic",
140
+ is_output_node=True,
141
+ inputs=[io.DynamicCombo.Input("combo", options=[
142
+ io.DynamicCombo.Option("option1", [io.String.Input("string")]),
143
+ io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
144
+ io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
145
+ io.DynamicCombo.Option("option4", [
146
+ io.DynamicCombo.Input("subcombo", options=[
147
+ io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
148
+ io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
149
+ ])
150
+ ])]
151
+ )],
152
+ outputs=[io.AnyType.Output()],
153
+ )
154
+
155
+ @classmethod
156
+ def execute(cls, combo: DCValues) -> io.NodeOutput:
157
+ combo_val = combo["combo"]
158
+ if combo_val == "option1":
159
+ return io.NodeOutput(combo["string"])
160
+ elif combo_val == "option2":
161
+ return io.NodeOutput(combo["integer"])
162
+ elif combo_val == "option3":
163
+ return io.NodeOutput(combo["image"])
164
+ elif combo_val == "option4":
165
+ return io.NodeOutput(f"{combo['subcombo']}")
166
+ else:
167
+ raise ValueError(f"Invalid combo: {combo_val}")
168
+
169
+
170
+ class AutogrowNamesTestNode(io.ComfyNode):
171
+ @classmethod
172
+ def define_schema(cls):
173
+ template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
174
+ return io.Schema(
175
+ node_id="AutogrowNamesTestNode",
176
+ display_name="AutogrowNamesTest",
177
+ category="logic",
178
+ inputs=[
179
+ _io.Autogrow.Input("autogrow", template=template)
180
+ ],
181
+ outputs=[io.String.Output()],
182
+ )
183
+
184
+ @classmethod
185
+ def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
186
+ vals = list(autogrow.values())
187
+ combined = ",".join([str(x) for x in vals])
188
+ return io.NodeOutput(combined)
189
+
190
+ class AutogrowPrefixTestNode(io.ComfyNode):
191
+ @classmethod
192
+ def define_schema(cls):
193
+ template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
194
+ return io.Schema(
195
+ node_id="AutogrowPrefixTestNode",
196
+ display_name="AutogrowPrefixTest",
197
+ category="logic",
198
+ inputs=[
199
+ _io.Autogrow.Input("autogrow", template=template)
200
+ ],
201
+ outputs=[io.String.Output()],
202
+ )
203
+
204
+ @classmethod
205
+ def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
206
+ vals = list(autogrow.values())
207
+ combined = ",".join([str(x) for x in vals])
208
+ return io.NodeOutput(combined)
209
+
210
+ class ComboOutputTestNode(io.ComfyNode):
211
+ @classmethod
212
+ def define_schema(cls):
213
+ return io.Schema(
214
+ node_id="ComboOptionTestNode",
215
+ display_name="ComboOptionTest",
216
+ category="logic",
217
+ inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
218
+ io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
219
+ outputs=[io.Combo.Output(), io.Combo.Output()],
220
+ )
221
+
222
+ @classmethod
223
+ def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput:
224
+ return io.NodeOutput(combo, combo2)
225
+
226
+ class ConvertStringToComboNode(io.ComfyNode):
227
+ @classmethod
228
+ def define_schema(cls):
229
+ return io.Schema(
230
+ node_id="ConvertStringToComboNode",
231
+ search_aliases=["string to dropdown", "text to combo"],
232
+ display_name="Convert String to Combo",
233
+ category="logic",
234
+ inputs=[io.String.Input("string")],
235
+ outputs=[io.Combo.Output()],
236
+ )
237
+
238
+ @classmethod
239
+ def execute(cls, string: str) -> io.NodeOutput:
240
+ return io.NodeOutput(string)
241
+
242
+ class InvertBooleanNode(io.ComfyNode):
243
+ @classmethod
244
+ def define_schema(cls):
245
+ return io.Schema(
246
+ node_id="InvertBooleanNode",
247
+ search_aliases=["not", "toggle", "negate", "flip boolean"],
248
+ display_name="Invert Boolean",
249
+ category="logic",
250
+ inputs=[io.Boolean.Input("boolean")],
251
+ outputs=[io.Boolean.Output()],
252
+ )
253
+
254
+ @classmethod
255
+ def execute(cls, boolean: bool) -> io.NodeOutput:
256
+ return io.NodeOutput(not boolean)
257
+
258
+ class LogicExtension(ComfyExtension):
259
+ @override
260
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
261
+ return [
262
+ SwitchNode,
263
+ CustomComboNode,
264
+ # SoftSwitchNode,
265
+ # ConvertStringToComboNode,
266
+ # DCTestNode,
267
+ # AutogrowNamesTestNode,
268
+ # AutogrowPrefixTestNode,
269
+ # ComboOutputTestNode,
270
+ # InvertBooleanNode,
271
+ ]
272
+
273
+ async def comfy_entrypoint() -> LogicExtension:
274
+ return LogicExtension()
ComfyUI/comfy_extras/nodes_lora_debug.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import folder_paths
2
+ import comfy.utils
3
+ import comfy.sd
4
+
5
+
6
+ class LoraLoaderBypass:
7
+ """
8
+ Apply LoRA in bypass mode without modifying base model weights.
9
+
10
+ Bypass mode computes: output = base_forward(x) + lora_path(x)
11
+ This is useful for training and when model weights are offloaded.
12
+ """
13
+
14
+ def __init__(self):
15
+ self.loaded_lora = None
16
+
17
+ @classmethod
18
+ def INPUT_TYPES(s):
19
+ return {
20
+ "required": {
21
+ "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
22
+ "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}),
23
+ "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
24
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}),
25
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}),
26
+ }
27
+ }
28
+
29
+ RETURN_TYPES = ("MODEL", "CLIP")
30
+ OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
31
+ FUNCTION = "load_lora"
32
+
33
+ CATEGORY = "loaders"
34
+ DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
35
+ EXPERIMENTAL = True
36
+
37
+ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
38
+ if strength_model == 0 and strength_clip == 0:
39
+ return (model, clip)
40
+
41
+ lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
42
+ lora = None
43
+ if self.loaded_lora is not None:
44
+ if self.loaded_lora[0] == lora_path:
45
+ lora = self.loaded_lora[1]
46
+ else:
47
+ self.loaded_lora = None
48
+
49
+ if lora is None:
50
+ lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
51
+ self.loaded_lora = (lora_path, lora)
52
+
53
+ model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip)
54
+ return (model_lora, clip_lora)
55
+
56
+
57
+ class LoraLoaderBypassModelOnly(LoraLoaderBypass):
58
+ @classmethod
59
+ def INPUT_TYPES(s):
60
+ return {"required": { "model": ("MODEL",),
61
+ "lora_name": (folder_paths.get_filename_list("loras"), ),
62
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
63
+ }}
64
+ RETURN_TYPES = ("MODEL",)
65
+ FUNCTION = "load_lora_model_only"
66
+
67
+ def load_lora_model_only(self, model, lora_name, strength_model):
68
+ return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
69
+
70
+
71
+ NODE_CLASS_MAPPINGS = {
72
+ "LoraLoaderBypass": LoraLoaderBypass,
73
+ "LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly,
74
+ }
75
+
76
+ NODE_DISPLAY_NAME_MAPPINGS = {
77
+ "LoraLoaderBypass": "Load LoRA (Bypass) (For debugging)",
78
+ "LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only) (for debugging)",
79
+ }
ComfyUI/comfy_extras/nodes_lora_extract.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.model_management
3
+ import comfy.utils
4
+ import folder_paths
5
+ import os
6
+ import logging
7
+ from enum import Enum
8
+ from typing_extensions import override
9
+ from comfy_api.latest import ComfyExtension, io
10
+ from tqdm.auto import trange
11
+
12
+ CLAMP_QUANTILE = 0.99
13
+
14
+ def extract_lora(diff, rank):
15
+ conv2d = (len(diff.shape) == 4)
16
+ kernel_size = None if not conv2d else diff.size()[2:4]
17
+ conv2d_3x3 = conv2d and kernel_size != (1, 1)
18
+ out_dim, in_dim = diff.size()[0:2]
19
+ rank = min(rank, in_dim, out_dim)
20
+
21
+ if conv2d:
22
+ if conv2d_3x3:
23
+ diff = diff.flatten(start_dim=1)
24
+ else:
25
+ diff = diff.squeeze()
26
+
27
+
28
+ U, S, Vh = torch.linalg.svd(diff.float())
29
+ U = U[:, :rank]
30
+ S = S[:rank]
31
+ U = U @ torch.diag(S)
32
+ Vh = Vh[:rank, :]
33
+
34
+ dist = torch.cat([U.flatten(), Vh.flatten()])
35
+ hi_val = torch.quantile(dist, CLAMP_QUANTILE)
36
+ low_val = -hi_val
37
+
38
+ U = U.clamp(low_val, hi_val)
39
+ Vh = Vh.clamp(low_val, hi_val)
40
+ if conv2d:
41
+ U = U.reshape(out_dim, rank, 1, 1)
42
+ Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
43
+ return (U, Vh)
44
+
45
+ class LORAType(Enum):
46
+ STANDARD = 0
47
+ FULL_DIFF = 1
48
+
49
+ LORA_TYPES = {"standard": LORAType.STANDARD,
50
+ "full_diff": LORAType.FULL_DIFF}
51
+
52
+ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
53
+ comfy.model_management.load_models_gpu([model_diff])
54
+ sd = model_diff.model_state_dict(filter_prefix=prefix_model)
55
+
56
+ sd_keys = list(sd.keys())
57
+ for index in trange(len(sd_keys), unit="weight"):
58
+ k = sd_keys[index]
59
+ op_keys = sd_keys[index].rsplit('.', 1)
60
+ if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
61
+ continue
62
+ op = comfy.utils.get_attr(model_diff.model, op_keys[0])
63
+ if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
64
+ weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
65
+ else:
66
+ weight_diff = sd[k]
67
+
68
+ if op_keys[1] == "weight":
69
+ if lora_type == LORAType.STANDARD:
70
+ if weight_diff.ndim < 2:
71
+ if bias_diff:
72
+ output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
73
+ continue
74
+ try:
75
+ out = extract_lora(weight_diff, rank)
76
+ output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu()
77
+ output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu()
78
+ except:
79
+ logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k))
80
+ elif lora_type == LORAType.FULL_DIFF:
81
+ output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
82
+
83
+ elif bias_diff and op_keys[1] == "bias":
84
+ output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
85
+ return output_sd
86
+
87
+ class LoraSave(io.ComfyNode):
88
+ @classmethod
89
+ def define_schema(cls):
90
+ return io.Schema(
91
+ node_id="LoraSave",
92
+ search_aliases=["export lora"],
93
+ display_name="Extract and Save Lora",
94
+ category="_for_testing",
95
+ inputs=[
96
+ io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
97
+ io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),
98
+ io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys()), advanced=True),
99
+ io.Boolean.Input("bias_diff", default=True, advanced=True),
100
+ io.Model.Input(
101
+ "model_diff",
102
+ tooltip="The ModelSubtract output to be converted to a lora.",
103
+ optional=True,
104
+ ),
105
+ io.Clip.Input(
106
+ "text_encoder_diff",
107
+ tooltip="The CLIPSubtract output to be converted to a lora.",
108
+ optional=True,
109
+ ),
110
+ ],
111
+ is_experimental=True,
112
+ is_output_node=True,
113
+ )
114
+
115
+ @classmethod
116
+ def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
117
+ if model_diff is None and text_encoder_diff is None:
118
+ return io.NodeOutput()
119
+
120
+ lora_type = LORA_TYPES.get(lora_type)
121
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
122
+
123
+ output_sd = {}
124
+ if model_diff is not None:
125
+ output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff)
126
+ if text_encoder_diff is not None:
127
+ output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff)
128
+
129
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
130
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
131
+
132
+ comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
133
+ return io.NodeOutput()
134
+
135
+
136
+ class LoraSaveExtension(ComfyExtension):
137
+ @override
138
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
139
+ return [
140
+ LoraSave,
141
+ ]
142
+
143
+
144
+ async def comfy_entrypoint() -> LoraSaveExtension:
145
+ return LoraSaveExtension()
ComfyUI/comfy_extras/nodes_lotus.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import override
2
+
3
+ import torch
4
+ import comfy.model_management as mm
5
+ from comfy_api.latest import ComfyExtension, io
6
+
7
+
8
+ class LotusConditioning(io.ComfyNode):
9
+ @classmethod
10
+ def define_schema(cls):
11
+ return io.Schema(
12
+ node_id="LotusConditioning",
13
+ category="conditioning/lotus",
14
+ inputs=[],
15
+ outputs=[io.Conditioning.Output(display_name="conditioning")],
16
+ )
17
+
18
+ @classmethod
19
+ def execute(cls) -> io.NodeOutput:
20
+ device = mm.get_torch_device()
21
+ #lotus uses a frozen encoder and null conditioning, i'm just inlining the results of that operation since it doesn't change
22
+ #and getting parity with the reference implementation would otherwise require inference and 800mb of tensors
23
+ prompt_embeds = torch.tensor([[[-0.3134765625, -0.447509765625, -0.00823974609375, -0.22802734375, 0.1785888671875, -0.2342529296875, -0.2188720703125, -0.0089111328125, -0.31396484375, 0.196533203125, -0.055877685546875, -0.3828125, -0.0965576171875, 0.0073394775390625, -0.284423828125, 0.07470703125, -0.086181640625, -0.211181640625, 0.0599365234375, 0.10693359375, 0.0007929801940917969, -0.78076171875, -0.382568359375, -0.1851806640625, -0.140625, -0.0936279296875, -0.1229248046875, -0.152099609375, -0.203857421875, -0.2349853515625, -0.2437744140625, -0.10858154296875, -0.08990478515625, 0.08892822265625, -0.2391357421875, -0.1611328125, -0.427978515625, -0.1336669921875, -0.27685546875, -0.1781005859375, -0.3857421875, 0.251953125, -0.055999755859375, -0.0712890625, -0.00130462646484375, 0.033477783203125, -0.26416015625, 0.07171630859375, -0.0090789794921875, -0.2025146484375, -0.2763671875, -0.09869384765625, -0.45751953125, -0.23095703125, 0.004528045654296875, -0.369140625, -0.366943359375, -0.205322265625, -0.1505126953125, -0.45166015625, -0.2059326171875, 0.0168609619140625, -0.305419921875, -0.150634765625, 0.02685546875, -0.609375, -0.019012451171875, 0.050445556640625, -0.0084381103515625, -0.31005859375, -0.184326171875, -0.15185546875, 0.06732177734375, 0.150390625, -0.10919189453125, -0.08837890625, -0.50537109375, -0.389892578125, -0.0294342041015625, -0.10491943359375, -0.187255859375, -0.43212890625, -0.328125, -1.060546875, 0.011871337890625, 0.04730224609375, -0.09521484375, -0.07452392578125, -0.29296875, -0.109130859375, -0.250244140625, -0.3828125, -0.171875, -0.03399658203125, -0.15478515625, -0.1861572265625, -0.2398681640625, 0.1053466796875, -0.22314453125, -0.1932373046875, -0.18798828125, -0.430419921875, -0.05364990234375, -0.474609375, -0.261474609375, -0.1077880859375, -0.439208984375, 0.08966064453125, -0.185302734375, -0.338134765625, -0.297119140625, -0.298583984375, -0.175537109375, -0.373291015625, -0.1397705078125, -0.260498046875, -0.383544921875, -0.09979248046875, -0.319580078125, -0.06884765625, -0.4365234375, -0.183837890625, -0.393310546875, -0.002277374267578125, 0.11236572265625, -0.260498046875, -0.2242431640625, -0.19384765625, -0.51123046875, 0.03216552734375, -0.048004150390625, -0.279052734375, -0.2978515625, -0.255615234375, 0.115478515625, -4.08984375, -0.1668701171875, -0.278076171875, -0.5712890625, -0.1385498046875, -0.244384765625, -0.41455078125, -0.244140625, -0.0677490234375, -0.141357421875, -0.11590576171875, -0.1439208984375, -0.0185394287109375, -2.490234375, -0.1549072265625, -0.2305908203125, -0.3828125, -0.1173095703125, -0.08258056640625, -0.1719970703125, -0.325439453125, -0.292724609375, -0.08154296875, -0.412353515625, -0.3115234375, -0.00832366943359375, 0.00489044189453125, -0.2236328125, -0.151123046875, -0.457275390625, -0.135009765625, -0.163330078125, -0.0819091796875, 0.06689453125, 0.0209197998046875, -0.11907958984375, -0.10369873046875, -0.2998046875, -0.478759765625, -0.07940673828125, -0.01517486572265625, -0.3017578125, -0.343994140625, -0.258544921875, -0.44775390625, -0.392822265625, -0.0255584716796875, -0.2998046875, 0.10833740234375, -0.271728515625, -0.36181640625, -0.255859375, -0.2056884765625, -0.055450439453125, 0.060516357421875, -0.45751953125, -0.2322998046875, -0.1737060546875, -0.40576171875, -0.2286376953125, -0.053070068359375, -0.0283660888671875, -0.1898193359375, -4.291534423828125e-05, -0.6591796875, -0.1717529296875, -0.479736328125, -0.1400146484375, -0.40771484375, 0.154296875, 0.003101348876953125, 0.00661468505859375, -0.2073974609375, -0.493408203125, 2.171875, -0.45361328125, -0.283935546875, -0.302001953125, -0.25146484375, -0.207275390625, -0.1524658203125, -0.72998046875, -0.08203125, 0.053192138671875, -0.2685546875, 0.1834716796875, -0.270263671875, -0.091552734375, -0.08319091796875, -0.1297607421875, -0.453857421875, 0.0687255859375, 0.0268096923828125, -0.16552734375, -0.4208984375, -0.1552734375, -0.057373046875, -0.300537109375, -0.04541015625, -0.486083984375, -0.2205810546875, -0.39013671875, 0.007488250732421875, -0.005329132080078125, -0.09759521484375, -0.1448974609375, -0.21923828125, -0.429443359375, -0.40087890625, -0.19384765625, -0.064453125, -0.0306243896484375, -0.045806884765625, -0.056793212890625, 0.119384765625, -0.2073974609375, -0.356201171875, -0.168212890625, -0.291748046875, -0.289794921875, -0.205322265625, -0.419677734375, -0.478271484375, -0.2037353515625, -0.368408203125, -0.186279296875, -0.427734375, -0.1756591796875, 0.07501220703125, -0.2457275390625, -0.03692626953125, 0.003997802734375, -5.7578125, -0.01052093505859375, -0.2305908203125, -0.2252197265625, -0.197509765625, -0.1566162109375, -0.1668701171875, -0.383056640625, -0.05413818359375, 0.12188720703125, -0.369873046875, -0.0184478759765625, -0.150146484375, -0.51123046875, -0.45947265625, -0.1561279296875, 0.060455322265625, 0.043487548828125, -0.1370849609375, -0.069091796875, -0.285888671875, -0.44482421875, -0.2374267578125, -0.2191162109375, -0.434814453125, -0.0360107421875, 0.1298828125, 0.0217742919921875, -0.51220703125, -0.13525390625, -0.09381103515625, -0.276611328125, -0.171875, -0.17138671875, -0.4443359375, -0.2178955078125, -0.269775390625, -0.38623046875, -0.31591796875, -0.42333984375, -0.280029296875, -0.255615234375, -0.17041015625, 0.06268310546875, -0.1878662109375, -0.00677490234375, -0.23583984375, -0.08795166015625, -0.2232666015625, -0.1719970703125, -0.484130859375, -0.328857421875, 0.04669189453125, -0.0419921875, -0.11114501953125, 0.02313232421875, -0.0033130645751953125, -0.6005859375, 0.09051513671875, -0.1884765625, -0.262939453125, -0.375732421875, -0.525390625, -0.1170654296875, -0.3779296875, -0.242919921875, -0.419921875, 0.0665283203125, -0.343017578125, 0.06658935546875, -0.346435546875, -0.1363525390625, -0.2000732421875, -0.3837890625, 0.028167724609375, 0.043853759765625, -0.0171051025390625, -0.477294921875, -0.107421875, -0.129150390625, -0.319580078125, -0.32177734375, -0.4951171875, -0.010589599609375, -0.1778564453125, -0.40234375, -0.0810546875, 0.03314208984375, -0.13720703125, -0.31591796875, -0.048248291015625, -0.274658203125, -0.0689697265625, -0.027130126953125, -0.0953369140625, 0.146728515625, -0.38671875, -0.025390625, -0.42333984375, -0.41748046875, -0.379638671875, -0.1978759765625, -0.533203125, -0.33544921875, 0.0694580078125, -0.322998046875, -0.1876220703125, 0.0094451904296875, 0.1839599609375, -0.254150390625, -0.30078125, -0.09228515625, -0.0885009765625, 0.12371826171875, 0.1500244140625, -0.12152099609375, -0.29833984375, 0.03924560546875, -0.1470947265625, -0.1610107421875, -0.2049560546875, -0.01708984375, -0.2470703125, -0.1522216796875, -0.25830078125, 0.10870361328125, -0.302490234375, -0.2376708984375, -0.360107421875, -0.443359375, -0.0784912109375, -0.63623046875, -0.0980224609375, -0.332275390625, -0.1749267578125, -0.30859375, -0.1968994140625, -0.250244140625, -0.447021484375, -0.18408203125, -0.006908416748046875, -0.2044677734375, -0.2548828125, -0.369140625, -0.11328125, -0.1103515625, -0.27783203125, -0.325439453125, 0.01381683349609375, 0.036773681640625, -0.1458740234375, -0.34619140625, -0.232177734375, -0.0562744140625, -0.4482421875, -0.21875, -0.0855712890625, -0.276123046875, -0.1544189453125, -0.223388671875, -0.259521484375, 0.0865478515625, -0.0038013458251953125, -0.340087890625, -0.076171875, -0.25341796875, -0.0007548332214355469, -0.060455322265625, -0.352294921875, 0.035736083984375, -0.2181396484375, -0.2318115234375, -0.1707763671875, 0.018646240234375, 0.093505859375, -0.197021484375, 0.033477783203125, -0.035247802734375, 0.0440673828125, -0.2056884765625, -0.040924072265625, -0.05865478515625, 0.056884765625, -0.08807373046875, -0.10845947265625, 0.09564208984375, -0.10888671875, -0.332275390625, -0.1119384765625, -0.115478515625, 13.0234375, 0.0030040740966796875, -0.53662109375, -0.1856689453125, -0.068115234375, -0.143798828125, -0.177978515625, -0.32666015625, -0.353515625, -0.1563720703125, -0.3203125, 0.0085906982421875, -0.1043701171875, -0.365478515625, -0.303466796875, -0.34326171875, -0.410888671875, -0.03790283203125, -0.11419677734375, -0.2939453125, 0.074462890625, -0.21826171875, 0.0242767333984375, -0.226318359375, -0.353515625, -0.177734375, -0.169189453125, -0.2423095703125, -0.12115478515625, -0.07843017578125, -0.341064453125, -0.2117919921875, -0.505859375, -0.544921875, -0.3935546875, -0.10772705078125, -0.2054443359375, -0.136474609375, -0.1796875, -0.396240234375, -0.1971435546875, -0.68408203125, -0.032684326171875, -0.03863525390625, -0.0709228515625, -0.1005859375, -0.156005859375, -0.3837890625, -0.319580078125, 0.11102294921875, -0.394287109375, 0.0799560546875, -0.50341796875, -0.1572265625, 0.004131317138671875, -0.12286376953125, -0.2347412109375, -0.29150390625, -0.10321044921875, -0.286376953125, 0.018798828125, -0.152099609375, -0.321044921875, 0.0191650390625, -0.11376953125, -0.54736328125, 0.15869140625, -0.257568359375, -0.2490234375, -0.3115234375, -0.09765625, -0.350830078125, -0.36376953125, -0.0771484375, -0.2298583984375, -0.30615234375, -0.052154541015625, -0.12091064453125, -0.40283203125, -0.1649169921875, 0.0206451416015625, -0.312744140625, -0.10308837890625, -0.50341796875, -0.1754150390625, -0.2003173828125, -0.173583984375, -0.204833984375, -0.1876220703125, -0.12176513671875, -0.06201171875, -0.03485107421875, -0.20068359375, -0.21484375, -0.246337890625, -0.006587982177734375, -0.09674072265625, -0.4658203125, -0.3994140625, -0.2210693359375, -0.09588623046875, -0.126220703125, -0.09222412109375, -0.145751953125, -0.217529296875, -0.289306640625, -0.28271484375, -0.1787109375, -0.169189453125, -0.359375, -0.21826171875, -0.043792724609375, -0.205322265625, -0.2900390625, -0.055419921875, -0.1490478515625, -0.340576171875, -0.045928955078125, -0.30517578125, -0.51123046875, -0.1046142578125, -0.349853515625, -0.10882568359375, -0.16748046875, -0.267333984375, -0.122314453125, -0.0985107421875, -0.3076171875, -0.1766357421875, -0.251708984375, 0.1964111328125, -0.2220458984375, -0.2349853515625, -0.035980224609375, -0.1749267578125, -0.237060546875, -0.480224609375, -0.240234375, -0.09539794921875, -0.2481689453125, -0.389404296875, -0.1748046875, -0.370849609375, -0.010650634765625, -0.147705078125, -0.0035457611083984375, -0.32568359375, -0.29931640625, -0.1395263671875, -0.28173828125, -0.09820556640625, -0.0176239013671875, -0.05926513671875, -0.0755615234375, -0.1746826171875, -0.283203125, -0.1617431640625, -0.4404296875, 0.046234130859375, -0.183837890625, -0.052032470703125, -0.24658203125, -0.11224365234375, -0.100830078125, -0.162841796875, -0.29736328125, -0.396484375, 0.11798095703125, -0.006496429443359375, -0.32568359375, -0.347900390625, -0.04595947265625, -0.09637451171875, -0.344970703125, -0.01166534423828125, -0.346435546875, -0.2861328125, -0.1845703125, -0.276611328125, -0.01312255859375, -0.395263671875, -0.50927734375, -0.1114501953125, -0.1861572265625, -0.2158203125, -0.1812744140625, 0.055419921875, -0.294189453125, 0.06500244140625, -0.1444091796875, -0.06365966796875, -0.18408203125, -0.0091705322265625, -0.1640625, -0.1856689453125, 0.090087890625, 0.024566650390625, -0.0195159912109375, -0.5546875, -0.301025390625, -0.438232421875, -0.072021484375, 0.030517578125, -0.1490478515625, 0.04888916015625, -0.23681640625, -0.1553955078125, -0.018096923828125, -0.229736328125, -0.2919921875, -0.355712890625, -0.285400390625, -0.1756591796875, -0.08355712890625, -0.416259765625, 0.022674560546875, -0.417236328125, 0.410400390625, -0.249755859375, 0.015625, -0.033599853515625, -0.040313720703125, -0.51708984375, -0.0518798828125, -0.08843994140625, -0.2022705078125, -0.3740234375, -0.285888671875, -0.176025390625, -0.292724609375, -0.369140625, -0.08367919921875, -0.356689453125, -0.38623046875, 0.06549072265625, 0.1669921875, -0.2099609375, -0.007434844970703125, 0.12890625, -0.0040740966796875, -0.2174072265625, -0.025115966796875, -0.2364501953125, -0.1695556640625, -0.0469970703125, -0.03924560546875, -0.36181640625, -0.047515869140625, -0.3154296875, -0.275634765625, -0.25634765625, -0.061920166015625, -0.12164306640625, -0.47314453125, -0.10784912109375, -0.74755859375, -0.13232421875, -0.32421875, -0.04998779296875, -0.286376953125, 0.10345458984375, -0.1710205078125, -0.388916015625, 0.12744140625, -0.3359375, -0.302490234375, -0.238525390625, -0.1455078125, -0.15869140625, -0.2427978515625, -0.0355224609375, -0.11944580078125, -0.31298828125, 0.11456298828125, -0.287841796875, -0.5439453125, -0.3076171875, -0.08642578125, -0.2408447265625, -0.283447265625, -0.428466796875, -0.085693359375, -0.1683349609375, 0.255126953125, 0.07635498046875, -0.38623046875, -0.2025146484375, -0.1331787109375, -0.10821533203125, -0.49951171875, 0.09130859375, -0.19677734375, -0.01904296875, -0.151123046875, -0.344482421875, -0.316650390625, -0.03900146484375, 0.1397705078125, 0.1334228515625, -0.037200927734375, -0.01861572265625, -0.1351318359375, -0.07037353515625, -0.380615234375, -0.34033203125, -0.06903076171875, 0.219970703125, 0.0132598876953125, -0.15869140625, -0.6376953125, 0.158935546875, -0.5283203125, -0.2320556640625, -0.185791015625, -0.2132568359375, -0.436767578125, -0.430908203125, -0.1763916015625, -0.0007672309875488281, -0.424072265625, -0.06719970703125, -0.347900390625, -0.14453125, -0.3056640625, -0.36474609375, -0.35986328125, -0.46240234375, -0.446044921875, -0.1905517578125, -0.1114501953125, -0.42919921875, -0.0643310546875, -0.3662109375, -0.4296875, -0.10968017578125, -0.2998046875, -0.1756591796875, -0.4052734375, -0.0841064453125, -0.252197265625, -0.047393798828125, 0.00434112548828125, -0.10040283203125, -0.271484375, -0.185302734375, -0.1910400390625, 0.10260009765625, 0.01393890380859375, -0.03350830078125, -0.33935546875, -0.329345703125, 0.0574951171875, -0.18896484375, -0.17724609375, -0.42919921875, -0.26708984375, -0.4189453125, -0.149169921875, -0.265625, -0.198974609375, -0.1722412109375, 0.1563720703125, -0.20947265625, -0.267822265625, -0.06353759765625, -0.365478515625, -0.340087890625, -0.3095703125, -0.320068359375, -0.0880126953125, -0.353759765625, -0.0005812644958496094, -0.1617431640625, -0.1866455078125, -0.201416015625, -0.181396484375, -0.2349853515625, -0.384765625, -0.5244140625, 0.01227569580078125, -0.21337890625, -0.30810546875, -0.17578125, -0.3037109375, -0.52978515625, -0.1561279296875, -0.296142578125, 0.057342529296875, -0.369384765625, -0.107666015625, -0.338623046875, -0.2060546875, -0.0213775634765625, -0.394775390625, -0.219482421875, -0.125732421875, -0.03997802734375, -0.42431640625, -0.134521484375, -0.2418212890625, -0.10504150390625, 0.1552734375, 0.1126708984375, -0.1427001953125, -0.133544921875, -0.111083984375, -0.375732421875, -0.2783203125, -0.036834716796875, -0.11053466796875, 0.2471923828125, -0.2529296875, -0.56494140625, -0.374755859375, -0.326416015625, 0.2137451171875, -0.09454345703125, -0.337158203125, -0.3359375, -0.34375, -0.0999755859375, -0.388671875, 0.0103302001953125, 0.14990234375, -0.2041015625, -0.39501953125, -0.39013671875, -0.1258544921875, 0.1453857421875, -0.250732421875, -0.06732177734375, -0.10638427734375, -0.032379150390625, -0.35888671875, -0.098876953125, -0.172607421875, 0.05126953125, -0.1956787109375, -0.183837890625, -0.37060546875, 0.1556396484375, -0.34375, -0.28662109375, -0.06982421875, -0.302490234375, -0.281005859375, -0.1640625, -0.5302734375, -0.1368408203125, -0.1268310546875, -0.35302734375, -0.1473388671875, -0.45556640625, -0.35986328125, -0.273681640625, -0.2249755859375, -0.1893310546875, 0.09356689453125, -0.248291015625, -0.197998046875, -0.3525390625, -0.30126953125, -0.228271484375, -0.2421875, -0.0906982421875, 0.227783203125, -0.296875, -0.009796142578125, -0.2939453125, -0.1021728515625, -0.215576171875, -0.267822265625, -0.052642822265625, 0.203369140625, -0.1417236328125, 0.18505859375, 0.12347412109375, -0.0972900390625, -0.54052734375, -0.430419921875, -0.0906982421875, -0.5419921875, -0.22900390625, -0.0625, -0.12152099609375, -0.495849609375, -0.206787109375, -0.025848388671875, 0.039031982421875, -0.453857421875, -0.318359375, -0.426025390625, -0.3701171875, -0.2169189453125, 0.0845947265625, -0.045654296875, 0.11090087890625, 0.0012454986572265625, 0.2066650390625, -0.046356201171875, -0.2337646484375, -0.295654296875, 0.057891845703125, -0.1639404296875, -0.0535888671875, -0.2607421875, -0.1488037109375, -0.16015625, -0.54345703125, -0.2305908203125, -0.55029296875, -0.178955078125, -0.222412109375, -0.0711669921875, -0.12298583984375, -0.119140625, -0.253662109375, -0.33984375, -0.11322021484375, -0.10723876953125, -0.205078125, -0.360595703125, 0.085205078125, -0.252197265625, -0.365966796875, -0.26953125, 0.2000732421875, -0.50634765625, 0.05706787109375, -0.3115234375, 0.0242919921875, -0.1689453125, -0.2401123046875, -0.3759765625, -0.2125244140625, 0.076416015625, -0.489013671875, -0.11749267578125, -0.55908203125, -0.313232421875, -0.572265625, -0.1387939453125, -0.037078857421875, -0.385498046875, 0.0323486328125, -0.39404296875, -0.05072021484375, -0.10430908203125, -0.10919189453125, -0.28759765625, -0.37451171875, -0.016937255859375, -0.2200927734375, -0.296875, -0.0286712646484375, -0.213134765625, 0.052001953125, -0.052337646484375, -0.253662109375, 0.07269287109375, -0.2498779296875, -0.150146484375, -0.09930419921875, -0.343505859375, 0.254150390625, -0.032440185546875, -0.296142578125], [1.4111328125, 0.00757598876953125, -0.428955078125, 0.089599609375, 0.0227813720703125, -0.0350341796875, -1.0986328125, 0.194091796875, 2.115234375, -0.75439453125, 0.269287109375, -0.73486328125, -1.1025390625, -0.050262451171875, -0.5830078125, 0.0268707275390625, -0.603515625, -0.6025390625, -1.1689453125, 0.25048828125, -0.4189453125, -0.5517578125, -0.30322265625, 0.7724609375, 0.931640625, -0.1422119140625, 2.27734375, -0.56591796875, 1.013671875, -0.9638671875, -0.66796875, -0.8125, 1.3740234375, -1.060546875, -1.029296875, -1.6796875, 0.62890625, 0.49365234375, 0.671875, 0.99755859375, -1.0185546875, -0.047027587890625, -0.374267578125, 0.2354736328125, 1.4970703125, -1.5673828125, 0.448974609375, 0.2078857421875, -1.060546875, -0.171875, -0.6201171875, -0.1607666015625, 0.7548828125, -0.58935546875, -0.2052001953125, 0.060791015625, 0.200439453125, 3.154296875, -3.87890625, 2.03515625, 1.126953125, 0.1640625, -1.8447265625, 0.002620697021484375, 0.7998046875, -0.337158203125, 0.47216796875, -0.5849609375, 0.9970703125, 0.3935546875, 1.22265625, -1.5048828125, -0.65673828125, 1.1474609375, -1.73046875, -1.8701171875, 1.529296875, -0.6787109375, -1.4453125, 1.556640625, -0.327392578125, 2.986328125, -0.146240234375, -2.83984375, 0.303466796875, -0.71728515625, -0.09698486328125, -0.2423095703125, 0.6767578125, -2.197265625, -0.86279296875, -0.53857421875, -1.2236328125, 1.669921875, -1.1689453125, -0.291259765625, -0.54736328125, -0.036346435546875, 1.041015625, -1.7265625, -0.6064453125, -0.1634521484375, 0.2381591796875, 0.65087890625, -1.169921875, 1.9208984375, 0.5634765625, 0.37841796875, 0.798828125, -1.021484375, -0.4091796875, 2.275390625, -0.302734375, -1.7783203125, 1.0458984375, 1.478515625, 0.708984375, -1.541015625, -0.0006041526794433594, 1.1884765625, 2.041015625, 0.560546875, -0.1131591796875, 1.0341796875, 0.06121826171875, 2.6796875, -0.53369140625, -1.2490234375, -0.7333984375, -1.017578125, -1.0078125, 1.3212890625, -0.47607421875, -1.4189453125, 0.54052734375, -0.796875, -0.73095703125, -1.412109375, -0.94873046875, -2.2734375, -1.1220703125, -1.3837890625, -0.5087890625, -1.0380859375, -0.93603515625, -0.58349609375, -1.0703125, -1.10546875, -2.60546875, 0.062225341796875, 0.38232421875, -0.411376953125, -0.369140625, -0.9833984375, -0.7294921875, -0.181396484375, -0.47216796875, -0.56884765625, -0.11041259765625, -2.673828125, 0.27783203125, -0.857421875, 0.9296875, 1.9580078125, 0.1385498046875, -1.91796875, -1.529296875, 0.53857421875, 0.509765625, -0.90380859375, -0.0947265625, -2.083984375, 0.9228515625, -0.28564453125, -0.80859375, -0.093505859375, -0.6015625, -1.255859375, 0.6533203125, 0.327880859375, -0.07598876953125, -0.22705078125, -0.30078125, -0.5185546875, -1.6044921875, 1.5927734375, 1.416015625, -0.91796875, -0.276611328125, -0.75830078125, -1.1689453125, -1.7421875, 1.0546875, -0.26513671875, -0.03314208984375, 0.278076171875, -1.337890625, 0.055023193359375, 0.10546875, -1.064453125, 1.048828125, -1.4052734375, -1.1240234375, -0.51416015625, -1.05859375, -1.7265625, -1.1328125, 0.43310546875, -2.576171875, -2.140625, -0.79345703125, 0.50146484375, 1.96484375, 0.98583984375, 0.337646484375, -0.77978515625, 0.85498046875, -0.65185546875, -0.484375, 2.708984375, 0.55810546875, -0.147216796875, -0.5537109375, -0.75439453125, -1.736328125, 1.1259765625, -1.095703125, -0.2587890625, 2.978515625, 0.335205078125, 0.357666015625, -0.09356689453125, 0.295654296875, -0.23779296875, 1.5751953125, 0.10400390625, 1.7001953125, -0.72900390625, -1.466796875, -0.2012939453125, 0.634765625, -0.1556396484375, -2.01171875, 0.32666015625, 0.047454833984375, -0.1671142578125, -0.78369140625, -0.994140625, 0.7802734375, -0.1429443359375, -0.115234375, 0.53271484375, -0.96142578125, -0.064208984375, 1.396484375, 1.654296875, -1.6015625, -0.77392578125, 0.276123046875, -0.42236328125, 0.8642578125, 0.533203125, 0.397216796875, -1.21484375, 0.392578125, -0.501953125, -0.231689453125, 1.474609375, 1.6669921875, 1.8662109375, -1.2998046875, 0.223876953125, -0.51318359375, -0.437744140625, -1.16796875, -0.7724609375, 1.6826171875, 0.62255859375, 2.189453125, -0.599609375, -0.65576171875, -1.1005859375, -0.45263671875, -0.292236328125, 2.58203125, -1.3779296875, 0.23486328125, -1.708984375, -1.4111328125, -0.5078125, -0.8525390625, -0.90771484375, 0.861328125, -2.22265625, -1.380859375, 0.7275390625, 0.85595703125, -0.77978515625, 2.044921875, -0.430908203125, 0.78857421875, -1.21484375, -0.09130859375, 0.5146484375, -1.92578125, -0.1396484375, 0.289306640625, 0.60498046875, 0.93896484375, -0.09295654296875, -0.45751953125, -0.986328125, -0.66259765625, 1.48046875, 0.274169921875, -0.267333984375, -1.3017578125, -1.3623046875, -1.982421875, -0.86083984375, -0.41259765625, -0.2939453125, -1.91015625, 1.6826171875, 0.437255859375, 1.0029296875, 0.376220703125, -0.010467529296875, -0.82861328125, -0.513671875, -3.134765625, 1.0205078125, -1.26171875, -1.009765625, 1.0869140625, -0.95703125, 0.0103759765625, 1.642578125, 0.78564453125, 1.029296875, 0.496826171875, 1.2880859375, 0.5234375, 0.05322265625, -0.206787109375, -0.79443359375, -1.1669921875, 0.049530029296875, -0.27978515625, 0.0237884521484375, -0.74169921875, -1.068359375, 0.86083984375, 1.1787109375, 0.91064453125, -0.453857421875, -1.822265625, -0.9228515625, -0.50048828125, 0.359130859375, 0.802734375, -1.3564453125, -0.322509765625, -1.1123046875, -1.0390625, -0.52685546875, -1.291015625, -0.343017578125, -1.2109375, -0.19091796875, 2.146484375, -0.04315185546875, -0.3701171875, -2.044921875, -0.429931640625, -0.56103515625, -0.166015625, -0.4658203125, -2.29296875, -1.078125, -1.0927734375, -0.1033935546875, -0.56103515625, -0.05743408203125, -1.986328125, -0.513671875, 0.70361328125, -2.484375, -1.3037109375, -1.6650390625, 0.4814453125, -0.84912109375, -2.697265625, -0.197998046875, 0.0869140625, -0.172607421875, -1.326171875, -1.197265625, 1.23828125, -0.38720703125, -0.075927734375, 0.02569580078125, -1.2119140625, 0.09027099609375, -2.12890625, -1.640625, -0.1524658203125, 0.2373046875, 1.37109375, 2.248046875, 1.4619140625, 0.3134765625, 0.50244140625, -0.1383056640625, -1.2705078125, 0.7353515625, 0.65771484375, -0.431396484375, -1.341796875, 0.10089111328125, 0.208984375, -0.0099945068359375, 0.83203125, 1.314453125, -0.422607421875, -1.58984375, -0.6044921875, 0.23681640625, -1.60546875, -0.61083984375, -1.5615234375, 1.62890625, -0.6728515625, -0.68212890625, -0.5224609375, -0.9150390625, -0.468994140625, 0.268310546875, 0.287353515625, -0.025543212890625, 0.443603515625, 1.62109375, -1.08984375, -0.5556640625, 1.03515625, -0.31298828125, -0.041778564453125, 0.260986328125, 0.34716796875, -2.326171875, 0.228271484375, -0.85107421875, -2.255859375, 0.3486328125, -0.25830078125, -0.3671875, -0.796875, -1.115234375, 1.8369140625, -0.19775390625, -1.236328125, -0.0447998046875, 0.69921875, 1.37890625, 1.11328125, 0.0928955078125, 0.6318359375, -0.62353515625, 0.55859375, -0.286865234375, 1.5361328125, -0.391357421875, -0.052215576171875, -1.12890625, 0.55517578125, -0.28515625, -0.3603515625, 0.68896484375, 0.67626953125, 0.003070831298828125, 1.2236328125, 0.1597900390625, -1.3076171875, 0.99951171875, -2.5078125, -1.2119140625, 0.1749267578125, -1.1865234375, -1.234375, -0.1180419921875, -1.751953125, 0.033050537109375, 0.234130859375, -3.107421875, -1.0380859375, 0.61181640625, -0.87548828125, 0.3154296875, -1.103515625, 0.261474609375, -1.130859375, -0.7470703125, -0.43408203125, 1.3828125, -0.41259765625, -1.7587890625, 0.765625, 0.004852294921875, 0.135498046875, -0.76953125, -0.1314697265625, 0.400390625, 1.43359375, 0.07135009765625, 0.0645751953125, -0.5869140625, -0.5810546875, -0.2900390625, -1.3037109375, 0.1287841796875, -0.27490234375, 0.59228515625, 2.333984375, -0.54541015625, -0.556640625, 0.447265625, -0.806640625, 0.09149169921875, -0.70654296875, -0.357177734375, -1.099609375, -0.5576171875, -0.44189453125, 0.400390625, -0.666015625, -1.4619140625, 0.728515625, -1.5986328125, 0.153076171875, -0.126708984375, -2.83984375, -1.84375, -0.2469482421875, 0.677734375, 0.43701171875, 3.298828125, 1.1591796875, -0.7158203125, -0.8251953125, 0.451171875, -2.376953125, -0.58642578125, -0.86767578125, 0.0789794921875, 0.1351318359375, -0.325439453125, 0.484375, 1.166015625, -0.1610107421875, -0.15234375, -0.54638671875, -0.806640625, 0.285400390625, 0.1661376953125, -0.50146484375, -1.0478515625, 1.5751953125, 0.0313720703125, 0.2396240234375, -0.6572265625, -0.1258544921875, -1.060546875, 1.3076171875, -0.301513671875, -1.2412109375, 0.6376953125, -1.5693359375, 0.354248046875, 0.2427978515625, -0.392333984375, 0.61962890625, -0.58837890625, -1.71484375, -0.2098388671875, -0.828125, 0.330810546875, 0.16357421875, -0.2259521484375, 0.0972900390625, -0.451416015625, 1.79296875, -1.673828125, -1.58203125, -2.099609375, -0.487548828125, -0.87060546875, 0.62646484375, -1.470703125, -0.1558837890625, 0.4609375, 1.3369140625, 0.2322998046875, 0.1632080078125, 0.65966796875, 1.0810546875, 0.1041259765625, 0.63232421875, -0.32421875, -1.04296875, -1.046875, -1.3720703125, -0.8486328125, 0.1290283203125, 0.137939453125, 0.1549072265625, -1.0908203125, 0.0167694091796875, -0.31689453125, 1.390625, 0.07269287109375, 1.0390625, 1.1162109375, -0.455810546875, -0.06689453125, -0.053741455078125, 0.5048828125, -0.8408203125, -1.19921875, 0.87841796875, 0.7421875, 0.2030029296875, 0.109619140625, -0.59912109375, -1.337890625, -0.74169921875, -0.64453125, -1.326171875, 0.21044921875, -1.3583984375, -1.685546875, -0.472900390625, -0.270263671875, 0.99365234375, -0.96240234375, 1.1279296875, -0.45947265625, -0.45654296875, -0.99169921875, -3.515625, -1.9853515625, 0.73681640625, 0.92333984375, -0.56201171875, -1.4453125, -2.078125, 0.94189453125, -1.333984375, 0.0982666015625, 0.60693359375, 0.367431640625, 3.015625, -1.1357421875, -1.5634765625, 0.90234375, -0.1783447265625, 0.1802978515625, -0.317138671875, -0.513671875, 1.2353515625, -0.033203125, 1.4482421875, 1.0087890625, 0.9248046875, 0.10418701171875, 0.7626953125, -1.3798828125, 0.276123046875, 0.55224609375, 1.1005859375, -0.62158203125, -0.806640625, 0.65087890625, 0.270263671875, -0.339111328125, -0.9384765625, -0.09381103515625, -0.7216796875, 1.37890625, -0.398193359375, -0.3095703125, -1.4912109375, 0.96630859375, 0.43798828125, 0.62255859375, 0.0213470458984375, 0.235595703125, -1.2958984375, 0.0157318115234375, -0.810546875, 1.9736328125, -0.2462158203125, 0.720703125, 0.822265625, -0.755859375, -0.658203125, 0.344482421875, -2.892578125, -0.282470703125, 1.2529296875, -0.294189453125, 0.6748046875, -0.80859375, 0.9287109375, 1.27734375, -1.71875, -0.166015625, 0.47412109375, -0.41259765625, -1.3681640625, -0.978515625, -0.77978515625, -1.044921875, -0.90380859375, -0.08184814453125, -0.86181640625, -0.10772705078125, -0.299560546875, -0.4306640625, -0.47119140625, 0.95703125, 1.107421875, 0.91796875, 0.76025390625, 0.7392578125, -0.09161376953125, -0.7392578125, 0.9716796875, -0.395751953125, -0.75390625, -0.164306640625, -0.087646484375, 0.028564453125, -0.91943359375, -0.66796875, 2.486328125, 0.427734375, 0.626953125, 0.474853515625, 0.0926513671875, 0.830078125, -0.6923828125, 0.7841796875, -0.89208984375, -2.482421875, 0.034912109375, -1.3447265625, -0.475341796875, -0.286376953125, -0.732421875, 0.190673828125, -0.491455078125, -3.091796875, -1.2783203125, -0.66015625, -0.1507568359375, 0.042236328125, -1.025390625, 0.12744140625, -1.984375, -0.393798828125, -1.25, -1.140625, 1.77734375, 0.2457275390625, -0.8017578125, 0.7763671875, -0.387939453125, -0.3662109375, 1.1572265625, 0.123291015625, -0.07135009765625, 1.412109375, -0.685546875, -3.078125, 0.031524658203125, -0.70458984375, 0.78759765625, 0.433837890625, -1.861328125, -1.33203125, 2.119140625, -1.3544921875, -0.6591796875, -1.4970703125, 0.40625, -2.078125, -1.30859375, 0.050262451171875, -0.60107421875, 1.0078125, 0.05657958984375, -0.96826171875, 0.0264892578125, 0.159912109375, 0.84033203125, -1.1494140625, -0.0433349609375, -0.2034912109375, 1.09765625, -1.142578125, -0.283203125, -0.427978515625, 1.0927734375, -0.67529296875, -0.61572265625, 2.517578125, 0.84130859375, 1.8662109375, 0.1748046875, -0.407958984375, -0.029449462890625, -0.27587890625, -0.958984375, -0.10028076171875, 1.248046875, -0.0792236328125, -0.45556640625, 0.7685546875, 1.5556640625, -1.8759765625, -0.131591796875, -1.3583984375, 0.7890625, 0.80810546875, -1.0322265625, -0.53076171875, -0.1484375, -1.7841796875, -1.2470703125, 0.17138671875, -0.04864501953125, -0.80322265625, -0.0933837890625, 0.984375, 0.7001953125, 0.5380859375, 0.2022705078125, -1.1865234375, 0.5439453125, 1.1318359375, 0.79931640625, 0.32666015625, -1.26171875, 0.457763671875, 1.1591796875, -0.34423828125, 0.65771484375, 0.216552734375, 1.19140625, -0.2744140625, -0.020416259765625, -0.86376953125, 0.93017578125, 1.0556640625, 0.69873046875, -0.15087890625, -0.33056640625, 0.8505859375, 0.06890869140625, 0.359375, -0.262939453125, 0.12493896484375, 0.017059326171875, -0.98974609375, 0.5107421875, 0.2408447265625, 0.615234375, -0.62890625, 0.86962890625, -0.07427978515625, 0.85595703125, 0.300537109375, -1.072265625, -1.6064453125, -0.353515625, -0.484130859375, -0.6044921875, -0.455810546875, 0.95849609375, 1.3671875, 0.544921875, 0.560546875, 0.34521484375, -0.6513671875, -0.410400390625, -0.2021484375, -0.1656494140625, 0.073486328125, 0.84716796875, -1.7998046875, -1.0126953125, -0.1324462890625, 0.95849609375, -0.669921875, -0.79052734375, -2.193359375, -0.42529296875, -1.7275390625, -1.04296875, 0.716796875, -0.4423828125, -1.193359375, 0.61572265625, -1.5224609375, 0.62890625, -0.705078125, 0.677734375, -0.213134765625, -1.6748046875, -1.087890625, -0.65185546875, -1.1337890625, 2.314453125, -0.352783203125, -0.27001953125, -2.01953125, -1.2685546875, 0.308837890625, -0.280517578125, -1.3798828125, -1.595703125, 0.642578125, 1.693359375, -0.82470703125, -1.255859375, 0.57373046875, 1.5859375, 1.068359375, -0.876953125, 0.370849609375, 1.220703125, 0.59765625, 0.007602691650390625, 0.09326171875, -0.9521484375, -0.024932861328125, -0.94775390625, -0.299560546875, -0.002536773681640625, 1.41796875, -0.06903076171875, -1.5927734375, 0.353515625, 3.63671875, -0.765625, -1.1142578125, 0.4287109375, -0.86865234375, -0.9267578125, -0.21826171875, -1.10546875, 0.29296875, -0.225830078125, 0.5400390625, -0.45556640625, -0.68701171875, -0.79150390625, -1.0810546875, 0.25439453125, -1.2998046875, -0.494140625, -0.1510009765625, 1.5615234375, -0.4248046875, -0.486572265625, 0.45458984375, 0.047637939453125, -0.11639404296875, 0.057403564453125, 0.130126953125, -0.10125732421875, -0.56201171875, 1.4765625, -1.7451171875, 1.34765625, -0.45703125, 0.873046875, -0.056121826171875, -0.8876953125, -0.986328125, 1.5654296875, 0.49853515625, 0.55859375, -0.2198486328125, 0.62548828125, 0.2734375, -0.63671875, -0.41259765625, -1.2705078125, 0.0665283203125, 1.3369140625, 0.90283203125, -0.77685546875, -1.5, -1.8525390625, -1.314453125, -0.86767578125, -0.331787109375, 0.1590576171875, 0.94775390625, -0.1771240234375, 1.638671875, -2.17578125, 0.58740234375, 0.424560546875, -0.3466796875, 0.642578125, 0.473388671875, 0.96435546875, 1.38671875, -0.91357421875, 1.0361328125, -0.67333984375, 1.5009765625]]]).to(device)
24
+
25
+ cond = [[prompt_embeds, {}]]
26
+
27
+ return io.NodeOutput(cond)
28
+
29
+
30
+ class LotusExtension(ComfyExtension):
31
+ @override
32
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
33
+ return [
34
+ LotusConditioning,
35
+ ]
36
+
37
+
38
+ async def comfy_entrypoint() -> LotusExtension:
39
+ return LotusExtension()