|
|
import math
|
|
|
import torch
|
|
|
import re
|
|
|
import gradio as gr
|
|
|
import numpy as np
|
|
|
import modules.scripts as scripts
|
|
|
import modules.images as saving
|
|
|
from modules import devices, processing, shared, sd_samplers_kdiffusion, sd_samplers_compvis, script_callbacks
|
|
|
from modules.processing import Processed
|
|
|
from modules.shared import opts, state
|
|
|
from ldm.models.diffusion import ddim
|
|
|
from PIL import Image
|
|
|
|
|
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, noise_like
|
|
|
|
|
|
re_prompt_cfgr = re.compile(r"<cfg_rescale:([^>]+)>")
|
|
|
|
|
|
class Script(scripts.Script):
|
|
|
|
|
|
def __init__(self):
|
|
|
self.old_denoising = sd_samplers_kdiffusion.CFGDenoiser.combine_denoised
|
|
|
self.old_schedule = ddim.DDIMSampler.make_schedule
|
|
|
self.old_sample = ddim.DDIMSampler.p_sample_ddim
|
|
|
globals()['enable_furry_cocks'] = True
|
|
|
|
|
|
def find_module(module_names):
|
|
|
if isinstance(module_names, str):
|
|
|
module_names = [s.strip() for s in module_names.split(",")]
|
|
|
for data in scripts.scripts_data:
|
|
|
if data.script_class.__module__ in module_names and hasattr(data, "module"):
|
|
|
return data.module
|
|
|
return None
|
|
|
|
|
|
def rescale_opt(p, x, xs):
|
|
|
globals()['cfg_rescale_fi'] = x
|
|
|
globals()['enable_furry_cocks'] = False
|
|
|
|
|
|
xyz_grid = find_module("xyz_grid.py, xy_grid.py")
|
|
|
if xyz_grid:
|
|
|
extra_axis_options = [xyz_grid.AxisOption("Rescale CFG", float, rescale_opt)]
|
|
|
xyz_grid.axis_options.extend(extra_axis_options)
|
|
|
|
|
|
def title(self):
|
|
|
return "CFG Rescale Extension"
|
|
|
|
|
|
def show(self, is_img2img):
|
|
|
return scripts.AlwaysVisible
|
|
|
|
|
|
def ui(self, is_img2img):
|
|
|
with gr.Accordion("CFG Rescale", open=True, elem_id="cfg_rescale"):
|
|
|
rescale = gr.Slider(label="CFG Rescale", show_label=False, minimum=0.0, maximum=1.0, step=0.01, value=0.0)
|
|
|
with gr.Row():
|
|
|
recolor = gr.Checkbox(label="Auto Color Fix", default=False)
|
|
|
rec_strength = gr.Slider(label="Fix Strength", interactive=True, visible=False,
|
|
|
elem_id=self.elem_id("rec_strength"), minimum=0.1, maximum=10.0, step=0.1,
|
|
|
value=1.0)
|
|
|
show_original = gr.Checkbox(label="Keep Original Images", elem_id=self.elem_id("show_original"), visible=False, default=False)
|
|
|
|
|
|
def show_recolor_strength(rec_checked):
|
|
|
return [gr.update(visible=rec_checked), gr.update(visible=rec_checked)]
|
|
|
|
|
|
recolor.change(
|
|
|
fn=show_recolor_strength,
|
|
|
inputs=recolor,
|
|
|
outputs=[rec_strength, show_original]
|
|
|
)
|
|
|
|
|
|
self.infotext_fields = [
|
|
|
(rescale, "CFG Rescale"),
|
|
|
(recolor, "Auto Color Fix")
|
|
|
]
|
|
|
self.paste_field_names = []
|
|
|
for _, field_name in self.infotext_fields:
|
|
|
self.paste_field_names.append(field_name)
|
|
|
return [rescale, recolor, rec_strength, show_original]
|
|
|
|
|
|
def cfg_replace(self, x_out, conds_list, uncond, cond_scale):
|
|
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
|
|
denoised = torch.clone(denoised_uncond)
|
|
|
fi = globals()['cfg_rescale_fi']
|
|
|
|
|
|
for i, conds in enumerate(conds_list):
|
|
|
for cond_index, weight in conds:
|
|
|
if fi == 0:
|
|
|
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
|
|
else:
|
|
|
xcfg = (denoised_uncond[i] + (x_out[cond_index] - denoised_uncond[i]) * (cond_scale * weight))
|
|
|
xrescaled = (torch.std(x_out[cond_index]) / torch.std(xcfg))
|
|
|
xfinal = fi * xrescaled + (1.0 - fi)
|
|
|
denoised[i] = xfinal * xcfg
|
|
|
|
|
|
return denoised
|
|
|
|
|
|
def process(self, p, rescale, recolor, rec_strength, show_original):
|
|
|
|
|
|
if globals()['enable_furry_cocks']:
|
|
|
|
|
|
rescale_override = None
|
|
|
def found(m):
|
|
|
nonlocal rescale_override
|
|
|
try:
|
|
|
rescale_override = float(m.group(1))
|
|
|
except ValueError:
|
|
|
rescale_override = None
|
|
|
return ""
|
|
|
p.prompt = re.sub(re_prompt_cfgr, found, p.prompt)
|
|
|
if rescale_override is not None:
|
|
|
rescale = rescale_override
|
|
|
|
|
|
globals()['cfg_rescale_fi'] = rescale
|
|
|
else:
|
|
|
|
|
|
rescale = globals()['cfg_rescale_fi']
|
|
|
globals()['enable_furry_cocks'] = True
|
|
|
|
|
|
sd_samplers_kdiffusion.CFGDenoiser.combine_denoised = self.cfg_replace
|
|
|
|
|
|
if rescale > 0:
|
|
|
p.extra_generation_params["CFG Rescale"] = rescale
|
|
|
|
|
|
if recolor:
|
|
|
p.extra_generation_params["Auto Color Fix Strength"] = rec_strength
|
|
|
p.do_not_save_samples = True
|
|
|
|
|
|
def postprocess_batch_list(self, p, pp, rescale, recolor, rec_strength, show_original, batch_number):
|
|
|
if recolor and show_original:
|
|
|
num = len(pp.images)
|
|
|
for i in range(num):
|
|
|
pp.images.append(pp.images[i])
|
|
|
p.prompts.append(p.prompts[i])
|
|
|
p.negative_prompts.append(p.negative_prompts[i])
|
|
|
p.seeds.append(p.seeds[i])
|
|
|
p.subseeds.append(p.subseeds[i])
|
|
|
|
|
|
def postprocess(self, p, processed, rescale, recolor, rec_strength, show_original):
|
|
|
sd_samplers_kdiffusion.CFGDenoiser.combine_denoised = self.old_denoising
|
|
|
|
|
|
def postfix(img, rec_strength):
|
|
|
prec = 0.0005 * rec_strength
|
|
|
r, g, b = img.split()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rh, rbins = np.histogram(r, 256, (0, 256))
|
|
|
tmp = np.where(rh > rh.sum() * prec)[0]
|
|
|
r_min = tmp.min()
|
|
|
r_max = tmp.max()
|
|
|
|
|
|
gh, gbins = np.histogram(g, 256, (0, 256))
|
|
|
tmp = np.where(gh > gh.sum() * prec)[0]
|
|
|
g_min = tmp.min()
|
|
|
g_max = tmp.max()
|
|
|
|
|
|
bh, bbins = np.histogram(b, 256, (0, 256))
|
|
|
tmp = np.where(bh > bh.sum() * prec)[0]
|
|
|
b_min = tmp.min()
|
|
|
b_max = tmp.max()
|
|
|
|
|
|
r = r.point(lambda i: int(255 * (min(max(i, r_min), r_max) - r_min) / (r_max - r_min)))
|
|
|
g = g.point(lambda i: int(255 * (min(max(i, g_min), g_max) - g_min) / (g_max - g_min)))
|
|
|
b = b.point(lambda i: int(255 * (min(max(i, b_min), b_max) - b_min) / (b_max - b_min)))
|
|
|
|
|
|
new_img = Image.merge("RGB", (r, g, b))
|
|
|
|
|
|
return new_img
|
|
|
|
|
|
if recolor:
|
|
|
grab = 0
|
|
|
n_img = len(processed.images)
|
|
|
for i in range(n_img):
|
|
|
doit = False
|
|
|
|
|
|
if show_original:
|
|
|
check = i
|
|
|
if opts.return_grid:
|
|
|
if i == 0:
|
|
|
continue
|
|
|
else:
|
|
|
check = check - 1
|
|
|
doit = check % (p.batch_size * 2) >= p.batch_size
|
|
|
else:
|
|
|
if n_img > 1 and i != 0:
|
|
|
doit = True
|
|
|
elif n_img == 1 or not opts.return_grid:
|
|
|
doit = True
|
|
|
|
|
|
if doit:
|
|
|
res_img = postfix(processed.images[i], rec_strength)
|
|
|
if opts.samples_save:
|
|
|
ind = grab
|
|
|
grab += 1
|
|
|
prompt_infotext = processing.create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds,
|
|
|
index=ind)
|
|
|
|
|
|
if opts.samples_save:
|
|
|
saving.save_image(processed.images[i], p.outpath_samples, "", seed=p.all_seeds[ind],
|
|
|
prompt=p.all_prompts[ind],
|
|
|
info=prompt_infotext, p=p, suffix="colorfix")
|
|
|
saving.save_image(res_img, p.outpath_samples, "", seed=p.all_seeds[ind],
|
|
|
prompt=p.all_prompts[ind],
|
|
|
info=prompt_infotext, p=p, suffix="colorfix")
|
|
|
|
|
|
processed.images[i] = res_img
|
|
|
|
|
|
|
|
|
def on_infotext_pasted(infotext, params):
|
|
|
if "CFG Rescale" not in params:
|
|
|
params["CFG Rescale"] = 0
|
|
|
|
|
|
if "CFG Rescale φ" in params:
|
|
|
params["CFG Rescale"] = params["CFG Rescale φ"]
|
|
|
del params["CFG Rescale φ"]
|
|
|
|
|
|
if "CFG Rescale phi" in params and scripts.scripts_txt2img.script("Neutral Prompt") is None:
|
|
|
params["CFG Rescale"] = params["CFG Rescale phi"]
|
|
|
del params["CFG Rescale phi"]
|
|
|
|
|
|
if "DDIM Trailing" not in params:
|
|
|
params["DDIM Trailing"] = False
|
|
|
|
|
|
|
|
|
script_callbacks.on_infotext_pasted(on_infotext_pasted)
|
|
|
|