AlekseyCalvin commited on
Commit
be26fd7
·
verified ·
1 Parent(s): 086cad3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1157 -0
app.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import yaml
11
+ import subprocess
12
+ import sys
13
+ import tempfile
14
+ import re
15
+ from pathlib import Path
16
+ from typing import Dict, Any, Optional, List, Iterable
17
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
18
+ from safetensors.torch import load_file, save_file
19
+ from tqdm import tqdm
20
+
21
+ # --- Import MergeKit Config ---
22
+ try:
23
+ from mergekit.config import MergeConfiguration
24
+ except ImportError:
25
+ # Fallback if installation fails temporarily
26
+ class MergeConfiguration:
27
+ @staticmethod
28
+ def model_validate(config): pass
29
+
30
+ # --- Constants & Setup ---
31
+ try:
32
+ TempDir = Path("/tmp/temp_tool")
33
+ os.makedirs(TempDir, exist_ok=True)
34
+ except:
35
+ TempDir = Path("./temp_tool")
36
+ os.makedirs(TempDir, exist_ok=True)
37
+
38
+ api = HfApi()
39
+
40
+ def cleanup_temp():
41
+ if TempDir.exists():
42
+ shutil.rmtree(TempDir)
43
+ os.makedirs(TempDir, exist_ok=True)
44
+ gc.collect()
45
+
46
+ # =================================================================================
47
+ # SHARED HELPERS
48
+ # =================================================================================
49
+
50
+ def parse_hf_url(url):
51
+ if "huggingface.co" in url and "resolve" in url:
52
+ try:
53
+ parts = url.split("huggingface.co/")[-1].split("/")
54
+ repo_id = f"{parts[0]}/{parts[1]}"
55
+ filename = "/".join(parts[4:]).split("?")[0]
56
+ return repo_id, filename
57
+ except:
58
+ return None, None
59
+ return None, None
60
+
61
+ def download_lora_smart(input_str, token):
62
+ local_path = TempDir / "adapter.safetensors"
63
+ if local_path.exists(): os.remove(local_path)
64
+
65
+ repo_id, filename = parse_hf_url(input_str)
66
+ if repo_id and filename:
67
+ try:
68
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
69
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
70
+ if found != local_path: shutil.move(found, local_path)
71
+ return local_path
72
+ except: pass
73
+ try:
74
+ if ".safetensors" in input_str and input_str.count("/") >= 2:
75
+ parts = input_str.split("/")
76
+ repo_id = f"{parts[0]}/{parts[1]}"
77
+ filename = "/".join(parts[2:])
78
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
79
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
80
+ if found != local_path: shutil.move(found, local_path)
81
+ return local_path
82
+ candidates = ["adapter_model.safetensors", "model.safetensors"]
83
+ files = list_repo_files(repo_id=input_str, token=token)
84
+ target = next((f for f in files if f in candidates), None)
85
+ if not target:
86
+ safes = [f for f in files if f.endswith(".safetensors")]
87
+ if safes: target = safes[0]
88
+ if not target: raise ValueError("No safetensors found")
89
+ hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
90
+ found = list(TempDir.rglob(target.split("/")[-1]))[0]
91
+ if found != local_path: shutil.move(found, local_path)
92
+ return local_path
93
+ except Exception as e:
94
+ if input_str.startswith("http"):
95
+ try:
96
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
97
+ r = requests.get(input_str, stream=True, headers=headers, timeout=60)
98
+ r.raise_for_status()
99
+ with open(local_path, 'wb') as f:
100
+ for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
101
+ return local_path
102
+ except: pass
103
+ raise e
104
+
105
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
106
+ state_dict = load_file(lora_path, device="cpu")
107
+ pairs = {}
108
+ alphas = {}
109
+ for k, v in state_dict.items():
110
+ stem = get_key_stem(k)
111
+ if "alpha" in k:
112
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
113
+ else:
114
+ if stem not in pairs: pairs[stem] = {}
115
+ if "lora_down" in k or "lora_A" in k:
116
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
117
+ pairs[stem]["rank"] = v.shape[0]
118
+ elif "lora_up" in k or "lora_B" in k:
119
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
120
+ for stem in pairs:
121
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
122
+ return pairs
123
+
124
+ def get_key_stem(key):
125
+ key = key.replace(".weight", "").replace(".bias", "")
126
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
127
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
128
+ key = key.replace(".alpha", "")
129
+ prefixes = [
130
+ "model.diffusion_model.", "diffusion_model.", "model.",
131
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
132
+ ]
133
+ changed = True
134
+ while changed:
135
+ changed = False
136
+ for p in prefixes:
137
+ if key.startswith(p):
138
+ key = key[len(p):]
139
+ changed = True
140
+ return key
141
+
142
+ # =================================================================================
143
+ # TABS 1-4 LOGIC
144
+ # =================================================================================
145
+
146
+ class MemoryEfficientSafeOpen:
147
+ def __init__(self, filename):
148
+ self.filename = filename
149
+ self.file = open(filename, "rb")
150
+ self.header, self.header_size = self._read_header()
151
+ def __enter__(self): return self
152
+ def __exit__(self, exc_type, exc_val, exc_tb): self.file.close()
153
+ def keys(self) -> list[str]: return [k for k in self.header.keys() if k != "__metadata__"]
154
+ def metadata(self) -> Dict[str, str]: return self.header.get("__metadata__", {})
155
+ def get_tensor(self, key):
156
+ if key not in self.header: raise KeyError(f"Tensor '{key}' not found")
157
+ metadata = self.header[key]
158
+ start, end = metadata["data_offsets"]
159
+ self.file.seek(self.header_size + 8 + start)
160
+ return self._deserialize_tensor(self.file.read(end - start), metadata)
161
+ def _read_header(self):
162
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
163
+ return json.loads(self.file.read(header_size).decode("utf-8")), header_size
164
+ def _deserialize_tensor(self, tensor_bytes, metadata):
165
+ dtype_map = {"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool}
166
+ dtype = dtype_map[metadata["dtype"]]
167
+ shape = metadata["shape"]
168
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
169
+
170
+ class ShardBuffer:
171
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
172
+ self.max_bytes = int(max_size_gb * 1024**3)
173
+ self.output_dir, self.output_repo, self.subfolder, self.hf_token, self.filename_prefix = output_dir, output_repo, subfolder, hf_token, filename_prefix
174
+ self.buffer, self.current_bytes, self.shard_count, self.index_map, self.total_size = [], 0, 0, {}, 0
175
+ def add_tensor(self, key, tensor):
176
+ if tensor.dtype == torch.bfloat16: raw, dt = tensor.view(torch.int16).numpy().tobytes(), "BF16"
177
+ elif tensor.dtype == torch.float16: raw, dt = tensor.numpy().tobytes(), "F16"
178
+ else: raw, dt = tensor.numpy().tobytes(), "F32"
179
+ self.buffer.append({"key": key, "data": raw, "dtype": dt, "shape": tensor.shape})
180
+ self.current_bytes += len(raw)
181
+ self.total_size += len(raw)
182
+ if self.current_bytes >= self.max_bytes: self.flush()
183
+ def flush(self):
184
+ if not self.buffer: return
185
+ self.shard_count += 1
186
+ fname = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
187
+ header = {"__metadata__": {"format": "pt"}}
188
+ curr_off = 0
189
+ for i in self.buffer:
190
+ header[i["key"]] = {"dtype": i["dtype"], "shape": i["shape"], "data_offsets": [curr_off, curr_off + len(i["data"])]}
191
+ curr_off += len(i["data"])
192
+ self.index_map[i["key"]] = fname
193
+ out = self.output_dir / fname
194
+ header_json = json.dumps(header).encode('utf-8')
195
+ with open(out, 'wb') as f:
196
+ f.write(struct.pack('<Q', len(header_json)))
197
+ f.write(header_json)
198
+ for i in self.buffer: f.write(i["data"])
199
+ api.upload_file(path_or_fileobj=out, path_in_repo=f"{self.subfolder}/{fname}" if self.subfolder else fname, repo_id=self.output_repo, token=self.hf_token)
200
+ os.remove(out)
201
+ self.buffer, self.current_bytes = [], 0
202
+ gc.collect()
203
+
204
+ def task_merge_legacy(hf_token, base, sub, lora, scale, prec, shard, out, struct_s, priv, progress=gr.Progress()):
205
+ cleanup_temp()
206
+ if hf_token: login(hf_token.strip())
207
+ try: api.create_repo(repo_id=out, private=priv, exist_ok=True, token=hf_token)
208
+ except Exception as e: return f"Error: {e}"
209
+ if struct_s:
210
+ try:
211
+ files = api.list_repo_files(repo_id=struct_s, token=hf_token)
212
+ for f in tqdm(files, desc="Copying Structure"):
213
+ if sub and f.startswith(sub): continue
214
+ if not sub and any(f.endswith(x) for x in ['.safetensors', '.bin', '.pt', '.pth']): continue
215
+ l = hf_hub_download(repo_id=struct_s, filename=f, token=hf_token, local_dir=TempDir)
216
+ api.upload_file(path_or_fileobj=l, path_in_repo=f, repo_id=out, token=hf_token)
217
+ except: pass
218
+
219
+ files = [f for f in list_repo_files(repo_id=base, token=hf_token) if f.endswith(".safetensors")]
220
+ if sub: files = [f for f in files if f.startswith(sub)]
221
+ if not files: return "No safetensors found"
222
+
223
+ prefix = "diffusion_pytorch_model" if (sub in ["transformer", "unet"] or "diffusion_pytorch_model" in os.path.basename(files[0])) else "model"
224
+ dtype = torch.bfloat16 if prec == "bf16" else torch.float16 if prec == "fp16" else torch.float32
225
+ try: lora_pairs = load_lora_to_memory(download_lora_smart(lora, hf_token), dtype)
226
+ except Exception as e: return f"LoRA Error: {e}"
227
+
228
+ buf = ShardBuffer(shard, TempDir, out, sub, hf_token, prefix)
229
+ for i, fpath in enumerate(files):
230
+ local = hf_hub_download(repo_id=base, filename=fpath, token=hf_token, local_dir=TempDir)
231
+ with MemoryEfficientSafeOpen(local) as f:
232
+ for k in f.keys():
233
+ v = f.get_tensor(k)
234
+ stem = get_key_stem(k)
235
+ match = lora_pairs.get(stem) or lora_pairs.get(stem.replace("to_q", "qkv")) or lora_pairs.get(stem.replace("to_k", "qkv")) or lora_pairs.get(stem.replace("to_v", "qkv"))
236
+ if match:
237
+ d, u = match["down"], match["up"]
238
+ s = scale * (match["alpha"] / match["rank"])
239
+ if len(v.shape)==4 and len(d.shape)==2: d, u = d.unsqueeze(-1).unsqueeze(-1), u.unsqueeze(-1).unsqueeze(-1)
240
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) if len(up.shape)==4 else u @ d
241
+ v = v.to(dtype).add_((delta * s).to(dtype))
242
+ buf.add_tensor(k, v.to(dtype))
243
+ os.remove(local)
244
+ buf.flush()
245
+ idx = {"metadata": {"total_size": buf.total_size}, "weight_map": buf.index_map}
246
+ idx_n = f"{prefix}.safetensors.index.json"
247
+ with open(TempDir/idx_n, "w") as f: json.dump(idx, f, indent=4)
248
+ api.upload_file(path_or_fileobj=TempDir/idx_n, path_in_repo=f"{sub}/{idx_n}" if sub else idx_n, repo_id=out, token=hf_token)
249
+ return "Done"
250
+
251
+ def task_extract(hf_token, org, tun, rank, out):
252
+ cleanup_temp()
253
+ if hf_token: login(hf_token.strip())
254
+ try:
255
+ p1 = download_lora_smart(org, hf_token)
256
+ p2 = download_lora_smart(tun, hf_token)
257
+ org_f, tun_f = MemoryEfficientSafeOpen(p1), MemoryEfficientSafeOpen(p2)
258
+ lora_sd = {}
259
+ common = set(org_f.keys()) & set(tun_f.keys())
260
+ for k in tqdm(common, desc="Extracting"):
261
+ if "num_batches_tracked" in k or "running_mean" in k or "running_var" in k: continue
262
+ m1, m2 = org_f.get_tensor(k).float(), tun_f.get_tensor(k).float()
263
+ if m1.shape != m2.shape: continue
264
+ diff = m2 - m1
265
+ if torch.max(torch.abs(diff)) < 1e-4: continue
266
+ out_d, in_d = diff.shape[0], diff.shape[1] if len(diff.shape) > 1 else 1
267
+ r = min(int(rank), in_d, out_d)
268
+ if len(diff.shape)==4: diff = diff.flatten(1)
269
+ elif len(diff.shape)==1: diff = diff.unsqueeze(1)
270
+ U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
271
+ Vh = V.t()
272
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
273
+ U = U @ torch.diag(S)
274
+ dist = torch.cat([U.flatten(), Vh.flatten()])
275
+ hi_val = torch.quantile(torch.abs(dist), 0.99)
276
+ if hi_val > 0: U, Vh = U.clamp(-hi_val, hi_val), Vh.clamp(-hi_val, hi_val)
277
+ if len(m1.shape)==4:
278
+ U = U.reshape(out_d, r, 1, 1)
279
+ Vh = Vh.reshape(r, in_d, m1.shape[2], m1.shape[3])
280
+ else:
281
+ U, Vh = U.reshape(out_d, r), Vh.reshape(r, in_d)
282
+ stem = k.replace(".weight", "")
283
+ lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
284
+ lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
285
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
286
+ out_f = TempDir/"extracted.safetensors"
287
+ save_file(lora_sd, out_f)
288
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
289
+ api.upload_file(path_or_fileobj=out_f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
290
+ return "Done"
291
+ except Exception as e: return f"Error: {e}"
292
+
293
+ def load_full_state_dict(path):
294
+ raw = load_file(path, device="cpu")
295
+ cleaned = {}
296
+ for k, v in raw.items():
297
+ if "lora_A" in k: new_k = k.replace("lora_A", "lora_down")
298
+ elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up")
299
+ else: new_k = k
300
+ cleaned[new_k] = v.float()
301
+ return cleaned
302
+
303
+ def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private):
304
+ cleanup_temp()
305
+ if hf_token: login(hf_token.strip())
306
+ urls = [line.strip() for line in inputs_text.replace(" ", "\n").split('\n') if line.strip()]
307
+ if len(urls) < 2: return "Error: Provide at least 2 adapters."
308
+ try: weights = [float(w.strip()) for w in weight_str.split(',')] if weight_str.strip() else [1.0] * len(urls)
309
+ except: return "Error parsing weights."
310
+ if len(weights) < len(urls): weights += [1.0] * (len(urls) - len(weights))
311
+
312
+ paths = []
313
+ for url in tqdm(urls, desc="Downloading"): paths.append(download_lora_smart(url, hf_token))
314
+
315
+ merged = {}
316
+ if "Iterative EMA" in method:
317
+ base_sd = load_file(paths[0], device="cpu")
318
+ gamma = None
319
+ if sigma_rel > 0:
320
+ t_val = sigma_rel**-2
321
+ roots = np.roots([1, 7, 16 - t_val, 12 - t_val])
322
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
323
+ for i, path in enumerate(paths[1:]):
324
+ current_beta = (1 - 1 / (i + 1)) ** (gamma + 1) if gamma is not None else beta
325
+ curr = load_file(path, device="cpu")
326
+ for k in base_sd:
327
+ if k in curr and "alpha" not in k:
328
+ base_sd[k] = base_sd[k].float() * current_beta + curr[k].float() * (1 - current_beta)
329
+ merged = base_sd
330
+ else:
331
+ states = [load_full_state_dict(p) for p in paths]
332
+ all_stems = set()
333
+ for s in states:
334
+ for k in s:
335
+ if "lora_" in k: all_stems.add(k.split(".lora_")[0])
336
+ for stem in tqdm(all_stems):
337
+ down_list, up_list = [], []
338
+ alpha_sum, total_delta = 0.0, None
339
+ for i, state in enumerate(states):
340
+ w = weights[i]
341
+ dk, uk, ak = f"{stem}.lora_down.weight", f"{stem}.lora_up.weight", f"{stem}.alpha"
342
+ if dk in state and uk in state:
343
+ d, u = state[dk], state[uk]
344
+ alpha_sum += state[ak].item() if ak in state else d.shape[0]
345
+ if "Concatenation" in method:
346
+ down_list.append(d); up_list.append(u * w)
347
+ elif "SVD" in method:
348
+ rank = d.shape[0]
349
+ alpha = state[ak].item() if ak in state else rank
350
+ scale = (alpha / rank) * w
351
+ delta = ((u.flatten(1) @ d.flatten(1)).reshape(u.shape[0], d.shape[1], d.shape[2], d.shape[3]) if len(d.shape)==4 else u @ d) * scale
352
+ total_delta = delta if total_delta is None else total_delta + delta
353
+ if "Concatenation" in method and down_list:
354
+ merged[f"{stem}.lora_down.weight"] = torch.cat(down_list, dim=0).contiguous()
355
+ merged[f"{stem}.lora_up.weight"] = torch.cat(up_list, dim=1).contiguous()
356
+ merged[f"{stem}.alpha"] = torch.tensor(alpha_sum)
357
+ elif "SVD" in method and total_delta is not None:
358
+ tr = int(target_rank)
359
+ flat = total_delta.flatten(1) if len(total_delta.shape)==4 else total_delta
360
+ try:
361
+ U, S, V = torch.svd_lowrank(flat, q=tr + 4, niter=4)
362
+ Vh = V.t()
363
+ U, S, Vh = U[:, :tr], S[:tr], Vh[:tr, :]
364
+ U = U @ torch.diag(S)
365
+ if len(total_delta.shape) == 4:
366
+ U = U.reshape(total_delta.shape[0], tr, 1, 1)
367
+ Vh = Vh.reshape(tr, total_delta.shape[1], total_delta.shape[2], total_delta.shape[3])
368
+ else:
369
+ U, Vh = U.reshape(total_delta.shape[0], tr), Vh.reshape(tr, total_delta.shape[1])
370
+ merged[f"{stem}.lora_down.weight"] = Vh.contiguous()
371
+ merged[f"{stem}.lora_up.weight"] = U.contiguous()
372
+ merged[f"{stem}.alpha"] = torch.tensor(tr).float()
373
+ except: pass
374
+
375
+ out = TempDir / "merged_adapters.safetensors"
376
+ if merged: save_file(merged, out)
377
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
378
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
379
+ return f"Success! Merged to {out_repo}"
380
+
381
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
382
+ cleanup_temp()
383
+ if hf_token: login(hf_token.strip())
384
+ path = download_lora_smart(lora_input, hf_token)
385
+ state = load_file(path, device="cpu")
386
+ new_state = {}
387
+ groups = {}
388
+ for k in state:
389
+ simple = k.split(".lora_")[0]
390
+ if simple not in groups: groups[simple] = {}
391
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
392
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
393
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
394
+
395
+ target_rank_limit = int(new_rank)
396
+ for stem, g in tqdm(groups.items()):
397
+ if "down" in g and "up" in g:
398
+ down, up = g["down"].float(), g["up"].float()
399
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) if len(down.shape)==4 else up @ down
400
+ flat = merged.flatten(1)
401
+ U, S, V = torch.svd_lowrank(flat, q=target_rank_limit + 32)
402
+ Vh = V.t()
403
+ calc_rank = target_rank_limit
404
+ if dynamic_method == "sv_ratio":
405
+ calc_rank = int(torch.sum(S > (S[0] / dynamic_param)).item())
406
+ elif dynamic_method == "sv_cumulative":
407
+ calc_rank = int(torch.searchsorted(torch.cumsum(S, 0) / torch.sum(S), dynamic_param)) + 1
408
+ elif dynamic_method == "sv_fro":
409
+ calc_rank = int(torch.searchsorted(torch.cumsum(S.pow(2), 0) / torch.sum(S.pow(2)), dynamic_param**2)) + 1
410
+ final_rank = max(1, min(calc_rank, target_rank_limit, S.shape[0]))
411
+ U = U[:, :final_rank] @ torch.diag(S[:final_rank])
412
+ Vh = Vh[:final_rank, :]
413
+ if len(down.shape) == 4:
414
+ U = U.reshape(up.shape[0], final_rank, 1, 1)
415
+ Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3])
416
+ new_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
417
+ new_state[f"{stem}.lora_up.weight"] = U.contiguous()
418
+ new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float()
419
+ out = TempDir / "shrunken.safetensors"
420
+ save_file(new_state, out)
421
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
422
+ api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token)
423
+ return "Done"
424
+
425
+ # =================================================================================
426
+ # MERGEKIT & STREAMING LOGS
427
+ # =================================================================================
428
+
429
+ def parse_weight(w_str):
430
+ if not w_str.strip(): return 1.0
431
+ try:
432
+ if "[" in w_str: return yaml.safe_load(w_str)
433
+ return float(w_str)
434
+ except: return 1.0
435
+
436
+ def run_mergekit_logic(config_dict, token, out_repo, private, shard_size, output_precision, tokenizer_source, chat_template, program="mergekit-yaml"):
437
+ logs = []
438
+ def log(msg):
439
+ logs.append(msg)
440
+ return "\n".join(logs)
441
+
442
+ yield log("Starting MergeKit Process...")
443
+ cleanup_temp()
444
+
445
+ if chat_template and chat_template.strip():
446
+ config_dict["chat_template"] = chat_template.strip()
447
+
448
+ try:
449
+ if program != "mergekit-moe":
450
+ MergeConfiguration.model_validate(config_dict)
451
+ yield log("Configuration Validated Successfully.")
452
+ except Exception as e:
453
+ yield log(f"Invalid Config: {e}")
454
+ return
455
+
456
+ if token:
457
+ login(token.strip())
458
+ os.environ["HF_TOKEN"] = token.strip()
459
+
460
+ if "dtype" not in config_dict: config_dict["dtype"] = output_precision
461
+ if "tokenizer_source" not in config_dict and tokenizer_source != "base":
462
+ config_dict["tokenizer_source"] = tokenizer_source
463
+
464
+ config_path = TempDir / "config.yaml"
465
+ with open(config_path, "w") as f: yaml.dump(config_dict, f, sort_keys=False)
466
+
467
+ yield log(f"Config saved to {config_path}")
468
+ yield log(f"YAML:\n{yaml.dump(config_dict, sort_keys=False)}")
469
+
470
+ try:
471
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=token)
472
+ yield log(f"Repo {out_repo} ready.")
473
+ except Exception as e:
474
+ yield log(f"Repo Creation Error (might exist): {e}")
475
+
476
+ out_path = TempDir / "merge_output"
477
+ shard_arg = f"{int(float(shard_size) * 1024)}M"
478
+
479
+ cmd = [
480
+ program,
481
+ str(config_path),
482
+ str(out_path),
483
+ "--allow-crimes",
484
+ "--copy-tokenizer",
485
+ "--out-shard-size", shard_arg,
486
+ "--lazy-unpickle"
487
+ ]
488
+
489
+ if torch.cuda.is_available():
490
+ cmd.extend(["--cuda", "--low-cpu-memory"])
491
+
492
+ yield log(f"Executing: {' '.join(cmd)}")
493
+ env = os.environ.copy()
494
+ env["HF_HOME"] = str(TempDir / ".cache")
495
+ process = subprocess.Popen(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
496
+
497
+ for line in iter(process.stdout.readline, ""):
498
+ yield log(line.strip())
499
+
500
+ process.wait()
501
+ if process.returncode != 0:
502
+ yield log("Merge failed with exit code " + str(process.returncode))
503
+ return
504
+
505
+ yield log(f"Uploading to {out_repo}...")
506
+ try:
507
+ api.upload_folder(repo_id=out_repo, folder_path=out_path)
508
+ yield log("Upload Complete!")
509
+ except Exception as e:
510
+ yield log(f"Upload failed: {e}")
511
+
512
+ # --- UI Wrappers ---
513
+
514
+ def wrapper_amphinterpolative(token, method, base, t, norm, i8, flat, row, eps, m_iter, tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t):
515
+ params = {"normalize": norm, "int8_mask": i8}
516
+ if method in ["slerp", "nuslerp"]: params["t"] = float(t)
517
+ if method == "nuslerp": params.update({"flatten": flat, "row_wise": row})
518
+ if method == "multislerp": params["eps"] = float(eps)
519
+ if method == "karcher": params.update({"max_iter": int(m_iter), "tol": float(tol)})
520
+
521
+ config = {"merge_method": method}
522
+
523
+ if method in ["slerp", "nuslerp"]:
524
+ if not base.strip(): yield "Error: Base model required"; return
525
+ config["base_model"] = base.strip()
526
+ sources = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2)] if m.strip()]
527
+ config["slices"] = [{"sources": sources, "parameters": params}]
528
+ else:
529
+ if base.strip() and method == "multislerp": config["base_model"] = base.strip()
530
+ models = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)] if m.strip()]
531
+ config["models"] = models
532
+ config["parameters"] = params
533
+
534
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
535
+
536
+ def wrapper_stirtie(token, method, base, norm, i8, lamb, resc, topk, m1, w1, d1, g1, e1, m2, w2, d2, g2, e2, m3, w3, d3, g3, e3, m4, w4, d4, g4, e4, out, priv, shard, prec, tok_src, chat_t):
537
+ models = []
538
+ for m, w, d, g, e in [
539
+ (m1, w1, d1, g1, e1),
540
+ (m2, w2, d2, g2, e2),
541
+ (m3, w3, d3, g3, e3),
542
+ (m4, w4, d4, g4, e4)
543
+ ]:
544
+ if not m.strip(): continue
545
+ p = {"weight": parse_weight(w)}
546
+ if method in ["ties", "dare_ties", "dare_linear", "breadcrumbs_ties"]: p["density"] = parse_weight(d)
547
+ if "breadcrumbs" in method: p["gamma"] = float(g)
548
+ if "della" in method: p["epsilon"] = float(e)
549
+ models.append({"model": m, "parameters": p})
550
+
551
+ g_params = {"normalize": norm, "int8_mask": i8}
552
+ if method != "sce": g_params["lambda"] = float(lamb)
553
+ if method == "dare_linear": g_params["rescale"] = resc
554
+ if method == "sce": g_params["select_topk"] = float(topk)
555
+
556
+ config = {
557
+ "merge_method": method,
558
+ "base_model": base.strip() if base.strip() else models[0]["model"],
559
+ "parameters": g_params,
560
+ "models": models
561
+ }
562
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
563
+
564
+ def wrapper_specious(token, method, base, norm, i8, t, filt_w, m1, w1, f1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t):
565
+ models = []
566
+ if method == "passthrough":
567
+ p = {"weight": parse_weight(w1)}
568
+ if f1.strip(): p["filter"] = f1.strip()
569
+ models.append({"model": m1, "parameters": p})
570
+ else:
571
+ models = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)] if m.strip()]
572
+
573
+ config = {"merge_method": method, "parameters": {"normalize": norm, "int8_mask": i8}}
574
+ if base.strip(): config["base_model"] = base.strip()
575
+ if method == "nearswap": config["parameters"]["t"] = float(t)
576
+ if method == "model_stock": config["parameters"]["filter_wise"] = filt_w
577
+ config["models"] = models
578
+
579
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
580
+
581
+ def wrapper_moer(token, base, expert1, prompt1, expert2, prompt2, expert3, prompt3, expert4, prompt4, expert5, prompt5, gate, dtype, out, priv, shard, prec, tok_src, chat_t):
582
+ experts = []
583
+ for exp, pmt in [
584
+ (expert1, prompt1), (expert2, prompt2), (expert3, prompt3),
585
+ (expert4, prompt4), (expert5, prompt5)
586
+ ]:
587
+ if exp.strip():
588
+ expert_entry = {"source_model": exp.strip()}
589
+ if pmt.strip():
590
+ prompts = [p.strip() for p in pmt.split(',') if p.strip()]
591
+ expert_entry["positive_prompts"] = prompts
592
+ else:
593
+ expert_entry["positive_prompts"] = [""]
594
+ experts.append(expert_entry)
595
+
596
+ if len(experts) < 2:
597
+ return "Error: At least 2 experts required"
598
+
599
+ config = {
600
+ "base_model": base.strip(),
601
+ "gate_mode": gate,
602
+ "dtype": dtype,
603
+ "experts": experts
604
+ }
605
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-moe")
606
+
607
+ def wrapper_rawer(token, models, method, dtype, out, priv, shard, prec, tok_src, chat_t):
608
+ models_list = [{"model": m.strip(), "parameters": {"weight": 1.0}} for m in models.split('\n') if m.strip()]
609
+ config = {
610
+ "models": models_list,
611
+ "merge_method": method,
612
+ "dtype": dtype
613
+ }
614
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
615
+
616
+ # --- TAB 10 (Custom DARE) Logic ---
617
+ def task_dare_custom(token, base, ft, ratio, mask, out, priv):
618
+ cleanup_temp()
619
+ if token: login(token.strip())
620
+ try:
621
+ b_path = download_lora_smart(base, token)
622
+ f_path = download_lora_smart(ft, token)
623
+ b_sd = load_file(b_path, device="cpu")
624
+ f_sd = load_file(f_path, device="cpu")
625
+ merged = {}
626
+ common = set(b_sd.keys()) & set(f_sd.keys())
627
+ for k in tqdm(common, desc="Merging"):
628
+ tb, tf = b_sd[k], f_sd[k]
629
+ if tb.shape != tf.shape:
630
+ merged[k] = tf
631
+ continue
632
+ delta = tf.float() - tb.float()
633
+ if mask > 0:
634
+ m = torch.bernoulli(torch.full_like(delta, 1.0 - mask))
635
+ delta = (delta * m) / (1.0 - mask)
636
+ merged[k] = (tb.float() + ratio * delta).to(tb.dtype)
637
+
638
+ out_f = TempDir / "model.safetensors"
639
+ save_file(merged, out_f)
640
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
641
+ api.upload_file(path_or_fileobj=out_f, path_in_repo="model.safetensors", repo_id=out, token=token)
642
+ return f"Done! {out}"
643
+ except Exception as e: return str(e)
644
+
645
+ # =================================================================================
646
+ # UI GENERATION
647
+ # =================================================================================
648
+
649
+ IFRAME_HTML = """
650
+ <div style="width: 485px; height: 485px; overflow: hidden; background-color: #050805; border: 0.5px solid #1a2f1a; display: flex; align-items: center; justify-content: center;">
651
+ <iframe
652
+ src="https://alekseycalvin-soonr-screenseeder.hf.space"
653
+ style="width: 485px; height: 485px; border: none;"
654
+ frameborder="0"
655
+ scrolling="no">
656
+ </iframe>
657
+ </div>
658
+ """
659
+
660
+ css = """
661
+ @font-face {
662
+ font-family: "K.O. Activista* Bold";
663
+ src: url("https://st.1001fonts.net/download/font/k-o-activista.bold.ttf") format('truetype');
664
+ }
665
+
666
+ @font-face {
667
+ font-family: "Averia Libre Regular";
668
+ src: url("https://st.1001fonts.net/download/font/averia-libre.regular.ttf") format('truetype');
669
+ }
670
+
671
+ /* --- General & Font Settings --- */
672
+ body, .gradio-container {
673
+ font-family: "Averia Libre Regular" !important;
674
+ background-color: #132015 !important;
675
+ color: #e0eecd;
676
+ }
677
+
678
+ h1, h2, h3, h4, h5, h6 {
679
+ font-family: "K.O. Activista* Bold", sans-serif !important;
680
+ }
681
+
682
+ .container {
683
+ max-width: 1450px;
684
+ margin: 0 auto;
685
+ padding: 5px !important; /* Reduced padding */
686
+ background-color: #050805;
687
+ }
688
+
689
+ /* --- Header Section --- */
690
+ .header-container {
691
+ display: flex;
692
+ align-items: flex-start; /* Align top */
693
+ margin-bottom: 5px;
694
+ padding: 5px;
695
+ border-bottom: 1px solid #1a2f1a;
696
+ }
697
+
698
+ .header-image {
699
+ max-width: 340px; /* Slightly bigger as requested */
700
+ height: auto;
701
+ border-radius: 8px;
702
+ margin-right: 5px;
703
+ flex-shrink: 0;
704
+ }
705
+
706
+ .header-text {
707
+ flex-grow: 1;
708
+ padding-top: 8px;
709
+ }
710
+
711
+ .header-text h1 {
712
+ font-size: 2.25rem;
713
+ line-height: 1.0; /* Tighter for stack */
714
+ color: #FF312A;
715
+ margin: 0;
716
+ text-transform: uppercase;
717
+ text-shadow: 2px 3px 1px #000;
718
+ }
719
+
720
+ .header-text h2 {
721
+ margin-top: 3px;
722
+ font-size: 1.3rem;
723
+ color: #FFF450;
724
+ font-weight: bold;
725
+ font-variant: small-caps;
726
+ text-transform: none;
727
+ text-shadow: 2px 2px 1px #000;
728
+ }
729
+
730
+ /* --- Navigation (Two Rows) --- */
731
+ .tabs {
732
+ background: transparent;
733
+ border-bottom: 1px solid #0F1B0F;
734
+ margin-bottom: 2px;
735
+ }
736
+
737
+ .tab-nav {
738
+ display: flex;
739
+ flex-wrap: wrap !important; /* Force wrapping */
740
+ gap: 3px;
741
+ }
742
+
743
+ .tab-nav button {
744
+ background: #0b140b;
745
+ border: 1px solid #1a2f1a;
746
+ color: #ABBEAB;
747
+ font-size: 0.96rem;
748
+ font-weight: 600;
749
+ padding: 2px 4px;
750
+ border-radius: 2px;
751
+ flex: 1 0 6%;
752
+ text-align: center;
753
+ min-width: 120px; /* Ensure readability */
754
+ transition: all 0.2s;
755
+ }
756
+
757
+ .tab-nav button:hover {
758
+ background: #16261b;
759
+ color: #fff;
760
+ border-color: #ff9f66;
761
+ }
762
+
763
+ .tab-nav button.selected {
764
+ color: #ff9f66;
765
+ border: 1px solid #ff9f66;
766
+ background: #070E07;
767
+ box-shadow: 0 0 5px rgba(255, 159, 102, 0.2);
768
+ }
769
+
770
+ /* --- Layout & Components (Compact) --- */
771
+ .gradio-container {
772
+ padding: 0 !important; /* Remove default Gradio padding */
773
+ }
774
+
775
+ .block {
776
+ background-color: #1D3A1B;
777
+ background: linear-gradient(to right, #10200E 0%, #6D1515 50%, #1F3C1F 100%);
778
+ border: 1px solid #1a2f1a;
779
+ border-radius: 4px;
780
+ box-shadow: 1px 1px 2px 1px #1D3A1B;
781
+ padding: 2px !important; /* Reduced padding */
782
+ margin-bottom: 2px !important;
783
+ }
784
+
785
+ /* Reduce gaps in rows and columns */
786
+ .row, .column {
787
+ gap: 2px !important;
788
+ }
789
+
790
+ /* Input Fields */
791
+ input[type="text"], input[type="password"], input[type="number"], textarea, select {
792
+ background-color: #111B11 !important;
793
+ border: 1px solid #2b4532 !important;
794
+ color: #e0eecd !important;
795
+ border-radius: 2px;
796
+ padding: 3px 7px !important;
797
+ font-size: 14px !important;
798
+ }
799
+
800
+ input:focus, textarea:focus, select:focus {
801
+ border-color: #FF6262 !important;
802
+ outline: none;
803
+ }
804
+
805
+ /* Buttons */
806
+ button.primary {
807
+ background: linear-gradient(135deg, #FF4646 0%, #C33939 100%);
808
+ border: none;
809
+ color: #050805;
810
+ font-weight: 800;
811
+ font-family: "K.O. Activista* Bold", sans-serif;
812
+ text-transform: uppercase;
813
+ letter-spacing: 1px;
814
+ padding: 8px 16px;
815
+ border-radius: 4px;
816
+ width: 100%;
817
+ }
818
+
819
+ button.primary:hover {
820
+ filter: brightness(1.2);
821
+ box-shadow: 0 0 10px rgba(255, 159, 102, 0.4);
822
+ }
823
+
824
+ /* Sliders */
825
+ input[type=range] {
826
+ accent-color: #ff9f66;
827
+ }
828
+
829
+ /* Accordion / Labels */
830
+ .label-wrap {
831
+ background-color: #193219;
832
+ border: 1px solid #1a2f1a;
833
+ padding: 4px 7px;
834
+ border-radius: 4px;
835
+ }
836
+
837
+ /* Specific Screensaver Styles */
838
+ .screensaver-wrapper {
839
+ background: #000;
840
+ border: 1px solid #1a2f1a;
841
+ border-radius: 2px;
842
+ overflow: hidden;
843
+ display: flex;
844
+ justify-content: center;
845
+ align-items: center;
846
+ }
847
+ """
848
+
849
+ with gr.Blocks() as demo:
850
+ # --- Header & Top Screensaver (Tabs 1-4, 9, 10) ---
851
+ with gr.Row(elem_classes="header-container"):
852
+ with gr.Column(scale=2):
853
+ gr.HTML(
854
+ """
855
+ <div style="display: flex; align-items: center;">
856
+ <img src="https://huggingface.co/spaces/AlekseyCalvin/Soon_Merger/resolve/main/SMerger3.png" alt="SOONmerge®" class="header-image">
857
+ <div class="header-text">
858
+ <h1>Transform Transformers for FREE</h1>
859
+ <h2>🧰 Training-Free Model Tools MergeKit gui: Tabs 5-9 + MORE</h2>
860
+ </div>
861
+ </div>
862
+ """
863
+ )
864
+ with gr.Column(scale=1, min_width=480):
865
+ # Top Right Screensaver (Used by Tabs 1-4, 9, 10)
866
+ top_screensaver = gr.HTML(IFRAME_HTML, visible=False, elem_classes="screensaver-wrapper")
867
+
868
+ # Helper to hide all screensavers on tab switch
869
+ def hide_all_screensavers():
870
+ return [gr.update(visible=False)] * 5
871
+
872
+ # --- SINGLE TAB CONTAINER ---
873
+ with gr.Tabs() as main_tabs:
874
+
875
+ # --- 1. Merge Legacy ---
876
+ with gr.Tab("Merge 2 Base") as t1_tab:
877
+ gr.Markdown("### Fuse adapter (LoRA, DoRA, etc...) + base model (LLM, t2i, t2v... any!)")
878
+ with gr.Row(variant="compact"):
879
+ t1_token = gr.Textbox(label="Token", type="password", scale=2)
880
+ t1_prec = gr.Dropdown(["bf16", "fp16", "float32"], value="bf16", label="Precision", scale=1)
881
+ t1_shard = gr.Slider(label="Max Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1, scale=2)
882
+ with gr.Row(variant="compact"):
883
+ t1_base = gr.Textbox(label="Base Repo", value="name/repo", scale=3)
884
+ t1_sub = gr.Textbox(label="Subfolder (Optional)", value="", scale=2)
885
+ with gr.Row(variant="compact"):
886
+ t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors", scale=3)
887
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1, scale=2)
888
+ with gr.Row(variant="compact"):
889
+ t1_out = gr.Textbox(label="Output Repo", scale=3)
890
+ t1_struct = gr.Textbox(label="Extras Source", value="name/repo", scale=2)
891
+ t1_priv = gr.Checkbox(label="Private", value=True, scale=1)
892
+ t1_btn = gr.Button("Merge", variant="primary")
893
+ t1_res = gr.Textbox(label="Result")
894
+ t1_btn.click(lambda: gr.update(visible=True), outputs=top_screensaver).then(task_merge_legacy, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
895
+
896
+ # --- 2. Extract Adapter ---
897
+ with gr.Tab("Extract LoRA") as t2_tab:
898
+ with gr.Row(variant="compact"):
899
+ t2_token = gr.Textbox(label="Token", type="password", scale=2)
900
+ t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1, scale=1)
901
+ with gr.Row(variant="compact"):
902
+ t2_org = gr.Textbox(label="Original Model", scale=1)
903
+ t2_tun = gr.Textbox(label="Tuned or Homologous Model", scale=1)
904
+ with gr.Row(variant="compact"):
905
+ t2_out = gr.Textbox(label="Output Repo", scale=3)
906
+ t2_btn = gr.Button("Extract", variant="primary", scale=1)
907
+ t2_res = gr.Textbox(label="Result")
908
+ t2_btn.click(lambda: gr.update(visible=True), outputs=top_screensaver).then(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
909
+
910
+ # --- 3. Merge Adapters ---
911
+ with gr.Tab("Fuse Adapters") as t3_tab:
912
+ gr.Markdown("### Batch Adapter Merging")
913
+ with gr.Row(variant="compact"):
914
+ t3_token = gr.Textbox(label="Token", type="password", scale=1)
915
+ t3_method = gr.Dropdown(["Iterative EMA (Linear w/ Beta/Sigma coefficient)", "Concatenation (MOE-like weights-stack)", "SVD Fusion (Task Arithmetic/Compressed)"], value="Iterative EMA (Linear w/ Beta/Sigma coefficient)", label="Merge Method", scale=2)
916
+ t3_urls = gr.TextArea(label="Adapter URLs/Repos (one per line, or space-separated)", lines=3)
917
+ with gr.Row(variant="compact"):
918
+ t3_weights = gr.Textbox(label="Weights (comma-separated) – for Concat/SVD", scale=2)
919
+ t3_rank = gr.Number(label="Target Rank – For SVD only", value=128, scale=1)
920
+ with gr.Row(variant="compact"):
921
+ t3_beta = gr.Slider(label="Beta – for linear/post-hoc EMA", value=0.95, minimum=0.01, maximum=1.00)
922
+ t3_sigma = gr.Slider(label="Sigma Rel – for linear/post-hoc EMA", value=0.21, minimum=0.01, maximum=1.00)
923
+ with gr.Row(variant="compact"):
924
+ t3_out = gr.Textbox(label="Output Repo", scale=3)
925
+ t3_priv = gr.Checkbox(label="Private Output", value=True, scale=1)
926
+ t3_btn = gr.Button("Merge", variant="primary")
927
+ t3_res = gr.Textbox(label="Result")
928
+ t3_btn.click(lambda: gr.update(visible=True), outputs=top_screensaver).then(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], t3_res)
929
+
930
+ # --- 4. Resize Adapter ---
931
+ with gr.Tab("Resize") as t4_tab:
932
+ with gr.Row(variant="compact"):
933
+ t4_token = gr.Textbox(label="Token", type="password", scale=1)
934
+ t4_in = gr.Textbox(label="LoRA Input", scale=2)
935
+ with gr.Row(variant="compact"):
936
+ t4_rank = gr.Number(label="To Rank", value=8, scale=1)
937
+ t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None", label="Dynamic Method", scale=1)
938
+ t4_param = gr.Number(label="Dynamic Param", value=0.9, scale=1)
939
+ with gr.Accordion("📉 Dynamic Resizing Guide", open=False):
940
+ gr.Markdown("- **sv_ratio (Relative Strength):** Keeps features that are at least `1/Param` as strong as the main feature. **Param must be >= 2**.\n- **sv_fro (Visual Information Density):** Preserves `Param%` of total information content. **Param between 0.0 and 1.0**.\n- **sv_cumulative (Cumulative Sum):** Preserves weights that sum up to `Param%` of total strength. **Param between 0.0 and 1.0**.\n- **⚠️ Safety Ceiling:** The **'To Rank'** slider acts as a hard limit.")
941
+ with gr.Row(variant="compact"):
942
+ t4_out = gr.Textbox(label="Output Repo", scale=3)
943
+ t4_btn = gr.Button("Resize", variant="primary", scale=1)
944
+ t4_res = gr.Textbox(label="Result")
945
+ t4_btn.click(lambda: gr.update(visible=True), outputs=top_screensaver).then(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
946
+
947
+ # --- 5. Amphinterpolative ---
948
+ with gr.Tab("Amphinterpolate") as t5_tab:
949
+ gr.Markdown("### Spherical Interpolation Methods Family: slerp, nuslerp, multislerp, karcher")
950
+ with gr.Row(variant="compact"):
951
+ t5_token = gr.Textbox(label="HF Token", type="password", scale=1)
952
+ t5_method = gr.Dropdown(["slerp", "nuslerp", "multislerp", "karcher"], value="slerp", label="Merge Method", scale=1)
953
+ t5_base = gr.Textbox(label="Base Model (Optional)", scale=2)
954
+
955
+ gr.Markdown("See [MergeKit Merge Method Docs](https://github.com/arcee-ai/mergekit/blob/main/docs/merge_methods.md) for more info.")
956
+
957
+ with gr.Row(variant="compact"):
958
+ t5_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0)
959
+ t5_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision")
960
+ t5_t = gr.Slider(0, 1, 0.5, label="t (mix factor)")
961
+ with gr.Row(variant="compact"):
962
+ t5_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source")
963
+ t5_chat = gr.Textbox(label="Chat Template", placeholder="auto")
964
+ gr.Markdown("Built-in Chat Templates: alpaca, chatml, llama3, mistral, exaone, auto")
965
+
966
+ with gr.Row(variant="compact"):
967
+ t5_norm = gr.Checkbox(label="Normalize Weights", value=True)
968
+ t5_i8 = gr.Checkbox(label="Int8 Mask", value=False)
969
+ t5_flat = gr.Checkbox(label="Flatten Tensors", value=False)
970
+ t5_row = gr.Checkbox(label="Row Wise", value=False)
971
+ with gr.Accordion("Advanced Parameters (eps, iter, tol)", open=False):
972
+ with gr.Row(variant="compact"):
973
+ t5_eps = gr.Textbox(label="eps (MultiSlerp)", value="1e-8")
974
+ t5_iter = gr.Number(label="Max Iter (Karcher)", value=10)
975
+ t5_tol = gr.Textbox(label="tol (Karcher)", value="1e-5")
976
+
977
+ gr.Markdown("**MODELS**: **slerp:** 2 models exactly, one of the 2 also listed as *Base* | **nuslerp:** 2 models exactly; *Base*: optional | **multislerp:** 2+ models; *Base*: optional | **karcher:** 2+ models; *Base*: none")
978
+
979
+ with gr.Row(variant="compact"):
980
+ with gr.Column(scale=3): m1 = gr.Textbox(label="Model 1")
981
+ with gr.Column(scale=1): w1 = gr.Textbox(label="Weight 1", value="1.0")
982
+ with gr.Row(variant="compact"):
983
+ with gr.Column(scale=3): m2 = gr.Textbox(label="Model 2")
984
+ with gr.Column(scale=1): w2 = gr.Textbox(label="Weight 2", value="1.0")
985
+ with gr.Accordion("More Models (3-5)", open=False):
986
+ with gr.Row(variant="compact"): m3, w3 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0")
987
+ with gr.Row(variant="compact"): m4, w4 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0")
988
+ with gr.Row(variant="compact"): m5, w5 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0")
989
+ with gr.Row(variant="compact"):
990
+ t5_out = gr.Textbox(label="Output Repo", scale=3)
991
+ t5_priv = gr.Checkbox(label="Private", value=True, scale=1)
992
+ t5_btn = gr.Button("Execute", variant="primary")
993
+
994
+ # Bottom row split: Result (Narrower) | SS (Wider)
995
+ with gr.Row():
996
+ with gr.Column(scale=2):
997
+ t5_res = gr.Textbox(label="Result", lines=10)
998
+ with gr.Column(scale=2):
999
+ t5_ss = gr.HTML(IFRAME_HTML, visible=False)
1000
+ t5_btn.click(lambda: gr.update(visible=True), outputs=t5_ss).then(wrapper_amphinterpolative, [t5_token, t5_method, t5_base, t5_t, t5_norm, t5_i8, t5_flat, t5_row, t5_eps, t5_iter, t5_tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, t5_out, t5_priv, t5_shard, t5_prec, t5_tok, t5_chat], t5_res)
1001
+
1002
+ # --- 6. Stir/Tie Bases ---
1003
+ with gr.Tab("Align/Tie") as t6_tab:
1004
+ gr.Markdown("### Task Vector Methods Family: task_arithmetic, ties, dare_ties, dare_linear, della, della_linear, breadcrumbs, breadcrumbs_ties, sce")
1005
+ with gr.Row(variant="compact"):
1006
+ t6_token = gr.Textbox(label="Token", type="password", scale=1)
1007
+ t6_method = gr.Dropdown(["task_arithmetic", "ties", "dare_ties", "dare_linear", "della", "della_linear", "breadcrumbs", "breadcrumbs_ties", "sce"], value="ties", label="Merge Method", scale=2)
1008
+ t6_base = gr.Textbox(label="Base Model (required)")
1009
+ gr.Markdown("See [MergeKit Merge Method Docs](https://github.com/arcee-ai/mergekit/blob/main/docs/merge_methods.md) for more info.")
1010
+
1011
+ with gr.Row(variant="compact"):
1012
+ t6_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0)
1013
+ t6_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision")
1014
+ t6_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source")
1015
+ t6_chat = gr.Textbox(label="Chat Template", placeholder="auto")
1016
+ gr.Markdown("Built-in **Chat Templates**: alpaca, chatml, llama3, mistral, exaone, auto (default)")
1017
+ gr.Markdown("**MODELS**: These methods all accept **2 or more models**, and require one of these designated as *Base*")
1018
+
1019
+ with gr.Accordion("Global Parameters (Normalize, Int8, Lambda, etc.)", open=False):
1020
+ with gr.Row(variant="compact"): t6_norm = gr.Checkbox(label="Normalize Weights", value=True); t6_i8 = gr.Checkbox(label="Int8 Mask", value=False); t6_resc = gr.Checkbox(label="Rescale (Dare_Linear)", value=True)
1021
+ with gr.Row(variant="compact"): t6_lamb = gr.Number(label="Lambda", value=1.0); t6_topk = gr.Slider(0, 1, 1.0, label="Select TopK (SCE)")
1022
+ with gr.Row(variant="compact"):
1023
+ m1_6 = gr.Textbox(label="Model 1", scale=2); w1_6 = gr.Textbox(label="Weight 1", value="1.0", scale=1); d1_6 = gr.Textbox(label="Density", value="1.0", scale=1)
1024
+ with gr.Row(variant="compact"):
1025
+ m2_6 = gr.Textbox(label="Model 2", scale=2); w2_6 = gr.Textbox(label="Weight 2", value="1.0", scale=1); d2_6 = gr.Textbox(label="Density", value="1.0", scale=1)
1026
+ with gr.Accordion("More Models & Params", open=False):
1027
+ with gr.Row(variant="compact"): g1_6 = gr.Number(label="Gamma (breadcrumbs)", value=0.01); e1_6 = gr.Number(label="Epsilon (DELLA)", value=0.15)
1028
+ with gr.Row(variant="compact"): g2_6 = gr.Number(label="Gamma (breadcrumbs)", value=0.01); e2_6 = gr.Number(label="Epsilon (DELLA)", value=0.15)
1029
+ with gr.Row(variant="compact"): m3_6, w3_6, d3_6, g3_6, e3_6 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"), gr.Textbox(label="Density (DARE/TIES)", value="1.0"), gr.Number(label="Gamma (breadcrumbs)", value=0.01), gr.Number(label="Epsilon (DELLA)", value=0.15)
1030
+ with gr.Row(variant="compact"): m4_6, w4_6, d4_6, g4_6, e4_6 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"), gr.Textbox(label="Density (DARE/TIES)", value="1.0"), gr.Number(label="Gamma (breadcrumbs)", value=0.01), gr.Number(label="Epsilon (DELLA)", value=0.15)
1031
+ with gr.Row(variant="compact"):
1032
+ t6_out = gr.Textbox(label="Output Repo", scale=3); t6_priv = gr.Checkbox(label="Private", value=True, scale=1)
1033
+ t6_btn = gr.Button("Execute", variant="primary")
1034
+ with gr.Row():
1035
+ with gr.Column(scale=2):
1036
+ t6_res = gr.Textbox(label="Result", lines=10)
1037
+ with gr.Column(scale=2):
1038
+ t6_ss = gr.HTML(IFRAME_HTML, visible=False)
1039
+ t6_btn.click(lambda: gr.update(visible=True), outputs=t6_ss).then(wrapper_stirtie, [t6_token, t6_method, t6_base, t6_norm, t6_i8, t6_lamb, t6_resc, t6_topk, m1_6, w1_6, d1_6, g1_6, e1_6, m2_6, w2_6, d2_6, g2_6, e2_6, m3_6, w3_6, d3_6, g3_6, e3_6, m4_6, w4_6, d4_6, g4_6, e4_6, t6_out, t6_priv, t6_shard, t6_prec, t6_tok, t6_chat], t6_res)
1040
+
1041
+ # --- 7. Specious ---
1042
+ with gr.Tab("Specious") as t7_tab:
1043
+ gr.Markdown("### Specialized Methods: model_stock, nearswap, arcee_fusion, passthrough")
1044
+ with gr.Row(variant="compact"):
1045
+ t7_token = gr.Textbox(label="Token", type="password", scale=1)
1046
+ t7_method = gr.Dropdown(["model_stock", "nearswap", "arcee_fusion", "passthrough", "linear"], value="model_stock", label="Merge Method", scale=2)
1047
+ t7_base = gr.Textbox(label="Base Model (required for nearswap/arcee_fusion/model_stock)", placeholder="org/base-model")
1048
+ gr.Markdown("See [MergeKit Merge Method Docs](https://github.com/arcee-ai/mergekit/blob/main/docs/merge_methods.md) for more info.")
1049
+
1050
+ with gr.Row(variant="compact"):
1051
+ t7_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0); t7_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t7_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t7_chat = gr.Textbox(label="Chat Template", placeholder="auto")
1052
+ gr.Markdown("Built-in **Chat Templates**: alpaca, chatml, llama3, mistral, exaone, auto (default)")
1053
+
1054
+ gr.Markdown("**MODELS**: **passthrough:** 1 model acc. to Docs, but [Examples](https://github.com/arcee-ai/mergekit/tree/main/examples) shows 2+ | **nearswap/arcee_fusion:** 2 models, one also listed as *Base* | **model_stock:** 3+ models, one also listed as *Base*")
1055
+
1056
+ with gr.Row(variant="compact"):
1057
+ t7_norm = gr.Checkbox(label="Normalize", value=True); t7_i8 = gr.Checkbox(label="Int8 Mask", value=False); t7_t = gr.Slider(0, 1, 0.5, label="t (Interpolation Ratio, for Nearswap)"); t7_filt_w = gr.Checkbox(label="Filter Wise (for Model_Stock)", value=False)
1058
+ with gr.Row(variant="compact"): m1_7 = gr.Textbox(label="Model 1", scale=2); w1_7 = gr.Textbox(label="Weight 1", value="1.0", scale=1); f1_7 = gr.Textbox(label="Filter Model Component", scale=1)
1059
+ with gr.Row(variant="compact"): m2_7 = gr.Textbox(label="Model 2", scale=2); w2_7 = gr.Textbox(label="Weight 2", value="1.0", scale=1)
1060
+ with gr.Accordion("More Models (3-5)", open=False):
1061
+ with gr.Row(variant="compact"): m3_7, w3_7 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0")
1062
+ with gr.Row(variant="compact"): m4_7, w4_7 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0")
1063
+ with gr.Row(variant="compact"): m5_7, w5_7 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0")
1064
+ with gr.Row(variant="compact"): t7_out = gr.Textbox(label="Output Repo", scale=3); t7_priv = gr.Checkbox(label="Private", value=True, scale=1)
1065
+ t7_btn = gr.Button("Execute", variant="primary")
1066
+ with gr.Row():
1067
+ with gr.Column(scale=2):
1068
+ t7_res = gr.Textbox(label="Result", lines=10)
1069
+ with gr.Column(scale=2):
1070
+ t7_ss = gr.HTML(IFRAME_HTML, visible=False)
1071
+ t7_btn.click(lambda: gr.update(visible=True), outputs=t7_ss).then(wrapper_specious, [t7_token, t7_method, t7_base, t7_norm, t7_i8, t7_t, t7_filt_w, m1_7, w1_7, f1_7, m2_7, w2_7, m3_7, w3_7, m4_7, w4_7, m5_7, w5_7, t7_out, t7_priv, t7_shard, t7_prec, t7_tok, t7_chat], t7_res)
1072
+
1073
+ # --- 8. MoEr ---
1074
+ with gr.Tab("MoEr") as t8_tab:
1075
+ gr.Markdown("### Mixture of Experts: fuses self-attention & normalization layers from *Base* w/MLP layers from *Experts*")
1076
+ gr.Markdown("See [MergeKit MoE doc](https://github.com/arcee-ai/mergekit/blob/main/docs/moe.md) for more info.")
1077
+
1078
+ with gr.Row(variant="compact"):
1079
+ t8_token = gr.Textbox(label="Token", type="password", scale=1)
1080
+ t8_base = gr.Textbox(label="Base Model (Required)", scale=2)
1081
+ with gr.Row(variant="compact"):
1082
+ t8_gate = gr.Dropdown(["cheap_embed", "random", "hidden"], value="cheap_embed", label="Gate Mode")
1083
+ t8_dtype = gr.Dropdown(["float16", "bfloat16"], value="bfloat16", label="Internal Dtype")
1084
+ with gr.Row(variant="compact"):
1085
+ t8_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0); t8_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t8_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t8_chat = gr.Textbox(label="Chat Template", placeholder="auto")
1086
+
1087
+ gr.Markdown("#### Experts (at least 2 required). Prompts are comma-separated.")
1088
+ with gr.Row(variant="compact"): t8_expert1 = gr.Textbox(label="Expert 1", placeholder="org/expert1", scale=2); t8_prompt1 = gr.Textbox(label="Positive Prompts", placeholder="math, reasoning, logic", scale=3)
1089
+ with gr.Row(variant="compact"): t8_expert2 = gr.Textbox(label="Expert 2", placeholder="org/expert2", scale=2); t8_prompt2 = gr.Textbox(label="Positive Prompts", placeholder="creative, writing, storytelling", scale=3)
1090
+ with gr.Accordion("More Experts (3-5)", open=False):
1091
+ with gr.Row(variant="compact"): t8_expert3, t8_prompt3 = gr.Textbox(label="Expert 3 (optional)", placeholder="org/expert3"), gr.Textbox(label="Positive Prompts", placeholder="code, programming")
1092
+ with gr.Row(variant="compact"): t8_expert4, t8_prompt4 = gr.Textbox(label="Expert 4 (optional)", placeholder="org/expert4"), gr.Textbox(label="Positive Prompts", placeholder="")
1093
+ with gr.Row(variant="compact"): t8_expert5, t8_prompt5 = gr.Textbox(label="Expert 5 (optional)", placeholder="org/expert5"), gr.Textbox(label="Positive Prompts", placeholder="")
1094
+ with gr.Row(variant="compact"): t8_out = gr.Textbox(label="Output Repo", scale=3); t8_priv = gr.Checkbox(label="Private", value=True, scale=1)
1095
+ t8_btn = gr.Button("Build MoE", variant="primary")
1096
+ with gr.Row():
1097
+ with gr.Column(scale=2):
1098
+ t8_res = gr.Textbox(label="Result", lines=10)
1099
+ with gr.Column(scale=2):
1100
+ t8_ss = gr.HTML(IFRAME_HTML, visible=False)
1101
+ t8_btn.click(lambda: gr.update(visible=True), outputs=t8_ss).then(wrapper_moer, [t8_token, t8_base, t8_expert1, t8_prompt1, t8_expert2, t8_prompt2, t8_expert3, t8_prompt3, t8_expert4, t8_prompt4, t8_expert5, t8_prompt5, t8_gate, t8_dtype, t8_out, t8_priv, t8_shard, t8_prec, t8_tok, t8_chat], t8_res)
1102
+
1103
+ # --- 9. Rawer ---
1104
+ with gr.Tab("Rawer") as t9_tab:
1105
+ gr.Markdown("### Raw PyTorch MergeKit / Non-pipeline-classed Models")
1106
+ with gr.Row(variant="compact"):
1107
+ t9_token = gr.Textbox(label="Token", type="password", scale=1)
1108
+ t9_method = gr.Dropdown(["linear", "passthrough"], value="linear", label="Merge Method", scale=1)
1109
+ t9_dtype = gr.Dropdown(["float32", "float16", "bfloat16"], value="float32", label="Config dtype", scale=1)
1110
+ t9_models = gr.TextArea(label="Models (one per line)", lines=3)
1111
+ with gr.Row(variant="compact"):
1112
+ t9_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=0.5, maximum=20.0); t9_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision")
1113
+ with gr.Row(variant="compact"):
1114
+ t9_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t9_chat = gr.Textbox(label="Chat Template (e.g., alpaca, chatml, auto)", placeholder="auto")
1115
+ gr.Markdown("Built-in Chat Templates: alpaca, chatml, llama3, mistral, exaone, auto")
1116
+ gr.Markdown("See [MergeKit Merge Method Docs](https://github.com/arcee-ai/mergekit/blob/main/docs/merge_methods.md) for more info.")
1117
+ with gr.Row(variant="compact"):
1118
+ t9_out = gr.Textbox(label="Output Repo", scale=3); t9_priv = gr.Checkbox(label="Private", value=True, scale=1)
1119
+ t9_btn = gr.Button("Merge Raw", variant="primary")
1120
+ t9_res = gr.Textbox(label="Result", lines=10)
1121
+ t9_btn.click(lambda: gr.update(visible=True), outputs=top_screensaver).then(wrapper_rawer, [t9_token, t9_models, t9_method, t9_dtype, t9_out, t9_priv, t9_shard, t9_prec, t9_tok, t9_chat], t9_res)
1122
+
1123
+ # --- 10. Mario,DARE! ---
1124
+ with gr.Tab("Mario,Dare!") as t10_tab:
1125
+ gr.Markdown("### Model-Agnostic DARE Implementation (Drop And REscale)")
1126
+ gr.Markdown("From [sft-merger by Martyn Garcia](https://github.com/martyn)")
1127
+ t10_token = gr.Textbox(label="Token", type="password")
1128
+
1129
+ gr.Markdown(
1130
+ """
1131
+ ### How DARE Works:
1132
+ 1. **Compute Delta**: Difference between fine-tuned and base weights
1133
+ 2. **Drop Elements**: Randomly mask out delta values based on mask rate
1134
+ 3. **Rescale**: Compensate for dropped elements by rescaling remaining values
1135
+ 4. **Apply**: Add scaled delta back to base model
1136
+
1137
+ **Mask Rate**: 0.5 = drop 50% of delta values, 0.9 = drop 90% (more aggressive sparsification)
1138
+ """
1139
+ )
1140
+
1141
+ with gr.Row(variant="compact"):
1142
+ t10_base = gr.Textbox(label="Base Model", placeholder="org/base-model"); t10_ft = gr.Textbox(label="Fine-Tuned Model", placeholder="org/fine-tuned-model")
1143
+ with gr.Row(variant="compact"):
1144
+ t10_ratio = gr.Slider(value=1.0, minimum=0.0, maximum=2.0, step=0.1, label="Merge Ratio (delta weight)"); t10_mask = gr.Slider(value=0.5, minimum=0.0, maximum=0.99, step=0.01, label="Mask Rate (drop probability)")
1145
+ t10_out = gr.Textbox(label="Output Repo"); t10_priv = gr.Checkbox(label="Private", value=True)
1146
+ t10_btn = gr.Button("Run", variant="primary")
1147
+ t10_res = gr.Textbox(label="Result")
1148
+ t10_btn.click(lambda: gr.update(visible=True), outputs=top_screensaver).then(task_dare_custom, [t10_token, t10_base, t10_ft, t10_ratio, t10_mask, t10_out, t10_priv], t10_res)
1149
+
1150
+ # --- Event Listeners for Tab Switching ---
1151
+ all_screensavers = [top_screensaver, t5_ss, t6_ss, t7_ss, t8_ss]
1152
+ all_tabs = [t1_tab, t2_tab, t3_tab, t4_tab, t9_tab, t10_tab, t5_tab, t6_tab, t7_tab, t8_tab]
1153
+ for tab in all_tabs:
1154
+ tab.select(fn=hide_all_screensavers, inputs=None, outputs=all_screensavers)
1155
+
1156
+ if __name__ == "__main__":
1157
+ demo.queue().launch(css=css, ssr_mode=False, mcp_server=True)