import gradio as gr import torch import os import gc import json import shutil import requests from pathlib import Path from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login from safetensors.torch import load_file, save_file from safetensors import safe_open from tqdm import tqdm # --- Constants & Setup --- TempDir = Path("./temp_merge") os.makedirs(TempDir, exist_ok=True) api = HfApi() def info_log(msg, progress=None): print(msg) if progress: return msg return msg def cleanup_temp(): if TempDir.exists(): shutil.rmtree(TempDir) os.makedirs(TempDir, exist_ok=True) gc.collect() # --- Core Logic --- def download_lora(lora_input, hf_token): """Downloads LoRA from a Repo ID or a direct URL.""" local_path = TempDir / "adapter.safetensors" if lora_input.startswith("http"): # Direct URL download print(f"Downloading LoRA from URL: {lora_input}") response = requests.get(lora_input, stream=True) response.raise_for_status() with open(local_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return local_path else: # Repo ID download print(f"Downloading LoRA from Repo: {lora_input}") try: return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir) except: files = list_repo_files(repo_id=lora_input, token=hf_token) safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f] if not safe_files: safe_files = [f for f in files if f.endswith(".safetensors")] if not safe_files: raise ValueError("Could not find a .safetensors file in the LoRA repo.") return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir) def load_lora_weights(path): tensors = load_file(path, device="cpu") return tensors def match_keys(base_key, lora_keys): matches = {} candidates = [k for k in lora_keys if base_key in k] pair_A = None pair_B = None for k in candidates: if "lora_A" in k or "lora_down" in k: pair_A = k elif "lora_B" in k or "lora_up" in k: pair_B = k return pair_A, pair_B def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""): print(f"Copying infrastructure from {src_repo} to {tgt_repo}...") files = list_repo_files(repo_id=src_repo, token=token) files_to_copy = [ f for f in files if not f.endswith(".safetensors") and not f.endswith(".bin") and not f.endswith(".pt") and not f.endswith(".pth") and not f.endswith(".msgpack") and not f.endswith(".h5") ] for f in tqdm(files_to_copy, desc="Copying configs"): try: local = hf_hub_download(repo_id=src_repo, filename=f, token=token) api.upload_file( path_or_fileobj=local, path_in_repo=f, repo_id=tgt_repo, repo_type="model", token=token ) os.remove(local) except Exception as e: print(f"Skipped {f}: {e}") def run_merge( hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private, progress=gr.Progress() ): cleanup_temp() logs = [] try: login(hf_token) logs.append(f"Logged in. Target: {output_repo}") # 1. Create Output Repo try: api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token) logs.append("Output repository ready.") except Exception as e: return "\n".join(logs) + f"\nError creating repo: {e}" # 2. Replicate Structure if structure_repo.strip(): progress(0.1, desc="Cloning Model Structure...") logs.append(f"Cloning configuration from {structure_repo}...") copy_auxiliary_files(structure_repo, output_repo, hf_token) logs.append("Configuration files copied.") # 3. Load LoRA progress(0.2, desc="Downloading LoRA...") logs.append(f"Fetching LoRA: {lora_input}") lora_path = download_lora(lora_input, hf_token) lora_state = load_lora_weights(lora_path) lora_keys = list(lora_state.keys()) logs.append(f"LoRA loaded. Found {len(lora_keys)} tensors.") # 4. Identify Base Shards progress(0.3, desc="Analyzing Base Model...") all_files = list_repo_files(repo_id=base_repo, token=hf_token) target_shards = [] for f in all_files: if not f.endswith(".safetensors"): continue if base_subfolder.strip() and not f.startswith(base_subfolder.strip("/")): continue target_shards.append(f) logs.append(f"Found {len(target_shards)} matching safetensors shards in base.") if not target_shards: raise ValueError("No safetensors found in the specified base repo/subfolder.") # 5. Process Shards total_shards = len(target_shards) merged_count = 0 for idx, shard_file in enumerate(target_shards): progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}") logs.append(f"--- Processing {shard_file} ---") local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir) base_tensors = load_file(local_shard, device="cpu") modified_tensors = {} has_changes = False for key, tensor in base_tensors.items(): pair_A, pair_B = match_keys(key, lora_keys) if not pair_A: matches = [k for k in lora_keys if key in k] for k in matches: if "lora_A" in k or "lora_down" in k: pair_A = k elif "lora_B" in k or "lora_up" in k: pair_B = k if pair_A and pair_B: w_a = lora_state[pair_A].float() w_b = lora_state[pair_B].float() current_tensor = tensor.float() delta = (w_b @ w_a) * scale if delta.shape != current_tensor.shape: if delta.T.shape == current_tensor.shape: delta = delta.T else: logs.append(f"Warning: Shape mismatch for {key}. Skipping.") modified_tensors[key] = tensor continue modified_tensors[key] = (current_tensor + delta).to(tensor.dtype) merged_count += 1 has_changes = True else: modified_tensors[key] = tensor if has_changes: logs.append(f"Merging complete for shard. Saving...") output_path = TempDir / "processed.safetensors" save_file(modified_tensors, output_path) api.upload_file(path_or_fileobj=output_path, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token) logs.append(f"Uploaded {shard_file}") else: logs.append(f"No LoRA matches in this shard. Copying original...") api.upload_file(path_or_fileobj=local_shard, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token) del base_tensors del modified_tensors if 'delta' in locals(): del delta gc.collect() os.remove(local_shard) if os.path.exists(TempDir / "processed.safetensors"): os.remove(TempDir / "processed.safetensors") progress(1.0, desc="Done!") logs.append(f"\nSUCCESS. Merged {merged_count} layers total.") logs.append(f"New model available at: https://huggingface.co/{output_repo}") except Exception as e: import traceback logs.append(f"\nCRITICAL ERROR: {str(e)}") logs.append(traceback.format_exc()) finally: cleanup_temp() return "\n".join(logs) # --- UI --- css = """ .container { max-width: 900px; margin: auto; } .header { text-align: center; margin-bottom: 20px; } """ # NOTE: Removed 'css' and 'theme' from gr.Blocks() to be compatible with latest Gradio versions. with gr.Blocks() as demo: gr.Markdown( """ # ⚡ Universal LoRA Merger & Reconstructor Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure. Optimized for CPU-only execution on Hugging Face Spaces. """ ) with gr.Group(): gr.Markdown("### 1. Authentication & Output") with gr.Row(): hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...") output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Custom") is_private = gr.Checkbox(label="Private Repo", value=True) with gr.Group(): gr.Markdown("### 2. Base Weights (The Target)") with gr.Row(): base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo") base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.") with gr.Group(): gr.Markdown("### 3. LoRA Configuration") with gr.Row(): lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.") scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1) with gr.Group(): gr.Markdown("### 4. Repository Reconstruction (Optional)") gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*") structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.") submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary") output_log = gr.Textbox(label="Process Log", lines=20, interactive=False) submit_btn.click( fn=run_merge, inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private], outputs=output_log ) if __name__ == "__main__": # CSS is now passed here in the launch method demo.queue(max_size=1).launch(css=css)