dikdimon's picture
Upload extensions using SD-Hub extension
3dabe4a verified
from modules.sd_samplers_kdiffusion import KDiffusionSampler
from modules import script_callbacks, devices
from functools import wraps
from random import random
import torch
from .scaling import apply_scaling
class NoiseMethods:
@staticmethod
def get_delta(latent: torch.Tensor) -> torch.Tensor:
mean = torch.mean(latent)
return torch.sub(latent, mean)
@staticmethod
def to_abs(latent: torch.Tensor) -> torch.Tensor:
return torch.abs(latent)
@staticmethod
def zeros(latent: torch.Tensor) -> torch.Tensor:
return torch.zeros_like(latent)
@staticmethod
def ones(latent: torch.Tensor) -> torch.Tensor:
return torch.ones_like(latent)
@staticmethod
def gaussian_noise(latent: torch.Tensor) -> torch.Tensor:
return torch.rand_like(latent)
@staticmethod
def normal_noise(latent: torch.Tensor) -> torch.Tensor:
return torch.randn_like(latent)
@staticmethod
@torch.inference_mode()
def multires_noise(
latent: torch.Tensor, use_zero: bool, iterations: int = 8, discount: float = 0.4
):
"""
Credit: Kohya_SS
https://github.com/kohya-ss/sd-scripts/blob/v0.8.5/library/custom_train_functions.py#L448
"""
noise = NoiseMethods.zeros(latent) if use_zero else NoiseMethods.ones(latent)
batchSize, c, w, h = noise.shape
device = devices.get_optimal_device()
upsampler = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for b in range(batchSize):
for i in range(iterations):
r = random() * 2 + 2
wn = max(1, int(w / (r**i)))
hn = max(1, int(h / (r**i)))
noise[b] += (
upsampler(torch.randn(1, c, hn, wn).to(device)) * discount**i
)[0]
if wn == 1 or hn == 1:
break
return noise / noise.std()
def RGB_2_CbCr(r: float, g: float, b: float) -> float:
"""Convert RGB channels into YCbCr for SDXL"""
cb = -0.15 * r - 0.29 * g + 0.44 * b
cr = 0.44 * r - 0.37 * g - 0.07 * b
return cb, cr
original_callback = KDiffusionSampler.callback_state
@torch.inference_mode()
@wraps(original_callback)
def cc_callback(self, d):
if not self.vec_cc["enable"]:
return original_callback(self, d)
if getattr(self.p, "is_hr_pass", False) and not self.vec_cc["doHR"]:
return original_callback(self, d)
if getattr(self.p, "_ad_inner", False) and not self.vec_cc["doAD"]:
return original_callback(self, d)
is_xl: bool = self.p.sd_model.is_sdxl
mode = str(self.vec_cc["mode"])
method = str(self.vec_cc["method"])
source = d[mode]
if "Straight" in method:
target = d[mode].detach().clone()
elif "Cross" in method:
target = d["x" if mode == "denoised" else "denoised"].detach().clone()
elif "Multi-Res" in method:
target = NoiseMethods.multires_noise(d[mode], "Abs" in method)
elif method == "Ones":
target = NoiseMethods.ones(d[mode])
elif method == "N.Random":
target = NoiseMethods.normal_noise(d[mode])
elif method == "U.Random":
target = NoiseMethods.gaussian_noise(d[mode])
else:
raise ValueError
if "Abs" in method:
target = NoiseMethods.to_abs(target)
batchSize = int(d[mode].size(0))
bri, con, sat, r, g, b = apply_scaling(
self.vec_cc["scaling"],
d["i"],
self.vec_cc["step"],
self.vec_cc["bri"],
self.vec_cc["con"],
self.vec_cc["sat"],
self.vec_cc["r"],
self.vec_cc["g"],
self.vec_cc["b"],
)
if not is_xl:
for i in range(batchSize):
# Brightness
source[i][0] += target[i][0] * bri
# Contrast
source[i][0] += NoiseMethods.get_delta(source[i][0]) * con
# R
source[i][2] -= target[i][2] * r
# G
source[i][1] += target[i][1] * g
# B
source[i][3] -= target[i][3] * b
# Saturation
source[i][2] *= sat
source[i][1] *= sat
source[i][3] *= sat
else:
# But why...
cb, cr = RGB_2_CbCr(r, b, g)
for i in range(batchSize):
# Brightness
source[i][0] += target[i][0] * bri
# Contrast
source[i][0] += NoiseMethods.get_delta(source[i][0]) * con
# CbCr
source[i][1] -= target[i][1] * cr
source[i][2] += target[i][2] * cb
# Saturation
source[i][1] *= sat
source[i][2] *= sat
return original_callback(self, d)
KDiffusionSampler.callback_state = cc_callback
def restore_callback():
KDiffusionSampler.callback_state = original_callback
script_callbacks.on_script_unloaded(restore_callback)