File size: 11,458 Bytes
f4a41d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import gradio as gr
import os
import random
import sys
import yaml

from modules import scripts, script_callbacks, shared, paths
from modules.processing import Processed
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton
from modules.ui import random_symbol, reuse_symbol, gr_show
from modules.generation_parameters_copypaste import parse_generation_parameters
from pprint import pprint

def read_yaml():
    promptfile = os.path.join(scripts.basedir(), "prompt_tree.yml")
    with open(promptfile, "r", encoding="utf8") as stream:
        prompt_tree = yaml.safe_load(stream)
        return prompt_tree

def get_last_params(declone_seed, gallery_index):
    filename = os.path.join(paths.data_path, "params.txt")
    if os.path.exists(filename):
        with open(filename, "r", encoding="utf8") as file:
            prompt = file.read()

    if gallery_index > 0:
        gallery_index -= 1
    params = parse_generation_parameters(prompt)
    if params.get("CC_use_main_seed", "") == "True":
        return [int(float(params.get("Seed", "-0.0"))) + gallery_index, gr_show(False)]
    else:
        return [int(float(params.get("CC_declone_seed", "-0.0"))) + gallery_index, gr_show(False)]

def sorted_difference(a, b):
    newlist = list(set(a).difference(b))
    newlist.sort()
    return newlist

