File size: 19,397 Bytes
ca2a3d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import torch, struct, json
from io import BytesIO
import latent_preview, comfy
from server import PromptServer
from comfy.model_base import SDXL, SVD_img2vid, Flux, WAN21, Chroma
from comfy import samplers
import numpy as np
from math import ceil
from latent_preview import TAESDPreviewerImpl
from comfy_execution.utils import get_executing_context

def slerp(val, low, high):
    low_norm = low / torch.norm(low, dim=1, keepdim=True)
    high_norm = high / torch.norm(high, dim=1, keepdim=True)
    dot = (low_norm * high_norm).sum(1)
    if dot.mean() > 0.9995:
        return low * val + high * (1 - val)
    omega = torch.acos(dot)
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
    return res

def swarm_partial_noise(seed, latent_image):
    generator = torch.manual_seed(seed)
    return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")

def swarm_fixed_noise(seed, latent_image, var_seed, var_seed_strength):
    noises = []
    for i in range(latent_image.size()[0]):
        if var_seed_strength > 0:
            noise = swarm_partial_noise(seed, latent_image[i])
            var_noise = swarm_partial_noise(var_seed + i, latent_image[i])
            if noise.ndim == 4: # Video models are B C F H W, we're in a B loop already so sub-iterate over F (Frames)
                for j in range(noise.shape[1]):
                    noise[:, j] = slerp(var_seed_strength, noise[:, j], var_noise[:, j])
            else:
                noise = slerp(var_seed_strength, noise, var_noise)
        else:
            noise = swarm_partial_noise(seed + i, latent_image[i])
        noises.append(noise)
    return torch.stack(noises, dim=0)

def get_preview_metadata():
    executing_context = get_executing_context()
    prompt_id = None
    node_id = None
    if executing_context is not None:
        prompt_id = executing_context.prompt_id
        node_id = executing_context.node_id
    if prompt_id is None:
        prompt_id = PromptServer.instance.last_prompt_id
    if node_id is None:
        node_id = PromptServer.instance.last_node_id
    return {"node_id": node_id, "prompt_id": prompt_id, "display_node_id": node_id, "parent_node_id": node_id, "real_node_id": node_id} # display_node_id, parent_node_id, real_node_id? comfy_execution/progress.py has this.

def swarm_send_extra_preview(id, image):
    server = PromptServer.instance
    metadata = get_preview_metadata()
    metadata["mime_type"] = "image/jpeg"
    metadata["id"] = id
    metadata_json = json.dumps(metadata).encode('utf-8')
    bytesIO = BytesIO()
    image.save(bytesIO, format="JPEG", quality=90, compress_level=4)
    image_bytes = bytesIO.getvalue()
    combined_data = bytearray()
    combined_data.extend(struct.pack(">I", len(metadata_json)))
    combined_data.extend(metadata_json)
    combined_data.extend(image_bytes)
    server.send_sync(9999123, combined_data, sid=server.client_id)

def swarm_send_animated_preview(id, images):
    server = PromptServer.instance
    bytesIO = BytesIO()
    images[0].save(bytesIO, save_all=True, duration=int(1000.0/6), append_images=images[1 : len(images)], lossless=False, quality=60, method=0, format='WEBP')
    bytesIO.seek(0)
    image_bytes = bytesIO.getvalue()
    metadata = get_preview_metadata()
    metadata["mime_type"] = "image/webp"
    metadata["id"] = id
    metadata_json = json.dumps(metadata).encode('utf-8')
    combined_data = bytearray()
    combined_data.extend(struct.pack(">I", len(metadata_json)))
    combined_data.extend(metadata_json)
    combined_data.extend(image_bytes)
    server.send_sync(9999123, combined_data, sid=server.client_id)

def calculate_sigmas_scheduler(model, scheduler_name, steps, sigma_min, sigma_max, rho):
    model_sampling = model.get_model_object("model_sampling")
    if scheduler_name == "karras":
        return comfy.k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min if sigma_min >= 0 else float(model_sampling.sigma_min), sigma_max=sigma_max if sigma_max >= 0 else float(model_sampling.sigma_max), rho=rho)
    elif scheduler_name == "exponential":
        return comfy.k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min if sigma_min >= 0 else float(model_sampling.sigma_min), sigma_max=sigma_max if sigma_max >= 0 else float(model_sampling.sigma_max))
    else:
        return None

