File size: 5,144 Bytes
f4a41d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import tqdm
import torch
import comfy.k_diffusion.sampling
from tqdm.auto import trange
from comfy.k_diffusion.sampling import to_d


INITIALIZED = False


def default_noise_sampler(x):
    return lambda sigma, sigma_next: comfy.k_diffusion.sampling.torch.randn_like(x)


@torch.no_grad()
def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
    # TCD sampling using modified DDPM.
    extra_args = {} if extra_args is None else extra_args
    noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
    s_in = x.new_ones([x.shape[0]])
    
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})

        sigma_from, sigma_to = sigmas[i], sigmas[i+1]

        # TCD offset, based on gamma, and conversion between sigma and timestep
        t = model.inner_model.inner_model.model_sampling.timestep(sigma_from)
        t_s = (1 - gamma) * t
        sigma_to_s = model.inner_model.inner_model.model_sampling.sigma(t_s)

        # if sigma_to_s > sigma_to:
        #     sigma_to_s = sigma_to
        # if sigma_to_s < 0:
        #     sigma_to_s = torch.tensor(1.0)
        # print(f"sigma_from: {sigma_from}, sigma_to: {sigma_to}, sigma_to_s: {sigma_to_s}")


        # The following is equivalent to the comfy DDPM implementation
        # x = DDPMSampler_step(x / torch.sqrt(1.0 + sigma_from ** 2.0), sigma_from, sigma_to, (x - denoised) / sigma_from, noise_sampler)

        noise_est = (x - denoised) / sigma_from
        x /= torch.sqrt(1.0 + sigma_from ** 2.0)

        alpha_cumprod = 1 / ((sigma_from * sigma_from) + 1)  # _t
        alpha_cumprod_prev = 1 / ((sigma_to * sigma_to) + 1) # _t_prev
        alpha = (alpha_cumprod / alpha_cumprod_prev)

        ## These values should approach 1.0?
        # print(f"alpha_cumprod: {alpha_cumprod}")
        # print(f"alpha_cumprod_prev: {alpha_cumprod_prev}")
        # print(f"alpha: {alpha}")


        # alpha_cumprod_down = 1 / ((sigma_to_s * sigma_to_s) + 1)  # _s
        # alpha_d = (alpha_cumprod_prev / alpha_cumprod_down)
        # alpha2 = (alpha_cumprod / alpha_cumprod_down)
        # print(f"** alpha_cumprod_down: {alpha_cumprod_down}")
        # print(f"** alpha_d: {alpha_d}, alpha2: #{alpha2}")

        # epsilon noise prediction from comfy DDPM implementation
        x = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())
        # x = (1.0 / alpha_d).sqrt() * (x - (1 - alpha) * noise_est / (1 - alpha_cumprod).sqrt())

        first_step = sigma_to == 0
        last_step = i == len(sigmas) - 2

        if not first_step:
            if gamma > 0 and not last_step:
                noise = noise_sampler(sigma_from, sigma_to)

                # x += ((1 - alpha_d) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise
                variance = ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)) * (1 - alpha_cumprod / alpha_cumprod_prev)
                x += variance.sqrt() * noise # scale noise by std deviation

                # relevant diffusers code from scheduling_tcd.py
                # prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (
                #     1 - alpha_prod_t_prev / alpha_prod_s
                # ).sqrt() * noise

            x *= torch.sqrt(1.0 + sigma_to ** 2.0)

        # beta_cumprod_t = 1 - alpha_cumprod
        # beta_cumprod_s = 1 - alpha_cumprod_down

 
    return x


def sample_tcd_euler_a(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, gamma=0.3):
    # TCD sampling using modified Euler Ancestral sampler. by @laksjdjf
    extra_args = {} if extra_args is None else extra_args
    noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})

        d = to_d(x, sigmas[i], denoised)

        sigma_from = sigmas[i]
        sigma_to = sigmas[i + 1]

        t = model.inner_model.inner_model.model_sampling.timestep(sigma_from)
        down_t = (1 - gamma) * t
        sigma_down = model.inner_model.inner_model.model_sampling.sigma(down_t)

        if sigma_down > sigma_to:
            sigma_down = sigma_to

        sigma_up = (sigma_to ** 2 - sigma_down ** 2) ** 0.5
        
        # same as euler ancestral
        d = to_d(x, sigma_from, denoised)
        dt = sigma_down - sigma_from
        x += d * dt

        if sigma_to > 0 and gamma > 0:
            x = model.inner_model.inner_model.model_sampling.noise_scaling(sigma_up, noise_sampler(sigma_from, sigma_to), x)
    return x