File size: 2,871 Bytes
29a5ed9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import sys
import time
import modules.scripts as scripts
import modules
import modules.shared as shared
import gradio as gr
from modules.ui_components import FormColumn, FormRow
from modules import script_callbacks
def on_ui_settings():
section = ("sdxlHiresHack ", "SDXL Refinder Hack")
shared.opts.add_option(
key = "sdxl_base_model",
info = shared.OptionInfo(
"sd_xl_base_1.0.safetensors",
"SDXL Base model",
section=section)
)
shared.opts.add_option(
key = "sdxl_refiner_model",
info = shared.OptionInfo(
"sd_xl_refiner_1.0.safetensors",
"SDXL refiner model",
section=section)
)
script_callbacks.on_ui_settings(on_ui_settings)
class sdxlRefinderHack(scripts.Script):
def __init__(self):
self.info_base = None
self.info_hr = None
self.first_pass = True
def title(self):
return "SDXL Refinder Hack"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Accordion(self.title(), open=False):
gr.Markdown("will become unnecessary in the 1.6 release of A1111")
if is_img2img:
gr.Markdown("will not do anything in img2img")
else:
with FormRow():
is_enabled = gr.Checkbox(value=False, label="Enable")
with FormColumn():
base_model = gr.inputs.Textbox(lines=1, label="SDXL base model name", default=getattr(shared.opts, "sdxl_base_model", ""))
refinder_model = gr.inputs.Textbox(lines=1, label="SDXL refinder model name", default=getattr(shared.opts, "sdxl_refiner_model", ""))
return [base_model, refinder_model, is_enabled]
def before_process_batch(self, p,*args, **kwargs):
print(f"\nEnabled: {args[2]}\n\ncheckpoint: {args[0]}")
if args[2]:
if self.first_pass:
self.first_pass = False
else:
modules.sd_models.unload_model_weights(shared.sd_model, self.info_hr)
self.info_base = modules.sd_models.get_closet_checkpoint_match(args[0])
modules.sd_models.reload_model_weights(shared.sd_model, self.info_base)
p.override_settings['sd_model_checkpoint'] = self.info_base.name
def before_hr(self, p, *args, **kwargs):
if args[2]:
modules.sd_models.unload_model_weights(shared.sd_model, self.info_base)
self.info_hr = modules.sd_models.get_closet_checkpoint_match(args[1])
modules.sd_models.reload_model_weights(shared.sd_model, self.info_hr)
p.override_settings['sd_model_checkpoint'] = self.info_hr.name
p.extra_generation_params['base model'] = self.info_base.name
|