Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import os | |
| import gc | |
| import shutil | |
| import requests | |
| import json | |
| import struct | |
| import numpy as np | |
| import re | |
| import yaml | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional, List | |
| from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login | |
| from safetensors.torch import load_file, save_file | |
| from tqdm import tqdm | |
| # --- Memory Efficient Safetensors --- | |
| class MemoryEfficientSafeOpen: | |
| def __init__(self, filename): | |
| self.filename = filename | |
| self.file = open(filename, "rb") | |
| self.header, self.header_size = self._read_header() | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.file.close() | |
| def keys(self) -> list[str]: | |
| return [k for k in self.header.keys() if k != "__metadata__"] | |
| def metadata(self) -> Dict[str, str]: | |
| return self.header.get("__metadata__", {}) | |
| def get_tensor(self, key): | |
| if key not in self.header: | |
| raise KeyError(f"Tensor '{key}' not found in the file") | |
| metadata = self.header[key] | |
| offset_start, offset_end = metadata["data_offsets"] | |
| self.file.seek(self.header_size + 8 + offset_start) | |
| tensor_bytes = self.file.read(offset_end - offset_start) | |
| return self._deserialize_tensor(tensor_bytes, metadata) | |
| def _read_header(self): | |
| header_size = struct.unpack("<Q", self.file.read(8))[0] | |
| header_json = self.file.read(header_size).decode("utf-8") | |
| return json.loads(header_json), header_size | |
| def _deserialize_tensor(self, tensor_bytes, metadata): | |
| dtype_map = { | |
| "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16, | |
| "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, | |
| "U8": torch.uint8, "BOOL": torch.bool | |
| } | |
| dtype = dtype_map[metadata["dtype"]] | |
| shape = metadata["shape"] | |
| return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape) | |
| # --- Constants & Setup --- | |
| try: | |
| TempDir = Path("/tmp/temp_tool") | |
| os.makedirs(TempDir, exist_ok=True) | |
| except: | |
| TempDir = Path("./temp_tool") | |
| os.makedirs(TempDir, exist_ok=True) | |
| api = HfApi() | |
| def cleanup_temp(): | |
| if TempDir.exists(): | |
| shutil.rmtree(TempDir) | |
| os.makedirs(TempDir, exist_ok=True) | |
| gc.collect() | |
| def get_key_stem(key): | |
| key = key.replace(".weight", "").replace(".bias", "") | |
| key = key.replace(".lora_down", "").replace(".lora_up", "") | |
| key = key.replace(".lora_A", "").replace(".lora_B", "") | |
| key = key.replace(".alpha", "") | |
| prefixes = [ | |
| "model.diffusion_model.", "diffusion_model.", "model.", | |
| "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model." | |
| ] | |
| changed = True | |
| while changed: | |
| changed = False | |
| for p in prefixes: | |
| if key.startswith(p): | |
| key = key[len(p):] | |
| changed = True | |
| return key | |
| # ================================================================================= | |
| # TAB 1: MERGE & RESHARD (Legacy Logic) | |
| # ================================================================================= | |
| def parse_hf_url(url): | |
| if "huggingface.co" in url and "resolve" in url: | |
| try: | |
| parts = url.split("huggingface.co/")[-1].split("/") | |
| repo_id = f"{parts[0]}/{parts[1]}" | |
| filename = "/".join(parts[4:]).split("?")[0] | |
| return repo_id, filename | |
| except: | |
| return None, None | |
| return None, None | |
| def download_lora_smart(input_str, token): | |
| local_path = TempDir / "adapter.safetensors" | |
| if local_path.exists(): os.remove(local_path) | |
| repo_id, filename = parse_hf_url(input_str) | |
| if repo_id and filename: | |
| try: | |
| hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir) | |
| found = list(TempDir.rglob(filename.split("/")[-1]))[0] | |
| if found != local_path: shutil.move(found, local_path) | |
| return local_path | |
| except: pass | |
| try: | |
| if ".safetensors" in input_str and input_str.count("/") >= 2: | |
| parts = input_str.split("/") | |
| repo_id = f"{parts[0]}/{parts[1]}" | |
| filename = "/".join(parts[2:]) | |
| hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir) | |
| found = list(TempDir.rglob(filename.split("/")[-1]))[0] | |
| if found != local_path: shutil.move(found, local_path) | |
| return local_path | |
| candidates = ["adapter_model.safetensors", "model.safetensors"] | |
| files = list_repo_files(repo_id=input_str, token=token) | |
| target = next((f for f in files if f in candidates), None) | |
| if not target: | |
| safes = [f for f in files if f.endswith(".safetensors")] | |
| if safes: target = safes[0] | |
| if not target: raise ValueError("No safetensors found") | |
| hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir) | |
| found = list(TempDir.rglob(target.split("/")[-1]))[0] | |
| if found != local_path: shutil.move(found, local_path) | |
| return local_path | |
| except Exception as e: | |
| if input_str.startswith("http"): | |
| try: | |
| headers = {"Authorization": f"Bearer {token}"} if token else {} | |
| r = requests.get(input_str, stream=True, headers=headers, timeout=60) | |
| r.raise_for_status() | |
| with open(local_path, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): f.write(chunk) | |
| return local_path | |
| except: pass | |
| raise e | |
| def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16): | |
| state_dict = load_file(lora_path, device="cpu") | |
| pairs = {} | |
| alphas = {} | |
| for k, v in state_dict.items(): | |
| stem = get_key_stem(k) | |
| if "alpha" in k: | |
| alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v | |
| else: | |
| if stem not in pairs: pairs[stem] = {} | |
| if "lora_down" in k or "lora_A" in k: | |
| pairs[stem]["down"] = v.to(dtype=precision_dtype) | |
| pairs[stem]["rank"] = v.shape[0] | |
| elif "lora_up" in k or "lora_B" in k: | |
| pairs[stem]["up"] = v.to(dtype=precision_dtype) | |
| for stem in pairs: | |
| pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0))) | |
| return pairs | |
| class ShardBuffer: | |
| def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"): | |
| self.max_bytes = int(max_size_gb * 1024**3) | |
| self.output_dir = output_dir | |
| self.output_repo = output_repo | |
| self.subfolder = subfolder | |
| self.hf_token = hf_token | |
| self.filename_prefix = filename_prefix | |
| self.buffer = [] | |
| self.current_bytes = 0 | |
| self.shard_count = 0 | |
| self.index_map = {} | |
| self.total_size = 0 | |
| def add_tensor(self, key, tensor): | |
| if tensor.dtype == torch.bfloat16: | |
| raw_bytes = tensor.view(torch.int16).numpy().tobytes() | |
| dtype_str = "BF16" | |
| elif tensor.dtype == torch.float16: | |
| raw_bytes = tensor.numpy().tobytes() | |
| dtype_str = "F16" | |
| else: | |
| raw_bytes = tensor.numpy().tobytes() | |
| dtype_str = "F32" | |
| size = len(raw_bytes) | |
| self.buffer.append({"key": key, "data": raw_bytes, "dtype": dtype_str, "shape": tensor.shape}) | |
| self.current_bytes += size | |
| self.total_size += size | |
| if self.current_bytes >= self.max_bytes: self.flush() | |
| def flush(self): | |
| if not self.buffer: return | |
| self.shard_count += 1 | |
| filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors" | |
| path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename | |
| header = {"__metadata__": {"format": "pt"}} | |
| current_offset = 0 | |
| for item in self.buffer: | |
| header[item["key"]] = {"dtype": item["dtype"], "shape": item["shape"], "data_offsets": [current_offset, current_offset + len(item["data"])]} | |
| current_offset += len(item["data"]) | |
| self.index_map[item["key"]] = filename | |
| header_json = json.dumps(header).encode('utf-8') | |
| out_path = self.output_dir / filename | |
| with open(out_path, 'wb') as f: | |
| f.write(struct.pack('<Q', len(header_json))) | |
| f.write(header_json) | |
| for item in self.buffer: f.write(item["data"]) | |
| api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token) | |
| os.remove(out_path) | |
| self.buffer = [] | |
| self.current_bytes = 0 | |
| gc.collect() | |
| def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix=None, is_root_merge=False): | |
| try: | |
| files = api.list_repo_files(repo_id=src_repo, token=token) | |
| for f in tqdm(files, desc="Copying Structure"): | |
| if ignore_prefix and f.startswith(ignore_prefix): continue | |
| if is_root_merge: | |
| if any(f.endswith(ext) for ext in ['.safetensors', '.bin', '.pt', '.pth']): continue | |
| try: | |
| local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir) | |
| api.upload_file(path_or_fileobj=local, path_in_repo=f, repo_id=dst_repo, token=token) | |
| if os.path.exists(local): os.remove(local) | |
| except: pass | |
| except: pass | |
| def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()): | |
| cleanup_temp() | |
| if not hf_token: return "Error: HF Token required." | |
| login(hf_token.strip()) | |
| try: api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token) | |
| except Exception as e: return f"Error creating repo: {e}" | |
| output_subfolder = base_subfolder if base_subfolder else "" | |
| if structure_repo: | |
| ignore = output_subfolder if output_subfolder else None | |
| streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix=ignore, is_root_merge=not bool(output_subfolder)) | |
| progress(0.1, desc="Downloading Input Model...") | |
| files = list_repo_files(repo_id=base_repo, token=hf_token) | |
| input_shards = [] | |
| for f in files: | |
| if f.endswith(".safetensors"): | |
| if output_subfolder and not f.startswith(output_subfolder): continue | |
| local = TempDir / "inputs" / os.path.basename(f) | |
| os.makedirs(local.parent, exist_ok=True) | |
| hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local.parent, local_dir_use_symlinks=False) | |
| found = list(local.parent.rglob(os.path.basename(f))) | |
| if found: input_shards.append(found[0]) | |
| if not input_shards: return "No safetensors found." | |
| input_shards.sort() | |
| filename_prefix = "diffusion_pytorch_model" if (output_subfolder in ["transformer", "unet"] or "diffusion_pytorch_model" in os.path.basename(input_shards[0])) else "model" | |
| index_filename = f"{filename_prefix}.safetensors.index.json" | |
| dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32 | |
| try: | |
| progress(0.15, desc="Downloading LoRA...") | |
| lora_path = download_lora_smart(lora_input, hf_token) | |
| lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype) | |
| except Exception as e: return f"Error loading LoRA: {e}" | |
| buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix) | |
| for i, shard_file in enumerate(input_shards): | |
| progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}") | |
| with MemoryEfficientSafeOpen(shard_file) as f: | |
| for k in f.keys(): | |
| v = f.get_tensor(k) | |
| base_stem = get_key_stem(k) | |
| match = lora_pairs.get(base_stem) | |
| if not match: | |
| if "to_q" in base_stem: match = lora_pairs.get(base_stem.replace("to_q", "qkv")) | |
| elif "to_k" in base_stem: match = lora_pairs.get(base_stem.replace("to_k", "qkv")) | |
| elif "to_v" in base_stem: match = lora_pairs.get(base_stem.replace("to_v", "qkv")) | |
| if match: | |
| down, up = match["down"], match["up"] | |
| scaling = scale * (match["alpha"] / match["rank"]) | |
| if len(v.shape) == 4 and len(down.shape) == 2: | |
| down, up = down.unsqueeze(-1).unsqueeze(-1), up.unsqueeze(-1).unsqueeze(-1) | |
| try: | |
| delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) if len(up.shape) == 4 else up @ down | |
| except: delta = up.T @ down | |
| delta = delta * scaling | |
| if delta.shape == v.shape: v = v.to(dtype).add_(delta.to(dtype)) | |
| del delta | |
| buffer.add_tensor(k, v.to(dtype)) | |
| del v | |
| os.remove(shard_file) | |
| gc.collect() | |
| buffer.flush() | |
| index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map} | |
| path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename | |
| with open(TempDir / index_filename, "w") as f: json.dump(index_data, f, indent=4) | |
| api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token) | |
| cleanup_temp() | |
| return f"Done! Merged {buffer.shard_count} shards to {output_repo}" | |
| # ================================================================================= | |
| # TAB 2: EXTRACT LORA | |
| # ================================================================================= | |
| def identify_and_download_model(input_str, token): | |
| repo_id, filename = parse_hf_url(input_str) | |
| if repo_id and filename: | |
| local_path = TempDir / os.path.basename(filename) | |
| if local_path.exists(): os.remove(local_path) | |
| hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir) | |
| return list(TempDir.rglob(os.path.basename(filename)))[0] | |
| files = list_repo_files(repo_id=input_str, token=token) | |
| priorities = ["transformer/diffusion_pytorch_model.safetensors", "unet/diffusion_pytorch_model.safetensors", "model.safetensors"] | |
| target_file = next((f for f in priorities if f in files), next((f for f in files if f.endswith(".safetensors") and "lora" not in f), None)) | |
| if not target_file: raise ValueError("No model file found") | |
| hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir) | |
| return list(TempDir.rglob(os.path.basename(target_file)))[0] | |
| def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp): | |
| org = MemoryEfficientSafeOpen(model_org) | |
| tuned = MemoryEfficientSafeOpen(model_tuned) | |
| lora_sd = {} | |
| keys = set(org.keys()).intersection(set(tuned.keys())) | |
| for key in tqdm(keys, desc="Extracting"): | |
| if "num_batches_tracked" in key or "running_mean" in key or "running_var" in key: continue | |
| mat_org = org.get_tensor(key).float() | |
| mat_tuned = tuned.get_tensor(key).float() | |
| if mat_org.shape != mat_tuned.shape: continue | |
| diff = mat_tuned - mat_org | |
| if torch.max(torch.abs(diff)) < 1e-4: continue | |
| out_dim, in_dim = diff.shape[0], diff.shape[1] if len(diff.shape) > 1 else 1 | |
| r = min(rank, in_dim, out_dim) | |
| is_conv = len(diff.shape) == 4 | |
| if is_conv: diff = diff.flatten(start_dim=1) | |
| elif len(diff.shape) == 1: diff = diff.unsqueeze(1) | |
| U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4) | |
| Vh = V.t() | |
| U, S, Vh = U[:, :r], S[:r], Vh[:r, :] | |
| U = U @ torch.diag(S) | |
| dist = torch.cat([U.flatten(), Vh.flatten()]) | |
| hi_val = torch.quantile(torch.abs(dist), clamp) | |
| if hi_val > 0: | |
| U, Vh = U.clamp(-hi_val, hi_val), Vh.clamp(-hi_val, hi_val) | |
| if is_conv: | |
| U = U.reshape(out_dim, r, 1, 1) | |
| Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3]) | |
| else: | |
| U = U.reshape(out_dim, r) | |
| Vh = Vh.reshape(r, in_dim) | |
| stem = key.replace(".weight", "") | |
| lora_sd[f"{stem}.lora_up.weight"] = U.contiguous() | |
| lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous() | |
| lora_sd[f"{stem}.alpha"] = torch.tensor(r).float() | |
| out = TempDir / "extracted.safetensors" | |
| save_file(lora_sd, out) | |
| return str(out) | |
| def task_extract(hf_token, org, tun, rank, out): | |
| cleanup_temp() | |
| if hf_token: login(hf_token.strip()) | |
| try: | |
| p1 = identify_and_download_model(org, hf_token) | |
| p2 = identify_and_download_model(tun, hf_token) | |
| f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99) | |
| api.create_repo(repo_id=out, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token) | |
| return "Done! Extracted to " + out | |
| except Exception as e: return f"Error: {e}" | |
| # ================================================================================= | |
| # TAB 3: MERGE ADAPTERS | |
| # ================================================================================= | |
| def load_full_state_dict(path): | |
| raw = load_file(path, device="cpu") | |
| cleaned = {} | |
| for k, v in raw.items(): | |
| if "lora_A" in k: new_k = k.replace("lora_A", "lora_down") | |
| elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up") | |
| else: new_k = k | |
| cleaned[new_k] = v.float() | |
| return cleaned | |
| def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private): | |
| cleanup_temp() | |
| if hf_token: login(hf_token.strip()) | |
| urls = [line.strip() for line in inputs_text.replace(" ", "\n").split('\n') if line.strip()] | |
| if len(urls) < 2: return "Error: Please provide at least 2 adapters." | |
| try: | |
| weights = [float(w.strip()) for w in weight_str.split(',')] if weight_str.strip() else [1.0] * len(urls) | |
| if len(weights) < len(urls): weights += [1.0] * (len(urls) - len(weights)) | |
| except: return "Error parsing weights." | |
| paths = [] | |
| try: | |
| for url in tqdm(urls, desc="Downloading Adapters"): paths.append(download_lora_smart(url, hf_token)) | |
| except Exception as e: return f"Download Error: {e}" | |
| merged = None | |
| if "Iterative EMA" in method: | |
| base_sd = load_file(paths[0], device="cpu") | |
| for k in base_sd: | |
| if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float() | |
| gamma = None | |
| if sigma_rel > 0: | |
| t_val = sigma_rel**-2 | |
| roots = np.roots([1, 7, 16 - t_val, 12 - t_val]) | |
| gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max() | |
| for i, path in enumerate(paths[1:]): | |
| current_beta = (1 - 1 / (i + 1)) ** (gamma + 1) if gamma is not None else beta | |
| curr = load_file(path, device="cpu") | |
| for k in base_sd: | |
| if k in curr and "alpha" not in k: | |
| base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta) | |
| merged = base_sd | |
| else: | |
| states = [load_full_state_dict(p) for p in paths] | |
| merged = {} | |
| all_stems = set() | |
| for s in states: | |
| for k in s: | |
| if "lora_" in k: all_stems.add(k.split(".lora_")[0]) | |
| for stem in tqdm(all_stems): | |
| down_list, up_list = [], [] | |
| alpha_sum = 0.0 | |
| total_delta = None | |
| for i, state in enumerate(states): | |
| w = weights[i] | |
| dk, uk, ak = f"{stem}.lora_down.weight", f"{stem}.lora_up.weight", f"{stem}.alpha" | |
| if dk in state and uk in state: | |
| d, u = state[dk], state[uk] | |
| alpha_sum += state[ak].item() if ak in state else d.shape[0] | |
| if "Concatenation" in method: | |
| down_list.append(d) | |
| up_list.append(u * w) | |
| elif "SVD" in method: | |
| rank, alpha = d.shape[0], state[ak].item() if ak in state else d.shape[0] | |
| scale = (alpha / rank) * w | |
| delta = ((u.flatten(1) @ d.flatten(1)).reshape(u.shape[0], d.shape[1], d.shape[2], d.shape[3]) if len(d.shape)==4 else u @ d) * scale | |
| total_delta = delta if total_delta is None else total_delta + delta | |
| if "Concatenation" in method and down_list: | |
| merged[f"{stem}.lora_down.weight"] = torch.cat(down_list, dim=0).contiguous() | |
| merged[f"{stem}.lora_up.weight"] = torch.cat(up_list, dim=1).contiguous() | |
| merged[f"{stem}.alpha"] = torch.tensor(alpha_sum) | |
| elif "SVD" in method and total_delta is not None: | |
| tr = int(target_rank) | |
| flat = total_delta.flatten(1) if len(total_delta.shape)==4 else total_delta | |
| try: | |
| U, S, V = torch.svd_lowrank(flat, q=tr + 4, niter=4) | |
| Vh = V.t() | |
| U, S, Vh = U[:, :tr], S[:tr], Vh[:tr, :] | |
| U = U @ torch.diag(S) | |
| if len(total_delta.shape) == 4: | |
| U = U.reshape(total_delta.shape[0], tr, 1, 1) | |
| Vh = Vh.reshape(tr, total_delta.shape[1], total_delta.shape[2], total_delta.shape[3]) | |
| else: | |
| U, Vh = U.reshape(total_delta.shape[0], tr), Vh.reshape(tr, total_delta.shape[1]) | |
| merged[f"{stem}.lora_down.weight"] = Vh.contiguous() | |
| merged[f"{stem}.lora_up.weight"] = U.contiguous() | |
| merged[f"{stem}.alpha"] = torch.tensor(tr).float() | |
| except: pass | |
| out = TempDir / "merged_adapters.safetensors" | |
| save_file(merged, out) | |
| api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token) | |
| return f"Success! Merged to {out_repo}" | |
| # ================================================================================= | |
| # TAB 4: RESIZE | |
| # ================================================================================= | |
| def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo): | |
| cleanup_temp() | |
| if hf_token: login(hf_token.strip()) | |
| path = download_lora_smart(lora_input, hf_token) | |
| state = load_file(path, device="cpu") | |
| new_state = {} | |
| groups = {} | |
| for k in state: | |
| simple = k.split(".lora_")[0] | |
| if simple not in groups: groups[simple] = {} | |
| if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k] | |
| if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k] | |
| if "alpha" in k: groups[simple]["alpha"] = state[k] | |
| target_rank_limit = int(new_rank) | |
| for stem, g in tqdm(groups.items()): | |
| if "down" in g and "up" in g: | |
| down, up = g["down"].float(), g["up"].float() | |
| merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) if len(down.shape)==4 else up @ down | |
| flat = merged.flatten(1) | |
| U, S, V = torch.svd_lowrank(flat, q=target_rank_limit + 32) | |
| Vh = V.t() | |
| calc_rank = target_rank_limit | |
| if dynamic_method == "sv_ratio": | |
| calc_rank = int(torch.sum(S > (S[0] / dynamic_param)).item()) | |
| elif dynamic_method == "sv_cumulative": | |
| calc_rank = int(torch.searchsorted(torch.cumsum(S, 0) / torch.sum(S), dynamic_param)) + 1 | |
| elif dynamic_method == "sv_fro": | |
| calc_rank = int(torch.searchsorted(torch.cumsum(S.pow(2), 0) / torch.sum(S.pow(2)), dynamic_param**2)) + 1 | |
| final_rank = max(1, min(calc_rank, target_rank_limit, S.shape[0])) | |
| U = U[:, :final_rank] @ torch.diag(S[:final_rank]) | |
| Vh = Vh[:final_rank, :] | |
| if len(down.shape) == 4: | |
| U = U.reshape(up.shape[0], final_rank, 1, 1) | |
| Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3]) | |
| new_state[f"{stem}.lora_down.weight"] = Vh.contiguous() | |
| new_state[f"{stem}.lora_up.weight"] = U.contiguous() | |
| new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float() | |
| out = TempDir / "shrunken.safetensors" | |
| save_file(new_state, out) | |
| api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token) | |
| return "Done" | |
| # ================================================================================= | |
| # NEW HELPERS: MERGEKIT CLI WRAPPER | |
| # ================================================================================= | |
| def run_mergekit_cli(config_str, output_path, hf_token): | |
| # This replaces the Python API call to avoid 'unexpected keyword' errors | |
| # Writes config to file -> runs `mergekit-yaml` subprocess -> returns path | |
| config_file = TempDir / "config.yaml" | |
| with open(config_file, "w") as f: | |
| f.write(config_str) | |
| # Ensure token is in env for subprocess | |
| env = os.environ.copy() | |
| if hf_token: | |
| env["HF_TOKEN"] = hf_token.strip() | |
| cmd = [ | |
| "mergekit-yaml", | |
| str(config_file), | |
| str(output_path), | |
| "--allow-crimes", # Allows mixing architectures if needed | |
| "--lazy-unpickle", # Memory optimization | |
| "--copy-tokenizer" | |
| ] | |
| # Run | |
| print(f"Running command: {' '.join(cmd)}") | |
| result = subprocess.run(cmd, env=env, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"MergeKit CLI Failed:\n{result.stderr}") | |
| return str(output_path) | |
| def upload_folder_to_hf(folder, repo_id, token, private=True): | |
| api.create_repo(repo_id=repo_id, private=private, exist_ok=True, token=token) | |
| api.upload_folder(folder_path=folder, repo_id=repo_id, token=token) | |
| return f"Success! Uploaded to {repo_id}" | |
| # ================================================================================= | |
| # TAB 5: WEIGHTED & SPARSIFIED (Linear, Ties, Dare) | |
| # ================================================================================= | |
| def task_mergekit_weighted(hf_token, models_text, method, dtype, base_model, weights, density, normalize, out_repo, private): | |
| cleanup_temp() | |
| if not hf_token: return "Error: Token required" | |
| login(hf_token.strip()) | |
| model_list = [m.strip() for m in models_text.split('\n') if m.strip()] | |
| if not model_list: return "Error: No models listed" | |
| # Build Config | |
| config = {} | |
| if method == "linear": | |
| # Linear/Model Stock usually structure: | |
| # models: | |
| # - model: x | |
| # parameters: | |
| # weight: 1.0 | |
| c_models = [] | |
| w_list = [float(x) for x in weights.split(',')] if weights.strip() else [1.0] * len(model_list) | |
| if len(w_list) < len(model_list): w_list += [1.0] * (len(model_list) - len(w_list)) | |
| for i, m in enumerate(model_list): | |
| c_models.append({"model": m, "parameters": {"weight": w_list[i]}}) | |
| config = {"models": c_models, "merge_method": method, "dtype": dtype} | |
| else: | |
| # TIES / DARE / ETC | |
| c_models = [] | |
| w_list = [float(x) for x in weights.split(',')] if weights.strip() else [1.0] * len(model_list) | |
| for i, m in enumerate(model_list): | |
| item = {"model": m, "parameters": {"weight": w_list[i] if i < len(w_list) else 1.0}} | |
| if density and method in ["dare_ties", "dare_linear", "ties"]: | |
| item["parameters"]["density"] = float(density) | |
| c_models.append(item) | |
| config = { | |
| "models": c_models, | |
| "merge_method": method, | |
| "base_model": base_model if base_model else model_list[0], | |
| "parameters": { | |
| "normalize": normalize, | |
| "int8_mask": True if "dare" in method else False | |
| }, | |
| "dtype": dtype | |
| } | |
| yaml_str = yaml.dump(config, sort_keys=False) | |
| out_path = TempDir / "out_merged" | |
| try: | |
| run_mergekit_cli(yaml_str, out_path, hf_token) | |
| return upload_folder_to_hf(str(out_path), out_repo, hf_token, private) | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def task_mergekit_interp(hf_token, model_a, model_b, base_model, method, t_val, dtype, out_repo, private): | |
| cleanup_temp() | |
| if not hf_token: return "Error: Token required" | |
| login(hf_token.strip()) | |
| config = {} | |
| if method in ["slerp", "nuslerp"]: | |
| config = { | |
| "slices": [ | |
| { | |
| "sources": [ | |
| {"model": model_a, "layer_range": [0, 32]}, # Default full range assumption | |
| {"model": model_b, "layer_range": [0, 32]} | |
| ], | |
| "parameters": { | |
| "t": float(t_val) | |
| } | |
| } | |
| ], | |
| "merge_method": method, | |
| "base_model": model_a, # Slerp needs a base usually just for config | |
| "dtype": dtype | |
| } | |
| elif method == "task_arithmetic": | |
| config = { | |
| "models": [ | |
| {"model": model_a, "parameters": {"weight": 1.0}}, | |
| {"model": model_b, "parameters": {"weight": -1.0}} # Simple subtraction example | |
| ], | |
| "base_model": base_model if base_model else model_a, | |
| "merge_method": method, | |
| "dtype": dtype | |
| } | |
| # Correcting for generic usage | |
| # If Task Arithmetic is selected, let's allow more generic standard config | |
| if method == "task_arithmetic": | |
| config = { | |
| "base_model": base_model if base_model else model_a, | |
| "merge_method": "task_arithmetic", | |
| "models": [ | |
| {"model": model_a, "parameters": {"weight": 1.0}}, | |
| {"model": model_b, "parameters": {"weight": float(t_val)}} | |
| ], | |
| "dtype": dtype | |
| } | |
| yaml_str = yaml.dump(config, sort_keys=False) | |
| out_path = TempDir / "out_interp" | |
| try: | |
| run_mergekit_cli(yaml_str, out_path, hf_token) | |
| return upload_folder_to_hf(str(out_path), out_repo, hf_token, private) | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def task_mergekit_moe(hf_token, base_model, experts_text, gate_mode, dtype, out_repo, private): | |
| cleanup_temp() | |
| if not hf_token: return "Error: Token required" | |
| login(hf_token.strip()) | |
| experts = [e.strip() for e in experts_text.split('\n') if e.strip()] | |
| if not experts: return "Error: No experts listed" | |
| # Construct MoE config | |
| formatted_experts = [] | |
| for e in experts: | |
| formatted_experts.append({ | |
| "source_model": e, | |
| "positive_prompts": [""] # Simplified for GUI | |
| }) | |
| config = { | |
| "base_model": base_model if base_model else experts[0], | |
| "gate_mode": gate_mode, | |
| "dtype": dtype, | |
| "experts": formatted_experts | |
| } | |
| yaml_str = yaml.dump(config, sort_keys=False) | |
| out_path = TempDir / "out_moe" | |
| try: | |
| run_mergekit_cli(yaml_str, out_path, hf_token) | |
| return upload_folder_to_hf(str(out_path), out_repo, hf_token, private) | |
| except Exception as e: | |
| return f"Error: {e}" | |
| # ================================================================================= | |
| # TAB 8: RAW PYTORCH (Passthrough / Non-Transformer) | |
| # ================================================================================= | |
| def task_raw_merge(hf_token, models_text, method, dtype, out_repo, private): | |
| cleanup_temp() | |
| if not hf_token: return "Error: Token required" | |
| login(hf_token.strip()) | |
| models = [m.strip() for m in models_text.split('\n') if m.strip()] | |
| # For Raw/Passthrough, we basically treat it like linear but with passthrough method | |
| # Or simple linear | |
| config = { | |
| "models": [{"model": m, "parameters": {"weight": 1.0}} for m in models], | |
| "merge_method": method, # passthrough, linear | |
| "dtype": dtype | |
| } | |
| yaml_str = yaml.dump(config, sort_keys=False) | |
| out_path = TempDir / "out_raw" | |
| try: | |
| run_mergekit_cli(yaml_str, out_path, hf_token) | |
| return upload_folder_to_hf(str(out_path), out_repo, hf_token, private) | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def task_dare_soonr(hf_token, base_model, ft_model, ratio, mask_rate, out_repo, private): | |
| # Ported from the requested DARE-MERGE-SOONR implementation | |
| cleanup_temp() | |
| if not hf_token: return "Error: Token required" | |
| login(hf_token.strip()) | |
| try: | |
| print("Downloading Base...") | |
| base_path = identify_and_download_model(base_model, hf_token) | |
| print("Downloading FT...") | |
| ft_path = identify_and_download_model(ft_model, hf_token) | |
| print("Loading Tensors...") | |
| base_sd = load_file(base_path, device="cpu") | |
| ft_sd = load_file(ft_path, device="cpu") | |
| merged_sd = {} | |
| common_keys = set(base_sd.keys()).intersection(set(ft_sd.keys())) | |
| print("Merging...") | |
| for key in tqdm(common_keys): | |
| base_t = base_sd[key] | |
| ft_t = ft_sd[key] | |
| if base_t.dtype != ft_t.dtype or base_t.shape != ft_t.shape: | |
| merged_sd[key] = ft_t # Fallback | |
| continue | |
| # DARE Logic | |
| # 1. Delta | |
| delta = ft_t.float() - base_t.float() | |
| # 2. Mask (Drop) | |
| if mask_rate > 0.0: | |
| # Bernoulli mask | |
| mask = torch.bernoulli(torch.full_like(delta, 1.0 - mask_rate)) | |
| # Rescale | |
| rescale_factor = 1.0 / (1.0 - mask_rate) | |
| delta = delta * mask * rescale_factor | |
| # 3. Apply Ratio and Add | |
| merged_t = base_t.float() + (delta * ratio) | |
| # Cast back | |
| if base_t.dtype == torch.bfloat16: | |
| merged_sd[key] = merged_t.bfloat16() | |
| elif base_t.dtype == torch.float16: | |
| merged_sd[key] = merged_t.half() | |
| else: | |
| merged_sd[key] = merged_t | |
| # Save | |
| out_path = TempDir / "dare_merged.safetensors" | |
| save_file(merged_sd, out_path) | |
| # Upload | |
| api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out_path, path_in_repo="model.safetensors", repo_id=out_repo, token=hf_token) | |
| return f"Success! Uploaded to {out_repo}" | |
| except Exception as e: | |
| return f"DARE Error: {e}" | |
| # ================================================================================= | |
| # UI | |
| # ================================================================================= | |
| css = ".container { max-width: 1100px; margin: auto; }" | |
| with gr.Blocks() as demo: | |
| title = gr.HTML( | |
| """<h1><img src="https://huggingface.co/spaces/AlekseyCalvin/Soon_Merger/resolve/main/SMerger3.png" alt="SOONmerge®"> Transform Transformers for FREE!</h1>""", | |
| elem_id="title", | |
| ) | |
| gr.Markdown("# 🧰Training-Free CPU-run Model Creation Toolkit") | |
| with gr.Tabs(): | |
| with gr.Tab("Merge into Base Model"): | |
| with gr.Row(): | |
| t1_token = gr.Textbox(label="Token", type="password") | |
| with gr.Row(): | |
| t1_base = gr.Textbox(label="Base Repo", value="name/repo") | |
| t1_sub = gr.Textbox(label="Subfolder (Optional)", value="") | |
| t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors") | |
| with gr.Row(): | |
| t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1) | |
| t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision") | |
| t1_shard = gr.Slider(label="Max Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1) | |
| t1_out = gr.Textbox(label="Output Repo") | |
| t1_struct = gr.Textbox(label="Extras Source (copies configs/components/etc)", value="name/repo") | |
| t1_priv = gr.Checkbox(label="Private", value=True) | |
| t1_btn = gr.Button("Merge") | |
| t1_res = gr.Textbox(label="Result") | |
| t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res) | |
| with gr.Tab("Extract Adapter"): | |
| t2_token = gr.Textbox(label="Token", type="password") | |
| t2_org = gr.Textbox(label="Original Model") | |
| t2_tun = gr.Textbox(label="Tuned or Homologous Model") | |
| t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1) | |
| t2_out = gr.Textbox(label="Output Repo") | |
| t2_btn = gr.Button("Extract") | |
| t2_res = gr.Textbox(label="Result") | |
| t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res) | |
| with gr.Tab("Merge Adapters"): | |
| gr.Markdown("### Batch Adapter Merging") | |
| t3_token = gr.Textbox(label="Token", type="password") | |
| t3_urls = gr.TextArea(label="Adapter URLs/Repos (one per line, or space-separated)", placeholder="user/lora1\nhttps://hf.co/user/lora2.safetensors\n...") | |
| with gr.Row(): | |
| t3_method = gr.Dropdown( | |
| ["Iterative EMA (Linear w/ Beta/Sigma coefficient)", "Concatenation (MOE-like weights-stack)", "SVD Fusion (Task Arithmetic/Compressed)"], | |
| value="Iterative EMA (Linear w/ Beta/Sigma coefficient)", | |
| label="Merge Method" | |
| ) | |
| with gr.Row(): | |
| t3_weights = gr.Textbox(label="Weights (comma-separated) – for Concat/SVD", placeholder="1.0, 0.5, 0.8...") | |
| t3_rank = gr.Number(label="Target Rank – For SVD only", value=128, minimum=1, maximum=1024) | |
| with gr.Row(): | |
| t3_beta = gr.Slider(label="Beta – for linear/post-hoc EMA", value=0.95, minimum=0.01, maximum=1.00, step=0.01) | |
| t3_sigma = gr.Slider(label="Sigma Rel – for linear/post-hoc EMA", value=0.21, minimum=0.01, maximum=1.00, step=0.01) | |
| t3_out = gr.Textbox(label="Output Repo") | |
| t3_priv = gr.Checkbox(label="Private Output", value=True) | |
| t3_btn = gr.Button("Merge") | |
| t3_res = gr.Textbox(label="Result") | |
| t3_btn.click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], t3_res) | |
| with gr.Tab("Resize Adapter"): | |
| t4_token = gr.Textbox(label="Token", type="password") | |
| t4_in = gr.Textbox(label="LoRA") | |
| with gr.Row(): | |
| t4_rank = gr.Number(label="To Rank (Safety Ceiling)", value=8, minimum=1, maximum=512, step=1) | |
| t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None", label="Dynamic Method") | |
| t4_param = gr.Number(label="Dynamic Param", value=0.9) | |
| gr.Markdown( | |
| """ | |
| ### 📉 Dynamic Resizing Guide | |
| These methods intelligently determine the best rank per layer. | |
| * **sv_ratio (Relative Strength):** Keeps features that are at least `1/Param` as strong as the main feature. **Param must be >= 2**. (e.g. 2 = keep features half as strong as top). | |
| * **sv_fro (Visual Information Density):** Preserves `Param%` of the total information content (Frobenius Norm) of the layer. **Param between 0.0 and 1.0** (e.g. 0.9 = 90% info retention). | |
| * **sv_cumulative (Cumulative Sum):** Preserves weights that sum up to `Param%` of the total strength. **Param between 0.0 and 1.0**. | |
| * **⚠️ Safety Ceiling:** The **"To Rank"** slider acts as a hard limit. Even if a dynamic method wants a higher rank, it will be cut down to this number to keep file sizes small. | |
| """ | |
| ) | |
| t4_out = gr.Textbox(label="Output") | |
| t4_btn = gr.Button("Resize") | |
| t4_res = gr.Textbox(label="Result") | |
| t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res) | |
| with gr.Tab("Stir/Tie Bases"): | |
| gr.Markdown("### Linear, TIES, dare-TIES, Model Stock") | |
| t5_token = gr.Textbox(label="HF Token", type="password") | |
| with gr.Row(): | |
| t5_method = gr.Dropdown(["linear", "ties", "dare_ties", "dare_linear", "model_stock"], value="linear", label="Method") | |
| t5_dtype = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Dtype") | |
| t5_models = gr.TextArea(label="Models (one per line)") | |
| with gr.Row(): | |
| t5_base = gr.Textbox(label="Base Model (Optional)") | |
| t5_weights = gr.Textbox(label="Weights (comma sep)", placeholder="1.0, 0.5") | |
| with gr.Row(): | |
| t5_density = gr.Textbox(label="Density (for DARE/TIES)", placeholder="0.5") | |
| t5_norm = gr.Checkbox(label="Normalize", value=True) | |
| t5_out = gr.Textbox(label="Output Repo") | |
| t5_priv = gr.Checkbox(label="Private", value=True) | |
| t5_btn = gr.Button("Run MergeKit (CLI)") | |
| t5_res = gr.Textbox(label="Result") | |
| t5_btn.click(task_mergekit_weighted, [t5_token, t5_models, t5_method, t5_dtype, t5_base, t5_weights, t5_density, t5_norm, t5_out, t5_priv], t5_res) | |
| with gr.Tab("Amphinterpolative"): | |
| gr.Markdown("### Slerp, Task Arithmetic, NuSlerp") | |
| t6_token = gr.Textbox(label="HF Token", type="password") | |
| with gr.Row(): | |
| t6_model_a = gr.Textbox(label="Model A") | |
| t6_model_b = gr.Textbox(label="Model B") | |
| with gr.Row(): | |
| t6_method = gr.Dropdown(["slerp", "nuslerp", "task_arithmetic"], value="slerp", label="Method") | |
| t6_dtype = gr.Dropdown(["float16", "bfloat16"], value="bfloat16", label="Dtype") | |
| t6_t = gr.Textbox(label="t (Interpolation factor)", value="0.5") | |
| t6_base = gr.Textbox(label="Base Model (for Task Arithmetic)", placeholder="Same as A usually") | |
| t6_out = gr.Textbox(label="Output Repo") | |
| t6_priv = gr.Checkbox(label="Private", value=True) | |
| t6_btn = gr.Button("Run MergeKit (CLI)") | |
| t6_res = gr.Textbox(label="Result") | |
| t6_btn.click(task_mergekit_interp, [t6_token, t6_model_a, t6_model_b, t6_base, t6_method, t6_t, t6_dtype, t6_out, t6_priv], t6_res) | |
| with gr.Tab("MoEr"): | |
| gr.Markdown("### Mixture of Experts Construction") | |
| t7_token = gr.Textbox(label="HF Token", type="password") | |
| t7_base = gr.Textbox(label="Base Model") | |
| t7_experts = gr.TextArea(label="Experts (one per line)") | |
| with gr.Row(): | |
| t7_gate = gr.Dropdown(["cheap_embed", "random", "hidden"], value="cheap_embed", label="Gate Mode") | |
| t7_dtype = gr.Dropdown(["float16", "bfloat16"], value="bfloat16", label="Dtype") | |
| t7_out = gr.Textbox(label="Output Repo") | |
| t7_priv = gr.Checkbox(label="Private", value=True) | |
| t7_btn = gr.Button("Build MoE (CLI)") | |
| t7_res = gr.Textbox(label="Result") | |
| t7_btn.click(task_mergekit_moe, [t7_token, t7_base, t7_experts, t7_gate, t7_dtype, t7_out, t7_priv], t7_res) | |
| with gr.Tab("Rawer"): | |
| gr.Markdown("### Raw PyTorch MergeKit / Non-pipeline-classed") | |
| t8_token = gr.Textbox(label="HF Token", type="password") | |
| t8_models = gr.TextArea(label="Models (one per line)") | |
| t8_method = gr.Dropdown(["linear", "passthrough"], value="linear", label="Method") | |
| t8_dtype = gr.Dropdown(["float32", "float16", "bfloat16"], value="float32", label="dtype") | |
| t8_out = gr.Textbox(label="Output Repo") | |
| t8_priv = gr.Checkbox(label="Private", value=True) | |
| t8_btn = gr.Button("Merge") | |
| t8_res = gr.Textbox(label="Result") | |
| t8_btn.click(task_raw_merge, [t8_token, t8_models, t8_method, t8_dtype, t8_out, t8_priv], t8_res) | |
| with gr.Tab("Mario,DARE!"): | |
| gr.Markdown("### From sft-merger by [Martyn Garcia](https://github.com/martyn)") | |
| t9_token = gr.Textbox(label="HF Token", type="password") | |
| with gr.Row(): | |
| t9_base = gr.Textbox(label="Base Model") | |
| t9_ft = gr.Textbox(label="Fine-Tuned Model") | |
| with gr.Row(): | |
| t9_ratio = gr.Slider(0, 2, 1.0, label="Ratio") | |
| t9_mask = gr.Slider(0, 0.99, 0.5, label="Mask Rate (Drop)") | |
| t9_out = gr.Textbox(label="Output Repo") | |
| t9_priv = gr.Checkbox(label="Private", value=True) | |
| t9_btn = gr.Button("Run DARE Custom") | |
| t9_res = gr.Textbox(label="Result") | |
| t9_btn.click(task_dare_soonr, [t9_token, t9_base, t9_ft, t9_ratio, t9_mask, t9_out, t9_priv], t9_res) | |
| if __name__ == "__main__": | |
| demo.queue().launch(css=css, ssr_mode=False) |