Soon_Merger / app.py
AlekseyCalvin's picture
Create app.py
58aaafc verified
raw
history blame
13.8 kB
import gradio as gr
import torch
import os
import gc
import json
import shutil
import requests
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
# --- Constants & Setup ---
TempDir = Path("./temp_merge")
os.makedirs(TempDir, exist_ok=True)
api = HfApi()
def info_log(msg, progress=None):
print(msg)
if progress:
return msg
return msg
def cleanup_temp():
if TempDir.exists():
shutil.rmtree(TempDir)
os.makedirs(TempDir, exist_ok=True)
gc.collect()
# --- Core Logic ---
def download_lora(lora_input, hf_token):
"""Downloads LoRA from a Repo ID or a direct URL."""
local_path = TempDir / "adapter.safetensors"
if lora_input.startswith("http"):
# Direct URL download
print(f"Downloading LoRA from URL: {lora_input}")
response = requests.get(lora_input, stream=True)
response.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return local_path
else:
# Repo ID download
print(f"Downloading LoRA from Repo: {lora_input}")
# Try finding the safetensors file
try:
return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir)
except:
# Fallback for diffusion models which might use different names
files = list_repo_files(repo_id=lora_input, token=hf_token)
safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f]
if not safe_files:
# Last ditch: grab the first safetensors
safe_files = [f for f in files if f.endswith(".safetensors")]
if not safe_files:
raise ValueError("Could not find a .safetensors file in the LoRA repo.")
return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir)
def load_lora_weights(path):
"""Loads LoRA weights and attempts to determine rank/alpha."""
tensors = load_file(path, device="cpu")
# Basic metadata extraction could happen here if needed,
# but for raw merging we mainly need the state dict.
return tensors
def match_keys(base_key, lora_keys):
"""
Heuristic matching.
1. Exact match (rare for LoRA).
2. LoRA naming conventions (lora_A, lora_B, lora_down, etc).
"""
# Common LoRA naming patterns
# pattern: base_key.lora_A.weight
# pattern: base_key + ".0.lora_B.weight" (sometimes happens)
matches = {}
# Cleaning the keys for comparison
# If base is "transformer.blocks.0.weight"
# LoRA might be "transformer.blocks.0.lora_A.weight"
candidates = [k for k in lora_keys if base_key in k]
pair_A = None
pair_B = None
for k in candidates:
if "lora_A" in k or "lora_down" in k:
pair_A = k
elif "lora_B" in k or "lora_up" in k:
pair_B = k
return pair_A, pair_B
def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""):
"""Copies config/tokenizer/scheduler files from source to target."""
print(f"Copying infrastructure from {src_repo} to {tgt_repo}...")
files = list_repo_files(repo_id=src_repo, token=token)
# Filter out heavy weights
files_to_copy = [
f for f in files
if not f.endswith(".safetensors")
and not f.endswith(".bin")
and not f.endswith(".pt")
and not f.endswith(".pth")
and not f.endswith(".msgpack")
and not f.endswith(".h5")
]
for f in tqdm(files_to_copy, desc="Copying configs"):
try:
# We download to memory/temp and upload immediately
local = hf_hub_download(repo_id=src_repo, filename=f, token=token)
api.upload_file(
path_or_fileobj=local,
path_in_repo=f,
repo_id=tgt_repo,
repo_type="model",
token=token
)
os.remove(local)
except Exception as e:
print(f"Skipped {f}: {e}")
def run_merge(
hf_token,
base_repo,
base_subfolder,
structure_repo,
lora_input,
scale,
output_repo,
is_private,
progress=gr.Progress()
):
cleanup_temp()
logs = []
try:
login(hf_token)
logs.append(f"Logged in. Target: {output_repo}")
# 1. Create Output Repo
try:
api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token)
logs.append("Output repository ready.")
except Exception as e:
return "\n".join(logs) + f"\nError creating repo: {e}"
# 2. Replicate Structure (If requested)
if structure_repo.strip():
progress(0.1, desc="Cloning Model Structure (Configs)...")
logs.append(f"Cloning configuration from {structure_repo}...")
copy_auxiliary_files(structure_repo, output_repo, hf_token)
logs.append("Configuration files copied.")
# 3. Load LoRA
progress(0.2, desc="Downloading LoRA...")
logs.append(f"Fetching LoRA: {lora_input}")
lora_path = download_lora(lora_input, hf_token)
lora_state = load_lora_weights(lora_path)
lora_keys = list(lora_state.keys())
logs.append(f"LoRA loaded. Found {len(lora_keys)} tensors.")
# 4. Identify Base Shards
progress(0.3, desc="Analyzing Base Model...")
all_files = list_repo_files(repo_id=base_repo, token=hf_token)
# Filter for safetensors in the specific subfolder (if provided)
target_shards = []
for f in all_files:
if not f.endswith(".safetensors"):
continue
# Check subfolder constraint
if base_subfolder.strip():
# Normalize paths
if not f.startswith(base_subfolder.strip("/")):
continue
target_shards.append(f)
logs.append(f"Found {len(target_shards)} matching safetensors shards in base.")
if not target_shards:
raise ValueError("No safetensors found in the specified base repo/subfolder.")
# 5. Process Shards (Streamed)
total_shards = len(target_shards)
merged_count = 0
for idx, shard_file in enumerate(target_shards):
progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}")
logs.append(f"--- Processing {shard_file} ---")
# Download Shard
local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
# Load and Merge
# We use safe_open to read metadata, but load_file for the dict to modify
# load_file loads to CPU RAM.
base_tensors = load_file(local_shard, device="cpu")
modified_tensors = {}
has_changes = False
for key, tensor in base_tensors.items():
# Match LoRA
# Handle architectural prefix mismatches (e.g. Ostris repo might rely on folder structure,
# while LoRA expects "transformer." prefix)
# Try exact match first (unlikely for LoRA)
pair_A, pair_B = match_keys(key, lora_keys)
# If not found, try adding/removing common prefixes
if not pair_A:
# Attempt to match "blocks.1..." to "transformer.blocks.1..."
matches = [k for k in lora_keys if key in k] # Simple substring check
for k in matches:
if "lora_A" in k or "lora_down" in k:
pair_A = k
elif "lora_B" in k or "lora_up" in k:
pair_B = k
if pair_A and pair_B:
# Apply Merge
w_a = lora_state[pair_A].float()
w_b = lora_state[pair_B].float()
# Target tensor
current_tensor = tensor.float()
# Dimension Check
# LoRA = B @ A. Shape should match current_tensor.
# Sometimes LoRA weights are transposed relative to base depending on training lib.
delta = (w_b @ w_a) * scale
if delta.shape != current_tensor.shape:
# Try transposing matches
if delta.T.shape == current_tensor.shape:
delta = delta.T
else:
logs.append(f"Warning: Shape mismatch for {key}. Base: {current_tensor.shape}, LoRA Delta: {delta.shape}. Skipping.")
modified_tensors[key] = tensor
continue
modified_tensors[key] = (current_tensor + delta).to(tensor.dtype)
merged_count += 1
has_changes = True
else:
modified_tensors[key] = tensor
# Save and Upload
if has_changes:
logs.append(f"Merging complete for shard. Saving...")
output_path = TempDir / "processed.safetensors"
save_file(modified_tensors, output_path)
api.upload_file(
path_or_fileobj=output_path,
path_in_repo=shard_file, # Keep original structure
repo_id=output_repo,
repo_type="model",
token=hf_token
)
logs.append(f"Uploaded {shard_file}")
else:
# If no changes, just copy the original file to the new repo
# This saves re-saving the tensor dict
logs.append(f"No LoRA matches in this shard. Copying original...")
api.upload_file(
path_or_fileobj=local_shard,
path_in_repo=shard_file,
repo_id=output_repo,
repo_type="model",
token=hf_token
)
# Cleanup Memory immediately
del base_tensors
del modified_tensors
if 'delta' in locals(): del delta
gc.collect()
os.remove(local_shard)
if os.path.exists(TempDir / "processed.safetensors"):
os.remove(TempDir / "processed.safetensors")
progress(1.0, desc="Done!")
logs.append(f"\nSUCCESS. Merged {merged_count} layers total.")
logs.append(f"New model available at: https://huggingface.co/{output_repo}")
except Exception as e:
import traceback
logs.append(f"\nCRITICAL ERROR: {str(e)}")
logs.append(traceback.format_exc())
finally:
cleanup_temp()
return "\n".join(logs)
# --- UI ---
css = """
.container { max-width: 900px; margin: auto; }
.header { text-align: center; margin-bottom: 20px; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown(
"""
# ⚡ Universal LoRA Merger & Reconstructor
Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure.
Optimized for CPU-only execution on Hugging Face Spaces.
"""
)
with gr.Group():
gr.Markdown("### 1. Authentication & Output")
with gr.Row():
hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Custom")
is_private = gr.Checkbox(label="Private Repo", value=True)
with gr.Group():
gr.Markdown("### 2. Base Weights (The Target)")
with gr.Row():
base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo")
base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.")
with gr.Group():
gr.Markdown("### 3. LoRA Configuration")
with gr.Row():
lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.")
scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1)
with gr.Group():
gr.Markdown("### 4. Repository Reconstruction (Optional)")
gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*")
structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.")
submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary")
output_log = gr.Textbox(label="Process Log", lines=20, interactive=False)
submit_btn.click(
fn=run_merge,
inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private],
outputs=output_log
)
if __name__ == "__main__":
demo.queue(max_size=1).launch()