File size: 2,871 Bytes
29a5ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import sys
import time
import modules.scripts as scripts
import modules
import modules.shared as shared
import gradio as gr
from modules.ui_components import FormColumn, FormRow
from modules import script_callbacks


def on_ui_settings():
    section = ("sdxlHiresHack ", "SDXL Refinder Hack")
    shared.opts.add_option(
        key = "sdxl_base_model",
        info = shared.OptionInfo(
            "sd_xl_base_1.0.safetensors",
            "SDXL Base model",
            section=section)
    )

    shared.opts.add_option(
        key = "sdxl_refiner_model",
        info = shared.OptionInfo(
            "sd_xl_refiner_1.0.safetensors",
            "SDXL refiner model",
            section=section)
    )

script_callbacks.on_ui_settings(on_ui_settings)

class sdxlRefinderHack(scripts.Script):

    def __init__(self):
        self.info_base = None
        self.info_hr = None
        self.first_pass = True

    def title(self):
        return "SDXL Refinder Hack"
    
    def show(self, is_img2img):
        return scripts.AlwaysVisible
    
    def ui(self, is_img2img):
        with gr.Accordion(self.title(), open=False):
            gr.Markdown("will become unnecessary in the 1.6 release of A1111")
            if is_img2img:
                gr.Markdown("will not do anything in img2img")
            else:
            
                with FormRow():
                    is_enabled = gr.Checkbox(value=False, label="Enable")
                with FormColumn():
                    base_model = gr.inputs.Textbox(lines=1, label="SDXL base model name", default=getattr(shared.opts, "sdxl_base_model", ""))
                    refinder_model = gr.inputs.Textbox(lines=1, label="SDXL refinder model name", default=getattr(shared.opts, "sdxl_refiner_model", ""))
                return [base_model, refinder_model, is_enabled]
    
    
    def before_process_batch(self, p,*args, **kwargs):
        print(f"\nEnabled: {args[2]}\n\ncheckpoint: {args[0]}")
        if args[2]:
            if self.first_pass:
                self.first_pass = False
            else:
                modules.sd_models.unload_model_weights(shared.sd_model, self.info_hr)
            self.info_base = modules.sd_models.get_closet_checkpoint_match(args[0])
            modules.sd_models.reload_model_weights(shared.sd_model, self.info_base)
            p.override_settings['sd_model_checkpoint'] = self.info_base.name

    def before_hr(self, p, *args, **kwargs):
        if args[2]:
            modules.sd_models.unload_model_weights(shared.sd_model, self.info_base)
            self.info_hr = modules.sd_models.get_closet_checkpoint_match(args[1])
            modules.sd_models.reload_model_weights(shared.sd_model, self.info_hr)
            p.override_settings['sd_model_checkpoint'] = self.info_hr.name
            p.extra_generation_params['base model'] = self.info_base.name