Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import os | |
| import gc | |
| import shutil | |
| import requests | |
| import json | |
| import struct | |
| import numpy as np | |
| import yaml | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import re | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional, List, Iterable | |
| from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login | |
| from safetensors.torch import load_file, save_file | |
| from tqdm import tqdm | |
| # --- Import MergeKit Config --- | |
| try: | |
| from mergekit.config import MergeConfiguration | |
| except ImportError: | |
| # Fallback if installation fails temporarily | |
| class MergeConfiguration: | |
| def model_validate(config): pass | |
| # --- Constants & Setup --- | |
| try: | |
| TempDir = Path("/tmp/temp_tool") | |
| os.makedirs(TempDir, exist_ok=True) | |
| except: | |
| TempDir = Path("./temp_tool") | |
| os.makedirs(TempDir, exist_ok=True) | |
| api = HfApi() | |
| def cleanup_temp(): | |
| if TempDir.exists(): | |
| shutil.rmtree(TempDir) | |
| os.makedirs(TempDir, exist_ok=True) | |
| gc.collect() | |
| # ================================================================================= | |
| # SHARED HELPERS (Tabs 1-4 & 10) | |
| # ================================================================================= | |
| def parse_hf_url(url): | |
| if "huggingface.co" in url and "resolve" in url: | |
| try: | |
| parts = url.split("huggingface.co/")[-1].split("/") | |
| repo_id = f"{parts[0]}/{parts[1]}" | |
| filename = "/".join(parts[4:]).split("?")[0] | |
| return repo_id, filename | |
| except: | |
| return None, None | |
| return None, None | |
| def download_lora_smart(input_str, token): | |
| local_path = TempDir / "adapter.safetensors" | |
| if local_path.exists(): os.remove(local_path) | |
| repo_id, filename = parse_hf_url(input_str) | |
| if repo_id and filename: | |
| try: | |
| hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir) | |
| found = list(TempDir.rglob(filename.split("/")[-1]))[0] | |
| if found != local_path: shutil.move(found, local_path) | |
| return local_path | |
| except: pass | |
| try: | |
| if ".safetensors" in input_str and input_str.count("/") >= 2: | |
| parts = input_str.split("/") | |
| repo_id = f"{parts[0]}/{parts[1]}" | |
| filename = "/".join(parts[2:]) | |
| hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir) | |
| found = list(TempDir.rglob(filename.split("/")[-1]))[0] | |
| if found != local_path: shutil.move(found, local_path) | |
| return local_path | |
| candidates = ["adapter_model.safetensors", "model.safetensors"] | |
| files = list_repo_files(repo_id=input_str, token=token) | |
| target = next((f for f in files if f in candidates), None) | |
| if not target: | |
| safes = [f for f in files if f.endswith(".safetensors")] | |
| if safes: target = safes[0] | |
| if not target: raise ValueError("No safetensors found") | |
| hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir) | |
| found = list(TempDir.rglob(target.split("/")[-1]))[0] | |
| if found != local_path: shutil.move(found, local_path) | |
| return local_path | |
| except Exception as e: | |
| if input_str.startswith("http"): | |
| try: | |
| headers = {"Authorization": f"Bearer {token}"} if token else {} | |
| r = requests.get(input_str, stream=True, headers=headers, timeout=60) | |
| r.raise_for_status() | |
| with open(local_path, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): f.write(chunk) | |
| return local_path | |
| except: pass | |
| raise e | |
| def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16): | |
| state_dict = load_file(lora_path, device="cpu") | |
| pairs = {} | |
| alphas = {} | |
| for k, v in state_dict.items(): | |
| stem = get_key_stem(k) | |
| if "alpha" in k: | |
| alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v | |
| else: | |
| if stem not in pairs: pairs[stem] = {} | |
| if "lora_down" in k or "lora_A" in k: | |
| pairs[stem]["down"] = v.to(dtype=precision_dtype) | |
| pairs[stem]["rank"] = v.shape[0] | |
| elif "lora_up" in k or "lora_B" in k: | |
| pairs[stem]["up"] = v.to(dtype=precision_dtype) | |
| for stem in pairs: | |
| pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0))) | |
| return pairs | |
| def get_key_stem(key): | |
| key = key.replace(".weight", "").replace(".bias", "") | |
| key = key.replace(".lora_down", "").replace(".lora_up", "") | |
| key = key.replace(".lora_A", "").replace(".lora_B", "") | |
| key = key.replace(".alpha", "") | |
| prefixes = [ | |
| "model.diffusion_model.", "diffusion_model.", "model.", | |
| "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model." | |
| ] | |
| changed = True | |
| while changed: | |
| changed = False | |
| for p in prefixes: | |
| if key.startswith(p): | |
| key = key[len(p):] | |
| changed = True | |
| return key | |
| # ================================================================================= | |
| # TABS 1-4 LOGIC (RESTORED) | |
| # ================================================================================= | |
| class MemoryEfficientSafeOpen: | |
| def __init__(self, filename): | |
| self.filename = filename | |
| self.file = open(filename, "rb") | |
| self.header, self.header_size = self._read_header() | |
| def __enter__(self): return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() | |
| def keys(self) -> list[str]: return [k for k in self.header.keys() if k != "__metadata__"] | |
| def metadata(self) -> Dict[str, str]: return self.header.get("__metadata__", {}) | |
| def get_tensor(self, key): | |
| if key not in self.header: raise KeyError(f"Tensor '{key}' not found") | |
| metadata = self.header[key] | |
| start, end = metadata["data_offsets"] | |
| self.file.seek(self.header_size + 8 + start) | |
| return self._deserialize_tensor(self.file.read(end - start), metadata) | |
| def _read_header(self): | |
| header_size = struct.unpack("<Q", self.file.read(8))[0] | |
| return json.loads(self.file.read(header_size).decode("utf-8")), header_size | |
| def _deserialize_tensor(self, tensor_bytes, metadata): | |
| dtype_map = {"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool} | |
| dtype = dtype_map[metadata["dtype"]] | |
| shape = metadata["shape"] | |
| return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape) | |
| class ShardBuffer: | |
| def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"): | |
| self.max_bytes = int(max_size_gb * 1024**3) | |
| self.output_dir, self.output_repo, self.subfolder, self.hf_token, self.filename_prefix = output_dir, output_repo, subfolder, hf_token, filename_prefix | |
| self.buffer, self.current_bytes, self.shard_count, self.index_map, self.total_size = [], 0, 0, {}, 0 | |
| def add_tensor(self, key, tensor): | |
| if tensor.dtype == torch.bfloat16: raw, dt = tensor.view(torch.int16).numpy().tobytes(), "BF16" | |
| elif tensor.dtype == torch.float16: raw, dt = tensor.numpy().tobytes(), "F16" | |
| else: raw, dt = tensor.numpy().tobytes(), "F32" | |
| self.buffer.append({"key": key, "data": raw, "dtype": dt, "shape": tensor.shape}) | |
| self.current_bytes += len(raw) | |
| self.total_size += len(raw) | |
| if self.current_bytes >= self.max_bytes: self.flush() | |
| def flush(self): | |
| if not self.buffer: return | |
| self.shard_count += 1 | |
| fname = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors" | |
| header = {"__metadata__": {"format": "pt"}} | |
| curr_off = 0 | |
| for i in self.buffer: | |
| header[i["key"]] = {"dtype": i["dtype"], "shape": i["shape"], "data_offsets": [curr_off, curr_off + len(i["data"])]} | |
| curr_off += len(i["data"]) | |
| self.index_map[i["key"]] = fname | |
| out = self.output_dir / fname | |
| header_json = json.dumps(header).encode('utf-8') | |
| with open(out, 'wb') as f: | |
| f.write(struct.pack('<Q', len(header_json))) | |
| f.write(header_json) | |
| for i in self.buffer: f.write(i["data"]) | |
| api.upload_file(path_or_fileobj=out, path_in_repo=f"{self.subfolder}/{fname}" if self.subfolder else fname, repo_id=self.output_repo, token=self.hf_token) | |
| os.remove(out) | |
| self.buffer, self.current_bytes = [], 0 | |
| gc.collect() | |
| def task_merge_legacy(hf_token, base, sub, lora, scale, prec, shard, out, struct_s, priv, progress=gr.Progress()): | |
| cleanup_temp() | |
| if hf_token: login(hf_token.strip()) | |
| try: api.create_repo(repo_id=out, private=priv, exist_ok=True, token=hf_token) | |
| except Exception as e: return f"Error: {e}" | |
| if struct_s: | |
| try: | |
| files = api.list_repo_files(repo_id=struct_s, token=hf_token) | |
| for f in tqdm(files, desc="Copying Structure"): | |
| if sub and f.startswith(sub): continue | |
| if not sub and any(f.endswith(x) for x in ['.safetensors', '.bin', '.pt', '.pth']): continue | |
| l = hf_hub_download(repo_id=struct_s, filename=f, token=hf_token, local_dir=TempDir) | |
| api.upload_file(path_or_fileobj=l, path_in_repo=f, repo_id=out, token=hf_token) | |
| except: pass | |
| files = [f for f in list_repo_files(repo_id=base, token=hf_token) if f.endswith(".safetensors")] | |
| if sub: files = [f for f in files if f.startswith(sub)] | |
| if not files: return "No safetensors found" | |
| prefix = "diffusion_pytorch_model" if (sub in ["transformer", "unet"] or "diffusion_pytorch_model" in os.path.basename(files[0])) else "model" | |
| dtype = torch.bfloat16 if prec == "bf16" else torch.float16 if prec == "fp16" else torch.float32 | |
| try: lora_pairs = load_lora_to_memory(download_lora_smart(lora, hf_token), dtype) | |
| except Exception as e: return f"LoRA Error: {e}" | |
| buf = ShardBuffer(shard, TempDir, out, sub, hf_token, prefix) | |
| for i, fpath in enumerate(files): | |
| local = hf_hub_download(repo_id=base, filename=fpath, token=hf_token, local_dir=TempDir) | |
| with MemoryEfficientSafeOpen(local) as f: | |
| for k in f.keys(): | |
| v = f.get_tensor(k) | |
| stem = get_key_stem(k) | |
| match = lora_pairs.get(stem) or lora_pairs.get(stem.replace("to_q", "qkv")) or lora_pairs.get(stem.replace("to_k", "qkv")) or lora_pairs.get(stem.replace("to_v", "qkv")) | |
| if match: | |
| d, u = match["down"], match["up"] | |
| s = scale * (match["alpha"] / match["rank"]) | |
| if len(v.shape)==4 and len(d.shape)==2: d, u = d.unsqueeze(-1).unsqueeze(-1), u.unsqueeze(-1).unsqueeze(-1) | |
| delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) if len(up.shape)==4 else u @ d | |
| v = v.to(dtype).add_((delta * s).to(dtype)) | |
| buf.add_tensor(k, v.to(dtype)) | |
| os.remove(local) | |
| buf.flush() | |
| idx = {"metadata": {"total_size": buf.total_size}, "weight_map": buf.index_map} | |
| idx_n = f"{prefix}.safetensors.index.json" | |
| with open(TempDir/idx_n, "w") as f: json.dump(idx, f, indent=4) | |
| api.upload_file(path_or_fileobj=TempDir/idx_n, path_in_repo=f"{sub}/{idx_n}" if sub else idx_n, repo_id=out, token=hf_token) | |
| return "Done" | |
| def task_extract(hf_token, org, tun, rank, out): | |
| cleanup_temp() | |
| if hf_token: login(hf_token.strip()) | |
| try: | |
| p1 = download_lora_smart(org, hf_token) | |
| p2 = download_lora_smart(tun, hf_token) | |
| org_f, tun_f = MemoryEfficientSafeOpen(p1), MemoryEfficientSafeOpen(p2) | |
| lora_sd = {} | |
| common = set(org_f.keys()) & set(tun_f.keys()) | |
| for k in tqdm(common, desc="Extracting"): | |
| if "num_batches_tracked" in k or "running_mean" in k or "running_var" in k: continue | |
| m1, m2 = org_f.get_tensor(k).float(), tun_f.get_tensor(k).float() | |
| if m1.shape != m2.shape: continue | |
| diff = m2 - m1 | |
| if torch.max(torch.abs(diff)) < 1e-4: continue | |
| out_d, in_d = diff.shape[0], diff.shape[1] if len(diff.shape) > 1 else 1 | |
| r = min(int(rank), in_d, out_d) | |
| if len(diff.shape)==4: diff = diff.flatten(1) | |
| elif len(diff.shape)==1: diff = diff.unsqueeze(1) | |
| U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4) | |
| Vh = V.t() | |
| U, S, Vh = U[:, :r], S[:r], Vh[:r, :] | |
| U = U @ torch.diag(S) | |
| dist = torch.cat([U.flatten(), Vh.flatten()]) | |
| hi_val = torch.quantile(torch.abs(dist), 0.99) | |
| if hi_val > 0: U, Vh = U.clamp(-hi_val, hi_val), Vh.clamp(-hi_val, hi_val) | |
| if len(m1.shape)==4: | |
| U = U.reshape(out_d, r, 1, 1) | |
| Vh = Vh.reshape(r, in_d, m1.shape[2], m1.shape[3]) | |
| else: | |
| U, Vh = U.reshape(out_d, r), Vh.reshape(r, in_d) | |
| stem = k.replace(".weight", "") | |
| lora_sd[f"{stem}.lora_up.weight"] = U.contiguous() | |
| lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous() | |
| lora_sd[f"{stem}.alpha"] = torch.tensor(r).float() | |
| out_f = TempDir/"extracted.safetensors" | |
| save_file(lora_sd, out_f) | |
| api.create_repo(repo_id=out, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out_f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token) | |
| return "Done" | |
| except Exception as e: return f"Error: {e}" | |
| def load_full_state_dict(path): | |
| raw = load_file(path, device="cpu") | |
| cleaned = {} | |
| for k, v in raw.items(): | |
| if "lora_A" in k: new_k = k.replace("lora_A", "lora_down") | |
| elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up") | |
| else: new_k = k | |
| cleaned[new_k] = v.float() | |
| return cleaned | |
| def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private): | |
| cleanup_temp() | |
| if hf_token: login(hf_token.strip()) | |
| urls = [line.strip() for line in inputs_text.replace(" ", "\n").split('\n') if line.strip()] | |
| if len(urls) < 2: return "Error: Provide at least 2 adapters." | |
| try: weights = [float(w.strip()) for w in weight_str.split(',')] if weight_str.strip() else [1.0] * len(urls) | |
| except: return "Error parsing weights." | |
| if len(weights) < len(urls): weights += [1.0] * (len(urls) - len(weights)) | |
| paths = [] | |
| for url in tqdm(urls, desc="Downloading"): paths.append(download_lora_smart(url, hf_token)) | |
| merged = {} | |
| if "Iterative EMA" in method: | |
| base_sd = load_file(paths[0], device="cpu") | |
| gamma = None | |
| if sigma_rel > 0: | |
| t_val = sigma_rel**-2 | |
| roots = np.roots([1, 7, 16 - t_val, 12 - t_val]) | |
| gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max() | |
| for i, path in enumerate(paths[1:]): | |
| current_beta = (1 - 1 / (i + 1)) ** (gamma + 1) if gamma is not None else beta | |
| curr = load_file(path, device="cpu") | |
| for k in base_sd: | |
| if k in curr and "alpha" not in k: | |
| base_sd[k] = base_sd[k].float() * current_beta + curr[k].float() * (1 - current_beta) | |
| merged = base_sd | |
| else: | |
| states = [load_full_state_dict(p) for p in paths] | |
| all_stems = set() | |
| for s in states: | |
| for k in s: | |
| if "lora_" in k: all_stems.add(k.split(".lora_")[0]) | |
| for stem in tqdm(all_stems): | |
| down_list, up_list = [], [] | |
| alpha_sum, total_delta = 0.0, None | |
| for i, state in enumerate(states): | |
| w = weights[i] | |
| dk, uk, ak = f"{stem}.lora_down.weight", f"{stem}.lora_up.weight", f"{stem}.alpha" | |
| if dk in state and uk in state: | |
| d, u = state[dk], state[uk] | |
| alpha_sum += state[ak].item() if ak in state else d.shape[0] | |
| if "Concatenation" in method: | |
| down_list.append(d); up_list.append(u * w) | |
| elif "SVD" in method: | |
| rank = d.shape[0] | |
| alpha = state[ak].item() if ak in state else rank | |
| scale = (alpha / rank) * w | |
| delta = ((u.flatten(1) @ d.flatten(1)).reshape(u.shape[0], d.shape[1], d.shape[2], d.shape[3]) if len(d.shape)==4 else u @ d) * scale | |
| total_delta = delta if total_delta is None else total_delta + delta | |
| if "Concatenation" in method and down_list: | |
| merged[f"{stem}.lora_down.weight"] = torch.cat(down_list, dim=0).contiguous() | |
| merged[f"{stem}.lora_up.weight"] = torch.cat(up_list, dim=1).contiguous() | |
| merged[f"{stem}.alpha"] = torch.tensor(alpha_sum) | |
| elif "SVD" in method and total_delta is not None: | |
| tr = int(target_rank) | |
| flat = total_delta.flatten(1) if len(total_delta.shape)==4 else total_delta | |
| try: | |
| U, S, V = torch.svd_lowrank(flat, q=tr + 4, niter=4) | |
| Vh = V.t() | |
| U, S, Vh = U[:, :tr], S[:tr], Vh[:tr, :] | |
| U = U @ torch.diag(S) | |
| if len(total_delta.shape) == 4: | |
| U = U.reshape(total_delta.shape[0], tr, 1, 1) | |
| Vh = Vh.reshape(tr, total_delta.shape[1], total_delta.shape[2], total_delta.shape[3]) | |
| else: | |
| U, Vh = U.reshape(total_delta.shape[0], tr), Vh.reshape(tr, total_delta.shape[1]) | |
| merged[f"{stem}.lora_down.weight"] = Vh.contiguous() | |
| merged[f"{stem}.lora_up.weight"] = U.contiguous() | |
| merged[f"{stem}.alpha"] = torch.tensor(tr).float() | |
| except: pass | |
| out = TempDir / "merged_adapters.safetensors" | |
| if merged: save_file(merged, out) | |
| api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token) | |
| return f"Success! Merged to {out_repo}" | |
| def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo): | |
| cleanup_temp() | |
| if hf_token: login(hf_token.strip()) | |
| path = download_lora_smart(lora_input, hf_token) | |
| state = load_file(path, device="cpu") | |
| new_state = {} | |
| groups = {} | |
| for k in state: | |
| simple = k.split(".lora_")[0] | |
| if simple not in groups: groups[simple] = {} | |
| if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k] | |
| if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k] | |
| if "alpha" in k: groups[simple]["alpha"] = state[k] | |
| target_rank_limit = int(new_rank) | |
| for stem, g in tqdm(groups.items()): | |
| if "down" in g and "up" in g: | |
| down, up = g["down"].float(), g["up"].float() | |
| merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) if len(down.shape)==4 else up @ down | |
| flat = merged.flatten(1) | |
| U, S, V = torch.svd_lowrank(flat, q=target_rank_limit + 32) | |
| Vh = V.t() | |
| calc_rank = target_rank_limit | |
| if dynamic_method == "sv_ratio": | |
| calc_rank = int(torch.sum(S > (S[0] / dynamic_param)).item()) | |
| elif dynamic_method == "sv_cumulative": | |
| calc_rank = int(torch.searchsorted(torch.cumsum(S, 0) / torch.sum(S), dynamic_param)) + 1 | |
| elif dynamic_method == "sv_fro": | |
| calc_rank = int(torch.searchsorted(torch.cumsum(S.pow(2), 0) / torch.sum(S.pow(2)), dynamic_param**2)) + 1 | |
| final_rank = max(1, min(calc_rank, target_rank_limit, S.shape[0])) | |
| U = U[:, :final_rank] @ torch.diag(S[:final_rank]) | |
| Vh = Vh[:final_rank, :] | |
| if len(down.shape) == 4: | |
| U = U.reshape(up.shape[0], final_rank, 1, 1) | |
| Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3]) | |
| new_state[f"{stem}.lora_down.weight"] = Vh.contiguous() | |
| new_state[f"{stem}.lora_up.weight"] = U.contiguous() | |
| new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float() | |
| out = TempDir / "shrunken.safetensors" | |
| save_file(new_state, out) | |
| api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token) | |
| api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token) | |
| return "Done" | |
| # ================================================================================= | |
| # MERGEKIT & STREAMING LOGS (TABS 5-9) | |
| # ================================================================================= | |
| def parse_weight(w_str): | |
| if not w_str.strip(): return 1.0 | |
| try: | |
| if "[" in w_str: return yaml.safe_load(w_str) | |
| return float(w_str) | |
| except: return 1.0 | |
| def run_mergekit_logic(config_dict, token, out_repo, private, shard_size, output_precision, tokenizer_source, chat_template, program="mergekit-yaml"): | |
| # Using generator for streaming logs directly to a Textbox, bypassing component issues | |
| logs = [] | |
| def log(msg): | |
| logs.append(msg) | |
| return "\n".join(logs) | |
| yield log("Starting MergeKit Process...") | |
| cleanup_temp() | |
| if chat_template and chat_template.strip(): | |
| config_dict["chat_template"] = chat_template.strip() | |
| # Validation | |
| try: | |
| if program != "mergekit-moe": | |
| MergeConfiguration.model_validate(config_dict) | |
| yield log("Configuration Validated Successfully.") | |
| except Exception as e: | |
| yield log(f"Invalid Config: {e}") | |
| return | |
| if token: | |
| login(token.strip()) | |
| os.environ["HF_TOKEN"] = token.strip() | |
| if "dtype" not in config_dict: config_dict["dtype"] = output_precision | |
| if "tokenizer_source" not in config_dict and tokenizer_source != "base": | |
| config_dict["tokenizer_source"] = tokenizer_source | |
| config_path = TempDir / "config.yaml" | |
| with open(config_path, "w") as f: yaml.dump(config_dict, f, sort_keys=False) | |
| yield log(f"Config saved to {config_path}") | |
| yield log(f"YAML:\n{yaml.dump(config_dict, sort_keys=False)}") | |
| try: | |
| api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=token) | |
| yield log(f"Repo {out_repo} ready.") | |
| except Exception as e: | |
| yield log(f"Repo Creation Error (might exist): {e}") | |
| out_path = TempDir / "merge_output" | |
| shard_arg = f"{int(float(shard_size) * 1024)}M" | |
| cmd = [ | |
| program, | |
| str(config_path), | |
| str(out_path), | |
| "--allow-crimes", | |
| "--copy-tokenizer", | |
| "--out-shard-size", shard_arg, | |
| "--lazy-unpickle" | |
| ] | |
| if torch.cuda.is_available(): | |
| cmd.extend(["--cuda", "--low-cpu-memory"]) | |
| yield log(f"Executing: {' '.join(cmd)}") | |
| env = os.environ.copy() | |
| env["HF_HOME"] = str(TempDir / ".cache") | |
| # Run process | |
| process = subprocess.Popen(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) | |
| for line in iter(process.stdout.readline, ""): | |
| yield log(line.strip()) | |
| process.wait() | |
| if process.returncode != 0: | |
| yield log("Merge failed with exit code " + str(process.returncode)) | |
| return | |
| yield log(f"Uploading to {out_repo}...") | |
| try: | |
| api.upload_folder(repo_id=out_repo, folder_path=out_path) | |
| yield log("Upload Complete!") | |
| except Exception as e: | |
| yield log(f"Upload failed: {e}") | |
| # --- UI Wrappers for Tabs 5-9 --- | |
| def wrapper_amphinterpolative(token, method, base, t, norm, i8, flat, row, eps, m_iter, tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t): | |
| params = {"normalize": norm, "int8_mask": i8} | |
| if method in ["slerp", "nuslerp"]: params["t"] = float(t) | |
| if method == "nuslerp": params.update({"flatten": flat, "row_wise": row}) | |
| if method == "multislerp": params["eps"] = float(eps) | |
| if method == "karcher": params.update({"max_iter": int(m_iter), "tol": float(tol)}) | |
| config = {"merge_method": method} | |
| if method in ["slerp", "nuslerp"]: | |
| if not base.strip(): yield "Error: Base model required"; return | |
| config["base_model"] = base.strip() | |
| sources = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2)] if m.strip()] | |
| config["slices"] = [{"sources": sources, "parameters": params}] | |
| else: | |
| if base.strip() and method == "multislerp": config["base_model"] = base.strip() | |
| models = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)] if m.strip()] | |
| config["models"] = models | |
| config["parameters"] = params | |
| yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml") | |
| def wrapper_stirtie(token, method, base, norm, i8, lamb, resc, topk, m1, w1, d1, g1, e1, m2, w2, d2, g2, e2, m3, w3, d3, g3, e3, m4, w4, d4, g4, e4, out, priv, shard, prec, tok_src, chat_t): | |
| models = [] | |
| # Explicit loop over the 4 sets of model inputs | |
| for m, w, d, g, e in [ | |
| (m1, w1, d1, g1, e1), | |
| (m2, w2, d2, g2, e2), | |
| (m3, w3, d3, g3, e3), | |
| (m4, w4, d4, g4, e4) | |
| ]: | |
| if not m.strip(): continue | |
| p = {"weight": parse_weight(w)} | |
| if method in ["ties", "dare_ties", "dare_linear", "breadcrumbs_ties"]: p["density"] = parse_weight(d) | |
| if "breadcrumbs" in method: p["gamma"] = float(g) | |
| if "della" in method: p["epsilon"] = float(e) | |
| models.append({"model": m, "parameters": p}) | |
| g_params = {"normalize": norm, "int8_mask": i8} | |
| if method != "sce": g_params["lambda"] = float(lamb) | |
| if method == "dare_linear": g_params["rescale"] = resc | |
| if method == "sce": g_params["select_topk"] = float(topk) | |
| config = { | |
| "merge_method": method, | |
| "base_model": base.strip() if base.strip() else models[0]["model"], | |
| "parameters": g_params, | |
| "models": models | |
| } | |
| yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml") | |
| def wrapper_specious(token, method, base, norm, i8, t, filt_w, m1, w1, f1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t): | |
| models = [] | |
| if method == "passthrough": | |
| p = {"weight": parse_weight(w1)} | |
| if f1.strip(): p["filter"] = f1.strip() | |
| models.append({"model": m1, "parameters": p}) | |
| else: | |
| models = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)] if m.strip()] | |
| config = {"merge_method": method, "parameters": {"normalize": norm, "int8_mask": i8}} | |
| if base.strip(): config["base_model"] = base.strip() | |
| if method == "nearswap": config["parameters"]["t"] = float(t) | |
| if method == "model_stock": config["parameters"]["filter_wise"] = filt_w | |
| config["models"] = models | |
| yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml") | |
| def wrapper_moer(token, base, expert1, prompt1, expert2, prompt2, expert3, prompt3, expert4, prompt4, expert5, prompt5, gate, dtype, out, priv, shard, prec, tok_src, chat_t): | |
| experts = [] | |
| for exp, pmt in [ | |
| (expert1, prompt1), (expert2, prompt2), (expert3, prompt3), | |
| (expert4, prompt4), (expert5, prompt5) | |
| ]: | |
| if exp.strip(): | |
| expert_entry = {"source_model": exp.strip()} | |
| # Parse prompts (comma-separated) | |
| if pmt.strip(): | |
| prompts = [p.strip() for p in pmt.split(',') if p.strip()] | |
| expert_entry["positive_prompts"] = prompts | |
| else: | |
| expert_entry["positive_prompts"] = [""] | |
| experts.append(expert_entry) | |
| if len(experts) < 2: | |
| return "Error: At least 2 experts required" | |
| # Build config for MoE | |
| config = { | |
| "base_model": base.strip(), | |
| "gate_mode": gate, | |
| "dtype": dtype, | |
| "experts": experts | |
| } | |
| # Uses mergekit-moe CLI | |
| yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-moe") | |
| def wrapper_rawer(token, models, method, dtype, out, priv, shard, prec, tok_src, chat_t): | |
| models_list = [{"model": m.strip(), "parameters": {"weight": 1.0}} for m in models.split('\n') if m.strip()] | |
| config = { | |
| "models": models_list, | |
| "merge_method": method, | |
| "dtype": dtype | |
| } | |
| yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml") | |
| # --- TAB 10 (Custom DARE) Logic --- | |
| def task_dare_custom(token, base, ft, ratio, mask, out, priv): | |
| cleanup_temp() | |
| if token: login(token.strip()) | |
| try: | |
| b_path = download_lora_smart(base, token) | |
| f_path = download_lora_smart(ft, token) | |
| b_sd = load_file(b_path, device="cpu") | |
| f_sd = load_file(f_path, device="cpu") | |
| merged = {} | |
| common = set(b_sd.keys()) & set(f_sd.keys()) | |
| for k in tqdm(common, desc="Merging"): | |
| tb, tf = b_sd[k], f_sd[k] | |
| if tb.shape != tf.shape: | |
| merged[k] = tf | |
| continue | |
| delta = tf.float() - tb.float() | |
| if mask > 0: | |
| m = torch.bernoulli(torch.full_like(delta, 1.0 - mask)) | |
| delta = (delta * m) / (1.0 - mask) | |
| merged[k] = (tb.float() + ratio * delta).to(tb.dtype) | |
| out_f = TempDir / "model.safetensors" | |
| save_file(merged, out_f) | |
| api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token) | |
| api.upload_file(path_or_fileobj=out_f, path_in_repo="model.safetensors", repo_id=out, token=token) | |
| return f"Done! {out}" | |
| except Exception as e: return str(e) | |
| # ================================================================================= | |
| # UI GENERATION | |
| # ================================================================================= | |
| css = ".container { max-width: 1100px; margin: auto; }" | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1><img src="https://huggingface.co/spaces/AlekseyCalvin/Soon_Merger/resolve/main/SMerger3.png" alt="SOONmerge®"> Transform Transformers for FREE!</h1>""") | |
| gr.Markdown("# 🧰Training-Free CPU-run Model Creation Toolkit | **MergeKit** implementation in Tabs 5-9 & MORE") | |
| with gr.Tabs(): | |
| # --- TAB 1: RESTORED --- | |
| with gr.Tab("Merge to Base Model + Reshard Output"): | |
| t1_token = gr.Textbox(label="Token", type="password") | |
| t1_base = gr.Textbox(label="Base Repo", value="name/repo") | |
| t1_sub = gr.Textbox(label="Subfolder (Optional)", value="") | |
| t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors") | |
| with gr.Row(): | |
| t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1) | |
| t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision") | |
| t1_shard = gr.Slider(label="Max Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1) | |
| t1_out = gr.Textbox(label="Output Repo") | |
| t1_struct = gr.Textbox(label="Extras Source (copies configs/components/etc)", value="name/repo") | |
| t1_priv = gr.Checkbox(label="Private", value=True) | |
| gr.Button("Merge").click(task_merge_legacy, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], gr.Textbox(label="Result")) | |
| # --- TAB 2: RESTORED --- | |
| with gr.Tab("Extract Adapter"): | |
| t2_token = gr.Textbox(label="Token", type="password") | |
| t2_org = gr.Textbox(label="Original Model") | |
| t2_tun = gr.Textbox(label="Tuned or Homologous Model") | |
| t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1) | |
| t2_out = gr.Textbox(label="Output Repo") | |
| gr.Button("Extract").click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], gr.Textbox(label="Result")) | |
| # --- TAB 3: RESTORED --- | |
| with gr.Tab("Merge Adapters"): | |
| gr.Markdown("### Batch Adapter Merging") | |
| t3_token = gr.Textbox(label="Token", type="password") | |
| t3_urls = gr.TextArea(label="Adapter URLs/Repos (one per line, or space-separated)") | |
| t3_method = gr.Dropdown(["Iterative EMA (Linear w/ Beta/Sigma coefficient)", "Concatenation (MOE-like weights-stack)", "SVD Fusion (Task Arithmetic/Compressed)"], value="Iterative EMA (Linear w/ Beta/Sigma coefficient)", label="Merge Method") | |
| with gr.Row(): | |
| t3_weights = gr.Textbox(label="Weights (comma-separated) – for Concat/SVD") | |
| t3_rank = gr.Number(label="Target Rank – For SVD only", value=128) | |
| with gr.Row(): | |
| t3_beta = gr.Slider(label="Beta – for linear/post-hoc EMA", value=0.95, minimum=0.01, maximum=1.00) | |
| t3_sigma = gr.Slider(label="Sigma Rel – for linear/post-hoc EMA", value=0.21, minimum=0.01, maximum=1.00) | |
| t3_out = gr.Textbox(label="Output Repo") | |
| t3_priv = gr.Checkbox(label="Private Output", value=True) | |
| gr.Button("Merge").click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], gr.Textbox(label="Result")) | |
| # --- TAB 4: RESTORED --- | |
| with gr.Tab("Resize Adapter"): | |
| t4_token = gr.Textbox(label="Token", type="password") | |
| t4_in = gr.Textbox(label="LoRA") | |
| with gr.Row(): | |
| t4_rank = gr.Number(label="To Rank (Safety Ceiling)", value=8) | |
| t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None", label="Dynamic Method") | |
| t4_param = gr.Number(label="Dynamic Param", value=0.9) | |
| gr.Markdown("### 📉 Dynamic Resizing Guide\nThese methods intelligently determine the best rank per layer.\n- **sv_ratio (Relative Strength):** Keeps features that are at least `1/Param` as strong as the main feature. **Param must be >= 2**.\n- **sv_fro (Visual Information Density):** Preserves `Param%` of total information content. **Param between 0.0 and 1.0**.\n- **sv_cumulative (Cumulative Sum):** Preserves weights that sum up to `Param%` of total strength. **Param between 0.0 and 1.0**.\n- **⚠️ Safety Ceiling:** The **'To Rank'** slider acts as a hard limit.") | |
| t4_out = gr.Textbox(label="Output") | |
| gr.Button("Resize").click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], gr.Textbox(label="Result")) | |
| # --- TAB 5 --- | |
| with gr.Tab("Amphinterpolative"): | |
| gr.Markdown("### Spherical Interpolation Methods Family: slerp, nuslerp, multislerp, karcher") | |
| t5_token = gr.Textbox(label="HF Token", type="password") | |
| t5_method = gr.Dropdown(["slerp", "nuslerp", "multislerp", "karcher"], value="slerp", label="Merge Method") | |
| gr.Markdown("See [MergeKit Merge Method Docs](https://github.com/arcee-ai/mergekit/blob/main/docs/merge_methods.md) for more info.") | |
| with gr.Row(): | |
| t5_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0) | |
| t5_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision") | |
| t5_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source") | |
| t5_chat = gr.Textbox(label="Chat Template (default: auto)", placeholder="auto") | |
| gr.Markdown("Built-in Chat Templates: alpaca, chatml, llama3, mistral, exaone, auto") | |
| with gr.Row(): | |
| t5_base = gr.Textbox(label="Base Model") | |
| t5_t = gr.Slider(0, 1, 0.5, label="t") | |
| with gr.Row(): | |
| t5_norm = gr.Checkbox(label="Normalize Weights", value=True); t5_i8 = gr.Checkbox(label="Int8 Mask", value=False); t5_flat = gr.Checkbox(label="Flatten Tensors (NuSlerp)", value=False); t5_row = gr.Checkbox(label="Row Wise (NuSlerp)", value=False) | |
| with gr.Row(): | |
| t5_eps = gr.Textbox(label="eps (Stabilization Constant) (MultiSlerp)", value="1e-8"); t5_iter = gr.Number(label="Max Iterations (Karcher)", value=10); t5_tol = gr.Textbox(label="tol (Convergence Tolerance) (Karcher)", value="1e-5") | |
| gr.Markdown("**MODELS**: **slerp:** 2 models exactly, one of the 2 also listed as *Base* | **nuslerp:** 2 models exactly; *Base*: optional | **multislerp:** 2+ models; *Base*: optional | **karcher:** 2+ models; *Base*: none") | |
| m1, w1 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"); m2, w2 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0") | |
| with gr.Accordion("More", open=False): | |
| m3, w3 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); m4, w4 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); m5, w5 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0") | |
| t5_out = gr.Textbox(label="Output Repo"); t5_priv = gr.Checkbox(label="Private", value=True) | |
| t5_btn = gr.Button("Execute") | |
| t5_res = gr.Textbox(label="Result", lines=10) | |
| t5_btn.click(wrapper_amphinterpolative, [t5_token, t5_method, t5_base, t5_t, t5_norm, t5_i8, t5_flat, t5_row, t5_eps, t5_iter, t5_tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, t5_out, t5_priv, t5_shard, t5_prec, t5_tok, t5_chat], t5_res) | |
| # --- TAB 6 --- | |
| with gr.Tab("Stir/Tie Bases"): | |
| gr.Markdown("### Task Vector Methods Family: task_arithmetic, ties, dare_ties, dare_linear, della, della_linear, breadcrumbs, breadcrumbs_ties, sce") | |
| t6_token = gr.Textbox(label="Token", type="password") | |
| t6_method = gr.Dropdown(["task_arithmetic", "ties", "dare_ties", "dare_linear", "della", "della_linear", "breadcrumbs", "breadcrumbs_ties", "sce"], value="ties", label="Merge Method") | |
| gr.Markdown("See [MergeKit Merge Method Docs](https://github.com/arcee-ai/mergekit/blob/main/docs/merge_methods.md) for more info.") | |
| with gr.Row(): | |
| t6_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0); t6_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t6_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t6_chat = gr.Textbox(label="Chat Template", placeholder="auto") | |
| gr.Markdown("Built-in **Chat Templates**: alpaca, chatml, llama3, mistral, exaone, auto (default)") | |
| t6_base = gr.Textbox(label="Base Model (required)") | |
| gr.Markdown("**MODELS**: These methods all accept **2 or more models**, and require one of these designated as *Base*") | |
| with gr.Row(): | |
| t6_norm = gr.Checkbox(label="Normalize Weights", value=True); t6_i8 = gr.Checkbox(label="Int8 Mask", value=False); t6_resc = gr.Checkbox(label="Rescale (Dare_Linear)", value=True); t6_lamb = gr.Number(label="Lambda", value=1.0); t6_topk = gr.Slider(0, 1, 1.0, label="Select TopK (SCE)") | |
| m1_6, w1_6 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"); d1_6, g1_6, e1_6 = gr.Textbox(label="Density (DARE/TIES)", value="1.0"), gr.Number(label="Gamma (breadcrumbs)", value=0.01), gr.Number(label="Epsilon (DELLA)", value=0.15) | |
| m2_6, w2_6 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0"); d2_6, g2_6, e2_6 = gr.Textbox(label="Density (DARE/TIES)", value="1.0"), gr.Number(label="Gamma (breadcrumbs)", value=0.01), gr.Number(label="Epsilon (DELLA)", value=0.15) | |
| with gr.Accordion("More", open=False): | |
| m3_6, w3_6 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); d3_6, g3_6, e3_6 = gr.Textbox(label="Density (DARE/TIES)", value="1.0"), gr.Number(label="Gamma (breadcrumbs)", value=0.01), gr.Number(label="Epsilon (DELLA)", value=0.15) | |
| m4_6, w4_6 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); d4_6, g4_6, e4_6 = gr.Textbox(label="Density (DARE/TIES)", value="1.0"), gr.Number(label="Gamma (breadcrumbs)", value=0.01), gr.Number(label="Epsilon (DELLA)", value=0.15) | |
| t6_out = gr.Textbox(label="Output Repo"); t6_priv = gr.Checkbox(label="Private", value=True) | |
| t6_btn = gr.Button("Execute") | |
| t6_res = gr.Textbox(label="Result", lines=10) | |
| t6_btn.click(wrapper_stirtie, [t6_token, t6_method, t6_base, t6_norm, t6_i8, t6_lamb, t6_resc, t6_topk, m1_6, w1_6, d1_6, g1_6, e1_6, m2_6, w2_6, d2_6, g2_6, e2_6, m3_6, w3_6, d3_6, g3_6, e3_6, m4_6, w4_6, d4_6, g4_6, e4_6, t6_out, t6_priv, t6_shard, t6_prec, t6_tok, t6_chat], t6_res) | |
| # --- TAB 7 --- | |
| with gr.Tab("Specious"): | |
| gr.Markdown("### Specialized Methods: model_stock, nearswap, arcee_fusion, passthrough") | |
| t7_token = gr.Textbox(label="Token", type="password") | |
| t7_method = gr.Dropdown(["model_stock", "nearswap", "arcee_fusion", "passthrough", "linear"], value="model_stock", label="Merge Method") | |
| gr.Markdown("See [MergeKit Merge Method Docs](https://github.com/arcee-ai/mergekit/blob/main/docs/merge_methods.md) for more info.") | |
| with gr.Row(): | |
| t7_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0); t7_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t7_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t7_chat = gr.Textbox(label="Chat Template", placeholder="auto") | |
| gr.Markdown("Built-in **Chat Templates**: alpaca, chatml, llama3, mistral, exaone, auto (default)") | |
| t7_base = gr.Textbox(label="Base Model (required for nearswap/arcee_fusion/model_stock)", placeholder="org/base-model") | |
| gr.Markdown("**MODELS**: **passthrough:** 1 model acc. to Docs, but [Examples](https://github.com/arcee-ai/mergekit/tree/main/examples) shows 2+ | **nearswap/arcee_fusion:** 2 models, one also listed as *Base* | **model_stock:** 3+ models, one also listed as *Base*") | |
| with gr.Row(): | |
| t7_norm = gr.Checkbox(label="Normalize", value=True); t7_i8 = gr.Checkbox(label="Int8 Mask", value=False); t7_t = gr.Slider(0, 1, 0.5, label="t (Interpolation Ratio, for Nearswap)"); t7_filt_w = gr.Checkbox(label="Filter Wise (for Model_Stock)", value=False) | |
| m1_7, w1_7, f1_7 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"), gr.Textbox(label="Filter Model Component") | |
| m2_7, w2_7 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0") | |
| with gr.Accordion("More", open=False): | |
| m3_7, w3_7 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); m4_7, w4_7 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); m5_7, w5_7 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0") | |
| t7_out = gr.Textbox(label="Output Repo"); t7_priv = gr.Checkbox(label="Private", value=True) | |
| t7_btn = gr.Button("Execute") | |
| t7_res = gr.Textbox(label="Result", lines=10) | |
| t7_btn.click(wrapper_specious, [t7_token, t7_method, t7_base, t7_norm, t7_i8, t7_t, t7_filt_w, m1_7, w1_7, f1_7, m2_7, w2_7, m3_7, w3_7, m4_7, w4_7, m5_7, w5_7, t7_out, t7_priv, t7_shard, t7_prec, t7_tok, t7_chat], t7_res) | |
| # --- TAB 8 (MoEr) --- | |
| with gr.Tab("MoEr"): | |
| gr.Markdown("### Mixture of Experts: fuses self-attention & normalization layers from *Base* w/MLP layers from *Experts*") | |
| gr.Markdown("See [MergeKit MoE doc](https://github.com/arcee-ai/mergekit/blob/main/docs/moe.md) for more info.") | |
| t8_token = gr.Textbox(label="Token", type="password") | |
| with gr.Row(): | |
| t8_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0); t8_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t8_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t8_chat = gr.Textbox(label="Chat Template", placeholder="auto") | |
| t8_base = gr.Textbox(label="Base Model (Required)"); t8_gate = gr.Dropdown(["cheap_embed", "random", "hidden"], value="cheap_embed", label="Gate Mode"); t8_dtype = gr.Dropdown(["float16", "bfloat16"], value="bfloat16", label="Internal Dtype") | |
| gr.Markdown("#### Experts (at least 2 required)") | |
| gr.Markdown("Prompts are comma-separated descriptors for each expert") | |
| with gr.Row(): | |
| t8_expert1 = gr.Textbox(label="Expert 1", placeholder="org/expert1") | |
| t8_prompt1 = gr.Textbox(label="Positive Prompts", placeholder="math, reasoning, logic") | |
| with gr.Row(): | |
| t8_expert2 = gr.Textbox(label="Expert 2", placeholder="org/expert2") | |
| t8_prompt2 = gr.Textbox(label="Positive Prompts", placeholder="creative, writing, storytelling") | |
| with gr.Row(): | |
| t8_expert3 = gr.Textbox(label="Expert 3 (optional)", placeholder="org/expert3") | |
| t8_prompt3 = gr.Textbox(label="Positive Prompts", placeholder="code, programming") | |
| with gr.Row(): | |
| t8_expert4 = gr.Textbox(label="Expert 4 (optional)", placeholder="org/expert4") | |
| t8_prompt4 = gr.Textbox(label="Positive Prompts", placeholder="") | |
| with gr.Row(): | |
| t8_expert5 = gr.Textbox(label="Expert 5 (optional)", placeholder="org/expert5") | |
| t8_prompt5 = gr.Textbox(label="Positive Prompts", placeholder="") | |
| t8_out = gr.Textbox(label="Output Repo"); t8_priv = gr.Checkbox(label="Private", value=True) | |
| t8_btn = gr.Button("Build MoE") | |
| t8_res = gr.Textbox(label="Result", lines=10) | |
| t8_btn.click(wrapper_moer, [t8_token, t8_base, t8_expert1, t8_prompt1, t8_expert2, t8_prompt2, t8_expert3, t8_prompt3, t8_expert4, t8_prompt4, t8_expert5, t8_prompt5, t8_gate, t8_dtype, t8_out, t8_priv, t8_shard, t8_prec, t8_tok, t8_chat], t8_res) | |
| # --- TAB 9 (Rawer) --- | |
| with gr.Tab("Rawer"): | |
| gr.Markdown("### Raw PyTorch MergeKit / Non-pipeline-classed Models") | |
| t9_token = gr.Textbox(label="Token", type="password"); t9_models = gr.TextArea(label="Models (one per line)") | |
| with gr.Row(): | |
| t9_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0); t9_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t9_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t9_chat = gr.Textbox(label="Chat Template", placeholder="auto") | |
| gr.Markdown("Built-in Chat Templates: alpaca, chatml, llama3, mistral, exaone, auto") | |
| gr.Markdown("See [MergeKit Merge Method Docs](https://github.com/arcee-ai/mergekit/blob/main/docs/merge_methods.md) for more info.") | |
| t9_method = gr.Dropdown(["linear", "passthrough"], value="linear", label="Merge Method"); t9_dtype = gr.Dropdown(["float32", "float16", "bfloat16"], value="float32", label="Config dtype") | |
| t9_out = gr.Textbox(label="Output Repo"); t9_priv = gr.Checkbox(label="Private", value=True) | |
| t9_btn = gr.Button("Merge Raw") | |
| t9_res = gr.Textbox(label="Result", lines=10) | |
| t9_btn.click(wrapper_rawer, [t9_token, t9_models, t9_method, t9_dtype, t9_out, t9_priv, t9_shard, t9_prec, t9_tok, t9_chat], t9_res) | |
| # --- TAB 10 --- | |
| with gr.Tab("Mario,DARE!"): | |
| gr.Markdown("### Model-Agnostic DARE Implementation (Drop And REscale)") | |
| gr.Markdown("From [sft-merger by Martyn Garcia](https://github.com/martyn)") | |
| t10_token = gr.Textbox(label="Token", type="password") | |
| gr.Markdown( | |
| """ | |
| ### How DARE Works: | |
| 1. **Compute Delta**: Difference between fine-tuned and base weights | |
| 2. **Drop Elements**: Randomly mask out delta values based on mask rate | |
| 3. **Rescale**: Compensate for dropped elements by rescaling remaining values | |
| 4. **Apply**: Add scaled delta back to base model | |
| **Mask Rate**: 0.5 = drop 50% of delta values, 0.9 = drop 90% (more aggressive sparsification) | |
| """ | |
| ) | |
| with gr.Row(): | |
| t10_base = gr.Textbox(label="Base Model", placeholder="org/base-model"); t10_ft = gr.Textbox(label="Fine-Tuned Model", placeholder="org/fine-tuned-model") | |
| with gr.Row(): | |
| t10_ratio = gr.Slider(value=1.0, minimum=0.0, maximum=2.0, step=0.1, label="Merge Ratio (delta weight)"); t10_mask = gr.Slider(value=0.5, minimum=0.0, maximum=0.99, step=0.01, label="Mask Rate (drop probability)") | |
| t10_out = gr.Textbox(label="Output Repo"); t10_priv = gr.Checkbox(label="Private", value=True) | |
| gr.Button("Run").click(task_dare_custom, [t10_token, t10_base, t10_ft, t10_ratio, t10_mask, t10_out, t10_priv], gr.Textbox(label="Result")) | |
| if __name__ == "__main__": | |
| demo.queue().launch(css=css, ssr_mode=False, mcp_server=True) |