AlekseyCalvin's picture
Rename app.py to app3.py
5bb1524 verified
import gradio as gr
import torch
import os
import gc
import shutil
import requests
import json
import struct
import numpy as np
import re
from pathlib import Path
from typing import Dict, Any, Optional, List
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
from safetensors.torch import load_file, save_file
from tqdm import tqdm
# --- Memory Efficient Safetensors ---
class MemoryEfficientSafeOpen:
def __init__(self, filename):
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self) -> list[str]:
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype_map = {
"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
"I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
"U8": torch.uint8, "BOOL": torch.bool
}
dtype = dtype_map[metadata["dtype"]]
shape = metadata["shape"]
return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
# --- Constants & Setup ---
try:
TempDir = Path("/tmp/temp_tool")
os.makedirs(TempDir, exist_ok=True)
except:
TempDir = Path("./temp_tool")
os.makedirs(TempDir, exist_ok=True)
api = HfApi()
def cleanup_temp():
if TempDir.exists():
shutil.rmtree(TempDir)
os.makedirs(TempDir, exist_ok=True)
gc.collect()
def get_key_stem(key):
key = key.replace(".weight", "").replace(".bias", "")
key = key.replace(".lora_down", "").replace(".lora_up", "")
key = key.replace(".lora_A", "").replace(".lora_B", "")
key = key.replace(".alpha", "")
prefixes = [
"model.diffusion_model.", "diffusion_model.", "model.",
"transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
]
changed = True
while changed:
changed = False
for p in prefixes:
if key.startswith(p):
key = key[len(p):]
changed = True
return key
# =================================================================================
# TAB 1: MERGE & RESHARD
# =================================================================================
def parse_hf_url(url):
"""Parses a direct HF URL into repo_id and filename."""
# Pattern: https://huggingface.co/{user}/{repo}/resolve/{branch}/{filename...}
if "huggingface.co" in url and "resolve" in url:
try:
parts = url.split("huggingface.co/")[-1].split("/")
# parts[0]=user, parts[1]=repo, parts[2]=resolve, parts[3]=branch, parts[4:]=file
repo_id = f"{parts[0]}/{parts[1]}"
filename = "/".join(parts[4:]).split("?")[0] # Strip query params
return repo_id, filename
except:
return None, None
return None, None
def download_lora_smart(input_str, token):
local_path = TempDir / "adapter.safetensors"
if local_path.exists(): os.remove(local_path)
print(f"Resolving LoRA Input: {input_str}")
# 1. Try Parse as HF URL (Most Robust Method)
repo_id, filename = parse_hf_url(input_str)
if repo_id and filename:
print(f"Detected HF URL. Repo: {repo_id}, File: {filename}")
try:
hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
# Move to standard name
found = list(TempDir.rglob(filename.split("/")[-1]))[0] # Handle subfolder downloads
if found != local_path: shutil.move(found, local_path)
return local_path
except Exception as e:
print(f"HF Download failed: {e}. Falling back...")
# 2. Try as Raw Repo ID (User/Repo)
try:
# Check if user put "User/Repo/file.safetensors"
if ".safetensors" in input_str and input_str.count("/") >= 2:
parts = input_str.split("/")
repo_id = f"{parts[0]}/{parts[1]}"
filename = "/".join(parts[2:])
hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
found = list(TempDir.rglob(filename.split("/")[-1]))[0]
if found != local_path: shutil.move(found, local_path)
return local_path
# Standard Auto-Discovery
candidates = ["adapter_model.safetensors", "model.safetensors"]
files = list_repo_files(repo_id=input_str, token=token)
target = next((f for f in files if f in candidates), None)
if not target:
safes = [f for f in files if f.endswith(".safetensors")]
if safes: target = safes[0]
if not target: raise ValueError("No safetensors found")
hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
found = list(TempDir.rglob(target.split("/")[-1]))[0]
if found != local_path: shutil.move(found, local_path)
return local_path
except Exception as e:
# 3. Last Resort: Raw Requests (For non-HF links)
if input_str.startswith("http"):
try:
headers = {"Authorization": f"Bearer {token}"} if token else {}
r = requests.get(input_str, stream=True, headers=headers, timeout=60)
r.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
return local_path
except Exception as req_e:
raise ValueError(f"All download methods failed.\nRepo Logic Error: {e}\nURL Logic Error: {req_e}")
raise e
def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
print(f"Loading LoRA from {lora_path}...")
state_dict = load_file(lora_path, device="cpu")
pairs = {}
alphas = {}
for k, v in state_dict.items():
stem = get_key_stem(k)
if "alpha" in k:
alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
else:
if stem not in pairs: pairs[stem] = {}
if "lora_down" in k or "lora_A" in k:
pairs[stem]["down"] = v.to(dtype=precision_dtype)
pairs[stem]["rank"] = v.shape[0]
elif "lora_up" in k or "lora_B" in k:
pairs[stem]["up"] = v.to(dtype=precision_dtype)
for stem in pairs:
pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
return pairs
class ShardBuffer:
def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
self.max_bytes = int(max_size_gb * 1024**3)
self.output_dir = output_dir
self.output_repo = output_repo
self.subfolder = subfolder
self.hf_token = hf_token
self.filename_prefix = filename_prefix
self.buffer = []
self.current_bytes = 0
self.shard_count = 0
self.index_map = {}
self.total_size = 0
def add_tensor(self, key, tensor):
if tensor.dtype == torch.bfloat16:
raw_bytes = tensor.view(torch.int16).numpy().tobytes()
dtype_str = "BF16"
elif tensor.dtype == torch.float16:
raw_bytes = tensor.numpy().tobytes()
dtype_str = "F16"
else:
raw_bytes = tensor.numpy().tobytes()
dtype_str = "F32"
size = len(raw_bytes)
self.buffer.append({
"key": key,
"data": raw_bytes,
"dtype": dtype_str,
"shape": tensor.shape
})
self.current_bytes += size
self.total_size += size
if self.current_bytes >= self.max_bytes:
self.flush()
def flush(self):
if not self.buffer: return
self.shard_count += 1
filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...")
header = {"__metadata__": {"format": "pt"}}
current_offset = 0
for item in self.buffer:
header[item["key"]] = {
"dtype": item["dtype"],
"shape": item["shape"],
"data_offsets": [current_offset, current_offset + len(item["data"])]
}
current_offset += len(item["data"])
self.index_map[item["key"]] = filename
header_json = json.dumps(header).encode('utf-8')
out_path = self.output_dir / filename
with open(out_path, 'wb') as f:
f.write(struct.pack('<Q', len(header_json)))
f.write(header_json)
for item in self.buffer:
f.write(item["data"])
print(f"Uploading {path_in_repo}...")
api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
os.remove(out_path)
self.buffer = []
self.current_bytes = 0
gc.collect()
def copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder):
"""Aggressively copy all config/misc files, only skipping heavy weights."""
print(f"Copying config files from {base_repo}...")
try:
files = list_repo_files(repo_id=base_repo, token=hf_token)
blocked_ext = ['.safetensors', '.bin', '.pt', '.pth', '.msgpack', '.h5', '.onnx']
for f in files:
# Filter by subfolder if needed
if base_subfolder and not f.startswith(base_subfolder): continue
# Block heavy weights
if any(f.endswith(ext) for ext in blocked_ext): continue
print(f"Transferring {f}...")
local = hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=TempDir)
# Determine path in new repo
rel_name = f[len(base_subfolder):].lstrip('/') if base_subfolder else f
target_path = f"{output_subfolder}/{rel_name}" if output_subfolder else rel_name
api.upload_file(path_or_fileobj=local, path_in_repo=target_path, repo_id=output_repo, token=hf_token)
os.remove(local)
except Exception as e:
print(f"Config copy warning: {e}")
def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix=None, is_root_merge=False):
print(f"Scanning {src_repo} for structure cloning...")
try:
files = api.list_repo_files(repo_id=src_repo, token=token)
for f in tqdm(files, desc="Copying Structure"):
if ignore_prefix and f.startswith(ignore_prefix): continue
if is_root_merge:
if any(f.endswith(ext) for ext in ['.safetensors', '.bin', '.pt', '.pth']):
continue
try:
local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir)
api.upload_file(path_or_fileobj=local, path_in_repo=f, repo_id=dst_repo, token=token)
if os.path.exists(local): os.remove(local)
except: pass
except Exception as e: print(f"Structure clone error: {e}")
def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
cleanup_temp()
if not hf_token: return "Error: HF Token required."
login(hf_token.strip())
try:
api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
except Exception as e: return f"Error creating repo: {e}"
# Logic: If using a subfolder like 'transformer', we want standard diffusers naming
output_subfolder = base_subfolder if base_subfolder else ""
# 2. Copy Configs from Base (Aggressive Copy)
if base_subfolder:
copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder)
# 3. Clone Structure Repo
if structure_repo:
ignore = output_subfolder if output_subfolder else None
streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix=ignore, is_root_merge=not bool(output_subfolder))
# 4. Download Shards
progress(0.1, desc="Downloading Input Model...")
files = list_repo_files(repo_id=base_repo, token=hf_token)
input_shards = []
for f in files:
if f.endswith(".safetensors"):
if output_subfolder and not f.startswith(output_subfolder): continue
local = TempDir / "inputs" / os.path.basename(f)
os.makedirs(local.parent, exist_ok=True)
hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local.parent, local_dir_use_symlinks=False)
found = list(local.parent.rglob(os.path.basename(f)))
if found: input_shards.append(found[0])
if not input_shards: return "No safetensors found."
input_shards.sort()
# --- NAMING CONVENTION ---
# Force diffusion naming if target is transformer/unet
if output_subfolder in ["transformer", "unet", "qint4", "qint8"]:
filename_prefix = "diffusion_pytorch_model"
index_filename = "diffusion_pytorch_model.safetensors.index.json"
elif "diffusion_pytorch_model" in os.path.basename(input_shards[0]):
filename_prefix = "diffusion_pytorch_model"
index_filename = "diffusion_pytorch_model.safetensors.index.json"
else:
filename_prefix = "model"
index_filename = "model.safetensors.index.json"
print(f"Naming scheme: {filename_prefix}")
# 5. Load LoRA
dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
try:
progress(0.15, desc="Downloading LoRA...")
lora_path = download_lora_smart(lora_input, hf_token)
lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
except Exception as e: return f"Error loading LoRA: {e}"
# 6. Stream
buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix)
for i, shard_file in enumerate(input_shards):
progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}")
with MemoryEfficientSafeOpen(shard_file) as f:
keys = f.keys()
for k in keys:
v = f.get_tensor(k)
base_stem = get_key_stem(k)
match = lora_pairs.get(base_stem)
# QKV Heuristics
if not match:
if "to_q" in base_stem:
qkv = base_stem.replace("to_q", "qkv")
match = lora_pairs.get(qkv)
elif "to_k" in base_stem:
qkv = base_stem.replace("to_k", "qkv")
match = lora_pairs.get(qkv)
elif "to_v" in base_stem:
qkv = base_stem.replace("to_v", "qkv")
match = lora_pairs.get(qkv)
if match:
down = match["down"]
up = match["up"]
scaling = scale * (match["alpha"] / match["rank"])
if len(v.shape) == 4 and len(down.shape) == 2:
down = down.unsqueeze(-1).unsqueeze(-1)
up = up.unsqueeze(-1).unsqueeze(-1)
try:
if len(up.shape) == 4:
delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
else:
delta = up @ down
except: delta = up.T @ down
delta = delta * scaling
valid = True
if delta.shape == v.shape: pass
elif delta.shape[0] == v.shape[0] * 3:
chunk = v.shape[0]
if "to_q" in k: delta = delta[0:chunk, ...]
elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
elif "to_v" in k: delta = delta[2*chunk:, ...]
else: valid = False
elif delta.numel() == v.numel(): delta = delta.reshape(v.shape)
else: valid = False
if valid:
v = v.to(dtype)
delta = delta.to(dtype)
v.add_(delta)
del delta
if v.dtype != dtype: v = v.to(dtype)
buffer.add_tensor(k, v)
del v
os.remove(shard_file)
gc.collect()
buffer.flush()
print(f"Uploading Index: {index_filename} (Size: {buffer.total_size})")
index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map}
with open(TempDir / index_filename, "w") as f:
json.dump(index_data, f, indent=4)
path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename
api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
cleanup_temp()
return f"Done! Merged {buffer.shard_count} shards to {output_repo}"
# =================================================================================
# TAB 2: EXTRACT LORA
# =================================================================================
def identify_and_download_model(input_str, token):
"""
Smart download:
1. Checks if input is a direct URL -> downloads specific file.
2. If input is a Repo ID -> scans for diffusers format (unet/transformer) or standard safetensors.
"""
print(f"Resolving model input: {input_str}")
# --- STRATEGY A: Direct URL ---
repo_id_from_url, filename_from_url = parse_hf_url(input_str)
if repo_id_from_url and filename_from_url:
print(f"Detected Direct Link. Repo: {repo_id_from_url}, File: {filename_from_url}")
local_path = TempDir / os.path.basename(filename_from_url)
# Clean up previous download if name conflicts
if local_path.exists(): os.remove(local_path)
try:
hf_hub_download(repo_id=repo_id_from_url, filename=filename_from_url, token=token, local_dir=TempDir)
# Find where it landed (handling subfolders in local_dir)
found = list(TempDir.rglob(os.path.basename(filename_from_url)))[0]
return found
except Exception as e:
print(f"URL Download failed: {e}. Trying fallback...")
# --- STRATEGY B: Repo Discovery (Auto-Detect) ---
# If we are here, input_str is treated as a Repo ID (e.g. "ostris/Z-Image-De-Turbo")
print(f"Scanning Repo {input_str} for model weights...")
try:
files = list_repo_files(repo_id=input_str, token=token)
except Exception as e:
raise ValueError(f"Failed to list repo '{input_str}'. If this is a URL, ensure it is formatted correctly. Error: {e}")
# Priority list for diffusers vs single file
priorities = [
"transformer/diffusion_pytorch_model.safetensors",
"unet/diffusion_pytorch_model.safetensors",
"model.safetensors",
# Fallback to any safetensors that isn't an adapter or lora
lambda f: f.endswith(".safetensors") and "lora" not in f and "adapter" not in f and "extracted" not in f
]
target_file = None
for p in priorities:
if callable(p):
candidates = [f for f in files if p(f)]
if candidates:
# Pick the largest file if multiple candidates (heuristic for "main" model)
target_file = candidates[0]
break
elif p in files:
target_file = p
break
if not target_file:
raise ValueError(f"Could not find a valid model weight file in {input_str}. Ensure it contains .safetensors weights.")
print(f"Downloading auto-detected weight file: {target_file}")
hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir)
# Locate actual path
found = list(TempDir.rglob(os.path.basename(target_file)))[0]
return found
def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
org = MemoryEfficientSafeOpen(model_org)
tuned = MemoryEfficientSafeOpen(model_tuned)
lora_sd = {}
print("Calculating diffs & extracting LoRA...")
# Get intersection of keys
keys = set(org.keys()).intersection(set(tuned.keys()))
for key in tqdm(keys, desc="Extracting"):
# Skip integer buffers/metadata
if "num_batches_tracked" in key or "running_mean" in key or "running_var" in key:
continue
mat_org = org.get_tensor(key).float()
mat_tuned = tuned.get_tensor(key).float()
# Skip if shapes mismatch (shouldn't happen if models match)
if mat_org.shape != mat_tuned.shape: continue
diff = mat_tuned - mat_org
# Skip if no difference
if torch.max(torch.abs(diff)) < 1e-4: continue
out_dim = diff.shape[0]
in_dim = diff.shape[1] if len(diff.shape) > 1 else 1
r = min(rank, in_dim, out_dim)
is_conv = len(diff.shape) == 4
if is_conv: diff = diff.flatten(start_dim=1)
elif len(diff.shape) == 1: diff = diff.unsqueeze(1) # Handle biases if needed
try:
# Use svd_lowrank for massive speedup on CPU vs linalg.svd
U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
Vh = V.t()
U = U[:, :r]
S = S[:r]
Vh = Vh[:r, :]
# Merge S into U for standard LoRA format
U = U @ torch.diag(S)
# Clamp outliers
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(torch.abs(dist), clamp)
if hi_val > 0:
U = U.clamp(-hi_val, hi_val)
Vh = Vh.clamp(-hi_val, hi_val)
if is_conv:
U = U.reshape(out_dim, r, 1, 1)
Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
else:
U = U.reshape(out_dim, r)
Vh = Vh.reshape(r, in_dim)
stem = key.replace(".weight", "")
lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
except Exception as e:
print(f"Skipping {key} due to error: {e}")
pass
out = TempDir / "extracted.safetensors"
save_file(lora_sd, out)
return str(out)
def task_extract(hf_token, org, tun, rank, out):
cleanup_temp()
if hf_token: login(hf_token.strip())
try:
print("Downloading Original Model...")
p1 = identify_and_download_model(org, hf_token)
print("Downloading Tuned Model...")
p2 = identify_and_download_model(tun, hf_token)
f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
return "Done! Extracted to " + out
except Exception as e: return f"Error: {e}"
# =================================================================================
# TAB 3: MERGE ADAPTERS (Multi-Method)
# =================================================================================
def load_full_state_dict(path):
"""Loads a safetensor file and cleans keys for easier processing."""
raw = load_file(path, device="cpu")
cleaned = {}
for k, v in raw.items():
# Map common keys to standard "lora_up/lora_down"
if "lora_A" in k: new_k = k.replace("lora_A", "lora_down")
elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up")
else: new_k = k
cleaned[new_k] = v.float()
return cleaned
# --- Original EMA Method ---
def sigma_rel_to_gamma(sigma_rel):
t = sigma_rel**-2
coeffs = [1, 7, 16 - t, 12 - t]
roots = np.roots(coeffs)
gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
return gamma
def merge_lora_iterative_ema(paths, beta, sigma_rel):
print("Executing Iterative EMA Merge (Original Method)...")
base_sd = load_file(paths[0], device="cpu")
for k in base_sd:
if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
gamma = None
if sigma_rel > 0:
gamma = sigma_rel_to_gamma(sigma_rel)
for i, path in enumerate(paths[1:]):
print(f"Merging {path}")
if gamma is not None:
t = i + 1
current_beta = (1 - 1 / t) ** (gamma + 1)
else:
current_beta = beta
curr = load_file(path, device="cpu")
for k in base_sd:
if k in curr and "alpha" not in k:
base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
return base_sd
# --- New Concatenation Method (DiffSynth) ---
def merge_lora_concatenation(adapter_states, weights):
"""
DiffSynth Method: Concatenates ranks.
New Rank = sum(ranks). Lossless merging.
"""
print("Executing Concatenation Merge (Rank Summation)...")
merged_state = {}
# Identify all stems (layers) present across all adapters
all_stems = set()
for state in adapter_states:
for k in state.keys():
stem = k.split(".lora_")[0]
if "lora_" in k: all_stems.add(stem)
for stem in tqdm(all_stems, desc="Concatenating Layers"):
down_list = []
up_list = []
alpha_sum = 0.0
for i, state in enumerate(adapter_states):
w = weights[i]
down_key = f"{stem}.lora_down.weight"
up_key = f"{stem}.lora_up.weight"
alpha_key = f"{stem}.alpha"
if down_key in state and up_key in state:
d = state[down_key]
u = state[up_key] * w # weighted contribution applied to UP
down_list.append(d)
up_list.append(u)
if alpha_key in state:
alpha_sum += state[alpha_key].item()
else:
alpha_sum += d.shape[0]
if down_list and up_list:
# Concat Down (A) along dim 0 (output of A, input to B) - Wait, lora_A is (rank, in)
# Concat Up (B) along dim 1 (input of B) - lora_B is (out, rank)
# Reference: DiffSynth code: lora_A = concat(tensors_A, dim=0), lora_B = concat(tensors_B, dim=1)
new_down = torch.cat(down_list, dim=0) # (sum_rank, in)
new_up = torch.cat(up_list, dim=1) # (out, sum_rank)
merged_state[f"{stem}.lora_down.weight"] = new_down.contiguous()
merged_state[f"{stem}.lora_up.weight"] = new_up.contiguous()
merged_state[f"{stem}.alpha"] = torch.tensor(alpha_sum)
return merged_state
# --- New SVD/Task Arithmetic Method ---
def merge_lora_svd(adapter_states, weights, target_rank):
"""
SVD / Task Arithmetic Method:
1. Calculate Delta W for each adapter: dW = B @ A
2. Sum Delta Ws: Total dW = sum(weight_i * dW_i)
3. SVD(Total dW) -> New B, New A at target_rank
"""
print(f"Executing SVD Merge (Target Rank: {target_rank})...")
merged_state = {}
all_stems = set()
for state in adapter_states:
for k in state.keys():
stem = k.split(".lora_")[0]
if "lora_" in k: all_stems.add(stem)
for stem in tqdm(all_stems, desc="SVD Merging Layers"):
total_delta = None
valid_layer = False
for i, state in enumerate(adapter_states):
w = weights[i]
down_key = f"{stem}.lora_down.weight"
up_key = f"{stem}.lora_up.weight"
alpha_key = f"{stem}.alpha"
if down_key in state and up_key in state:
down = state[down_key]
up = state[up_key]
alpha = state[alpha_key].item() if alpha_key in state else down.shape[0]
rank = down.shape[0]
scale = (alpha / rank) * w
# Reconstruct Delta
if len(down.shape) == 4: # Conv2d
d_flat = down.flatten(start_dim=1)
u_flat = up.flatten(start_dim=1)
delta = (u_flat @ d_flat).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
else:
delta = up @ down
delta = delta * scale
if total_delta is None:
total_delta = delta
valid_layer = True
else:
if total_delta.shape == delta.shape:
total_delta += delta
else:
print(f"Shape mismatch in {stem}, skipping.")
if valid_layer and total_delta is not None:
out_dim = total_delta.shape[0]
in_dim = total_delta.shape[1]
is_conv = len(total_delta.shape) == 4
if is_conv:
flat_delta = total_delta.flatten(start_dim=1)
else:
flat_delta = total_delta
try:
U, S, V = torch.svd_lowrank(flat_delta, q=target_rank + 4, niter=4)
Vh = V.t()
U = U[:, :target_rank]
S = S[:target_rank]
Vh = Vh[:target_rank, :]
U = U @ torch.diag(S)
if is_conv:
U = U.reshape(out_dim, target_rank, 1, 1)
Vh = Vh.reshape(target_rank, in_dim, total_delta.shape[2], total_delta.shape[3])
else:
U = U.reshape(out_dim, target_rank)
Vh = Vh.reshape(target_rank, in_dim)
merged_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
merged_state[f"{stem}.lora_up.weight"] = U.contiguous()
merged_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
except Exception as e:
print(f"SVD Failed for {stem}: {e}")
return merged_state
def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private):
cleanup_temp()
if hf_token: login(hf_token.strip())
if not out_repo or not out_repo.strip():
return "Error: Output Repo cannot be empty."
# 1. Parse Inputs (Multi-line support)
raw_lines = inputs_text.replace(" ", "\n").split('\n')
urls = [line.strip() for line in raw_lines if line.strip()]
if len(urls) < 2: return "Error: Please provide at least 2 adapters."
# 2. Parse Weights (for SVD/Concatenation)
try:
if not weight_str.strip():
weights = [1.0] * len(urls)
else:
weights = [float(w.strip()) for w in weight_str.split(',')]
# Broadcast or Truncate
if len(weights) < len(urls):
weights += [1.0] * (len(urls) - len(weights))
else:
weights = weights[:len(urls)]
except:
return "Error parsing weights. Use format: 1.0, 0.5, 0.8"
# 3. Download All
paths = []
try:
for url in tqdm(urls, desc="Downloading Adapters"):
paths.append(download_lora_smart(url, hf_token))
except Exception as e: return f"Download Error: {e}"
merged = None
# 4. Execute Selected Method
if "Iterative EMA" in method:
# Calls the original method logic exactly
merged = merge_lora_iterative_ema(paths, beta, sigma_rel)
else:
# For new methods, we load everything upfront
states = [load_full_state_dict(p) for p in paths]
if "Concatenation" in method:
merged = merge_lora_concatenation(states, weights)
elif "SVD" in method:
merged = merge_lora_svd(states, weights, int(target_rank))
if not merged: return "Merge failed (Result empty)."
# 5. Save & Upload
out = TempDir / "merged_adapters.safetensors"
save_file(merged, out)
try:
api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
return f"Success! Merged to {out_repo}"
except Exception as e: return f"Upload Error: {e}"
# =================================================================================
# TAB 4: RESIZE (CPU Optimized)
# =================================================================================
def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
cleanup_temp()
if not hf_token: return "Error: Token required"
login(hf_token.strip())
try:
path = download_lora_smart(lora_input, hf_token)
except Exception as e: return f"Error: {e}"
state = load_file(path, device="cpu")
new_state = {}
groups = {}
for k in state:
stem = get_key_stem(k)
simple = k.split(".lora_")[0]
if simple not in groups: groups[simple] = {}
if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
if "alpha" in k: groups[simple]["alpha"] = state[k]
print(f"Resizing {len(groups)} blocks...")
for stem, g in tqdm(groups.items()):
if "down" in g and "up" in g:
down, up = g["down"].float(), g["up"].float()
# 1. Merge Up/Down
if len(down.shape) == 4:
merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
flat = merged.flatten(1)
else:
merged = up @ down
flat = merged
# 2. FAST SVD (svd_lowrank)
target_rank = int(new_rank)
# Add buffer to q to ensure convergence
q = min(target_rank + 10, min(flat.shape))
U, S, V = torch.svd_lowrank(flat, q=q)
Vh = V.t()
# 3. Dynamic Rank Selection
if dynamic_method == "sv_ratio":
target_rank = index_sv_ratio(S, dynamic_param)
# Hard limit by user's max rank
target_rank = min(target_rank, int(new_rank), S.shape[0])
# 4. Truncate
U = U[:, :target_rank]
S = S[:target_rank]
Vh = Vh[:target_rank, :]
# 5. Reconstruct Up Matrix
U = U @ torch.diag(S)
if len(down.shape) == 4:
U = U.reshape(up.shape[0], target_rank, 1, 1)
Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3])
# 6. Save (FIX: Enforce contiguous memory layout)
new_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
new_state[f"{stem}.lora_up.weight"] = U.contiguous()
new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
out = TempDir / "resized.safetensors"
# safetensors requires contiguous tensors
save_file(new_state, out)
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token)
return "Done"
# =================================================================================
# UI
# =================================================================================
css = ".container { max-width: 900px; margin: auto; }"
with gr.Blocks() as demo:
gr.Markdown("# 🧰SOONmerge® LoRA Toolkit")
with gr.Tabs():
with gr.Tab("Merge to Base + Reshard Output"):
t1_token = gr.Textbox(label="Token", type="password")
t1_base = gr.Textbox(label="Base Repo (Diffusers)", value="ostris/Z-Image-De-Turbo")
t1_sub = gr.Textbox(label="Subfolder", value="transformer")
t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
with gr.Row():
t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1)
t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1)
t1_out = gr.Textbox(label="Output Repo")
t1_struct = gr.Textbox(label="Diffusers Extras (Copies VAE/TextEnc/etc)", value="Tongyi-MAI/Z-Image-Turbo")
t1_priv = gr.Checkbox(label="Private", value=True)
t1_btn = gr.Button("Merge")
t1_res = gr.Textbox(label="Result")
t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
with gr.Tab("Extract Adapter"):
t2_token = gr.Textbox(label="Token", type="password")
t2_org = gr.Textbox(label="Original Model")
t2_tun = gr.Textbox(label="Tuned Model")
t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1)
t2_out = gr.Textbox(label="Output Repo")
t2_btn = gr.Button("Extract")
t2_res = gr.Textbox(label="Result")
t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
with gr.Tab("Merge Multiple Adapters"):
gr.Markdown("### Batch Adapter Merging")
t3_token = gr.Textbox(label="Token", type="password")
t3_urls = gr.TextArea(label="Adapter URLs/Repos (One per line, or space separated)", placeholder="ostris/lora1\nhttps://hf.co/user/lora2.safetensors\n...")
with gr.Row():
t3_method = gr.Dropdown(
["Iterative EMA (Original Beta/Sigma)", "Concatenation (DiffSynth - Lossless)", "SVD Merge (Task Arithmetic/Compressed)"],
value="Iterative EMA (Original Beta/Sigma)",
label="Merge Method"
)
with gr.Row():
t3_weights = gr.Textbox(label="Weights (Comma separated) - For Concat/SVD", placeholder="1.0, 0.5, 0.8...")
t3_rank = gr.Number(label="Target Rank - For SVD only", value=128, minimum=4, maximum=1024)
with gr.Row():
t3_beta = gr.Slider(label="Beta - For EMA only", value=0.95, minimum=0.01, maximum=1.00, step=0.01)
t3_sigma = gr.Slider(label="Sigma Rel - For EMA only", value=0.21, minimum=0.01, maximum=1.00, step=0.01)
t3_out = gr.Textbox(label="Output Repo")
t3_priv = gr.Checkbox(label="Private Output", value=True)
t3_btn = gr.Button("Merge Adapters")
t3_res = gr.Textbox(label="Result")
t3_btn.click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], t3_res)
with gr.Tab("Resize Adapter"):
t4_token = gr.Textbox(label="Token", type="password")
t4_in = gr.Textbox(label="LoRA")
with gr.Row():
t4_rank = gr.Number(label="To Rank (Lower Only!)", value=8, minimum=1, maximum=256, step=1)
t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method")
t4_param = gr.Number(label="Dynamic Param", value=4.0)
t4_out = gr.Textbox(label="Output")
t4_btn = gr.Button("Resize")
t4_res = gr.Textbox(label="Result")
t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
if __name__ == "__main__":
demo.queue().launch(css=css, ssr_mode=False)