def make_swarm_sampler_callback(steps, device, model, previews):
    previewer = latent_preview.get_previewer(device, model.model.latent_format) if previews != "none" else None
    pbar = comfy.utils.ProgressBar(steps)
    def callback(step, x0, x, total_steps):
        pbar.update_absolute(step + 1, total_steps, None)
        if previewer:
            if (step == 0 or (step < 3 and x0.ndim == 5 and x0.shape[1] > 8)) and not isinstance(previewer, TAESDPreviewerImpl):
                x0 = x0.clone().cpu() # Sync copy to CPU for first few steps to prevent reading old data, more steps for videos. Future steps allow comfy to do its async non_blocky stuff.
            if x0.ndim == 5:
                # video shape is [batch, channels, backwards time, width, height], for previews needs to be swapped to [forwards time, channels, width, height]
                x0 = x0[0].permute(1, 0, 2, 3)
                x0 = torch.flip(x0, [0])
            def do_preview(id, index):
                preview_img = previewer.decode_latent_to_preview_image("JPEG", x0[index:index+1])
                swarm_send_extra_preview(id, preview_img[1])
            if previews == "iterate":
                do_preview(0, step % x0.shape[0])
            elif previews == "animate":
                if x0.shape[0] == 1:
                    do_preview(0, 0)
                else:
                    images = []
                    for i in range(x0.shape[0]):
                        preview_img = previewer.decode_latent_to_preview_image("JPEG", x0[i:i+1])
                        images.append(preview_img[1])
                    swarm_send_animated_preview(0, images)
            elif previews == "default":
                for i in range(x0.shape[0]):
                    preview_img = previewer.decode_latent_to_preview_image("JPEG", x0[i:i+1])
                    swarm_send_extra_preview(i, preview_img[1])
            elif previews == "one":
                do_preview(0, 0)
            elif previews == "second":
                do_preview(0, 1 % x0.shape[0])
    return callback


def loglinear_interp(t_steps, num_steps):
    """

    Performs log-linear interpolation of a given array of decreasing numbers.

    """
    xs = np.linspace(0, 1, len(t_steps))
    ys = np.log(t_steps[::-1])

    new_xs = np.linspace(0, 1, num_steps)
    new_ys = np.interp(new_xs, xs, ys)

    interped_ys = np.exp(new_ys)[::-1].copy()
    return interped_ys

AYS_NOISE_LEVELS = {
    "SD1": [14.6146412293, 6.4745760956,  3.8636745985,  2.6946151520, 1.8841921177,  1.3943805092,  0.9642583904,  0.6523686016, 0.3977456272,  0.1515232662,  0.0291671582],
    "SDXL":[14.6146412293, 6.3184485287,  3.7681790315,  2.1811480769, 1.3405244945,  0.8620721141,  0.5550693289,  0.3798540708, 0.2332364134,  0.1114188177,  0.0291671582],
    "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002],
    # Flux and Wan from https://github.com/comfyanonymous/ComfyUI/pull/7584
    "Flux": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001],
    "Wan": [1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001],
    # https://github.com/comfyanonymous/ComfyUI/commit/08ff5fa08a92e0b3f23b9abec979a830a6cffb03#diff-3e4e70e402dcd9e1070ad71ef9292277f10d9faccf36a1c405c0c717a7ee6485R23
    "Chroma": [0.992, 0.99, 0.988, 0.985, 0.982, 0.978, 0.973, 0.968, 0.961, 0.953, 0.943, 0.931, 0.917, 0.9, 0.881, 0.858, 0.832, 0.802, 0.769, 0.731, 0.69, 0.646, 0.599, 0.55, 0.501, 0.451, 0.402, 0.355, 0.311, 0.27, 0.232, 0.199, 0.169, 0.143, 0.12, 0.101, 0.084, 0.07, 0.058, 0.048, 0.001]
}

