from math import ceil import torch from modules import scripts, script_callbacks, devices, sd_models, sd_models_config, shared import gradio as gr import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from sgm.modules.encoders.modules import ConcatTimestepEmbedderND from safetensors.torch import load_file, load from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from omegaconf import OmegaConf from sgm.util import ( disabled_train, get_obj_from_str, instantiate_from_config, ) class Refiner(scripts.Script): def __init__(self): super().__init__() self.callback_set = False self.model = None self.conditioner = None self.base = None self.swapped = False self.model_name = '' self.embedder = ConcatTimestepEmbedderND(256) self.c_ae = None self.uc_ae = None def title(self): return "Refiner" def show(self, is_img2img): return scripts.AlwaysVisible def build_model(self): refiner_config = OmegaConf.load(sd_models_config.config_sdxl_refiner).model.params.network_config self.model = instantiate_from_config(refiner_config) self.model = get_obj_from_str(OPENAIUNETWRAPPER)( self.model, compile_model=False ).eval() self.model.to('cpu', devices.dtype_unet) self.model.train = disabled_train self.model.diffusion_model.dtype = devices.dtype_unet self.model.conditioning_key = 'crossattn' self.model.cond_stage_key = 'txt' self.model.parameterization = 'v' discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() self.model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=devices.dtype_unet) for param in self.model.parameters(): param.requires_grad = False def load_model(self, model_name): if not shared.opts.disable_mmap_load_safetensors: ckpt = load_file(sd_models.checkpoints_list[model_name].filename) else: ckpt = load(open(sd_models.checkpoints_list[model_name].filename, 'rb').read()) model_type = '' for key in ckpt.keys(): if 'conditioner' in key: model_type = 'Refiner' if 'input_blocks.7.1.transformer_blocks.4.attn1.to_k.weight' in key: model_type = 'Base' break if model_type != 'Refiner': self.enable = False script_callbacks.remove_current_script_callbacks() if model_type == 'Base': print('\nIt\'s Base model, use Refiner, extension disabled!\n') else: print('\nNot refiner, extension disabled!\n') return False print('\nLoading refiner...\n') self.build_model() state_dict = dict() for key in ckpt.keys(): if 'model.diffusion_model' in key: state_dict[key.replace('model.d', 'd')] = ckpt[key].half() self.model.load_state_dict(state_dict) self.model_name = model_name return True def ui(self, is_img2img): with gr.Accordion(label='Refiner', open=False): enable = gr.Checkbox(label='Enable Refiner', value=False) with gr.Row(): checkpoint = gr.Dropdown(choices=['None', *sd_models.checkpoints_list.keys()], label='Model', value='None') steps = gr.Slider(minimum=0, maximum=50, step=1, label='Percent of refiner steps from total sampling steps', value=20) gr.HTML('
It\'s recommended to keep the percentage at 20% (80% base steps, 20% refiner steps). Higher values may result in distortions.
') ui = [enable, checkpoint, steps] return ui def process(self, p, enable, checkpoint, steps): if self.base != None or self.swapped == True or self.callback_set == True: self.reset(p) if not enable or checkpoint == 'None': script_callbacks.remove_current_script_callbacks() self.model = None return if self.model == None or self.model_name != checkpoint: if not self.load_model(checkpoint): return self.c_ae = self.embedder(torch.tensor(shared.opts.sdxl_refiner_high_aesthetic_score).unsqueeze(0).to(devices.device).repeat(p.batch_size, 1)) self.uc_ae = self.embedder(torch.tensor(shared.opts.sdxl_refiner_low_aesthetic_score).unsqueeze(0).to(devices.device).repeat(p.batch_size, 1)) p.extra_generation_params['Refiner model'] = checkpoint.rsplit('.', 1)[0] p.extra_generation_params['Refiner steps'] = ceil((p.steps * (steps / 100))) def denoiser_callback(params: script_callbacks.CFGDenoiserParams): if params.sampling_step > params.total_sampling_steps * (1 - steps / 100) - 2: params.text_cond['vector'] = torch.cat((params.text_cond['vector'][:, :2304], self.c_ae), 1) params.text_uncond['vector'] = torch.cat((params.text_uncond['vector'][:, :2304], self.uc_ae), 1) params.text_cond['crossattn'] = params.text_cond['crossattn'][:, :, -1280:] params.text_uncond['crossattn'] = params.text_uncond['crossattn'][:, :, -1280:] if not self.swapped: self.base = p.sd_model.model.to('cpu', devices.dtype_unet) devices.torch_gc() p.sd_model.model = self.model.to(devices.device, devices.dtype_unet) self.swapped = True def denoised_callback(params: script_callbacks.CFGDenoiserParams): if params.sampling_step == params.total_sampling_steps - 2: self.reset(p, keep_hook=True) if not self.callback_set: script_callbacks.on_cfg_denoiser(denoiser_callback) script_callbacks.on_cfg_denoised(denoised_callback) self.callback_set = True def reset(self, p, keep_hook=False): if self.model is not None: self.model.to('cpu', devices.dtype_unet) p.sd_model.model = (self.base or p.sd_model.model).to(devices.device, devices.dtype_unet) devices.torch_gc() self.base = None self.swapped = False if not keep_hook: script_callbacks.remove_current_script_callbacks() self.callback_set = False def postprocess(self, p, processed, enable, checkpoint, steps): if enable and checkpoint != 'None': self.reset(p)