Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import os | |
| import gc | |
| import re | |
| import shutil | |
| import requests | |
| import json | |
| from pathlib import Path | |
| from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login | |
| from safetensors.torch import load_file, save_file | |
| from safetensors import safe_open | |
| from tqdm import tqdm | |
| # --- Constants & Setup --- | |
| TempDir = Path("./temp_merge") | |
| os.makedirs(TempDir, exist_ok=True) | |
| api = HfApi() | |
| def cleanup_temp(): | |
| if TempDir.exists(): | |
| shutil.rmtree(TempDir) | |
| os.makedirs(TempDir, exist_ok=True) | |
| gc.collect() | |
| # --- Core Logic --- | |
| def download_lora(lora_input, hf_token): | |
| """Downloads LoRA from a Repo ID or a direct URL.""" | |
| local_path = TempDir / "adapter.safetensors" | |
| if lora_input.startswith("http"): | |
| print(f"Downloading LoRA from URL: {lora_input}") | |
| response = requests.get(lora_input, stream=True) | |
| response.raise_for_status() | |
| with open(local_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return local_path | |
| else: | |
| print(f"Downloading LoRA from Repo: {lora_input}") | |
| try: | |
| return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir) | |
| except: | |
| files = list_repo_files(repo_id=lora_input, token=hf_token) | |
| # Prioritize safetensors | |
| safe_files = [f for f in files if f.endswith(".safetensors")] | |
| if not safe_files: | |
| raise ValueError("Could not find a .safetensors file in the LoRA repo.") | |
| # Heuristic: pick the one that looks most like a model file | |
| target_file = safe_files[0] | |
| for f in safe_files: | |
| if "fp16" in f or "rank" in f: | |
| target_file = f | |
| break | |
| return hf_hub_download(repo_id=lora_input, filename=target_file, token=hf_token, local_dir=TempDir) | |
| def standardize_lora_config(lora_state_dict): | |
| """ | |
| Analyzes the LoRA state dict and converts keys to a standardized Diffusers-compatible format. | |
| Handles 'lora_down' -> 'lora_A', prefix stripping, and alpha scaling. | |
| """ | |
| standardized_dict = {} | |
| alphas = {} | |
| ranks = {} | |
| keys = list(lora_state_dict.keys()) | |
| # 1. First Pass: Detect structure and Alphas | |
| for key in keys: | |
| if "alpha" in key: | |
| # key example: diffusion_model.layers.24.feed_forward.w1.alpha | |
| stem = key.replace(".alpha", "") | |
| alphas[stem] = lora_state_dict[key].item() if isinstance(lora_state_dict[key], torch.Tensor) else lora_state_dict[key] | |
| print(f"Found {len(alphas)} alpha keys in LoRA.") | |
| # 2. Second Pass: Convert Weights | |
| for key in keys: | |
| if "alpha" in key: | |
| continue | |
| tensor = lora_state_dict[key] | |
| new_key = key | |
| # --- Conversion Logic (Inspired by Diffusers lora_conversion_utils.py) --- | |
| # Strip common ComfyUI/Internal prefixes | |
| prefixes_to_strip = ["diffusion_model.", "model.diffusion_model.", "lora_unet_"] | |
| for p in prefixes_to_strip: | |
| if new_key.startswith(p): | |
| new_key = new_key[len(p):] | |
| # Convert lora_down/up to lora_A/B | |
| is_down = "lora_down.weight" in new_key | |
| is_up = "lora_up.weight" in new_key | |
| if is_down: | |
| new_key = new_key.replace("lora_down.weight", "lora_A.weight") | |
| stem = key.split(".lora_down.weight")[0] | |
| ranks[stem] = tensor.shape[0] # Down projection output dim is rank | |
| elif is_up: | |
| new_key = new_key.replace("lora_up.weight", "lora_B.weight") | |
| # Handling Z-Image specific "feed_forward" vs "ff" discrepancies if necessary | |
| # (Based on your logs, Z-Image base uses 'feed_forward' so we might not need heavy mapping if we strip prefix) | |
| standardized_dict[new_key] = tensor | |
| # 3. Third Pass: Embed Scaling into Weights | |
| # If we have alpha and rank, we can pre-multiply the weights so the merge function just needs to do B @ A | |
| # Scale = alpha / rank | |
| final_dict = {} | |
| for key, tensor in standardized_dict.items(): | |
| # Find corresponding stem to check for alpha | |
| # key is like: layers.24.feed_forward.w1.lora_A.weight | |
| if "lora_A.weight" in key: | |
| stem_suffix = ".lora_A.weight" | |
| is_A = True | |
| elif "lora_B.weight" in key: | |
| stem_suffix = ".lora_B.weight" | |
| is_A = False | |
| else: | |
| final_dict[key] = tensor | |
| continue | |
| # We need to map the "new key" stem back to the "old key" stem to find the alpha | |
| # This is tricky because we stripped prefixes. | |
| # Simpler approach: Calculate scale factor now if possible, or store metadata. | |
| # Heuristic: Match alpha by checking if alpha key ends with the current key's structural part | |
| # Current key struct: layers.24.feed_forward.w1 | |
| struct_part = key.replace(stem_suffix, "") | |
| scale = 1.0 | |
| # Find matching alpha | |
| # We look for an alpha key that ends with 'struct_part' | |
| # e.g. alpha key "diffusion_model.layers.24...w1" ends with "layers.24...w1" | |
| found_alpha = None | |
| for a_key, a_val in alphas.items(): | |
| if a_key.endswith(struct_part): | |
| found_alpha = a_val | |
| break | |
| if found_alpha: | |
| # We need the rank. | |
| # If it's lora_A, rank is tensor.shape[0] | |
| # If it's lora_B, rank is tensor.shape[1] | |
| rank = tensor.shape[0] if is_A else tensor.shape[1] | |
| # Scale calculation: scale = alpha / rank | |
| # We apply sqrt(scale) to both A and B so that A@B is scaled by (alpha/rank) | |
| scale_factor = (found_alpha / rank) ** 0.5 | |
| tensor = tensor * scale_factor | |
| final_dict[key] = tensor | |
| return final_dict | |
| def match_keys(base_key, lora_keys): | |
| """ | |
| Robust matching finding the best LoRA pair for a Base Key. | |
| """ | |
| # base_key example: layers.24.feed_forward.w1.weight | |
| # lora_key example: layers.24.feed_forward.w1.lora_A.weight | |
| base_stem = base_key.replace(".weight", "") | |
| pair_A = None | |
| pair_B = None | |
| # Exact stem match check | |
| candidate_A = f"{base_stem}.lora_A.weight" | |
| candidate_B = f"{base_stem}.lora_B.weight" | |
| if candidate_A in lora_keys and candidate_B in lora_keys: | |
| return candidate_A, candidate_B | |
| # Fuzzy match if exact fails | |
| # This handles slight naming diffs like "processor" inclusion | |
| matches = [k for k in lora_keys if base_stem in k] | |
| for k in matches: | |
| if "lora_A" in k: | |
| pair_A = k | |
| elif "lora_B" in k: | |
| pair_B = k | |
| if pair_A and pair_B: | |
| # Verify they belong to the same block | |
| # e.g. ensure we don't match layer.24 to layer.2 | |
| prefix_A = pair_A.split(".lora_A")[0] | |
| prefix_B = pair_B.split(".lora_B")[0] | |
| if prefix_A == prefix_B: | |
| return pair_A, pair_B | |
| return None, None | |
| def copy_auxiliary_files(src_repo, tgt_repo, token): | |
| print(f"Copying infrastructure from {src_repo} to {tgt_repo}...") | |
| try: | |
| files = list_repo_files(repo_id=src_repo, token=token) | |
| files_to_copy = [ | |
| f for f in files | |
| if not f.endswith(".safetensors") | |
| and not f.endswith(".bin") | |
| and not f.endswith(".pt") | |
| and not f.endswith(".pth") | |
| and not f.endswith(".msgpack") | |
| and not f.endswith(".h5") | |
| ] | |
| for f in tqdm(files_to_copy, desc="Copying configs"): | |
| try: | |
| local = hf_hub_download(repo_id=src_repo, filename=f, token=token) | |
| api.upload_file( | |
| path_or_fileobj=local, | |
| path_in_repo=f, | |
| repo_id=tgt_repo, | |
| repo_type="model", | |
| token=token | |
| ) | |
| os.remove(local) | |
| except Exception as e: | |
| print(f"Skipped {f}: {e}") | |
| except Exception as e: | |
| print(f"Error copying config files: {e}") | |
| def run_merge( | |
| hf_token, | |
| base_repo, | |
| base_subfolder, | |
| structure_repo, | |
| lora_input, | |
| user_scale, | |
| output_repo, | |
| is_private, | |
| progress=gr.Progress() | |
| ): | |
| cleanup_temp() | |
| logs = [] | |
| try: | |
| login(hf_token) | |
| logs.append(f"Logged in. Target: {output_repo}") | |
| # 1. Create Output Repo | |
| try: | |
| api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token) | |
| logs.append("Output repository ready.") | |
| except Exception as e: | |
| return "\n".join(logs) + f"\nError creating repo: {e}" | |
| # 2. Replicate Structure | |
| if structure_repo.strip(): | |
| progress(0.1, desc="Cloning Model Structure...") | |
| logs.append(f"Cloning configuration from {structure_repo}...") | |
| copy_auxiliary_files(structure_repo, output_repo, hf_token) | |
| logs.append("Configuration files copied.") | |
| # 3. Load and Standardize LoRA | |
| progress(0.2, desc="Downloading & Processing LoRA...") | |
| logs.append(f"Fetching LoRA: {lora_input}") | |
| lora_path = download_lora(lora_input, hf_token) | |
| raw_lora_state = load_file(lora_path, device="cpu") | |
| # STANDARDIZE: Convert Comfy/Kohya keys to Diffusers keys & apply Alpha | |
| lora_state = standardize_lora_config(raw_lora_state) | |
| lora_keys = list(lora_state.keys()) | |
| logs.append(f"LoRA loaded & standardized. Found {len(lora_keys)} tensors.") | |
| if len(lora_keys) > 0: | |
| logs.append(f"Sample key: {lora_keys[0]}") | |
| # 4. Identify Base Shards | |
| progress(0.3, desc="Analyzing Base Model...") | |
| all_files = list_repo_files(repo_id=base_repo, token=hf_token) | |
| target_shards = [] | |
| for f in all_files: | |
| if not f.endswith(".safetensors"): | |
| continue | |
| if base_subfolder.strip() and not f.startswith(base_subfolder.strip("/")): | |
| continue | |
| target_shards.append(f) | |
| logs.append(f"Found {len(target_shards)} matching safetensors shards in base.") | |
| if not target_shards: | |
| raise ValueError("No safetensors found in the specified base repo/subfolder.") | |
| # 5. Process Shards | |
| total_shards = len(target_shards) | |
| merged_count = 0 | |
| for idx, shard_file in enumerate(target_shards): | |
| progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}") | |
| logs.append(f"--- Processing {shard_file} ---") | |
| local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir) | |
| # Load base to CPU | |
| base_tensors = load_file(local_shard, device="cpu") | |
| modified_tensors = {} | |
| has_changes = False | |
| for key, tensor in base_tensors.items(): | |
| pair_A, pair_B = match_keys(key, lora_keys) | |
| if pair_A and pair_B: | |
| w_a = lora_state[pair_A].float() | |
| w_b = lora_state[pair_B].float() | |
| current_tensor = tensor.float() | |
| # Apply merge | |
| # Note: Alpha scaling is already embedded in w_a/w_b by standardize_lora_config | |
| # We just apply the user_scale here | |
| # Check shapes for Transpose requirement | |
| # Standard LoRA: B @ A | |
| try: | |
| delta = (w_b @ w_a) * user_scale | |
| except RuntimeError: | |
| # Shape mismatch fallback | |
| # Sometimes LoRA weights are transposed relative to base | |
| if w_a.shape[0] == w_b.shape[1]: | |
| delta = (w_a @ w_b) * user_scale | |
| else: | |
| # Last ditch: try transposing B | |
| delta = (w_b.T @ w_a) * user_scale | |
| if delta.shape != current_tensor.shape: | |
| if delta.T.shape == current_tensor.shape: | |
| delta = delta.T | |
| else: | |
| # Log only once per shard to avoid spam | |
| if not has_changes: | |
| logs.append(f"Warning: Shape mismatch for {key}. Base: {current_tensor.shape}, Delta: {delta.shape}. Skipping.") | |
| modified_tensors[key] = tensor | |
| continue | |
| modified_tensors[key] = (current_tensor + delta).to(tensor.dtype) | |
| merged_count += 1 | |
| has_changes = True | |
| else: | |
| modified_tensors[key] = tensor | |
| if has_changes: | |
| logs.append(f"Merging complete for shard. Saving...") | |
| output_path = TempDir / "processed.safetensors" | |
| save_file(modified_tensors, output_path) | |
| api.upload_file(path_or_fileobj=output_path, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token) | |
| logs.append(f"Uploaded {shard_file}") | |
| else: | |
| logs.append(f"No LoRA matches in this shard. Copying original...") | |
| api.upload_file(path_or_fileobj=local_shard, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token) | |
| # cleanup | |
| del base_tensors | |
| del modified_tensors | |
| if 'delta' in locals(): del delta | |
| gc.collect() | |
| os.remove(local_shard) | |
| if os.path.exists(TempDir / "processed.safetensors"): | |
| os.remove(TempDir / "processed.safetensors") | |
| progress(1.0, desc="Done!") | |
| logs.append(f"\nSUCCESS. Merged {merged_count} layers total.") | |
| logs.append(f"New model available at: https://huggingface.co/{output_repo}") | |
| except Exception as e: | |
| import traceback | |
| logs.append(f"\nCRITICAL ERROR: {str(e)}") | |
| logs.append(traceback.format_exc()) | |
| finally: | |
| cleanup_temp() | |
| return "\n".join(logs) | |
| # --- UI --- | |
| css = """ | |
| .container { max-width: 900px; margin: auto; } | |
| .header { text-align: center; margin-bottom: 20px; } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # ⚡ soonMERGE® for Weights & Adapters | |
| Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure. | |
| **New:** Auto-converts ComfyUI/Kohya LoRA formats (e.g. Z-Image) to match Diffusers base models on the fly. | |
| """ | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("### 1. Authentication & Output") | |
| with gr.Row(): | |
| hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...") | |
| output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Merged") | |
| is_private = gr.Checkbox(label="Private Repo", value=True) | |
| with gr.Group(): | |
| gr.Markdown("### 2. Base Weights (The Target)") | |
| with gr.Row(): | |
| base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo") | |
| base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.") | |
| with gr.Group(): | |
| gr.Markdown("### 3. LoRA Configuration") | |
| with gr.Row(): | |
| lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.") | |
| scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1, info="Global multiplier (applied on top of LoRA's internal alpha)") | |
| with gr.Group(): | |
| gr.Markdown("### 4. Repository Reconstruction (Optional)") | |
| gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*") | |
| structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.") | |
| submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary") | |
| output_log = gr.Textbox(label="Process Log", lines=20, interactive=False) | |
| submit_btn.click( | |
| fn=run_merge, | |
| inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private], | |
| outputs=output_log | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=1).launch(css=css) |