import gradio as gr import torch import os import gc import re import shutil import requests import json import numpy as np from pathlib import Path from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login from safetensors.torch import load_file, save_file from tqdm import tqdm # --- Constants & Setup --- TempDir = Path("./temp_tool") os.makedirs(TempDir, exist_ok=True) api = HfApi() def info_log(msg, progress=None): print(msg) if progress: return msg return msg def cleanup_temp(): if TempDir.exists(): shutil.rmtree(TempDir) os.makedirs(TempDir, exist_ok=True) gc.collect() # --- Utility Functions --- def download_file(input_path, token, filename=None): """Downloads a file from URL or HF Repo.""" local_path = TempDir / (filename if filename else "model.safetensors") if input_path.startswith("http"): print(f"Downloading from URL: {input_path}") response = requests.get(input_path, stream=True) response.raise_for_status() with open(local_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) else: print(f"Downloading from Repo: {input_path}") if not filename: try: files = list_repo_files(repo_id=input_path, token=token) safetensors = [f for f in files if f.endswith(".safetensors")] if safetensors: filename = safetensors[0] else: filename = "adapter_model.bin" except: filename = "adapter_model.safetensors" hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False) downloaded_path = TempDir / filename if downloaded_path != local_path: shutil.move(downloaded_path, local_path) return local_path def get_key_stem(key): """ Normalizes a key to its structural stem. Aggressively strips known prefixes to align Comfy/Kohya/Diffusers keys. """ # 1. Remove Suffixes key = key.replace(".weight", "").replace(".bias", "") key = key.replace(".lora_down", "").replace(".lora_up", "") key = key.replace(".lora_A", "").replace(".lora_B", "") key = key.replace(".alpha", "") # 2. Remove Common Prefixes prefixes = [ "model.diffusion_model.", "diffusion_model.", "model.", "transformer.", "text_encoder.", "lora_unet_", "lora_te_" ] changed = True while changed: changed = False for p in prefixes: if key.startswith(p): key = key[len(p):] changed = True return key # ================================================================================= # TAB 1: SMART MERGE (Fixes Z-Image QKV) # ================================================================================= def load_lora_to_memory(lora_path): """Loads LoRA and pre-calculates pairs.""" state_dict = load_file(lora_path, device="cpu") alphas = {} weights = {} for k, v in state_dict.items(): if "alpha" in k: stem = get_key_stem(k) alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v else: weights[k] = v pairs = {} for k, v in weights.items(): stem = get_key_stem(k) if stem not in pairs: pairs[stem] = {} if "lora_down" in k or "lora_A" in k: pairs[stem]["down"] = v.float() pairs[stem]["rank"] = v.shape[0] elif "lora_up" in k or "lora_B" in k: pairs[stem]["up"] = v.float() for stem in pairs: if stem in alphas: pairs[stem]["alpha"] = alphas[stem] else: if "rank" in pairs[stem]: pairs[stem]["alpha"] = float(pairs[stem]["rank"]) else: pairs[stem]["alpha"] = 1.0 return pairs def merge_shard_logic(base_path, lora_pairs, scale, output_path): base_state = load_file(base_path, device="cpu") modified_state = {} has_modifications = False # Pre-index LoRA stems for fast lookup lora_stems = set(lora_pairs.keys()) for k, v in base_state.items(): base_stem = get_key_stem(k) # 1. Direct Match match = lora_pairs.get(base_stem) # 2. QKV Match (The Z-Image Fix) # If base is `attention.to_q` but LoRA has `attention.qkv` chunk_idx = -1 if not match: if "to_q" in base_stem: qkv_stem = base_stem.replace("to_q", "qkv") if qkv_stem in lora_stems: match = lora_pairs[qkv_stem] chunk_idx = 0 elif "to_k" in base_stem: qkv_stem = base_stem.replace("to_k", "qkv") if qkv_stem in lora_stems: match = lora_pairs[qkv_stem] chunk_idx = 1 elif "to_v" in base_stem: qkv_stem = base_stem.replace("to_v", "qkv") if qkv_stem in lora_stems: match = lora_pairs[qkv_stem] chunk_idx = 2 if match and "down" in match and "up" in match: down = match["down"] up = match["up"] # Handle Conv2d 1x1 if len(v.shape) == 4 and len(down.shape) == 2: down = down.unsqueeze(-1).unsqueeze(-1) up = up.unsqueeze(-1).unsqueeze(-1) scaling = scale * (match["alpha"] / match["rank"]) try: # Standard LoRA Matmul (Up @ Down) if len(up.shape) == 4: delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) # Approx for 1x1 else: delta = up @ down except: delta = up.T @ down # Fallback for transposed weights delta = delta * scaling # --- QKV Chunking Logic --- if chunk_idx >= 0: # The LoRA delta covers Q+K+V. We need to slice it. # Assuming output dim (dim 0) is stacked Q, K, V total_out = delta.shape[0] chunk_size = total_out // 3 start = chunk_idx * chunk_size end = start + chunk_size delta = delta[start:end, ...] # print(f"Splitting QKV for {k}: chunk {chunk_idx}") # Final Shape Check if delta.shape != v.shape: if delta.numel() == v.numel(): delta = delta.reshape(v.shape) else: print(f"Skipping {k}: Shape mismatch Base {v.shape} vs Delta {delta.shape}") modified_state[k] = v continue modified_state[k] = v.float() + delta modified_state[k] = modified_state[k].to(v.dtype) has_modifications = True else: modified_state[k] = v if has_modifications: save_file(modified_state, output_path) return True return False def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, progress=gr.Progress()): cleanup_temp() login(hf_token) try: api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token) except Exception as e: return f"Error creating repo: {e}" if structure_repo: print("Cloning structure...") try: files = list_repo_files(repo_id=structure_repo, token=hf_token) for f in files: if not f.endswith(".safetensors") and not f.endswith(".bin"): try: path = hf_hub_download(repo_id=structure_repo, filename=f, token=hf_token) api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=output_repo, token=hf_token) except: pass except Exception as e: print(f"Structure clone warning: {e}") progress(0.1, desc="Loading LoRA...") lora_path = download_file(lora_input, hf_token) lora_pairs = load_lora_to_memory(lora_path) print(f"Loaded LoRA with {len(lora_pairs)} modules.") files = list_repo_files(repo_id=base_repo, token=hf_token) shards = [f for f in files if f.endswith(".safetensors")] if base_subfolder: shards = [f for f in shards if f.startswith(base_subfolder)] if not shards: return "Error: No model shards found in base repo." for i, shard in enumerate(shards): progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}") print(f"Processing {shard}...") local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir) merged_path = TempDir / "merged.safetensors" success = merge_shard_logic(local_shard, lora_pairs, scale, merged_path) # Upload preserving directory structure api.upload_file(path_or_fileobj=merged_path if success else local_shard, path_in_repo=shard, repo_id=output_repo, token=hf_token) os.remove(local_shard) if merged_path.exists(): os.remove(merged_path) gc.collect() return f"Done! Model at https://huggingface.co/{output_repo}" # ================================================================================= # TAB 2: EXTRACT LORA # ================================================================================= def extract_lora(model_org, model_tuned, rank, conv_rank, clamp): try: org_state = load_file(model_org, device="cpu") tuned_state = load_file(model_tuned, device="cpu") except: return None, "Error: Could not load models." lora_sd = {} print("Calculating diffs and running SVD...") for key in tqdm(org_state.keys()): if key not in tuned_state: continue # Calculate diff mat = tuned_state[key].float() - org_state[key].float() if torch.max(torch.abs(mat)) < 1e-4: continue out_dim, in_dim = mat.shape[:2] rank_to_use = min(rank, in_dim, out_dim) is_conv = len(mat.shape) == 4 if is_conv: mat = mat.flatten(start_dim=1) try: # SVD U, S, Vh = torch.linalg.svd(mat, full_matrices=False) U = U[:, :rank_to_use] S = S[:rank_to_use] U = U @ torch.diag(S) Vh = Vh[:rank_to_use, :] # Clamp (Kohya trick) dist = torch.cat([U.flatten(), Vh.flatten()]) hi_val = torch.quantile(dist, clamp) low_val = -hi_val U = U.clamp(low_val, hi_val) Vh = Vh.clamp(low_val, hi_val) # Reshape if is_conv: U = U.reshape(out_dim, rank_to_use, 1, 1) Vh = Vh.reshape(rank_to_use, in_dim, mat.shape[0], mat.shape[1]) else: U = U.reshape(out_dim, rank_to_use) Vh = Vh.reshape(rank_to_use, in_dim) stem = key.replace(".weight", "") lora_sd[f"{stem}.lora_up.weight"] = U lora_sd[f"{stem}.lora_down.weight"] = Vh lora_sd[f"{stem}.alpha"] = torch.tensor(rank_to_use).float() except Exception as e: print(f"SVD failed for {key}: {e}") out_path = TempDir / "extracted_lora.safetensors" save_file(lora_sd, out_path) return str(out_path), "Success" def task_extract(hf_token, org_repo, tuned_repo, rank, output_repo): cleanup_temp() login(hf_token) print("Downloading Original...") org_path = download_file(org_repo, hf_token, "original.safetensors") print("Downloading Tuned...") tuned_path = download_file(tuned_repo, hf_token, "tuned.safetensors") path, msg = extract_lora(org_path, tuned_path, int(rank), int(rank), 0.99) if path: api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=path, path_in_repo="extracted_lora.safetensors", repo_id=output_repo, token=hf_token) return "Extraction Done." return msg # ================================================================================= # TAB 3: MERGE ADAPTERS (Post-Hoc EMA) # ================================================================================= def merge_adapters_ema(lora_paths, beta, output_path): """ Implements Power Function EMA merging from lora_post_hoc_ema.py """ # Sort files (assuming temporal order is desired, though we rely on input list order) # lora_paths are typically passed in order. if not lora_paths: return False print(f"Loading base: {lora_paths[0]}") base_state = load_file(lora_paths[0], device="cpu") # Convert to float32 for merging for k in base_state: if base_state[k].dtype.is_floating_point: base_state[k] = base_state[k].float() ema_count = len(lora_paths) - 1 for i, path in enumerate(lora_paths[1:]): print(f"Merging {path}...") current_state = load_file(path, device="cpu") # Simple Beta Decay (Can be extended to Power Function if sigma_rel is needed) # Using a fixed beta or linear interp as per user request # Default simple EMA: state = state * beta + new * (1-beta) # Kohya's script allows dynamic beta. Let's use the user provided beta. for k in base_state: if k in current_state: if "alpha" in k: continue # Alphas should match curr_val = current_state[k].float() base_state[k] = base_state[k] * beta + curr_val * (1 - beta) save_file(base_state, output_path) return True def task_merge_adapters(hf_token, lora_urls, beta, output_repo): cleanup_temp() login(hf_token) urls = [url.strip() for url in lora_urls.split(",")] local_paths = [] for i, url in enumerate(urls): if not url: continue print(f"Downloading Adapter {i+1}...") # handle resolve urls path = download_file(url, hf_token, f"adapter_{i}.safetensors") local_paths.append(path) out_path = TempDir / "merged_adapters.safetensors" success = merge_adapters_ema(local_paths, beta, out_path) if success: api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=out_path, path_in_repo="merged_adapters_ema.safetensors", repo_id=output_repo, token=hf_token) return "Adapter Merge Done." return "Error merging adapters." # ================================================================================= # TAB 4: RESIZE LORA # ================================================================================= def task_resize(hf_token, lora_input, new_rank, output_repo): cleanup_temp() login(hf_token) path = download_file(lora_input, hf_token) state = load_file(path, device="cpu") new_state = {} print("Resizing...") stems = set() for k in state.keys(): stems.add(get_key_stem(k)) for stem in tqdm(stems): down_key = None up_key = None # Fuzzy finder for the raw keys for k in state: if stem in k and ("lora_down" in k or "lora_A" in k): down_key = k if stem in k and ("lora_up" in k or "lora_B" in k): up_key = k if down_key and up_key: down = state[down_key].float() up = state[up_key].float() if len(down.shape) == 2: merged = up @ down else: merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) # Re-SVD U, S, Vh = torch.linalg.svd(merged.flatten(1), full_matrices=False) U = U[:, :new_rank] S = S[:new_rank] U = U @ torch.diag(S) Vh = Vh[:new_rank, :] new_state[down_key] = Vh new_state[up_key] = U # Find alpha key for k in state: if stem in k and "alpha" in k: new_state[k] = torch.tensor(new_rank).float() out = TempDir / "resized.safetensors" save_file(new_state, out) api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=out, path_in_repo="resized_lora.safetensors", repo_id=output_repo, token=hf_token) return "Resize Done." # ================================================================================= # UI # ================================================================================= css = """ .container { max-width: 900px; margin: auto; } """ with gr.Blocks() as demo: gr.Markdown("# 🧰 SOONmerge® Toolkit") gr.Markdown("Includes: Smart QKV Un-fusing, Post-Hoc EMA, Adapter Merging, Resizing, and Extraction.") with gr.Tabs(): # --- TAB 1 --- with gr.Tab("Merge LoRA into Base"): gr.Markdown("Supports Z-Image Fused QKV LoRAs -> Split Base.") t1_token = gr.Textbox(label="HF Token", type="password") with gr.Row(): t1_base = gr.Textbox(label="Base Model Repo", placeholder="ostris/Z-Image-De-Turbo") t1_sub = gr.Textbox(label="Subfolder (Optional)", placeholder="transformer") with gr.Row(): t1_lora = gr.Textbox(label="LoRA Repo/URL") t1_scale = gr.Slider(label="Scale", value=1.0, minimum=-1, maximum=2) t1_out = gr.Textbox(label="Output Repo") t1_struct = gr.Textbox(label="Structure Repo (Optional)", placeholder="Tongyi-MAI/Z-Image-Turbo") t1_btn = gr.Button("Merge") t1_log = gr.Textbox(label="Log", interactive=False) t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_out, t1_struct, gr.Checkbox(value=True, visible=False)], t1_log) # --- TAB 2 --- with gr.Tab("Extract LoRA"): t2_token = gr.Textbox(label="HF Token", type="password") t2_org = gr.Textbox(label="Original Model Repo/URL") t2_tuned = gr.Textbox(label="Tuned Model Repo/URL") t2_rank = gr.Number(label="Rank", value=32) t2_out = gr.Textbox(label="Output Repo") t2_btn = gr.Button("Extract") t2_log = gr.Textbox(label="Log") t2_btn.click(task_extract, [t2_token, t2_org, t2_tuned, t2_rank, t2_out], t2_log) # --- TAB 3 --- with gr.Tab("Merge Adapters (EMA)"): gr.Markdown("Post-Hoc EMA Merge: Combined multiple LoRAs into one file.") t3_token = gr.Textbox(label="HF Token", type="password") t3_urls = gr.Textbox(label="LoRA URLs (comma separated)", placeholder="http://...lora1.safetensors, http://...lora2.safetensors") t3_beta = gr.Slider(label="Beta (Decay)", value=0.95, minimum=0.0, maximum=1.0) t3_out = gr.Textbox(label="Output Repo") t3_btn = gr.Button("Merge Adapters") t3_log = gr.Textbox(label="Log") t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_out], t3_log) # --- TAB 4 --- with gr.Tab("Resize LoRA"): t4_token = gr.Textbox(label="HF Token", type="password") t4_in = gr.Textbox(label="LoRA Repo/URL") t4_rank = gr.Number(label="Target Rank", value=8) t4_out = gr.Textbox(label="Output Repo") t4_btn = gr.Button("Resize") t4_log = gr.Textbox(label="Log") t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_out], t4_log) if __name__ == "__main__": demo.queue(max_size=1).launch(css=css)