Soon_Merger_Toolkit / app_SoonMerger.py
AlekseyCalvin's picture
Rename app.py to app_SoonMerger.py
f5b827c verified
import gradio as gr
import torch
import os
import gc
from merge_utils import execute_mergekit
import shutil
import requests
import json
import struct
import numpy as np
import re
import yaml
from pathlib import Path
from typing import Dict, Any, Optional, List
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
from safetensors.torch import load_file, save_file
from tqdm import tqdm
# --- Memory Efficient Safetensors ---
class MemoryEfficientSafeOpen:
def __init__(self, filename):
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self) -> list[str]:
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype_map = {
"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
"I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
"U8": torch.uint8, "BOOL": torch.bool
}
dtype = dtype_map[metadata["dtype"]]
shape = metadata["shape"]
return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
# --- Constants & Setup ---
try:
TempDir = Path("/tmp/temp_tool")
os.makedirs(TempDir, exist_ok=True)
except:
TempDir = Path("./temp_tool")
os.makedirs(TempDir, exist_ok=True)
api = HfApi()
def cleanup_temp():
if TempDir.exists():
shutil.rmtree(TempDir)
os.makedirs(TempDir, exist_ok=True)
gc.collect()
def get_key_stem(key):
key = key.replace(".weight", "").replace(".bias", "")
key = key.replace(".lora_down", "").replace(".lora_up", "")
key = key.replace(".lora_A", "").replace(".lora_B", "")
key = key.replace(".alpha", "")
prefixes = [
"model.diffusion_model.", "diffusion_model.", "model.",
"transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
]
changed = True
while changed:
changed = False
for p in prefixes:
if key.startswith(p):
key = key[len(p):]
changed = True
return key
# =================================================================================
# TAB 1: MERGE & RESHARD
# =================================================================================
def parse_hf_url(url):
"""Parses a direct HF URL into repo_id and filename."""
# Pattern: https://huggingface.co/{user}/{repo}/resolve/{branch}/{filename...}
if "huggingface.co" in url and "resolve" in url:
try:
parts = url.split("huggingface.co/")[-1].split("/")
# parts[0]=user, parts[1]=repo, parts[2]=resolve, parts[3]=branch, parts[4:]=file
repo_id = f"{parts[0]}/{parts[1]}"
filename = "/".join(parts[4:]).split("?")[0] # Strip query params
return repo_id, filename
except:
return None, None
return None, None
def download_lora_smart(input_str, token):
local_path = TempDir / "adapter.safetensors"
if local_path.exists(): os.remove(local_path)
print(f"Resolving LoRA Input: {input_str}")
# 1. Try Parse as HF URL (Most Robust Method)
repo_id, filename = parse_hf_url(input_str)
if repo_id and filename:
print(f"Detected HF URL. Repo: {repo_id}, File: {filename}")
try:
hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
# Move to standard name
found = list(TempDir.rglob(filename.split("/")[-1]))[0] # Handle subfolder downloads
if found != local_path: shutil.move(found, local_path)
return local_path
except Exception as e:
print(f"HF Download failed: {e}. Falling back...")
# 2. Try as Raw Repo ID (User/Repo)
try:
# Check if user put "User/Repo/file.safetensors"
if ".safetensors" in input_str and input_str.count("/") >= 2:
parts = input_str.split("/")
repo_id = f"{parts[0]}/{parts[1]}"
filename = "/".join(parts[2:])
hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
found = list(TempDir.rglob(filename.split("/")[-1]))[0]
if found != local_path: shutil.move(found, local_path)
return local_path
# Standard Auto-Discovery
candidates = ["adapter_model.safetensors", "model.safetensors"]
files = list_repo_files(repo_id=input_str, token=token)
target = next((f for f in files if f in candidates), None)
if not target:
safes = [f for f in files if f.endswith(".safetensors")]
if safes: target = safes[0]
if not target: raise ValueError("No safetensors found")
hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
found = list(TempDir.rglob(target.split("/")[-1]))[0]
if found != local_path: shutil.move(found, local_path)
return local_path
except Exception as e:
# 3. Last Resort: Raw Requests (For non-HF links)
if input_str.startswith("http"):
try:
headers = {"Authorization": f"Bearer {token}"} if token else {}
r = requests.get(input_str, stream=True, headers=headers, timeout=60)
r.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
return local_path
except Exception as req_e:
raise ValueError(f"All download methods failed.\nRepo Logic Error: {e}\nURL Logic Error: {req_e}")
raise e
def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
print(f"Loading LoRA from {lora_path}...")
state_dict = load_file(lora_path, device="cpu")
pairs = {}
alphas = {}
for k, v in state_dict.items():
stem = get_key_stem(k)
if "alpha" in k:
alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
else:
if stem not in pairs: pairs[stem] = {}
if "lora_down" in k or "lora_A" in k:
pairs[stem]["down"] = v.to(dtype=precision_dtype)
pairs[stem]["rank"] = v.shape[0]
elif "lora_up" in k or "lora_B" in k:
pairs[stem]["up"] = v.to(dtype=precision_dtype)
for stem in pairs:
pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
return pairs
class ShardBuffer:
def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
self.max_bytes = int(max_size_gb * 1024**3)
self.output_dir = output_dir
self.output_repo = output_repo
self.subfolder = subfolder
self.hf_token = hf_token
self.filename_prefix = filename_prefix
self.buffer = []
self.current_bytes = 0
self.shard_count = 0
self.index_map = {}
self.total_size = 0
def add_tensor(self, key, tensor):
if tensor.dtype == torch.bfloat16:
raw_bytes = tensor.view(torch.int16).numpy().tobytes()
dtype_str = "BF16"
elif tensor.dtype == torch.float16:
raw_bytes = tensor.numpy().tobytes()
dtype_str = "F16"
else:
raw_bytes = tensor.numpy().tobytes()
dtype_str = "F32"
size = len(raw_bytes)
self.buffer.append({
"key": key,
"data": raw_bytes,
"dtype": dtype_str,
"shape": tensor.shape
})
self.current_bytes += size
self.total_size += size
if self.current_bytes >= self.max_bytes:
self.flush()
def flush(self):
if not self.buffer: return
self.shard_count += 1
filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...")
header = {"__metadata__": {"format": "pt"}}
current_offset = 0
for item in self.buffer:
header[item["key"]] = {
"dtype": item["dtype"],
"shape": item["shape"],
"data_offsets": [current_offset, current_offset + len(item["data"])]
}
current_offset += len(item["data"])
self.index_map[item["key"]] = filename
header_json = json.dumps(header).encode('utf-8')
out_path = self.output_dir / filename
with open(out_path, 'wb') as f:
f.write(struct.pack('<Q', len(header_json)))
f.write(header_json)
for item in self.buffer:
f.write(item["data"])
print(f"Uploading {path_in_repo}...")
api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
os.remove(out_path)
self.buffer = []
self.current_bytes = 0
gc.collect()
def copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder):
"""Aggressively copy all config/misc files, only skipping heavy weights."""
print(f"Copying config files from {base_repo}...")
try:
files = list_repo_files(repo_id=base_repo, token=hf_token)
blocked_ext = ['.safetensors', '.bin', '.pt', '.pth', '.msgpack', '.h5', '.onnx']
for f in files:
# Filter by subfolder if needed
if base_subfolder and not f.startswith(base_subfolder): continue
# Block heavy weights
if any(f.endswith(ext) for ext in blocked_ext): continue
print(f"Transferring {f}...")
local = hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=TempDir)
# Determine path in new repo
rel_name = f[len(base_subfolder):].lstrip('/') if base_subfolder else f
target_path = f"{output_subfolder}/{rel_name}" if output_subfolder else rel_name
api.upload_file(path_or_fileobj=local, path_in_repo=target_path, repo_id=output_repo, token=hf_token)
os.remove(local)
except Exception as e:
print(f"Config copy warning: {e}")
def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix=None, is_root_merge=False):
print(f"Scanning {src_repo} for structure cloning...")
try:
files = api.list_repo_files(repo_id=src_repo, token=token)
for f in tqdm(files, desc="Copying Structure"):
if ignore_prefix and f.startswith(ignore_prefix): continue
if is_root_merge:
if any(f.endswith(ext) for ext in ['.safetensors', '.bin', '.pt', '.pth']):
continue
try:
local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir)
api.upload_file(path_or_fileobj=local, path_in_repo=f, repo_id=dst_repo, token=token)
if os.path.exists(local): os.remove(local)
except: pass
except Exception as e: print(f"Structure clone error: {e}")
def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
cleanup_temp()
if not hf_token: return "Error: HF Token required."
login(hf_token.strip())
try:
api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
except Exception as e: return f"Error creating repo: {e}"
# Logic: If using a subfolder like 'transformer', we want standard diffusers naming
output_subfolder = base_subfolder if base_subfolder else ""
# 2. Copy Configs from Base (Aggressive Copy)
if base_subfolder:
copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder)
# 3. Clone Structure Repo
if structure_repo:
ignore = output_subfolder if output_subfolder else None
streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix=ignore, is_root_merge=not bool(output_subfolder))
# 4. Download Shards
progress(0.1, desc="Downloading Input Model...")
files = list_repo_files(repo_id=base_repo, token=hf_token)
input_shards = []
for f in files:
if f.endswith(".safetensors"):
if output_subfolder and not f.startswith(output_subfolder): continue
local = TempDir / "inputs" / os.path.basename(f)
os.makedirs(local.parent, exist_ok=True)
hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local.parent, local_dir_use_symlinks=False)
found = list(local.parent.rglob(os.path.basename(f)))
if found: input_shards.append(found[0])
if not input_shards: return "No safetensors found."
input_shards.sort()
# --- NAMING CONVENTION ---
# Force diffusion naming if target is transformer/unet
if output_subfolder in ["transformer", "unet", "qint4", "qint8"]:
filename_prefix = "diffusion_pytorch_model"
index_filename = "diffusion_pytorch_model.safetensors.index.json"
elif "diffusion_pytorch_model" in os.path.basename(input_shards[0]):
filename_prefix = "diffusion_pytorch_model"
index_filename = "diffusion_pytorch_model.safetensors.index.json"
else:
filename_prefix = "model"
index_filename = "model.safetensors.index.json"
print(f"Naming scheme: {filename_prefix}")
# 5. Load LoRA
dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
try:
progress(0.15, desc="Downloading LoRA...")
lora_path = download_lora_smart(lora_input, hf_token)
lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
except Exception as e: return f"Error loading LoRA: {e}"
# 6. Stream
buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix)
for i, shard_file in enumerate(input_shards):
progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}")
with MemoryEfficientSafeOpen(shard_file) as f:
keys = f.keys()
for k in keys:
v = f.get_tensor(k)
base_stem = get_key_stem(k)
match = lora_pairs.get(base_stem)
# QKV Heuristics
if not match:
if "to_q" in base_stem:
qkv = base_stem.replace("to_q", "qkv")
match = lora_pairs.get(qkv)
elif "to_k" in base_stem:
qkv = base_stem.replace("to_k", "qkv")
match = lora_pairs.get(qkv)
elif "to_v" in base_stem:
qkv = base_stem.replace("to_v", "qkv")
match = lora_pairs.get(qkv)
if match:
down = match["down"]
up = match["up"]
scaling = scale * (match["alpha"] / match["rank"])
if len(v.shape) == 4 and len(down.shape) == 2:
down = down.unsqueeze(-1).unsqueeze(-1)
up = up.unsqueeze(-1).unsqueeze(-1)
try:
if len(up.shape) == 4:
delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
else:
delta = up @ down
except: delta = up.T @ down
delta = delta * scaling
valid = True
if delta.shape == v.shape: pass
elif delta.shape[0] == v.shape[0] * 3:
chunk = v.shape[0]
if "to_q" in k: delta = delta[0:chunk, ...]
elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
elif "to_v" in k: delta = delta[2*chunk:, ...]
else: valid = False
elif delta.numel() == v.numel(): delta = delta.reshape(v.shape)
else: valid = False
if valid:
v = v.to(dtype)
delta = delta.to(dtype)
v.add_(delta)
del delta
if v.dtype != dtype: v = v.to(dtype)
buffer.add_tensor(k, v)
del v
os.remove(shard_file)
gc.collect()
buffer.flush()
print(f"Uploading Index: {index_filename} (Size: {buffer.total_size})")
index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map}
with open(TempDir / index_filename, "w") as f:
json.dump(index_data, f, indent=4)
path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename
api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
cleanup_temp()
return f"Done! Merged {buffer.shard_count} shards to {output_repo}"
# =================================================================================
# TAB 2: EXTRACT LORA
# =================================================================================
def identify_and_download_model(input_str, token):
"""
Smart download:
1. Checks if input is a direct URL -> downloads specific file.
2. If input is a Repo ID -> scans for diffusers format (unet/transformer) or standard safetensors.
"""
print(f"Resolving model input: {input_str}")
# --- STRATEGY A: Direct URL ---
repo_id_from_url, filename_from_url = parse_hf_url(input_str)
if repo_id_from_url and filename_from_url:
print(f"Detected Direct Link. Repo: {repo_id_from_url}, File: {filename_from_url}")
local_path = TempDir / os.path.basename(filename_from_url)
# Clean up previous download if name conflicts
if local_path.exists(): os.remove(local_path)
try:
hf_hub_download(repo_id=repo_id_from_url, filename=filename_from_url, token=token, local_dir=TempDir)
# Find where it landed (handling subfolders in local_dir)
found = list(TempDir.rglob(os.path.basename(filename_from_url)))[0]
return found
except Exception as e:
print(f"URL Download failed: {e}. Trying fallback...")
# --- STRATEGY B: Repo Discovery (Auto-Detect) ---
# If we are here, input_str is treated as a Repo ID (e.g. "ostris/Z-Image-De-Turbo")
print(f"Scanning Repo {input_str} for model weights...")
try:
files = list_repo_files(repo_id=input_str, token=token)
except Exception as e:
raise ValueError(f"Failed to list repo '{input_str}'. If this is a URL, ensure it is formatted correctly. Error: {e}")
# Priority list for diffusers vs single file
priorities = [
"transformer/diffusion_pytorch_model.safetensors",
"unet/diffusion_pytorch_model.safetensors",
"model.safetensors",
# Fallback to any safetensors that isn't an adapter or lora
lambda f: f.endswith(".safetensors") and "lora" not in f and "adapter" not in f and "extracted" not in f
]
target_file = None
for p in priorities:
if callable(p):
candidates = [f for f in files if p(f)]
if candidates:
# Pick the largest file if multiple candidates (heuristic for "main" model)
target_file = candidates[0]
break
elif p in files:
target_file = p
break
if not target_file:
raise ValueError(f"Could not find a valid model weight file in {input_str}. Ensure it contains .safetensors weights.")
print(f"Downloading auto-detected weight file: {target_file}")
hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir)
# Locate actual path
found = list(TempDir.rglob(os.path.basename(target_file)))[0]
return found
def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
org = MemoryEfficientSafeOpen(model_org)
tuned = MemoryEfficientSafeOpen(model_tuned)
lora_sd = {}
print("Calculating diffs & extracting LoRA...")
# Get intersection of keys
keys = set(org.keys()).intersection(set(tuned.keys()))
for key in tqdm(keys, desc="Extracting"):
# Skip integer buffers/metadata
if "num_batches_tracked" in key or "running_mean" in key or "running_var" in key:
continue
mat_org = org.get_tensor(key).float()
mat_tuned = tuned.get_tensor(key).float()
# Skip if shapes mismatch (shouldn't happen if models match)
if mat_org.shape != mat_tuned.shape: continue
diff = mat_tuned - mat_org
# Skip if no difference
if torch.max(torch.abs(diff)) < 1e-4: continue
out_dim = diff.shape[0]
in_dim = diff.shape[1] if len(diff.shape) > 1 else 1
r = min(rank, in_dim, out_dim)
is_conv = len(diff.shape) == 4
if is_conv: diff = diff.flatten(start_dim=1)
elif len(diff.shape) == 1: diff = diff.unsqueeze(1) # Handle biases if needed
try:
# Use svd_lowrank for massive speedup on CPU vs linalg.svd
U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
Vh = V.t()
U = U[:, :r]
S = S[:r]
Vh = Vh[:r, :]
# Merge S into U for standard LoRA format
U = U @ torch.diag(S)
# Clamp outliers
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(torch.abs(dist), clamp)
if hi_val > 0:
U = U.clamp(-hi_val, hi_val)
Vh = Vh.clamp(-hi_val, hi_val)
if is_conv:
U = U.reshape(out_dim, r, 1, 1)
Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
else:
U = U.reshape(out_dim, r)
Vh = Vh.reshape(r, in_dim)
stem = key.replace(".weight", "")
lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
except Exception as e:
print(f"Skipping {key} due to error: {e}")
pass
out = TempDir / "extracted.safetensors"
save_file(lora_sd, out)
return str(out)
def task_extract(hf_token, org, tun, rank, out):
cleanup_temp()
if hf_token: login(hf_token.strip())
try:
print("Downloading Original Model...")
p1 = identify_and_download_model(org, hf_token)
print("Downloading Tuned Model...")
p2 = identify_and_download_model(tun, hf_token)
f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
return "Done! Extracted to " + out
except Exception as e: return f"Error: {e}"
# =================================================================================
# TAB 3: MERGE ADAPTERS (Multi-Method)
# =================================================================================
def load_full_state_dict(path):
"""Loads a safetensor file and cleans keys for easier processing."""
raw = load_file(path, device="cpu")
cleaned = {}
for k, v in raw.items():
# Map common keys to standard "lora_up/lora_down"
if "lora_A" in k: new_k = k.replace("lora_A", "lora_down")
elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up")
else: new_k = k
cleaned[new_k] = v.float()
return cleaned
# --- Original EMA Method ---
def sigma_rel_to_gamma(sigma_rel):
t = sigma_rel**-2
coeffs = [1, 7, 16 - t, 12 - t]
roots = np.roots(coeffs)
gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
return gamma
def merge_lora_iterative_ema(paths, beta, sigma_rel):
print("Executing Iterative EMA Merge (Original Method)...")
base_sd = load_file(paths[0], device="cpu")
for k in base_sd:
if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
gamma = None
if sigma_rel > 0:
gamma = sigma_rel_to_gamma(sigma_rel)
for i, path in enumerate(paths[1:]):
print(f"Merging {path}")
if gamma is not None:
t = i + 1
current_beta = (1 - 1 / t) ** (gamma + 1)
else:
current_beta = beta
curr = load_file(path, device="cpu")
for k in base_sd:
if k in curr and "alpha" not in k:
base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
return base_sd
# --- New Concatenation Method (DiffSynth) ---
def merge_lora_concatenation(adapter_states, weights):
"""
DiffSynth Method: Concatenates ranks.
New Rank = sum(ranks). Lossless merging.
"""
print("Executing Concatenation Merge (Rank Summation)...")
merged_state = {}
# Identify all stems (layers) present across all adapters
all_stems = set()
for state in adapter_states:
for k in state.keys():
stem = k.split(".lora_")[0]
if "lora_" in k: all_stems.add(stem)
for stem in tqdm(all_stems, desc="Concatenating Layers"):
down_list = []
up_list = []
alpha_sum = 0.0
for i, state in enumerate(adapter_states):
w = weights[i]
down_key = f"{stem}.lora_down.weight"
up_key = f"{stem}.lora_up.weight"
alpha_key = f"{stem}.alpha"
if down_key in state and up_key in state:
d = state[down_key]
u = state[up_key] * w # weighted contribution applied to UP
down_list.append(d)
up_list.append(u)
if alpha_key in state:
alpha_sum += state[alpha_key].item()
else:
alpha_sum += d.shape[0]
if down_list and up_list:
# Concat Down (A) along dim 0 (output of A, input to B) - Wait, lora_A is (rank, in)
# Concat Up (B) along dim 1 (input of B) - lora_B is (out, rank)
# Reference: DiffSynth code: lora_A = concat(tensors_A, dim=0), lora_B = concat(tensors_B, dim=1)
new_down = torch.cat(down_list, dim=0) # (sum_rank, in)
new_up = torch.cat(up_list, dim=1) # (out, sum_rank)
merged_state[f"{stem}.lora_down.weight"] = new_down.contiguous()
merged_state[f"{stem}.lora_up.weight"] = new_up.contiguous()
merged_state[f"{stem}.alpha"] = torch.tensor(alpha_sum)
return merged_state
# --- New SVD/Task Arithmetic Method ---
def merge_lora_svd(adapter_states, weights, target_rank):
"""
SVD / Task Arithmetic Method:
1. Calculate Delta W for each adapter: dW = B @ A
2. Sum Delta Ws: Total dW = sum(weight_i * dW_i)
3. SVD(Total dW) -> New B, New A at target_rank
"""
print(f"Executing SVD Merge (Target Rank: {target_rank})...")
merged_state = {}
all_stems = set()
for state in adapter_states:
for k in state.keys():
stem = k.split(".lora_")[0]
if "lora_" in k: all_stems.add(stem)
for stem in tqdm(all_stems, desc="SVD Merging Layers"):
total_delta = None
valid_layer = False
for i, state in enumerate(adapter_states):
w = weights[i]
down_key = f"{stem}.lora_down.weight"
up_key = f"{stem}.lora_up.weight"
alpha_key = f"{stem}.alpha"
if down_key in state and up_key in state:
down = state[down_key]
up = state[up_key]
alpha = state[alpha_key].item() if alpha_key in state else down.shape[0]
rank = down.shape[0]
scale = (alpha / rank) * w
# Reconstruct Delta
if len(down.shape) == 4: # Conv2d
d_flat = down.flatten(start_dim=1)
u_flat = up.flatten(start_dim=1)
delta = (u_flat @ d_flat).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
else:
delta = up @ down
delta = delta * scale
if total_delta is None:
total_delta = delta
valid_layer = True
else:
if total_delta.shape == delta.shape:
total_delta += delta
else:
print(f"Shape mismatch in {stem}, skipping.")
if valid_layer and total_delta is not None:
out_dim = total_delta.shape[0]
in_dim = total_delta.shape[1]
is_conv = len(total_delta.shape) == 4
if is_conv:
flat_delta = total_delta.flatten(start_dim=1)
else:
flat_delta = total_delta
try:
U, S, V = torch.svd_lowrank(flat_delta, q=target_rank + 4, niter=4)
Vh = V.t()
U = U[:, :target_rank]
S = S[:target_rank]
Vh = Vh[:target_rank, :]
U = U @ torch.diag(S)
if is_conv:
U = U.reshape(out_dim, target_rank, 1, 1)
Vh = Vh.reshape(target_rank, in_dim, total_delta.shape[2], total_delta.shape[3])
else:
U = U.reshape(out_dim, target_rank)
Vh = Vh.reshape(target_rank, in_dim)
merged_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
merged_state[f"{stem}.lora_up.weight"] = U.contiguous()
merged_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
except Exception as e:
print(f"SVD Failed for {stem}: {e}")
return merged_state
def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private):
cleanup_temp()
if hf_token: login(hf_token.strip())
if not out_repo or not out_repo.strip():
return "Error: Output Repo cannot be empty."
# 1. Parse Inputs (Multi-line support)
raw_lines = inputs_text.replace(" ", "\n").split('\n')
urls = [line.strip() for line in raw_lines if line.strip()]
if len(urls) < 2: return "Error: Please provide at least 2 adapters."
# 2. Parse Weights (for SVD/Concatenation)
try:
if not weight_str.strip():
weights = [1.0] * len(urls)
else:
weights = [float(w.strip()) for w in weight_str.split(',')]
# Broadcast or Truncate
if len(weights) < len(urls):
weights += [1.0] * (len(urls) - len(weights))
else:
weights = weights[:len(urls)]
except:
return "Error parsing weights. Use format: 1.0, 0.5, 0.8"
# 3. Download All
paths = []
try:
for url in tqdm(urls, desc="Downloading Adapters"):
paths.append(download_lora_smart(url, hf_token))
except Exception as e: return f"Download Error: {e}"
merged = None
# 4. Execute Selected Method
if "Iterative EMA" in method:
# Calls the original method logic exactly
merged = merge_lora_iterative_ema(paths, beta, sigma_rel)
else:
# For new methods, we load everything upfront
states = [load_full_state_dict(p) for p in paths]
if "Concatenation" in method:
merged = merge_lora_concatenation(states, weights)
elif "SVD" in method:
merged = merge_lora_svd(states, weights, int(target_rank))
if not merged: return "Merge failed (Result empty)."
# 5. Save & Upload
out = TempDir / "merged_adapters.safetensors"
save_file(merged, out)
try:
api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
return f"Success! Merged to {out_repo}"
except Exception as e: return f"Upload Error: {e}"
# =================================================================================
# TAB 4: RESIZE (CPU Optimized)
# =================================================================================
def index_sv_cumulative(S, target):
"""Cumulative sum retention."""
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
"""Frobenius norm retention (squared sum)."""
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_ratio(S, target):
"""Ratio between max and min singular value."""
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
cleanup_temp()
if not hf_token: return "Error: Token required"
login(hf_token.strip())
try:
path = download_lora_smart(lora_input, hf_token)
except Exception as e: return f"Error: {e}"
state = load_file(path, device="cpu")
new_state = {}
groups = {}
for k in state:
stem = get_key_stem(k)
simple = k.split(".lora_")[0]
if simple not in groups: groups[simple] = {}
if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
if "alpha" in k: groups[simple]["alpha"] = state[k]
print(f"Resizing {len(groups)} blocks...")
# Pre-parse user settings
target_rank_limit = int(new_rank)
if dynamic_method == "None": dynamic_method = None
for stem, g in tqdm(groups.items()):
if "down" in g and "up" in g:
down, up = g["down"].float(), g["up"].float()
# 1. Merge Up/Down to get full weight delta
if len(down.shape) == 4:
merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
flat = merged.flatten(1)
else:
merged = up @ down
flat = merged
# 2. FAST SVD (svd_lowrank)
# Use the "To Rank" input as a computational hard limit + buffer.
# This ensures we don't compute expensive full SVD for massive layers.
q_limit = target_rank_limit + 32 # Buffer to allow dynamic methods some wiggle room before truncation
q = min(q_limit, min(flat.shape))
U, S, V = torch.svd_lowrank(flat, q=q)
Vh = V.t()
# 3. Dynamic Rank Selection
calculated_rank = target_rank_limit
if dynamic_method == "sv_ratio":
calculated_rank = index_sv_ratio(S, dynamic_param)
elif dynamic_method == "sv_cumulative":
calculated_rank = index_sv_cumulative(S, dynamic_param)
elif dynamic_method == "sv_fro":
calculated_rank = index_sv_fro(S, dynamic_param)
# Apply Hard Limit (User's "To Rank")
final_rank = min(calculated_rank, target_rank_limit, S.shape[0])
# 4. Truncate
U = U[:, :final_rank]
S = S[:final_rank]
Vh = Vh[:final_rank, :]
# 5. Reconstruct Up Matrix (Absorb S into U)
U = U @ torch.diag(S)
if len(down.shape) == 4:
U = U.reshape(up.shape[0], final_rank, 1, 1)
Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3])
# 6. Save (FIX: Enforce contiguous memory layout)
new_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
new_state[f"{stem}.lora_up.weight"] = U.contiguous()
new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float()
out = TempDir / "shrunken_.safetensors"
# safetensors requires contiguous tensors
save_file(new_state, out)
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token)
return "Done"
# =================================================================================
# NEW TAB 5: FULL MODEL MERGER (MergeKit GUI Wrapper)
# =================================================================================
def task_full_model_merge(hf_token, models_text, method, dtype, base, weights, density, layer_ranges, tok_src, shard_size, out_repo, private):
cleanup_temp()
if not hf_token or not out_repo: return "Error: Token and Output Repo required."
login(hf_token.strip())
model_list = [m.strip() for m in models_text.split('\n') if m.strip()]
if len(model_list) < 2: return "Error: Minimum 2 models required."
# Parse Weights
try:
w_list = [float(w.strip()) for w in weights.split(',')] if weights else [1.0] * len(model_list)
except: return "Error: Weights must be comma-separated numbers."
config = build_full_merge_config(
method=method, models=models, base_model=base if base else model_list[0],
weights=weights_text, density=density, dtype=dtype,
tokenizer_source=tok_src, layer_ranges=layer_ranges
)
for i, m in enumerate(model_list):
m_params = {"model": m, "parameters": {"weight": w_list[i] if i < len(w_list) else 1.0}}
if method.lower() in ["ties", "dare_ties", "dare_linear"]:
m_params["parameters"]["density"] = density
config["models"].append(m_params)
out_path = TempDir / "merged_model"
try:
# Pass shard size to our execute_mergekit helper
execute_mergekit(config, str(out_path), shard_size)
api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
api.upload_folder(folder_path=str(out_path), repo_id=out_repo, token=hf_token)
return f"Success! Model merged and uploaded to {out_repo}"
except Exception as e:
return f"Merge Error: {e}"
# =================================================================================
# NEW TAB 6: MIXTURE OF EXPERTS (MoE Creator)
# =================================================================================
def task_create_moe(hf_token, dtype, shard_size, base_model, experts_text, gate_mode, tok_src, out_repo, private):
cleanup_temp()
if not hf_token or not out_repo: return "Error: Token and Output Repo required."
login(hf_token.strip())
experts = [e.strip() for e in experts_text.split('\n') if e.strip()]
if not experts: return "Error: At least one expert model is required."
config = {
"method": "moe",
"base_model": base_model,
"dtype": dtype,
"tokenizer_source": tok_src,
"params": {"gate_mode": gate_mode},
"experts": [{"source_model": exp} for exp in experts]
}
out_path = TempDir / "moe_model"
try:
execute_mergekit(config, str(out_path), shard_size)
api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
api.upload_folder(folder_path=str(out_path), repo_id=out_repo, token=hf_token)
return f"Success! MoE model uploaded to {out_repo}"
except Exception as e:
return f"MoE Build Error: {e}"
# =================================================================================
# UI
# =================================================================================
css = ".container { max-width: 900px; margin: auto; }"
with gr.Blocks() as demo:
title = gr.HTML(
"""<h1><img src="https://huggingface.co/spaces/AlekseyCalvin/Soon_Merger/resolve/main/SMerger3.png" alt="SOONmerge®"> Transform Transformers for FREE!</h1>""",
elem_id="title",
)
gr.Markdown("# 🧰SOONmerge® LoRA Toolkit")
with gr.Tabs():
with gr.Tab("Merge to Base Model + Reshard Output"):
t1_token = gr.Textbox(label="Token", type="password")
t1_base = gr.Textbox(label="Base Repo", value="name/repo")
t1_sub = gr.Textbox(label="Subfolder (Optional)", value="")
t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
with gr.Row():
t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1)
t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
t1_shard = gr.Slider(label="Max Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1)
t1_out = gr.Textbox(label="Output Repo")
t1_struct = gr.Textbox(label="Extras Source (copies configs/components/etc)", value="name/repo")
t1_priv = gr.Checkbox(label="Private", value=True)
t1_btn = gr.Button("Merge")
t1_res = gr.Textbox(label="Result")
t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
with gr.Tab("Extract Adapter"):
t2_token = gr.Textbox(label="Token", type="password")
t2_org = gr.Textbox(label="Original Model")
t2_tun = gr.Textbox(label="Tuned or Homologous Model")
t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1)
t2_out = gr.Textbox(label="Output Repo")
t2_btn = gr.Button("Extract")
t2_res = gr.Textbox(label="Result")
t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
with gr.Tab("Merge Adapters/Weights"):
gr.Markdown("### Batch Adapter Merging")
t3_token = gr.Textbox(label="Token", type="password")
t3_urls = gr.TextArea(label="Adapter URLs/Repos (one per line, or space-separated)", placeholder="user/lora1\nhttps://hf.co/user/lora2.safetensors\n...")
with gr.Row():
t3_method = gr.Dropdown(
["Iterative EMA (Linear w/ Beta/Sigma coefficient)", "Concatenation (MOE-like weights-stack)", "SVD Fusion (Task Arithmetic/Compressed)"],
value="Iterative EMA (Linear w/ Beta/Sigma coefficient)",
label="Merge Method"
)
with gr.Row():
t3_weights = gr.Textbox(label="Weights (comma-separated) – for Concat/SVD", placeholder="1.0, 0.5, 0.8...")
t3_rank = gr.Number(label="Target Rank – For SVD only", value=128, minimum=4, maximum=1024)
with gr.Row():
t3_beta = gr.Slider(label="Beta – for linear/post-hoc EMA", value=0.95, minimum=0.01, maximum=1.00, step=0.01)
t3_sigma = gr.Slider(label="Sigma Rel – for linear/post-hoc EMA", value=0.21, minimum=0.01, maximum=1.00, step=0.01)
t3_out = gr.Textbox(label="Output Repo")
t3_priv = gr.Checkbox(label="Private Output", value=True)
t3_btn = gr.Button("Merge")
t3_res = gr.Textbox(label="Result")
t3_btn.click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], t3_res)
with gr.Tab("Resize Adapter"):
t4_token = gr.Textbox(label="Token", type="password")
t4_in = gr.Textbox(label="LoRA")
with gr.Row():
t4_rank = gr.Number(label="To Rank (Safety Ceiling)", value=8, minimum=1, maximum=512, step=1)
t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None", label="Dynamic Method")
t4_param = gr.Number(label="Dynamic Param", value=0.9)
gr.Markdown(
"""
### 📉 Dynamic Resizing Guide
These methods intelligently determine the best rank per layer.
* **sv_ratio (Relative Strength):** Keeps features that are at least `1/Param` as strong as the main feature. **Param must be >= 2**. (e.g. 2 = keep features half as strong as top).
* **sv_fro (Visual Information Density):** Preserves `Param%` of the total information content (Frobenius Norm) of the layer. **Param between 0.0 and 1.0** (e.g. 0.9 = 90% info retention).
* **sv_cumulative (Cumulative Sum):** Preserves weights that sum up to `Param%` of the total strength. **Param between 0.0 and 1.0**.
* **⚠️ Safety Ceiling:** The **"To Rank"** slider acts as a hard limit. Even if a dynamic method wants a higher rank, it will be cut down to this number to keep file sizes small.
"""
)
t4_out = gr.Textbox(label="Output")
t4_btn = gr.Button("Resize")
t4_res = gr.Textbox(label="Result")
t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
# =================================================================================
# UPDATED TAB 5: FULL MODEL MERGER (MergeKit Engine)
# =================================================================================
with gr.Tab("Full Model Merge (MergeKit)"):
gr.Markdown("### 🧩 Multi-Model Weight Fusion")
with gr.Row():
t5_token = gr.Textbox(label="HF Token", type="password")
t5_method = gr.Dropdown(["Linear", "SLERP", "TIES", "DARE_TIES", "DARE_LINEAR", "Model_Stock"], value="TIES", label="Merge Method")
t5_dtype = gr.Radio(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision")
t5_models = gr.TextArea(label="Models to Merge (One Repo ID per line)", placeholder="repo/model-a\nrepo/model-b\nrepo/model-c...")
with gr.Row():
t5_base = gr.Textbox(label="Base Model (Required for TIES/DARE)", placeholder="repo/base-model")
t5_shard = gr.Slider(0.5, 10, 2.0, step=0.5, label="Max Shard Size (GB)")
with gr.Accordion("Advanced Parametrization", open=False):
with gr.Row():
t5_weights = gr.Textbox(label="Weights (Comma separated)", placeholder="1.0, 0.5, 0.3")
t5_density = gr.Slider(0, 1, 0.5, label="Density (TIES/DARE)")
with gr.Row():
t5_layers = gr.Textbox(label="Layer Ranges (JSON Format)", placeholder='[{"start": 0, "end": 32}]')
t5_tok_src = gr.Dropdown(["base", "union", "first"], value="base", label="Tokenizer Source")
t5_out = gr.Textbox(label="Output Repo (User/Repo)")
t5_priv = gr.Checkbox(label="Private Output", value=True)
t5_btn = gr.Button("🚀 Execute Full Merge", variant="primary")
t5_res = gr.Textbox(label="Result")
t5_btn.click(task_full_model_merge, [t5_token, t5_models, t5_method, t5_dtype, t5_base, gr.State(""), t5_density, t5_shard, t5_out, t5_priv], t5_res)
# =================================================================================
# UPDATED TAB 6: MIXTURE OF EXPERTS (MoE Creator)
# =================================================================================
with gr.Tab("Create MoE"):
gr.Markdown("### 🤖 Mixture of Experts Upscaling")
with gr.Row():
t6_token = gr.Textbox(label="HF Token", type="password")
t6_dtype = gr.Radio(["bfloat16", "float16", "float32"], value="bfloat16", label="Precision")
t6_shard = gr.Slider(0.5, 10, 2.0, label="Shard Size (GB)")
t6_base = gr.Textbox(label="Base Architecture Model", placeholder="repo/backbone-model")
t6_experts = gr.TextArea(label="Experts (One per line)", placeholder="repo/expert-1\nrepo/expert-2...")
with gr.Accordion("MoE Hyperparameters", open=True):
with gr.Row():
t6_gate_mode = gr.Dropdown(["cheap_embed", "hidden", "random"], value="cheap_embed", label="Gating Mode")
t6_tok_src = gr.Dropdown(["base", "union", "first"], value="base", label="Tokenizer Source")
t6_out = gr.Textbox(label="Output Repo", placeholder="User/Repo")
t6_priv = gr.Checkbox(label="Private", value=True)
t6_btn = gr.Button("🏗️ Build MoE", variant="primary")
t6_res = gr.Textbox(label="Result")
t6_btn.click(task_create_moe, [t6_token, t6_dtype, t6_shard, t6_base, t6_experts, t6_gate_mode, t6_tok_src, t6_out, t6_priv], t6_res)
if __name__ == "__main__":
demo.queue().launch(css=css, ssr_mode=False)