class CloneCleanerScript(scripts.Script):
    prompt_tree = read_yaml()   # maybe make this an instance property later

    def title(self):
        return "CloneCleaner"

    # show menu in either txt2img or img2img
    def show(self, is_img2img):
        return scripts.AlwaysVisible
    
    def ui(self, is_img2img):
        with gr.Accordion("CloneCleaner (beta!)", open=True):
            dummy_component = gr.Label(visible=False)
            regions = self.prompt_tree["country"].keys()
            hairlength = self.prompt_tree["hair"]["length"].keys()
            haircolor = self.prompt_tree["hair"]["color"].keys()
            with FormRow():
                with FormColumn(min_width=160):
                    is_enabled = gr.Checkbox(value=True, label="Enable CloneCleaner")
                with FormColumn(elem_id="CloneCleaner_gender"):
                    gender = gr.Radio(["female", "male", "generic"], value="female", label="Male & generic not yet implemented.", elem_classes="ghosted")
                    gender.style(container=False, item_container=False)
            with FormRow(elem_id="CloneCleaner_components"):
                components = ["name", "country", "hair length", "hair style", "hair color"]
                use_components = gr.CheckboxGroup(components, label="Use declone components", value=components)
            with FormRow(elem_id="CloneCleaner_midsection"):
                with FormGroup():
                    insert_start = gr.Checkbox(value=True, label="Put declone tokens at beginning of prompt")
                    declone_weight = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=1.0, label="Weight of declone tokens", elem_id="CloneCleaner_slider")
                with FormGroup():
                    use_main_seed = gr.Checkbox(value=True, label="Use main image seed for decloning")
                    with FormRow(variant="compact", elem_id="CloneCleaner_seed_row", elem_classes="ghosted"):
                        declone_seed = gr.Number(label='Declone seed', value=-1, elem_id="CloneCleaner_seed")
                        declone_seed.style(container=False)
                        random_seed = ToolButton(random_symbol, elem_id="CloneCleaner_random_seed", label='Random seed')
                        reuse_seed = ToolButton(reuse_symbol, elem_id="CloneCleaner_reuse_seed", label='Reuse seed')
            with FormRow(elem_id="CloneCleaner_exclude_row") as exclude_row:
                exclude_regions = gr.Dropdown(choices=regions, label="Exclude regions", multiselect=True)
                exclude_hairlength = gr.Dropdown(choices=hairlength, label="Exclude hair lengths", multiselect=True)
                exclude_haircolor = gr.Dropdown(choices=haircolor, label="Exclude hair colors", multiselect=True)
        
        jstoggle = "() => {document.getElementById('CloneCleaner_seed_row').classList.toggle('ghosted')}"
        jsclickseed = "() => {setRandomSeed('CloneCleaner_seed')}"
        jsgetgalleryindex = "(x, y) => [x, selected_gallery_index()]"
        other_jstoggles =   "() => {" + \
                    "const labels = document.getElementById('CloneCleaner_components').getElementsByTagName('label');" + \
                    "const excludelabels = document.getElementById('CloneCleaner_exclude_row').getElementsByTagName('label');"  + \
                    "excludelabels[1].classList.toggle('ghosted', !labels[2].firstChild.checked);" + \
                    "excludelabels[2].classList.toggle('ghosted', !labels[4].firstChild.checked);" + \
                    "}"
        use_main_seed.change(fn=None, _js=jstoggle)
        random_seed.click(fn=None, _js=jsclickseed, show_progress=False, inputs=[], outputs=[])
        reuse_seed.click(fn=get_last_params, _js=jsgetgalleryindex, show_progress=False, inputs=[declone_seed, dummy_component], outputs=[declone_seed, dummy_component])
        use_components.change(fn=None, _js=other_jstoggles)
        
        def list_from_params_key(key, params):
            regionstring = params.get(key, "")
            regions = regionstring.split(",") if regionstring else []
            return gr.update(value = regions)

        self.infotext_fields = [
            (is_enabled, "CloneCleaner enabled"),
            (gender, "CC_gender"),
            (insert_start, "CC_insert_start"),
            (declone_weight, "CC_declone_weight"),
            (use_main_seed, "CC_use_main_seed"),
            (declone_seed, "CC_declone_seed"),
            (exclude_regions, lambda params:list_from_params_key("CC_exclude_regions", params)),
            (exclude_hairlength, lambda params:list_from_params_key("CC_exclude_hairlength", params)),
            (exclude_haircolor, lambda params:list_from_params_key("CC_exclude_haircolor", params))
        ]
        return [is_enabled, gender, insert_start, declone_weight, use_main_seed, declone_seed, use_components, exclude_regions, exclude_hairlength, exclude_haircolor]

    def process(self, p, is_enabled, gender, insert_start, declone_weight, use_main_seed, declone_seed, use_components, exclude_regions, exclude_hairlength, exclude_haircolor):
        if not is_enabled:
            return

        if use_main_seed:
            declone_seed = p.all_seeds[0]
        elif declone_seed == -1:
            declone_seed = int(random.randrange(4294967294))
        else:
            declone_seed = int(declone_seed)

        # original_prompt = p.all_prompts[0]
        # settings = f"gender={gender}, beginning={insert_start}, declone_weight={declone_weight}, main_seed={use_main_seed}, " + \
        #             f"declone_seed={declone_seed}, exclude_regions={exclude_regions}"
        p.extra_generation_params["CloneCleaner enabled"] = True
        p.extra_generation_params["CC_gender"] = gender
        p.extra_generation_params["CC_insert_start"] = insert_start
        p.extra_generation_params["CC_declone_weight"] = declone_weight
        p.extra_generation_params["CC_use_main_seed"] = use_main_seed
        p.extra_generation_params["CC_declone_seed"] = declone_seed
        if exclude_regions:
            p.extra_generation_params["CC_exclude_regions"] = ",".join(exclude_regions)
        if exclude_hairlength:
            p.extra_generation_params["CC_exclude_hairlength"] = ",".join(exclude_hairlength)
        if exclude_haircolor:
            p.extra_generation_params["CC_exclude_haircolor"] = ",".join(exclude_haircolor)

        countrytree = self.prompt_tree["country"]
        hairtree = self.prompt_tree["hair"]

        regions = sorted_difference(countrytree.keys(), exclude_regions)
        hairlengths = sorted_difference(hairtree["length"].keys(), exclude_hairlength)
        haircolors = sorted_difference(hairtree["color"].keys(), exclude_haircolor)

        use_name = "name" in use_components
        use_country = "country" in use_components
        use_length = "hair length" in use_components
        use_style = "hair style" in use_components
        use_color = "hair color" in use_components

        for i, prompt in enumerate(p.all_prompts):      # for each image in batch
            rng = random.Random()
            seed = p.all_seeds[i] if use_main_seed else declone_seed + i
            rng.seed(seed)

            region = rng.choice(regions)
            countries = list(countrytree[region].keys())
            countryweights = [countrytree[region][cty]["weight"] for cty in countries]
            country = rng.choices(countries, weights=countryweights)[0]

            countrydata = countrytree[region][country]
            hairdata = countrydata.get("hair", hairtree["defaultweight"][region])
            maincolor = rng.choices(haircolors, weights=[hairdata[col] for col in haircolors])[0]
            color = rng.choice(hairtree["color"][maincolor])
            mainlength = rng.choice(hairlengths)
            length = rng.choice(hairtree["length"][mainlength])
            style = rng.choice(hairtree["style"][mainlength])
            name = rng.choice(countrydata["names"])

            inserted_prompt = ""

            if use_name or use_country:
                inserted_prompt += name if use_name else "person"
                inserted_prompt += " from " + country if use_country else ""
            
            if use_length or use_style or use_color:
                if inserted_prompt:
                    inserted_prompt += ", "
                if use_length:
                    inserted_prompt += length + " "
                if use_style:
                    inserted_prompt += style + " "
                if use_color:
                    inserted_prompt += color + " "
                inserted_prompt += "hair"

            if inserted_prompt:
                if declone_weight != 1:
                    inserted_prompt = f"({inserted_prompt}:{declone_weight})"

                if insert_start:
                    p.all_prompts[i] = inserted_prompt + ", " + prompt
                else:
                    p.all_prompts[i] = prompt + ", " + inserted_prompt
    
    # def postprocess_batch(self, p, *args, **kwargs):
    #     p.all_prompts[0] = p.prompt    # gets saved in file metadata AND in batch file metadata

    # def process_batch(self, p, *args, **kwargs):
    #     p.extra_generation_params["CC_TEST"] = "whatever"
    #     p.all_prompts[0] = p.prompt + " SUFFIX"

    def postprocess(self, p, processed, *args):
        with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
            p.all_prompts[0] = p.prompt
            processed = Processed(p, [], p.seed, "")
            file.write(processed.infotext(p, 0))

# read with shared.opts.prompt_database_path
def on_ui_settings():
    info = shared.OptionInfo("prompt_tree.yml", "CloneCleaner prompt database path", section=("clonecleaner", "CloneCleaner"))
    shared.opts.add_option("prompt_database_path", info)
    # shared.opts.add_option("option1", shared.OptionInfo(
    #     False,
    #     "option1 description",
    #     gr.Checkbox,
    #     {"interactive": True},
    #     section=('template', "Template"))
    # )


script_callbacks.on_ui_settings(on_ui_settings)