import sys import time import modules.scripts as scripts import modules import modules.shared as shared import gradio as gr from modules.ui_components import FormColumn, FormRow from modules import script_callbacks def on_ui_settings(): section = ("sdxlHiresHack ", "SDXL Refinder Hack") shared.opts.add_option( key = "sdxl_base_model", info = shared.OptionInfo( "sd_xl_base_1.0.safetensors", "SDXL Base model", section=section) ) shared.opts.add_option( key = "sdxl_refiner_model", info = shared.OptionInfo( "sd_xl_refiner_1.0.safetensors", "SDXL refiner model", section=section) ) script_callbacks.on_ui_settings(on_ui_settings) class sdxlRefinderHack(scripts.Script): def __init__(self): self.info_base = None self.info_hr = None self.first_pass = True def title(self): return "SDXL Refinder Hack" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): with gr.Accordion(self.title(), open=False): gr.Markdown("will become unnecessary in the 1.6 release of A1111") if is_img2img: gr.Markdown("will not do anything in img2img") else: with FormRow(): is_enabled = gr.Checkbox(value=False, label="Enable") with FormColumn(): base_model = gr.inputs.Textbox(lines=1, label="SDXL base model name", default=getattr(shared.opts, "sdxl_base_model", "")) refinder_model = gr.inputs.Textbox(lines=1, label="SDXL refinder model name", default=getattr(shared.opts, "sdxl_refiner_model", "")) return [base_model, refinder_model, is_enabled] def before_process_batch(self, p,*args, **kwargs): print(f"\nEnabled: {args[2]}\n\ncheckpoint: {args[0]}") if args[2]: if self.first_pass: self.first_pass = False else: modules.sd_models.unload_model_weights(shared.sd_model, self.info_hr) self.info_base = modules.sd_models.get_closet_checkpoint_match(args[0]) modules.sd_models.reload_model_weights(shared.sd_model, self.info_base) p.override_settings['sd_model_checkpoint'] = self.info_base.name def before_hr(self, p, *args, **kwargs): if args[2]: modules.sd_models.unload_model_weights(shared.sd_model, self.info_base) self.info_hr = modules.sd_models.get_closet_checkpoint_match(args[1]) modules.sd_models.reload_model_weights(shared.sd_model, self.info_hr) p.override_settings['sd_model_checkpoint'] = self.info_hr.name p.extra_generation_params['base model'] = self.info_base.name