| import re |
| import sys |
| from modules import scripts, script_callbacks, ui_extra_networks, extra_networks, shared, sd_models, sd_vae, sd_samplers, processing |
|
|
|
|
| operations = { |
| "txt2img": processing.StableDiffusionProcessingTxt2Img, |
| "img2img": processing.StableDiffusionProcessingImg2Img, |
| } |
| needs_hr_recalc = False |
|
|
|
|
| def is_debug(): |
| return shared.opts.data.get("randomizer_keywords_debug", False) |
|
|
|
|
| def recalc_hires_fix(p): |
| def print_params(p): |
| print(f"- width: {p.width}") |
| print(f"- height: {p.height}") |
| print(f"- hr_upscaler: {p.hr_upscaler}") |
| print(f"- hr_second_pass_steps: {p.hr_second_pass_steps}") |
| print(f"- hr_scale: {p.hr_scale}") |
| print(f"- hr_resize_x: {p.hr_resize_x}") |
| print(f"- hr_resize_y: {p.hr_resize_y}") |
| print(f"- hr_upscale_to_x: {p.hr_upscale_to_x}") |
| print(f"- hr_upscale_to_y: {p.hr_upscale_to_y}") |
|
|
| if isinstance(p, processing.StableDiffusionProcessingTxt2Img): |
| if is_debug(): |
| print("[RandomizerKeywords] Recalculating Hires. fix") |
| print("Before:") |
| print_params(p) |
|
|
| for param in ["Hires upscale", "Hires resize", "Hires steps", "Hires upscaler"]: |
| p.extra_generation_params.pop(param, None) |
|
|
| |
| p.init(p.all_prompts, p.all_seeds, p.all_subseeds) |
|
|
| if is_debug(): |
| print("====================") |
| print("After:") |
| print_params(p) |
|
|
|
|
| class RandomizerKeywordConfigOption(extra_networks.ExtraNetwork): |
| def __init__(self, keyword_name, param_type, value_min=0, value_max=None, option_name=None, validate_cb=None, adjust_cb=None): |
| super().__init__(keyword_name) |
| self.param_type = param_type |
| self.value_min = value_min |
| self.value_max = value_max |
| self.validate_cb = validate_cb |
| self.adjust_cb = adjust_cb |
|
|
| self.option_name = option_name |
| if self.option_name is None: |
| self.option_name = keyword_name |
|
|
| self.has_original = False |
| self.original_value = None |
|
|
| def activate(self, p, params_list): |
| if not params_list: |
| return |
|
|
| if not self.has_original: |
| self.original_value = shared.opts.data[self.option_name] |
| self.has_original = True |
|
|
| value = params_list[0].items[0] |
| value = self.param_type(value) |
|
|
| if self.adjust_cb: |
| value = self.adjust_cb(value, p) |
|
|
| if isinstance(value, int) or isinstance(value, float): |
| if self.value_min: |
| value = max(value, self.value_min) |
| if self.value_max: |
| value = min(value, self.value_max) |
|
|
| if self.validate_cb: |
| error = self.validate_cb(value, p) |
| if error: |
| raise RuntimeError(f"Validation for '{self.name}' keyword failed: {error}") |
|
|
| if is_debug(): |
| print(f"[RandomizerKeywords] Set CONFIG option: {self.option_name} -> {value}") |
|
|
| shared.opts.data[self.option_name] = value |
|
|
| def deactivate(self, p): |
| if self.has_original: |
| if is_debug(): |
| print(f"[RandomizerKeywords] Reset CONFIG option: {self.option_name} -> {self.original_value}") |
|
|
| shared.opts.data[self.option_name] = self.original_value |
| self.has_original = False |
| self.original_value = None |
|
|
|
|
| class RandomizerKeywordSamplerParam(extra_networks.ExtraNetwork): |
| def __init__(self, param_name, param_type, value_min=0, value_max=None, op_type=None, validate_cb=None, adjust_cb=None): |
| super().__init__(param_name) |
| self.param_type = param_type |
| self.value_min = value_min |
| self.value_max = value_max |
| self.op_type = op_type |
| self.validate_cb = validate_cb |
| self.adjust_cb = adjust_cb |
|
|
| def activate(self, p, params_list): |
| if not params_list: |
| return |
|
|
| if self.op_type: |
| ty = operations[self.op_type] |
| if not isinstance(p, ty): |
| return |
|
|
| value = params_list[0].items[0] |
| value = self.param_type(value) |
|
|
| if self.adjust_cb: |
| value = self.adjust_cb(value, p) |
|
|
| if isinstance(value, int) or isinstance(value, float): |
| if self.value_min: |
| value = max(value, self.value_min) |
| if self.value_max: |
| value = min(value, self.value_max) |
|
|
| if self.validate_cb: |
| error = self.validate_cb(value, p) |
| if error: |
| raise RuntimeError(f"Validation for '{self.name}' keyword failed: {error}") |
|
|
| if is_debug(): |
| print(f"[RandomizerKeywords] Set SAMPLER option: {self.name} -> {value}") |
|
|
| setattr(p, self.name, value) |
|
|
| global needs_hr_recalc |
| if self.name == "width" or self.name == "height" or self.name.startswith("hr_"): |
| needs_hr_recalc = True |
|
|
| def deactivate(self, p): |
| pass |
|
|
|
|
| def validate_sampler_name(x, p): |
| if isinstance(p, processing.StableDiffusionProcessingImg2Img): |
| choices = sd_samplers.samplers_for_img2img |
| else: |
| choices = sd_samplers.samplers |
|
|
| names = set(x.name for x in choices) |
|
|
| if x not in names: |
| return f"Invalid sampler '{x}'" |
| return None |
|
|
|
|
| class RandomizerKeywordCheckpoint(extra_networks.ExtraNetwork): |
| def __init__(self): |
| super().__init__("checkpoint") |
| self.original_checkpoint_info = None |
|
|
| def activate(self, p, params_list): |
| if not params_list: |
| return |
|
|
| if self.original_checkpoint_info is None: |
| self.original_checkpoint_info = shared.sd_model.sd_checkpoint_info |
|
|
| params = params_list[0] |
| assert len(params.items) > 0, "Must provide checkpoint name" |
|
|
| name = params.items[0] |
| info = sd_models.get_closet_checkpoint_match(name) |
| if info is None: |
| raise RuntimeError(f"Unknown checkpoint: {name}") |
|
|
| if is_debug(): |
| print(f"[RandomizerKeywords] Set CHECKPOINT: {info.name}") |
|
|
| sd_models.reload_model_weights(shared.sd_model, info) |
|
|
| def deactivate(self, p): |
| if self.original_checkpoint_info is not None: |
| if is_debug(): |
| print(f"[RandomizerKeywords] Reset CHECKPOINT: {self.original_checkpoint_info.name}") |
|
|
| sd_models.reload_model_weights(shared.sd_model, self.original_checkpoint_info) |
| self.original_checkpoint_info = None |
|
|
|
|
| class RandomizerKeywordVAE(extra_networks.ExtraNetwork): |
| def __init__(self): |
| super().__init__("vae") |
| self.has_original = False |
| self.original_vae_info = None |
|
|
| def find_vae(self, name: str): |
| if name.lower() in ['auto', 'automatic']: |
| return sd_vae.unspecified |
| if name.lower() == 'none': |
| return None |
| else: |
| choices = [x for x in sorted(sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] |
| if len(choices) == 0: |
| return None |
| else: |
| return sd_vae.vae_dict[choices[0]] |
|
|
| def activate(self, p, params_list): |
| if not params_list: |
| return |
|
|
| if not self.has_original: |
| self.original_vae_info = shared.opts.sd_vae |
| self.has_original = True |
|
|
| params = params_list[0] |
| assert len(params.items) > 0, "Must provide VAE name or 'auto' for automatic" |
|
|
| name = params.items[0] |
| info = self.find_vae(name) |
| if info is None: |
| raise RuntimeError(f"Unknown VAE: {name}") |
|
|
| if is_debug(): |
| print(f"[RandomizerKeywords] Set VAE: {info.name}") |
|
|
| sd_vae.reload_vae_weights(shared.sd_model, vae_file=info) |
|
|
| def deactivate(self, p): |
| if self.has_original: |
| if is_debug(): |
| print(f"[RandomizerKeywords] Reset VAE: {self.original_vae_info.name}") |
|
|
| shared.opts.data["sd_vae"] = self.original_vae_info |
| sd_vae.reload_vae_weights() |
|
|
| self.original_checkpoint_info = None |
| self.has_original = False |
|
|
|
|
| def update_extension_args(ext_name, p, value, arg_idx): |
| if isinstance(p, processing.StableDiffusionProcessingImg2Img): |
| all_scripts = scripts.scripts_img2img.alwayson_scripts |
| else: |
| all_scripts = scripts.scripts_txt2img.alwayson_scripts |
|
|
| script_class = extension_classes[ext_name] |
| script = next(iter([s for s in all_scripts if isinstance(s, script_class)]), None) |
| assert script, f"Could not find script for {script_class}!" |
|
|
| args = list(p.script_args) |
|
|
| if is_debug(): |
| print(f"[RandomizerKeywords] Args in {ext_name}: {args[script.args_from:script.args_to]}") |
| print(f"[RandomizerKeywords] For {ext_name}: Changed arg {arg_idx} from {args[script.args_from + arg_idx]} to {value}") |
|
|
| args[script.args_from + arg_idx] = value |
| p.script_args = tuple(args) |
|
|
|
|
| class RandomizerKeywordExtAddNetModel(extra_networks.ExtraNetwork): |
| def __init__(self, index): |
| super().__init__(f"addnet_model_{index+1}") |
| self.index = i |
|
|
| def activate(self, p, params_list): |
| if not params_list: |
| return |
|
|
| model_util = sys.modules.get("scripts.model_util") |
| if not model_util: |
| raise RuntimeError("Could not load additional_networks model_util") |
|
|
| value = params_list[0].items[0] |
| name = model_util.find_closest_lora_model_name(value) |
| if not name: |
| raise RuntimeError(f"Could not find LoRA with name {value}") |
|
|
| update_extension_args("additional_networks", p, True, 0) |
| update_extension_args("additional_networks", p, name, 3 + 4 * self.index) |
|
|
| def deactivate(self, p): |
| pass |
|
|
|
|
| class RandomizerKeywordExtAddNetWeight(extra_networks.ExtraNetwork): |
| def __init__(self, index, kind=None): |
| if kind is None: |
| name = f"addnet_weight_{index+1}" |
| else: |
| name = f"addnet_{kind}_weight_{index+1}" |
|
|
| super().__init__(name) |
| self.index = i |
| self.kind = kind |
|
|
| def activate(self, p, params_list): |
| if not params_list: |
| return |
|
|
| value = float(params_list[0].items[0]) |
|
|
| |
| update_extension_args("additional_networks", p, True, 0) |
| if self.kind is None or self.kind == "unet": |
| update_extension_args("additional_networks", p, value, 4 + 4 * self.index) |
| if self.kind is None or self.kind == "tenc": |
| update_extension_args("additional_networks", p, value, 5 + 4 * self.index) |
|
|
| def deactivate(self, p): |
| pass |
|
|
|
|
| class Script(scripts.Script): |
| def title(self): |
| return "Randomizer Keywords" |
|
|
| def show(self, is_img2img): |
| return scripts.AlwaysVisible |
|
|
| def process_batch(self, p, *args, **kwargs): |
| global needs_hr_recalc |
| if needs_hr_recalc: |
| recalc_hires_fix(p) |
|
|
| needs_hr_recalc = False |
|
|
|
|
| config_params = [ |
| RandomizerKeywordConfigOption("clip_skip", int, 1, 12, option_name="CLIP_stop_at_last_layers") |
| ] |
|
|
|
|
| |
| |
| sampler_params = [ |
| RandomizerKeywordSamplerParam("cfg_scale", float, 1), |
| RandomizerKeywordSamplerParam("seed", int, -1), |
| RandomizerKeywordSamplerParam("subseed", int, -1), |
| RandomizerKeywordSamplerParam("subseed_strength", float, 0), |
| RandomizerKeywordSamplerParam("sampler_name", str, validate_cb=validate_sampler_name), |
| RandomizerKeywordSamplerParam("steps", int, 1), |
| RandomizerKeywordSamplerParam("width", int, 64, adjust_cb=lambda x, p: x - (x % 8)), |
| RandomizerKeywordSamplerParam("height", int, 64, adjust_cb=lambda x, p: x - (x % 8)), |
| RandomizerKeywordSamplerParam("tiling", bool), |
| RandomizerKeywordSamplerParam("restore_faces", bool), |
| RandomizerKeywordSamplerParam("s_churn", float), |
| RandomizerKeywordSamplerParam("s_tmin", float), |
| RandomizerKeywordSamplerParam("s_tmax", float), |
| RandomizerKeywordSamplerParam("s_noise", float), |
| RandomizerKeywordSamplerParam("eta", float, 0), |
| RandomizerKeywordSamplerParam("ddim_discretize", str), |
| RandomizerKeywordSamplerParam("denoising_strength", float), |
|
|
| |
| RandomizerKeywordSamplerParam("hr_scale", float, 1, op_type="txt2img"), |
| RandomizerKeywordSamplerParam("hr_upscaler", str, op_type="txt2img"), |
| RandomizerKeywordSamplerParam("hr_second_pass_steps", int, 1, op_type="txt2img"), |
| RandomizerKeywordSamplerParam("hr_resize_x", int, 64, adjust_cb=lambda x, p: x - (x % 8), op_type="txt2img"), |
| RandomizerKeywordSamplerParam("hr_resize_y", int, 64, adjust_cb=lambda x, p: x - (x % 8), op_type="txt2img"), |
|
|
| |
| RandomizerKeywordSamplerParam("mask_blur", float, op_type="img2img"), |
| RandomizerKeywordSamplerParam("inpainting_mask_weight", float, op_type="img2img"), |
| ] |
|
|
|
|
| other_params = [ |
| RandomizerKeywordCheckpoint(), |
| RandomizerKeywordVAE() |
| ] |
|
|
|
|
| extension_params = [] |
| extension_modules = {} |
| extension_classes = {} |
| supported_modules = { |
| "additional_networks": [] |
| } |
|
|
| for i in range(5): |
| supported_modules["additional_networks"].extend([ |
| RandomizerKeywordExtAddNetModel(i), |
| RandomizerKeywordExtAddNetWeight(i), |
| RandomizerKeywordExtAddNetWeight(i, "unet"), |
| RandomizerKeywordExtAddNetWeight(i, "tenc"), |
| ]) |
|
|
|
|
| all_params = [] |
|
|
|
|
| def on_app_started(demo, app): |
| global all_params |
|
|
| for s in scripts.scripts_data: |
| for m, params in supported_modules.items(): |
| if s.module.__name__ == m + ".py": |
| assert m not in extension_modules |
| print(f"[RandomizerKeywords] Adding support for extension: {m}") |
| extension_modules[m] = s.module |
| extension_classes[m] = s.script_class |
| extension_params.extend(params) |
|
|
| all_params = config_params + sampler_params + other_params + extension_params |
| print(f"[RandomizerKeywords] Supported keywords: {', '.join([p.name for p in all_params])}") |
|
|
| for param in all_params: |
| extra_networks.register_extra_network(param) |
|
|
|
|
| def on_ui_settings(): |
| section = ('randomizer_keywords', "Randomizer Keywords") |
| shared.opts.add_option("randomizer_keywords_debug", shared.OptionInfo(False, "Print debug messages", section=section)) |
|
|
|
|
| script_callbacks.on_app_started(on_app_started) |
| script_callbacks.on_ui_settings(on_ui_settings) |
|
|