Reforge / extensions-builtin /reForge-APGIsYourCFG /scripts /advanced_model_sampling_script.py
DaoNGTKhoa
Upload folder using huggingface_hub
ecc4278 verified
import logging
import gradio as gr
from modules import scripts
from APGIsYourCFG.nodes_APGImYourCFGNow import APG_ImYourCFGNow
class APGIsNowYourCFGScript(scripts.Script):
def __init__(self):
self.enabled = False
self.apg_scale = 9.0
self.apg_moment = -0.05
self.apg_norm_thr = 15.0
self.apg_eta = 1.0
sorting_priority = 15
def title(self):
return '"APG\'s now your CFG" for reForge'
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, *args, **kwargs):
with gr.Accordion(open=False, label=self.title()):
gr.HTML("<p><i>Adjust the settings for APG's now your CFG.</i></p>")
enabled = gr.Checkbox(label="Enable APG's now your CFG", value=self.enabled)
apg_scale = gr.Slider(
label="APG Scale",
minimum=0.0,
maximum=100.0,
step=0.1,
value=self.apg_scale,
)
apg_momentum = gr.Slider(
label="APG Momentum",
minimum=-1.5,
maximum=0.5,
step=0.01,
value=self.apg_moment,
)
apg_norm_thr = gr.Slider(
label="APG Norm Threshold",
minimum=0.5,
maximum=50.0,
step=0.5,
value=self.apg_norm_thr,
)
apg_eta = gr.Slider(
label="APG Eta", minimum=0.0, maximum=1.0, step=0.1, value=self.apg_eta
)
enabled.change(lambda x: self.update_enabled(x), inputs=[enabled])
return (enabled, apg_scale, apg_momentum, apg_norm_thr, apg_eta)
def update_enabled(self, value):
self.enabled = value
def process_before_every_sampling(self, p, *args, **kwargs):
if len(args) >= 5:
(
self.enabled,
self.apg_scale,
self.apg_moment,
self.apg_norm_thr,
self.apg_eta,
) = args[:5]
else:
logging.warning(
"Not enough arguments provided to process_before_every_sampling"
)
return
# Always start with a fresh clone of the original unet
unet = p.sd_model.forge_objects.unet.clone()
if not self.enabled:
# Reset the unet to its original state
p.sd_model.forge_objects.unet = unet
return
unet = APG_ImYourCFGNow().patch(
unet, self.apg_scale, self.apg_moment, self.apg_norm_thr, self.apg_eta
)[0]
p.sd_model.forge_objects.unet = unet
args = {
"apgisyourcfg_enabled": True,
"apgisyourcfg_scale": self.apg_scale,
"apgisyourcfg_momentum": self.apg_moment,
"apgisyourcfg_norm_thr": self.apg_norm_thr,
"apgisyourcfg_eta": self.apg_eta,
}
p.extra_generation_params.update(args)
str_args:str = ", ".join([f"{k}:\"{v}\"" for k,v in args.items()])
logging.debug("WOLOLO: \"APG is now your CFG!\"")
logging.debug(str_args)
return