import gradio as gr import torch import os import gc from merge_utils import execute_mergekit import shutil import requests import json import struct import numpy as np import re import yaml from pathlib import Path from typing import Dict, Any, Optional, List from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login from safetensors.torch import load_file, save_file from tqdm import tqdm # --- Memory Efficient Safetensors --- class MemoryEfficientSafeOpen: def __init__(self, filename): self.filename = filename self.file = open(filename, "rb") self.header, self.header_size = self._read_header() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() def keys(self) -> list[str]: return [k for k in self.header.keys() if k != "__metadata__"] def metadata(self) -> Dict[str, str]: return self.header.get("__metadata__", {}) def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") metadata = self.header[key] offset_start, offset_end = metadata["data_offsets"] self.file.seek(self.header_size + 8 + offset_start) tensor_bytes = self.file.read(offset_end - offset_start) return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): header_size = struct.unpack("= 2: parts = input_str.split("/") repo_id = f"{parts[0]}/{parts[1]}" filename = "/".join(parts[2:]) hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir) found = list(TempDir.rglob(filename.split("/")[-1]))[0] if found != local_path: shutil.move(found, local_path) return local_path # Standard Auto-Discovery candidates = ["adapter_model.safetensors", "model.safetensors"] files = list_repo_files(repo_id=input_str, token=token) target = next((f for f in files if f in candidates), None) if not target: safes = [f for f in files if f.endswith(".safetensors")] if safes: target = safes[0] if not target: raise ValueError("No safetensors found") hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir) found = list(TempDir.rglob(target.split("/")[-1]))[0] if found != local_path: shutil.move(found, local_path) return local_path except Exception as e: # 3. Last Resort: Raw Requests (For non-HF links) if input_str.startswith("http"): try: headers = {"Authorization": f"Bearer {token}"} if token else {} r = requests.get(input_str, stream=True, headers=headers, timeout=60) r.raise_for_status() with open(local_path, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) return local_path except Exception as req_e: raise ValueError(f"All download methods failed.\nRepo Logic Error: {e}\nURL Logic Error: {req_e}") raise e def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16): print(f"Loading LoRA from {lora_path}...") state_dict = load_file(lora_path, device="cpu") pairs = {} alphas = {} for k, v in state_dict.items(): stem = get_key_stem(k) if "alpha" in k: alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v else: if stem not in pairs: pairs[stem] = {} if "lora_down" in k or "lora_A" in k: pairs[stem]["down"] = v.to(dtype=precision_dtype) pairs[stem]["rank"] = v.shape[0] elif "lora_up" in k or "lora_B" in k: pairs[stem]["up"] = v.to(dtype=precision_dtype) for stem in pairs: pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0))) return pairs class ShardBuffer: def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"): self.max_bytes = int(max_size_gb * 1024**3) self.output_dir = output_dir self.output_repo = output_repo self.subfolder = subfolder self.hf_token = hf_token self.filename_prefix = filename_prefix self.buffer = [] self.current_bytes = 0 self.shard_count = 0 self.index_map = {} self.total_size = 0 def add_tensor(self, key, tensor): if tensor.dtype == torch.bfloat16: raw_bytes = tensor.view(torch.int16).numpy().tobytes() dtype_str = "BF16" elif tensor.dtype == torch.float16: raw_bytes = tensor.numpy().tobytes() dtype_str = "F16" else: raw_bytes = tensor.numpy().tobytes() dtype_str = "F32" size = len(raw_bytes) self.buffer.append({ "key": key, "data": raw_bytes, "dtype": dtype_str, "shape": tensor.shape }) self.current_bytes += size self.total_size += size if self.current_bytes >= self.max_bytes: self.flush() def flush(self): if not self.buffer: return self.shard_count += 1 filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors" path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...") header = {"__metadata__": {"format": "pt"}} current_offset = 0 for item in self.buffer: header[item["key"]] = { "dtype": item["dtype"], "shape": item["shape"], "data_offsets": [current_offset, current_offset + len(item["data"])] } current_offset += len(item["data"]) self.index_map[item["key"]] = filename header_json = json.dumps(header).encode('utf-8') out_path = self.output_dir / filename with open(out_path, 'wb') as f: f.write(struct.pack(' downloads specific file. 2. If input is a Repo ID -> scans for diffusers format (unet/transformer) or standard safetensors. """ print(f"Resolving model input: {input_str}") # --- STRATEGY A: Direct URL --- repo_id_from_url, filename_from_url = parse_hf_url(input_str) if repo_id_from_url and filename_from_url: print(f"Detected Direct Link. Repo: {repo_id_from_url}, File: {filename_from_url}") local_path = TempDir / os.path.basename(filename_from_url) # Clean up previous download if name conflicts if local_path.exists(): os.remove(local_path) try: hf_hub_download(repo_id=repo_id_from_url, filename=filename_from_url, token=token, local_dir=TempDir) # Find where it landed (handling subfolders in local_dir) found = list(TempDir.rglob(os.path.basename(filename_from_url)))[0] return found except Exception as e: print(f"URL Download failed: {e}. Trying fallback...") # --- STRATEGY B: Repo Discovery (Auto-Detect) --- # If we are here, input_str is treated as a Repo ID (e.g. "ostris/Z-Image-De-Turbo") print(f"Scanning Repo {input_str} for model weights...") try: files = list_repo_files(repo_id=input_str, token=token) except Exception as e: raise ValueError(f"Failed to list repo '{input_str}'. If this is a URL, ensure it is formatted correctly. Error: {e}") # Priority list for diffusers vs single file priorities = [ "transformer/diffusion_pytorch_model.safetensors", "unet/diffusion_pytorch_model.safetensors", "model.safetensors", # Fallback to any safetensors that isn't an adapter or lora lambda f: f.endswith(".safetensors") and "lora" not in f and "adapter" not in f and "extracted" not in f ] target_file = None for p in priorities: if callable(p): candidates = [f for f in files if p(f)] if candidates: # Pick the largest file if multiple candidates (heuristic for "main" model) target_file = candidates[0] break elif p in files: target_file = p break if not target_file: raise ValueError(f"Could not find a valid model weight file in {input_str}. Ensure it contains .safetensors weights.") print(f"Downloading auto-detected weight file: {target_file}") hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir) # Locate actual path found = list(TempDir.rglob(os.path.basename(target_file)))[0] return found def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp): org = MemoryEfficientSafeOpen(model_org) tuned = MemoryEfficientSafeOpen(model_tuned) lora_sd = {} print("Calculating diffs & extracting LoRA...") # Get intersection of keys keys = set(org.keys()).intersection(set(tuned.keys())) for key in tqdm(keys, desc="Extracting"): # Skip integer buffers/metadata if "num_batches_tracked" in key or "running_mean" in key or "running_var" in key: continue mat_org = org.get_tensor(key).float() mat_tuned = tuned.get_tensor(key).float() # Skip if shapes mismatch (shouldn't happen if models match) if mat_org.shape != mat_tuned.shape: continue diff = mat_tuned - mat_org # Skip if no difference if torch.max(torch.abs(diff)) < 1e-4: continue out_dim = diff.shape[0] in_dim = diff.shape[1] if len(diff.shape) > 1 else 1 r = min(rank, in_dim, out_dim) is_conv = len(diff.shape) == 4 if is_conv: diff = diff.flatten(start_dim=1) elif len(diff.shape) == 1: diff = diff.unsqueeze(1) # Handle biases if needed try: # Use svd_lowrank for massive speedup on CPU vs linalg.svd U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4) Vh = V.t() U = U[:, :r] S = S[:r] Vh = Vh[:r, :] # Merge S into U for standard LoRA format U = U @ torch.diag(S) # Clamp outliers dist = torch.cat([U.flatten(), Vh.flatten()]) hi_val = torch.quantile(torch.abs(dist), clamp) if hi_val > 0: U = U.clamp(-hi_val, hi_val) Vh = Vh.clamp(-hi_val, hi_val) if is_conv: U = U.reshape(out_dim, r, 1, 1) Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3]) else: U = U.reshape(out_dim, r) Vh = Vh.reshape(r, in_dim) stem = key.replace(".weight", "") lora_sd[f"{stem}.lora_up.weight"] = U.contiguous() lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous() lora_sd[f"{stem}.alpha"] = torch.tensor(r).float() except Exception as e: print(f"Skipping {key} due to error: {e}") pass out = TempDir / "extracted.safetensors" save_file(lora_sd, out) return str(out) def task_extract(hf_token, org, tun, rank, out): cleanup_temp() if hf_token: login(hf_token.strip()) try: print("Downloading Original Model...") p1 = identify_and_download_model(org, hf_token) print("Downloading Tuned Model...") p2 = identify_and_download_model(tun, hf_token) f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99) api.create_repo(repo_id=out, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token) return "Done! Extracted to " + out except Exception as e: return f"Error: {e}" # ================================================================================= # TAB 3: MERGE ADAPTERS (Multi-Method) # ================================================================================= def load_full_state_dict(path): """Loads a safetensor file and cleans keys for easier processing.""" raw = load_file(path, device="cpu") cleaned = {} for k, v in raw.items(): # Map common keys to standard "lora_up/lora_down" if "lora_A" in k: new_k = k.replace("lora_A", "lora_down") elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up") else: new_k = k cleaned[new_k] = v.float() return cleaned # --- Original EMA Method --- def sigma_rel_to_gamma(sigma_rel): t = sigma_rel**-2 coeffs = [1, 7, 16 - t, 12 - t] roots = np.roots(coeffs) gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max() return gamma def merge_lora_iterative_ema(paths, beta, sigma_rel): print("Executing Iterative EMA Merge (Original Method)...") base_sd = load_file(paths[0], device="cpu") for k in base_sd: if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float() gamma = None if sigma_rel > 0: gamma = sigma_rel_to_gamma(sigma_rel) for i, path in enumerate(paths[1:]): print(f"Merging {path}") if gamma is not None: t = i + 1 current_beta = (1 - 1 / t) ** (gamma + 1) else: current_beta = beta curr = load_file(path, device="cpu") for k in base_sd: if k in curr and "alpha" not in k: base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta) return base_sd # --- New Concatenation Method (DiffSynth) --- def merge_lora_concatenation(adapter_states, weights): """ DiffSynth Method: Concatenates ranks. New Rank = sum(ranks). Lossless merging. """ print("Executing Concatenation Merge (Rank Summation)...") merged_state = {} # Identify all stems (layers) present across all adapters all_stems = set() for state in adapter_states: for k in state.keys(): stem = k.split(".lora_")[0] if "lora_" in k: all_stems.add(stem) for stem in tqdm(all_stems, desc="Concatenating Layers"): down_list = [] up_list = [] alpha_sum = 0.0 for i, state in enumerate(adapter_states): w = weights[i] down_key = f"{stem}.lora_down.weight" up_key = f"{stem}.lora_up.weight" alpha_key = f"{stem}.alpha" if down_key in state and up_key in state: d = state[down_key] u = state[up_key] * w # weighted contribution applied to UP down_list.append(d) up_list.append(u) if alpha_key in state: alpha_sum += state[alpha_key].item() else: alpha_sum += d.shape[0] if down_list and up_list: # Concat Down (A) along dim 0 (output of A, input to B) - Wait, lora_A is (rank, in) # Concat Up (B) along dim 1 (input of B) - lora_B is (out, rank) # Reference: DiffSynth code: lora_A = concat(tensors_A, dim=0), lora_B = concat(tensors_B, dim=1) new_down = torch.cat(down_list, dim=0) # (sum_rank, in) new_up = torch.cat(up_list, dim=1) # (out, sum_rank) merged_state[f"{stem}.lora_down.weight"] = new_down.contiguous() merged_state[f"{stem}.lora_up.weight"] = new_up.contiguous() merged_state[f"{stem}.alpha"] = torch.tensor(alpha_sum) return merged_state # --- New SVD/Task Arithmetic Method --- def merge_lora_svd(adapter_states, weights, target_rank): """ SVD / Task Arithmetic Method: 1. Calculate Delta W for each adapter: dW = B @ A 2. Sum Delta Ws: Total dW = sum(weight_i * dW_i) 3. SVD(Total dW) -> New B, New A at target_rank """ print(f"Executing SVD Merge (Target Rank: {target_rank})...") merged_state = {} all_stems = set() for state in adapter_states: for k in state.keys(): stem = k.split(".lora_")[0] if "lora_" in k: all_stems.add(stem) for stem in tqdm(all_stems, desc="SVD Merging Layers"): total_delta = None valid_layer = False for i, state in enumerate(adapter_states): w = weights[i] down_key = f"{stem}.lora_down.weight" up_key = f"{stem}.lora_up.weight" alpha_key = f"{stem}.alpha" if down_key in state and up_key in state: down = state[down_key] up = state[up_key] alpha = state[alpha_key].item() if alpha_key in state else down.shape[0] rank = down.shape[0] scale = (alpha / rank) * w # Reconstruct Delta if len(down.shape) == 4: # Conv2d d_flat = down.flatten(start_dim=1) u_flat = up.flatten(start_dim=1) delta = (u_flat @ d_flat).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) else: delta = up @ down delta = delta * scale if total_delta is None: total_delta = delta valid_layer = True else: if total_delta.shape == delta.shape: total_delta += delta else: print(f"Shape mismatch in {stem}, skipping.") if valid_layer and total_delta is not None: out_dim = total_delta.shape[0] in_dim = total_delta.shape[1] is_conv = len(total_delta.shape) == 4 if is_conv: flat_delta = total_delta.flatten(start_dim=1) else: flat_delta = total_delta try: U, S, V = torch.svd_lowrank(flat_delta, q=target_rank + 4, niter=4) Vh = V.t() U = U[:, :target_rank] S = S[:target_rank] Vh = Vh[:target_rank, :] U = U @ torch.diag(S) if is_conv: U = U.reshape(out_dim, target_rank, 1, 1) Vh = Vh.reshape(target_rank, in_dim, total_delta.shape[2], total_delta.shape[3]) else: U = U.reshape(out_dim, target_rank) Vh = Vh.reshape(target_rank, in_dim) merged_state[f"{stem}.lora_down.weight"] = Vh.contiguous() merged_state[f"{stem}.lora_up.weight"] = U.contiguous() merged_state[f"{stem}.alpha"] = torch.tensor(target_rank).float() except Exception as e: print(f"SVD Failed for {stem}: {e}") return merged_state def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private): cleanup_temp() if hf_token: login(hf_token.strip()) if not out_repo or not out_repo.strip(): return "Error: Output Repo cannot be empty." # 1. Parse Inputs (Multi-line support) raw_lines = inputs_text.replace(" ", "\n").split('\n') urls = [line.strip() for line in raw_lines if line.strip()] if len(urls) < 2: return "Error: Please provide at least 2 adapters." # 2. Parse Weights (for SVD/Concatenation) try: if not weight_str.strip(): weights = [1.0] * len(urls) else: weights = [float(w.strip()) for w in weight_str.split(',')] # Broadcast or Truncate if len(weights) < len(urls): weights += [1.0] * (len(urls) - len(weights)) else: weights = weights[:len(urls)] except: return "Error parsing weights. Use format: 1.0, 0.5, 0.8" # 3. Download All paths = [] try: for url in tqdm(urls, desc="Downloading Adapters"): paths.append(download_lora_smart(url, hf_token)) except Exception as e: return f"Download Error: {e}" merged = None # 4. Execute Selected Method if "Iterative EMA" in method: # Calls the original method logic exactly merged = merge_lora_iterative_ema(paths, beta, sigma_rel) else: # For new methods, we load everything upfront states = [load_full_state_dict(p) for p in paths] if "Concatenation" in method: merged = merge_lora_concatenation(states, weights) elif "SVD" in method: merged = merge_lora_svd(states, weights, int(target_rank)) if not merged: return "Merge failed (Result empty)." # 5. Save & Upload out = TempDir / "merged_adapters.safetensors" save_file(merged, out) try: api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token) return f"Success! Merged to {out_repo}" except Exception as e: return f"Upload Error: {e}" # ================================================================================= # TAB 4: RESIZE (CPU Optimized) # ================================================================================= def index_sv_cumulative(S, target): """Cumulative sum retention.""" original_sum = float(torch.sum(S)) cumulative_sums = torch.cumsum(S, dim=0) / original_sum index = int(torch.searchsorted(cumulative_sums, target)) + 1 index = max(1, min(index, len(S) - 1)) return index def index_sv_fro(S, target): """Frobenius norm retention (squared sum).""" S_squared = S.pow(2) S_fro_sq = float(torch.sum(S_squared)) sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 index = max(1, min(index, len(S) - 1)) return index def index_sv_ratio(S, target): """Ratio between max and min singular value.""" max_sv = S[0] min_sv = max_sv / target index = int(torch.sum(S > min_sv).item()) index = max(1, min(index, len(S) - 1)) return index def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo): cleanup_temp() if not hf_token: return "Error: Token required" login(hf_token.strip()) try: path = download_lora_smart(lora_input, hf_token) except Exception as e: return f"Error: {e}" state = load_file(path, device="cpu") new_state = {} groups = {} for k in state: stem = get_key_stem(k) simple = k.split(".lora_")[0] if simple not in groups: groups[simple] = {} if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k] if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k] if "alpha" in k: groups[simple]["alpha"] = state[k] print(f"Resizing {len(groups)} blocks...") # Pre-parse user settings target_rank_limit = int(new_rank) if dynamic_method == "None": dynamic_method = None for stem, g in tqdm(groups.items()): if "down" in g and "up" in g: down, up = g["down"].float(), g["up"].float() # 1. Merge Up/Down to get full weight delta if len(down.shape) == 4: merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) flat = merged.flatten(1) else: merged = up @ down flat = merged # 2. FAST SVD (svd_lowrank) # Use the "To Rank" input as a computational hard limit + buffer. # This ensures we don't compute expensive full SVD for massive layers. q_limit = target_rank_limit + 32 # Buffer to allow dynamic methods some wiggle room before truncation q = min(q_limit, min(flat.shape)) U, S, V = torch.svd_lowrank(flat, q=q) Vh = V.t() # 3. Dynamic Rank Selection calculated_rank = target_rank_limit if dynamic_method == "sv_ratio": calculated_rank = index_sv_ratio(S, dynamic_param) elif dynamic_method == "sv_cumulative": calculated_rank = index_sv_cumulative(S, dynamic_param) elif dynamic_method == "sv_fro": calculated_rank = index_sv_fro(S, dynamic_param) # Apply Hard Limit (User's "To Rank") final_rank = min(calculated_rank, target_rank_limit, S.shape[0]) # 4. Truncate U = U[:, :final_rank] S = S[:final_rank] Vh = Vh[:final_rank, :] # 5. Reconstruct Up Matrix (Absorb S into U) U = U @ torch.diag(S) if len(down.shape) == 4: U = U.reshape(up.shape[0], final_rank, 1, 1) Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3]) # 6. Save (FIX: Enforce contiguous memory layout) new_state[f"{stem}.lora_down.weight"] = Vh.contiguous() new_state[f"{stem}.lora_up.weight"] = U.contiguous() new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float() out = TempDir / "shrunken_.safetensors" # safetensors requires contiguous tensors save_file(new_state, out) api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token) return "Done" # ================================================================================= # NEW TAB 5: FULL MODEL MERGER (MergeKit GUI Wrapper) # ================================================================================= def task_full_model_merge(hf_token, models_text, method, dtype, base, weights, density, layer_ranges, tok_src, shard_size, out_repo, private): cleanup_temp() if not hf_token or not out_repo: return "Error: Token and Output Repo required." login(hf_token.strip()) model_list = [m.strip() for m in models_text.split('\n') if m.strip()] if len(model_list) < 2: return "Error: Minimum 2 models required." # Parse Weights try: w_list = [float(w.strip()) for w in weights.split(',')] if weights else [1.0] * len(model_list) except: return "Error: Weights must be comma-separated numbers." config = build_full_merge_config( method=method, models=models, base_model=base if base else model_list[0], weights=weights_text, density=density, dtype=dtype, tokenizer_source=tok_src, layer_ranges=layer_ranges ) for i, m in enumerate(model_list): m_params = {"model": m, "parameters": {"weight": w_list[i] if i < len(w_list) else 1.0}} if method.lower() in ["ties", "dare_ties", "dare_linear"]: m_params["parameters"]["density"] = density config["models"].append(m_params) out_path = TempDir / "merged_model" try: # Pass shard size to our execute_mergekit helper execute_mergekit(config, str(out_path), shard_size) api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token) api.upload_folder(folder_path=str(out_path), repo_id=out_repo, token=hf_token) return f"Success! Model merged and uploaded to {out_repo}" except Exception as e: return f"Merge Error: {e}" # ================================================================================= # NEW TAB 6: MIXTURE OF EXPERTS (MoE Creator) # ================================================================================= def task_create_moe(hf_token, dtype, shard_size, base_model, experts_text, gate_mode, tok_src, out_repo, private): cleanup_temp() if not hf_token or not out_repo: return "Error: Token and Output Repo required." login(hf_token.strip()) experts = [e.strip() for e in experts_text.split('\n') if e.strip()] if not experts: return "Error: At least one expert model is required." config = { "method": "moe", "base_model": base_model, "dtype": dtype, "tokenizer_source": tok_src, "params": {"gate_mode": gate_mode}, "experts": [{"source_model": exp} for exp in experts] } out_path = TempDir / "moe_model" try: execute_mergekit(config, str(out_path), shard_size) api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token) api.upload_folder(folder_path=str(out_path), repo_id=out_repo, token=hf_token) return f"Success! MoE model uploaded to {out_repo}" except Exception as e: return f"MoE Build Error: {e}" # ================================================================================= # UI # ================================================================================= css = ".container { max-width: 900px; margin: auto; }" with gr.Blocks() as demo: title = gr.HTML( """

