Soon_Merger / app77_worksBut_wDiTshardsOnly.py
AlekseyCalvin's picture
Rename app.py to app77_worksBut_wDiTshardsOnly.py
9312d5d verified
import gradio as gr
import torch
import os
import gc
import shutil
import requests
import json
import struct
import numpy as np
import re
from pathlib import Path
from typing import Dict, Any, Optional, List
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
from safetensors.torch import load_file, save_file
from tqdm import tqdm
# --- Memory Efficient Safetensors ---
class MemoryEfficientSafeOpen:
"""
Reads safetensors metadata and tensors without mmap, keeping RAM usage low.
"""
def __init__(self, filename):
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self) -> list[str]:
return [k for k in self.header.keys() if k != "__metadata__"]
def metadata(self) -> Dict[str, str]:
return self.header.get("__metadata__", {})
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
header_size = struct.unpack("<Q", self.file.read(8))[0]
header_json = self.file.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype_map = {
"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
"I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
"U8": torch.uint8, "BOOL": torch.bool
}
dtype = dtype_map[metadata["dtype"]]
shape = metadata["shape"]
return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
# --- Constants & Setup ---
try:
TempDir = Path("/tmp/temp_tool")
os.makedirs(TempDir, exist_ok=True)
except:
TempDir = Path("./temp_tool")
os.makedirs(TempDir, exist_ok=True)
api = HfApi()
def cleanup_temp():
if TempDir.exists():
shutil.rmtree(TempDir)
os.makedirs(TempDir, exist_ok=True)
gc.collect()
def download_file(input_path, token, filename=None):
local_path = TempDir / (filename if filename else "model.safetensors")
if input_path.startswith("http"):
print(f"Downloading {filename} from URL...")
try:
response = requests.get(input_path, stream=True, timeout=30)
response.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
except Exception as e: raise ValueError(f"Download failed: {e}")
else:
print(f"Downloading {filename} from Hub...")
if not filename:
try:
files = list_repo_files(repo_id=input_path, token=token)
safetensors = [f for f in files if f.endswith(".safetensors")]
filename = safetensors[0] if safetensors else "adapter_model.safetensors"
except: filename = "adapter_model.safetensors"
try:
hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
if not (TempDir / filename).exists():
found = list(TempDir.rglob(filename))
if found: shutil.move(found[0], local_path)
except Exception as e: raise ValueError(f"Hub download failed: {e}")
return local_path
def get_key_stem(key):
key = key.replace(".weight", "").replace(".bias", "")
key = key.replace(".lora_down", "").replace(".lora_up", "")
key = key.replace(".lora_A", "").replace(".lora_B", "")
key = key.replace(".alpha", "")
prefixes = [
"model.diffusion_model.", "diffusion_model.", "model.",
"transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
]
changed = True
while changed:
changed = False
for p in prefixes:
if key.startswith(p):
key = key[len(p):]
changed = True
return key
# =================================================================================
# TAB 1: GREEDY STREAMING RESHARDER + SERVER-SIDE COPY
# =================================================================================
def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
print(f"Loading LoRA from {lora_path}...")
state_dict = load_file(lora_path, device="cpu")
pairs = {}
alphas = {}
for k, v in state_dict.items():
stem = get_key_stem(k)
if "alpha" in k:
alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
else:
if stem not in pairs: pairs[stem] = {}
if "lora_down" in k or "lora_A" in k:
pairs[stem]["down"] = v.to(dtype=precision_dtype)
pairs[stem]["rank"] = v.shape[0]
elif "lora_up" in k or "lora_B" in k:
pairs[stem]["up"] = v.to(dtype=precision_dtype)
for stem in pairs:
pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
return pairs
class ShardBuffer:
def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token):
self.max_bytes = int(max_size_gb * 1024**3)
self.output_dir = output_dir
self.output_repo = output_repo
self.subfolder = subfolder
self.hf_token = hf_token
self.buffer = []
self.current_bytes = 0
self.shard_count = 0
self.index_map = {}
def add_tensor(self, key, tensor):
if tensor.dtype == torch.bfloat16:
raw_bytes = tensor.view(torch.int16).numpy().tobytes()
dtype_str = "BF16"
elif tensor.dtype == torch.float16:
raw_bytes = tensor.numpy().tobytes()
dtype_str = "F16"
else:
raw_bytes = tensor.numpy().tobytes()
dtype_str = "F32"
size = len(raw_bytes)
self.buffer.append({
"key": key,
"data": raw_bytes,
"dtype": dtype_str,
"shape": tensor.shape
})
self.current_bytes += size
if self.current_bytes >= self.max_bytes:
self.flush()
def flush(self):
if not self.buffer: return
self.shard_count += 1
# Proper shard naming including subfolder
filename = f"model-{self.shard_count:05d}.safetensors"
# If subfolder exists, prepend it for the upload path
path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...")
header = {"__metadata__": {"format": "pt"}}
current_offset = 0
for item in self.buffer:
header[item["key"]] = {
"dtype": item["dtype"],
"shape": item["shape"],
"data_offsets": [current_offset, current_offset + len(item["data"])]
}
current_offset += len(item["data"])
self.index_map[item["key"]] = filename # Index map uses relative filename
header_json = json.dumps(header).encode('utf-8')
out_path = self.output_dir / filename
with open(out_path, 'wb') as f:
f.write(struct.pack('<Q', len(header_json)))
f.write(header_json)
for item in self.buffer:
f.write(item["data"])
print(f"Uploading {path_in_repo}...")
api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
os.remove(out_path)
self.buffer = []
self.current_bytes = 0
gc.collect()
def server_side_copy_structure(token, src_repo, dst_repo, ignore_prefix="transformer"):
"""
Copies all files from src_repo to dst_repo EXCEPT those starting with ignore_prefix.
Uses server-side copy (zero local disk usage).
"""
print(f"Scanning {src_repo} for structure cloning...")
try:
files = api.list_repo_files(repo_id=src_repo, token=token)
files_to_copy = []
for f in files:
# Skip the folder we are replacing
if ignore_prefix and f.startswith(ignore_prefix):
continue
# Skip hidden files
if f.startswith("."):
continue
files_to_copy.append(f)
print(f"Found {len(files_to_copy)} files to copy (skipping {ignore_prefix})...")
for f in tqdm(files_to_copy, desc="Server-side Copying"):
try:
# API copy_file is server-side
api.copy_file(
repo_id=src_repo,
filename=f,
target_repo_id=dst_repo,
target_filename=f,
token=token
)
except Exception as e:
print(f"Failed to copy {f}: {e}")
except Exception as e:
print(f"Structure cloning failed: {e}")
def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
cleanup_temp()
login(hf_token)
# 1. Output Setup
try:
api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
except Exception as e: return f"Error creating repo: {e}"
# 2. Server-Side Structure Clone
if structure_repo:
# If we are writing to 'transformer', we ignore existing 'transformer' files in source
ignore = base_subfolder if base_subfolder else None
server_side_copy_structure(hf_token, structure_repo, output_repo, ignore)
# 3. Load LoRA
dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
try:
progress(0.1, desc="Downloading LoRA...")
lora_path = download_file(lora_input, hf_token, filename="adapter.safetensors")
lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
except Exception as e: return f"Error loading LoRA: {e}"
# 4. Stream Process
progress(0.2, desc="Fetching File List...")
files = list_repo_files(repo_id=base_repo, token=hf_token)
input_shards = [f for f in files if f.endswith(".safetensors")]
if base_subfolder:
input_shards = [f for f in input_shards if f.startswith(base_subfolder)]
if not input_shards: return "No base safetensors found."
input_shards.sort()
# Pass base_subfolder to buffer so it knows where to put files
buffer = ShardBuffer(shard_size, TempDir, output_repo, base_subfolder, hf_token)
for i, shard_file in enumerate(input_shards):
progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {shard_file}")
local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
with MemoryEfficientSafeOpen(local_shard) as f:
keys = f.keys()
for k in keys:
v = f.get_tensor(k)
base_stem = get_key_stem(k)
lora_keys = set(lora_pairs.keys())
match = None
if base_stem in lora_keys:
match = lora_pairs[base_stem]
else:
if "to_q" in base_stem:
qkv_stem = base_stem.replace("to_q", "qkv")
if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
elif "to_k" in base_stem:
qkv_stem = base_stem.replace("to_k", "qkv")
if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
elif "to_v" in base_stem:
qkv_stem = base_stem.replace("to_v", "qkv")
if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
if match and "down" in match and "up" in match:
down = match["down"]
up = match["up"]
alpha = match["alpha"]
rank = match["rank"]
scaling = scale * (alpha / rank)
if len(v.shape) == 4 and len(down.shape) == 2:
down = down.unsqueeze(-1).unsqueeze(-1)
up = up.unsqueeze(-1).unsqueeze(-1)
try:
if len(up.shape) == 4:
delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
else:
delta = up @ down
except:
delta = up.T @ down
delta = delta * scaling
valid_delta = True
if delta.shape == v.shape:
pass
elif delta.shape[0] == v.shape[0] * 3:
chunk = v.shape[0]
if "to_q" in k: delta = delta[0:chunk, ...]
elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
elif "to_v" in k: delta = delta[2*chunk:, ...]
else: valid_delta = False
elif delta.numel() == v.numel():
delta = delta.reshape(v.shape)
else:
valid_delta = False
if valid_delta:
v = v.to(dtype)
delta = delta.to(dtype)
v.add_(delta)
del delta
if v.dtype != dtype: v = v.to(dtype)
buffer.add_tensor(k, v)
del v
os.remove(local_shard)
gc.collect()
buffer.flush()
# Upload Index
print("Uploading Index...")
index_data = {"metadata": {"total_size": 0}, "weight_map": buffer.index_map}
index_name = "model.safetensors.index.json"
with open(TempDir / index_name, "w") as f:
json.dump(index_data, f, indent=4)
path_in_repo = f"{base_subfolder}/{index_name}" if base_subfolder else index_name
api.upload_file(path_or_fileobj=TempDir / index_name, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
cleanup_temp()
return f"Done! Merged into {buffer.shard_count} shards at {output_repo}"
# =================================================================================
# TAB 2: EXTRACT LORA
# =================================================================================
def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
org = MemoryEfficientSafeOpen(model_org)
tuned = MemoryEfficientSafeOpen(model_tuned)
lora_sd = {}
print("Calculating diffs...")
for key in tqdm(org.keys()):
if key not in tuned.keys(): continue
mat_org = org.get_tensor(key).float()
mat_tuned = tuned.get_tensor(key).float()
diff = mat_tuned - mat_org
if torch.max(torch.abs(diff)) < 1e-4: continue
out_dim, in_dim = diff.shape[:2]
r = min(rank, in_dim, out_dim)
is_conv = len(diff.shape) == 4
if is_conv: diff = diff.flatten(start_dim=1)
try:
U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
U = U @ torch.diag(S)
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp)
U = U.clamp(-hi_val, hi_val)
Vh = Vh.clamp(-hi_val, hi_val)
if is_conv:
U = U.reshape(out_dim, r, 1, 1)
Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
else:
U = U.reshape(out_dim, r)
Vh = Vh.reshape(r, in_dim)
stem = key.replace(".weight", "")
lora_sd[f"{stem}.lora_up.weight"] = U
lora_sd[f"{stem}.lora_down.weight"] = Vh
lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
except: pass
out = TempDir / "extracted.safetensors"
save_file(lora_sd, out)
return str(out)
def task_extract(hf_token, org, tun, rank, out):
cleanup_temp()
login(hf_token)
try:
p1 = download_file(org, hf_token, filename="org.safetensors")
p2 = download_file(tun, hf_token, filename="tun.safetensors")
f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token)
return "Done"
except Exception as e: return f"Error: {e}"
# =================================================================================
# TAB 3: MERGE ADAPTERS (EMA) with Sigma Rel
# =================================================================================
def sigma_rel_to_gamma(sigma_rel):
t = sigma_rel**-2
coeffs = [1, 7, 16 - t, 12 - t]
roots = np.roots(coeffs)
gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
return gamma
def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo):
cleanup_temp()
login(hf_token)
urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
paths = []
try:
for i, url in enumerate(urls):
paths.append(download_file(url, hf_token, filename=f"a_{i}.safetensors"))
except Exception as e: return f"Download Error: {e}"
if not paths: return "No models found"
# Sort by mtime (proxy for age) or input order. Kohya uses mtime.
# We will trust input order as "oldest to newest" for simplicity in UI context.
base_sd = load_file(paths[0], device="cpu")
for k in base_sd:
if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
gamma = None
if sigma_rel > 0:
gamma = sigma_rel_to_gamma(sigma_rel)
ema_count = len(paths) - 1
for i, path in enumerate(paths[1:]):
print(f"Merging {path}")
# Calculate Beta
if gamma is not None:
t = i + 1
current_beta = (1 - 1 / t) ** (gamma + 1)
else:
current_beta = beta # Fixed beta or interpolation logic could go here
curr = load_file(path, device="cpu")
for k in base_sd:
if k in curr and "alpha" not in k:
base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
out = TempDir / "merged_adapters.safetensors"
save_file(base_sd, out)
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
return "Done"
# =================================================================================
# TAB 4: RESIZE
# =================================================================================
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
return max(1, min(index, len(S) - 1))
def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
cleanup_temp()
login(hf_token)
try:
path = download_file(lora_input, hf_token)
except Exception as e: return f"Error: {e}"
state = load_file(path, device="cpu")
new_state = {}
groups = {}
for k in state:
stem = get_key_stem(k)
simple = k.split(".lora_")[0]
if simple not in groups: groups[simple] = {}
if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
if "alpha" in k: groups[simple]["alpha"] = state[k]
for stem, g in tqdm(groups.items()):
if "down" in g and "up" in g:
down, up = g["down"].float(), g["up"].float()
# Merge
if len(down.shape) == 4:
merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
flat = merged.flatten(1)
else:
merged = up @ down
flat = merged
U, S, Vh = torch.linalg.svd(flat, full_matrices=False)
# Rank Selection
target_rank = int(new_rank)
if dynamic_method == "sv_ratio":
target_rank = index_sv_ratio(S, dynamic_param)
target_rank = min(target_rank, S.shape[0])
U = U[:, :target_rank]
S = S[:target_rank]
U = U @ torch.diag(S)
Vh = Vh[:target_rank, :]
if len(down.shape) == 4:
U = U.reshape(up.shape[0], target_rank, 1, 1)
Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3])
new_state[f"{stem}.lora_down.weight"] = Vh
new_state[f"{stem}.lora_up.weight"] = U
new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
out = TempDir / "resized.safetensors"
save_file(new_state, out)
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token)
return "Done"
# =================================================================================
# UI
# =================================================================================
css = ".container { max-width: 900px; margin: auto; }"
with gr.Blocks() as demo:
gr.Markdown("# 🧰 SOONmerge® LoRA Toolkit")
with gr.Tabs():
with gr.Tab("Merge + Reshard"):
t1_token = gr.Textbox(label="Token", type="password")
t1_base = gr.Textbox(label="Base Repo", value="ostris/Z-Image-De-Turbo")
t1_sub = gr.Textbox(label="Subfolder", value="transformer")
t1_lora = gr.Textbox(label="LoRA")
with gr.Row():
t1_scale = gr.Slider(label="Scale", value=1.0)
t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.5, maximum=10.0, step=0.5)
t1_out = gr.Textbox(label="Output")
t1_struct = gr.Textbox(label="Structure Repo (Copies VAE/TextEnc)", value="Tongyi-MAI/Z-Image-Turbo")
t1_priv = gr.Checkbox(label="Private", value=True)
t1_btn = gr.Button("Merge")
t1_res = gr.Textbox(label="Result")
t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
with gr.Tab("Extract"):
t2_token = gr.Textbox(label="Token", type="password")
t2_org = gr.Textbox(label="Original")
t2_tun = gr.Textbox(label="Tuned")
t2_rank = gr.Number(label="Rank", value=32)
t2_out = gr.Textbox(label="Output")
t2_btn = gr.Button("Extract")
t2_res = gr.Textbox(label="Result")
t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
with gr.Tab("Merge Adapters (EMA)"):
t3_token = gr.Textbox(label="Token", type="password")
t3_urls = gr.Textbox(label="URLs")
with gr.Row():
t3_beta = gr.Slider(label="Beta", value=0.9)
t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.0)
t3_out = gr.Textbox(label="Output")
t3_btn = gr.Button("Merge")
t3_res = gr.Textbox(label="Result")
t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res)
with gr.Tab("Resize"):
t4_token = gr.Textbox(label="Token", type="password")
t4_in = gr.Textbox(label="LoRA")
with gr.Row():
t4_rank = gr.Number(label="Rank", value=8)
t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method")
t4_param = gr.Number(label="Dynamic Param", value=4.0)
t4_out = gr.Textbox(label="Output")
t4_btn = gr.Button("Resize")
t4_res = gr.Textbox(label="Result")
t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
if __name__ == "__main__":
demo.queue().launch(css=css, ssr_mode=False)