def split_latent_tensor(latent_tensor, tile_size=1024, scale_factor=8):
    """Generate tiles for a given latent tensor, considering the scaling factor."""
    latent_tile_size = tile_size // scale_factor  # Adjust tile size for latent space
    height, width = latent_tensor.shape[-2:]

    # Determine the number of tiles needed
    num_tiles_x = ceil(width / latent_tile_size)
    num_tiles_y = ceil(height / latent_tile_size)

    # If width or height is an exact multiple of the tile size, add an additional tile for overlap
    if width % latent_tile_size == 0:
        num_tiles_x += 1
    if height % latent_tile_size == 0:
        num_tiles_y += 1

    # Calculate the overlap
    overlap_x = 0 if num_tiles_x == 1 else (num_tiles_x * latent_tile_size - width) / (num_tiles_x - 1)
    overlap_y = 0 if num_tiles_y == 1 else (num_tiles_y * latent_tile_size - height) / (num_tiles_y - 1)
    if overlap_x < 32 and num_tiles_x > 1:
        num_tiles_x += 1
        overlap_x = (num_tiles_x * latent_tile_size - width) / (num_tiles_x - 1)
    if overlap_y < 32 and num_tiles_y > 1:
        num_tiles_y += 1
        overlap_y = (num_tiles_y * latent_tile_size - height) / (num_tiles_y - 1)

    tiles = []

    for i in range(num_tiles_y):
        for j in range(num_tiles_x):
            x_start = j * latent_tile_size - j * overlap_x
            y_start = i * latent_tile_size - i * overlap_y

            # Correct for potential float precision issues
            x_start = round(x_start)
            y_start = round(y_start)

            # Crop the tile from the latent tensor
            tile_tensor = latent_tensor[..., y_start:y_start + latent_tile_size, x_start:x_start + latent_tile_size]
            tiles.append(((x_start, y_start, x_start + latent_tile_size, y_start + latent_tile_size), tile_tensor))

    return tiles

def stitch_latent_tensors(original_size, tiles, scale_factor=8):
    """Stitch tiles together to create the final upscaled latent tensor with overlaps."""
    result = torch.zeros(original_size)

    # We assume tiles come in the format [(coordinates, tile), ...]
    sorted_tiles = sorted(tiles, key=lambda x: (x[0][1], x[0][0]))  # Sort by upper then left

    # Variables to keep track of the current row's starting point
    current_row_upper = None

    for (left, upper, right, lower), tile in sorted_tiles:

        # Check if we're starting a new row
        if current_row_upper != upper:
            current_row_upper = upper
            first_tile_in_row = True
        else:
            first_tile_in_row = False

        tile_width = right - left
        tile_height = lower - upper
        feather = tile_width // 8  # Assuming feather size is consistent with the example

        mask = torch.ones_like(tile)

        if not first_tile_in_row:  # Left feathering for tiles other than the first in the row
            for t in range(feather):
                mask[..., :, t:t+1] *= (1.0 / feather) * (t + 1)

        if upper != 0:  # Top feathering for all tiles except the first row
            for t in range(feather):
                mask[..., t:t+1, :] *= (1.0 / feather) * (t + 1)

        # Apply the feathering mask
        combined_area = tile * mask + result[..., upper:lower, left:right] * (1.0 - mask)
        result[..., upper:lower, left:right] = combined_area

    return result

