|
|
import logging |
|
|
from os import environ |
|
|
import modules.scripts as scripts |
|
|
import gradio as gr |
|
|
from dataclasses import dataclass |
|
|
from typing import Any |
|
|
|
|
|
from modules import script_callbacks |
|
|
from modules.processing import StableDiffusionProcessing |
|
|
from scripts.ui_wrapper import UIWrapper |
|
|
from scripts.incant import IncantExtensionScript |
|
|
from scripts.t2i_zero import T2I0ExtensionScript |
|
|
from scripts.scfg import SCFGExtensionScript |
|
|
from scripts.pag import PAGExtensionScript |
|
|
from scripts.save_attn_maps import SaveAttentionMapsScript |
|
|
from scripts.cfg_combiner import CFGCombinerScript |
|
|
from scripts.smoothed_energy_guidance import SEGExtensionScript |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
Author: v0xie |
|
|
GitHub URL: https://github.com/v0xie/sd-webui-incantations |
|
|
|
|
|
""" |
|
|
class SubmoduleInfo: |
|
|
def __init__(self, module: UIWrapper, module_idx = 0, num_args = -1, arg_idx = -1): |
|
|
self.module: UIWrapper = module |
|
|
self.module_idx: int = num_args |
|
|
self.num_args: int = num_args |
|
|
self.arg_idx: int = arg_idx |
|
|
|
|
|
|
|
|
submodules: list[SubmoduleInfo] = [ |
|
|
SubmoduleInfo(module=SEGExtensionScript()), |
|
|
SubmoduleInfo(module=SCFGExtensionScript()), |
|
|
SubmoduleInfo(module=PAGExtensionScript()), |
|
|
SubmoduleInfo(module=T2I0ExtensionScript()), |
|
|
SubmoduleInfo(module=IncantExtensionScript()), |
|
|
] |
|
|
|
|
|
if environ.get("INCANT_DEBUG", default=False) != False: |
|
|
submodules.append(SubmoduleInfo(module=SaveAttentionMapsScript())) |
|
|
else: |
|
|
logger.info("Incantation: Debug scripts are disabled. Set INCANT_DEBUG environment variable to enable them.") |
|
|
|
|
|
end_submodules: list[SubmoduleInfo] = [ |
|
|
SubmoduleInfo(module=CFGCombinerScript()) |
|
|
] |
|
|
submodules = submodules + end_submodules |
|
|
|
|
|
|
|
|
class IncantBaseExtensionScript(scripts.Script): |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
|
|
|
def title(self): |
|
|
return "Incantations" |
|
|
|
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
|
|
|
def ui(self, is_img2img): |
|
|
|
|
|
out = [] |
|
|
with gr.Accordion('Incantations', open=False): |
|
|
for idx, module_info in enumerate(submodules): |
|
|
module_info.module_idx = idx |
|
|
module = module_info.module |
|
|
module_param_list = module.setup_ui(is_img2img) |
|
|
module_info.num_args = len(module_param_list) |
|
|
if module_info.num_args > 0: |
|
|
arg_idx = max(len(out), 0) |
|
|
module_info.arg_idx = arg_idx |
|
|
out.extend(module_param_list) |
|
|
|
|
|
self.infotext_fields = [] |
|
|
self.paste_field_names = [] |
|
|
for module_info in submodules: |
|
|
module = module_info.module |
|
|
self.infotext_fields.extend(module.get_infotext_fields()) |
|
|
self.paste_field_names.extend(module.get_paste_field_names()) |
|
|
return out |
|
|
|
|
|
def before_process(self, p: StableDiffusionProcessing, *args, **kwargs): |
|
|
for m in submodules: |
|
|
m.module.before_process(p, *self.m_args(m, *args), **kwargs) |
|
|
|
|
|
def process(self, p: StableDiffusionProcessing, *args, **kwargs): |
|
|
for m in submodules: |
|
|
m.module.process(p, *self.m_args(m, *args), **kwargs) |
|
|
|
|
|
def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): |
|
|
for m in submodules: |
|
|
m.module.before_process_batch(p, *self.m_args(m, *args), **kwargs) |
|
|
|
|
|
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): |
|
|
for m in submodules: |
|
|
m.module.process_batch(p, *self.m_args(m, *args), **kwargs) |
|
|
|
|
|
def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs): |
|
|
for m in submodules: |
|
|
m.module.postprocess_batch(p, *self.m_args(m, *args), **kwargs) |
|
|
|
|
|
def unhook_callbacks(self): |
|
|
for m in submodules: |
|
|
m.module.unhook_callbacks() |
|
|
script_callbacks.remove_current_script_callbacks() |
|
|
|
|
|
def m_args(self, module: SubmoduleInfo, *args): |
|
|
return args[module.arg_idx:module.arg_idx + module.num_args] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_axis_options(extra_axis_options): |
|
|
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module |
|
|
current_opts = [x.label for x in xyz_grid.axis_options] |
|
|
|
|
|
for opt in extra_axis_options: |
|
|
if opt.label in current_opts: |
|
|
return |
|
|
xyz_grid.axis_options.extend(extra_axis_options) |
|
|
|
|
|
|
|
|
def callback_before_ui(): |
|
|
try: |
|
|
for module_info in submodules: |
|
|
module = module_info.module |
|
|
try: |
|
|
extra_axis_options = module.get_xyz_axis_options() |
|
|
except NotImplementedError: |
|
|
logger.warning(f"Module {module.title()} does not implement get_xyz_axis_options") |
|
|
extra_axis_options = {} |
|
|
make_axis_options(extra_axis_options) |
|
|
except: |
|
|
logger.exception("Incantation: Error while making axis options") |
|
|
|
|
|
script_callbacks.on_before_ui(callback_before_ui) |
|
|
|