File size: 6,744 Bytes
29a5ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from math import ceil
import torch
from modules import scripts, script_callbacks, devices, sd_models, sd_models_config, shared
import gradio as gr
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from sgm.modules.encoders.modules import ConcatTimestepEmbedderND
from safetensors.torch import load_file, load
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from omegaconf import OmegaConf
from sgm.util import (
    disabled_train,
    get_obj_from_str,
    instantiate_from_config,
)

class Refiner(scripts.Script):
    def __init__(self):
        super().__init__()
        self.callback_set = False
        self.model = None
        self.conditioner = None
        self.base = None
        self.swapped = False
        self.model_name = ''
        self.embedder = ConcatTimestepEmbedderND(256)
        self.c_ae = None
        self.uc_ae = None
        
    def title(self):
        return "Refiner"

    def show(self, is_img2img):
        return scripts.AlwaysVisible
    
    def build_model(self):
        refiner_config = OmegaConf.load(sd_models_config.config_sdxl_refiner).model.params.network_config
        self.model = instantiate_from_config(refiner_config)
        self.model = get_obj_from_str(OPENAIUNETWRAPPER)(
            self.model, compile_model=False
        ).eval()
        self.model.to('cpu', devices.dtype_unet)
        self.model.train = disabled_train
        self.model.diffusion_model.dtype = devices.dtype_unet
        self.model.conditioning_key = 'crossattn'
        self.model.cond_stage_key = 'txt'
        self.model.parameterization = 'v'
        discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
        self.model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=devices.dtype_unet)
        for param in self.model.parameters():
            param.requires_grad = False
    
    def load_model(self, model_name):
        if not shared.opts.disable_mmap_load_safetensors:
            ckpt = load_file(sd_models.checkpoints_list[model_name].filename)
        else:
            ckpt = load(open(sd_models.checkpoints_list[model_name].filename, 'rb').read())
        model_type = ''
        for key in ckpt.keys():
            if 'conditioner' in key: 
                model_type = 'Refiner'
            if 'input_blocks.7.1.transformer_blocks.4.attn1.to_k.weight' in key:
                model_type = 'Base'
                break
        if model_type != 'Refiner': 
            self.enable = False
            script_callbacks.remove_current_script_callbacks()
            if model_type == 'Base':
                print('\nIt\'s Base model, use Refiner, extension disabled!\n')
            else:
                print('\nNot refiner, extension disabled!\n')
            return False
        
        print('\nLoading refiner...\n')
        self.build_model()
            
        state_dict = dict()
        for key in ckpt.keys():
            if 'model.diffusion_model' in key:
                state_dict[key.replace('model.d', 'd')] = ckpt[key].half()
        self.model.load_state_dict(state_dict)
        self.model_name = model_name
        return True
        
    def ui(self, is_img2img):
        with gr.Accordion(label='Refiner', open=False):
            enable = gr.Checkbox(label='Enable Refiner', value=False)
            with gr.Row():
                checkpoint = gr.Dropdown(choices=['None', *sd_models.checkpoints_list.keys()], label='Model', value='None')
                steps = gr.Slider(minimum=0, maximum=50, step=1, label='Percent of refiner steps from total sampling steps', value=20)

            gr.HTML('<p style="margin-bottom:0.8em"> It\'s recommended to keep the percentage at 20% (80% base steps, 20% refiner steps). Higher values may result in distortions. </p>')
            
        ui = [enable, checkpoint, steps]
        return ui
    
    def process(self, p, enable, checkpoint, steps):
        if self.base != None or self.swapped == True or self.callback_set == True:
            self.reset(p)
        if not enable or checkpoint == 'None':
            script_callbacks.remove_current_script_callbacks()
            self.model = None
            return
        if self.model == None or self.model_name != checkpoint:
            if not self.load_model(checkpoint): return
        self.c_ae = self.embedder(torch.tensor(shared.opts.sdxl_refiner_high_aesthetic_score).unsqueeze(0).to(devices.device).repeat(p.batch_size, 1))
        self.uc_ae = self.embedder(torch.tensor(shared.opts.sdxl_refiner_low_aesthetic_score).unsqueeze(0).to(devices.device).repeat(p.batch_size, 1))
        p.extra_generation_params['Refiner model'] = checkpoint.rsplit('.', 1)[0]
        p.extra_generation_params['Refiner steps'] = ceil((p.steps * (steps / 100)))
        
        def denoiser_callback(params: script_callbacks.CFGDenoiserParams):
            if params.sampling_step > params.total_sampling_steps * (1 - steps / 100) - 2:
                params.text_cond['vector'] = torch.cat((params.text_cond['vector'][:, :2304], self.c_ae), 1)
                params.text_uncond['vector'] = torch.cat((params.text_uncond['vector'][:, :2304], self.uc_ae), 1)
                params.text_cond['crossattn'] = params.text_cond['crossattn'][:, :, -1280:]
                params.text_uncond['crossattn'] = params.text_uncond['crossattn'][:, :, -1280:]
                if not self.swapped:
                    self.base = p.sd_model.model.to('cpu', devices.dtype_unet)
                    devices.torch_gc()
                    p.sd_model.model = self.model.to(devices.device, devices.dtype_unet)
                    self.swapped = True
        
        def denoised_callback(params: script_callbacks.CFGDenoiserParams):
            if params.sampling_step == params.total_sampling_steps - 2:
                self.reset(p, keep_hook=True)
        
        if not self.callback_set:
            script_callbacks.on_cfg_denoiser(denoiser_callback)
            script_callbacks.on_cfg_denoised(denoised_callback)
            self.callback_set = True
    
    def reset(self, p, keep_hook=False):
        if self.model is not None:
            self.model.to('cpu', devices.dtype_unet)
        p.sd_model.model = (self.base or p.sd_model.model).to(devices.device, devices.dtype_unet)
        devices.torch_gc()
        self.base = None
        self.swapped = False
        if not keep_hook:
            script_callbacks.remove_current_script_callbacks()
            self.callback_set = False
        
    def postprocess(self, p, processed, enable, checkpoint, steps):
        if enable and checkpoint != 'None':
            self.reset(p)