class SwarmKSampler:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
                "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
                "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.5, "round": 0.001}),
                "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
                "scheduler": (["turbo", "align_your_steps", "ltxv", "ltxv-image"] + comfy.samplers.KSampler.SCHEDULERS, ),
                "positive": ("CONDITIONING", ),
                "negative": ("CONDITIONING", ),
                "latent_image": ("LATENT", ),
                "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
                "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
                "var_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
                "var_seed_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.05, "round": 0.001}),
                "sigma_max": ("FLOAT", {"default": -1, "min": -1.0, "max": 1000.0, "step":0.01, "round": False}),
                "sigma_min": ("FLOAT", {"default": -1, "min": -1.0, "max": 1000.0, "step":0.01, "round": False}),
                "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
                "add_noise": (["enable", "disable"], ),
                "return_with_leftover_noise": (["disable", "enable"], ),
                "previews": (["default", "none", "one", "second", "iterate", "animate"], ),
                "tile_sample": ("BOOLEAN", {"default": False}),
                "tile_size": ("INT", {"default": 1024, "min": 256, "max": 4096}),
            }
        }

    CATEGORY = "SwarmUI/sampling"
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "run_sampling"
    DESCRIPTION = "Works like a vanilla Comfy KSamplerAdvanced, but with extra inputs for advanced features such as sigma scale, tiling, previews, etc."

    def sample(self, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, var_seed, var_seed_strength, sigma_max, sigma_min, rho, add_noise, return_with_leftover_noise, previews):
        device = comfy.model_management.get_torch_device()
        latent_samples = latent_image["samples"]
        latent_samples = comfy.sample.fix_empty_latent_channels(model, latent_samples)
        disable_noise = add_noise == "disable"

        if disable_noise:
            noise = torch.zeros(latent_samples.size(), dtype=latent_samples.dtype, layout=latent_samples.layout, device="cpu")
        else:
            noise = swarm_fixed_noise(noise_seed, latent_samples, var_seed, var_seed_strength)

        noise_mask = None
        if "noise_mask" in latent_image:
            noise_mask = latent_image["noise_mask"]

        sigmas = None
        if scheduler == "turbo":
            timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
            sigmas = model.model.model_sampling.sigma(timesteps)
            sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
        elif scheduler == "ltx" or scheduler == "ltxv-image":
            from comfy_extras.nodes_lt import LTXVScheduler
            sigmas = LTXVScheduler().get_sigmas(steps, 2.05, 0.95, True, 0.1, latent_image if scheduler == "ltxv-image" else None)[0]
        elif scheduler == "align_your_steps":
            if isinstance(model.model, SDXL):
                model_type = "SDXL"
            elif isinstance(model.model, SVD_img2vid):
                model_type = "SVD"
            elif isinstance(model.model, Flux):
                model_type = "Flux"
            elif isinstance(model.model, WAN21):
                model_type = "Wan"
            elif isinstance(model.model, Chroma):
                model_type = "Chroma"
            else:
                print(f"AlignYourSteps: Unknown model type: {type(model.model)}, defaulting to SD1")
                model_type = "SD1"
            sigmas = AYS_NOISE_LEVELS[model_type][:]
            if (steps + 1) != len(sigmas):
                sigmas = loglinear_interp(sigmas, steps + 1)
            sigmas[-1] = 0
            sigmas = torch.FloatTensor(sigmas)
        elif sigma_min >= 0 and sigma_max >= 0 and scheduler in ["karras", "exponential"]:
            if sampler_name in ['dpm_2', 'dpm_2_ancestral']:
                sigmas = calculate_sigmas_scheduler(model, scheduler, steps + 1, sigma_min, sigma_max, rho)
                sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
            else:
                sigmas = calculate_sigmas_scheduler(model, scheduler, steps, sigma_min, sigma_max, rho)
            sigmas = sigmas.to(device)
        
        out = latent_image.copy()
        if steps > 0:
            callback = make_swarm_sampler_callback(steps, device, model, previews)

            samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_samples,
                                    denoise=1.0, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step,
                                    force_full_denoise=return_with_leftover_noise == "disable", noise_mask=noise_mask, sigmas=sigmas, callback=callback, seed=noise_seed)
            out["samples"] = samples
        return (out, )

    # tiled sample version of sample function
    def tiled_sample(self, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, var_seed, var_seed_strength, sigma_max, sigma_min, rho, add_noise, return_with_leftover_noise, previews, tile_size):
        out = latent_image.copy()
        # split image into tiles
        latent_samples = latent_image["samples"]
        tiles = split_latent_tensor(latent_samples, tile_size=tile_size)
        # resample each tile using self.sample
        resampled_tiles = []
        for coords, tile in tiles:
            resampled_tile = self.sample(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, {"samples": tile}, start_at_step, end_at_step, var_seed, var_seed_strength, sigma_max, sigma_min, rho, add_noise, return_with_leftover_noise, previews)
            resampled_tiles.append((coords, resampled_tile[0]["samples"]))
        # stitch the tiles to get the final upscaled image
        result = stitch_latent_tensors(latent_samples.shape, resampled_tiles)
        out["samples"] = result
        return (out,)

    def run_sampling(self, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, var_seed, var_seed_strength, sigma_max, sigma_min, rho, add_noise, return_with_leftover_noise, previews, tile_sample,  tile_size):
        if tile_sample:
            return self.tiled_sample(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, var_seed, var_seed_strength, sigma_max, sigma_min, rho, add_noise, return_with_leftover_noise, previews, tile_size)
        else:
            return self.sample(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, var_seed, var_seed_strength, sigma_max, sigma_min, rho, add_noise, return_with_leftover_noise, previews)

NODE_CLASS_MAPPINGS = {
    "SwarmKSampler": SwarmKSampler,
}