File size: 9,518 Bytes
c336648 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import math
import torch
import re
import gradio as gr
import numpy as np
import modules.scripts as scripts
import modules.images as saving
from modules import devices, processing, shared, sd_samplers_kdiffusion, sd_samplers_compvis, script_callbacks
from modules.processing import Processed
from modules.shared import opts, state
from ldm.models.diffusion import ddim
from PIL import Image
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, noise_like
re_prompt_cfgr = re.compile(r"<cfg_rescale:([^>]+)>")
class Script(scripts.Script):
def __init__(self):
self.old_denoising = sd_samplers_kdiffusion.CFGDenoiser.combine_denoised
self.old_schedule = ddim.DDIMSampler.make_schedule
self.old_sample = ddim.DDIMSampler.p_sample_ddim
globals()['enable_furry_cocks'] = True
def find_module(module_names):
if isinstance(module_names, str):
module_names = [s.strip() for s in module_names.split(",")]
for data in scripts.scripts_data:
if data.script_class.__module__ in module_names and hasattr(data, "module"):
return data.module
return None
def rescale_opt(p, x, xs):
globals()['cfg_rescale_fi'] = x
globals()['enable_furry_cocks'] = False
xyz_grid = find_module("xyz_grid.py, xy_grid.py")
if xyz_grid:
extra_axis_options = [xyz_grid.AxisOption("Rescale CFG", float, rescale_opt)]
xyz_grid.axis_options.extend(extra_axis_options)
def title(self):
return "CFG Rescale Extension"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Accordion("CFG Rescale", open=True, elem_id="cfg_rescale"):
rescale = gr.Slider(label="CFG Rescale", show_label=False, minimum=0.0, maximum=1.0, step=0.01, value=0.0)
with gr.Row():
recolor = gr.Checkbox(label="Auto Color Fix", default=False)
rec_strength = gr.Slider(label="Fix Strength", interactive=True, visible=False,
elem_id=self.elem_id("rec_strength"), minimum=0.1, maximum=10.0, step=0.1,
value=1.0)
show_original = gr.Checkbox(label="Keep Original Images", elem_id=self.elem_id("show_original"), visible=False, default=False)
def show_recolor_strength(rec_checked):
return [gr.update(visible=rec_checked), gr.update(visible=rec_checked)]
recolor.change(
fn=show_recolor_strength,
inputs=recolor,
outputs=[rec_strength, show_original]
)
self.infotext_fields = [
(rescale, "CFG Rescale"),
(recolor, "Auto Color Fix")
]
self.paste_field_names = []
for _, field_name in self.infotext_fields:
self.paste_field_names.append(field_name)
return [rescale, recolor, rec_strength, show_original]
def cfg_replace(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
fi = globals()['cfg_rescale_fi']
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
if fi == 0:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
else:
xcfg = (denoised_uncond[i] + (x_out[cond_index] - denoised_uncond[i]) * (cond_scale * weight))
xrescaled = (torch.std(x_out[cond_index]) / torch.std(xcfg))
xfinal = fi * xrescaled + (1.0 - fi)
denoised[i] = xfinal * xcfg
return denoised
def process(self, p, rescale, recolor, rec_strength, show_original):
if globals()['enable_furry_cocks']:
# parse <cfg_rescale:[number]> from prompt for override
rescale_override = None
def found(m):
nonlocal rescale_override
try:
rescale_override = float(m.group(1))
except ValueError:
rescale_override = None
return ""
p.prompt = re.sub(re_prompt_cfgr, found, p.prompt)
if rescale_override is not None:
rescale = rescale_override
globals()['cfg_rescale_fi'] = rescale
else:
# rescale value is being set from xyz_grid
rescale = globals()['cfg_rescale_fi']
globals()['enable_furry_cocks'] = True
sd_samplers_kdiffusion.CFGDenoiser.combine_denoised = self.cfg_replace
if rescale > 0:
p.extra_generation_params["CFG Rescale"] = rescale
if recolor:
p.extra_generation_params["Auto Color Fix Strength"] = rec_strength
p.do_not_save_samples = True
def postprocess_batch_list(self, p, pp, rescale, recolor, rec_strength, show_original, batch_number):
if recolor and show_original:
num = len(pp.images)
for i in range(num):
pp.images.append(pp.images[i])
p.prompts.append(p.prompts[i])
p.negative_prompts.append(p.negative_prompts[i])
p.seeds.append(p.seeds[i])
p.subseeds.append(p.subseeds[i])
def postprocess(self, p, processed, rescale, recolor, rec_strength, show_original):
sd_samplers_kdiffusion.CFGDenoiser.combine_denoised = self.old_denoising
def postfix(img, rec_strength):
prec = 0.0005 * rec_strength
r, g, b = img.split()
# softer effect
# r_min, r_max = np.percentile(r, p), np.percentile(r, 100.0 - p)
# g_min, g_max = np.percentile(g, p), np.percentile(g, 100.0 - p)
# b_min, b_max = np.percentile(b, p), np.percentile(b, 100.0 - p)
rh, rbins = np.histogram(r, 256, (0, 256))
tmp = np.where(rh > rh.sum() * prec)[0]
r_min = tmp.min()
r_max = tmp.max()
gh, gbins = np.histogram(g, 256, (0, 256))
tmp = np.where(gh > gh.sum() * prec)[0]
g_min = tmp.min()
g_max = tmp.max()
bh, bbins = np.histogram(b, 256, (0, 256))
tmp = np.where(bh > bh.sum() * prec)[0]
b_min = tmp.min()
b_max = tmp.max()
r = r.point(lambda i: int(255 * (min(max(i, r_min), r_max) - r_min) / (r_max - r_min)))
g = g.point(lambda i: int(255 * (min(max(i, g_min), g_max) - g_min) / (g_max - g_min)))
b = b.point(lambda i: int(255 * (min(max(i, b_min), b_max) - b_min) / (b_max - b_min)))
new_img = Image.merge("RGB", (r, g, b))
return new_img
if recolor:
grab = 0
n_img = len(processed.images)
for i in range(n_img):
doit = False
if show_original:
check = i
if opts.return_grid:
if i == 0:
continue
else:
check = check - 1
doit = check % (p.batch_size * 2) >= p.batch_size
else:
if n_img > 1 and i != 0:
doit = True
elif n_img == 1 or not opts.return_grid:
doit = True
if doit:
res_img = postfix(processed.images[i], rec_strength)
if opts.samples_save:
ind = grab
grab += 1
prompt_infotext = processing.create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds,
index=ind)
# Save images to disk
if opts.samples_save:
saving.save_image(processed.images[i], p.outpath_samples, "", seed=p.all_seeds[ind],
prompt=p.all_prompts[ind],
info=prompt_infotext, p=p, suffix="colorfix")
saving.save_image(res_img, p.outpath_samples, "", seed=p.all_seeds[ind],
prompt=p.all_prompts[ind],
info=prompt_infotext, p=p, suffix="colorfix")
processed.images[i] = res_img
def on_infotext_pasted(infotext, params):
if "CFG Rescale" not in params:
params["CFG Rescale"] = 0
if "CFG Rescale φ" in params:
params["CFG Rescale"] = params["CFG Rescale φ"]
del params["CFG Rescale φ"]
if "CFG Rescale phi" in params and scripts.scripts_txt2img.script("Neutral Prompt") is None:
params["CFG Rescale"] = params["CFG Rescale phi"]
del params["CFG Rescale phi"]
if "DDIM Trailing" not in params:
params["DDIM Trailing"] = False
script_callbacks.on_infotext_pasted(on_infotext_pasted)
|