File size: 4,925 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from modules.sd_samplers_kdiffusion import KDiffusionSampler
from modules import script_callbacks, devices
from functools import wraps
from random import random
import torch

from .scaling import apply_scaling


class NoiseMethods:

    @staticmethod
    def get_delta(latent: torch.Tensor) -> torch.Tensor:
        mean = torch.mean(latent)
        return torch.sub(latent, mean)

    @staticmethod
    def to_abs(latent: torch.Tensor) -> torch.Tensor:
        return torch.abs(latent)

    @staticmethod
    def zeros(latent: torch.Tensor) -> torch.Tensor:
        return torch.zeros_like(latent)

    @staticmethod
    def ones(latent: torch.Tensor) -> torch.Tensor:
        return torch.ones_like(latent)

    @staticmethod
    def gaussian_noise(latent: torch.Tensor) -> torch.Tensor:
        return torch.rand_like(latent)

    @staticmethod
    def normal_noise(latent: torch.Tensor) -> torch.Tensor:
        return torch.randn_like(latent)

    @staticmethod
    @torch.inference_mode()
    def multires_noise(
        latent: torch.Tensor, use_zero: bool, iterations: int = 8, discount: float = 0.4
    ):
        """
        Credit: Kohya_SS
        https://github.com/kohya-ss/sd-scripts/blob/v0.8.5/library/custom_train_functions.py#L448
        """

        noise = NoiseMethods.zeros(latent) if use_zero else NoiseMethods.ones(latent)
        batchSize, c, w, h = noise.shape

        device = devices.get_optimal_device()
        upsampler = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)

        for b in range(batchSize):
            for i in range(iterations):
                r = random() * 2 + 2

                wn = max(1, int(w / (r**i)))
                hn = max(1, int(h / (r**i)))

                noise[b] += (
                    upsampler(torch.randn(1, c, hn, wn).to(device)) * discount**i
                )[0]

                if wn == 1 or hn == 1:
                    break

        return noise / noise.std()


def RGB_2_CbCr(r: float, g: float, b: float) -> float:
    """Convert RGB channels into YCbCr for SDXL"""
    cb = -0.15 * r - 0.29 * g + 0.44 * b
    cr = 0.44 * r - 0.37 * g - 0.07 * b

    return cb, cr


original_callback = KDiffusionSampler.callback_state


@torch.inference_mode()
@wraps(original_callback)
def cc_callback(self, d):
    if not self.vec_cc["enable"]:
        return original_callback(self, d)

    if getattr(self.p, "is_hr_pass", False) and not self.vec_cc["doHR"]:
        return original_callback(self, d)

    if getattr(self.p, "_ad_inner", False) and not self.vec_cc["doAD"]:
        return original_callback(self, d)

    is_xl: bool = self.p.sd_model.is_sdxl

    mode = str(self.vec_cc["mode"])
    method = str(self.vec_cc["method"])
    source = d[mode]

    if "Straight" in method:
        target = d[mode].detach().clone()
    elif "Cross" in method:
        target = d["x" if mode == "denoised" else "denoised"].detach().clone()
    elif "Multi-Res" in method:
        target = NoiseMethods.multires_noise(d[mode], "Abs" in method)
    elif method == "Ones":
        target = NoiseMethods.ones(d[mode])
    elif method == "N.Random":
        target = NoiseMethods.normal_noise(d[mode])
    elif method == "U.Random":
        target = NoiseMethods.gaussian_noise(d[mode])
    else:
        raise ValueError

    if "Abs" in method:
        target = NoiseMethods.to_abs(target)

    batchSize = int(d[mode].size(0))

    bri, con, sat, r, g, b = apply_scaling(
        self.vec_cc["scaling"],
        d["i"],
        self.vec_cc["step"],
        self.vec_cc["bri"],
        self.vec_cc["con"],
        self.vec_cc["sat"],
        self.vec_cc["r"],
        self.vec_cc["g"],
        self.vec_cc["b"],
    )

    if not is_xl:
        for i in range(batchSize):
            # Brightness
            source[i][0] += target[i][0] * bri
            # Contrast
            source[i][0] += NoiseMethods.get_delta(source[i][0]) * con

            # R
            source[i][2] -= target[i][2] * r
            # G
            source[i][1] += target[i][1] * g
            # B
            source[i][3] -= target[i][3] * b

            # Saturation
            source[i][2] *= sat
            source[i][1] *= sat
            source[i][3] *= sat

    else:
        # But why...
        cb, cr = RGB_2_CbCr(r, b, g)

        for i in range(batchSize):
            # Brightness
            source[i][0] += target[i][0] * bri
            # Contrast
            source[i][0] += NoiseMethods.get_delta(source[i][0]) * con

            # CbCr
            source[i][1] -= target[i][1] * cr
            source[i][2] += target[i][2] * cb

            # Saturation
            source[i][1] *= sat
            source[i][2] *= sat

    return original_callback(self, d)


KDiffusionSampler.callback_state = cc_callback


def restore_callback():
    KDiffusionSampler.callback_state = original_callback


script_callbacks.on_script_unloaded(restore_callback)