File size: 6,171 Bytes
3dabe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import logging
from os import environ
import modules.scripts as scripts
import gradio as gr
from dataclasses import dataclass
from typing import Any
from modules import script_callbacks
from modules.processing import StableDiffusionProcessing
from scripts.ui_wrapper import UIWrapper
from scripts.incant import IncantExtensionScript
from scripts.t2i_zero import T2I0ExtensionScript
from scripts.scfg import SCFGExtensionScript
from scripts.pag import PAGExtensionScript
from scripts.save_attn_maps import SaveAttentionMapsScript
from scripts.cfg_combiner import CFGCombinerScript
from scripts.smoothed_energy_guidance import SEGExtensionScript
logger = logging.getLogger(__name__)
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO))
"""
Author: v0xie
GitHub URL: https://github.com/v0xie/sd-webui-incantations
"""
class SubmoduleInfo:
def __init__(self, module: UIWrapper, module_idx = 0, num_args = -1, arg_idx = -1):
self.module: UIWrapper = module
self.module_idx: int = num_args # the length of arg list
self.num_args: int = num_args # the length of arg list
self.arg_idx: int = arg_idx # where the list of args starts
# main scripts
submodules: list[SubmoduleInfo] = [
SubmoduleInfo(module=SEGExtensionScript()),
SubmoduleInfo(module=SCFGExtensionScript()),
SubmoduleInfo(module=PAGExtensionScript()),
SubmoduleInfo(module=T2I0ExtensionScript()),
SubmoduleInfo(module=IncantExtensionScript()),
]
# debug scripts
if environ.get("INCANT_DEBUG", default=False) != False:
submodules.append(SubmoduleInfo(module=SaveAttentionMapsScript()))
else:
logger.info("Incantation: Debug scripts are disabled. Set INCANT_DEBUG environment variable to enable them.")
# run these after submodules
end_submodules: list[SubmoduleInfo] = [
SubmoduleInfo(module=CFGCombinerScript())
]
submodules = submodules + end_submodules
class IncantBaseExtensionScript(scripts.Script):
def __init__(self):
pass
# Extension title in menu UI
def title(self):
return "Incantations"
# Decide to show menu in txt2img or img2img
def show(self, is_img2img):
return scripts.AlwaysVisible
# Setup menu ui detail
def ui(self, is_img2img):
# setup UI
out = []
with gr.Accordion('Incantations', open=False):
for idx, module_info in enumerate(submodules):
module_info.module_idx = idx
module = module_info.module
module_param_list = module.setup_ui(is_img2img)
module_info.num_args = len(module_param_list)
if module_info.num_args > 0:
arg_idx = max(len(out), 0)
module_info.arg_idx = arg_idx
out.extend(module_param_list)
# setup fields
self.infotext_fields = []
self.paste_field_names = []
for module_info in submodules:
module = module_info.module
self.infotext_fields.extend(module.get_infotext_fields())
self.paste_field_names.extend(module.get_paste_field_names())
return out
def before_process(self, p: StableDiffusionProcessing, *args, **kwargs):
for m in submodules:
m.module.before_process(p, *self.m_args(m, *args), **kwargs)
def process(self, p: StableDiffusionProcessing, *args, **kwargs):
for m in submodules:
m.module.process(p, *self.m_args(m, *args), **kwargs)
def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
for m in submodules:
m.module.before_process_batch(p, *self.m_args(m, *args), **kwargs)
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
for m in submodules:
m.module.process_batch(p, *self.m_args(m, *args), **kwargs)
def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
for m in submodules:
m.module.postprocess_batch(p, *self.m_args(m, *args), **kwargs)
def unhook_callbacks(self):
for m in submodules:
m.module.unhook_callbacks()
script_callbacks.remove_current_script_callbacks()
def m_args(self, module: SubmoduleInfo, *args):
return args[module.arg_idx:module.arg_idx + module.num_args]
# XYZ Plot
# Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py
def make_axis_options(extra_axis_options):
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module
current_opts = [x.label for x in xyz_grid.axis_options]
# TODO:
for opt in extra_axis_options:
if opt.label in current_opts:
return
xyz_grid.axis_options.extend(extra_axis_options)
def callback_before_ui():
try:
for module_info in submodules:
module = module_info.module
try:
extra_axis_options = module.get_xyz_axis_options()
except NotImplementedError:
logger.warning(f"Module {module.title()} does not implement get_xyz_axis_options")
extra_axis_options = {}
make_axis_options(extra_axis_options)
except:
logger.exception("Incantation: Error while making axis options")
script_callbacks.on_before_ui(callback_before_ui)
|