Soon_Merger / app2.py
AlekseyCalvin's picture
Rename app.py to app2.py
f578122 verified
raw
history blame
11.2 kB
import gradio as gr
import torch
import os
import gc
import json
import shutil
import requests
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
# --- Constants & Setup ---
TempDir = Path("./temp_merge")
os.makedirs(TempDir, exist_ok=True)
api = HfApi()
def info_log(msg, progress=None):
print(msg)
if progress:
return msg
return msg
def cleanup_temp():
if TempDir.exists():
shutil.rmtree(TempDir)
os.makedirs(TempDir, exist_ok=True)
gc.collect()
# --- Core Logic ---
def download_lora(lora_input, hf_token):
"""Downloads LoRA from a Repo ID or a direct URL."""
local_path = TempDir / "adapter.safetensors"
if lora_input.startswith("http"):
# Direct URL download
print(f"Downloading LoRA from URL: {lora_input}")
response = requests.get(lora_input, stream=True)
response.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return local_path
else:
# Repo ID download
print(f"Downloading LoRA from Repo: {lora_input}")
try:
return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir)
except:
files = list_repo_files(repo_id=lora_input, token=hf_token)
safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f]
if not safe_files:
safe_files = [f for f in files if f.endswith(".safetensors")]
if not safe_files:
raise ValueError("Could not find a .safetensors file in the LoRA repo.")
return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir)
def load_lora_weights(path):
tensors = load_file(path, device="cpu")
return tensors
def match_keys(base_key, lora_keys):
matches = {}
candidates = [k for k in lora_keys if base_key in k]
pair_A = None
pair_B = None
for k in candidates:
if "lora_A" in k or "lora_down" in k:
pair_A = k
elif "lora_B" in k or "lora_up" in k:
pair_B = k
return pair_A, pair_B
def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""):
print(f"Copying infrastructure from {src_repo} to {tgt_repo}...")
files = list_repo_files(repo_id=src_repo, token=token)
files_to_copy = [
f for f in files
if not f.endswith(".safetensors")
and not f.endswith(".bin")
and not f.endswith(".pt")
and not f.endswith(".pth")
and not f.endswith(".msgpack")
and not f.endswith(".h5")
]
for f in tqdm(files_to_copy, desc="Copying configs"):
try:
local = hf_hub_download(repo_id=src_repo, filename=f, token=token)
api.upload_file(
path_or_fileobj=local,
path_in_repo=f,
repo_id=tgt_repo,
repo_type="model",
token=token
)
os.remove(local)
except Exception as e:
print(f"Skipped {f}: {e}")
def run_merge(
hf_token,
base_repo,
base_subfolder,
structure_repo,
lora_input,
scale,
output_repo,
is_private,
progress=gr.Progress()
):
cleanup_temp()
logs = []
try:
login(hf_token)
logs.append(f"Logged in. Target: {output_repo}")
# 1. Create Output Repo
try:
api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token)
logs.append("Output repository ready.")
except Exception as e:
return "\n".join(logs) + f"\nError creating repo: {e}"
# 2. Replicate Structure
if structure_repo.strip():
progress(0.1, desc="Cloning Model Structure...")
logs.append(f"Cloning configuration from {structure_repo}...")
copy_auxiliary_files(structure_repo, output_repo, hf_token)
logs.append("Configuration files copied.")
# 3. Load LoRA
progress(0.2, desc="Downloading LoRA...")
logs.append(f"Fetching LoRA: {lora_input}")
lora_path = download_lora(lora_input, hf_token)
lora_state = load_lora_weights(lora_path)
lora_keys = list(lora_state.keys())
logs.append(f"LoRA loaded. Found {len(lora_keys)} tensors.")
# 4. Identify Base Shards
progress(0.3, desc="Analyzing Base Model...")
all_files = list_repo_files(repo_id=base_repo, token=hf_token)
target_shards = []
for f in all_files:
if not f.endswith(".safetensors"):
continue
if base_subfolder.strip() and not f.startswith(base_subfolder.strip("/")):
continue
target_shards.append(f)
logs.append(f"Found {len(target_shards)} matching safetensors shards in base.")
if not target_shards:
raise ValueError("No safetensors found in the specified base repo/subfolder.")
# 5. Process Shards
total_shards = len(target_shards)
merged_count = 0
for idx, shard_file in enumerate(target_shards):
progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}")
logs.append(f"--- Processing {shard_file} ---")
local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
base_tensors = load_file(local_shard, device="cpu")
modified_tensors = {}
has_changes = False
for key, tensor in base_tensors.items():
pair_A, pair_B = match_keys(key, lora_keys)
if not pair_A:
matches = [k for k in lora_keys if key in k]
for k in matches:
if "lora_A" in k or "lora_down" in k:
pair_A = k
elif "lora_B" in k or "lora_up" in k:
pair_B = k
if pair_A and pair_B:
w_a = lora_state[pair_A].float()
w_b = lora_state[pair_B].float()
current_tensor = tensor.float()
delta = (w_b @ w_a) * scale
if delta.shape != current_tensor.shape:
if delta.T.shape == current_tensor.shape:
delta = delta.T
else:
logs.append(f"Warning: Shape mismatch for {key}. Skipping.")
modified_tensors[key] = tensor
continue
modified_tensors[key] = (current_tensor + delta).to(tensor.dtype)
merged_count += 1
has_changes = True
else:
modified_tensors[key] = tensor
if has_changes:
logs.append(f"Merging complete for shard. Saving...")
output_path = TempDir / "processed.safetensors"
save_file(modified_tensors, output_path)
api.upload_file(path_or_fileobj=output_path, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token)
logs.append(f"Uploaded {shard_file}")
else:
logs.append(f"No LoRA matches in this shard. Copying original...")
api.upload_file(path_or_fileobj=local_shard, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token)
del base_tensors
del modified_tensors
if 'delta' in locals(): del delta
gc.collect()
os.remove(local_shard)
if os.path.exists(TempDir / "processed.safetensors"):
os.remove(TempDir / "processed.safetensors")
progress(1.0, desc="Done!")
logs.append(f"\nSUCCESS. Merged {merged_count} layers total.")
logs.append(f"New model available at: https://huggingface.co/{output_repo}")
except Exception as e:
import traceback
logs.append(f"\nCRITICAL ERROR: {str(e)}")
logs.append(traceback.format_exc())
finally:
cleanup_temp()
return "\n".join(logs)
# --- UI ---
css = """
.container { max-width: 900px; margin: auto; }
.header { text-align: center; margin-bottom: 20px; }
"""
# NOTE: Removed 'css' and 'theme' from gr.Blocks() to be compatible with latest Gradio versions.
with gr.Blocks() as demo:
gr.Markdown(
"""
# ⚡ Universal LoRA Merger & Reconstructor
Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure.
Optimized for CPU-only execution on Hugging Face Spaces.
"""
)
with gr.Group():
gr.Markdown("### 1. Authentication & Output")
with gr.Row():
hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Custom")
is_private = gr.Checkbox(label="Private Repo", value=True)
with gr.Group():
gr.Markdown("### 2. Base Weights (The Target)")
with gr.Row():
base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo")
base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.")
with gr.Group():
gr.Markdown("### 3. LoRA Configuration")
with gr.Row():
lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.")
scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1)
with gr.Group():
gr.Markdown("### 4. Repository Reconstruction (Optional)")
gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*")
structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.")
submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary")
output_log = gr.Textbox(label="Process Log", lines=20, interactive=False)
submit_btn.click(
fn=run_merge,
inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private],
outputs=output_log
)
if __name__ == "__main__":
# CSS is now passed here in the launch method
demo.queue(max_size=1).launch(css=css)