|
|
from modules.sd_samplers_kdiffusion import KDiffusionSampler
|
|
|
from modules.script_callbacks import on_script_unloaded, on_ui_settings
|
|
|
from functools import wraps
|
|
|
from random import random
|
|
|
import torch
|
|
|
|
|
|
from .scaling import apply_scaling
|
|
|
from .settings import settings
|
|
|
|
|
|
|
|
|
class NoiseMethods:
|
|
|
|
|
|
@staticmethod
|
|
|
def get_delta(latent: torch.Tensor) -> torch.Tensor:
|
|
|
mean = torch.mean(latent)
|
|
|
return torch.sub(latent, mean)
|
|
|
|
|
|
@staticmethod
|
|
|
def to_abs(latent: torch.Tensor) -> torch.Tensor:
|
|
|
return torch.abs(latent)
|
|
|
|
|
|
@staticmethod
|
|
|
def zeros(latent: torch.Tensor) -> torch.Tensor:
|
|
|
return torch.zeros_like(latent)
|
|
|
|
|
|
@staticmethod
|
|
|
def ones(latent: torch.Tensor) -> torch.Tensor:
|
|
|
return torch.ones_like(latent)
|
|
|
|
|
|
@staticmethod
|
|
|
def gaussian_noise(latent: torch.Tensor) -> torch.Tensor:
|
|
|
return torch.rand_like(latent)
|
|
|
|
|
|
@staticmethod
|
|
|
def normal_noise(latent: torch.Tensor) -> torch.Tensor:
|
|
|
return torch.randn_like(latent)
|
|
|
|
|
|
@staticmethod
|
|
|
@torch.inference_mode()
|
|
|
def multires_noise(
|
|
|
latent: torch.Tensor,
|
|
|
use_zero: bool,
|
|
|
iterations: int = 10,
|
|
|
discount: float = 0.8,
|
|
|
):
|
|
|
"""
|
|
|
Credit: Kohya_SS
|
|
|
https://github.com/kohya-ss/sd-scripts/blob/v0.8.5/library/custom_train_functions.py#L448
|
|
|
"""
|
|
|
|
|
|
noise = NoiseMethods.zeros(latent) if use_zero else NoiseMethods.ones(latent)
|
|
|
device = latent.device
|
|
|
|
|
|
b, c, w, h = noise.shape
|
|
|
upsampler = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
|
|
|
|
|
for i in range(iterations):
|
|
|
r = random() * 2 + 2
|
|
|
|
|
|
wn = max(1, int(w / (r**i)))
|
|
|
hn = max(1, int(h / (r**i)))
|
|
|
|
|
|
noise += upsampler(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
|
|
|
|
|
if wn == 1 or hn == 1:
|
|
|
break
|
|
|
|
|
|
return noise / noise.std()
|
|
|
|
|
|
|
|
|
def RGB2CbCr(r: float, g: float, b: float) -> tuple[float, float]:
|
|
|
"""Convert RGB channels into YCbCr for SDXL"""
|
|
|
cb = -0.17 * r - 0.33 * g + 0.5 * b
|
|
|
cr = 0.5 * r - 0.42 * g - 0.08 * b
|
|
|
|
|
|
return cb, cr
|
|
|
|
|
|
|
|
|
original_callback = KDiffusionSampler.callback_state
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
@torch.inference_mode()
|
|
|
@wraps(original_callback)
|
|
|
def cc_callback(self, d):
|
|
|
if not self.vec_cc["enable"]:
|
|
|
return original_callback(self, d)
|
|
|
|
|
|
if getattr(self.p, "is_hr_pass", False) and not self.vec_cc["doHR"]:
|
|
|
return original_callback(self, d)
|
|
|
|
|
|
if getattr(self.p, "_ad_inner", False) and not self.vec_cc["doAD"]:
|
|
|
return original_callback(self, d)
|
|
|
|
|
|
is_xl: bool = self.p.sd_model.is_sdxl
|
|
|
|
|
|
mode = str(self.vec_cc["mode"])
|
|
|
method = str(self.vec_cc["method"])
|
|
|
source: torch.Tensor = d[mode]
|
|
|
target = None
|
|
|
|
|
|
if "Straight" in method:
|
|
|
target = d[mode].detach().clone()
|
|
|
elif "Cross" in method:
|
|
|
target = d["x" if mode == "denoised" else "denoised"].detach().clone()
|
|
|
elif "Multi-Res" in method:
|
|
|
target = NoiseMethods.multires_noise(d[mode], "Abs" in method)
|
|
|
elif method == "Ones":
|
|
|
target = NoiseMethods.ones(d[mode])
|
|
|
elif method == "N.Random":
|
|
|
target = NoiseMethods.normal_noise(d[mode])
|
|
|
elif method == "U.Random":
|
|
|
target = NoiseMethods.gaussian_noise(d[mode])
|
|
|
else:
|
|
|
raise ValueError
|
|
|
|
|
|
if "Abs" in method:
|
|
|
target = NoiseMethods.to_abs(target)
|
|
|
|
|
|
batchSize = int(d[mode].size(0))
|
|
|
|
|
|
bri, con, sat, r, g, b = apply_scaling(
|
|
|
self.vec_cc["scaling"],
|
|
|
d["i"],
|
|
|
self.vec_cc["step"],
|
|
|
self.vec_cc["bri"],
|
|
|
self.vec_cc["con"],
|
|
|
self.vec_cc["sat"],
|
|
|
self.vec_cc["r"],
|
|
|
self.vec_cc["g"],
|
|
|
self.vec_cc["b"],
|
|
|
)
|
|
|
|
|
|
if not is_xl:
|
|
|
for i in range(batchSize):
|
|
|
|
|
|
source[i][0] += target[i][0] * bri
|
|
|
|
|
|
source[i][0] += NoiseMethods.get_delta(source[i][0]) * con
|
|
|
|
|
|
|
|
|
source[i][2] -= target[i][2] * r
|
|
|
|
|
|
source[i][1] += target[i][1] * g
|
|
|
|
|
|
source[i][3] -= target[i][3] * b
|
|
|
|
|
|
|
|
|
source[i][2] *= sat
|
|
|
source[i][1] *= sat
|
|
|
source[i][3] *= sat
|
|
|
|
|
|
else:
|
|
|
cb, cr = RGB2CbCr(r, g, b)
|
|
|
|
|
|
for i in range(batchSize):
|
|
|
|
|
|
source[i][0] += target[i][0] * bri
|
|
|
|
|
|
source[i][0] += NoiseMethods.get_delta(source[i][0]) * con
|
|
|
|
|
|
|
|
|
source[i][1] -= target[i][1] * cr
|
|
|
source[i][2] -= target[i][2] * cb
|
|
|
|
|
|
|
|
|
source[i][1] *= sat
|
|
|
source[i][2] *= sat
|
|
|
|
|
|
return original_callback(self, d)
|
|
|
|
|
|
|
|
|
def restore_callback():
|
|
|
KDiffusionSampler.callback_state = original_callback
|
|
|
|
|
|
|
|
|
def hook_callbacks():
|
|
|
KDiffusionSampler.callback_state = cc_callback
|
|
|
on_script_unloaded(restore_callback)
|
|
|
on_ui_settings(settings)
|
|
|
|