Soon_Merger / app.py
AlekseyCalvin's picture
Create app.py
a690cfc verified
raw
history blame
26.9 kB
import gradio as gr
import torch
import os
import gc
import shutil
import requests
import json
import struct
import numpy as np
import re
from pathlib import Path
from typing import Dict, Any, Optional, List
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
from safetensors.torch import load_file, save_file
from tqdm import tqdm
# --- Memory Efficient Safetensors ---
class MemoryEfficientSafeOpen:
"""
Reads safetensors metadata and tensors without mmap, keeping RAM usage low.
"""
def __init__(self, filename):
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self) -> list[str]:
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype_map = {
"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
"I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
"U8": torch.uint8, "BOOL": torch.bool
}
dtype = dtype_map[metadata["dtype"]]
shape = metadata["shape"]
return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
# --- Constants & Setup ---
try:
TempDir = Path("/tmp/temp_tool")
os.makedirs(TempDir, exist_ok=True)
except:
TempDir = Path("./temp_tool")
os.makedirs(TempDir, exist_ok=True)
api = HfApi()
def cleanup_temp():
if TempDir.exists():
shutil.rmtree(TempDir)
os.makedirs(TempDir, exist_ok=True)
gc.collect()
def download_file(input_path, token, filename=None):
local_path = TempDir / (filename if filename else "model.safetensors")
if input_path.startswith("http"):
print(f"Downloading {filename} from URL...")
try:
response = requests.get(input_path, stream=True, timeout=30)
response.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
except Exception as e: raise ValueError(f"Download failed: {e}")
else:
print(f"Downloading {filename} from Hub...")
if not filename:
try:
files = list_repo_files(repo_id=input_path, token=token)
safetensors = [f for f in files if f.endswith(".safetensors")]
filename = safetensors[0] if safetensors else "adapter_model.safetensors"
except: filename = "adapter_model.safetensors"
try:
hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
if not (TempDir / filename).exists():
found = list(TempDir.rglob(filename))
if found: shutil.move(found[0], local_path)
except Exception as e: raise ValueError(f"Hub download failed: {e}")
return local_path
def get_key_stem(key):
key = key.replace(".weight", "").replace(".bias", "")
key = key.replace(".lora_down", "").replace(".lora_up", "")
key = key.replace(".lora_A", "").replace(".lora_B", "")
key = key.replace(".alpha", "")
prefixes = [
"model.diffusion_model.", "diffusion_model.", "model.",
"transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
]
changed = True
while changed:
changed = False
for p in prefixes:
if key.startswith(p):
key = key[len(p):]
changed = True
return key
# =================================================================================
# TAB 1: MERGE & RESHARD (Fixes Folder Structure & Aux Files)
# =================================================================================
def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
print(f"Loading LoRA from {lora_path}...")
state_dict = load_file(lora_path, device="cpu")
pairs = {}
alphas = {}
for k, v in state_dict.items():
stem = get_key_stem(k)
if "alpha" in k:
alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
else:
if stem not in pairs: pairs[stem] = {}
if "lora_down" in k or "lora_A" in k:
pairs[stem]["down"] = v.to(dtype=precision_dtype)
pairs[stem]["rank"] = v.shape[0]
elif "lora_up" in k or "lora_B" in k:
pairs[stem]["up"] = v.to(dtype=precision_dtype)
for stem in pairs:
pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
return pairs
class ShardBuffer:
def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
self.max_bytes = int(max_size_gb * 1024**3)
self.output_dir = output_dir
self.output_repo = output_repo
self.subfolder = subfolder
self.hf_token = hf_token
self.filename_prefix = filename_prefix # Dynamic prefix (e.g. 'diffusion_pytorch_model' or 'model')
self.buffer = []
self.current_bytes = 0
self.shard_count = 0
self.index_map = {}
self.total_model_size = 0
def add_tensor(self, key, tensor):
if tensor.dtype == torch.bfloat16:
raw_bytes = tensor.view(torch.int16).numpy().tobytes()
dtype_str = "BF16"
elif tensor.dtype == torch.float16:
raw_bytes = tensor.numpy().tobytes()
dtype_str = "F16"
else:
raw_bytes = tensor.numpy().tobytes()
dtype_str = "F32"
size = len(raw_bytes)
self.buffer.append({
"key": key,
"data": raw_bytes,
"dtype": dtype_str,
"shape": tensor.shape
})
self.current_bytes += size
self.total_model_size += size
if self.current_bytes >= self.max_bytes:
self.flush()
def flush(self):
if not self.buffer: return
self.shard_count += 1
# ADAPTIVE NAMING: Uses the prefix detected from the base model
filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
# Proper Subfolder Handling
path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...")
header = {"__metadata__": {"format": "pt"}}
current_offset = 0
for item in self.buffer:
header[item["key"]] = {
"dtype": item["dtype"],
"shape": item["shape"],
"data_offsets": [current_offset, current_offset + len(item["data"])]
}
current_offset += len(item["data"])
self.index_map[item["key"]] = filename
header_json = json.dumps(header).encode('utf-8')
out_path = self.output_dir / filename
with open(out_path, 'wb') as f:
f.write(struct.pack('<Q', len(header_json)))
f.write(header_json)
for item in self.buffer:
f.write(item["data"])
print(f"Uploading {path_in_repo}...")
api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
os.remove(out_path)
self.buffer = []
self.current_bytes = 0
gc.collect()
def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix="transformer"):
"""
Copies files one-by-one from source to dest, skipping 'ignore_prefix'.
Does NOT skip .safetensors/.bin if they are outside the ignore folder.
"""
print(f"Scanning {src_repo} for auxiliary files...")
try:
files = api.list_repo_files(repo_id=src_repo, token=token)
for f in tqdm(files, desc="Copying Structure"):
# 1. Skip the folder we are replacing (e.g., transformer/)
if ignore_prefix and f.startswith(ignore_prefix):
continue
# 2. Skip hidden/system files
if f.startswith("."):
continue
# 3. Download -> Upload -> Delete loop
# This ensures we get VAE/TextEnc weights without disk overflow
try:
print(f"Copying {f}...")
local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir)
api.upload_file(
path_or_fileobj=local,
path_in_repo=f,
repo_id=dst_repo,
token=token
)
if os.path.exists(local):
os.remove(local)
except Exception as e:
print(f"Failed to copy {f}: {e}")
except Exception as e:
print(f"Structure cloning error: {e}")
def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
cleanup_temp()
login(hf_token)
# 1. Output Setup
try:
api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
except Exception as e: return f"Error creating repo: {e}"
# 2. Server-Side Structure Clone
if structure_repo:
ignore = base_subfolder if base_subfolder else None
streaming_copy_structure(hf_token, structure_repo, output_repo, ignore)
# 3. Load LoRA
dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
try:
progress(0.1, desc="Downloading LoRA...")
lora_path = download_file(lora_input, hf_token, filename="adapter.safetensors")
lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
except Exception as e: return f"Error loading LoRA: {e}"
# 4. Stream Process
progress(0.2, desc="Fetching File List...")
files = list_repo_files(repo_id=base_repo, token=hf_token)
# Identify valid shards in the target folder
input_shards = []
for f in files:
if not f.endswith(".safetensors"): continue
if base_subfolder and not f.startswith(base_subfolder): continue
input_shards.append(f)
if not input_shards: return "No base safetensors found in specified location."
input_shards.sort()
# --- AUTO-DETECT NAMING CONVENTION ---
# We look at the first file to decide the naming scheme.
# Common schemes:
# "diffusion_pytorch_model-00001..." -> prefix: "diffusion_pytorch_model"
# "model-00001..." -> prefix: "model"
# "model.safetensors" -> prefix: "model"
first_file = os.path.basename(input_shards[0])
if first_file.startswith("diffusion_pytorch_model"):
filename_prefix = "diffusion_pytorch_model"
index_filename = "diffusion_pytorch_model.safetensors.index.json"
else:
# Default for LLMs, Text Encoders, etc.
filename_prefix = "model"
index_filename = "model.safetensors.index.json"
print(f"Detected naming convention: {filename_prefix} (Index: {index_filename})")
# Initialize Buffer with detected prefix
buffer = ShardBuffer(shard_size, TempDir, output_repo, base_subfolder, hf_token, filename_prefix=filename_prefix)
for i, shard_file in enumerate(input_shards):
progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {shard_file}")
local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
with MemoryEfficientSafeOpen(local_shard) as f:
keys = f.keys()
for k in keys:
v = f.get_tensor(k)
base_stem = get_key_stem(k)
lora_keys = set(lora_pairs.keys())
match = None
# Matching Logic (Exact + Heuristic for QKV)
if base_stem in lora_keys:
match = lora_pairs[base_stem]
else:
if "to_q" in base_stem:
qkv_stem = base_stem.replace("to_q", "qkv")
if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
elif "to_k" in base_stem:
qkv_stem = base_stem.replace("to_k", "qkv")
if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
elif "to_v" in base_stem:
qkv_stem = base_stem.replace("to_v", "qkv")
if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
if match and "down" in match and "up" in match:
down = match["down"]
up = match["up"]
alpha = match["alpha"]
rank = match["rank"]
scaling = scale * (alpha / rank)
# Handle Conv 1x1 squeeze
if len(v.shape) == 4 and len(down.shape) == 2:
down = down.unsqueeze(-1).unsqueeze(-1)
up = up.unsqueeze(-1).unsqueeze(-1)
try:
if len(up.shape) == 4:
delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
else:
delta = up @ down
except:
delta = up.T @ down
delta = delta * scaling
valid_delta = True
# Shape Slicing Logic
if delta.shape == v.shape:
pass
elif delta.shape[0] == v.shape[0] * 3:
chunk = v.shape[0]
if "to_q" in k: delta = delta[0:chunk, ...]
elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
elif "to_v" in k: delta = delta[2*chunk:, ...]
else: valid_delta = False
elif delta.numel() == v.numel():
delta = delta.reshape(v.shape)
else:
valid_delta = False
if valid_delta:
v = v.to(dtype)
delta = delta.to(dtype)
v.add_(delta)
del delta
if v.dtype != dtype: v = v.to(dtype)
buffer.add_tensor(k, v)
del v
os.remove(local_shard)
gc.collect()
buffer.flush()
# Upload Index (Using the dynamically determined index filename)
print(f"Uploading Index: {index_filename}")
index_data = {
"metadata": {"total_size": buffer.total_model_size},
"weight_map": buffer.index_map
}
with open(TempDir / index_filename, "w") as f:
json.dump(index_data, f, indent=4)
path_in_repo = f"{base_subfolder}/{index_filename}" if base_subfolder else index_filename
api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
cleanup_temp()
return f"Done! Merged into {buffer.shard_count} shards at {output_repo}"
# =================================================================================
# TAB 2: EXTRACT LORA
# =================================================================================
def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
org = MemoryEfficientSafeOpen(model_org)
tuned = MemoryEfficientSafeOpen(model_tuned)
lora_sd = {}
print("Calculating diffs...")
for key in tqdm(org.keys()):
if key not in tuned.keys(): continue
mat_org = org.get_tensor(key).float()
mat_tuned = tuned.get_tensor(key).float()
diff = mat_tuned - mat_org
if torch.max(torch.abs(diff)) < 1e-4: continue
out_dim, in_dim = diff.shape[:2]
r = min(rank, in_dim, out_dim)
is_conv = len(diff.shape) == 4
if is_conv: diff = diff.flatten(start_dim=1)
try:
U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
U = U @ torch.diag(S)
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp)
U = U.clamp(-hi_val, hi_val)
Vh = Vh.clamp(-hi_val, hi_val)
if is_conv:
U = U.reshape(out_dim, r, 1, 1)
Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
else:
U = U.reshape(out_dim, r)
Vh = Vh.reshape(r, in_dim)
stem = key.replace(".weight", "")
lora_sd[f"{stem}.lora_up.weight"] = U
lora_sd[f"{stem}.lora_down.weight"] = Vh
lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
except: pass
out = TempDir / "extracted.safetensors"
save_file(lora_sd, out)
return str(out)
def task_extract(hf_token, org, tun, rank, out):
cleanup_temp()
login(hf_token)
try:
p1 = download_file(org, hf_token, filename="org.safetensors")
p2 = download_file(tun, hf_token, filename="tun.safetensors")
f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token)
return "Done"
except Exception as e: return f"Error: {e}"
# =================================================================================
# TAB 3: MERGE ADAPTERS (EMA) with Sigma Rel
# =================================================================================
def sigma_rel_to_gamma(sigma_rel):
t = sigma_rel**-2
coeffs = [1, 7, 16 - t, 12 - t]
roots = np.roots(coeffs)
gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
return gamma
def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo):
cleanup_temp()
login(hf_token)
urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
paths = []
try:
for i, url in enumerate(urls):
paths.append(download_file(url, hf_token, filename=f"a_{i}.safetensors"))
except Exception as e: return f"Download Error: {e}"
if not paths: return "No models found"
base_sd = load_file(paths[0], device="cpu")
for k in base_sd:
if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
gamma = None
if sigma_rel > 0:
gamma = sigma_rel_to_gamma(sigma_rel)
for i, path in enumerate(paths[1:]):
print(f"Merging {path}")
if gamma is not None:
t = i + 1
current_beta = (1 - 1 / t) ** (gamma + 1)
else:
current_beta = beta
curr = load_file(path, device="cpu")
for k in base_sd:
if k in curr and "alpha" not in k:
base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
out = TempDir / "merged_adapters.safetensors"
save_file(base_sd, out)
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
return "Done"
# =================================================================================
# TAB 4: RESIZE
# =================================================================================
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
return max(1, min(index, len(S) - 1))
def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
cleanup_temp()
login(hf_token)
try:
path = download_file(lora_input, hf_token)
except Exception as e: return f"Error: {e}"
state = load_file(path, device="cpu")
new_state = {}
groups = {}
for k in state:
stem = get_key_stem(k)
simple = k.split(".lora_")[0]
if simple not in groups: groups[simple] = {}
if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
if "alpha" in k: groups[simple]["alpha"] = state[k]
for stem, g in tqdm(groups.items()):
if "down" in g and "up" in g:
down, up = g["down"].float(), g["up"].float()
if len(down.shape) == 4:
merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
flat = merged.flatten(1)
else:
merged = up @ down
flat = merged
U, S, Vh = torch.linalg.svd(flat, full_matrices=False)
target_rank = int(new_rank)
if dynamic_method == "sv_ratio":
target_rank = index_sv_ratio(S, dynamic_param)
target_rank = min(target_rank, S.shape[0])
U = U[:, :target_rank]
S = S[:target_rank]
U = U @ torch.diag(S)
Vh = Vh[:target_rank, :]
if len(down.shape) == 4:
U = U.reshape(up.shape[0], target_rank, 1, 1)
Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3])
new_state[f"{stem}.lora_down.weight"] = Vh
new_state[f"{stem}.lora_up.weight"] = U
new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
out = TempDir / "resized.safetensors"
save_file(new_state, out)
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token)
return "Done"
# =================================================================================
# UI
# =================================================================================
css = ".container { max-width: 900px; margin: auto; }"
with gr.Blocks() as demo:
gr.Markdown("# 🧰SOONmerge® LoRA Toolkit")
with gr.Tabs():
with gr.Tab("Merge to Base + Reshard Output"):
t1_token = gr.Textbox(label="Token", type="password")
t1_base = gr.Textbox(label="Base Repo (Diffusers)", value="ostris/Z-Image-De-Turbo")
t1_sub = gr.Textbox(label="Subfolder", value="transformer")
t1_lora = gr.Textbox(label="LoRA Direct Link", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
with gr.Row():
t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1)
t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1)
t1_out = gr.Textbox(label="Output Repo")
t1_struct = gr.Textbox(label="Diffusers Extras (Copies VAE/TextEnc/etc)", value="Tongyi-MAI/Z-Image-Turbo")
t1_priv = gr.Checkbox(label="Private", value=True)
t1_btn = gr.Button("Merge")
t1_res = gr.Textbox(label="Result")
t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
with gr.Tab("Extract Adapter"):
t2_token = gr.Textbox(label="Token", type="password")
t2_org = gr.Textbox(label="Original Model")
t2_tun = gr.Textbox(label="Tuned Model")
t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1)
t2_out = gr.Textbox(label="Output Repo")
t2_btn = gr.Button("Extract")
t2_res = gr.Textbox(label="Result")
t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
with gr.Tab("Merge Multiple Adapters"):
t3_token = gr.Textbox(label="Token", type="password")
t3_urls = gr.Textbox(label="URLs")
with gr.Row():
t3_beta = gr.Slider(label="Beta", value=0.95, minimum=0.01, maximum=1.00, step=0.01)
t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.21, minimum=0.01, maximum=1.00, step=0.01)
t3_out = gr.Textbox(label="Output Repo")
t3_btn = gr.Button("Merge")
t3_res = gr.Textbox(label="Result")
t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res)
with gr.Tab("Resize Adapter"):
t4_token = gr.Textbox(label="Token", type="password")
t4_in = gr.Textbox(label="LoRA")
with gr.Row():
t4_rank = gr.Number(label="To Rank (Lower Only!)", value=8, minimum=1, maximum=256, step=1)
t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method")
t4_param = gr.Number(label="Dynamic Param", value=4.0)
t4_out = gr.Textbox(label="Output")
t4_btn = gr.Button("Resize")
t4_res = gr.Textbox(label="Result")
t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
if __name__ == "__main__":
demo.queue().launch(css=css, ssr_mode=False)