extenew / extensions /CFgfade /scripts /forge_cfgfade.py
dikdimon's picture
Upload extensions using SD-Hub extension
f4a41d8 verified
import gradio as gr
from modules import scripts
import modules.shared as shared
from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks
import torch, math
#import torchvision.transforms.functional as TF
import ldm_patched.modules.samplers as LDM
import modules_forge.forge_sampler
# button to spit weighted cfg to console, better: gradio lineplot for display of weights
class CFGfadeForge(scripts.Script):
weight = 1.0
backup_sampling_function = None
def __init__(self):
self.boostStep = 0.0
self.highStep = 0.5
self.maxScale = 1.0
self.fadeStep = 0.5
self.zeroStep = 1.0
self.minScale = 0.0
self.reinhard = 1.0
self.rfcgmult = 1.0
self.centreMean = False
self.heuristic = 0
self.hStart = 0.0
def title(self):
return "CFG fade"
def show(self, is_img2img):
# make this extension visible in both txt2img and img2img tab.
return scripts.AlwaysVisible
def ui(self, *args, **kwargs):
with gr.Accordion(open=False, label=self.title()):
with gr.Row():
enabled = gr.Checkbox(value=False, label='Enable modifications to CFG')
cntrMean = gr.Checkbox(value=False, label='centre conds to mean')
# scaleCFGs = gr.Checkbox(value=False, label='scale hCFG and rCFG')
with gr.Row():
lowCFG1 = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.1, label='CFG 1 until step')
maxScale = gr.Slider(minimum=1.0, maximum=4.0, step=0.01, value=1.0, label='boost factor')
with gr.Row():
boostStep = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label='CFG boost start step')
minScale = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label='fade factor')
with gr.Row():
highStep = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.4, label='full boost at step')
heuristic = gr.Slider(minimum=0.0, maximum=16.0, step=0.5, value=0, label='Heuristic CFG')
with gr.Row():
fadeStep = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label='CFG fade start step')
hStart = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label='... start step')
with gr.Row():
zeroStep = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.7, label='full fade at step')
reinhard = gr.Slider(minimum=0.0, maximum=16.0, step=0.5, value=0.0, label='Reinhard CFG')
with gr.Row():
highCFG1 = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.8, label='CFG 1 after step')
rcfgmult = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label='Rescale CFG')
self.infotext_fields = [
(enabled, lambda d: enabled.update(value=("cfgfade_enabled" in d))),
(cntrMean, "cfgfade_cntrMean"),
(boostStep, "cfgfade_boostStep"),
(highStep, "cfgfade_highStep"),
(maxScale, "cfgfade_maxScale"),
(fadeStep, "cfgfade_fadeStep"),
(zeroStep, "cfgfade_zeroStep"),
(minScale, "cfgfade_minScale"),
(lowCFG1, "cfgfade_lowCFG1"),
(highCFG1, "cfgfade_highCFG1"),
(reinhard, "cfgfade_reinhard"),
(rcfgmult, "cfgfade_rcfgmult"),
(heuristic, "cfgfade_heuristic"),
(hStart, "cfgfade_hStart"),
]
return enabled, cntrMean, boostStep, highStep, maxScale, fadeStep, zeroStep, minScale, lowCFG1, highCFG1, reinhard, rcfgmult, heuristic, hStart
# edited from ldm_patched/modules/samplers to add cond_scaling (initial 3 lines)
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
cond_scale *= CFGfadeForge.weight
if cond_scale < 1.0:
cond_scale = 1.0
edit_strength = sum((item['strength'] if 'strength' in item else 1) for item in cond)
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
for fn in model_options.get("sampler_pre_cfg_function", []):
model, cond, uncond_, x, timestep, model_options = fn(model, cond, uncond_, x, timestep, model_options)
cond_pred, uncond_pred = LDM.calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
elif not math.isclose(edit_strength, 1.0):
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale * edit_strength
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
def patch(self, model):
# sigmin = model.model.model_sampling.sigma(model.model.model_sampling.timestep(model.model.model_sampling.sigma_min))
# sigmax = model.model.model_sampling.sigma(model.model.model_sampling.timestep(model.model.model_sampling.sigma_max))
def sampler_cfgfade(args):
cond = args["cond"]
cond_scale = args["cond_scale"]
if cond_scale == 1.0:
return cond
else:
uncond = args["uncond"]
# sometimes this scaling seems like a win, but only when heuristic/reinhard CFG is too high
# if self.scaleCFGs == True:
# heuristic = max(1.0, self.heuristic * CFGfadeForge.weight) if (self.heuristic > 0.0) else 0.0
# reinhard = max(1.0, self.reinhard * CFGfadeForge.weight) if (self.reinhard > 0.0) else 0.0
# else:
heuristic = self.heuristic
reinhard = self.reinhard
if self.centreMean == True: # better after, but value here too?
for b in range(len(cond)):
for c in range(4):
cond[b][c] -= cond[b][c].mean()
uncond[b][c] -= uncond[b][c].mean()
# cond_scale weighting now applied in sampling_function, can avoid processing of uncond for performance increase
thisStep = shared.state.sampling_step
lastStep = shared.state.sampling_steps
# heuristic scaling, higher hcfg acts to boost contrast/detail/sharpness; low reduces; quantile has effect, but not significant for quality IMO
noisePrediction = cond - uncond
if heuristic != 0.0 and heuristic != cond_scale and thisStep >= self.hStart * lastStep:
base = uncond + cond_scale * noisePrediction
heur = uncond + heuristic * noisePrediction
# center both on zero
baseC = base - base.mean()
heurC = heur - heur.mean()
del base, heur
# calc 99.0% quartiles - doesn't seem to have value as an option
baseQ = torch.quantile(baseC.abs(), 0.99)
heurQ = torch.quantile(heurC.abs(), 0.99)
del baseC, heurC
if baseQ != 0.0 and heurQ != 0.0:
cond *= (baseQ / heurQ)
uncond *= (baseQ / heurQ)
# end: heuristic scaling
# reinhard tonemap from comfy
noisePrediction = cond - uncond
if reinhard != 0.0 and reinhard != cond_scale:
multiplier = 1.0 / cond_scale * reinhard
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noisePrediction, dim=(1)) + 0.0000000001)[:,None]
noisePrediction /= noise_pred_vector_magnitude
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
top = (std * 3 + mean) * multiplier
noise_pred_vector_magnitude *= (1.0 / top)
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
new_magnitude *= top
cond_scale *= new_magnitude
# end: reinhard
# rescaleCFG - maybe should be exclusive of other effects, but why restrict?
result = uncond + cond_scale * noisePrediction
if self.rcfgmult != 0.0:
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
ro_cfg = torch.std(result, dim=(1,2,3), keepdim=True)
if ro_pos != 0.0 and ro_cfg != 0.0:
x_rescaled = result * (ro_pos / ro_cfg)
result = torch.lerp (result, x_rescaled, self.rcfgmult)
del x_rescaled
# end: rescaleCFG
del noisePrediction
return result
m = model.clone()
m.set_model_sampler_cfg_function(sampler_cfgfade)
return (m, )
def denoiser_callback(self, params):
lastStep = params.total_sampling_steps - 1
thisStep = params.sampling_step
sigma = params.sigma[0]
lowCFG1 = self.lowCFG1 * lastStep
highStep = self.highStep * lastStep
boostStep = self.boostStep * lastStep
highCFG1 = self.highCFG1 * lastStep
fadeStep = self.fadeStep * lastStep
zeroStep = self.zeroStep * lastStep
if thisStep < lowCFG1:
boostWeight = 0.0
elif thisStep < boostStep:
boostWeight = 1.0
elif thisStep < highStep:
boostWeight = 1.0 + (self.maxScale - 1.0) * ((thisStep - boostStep) / (highStep - boostStep))
else:
boostWeight = self.maxScale
if thisStep > highCFG1:
fadeWeight = 0.0
else:
if thisStep < fadeStep:
fadeWeight = 1.0
elif thisStep < zeroStep:
fadeWeight = 1.0 - (thisStep - fadeStep) / (zeroStep - fadeStep)
else:
fadeWeight = 0.0
# at this point, weight is in the range 0.0->1.0
fadeWeight *= (1.0 - self.minScale)
fadeWeight += self.minScale
# now it is minimum->1.0
CFGfadeForge.weight = boostWeight * fadeWeight
def process_before_every_sampling(self, params, *script_args, **kwargs):
enabled = script_args[0]
if enabled:
unet = params.sd_model.forge_objects.unet
unet = CFGfadeForge.patch(self, unet)[0]
params.sd_model.forge_objects.unet = unet
return
def process(self, params, *script_args, **kwargs):
enabled, cntrMean, boostStep, highStep, maxScale, fadeStep, zeroStep, minScale, lowCFG1, highCFG1, reinhard, rcfgmult, heuristic, hStart = script_args
if not enabled:
return
self.centreMean = cntrMean
self.boostStep = boostStep
self.highStep = highStep
self.maxScale = maxScale
self.fadeStep = fadeStep
self.zeroStep = zeroStep
self.minScale = minScale
self.lowCFG1 = lowCFG1
self.highCFG1 = highCFG1
self.reinhard = reinhard
self.rcfgmult = rcfgmult
self.heuristic = heuristic
self.hStart = hStart
# logs, could save boost start/full only if boost factor > 1
# similar for fade
params.extra_generation_params.update(dict(
cfgfade_enabled = enabled,
cfgfade_cntrMean = cntrMean,
cfgfade_boostStep = boostStep,
cfgfade_highStep = highStep,
cfgfade_maxScale = maxScale,
cfgfade_fadeStep = fadeStep,
cfgfade_zeroStep = zeroStep,
cfgfade_minScale = minScale,
cfgfade_lowCFG1 = lowCFG1,
cfgfade_highCFG1 = highCFG1,
cfgfade_reinhard = reinhard,
cfgfade_rcfgmult = rcfgmult,
cfgfade_heuristic = heuristic,
cfgfade_hStart = hStart,
))
# must log the parameters before fixing minScale
self.minScale /= self.maxScale
on_cfg_denoiser(self.denoiser_callback)
if CFGfadeForge.backup_sampling_function == None:
CFGfadeForge.backup_sampling_function = modules_forge.forge_sampler.sampling_function
modules_forge.forge_sampler.sampling_function = CFGfadeForge.sampling_function
return
def postprocess(self, params, processed, *args):
enabled = args[0]
if enabled:
if CFGfadeForge.backup_sampling_function != None:
modules_forge.forge_sampler.sampling_function = CFGfadeForge.backup_sampling_function
remove_current_script_callbacks()
return