import gradio as gr import torch import os import gc import re import shutil import requests import json from pathlib import Path from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login from safetensors.torch import load_file, save_file from safetensors import safe_open from tqdm import tqdm # --- Constants & Setup --- TempDir = Path("./temp_merge") os.makedirs(TempDir, exist_ok=True) api = HfApi() def cleanup_temp(): if TempDir.exists(): shutil.rmtree(TempDir) os.makedirs(TempDir, exist_ok=True) gc.collect() # --- Core Logic --- def download_lora(lora_input, hf_token): """Downloads LoRA from a Repo ID or a direct URL.""" local_path = TempDir / "adapter.safetensors" if lora_input.startswith("http"): print(f"Downloading LoRA from URL: {lora_input}") response = requests.get(lora_input, stream=True) response.raise_for_status() with open(local_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return local_path else: print(f"Downloading LoRA from Repo: {lora_input}") try: return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir) except: files = list_repo_files(repo_id=lora_input, token=hf_token) # Prioritize safetensors safe_files = [f for f in files if f.endswith(".safetensors")] if not safe_files: raise ValueError("Could not find a .safetensors file in the LoRA repo.") # Heuristic: pick the one that looks most like a model file target_file = safe_files[0] for f in safe_files: if "fp16" in f or "rank" in f: target_file = f break return hf_hub_download(repo_id=lora_input, filename=target_file, token=hf_token, local_dir=TempDir) def standardize_lora_config(lora_state_dict): """ Analyzes the LoRA state dict and converts keys to a standardized Diffusers-compatible format. Handles 'lora_down' -> 'lora_A', prefix stripping, and alpha scaling. """ standardized_dict = {} alphas = {} ranks = {} keys = list(lora_state_dict.keys()) # 1. First Pass: Detect structure and Alphas for key in keys: if "alpha" in key: # key example: diffusion_model.layers.24.feed_forward.w1.alpha stem = key.replace(".alpha", "") alphas[stem] = lora_state_dict[key].item() if isinstance(lora_state_dict[key], torch.Tensor) else lora_state_dict[key] print(f"Found {len(alphas)} alpha keys in LoRA.") # 2. Second Pass: Convert Weights for key in keys: if "alpha" in key: continue tensor = lora_state_dict[key] new_key = key # --- Conversion Logic (Inspired by Diffusers lora_conversion_utils.py) --- # Strip common ComfyUI/Internal prefixes prefixes_to_strip = ["diffusion_model.", "model.diffusion_model.", "lora_unet_"] for p in prefixes_to_strip: if new_key.startswith(p): new_key = new_key[len(p):] # Convert lora_down/up to lora_A/B is_down = "lora_down.weight" in new_key is_up = "lora_up.weight" in new_key if is_down: new_key = new_key.replace("lora_down.weight", "lora_A.weight") stem = key.split(".lora_down.weight")[0] ranks[stem] = tensor.shape[0] # Down projection output dim is rank elif is_up: new_key = new_key.replace("lora_up.weight", "lora_B.weight") # Handling Z-Image specific "feed_forward" vs "ff" discrepancies if necessary # (Based on your logs, Z-Image base uses 'feed_forward' so we might not need heavy mapping if we strip prefix) standardized_dict[new_key] = tensor # 3. Third Pass: Embed Scaling into Weights # If we have alpha and rank, we can pre-multiply the weights so the merge function just needs to do B @ A # Scale = alpha / rank final_dict = {} for key, tensor in standardized_dict.items(): # Find corresponding stem to check for alpha # key is like: layers.24.feed_forward.w1.lora_A.weight if "lora_A.weight" in key: stem_suffix = ".lora_A.weight" is_A = True elif "lora_B.weight" in key: stem_suffix = ".lora_B.weight" is_A = False else: final_dict[key] = tensor continue # We need to map the "new key" stem back to the "old key" stem to find the alpha # This is tricky because we stripped prefixes. # Simpler approach: Calculate scale factor now if possible, or store metadata. # Heuristic: Match alpha by checking if alpha key ends with the current key's structural part # Current key struct: layers.24.feed_forward.w1 struct_part = key.replace(stem_suffix, "") scale = 1.0 # Find matching alpha # We look for an alpha key that ends with 'struct_part' # e.g. alpha key "diffusion_model.layers.24...w1" ends with "layers.24...w1" found_alpha = None for a_key, a_val in alphas.items(): if a_key.endswith(struct_part): found_alpha = a_val break if found_alpha: # We need the rank. # If it's lora_A, rank is tensor.shape[0] # If it's lora_B, rank is tensor.shape[1] rank = tensor.shape[0] if is_A else tensor.shape[1] # Scale calculation: scale = alpha / rank # We apply sqrt(scale) to both A and B so that A@B is scaled by (alpha/rank) scale_factor = (found_alpha / rank) ** 0.5 tensor = tensor * scale_factor final_dict[key] = tensor return final_dict def match_keys(base_key, lora_keys): """ Robust matching finding the best LoRA pair for a Base Key. """ # base_key example: layers.24.feed_forward.w1.weight # lora_key example: layers.24.feed_forward.w1.lora_A.weight base_stem = base_key.replace(".weight", "") pair_A = None pair_B = None # Exact stem match check candidate_A = f"{base_stem}.lora_A.weight" candidate_B = f"{base_stem}.lora_B.weight" if candidate_A in lora_keys and candidate_B in lora_keys: return candidate_A, candidate_B # Fuzzy match if exact fails # This handles slight naming diffs like "processor" inclusion matches = [k for k in lora_keys if base_stem in k] for k in matches: if "lora_A" in k: pair_A = k elif "lora_B" in k: pair_B = k if pair_A and pair_B: # Verify they belong to the same block # e.g. ensure we don't match layer.24 to layer.2 prefix_A = pair_A.split(".lora_A")[0] prefix_B = pair_B.split(".lora_B")[0] if prefix_A == prefix_B: return pair_A, pair_B return None, None def copy_auxiliary_files(src_repo, tgt_repo, token): print(f"Copying infrastructure from {src_repo} to {tgt_repo}...") try: files = list_repo_files(repo_id=src_repo, token=token) files_to_copy = [ f for f in files if not f.endswith(".safetensors") and not f.endswith(".bin") and not f.endswith(".pt") and not f.endswith(".pth") and not f.endswith(".msgpack") and not f.endswith(".h5") ] for f in tqdm(files_to_copy, desc="Copying configs"): try: local = hf_hub_download(repo_id=src_repo, filename=f, token=token) api.upload_file( path_or_fileobj=local, path_in_repo=f, repo_id=tgt_repo, repo_type="model", token=token ) os.remove(local) except Exception as e: print(f"Skipped {f}: {e}") except Exception as e: print(f"Error copying config files: {e}") def run_merge( hf_token, base_repo, base_subfolder, structure_repo, lora_input, user_scale, output_repo, is_private, progress=gr.Progress() ): cleanup_temp() logs = [] try: login(hf_token) logs.append(f"Logged in. Target: {output_repo}") # 1. Create Output Repo try: api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token) logs.append("Output repository ready.") except Exception as e: return "\n".join(logs) + f"\nError creating repo: {e}" # 2. Replicate Structure if structure_repo.strip(): progress(0.1, desc="Cloning Model Structure...") logs.append(f"Cloning configuration from {structure_repo}...") copy_auxiliary_files(structure_repo, output_repo, hf_token) logs.append("Configuration files copied.") # 3. Load and Standardize LoRA progress(0.2, desc="Downloading & Processing LoRA...") logs.append(f"Fetching LoRA: {lora_input}") lora_path = download_lora(lora_input, hf_token) raw_lora_state = load_file(lora_path, device="cpu") # STANDARDIZE: Convert Comfy/Kohya keys to Diffusers keys & apply Alpha lora_state = standardize_lora_config(raw_lora_state) lora_keys = list(lora_state.keys()) logs.append(f"LoRA loaded & standardized. Found {len(lora_keys)} tensors.") if len(lora_keys) > 0: logs.append(f"Sample key: {lora_keys[0]}") # 4. Identify Base Shards progress(0.3, desc="Analyzing Base Model...") all_files = list_repo_files(repo_id=base_repo, token=hf_token) target_shards = [] for f in all_files: if not f.endswith(".safetensors"): continue if base_subfolder.strip() and not f.startswith(base_subfolder.strip("/")): continue target_shards.append(f) logs.append(f"Found {len(target_shards)} matching safetensors shards in base.") if not target_shards: raise ValueError("No safetensors found in the specified base repo/subfolder.") # 5. Process Shards total_shards = len(target_shards) merged_count = 0 for idx, shard_file in enumerate(target_shards): progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}") logs.append(f"--- Processing {shard_file} ---") local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir) # Load base to CPU base_tensors = load_file(local_shard, device="cpu") modified_tensors = {} has_changes = False for key, tensor in base_tensors.items(): pair_A, pair_B = match_keys(key, lora_keys) if pair_A and pair_B: w_a = lora_state[pair_A].float() w_b = lora_state[pair_B].float() current_tensor = tensor.float() # Apply merge # Note: Alpha scaling is already embedded in w_a/w_b by standardize_lora_config # We just apply the user_scale here # Check shapes for Transpose requirement # Standard LoRA: B @ A try: delta = (w_b @ w_a) * user_scale except RuntimeError: # Shape mismatch fallback # Sometimes LoRA weights are transposed relative to base if w_a.shape[0] == w_b.shape[1]: delta = (w_a @ w_b) * user_scale else: # Last ditch: try transposing B delta = (w_b.T @ w_a) * user_scale if delta.shape != current_tensor.shape: if delta.T.shape == current_tensor.shape: delta = delta.T else: # Log only once per shard to avoid spam if not has_changes: logs.append(f"Warning: Shape mismatch for {key}. Base: {current_tensor.shape}, Delta: {delta.shape}. Skipping.") modified_tensors[key] = tensor continue modified_tensors[key] = (current_tensor + delta).to(tensor.dtype) merged_count += 1 has_changes = True else: modified_tensors[key] = tensor if has_changes: logs.append(f"Merging complete for shard. Saving...") output_path = TempDir / "processed.safetensors" save_file(modified_tensors, output_path) api.upload_file(path_or_fileobj=output_path, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token) logs.append(f"Uploaded {shard_file}") else: logs.append(f"No LoRA matches in this shard. Copying original...") api.upload_file(path_or_fileobj=local_shard, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token) # cleanup del base_tensors del modified_tensors if 'delta' in locals(): del delta gc.collect() os.remove(local_shard) if os.path.exists(TempDir / "processed.safetensors"): os.remove(TempDir / "processed.safetensors") progress(1.0, desc="Done!") logs.append(f"\nSUCCESS. Merged {merged_count} layers total.") logs.append(f"New model available at: https://huggingface.co/{output_repo}") except Exception as e: import traceback logs.append(f"\nCRITICAL ERROR: {str(e)}") logs.append(traceback.format_exc()) finally: cleanup_temp() return "\n".join(logs) # --- UI --- css = """ .container { max-width: 900px; margin: auto; } .header { text-align: center; margin-bottom: 20px; } """ with gr.Blocks() as demo: gr.Markdown( """ # ⚡ soonMERGE® for Weights & Adapters Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure. **New:** Auto-converts ComfyUI/Kohya LoRA formats (e.g. Z-Image) to match Diffusers base models on the fly. """ ) with gr.Group(): gr.Markdown("### 1. Authentication & Output") with gr.Row(): hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...") output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Merged") is_private = gr.Checkbox(label="Private Repo", value=True) with gr.Group(): gr.Markdown("### 2. Base Weights (The Target)") with gr.Row(): base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo") base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.") with gr.Group(): gr.Markdown("### 3. LoRA Configuration") with gr.Row(): lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.") scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1, info="Global multiplier (applied on top of LoRA's internal alpha)") with gr.Group(): gr.Markdown("### 4. Repository Reconstruction (Optional)") gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*") structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.") submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary") output_log = gr.Textbox(label="Process Log", lines=20, interactive=False) submit_btn.click( fn=run_merge, inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private], outputs=output_log ) if __name__ == "__main__": demo.queue(max_size=1).launch(css=css)