SOONmerge® Transform Transformers for FREE!

""", elem_id="title", ) gr.Markdown("# 🧰SOONmerge® LoRA Toolkit") with gr.Tabs(): with gr.Tab("Merge to Base Model + Reshard Output"): t1_token = gr.Textbox(label="Token", type="password") t1_base = gr.Textbox(label="Base Repo", value="name/repo") t1_sub = gr.Textbox(label="Subfolder (Optional)", value="") t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors") with gr.Row(): t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1) t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision") t1_shard = gr.Slider(label="Max Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1) t1_out = gr.Textbox(label="Output Repo") t1_struct = gr.Textbox(label="Extras Source (copies configs/components/etc)", value="name/repo") t1_priv = gr.Checkbox(label="Private", value=True) t1_btn = gr.Button("Merge") t1_res = gr.Textbox(label="Result") t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res) with gr.Tab("Extract Adapter"): t2_token = gr.Textbox(label="Token", type="password") t2_org = gr.Textbox(label="Original Model") t2_tun = gr.Textbox(label="Tuned or Homologous Model") t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1) t2_out = gr.Textbox(label="Output Repo") t2_btn = gr.Button("Extract") t2_res = gr.Textbox(label="Result") t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res) with gr.Tab("Merge Adapters/Weights"): gr.Markdown("### Batch Adapter Merging") t3_token = gr.Textbox(label="Token", type="password") t3_urls = gr.TextArea(label="Adapter URLs/Repos (one per line, or space-separated)", placeholder="user/lora1\nhttps://hf.co/user/lora2.safetensors\n...") with gr.Row(): t3_method = gr.Dropdown( ["Iterative EMA (Linear w/ Beta/Sigma coefficient)", "Concatenation (MOE-like weights-stack)", "SVD Fusion (Task Arithmetic/Compressed)"], value="Iterative EMA (Linear w/ Beta/Sigma coefficient)", label="Merge Method" ) with gr.Row(): t3_weights = gr.Textbox(label="Weights (comma-separated) – for Concat/SVD", placeholder="1.0, 0.5, 0.8...") t3_rank = gr.Number(label="Target Rank – For SVD only", value=128, minimum=4, maximum=1024) with gr.Row(): t3_beta = gr.Slider(label="Beta – for linear/post-hoc EMA", value=0.95, minimum=0.01, maximum=1.00, step=0.01) t3_sigma = gr.Slider(label="Sigma Rel – for linear/post-hoc EMA", value=0.21, minimum=0.01, maximum=1.00, step=0.01) t3_out = gr.Textbox(label="Output Repo") t3_priv = gr.Checkbox(label="Private Output", value=True) t3_btn = gr.Button("Merge") t3_res = gr.Textbox(label="Result") t3_btn.click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], t3_res) with gr.Tab("Resize Adapter"): t4_token = gr.Textbox(label="Token", type="password") t4_in = gr.Textbox(label="LoRA") with gr.Row(): t4_rank = gr.Number(label="To Rank (Safety Ceiling)", value=8, minimum=1, maximum=512, step=1) t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None", label="Dynamic Method") t4_param = gr.Number(label="Dynamic Param", value=0.9) gr.Markdown( """ ### 📉 Dynamic Resizing Guide These methods intelligently determine the best rank per layer. * **sv_ratio (Relative Strength):** Keeps features that are at least `1/Param` as strong as the main feature. **Param must be >= 2**. (e.g. 2 = keep features half as strong as top). * **sv_fro (Visual Information Density):** Preserves `Param%` of the total information content (Frobenius Norm) of the layer. **Param between 0.0 and 1.0** (e.g. 0.9 = 90% info retention). * **sv_cumulative (Cumulative Sum):** Preserves weights that sum up to `Param%` of the total strength. **Param between 0.0 and 1.0**. * **⚠️ Safety Ceiling:** The **"To Rank"** slider acts as a hard limit. Even if a dynamic method wants a higher rank, it will be cut down to this number to keep file sizes small. """ ) t4_out = gr.Textbox(label="Output") t4_btn = gr.Button("Resize") t4_res = gr.Textbox(label="Result") t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res) # ================================================================================= # UPDATED TAB 5: FULL MODEL MERGER (MergeKit Engine) # ================================================================================= with gr.Tab("Full Model Merge (MergeKit)"): gr.Markdown("### 🧩 Multi-Model Weight Fusion") with gr.Row(): t5_token = gr.Textbox(label="HF Token", type="password") t5_method = gr.Dropdown(["Linear", "SLERP", "TIES", "DARE_TIES", "DARE_LINEAR", "Model_Stock"], value="TIES", label="Merge Method") t5_dtype = gr.Radio(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision") t5_models = gr.TextArea(label="Models to Merge (One Repo ID per line)", placeholder="repo/model-a\nrepo/model-b\nrepo/model-c...") with gr.Row(): t5_base = gr.Textbox(label="Base Model (Required for TIES/DARE)", placeholder="repo/base-model") t5_shard = gr.Slider(0.5, 10, 2.0, step=0.5, label="Max Shard Size (GB)") with gr.Accordion("Advanced Parametrization", open=False): with gr.Row(): t5_weights = gr.Textbox(label="Weights (Comma separated)", placeholder="1.0, 0.5, 0.3") t5_density = gr.Slider(0, 1, 0.5, label="Density (TIES/DARE)") with gr.Row(): t5_layers = gr.Textbox(label="Layer Ranges (JSON Format)", placeholder='[{"start": 0, "end": 32}]') t5_tok_src = gr.Dropdown(["base", "union", "first"], value="base", label="Tokenizer Source") t5_out = gr.Textbox(label="Output Repo (User/Repo)") t5_priv = gr.Checkbox(label="Private Output", value=True) t5_btn = gr.Button("🚀 Execute Full Merge", variant="primary") t5_res = gr.Textbox(label="Result") t5_btn.click(task_full_model_merge, [t5_token, t5_models, t5_method, t5_dtype, t5_base, gr.State(""), t5_density, t5_shard, t5_out, t5_priv], t5_res) # ================================================================================= # UPDATED TAB 6: MIXTURE OF EXPERTS (MoE Creator) # ================================================================================= with gr.Tab("Create MoE"): gr.Markdown("### 🤖 Mixture of Experts Upscaling") with gr.Row(): t6_token = gr.Textbox(label="HF Token", type="password") t6_dtype = gr.Radio(["bfloat16", "float16", "float32"], value="bfloat16", label="Precision") t6_shard = gr.Slider(0.5, 10, 2.0, label="Shard Size (GB)") t6_base = gr.Textbox(label="Base Architecture Model", placeholder="repo/backbone-model") t6_experts = gr.TextArea(label="Experts (One per line)", placeholder="repo/expert-1\nrepo/expert-2...") with gr.Accordion("MoE Hyperparameters", open=True): with gr.Row(): t6_gate_mode = gr.Dropdown(["cheap_embed", "hidden", "random"], value="cheap_embed", label="Gating Mode") t6_tok_src = gr.Dropdown(["base", "union", "first"], value="base", label="Tokenizer Source") t6_out = gr.Textbox(label="Output Repo", placeholder="User/Repo") t6_priv = gr.Checkbox(label="Private", value=True) t6_btn = gr.Button("🏗️ Build MoE", variant="primary") t6_res = gr.Textbox(label="Result") t6_btn.click(task_create_moe, [t6_token, t6_dtype, t6_shard, t6_base, t6_experts, t6_gate_mode, t6_tok_src, t6_out, t6_priv], t6_res) if __name__ == "__main__": demo.queue().launch(css=css, ssr_mode=False)