AlekseyCalvin commited on
Commit
cb05474
·
verified ·
1 Parent(s): a8ff6b9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +771 -0
app.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import yaml
11
+ import subprocess
12
+ import sys
13
+ import tempfile
14
+ import re
15
+ from pathlib import Path
16
+ from typing import Dict, Any, Optional, List, Iterable
17
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
18
+ from safetensors.torch import load_file, save_file
19
+ from tqdm import tqdm
20
+
21
+ # --- Essential Imports (No try-except blocks to ensure visibility of errors) ---
22
+ from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
23
+ from mergekit.config import MergeConfiguration
24
+
25
+ # --- Constants ---
26
+ try:
27
+ TempDir = Path("/tmp/temp_tool")
28
+ os.makedirs(TempDir, exist_ok=True)
29
+ except:
30
+ TempDir = Path("./temp_tool")
31
+ os.makedirs(TempDir, exist_ok=True)
32
+
33
+ api = HfApi()
34
+
35
+ def cleanup_temp():
36
+ if TempDir.exists():
37
+ shutil.rmtree(TempDir)
38
+ os.makedirs(TempDir, exist_ok=True)
39
+ gc.collect()
40
+
41
+ # =================================================================================
42
+ # SHARED HELPERS (Tabs 1-4 & 10)
43
+ # =================================================================================
44
+
45
+ def parse_hf_url(url):
46
+ if "huggingface.co" in url and "resolve" in url:
47
+ try:
48
+ parts = url.split("huggingface.co/")[-1].split("/")
49
+ repo_id = f"{parts[0]}/{parts[1]}"
50
+ filename = "/".join(parts[4:]).split("?")[0]
51
+ return repo_id, filename
52
+ except:
53
+ return None, None
54
+ return None, None
55
+
56
+ def download_lora_smart(input_str, token):
57
+ local_path = TempDir / "adapter.safetensors"
58
+ if local_path.exists(): os.remove(local_path)
59
+
60
+ repo_id, filename = parse_hf_url(input_str)
61
+ if repo_id and filename:
62
+ try:
63
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
64
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
65
+ if found != local_path: shutil.move(found, local_path)
66
+ return local_path
67
+ except: pass
68
+ try:
69
+ if ".safetensors" in input_str and input_str.count("/") >= 2:
70
+ parts = input_str.split("/")
71
+ repo_id = f"{parts[0]}/{parts[1]}"
72
+ filename = "/".join(parts[2:])
73
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
74
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
75
+ if found != local_path: shutil.move(found, local_path)
76
+ return local_path
77
+ candidates = ["adapter_model.safetensors", "model.safetensors"]
78
+ files = list_repo_files(repo_id=input_str, token=token)
79
+ target = next((f for f in files if f in candidates), None)
80
+ if not target:
81
+ safes = [f for f in files if f.endswith(".safetensors")]
82
+ if safes: target = safes[0]
83
+ if not target: raise ValueError("No safetensors found")
84
+ hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
85
+ found = list(TempDir.rglob(target.split("/")[-1]))[0]
86
+ if found != local_path: shutil.move(found, local_path)
87
+ return local_path
88
+ except Exception as e:
89
+ if input_str.startswith("http"):
90
+ try:
91
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
92
+ r = requests.get(input_str, stream=True, headers=headers, timeout=60)
93
+ r.raise_for_status()
94
+ with open(local_path, 'wb') as f:
95
+ for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
96
+ return local_path
97
+ except: pass
98
+ raise e
99
+
100
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
101
+ state_dict = load_file(lora_path, device="cpu")
102
+ pairs = {}
103
+ alphas = {}
104
+ for k, v in state_dict.items():
105
+ stem = get_key_stem(k)
106
+ if "alpha" in k:
107
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
108
+ else:
109
+ if stem not in pairs: pairs[stem] = {}
110
+ if "lora_down" in k or "lora_A" in k:
111
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
112
+ pairs[stem]["rank"] = v.shape[0]
113
+ elif "lora_up" in k or "lora_B" in k:
114
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
115
+ for stem in pairs:
116
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
117
+ return pairs
118
+
119
+ def get_key_stem(key):
120
+ key = key.replace(".weight", "").replace(".bias", "")
121
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
122
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
123
+ key = key.replace(".alpha", "")
124
+ prefixes = [
125
+ "model.diffusion_model.", "diffusion_model.", "model.",
126
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
127
+ ]
128
+ changed = True
129
+ while changed:
130
+ changed = False
131
+ for p in prefixes:
132
+ if key.startswith(p):
133
+ key = key[len(p):]
134
+ changed = True
135
+ return key
136
+
137
+ # =================================================================================
138
+ # TABS 1-4 LOGIC (Legacy Python Implementation)
139
+ # =================================================================================
140
+
141
+ class MemoryEfficientSafeOpen:
142
+ def __init__(self, filename):
143
+ self.filename = filename
144
+ self.file = open(filename, "rb")
145
+ self.header, self.header_size = self._read_header()
146
+ def __enter__(self): return self
147
+ def __exit__(self, exc_type, exc_val, exc_tb): self.file.close()
148
+ def keys(self) -> list[str]: return [k for k in self.header.keys() if k != "__metadata__"]
149
+ def metadata(self) -> Dict[str, str]: return self.header.get("__metadata__", {})
150
+ def get_tensor(self, key):
151
+ if key not in self.header: raise KeyError(f"Tensor '{key}' not found")
152
+ metadata = self.header[key]
153
+ start, end = metadata["data_offsets"]
154
+ self.file.seek(self.header_size + 8 + start)
155
+ return self._deserialize_tensor(self.file.read(end - start), metadata)
156
+ def _read_header(self):
157
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
158
+ return json.loads(self.file.read(header_size).decode("utf-8")), header_size
159
+ def _deserialize_tensor(self, tensor_bytes, metadata):
160
+ dtype_map = {"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool}
161
+ dtype = dtype_map[metadata["dtype"]]
162
+ shape = metadata["shape"]
163
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
164
+
165
+ class ShardBuffer:
166
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
167
+ self.max_bytes = int(max_size_gb * 1024**3)
168
+ self.output_dir, self.output_repo, self.subfolder, self.hf_token, self.filename_prefix = output_dir, output_repo, subfolder, hf_token, filename_prefix
169
+ self.buffer, self.current_bytes, self.shard_count, self.index_map, self.total_size = [], 0, 0, {}, 0
170
+ def add_tensor(self, key, tensor):
171
+ if tensor.dtype == torch.bfloat16: raw, dt = tensor.view(torch.int16).numpy().tobytes(), "BF16"
172
+ elif tensor.dtype == torch.float16: raw, dt = tensor.numpy().tobytes(), "F16"
173
+ else: raw, dt = tensor.numpy().tobytes(), "F32"
174
+ self.buffer.append({"key": key, "data": raw, "dtype": dt, "shape": tensor.shape})
175
+ self.current_bytes += len(raw)
176
+ self.total_size += len(raw)
177
+ if self.current_bytes >= self.max_bytes: self.flush()
178
+ def flush(self):
179
+ if not self.buffer: return
180
+ self.shard_count += 1
181
+ fname = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
182
+ header = {"__metadata__": {"format": "pt"}}
183
+ curr_off = 0
184
+ for i in self.buffer:
185
+ header[i["key"]] = {"dtype": i["dtype"], "shape": i["shape"], "data_offsets": [curr_off, curr_off + len(i["data"])]}
186
+ curr_off += len(i["data"])
187
+ self.index_map[i["key"]] = fname
188
+ out = self.output_dir / fname
189
+ header_json = json.dumps(header).encode('utf-8')
190
+ with open(out, 'wb') as f:
191
+ f.write(struct.pack('<Q', len(header_json)))
192
+ f.write(header_json)
193
+ for i in self.buffer: f.write(i["data"])
194
+ api.upload_file(path_or_fileobj=out, path_in_repo=f"{self.subfolder}/{fname}" if self.subfolder else fname, repo_id=self.output_repo, token=self.hf_token)
195
+ os.remove(out)
196
+ self.buffer, self.current_bytes = [], 0
197
+ gc.collect()
198
+
199
+ def task_merge_legacy(hf_token, base, sub, lora, scale, prec, shard, out, struct_s, priv, progress=gr.Progress()):
200
+ cleanup_temp()
201
+ if hf_token: login(hf_token.strip())
202
+ try: api.create_repo(repo_id=out, private=priv, exist_ok=True, token=hf_token)
203
+ except Exception as e: return f"Error: {e}"
204
+ if struct_s:
205
+ try:
206
+ files = api.list_repo_files(repo_id=struct_s, token=hf_token)
207
+ for f in tqdm(files, desc="Copying Structure"):
208
+ if sub and f.startswith(sub): continue
209
+ if not sub and any(f.endswith(x) for x in ['.safetensors', '.bin', '.pt', '.pth']): continue
210
+ l = hf_hub_download(repo_id=struct_s, filename=f, token=hf_token, local_dir=TempDir)
211
+ api.upload_file(path_or_fileobj=l, path_in_repo=f, repo_id=out, token=hf_token)
212
+ except: pass
213
+
214
+ files = [f for f in list_repo_files(repo_id=base, token=hf_token) if f.endswith(".safetensors")]
215
+ if sub: files = [f for f in files if f.startswith(sub)]
216
+ if not files: return "No safetensors found"
217
+
218
+ prefix = "diffusion_pytorch_model" if (sub in ["transformer", "unet"] or "diffusion_pytorch_model" in os.path.basename(files[0])) else "model"
219
+ dtype = torch.bfloat16 if prec == "bf16" else torch.float16 if prec == "fp16" else torch.float32
220
+ try: lora_pairs = load_lora_to_memory(download_lora_smart(lora, hf_token), dtype)
221
+ except Exception as e: return f"LoRA Error: {e}"
222
+
223
+ buf = ShardBuffer(shard, TempDir, out, sub, hf_token, prefix)
224
+ for i, fpath in enumerate(files):
225
+ local = hf_hub_download(repo_id=base, filename=fpath, token=hf_token, local_dir=TempDir)
226
+ with MemoryEfficientSafeOpen(local) as f:
227
+ for k in f.keys():
228
+ v = f.get_tensor(k)
229
+ stem = get_key_stem(k)
230
+ match = lora_pairs.get(stem) or lora_pairs.get(stem.replace("to_q", "qkv")) or lora_pairs.get(stem.replace("to_k", "qkv")) or lora_pairs.get(stem.replace("to_v", "qkv"))
231
+ if match:
232
+ d, u = match["down"], match["up"]
233
+ s = scale * (match["alpha"] / match["rank"])
234
+ if len(v.shape)==4 and len(d.shape)==2: d, u = d.unsqueeze(-1).unsqueeze(-1), u.unsqueeze(-1).unsqueeze(-1)
235
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) if len(up.shape)==4 else u @ d
236
+ v = v.to(dtype).add_((delta * s).to(dtype))
237
+ buf.add_tensor(k, v.to(dtype))
238
+ os.remove(local)
239
+ buf.flush()
240
+ idx = {"metadata": {"total_size": buf.total_size}, "weight_map": buf.index_map}
241
+ idx_n = f"{prefix}.safetensors.index.json"
242
+ with open(TempDir/idx_n, "w") as f: json.dump(idx, f, indent=4)
243
+ api.upload_file(path_or_fileobj=TempDir/idx_n, path_in_repo=f"{sub}/{idx_n}" if sub else idx_n, repo_id=out, token=hf_token)
244
+ return "Done"
245
+
246
+ def task_extract(hf_token, org, tun, rank, out):
247
+ cleanup_temp()
248
+ if hf_token: login(hf_token.strip())
249
+ try:
250
+ p1 = download_lora_smart(org, hf_token)
251
+ p2 = download_lora_smart(tun, hf_token)
252
+ org_f, tun_f = MemoryEfficientSafeOpen(p1), MemoryEfficientSafeOpen(p2)
253
+ lora_sd = {}
254
+ common = set(org_f.keys()) & set(tun_f.keys())
255
+ for k in tqdm(common, desc="Extracting"):
256
+ if "num_batches_tracked" in k or "running_mean" in k or "running_var" in k: continue
257
+ m1, m2 = org_f.get_tensor(k).float(), tun_f.get_tensor(k).float()
258
+ if m1.shape != m2.shape: continue
259
+ diff = m2 - m1
260
+ if torch.max(torch.abs(diff)) < 1e-4: continue
261
+ out_d, in_d = diff.shape[0], diff.shape[1] if len(diff.shape) > 1 else 1
262
+ r = min(int(rank), in_d, out_d)
263
+ if len(diff.shape)==4: diff = diff.flatten(1)
264
+ elif len(diff.shape)==1: diff = diff.unsqueeze(1)
265
+ U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
266
+ Vh = V.t()
267
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
268
+ U = U @ torch.diag(S)
269
+ dist = torch.cat([U.flatten(), Vh.flatten()])
270
+ hi_val = torch.quantile(torch.abs(dist), 0.99)
271
+ if hi_val > 0: U, Vh = U.clamp(-hi_val, hi_val), Vh.clamp(-hi_val, hi_val)
272
+ if len(m1.shape)==4:
273
+ U = U.reshape(out_d, r, 1, 1)
274
+ Vh = Vh.reshape(r, in_d, m1.shape[2], m1.shape[3])
275
+ else:
276
+ U, Vh = U.reshape(out_d, r), Vh.reshape(r, in_d)
277
+ stem = k.replace(".weight", "")
278
+ lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
279
+ lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
280
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
281
+ out_f = TempDir/"extracted.safetensors"
282
+ save_file(lora_sd, out_f)
283
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
284
+ api.upload_file(path_or_fileobj=out_f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
285
+ return "Done"
286
+ except Exception as e: return f"Error: {e}"
287
+
288
+ def load_full_state_dict(path):
289
+ raw = load_file(path, device="cpu")
290
+ cleaned = {}
291
+ for k, v in raw.items():
292
+ if "lora_A" in k: new_k = k.replace("lora_A", "lora_down")
293
+ elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up")
294
+ else: new_k = k
295
+ cleaned[new_k] = v.float()
296
+ return cleaned
297
+
298
+ def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private):
299
+ cleanup_temp()
300
+ if hf_token: login(hf_token.strip())
301
+ urls = [line.strip() for line in inputs_text.replace(" ", "\n").split('\n') if line.strip()]
302
+ if len(urls) < 2: return "Error: Provide at least 2 adapters."
303
+ try: weights = [float(w.strip()) for w in weight_str.split(',')] if weight_str.strip() else [1.0] * len(urls)
304
+ except: return "Error parsing weights."
305
+ if len(weights) < len(urls): weights += [1.0] * (len(urls) - len(weights))
306
+
307
+ paths = []
308
+ for url in tqdm(urls, desc="Downloading"): paths.append(download_lora_smart(url, hf_token))
309
+
310
+ merged = {}
311
+ if "Iterative EMA" in method:
312
+ base_sd = load_file(paths[0], device="cpu")
313
+ gamma = None
314
+ if sigma_rel > 0:
315
+ t_val = sigma_rel**-2
316
+ roots = np.roots([1, 7, 16 - t_val, 12 - t_val])
317
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
318
+ for i, path in enumerate(paths[1:]):
319
+ current_beta = (1 - 1 / (i + 1)) ** (gamma + 1) if gamma is not None else beta
320
+ curr = load_file(path, device="cpu")
321
+ for k in base_sd:
322
+ if k in curr and "alpha" not in k:
323
+ base_sd[k] = base_sd[k].float() * current_beta + curr[k].float() * (1 - current_beta)
324
+ merged = base_sd
325
+ else:
326
+ states = [load_full_state_dict(p) for p in paths]
327
+ all_stems = set()
328
+ for s in states:
329
+ for k in s:
330
+ if "lora_" in k: all_stems.add(k.split(".lora_")[0])
331
+ for stem in tqdm(all_stems):
332
+ down_list, up_list = [], []
333
+ alpha_sum, total_delta = 0.0, None
334
+ for i, state in enumerate(states):
335
+ w = weights[i]
336
+ dk, uk, ak = f"{stem}.lora_down.weight", f"{stem}.lora_up.weight", f"{stem}.alpha"
337
+ if dk in state and uk in state:
338
+ d, u = state[dk], state[uk]
339
+ alpha_sum += state[ak].item() if ak in state else d.shape[0]
340
+ if "Concatenation" in method:
341
+ down_list.append(d); up_list.append(u * w)
342
+ elif "SVD" in method:
343
+ rank = d.shape[0]
344
+ alpha = state[ak].item() if ak in state else rank
345
+ scale = (alpha / rank) * w
346
+ delta = ((u.flatten(1) @ d.flatten(1)).reshape(u.shape[0], d.shape[1], d.shape[2], d.shape[3]) if len(d.shape)==4 else u @ d) * scale
347
+ total_delta = delta if total_delta is None else total_delta + delta
348
+ if "Concatenation" in method and down_list:
349
+ merged[f"{stem}.lora_down.weight"] = torch.cat(down_list, dim=0).contiguous()
350
+ merged[f"{stem}.lora_up.weight"] = torch.cat(up_list, dim=1).contiguous()
351
+ merged[f"{stem}.alpha"] = torch.tensor(alpha_sum)
352
+ elif "SVD" in method and total_delta is not None:
353
+ tr = int(target_rank)
354
+ flat = total_delta.flatten(1) if len(total_delta.shape)==4 else total_delta
355
+ try:
356
+ U, S, V = torch.svd_lowrank(flat, q=tr + 4, niter=4)
357
+ Vh = V.t()
358
+ U, S, Vh = U[:, :tr], S[:tr], Vh[:tr, :]
359
+ U = U @ torch.diag(S)
360
+ if len(total_delta.shape) == 4:
361
+ U = U.reshape(total_delta.shape[0], tr, 1, 1)
362
+ Vh = Vh.reshape(tr, total_delta.shape[1], total_delta.shape[2], total_delta.shape[3])
363
+ else:
364
+ U, Vh = U.reshape(total_delta.shape[0], tr), Vh.reshape(tr, total_delta.shape[1])
365
+ merged[f"{stem}.lora_down.weight"] = Vh.contiguous()
366
+ merged[f"{stem}.lora_up.weight"] = U.contiguous()
367
+ merged[f"{stem}.alpha"] = torch.tensor(tr).float()
368
+ except: pass
369
+
370
+ out = TempDir / "merged_adapters.safetensors"
371
+ if merged: save_file(merged, out)
372
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
373
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
374
+ return f"Success! Merged to {out_repo}"
375
+
376
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
377
+ cleanup_temp()
378
+ if hf_token: login(hf_token.strip())
379
+ path = download_lora_smart(lora_input, hf_token)
380
+ state = load_file(path, device="cpu")
381
+ new_state = {}
382
+ groups = {}
383
+ for k in state:
384
+ simple = k.split(".lora_")[0]
385
+ if simple not in groups: groups[simple] = {}
386
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
387
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
388
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
389
+
390
+ target_rank_limit = int(new_rank)
391
+ for stem, g in tqdm(groups.items()):
392
+ if "down" in g and "up" in g:
393
+ down, up = g["down"].float(), g["up"].float()
394
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) if len(down.shape)==4 else up @ down
395
+ flat = merged.flatten(1)
396
+ U, S, V = torch.svd_lowrank(flat, q=target_rank_limit + 32)
397
+ Vh = V.t()
398
+ calc_rank = target_rank_limit
399
+ if dynamic_method == "sv_ratio":
400
+ calc_rank = int(torch.sum(S > (S[0] / dynamic_param)).item())
401
+ elif dynamic_method == "sv_cumulative":
402
+ calc_rank = int(torch.searchsorted(torch.cumsum(S, 0) / torch.sum(S), dynamic_param)) + 1
403
+ elif dynamic_method == "sv_fro":
404
+ calc_rank = int(torch.searchsorted(torch.cumsum(S.pow(2), 0) / torch.sum(S.pow(2)), dynamic_param**2)) + 1
405
+ final_rank = max(1, min(calc_rank, target_rank_limit, S.shape[0]))
406
+ U = U[:, :final_rank] @ torch.diag(S[:final_rank])
407
+ Vh = Vh[:final_rank, :]
408
+ if len(down.shape) == 4:
409
+ U = U.reshape(up.shape[0], final_rank, 1, 1)
410
+ Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3])
411
+ new_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
412
+ new_state[f"{stem}.lora_up.weight"] = U.contiguous()
413
+ new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float()
414
+ out = TempDir / "shrunken.safetensors"
415
+ save_file(new_state, out)
416
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
417
+ api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token)
418
+ return "Done"
419
+
420
+ # =================================================================================
421
+ # MERGEKIT & LOGSVIEW (TABS 5-9) - FIXED CLI LOGIC
422
+ # =================================================================================
423
+
424
+ def parse_weight(w_str):
425
+ if not w_str.strip(): return 1.0
426
+ try:
427
+ if "[" in w_str: return yaml.safe_load(w_str)
428
+ return float(w_str)
429
+ except: return 1.0
430
+
431
+ def run_mergekit_logic(config_dict, token, out_repo, private, shard_size, output_precision, tokenizer_source, chat_template, program="mergekit-yaml"):
432
+ runner = LogsViewRunner()
433
+ cleanup_temp()
434
+
435
+ # 1. Validation
436
+ try:
437
+ MergeConfiguration.model_validate(config_dict)
438
+ except Exception as e:
439
+ yield runner.log(f"Invalid Config: {e}", level="ERROR")
440
+ return
441
+
442
+ # 2. Auth & Config Save
443
+ if token:
444
+ login(token.strip())
445
+ os.environ["HF_TOKEN"] = token.strip()
446
+
447
+ if "dtype" not in config_dict: config_dict["dtype"] = output_precision
448
+ if "tokenizer_source" not in config_dict and tokenizer_source != "base":
449
+ config_dict["tokenizer_source"] = tokenizer_source
450
+
451
+ # Add chat_template if not empty
452
+ if chat_template and chat_template.strip():
453
+ config_dict["chat_template"] = chat_template.strip()
454
+
455
+ config_path = TempDir / "config.yaml"
456
+ with open(config_path, "w") as f: yaml.dump(config_dict, f, sort_keys=False)
457
+
458
+ yield runner.log(f"Config saved to {config_path}")
459
+ yield runner.log(f"YAML:\n{yaml.dump(config_dict, sort_keys=False)}")
460
+
461
+ # 3. Create Repo
462
+ try:
463
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=token)
464
+ yield runner.log(f"Repo {out_repo} ready.")
465
+ except Exception as e:
466
+ yield runner.log(f"Repo Error: {e}", level="ERROR")
467
+ return
468
+
469
+ # 4. Execution
470
+ out_path = TempDir / "merge_output"
471
+
472
+ shard_arg = f"{int(float(shard_size) * 1024)}M"
473
+
474
+ cmd = [
475
+ program,
476
+ str(config_path),
477
+ str(out_path),
478
+ "--allow-crimes",
479
+ "--copy-tokenizer",
480
+ "--out-shard-size", shard_arg,
481
+ "--lazy-unpickle"
482
+ ]
483
+
484
+ if torch.cuda.is_available():
485
+ cmd.extend(["--cuda", "--low-cpu-memory"])
486
+
487
+ yield runner.log(f"Executing: {' '.join(cmd)}")
488
+ env = os.environ.copy()
489
+ env["HF_HOME"] = str(TempDir / ".cache")
490
+
491
+ yield from runner.run_command(cmd, env=env)
492
+
493
+ if runner.exit_code != 0:
494
+ yield runner.log("Merge failed.", level="ERROR")
495
+ return
496
+
497
+ # 5. Upload
498
+ yield runner.log(f"Uploading to {out_repo}...")
499
+ yield from runner.run_python(api.upload_folder, repo_id=out_repo, folder_path=out_path)
500
+ yield runner.log("Upload Complete!")
501
+
502
+ # --- UI Wrappers for Tabs 5-9 ---
503
+
504
+ def wrapper_amphinterpolative(token, method, base, t, norm, i8, flat, row, eps, m_iter, tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t):
505
+ params = {"normalize": norm, "int8_mask": i8}
506
+ if method in ["slerp", "nuslerp"]: params["t"] = float(t)
507
+ if method == "nuslerp": params.update({"flatten": flat, "row_wise": row})
508
+ if method == "multislerp": params["eps"] = float(eps)
509
+ if method == "karcher": params.update({"max_iter": int(m_iter), "tol": float(tol)})
510
+
511
+ config = {"merge_method": method}
512
+
513
+ if method in ["slerp", "nuslerp"]:
514
+ if not base.strip(): yield runner.log("Error: Base model required", level="ERROR"); return
515
+ config["base_model"] = base.strip()
516
+ sources = []
517
+ for m, w in [(m1,w1), (m2,w2)]:
518
+ if m.strip(): sources.append({"model": m, "parameters": {"weight": parse_weight(w)}})
519
+ config["slices"] = [{"sources": sources, "parameters": params}]
520
+ else:
521
+ if base.strip() and method == "multislerp": config["base_model"] = base.strip()
522
+ models = []
523
+ for m, w in [(m1, w1), (m2, w2), (m3, w3), (m4, w4), (m5, w5)]:
524
+ if m.strip(): models.append({"model": m, "parameters": {"weight": parse_weight(w)}})
525
+ config["models"] = models
526
+ config["parameters"] = params
527
+
528
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
529
+
530
+ def wrapper_stirtie(token, method, base, norm, i8, lamb, resc, topk, m1, w1, d1, g1, e1, m2, w2, d2, g2, e2, m3, w3, d3, g3, e3, m4, w4, d4, g4, e4, out, priv, shard, prec, tok_src, chat_t):
531
+ models = []
532
+ for m, w, d, g, e in [(m1,w1,d1,g1,e1), (m2,w2,d2,g2,e2), (m3,w3,d3,g3,e3), (m4,w4,d4,g4,e4)]:
533
+ if not m.strip(): continue
534
+ p = {"weight": parse_weight(w)}
535
+ if method in ["ties", "dare_ties", "dare_linear", "breadcrumbs_ties"]: p["density"] = parse_weight(d)
536
+ if "breadcrumbs" in method: p["gamma"] = float(g)
537
+ if "della" in method: p["epsilon"] = float(e)
538
+ models.append({"model": m, "parameters": p})
539
+
540
+ g_params = {"normalize": norm, "int8_mask": i8}
541
+ if method != "sce": g_params["lambda"] = float(lamb)
542
+ if method == "dare_linear": g_params["rescale"] = resc
543
+ if method == "sce": g_params["select_topk"] = float(topk)
544
+
545
+ config = {
546
+ "merge_method": method,
547
+ "base_model": base.strip() if base.strip() else models[0]["model"],
548
+ "parameters": g_params,
549
+ "models": models
550
+ }
551
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
552
+
553
+ def wrapper_specious(token, method, base, norm, i8, t, filt_w, m1, w1, f1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t):
554
+ models = []
555
+ if method == "passthrough":
556
+ if not m1.strip(): yield runner.log("Error: Model 1 required", level="ERROR"); return
557
+ p = {"weight": parse_weight(w1)}
558
+ if f1.strip(): p["filter"] = f1.strip()
559
+ models.append({"model": m1, "parameters": p})
560
+ else:
561
+ for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)]:
562
+ if m.strip(): models.append({"model": m, "parameters": {"weight": parse_weight(w)}})
563
+
564
+ config = {"merge_method": method, "parameters": {"normalize": norm, "int8_mask": i8}}
565
+ if base.strip(): config["base_model"] = base.strip()
566
+ if method == "nearswap": config["parameters"]["t"] = float(t)
567
+ if method == "model_stock": config["parameters"]["filter_wise"] = filt_w
568
+ config["models"] = models
569
+
570
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
571
+
572
+ def wrapper_moer(token, base, experts, gate, dtype, out, priv, shard, prec, tok_src, chat_t):
573
+ formatted = [{"source_model": e.strip(), "positive_prompts": ["chat", "assist"]} for e in experts.split('\n') if e.strip()]
574
+ config = {
575
+ "base_model": base.strip() if base.strip() else formatted[0]["source_model"],
576
+ "gate_mode": gate,
577
+ "dtype": dtype,
578
+ "experts": formatted
579
+ }
580
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-moe")
581
+
582
+ def wrapper_rawer(token, models, method, dtype, out, priv, shard, prec, tok_src, chat_t):
583
+ m_list = [m.strip() for m in models.split('\n') if m.strip()]
584
+ config = {
585
+ "models": [{"model": m, "parameters": {"weight": 1.0}} for m in m_list],
586
+ "merge_method": method,
587
+ "dtype": dtype
588
+ }
589
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
590
+
591
+ # --- TAB 10 (Custom DARE) Logic ---
592
+ def task_dare_custom(token, base, ft, ratio, mask, out, priv):
593
+ cleanup_temp()
594
+ if token: login(token.strip())
595
+ try:
596
+ b_path = download_lora_smart(base, token)
597
+ f_path = download_lora_smart(ft, token)
598
+ b_sd = load_file(b_path, device="cpu")
599
+ f_sd = load_file(f_path, device="cpu")
600
+ merged = {}
601
+ common = set(b_sd.keys()) & set(f_sd.keys())
602
+ for k in tqdm(common, desc="Merging"):
603
+ tb, tf = b_sd[k], f_sd[k]
604
+ if tb.shape != tf.shape:
605
+ merged[k] = tf
606
+ continue
607
+ delta = tf.float() - tb.float()
608
+ if mask > 0:
609
+ m = torch.bernoulli(torch.full_like(delta, 1.0 - mask))
610
+ delta = (delta * m) / (1.0 - mask)
611
+ merged[k] = (tb.float() + ratio * delta).to(tb.dtype)
612
+
613
+ out_f = TempDir / "model.safetensors"
614
+ save_file(merged, out_f)
615
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
616
+ api.upload_file(path_or_fileobj=out_f, path_in_repo="model.safetensors", repo_id=out, token=token)
617
+ return f"Done! {out}"
618
+ except Exception as e: return str(e)
619
+
620
+ # =================================================================================
621
+ # UI GENERATION
622
+ # =================================================================================
623
+
624
+ css = ".container { max-width: 1100px; margin: auto; }"
625
+
626
+ with gr.Blocks() as demo:
627
+ gr.HTML("""<h1><img src="https://huggingface.co/spaces/AlekseyCalvin/Soon_Merger/resolve/main/SMerger3.png" alt="SOONmerge®"> Transform Transformers for FREE!</h1>""")
628
+ gr.Markdown("# 🧰Training-Free CPU-run Model Creation Toolkit")
629
+
630
+ with gr.Tabs():
631
+ # --- TAB 1: RESTORED ---
632
+ with gr.Tab("Merge to Base Model + Reshard Output"):
633
+ t1_token = gr.Textbox(label="Token", type="password")
634
+ t1_base = gr.Textbox(label="Base Repo", value="name/repo")
635
+ t1_sub = gr.Textbox(label="Subfolder (Optional)", value="")
636
+ t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
637
+ with gr.Row():
638
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1)
639
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
640
+ t1_shard = gr.Slider(label="Max Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1)
641
+ t1_out = gr.Textbox(label="Output Repo")
642
+ t1_struct = gr.Textbox(label="Extras Source (copies configs/components/etc)", value="name/repo")
643
+ t1_priv = gr.Checkbox(label="Private", value=True)
644
+ gr.Button("Merge").click(task_merge_legacy, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], gr.Textbox(label="Result"))
645
+
646
+ # --- TAB 2: RESTORED ---
647
+ with gr.Tab("Extract Adapter"):
648
+ t2_token = gr.Textbox(label="Token", type="password")
649
+ t2_org = gr.Textbox(label="Original Model")
650
+ t2_tun = gr.Textbox(label="Tuned or Homologous Model")
651
+ t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1)
652
+ t2_out = gr.Textbox(label="Output Repo")
653
+ gr.Button("Extract").click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], gr.Textbox(label="Result"))
654
+
655
+ # --- TAB 3: RESTORED ---
656
+ with gr.Tab("Merge Adapters"):
657
+ gr.Markdown("### Batch Adapter Merging")
658
+ t3_token = gr.Textbox(label="Token", type="password")
659
+ t3_urls = gr.TextArea(label="Adapter URLs/Repos (one per line, or space-separated)")
660
+ t3_method = gr.Dropdown(["Iterative EMA (Linear w/ Beta/Sigma coefficient)", "Concatenation (MOE-like weights-stack)", "SVD Fusion (Task Arithmetic/Compressed)"], value="Iterative EMA (Linear w/ Beta/Sigma coefficient)", label="Merge Method")
661
+ with gr.Row():
662
+ t3_weights = gr.Textbox(label="Weights (comma-separated) – for Concat/SVD")
663
+ t3_rank = gr.Number(label="Target Rank – For SVD only", value=128)
664
+ with gr.Row():
665
+ t3_beta = gr.Slider(label="Beta – for linear/post-hoc EMA", value=0.95, minimum=0.01, maximum=1.00)
666
+ t3_sigma = gr.Slider(label="Sigma Rel – for linear/post-hoc EMA", value=0.21, minimum=0.01, maximum=1.00)
667
+ t3_out = gr.Textbox(label="Output Repo")
668
+ t3_priv = gr.Checkbox(label="Private Output", value=True)
669
+ gr.Button("Merge").click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], gr.Textbox(label="Result"))
670
+
671
+ # --- TAB 4: RESTORED ---
672
+ with gr.Tab("Resize Adapter"):
673
+ t4_token = gr.Textbox(label="Token", type="password")
674
+ t4_in = gr.Textbox(label="LoRA")
675
+ with gr.Row():
676
+ t4_rank = gr.Number(label="To Rank (Safety Ceiling)", value=8)
677
+ t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None", label="Dynamic Method")
678
+ t4_param = gr.Number(label="Dynamic Param", value=0.9)
679
+ gr.Markdown("### 📉 Dynamic Resizing Guide\nThese methods intelligently determine the best rank per layer.\n- **sv_ratio (Relative Strength):** Keeps features that are at least `1/Param` as strong as the main feature. **Param must be >= 2**.\n- **sv_fro (Visual Information Density):** Preserves `Param%` of total information content. **Param between 0.0 and 1.0**.\n- **sv_cumulative (Cumulative Sum):** Preserves weights that sum up to `Param%` of total strength. **Param between 0.0 and 1.0**.\n- **⚠️ Safety Ceiling:** The **'To Rank'** slider acts as a hard limit.")
680
+ t4_out = gr.Textbox(label="Output")
681
+ gr.Button("Resize").click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], gr.Textbox(label="Result"))
682
+
683
+ # --- TAB 5: Amphinterpolative ---
684
+ with gr.Tab("Amphinterpolative"):
685
+ gr.Markdown("### Spherical Interpolation Family")
686
+ t5_token = gr.Textbox(label="HF Token", type="password")
687
+ t5_method = gr.Dropdown(["slerp", "nuslerp", "multislerp", "karcher"], value="slerp", label="Method")
688
+ with gr.Row():
689
+ t5_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0)
690
+ t5_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision")
691
+ t5_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source")
692
+ t5_chat = gr.Textbox(label="Chat Template (write-in, default: auto)", placeholder="auto")
693
+ with gr.Row():
694
+ t5_base = gr.Textbox(label="Base Model")
695
+ t5_t = gr.Slider(0, 1, 0.5, label="t")
696
+ with gr.Row():
697
+ t5_norm = gr.Checkbox(label="Normalize", value=True); t5_i8 = gr.Checkbox(label="Int8 Mask", value=False); t5_flat = gr.Checkbox(label="NuSlerp Flatten", value=False); t5_row = gr.Checkbox(label="NuSlerp Row Wise", value=False)
698
+ with gr.Row():
699
+ t5_eps = gr.Textbox(label="eps", value="1e-8"); t5_iter = gr.Number(label="max_iter", value=10); t5_tol = gr.Textbox(label="tol", value="1e-5")
700
+ m1, w1 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"); m2, w2 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0")
701
+ with gr.Accordion("More", open=False):
702
+ m3, w3 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); m4, w4 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); m5, w5 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0")
703
+ t5_out = gr.Textbox(label="Output Repo"); t5_priv = gr.Checkbox(label="Private", value=True)
704
+ gr.Button("Execute").click(wrapper_amphinterpolative, [t5_token, t5_method, t5_base, t5_t, t5_norm, t5_i8, t5_flat, t5_row, t5_eps, t5_iter, t5_tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, t5_out, t5_priv, t5_shard, t5_prec, t5_tok, t5_chat], LogsView())
705
+
706
+ # --- TAB 6: Stir/Tie Bases ---
707
+ with gr.Tab("Stir/Tie Bases"):
708
+ gr.Markdown("### Task Vector Family")
709
+ t6_token = gr.Textbox(label="Token", type="password")
710
+ t6_method = gr.Dropdown(["task_arithmetic", "ties", "dare_ties", "dare_linear", "della", "della_linear", "breadcrumbs", "breadcrumbs_ties", "sce"], value="ties", label="Method")
711
+ with gr.Row():
712
+ t6_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t6_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t6_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t6_chat = gr.Textbox(label="Chat Template", placeholder="auto")
713
+ t6_base = gr.Textbox(label="Base Model")
714
+ with gr.Row():
715
+ t6_norm = gr.Checkbox(label="Normalize", value=True); t6_i8 = gr.Checkbox(label="Int8 Mask", value=False); t6_resc = gr.Checkbox(label="Rescale", value=True); t6_lamb = gr.Number(label="Lambda", value=1.0); t6_topk = gr.Slider(0, 1, 1.0, label="Select TopK")
716
+ m1_6, w1_6 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"); d1_6, g1_6, e1_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
717
+ with gr.Accordion("More", open=False):
718
+ m2_6, w2_6 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0"); d2_6, g2_6, e2_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
719
+ t6_out = gr.Textbox(label="Output Repo"); t6_priv = gr.Checkbox(label="Private", value=True)
720
+ gr.Button("Execute").click(wrapper_stirtie, [t6_token, t6_method, t6_base, t6_norm, t6_i8, t6_lamb, t6_resc, t6_topk, m1_6, w1_6, d1_6, g1_6, e1_6, m2_6, w2_6, d2_6, g2_6, e2_6, t6_out, t6_priv, t6_shard, t6_prec, t6_tok, t6_chat], LogsView())
721
+
722
+ # --- TAB 7: Specious ---
723
+ with gr.Tab("Specious"):
724
+ gr.Markdown("### Specialized Methods")
725
+ t7_token = gr.Textbox(label="Token", type="password")
726
+ t7_method = gr.Dropdown(["model_stock", "nearswap", "arcee_fusion", "passthrough", "linear"], value="model_stock", label="Method")
727
+ with gr.Row():
728
+ t7_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t7_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t7_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t7_chat = gr.Textbox(label="Chat Template", placeholder="auto")
729
+ t7_base = gr.Textbox(label="Base Model")
730
+ with gr.Row():
731
+ t7_norm = gr.Checkbox(label="Normalize", value=True); t7_i8 = gr.Checkbox(label="Int8 Mask", value=False); t7_t = gr.Slider(0, 1, 0.5, label="t"); t7_filt_w = gr.Checkbox(label="Filter Wise", value=False)
732
+ m1_7, w1_7, f1_7 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"), gr.Textbox(label="Filter (Passthrough)")
733
+ m2_7, w2_7 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0")
734
+ with gr.Accordion("More", open=False):
735
+ m3_7, w3_7 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); m4_7, w4_7 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); m5_7, w5_7 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0")
736
+ t7_out = gr.Textbox(label="Output Repo"); t7_priv = gr.Checkbox(label="Private", value=True)
737
+ gr.Button("Execute").click(wrapper_specious, [t7_token, t7_method, t7_base, t7_norm, t7_i8, t7_t, t7_filt_w, m1_7, w1_7, f1_7, m2_7, w2_7, m3_7, w3_7, m4_7, w4_7, m5_7, w5_7, t7_out, t7_priv, t7_shard, t7_prec, t7_tok, t7_chat], LogsView())
738
+
739
+ # --- TAB 8: MoEr ---
740
+ with gr.Tab("MoEr"):
741
+ gr.Markdown("### Mixture of Experts")
742
+ t8_token = gr.Textbox(label="Token", type="password")
743
+ with gr.Row():
744
+ t8_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t8_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t8_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t8_chat = gr.Textbox(label="Chat Template", placeholder="auto")
745
+ t8_base = gr.Textbox(label="Base Model"); t8_experts = gr.TextArea(label="Experts List"); t8_gate = gr.Dropdown(["cheap_embed", "random", "hidden"], value="cheap_embed", label="Gate Mode"); t8_dtype = gr.Dropdown(["float16", "bfloat16"], value="bfloat16", label="Internal Dtype")
746
+ t8_out = gr.Textbox(label="Output Repo"); t8_priv = gr.Checkbox(label="Private", value=True)
747
+ gr.Button("Build MoE").click(wrapper_moer, [t8_token, t8_base, t8_experts, t8_gate, t8_dtype, t8_out, t8_priv, t8_shard, t8_prec, t8_tok, t8_chat], LogsView())
748
+
749
+ # --- TAB 9: Rawer ---
750
+ with gr.Tab("Rawer"):
751
+ gr.Markdown("### Raw PyTorch / Non-Transformer")
752
+ t9_token = gr.Textbox(label="Token", type="password"); t9_models = gr.TextArea(label="Models (one per line)")
753
+ with gr.Row():
754
+ t9_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t9_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t9_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t9_chat = gr.Textbox(label="Chat Template", placeholder="auto")
755
+ t9_method = gr.Dropdown(["linear", "passthrough"], value="linear", label="Method"); t9_dtype = gr.Dropdown(["float32", "float16", "bfloat16"], value="float32", label="Config Dtype")
756
+ t9_out = gr.Textbox(label="Output Repo"); t9_priv = gr.Checkbox(label="Private", value=True)
757
+ gr.Button("Merge Raw").click(wrapper_rawer, [t9_token, t9_models, t9_method, t9_dtype, t9_out, t9_priv, t9_shard, t9_prec, t9_tok, t9_chat], LogsView())
758
+
759
+ # --- TAB 10: Mario,DARE! ---
760
+ with gr.Tab("Mario,DARE!"):
761
+ gr.Markdown("### From sft-merger by [Martyn Garcia](https://github.com/martyn)")
762
+ t10_token = gr.Textbox(label="Token", type="password")
763
+ with gr.Row():
764
+ t10_base = gr.Textbox(label="Base Model"); t10_ft = gr.Textbox(label="Fine-Tuned Model")
765
+ with gr.Row():
766
+ t10_ratio = gr.Slider(0, 5, 1.0, label="Ratio"); t10_mask = gr.Slider(0, 0.99, 0.5, label="Mask Rate")
767
+ t10_out = gr.Textbox(label="Output Repo"); t10_priv = gr.Checkbox(label="Private", value=True)
768
+ gr.Button("Run").click(task_dare_custom, [t10_token, t10_base, t10_ft, t10_ratio, t10_mask, t10_out, t10_priv], gr.Textbox(label="Result"))
769
+
770
+ if __name__ == "__main__":
771
+ demo.queue().launch(css=css, ssr_mode=False)