File size: 6,096 Bytes
9ab8b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import numpy as np
from comfy.model_patcher import ModelPatcher
from . import shared, rng_philox

class TorchHijack:
    """This is here to replace torch.randn_like of k-diffusion.



    k-diffusion has random_sampler argument for most samplers, but not for all, so

    this is needed to properly replace every use of torch.randn_like.



    We need to replace to make images generated in batches to be same as images generated individually."""

    def __init__(self, generator, randn_source, init=True):
        self.generator = generator
        self.randn_source = randn_source
        self.init = init

    def __getattr__(self, item):
        if item == 'randn_like':
            return self.randn_like

        if hasattr(torch, item):
            return getattr(torch, item)

        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")

    def randn_like(self, x):
        return randn_without_seed(x, generator=self.generator, randn_source=self.randn_source)

def randn_without_seed(x, generator=None, randn_source="cpu"):
    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.



    Use either randn() or manual_seed() to initialize the generator."""
    if randn_source == "nv":
        return torch.asarray(generator.randn(x.size()), device=x.device)
    else:
        return torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=generator.device, generator=generator).to(device=x.device)

def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'):
    """

    creates random noise given a latent image and a seed.

    optional arg skip can be used to skip and discard x number of noise generations for a given seed

    """
    opts = None
    opts_found = False
    model = _find_outer_instance('model', ModelPatcher)
    if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None:
        import comfy.samplers
        guider = _find_outer_instance('guider', comfy.samplers.CFGGuider)
        model = getattr(guider, 'model_patcher', None)
    if (model is not None and (opts:=model.model_options.get(shared.Options.KEY)) is None) or opts is None:
        pass
    opts_found = opts is not None
    if not opts_found:
        opts = shared.opts_default
        device = torch.device("cpu")

    if opts.randn_source == 'gpu':
        import comfy.model_management
        device = comfy.model_management.get_torch_device()

    device_orig = device
    device = torch.device("cpu") if opts.randn_source == "cpu" else device_orig

    def get_generator(seed):
        nonlocal device, opts
        if opts.randn_source == 'nv':
            generator = rng_philox.Generator(seed)
        else:
            generator = torch.Generator(device=device).manual_seed(seed)
        return generator

    def get_generator_obj(seed):
        nonlocal opts
        generator = torch.manual_seed(seed)
        generator = generator_eta = get_generator(seed)
        if opts.eta_noise_seed_delta > 0:
            seed = min(int(seed + opts.eta_noise_seed_delta), int(0xffffffffffffffff))
            generator_eta = get_generator(seed)
        return (generator, generator_eta)

    generator, generator_eta = get_generator_obj(seed)
    randn_source = opts.randn_source

    # ========== hijack randn_like ===============
    import comfy.k_diffusion.sampling
    # if not hasattr(comfy.k_diffusion.sampling, 'torch_orig'):
    #     comfy.k_diffusion.sampling.torch_orig = comfy.k_diffusion.sampling.torch
    # comfy.k_diffusion.sampling.torch = TorchHijack(generator_eta, opts.randn_source)

    if not hasattr(comfy.k_diffusion.sampling, 'default_noise_sampler_orig'):
        comfy.k_diffusion.sampling.default_noise_sampler_orig = comfy.k_diffusion.sampling.default_noise_sampler
    if opts_found:
        th = TorchHijack(generator_eta, randn_source)
        def default_noise_sampler(x, seed=None, *args, **kwargs):
            nonlocal th
            return lambda sigma, sigma_next: th.randn_like(x)
        default_noise_sampler.init = True
        comfy.k_diffusion.sampling.default_noise_sampler = default_noise_sampler
    else:
        comfy.k_diffusion.sampling.default_noise_sampler = comfy.k_diffusion.sampling.default_noise_sampler_orig
    # =============================================

    if noise_inds is None:
        shape = latent_image.size()
        if opts.randn_source == 'nv':
            noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device)
        else:
            noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
        noise = noise.to(device=device_orig)
        return noise

    unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
    noises = []
    for i in range(unique_inds[-1]+1):
        shape = [1] + list(latent_image.size())[1:]
        if opts.randn_source == 'nv':
            noise = torch.asarray(generator.randn(shape), dtype=latent_image.dtype, device=device)
        else:
            noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, device=device, generator=generator)
        noise = noise.to(device=device_orig)
        if i in unique_inds:
            noises.append(noise)
    noises = [noises[i] for i in inverse]
    noises = torch.cat(noises, axis=0)
    return noises

def _find_outer_instance(target:str, target_type=None, callback=None, max_len=10):
    import inspect
    frame = inspect.currentframe()
    i = 0
    while frame and i < max_len:
        if target in frame.f_locals:
            if callback is not None:
                return callback(frame)
            else:
                found = frame.f_locals[target]
                if isinstance(found, target_type):
                    return found
        frame = frame.f_back
        i += 1
    return None