Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import os | |
| import gc | |
| import json | |
| import shutil | |
| import requests | |
| from pathlib import Path | |
| from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login | |
| from safetensors.torch import load_file, save_file | |
| from safetensors import safe_open | |
| from tqdm import tqdm | |
| # --- Constants & Setup --- | |
| TempDir = Path("./temp_merge") | |
| os.makedirs(TempDir, exist_ok=True) | |
| api = HfApi() | |
| def info_log(msg, progress=None): | |
| print(msg) | |
| if progress: | |
| return msg | |
| return msg | |
| def cleanup_temp(): | |
| if TempDir.exists(): | |
| shutil.rmtree(TempDir) | |
| os.makedirs(TempDir, exist_ok=True) | |
| gc.collect() | |
| # --- Core Logic --- | |
| def download_lora(lora_input, hf_token): | |
| """Downloads LoRA from a Repo ID or a direct URL.""" | |
| local_path = TempDir / "adapter.safetensors" | |
| if lora_input.startswith("http"): | |
| # Direct URL download | |
| print(f"Downloading LoRA from URL: {lora_input}") | |
| response = requests.get(lora_input, stream=True) | |
| response.raise_for_status() | |
| with open(local_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return local_path | |
| else: | |
| # Repo ID download | |
| print(f"Downloading LoRA from Repo: {lora_input}") | |
| try: | |
| return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir) | |
| except: | |
| files = list_repo_files(repo_id=lora_input, token=hf_token) | |
| safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f] | |
| if not safe_files: | |
| safe_files = [f for f in files if f.endswith(".safetensors")] | |
| if not safe_files: | |
| raise ValueError("Could not find a .safetensors file in the LoRA repo.") | |
| return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir) | |
| def load_lora_weights(path): | |
| tensors = load_file(path, device="cpu") | |
| return tensors | |
| def match_keys(base_key, lora_keys): | |
| matches = {} | |
| candidates = [k for k in lora_keys if base_key in k] | |
| pair_A = None | |
| pair_B = None | |
| for k in candidates: | |
| if "lora_A" in k or "lora_down" in k: | |
| pair_A = k | |
| elif "lora_B" in k or "lora_up" in k: | |
| pair_B = k | |
| return pair_A, pair_B | |
| def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""): | |
| print(f"Copying infrastructure from {src_repo} to {tgt_repo}...") | |
| files = list_repo_files(repo_id=src_repo, token=token) | |
| files_to_copy = [ | |
| f for f in files | |
| if not f.endswith(".safetensors") | |
| and not f.endswith(".bin") | |
| and not f.endswith(".pt") | |
| and not f.endswith(".pth") | |
| and not f.endswith(".msgpack") | |
| and not f.endswith(".h5") | |
| ] | |
| for f in tqdm(files_to_copy, desc="Copying configs"): | |
| try: | |
| local = hf_hub_download(repo_id=src_repo, filename=f, token=token) | |
| api.upload_file( | |
| path_or_fileobj=local, | |
| path_in_repo=f, | |
| repo_id=tgt_repo, | |
| repo_type="model", | |
| token=token | |
| ) | |
| os.remove(local) | |
| except Exception as e: | |
| print(f"Skipped {f}: {e}") | |
| def run_merge( | |
| hf_token, | |
| base_repo, | |
| base_subfolder, | |
| structure_repo, | |
| lora_input, | |
| scale, | |
| output_repo, | |
| is_private, | |
| progress=gr.Progress() | |
| ): | |
| cleanup_temp() | |
| logs = [] | |
| try: | |
| login(hf_token) | |
| logs.append(f"Logged in. Target: {output_repo}") | |
| # 1. Create Output Repo | |
| try: | |
| api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token) | |
| logs.append("Output repository ready.") | |
| except Exception as e: | |
| return "\n".join(logs) + f"\nError creating repo: {e}" | |
| # 2. Replicate Structure | |
| if structure_repo.strip(): | |
| progress(0.1, desc="Cloning Model Structure...") | |
| logs.append(f"Cloning configuration from {structure_repo}...") | |
| copy_auxiliary_files(structure_repo, output_repo, hf_token) | |
| logs.append("Configuration files copied.") | |
| # 3. Load LoRA | |
| progress(0.2, desc="Downloading LoRA...") | |
| logs.append(f"Fetching LoRA: {lora_input}") | |
| lora_path = download_lora(lora_input, hf_token) | |
| lora_state = load_lora_weights(lora_path) | |
| lora_keys = list(lora_state.keys()) | |
| logs.append(f"LoRA loaded. Found {len(lora_keys)} tensors.") | |
| # 4. Identify Base Shards | |
| progress(0.3, desc="Analyzing Base Model...") | |
| all_files = list_repo_files(repo_id=base_repo, token=hf_token) | |
| target_shards = [] | |
| for f in all_files: | |
| if not f.endswith(".safetensors"): | |
| continue | |
| if base_subfolder.strip() and not f.startswith(base_subfolder.strip("/")): | |
| continue | |
| target_shards.append(f) | |
| logs.append(f"Found {len(target_shards)} matching safetensors shards in base.") | |
| if not target_shards: | |
| raise ValueError("No safetensors found in the specified base repo/subfolder.") | |
| # 5. Process Shards | |
| total_shards = len(target_shards) | |
| merged_count = 0 | |
| for idx, shard_file in enumerate(target_shards): | |
| progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}") | |
| logs.append(f"--- Processing {shard_file} ---") | |
| local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir) | |
| base_tensors = load_file(local_shard, device="cpu") | |
| modified_tensors = {} | |
| has_changes = False | |
| for key, tensor in base_tensors.items(): | |
| pair_A, pair_B = match_keys(key, lora_keys) | |
| if not pair_A: | |
| matches = [k for k in lora_keys if key in k] | |
| for k in matches: | |
| if "lora_A" in k or "lora_down" in k: | |
| pair_A = k | |
| elif "lora_B" in k or "lora_up" in k: | |
| pair_B = k | |
| if pair_A and pair_B: | |
| w_a = lora_state[pair_A].float() | |
| w_b = lora_state[pair_B].float() | |
| current_tensor = tensor.float() | |
| delta = (w_b @ w_a) * scale | |
| if delta.shape != current_tensor.shape: | |
| if delta.T.shape == current_tensor.shape: | |
| delta = delta.T | |
| else: | |
| logs.append(f"Warning: Shape mismatch for {key}. Skipping.") | |
| modified_tensors[key] = tensor | |
| continue | |
| modified_tensors[key] = (current_tensor + delta).to(tensor.dtype) | |
| merged_count += 1 | |
| has_changes = True | |
| else: | |
| modified_tensors[key] = tensor | |
| if has_changes: | |
| logs.append(f"Merging complete for shard. Saving...") | |
| output_path = TempDir / "processed.safetensors" | |
| save_file(modified_tensors, output_path) | |
| api.upload_file(path_or_fileobj=output_path, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token) | |
| logs.append(f"Uploaded {shard_file}") | |
| else: | |
| logs.append(f"No LoRA matches in this shard. Copying original...") | |
| api.upload_file(path_or_fileobj=local_shard, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token) | |
| del base_tensors | |
| del modified_tensors | |
| if 'delta' in locals(): del delta | |
| gc.collect() | |
| os.remove(local_shard) | |
| if os.path.exists(TempDir / "processed.safetensors"): | |
| os.remove(TempDir / "processed.safetensors") | |
| progress(1.0, desc="Done!") | |
| logs.append(f"\nSUCCESS. Merged {merged_count} layers total.") | |
| logs.append(f"New model available at: https://huggingface.co/{output_repo}") | |
| except Exception as e: | |
| import traceback | |
| logs.append(f"\nCRITICAL ERROR: {str(e)}") | |
| logs.append(traceback.format_exc()) | |
| finally: | |
| cleanup_temp() | |
| return "\n".join(logs) | |
| # --- UI --- | |
| css = """ | |
| .container { max-width: 900px; margin: auto; } | |
| .header { text-align: center; margin-bottom: 20px; } | |
| """ | |
| # NOTE: Removed 'css' and 'theme' from gr.Blocks() to be compatible with latest Gradio versions. | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # ⚡ Universal LoRA Merger & Reconstructor | |
| Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure. | |
| Optimized for CPU-only execution on Hugging Face Spaces. | |
| """ | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("### 1. Authentication & Output") | |
| with gr.Row(): | |
| hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...") | |
| output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Custom") | |
| is_private = gr.Checkbox(label="Private Repo", value=True) | |
| with gr.Group(): | |
| gr.Markdown("### 2. Base Weights (The Target)") | |
| with gr.Row(): | |
| base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo") | |
| base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.") | |
| with gr.Group(): | |
| gr.Markdown("### 3. LoRA Configuration") | |
| with gr.Row(): | |
| lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.") | |
| scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1) | |
| with gr.Group(): | |
| gr.Markdown("### 4. Repository Reconstruction (Optional)") | |
| gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*") | |
| structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.") | |
| submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary") | |
| output_log = gr.Textbox(label="Process Log", lines=20, interactive=False) | |
| submit_btn.click( | |
| fn=run_merge, | |
| inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private], | |
| outputs=output_log | |
| ) | |
| if __name__ == "__main__": | |
| # CSS is now passed here in the launch method | |
| demo.queue(max_size=1).launch(css=css) |