oimoyu's picture
init
9ab8b5f verified
raw
history blame
6.1 kB
import torch
import numpy as np
from comfy.model_patcher import ModelPatcher
from . import shared, rng_philox
class TorchHijack:
"""This is here to replace torch.randn_like of k-diffusion.
k-diffusion has random_sampler argument for most samplers, but not for all, so
this is needed to properly replace every use of torch.randn_like.
We need to replace to make images generated in batches to be same as images generated individually."""
def __init__(self, generator, randn_source, init=True):
self.generator = generator
self.randn_source = randn_source
self.init = init
def __getattr__(self, item):
if item == 'randn_like':
return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
return randn_without_seed(x, generator=self.generator, randn_source=self.randn_source)
def randn_without_seed(x, generator=None, randn_source="cpu"):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
if randn_source == "nv":
return torch.asarray(generator.randn(x.size()), device=x.device)
else:
return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=generator.device, generator=generator).to(device=x.device)
def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
opts = None
opts_found = False
model = _find_outer_instance('model', ModelPatcher)
if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None:
import comfy.samplers
guider = _find_outer_instance('guider', comfy.samplers.CFGGuider)
model = getattr(guider, 'model_patcher', None)
if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None:
pass
opts_found = opts is not None
if not opts_found:
opts = shared.opts_default
device = torch.device("cpu")
if opts.randn_source == 'gpu':
import comfy.model_management
device = comfy.model_management.get_torch_device()
device_orig = device
device = torch.device("cpu") if opts.randn_source == "cpu" else device_orig
def get_generator(seed):
nonlocal device, opts
if opts.randn_source == 'nv':
generator = rng_philox.Generator(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
return generator
def get_generator_obj(seed):
nonlocal opts
generator = torch.manual_seed(seed)
generator = generator_eta = get_generator(seed)
if opts.eta_noise_seed_delta > 0:
seed = min(int(seed + opts.eta_noise_seed_delta), int(0xffffffffffffffff))
generator_eta = get_generator(seed)
return (generator, generator_eta)
generator, generator_eta = get_generator_obj(seed)
randn_source = opts.randn_source
# ========== hijack randn_like ===============
import comfy.k_diffusion.sampling
# if not hasattr(comfy.k_diffusion.sampling, 'torch_orig'):
# comfy.k_diffusion.sampling.torch_orig = comfy.k_diffusion.sampling.torch
# comfy.k_diffusion.sampling.torch = TorchHijack(generator_eta, opts.randn_source)
if not hasattr(comfy.k_diffusion.sampling, 'default_noise_sampler_orig'):
comfy.k_diffusion.sampling.default_noise_sampler_orig = comfy.k_diffusion.sampling.default_noise_sampler
if opts_found:
th = TorchHijack(generator_eta, randn_source)
def default_noise_sampler(x, seed=None, *args, **kwargs):
nonlocal th
return lambda sigma, sigma_next: th.randn_like(x)
default_noise_sampler.init = True
comfy.k_diffusion.sampling.default_noise_sampler = default_noise_sampler
else:
comfy.k_diffusion.sampling.default_noise_sampler = comfy.k_diffusion.sampling.default_noise_sampler_orig
# =============================================
if noise_inds is None:
shape = latent_image.size()
if opts.randn_source == 'nv':
noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device)
else:
noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
noise = noise.to(device=device_orig)
return noise
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
shape = [1] + list(latent_image.size())[1:]
if opts.randn_source == 'nv':
noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device)
else:
noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
noise = noise.to(device=device_orig)
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
def _find_outer_instance(target:str, target_type=None, callback=None, max_len=10):
import inspect
frame = inspect.currentframe()
i = 0
while frame and i < max_len:
if target in frame.f_locals:
if callback is not None:
return callback(frame)
else:
found = frame.f_locals[target]
if isinstance(found, target_type):
return found
frame = frame.f_back
i += 1
return None