import gradio as gr import torch import os import gc import shutil import requests import json import struct import numpy as np import re from pathlib import Path from typing import Dict, Any, Optional, List from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login from safetensors.torch import load_file, save_file from tqdm import tqdm # --- Memory Efficient Safetensors --- class MemoryEfficientSafeOpen: """ Reads safetensors metadata and tensors without mmap, keeping RAM usage low. """ def __init__(self, filename): self.filename = filename self.file = open(filename, "rb") self.header, self.header_size = self._read_header() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() def keys(self) -> list[str]: return [k for k in self.header.keys() if k != "__metadata__"] def metadata(self) -> Dict[str, str]: return self.header.get("__metadata__", {}) def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") metadata = self.header[key] offset_start, offset_end = metadata["data_offsets"] self.file.seek(self.header_size + 8 + offset_start) tensor_bytes = self.file.read(offset_end - offset_start) return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): header_size = struct.unpack("= self.max_bytes: self.flush() def flush(self): if not self.buffer: return self.shard_count += 1 # Naming: prefix-0000X.safetensors # This is standard for indexed loading. filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors" # Proper Subfolder Handling path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...") header = {"__metadata__": {"format": "pt"}} current_offset = 0 for item in self.buffer: header[item["key"]] = { "dtype": item["dtype"], "shape": item["shape"], "data_offsets": [current_offset, current_offset + len(item["data"])] } current_offset += len(item["data"]) self.index_map[item["key"]] = filename # Relative filename for index header_json = json.dumps(header).encode('utf-8') out_path = self.output_dir / filename with open(out_path, 'wb') as f: f.write(struct.pack(' force 'diffusion_pytorch_model' if output_subfolder in ["transformer", "unet"]: filename_prefix = "diffusion_pytorch_model" index_filename = "diffusion_pytorch_model.safetensors.index.json" # 2. Check input file naming -> adopt input convention elif "diffusion_pytorch_model" in os.path.basename(input_shards[0]): filename_prefix = "diffusion_pytorch_model" index_filename = "diffusion_pytorch_model.safetensors.index.json" # 3. Default to LLM style else: filename_prefix = "model" index_filename = "model.safetensors.index.json" print(f"Naming scheme: {filename_prefix} (Index: {index_filename})") # 4. Load LoRA dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32 try: progress(0.15, desc="Downloading LoRA...") lora_path = download_lora_smart(lora_input, hf_token) lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype) except Exception as e: return f"Error loading LoRA: {e}" # 5. Stream Process buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix) for i, shard_file in enumerate(input_shards): progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}") with MemoryEfficientSafeOpen(shard_file) as f: keys = f.keys() for k in keys: v = f.get_tensor(k) base_stem = get_key_stem(k) lora_keys = set(lora_pairs.keys()) match = None if base_stem in lora_keys: match = lora_pairs[base_stem] # QKV Heuristics (Z-Image/Flux specific) if not match: if "to_q" in base_stem: qkv_stem = base_stem.replace("to_q", "qkv") if qkv_stem in lora_keys: match = lora_pairs[qkv_stem] elif "to_k" in base_stem: qkv_stem = base_stem.replace("to_k", "qkv") if qkv_stem in lora_keys: match = lora_pairs[qkv_stem] elif "to_v" in base_stem: qkv_stem = base_stem.replace("to_v", "qkv") if qkv_stem in lora_keys: match = lora_pairs[qkv_stem] if match and "down" in match and "up" in match: down = match["down"] up = match["up"] scaling = scale * (match["alpha"] / match["rank"]) if len(v.shape) == 4 and len(down.shape) == 2: down = down.unsqueeze(-1).unsqueeze(-1) up = up.unsqueeze(-1).unsqueeze(-1) try: if len(up.shape) == 4: delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) else: delta = up @ down except: delta = up.T @ down delta = delta * scaling valid_delta = True if delta.shape == v.shape: pass elif delta.shape[0] == v.shape[0] * 3: chunk = v.shape[0] if "to_q" in k: delta = delta[0:chunk, ...] elif "to_k" in k: delta = delta[chunk:2*chunk, ...] elif "to_v" in k: delta = delta[2*chunk:, ...] else: valid_delta = False elif delta.numel() == v.numel(): delta = delta.reshape(v.shape) else: valid_delta = False if valid_delta: v = v.to(dtype) delta = delta.to(dtype) v.add_(delta) del delta if v.dtype != dtype: v = v.to(dtype) buffer.add_tensor(k, v) del v os.remove(shard_file) gc.collect() buffer.flush() # 6. Upload Index (Now using correct total_size) print(f"Uploading Index: {index_filename} (Total Size: {buffer.total_size})") index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map} with open(TempDir / index_filename, "w") as f: json.dump(index_data, f, indent=4) path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token) cleanup_temp() return f"Done! Merged into {buffer.shard_count} shards at {output_repo}" # ================================================================================= # TAB 2: EXTRACT LORA # ================================================================================= def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp): org = MemoryEfficientSafeOpen(model_org) tuned = MemoryEfficientSafeOpen(model_tuned) lora_sd = {} print("Calculating diffs...") for key in tqdm(org.keys()): if key not in tuned.keys(): continue mat_org = org.get_tensor(key).float() mat_tuned = tuned.get_tensor(key).float() diff = mat_tuned - mat_org if torch.max(torch.abs(diff)) < 1e-4: continue out_dim, in_dim = diff.shape[:2] r = min(rank, in_dim, out_dim) is_conv = len(diff.shape) == 4 if is_conv: diff = diff.flatten(start_dim=1) try: U, S, Vh = torch.linalg.svd(diff, full_matrices=False) U, S, Vh = U[:, :r], S[:r], Vh[:r, :] U = U @ torch.diag(S) dist = torch.cat([U.flatten(), Vh.flatten()]) hi_val = torch.quantile(dist, clamp) U = U.clamp(-hi_val, hi_val) Vh = Vh.clamp(-hi_val, hi_val) if is_conv: U = U.reshape(out_dim, r, 1, 1) Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3]) else: U = U.reshape(out_dim, r) Vh = Vh.reshape(r, in_dim) stem = key.replace(".weight", "") lora_sd[f"{stem}.lora_up.weight"] = U lora_sd[f"{stem}.lora_down.weight"] = Vh lora_sd[f"{stem}.alpha"] = torch.tensor(r).float() except: pass out = TempDir / "extracted.safetensors" save_file(lora_sd, out) return str(out) def task_extract(hf_token, org, tun, rank, out): cleanup_temp() login(hf_token) try: p1 = download_file(org, hf_token, filename="org.safetensors") p2 = download_file(tun, hf_token, filename="tun.safetensors") f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99) api.create_repo(repo_id=out, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token) return "Done" except Exception as e: return f"Error: {e}" # ================================================================================= # TAB 3: MERGE ADAPTERS (EMA) with Sigma Rel # ================================================================================= def sigma_rel_to_gamma(sigma_rel): t = sigma_rel**-2 coeffs = [1, 7, 16 - t, 12 - t] roots = np.roots(coeffs) gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max() return gamma def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo): cleanup_temp() login(hf_token) urls = [u.strip() for u in lora_urls.split(",") if u.strip()] paths = [] try: for i, url in enumerate(urls): paths.append(download_file(url, hf_token, filename=f"a_{i}.safetensors")) except Exception as e: return f"Download Error: {e}" if not paths: return "No models found" base_sd = load_file(paths[0], device="cpu") for k in base_sd: if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float() gamma = None if sigma_rel > 0: gamma = sigma_rel_to_gamma(sigma_rel) for i, path in enumerate(paths[1:]): print(f"Merging {path}") if gamma is not None: t = i + 1 current_beta = (1 - 1 / t) ** (gamma + 1) else: current_beta = beta curr = load_file(path, device="cpu") for k in base_sd: if k in curr and "alpha" not in k: base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta) out = TempDir / "merged_adapters.safetensors" save_file(base_sd, out) api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token) return "Done" # ================================================================================= # TAB 4: RESIZE # ================================================================================= def index_sv_ratio(S, target): max_sv = S[0] min_sv = max_sv / target index = int(torch.sum(S > min_sv).item()) return max(1, min(index, len(S) - 1)) def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo): cleanup_temp() login(hf_token) try: path = download_file(lora_input, hf_token) except Exception as e: return f"Error: {e}" state = load_file(path, device="cpu") new_state = {} groups = {} for k in state: stem = get_key_stem(k) simple = k.split(".lora_")[0] if simple not in groups: groups[simple] = {} if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k] if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k] if "alpha" in k: groups[simple]["alpha"] = state[k] for stem, g in tqdm(groups.items()): if "down" in g and "up" in g: down, up = g["down"].float(), g["up"].float() if len(down.shape) == 4: merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) flat = merged.flatten(1) else: merged = up @ down flat = merged U, S, Vh = torch.linalg.svd(flat, full_matrices=False) target_rank = int(new_rank) if dynamic_method == "sv_ratio": target_rank = index_sv_ratio(S, dynamic_param) target_rank = min(target_rank, S.shape[0]) U = U[:, :target_rank] S = S[:target_rank] U = U @ torch.diag(S) Vh = Vh[:target_rank, :] if len(down.shape) == 4: U = U.reshape(up.shape[0], target_rank, 1, 1) Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3]) new_state[f"{stem}.lora_down.weight"] = Vh new_state[f"{stem}.lora_up.weight"] = U new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float() out = TempDir / "resized.safetensors" save_file(new_state, out) api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token) return "Done" # ================================================================================= # UI # ================================================================================= css = ".container { max-width: 900px; margin: auto; }" with gr.Blocks() as demo: gr.Markdown("# 🧰SOONmerge® LoRA Toolkit") with gr.Tabs(): with gr.Tab("Merge to Base + Reshard Output"): t1_token = gr.Textbox(label="Token", type="password") t1_base = gr.Textbox(label="Base Repo (Diffusers)", value="ostris/Z-Image-De-Turbo") t1_sub = gr.Textbox(label="Subfolder", value="transformer") t1_lora = gr.Textbox(label="LoRA Direct Link", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors") with gr.Row(): t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1) t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision") t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1) t1_out = gr.Textbox(label="Output Repo") t1_struct = gr.Textbox(label="Diffusers Extras (Copies VAE/TextEnc/etc)", value="Tongyi-MAI/Z-Image-Turbo") t1_priv = gr.Checkbox(label="Private", value=True) t1_btn = gr.Button("Merge") t1_res = gr.Textbox(label="Result") t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res) with gr.Tab("Extract Adapter"): t2_token = gr.Textbox(label="Token", type="password") t2_org = gr.Textbox(label="Original Model") t2_tun = gr.Textbox(label="Tuned Model") t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1) t2_out = gr.Textbox(label="Output Repo") t2_btn = gr.Button("Extract") t2_res = gr.Textbox(label="Result") t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res) with gr.Tab("Merge Multiple Adapters"): t3_token = gr.Textbox(label="Token", type="password") t3_urls = gr.Textbox(label="URLs") with gr.Row(): t3_beta = gr.Slider(label="Beta", value=0.95, minimum=0.01, maximum=1.00, step=0.01) t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.21, minimum=0.01, maximum=1.00, step=0.01) t3_out = gr.Textbox(label="Output Repo") t3_btn = gr.Button("Merge") t3_res = gr.Textbox(label="Result") t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res) with gr.Tab("Resize Adapter"): t4_token = gr.Textbox(label="Token", type="password") t4_in = gr.Textbox(label="LoRA") with gr.Row(): t4_rank = gr.Number(label="To Rank (Lower Only!)", value=8, minimum=1, maximum=256, step=1) t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method") t4_param = gr.Number(label="Dynamic Param", value=4.0) t4_out = gr.Textbox(label="Output") t4_btn = gr.Button("Resize") t4_res = gr.Textbox(label="Result") t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res) if __name__ == "__main__": demo.queue().launch(css=css, ssr_mode=False)