Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,096 Bytes
9ab8b5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
import numpy as np
from comfy.model_patcher import ModelPatcher
from . import shared, rng_philox
class TorchHijack:
"""This is here to replace torch.randn_like of k-diffusion.
k-diffusion has random_sampler argument for most samplers, but not for all, so
this is needed to properly replace every use of torch.randn_like.
We need to replace to make images generated in batches to be same as images generated individually."""
def __init__(self, generator, randn_source, init=True):
self.generator = generator
self.randn_source = randn_source
self.init = init
def __getattr__(self, item):
if item == 'randn_like':
return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
return randn_without_seed(x, generator=self.generator, randn_source=self.randn_source)
def randn_without_seed(x, generator=None, randn_source="cpu"):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
if randn_source == "nv":
return torch.asarray(generator.randn(x.size()), device=x.device)
else:
return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=generator.device, generator=generator).to(device=x.device)
def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
opts = None
opts_found = False
model = _find_outer_instance('model', ModelPatcher)
if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None:
import comfy.samplers
guider = _find_outer_instance('guider', comfy.samplers.CFGGuider)
model = getattr(guider, 'model_patcher', None)
if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None:
pass
opts_found = opts is not None
if not opts_found:
opts = shared.opts_default
device = torch.device("cpu")
if opts.randn_source == 'gpu':
import comfy.model_management
device = comfy.model_management.get_torch_device()
device_orig = device
device = torch.device("cpu") if opts.randn_source == "cpu" else device_orig
def get_generator(seed):
nonlocal device, opts
if opts.randn_source == 'nv':
generator = rng_philox.Generator(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
return generator
def get_generator_obj(seed):
nonlocal opts
generator = torch.manual_seed(seed)
generator = generator_eta = get_generator(seed)
if opts.eta_noise_seed_delta > 0:
seed = min(int(seed + opts.eta_noise_seed_delta), int(0xffffffffffffffff))
generator_eta = get_generator(seed)
return (generator, generator_eta)
generator, generator_eta = get_generator_obj(seed)
randn_source = opts.randn_source
# ========== hijack randn_like ===============
import comfy.k_diffusion.sampling
# if not hasattr(comfy.k_diffusion.sampling, 'torch_orig'):
# comfy.k_diffusion.sampling.torch_orig = comfy.k_diffusion.sampling.torch
# comfy.k_diffusion.sampling.torch = TorchHijack(generator_eta, opts.randn_source)
if not hasattr(comfy.k_diffusion.sampling, 'default_noise_sampler_orig'):
comfy.k_diffusion.sampling.default_noise_sampler_orig = comfy.k_diffusion.sampling.default_noise_sampler
if opts_found:
th = TorchHijack(generator_eta, randn_source)
def default_noise_sampler(x, seed=None, *args, **kwargs):
nonlocal th
return lambda sigma, sigma_next: th.randn_like(x)
default_noise_sampler.init = True
comfy.k_diffusion.sampling.default_noise_sampler = default_noise_sampler
else:
comfy.k_diffusion.sampling.default_noise_sampler = comfy.k_diffusion.sampling.default_noise_sampler_orig
# =============================================
if noise_inds is None:
shape = latent_image.size()
if opts.randn_source == 'nv':
noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device)
else:
noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
noise = noise.to(device=device_orig)
return noise
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
shape = [1] + list(latent_image.size())[1:]
if opts.randn_source == 'nv':
noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device)
else:
noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
noise = noise.to(device=device_orig)
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
def _find_outer_instance(target:str, target_type=None, callback=None, max_len=10):
import inspect
frame = inspect.currentframe()
i = 0
while frame and i < max_len:
if target in frame.f_locals:
if callback is not None:
return callback(frame)
else:
found = frame.f_locals[target]
if isinstance(found, target_type):
return found
frame = frame.f_back
i += 1
return None
|