Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import os | |
| import gc | |
| import re | |
| import shutil | |
| import requests | |
| import json | |
| import numpy as np | |
| from pathlib import Path | |
| from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login | |
| from safetensors.torch import load_file, save_file | |
| from tqdm import tqdm | |
| # --- Constants & Setup --- | |
| TempDir = Path("./temp_tool") | |
| os.makedirs(TempDir, exist_ok=True) | |
| api = HfApi() | |
| def info_log(msg, progress=None): | |
| print(msg) | |
| if progress: | |
| return msg | |
| return msg | |
| def cleanup_temp(): | |
| if TempDir.exists(): | |
| shutil.rmtree(TempDir) | |
| os.makedirs(TempDir, exist_ok=True) | |
| gc.collect() | |
| # --- Utility Functions --- | |
| def download_file(input_path, token, filename=None): | |
| """Downloads a file from URL or HF Repo.""" | |
| local_path = TempDir / (filename if filename else "model.safetensors") | |
| if input_path.startswith("http"): | |
| print(f"Downloading from URL: {input_path}") | |
| response = requests.get(input_path, stream=True) | |
| response.raise_for_status() | |
| with open(local_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| else: | |
| print(f"Downloading from Repo: {input_path}") | |
| if not filename: | |
| try: | |
| files = list_repo_files(repo_id=input_path, token=token) | |
| safetensors = [f for f in files if f.endswith(".safetensors")] | |
| if safetensors: | |
| filename = safetensors[0] | |
| else: | |
| filename = "adapter_model.bin" | |
| except: | |
| filename = "adapter_model.safetensors" | |
| hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False) | |
| downloaded_path = TempDir / filename | |
| if downloaded_path != local_path: | |
| shutil.move(downloaded_path, local_path) | |
| return local_path | |
| def get_key_stem(key): | |
| """ | |
| Normalizes a key to its structural stem. | |
| Aggressively strips known prefixes to align Comfy/Kohya/Diffusers keys. | |
| """ | |
| # 1. Remove Suffixes | |
| key = key.replace(".weight", "").replace(".bias", "") | |
| key = key.replace(".lora_down", "").replace(".lora_up", "") | |
| key = key.replace(".lora_A", "").replace(".lora_B", "") | |
| key = key.replace(".alpha", "") | |
| # 2. Remove Common Prefixes | |
| prefixes = [ | |
| "model.diffusion_model.", "diffusion_model.", "model.", | |
| "transformer.", "text_encoder.", "lora_unet_", "lora_te_" | |
| ] | |
| changed = True | |
| while changed: | |
| changed = False | |
| for p in prefixes: | |
| if key.startswith(p): | |
| key = key[len(p):] | |
| changed = True | |
| return key | |
| # ================================================================================= | |
| # TAB 1: SMART MERGE (Fixes Z-Image QKV) | |
| # ================================================================================= | |
| def load_lora_to_memory(lora_path): | |
| """Loads LoRA and pre-calculates pairs.""" | |
| state_dict = load_file(lora_path, device="cpu") | |
| alphas = {} | |
| weights = {} | |
| for k, v in state_dict.items(): | |
| if "alpha" in k: | |
| stem = get_key_stem(k) | |
| alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v | |
| else: | |
| weights[k] = v | |
| pairs = {} | |
| for k, v in weights.items(): | |
| stem = get_key_stem(k) | |
| if stem not in pairs: | |
| pairs[stem] = {} | |
| if "lora_down" in k or "lora_A" in k: | |
| pairs[stem]["down"] = v.float() | |
| pairs[stem]["rank"] = v.shape[0] | |
| elif "lora_up" in k or "lora_B" in k: | |
| pairs[stem]["up"] = v.float() | |
| for stem in pairs: | |
| if stem in alphas: | |
| pairs[stem]["alpha"] = alphas[stem] | |
| else: | |
| if "rank" in pairs[stem]: | |
| pairs[stem]["alpha"] = float(pairs[stem]["rank"]) | |
| else: | |
| pairs[stem]["alpha"] = 1.0 | |
| return pairs | |
| def merge_shard_logic(base_path, lora_pairs, scale, output_path): | |
| base_state = load_file(base_path, device="cpu") | |
| modified_state = {} | |
| has_modifications = False | |
| # Pre-index LoRA stems for fast lookup | |
| lora_stems = set(lora_pairs.keys()) | |
| for k, v in base_state.items(): | |
| base_stem = get_key_stem(k) | |
| # 1. Direct Match | |
| match = lora_pairs.get(base_stem) | |
| # 2. QKV Match (The Z-Image Fix) | |
| # If base is `attention.to_q` but LoRA has `attention.qkv` | |
| chunk_idx = -1 | |
| if not match: | |
| if "to_q" in base_stem: | |
| qkv_stem = base_stem.replace("to_q", "qkv") | |
| if qkv_stem in lora_stems: | |
| match = lora_pairs[qkv_stem] | |
| chunk_idx = 0 | |
| elif "to_k" in base_stem: | |
| qkv_stem = base_stem.replace("to_k", "qkv") | |
| if qkv_stem in lora_stems: | |
| match = lora_pairs[qkv_stem] | |
| chunk_idx = 1 | |
| elif "to_v" in base_stem: | |
| qkv_stem = base_stem.replace("to_v", "qkv") | |
| if qkv_stem in lora_stems: | |
| match = lora_pairs[qkv_stem] | |
| chunk_idx = 2 | |
| if match and "down" in match and "up" in match: | |
| down = match["down"] | |
| up = match["up"] | |
| # Handle Conv2d 1x1 | |
| if len(v.shape) == 4 and len(down.shape) == 2: | |
| down = down.unsqueeze(-1).unsqueeze(-1) | |
| up = up.unsqueeze(-1).unsqueeze(-1) | |
| scaling = scale * (match["alpha"] / match["rank"]) | |
| try: | |
| # Standard LoRA Matmul (Up @ Down) | |
| if len(up.shape) == 4: | |
| delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) # Approx for 1x1 | |
| else: | |
| delta = up @ down | |
| except: | |
| delta = up.T @ down # Fallback for transposed weights | |
| delta = delta * scaling | |
| # --- QKV Chunking Logic --- | |
| if chunk_idx >= 0: | |
| # The LoRA delta covers Q+K+V. We need to slice it. | |
| # Assuming output dim (dim 0) is stacked Q, K, V | |
| total_out = delta.shape[0] | |
| chunk_size = total_out // 3 | |
| start = chunk_idx * chunk_size | |
| end = start + chunk_size | |
| delta = delta[start:end, ...] | |
| # print(f"Splitting QKV for {k}: chunk {chunk_idx}") | |
| # Final Shape Check | |
| if delta.shape != v.shape: | |
| if delta.numel() == v.numel(): | |
| delta = delta.reshape(v.shape) | |
| else: | |
| print(f"Skipping {k}: Shape mismatch Base {v.shape} vs Delta {delta.shape}") | |
| modified_state[k] = v | |
| continue | |
| modified_state[k] = v.float() + delta | |
| modified_state[k] = modified_state[k].to(v.dtype) | |
| has_modifications = True | |
| else: | |
| modified_state[k] = v | |
| if has_modifications: | |
| save_file(modified_state, output_path) | |
| return True | |
| return False | |
| def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, output_repo, structure_repo, private, progress=gr.Progress()): | |
| cleanup_temp() | |
| login(hf_token) | |
| try: | |
| api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token) | |
| except Exception as e: | |
| return f"Error creating repo: {e}" | |
| if structure_repo: | |
| print("Cloning structure...") | |
| try: | |
| files = list_repo_files(repo_id=structure_repo, token=hf_token) | |
| for f in files: | |
| if not f.endswith(".safetensors") and not f.endswith(".bin"): | |
| try: | |
| path = hf_hub_download(repo_id=structure_repo, filename=f, token=hf_token) | |
| api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=output_repo, token=hf_token) | |
| except: pass | |
| except Exception as e: | |
| print(f"Structure clone warning: {e}") | |
| progress(0.1, desc="Loading LoRA...") | |
| lora_path = download_file(lora_input, hf_token) | |
| lora_pairs = load_lora_to_memory(lora_path) | |
| print(f"Loaded LoRA with {len(lora_pairs)} modules.") | |
| files = list_repo_files(repo_id=base_repo, token=hf_token) | |
| shards = [f for f in files if f.endswith(".safetensors")] | |
| if base_subfolder: | |
| shards = [f for f in shards if f.startswith(base_subfolder)] | |
| if not shards: | |
| return "Error: No model shards found in base repo." | |
| for i, shard in enumerate(shards): | |
| progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}") | |
| print(f"Processing {shard}...") | |
| local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir) | |
| merged_path = TempDir / "merged.safetensors" | |
| success = merge_shard_logic(local_shard, lora_pairs, scale, merged_path) | |
| # Upload preserving directory structure | |
| api.upload_file(path_or_fileobj=merged_path if success else local_shard, path_in_repo=shard, repo_id=output_repo, token=hf_token) | |
| os.remove(local_shard) | |
| if merged_path.exists(): os.remove(merged_path) | |
| gc.collect() | |
| return f"Done! Model at https://huggingface.co/{output_repo}" | |
| # ================================================================================= | |
| # TAB 2: EXTRACT LORA | |
| # ================================================================================= | |
| def extract_lora(model_org, model_tuned, rank, conv_rank, clamp): | |
| try: | |
| org_state = load_file(model_org, device="cpu") | |
| tuned_state = load_file(model_tuned, device="cpu") | |
| except: | |
| return None, "Error: Could not load models." | |
| lora_sd = {} | |
| print("Calculating diffs and running SVD...") | |
| for key in tqdm(org_state.keys()): | |
| if key not in tuned_state: continue | |
| # Calculate diff | |
| mat = tuned_state[key].float() - org_state[key].float() | |
| if torch.max(torch.abs(mat)) < 1e-4: continue | |
| out_dim, in_dim = mat.shape[:2] | |
| rank_to_use = min(rank, in_dim, out_dim) | |
| is_conv = len(mat.shape) == 4 | |
| if is_conv: mat = mat.flatten(start_dim=1) | |
| try: | |
| # SVD | |
| U, S, Vh = torch.linalg.svd(mat, full_matrices=False) | |
| U = U[:, :rank_to_use] | |
| S = S[:rank_to_use] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:rank_to_use, :] | |
| # Clamp (Kohya trick) | |
| dist = torch.cat([U.flatten(), Vh.flatten()]) | |
| hi_val = torch.quantile(dist, clamp) | |
| low_val = -hi_val | |
| U = U.clamp(low_val, hi_val) | |
| Vh = Vh.clamp(low_val, hi_val) | |
| # Reshape | |
| if is_conv: | |
| U = U.reshape(out_dim, rank_to_use, 1, 1) | |
| Vh = Vh.reshape(rank_to_use, in_dim, mat.shape[0], mat.shape[1]) | |
| else: | |
| U = U.reshape(out_dim, rank_to_use) | |
| Vh = Vh.reshape(rank_to_use, in_dim) | |
| stem = key.replace(".weight", "") | |
| lora_sd[f"{stem}.lora_up.weight"] = U | |
| lora_sd[f"{stem}.lora_down.weight"] = Vh | |
| lora_sd[f"{stem}.alpha"] = torch.tensor(rank_to_use).float() | |
| except Exception as e: | |
| print(f"SVD failed for {key}: {e}") | |
| out_path = TempDir / "extracted_lora.safetensors" | |
| save_file(lora_sd, out_path) | |
| return str(out_path), "Success" | |
| def task_extract(hf_token, org_repo, tuned_repo, rank, output_repo): | |
| cleanup_temp() | |
| login(hf_token) | |
| print("Downloading Original...") | |
| org_path = download_file(org_repo, hf_token, "original.safetensors") | |
| print("Downloading Tuned...") | |
| tuned_path = download_file(tuned_repo, hf_token, "tuned.safetensors") | |
| path, msg = extract_lora(org_path, tuned_path, int(rank), int(rank), 0.99) | |
| if path: | |
| api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=path, path_in_repo="extracted_lora.safetensors", repo_id=output_repo, token=hf_token) | |
| return "Extraction Done." | |
| return msg | |
| # ================================================================================= | |
| # TAB 3: MERGE ADAPTERS (Post-Hoc EMA) | |
| # ================================================================================= | |
| def merge_adapters_ema(lora_paths, beta, output_path): | |
| """ | |
| Implements Power Function EMA merging from lora_post_hoc_ema.py | |
| """ | |
| # Sort files (assuming temporal order is desired, though we rely on input list order) | |
| # lora_paths are typically passed in order. | |
| if not lora_paths: return False | |
| print(f"Loading base: {lora_paths[0]}") | |
| base_state = load_file(lora_paths[0], device="cpu") | |
| # Convert to float32 for merging | |
| for k in base_state: | |
| if base_state[k].dtype.is_floating_point: | |
| base_state[k] = base_state[k].float() | |
| ema_count = len(lora_paths) - 1 | |
| for i, path in enumerate(lora_paths[1:]): | |
| print(f"Merging {path}...") | |
| current_state = load_file(path, device="cpu") | |
| # Simple Beta Decay (Can be extended to Power Function if sigma_rel is needed) | |
| # Using a fixed beta or linear interp as per user request | |
| # Default simple EMA: state = state * beta + new * (1-beta) | |
| # Kohya's script allows dynamic beta. Let's use the user provided beta. | |
| for k in base_state: | |
| if k in current_state: | |
| if "alpha" in k: continue # Alphas should match | |
| curr_val = current_state[k].float() | |
| base_state[k] = base_state[k] * beta + curr_val * (1 - beta) | |
| save_file(base_state, output_path) | |
| return True | |
| def task_merge_adapters(hf_token, lora_urls, beta, output_repo): | |
| cleanup_temp() | |
| login(hf_token) | |
| urls = [url.strip() for url in lora_urls.split(",")] | |
| local_paths = [] | |
| for i, url in enumerate(urls): | |
| if not url: continue | |
| print(f"Downloading Adapter {i+1}...") | |
| # handle resolve urls | |
| path = download_file(url, hf_token, f"adapter_{i}.safetensors") | |
| local_paths.append(path) | |
| out_path = TempDir / "merged_adapters.safetensors" | |
| success = merge_adapters_ema(local_paths, beta, out_path) | |
| if success: | |
| api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out_path, path_in_repo="merged_adapters_ema.safetensors", repo_id=output_repo, token=hf_token) | |
| return "Adapter Merge Done." | |
| return "Error merging adapters." | |
| # ================================================================================= | |
| # TAB 4: RESIZE LORA | |
| # ================================================================================= | |
| def task_resize(hf_token, lora_input, new_rank, output_repo): | |
| cleanup_temp() | |
| login(hf_token) | |
| path = download_file(lora_input, hf_token) | |
| state = load_file(path, device="cpu") | |
| new_state = {} | |
| print("Resizing...") | |
| stems = set() | |
| for k in state.keys(): | |
| stems.add(get_key_stem(k)) | |
| for stem in tqdm(stems): | |
| down_key = None | |
| up_key = None | |
| # Fuzzy finder for the raw keys | |
| for k in state: | |
| if stem in k and ("lora_down" in k or "lora_A" in k): down_key = k | |
| if stem in k and ("lora_up" in k or "lora_B" in k): up_key = k | |
| if down_key and up_key: | |
| down = state[down_key].float() | |
| up = state[up_key].float() | |
| if len(down.shape) == 2: | |
| merged = up @ down | |
| else: | |
| merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) | |
| # Re-SVD | |
| U, S, Vh = torch.linalg.svd(merged.flatten(1), full_matrices=False) | |
| U = U[:, :new_rank] | |
| S = S[:new_rank] | |
| U = U @ torch.diag(S) | |
| Vh = Vh[:new_rank, :] | |
| new_state[down_key] = Vh | |
| new_state[up_key] = U | |
| # Find alpha key | |
| for k in state: | |
| if stem in k and "alpha" in k: | |
| new_state[k] = torch.tensor(new_rank).float() | |
| out = TempDir / "resized.safetensors" | |
| save_file(new_state, out) | |
| api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out, path_in_repo="resized_lora.safetensors", repo_id=output_repo, token=hf_token) | |
| return "Resize Done." | |
| # ================================================================================= | |
| # UI | |
| # ================================================================================= | |
| css = """ | |
| .container { max-width: 900px; margin: auto; } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🧰 SOONmerge® Toolkit") | |
| gr.Markdown("Includes: Smart QKV Un-fusing, Post-Hoc EMA, Adapter Merging, Resizing, and Extraction.") | |
| with gr.Tabs(): | |
| # --- TAB 1 --- | |
| with gr.Tab("Merge LoRA into Base"): | |
| gr.Markdown("Supports Z-Image Fused QKV LoRAs -> Split Base.") | |
| t1_token = gr.Textbox(label="HF Token", type="password") | |
| with gr.Row(): | |
| t1_base = gr.Textbox(label="Base Model Repo", placeholder="ostris/Z-Image-De-Turbo") | |
| t1_sub = gr.Textbox(label="Subfolder (Optional)", placeholder="transformer") | |
| with gr.Row(): | |
| t1_lora = gr.Textbox(label="LoRA Repo/URL") | |
| t1_scale = gr.Slider(label="Scale", value=1.0, minimum=-1, maximum=2) | |
| t1_out = gr.Textbox(label="Output Repo") | |
| t1_struct = gr.Textbox(label="Structure Repo (Optional)", placeholder="Tongyi-MAI/Z-Image-Turbo") | |
| t1_btn = gr.Button("Merge") | |
| t1_log = gr.Textbox(label="Log", interactive=False) | |
| t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_out, t1_struct, gr.Checkbox(value=True, visible=False)], t1_log) | |
| # --- TAB 2 --- | |
| with gr.Tab("Extract LoRA"): | |
| t2_token = gr.Textbox(label="HF Token", type="password") | |
| t2_org = gr.Textbox(label="Original Model Repo/URL") | |
| t2_tuned = gr.Textbox(label="Tuned Model Repo/URL") | |
| t2_rank = gr.Number(label="Rank", value=32) | |
| t2_out = gr.Textbox(label="Output Repo") | |
| t2_btn = gr.Button("Extract") | |
| t2_log = gr.Textbox(label="Log") | |
| t2_btn.click(task_extract, [t2_token, t2_org, t2_tuned, t2_rank, t2_out], t2_log) | |
| # --- TAB 3 --- | |
| with gr.Tab("Merge Adapters (EMA)"): | |
| gr.Markdown("Post-Hoc EMA Merge: Combined multiple LoRAs into one file.") | |
| t3_token = gr.Textbox(label="HF Token", type="password") | |
| t3_urls = gr.Textbox(label="LoRA URLs (comma separated)", placeholder="http://...lora1.safetensors, http://...lora2.safetensors") | |
| t3_beta = gr.Slider(label="Beta (Decay)", value=0.95, minimum=0.0, maximum=1.0) | |
| t3_out = gr.Textbox(label="Output Repo") | |
| t3_btn = gr.Button("Merge Adapters") | |
| t3_log = gr.Textbox(label="Log") | |
| t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_out], t3_log) | |
| # --- TAB 4 --- | |
| with gr.Tab("Resize LoRA"): | |
| t4_token = gr.Textbox(label="HF Token", type="password") | |
| t4_in = gr.Textbox(label="LoRA Repo/URL") | |
| t4_rank = gr.Number(label="Target Rank", value=8) | |
| t4_out = gr.Textbox(label="Output Repo") | |
| t4_btn = gr.Button("Resize") | |
| t4_log = gr.Textbox(label="Log") | |
| t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_out], t4_log) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=1).launch(css=css) |