|
|
import copy |
|
|
import itertools |
|
|
import json |
|
|
from datetime import datetime |
|
|
|
|
|
import modules.scripts as scripts |
|
|
import gradio as gr |
|
|
|
|
|
from ldm.modules.diffusionmodules.openaimodel import UNetModel |
|
|
from modules import sd_models, shared, devices |
|
|
from scripts.mbw_util.preset_weights import PresetWeights |
|
|
import torch |
|
|
from natsort import natsorted |
|
|
|
|
|
from pathlib import Path |
|
|
import safetensors.torch |
|
|
|
|
|
presetWeights = PresetWeights() |
|
|
|
|
|
shared.UNetBManager = None |
|
|
|
|
|
known_block_prefixes = [ |
|
|
'input_blocks.0.', |
|
|
'input_blocks.1.', |
|
|
'input_blocks.2.', |
|
|
'input_blocks.3.', |
|
|
'input_blocks.4.', |
|
|
'input_blocks.5.', |
|
|
'input_blocks.6.', |
|
|
'input_blocks.7.', |
|
|
'input_blocks.8.', |
|
|
'input_blocks.9.', |
|
|
'input_blocks.10.', |
|
|
'input_blocks.11.', |
|
|
'middle_block.', |
|
|
'out.', |
|
|
'output_blocks.0.', |
|
|
'output_blocks.1.', |
|
|
'output_blocks.2.', |
|
|
'output_blocks.3.', |
|
|
'output_blocks.4.', |
|
|
'output_blocks.5.', |
|
|
'output_blocks.6.', |
|
|
'output_blocks.7.', |
|
|
'output_blocks.8.', |
|
|
'output_blocks.9.', |
|
|
'output_blocks.10.', |
|
|
'output_blocks.11.', |
|
|
'time_embed.' |
|
|
] |
|
|
|
|
|
class UNetStateManager(object): |
|
|
def __init__(self, org_unet: UNetModel = None): |
|
|
super().__init__() |
|
|
self.modelB_state_dict_by_blocks = [] |
|
|
self.torch_unet = org_unet |
|
|
|
|
|
self.modelA_state_dict = None |
|
|
self.dtype = devices.dtype |
|
|
self.modelA_state_dict_by_blocks = [] |
|
|
|
|
|
self.modelB_state_dict = None |
|
|
|
|
|
self.unet_block_module_list = [*self.torch_unet.input_blocks, self.torch_unet.middle_block, self.torch_unet.out, |
|
|
*self.torch_unet.output_blocks, self.torch_unet.time_embed] |
|
|
self.applied_weights = [0] * 27 |
|
|
|
|
|
self.enabled = False |
|
|
self.modelA_path = shared.sd_model.sd_model_checkpoint |
|
|
self.modelB_path = '' |
|
|
self.force_cpu = False |
|
|
self.modelA_dtype = None |
|
|
self.modelB_dtype = None |
|
|
self.device = devices.get_cuda_device_string() if (torch.cuda.is_available() and not shared.cmd_opts.lowvram) else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reload_modelA(self): |
|
|
if not self.enabled: |
|
|
return |
|
|
|
|
|
if self.modelA_path == shared.sd_model.sd_model_checkpoint and self.modelA_state_dict is not None: |
|
|
return |
|
|
self.modelA_path = shared.sd_model.sd_model_checkpoint |
|
|
|
|
|
del self.modelA_state_dict_by_blocks |
|
|
self.modelA_state_dict_by_blocks = [] |
|
|
|
|
|
|
|
|
|
|
|
del self.modelA_state_dict |
|
|
torch.cuda.empty_cache() |
|
|
if self.force_cpu: |
|
|
self.modelA_state_dict = self.filter_unet_state_dict( |
|
|
sd_models.read_state_dict(self.modelA_path, map_location="cpu")) |
|
|
self.map_blocks(self.modelA_state_dict, self.modelA_state_dict_by_blocks) |
|
|
self.modelA_dtype = itertools.islice(self.modelA_state_dict.items(), 1).__next__()[1].dtype |
|
|
else: |
|
|
self.modelA_state_dict = copy.deepcopy(self.torch_unet.state_dict()) |
|
|
self.map_blocks(self.modelA_state_dict, self.modelA_state_dict_by_blocks) |
|
|
|
|
|
|
|
|
self.model_state_apply(self.applied_weights) |
|
|
print('model A reloaded') |
|
|
|
|
|
def load_modelB(self, modelB_path, force_cpu_checkbox, current_weights): |
|
|
self.force_cpu = force_cpu_checkbox |
|
|
self.device = devices.get_cuda_device_string() if (torch.cuda.is_available() and not shared.cmd_opts.lowvram) else "cpu" |
|
|
if self.force_cpu: |
|
|
self.device = "cpu" |
|
|
model_info = sd_models.get_closet_checkpoint_match(modelB_path) |
|
|
checkpoint_file = model_info.filename |
|
|
self.modelB_path = checkpoint_file |
|
|
|
|
|
|
|
|
if self.modelA_path == checkpoint_file: |
|
|
if not self.modelB_state_dict: |
|
|
self.enabled = False |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
if not self.modelA_state_dict: |
|
|
if self.force_cpu: |
|
|
self.modelA_path = shared.sd_model.sd_model_checkpoint |
|
|
self.modelA_state_dict = self.filter_unet_state_dict( |
|
|
sd_models.read_state_dict(self.modelA_path, map_location="cpu")) |
|
|
self.map_blocks(self.modelA_state_dict, self.modelA_state_dict_by_blocks) |
|
|
|
|
|
else: |
|
|
self.modelA_state_dict = copy.deepcopy(self.torch_unet.state_dict()) |
|
|
self.map_blocks(self.modelA_state_dict, self.modelA_state_dict_by_blocks) |
|
|
|
|
|
self.modelA_dtype = itertools.islice(self.modelA_state_dict.items(), 1).__next__()[1].dtype |
|
|
sd_model_hash = model_info.hash |
|
|
cache_enabled = shared.opts.sd_checkpoint_cache > 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.modelB_state_dict: |
|
|
|
|
|
|
|
|
|
|
|
del self.modelB_state_dict_by_blocks |
|
|
del self.modelB_state_dict |
|
|
torch.cuda.empty_cache() |
|
|
self.modelB_state_dict_by_blocks = [] |
|
|
self.modelB_state_dict = self.filter_unet_state_dict( |
|
|
sd_models.read_state_dict(checkpoint_file, map_location=self.device)) |
|
|
self.modelB_dtype = itertools.islice(self.modelB_state_dict.items(), 1).__next__()[1].dtype |
|
|
if len(self.modelA_state_dict) != len(self.modelB_state_dict): |
|
|
print('modelA and modelB state dict have different length, aborting') |
|
|
return False |
|
|
self.map_blocks(self.modelB_state_dict, self.modelB_state_dict_by_blocks) |
|
|
|
|
|
self.model_state_apply(current_weights) |
|
|
|
|
|
print('model B loaded') |
|
|
self.enabled = True |
|
|
return True |
|
|
|
|
|
def model_state_apply(self, current_weights): |
|
|
|
|
|
|
|
|
operation_dtype = torch.float32 if self.modelA_dtype == torch.float32 or self.modelB_dtype == torch.float32 else torch.float16 |
|
|
for i in range(27): |
|
|
cur_block_state_dict = {} |
|
|
for cur_layer_key in self.modelA_state_dict_by_blocks[i]: |
|
|
if operation_dtype == torch.float32: |
|
|
|
|
|
curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
current_weights[i]).to(self.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
if self.force_cpu: |
|
|
curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
current_weights[i]).to(self.dtype) |
|
|
else: |
|
|
|
|
|
curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key], |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key], current_weights[i]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if str(shared.device) != self.device: |
|
|
curlayer_tensor = curlayer_tensor.to(shared.device) |
|
|
cur_block_state_dict[cur_layer_key] = curlayer_tensor |
|
|
self.unet_block_module_list[i].load_state_dict(cur_block_state_dict) |
|
|
self.applied_weights = current_weights |
|
|
|
|
|
def model_state_construct(self, current_weights): |
|
|
precision_dtype = torch.float32 if self.modelA_dtype == torch.float32 or self.modelB_dtype == torch.float32 else torch.float16 |
|
|
result_state_dict = {} |
|
|
for i in range(27): |
|
|
cur_block_state_dict = {} |
|
|
for cur_layer_key in self.modelA_state_dict_by_blocks[i]: |
|
|
if precision_dtype == torch.float32: |
|
|
curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
current_weights[i]) |
|
|
else: |
|
|
if self.force_cpu: |
|
|
curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
current_weights[i]).to(torch.float16) |
|
|
else: |
|
|
curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key], |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key], current_weights[i]) |
|
|
|
|
|
result_state_dict[known_block_prefixes[i] + cur_layer_key] = curlayer_tensor |
|
|
return result_state_dict |
|
|
|
|
|
|
|
|
|
|
|
def model_state_apply_modified_blocks(self, current_weights, current_model_B): |
|
|
if not self.enabled: |
|
|
return |
|
|
modelB_info = sd_models.get_closet_checkpoint_match(current_model_B) |
|
|
checkpoint_file_B = modelB_info.filename |
|
|
if checkpoint_file_B != self.modelB_path: |
|
|
print('model B changed, shouldn\'t happen') |
|
|
self.load_modelB(current_model_B, current_weights) |
|
|
return |
|
|
if self.applied_weights == current_weights: |
|
|
return |
|
|
operation_dtype = torch.float32 if self.modelA_dtype == torch.float32 or self.modelB_dtype == torch.float32 else torch.float16 |
|
|
for i in range(27): |
|
|
if current_weights[i] != self.applied_weights[i]: |
|
|
cur_block_state_dict = {} |
|
|
for cur_layer_key in self.modelA_state_dict_by_blocks[i]: |
|
|
if operation_dtype == torch.float32: |
|
|
curlayer_tensor = torch.lerp( |
|
|
self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
current_weights[i]).to(self.dtype) |
|
|
else: |
|
|
if self.force_cpu: |
|
|
curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key].to(torch.float32), |
|
|
current_weights[i]).to(torch.float16) |
|
|
else: |
|
|
curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key], |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key], |
|
|
current_weights[i]) |
|
|
if str(shared.device) != self.device: |
|
|
curlayer_tensor = curlayer_tensor.to(shared.device) |
|
|
cur_block_state_dict[cur_layer_key] = curlayer_tensor |
|
|
self.unet_block_module_list[i].load_state_dict(cur_block_state_dict) |
|
|
self.applied_weights = current_weights |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model_state_apply_block(self, current_weights): |
|
|
|
|
|
if not self.enabled: |
|
|
return self.applied_weights |
|
|
for i in range(27): |
|
|
if current_weights[i] != self.applied_weights[i]: |
|
|
cur_block_state_dict = {} |
|
|
for cur_layer_key in self.modelA_state_dict_by_blocks[i]: |
|
|
curlayer_tensor = torch.lerp(self.modelA_state_dict_by_blocks[i][cur_layer_key], |
|
|
self.modelB_state_dict_by_blocks[i][cur_layer_key], current_weights[i]) |
|
|
cur_block_state_dict[cur_layer_key] = curlayer_tensor |
|
|
self.unet_block_module_list[i].load_state_dict(cur_block_state_dict) |
|
|
self.applied_weights = current_weights |
|
|
return self.applied_weights |
|
|
|
|
|
|
|
|
def filter_unet_state_dict(self, input_dict): |
|
|
filtered_dict = {} |
|
|
for key, value in input_dict.items(): |
|
|
|
|
|
if key.startswith('model.diffusion_model'): |
|
|
filtered_dict[key[22:]] = value |
|
|
filtered_dict_keys = natsorted(filtered_dict.keys()) |
|
|
filtered_dict = {k: filtered_dict[k] for k in filtered_dict_keys} |
|
|
|
|
|
return filtered_dict |
|
|
|
|
|
def map_blocks(self, model_state_dict_input, model_state_dict_by_blocks): |
|
|
if model_state_dict_by_blocks: |
|
|
print('mapping to non empty list') |
|
|
return |
|
|
model_state_dict_sorted_keys = natsorted(model_state_dict_input.keys()) |
|
|
|
|
|
model_state_dict = {k: model_state_dict_input[k] for k in model_state_dict_sorted_keys} |
|
|
|
|
|
|
|
|
current_block_index = 0 |
|
|
processing_block_dict = {} |
|
|
for key in model_state_dict: |
|
|
|
|
|
if not key.startswith(known_block_prefixes[current_block_index]): |
|
|
if not key.startswith(known_block_prefixes[current_block_index + 1]): |
|
|
print( |
|
|
f"unknown key {key} in statedict after block {known_block_prefixes[current_block_index]}, possible UNet structure deviation" |
|
|
) |
|
|
continue |
|
|
else: |
|
|
model_state_dict_by_blocks.append(processing_block_dict) |
|
|
processing_block_dict = {} |
|
|
current_block_index += 1 |
|
|
block_local_key = key[len(known_block_prefixes[current_block_index]):] |
|
|
processing_block_dict[block_local_key] = model_state_dict[key] |
|
|
|
|
|
model_state_dict_by_blocks.append(processing_block_dict) |
|
|
print('mapping complete') |
|
|
return |
|
|
|
|
|
def restore_original_unet(self): |
|
|
self.torch_unet.load_state_dict(self.modelA_state_dict) |
|
|
return |
|
|
|
|
|
def unload_all(self): |
|
|
self.modelA_path = '' |
|
|
self.modelB_path = '' |
|
|
self.applied_weights = [0.0] * 27 |
|
|
del self.modelA_state_dict |
|
|
self.modelA_state_dict = None |
|
|
del self.modelA_state_dict_by_blocks |
|
|
self.modelA_state_dict_by_blocks = [] |
|
|
del self.modelB_state_dict |
|
|
self.modelB_state_dict = None |
|
|
del self.modelB_state_dict_by_blocks |
|
|
self.modelB_state_dict_by_blocks = [] |
|
|
|
|
|
self.enabled = False |
|
|
|
|
|
|
|
|
class Script(scripts.Script): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
if shared.UNetBManager is None: |
|
|
try: |
|
|
shared.UNetBManager = UNetStateManager(shared.sd_model.model.diffusion_model) |
|
|
except AttributeError: |
|
|
shared.UNetBManager = None |
|
|
from modules.call_queue import wrap_queued_call |
|
|
|
|
|
def reload_modelA_checkpoint(): |
|
|
if shared.opts.sd_model_checkpoint == shared.sd_model.sd_checkpoint_info.title: |
|
|
return |
|
|
sd_models.reload_model_weights() |
|
|
shared.UNetBManager.reload_modelA() |
|
|
|
|
|
shared.opts.onchange("sd_model_checkpoint", |
|
|
wrap_queued_call(reload_modelA_checkpoint), call=False) |
|
|
|
|
|
def title(self): |
|
|
return "Runtime block merging for UNet" |
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
def ui(self, is_img2img): |
|
|
process_script_params = [] |
|
|
with gr.Accordion('Runtime Block Merge', open=False): |
|
|
hidden_title = gr.Textbox(label='Runtime Block Merge Title', value='Runtime Block Merge', |
|
|
visible=False, interactive=False) |
|
|
with gr.Row(): |
|
|
enabled = gr.Checkbox(label='Enable', value=False, interactive=False) |
|
|
unload_button = gr.Button(value='Unload and Disable', elem_id="rbm_unload", visible=False) |
|
|
experimental_range_checkbox = gr.Checkbox(label='Enable Experimental Range', value=False) |
|
|
force_cpu_checkbox = gr.Checkbox(label='Force CPU (Max Precision)', value=True, interactive=True) |
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
dd_preset_weight = gr.Dropdown(label="Preset Weights", |
|
|
choices=presetWeights.get_preset_name_list()) |
|
|
config_paste_button = gr.Button(value='Generate Merge Block Weighted Config\u2199\ufe0f', |
|
|
elem_id="rbm_config_paste", |
|
|
title="Paste Current Block Configs Into Weight Command. Useful for copying to \"Merge Block Weighted\" extension") |
|
|
weight_command_textbox = gr.Textbox(label="Weight Command", |
|
|
placeholder="Input weight command, then press enter. \nExample: base:0.5, in00:1, out09:0.8, time_embed:0, out:0") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
model_B = gr.Dropdown(label="Model B", choices=sd_models.checkpoint_tiles()) |
|
|
refresh_button = gr.Button(variant='tool', value='\U0001f504', elem_id='rbm_modelb_refresh') |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
sl_TIME_EMBED = gr.Slider(label="TIME_EMBED", minimum=0, maximum=1, step=0.01, value=0) |
|
|
sl_OUT = gr.Slider(label="OUT", minimum=0, maximum=1, step=0.01, value=0) |
|
|
with gr.Row(): |
|
|
with gr.Column(min_width=100): |
|
|
sl_IN_00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_IN_11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
with gr.Column(min_width=100): |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
gr.Slider(visible=False) |
|
|
sl_M_00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.01, value=0.5, |
|
|
elem_id="mbw_sl_M00") |
|
|
with gr.Column(min_width=100): |
|
|
sl_OUT_11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
sl_OUT_00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.01, value=0.5) |
|
|
|
|
|
sl_INPUT = [ |
|
|
sl_IN_00, sl_IN_01, sl_IN_02, sl_IN_03, sl_IN_04, sl_IN_05, |
|
|
sl_IN_06, sl_IN_07, sl_IN_08, sl_IN_09, sl_IN_10, sl_IN_11] |
|
|
sl_MID = [sl_M_00] |
|
|
sl_OUTPUT = [ |
|
|
sl_OUT_00, sl_OUT_01, sl_OUT_02, sl_OUT_03, sl_OUT_04, sl_OUT_05, |
|
|
sl_OUT_06, sl_OUT_07, sl_OUT_08, sl_OUT_09, sl_OUT_10, sl_OUT_11] |
|
|
sl_ALL_nat = [*sl_INPUT, *sl_MID, sl_OUT, *sl_OUTPUT, sl_TIME_EMBED] |
|
|
sl_ALL = [*sl_INPUT, *sl_MID, *sl_OUTPUT, sl_TIME_EMBED, sl_OUT] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_modelB_load(modelB, force_cpu_checkbox, *slALL): |
|
|
if modelB is None: |
|
|
return None, False, gr.update(interactive=True), gr.update(visible=False), gr.update(visible=False) |
|
|
load_flag = shared.UNetBManager.load_modelB(modelB, force_cpu_checkbox, slALL) |
|
|
if load_flag: |
|
|
return modelB, True, gr.update(interactive=False), gr.update(visible=True), gr.update(visible=True) |
|
|
else: |
|
|
return None, False, gr.update(interactive=True), gr.update(visible=False), gr.update(visible=False) |
|
|
|
|
|
def handle_unload(): |
|
|
shared.UNetBManager.restore_original_unet() |
|
|
shared.UNetBManager.unload_all() |
|
|
return None, False, gr.update(interactive=True), gr.update(visible=False), gr.update(visible=False) |
|
|
|
|
|
def handle_weight_change(*slALL): |
|
|
|
|
|
slALL_str = [str(sl) for sl in slALL] |
|
|
old_config_str = ','.join(slALL_str[:25]) |
|
|
return old_config_str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_weight_command_submit(command_str, *current_weights): |
|
|
weight_list = parse_weight_str_to_list(command_str, list(current_weights)) |
|
|
if not weight_list: |
|
|
return [gr.update() for _ in range(27)] |
|
|
if len(weight_list) == 25: |
|
|
|
|
|
weight_list.extend([gr.update(), gr.update()]) |
|
|
return weight_list |
|
|
|
|
|
weight_command_textbox.submit( |
|
|
fn=on_weight_command_submit, |
|
|
inputs=[weight_command_textbox, *sl_ALL], |
|
|
outputs=sl_ALL |
|
|
) |
|
|
|
|
|
def parse_weight_str_to_list(weightstr, current_weights): |
|
|
weightstr = weightstr[:500] |
|
|
if ':' in weightstr: |
|
|
|
|
|
weightstr = weightstr.replace(' ', '') |
|
|
cmd_segments = weightstr.split(',') |
|
|
constructed_json_segments = [f'"{key.upper()}":{value}' for key, value in |
|
|
[x.split(':') for x in cmd_segments]] |
|
|
constructed_json = '{' + ','.join(constructed_json_segments) + '}' |
|
|
try: |
|
|
parsed_json = json.loads(constructed_json) |
|
|
|
|
|
except Exception as e: |
|
|
print(e) |
|
|
return None |
|
|
weight_name_map = { |
|
|
'IN00': 0, |
|
|
'IN01': 1, |
|
|
'IN02': 2, |
|
|
'IN03': 3, |
|
|
'IN04': 4, |
|
|
'IN05': 5, |
|
|
'IN06': 6, |
|
|
'IN07': 7, |
|
|
'IN08': 8, |
|
|
'IN09': 9, |
|
|
'IN10': 10, |
|
|
'IN11': 11, |
|
|
'M00': 12, |
|
|
'OUT00': 13, |
|
|
'OUT01': 14, |
|
|
'OUT02': 15, |
|
|
'OUT03': 16, |
|
|
'OUT04': 17, |
|
|
'OUT05': 18, |
|
|
'OUT06': 19, |
|
|
'OUT07': 20, |
|
|
'OUT08': 21, |
|
|
'OUT09': 22, |
|
|
'OUT10': 23, |
|
|
'OUT11': 24, |
|
|
'TIME_EMBED': 25, |
|
|
'OUT': 26 |
|
|
} |
|
|
extra_commands = ['BASE'] |
|
|
|
|
|
for key, value in parsed_json.items(): |
|
|
if key not in weight_name_map and key not in extra_commands: |
|
|
print(f'invalid key: {key}') |
|
|
return None |
|
|
if not (isinstance(value, (float, int))) or value < -1 or value > 2: |
|
|
print(f'{key} value {value} out of range') |
|
|
return None |
|
|
|
|
|
weight_list = current_weights |
|
|
if 'BASE' in parsed_json: |
|
|
weight_list = [float(parsed_json['BASE'])] * 27 |
|
|
del parsed_json['BASE'] |
|
|
for key, value in parsed_json.items(): |
|
|
weight_list[weight_name_map[key]] = value |
|
|
return weight_list |
|
|
else: |
|
|
|
|
|
_list = [x.strip() for x in weightstr.split(",")] |
|
|
if len(_list) != 25 and len(_list) != 27: |
|
|
return None |
|
|
validated_float_weight_list = [] |
|
|
for x in _list: |
|
|
try: |
|
|
validated_float_weight_list.append(float(x)) |
|
|
except ValueError: |
|
|
return None |
|
|
return validated_float_weight_list |
|
|
|
|
|
def on_change_dd_preset_weight(preset_weight_name, *current_weights): |
|
|
_weights = presetWeights.find_weight_by_name(preset_weight_name) |
|
|
weight_list = parse_weight_str_to_list(_weights, list(current_weights)) |
|
|
if not weight_list: |
|
|
return [gr.update() for _ in range(27)] |
|
|
if len(weight_list) == 25: |
|
|
|
|
|
weight_list.extend([gr.update(), gr.update()]) |
|
|
return weight_list |
|
|
|
|
|
dd_preset_weight.change( |
|
|
fn=on_change_dd_preset_weight, |
|
|
inputs=[dd_preset_weight, *sl_ALL], |
|
|
outputs=sl_ALL |
|
|
) |
|
|
|
|
|
def update_slider_range(experimental_range_flag): |
|
|
if experimental_range_flag: |
|
|
return [gr.update(minimum=-1, maximum=2) for _ in sl_ALL] |
|
|
else: |
|
|
return [gr.update(minimum=0, maximum=1) for _ in sl_ALL] |
|
|
|
|
|
experimental_range_checkbox.change(fn=update_slider_range, inputs=[experimental_range_checkbox], |
|
|
outputs=sl_ALL) |
|
|
|
|
|
def on_config_paste(*current_weights): |
|
|
slALL_str = [str(sl) for sl in current_weights] |
|
|
old_config_str = ','.join(slALL_str[:25]) |
|
|
return old_config_str |
|
|
|
|
|
config_paste_button.click(fn=on_config_paste, inputs=[*sl_ALL], outputs=[weight_command_textbox]) |
|
|
|
|
|
def refresh_modelB_dropdown(): |
|
|
return gr.update(choices=sd_models.checkpoint_tiles()) |
|
|
|
|
|
refresh_button.click( |
|
|
fn=refresh_modelB_dropdown, |
|
|
inputs=None, |
|
|
outputs=[model_B] |
|
|
) |
|
|
|
|
|
|
|
|
process_script_params.extend(sl_ALL_nat) |
|
|
process_script_params.append(model_B) |
|
|
process_script_params.append(enabled) |
|
|
|
|
|
with gr.Row(): |
|
|
output_mode_radio = gr.Radio(label="Output Mode",choices=["Max Precision", "Runtime Snapshot"], |
|
|
value="Max Precision", type="value", interactive=True) |
|
|
position_id_fix_radio = gr.Radio(label="Skip/Reset CLIP position_ids", |
|
|
choices=["Keep Original", "Fix"], value="Keep Original", type="value", interactive=True) |
|
|
|
|
|
output_format_radio = gr.Radio(label="Output Format", |
|
|
choices=[".ckpt", ".safetensors"], value=".ckpt", type="value", |
|
|
interactive=True) |
|
|
with gr.Row(): |
|
|
output_recipe_checkbox = gr.Checkbox(label="Output Recipe", value=True, interactive=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
save_checkpoint_name_textbox = gr.Textbox(label="New Checkpoint Name") |
|
|
save_checkpoint_button = gr.Button(value="Save Runtime Checkpoint", elem_id="mbw_save_checkpoint_button", variant='primary', interactive=True, visible=False, ) |
|
|
|
|
|
def on_save_checkpoint(output_mode_radio, position_id_fix_radio, output_format_radio, save_checkpoint_name, output_recipe_checkbox, *weights, |
|
|
): |
|
|
current_weights_nat = weights[:27] |
|
|
|
|
|
weights_output_recipe = weights[27:] |
|
|
if not save_checkpoint_name: |
|
|
|
|
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
save_checkpoint_name = f"mbw_{timestamp_str}" |
|
|
save_checkpoint_namewext = save_checkpoint_name + output_format_radio |
|
|
loaded_sd_model_path = Path(shared.sd_model.sd_model_checkpoint) |
|
|
model_ext = loaded_sd_model_path.suffix |
|
|
if model_ext == '.ckpt': |
|
|
|
|
|
model_A_raw_state_dict = torch.load(shared.sd_model.sd_model_checkpoint, map_location='cpu') |
|
|
if 'state_dict' in model_A_raw_state_dict: |
|
|
model_A_raw_state_dict = model_A_raw_state_dict['state_dict'] |
|
|
elif model_ext == '.safetensors': |
|
|
model_A_raw_state_dict = safetensors.torch.load_file(shared.sd_model.sd_model_checkpoint, device="cpu") |
|
|
save_checkpoint_path = Path(shared.sd_model.sd_model_checkpoint).parent / save_checkpoint_namewext |
|
|
|
|
|
if output_mode_radio == 'Runtime Snapshot': |
|
|
snapshot_state_dict = shared.sd_model.model.diffusion_model.state_dict() |
|
|
|
|
|
elif output_mode_radio == 'Max Precision': |
|
|
snapshot_state_dict = shared.UNetBManager.model_state_construct(current_weights_nat) |
|
|
|
|
|
snapshot_state_dict_prefixed = {'model.diffusion_model.' + key: value for key, value in |
|
|
snapshot_state_dict.items()} |
|
|
if not set(snapshot_state_dict_prefixed.keys()).issubset(set(model_A_raw_state_dict.keys())): |
|
|
print( |
|
|
'warning: snapshot state_dict keys are not subset of model A state_dict keys, possible structural deviation') |
|
|
|
|
|
combined_state_dict = {**model_A_raw_state_dict, **snapshot_state_dict_prefixed} |
|
|
if position_id_fix_radio == 'Fix': |
|
|
combined_state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = torch.tensor([list(range(77))], dtype=torch.int64) |
|
|
|
|
|
if output_format_radio == '.ckpt': |
|
|
state_dict_save = {'state_dict': combined_state_dict} |
|
|
torch.save(state_dict_save, save_checkpoint_path) |
|
|
elif output_format_radio == '.safetensors': |
|
|
safetensors.torch.save_file(combined_state_dict, save_checkpoint_path) |
|
|
|
|
|
if output_recipe_checkbox: |
|
|
recipe_path = Path(shared.sd_model.sd_model_checkpoint).parent / f"{save_checkpoint_name}.recipe.txt" |
|
|
with open(recipe_path, 'w') as f: |
|
|
f.write(f"modelA={shared.sd_model.sd_model_checkpoint}\n") |
|
|
f.write(f"modelB={shared.UNetBManager.modelB_path}\n") |
|
|
f.write(f"position_id_fix={position_id_fix_radio}\n") |
|
|
f.write(f"output_mode={output_mode_radio}\n") |
|
|
f.write(f"{','.join([str(w) for w in weights_output_recipe])}\n") |
|
|
|
|
|
return gr.update(value=save_checkpoint_name) |
|
|
|
|
|
|
|
|
def on_change_force_cpu(force_cpu_flag): |
|
|
if not force_cpu_flag: |
|
|
return gr.update(choices=["Runtime Snapshot"], value="Runtime Snapshot") |
|
|
else: |
|
|
return gr.update(choices=["Max Precision", "Runtime Snapshot"], value="Max Precision") |
|
|
|
|
|
|
|
|
save_checkpoint_button.click( |
|
|
fn=on_save_checkpoint, |
|
|
inputs=[output_mode_radio, position_id_fix_radio, output_format_radio, save_checkpoint_name_textbox, output_recipe_checkbox, *sl_ALL_nat, *sl_ALL], |
|
|
outputs=[save_checkpoint_name_textbox], |
|
|
show_progress=True |
|
|
) |
|
|
force_cpu_checkbox.change(fn=on_change_force_cpu, inputs=[force_cpu_checkbox], outputs=[output_mode_radio]) |
|
|
model_B.change(fn=handle_modelB_load, inputs=[model_B, force_cpu_checkbox, *sl_ALL_nat], |
|
|
outputs=[model_B, enabled, force_cpu_checkbox, save_checkpoint_button, unload_button]) |
|
|
unload_button.click(fn=handle_unload, inputs=[], outputs=[model_B, enabled, force_cpu_checkbox, save_checkpoint_button, unload_button]) |
|
|
|
|
|
return process_script_params |
|
|
|
|
|
def process(self, p, *args): |
|
|
gui_weights = args[:27] |
|
|
modelB = args[27] |
|
|
enabled = args[28] |
|
|
if not enabled: |
|
|
return |
|
|
if not shared.UNetBManager: |
|
|
shared.UNetBManager = UNetStateManager(shared.sd_model.model.diffusion_model) |
|
|
shared.UNetBManager.model_state_apply_modified_blocks(gui_weights, modelB) |
|
|
|