|
|
import re |
|
|
import sys |
|
|
from modules import scripts, script_callbacks, ui_extra_networks, extra_networks, shared, sd_models, sd_vae, sd_samplers, processing |
|
|
|
|
|
|
|
|
operations = { |
|
|
"txt2img": processing.StableDiffusionProcessingTxt2Img, |
|
|
"img2img": processing.StableDiffusionProcessingImg2Img, |
|
|
} |
|
|
needs_hr_recalc = False |
|
|
|
|
|
|
|
|
def is_debug(): |
|
|
return shared.opts.data.get("randomizer_keywords_debug", False) |
|
|
|
|
|
|
|
|
def recalc_hires_fix(p): |
|
|
def print_params(p): |
|
|
print(f"- width: {p.width}") |
|
|
print(f"- height: {p.height}") |
|
|
print(f"- hr_upscaler: {p.hr_upscaler}") |
|
|
print(f"- hr_second_pass_steps: {p.hr_second_pass_steps}") |
|
|
print(f"- hr_scale: {p.hr_scale}") |
|
|
print(f"- hr_resize_x: {p.hr_resize_x}") |
|
|
print(f"- hr_resize_y: {p.hr_resize_y}") |
|
|
print(f"- hr_upscale_to_x: {p.hr_upscale_to_x}") |
|
|
print(f"- hr_upscale_to_y: {p.hr_upscale_to_y}") |
|
|
|
|
|
if isinstance(p, processing.StableDiffusionProcessingTxt2Img): |
|
|
if is_debug(): |
|
|
print("[RandomizerKeywords] Recalculating Hires. fix") |
|
|
print("Before:") |
|
|
print_params(p) |
|
|
|
|
|
for param in ["Hires upscale", "Hires resize", "Hires steps", "Hires upscaler"]: |
|
|
p.extra_generation_params.pop(param, None) |
|
|
|
|
|
|
|
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds) |
|
|
|
|
|
if is_debug(): |
|
|
print("====================") |
|
|
print("After:") |
|
|
print_params(p) |
|
|
|
|
|
|
|
|
class RandomizerKeywordConfigOption(extra_networks.ExtraNetwork): |
|
|
def __init__(self, keyword_name, param_type, value_min=0, value_max=None, option_name=None, validate_cb=None, adjust_cb=None): |
|
|
super().__init__(keyword_name) |
|
|
self.param_type = param_type |
|
|
self.value_min = value_min |
|
|
self.value_max = value_max |
|
|
self.validate_cb = validate_cb |
|
|
self.adjust_cb = adjust_cb |
|
|
|
|
|
self.option_name = option_name |
|
|
if self.option_name is None: |
|
|
self.option_name = keyword_name |
|
|
|
|
|
self.has_original = False |
|
|
self.original_value = None |
|
|
|
|
|
def activate(self, p, params_list): |
|
|
if not params_list: |
|
|
return |
|
|
|
|
|
if not self.has_original: |
|
|
self.original_value = shared.opts.data[self.option_name] |
|
|
self.has_original = True |
|
|
|
|
|
value = params_list[0].items[0] |
|
|
value = self.param_type(value) |
|
|
|
|
|
if self.adjust_cb: |
|
|
value = self.adjust_cb(value, p) |
|
|
|
|
|
if isinstance(value, int) or isinstance(value, float): |
|
|
if self.value_min: |
|
|
value = max(value, self.value_min) |
|
|
if self.value_max: |
|
|
value = min(value, self.value_max) |
|
|
|
|
|
if self.validate_cb: |
|
|
error = self.validate_cb(value, p) |
|
|
if error: |
|
|
raise RuntimeError(f"Validation for '{self.name}' keyword failed: {error}") |
|
|
|
|
|
if is_debug(): |
|
|
print(f"[RandomizerKeywords] Set CONFIG option: {self.option_name} -> {value}") |
|
|
|
|
|
shared.opts.data[self.option_name] = value |
|
|
|
|
|
def deactivate(self, p): |
|
|
if self.has_original: |
|
|
if is_debug(): |
|
|
print(f"[RandomizerKeywords] Reset CONFIG option: {self.option_name} -> {self.original_value}") |
|
|
|
|
|
shared.opts.data[self.option_name] = self.original_value |
|
|
self.has_original = False |
|
|
self.original_value = None |
|
|
|
|
|
|
|
|
class RandomizerKeywordSamplerParam(extra_networks.ExtraNetwork): |
|
|
def __init__(self, param_name, param_type, value_min=0, value_max=None, op_type=None, validate_cb=None, adjust_cb=None): |
|
|
super().__init__(param_name) |
|
|
self.param_type = param_type |
|
|
self.value_min = value_min |
|
|
self.value_max = value_max |
|
|
self.op_type = op_type |
|
|
self.validate_cb = validate_cb |
|
|
self.adjust_cb = adjust_cb |
|
|
|
|
|
def activate(self, p, params_list): |
|
|
if not params_list: |
|
|
return |
|
|
|
|
|
if self.op_type: |
|
|
ty = operations[self.op_type] |
|
|
if not isinstance(p, ty): |
|
|
return |
|
|
|
|
|
value = params_list[0].items[0] |
|
|
value = self.param_type(value) |
|
|
|
|
|
if self.adjust_cb: |
|
|
value = self.adjust_cb(value, p) |
|
|
|
|
|
if isinstance(value, int) or isinstance(value, float): |
|
|
if self.value_min: |
|
|
value = max(value, self.value_min) |
|
|
if self.value_max: |
|
|
value = min(value, self.value_max) |
|
|
|
|
|
if self.validate_cb: |
|
|
error = self.validate_cb(value, p) |
|
|
if error: |
|
|
raise RuntimeError(f"Validation for '{self.name}' keyword failed: {error}") |
|
|
|
|
|
if is_debug(): |
|
|
print(f"[RandomizerKeywords] Set SAMPLER option: {self.name} -> {value}") |
|
|
|
|
|
setattr(p, self.name, value) |
|
|
|
|
|
global needs_hr_recalc |
|
|
if self.name == "width" or self.name == "height" or self.name.startswith("hr_"): |
|
|
needs_hr_recalc = True |
|
|
|
|
|
def deactivate(self, p): |
|
|
pass |
|
|
|
|
|
|
|
|
def validate_sampler_name(x, p): |
|
|
if isinstance(p, processing.StableDiffusionProcessingImg2Img): |
|
|
choices = sd_samplers.samplers_for_img2img |
|
|
else: |
|
|
choices = sd_samplers.samplers |
|
|
|
|
|
names = set(x.name for x in choices) |
|
|
|
|
|
if x not in names: |
|
|
return f"Invalid sampler '{x}'" |
|
|
return None |
|
|
|
|
|
|
|
|
class RandomizerKeywordCheckpoint(extra_networks.ExtraNetwork): |
|
|
def __init__(self): |
|
|
super().__init__("checkpoint") |
|
|
self.original_checkpoint_info = None |
|
|
|
|
|
def activate(self, p, params_list): |
|
|
if not params_list: |
|
|
return |
|
|
|
|
|
if self.original_checkpoint_info is None: |
|
|
self.original_checkpoint_info = shared.sd_model.sd_checkpoint_info |
|
|
|
|
|
params = params_list[0] |
|
|
assert len(params.items) > 0, "Must provide checkpoint name" |
|
|
|
|
|
name = params.items[0] |
|
|
info = sd_models.get_closet_checkpoint_match(name) |
|
|
if info is None: |
|
|
raise RuntimeError(f"Unknown checkpoint: {name}") |
|
|
|
|
|
if is_debug(): |
|
|
print(f"[RandomizerKeywords] Set CHECKPOINT: {info.name}") |
|
|
|
|
|
sd_models.reload_model_weights(shared.sd_model, info) |
|
|
|
|
|
def deactivate(self, p): |
|
|
if self.original_checkpoint_info is not None: |
|
|
if is_debug(): |
|
|
print(f"[RandomizerKeywords] Reset CHECKPOINT: {self.original_checkpoint_info.name}") |
|
|
|
|
|
sd_models.reload_model_weights(shared.sd_model, self.original_checkpoint_info) |
|
|
self.original_checkpoint_info = None |
|
|
|
|
|
|
|
|
class RandomizerKeywordVAE(extra_networks.ExtraNetwork): |
|
|
def __init__(self): |
|
|
super().__init__("vae") |
|
|
self.has_original = False |
|
|
self.original_vae_info = None |
|
|
|
|
|
def find_vae(self, name: str): |
|
|
if name.lower() in ['auto', 'automatic']: |
|
|
return sd_vae.unspecified |
|
|
if name.lower() == 'none': |
|
|
return None |
|
|
else: |
|
|
choices = [x for x in sorted(sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] |
|
|
if len(choices) == 0: |
|
|
return None |
|
|
else: |
|
|
return sd_vae.vae_dict[choices[0]] |
|
|
|
|
|
def activate(self, p, params_list): |
|
|
if not params_list: |
|
|
return |
|
|
|
|
|
if not self.has_original: |
|
|
self.original_vae_info = shared.opts.sd_vae |
|
|
self.has_original = True |
|
|
|
|
|
params = params_list[0] |
|
|
assert len(params.items) > 0, "Must provide VAE name or 'auto' for automatic" |
|
|
|
|
|
name = params.items[0] |
|
|
info = self.find_vae(name) |
|
|
if info is None: |
|
|
raise RuntimeError(f"Unknown VAE: {name}") |
|
|
|
|
|
if is_debug(): |
|
|
print(f"[RandomizerKeywords] Set VAE: {info.name}") |
|
|
|
|
|
sd_vae.reload_vae_weights(shared.sd_model, vae_file=info) |
|
|
|
|
|
def deactivate(self, p): |
|
|
if self.has_original: |
|
|
if is_debug(): |
|
|
print(f"[RandomizerKeywords] Reset VAE: {self.original_vae_info.name}") |
|
|
|
|
|
shared.opts.data["sd_vae"] = self.original_vae_info |
|
|
sd_vae.reload_vae_weights() |
|
|
|
|
|
self.original_checkpoint_info = None |
|
|
self.has_original = False |
|
|
|
|
|
|
|
|
def update_extension_args(ext_name, p, value, arg_idx): |
|
|
if isinstance(p, processing.StableDiffusionProcessingImg2Img): |
|
|
all_scripts = scripts.scripts_img2img.alwayson_scripts |
|
|
else: |
|
|
all_scripts = scripts.scripts_txt2img.alwayson_scripts |
|
|
|
|
|
script_class = extension_classes[ext_name] |
|
|
script = next(iter([s for s in all_scripts if isinstance(s, script_class)]), None) |
|
|
assert script, f"Could not find script for {script_class}!" |
|
|
|
|
|
args = list(p.script_args) |
|
|
|
|
|
if is_debug(): |
|
|
print(f"[RandomizerKeywords] Args in {ext_name}: {args[script.args_from:script.args_to]}") |
|
|
print(f"[RandomizerKeywords] For {ext_name}: Changed arg {arg_idx} from {args[script.args_from + arg_idx]} to {value}") |
|
|
|
|
|
args[script.args_from + arg_idx] = value |
|
|
p.script_args = tuple(args) |
|
|
|
|
|
|
|
|
class RandomizerKeywordExtAddNetModel(extra_networks.ExtraNetwork): |
|
|
def __init__(self, index): |
|
|
super().__init__(f"addnet_model_{index+1}") |
|
|
self.index = i |
|
|
|
|
|
def activate(self, p, params_list): |
|
|
if not params_list: |
|
|
return |
|
|
|
|
|
model_util = sys.modules.get("scripts.model_util") |
|
|
if not model_util: |
|
|
raise RuntimeError("Could not load additional_networks model_util") |
|
|
|
|
|
value = params_list[0].items[0] |
|
|
name = model_util.find_closest_lora_model_name(value) |
|
|
if not name: |
|
|
raise RuntimeError(f"Could not find LoRA with name {value}") |
|
|
|
|
|
update_extension_args("additional_networks", p, True, 0) |
|
|
update_extension_args("additional_networks", p, name, 3 + 4 * self.index) |
|
|
|
|
|
def deactivate(self, p): |
|
|
pass |
|
|
|
|
|
|
|
|
class RandomizerKeywordExtAddNetWeight(extra_networks.ExtraNetwork): |
|
|
def __init__(self, index, kind=None): |
|
|
if kind is None: |
|
|
name = f"addnet_weight_{index+1}" |
|
|
else: |
|
|
name = f"addnet_{kind}_weight_{index+1}" |
|
|
|
|
|
super().__init__(name) |
|
|
self.index = i |
|
|
self.kind = kind |
|
|
|
|
|
def activate(self, p, params_list): |
|
|
if not params_list: |
|
|
return |
|
|
|
|
|
value = float(params_list[0].items[0]) |
|
|
|
|
|
|
|
|
update_extension_args("additional_networks", p, True, 0) |
|
|
if self.kind is None or self.kind == "unet": |
|
|
update_extension_args("additional_networks", p, value, 4 + 4 * self.index) |
|
|
if self.kind is None or self.kind == "tenc": |
|
|
update_extension_args("additional_networks", p, value, 5 + 4 * self.index) |
|
|
|
|
|
def deactivate(self, p): |
|
|
pass |
|
|
|
|
|
|
|
|
class Script(scripts.Script): |
|
|
def title(self): |
|
|
return "Randomizer Keywords" |
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
def process_batch(self, p, *args, **kwargs): |
|
|
global needs_hr_recalc |
|
|
if needs_hr_recalc: |
|
|
recalc_hires_fix(p) |
|
|
|
|
|
needs_hr_recalc = False |
|
|
|
|
|
|
|
|
config_params = [ |
|
|
RandomizerKeywordConfigOption("clip_skip", int, 1, 12, option_name="CLIP_stop_at_last_layers") |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sampler_params = [ |
|
|
RandomizerKeywordSamplerParam("cfg_scale", float, 1), |
|
|
RandomizerKeywordSamplerParam("seed", int, -1), |
|
|
RandomizerKeywordSamplerParam("subseed", int, -1), |
|
|
RandomizerKeywordSamplerParam("subseed_strength", float, 0), |
|
|
RandomizerKeywordSamplerParam("sampler_name", str, validate_cb=validate_sampler_name), |
|
|
RandomizerKeywordSamplerParam("steps", int, 1), |
|
|
RandomizerKeywordSamplerParam("width", int, 64, adjust_cb=lambda x, p: x - (x % 8)), |
|
|
RandomizerKeywordSamplerParam("height", int, 64, adjust_cb=lambda x, p: x - (x % 8)), |
|
|
RandomizerKeywordSamplerParam("tiling", bool), |
|
|
RandomizerKeywordSamplerParam("restore_faces", bool), |
|
|
RandomizerKeywordSamplerParam("s_churn", float), |
|
|
RandomizerKeywordSamplerParam("s_tmin", float), |
|
|
RandomizerKeywordSamplerParam("s_tmax", float), |
|
|
RandomizerKeywordSamplerParam("s_noise", float), |
|
|
RandomizerKeywordSamplerParam("eta", float, 0), |
|
|
RandomizerKeywordSamplerParam("ddim_discretize", str), |
|
|
RandomizerKeywordSamplerParam("denoising_strength", float), |
|
|
|
|
|
|
|
|
RandomizerKeywordSamplerParam("hr_scale", float, 1, op_type="txt2img"), |
|
|
RandomizerKeywordSamplerParam("hr_upscaler", str, op_type="txt2img"), |
|
|
RandomizerKeywordSamplerParam("hr_second_pass_steps", int, 1, op_type="txt2img"), |
|
|
RandomizerKeywordSamplerParam("hr_resize_x", int, 64, adjust_cb=lambda x, p: x - (x % 8), op_type="txt2img"), |
|
|
RandomizerKeywordSamplerParam("hr_resize_y", int, 64, adjust_cb=lambda x, p: x - (x % 8), op_type="txt2img"), |
|
|
|
|
|
|
|
|
RandomizerKeywordSamplerParam("mask_blur", float, op_type="img2img"), |
|
|
RandomizerKeywordSamplerParam("inpainting_mask_weight", float, op_type="img2img"), |
|
|
] |
|
|
|
|
|
|
|
|
other_params = [ |
|
|
RandomizerKeywordCheckpoint(), |
|
|
RandomizerKeywordVAE() |
|
|
] |
|
|
|
|
|
|
|
|
extension_params = [] |
|
|
extension_modules = {} |
|
|
extension_classes = {} |
|
|
supported_modules = { |
|
|
"additional_networks": [] |
|
|
} |
|
|
|
|
|
for i in range(5): |
|
|
supported_modules["additional_networks"].extend([ |
|
|
RandomizerKeywordExtAddNetModel(i), |
|
|
RandomizerKeywordExtAddNetWeight(i), |
|
|
RandomizerKeywordExtAddNetWeight(i, "unet"), |
|
|
RandomizerKeywordExtAddNetWeight(i, "tenc"), |
|
|
]) |
|
|
|
|
|
|
|
|
all_params = [] |
|
|
|
|
|
|
|
|
def on_app_started(demo, app): |
|
|
global all_params |
|
|
|
|
|
for s in scripts.scripts_data: |
|
|
for m, params in supported_modules.items(): |
|
|
if s.module.__name__ == m + ".py": |
|
|
assert m not in extension_modules |
|
|
print(f"[RandomizerKeywords] Adding support for extension: {m}") |
|
|
extension_modules[m] = s.module |
|
|
extension_classes[m] = s.script_class |
|
|
extension_params.extend(params) |
|
|
|
|
|
all_params = config_params + sampler_params + other_params + extension_params |
|
|
print(f"[RandomizerKeywords] Supported keywords: {', '.join([p.name for p in all_params])}") |
|
|
|
|
|
for param in all_params: |
|
|
extra_networks.register_extra_network(param) |
|
|
|
|
|
|
|
|
def on_ui_settings(): |
|
|
section = ('randomizer_keywords', "Randomizer Keywords") |
|
|
shared.opts.add_option("randomizer_keywords_debug", shared.OptionInfo(False, "Print debug messages", section=section)) |
|
|
|
|
|
|
|
|
script_callbacks.on_app_started(on_app_started) |
|
|
script_callbacks.on_ui_settings(on_ui_settings) |
|
|
|