rahul7star commited on
Commit
10685bc
·
verified ·
1 Parent(s): c347a24

Add keybased_modelmerger.py from 2vXpSwA7/iroiro-lora

Browse files
Files changed (1) hide show
  1. keybased_modelmerger.py +144 -0
keybased_modelmerger.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import safe_open
3
+ from modules import scripts, sd_models, shared
4
+ import gradio as gr
5
+ from modules.processing import process_images
6
+
7
+
8
+ class KeyBasedModelMerger(scripts.Script):
9
+ def title(self):
10
+ return "Key-based model merging"
11
+
12
+ def ui(self, is_txt2img):
13
+ model_names = sorted(sd_models.checkpoints_list.keys(), key=str.casefold)
14
+
15
+ model_a_dropdown = gr.Dropdown(
16
+ label="Model A", choices=model_names, value=model_names[0] if model_names else None
17
+ )
18
+ model_b_dropdown = gr.Dropdown(
19
+ label="Model B", choices=model_names, value=model_names[0] if model_names else None
20
+ )
21
+ model_c_dropdown = gr.Dropdown(
22
+ label="Model C (Add difference mode用)", choices=model_names, value=model_names[0] if model_names else None
23
+ )
24
+ keys_and_alphas_textbox = gr.Textbox(
25
+ label="マージするテンソルのキーとマージ比率 (部分一致, 1行に1つ, カンマ区切り)",
26
+ lines=5,
27
+ placeholder="例:\nmodel.diffusion_model.input_blocks.0,0.5\nmodel.diffusion_model.middle_block,0.3"
28
+ )
29
+ merge_checkbox = gr.Checkbox(label="モデルのマージを有効にする", value=True)
30
+ use_gpu_checkbox = gr.Checkbox(label="GPUを使用", value=True)
31
+ batch_size_slider = gr.Slider(minimum=1, maximum=500, step=1, value=250, label="KeyMgerge_BatchSize")
32
+ merge_mode_dropdown = gr.Dropdown(
33
+ label="Merge Mode",
34
+ choices=["Normal", "Add difference (B-C to Current)", "Add difference (A + (B-C) to Current)"],
35
+ value="Normal"
36
+ )
37
+
38
+ return [model_a_dropdown, model_b_dropdown, model_c_dropdown, keys_and_alphas_textbox,
39
+ merge_checkbox, use_gpu_checkbox, batch_size_slider, merge_mode_dropdown]
40
+
41
+ def run(self, p, model_a_name, model_b_name, model_c_name, keys_and_alphas_str,
42
+ merge_enabled, use_gpu, batch_size, merge_mode):
43
+ if not model_b_name:
44
+ print("Error: Model B is not selected.")
45
+ return p
46
+
47
+ try:
48
+ # 必要なモデルファイルだけを読み込む
49
+ if merge_mode == "Normal":
50
+ model_a_filename = sd_models.checkpoints_list[model_a_name].filename
51
+ model_b_filename = sd_models.checkpoints_list[model_b_name].filename
52
+ elif merge_mode == "Add difference (B-C to Current)":
53
+ model_b_filename = sd_models.checkpoints_list[model_b_name].filename
54
+ model_c_filename = sd_models.checkpoints_list[model_c_name].filename
55
+ elif merge_mode == "Add difference (A + (B-C) to Current)":
56
+ model_a_filename = sd_models.checkpoints_list[model_a_name].filename
57
+ model_b_filename = sd_models.checkpoints_list[model_b_name].filename
58
+ model_c_filename = sd_models.checkpoints_list[model_c_name].filename
59
+ else:
60
+ raise ValueError(f"Invalid merge mode: ")
61
+
62
+ except KeyError as e:
63
+ print(f"Error: Selected model is not found in checkpoints list. ")
64
+ return p
65
+
66
+ # マージ処理
67
+ if merge_enabled:
68
+ input_keys_and_alphas = []
69
+ for line in keys_and_alphas_str.split("\n"):
70
+ if "," in line:
71
+ key_part, alpha_str = line.split(",", 1)
72
+ try:
73
+ alpha = float(alpha_str)
74
+ input_keys_and_alphas.append((key_part, alpha))
75
+ except ValueError:
76
+ print(f"Invalid alpha value in line '', skipping...")
77
+
78
+ # state_dictからキーのリストを事前に作成
79
+ model_keys = list(shared.sd_model.state_dict().keys())
80
+
81
+ # 部分一致検索を行う
82
+ final_keys_and_alphas = {}
83
+ for key_part, alpha in input_keys_and_alphas:
84
+ for model_key in model_keys:
85
+ if key_part in model_key:
86
+ final_keys_and_alphas[model_key] = alpha
87
+
88
+ # デバイスの設定 (GPUかCPUか選べるようにする)
89
+ device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
90
+
91
+ # バッチ処理でキーをまとめて処理
92
+ batched_keys = list(final_keys_and_alphas.items())
93
+
94
+ # モデルファイルを開く
95
+ if merge_mode == "Normal":
96
+ with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
97
+ safe_open(model_b_filename, framework="pt", device=device) as f_b:
98
+ self._merge_models(f_a, f_b, None, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
99
+ elif merge_mode == "Add difference (B-C to Current)":
100
+ with safe_open(model_b_filename, framework="pt", device=device) as f_b, \
101
+ safe_open(model_c_filename, framework="pt", device=device) as f_c:
102
+ self._merge_models(None, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
103
+ elif merge_mode == "Add difference (A + (B-C) to Current)":
104
+ with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
105
+ safe_open(model_b_filename, framework="pt", device=device) as f_b, \
106
+ safe_open(model_c_filename, framework="pt", device=device) as f_c:
107
+ self._merge_models(f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
108
+ else:
109
+ raise ValueError(f"Invalid merge mode: ")
110
+
111
+ # 必要に応じて process_images を実行
112
+ return process_images(p)
113
+
114
+ def _merge_models(self, f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device):
115
+ # バッチごとに処理
116
+ for i in range(0, len(batched_keys), batch_size):
117
+ batch = batched_keys[i:i + batch_size]
118
+
119
+ # バッチでテンソルを取得
120
+ tensors_a = [f_a.get_tensor(key) for key, _ in batch] if f_a is not None else None
121
+ tensors_b = [f_b.get_tensor(key) for key, _ in batch] if f_b is not None else None
122
+ tensors_c = [f_c.get_tensor(key) for key, _ in batch] if f_c is not None else None
123
+ alphas = [final_keys_and_alphas[key] for key, _ in batch]
124
+
125
+ # マージ処理の実行
126
+ for j, (key, alpha) in enumerate(batch):
127
+ tensor_a = tensors_a[j] if tensors_a is not None else None
128
+ tensor_b = tensors_b[j] if tensors_b is not None else None
129
+ tensor_c = tensors_c[j] if tensors_c is not None else None
130
+
131
+ if merge_mode == "Normal":
132
+ merged_tensor = torch.lerp(tensor_a, tensor_b, alpha)
133
+ print(f"NomalMerged:{alpha}:{key}")
134
+ elif merge_mode == "Add difference (B-C to Current)":
135
+ merged_tensor = shared.sd_model.state_dict()[key] + alpha * (tensor_b - tensor_c)
136
+ print(f"(B-C to Current):{alpha}:{key}")
137
+ elif merge_mode == "Add difference (A + (B-C) to Current)":
138
+ merged_tensor = tensor_a + alpha * (tensor_b - tensor_c)
139
+ print(f"(A + (B-C) to Current):{alpha}:{key}")
140
+ else:
141
+ raise ValueError(f"Invalid merge mode: ")
142
+
143
+ shared.sd_model.state_dict()[key].copy_(merged_tensor.to(device))
144
+