AlekseyCalvin commited on
Commit
2c69e73
·
verified ·
1 Parent(s): c06758a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +779 -0
app.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import yaml
11
+ import subprocess
12
+ import sys
13
+ import tempfile
14
+ import re
15
+ from pathlib import Path
16
+ from typing import Dict, Any, Optional, List, Iterable
17
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
18
+ from safetensors.torch import load_file, save_file
19
+ from tqdm import tqdm
20
+
21
+ # --- Essential Imports ---
22
+ from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
23
+ from mergekit.config import MergeConfiguration
24
+
25
+ # --- Constants & Setup ---
26
+ try:
27
+ TempDir = Path("/tmp/temp_tool")
28
+ os.makedirs(TempDir, exist_ok=True)
29
+ except:
30
+ TempDir = Path("./temp_tool")
31
+ os.makedirs(TempDir, exist_ok=True)
32
+
33
+ api = HfApi()
34
+
35
+ def cleanup_temp():
36
+ if TempDir.exists():
37
+ shutil.rmtree(TempDir)
38
+ os.makedirs(TempDir, exist_ok=True)
39
+ gc.collect()
40
+
41
+ # =================================================================================
42
+ # SHARED HELPERS
43
+ # =================================================================================
44
+
45
+ def parse_hf_url(url):
46
+ if "huggingface.co" in url and "resolve" in url:
47
+ try:
48
+ parts = url.split("huggingface.co/")[-1].split("/")
49
+ repo_id = f"{parts[0]}/{parts[1]}"
50
+ filename = "/".join(parts[4:]).split("?")[0]
51
+ return repo_id, filename
52
+ except:
53
+ return None, None
54
+ return None, None
55
+
56
+ def download_lora_smart(input_str, token):
57
+ local_path = TempDir / "adapter.safetensors"
58
+ if local_path.exists(): os.remove(local_path)
59
+
60
+ repo_id, filename = parse_hf_url(input_str)
61
+ if repo_id and filename:
62
+ try:
63
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
64
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
65
+ if found != local_path: shutil.move(found, local_path)
66
+ return local_path
67
+ except: pass
68
+ try:
69
+ if ".safetensors" in input_str and input_str.count("/") >= 2:
70
+ parts = input_str.split("/")
71
+ repo_id = f"{parts[0]}/{parts[1]}"
72
+ filename = "/".join(parts[2:])
73
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
74
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
75
+ if found != local_path: shutil.move(found, local_path)
76
+ return local_path
77
+ candidates = ["adapter_model.safetensors", "model.safetensors"]
78
+ files = list_repo_files(repo_id=input_str, token=token)
79
+ target = next((f for f in files if f in candidates), None)
80
+ if not target:
81
+ safes = [f for f in files if f.endswith(".safetensors")]
82
+ if safes: target = safes[0]
83
+ if not target: raise ValueError("No safetensors found")
84
+ hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
85
+ found = list(TempDir.rglob(target.split("/")[-1]))[0]
86
+ if found != local_path: shutil.move(found, local_path)
87
+ return local_path
88
+ except Exception as e:
89
+ if input_str.startswith("http"):
90
+ try:
91
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
92
+ r = requests.get(input_str, stream=True, headers=headers, timeout=60)
93
+ r.raise_for_status()
94
+ with open(local_path, 'wb') as f:
95
+ for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
96
+ return local_path
97
+ except: pass
98
+ raise e
99
+
100
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
101
+ state_dict = load_file(lora_path, device="cpu")
102
+ pairs = {}
103
+ alphas = {}
104
+ for k, v in state_dict.items():
105
+ stem = get_key_stem(k)
106
+ if "alpha" in k:
107
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
108
+ else:
109
+ if stem not in pairs: pairs[stem] = {}
110
+ if "lora_down" in k or "lora_A" in k:
111
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
112
+ pairs[stem]["rank"] = v.shape[0]
113
+ elif "lora_up" in k or "lora_B" in k:
114
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
115
+ for stem in pairs:
116
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
117
+ return pairs
118
+
119
+ def get_key_stem(key):
120
+ key = key.replace(".weight", "").replace(".bias", "")
121
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
122
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
123
+ key = key.replace(".alpha", "")
124
+ prefixes = [
125
+ "model.diffusion_model.", "diffusion_model.", "model.",
126
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
127
+ ]
128
+ changed = True
129
+ while changed:
130
+ changed = False
131
+ for p in prefixes:
132
+ if key.startswith(p):
133
+ key = key[len(p):]
134
+ changed = True
135
+ return key
136
+
137
+ # =================================================================================
138
+ # TABS 1-4 LOGIC (RESTORED)
139
+ # =================================================================================
140
+
141
+ class MemoryEfficientSafeOpen:
142
+ def __init__(self, filename):
143
+ self.filename = filename
144
+ self.file = open(filename, "rb")
145
+ self.header, self.header_size = self._read_header()
146
+ def __enter__(self): return self
147
+ def __exit__(self, exc_type, exc_val, exc_tb): self.file.close()
148
+ def keys(self) -> list[str]: return [k for k in self.header.keys() if k != "__metadata__"]
149
+ def metadata(self) -> Dict[str, str]: return self.header.get("__metadata__", {})
150
+ def get_tensor(self, key):
151
+ if key not in self.header: raise KeyError(f"Tensor '{key}' not found")
152
+ metadata = self.header[key]
153
+ start, end = metadata["data_offsets"]
154
+ self.file.seek(self.header_size + 8 + start)
155
+ return self._deserialize_tensor(self.file.read(end - start), metadata)
156
+ def _read_header(self):
157
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
158
+ return json.loads(self.file.read(header_size).decode("utf-8")), header_size
159
+ def _deserialize_tensor(self, tensor_bytes, metadata):
160
+ dtype_map = {"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool}
161
+ dtype = dtype_map[metadata["dtype"]]
162
+ shape = metadata["shape"]
163
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
164
+
165
+ class ShardBuffer:
166
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
167
+ self.max_bytes = int(max_size_gb * 1024**3)
168
+ self.output_dir, self.output_repo, self.subfolder, self.hf_token, self.filename_prefix = output_dir, output_repo, subfolder, hf_token, filename_prefix
169
+ self.buffer, self.current_bytes, self.shard_count, self.index_map, self.total_size = [], 0, 0, {}, 0
170
+ def add_tensor(self, key, tensor):
171
+ if tensor.dtype == torch.bfloat16: raw, dt = tensor.view(torch.int16).numpy().tobytes(), "BF16"
172
+ elif tensor.dtype == torch.float16: raw, dt = tensor.numpy().tobytes(), "F16"
173
+ else: raw, dt = tensor.numpy().tobytes(), "F32"
174
+ self.buffer.append({"key": key, "data": raw, "dtype": dt, "shape": tensor.shape})
175
+ self.current_bytes += len(raw)
176
+ self.total_size += len(raw)
177
+ if self.current_bytes >= self.max_bytes: self.flush()
178
+ def flush(self):
179
+ if not self.buffer: return
180
+ self.shard_count += 1
181
+ fname = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
182
+ header = {"__metadata__": {"format": "pt"}}
183
+ curr_off = 0
184
+ for i in self.buffer:
185
+ header[i["key"]] = {"dtype": i["dtype"], "shape": i["shape"], "data_offsets": [curr_off, curr_off + len(i["data"])]}
186
+ curr_off += len(i["data"])
187
+ self.index_map[i["key"]] = fname
188
+ out = self.output_dir / fname
189
+ header_json = json.dumps(header).encode('utf-8')
190
+ with open(out, 'wb') as f:
191
+ f.write(struct.pack('<Q', len(header_json)))
192
+ f.write(header_json)
193
+ for i in self.buffer: f.write(i["data"])
194
+ api.upload_file(path_or_fileobj=out, path_in_repo=f"{self.subfolder}/{fname}" if self.subfolder else fname, repo_id=self.output_repo, token=self.hf_token)
195
+ os.remove(out)
196
+ self.buffer, self.current_bytes = [], 0
197
+ gc.collect()
198
+
199
+ def task_merge_legacy(hf_token, base, sub, lora, scale, prec, shard, out, struct_s, priv, progress=gr.Progress()):
200
+ cleanup_temp()
201
+ if hf_token: login(hf_token.strip())
202
+ try: api.create_repo(repo_id=out, private=priv, exist_ok=True, token=hf_token)
203
+ except Exception as e: return f"Error: {e}"
204
+ if struct_s:
205
+ try:
206
+ files = api.list_repo_files(repo_id=struct_s, token=hf_token)
207
+ for f in tqdm(files, desc="Copying Structure"):
208
+ if sub and f.startswith(sub): continue
209
+ if not sub and any(f.endswith(x) for x in ['.safetensors', '.bin', '.pt', '.pth']): continue
210
+ l = hf_hub_download(repo_id=struct_s, filename=f, token=hf_token, local_dir=TempDir)
211
+ api.upload_file(path_or_fileobj=l, path_in_repo=f, repo_id=out, token=hf_token)
212
+ except: pass
213
+
214
+ files = [f for f in list_repo_files(repo_id=base, token=hf_token) if f.endswith(".safetensors")]
215
+ if sub: files = [f for f in files if f.startswith(sub)]
216
+ if not files: return "No safetensors found"
217
+
218
+ prefix = "diffusion_pytorch_model" if (sub in ["transformer", "unet"] or "diffusion_pytorch_model" in os.path.basename(files[0])) else "model"
219
+ dtype = torch.bfloat16 if prec == "bf16" else torch.float16 if prec == "fp16" else torch.float32
220
+ try: lora_pairs = load_lora_to_memory(download_lora_smart(lora, hf_token), dtype)
221
+ except Exception as e: return f"LoRA Error: {e}"
222
+
223
+ buf = ShardBuffer(shard, TempDir, out, sub, hf_token, prefix)
224
+ for i, fpath in enumerate(files):
225
+ local = hf_hub_download(repo_id=base, filename=fpath, token=hf_token, local_dir=TempDir)
226
+ with MemoryEfficientSafeOpen(local) as f:
227
+ for k in f.keys():
228
+ v = f.get_tensor(k)
229
+ stem = get_key_stem(k)
230
+ match = lora_pairs.get(stem) or lora_pairs.get(stem.replace("to_q", "qkv")) or lora_pairs.get(stem.replace("to_k", "qkv")) or lora_pairs.get(stem.replace("to_v", "qkv"))
231
+ if match:
232
+ d, u = match["down"], match["up"]
233
+ s = scale * (match["alpha"] / match["rank"])
234
+ if len(v.shape)==4 and len(d.shape)==2: d, u = d.unsqueeze(-1).unsqueeze(-1), u.unsqueeze(-1).unsqueeze(-1)
235
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) if len(up.shape)==4 else u @ d
236
+ v = v.to(dtype).add_((delta * s).to(dtype))
237
+ buf.add_tensor(k, v.to(dtype))
238
+ os.remove(local)
239
+ buf.flush()
240
+ idx = {"metadata": {"total_size": buf.total_size}, "weight_map": buf.index_map}
241
+ idx_n = f"{prefix}.safetensors.index.json"
242
+ with open(TempDir/idx_n, "w") as f: json.dump(idx, f, indent=4)
243
+ api.upload_file(path_or_fileobj=TempDir/idx_n, path_in_repo=f"{sub}/{idx_n}" if sub else idx_n, repo_id=out, token=hf_token)
244
+ return "Done"
245
+
246
+ def task_extract(hf_token, org, tun, rank, out):
247
+ cleanup_temp()
248
+ if hf_token: login(hf_token.strip())
249
+ try:
250
+ p1 = download_lora_smart(org, hf_token)
251
+ p2 = download_lora_smart(tun, hf_token)
252
+ org_f, tun_f = MemoryEfficientSafeOpen(p1), MemoryEfficientSafeOpen(p2)
253
+ lora_sd = {}
254
+ common = set(org_f.keys()) & set(tun_f.keys())
255
+ for k in tqdm(common, desc="Extracting"):
256
+ if "num_batches_tracked" in k or "running_mean" in k or "running_var" in k: continue
257
+ m1, m2 = org_f.get_tensor(k).float(), tun_f.get_tensor(k).float()
258
+ if m1.shape != m2.shape: continue
259
+ diff = m2 - m1
260
+ if torch.max(torch.abs(diff)) < 1e-4: continue
261
+ out_d, in_d = diff.shape[0], diff.shape[1] if len(diff.shape) > 1 else 1
262
+ r = min(int(rank), in_d, out_d)
263
+ if len(diff.shape)==4: diff = diff.flatten(1)
264
+ elif len(diff.shape)==1: diff = diff.unsqueeze(1)
265
+ U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
266
+ Vh = V.t()
267
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
268
+ U = U @ torch.diag(S)
269
+ dist = torch.cat([U.flatten(), Vh.flatten()])
270
+ hi_val = torch.quantile(torch.abs(dist), 0.99)
271
+ if hi_val > 0: U, Vh = U.clamp(-hi_val, hi_val), Vh.clamp(-hi_val, hi_val)
272
+ if len(m1.shape)==4:
273
+ U = U.reshape(out_d, r, 1, 1)
274
+ Vh = Vh.reshape(r, in_d, m1.shape[2], m1.shape[3])
275
+ else:
276
+ U, Vh = U.reshape(out_d, r), Vh.reshape(r, in_d)
277
+ stem = k.replace(".weight", "")
278
+ lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
279
+ lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
280
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
281
+ out_f = TempDir/"extracted.safetensors"
282
+ save_file(lora_sd, out_f)
283
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
284
+ api.upload_file(path_or_fileobj=out_f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
285
+ return "Done"
286
+ except Exception as e: return f"Error: {e}"
287
+
288
+ def load_full_state_dict(path):
289
+ raw = load_file(path, device="cpu")
290
+ cleaned = {}
291
+ for k, v in raw.items():
292
+ if "lora_A" in k: new_k = k.replace("lora_A", "lora_down")
293
+ elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up")
294
+ else: new_k = k
295
+ cleaned[new_k] = v.float()
296
+ return cleaned
297
+
298
+ def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private):
299
+ cleanup_temp()
300
+ if hf_token: login(hf_token.strip())
301
+ urls = [line.strip() for line in inputs_text.replace(" ", "\n").split('\n') if line.strip()]
302
+ if len(urls) < 2: return "Error: Provide at least 2 adapters."
303
+ try: weights = [float(w.strip()) for w in weight_str.split(',')] if weight_str.strip() else [1.0] * len(urls)
304
+ except: return "Error parsing weights."
305
+ if len(weights) < len(urls): weights += [1.0] * (len(urls) - len(weights))
306
+
307
+ paths = []
308
+ for url in tqdm(urls, desc="Downloading"): paths.append(download_lora_smart(url, hf_token))
309
+
310
+ merged = {}
311
+ if "Iterative EMA" in method:
312
+ base_sd = load_file(paths[0], device="cpu")
313
+ gamma = None
314
+ if sigma_rel > 0:
315
+ t_val = sigma_rel**-2
316
+ roots = np.roots([1, 7, 16 - t_val, 12 - t_val])
317
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
318
+ for i, path in enumerate(paths[1:]):
319
+ current_beta = (1 - 1 / (i + 1)) ** (gamma + 1) if gamma is not None else beta
320
+ curr = load_file(path, device="cpu")
321
+ for k in base_sd:
322
+ if k in curr and "alpha" not in k:
323
+ base_sd[k] = base_sd[k].float() * current_beta + curr[k].float() * (1 - current_beta)
324
+ merged = base_sd
325
+ else:
326
+ states = [load_full_state_dict(p) for p in paths]
327
+ all_stems = set()
328
+ for s in states:
329
+ for k in s:
330
+ if "lora_" in k: all_stems.add(k.split(".lora_")[0])
331
+ for stem in tqdm(all_stems):
332
+ down_list, up_list = [], []
333
+ alpha_sum, total_delta = 0.0, None
334
+ for i, state in enumerate(states):
335
+ w = weights[i]
336
+ dk, uk, ak = f"{stem}.lora_down.weight", f"{stem}.lora_up.weight", f"{stem}.alpha"
337
+ if dk in state and uk in state:
338
+ d, u = state[dk], state[uk]
339
+ alpha_sum += state[ak].item() if ak in state else d.shape[0]
340
+ if "Concatenation" in method:
341
+ down_list.append(d); up_list.append(u * w)
342
+ elif "SVD" in method:
343
+ rank = d.shape[0]
344
+ alpha = state[ak].item() if ak in state else rank
345
+ scale = (alpha / rank) * w
346
+ delta = ((u.flatten(1) @ d.flatten(1)).reshape(u.shape[0], d.shape[1], d.shape[2], d.shape[3]) if len(d.shape)==4 else u @ d) * scale
347
+ total_delta = delta if total_delta is None else total_delta + delta
348
+ if "Concatenation" in method and down_list:
349
+ merged[f"{stem}.lora_down.weight"] = torch.cat(down_list, dim=0).contiguous()
350
+ merged[f"{stem}.lora_up.weight"] = torch.cat(up_list, dim=1).contiguous()
351
+ merged[f"{stem}.alpha"] = torch.tensor(alpha_sum)
352
+ elif "SVD" in method and total_delta is not None:
353
+ tr = int(target_rank)
354
+ flat = total_delta.flatten(1) if len(total_delta.shape)==4 else total_delta
355
+ try:
356
+ U, S, V = torch.svd_lowrank(flat, q=tr + 4, niter=4)
357
+ Vh = V.t()
358
+ U, S, Vh = U[:, :tr], S[:tr], Vh[:tr, :]
359
+ U = U @ torch.diag(S)
360
+ if len(total_delta.shape) == 4:
361
+ U = U.reshape(total_delta.shape[0], tr, 1, 1)
362
+ Vh = Vh.reshape(tr, total_delta.shape[1], total_delta.shape[2], total_delta.shape[3])
363
+ else:
364
+ U, Vh = U.reshape(total_delta.shape[0], tr), Vh.reshape(tr, total_delta.shape[1])
365
+ merged[f"{stem}.lora_down.weight"] = Vh.contiguous()
366
+ merged[f"{stem}.lora_up.weight"] = U.contiguous()
367
+ merged[f"{stem}.alpha"] = torch.tensor(tr).float()
368
+ except: pass
369
+
370
+ out = TempDir / "merged_adapters.safetensors"
371
+ if merged: save_file(merged, out)
372
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
373
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
374
+ return f"Success! Merged to {out_repo}"
375
+
376
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
377
+ cleanup_temp()
378
+ if hf_token: login(hf_token.strip())
379
+ path = download_lora_smart(lora_input, hf_token)
380
+ state = load_file(path, device="cpu")
381
+ new_state = {}
382
+ groups = {}
383
+ for k in state:
384
+ simple = k.split(".lora_")[0]
385
+ if simple not in groups: groups[simple] = {}
386
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
387
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
388
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
389
+
390
+ target_rank_limit = int(new_rank)
391
+ for stem, g in tqdm(groups.items()):
392
+ if "down" in g and "up" in g:
393
+ down, up = g["down"].float(), g["up"].float()
394
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) if len(down.shape)==4 else up @ down
395
+ flat = merged.flatten(1)
396
+ U, S, V = torch.svd_lowrank(flat, q=target_rank_limit + 32)
397
+ Vh = V.t()
398
+ calc_rank = target_rank_limit
399
+ if dynamic_method == "sv_ratio":
400
+ calc_rank = int(torch.sum(S > (S[0] / dynamic_param)).item())
401
+ elif dynamic_method == "sv_cumulative":
402
+ calc_rank = int(torch.searchsorted(torch.cumsum(S, 0) / torch.sum(S), dynamic_param)) + 1
403
+ elif dynamic_method == "sv_fro":
404
+ calc_rank = int(torch.searchsorted(torch.cumsum(S.pow(2), 0) / torch.sum(S.pow(2)), dynamic_param**2)) + 1
405
+ final_rank = max(1, min(calc_rank, target_rank_limit, S.shape[0]))
406
+ U = U[:, :final_rank] @ torch.diag(S[:final_rank])
407
+ Vh = Vh[:final_rank, :]
408
+ if len(down.shape) == 4:
409
+ U = U.reshape(up.shape[0], final_rank, 1, 1)
410
+ Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3])
411
+ new_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
412
+ new_state[f"{stem}.lora_up.weight"] = U.contiguous()
413
+ new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float()
414
+ out = TempDir / "shrunken.safetensors"
415
+ save_file(new_state, out)
416
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
417
+ api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token)
418
+ return "Done"
419
+
420
+ # =================================================================================
421
+ # MERGEKIT & LOGSVIEW (TABS 5-9) - CLI LOGIC
422
+ # =================================================================================
423
+
424
+ def parse_weight(w_str):
425
+ if not w_str.strip(): return 1.0
426
+ try:
427
+ if "[" in w_str: return yaml.safe_load(w_str)
428
+ return float(w_str)
429
+ except: return 1.0
430
+
431
+ def run_mergekit_logic(config_dict, token, out_repo, private, shard_size, output_precision, tokenizer_source, chat_template, program="mergekit-yaml"):
432
+ runner = LogsViewRunner()
433
+ cleanup_temp()
434
+
435
+ if chat_template and chat_template.strip():
436
+ config_dict["chat_template"] = chat_template.strip()
437
+
438
+ try:
439
+ if program != "mergekit-moe":
440
+ MergeConfiguration.model_validate(config_dict)
441
+ except Exception as e:
442
+ yield runner.log(f"Invalid Config: {e}", level="ERROR")
443
+ return
444
+
445
+ if token:
446
+ login(token.strip())
447
+ os.environ["HF_TOKEN"] = token.strip()
448
+
449
+ if "dtype" not in config_dict: config_dict["dtype"] = output_precision
450
+ if "tokenizer_source" not in config_dict and tokenizer_source != "base":
451
+ config_dict["tokenizer_source"] = tokenizer_source
452
+
453
+ config_path = TempDir / "config.yaml"
454
+ with open(config_path, "w") as f: yaml.dump(config_dict, f, sort_keys=False)
455
+
456
+ yield runner.log(f"Config saved to {config_path}")
457
+ yield runner.log(f"YAML:\n{yaml.dump(config_dict, sort_keys=False)}")
458
+
459
+ try:
460
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=token)
461
+ yield runner.log(f"Repo {out_repo} ready.")
462
+ except Exception as e:
463
+ yield runner.log(f"Repo Error: {e}", level="ERROR")
464
+ return
465
+
466
+ out_path = TempDir / "merge_output"
467
+ shard_arg = f"{int(float(shard_size) * 1024)}M"
468
+
469
+ cmd = [
470
+ program,
471
+ str(config_path),
472
+ str(out_path),
473
+ "--allow-crimes",
474
+ "--copy-tokenizer",
475
+ "--out-shard-size", shard_arg,
476
+ "--lazy-unpickle"
477
+ ]
478
+
479
+ if torch.cuda.is_available():
480
+ cmd.extend(["--cuda", "--low-cpu-memory"])
481
+
482
+ yield runner.log(f"Executing: {' '.join(cmd)}")
483
+ env = os.environ.copy()
484
+ env["HF_HOME"] = str(TempDir / ".cache")
485
+
486
+ yield from runner.run_command(cmd, env=env)
487
+
488
+ if runner.exit_code != 0:
489
+ yield runner.log("Merge failed.", level="ERROR")
490
+ return
491
+
492
+ yield runner.log(f"Uploading to {out_repo}...")
493
+ yield from runner.run_python(api.upload_folder, repo_id=out_repo, folder_path=out_path)
494
+ yield runner.log("Upload Complete!")
495
+
496
+ # --- UI Wrappers for Tabs 5-9 ---
497
+
498
+ def wrapper_amphinterpolative(token, method, base, t, norm, i8, flat, row, eps, m_iter, tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t):
499
+ params = {"normalize": norm, "int8_mask": i8}
500
+ if method in ["slerp", "nuslerp"]: params["t"] = float(t)
501
+ if method == "nuslerp": params.update({"flatten": flat, "row_wise": row})
502
+ if method == "multislerp": params["eps"] = float(eps)
503
+ if method == "karcher": params.update({"max_iter": int(m_iter), "tol": float(tol)})
504
+
505
+ config = {"merge_method": method}
506
+
507
+ if method in ["slerp", "nuslerp"]:
508
+ if not base.strip(): yield runner.log("Error: Base model required", level="ERROR"); return
509
+ config["base_model"] = base.strip()
510
+ sources = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2)] if m.strip()]
511
+ config["slices"] = [{"sources": sources, "parameters": params}]
512
+ else:
513
+ if base.strip() and method == "multislerp": config["base_model"] = base.strip()
514
+ models = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)] if m.strip()]
515
+ config["models"] = models
516
+ config["parameters"] = params
517
+
518
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
519
+
520
+ def wrapper_stirtie(token, method, base, norm, i8, lamb, resc, topk, m1, w1, d1, g1, e1, m2, w2, d2, g2, e2, m3, w3, d3, g3, e3, m4, w4, d4, g4, e4, out, priv, shard, prec, tok_src, chat_t):
521
+ models = []
522
+ # Explicit loop over the 4 sets of model inputs
523
+ for m, w, d, g, e in [
524
+ (m1, w1, d1, g1, e1),
525
+ (m2, w2, d2, g2, e2),
526
+ (m3, w3, d3, g3, e3),
527
+ (m4, w4, d4, g4, e4)
528
+ ]:
529
+ if not m.strip(): continue
530
+ p = {"weight": parse_weight(w)}
531
+ if method in ["ties", "dare_ties", "dare_linear", "breadcrumbs_ties"]: p["density"] = parse_weight(d)
532
+ if "breadcrumbs" in method: p["gamma"] = float(g)
533
+ if "della" in method: p["epsilon"] = float(e)
534
+ models.append({"model": m, "parameters": p})
535
+
536
+ g_params = {"normalize": norm, "int8_mask": i8}
537
+ if method != "sce": g_params["lambda"] = float(lamb)
538
+ if method == "dare_linear": g_params["rescale"] = resc
539
+ if method == "sce": g_params["select_topk"] = float(topk)
540
+
541
+ config = {
542
+ "merge_method": method,
543
+ "base_model": base.strip() if base.strip() else models[0]["model"],
544
+ "parameters": g_params,
545
+ "models": models
546
+ }
547
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
548
+
549
+ def wrapper_specious(token, method, base, norm, i8, t, filt_w, m1, w1, f1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t):
550
+ models = []
551
+ if method == "passthrough":
552
+ p = {"weight": parse_weight(w1)}
553
+ if f1.strip(): p["filter"] = f1.strip()
554
+ models.append({"model": m1, "parameters": p})
555
+ else:
556
+ models = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)] if m.strip()]
557
+
558
+ config = {"merge_method": method, "parameters": {"normalize": norm, "int8_mask": i8}}
559
+ if base.strip(): config["base_model"] = base.strip()
560
+ if method == "nearswap": config["parameters"]["t"] = float(t)
561
+ if method == "model_stock": config["parameters"]["filter_wise"] = filt_w
562
+ config["models"] = models
563
+
564
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
565
+
566
+ def wrapper_moer(token, base, experts, gate, dtype, out, priv, shard, prec, tok_src, chat_t):
567
+ formatted = [{"source_model": e.strip(), "positive_prompts": ["chat", "assist"]} for e in experts.split('\n') if e.strip()]
568
+ config = {
569
+ "base_model": base.strip() if base.strip() else formatted[0]["source_model"],
570
+ "gate_mode": gate,
571
+ "dtype": dtype,
572
+ "experts": formatted
573
+ }
574
+ # Uses mergekit-moe CLI
575
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-moe")
576
+
577
+ def wrapper_rawer(token, models, method, dtype, out, priv, shard, prec, tok_src, chat_t):
578
+ models_list = [{"model": m.strip(), "parameters": {"weight": 1.0}} for m in models.split('\n') if m.strip()]
579
+ config = {
580
+ "models": models_list,
581
+ "merge_method": method,
582
+ "dtype": dtype
583
+ }
584
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
585
+
586
+ # --- TAB 10 (Custom DARE) Logic ---
587
+ def task_dare_custom(token, base, ft, ratio, mask, out, priv):
588
+ cleanup_temp()
589
+ if token: login(token.strip())
590
+ try:
591
+ b_path = download_lora_smart(base, token)
592
+ f_path = download_lora_smart(ft, token)
593
+ b_sd = load_file(b_path, device="cpu")
594
+ f_sd = load_file(f_path, device="cpu")
595
+ merged = {}
596
+ common = set(b_sd.keys()) & set(f_sd.keys())
597
+ for k in tqdm(common, desc="Merging"):
598
+ tb, tf = b_sd[k], f_sd[k]
599
+ if tb.shape != tf.shape:
600
+ merged[k] = tf
601
+ continue
602
+ delta = tf.float() - tb.float()
603
+ if mask > 0:
604
+ m = torch.bernoulli(torch.full_like(delta, 1.0 - mask))
605
+ delta = (delta * m) / (1.0 - mask)
606
+ merged[k] = (tb.float() + ratio * delta).to(tb.dtype)
607
+
608
+ out_f = TempDir / "model.safetensors"
609
+ save_file(merged, out_f)
610
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
611
+ api.upload_file(path_or_fileobj=out_f, path_in_repo="model.safetensors", repo_id=out, token=token)
612
+ return f"Done! {out}"
613
+ except Exception as e: return str(e)
614
+
615
+ # =================================================================================
616
+ # UI GENERATION
617
+ # =================================================================================
618
+
619
+ css = ".container { max-width: 1100px; margin: auto; }"
620
+
621
+ with gr.Blocks() as demo:
622
+ gr.HTML("""<h1><img src="https://huggingface.co/spaces/AlekseyCalvin/Soon_Merger/resolve/main/SMerger3.png" alt="SOONmerge®"> Transform Transformers for FREE!</h1>""")
623
+ gr.Markdown("# 🧰Training-Free CPU-run Model Creation Toolkit")
624
+
625
+ with gr.Tabs():
626
+ # --- TAB 1: RESTORED ---
627
+ with gr.Tab("Merge to Base Model + Reshard Output"):
628
+ t1_token = gr.Textbox(label="Token", type="password")
629
+ t1_base = gr.Textbox(label="Base Repo", value="name/repo")
630
+ t1_sub = gr.Textbox(label="Subfolder (Optional)", value="")
631
+ t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
632
+ with gr.Row():
633
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1)
634
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
635
+ t1_shard = gr.Slider(label="Max Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1)
636
+ t1_out = gr.Textbox(label="Output Repo")
637
+ t1_struct = gr.Textbox(label="Extras Source (copies configs/components/etc)", value="name/repo")
638
+ t1_priv = gr.Checkbox(label="Private", value=True)
639
+ gr.Button("Merge").click(task_merge_legacy, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], gr.Textbox(label="Result"))
640
+
641
+ # --- TAB 2: RESTORED ---
642
+ with gr.Tab("Extract Adapter"):
643
+ t2_token = gr.Textbox(label="Token", type="password")
644
+ t2_org = gr.Textbox(label="Original Model")
645
+ t2_tun = gr.Textbox(label="Tuned or Homologous Model")
646
+ t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1)
647
+ t2_out = gr.Textbox(label="Output Repo")
648
+ gr.Button("Extract").click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], gr.Textbox(label="Result"))
649
+
650
+ # --- TAB 3: RESTORED ---
651
+ with gr.Tab("Merge Adapters"):
652
+ gr.Markdown("### Batch Adapter Merging")
653
+ t3_token = gr.Textbox(label="Token", type="password")
654
+ t3_urls = gr.TextArea(label="Adapter URLs/Repos (one per line, or space-separated)")
655
+ t3_method = gr.Dropdown(["Iterative EMA (Linear w/ Beta/Sigma coefficient)", "Concatenation (MOE-like weights-stack)", "SVD Fusion (Task Arithmetic/Compressed)"], value="Iterative EMA (Linear w/ Beta/Sigma coefficient)", label="Merge Method")
656
+ with gr.Row():
657
+ t3_weights = gr.Textbox(label="Weights (comma-separated) – for Concat/SVD")
658
+ t3_rank = gr.Number(label="Target Rank – For SVD only", value=128)
659
+ with gr.Row():
660
+ t3_beta = gr.Slider(label="Beta – for linear/post-hoc EMA", value=0.95, minimum=0.01, maximum=1.00)
661
+ t3_sigma = gr.Slider(label="Sigma Rel – for linear/post-hoc EMA", value=0.21, minimum=0.01, maximum=1.00)
662
+ t3_out = gr.Textbox(label="Output Repo")
663
+ t3_priv = gr.Checkbox(label="Private Output", value=True)
664
+ gr.Button("Merge").click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], gr.Textbox(label="Result"))
665
+
666
+ # --- TAB 4: RESTORED ---
667
+ with gr.Tab("Resize Adapter"):
668
+ t4_token = gr.Textbox(label="Token", type="password")
669
+ t4_in = gr.Textbox(label="LoRA")
670
+ with gr.Row():
671
+ t4_rank = gr.Number(label="To Rank (Safety Ceiling)", value=8)
672
+ t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None", label="Dynamic Method")
673
+ t4_param = gr.Number(label="Dynamic Param", value=0.9)
674
+ gr.Markdown("### 📉 Dynamic Resizing Guide\nThese methods intelligently determine the best rank per layer.\n- **sv_ratio (Relative Strength):** Keeps features that are at least `1/Param` as strong as the main feature. **Param must be >= 2**.\n- **sv_fro (Visual Information Density):** Preserves `Param%` of total information content. **Param between 0.0 and 1.0**.\n- **sv_cumulative (Cumulative Sum):** Preserves weights that sum up to `Param%` of total strength. **Param between 0.0 and 1.0**.\n- **⚠️ Safety Ceiling:** The **'To Rank'** slider acts as a hard limit.")
675
+ t4_out = gr.Textbox(label="Output")
676
+ gr.Button("Resize").click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], gr.Textbox(label="Result"))
677
+
678
+ # --- TAB 5 ---
679
+ with gr.Tab("Amphinterpolative"):
680
+ gr.Markdown("### Spherical Interpolation Family")
681
+ t5_token = gr.Textbox(label="HF Token", type="password")
682
+ t5_method = gr.Dropdown(["slerp", "nuslerp", "multislerp", "karcher"], value="slerp", label="Method")
683
+ with gr.Row():
684
+ t5_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0)
685
+ t5_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision")
686
+ t5_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source")
687
+ t5_chat = gr.Textbox(label="Chat Template (write-in, default: auto)", placeholder="auto")
688
+ with gr.Row():
689
+ t5_base = gr.Textbox(label="Base Model")
690
+ t5_t = gr.Slider(0, 1, 0.5, label="t")
691
+ with gr.Row():
692
+ t5_norm = gr.Checkbox(label="Normalize", value=True); t5_i8 = gr.Checkbox(label="Int8 Mask", value=False); t5_flat = gr.Checkbox(label="NuSlerp Flatten", value=False); t5_row = gr.Checkbox(label="NuSlerp Row Wise", value=False)
693
+ with gr.Row():
694
+ t5_eps = gr.Textbox(label="eps", value="1e-8"); t5_iter = gr.Number(label="max_iter", value=10); t5_tol = gr.Textbox(label="tol", value="1e-5")
695
+ m1, w1 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"); m2, w2 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0")
696
+ with gr.Accordion("More", open=False):
697
+ m3, w3 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); m4, w4 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); m5, w5 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0")
698
+ t5_out = gr.Textbox(label="Output Repo"); t5_priv = gr.Checkbox(label="Private", value=True)
699
+ t5_btn = gr.Button("Execute")
700
+ t5_logs = LogsView()
701
+ t5_btn.click(wrapper_amphinterpolative, [t5_token, t5_method, t5_base, t5_t, t5_norm, t5_i8, t5_flat, t5_row, t5_eps, t5_iter, t5_tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, t5_out, t5_priv, t5_shard, t5_prec, t5_tok, t5_chat], t5_logs)
702
+
703
+ # --- TAB 6 ---
704
+ with gr.Tab("Stir/Tie Bases"):
705
+ gr.Markdown("### Task Vector Family")
706
+ t6_token = gr.Textbox(label="Token", type="password")
707
+ t6_method = gr.Dropdown(["task_arithmetic", "ties", "dare_ties", "dare_linear", "della", "della_linear", "breadcrumbs", "breadcrumbs_ties", "sce"], value="ties", label="Method")
708
+ with gr.Row():
709
+ t6_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t6_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t6_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t6_chat = gr.Textbox(label="Chat Template", placeholder="auto")
710
+ t6_base = gr.Textbox(label="Base Model")
711
+ with gr.Row():
712
+ t6_norm = gr.Checkbox(label="Normalize", value=True); t6_i8 = gr.Checkbox(label="Int8 Mask", value=False); t6_resc = gr.Checkbox(label="Rescale", value=True); t6_lamb = gr.Number(label="Lambda", value=1.0); t6_topk = gr.Slider(0, 1, 1.0, label="Select TopK")
713
+ m1_6, w1_6 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"); d1_6, g1_6, e1_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
714
+ with gr.Accordion("More", open=False):
715
+ m2_6, w2_6 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0"); d2_6, g2_6, e2_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
716
+ # Corrected UI Layout: Added missing textboxes for Models 3 & 4
717
+ m3_6, w3_6 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); d3_6, g3_6, e3_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
718
+ m4_6, w4_6 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); d4_6, g4_6, e4_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
719
+ t6_out = gr.Textbox(label="Output Repo"); t6_priv = gr.Checkbox(label="Private", value=True)
720
+ t6_btn = gr.Button("Execute")
721
+ t6_logs = LogsView()
722
+ # The .click list now matches the wrapper function signature perfectly (34 arguments)
723
+ t6_btn.click(wrapper_stirtie, [t6_token, t6_method, t6_base, t6_norm, t6_i8, t6_lamb, t6_resc, t6_topk, m1_6, w1_6, d1_6, g1_6, e1_6, m2_6, w2_6, d2_6, g2_6, e2_6, m3_6, w3_6, d3_6, g3_6, e3_6, m4_6, w4_6, d4_6, g4_6, e4_6, t6_out, t6_priv, t6_shard, t6_prec, t6_tok, t6_chat], t6_logs)
724
+
725
+ # --- TAB 7 ---
726
+ with gr.Tab("Specious"):
727
+ gr.Markdown("### Specialized Methods")
728
+ t7_token = gr.Textbox(label="Token", type="password")
729
+ t7_method = gr.Dropdown(["model_stock", "nearswap", "arcee_fusion", "passthrough", "linear"], value="model_stock", label="Method")
730
+ with gr.Row():
731
+ t7_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t7_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t7_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t7_chat = gr.Textbox(label="Chat Template", placeholder="auto")
732
+ t7_base = gr.Textbox(label="Base Model")
733
+ with gr.Row():
734
+ t7_norm = gr.Checkbox(label="Normalize", value=True); t7_i8 = gr.Checkbox(label="Int8 Mask", value=False); t7_t = gr.Slider(0, 1, 0.5, label="t"); t7_filt_w = gr.Checkbox(label="Filter Wise", value=False)
735
+ m1_7, w1_7, f1_7 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"), gr.Textbox(label="Filter (Passthrough)")
736
+ m2_7, w2_7 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0")
737
+ with gr.Accordion("More", open=False):
738
+ m3_7, w3_7 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); m4_7, w4_7 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); m5_7, w5_7 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0")
739
+ t7_out = gr.Textbox(label="Output Repo"); t7_priv = gr.Checkbox(label="Private", value=True)
740
+ t7_btn = gr.Button("Execute")
741
+ t7_logs = LogsView()
742
+ t7_btn.click(wrapper_specious, [t7_token, t7_method, t7_base, t7_norm, t7_i8, t7_t, t7_filt_w, m1_7, w1_7, f1_7, m2_7, w2_7, m3_7, w3_7, m4_7, w4_7, m5_7, w5_7, t7_out, t7_priv, t7_shard, t7_prec, t7_tok, t7_chat], t7_logs)
743
+
744
+ # --- TAB 8 (MoEr) ---
745
+ with gr.Tab("MoEr"):
746
+ gr.Markdown("### Mixture of Experts")
747
+ t8_token = gr.Textbox(label="Token", type="password")
748
+ with gr.Row():
749
+ t8_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t8_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t8_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t8_chat = gr.Textbox(label="Chat Template", placeholder="auto")
750
+ t8_base = gr.Textbox(label="Base Model"); t8_experts = gr.TextArea(label="Experts List"); t8_gate = gr.Dropdown(["cheap_embed", "random", "hidden"], value="cheap_embed", label="Gate Mode"); t8_dtype = gr.Dropdown(["float16", "bfloat16"], value="bfloat16", label="Internal Dtype")
751
+ t8_out = gr.Textbox(label="Output Repo"); t8_priv = gr.Checkbox(label="Private", value=True)
752
+ t8_btn = gr.Button("Build MoE")
753
+ t8_logs = LogsView()
754
+ t8_btn.click(wrapper_moer, [t8_token, t8_base, t8_experts, t8_gate, t8_dtype, t8_out, t8_priv, t8_shard, t8_prec, t8_tok, t8_chat], t8_logs)
755
+
756
+ # --- TAB 9 (Rawer) ---
757
+ with gr.Tab("Rawer"):
758
+ gr.Markdown("### Raw PyTorch / Non-Transformer")
759
+ t9_token = gr.Textbox(label="Token", type="password"); t9_models = gr.TextArea(label="Models (one per line)")
760
+ with gr.Row():
761
+ t9_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t9_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t9_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t9_chat = gr.Textbox(label="Chat Template", placeholder="auto")
762
+ t9_method = gr.Dropdown(["linear", "passthrough"], value="linear", label="Method"); t9_dtype = gr.Dropdown(["float32", "float16", "bfloat16"], value="float32", label="Config Dtype")
763
+ t9_out = gr.Textbox(label="Output Repo"); t9_priv = gr.Checkbox(label="Private", value=True)
764
+ t9_btn = gr.Button("Merge Raw")
765
+ t9_logs = LogsView()
766
+ t9_btn.click(wrapper_rawer, [t9_token, t9_models, t9_method, t9_dtype, t9_out, t9_priv, t9_shard, t9_prec, t9_tok, t9_chat], t9_logs)
767
+
768
+ # --- TAB 10 ---
769
+ with gr.Tab("Mario,DARE!"):
770
+ t10_token = gr.Textbox(label="Token", type="password")
771
+ with gr.Row():
772
+ t10_base = gr.Textbox(label="Base Model"); t10_ft = gr.Textbox(label="Fine-Tuned Model")
773
+ with gr.Row():
774
+ t10_ratio = gr.Slider(0, 5, 1.0, label="Ratio"); t10_mask = gr.Slider(0, 0.99, 0.5, label="Mask Rate")
775
+ t10_out = gr.Textbox(label="Output Repo"); t10_priv = gr.Checkbox(label="Private", value=True)
776
+ gr.Button("Run").click(task_dare_custom, [t10_token, t10_base, t10_ft, t10_ratio, t10_mask, t10_out, t10_priv], gr.Textbox(label="Result"))
777
+
778
+ if __name__ == "__main__":
779
+ demo.queue().launch(css=css, ssr_mode=False, debug=True)