AlekseyCalvin commited on
Commit
8197d7d
·
verified ·
1 Parent(s): e630094

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +800 -0
app.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import yaml
11
+ import subprocess
12
+ import sys
13
+ import tempfile
14
+ import re
15
+ from pathlib import Path
16
+ from typing import Dict, Any, Optional, List, Iterable
17
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
18
+ from safetensors.torch import load_file, save_file
19
+ from tqdm import tqdm
20
+
21
+ # --- Import MergeKit Config ---
22
+ try:
23
+ from mergekit.config import MergeConfiguration
24
+ except ImportError:
25
+ # Fallback if installation fails temporarily
26
+ class MergeConfiguration:
27
+ @staticmethod
28
+ def model_validate(config): pass
29
+
30
+ # --- Constants & Setup ---
31
+ try:
32
+ TempDir = Path("/tmp/temp_tool")
33
+ os.makedirs(TempDir, exist_ok=True)
34
+ except:
35
+ TempDir = Path("./temp_tool")
36
+ os.makedirs(TempDir, exist_ok=True)
37
+
38
+ api = HfApi()
39
+
40
+ def cleanup_temp():
41
+ if TempDir.exists():
42
+ shutil.rmtree(TempDir)
43
+ os.makedirs(TempDir, exist_ok=True)
44
+ gc.collect()
45
+
46
+ # =================================================================================
47
+ # SHARED HELPERS (Tabs 1-4 & 10)
48
+ # =================================================================================
49
+
50
+ def parse_hf_url(url):
51
+ if "huggingface.co" in url and "resolve" in url:
52
+ try:
53
+ parts = url.split("huggingface.co/")[-1].split("/")
54
+ repo_id = f"{parts[0]}/{parts[1]}"
55
+ filename = "/".join(parts[4:]).split("?")[0]
56
+ return repo_id, filename
57
+ except:
58
+ return None, None
59
+ return None, None
60
+
61
+ def download_lora_smart(input_str, token):
62
+ local_path = TempDir / "adapter.safetensors"
63
+ if local_path.exists(): os.remove(local_path)
64
+
65
+ repo_id, filename = parse_hf_url(input_str)
66
+ if repo_id and filename:
67
+ try:
68
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
69
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
70
+ if found != local_path: shutil.move(found, local_path)
71
+ return local_path
72
+ except: pass
73
+ try:
74
+ if ".safetensors" in input_str and input_str.count("/") >= 2:
75
+ parts = input_str.split("/")
76
+ repo_id = f"{parts[0]}/{parts[1]}"
77
+ filename = "/".join(parts[2:])
78
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
79
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
80
+ if found != local_path: shutil.move(found, local_path)
81
+ return local_path
82
+ candidates = ["adapter_model.safetensors", "model.safetensors"]
83
+ files = list_repo_files(repo_id=input_str, token=token)
84
+ target = next((f for f in files if f in candidates), None)
85
+ if not target:
86
+ safes = [f for f in files if f.endswith(".safetensors")]
87
+ if safes: target = safes[0]
88
+ if not target: raise ValueError("No safetensors found")
89
+ hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
90
+ found = list(TempDir.rglob(target.split("/")[-1]))[0]
91
+ if found != local_path: shutil.move(found, local_path)
92
+ return local_path
93
+ except Exception as e:
94
+ if input_str.startswith("http"):
95
+ try:
96
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
97
+ r = requests.get(input_str, stream=True, headers=headers, timeout=60)
98
+ r.raise_for_status()
99
+ with open(local_path, 'wb') as f:
100
+ for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
101
+ return local_path
102
+ except: pass
103
+ raise e
104
+
105
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
106
+ state_dict = load_file(lora_path, device="cpu")
107
+ pairs = {}
108
+ alphas = {}
109
+ for k, v in state_dict.items():
110
+ stem = get_key_stem(k)
111
+ if "alpha" in k:
112
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
113
+ else:
114
+ if stem not in pairs: pairs[stem] = {}
115
+ if "lora_down" in k or "lora_A" in k:
116
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
117
+ pairs[stem]["rank"] = v.shape[0]
118
+ elif "lora_up" in k or "lora_B" in k:
119
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
120
+ for stem in pairs:
121
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
122
+ return pairs
123
+
124
+ def get_key_stem(key):
125
+ key = key.replace(".weight", "").replace(".bias", "")
126
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
127
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
128
+ key = key.replace(".alpha", "")
129
+ prefixes = [
130
+ "model.diffusion_model.", "diffusion_model.", "model.",
131
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
132
+ ]
133
+ changed = True
134
+ while changed:
135
+ changed = False
136
+ for p in prefixes:
137
+ if key.startswith(p):
138
+ key = key[len(p):]
139
+ changed = True
140
+ return key
141
+
142
+ # =================================================================================
143
+ # TABS 1-4 LOGIC (RESTORED)
144
+ # =================================================================================
145
+
146
+ class MemoryEfficientSafeOpen:
147
+ def __init__(self, filename):
148
+ self.filename = filename
149
+ self.file = open(filename, "rb")
150
+ self.header, self.header_size = self._read_header()
151
+ def __enter__(self): return self
152
+ def __exit__(self, exc_type, exc_val, exc_tb): self.file.close()
153
+ def keys(self) -> list[str]: return [k for k in self.header.keys() if k != "__metadata__"]
154
+ def metadata(self) -> Dict[str, str]: return self.header.get("__metadata__", {})
155
+ def get_tensor(self, key):
156
+ if key not in self.header: raise KeyError(f"Tensor '{key}' not found")
157
+ metadata = self.header[key]
158
+ start, end = metadata["data_offsets"]
159
+ self.file.seek(self.header_size + 8 + start)
160
+ return self._deserialize_tensor(self.file.read(end - start), metadata)
161
+ def _read_header(self):
162
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
163
+ return json.loads(self.file.read(header_size).decode("utf-8")), header_size
164
+ def _deserialize_tensor(self, tensor_bytes, metadata):
165
+ dtype_map = {"F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16, "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8, "U8": torch.uint8, "BOOL": torch.bool}
166
+ dtype = dtype_map[metadata["dtype"]]
167
+ shape = metadata["shape"]
168
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
169
+
170
+ class ShardBuffer:
171
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
172
+ self.max_bytes = int(max_size_gb * 1024**3)
173
+ self.output_dir, self.output_repo, self.subfolder, self.hf_token, self.filename_prefix = output_dir, output_repo, subfolder, hf_token, filename_prefix
174
+ self.buffer, self.current_bytes, self.shard_count, self.index_map, self.total_size = [], 0, 0, {}, 0
175
+ def add_tensor(self, key, tensor):
176
+ if tensor.dtype == torch.bfloat16: raw, dt = tensor.view(torch.int16).numpy().tobytes(), "BF16"
177
+ elif tensor.dtype == torch.float16: raw, dt = tensor.numpy().tobytes(), "F16"
178
+ else: raw, dt = tensor.numpy().tobytes(), "F32"
179
+ self.buffer.append({"key": key, "data": raw, "dtype": dt, "shape": tensor.shape})
180
+ self.current_bytes += len(raw)
181
+ self.total_size += len(raw)
182
+ if self.current_bytes >= self.max_bytes: self.flush()
183
+ def flush(self):
184
+ if not self.buffer: return
185
+ self.shard_count += 1
186
+ fname = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
187
+ header = {"__metadata__": {"format": "pt"}}
188
+ curr_off = 0
189
+ for i in self.buffer:
190
+ header[i["key"]] = {"dtype": i["dtype"], "shape": i["shape"], "data_offsets": [curr_off, curr_off + len(i["data"])]}
191
+ curr_off += len(i["data"])
192
+ self.index_map[i["key"]] = fname
193
+ out = self.output_dir / fname
194
+ header_json = json.dumps(header).encode('utf-8')
195
+ with open(out, 'wb') as f:
196
+ f.write(struct.pack('<Q', len(header_json)))
197
+ f.write(header_json)
198
+ for i in self.buffer: f.write(i["data"])
199
+ api.upload_file(path_or_fileobj=out, path_in_repo=f"{self.subfolder}/{fname}" if self.subfolder else fname, repo_id=self.output_repo, token=self.hf_token)
200
+ os.remove(out)
201
+ self.buffer, self.current_bytes = [], 0
202
+ gc.collect()
203
+
204
+ def task_merge_legacy(hf_token, base, sub, lora, scale, prec, shard, out, struct_s, priv, progress=gr.Progress()):
205
+ cleanup_temp()
206
+ if hf_token: login(hf_token.strip())
207
+ try: api.create_repo(repo_id=out, private=priv, exist_ok=True, token=hf_token)
208
+ except Exception as e: return f"Error: {e}"
209
+ if struct_s:
210
+ try:
211
+ files = api.list_repo_files(repo_id=struct_s, token=hf_token)
212
+ for f in tqdm(files, desc="Copying Structure"):
213
+ if sub and f.startswith(sub): continue
214
+ if not sub and any(f.endswith(x) for x in ['.safetensors', '.bin', '.pt', '.pth']): continue
215
+ l = hf_hub_download(repo_id=struct_s, filename=f, token=hf_token, local_dir=TempDir)
216
+ api.upload_file(path_or_fileobj=l, path_in_repo=f, repo_id=out, token=hf_token)
217
+ except: pass
218
+
219
+ files = [f for f in list_repo_files(repo_id=base, token=hf_token) if f.endswith(".safetensors")]
220
+ if sub: files = [f for f in files if f.startswith(sub)]
221
+ if not files: return "No safetensors found"
222
+
223
+ prefix = "diffusion_pytorch_model" if (sub in ["transformer", "unet"] or "diffusion_pytorch_model" in os.path.basename(files[0])) else "model"
224
+ dtype = torch.bfloat16 if prec == "bf16" else torch.float16 if prec == "fp16" else torch.float32
225
+ try: lora_pairs = load_lora_to_memory(download_lora_smart(lora, hf_token), dtype)
226
+ except Exception as e: return f"LoRA Error: {e}"
227
+
228
+ buf = ShardBuffer(shard, TempDir, out, sub, hf_token, prefix)
229
+ for i, fpath in enumerate(files):
230
+ local = hf_hub_download(repo_id=base, filename=fpath, token=hf_token, local_dir=TempDir)
231
+ with MemoryEfficientSafeOpen(local) as f:
232
+ for k in f.keys():
233
+ v = f.get_tensor(k)
234
+ stem = get_key_stem(k)
235
+ match = lora_pairs.get(stem) or lora_pairs.get(stem.replace("to_q", "qkv")) or lora_pairs.get(stem.replace("to_k", "qkv")) or lora_pairs.get(stem.replace("to_v", "qkv"))
236
+ if match:
237
+ d, u = match["down"], match["up"]
238
+ s = scale * (match["alpha"] / match["rank"])
239
+ if len(v.shape)==4 and len(d.shape)==2: d, u = d.unsqueeze(-1).unsqueeze(-1), u.unsqueeze(-1).unsqueeze(-1)
240
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) if len(up.shape)==4 else u @ d
241
+ v = v.to(dtype).add_((delta * s).to(dtype))
242
+ buf.add_tensor(k, v.to(dtype))
243
+ os.remove(local)
244
+ buf.flush()
245
+ idx = {"metadata": {"total_size": buf.total_size}, "weight_map": buf.index_map}
246
+ idx_n = f"{prefix}.safetensors.index.json"
247
+ with open(TempDir/idx_n, "w") as f: json.dump(idx, f, indent=4)
248
+ api.upload_file(path_or_fileobj=TempDir/idx_n, path_in_repo=f"{sub}/{idx_n}" if sub else idx_n, repo_id=out, token=hf_token)
249
+ return "Done"
250
+
251
+ def task_extract(hf_token, org, tun, rank, out):
252
+ cleanup_temp()
253
+ if hf_token: login(hf_token.strip())
254
+ try:
255
+ p1 = download_lora_smart(org, hf_token)
256
+ p2 = download_lora_smart(tun, hf_token)
257
+ org_f, tun_f = MemoryEfficientSafeOpen(p1), MemoryEfficientSafeOpen(p2)
258
+ lora_sd = {}
259
+ common = set(org_f.keys()) & set(tun_f.keys())
260
+ for k in tqdm(common, desc="Extracting"):
261
+ if "num_batches_tracked" in k or "running_mean" in k or "running_var" in k: continue
262
+ m1, m2 = org_f.get_tensor(k).float(), tun_f.get_tensor(k).float()
263
+ if m1.shape != m2.shape: continue
264
+ diff = m2 - m1
265
+ if torch.max(torch.abs(diff)) < 1e-4: continue
266
+ out_d, in_d = diff.shape[0], diff.shape[1] if len(diff.shape) > 1 else 1
267
+ r = min(int(rank), in_d, out_d)
268
+ if len(diff.shape)==4: diff = diff.flatten(1)
269
+ elif len(diff.shape)==1: diff = diff.unsqueeze(1)
270
+ U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
271
+ Vh = V.t()
272
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
273
+ U = U @ torch.diag(S)
274
+ dist = torch.cat([U.flatten(), Vh.flatten()])
275
+ hi_val = torch.quantile(torch.abs(dist), 0.99)
276
+ if hi_val > 0: U, Vh = U.clamp(-hi_val, hi_val), Vh.clamp(-hi_val, hi_val)
277
+ if len(m1.shape)==4:
278
+ U = U.reshape(out_d, r, 1, 1)
279
+ Vh = Vh.reshape(r, in_d, m1.shape[2], m1.shape[3])
280
+ else:
281
+ U, Vh = U.reshape(out_d, r), Vh.reshape(r, in_d)
282
+ stem = k.replace(".weight", "")
283
+ lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
284
+ lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
285
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
286
+ out_f = TempDir/"extracted.safetensors"
287
+ save_file(lora_sd, out_f)
288
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
289
+ api.upload_file(path_or_fileobj=out_f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
290
+ return "Done"
291
+ except Exception as e: return f"Error: {e}"
292
+
293
+ def load_full_state_dict(path):
294
+ raw = load_file(path, device="cpu")
295
+ cleaned = {}
296
+ for k, v in raw.items():
297
+ if "lora_A" in k: new_k = k.replace("lora_A", "lora_down")
298
+ elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up")
299
+ else: new_k = k
300
+ cleaned[new_k] = v.float()
301
+ return cleaned
302
+
303
+ def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private):
304
+ cleanup_temp()
305
+ if hf_token: login(hf_token.strip())
306
+ urls = [line.strip() for line in inputs_text.replace(" ", "\n").split('\n') if line.strip()]
307
+ if len(urls) < 2: return "Error: Provide at least 2 adapters."
308
+ try: weights = [float(w.strip()) for w in weight_str.split(',')] if weight_str.strip() else [1.0] * len(urls)
309
+ except: return "Error parsing weights."
310
+ if len(weights) < len(urls): weights += [1.0] * (len(urls) - len(weights))
311
+
312
+ paths = []
313
+ for url in tqdm(urls, desc="Downloading"): paths.append(download_lora_smart(url, hf_token))
314
+
315
+ merged = {}
316
+ if "Iterative EMA" in method:
317
+ base_sd = load_file(paths[0], device="cpu")
318
+ gamma = None
319
+ if sigma_rel > 0:
320
+ t_val = sigma_rel**-2
321
+ roots = np.roots([1, 7, 16 - t_val, 12 - t_val])
322
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
323
+ for i, path in enumerate(paths[1:]):
324
+ current_beta = (1 - 1 / (i + 1)) ** (gamma + 1) if gamma is not None else beta
325
+ curr = load_file(path, device="cpu")
326
+ for k in base_sd:
327
+ if k in curr and "alpha" not in k:
328
+ base_sd[k] = base_sd[k].float() * current_beta + curr[k].float() * (1 - current_beta)
329
+ merged = base_sd
330
+ else:
331
+ states = [load_full_state_dict(p) for p in paths]
332
+ all_stems = set()
333
+ for s in states:
334
+ for k in s:
335
+ if "lora_" in k: all_stems.add(k.split(".lora_")[0])
336
+ for stem in tqdm(all_stems):
337
+ down_list, up_list = [], []
338
+ alpha_sum, total_delta = 0.0, None
339
+ for i, state in enumerate(states):
340
+ w = weights[i]
341
+ dk, uk, ak = f"{stem}.lora_down.weight", f"{stem}.lora_up.weight", f"{stem}.alpha"
342
+ if dk in state and uk in state:
343
+ d, u = state[dk], state[uk]
344
+ alpha_sum += state[ak].item() if ak in state else d.shape[0]
345
+ if "Concatenation" in method:
346
+ down_list.append(d); up_list.append(u * w)
347
+ elif "SVD" in method:
348
+ rank = d.shape[0]
349
+ alpha = state[ak].item() if ak in state else rank
350
+ scale = (alpha / rank) * w
351
+ delta = ((u.flatten(1) @ d.flatten(1)).reshape(u.shape[0], d.shape[1], d.shape[2], d.shape[3]) if len(d.shape)==4 else u @ d) * scale
352
+ total_delta = delta if total_delta is None else total_delta + delta
353
+ if "Concatenation" in method and down_list:
354
+ merged[f"{stem}.lora_down.weight"] = torch.cat(down_list, dim=0).contiguous()
355
+ merged[f"{stem}.lora_up.weight"] = torch.cat(up_list, dim=1).contiguous()
356
+ merged[f"{stem}.alpha"] = torch.tensor(alpha_sum)
357
+ elif "SVD" in method and total_delta is not None:
358
+ tr = int(target_rank)
359
+ flat = total_delta.flatten(1) if len(total_delta.shape)==4 else total_delta
360
+ try:
361
+ U, S, V = torch.svd_lowrank(flat, q=tr + 4, niter=4)
362
+ Vh = V.t()
363
+ U, S, Vh = U[:, :tr], S[:tr], Vh[:tr, :]
364
+ U = U @ torch.diag(S)
365
+ if len(total_delta.shape) == 4:
366
+ U = U.reshape(total_delta.shape[0], tr, 1, 1)
367
+ Vh = Vh.reshape(tr, total_delta.shape[1], total_delta.shape[2], total_delta.shape[3])
368
+ else:
369
+ U, Vh = U.reshape(total_delta.shape[0], tr), Vh.reshape(tr, total_delta.shape[1])
370
+ merged[f"{stem}.lora_down.weight"] = Vh.contiguous()
371
+ merged[f"{stem}.lora_up.weight"] = U.contiguous()
372
+ merged[f"{stem}.alpha"] = torch.tensor(tr).float()
373
+ except: pass
374
+
375
+ out = TempDir / "merged_adapters.safetensors"
376
+ if merged: save_file(merged, out)
377
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
378
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
379
+ return f"Success! Merged to {out_repo}"
380
+
381
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
382
+ cleanup_temp()
383
+ if hf_token: login(hf_token.strip())
384
+ path = download_lora_smart(lora_input, hf_token)
385
+ state = load_file(path, device="cpu")
386
+ new_state = {}
387
+ groups = {}
388
+ for k in state:
389
+ simple = k.split(".lora_")[0]
390
+ if simple not in groups: groups[simple] = {}
391
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
392
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
393
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
394
+
395
+ target_rank_limit = int(new_rank)
396
+ for stem, g in tqdm(groups.items()):
397
+ if "down" in g and "up" in g:
398
+ down, up = g["down"].float(), g["up"].float()
399
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) if len(down.shape)==4 else up @ down
400
+ flat = merged.flatten(1)
401
+ U, S, V = torch.svd_lowrank(flat, q=target_rank_limit + 32)
402
+ Vh = V.t()
403
+ calc_rank = target_rank_limit
404
+ if dynamic_method == "sv_ratio":
405
+ calc_rank = int(torch.sum(S > (S[0] / dynamic_param)).item())
406
+ elif dynamic_method == "sv_cumulative":
407
+ calc_rank = int(torch.searchsorted(torch.cumsum(S, 0) / torch.sum(S), dynamic_param)) + 1
408
+ elif dynamic_method == "sv_fro":
409
+ calc_rank = int(torch.searchsorted(torch.cumsum(S.pow(2), 0) / torch.sum(S.pow(2)), dynamic_param**2)) + 1
410
+ final_rank = max(1, min(calc_rank, target_rank_limit, S.shape[0]))
411
+ U = U[:, :final_rank] @ torch.diag(S[:final_rank])
412
+ Vh = Vh[:final_rank, :]
413
+ if len(down.shape) == 4:
414
+ U = U.reshape(up.shape[0], final_rank, 1, 1)
415
+ Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3])
416
+ new_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
417
+ new_state[f"{stem}.lora_up.weight"] = U.contiguous()
418
+ new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float()
419
+ out = TempDir / "shrunken.safetensors"
420
+ save_file(new_state, out)
421
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
422
+ api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token)
423
+ return "Done"
424
+
425
+ # =================================================================================
426
+ # MERGEKIT & STREAMING LOGS (TABS 5-9)
427
+ # =================================================================================
428
+
429
+ def parse_weight(w_str):
430
+ if not w_str.strip(): return 1.0
431
+ try:
432
+ if "[" in w_str: return yaml.safe_load(w_str)
433
+ return float(w_str)
434
+ except: return 1.0
435
+
436
+ def run_mergekit_logic(config_dict, token, out_repo, private, shard_size, output_precision, tokenizer_source, chat_template, program="mergekit-yaml"):
437
+ # Using generator for streaming logs directly to a Textbox, bypassing component issues
438
+ logs = []
439
+ def log(msg):
440
+ logs.append(msg)
441
+ return "\n".join(logs)
442
+
443
+ yield log("Starting MergeKit Process...")
444
+
445
+ cleanup_temp()
446
+
447
+ if chat_template and chat_template.strip():
448
+ config_dict["chat_template"] = chat_template.strip()
449
+
450
+ # Validation
451
+ try:
452
+ if program != "mergekit-moe":
453
+ MergeConfiguration.model_validate(config_dict)
454
+ yield log("Configuration Validated Successfully.")
455
+ except Exception as e:
456
+ yield log(f"Invalid Config: {e}")
457
+ return
458
+
459
+ if token:
460
+ login(token.strip())
461
+ os.environ["HF_TOKEN"] = token.strip()
462
+
463
+ if "dtype" not in config_dict: config_dict["dtype"] = output_precision
464
+ if "tokenizer_source" not in config_dict and tokenizer_source != "base":
465
+ config_dict["tokenizer_source"] = tokenizer_source
466
+
467
+ config_path = TempDir / "config.yaml"
468
+ with open(config_path, "w") as f: yaml.dump(config_dict, f, sort_keys=False)
469
+
470
+ yield log(f"Config saved to {config_path}")
471
+ yield log(f"YAML:\n{yaml.dump(config_dict, sort_keys=False)}")
472
+
473
+ try:
474
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=token)
475
+ yield log(f"Repo {out_repo} ready.")
476
+ except Exception as e:
477
+ yield log(f"Repo Creation Error (might exist): {e}")
478
+
479
+ out_path = TempDir / "merge_output"
480
+ shard_arg = f"{int(float(shard_size) * 1024)}M"
481
+
482
+ cmd = [
483
+ program,
484
+ str(config_path),
485
+ str(out_path),
486
+ "--allow-crimes",
487
+ "--copy-tokenizer",
488
+ "--out-shard-size", shard_arg,
489
+ "--lazy-unpickle"
490
+ ]
491
+
492
+ if torch.cuda.is_available():
493
+ cmd.extend(["--cuda", "--low-cpu-memory"])
494
+
495
+ yield log(f"Executing: {' '.join(cmd)}")
496
+
497
+ env = os.environ.copy()
498
+ env["HF_HOME"] = str(TempDir / ".cache")
499
+
500
+ # Run process
501
+ process = subprocess.Popen(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
502
+
503
+ for line in iter(process.stdout.readline, ""):
504
+ yield log(line.strip())
505
+
506
+ process.wait()
507
+
508
+ if process.returncode != 0:
509
+ yield log("Merge failed with exit code " + str(process.returncode))
510
+ return
511
+
512
+ yield log(f"Uploading to {out_repo}...")
513
+ try:
514
+ api.upload_folder(repo_id=out_repo, folder_path=out_path)
515
+ yield log("Upload Complete!")
516
+ except Exception as e:
517
+ yield log(f"Upload failed: {e}")
518
+
519
+ # --- UI Wrappers for Tabs 5-9 ---
520
+
521
+ def wrapper_amphinterpolative(token, method, base, t, norm, i8, flat, row, eps, m_iter, tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t):
522
+ params = {"normalize": norm, "int8_mask": i8}
523
+ if method in ["slerp", "nuslerp"]: params["t"] = float(t)
524
+ if method == "nuslerp": params.update({"flatten": flat, "row_wise": row})
525
+ if method == "multislerp": params["eps"] = float(eps)
526
+ if method == "karcher": params.update({"max_iter": int(m_iter), "tol": float(tol)})
527
+
528
+ config = {"merge_method": method}
529
+
530
+ if method in ["slerp", "nuslerp"]:
531
+ if not base.strip(): yield "Error: Base model required"; return
532
+ config["base_model"] = base.strip()
533
+ sources = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2)] if m.strip()]
534
+ config["slices"] = [{"sources": sources, "parameters": params}]
535
+ else:
536
+ if base.strip() and method == "multislerp": config["base_model"] = base.strip()
537
+ models = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)] if m.strip()]
538
+ config["models"] = models
539
+ config["parameters"] = params
540
+
541
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
542
+
543
+ def wrapper_stirtie(token, method, base, norm, i8, lamb, resc, topk, m1, w1, d1, g1, e1, m2, w2, d2, g2, e2, m3, w3, d3, g3, e3, m4, w4, d4, g4, e4, out, priv, shard, prec, tok_src, chat_t):
544
+ models = []
545
+ # Explicit loop over the 4 sets of model inputs
546
+ for m, w, d, g, e in [
547
+ (m1, w1, d1, g1, e1),
548
+ (m2, w2, d2, g2, e2),
549
+ (m3, w3, d3, g3, e3),
550
+ (m4, w4, d4, g4, e4)
551
+ ]:
552
+ if not m.strip(): continue
553
+ p = {"weight": parse_weight(w)}
554
+ if method in ["ties", "dare_ties", "dare_linear", "breadcrumbs_ties"]: p["density"] = parse_weight(d)
555
+ if "breadcrumbs" in method: p["gamma"] = float(g)
556
+ if "della" in method: p["epsilon"] = float(e)
557
+ models.append({"model": m, "parameters": p})
558
+
559
+ g_params = {"normalize": norm, "int8_mask": i8}
560
+ if method != "sce": g_params["lambda"] = float(lamb)
561
+ if method == "dare_linear": g_params["rescale"] = resc
562
+ if method == "sce": g_params["select_topk"] = float(topk)
563
+
564
+ config = {
565
+ "merge_method": method,
566
+ "base_model": base.strip() if base.strip() else models[0]["model"],
567
+ "parameters": g_params,
568
+ "models": models
569
+ }
570
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
571
+
572
+ def wrapper_specious(token, method, base, norm, i8, t, filt_w, m1, w1, f1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv, shard, prec, tok_src, chat_t):
573
+ models = []
574
+ if method == "passthrough":
575
+ p = {"weight": parse_weight(w1)}
576
+ if f1.strip(): p["filter"] = f1.strip()
577
+ models.append({"model": m1, "parameters": p})
578
+ else:
579
+ models = [{"model": m, "parameters": {"weight": parse_weight(w)}} for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)] if m.strip()]
580
+
581
+ config = {"merge_method": method, "parameters": {"normalize": norm, "int8_mask": i8}}
582
+ if base.strip(): config["base_model"] = base.strip()
583
+ if method == "nearswap": config["parameters"]["t"] = float(t)
584
+ if method == "model_stock": config["parameters"]["filter_wise"] = filt_w
585
+ config["models"] = models
586
+
587
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
588
+
589
+ def wrapper_moer(token, base, experts, gate, dtype, out, priv, shard, prec, tok_src, chat_t):
590
+ formatted = [{"source_model": e.strip(), "positive_prompts": ["chat", "assist"]} for e in experts.split('\n') if e.strip()]
591
+ config = {
592
+ "base_model": base.strip() if base.strip() else formatted[0]["source_model"],
593
+ "gate_mode": gate,
594
+ "dtype": dtype,
595
+ "experts": formatted
596
+ }
597
+ # Uses mergekit-moe CLI
598
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-moe")
599
+
600
+ def wrapper_rawer(token, models, method, dtype, out, priv, shard, prec, tok_src, chat_t):
601
+ models_list = [{"model": m.strip(), "parameters": {"weight": 1.0}} for m in models.split('\n') if m.strip()]
602
+ config = {
603
+ "models": models_list,
604
+ "merge_method": method,
605
+ "dtype": dtype
606
+ }
607
+ yield from run_mergekit_logic(config, token, out, priv, shard, prec, tok_src, chat_t, program="mergekit-yaml")
608
+
609
+ # --- TAB 10 (Custom DARE) Logic ---
610
+ def task_dare_custom(token, base, ft, ratio, mask, out, priv):
611
+ cleanup_temp()
612
+ if token: login(token.strip())
613
+ try:
614
+ b_path = download_lora_smart(base, token)
615
+ f_path = download_lora_smart(ft, token)
616
+ b_sd = load_file(b_path, device="cpu")
617
+ f_sd = load_file(f_path, device="cpu")
618
+ merged = {}
619
+ common = set(b_sd.keys()) & set(f_sd.keys())
620
+ for k in tqdm(common, desc="Merging"):
621
+ tb, tf = b_sd[k], f_sd[k]
622
+ if tb.shape != tf.shape:
623
+ merged[k] = tf
624
+ continue
625
+ delta = tf.float() - tb.float()
626
+ if mask > 0:
627
+ m = torch.bernoulli(torch.full_like(delta, 1.0 - mask))
628
+ delta = (delta * m) / (1.0 - mask)
629
+ merged[k] = (tb.float() + ratio * delta).to(tb.dtype)
630
+
631
+ out_f = TempDir / "model.safetensors"
632
+ save_file(merged, out_f)
633
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
634
+ api.upload_file(path_or_fileobj=out_f, path_in_repo="model.safetensors", repo_id=out, token=token)
635
+ return f"Done! {out}"
636
+ except Exception as e: return str(e)
637
+
638
+ # =================================================================================
639
+ # UI GENERATION
640
+ # =================================================================================
641
+
642
+ css = ".container { max-width: 1100px; margin: auto; }"
643
+
644
+ with gr.Blocks() as demo:
645
+ gr.HTML("""<h1><img src="https://huggingface.co/spaces/AlekseyCalvin/Soon_Merger/resolve/main/SMerger3.png" alt="SOONmerge®"> Transform Transformers for FREE!</h1>""")
646
+ gr.Markdown("# 🧰Training-Free CPU-run Model Creation Toolkit")
647
+
648
+ with gr.Tabs():
649
+ # --- TAB 1: RESTORED ---
650
+ with gr.Tab("Merge to Base Model + Reshard Output"):
651
+ t1_token = gr.Textbox(label="Token", type="password")
652
+ t1_base = gr.Textbox(label="Base Repo", value="name/repo")
653
+ t1_sub = gr.Textbox(label="Subfolder (Optional)", value="")
654
+ t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
655
+ with gr.Row():
656
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1)
657
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
658
+ t1_shard = gr.Slider(label="Max Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1)
659
+ t1_out = gr.Textbox(label="Output Repo")
660
+ t1_struct = gr.Textbox(label="Extras Source (copies configs/components/etc)", value="name/repo")
661
+ t1_priv = gr.Checkbox(label="Private", value=True)
662
+ gr.Button("Merge").click(task_merge_legacy, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], gr.Textbox(label="Result"))
663
+
664
+ # --- TAB 2: RESTORED ---
665
+ with gr.Tab("Extract Adapter"):
666
+ t2_token = gr.Textbox(label="Token", type="password")
667
+ t2_org = gr.Textbox(label="Original Model")
668
+ t2_tun = gr.Textbox(label="Tuned or Homologous Model")
669
+ t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1)
670
+ t2_out = gr.Textbox(label="Output Repo")
671
+ gr.Button("Extract").click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], gr.Textbox(label="Result"))
672
+
673
+ # --- TAB 3: RESTORED ---
674
+ with gr.Tab("Merge Adapters"):
675
+ gr.Markdown("### Batch Adapter Merging")
676
+ t3_token = gr.Textbox(label="Token", type="password")
677
+ t3_urls = gr.TextArea(label="Adapter URLs/Repos (one per line, or space-separated)")
678
+ t3_method = gr.Dropdown(["Iterative EMA (Linear w/ Beta/Sigma coefficient)", "Concatenation (MOE-like weights-stack)", "SVD Fusion (Task Arithmetic/Compressed)"], value="Iterative EMA (Linear w/ Beta/Sigma coefficient)", label="Merge Method")
679
+ with gr.Row():
680
+ t3_weights = gr.Textbox(label="Weights (comma-separated) – for Concat/SVD")
681
+ t3_rank = gr.Number(label="Target Rank – For SVD only", value=128)
682
+ with gr.Row():
683
+ t3_beta = gr.Slider(label="Beta – for linear/post-hoc EMA", value=0.95, minimum=0.01, maximum=1.00)
684
+ t3_sigma = gr.Slider(label="Sigma Rel – for linear/post-hoc EMA", value=0.21, minimum=0.01, maximum=1.00)
685
+ t3_out = gr.Textbox(label="Output Repo")
686
+ t3_priv = gr.Checkbox(label="Private Output", value=True)
687
+ gr.Button("Merge").click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], gr.Textbox(label="Result"))
688
+
689
+ # --- TAB 4: RESTORED ---
690
+ with gr.Tab("Resize Adapter"):
691
+ t4_token = gr.Textbox(label="Token", type="password")
692
+ t4_in = gr.Textbox(label="LoRA")
693
+ with gr.Row():
694
+ t4_rank = gr.Number(label="To Rank (Safety Ceiling)", value=8)
695
+ t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None", label="Dynamic Method")
696
+ t4_param = gr.Number(label="Dynamic Param", value=0.9)
697
+ gr.Markdown("### 📉 Dynamic Resizing Guide\nThese methods intelligently determine the best rank per layer.\n- **sv_ratio (Relative Strength):** Keeps features that are at least `1/Param` as strong as the main feature. **Param must be >= 2**.\n- **sv_fro (Visual Information Density):** Preserves `Param%` of total information content. **Param between 0.0 and 1.0**.\n- **sv_cumulative (Cumulative Sum):** Preserves weights that sum up to `Param%` of total strength. **Param between 0.0 and 1.0**.\n- **⚠️ Safety Ceiling:** The **'To Rank'** slider acts as a hard limit.")
698
+ t4_out = gr.Textbox(label="Output")
699
+ gr.Button("Resize").click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], gr.Textbox(label="Result"))
700
+
701
+ # --- TAB 5 ---
702
+ with gr.Tab("Amphinterpolative"):
703
+ gr.Markdown("### Spherical Interpolation Family")
704
+ t5_token = gr.Textbox(label="HF Token", type="password")
705
+ t5_method = gr.Dropdown(["slerp", "nuslerp", "multislerp", "karcher"], value="slerp", label="Method")
706
+ with gr.Row():
707
+ t5_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0)
708
+ t5_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision")
709
+ t5_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source")
710
+ t5_chat = gr.Textbox(label="Chat Template (write-in, default: auto)", placeholder="auto")
711
+ with gr.Row():
712
+ t5_base = gr.Textbox(label="Base Model")
713
+ t5_t = gr.Slider(0, 1, 0.5, label="t")
714
+ with gr.Row():
715
+ t5_norm = gr.Checkbox(label="Normalize", value=True); t5_i8 = gr.Checkbox(label="Int8 Mask", value=False); t5_flat = gr.Checkbox(label="NuSlerp Flatten", value=False); t5_row = gr.Checkbox(label="NuSlerp Row Wise", value=False)
716
+ with gr.Row():
717
+ t5_eps = gr.Textbox(label="eps", value="1e-8"); t5_iter = gr.Number(label="max_iter", value=10); t5_tol = gr.Textbox(label="tol", value="1e-5")
718
+ m1, w1 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"); m2, w2 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0")
719
+ with gr.Accordion("More", open=False):
720
+ m3, w3 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); m4, w4 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); m5, w5 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0")
721
+ t5_out = gr.Textbox(label="Output Repo"); t5_priv = gr.Checkbox(label="Private", value=True)
722
+ t5_btn = gr.Button("Execute")
723
+ t5_res = gr.Textbox(label="Result", lines=10)
724
+ t5_btn.click(wrapper_amphinterpolative, [t5_token, t5_method, t5_base, t5_t, t5_norm, t5_i8, t5_flat, t5_row, t5_eps, t5_iter, t5_tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, t5_out, t5_priv, t5_shard, t5_prec, t5_tok, t5_chat], t5_res)
725
+
726
+ # --- TAB 6 ---
727
+ with gr.Tab("Stir/Tie Bases"):
728
+ gr.Markdown("### Task Vector Family")
729
+ t6_token = gr.Textbox(label="Token", type="password")
730
+ t6_method = gr.Dropdown(["task_arithmetic", "ties", "dare_ties", "dare_linear", "della", "della_linear", "breadcrumbs", "breadcrumbs_ties", "sce"], value="ties", label="Method")
731
+ with gr.Row():
732
+ t6_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t6_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t6_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t6_chat = gr.Textbox(label="Chat Template", placeholder="auto")
733
+ t6_base = gr.Textbox(label="Base Model")
734
+ with gr.Row():
735
+ t6_norm = gr.Checkbox(label="Normalize", value=True); t6_i8 = gr.Checkbox(label="Int8 Mask", value=False); t6_resc = gr.Checkbox(label="Rescale", value=True); t6_lamb = gr.Number(label="Lambda", value=1.0); t6_topk = gr.Slider(0, 1, 1.0, label="Select TopK")
736
+ m1_6, w1_6 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"); d1_6, g1_6, e1_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
737
+ with gr.Accordion("More", open=False):
738
+ m2_6, w2_6 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0"); d2_6, g2_6, e2_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
739
+ m3_6, w3_6 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); d3_6, g3_6, e3_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
740
+ m4_6, w4_6 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); d4_6, g4_6, e4_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
741
+ t6_out = gr.Textbox(label="Output Repo"); t6_priv = gr.Checkbox(label="Private", value=True)
742
+ t6_btn = gr.Button("Execute")
743
+ t6_res = gr.Textbox(label="Result", lines=10)
744
+ t6_btn.click(wrapper_stirtie, [t6_token, t6_method, t6_base, t6_norm, t6_i8, t6_lamb, t6_resc, t6_topk, m1_6, w1_6, d1_6, g1_6, e1_6, m2_6, w2_6, d2_6, g2_6, e2_6, m3_6, w3_6, d3_6, g3_6, e3_6, m4_6, w4_6, d4_6, g4_6, e4_6, t6_out, t6_priv, t6_shard, t6_prec, t6_tok, t6_chat], t6_res)
745
+
746
+ # --- TAB 7 ---
747
+ with gr.Tab("Specious"):
748
+ gr.Markdown("### Specialized Methods")
749
+ t7_token = gr.Textbox(label="Token", type="password")
750
+ t7_method = gr.Dropdown(["model_stock", "nearswap", "arcee_fusion", "passthrough", "linear"], value="model_stock", label="Method")
751
+ with gr.Row():
752
+ t7_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t7_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t7_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t7_chat = gr.Textbox(label="Chat Template", placeholder="auto")
753
+ t7_base = gr.Textbox(label="Base Model")
754
+ with gr.Row():
755
+ t7_norm = gr.Checkbox(label="Normalize", value=True); t7_i8 = gr.Checkbox(label="Int8 Mask", value=False); t7_t = gr.Slider(0, 1, 0.5, label="t"); t7_filt_w = gr.Checkbox(label="Filter Wise", value=False)
756
+ m1_7, w1_7, f1_7 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0"), gr.Textbox(label="Filter (Passthrough)")
757
+ m2_7, w2_7 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0")
758
+ with gr.Accordion("More", open=False):
759
+ m3_7, w3_7 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0"); m4_7, w4_7 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0"); m5_7, w5_7 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0")
760
+ t7_out = gr.Textbox(label="Output Repo"); t7_priv = gr.Checkbox(label="Private", value=True)
761
+ t7_btn = gr.Button("Execute")
762
+ t7_res = gr.Textbox(label="Result", lines=10)
763
+ t7_btn.click(wrapper_specious, [t7_token, t7_method, t7_base, t7_norm, t7_i8, t7_t, t7_filt_w, m1_7, w1_7, f1_7, m2_7, w2_7, m3_7, w3_7, m4_7, w4_7, m5_7, w5_7, t7_out, t7_priv, t7_shard, t7_prec, t7_tok, t7_chat], t7_res)
764
+
765
+ # --- TAB 8 (MoEr) ---
766
+ with gr.Tab("MoEr"):
767
+ gr.Markdown("### Mixture of Experts")
768
+ t8_token = gr.Textbox(label="Token", type="password")
769
+ with gr.Row():
770
+ t8_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t8_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t8_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t8_chat = gr.Textbox(label="Chat Template", placeholder="auto")
771
+ t8_base = gr.Textbox(label="Base Model"); t8_experts = gr.TextArea(label="Experts List"); t8_gate = gr.Dropdown(["cheap_embed", "random", "hidden"], value="cheap_embed", label="Gate Mode"); t8_dtype = gr.Dropdown(["float16", "bfloat16"], value="bfloat16", label="Internal Dtype")
772
+ t8_out = gr.Textbox(label="Output Repo"); t8_priv = gr.Checkbox(label="Private", value=True)
773
+ t8_btn = gr.Button("Build MoE")
774
+ t8_res = gr.Textbox(label="Result", lines=10)
775
+ t8_btn.click(wrapper_moer, [t8_token, t8_base, t8_experts, t8_gate, t8_dtype, t8_out, t8_priv, t8_shard, t8_prec, t8_tok, t8_chat], t8_res)
776
+
777
+ # --- TAB 9 (Rawer) ---
778
+ with gr.Tab("Rawer"):
779
+ gr.Markdown("### Raw PyTorch / Non-Transformer")
780
+ t9_token = gr.Textbox(label="Token", type="password"); t9_models = gr.TextArea(label="Models (one per line)")
781
+ with gr.Row():
782
+ t9_shard = gr.Slider(label="Max Shard Size (GB)", value=5.0, minimum=1.0, maximum=20.0); t9_prec = gr.Dropdown(["float16", "bfloat16", "float32"], value="bfloat16", label="Output Precision"); t9_tok = gr.Dropdown(["base", "union", "model:path"], value="base", label="Tokenizer Source"); t9_chat = gr.Textbox(label="Chat Template", placeholder="auto")
783
+ t9_method = gr.Dropdown(["linear", "passthrough"], value="linear", label="Method"); t9_dtype = gr.Dropdown(["float32", "float16", "bfloat16"], value="float32", label="Config Dtype")
784
+ t9_out = gr.Textbox(label="Output Repo"); t9_priv = gr.Checkbox(label="Private", value=True)
785
+ t9_btn = gr.Button("Merge Raw")
786
+ t9_res = gr.Textbox(label="Result", lines=10)
787
+ t9_btn.click(wrapper_rawer, [t9_token, t9_models, t9_method, t9_dtype, t9_out, t9_priv, t9_shard, t9_prec, t9_tok, t9_chat], t9_res)
788
+
789
+ # --- TAB 10 ---
790
+ with gr.Tab("Mario,DARE!"):
791
+ t10_token = gr.Textbox(label="Token", type="password")
792
+ with gr.Row():
793
+ t10_base = gr.Textbox(label="Base Model"); t10_ft = gr.Textbox(label="Fine-Tuned Model")
794
+ with gr.Row():
795
+ t10_ratio = gr.Slider(0, 5, 1.0, label="Ratio"); t10_mask = gr.Slider(0, 0.99, 0.5, label="Mask Rate")
796
+ t10_out = gr.Textbox(label="Output Repo"); t10_priv = gr.Checkbox(label="Private", value=True)
797
+ gr.Button("Run").click(task_dare_custom, [t10_token, t10_base, t10_ft, t10_ratio, t10_mask, t10_out, t10_priv], gr.Textbox(label="Result"))
798
+
799
+ if __name__ == "__main__":
800
+ demo.queue().launch(css=css, ssr_mode=False)