| import torch | |
| from tqdm import trange | |
| from comfy.samplers import KSAMPLER | |
| def sample_inverse(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): | |
| extra_args = {} if extra_args is None else extra_args | |
| s_in = x.new_ones([x.shape[0]]) | |
| for i in trange(len(sigmas) - 1, disable=disable): | |
| sigma_hat = sigmas[i] | |
| denoised = model(x, sigma_hat * s_in, **extra_args) | |
| if callback is not None: | |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) | |
| dt = sigmas[i + 1] - sigma_hat | |
| x = x + denoised * dt | |
| return x | |
| class FluxInverseSamplerNode: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| }, "optional": { | |
| }} | |
| RETURN_TYPES = ("SAMPLER",) | |
| FUNCTION = "build" | |
| CATEGORY = "flux" | |
| def build(self): | |
| sampler = KSAMPLER(sample_inverse) | |
| return (sampler, torch.Tensor([0])) | |