AlekseyCalvin commited on
Commit
1bdae5f
·
verified ·
1 Parent(s): f578122

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +436 -0
app.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import re
6
+ import shutil
7
+ import requests
8
+ import json
9
+ from pathlib import Path
10
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
11
+ from safetensors.torch import load_file, save_file
12
+ from safetensors import safe_open
13
+ from tqdm import tqdm
14
+
15
+ # --- Constants & Setup ---
16
+ TempDir = Path("./temp_merge")
17
+ os.makedirs(TempDir, exist_ok=True)
18
+ api = HfApi()
19
+
20
+ def cleanup_temp():
21
+ if TempDir.exists():
22
+ shutil.rmtree(TempDir)
23
+ os.makedirs(TempDir, exist_ok=True)
24
+ gc.collect()
25
+
26
+ # --- Core Logic ---
27
+
28
+ def download_lora(lora_input, hf_token):
29
+ """Downloads LoRA from a Repo ID or a direct URL."""
30
+ local_path = TempDir / "adapter.safetensors"
31
+
32
+ if lora_input.startswith("http"):
33
+ print(f"Downloading LoRA from URL: {lora_input}")
34
+ response = requests.get(lora_input, stream=True)
35
+ response.raise_for_status()
36
+ with open(local_path, 'wb') as f:
37
+ for chunk in response.iter_content(chunk_size=8192):
38
+ f.write(chunk)
39
+ return local_path
40
+ else:
41
+ print(f"Downloading LoRA from Repo: {lora_input}")
42
+ try:
43
+ return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir)
44
+ except:
45
+ files = list_repo_files(repo_id=lora_input, token=hf_token)
46
+ # Prioritize safetensors
47
+ safe_files = [f for f in files if f.endswith(".safetensors")]
48
+ if not safe_files:
49
+ raise ValueError("Could not find a .safetensors file in the LoRA repo.")
50
+ # Heuristic: pick the one that looks most like a model file
51
+ target_file = safe_files[0]
52
+ for f in safe_files:
53
+ if "fp16" in f or "rank" in f:
54
+ target_file = f
55
+ break
56
+
57
+ return hf_hub_download(repo_id=lora_input, filename=target_file, token=hf_token, local_dir=TempDir)
58
+
59
+ def standardize_lora_config(lora_state_dict):
60
+ """
61
+ Analyzes the LoRA state dict and converts keys to a standardized Diffusers-compatible format.
62
+ Handles 'lora_down' -> 'lora_A', prefix stripping, and alpha scaling.
63
+ """
64
+ standardized_dict = {}
65
+ alphas = {}
66
+ ranks = {}
67
+
68
+ keys = list(lora_state_dict.keys())
69
+
70
+ # 1. First Pass: Detect structure and Alphas
71
+ for key in keys:
72
+ if "alpha" in key:
73
+ # key example: diffusion_model.layers.24.feed_forward.w1.alpha
74
+ stem = key.replace(".alpha", "")
75
+ alphas[stem] = lora_state_dict[key].item() if isinstance(lora_state_dict[key], torch.Tensor) else lora_state_dict[key]
76
+
77
+ print(f"Found {len(alphas)} alpha keys in LoRA.")
78
+
79
+ # 2. Second Pass: Convert Weights
80
+ for key in keys:
81
+ if "alpha" in key:
82
+ continue
83
+
84
+ tensor = lora_state_dict[key]
85
+ new_key = key
86
+
87
+ # --- Conversion Logic (Inspired by Diffusers lora_conversion_utils.py) ---
88
+
89
+ # Strip common ComfyUI/Internal prefixes
90
+ prefixes_to_strip = ["diffusion_model.", "model.diffusion_model.", "lora_unet_"]
91
+ for p in prefixes_to_strip:
92
+ if new_key.startswith(p):
93
+ new_key = new_key[len(p):]
94
+
95
+ # Convert lora_down/up to lora_A/B
96
+ is_down = "lora_down.weight" in new_key
97
+ is_up = "lora_up.weight" in new_key
98
+
99
+ if is_down:
100
+ new_key = new_key.replace("lora_down.weight", "lora_A.weight")
101
+ stem = key.split(".lora_down.weight")[0]
102
+ ranks[stem] = tensor.shape[0] # Down projection output dim is rank
103
+ elif is_up:
104
+ new_key = new_key.replace("lora_up.weight", "lora_B.weight")
105
+
106
+ # Handling Z-Image specific "feed_forward" vs "ff" discrepancies if necessary
107
+ # (Based on your logs, Z-Image base uses 'feed_forward' so we might not need heavy mapping if we strip prefix)
108
+
109
+ standardized_dict[new_key] = tensor
110
+
111
+ # 3. Third Pass: Embed Scaling into Weights
112
+ # If we have alpha and rank, we can pre-multiply the weights so the merge function just needs to do B @ A
113
+ # Scale = alpha / rank
114
+
115
+ final_dict = {}
116
+ for key, tensor in standardized_dict.items():
117
+ # Find corresponding stem to check for alpha
118
+ # key is like: layers.24.feed_forward.w1.lora_A.weight
119
+ if "lora_A.weight" in key:
120
+ stem_suffix = ".lora_A.weight"
121
+ is_A = True
122
+ elif "lora_B.weight" in key:
123
+ stem_suffix = ".lora_B.weight"
124
+ is_A = False
125
+ else:
126
+ final_dict[key] = tensor
127
+ continue
128
+
129
+ # We need to map the "new key" stem back to the "old key" stem to find the alpha
130
+ # This is tricky because we stripped prefixes.
131
+ # Simpler approach: Calculate scale factor now if possible, or store metadata.
132
+
133
+ # Heuristic: Match alpha by checking if alpha key ends with the current key's structural part
134
+ # Current key struct: layers.24.feed_forward.w1
135
+ struct_part = key.replace(stem_suffix, "")
136
+
137
+ scale = 1.0
138
+
139
+ # Find matching alpha
140
+ # We look for an alpha key that ends with 'struct_part'
141
+ # e.g. alpha key "diffusion_model.layers.24...w1" ends with "layers.24...w1"
142
+ found_alpha = None
143
+ for a_key, a_val in alphas.items():
144
+ if a_key.endswith(struct_part):
145
+ found_alpha = a_val
146
+ break
147
+
148
+ if found_alpha:
149
+ # We need the rank.
150
+ # If it's lora_A, rank is tensor.shape[0]
151
+ # If it's lora_B, rank is tensor.shape[1]
152
+ rank = tensor.shape[0] if is_A else tensor.shape[1]
153
+
154
+ # Scale calculation: scale = alpha / rank
155
+ # We apply sqrt(scale) to both A and B so that A@B is scaled by (alpha/rank)
156
+ scale_factor = (found_alpha / rank) ** 0.5
157
+ tensor = tensor * scale_factor
158
+
159
+ final_dict[key] = tensor
160
+
161
+ return final_dict
162
+
163
+ def match_keys(base_key, lora_keys):
164
+ """
165
+ Robust matching finding the best LoRA pair for a Base Key.
166
+ """
167
+ # base_key example: layers.24.feed_forward.w1.weight
168
+ # lora_key example: layers.24.feed_forward.w1.lora_A.weight
169
+
170
+ base_stem = base_key.replace(".weight", "")
171
+
172
+ pair_A = None
173
+ pair_B = None
174
+
175
+ # Exact stem match check
176
+ candidate_A = f"{base_stem}.lora_A.weight"
177
+ candidate_B = f"{base_stem}.lora_B.weight"
178
+
179
+ if candidate_A in lora_keys and candidate_B in lora_keys:
180
+ return candidate_A, candidate_B
181
+
182
+ # Fuzzy match if exact fails
183
+ # This handles slight naming diffs like "processor" inclusion
184
+ matches = [k for k in lora_keys if base_stem in k]
185
+
186
+ for k in matches:
187
+ if "lora_A" in k:
188
+ pair_A = k
189
+ elif "lora_B" in k:
190
+ pair_B = k
191
+
192
+ if pair_A and pair_B:
193
+ # Verify they belong to the same block
194
+ # e.g. ensure we don't match layer.24 to layer.2
195
+ prefix_A = pair_A.split(".lora_A")[0]
196
+ prefix_B = pair_B.split(".lora_B")[0]
197
+ if prefix_A == prefix_B:
198
+ return pair_A, pair_B
199
+
200
+ return None, None
201
+
202
+ def copy_auxiliary_files(src_repo, tgt_repo, token):
203
+ print(f"Copying infrastructure from {src_repo} to {tgt_repo}...")
204
+ try:
205
+ files = list_repo_files(repo_id=src_repo, token=token)
206
+ files_to_copy = [
207
+ f for f in files
208
+ if not f.endswith(".safetensors")
209
+ and not f.endswith(".bin")
210
+ and not f.endswith(".pt")
211
+ and not f.endswith(".pth")
212
+ and not f.endswith(".msgpack")
213
+ and not f.endswith(".h5")
214
+ ]
215
+
216
+ for f in tqdm(files_to_copy, desc="Copying configs"):
217
+ try:
218
+ local = hf_hub_download(repo_id=src_repo, filename=f, token=token)
219
+ api.upload_file(
220
+ path_or_fileobj=local,
221
+ path_in_repo=f,
222
+ repo_id=tgt_repo,
223
+ repo_type="model",
224
+ token=token
225
+ )
226
+ os.remove(local)
227
+ except Exception as e:
228
+ print(f"Skipped {f}: {e}")
229
+ except Exception as e:
230
+ print(f"Error copying config files: {e}")
231
+
232
+ def run_merge(
233
+ hf_token,
234
+ base_repo,
235
+ base_subfolder,
236
+ structure_repo,
237
+ lora_input,
238
+ user_scale,
239
+ output_repo,
240
+ is_private,
241
+ progress=gr.Progress()
242
+ ):
243
+ cleanup_temp()
244
+ logs = []
245
+
246
+ try:
247
+ login(hf_token)
248
+ logs.append(f"Logged in. Target: {output_repo}")
249
+
250
+ # 1. Create Output Repo
251
+ try:
252
+ api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token)
253
+ logs.append("Output repository ready.")
254
+ except Exception as e:
255
+ return "\n".join(logs) + f"\nError creating repo: {e}"
256
+
257
+ # 2. Replicate Structure
258
+ if structure_repo.strip():
259
+ progress(0.1, desc="Cloning Model Structure...")
260
+ logs.append(f"Cloning configuration from {structure_repo}...")
261
+ copy_auxiliary_files(structure_repo, output_repo, hf_token)
262
+ logs.append("Configuration files copied.")
263
+
264
+ # 3. Load and Standardize LoRA
265
+ progress(0.2, desc="Downloading & Processing LoRA...")
266
+ logs.append(f"Fetching LoRA: {lora_input}")
267
+
268
+ lora_path = download_lora(lora_input, hf_token)
269
+ raw_lora_state = load_file(lora_path, device="cpu")
270
+
271
+ # STANDARDIZE: Convert Comfy/Kohya keys to Diffusers keys & apply Alpha
272
+ lora_state = standardize_lora_config(raw_lora_state)
273
+ lora_keys = list(lora_state.keys())
274
+
275
+ logs.append(f"LoRA loaded & standardized. Found {len(lora_keys)} tensors.")
276
+ if len(lora_keys) > 0:
277
+ logs.append(f"Sample key: {lora_keys[0]}")
278
+
279
+ # 4. Identify Base Shards
280
+ progress(0.3, desc="Analyzing Base Model...")
281
+ all_files = list_repo_files(repo_id=base_repo, token=hf_token)
282
+
283
+ target_shards = []
284
+ for f in all_files:
285
+ if not f.endswith(".safetensors"):
286
+ continue
287
+ if base_subfolder.strip() and not f.startswith(base_subfolder.strip("/")):
288
+ continue
289
+ target_shards.append(f)
290
+
291
+ logs.append(f"Found {len(target_shards)} matching safetensors shards in base.")
292
+ if not target_shards:
293
+ raise ValueError("No safetensors found in the specified base repo/subfolder.")
294
+
295
+ # 5. Process Shards
296
+ total_shards = len(target_shards)
297
+ merged_count = 0
298
+
299
+ for idx, shard_file in enumerate(target_shards):
300
+ progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}")
301
+ logs.append(f"--- Processing {shard_file} ---")
302
+
303
+ local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
304
+
305
+ # Load base to CPU
306
+ base_tensors = load_file(local_shard, device="cpu")
307
+ modified_tensors = {}
308
+ has_changes = False
309
+
310
+ for key, tensor in base_tensors.items():
311
+ pair_A, pair_B = match_keys(key, lora_keys)
312
+
313
+ if pair_A and pair_B:
314
+ w_a = lora_state[pair_A].float()
315
+ w_b = lora_state[pair_B].float()
316
+ current_tensor = tensor.float()
317
+
318
+ # Apply merge
319
+ # Note: Alpha scaling is already embedded in w_a/w_b by standardize_lora_config
320
+ # We just apply the user_scale here
321
+
322
+ # Check shapes for Transpose requirement
323
+ # Standard LoRA: B @ A
324
+ try:
325
+ delta = (w_b @ w_a) * user_scale
326
+ except RuntimeError:
327
+ # Shape mismatch fallback
328
+ # Sometimes LoRA weights are transposed relative to base
329
+ if w_a.shape[0] == w_b.shape[1]:
330
+ delta = (w_a @ w_b) * user_scale
331
+ else:
332
+ # Last ditch: try transposing B
333
+ delta = (w_b.T @ w_a) * user_scale
334
+
335
+ if delta.shape != current_tensor.shape:
336
+ if delta.T.shape == current_tensor.shape:
337
+ delta = delta.T
338
+ else:
339
+ # Log only once per shard to avoid spam
340
+ if not has_changes:
341
+ logs.append(f"Warning: Shape mismatch for {key}. Base: {current_tensor.shape}, Delta: {delta.shape}. Skipping.")
342
+ modified_tensors[key] = tensor
343
+ continue
344
+
345
+ modified_tensors[key] = (current_tensor + delta).to(tensor.dtype)
346
+ merged_count += 1
347
+ has_changes = True
348
+ else:
349
+ modified_tensors[key] = tensor
350
+
351
+ if has_changes:
352
+ logs.append(f"Merging complete for shard. Saving...")
353
+ output_path = TempDir / "processed.safetensors"
354
+ save_file(modified_tensors, output_path)
355
+ api.upload_file(path_or_fileobj=output_path, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token)
356
+ logs.append(f"Uploaded {shard_file}")
357
+ else:
358
+ logs.append(f"No LoRA matches in this shard. Copying original...")
359
+ api.upload_file(path_or_fileobj=local_shard, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token)
360
+
361
+ # cleanup
362
+ del base_tensors
363
+ del modified_tensors
364
+ if 'delta' in locals(): del delta
365
+ gc.collect()
366
+ os.remove(local_shard)
367
+ if os.path.exists(TempDir / "processed.safetensors"):
368
+ os.remove(TempDir / "processed.safetensors")
369
+
370
+ progress(1.0, desc="Done!")
371
+ logs.append(f"\nSUCCESS. Merged {merged_count} layers total.")
372
+ logs.append(f"New model available at: https://huggingface.co/{output_repo}")
373
+
374
+ except Exception as e:
375
+ import traceback
376
+ logs.append(f"\nCRITICAL ERROR: {str(e)}")
377
+ logs.append(traceback.format_exc())
378
+
379
+ finally:
380
+ cleanup_temp()
381
+
382
+ return "\n".join(logs)
383
+
384
+ # --- UI ---
385
+
386
+ css = """
387
+ .container { max-width: 900px; margin: auto; }
388
+ .header { text-align: center; margin-bottom: 20px; }
389
+ """
390
+
391
+ with gr.Blocks() as demo:
392
+ gr.Markdown(
393
+ """
394
+ # ⚡ soonMERGE® for Weights & Adapters
395
+
396
+ Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure.
397
+ **New:** Auto-converts ComfyUI/Kohya LoRA formats (e.g. Z-Image) to match Diffusers base models on the fly.
398
+ """
399
+ )
400
+
401
+ with gr.Group():
402
+ gr.Markdown("### 1. Authentication & Output")
403
+ with gr.Row():
404
+ hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
405
+ output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Merged")
406
+ is_private = gr.Checkbox(label="Private Repo", value=True)
407
+
408
+ with gr.Group():
409
+ gr.Markdown("### 2. Base Weights (The Target)")
410
+ with gr.Row():
411
+ base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo")
412
+ base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.")
413
+
414
+ with gr.Group():
415
+ gr.Markdown("### 3. LoRA Configuration")
416
+ with gr.Row():
417
+ lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.")
418
+ scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1, info="Global multiplier (applied on top of LoRA's internal alpha)")
419
+
420
+ with gr.Group():
421
+ gr.Markdown("### 4. Repository Reconstruction (Optional)")
422
+ gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*")
423
+ structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.")
424
+
425
+ submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary")
426
+
427
+ output_log = gr.Textbox(label="Process Log", lines=20, interactive=False)
428
+
429
+ submit_btn.click(
430
+ fn=run_merge,
431
+ inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private],
432
+ outputs=output_log
433
+ )
434
+
435
+ if __name__ == "__main__":
436
+ demo.queue(max_size=1).launch(css=css)