File size: 6,171 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import logging
from os import environ
import modules.scripts as scripts
import gradio as gr
from dataclasses import dataclass
from typing import Any

from modules import script_callbacks
from modules.processing import StableDiffusionProcessing
from scripts.ui_wrapper import UIWrapper
from scripts.incant import IncantExtensionScript
from scripts.t2i_zero import T2I0ExtensionScript
from scripts.scfg import SCFGExtensionScript
from scripts.pag import PAGExtensionScript
from scripts.save_attn_maps import SaveAttentionMapsScript
from scripts.cfg_combiner import CFGCombinerScript
from scripts.smoothed_energy_guidance import SEGExtensionScript

logger = logging.getLogger(__name__)
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO))


"""

Author: v0xie
GitHub URL: https://github.com/v0xie/sd-webui-incantations

"""
class SubmoduleInfo:
        def __init__(self, module: UIWrapper, module_idx = 0, num_args = -1, arg_idx = -1):
                self.module: UIWrapper = module
                self.module_idx: int = num_args # the length of arg list
                self.num_args: int = num_args # the length of arg list
                self.arg_idx: int = arg_idx # where the list of args starts

# main scripts
submodules: list[SubmoduleInfo] = [
        SubmoduleInfo(module=SEGExtensionScript()),
        SubmoduleInfo(module=SCFGExtensionScript()),
        SubmoduleInfo(module=PAGExtensionScript()),
        SubmoduleInfo(module=T2I0ExtensionScript()),
        SubmoduleInfo(module=IncantExtensionScript()),
]
# debug scripts
if environ.get("INCANT_DEBUG", default=False) != False:
        submodules.append(SubmoduleInfo(module=SaveAttentionMapsScript()))
else:
        logger.info("Incantation: Debug scripts are disabled. Set INCANT_DEBUG environment variable to enable them.")
# run these after submodules
end_submodules: list[SubmoduleInfo] = [
        SubmoduleInfo(module=CFGCombinerScript())
]
submodules = submodules + end_submodules


class IncantBaseExtensionScript(scripts.Script):
        def __init__(self):
                pass

        # Extension title in menu UI
        def title(self):
                return "Incantations"

        # Decide to show menu in txt2img or img2img
        def show(self, is_img2img):
                return scripts.AlwaysVisible

        # Setup menu ui detail
        def ui(self, is_img2img):
                # setup UI
                out = []
                with gr.Accordion('Incantations', open=False):
                        for idx, module_info in enumerate(submodules):
                                module_info.module_idx = idx
                                module = module_info.module
                                module_param_list = module.setup_ui(is_img2img)
                                module_info.num_args = len(module_param_list)
                                if module_info.num_args > 0:
                                        arg_idx = max(len(out), 0)
                                        module_info.arg_idx = arg_idx
                                        out.extend(module_param_list)
                # setup fields
                self.infotext_fields = []
                self.paste_field_names = []
                for module_info in submodules:
                        module = module_info.module
                        self.infotext_fields.extend(module.get_infotext_fields())
                        self.paste_field_names.extend(module.get_paste_field_names())
                return out
        
        def before_process(self, p: StableDiffusionProcessing, *args, **kwargs):
                for m in submodules:
                        m.module.before_process(p, *self.m_args(m, *args), **kwargs)

        def process(self, p: StableDiffusionProcessing, *args, **kwargs):
                for m in submodules:
                        m.module.process(p, *self.m_args(m, *args), **kwargs)

        def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
                for m in submodules:
                        m.module.before_process_batch(p, *self.m_args(m, *args), **kwargs)
        
        def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
                for m in submodules:
                        m.module.process_batch(p, *self.m_args(m, *args), **kwargs)

        def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
                for m in submodules:
                        m.module.postprocess_batch(p, *self.m_args(m, *args), **kwargs)

        def unhook_callbacks(self):
                for m in submodules:
                        m.module.unhook_callbacks()
                script_callbacks.remove_current_script_callbacks()
        
        def m_args(self, module: SubmoduleInfo, *args):
                return args[module.arg_idx:module.arg_idx + module.num_args]


# XYZ Plot
# Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py
def make_axis_options(extra_axis_options):
        xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module
        current_opts = [x.label for x in xyz_grid.axis_options]
        # TODO: 
        for opt in extra_axis_options:
                if opt.label in current_opts:
                        return
        xyz_grid.axis_options.extend(extra_axis_options)


def callback_before_ui():
        try:
                for module_info in submodules:
                        module = module_info.module
                        try:
                                extra_axis_options = module.get_xyz_axis_options()
                        except NotImplementedError:
                                logger.warning(f"Module {module.title()} does not implement get_xyz_axis_options")
                                extra_axis_options = {}
                        make_axis_options(extra_axis_options)
        except:
                logger.exception("Incantation: Error while making axis options")

script_callbacks.on_before_ui(callback_before_ui)