File size: 11,458 Bytes
f4a41d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import gradio as gr
import os
import random
import sys
import yaml
from modules import scripts, script_callbacks, shared, paths
from modules.processing import Processed
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton
from modules.ui import random_symbol, reuse_symbol, gr_show
from modules.generation_parameters_copypaste import parse_generation_parameters
from pprint import pprint
def read_yaml():
promptfile = os.path.join(scripts.basedir(), "prompt_tree.yml")
with open(promptfile, "r", encoding="utf8") as stream:
prompt_tree = yaml.safe_load(stream)
return prompt_tree
def get_last_params(declone_seed, gallery_index):
filename = os.path.join(paths.data_path, "params.txt")
if os.path.exists(filename):
with open(filename, "r", encoding="utf8") as file:
prompt = file.read()
if gallery_index > 0:
gallery_index -= 1
params = parse_generation_parameters(prompt)
if params.get("CC_use_main_seed", "") == "True":
return [int(float(params.get("Seed", "-0.0"))) + gallery_index, gr_show(False)]
else:
return [int(float(params.get("CC_declone_seed", "-0.0"))) + gallery_index, gr_show(False)]
def sorted_difference(a, b):
newlist = list(set(a).difference(b))
newlist.sort()
return newlist
class CloneCleanerScript(scripts.Script):
prompt_tree = read_yaml() # maybe make this an instance property later
def title(self):
return "CloneCleaner"
# show menu in either txt2img or img2img
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Accordion("CloneCleaner (beta!)", open=True):
dummy_component = gr.Label(visible=False)
regions = self.prompt_tree["country"].keys()
hairlength = self.prompt_tree["hair"]["length"].keys()
haircolor = self.prompt_tree["hair"]["color"].keys()
with FormRow():
with FormColumn(min_width=160):
is_enabled = gr.Checkbox(value=True, label="Enable CloneCleaner")
with FormColumn(elem_id="CloneCleaner_gender"):
gender = gr.Radio(["female", "male", "generic"], value="female", label="Male & generic not yet implemented.", elem_classes="ghosted")
gender.style(container=False, item_container=False)
with FormRow(elem_id="CloneCleaner_components"):
components = ["name", "country", "hair length", "hair style", "hair color"]
use_components = gr.CheckboxGroup(components, label="Use declone components", value=components)
with FormRow(elem_id="CloneCleaner_midsection"):
with FormGroup():
insert_start = gr.Checkbox(value=True, label="Put declone tokens at beginning of prompt")
declone_weight = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=1.0, label="Weight of declone tokens", elem_id="CloneCleaner_slider")
with FormGroup():
use_main_seed = gr.Checkbox(value=True, label="Use main image seed for decloning")
with FormRow(variant="compact", elem_id="CloneCleaner_seed_row", elem_classes="ghosted"):
declone_seed = gr.Number(label='Declone seed', value=-1, elem_id="CloneCleaner_seed")
declone_seed.style(container=False)
random_seed = ToolButton(random_symbol, elem_id="CloneCleaner_random_seed", label='Random seed')
reuse_seed = ToolButton(reuse_symbol, elem_id="CloneCleaner_reuse_seed", label='Reuse seed')
with FormRow(elem_id="CloneCleaner_exclude_row") as exclude_row:
exclude_regions = gr.Dropdown(choices=regions, label="Exclude regions", multiselect=True)
exclude_hairlength = gr.Dropdown(choices=hairlength, label="Exclude hair lengths", multiselect=True)
exclude_haircolor = gr.Dropdown(choices=haircolor, label="Exclude hair colors", multiselect=True)
jstoggle = "() => {document.getElementById('CloneCleaner_seed_row').classList.toggle('ghosted')}"
jsclickseed = "() => {setRandomSeed('CloneCleaner_seed')}"
jsgetgalleryindex = "(x, y) => [x, selected_gallery_index()]"
other_jstoggles = "() => {" + \
"const labels = document.getElementById('CloneCleaner_components').getElementsByTagName('label');" + \
"const excludelabels = document.getElementById('CloneCleaner_exclude_row').getElementsByTagName('label');" + \
"excludelabels[1].classList.toggle('ghosted', !labels[2].firstChild.checked);" + \
"excludelabels[2].classList.toggle('ghosted', !labels[4].firstChild.checked);" + \
"}"
use_main_seed.change(fn=None, _js=jstoggle)
random_seed.click(fn=None, _js=jsclickseed, show_progress=False, inputs=[], outputs=[])
reuse_seed.click(fn=get_last_params, _js=jsgetgalleryindex, show_progress=False, inputs=[declone_seed, dummy_component], outputs=[declone_seed, dummy_component])
use_components.change(fn=None, _js=other_jstoggles)
def list_from_params_key(key, params):
regionstring = params.get(key, "")
regions = regionstring.split(",") if regionstring else []
return gr.update(value = regions)
self.infotext_fields = [
(is_enabled, "CloneCleaner enabled"),
(gender, "CC_gender"),
(insert_start, "CC_insert_start"),
(declone_weight, "CC_declone_weight"),
(use_main_seed, "CC_use_main_seed"),
(declone_seed, "CC_declone_seed"),
(exclude_regions, lambda params:list_from_params_key("CC_exclude_regions", params)),
(exclude_hairlength, lambda params:list_from_params_key("CC_exclude_hairlength", params)),
(exclude_haircolor, lambda params:list_from_params_key("CC_exclude_haircolor", params))
]
return [is_enabled, gender, insert_start, declone_weight, use_main_seed, declone_seed, use_components, exclude_regions, exclude_hairlength, exclude_haircolor]
def process(self, p, is_enabled, gender, insert_start, declone_weight, use_main_seed, declone_seed, use_components, exclude_regions, exclude_hairlength, exclude_haircolor):
if not is_enabled:
return
if use_main_seed:
declone_seed = p.all_seeds[0]
elif declone_seed == -1:
declone_seed = int(random.randrange(4294967294))
else:
declone_seed = int(declone_seed)
# original_prompt = p.all_prompts[0]
# settings = f"gender={gender}, beginning={insert_start}, declone_weight={declone_weight}, main_seed={use_main_seed}, " + \
# f"declone_seed={declone_seed}, exclude_regions={exclude_regions}"
p.extra_generation_params["CloneCleaner enabled"] = True
p.extra_generation_params["CC_gender"] = gender
p.extra_generation_params["CC_insert_start"] = insert_start
p.extra_generation_params["CC_declone_weight"] = declone_weight
p.extra_generation_params["CC_use_main_seed"] = use_main_seed
p.extra_generation_params["CC_declone_seed"] = declone_seed
if exclude_regions:
p.extra_generation_params["CC_exclude_regions"] = ",".join(exclude_regions)
if exclude_hairlength:
p.extra_generation_params["CC_exclude_hairlength"] = ",".join(exclude_hairlength)
if exclude_haircolor:
p.extra_generation_params["CC_exclude_haircolor"] = ",".join(exclude_haircolor)
countrytree = self.prompt_tree["country"]
hairtree = self.prompt_tree["hair"]
regions = sorted_difference(countrytree.keys(), exclude_regions)
hairlengths = sorted_difference(hairtree["length"].keys(), exclude_hairlength)
haircolors = sorted_difference(hairtree["color"].keys(), exclude_haircolor)
use_name = "name" in use_components
use_country = "country" in use_components
use_length = "hair length" in use_components
use_style = "hair style" in use_components
use_color = "hair color" in use_components
for i, prompt in enumerate(p.all_prompts): # for each image in batch
rng = random.Random()
seed = p.all_seeds[i] if use_main_seed else declone_seed + i
rng.seed(seed)
region = rng.choice(regions)
countries = list(countrytree[region].keys())
countryweights = [countrytree[region][cty]["weight"] for cty in countries]
country = rng.choices(countries, weights=countryweights)[0]
countrydata = countrytree[region][country]
hairdata = countrydata.get("hair", hairtree["defaultweight"][region])
maincolor = rng.choices(haircolors, weights=[hairdata[col] for col in haircolors])[0]
color = rng.choice(hairtree["color"][maincolor])
mainlength = rng.choice(hairlengths)
length = rng.choice(hairtree["length"][mainlength])
style = rng.choice(hairtree["style"][mainlength])
name = rng.choice(countrydata["names"])
inserted_prompt = ""
if use_name or use_country:
inserted_prompt += name if use_name else "person"
inserted_prompt += " from " + country if use_country else ""
if use_length or use_style or use_color:
if inserted_prompt:
inserted_prompt += ", "
if use_length:
inserted_prompt += length + " "
if use_style:
inserted_prompt += style + " "
if use_color:
inserted_prompt += color + " "
inserted_prompt += "hair"
if inserted_prompt:
if declone_weight != 1:
inserted_prompt = f"({inserted_prompt}:{declone_weight})"
if insert_start:
p.all_prompts[i] = inserted_prompt + ", " + prompt
else:
p.all_prompts[i] = prompt + ", " + inserted_prompt
# def postprocess_batch(self, p, *args, **kwargs):
# p.all_prompts[0] = p.prompt # gets saved in file metadata AND in batch file metadata
# def process_batch(self, p, *args, **kwargs):
# p.extra_generation_params["CC_TEST"] = "whatever"
# p.all_prompts[0] = p.prompt + " SUFFIX"
def postprocess(self, p, processed, *args):
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
p.all_prompts[0] = p.prompt
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
# read with shared.opts.prompt_database_path
def on_ui_settings():
info = shared.OptionInfo("prompt_tree.yml", "CloneCleaner prompt database path", section=("clonecleaner", "CloneCleaner"))
shared.opts.add_option("prompt_database_path", info)
# shared.opts.add_option("option1", shared.OptionInfo(
# False,
# "option1 description",
# gr.Checkbox,
# {"interactive": True},
# section=('template', "Template"))
# )
script_callbacks.on_ui_settings(on_ui_settings)
|