fac / gf /sd-webui-refiner /scripts /refiner.py
dikdimon's picture
Upload gf using SD-Hub extension
29a5ed9 verified
from math import ceil
import torch
from modules import scripts, script_callbacks, devices, sd_models, sd_models_config, shared
import gradio as gr
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from sgm.modules.encoders.modules import ConcatTimestepEmbedderND
from safetensors.torch import load_file, load
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from omegaconf import OmegaConf
from sgm.util import (
disabled_train,
get_obj_from_str,
instantiate_from_config,
)
class Refiner(scripts.Script):
def __init__(self):
super().__init__()
self.callback_set = False
self.model = None
self.conditioner = None
self.base = None
self.swapped = False
self.model_name = ''
self.embedder = ConcatTimestepEmbedderND(256)
self.c_ae = None
self.uc_ae = None
def title(self):
return "Refiner"
def show(self, is_img2img):
return scripts.AlwaysVisible
def build_model(self):
refiner_config = OmegaConf.load(sd_models_config.config_sdxl_refiner).model.params.network_config
self.model = instantiate_from_config(refiner_config)
self.model = get_obj_from_str(OPENAIUNETWRAPPER)(
self.model, compile_model=False
).eval()
self.model.to('cpu', devices.dtype_unet)
self.model.train = disabled_train
self.model.diffusion_model.dtype = devices.dtype_unet
self.model.conditioning_key = 'crossattn'
self.model.cond_stage_key = 'txt'
self.model.parameterization = 'v'
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
self.model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=devices.dtype_unet)
for param in self.model.parameters():
param.requires_grad = False
def load_model(self, model_name):
if not shared.opts.disable_mmap_load_safetensors:
ckpt = load_file(sd_models.checkpoints_list[model_name].filename)
else:
ckpt = load(open(sd_models.checkpoints_list[model_name].filename, 'rb').read())
model_type = ''
for key in ckpt.keys():
if 'conditioner' in key:
model_type = 'Refiner'
if 'input_blocks.7.1.transformer_blocks.4.attn1.to_k.weight' in key:
model_type = 'Base'
break
if model_type != 'Refiner':
self.enable = False
script_callbacks.remove_current_script_callbacks()
if model_type == 'Base':
print('\nIt\'s Base model, use Refiner, extension disabled!\n')
else:
print('\nNot refiner, extension disabled!\n')
return False
print('\nLoading refiner...\n')
self.build_model()
state_dict = dict()
for key in ckpt.keys():
if 'model.diffusion_model' in key:
state_dict[key.replace('model.d', 'd')] = ckpt[key].half()
self.model.load_state_dict(state_dict)
self.model_name = model_name
return True
def ui(self, is_img2img):
with gr.Accordion(label='Refiner', open=False):
enable = gr.Checkbox(label='Enable Refiner', value=False)
with gr.Row():
checkpoint = gr.Dropdown(choices=['None', *sd_models.checkpoints_list.keys()], label='Model', value='None')
steps = gr.Slider(minimum=0, maximum=50, step=1, label='Percent of refiner steps from total sampling steps', value=20)
gr.HTML('<p style="margin-bottom:0.8em"> It\'s recommended to keep the percentage at 20% (80% base steps, 20% refiner steps). Higher values may result in distortions. </p>')
ui = [enable, checkpoint, steps]
return ui
def process(self, p, enable, checkpoint, steps):
if self.base != None or self.swapped == True or self.callback_set == True:
self.reset(p)
if not enable or checkpoint == 'None':
script_callbacks.remove_current_script_callbacks()
self.model = None
return
if self.model == None or self.model_name != checkpoint:
if not self.load_model(checkpoint): return
self.c_ae = self.embedder(torch.tensor(shared.opts.sdxl_refiner_high_aesthetic_score).unsqueeze(0).to(devices.device).repeat(p.batch_size, 1))
self.uc_ae = self.embedder(torch.tensor(shared.opts.sdxl_refiner_low_aesthetic_score).unsqueeze(0).to(devices.device).repeat(p.batch_size, 1))
p.extra_generation_params['Refiner model'] = checkpoint.rsplit('.', 1)[0]
p.extra_generation_params['Refiner steps'] = ceil((p.steps * (steps / 100)))
def denoiser_callback(params: script_callbacks.CFGDenoiserParams):
if params.sampling_step > params.total_sampling_steps * (1 - steps / 100) - 2:
params.text_cond['vector'] = torch.cat((params.text_cond['vector'][:, :2304], self.c_ae), 1)
params.text_uncond['vector'] = torch.cat((params.text_uncond['vector'][:, :2304], self.uc_ae), 1)
params.text_cond['crossattn'] = params.text_cond['crossattn'][:, :, -1280:]
params.text_uncond['crossattn'] = params.text_uncond['crossattn'][:, :, -1280:]
if not self.swapped:
self.base = p.sd_model.model.to('cpu', devices.dtype_unet)
devices.torch_gc()
p.sd_model.model = self.model.to(devices.device, devices.dtype_unet)
self.swapped = True
def denoised_callback(params: script_callbacks.CFGDenoiserParams):
if params.sampling_step == params.total_sampling_steps - 2:
self.reset(p, keep_hook=True)
if not self.callback_set:
script_callbacks.on_cfg_denoiser(denoiser_callback)
script_callbacks.on_cfg_denoised(denoised_callback)
self.callback_set = True
def reset(self, p, keep_hook=False):
if self.model is not None:
self.model.to('cpu', devices.dtype_unet)
p.sd_model.model = (self.base or p.sd_model.model).to(devices.device, devices.dtype_unet)
devices.torch_gc()
self.base = None
self.swapped = False
if not keep_hook:
script_callbacks.remove_current_script_callbacks()
self.callback_set = False
def postprocess(self, p, processed, enable, checkpoint, steps):
if enable and checkpoint != 'None':
self.reset(p)