AlekseyCalvin commited on
Commit
f42b32a
·
verified ·
1 Parent(s): 5c39fe6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1039 -0
app.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import yaml
11
+ import subprocess
12
+ import shlex
13
+ from pathlib import Path
14
+ from typing import Dict, Any, Optional, List
15
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login, get_repo_discussions
16
+ from safetensors.torch import load_file, save_file, safe_open
17
+ from tqdm import tqdm
18
+
19
+ # --- Memory Efficient Safetensors ---
20
+ class MemoryEfficientSafeOpen:
21
+ def __init__(self, filename):
22
+ self.filename = filename
23
+ self.file = open(filename, "rb")
24
+ self.header, self.header_size = self._read_header()
25
+
26
+ def __enter__(self):
27
+ return self
28
+
29
+ def __exit__(self, exc_type, exc_val, exc_tb):
30
+ self.file.close()
31
+
32
+ def keys(self) -> list[str]:
33
+ return [k for k in self.header.keys() if k != "__metadata__"]
34
+
35
+ def metadata(self) -> Dict[str, str]:
36
+ return self.header.get("__metadata__", {})
37
+
38
+ def get_tensor(self, key):
39
+ if key not in self.header:
40
+ raise KeyError(f"Tensor '{key}' not found in the file")
41
+ metadata = self.header[key]
42
+ offset_start, offset_end = metadata["data_offsets"]
43
+ self.file.seek(self.header_size + 8 + offset_start)
44
+ tensor_bytes = self.file.read(offset_end - offset_start)
45
+ return self._deserialize_tensor(tensor_bytes, metadata)
46
+
47
+ def _read_header(self):
48
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
49
+ header_json = self.file.read(header_size).decode("utf-8")
50
+ return json.loads(header_json), header_size
51
+
52
+ def _deserialize_tensor(self, tensor_bytes, metadata):
53
+ dtype_map = {
54
+ "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
55
+ "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
56
+ "U8": torch.uint8, "BOOL": torch.bool
57
+ }
58
+ dtype = dtype_map[metadata["dtype"]]
59
+ shape = metadata["shape"]
60
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
61
+
62
+ # --- Constants & Setup ---
63
+ try:
64
+ TempDir = Path("/tmp/temp_tool")
65
+ os.makedirs(TempDir, exist_ok=True)
66
+ except:
67
+ TempDir = Path("./temp_tool")
68
+ os.makedirs(TempDir, exist_ok=True)
69
+
70
+ api = HfApi()
71
+
72
+ def cleanup_temp():
73
+ if TempDir.exists():
74
+ shutil.rmtree(TempDir)
75
+ os.makedirs(TempDir, exist_ok=True)
76
+ gc.collect()
77
+
78
+ def get_key_stem(key):
79
+ key = key.replace(".weight", "").replace(".bias", "")
80
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
81
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
82
+ key = key.replace(".alpha", "")
83
+ prefixes = [
84
+ "model.diffusion_model.", "diffusion_model.", "model.",
85
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
86
+ ]
87
+ changed = True
88
+ while changed:
89
+ changed = False
90
+ for p in prefixes:
91
+ if key.startswith(p):
92
+ key = key[len(p):]
93
+ changed = True
94
+ return key
95
+
96
+ # --- Helper Functions for Download ---
97
+ def parse_hf_url(url):
98
+ if "huggingface.co" in url and "resolve" in url:
99
+ try:
100
+ parts = url.split("huggingface.co/")[-1].split("/")
101
+ repo_id = f"{parts[0]}/{parts[1]}"
102
+ filename = "/".join(parts[4:]).split("?")[0]
103
+ return repo_id, filename
104
+ except:
105
+ return None, None
106
+ return None, None
107
+
108
+ def download_lora_smart(input_str, token):
109
+ local_path = TempDir / "adapter.safetensors"
110
+ if local_path.exists(): os.remove(local_path)
111
+
112
+ repo_id, filename = parse_hf_url(input_str)
113
+ if repo_id and filename:
114
+ try:
115
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
116
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
117
+ if found != local_path: shutil.move(found, local_path)
118
+ return local_path
119
+ except: pass
120
+ try:
121
+ if ".safetensors" in input_str and input_str.count("/") >= 2:
122
+ parts = input_str.split("/")
123
+ repo_id = f"{parts[0]}/{parts[1]}"
124
+ filename = "/".join(parts[2:])
125
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
126
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
127
+ if found != local_path: shutil.move(found, local_path)
128
+ return local_path
129
+ candidates = ["adapter_model.safetensors", "model.safetensors"]
130
+ files = list_repo_files(repo_id=input_str, token=token)
131
+ target = next((f for f in files if f in candidates), None)
132
+ if not target:
133
+ safes = [f for f in files if f.endswith(".safetensors")]
134
+ if safes: target = safes[0]
135
+ if not target: raise ValueError("No safetensors found")
136
+ hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
137
+ found = list(TempDir.rglob(target.split("/")[-1]))[0]
138
+ if found != local_path: shutil.move(found, local_path)
139
+ return local_path
140
+ except Exception as e:
141
+ if input_str.startswith("http"):
142
+ try:
143
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
144
+ r = requests.get(input_str, stream=True, headers=headers, timeout=60)
145
+ r.raise_for_status()
146
+ with open(local_path, 'wb') as f:
147
+ for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
148
+ return local_path
149
+ except: pass
150
+ raise e
151
+
152
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
153
+ state_dict = load_file(lora_path, device="cpu")
154
+ pairs = {}
155
+ alphas = {}
156
+ for k, v in state_dict.items():
157
+ stem = get_key_stem(k)
158
+ if "alpha" in k:
159
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
160
+ else:
161
+ if stem not in pairs: pairs[stem] = {}
162
+ if "lora_down" in k or "lora_A" in k:
163
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
164
+ pairs[stem]["rank"] = v.shape[0]
165
+ elif "lora_up" in k or "lora_B" in k:
166
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
167
+ for stem in pairs:
168
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
169
+ return pairs
170
+
171
+ class ShardBuffer:
172
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
173
+ self.max_bytes = int(max_size_gb * 1024**3)
174
+ self.output_dir = output_dir
175
+ self.output_repo = output_repo
176
+ self.subfolder = subfolder
177
+ self.hf_token = hf_token
178
+ self.filename_prefix = filename_prefix
179
+ self.buffer = []
180
+ self.current_bytes = 0
181
+ self.shard_count = 0
182
+ self.index_map = {}
183
+ self.total_size = 0
184
+
185
+ def add_tensor(self, key, tensor):
186
+ if tensor.dtype == torch.bfloat16:
187
+ raw_bytes = tensor.view(torch.int16).numpy().tobytes()
188
+ dtype_str = "BF16"
189
+ elif tensor.dtype == torch.float16:
190
+ raw_bytes = tensor.numpy().tobytes()
191
+ dtype_str = "F16"
192
+ else:
193
+ raw_bytes = tensor.numpy().tobytes()
194
+ dtype_str = "F32"
195
+ size = len(raw_bytes)
196
+ self.buffer.append({"key": key, "data": raw_bytes, "dtype": dtype_str, "shape": tensor.shape})
197
+ self.current_bytes += size
198
+ self.total_size += size
199
+ if self.current_bytes >= self.max_bytes: self.flush()
200
+
201
+ def flush(self):
202
+ if not self.buffer: return
203
+ self.shard_count += 1
204
+ filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
205
+ path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
206
+ header = {"__metadata__": {"format": "pt"}}
207
+ current_offset = 0
208
+ for item in self.buffer:
209
+ header[item["key"]] = {"dtype": item["dtype"], "shape": item["shape"], "data_offsets": [current_offset, current_offset + len(item["data"])]}
210
+ current_offset += len(item["data"])
211
+ self.index_map[item["key"]] = filename
212
+ header_json = json.dumps(header).encode('utf-8')
213
+ out_path = self.output_dir / filename
214
+ with open(out_path, 'wb') as f:
215
+ f.write(struct.pack('<Q', len(header_json)))
216
+ f.write(header_json)
217
+ for item in self.buffer: f.write(item["data"])
218
+ api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
219
+ os.remove(out_path)
220
+ self.buffer = []
221
+ self.current_bytes = 0
222
+ gc.collect()
223
+
224
+ def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix=None, is_root_merge=False):
225
+ try:
226
+ files = api.list_repo_files(repo_id=src_repo, token=token)
227
+ for f in tqdm(files, desc="Copying Structure"):
228
+ if ignore_prefix and f.startswith(ignore_prefix): continue
229
+ if is_root_merge:
230
+ if any(f.endswith(ext) for ext in ['.safetensors', '.bin', '.pt', '.pth']): continue
231
+ try:
232
+ local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir)
233
+ api.upload_file(path_or_fileobj=local, path_in_repo=f, repo_id=dst_repo, token=token)
234
+ if os.path.exists(local): os.remove(local)
235
+ except: pass
236
+ except: pass
237
+
238
+ def identify_and_download_model(input_str, token):
239
+ repo_id, filename = parse_hf_url(input_str)
240
+ if repo_id and filename:
241
+ local_path = TempDir / os.path.basename(filename)
242
+ if local_path.exists(): os.remove(local_path)
243
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
244
+ return list(TempDir.rglob(os.path.basename(filename)))[0]
245
+ files = list_repo_files(repo_id=input_str, token=token)
246
+ priorities = ["transformer/diffusion_pytorch_model.safetensors", "unet/diffusion_pytorch_model.safetensors", "model.safetensors"]
247
+ target_file = next((f for f in priorities if f in files), next((f for f in files if f.endswith(".safetensors") and "lora" not in f), None))
248
+ if not target_file: raise ValueError("No model file found")
249
+ hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir)
250
+ return list(TempDir.rglob(os.path.basename(target_file)))[0]
251
+
252
+ # =================================================================================
253
+ # TAB 1: MERGE & RESHARD
254
+ # =================================================================================
255
+
256
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
257
+ cleanup_temp()
258
+ if not hf_token: return "Error: HF Token required."
259
+ login(hf_token.strip())
260
+ try: api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
261
+ except Exception as e: return f"Error creating repo: {e}"
262
+ output_subfolder = base_subfolder if base_subfolder else ""
263
+ if structure_repo:
264
+ ignore = output_subfolder if output_subfolder else None
265
+ streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix=ignore, is_root_merge=not bool(output_subfolder))
266
+ progress(0.1, desc="Downloading Input Model...")
267
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
268
+ input_shards = []
269
+ for f in files:
270
+ if f.endswith(".safetensors"):
271
+ if output_subfolder and not f.startswith(output_subfolder): continue
272
+ local = TempDir / "inputs" / os.path.basename(f)
273
+ os.makedirs(local.parent, exist_ok=True)
274
+ hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local.parent, local_dir_use_symlinks=False)
275
+ found = list(local.parent.rglob(os.path.basename(f)))
276
+ if found: input_shards.append(found[0])
277
+ if not input_shards: return "No safetensors found."
278
+ input_shards.sort()
279
+ filename_prefix = "diffusion_pytorch_model" if (output_subfolder in ["transformer", "unet"] or "diffusion_pytorch_model" in os.path.basename(input_shards[0])) else "model"
280
+ index_filename = f"{filename_prefix}.safetensors.index.json"
281
+ dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
282
+ try:
283
+ progress(0.15, desc="Downloading LoRA...")
284
+ lora_path = download_lora_smart(lora_input, hf_token)
285
+ lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
286
+ except Exception as e: return f"Error loading LoRA: {e}"
287
+ buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix)
288
+ for i, shard_file in enumerate(input_shards):
289
+ progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}")
290
+ with MemoryEfficientSafeOpen(shard_file) as f:
291
+ for k in f.keys():
292
+ v = f.get_tensor(k)
293
+ base_stem = get_key_stem(k)
294
+ match = lora_pairs.get(base_stem)
295
+ if not match:
296
+ if "to_q" in base_stem: match = lora_pairs.get(base_stem.replace("to_q", "qkv"))
297
+ elif "to_k" in base_stem: match = lora_pairs.get(base_stem.replace("to_k", "qkv"))
298
+ elif "to_v" in base_stem: match = lora_pairs.get(base_stem.replace("to_v", "qkv"))
299
+ if match:
300
+ down, up = match["down"], match["up"]
301
+ scaling = scale * (match["alpha"] / match["rank"])
302
+ if len(v.shape) == 4 and len(down.shape) == 2:
303
+ down, up = down.unsqueeze(-1).unsqueeze(-1), up.unsqueeze(-1).unsqueeze(-1)
304
+ try:
305
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1) if len(up.shape) == 4 else up @ down
306
+ except: delta = up.T @ down
307
+ delta = delta * scaling
308
+ if delta.shape == v.shape: v = v.to(dtype).add_(delta.to(dtype))
309
+ del delta
310
+ buffer.add_tensor(k, v.to(dtype))
311
+ del v
312
+ os.remove(shard_file)
313
+ gc.collect()
314
+ buffer.flush()
315
+ index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map}
316
+ path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename
317
+ with open(TempDir / index_filename, "w") as f: json.dump(index_data, f, indent=4)
318
+ api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
319
+ cleanup_temp()
320
+ return f"Done! Merged {buffer.shard_count} shards to {output_repo}"
321
+
322
+ # =================================================================================
323
+ # TAB 2: EXTRACT LORA
324
+ # =================================================================================
325
+
326
+ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
327
+ org = MemoryEfficientSafeOpen(model_org)
328
+ tuned = MemoryEfficientSafeOpen(model_tuned)
329
+ lora_sd = {}
330
+ keys = set(org.keys()).intersection(set(tuned.keys()))
331
+ for key in tqdm(keys, desc="Extracting"):
332
+ if "num_batches_tracked" in key or "running_mean" in key or "running_var" in key: continue
333
+ mat_org = org.get_tensor(key).float()
334
+ mat_tuned = tuned.get_tensor(key).float()
335
+ if mat_org.shape != mat_tuned.shape: continue
336
+ diff = mat_tuned - mat_org
337
+ if torch.max(torch.abs(diff)) < 1e-4: continue
338
+ out_dim, in_dim = diff.shape[0], diff.shape[1] if len(diff.shape) > 1 else 1
339
+ r = min(rank, in_dim, out_dim)
340
+ is_conv = len(diff.shape) == 4
341
+ if is_conv: diff = diff.flatten(start_dim=1)
342
+ elif len(diff.shape) == 1: diff = diff.unsqueeze(1)
343
+ U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
344
+ Vh = V.t()
345
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
346
+ U = U @ torch.diag(S)
347
+ dist = torch.cat([U.flatten(), Vh.flatten()])
348
+ hi_val = torch.quantile(torch.abs(dist), clamp)
349
+ if hi_val > 0:
350
+ U, Vh = U.clamp(-hi_val, hi_val), Vh.clamp(-hi_val, hi_val)
351
+ if is_conv:
352
+ U = U.reshape(out_dim, r, 1, 1)
353
+ Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
354
+ else:
355
+ U = U.reshape(out_dim, r)
356
+ Vh = Vh.reshape(r, in_dim)
357
+ stem = key.replace(".weight", "")
358
+ lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
359
+ lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
360
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
361
+ out = TempDir / "extracted.safetensors"
362
+ save_file(lora_sd, out)
363
+ return str(out)
364
+
365
+ def task_extract(hf_token, org, tun, rank, out):
366
+ cleanup_temp()
367
+ if hf_token: login(hf_token.strip())
368
+ try:
369
+ p1 = identify_and_download_model(org, hf_token)
370
+ p2 = identify_and_download_model(tun, hf_token)
371
+ f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
372
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
373
+ api.upload_file(path_or_fileobj=f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
374
+ return "Done! Extracted to " + out
375
+ except Exception as e: return f"Error: {e}"
376
+
377
+ # =================================================================================
378
+ # TAB 3: MERGE ADAPTERS
379
+ # =================================================================================
380
+
381
+ def load_full_state_dict(path):
382
+ raw = load_file(path, device="cpu")
383
+ cleaned = {}
384
+ for k, v in raw.items():
385
+ if "lora_A" in k: new_k = k.replace("lora_A", "lora_down")
386
+ elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up")
387
+ else: new_k = k
388
+ cleaned[new_k] = v.float()
389
+ return cleaned
390
+
391
+ def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private):
392
+ cleanup_temp()
393
+ if hf_token: login(hf_token.strip())
394
+ urls = [line.strip() for line in inputs_text.replace(" ", "\n").split('\n') if line.strip()]
395
+ if len(urls) < 2: return "Error: Please provide at least 2 adapters."
396
+ try:
397
+ weights = [float(w.strip()) for w in weight_str.split(',')] if weight_str.strip() else [1.0] * len(urls)
398
+ if len(weights) < len(urls): weights += [1.0] * (len(urls) - len(weights))
399
+ except: return "Error parsing weights."
400
+ paths = []
401
+ try:
402
+ for url in tqdm(urls, desc="Downloading Adapters"): paths.append(download_lora_smart(url, hf_token))
403
+ except Exception as e: return f"Download Error: {e}"
404
+ merged = None
405
+ if "Iterative EMA" in method:
406
+ base_sd = load_file(paths[0], device="cpu")
407
+ for k in base_sd:
408
+ if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
409
+ gamma = None
410
+ if sigma_rel > 0:
411
+ t_val = sigma_rel**-2
412
+ roots = np.roots([1, 7, 16 - t_val, 12 - t_val])
413
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
414
+ for i, path in enumerate(paths[1:]):
415
+ current_beta = (1 - 1 / (i + 1)) ** (gamma + 1) if gamma is not None else beta
416
+ curr = load_file(path, device="cpu")
417
+ for k in base_sd:
418
+ if k in curr and "alpha" not in k:
419
+ base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
420
+ merged = base_sd
421
+ else:
422
+ states = [load_full_state_dict(p) for p in paths]
423
+ merged = {}
424
+ all_stems = set()
425
+ for s in states:
426
+ for k in s:
427
+ if "lora_" in k: all_stems.add(k.split(".lora_")[0])
428
+ for stem in tqdm(all_stems):
429
+ down_list, up_list = [], []
430
+ alpha_sum = 0.0
431
+ total_delta = None
432
+ for i, state in enumerate(states):
433
+ w = weights[i]
434
+ dk, uk, ak = f"{stem}.lora_down.weight", f"{stem}.lora_up.weight", f"{stem}.alpha"
435
+ if dk in state and uk in state:
436
+ d, u = state[dk], state[uk]
437
+ alpha_sum += state[ak].item() if ak in state else d.shape[0]
438
+ if "Concatenation" in method:
439
+ down_list.append(d)
440
+ up_list.append(u * w)
441
+ elif "SVD" in method:
442
+ rank, alpha = d.shape[0], state[ak].item() if ak in state else d.shape[0]
443
+ scale = (alpha / rank) * w
444
+ delta = ((u.flatten(1) @ d.flatten(1)).reshape(u.shape[0], d.shape[1], d.shape[2], d.shape[3]) if len(d.shape)==4 else u @ d) * scale
445
+ total_delta = delta if total_delta is None else total_delta + delta
446
+ if "Concatenation" in method and down_list:
447
+ merged[f"{stem}.lora_down.weight"] = torch.cat(down_list, dim=0).contiguous()
448
+ merged[f"{stem}.lora_up.weight"] = torch.cat(up_list, dim=1).contiguous()
449
+ merged[f"{stem}.alpha"] = torch.tensor(alpha_sum)
450
+ elif "SVD" in method and total_delta is not None:
451
+ tr = int(target_rank)
452
+ flat = total_delta.flatten(1) if len(total_delta.shape)==4 else total_delta
453
+ try:
454
+ U, S, V = torch.svd_lowrank(flat, q=tr + 4, niter=4)
455
+ Vh = V.t()
456
+ U, S, Vh = U[:, :tr], S[:tr], Vh[:tr, :]
457
+ U = U @ torch.diag(S)
458
+ if len(total_delta.shape) == 4:
459
+ U = U.reshape(total_delta.shape[0], tr, 1, 1)
460
+ Vh = Vh.reshape(tr, total_delta.shape[1], total_delta.shape[2], total_delta.shape[3])
461
+ else:
462
+ U, Vh = U.reshape(total_delta.shape[0], tr), Vh.reshape(tr, total_delta.shape[1])
463
+ merged[f"{stem}.lora_down.weight"] = Vh.contiguous()
464
+ merged[f"{stem}.lora_up.weight"] = U.contiguous()
465
+ merged[f"{stem}.alpha"] = torch.tensor(tr).float()
466
+ except: pass
467
+ out = TempDir / "merged_adapters.safetensors"
468
+ save_file(merged, out)
469
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
470
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
471
+ return f"Success! Merged to {out_repo}"
472
+
473
+ # =================================================================================
474
+ # TAB 4: RESIZE ADAPTER
475
+ # =================================================================================
476
+
477
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
478
+ cleanup_temp()
479
+ if hf_token: login(hf_token.strip())
480
+ path = download_lora_smart(lora_input, hf_token)
481
+ state = load_file(path, device="cpu")
482
+ new_state = {}
483
+ groups = {}
484
+ for k in state:
485
+ simple = k.split(".lora_")[0]
486
+ if simple not in groups: groups[simple] = {}
487
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
488
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
489
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
490
+ target_rank_limit = int(new_rank)
491
+ for stem, g in tqdm(groups.items()):
492
+ if "down" in g and "up" in g:
493
+ down, up = g["down"].float(), g["up"].float()
494
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3]) if len(down.shape)==4 else up @ down
495
+ flat = merged.flatten(1)
496
+ U, S, V = torch.svd_lowrank(flat, q=target_rank_limit + 32)
497
+ Vh = V.t()
498
+ calc_rank = target_rank_limit
499
+ if dynamic_method == "sv_ratio":
500
+ calc_rank = int(torch.sum(S > (S[0] / dynamic_param)).item())
501
+ elif dynamic_method == "sv_cumulative":
502
+ calc_rank = int(torch.searchsorted(torch.cumsum(S, 0) / torch.sum(S), dynamic_param)) + 1
503
+ elif dynamic_method == "sv_fro":
504
+ calc_rank = int(torch.searchsorted(torch.cumsum(S.pow(2), 0) / torch.sum(S.pow(2)), dynamic_param**2)) + 1
505
+ final_rank = max(1, min(calc_rank, target_rank_limit, S.shape[0]))
506
+ U = U[:, :final_rank] @ torch.diag(S[:final_rank])
507
+ Vh = Vh[:final_rank, :]
508
+ if len(down.shape) == 4:
509
+ U = U.reshape(up.shape[0], final_rank, 1, 1)
510
+ Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3])
511
+ new_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
512
+ new_state[f"{stem}.lora_up.weight"] = U.contiguous()
513
+ new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float()
514
+ out = TempDir / "shrunken.safetensors"
515
+ save_file(new_state, out)
516
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
517
+ api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token)
518
+ return "Done"
519
+
520
+ # =================================================================================
521
+ # NEW MERGEKIT HELPERS & CLIs
522
+ # =================================================================================
523
+
524
+ def run_mergekit_cli(config_dict, output_path, hf_token):
525
+ config_file = TempDir / "config.yaml"
526
+ with open(config_file, "w") as f: yaml.dump(config_dict, f, sort_keys=False)
527
+ env = os.environ.copy()
528
+ if hf_token: env["HF_TOKEN"] = hf_token.strip()
529
+
530
+ # We use shlex to construct the command safely, though subprocess takes a list
531
+ cmd = ["mergekit-yaml", str(config_file), str(output_path), "--allow-crimes", "--lazy-unpickle", "--copy-tokenizer"]
532
+
533
+ # Capture output for debugging (simulating gradio_logsview behavior)
534
+ print(f"Running command: {' '.join(cmd)}")
535
+ res = subprocess.run(cmd, env=env, capture_output=True, text=True)
536
+
537
+ if res.returncode != 0:
538
+ print("MergeKit stdout:", res.stdout)
539
+ print("MergeKit stderr:", res.stderr)
540
+ raise RuntimeError(f"MergeKit Error: {res.stderr}")
541
+ return str(output_path)
542
+
543
+ def parse_weight(w_str):
544
+ if not w_str.strip(): return 1.0
545
+ try:
546
+ # Check if it's a list string like "[0, 0.5, 1]"
547
+ if "[" in w_str and "]" in w_str:
548
+ return yaml.safe_load(w_str)
549
+ return float(w_str)
550
+ except: return 1.0
551
+
552
+ # =================================================================================
553
+ # TAB 5: AMPHINTERPOLATIVE
554
+ # =================================================================================
555
+
556
+ def task_amphinterpolative(token, method, base, t, norm, i8, flat, row, eps, m_iter, tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv):
557
+ cleanup_temp()
558
+ if token: login(token.strip())
559
+
560
+ # Construct base params
561
+ params = {"normalize": norm, "int8_mask": i8}
562
+ if method in ["slerp", "nuslerp"]:
563
+ params["t"] = float(t)
564
+ if method == "nuslerp":
565
+ params["flatten"] = flat
566
+ params["row_wise"] = row
567
+ if method == "multislerp":
568
+ params["eps"] = float(eps)
569
+ if method == "karcher":
570
+ params["max_iter"] = int(m_iter)
571
+ params["tol"] = float(tol)
572
+
573
+ config = {
574
+ "merge_method": method,
575
+ "dtype": "bfloat16"
576
+ }
577
+
578
+ # Slerp/NuSlerp often use 'slices'
579
+ if method in ["slerp", "nuslerp"]:
580
+ if not base.strip(): return "Error: Base Model is mandatory for Slerp/NuSlerp."
581
+ config["base_model"] = base.strip()
582
+
583
+ # Build sources list
584
+ sources = []
585
+ for m, w in [(m1,w1), (m2,w2)]: # Slerp takes 2 models usually in the slice
586
+ if m.strip():
587
+ sources.append({"model": m, "parameters": {"weight": parse_weight(w)}})
588
+
589
+ # Slerp requires slices. We define one slice for the whole model.
590
+ config["slices"] = [{"sources": sources, "parameters": params}]
591
+ else:
592
+ # MultiSlerp/Karcher use 'models' list
593
+ if base.strip() and method == "multislerp": config["base_model"] = base.strip()
594
+
595
+ models = []
596
+ for m, w in [(m1, w1), (m2, w2), (m3, w3), (m4, w4), (m5, w5)]:
597
+ if m.strip():
598
+ models.append({"model": m, "parameters": {"weight": parse_weight(w)}})
599
+ config["models"] = models
600
+ config["parameters"] = params
601
+
602
+ try:
603
+ path = run_mergekit_cli(config, TempDir / "out", token)
604
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
605
+ api.upload_folder(folder_path=path, repo_id=out, token=token)
606
+ return f"Success! Uploaded to {out}"
607
+ except Exception as e: return f"Error: {str(e)}"
608
+
609
+ # =================================================================================
610
+ # TAB 6: STIR/TIE BASES
611
+ # =================================================================================
612
+
613
+ def task_stirtie(token, method, base, norm, i8, lamb, resc, topk, m1, w1, d1, g1, e1, m2, w2, d2, g2, e2, m3, w3, d3, g3, e3, m4, w4, d4, g4, e4, out, priv):
614
+ cleanup_temp()
615
+ if token: login(token.strip())
616
+
617
+ models_config = []
618
+ # Collect models
619
+ for m, w, d, g, e in [(m1,w1,d1,g1,e1), (m2,w2,d2,g2,e2), (m3,w3,d3,g3,e3), (m4,w4,d4,g4,e4)]:
620
+ if not m.strip(): continue
621
+ p = {"weight": parse_weight(w)}
622
+
623
+ # Add specific per-model params
624
+ if method in ["ties", "dare_ties", "dare_linear", "breadcrumbs_ties"]:
625
+ p["density"] = parse_weight(d)
626
+ if method in ["breadcrumbs", "breadcrumbs_ties"]:
627
+ p["gamma"] = float(g)
628
+ if method in ["della", "della_linear"]:
629
+ p["epsilon"] = float(e)
630
+
631
+ models_config.append({"model": m, "parameters": p})
632
+
633
+ # Global Parameters
634
+ global_params = {"normalize": norm, "int8_mask": i8}
635
+ if method != "sce":
636
+ global_params["lambda"] = float(lamb)
637
+ if method == "dare_linear":
638
+ global_params["rescale"] = resc
639
+ if method == "sce":
640
+ global_params["select_topk"] = float(topk)
641
+
642
+ config = {
643
+ "merge_method": method,
644
+ "base_model": base.strip() if base.strip() else models_config[0]["model"],
645
+ "dtype": "bfloat16",
646
+ "parameters": global_params,
647
+ "models": models_config
648
+ }
649
+
650
+ try:
651
+ path = run_mergekit_cli(config, TempDir / "out", token)
652
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
653
+ api.upload_folder(folder_path=path, repo_id=out, token=token)
654
+ return f"Success! Uploaded to {out}"
655
+ except Exception as e: return f"Error: {str(e)}"
656
+
657
+ # =================================================================================
658
+ # TAB 7: SPECIOUS
659
+ # =================================================================================
660
+
661
+ def task_specious(token, method, base, norm, i8, t, filt_w, m1, w1, f1, m2, w2, m3, w3, m4, w4, m5, w5, out, priv):
662
+ cleanup_temp()
663
+ if token: login(token.strip())
664
+
665
+ model_configs = []
666
+
667
+ if method == "passthrough":
668
+ # Passthrough takes exactly 1 model
669
+ if not m1.strip(): return "Error: Model 1 required for passthrough"
670
+ p = {"weight": parse_weight(w1)}
671
+ if f1.strip(): p["filter"] = f1.strip()
672
+ model_configs.append({"model": m1, "parameters": p})
673
+ else:
674
+ for m, w in [(m1,w1), (m2,w2), (m3,w3), (m4,w4), (m5,w5)]:
675
+ if not m.strip(): continue
676
+ model_configs.append({"model": m, "parameters": {"weight": parse_weight(w)}})
677
+
678
+ config = {
679
+ "merge_method": method,
680
+ "dtype": "bfloat16",
681
+ "parameters": {"normalize": norm, "int8_mask": i8}
682
+ }
683
+
684
+ if base.strip(): config["base_model"] = base.strip()
685
+
686
+ if method == "nearswap":
687
+ config["parameters"]["t"] = float(t)
688
+ if method == "model_stock":
689
+ config["parameters"]["filter_wise"] = filt_w
690
+
691
+ config["models"] = model_configs
692
+
693
+ try:
694
+ path = run_mergekit_cli(config, TempDir / "out", token)
695
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
696
+ api.upload_folder(folder_path=path, repo_id=out, token=token)
697
+ return f"Success! Uploaded to {out}"
698
+ except Exception as e: return f"Error: {str(e)}"
699
+
700
+ # =================================================================================
701
+ # TAB 8: MoEr (Mixture of Experts)
702
+ # =================================================================================
703
+
704
+ def task_moer(token, base, experts_text, gate_mode, dtype, out, priv):
705
+ cleanup_temp()
706
+ if token: login(token.strip())
707
+
708
+ experts_list = [e.strip() for e in experts_text.split('\n') if e.strip()]
709
+ if not experts_list: return "Error: No experts provided."
710
+
711
+ # Construct Experts List with positive_prompts (required by MergeKit config schema)
712
+ formatted_experts = []
713
+ for e in experts_list:
714
+ formatted_experts.append({
715
+ "source_model": e,
716
+ "positive_prompts": [
717
+ "chat",
718
+ "assist",
719
+ "tell me",
720
+ "explain"
721
+ ] # Generic prompts to satisfy schema
722
+ })
723
+
724
+ config = {
725
+ "base_model": base.strip() if base.strip() else experts_list[0],
726
+ "gate_mode": gate_mode,
727
+ "dtype": dtype,
728
+ "experts": formatted_experts
729
+ }
730
+
731
+ try:
732
+ path = run_mergekit_cli(config, TempDir / "out", token)
733
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
734
+ api.upload_folder(folder_path=path, repo_id=out, token=token)
735
+ return f"Success! Uploaded to {out}"
736
+ except Exception as e: return f"Error: {str(e)}"
737
+
738
+ # =================================================================================
739
+ # TAB 9: Rawer (Raw PyTorch)
740
+ # =================================================================================
741
+
742
+ def task_rawer(token, models_text, method, dtype, out, priv):
743
+ cleanup_temp()
744
+ if token: login(token.strip())
745
+
746
+ models = [m.strip() for m in models_text.split('\n') if m.strip()]
747
+ if not models: return "Error: No models provided."
748
+
749
+ # Raw merge configuration
750
+ config = {
751
+ "models": [{"model": m, "parameters": {"weight": 1.0}} for m in models],
752
+ "merge_method": method,
753
+ "dtype": dtype
754
+ }
755
+
756
+ try:
757
+ path = run_mergekit_cli(config, TempDir / "out", token)
758
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
759
+ api.upload_folder(folder_path=path, repo_id=out, token=token)
760
+ return f"Success! Uploaded to {out}"
761
+ except Exception as e: return f"Error: {str(e)}"
762
+
763
+ # =================================================================================
764
+ # TAB 10: MARIO, DARE! (Custom Logic)
765
+ # =================================================================================
766
+
767
+ def task_mario_dare(token, base, ft, ratio, mask, out, priv):
768
+ cleanup_temp()
769
+ if token: login(token.strip())
770
+
771
+ try:
772
+ # 1. Download Models
773
+ print(f"Downloading Base: {base}")
774
+ base_path = identify_and_download_model(base, token)
775
+ print(f"Downloading FT: {ft}")
776
+ ft_path = identify_and_download_model(ft, token)
777
+
778
+ # 2. Load Tensors
779
+ base_sd = load_file(base_path, device="cpu")
780
+ ft_sd = load_file(ft_path, device="cpu")
781
+
782
+ merged_sd = {}
783
+ keys = set(base_sd.keys()).intersection(set(ft_sd.keys()))
784
+
785
+ # 3. Apply DARE Logic (as per provided merge.py logic)
786
+ # delta = ft - base
787
+ # m = bernoulli(1 - p)
788
+ # delta_hat = (m * delta) / (1 - p)
789
+ # merged = base + lambda * delta_hat
790
+
791
+ print("Merging tensors...")
792
+ for k in tqdm(keys):
793
+ t1 = base_sd[k] # Base
794
+ t2 = ft_sd[k] # FT
795
+
796
+ # Simple shape check / resizing if needed (simplified)
797
+ if t1.shape != t2.shape:
798
+ merged_sd[k] = t2 # Fallback to FT if shapes mismatch significantly
799
+ continue
800
+
801
+ delta = t2.float() - t1.float()
802
+
803
+ # Masking
804
+ if mask > 0:
805
+ m = torch.bernoulli(torch.full_like(delta, 1.0 - mask))
806
+ delta = delta * m
807
+ # Rescale
808
+ delta = delta / (1.0 - mask)
809
+
810
+ # Scale by Ratio (lambda) and add to base
811
+ res = t1.float() + (ratio * delta)
812
+
813
+ # Cast back
814
+ if t1.dtype == torch.bfloat16:
815
+ merged_sd[k] = res.bfloat16()
816
+ elif t1.dtype == torch.float16:
817
+ merged_sd[k] = res.half()
818
+ else:
819
+ merged_sd[k] = res
820
+
821
+ # 4. Save and Upload
822
+ out_path = TempDir / "model.safetensors"
823
+ save_file(merged_sd, out_path)
824
+
825
+ api.create_repo(repo_id=out, private=priv, exist_ok=True, token=token)
826
+ api.upload_file(path_or_fileobj=out_path, path_in_repo="model.safetensors", repo_id=out, token=token)
827
+ return f"Success! Uploaded to {out}"
828
+
829
+ except Exception as e:
830
+ return f"DARE Error: {str(e)}"
831
+
832
+ # =================================================================================
833
+ # UI GENERATION
834
+ # =================================================================================
835
+
836
+ css = ".container { max-width: 1100px; margin: auto; }"
837
+
838
+ with gr.Blocks() as demo:
839
+ gr.HTML("""<h1><img src="https://huggingface.co/spaces/AlekseyCalvin/Soon_Merger/resolve/main/SMerger3.png" alt="SOONmerge®"> Transform Transformers for FREE!</h1>""")
840
+ gr.Markdown("# 🧰Training-Free CPU-run Model Creation Toolkit")
841
+
842
+ with gr.Tabs():
843
+ # --- TAB 1 (PRESERVED) ---
844
+ with gr.Tab("Merge to Base Model + Reshard Output"):
845
+ t1_token = gr.Textbox(label="Token", type="password")
846
+ t1_base = gr.Textbox(label="Base Repo", value="name/repo")
847
+ t1_sub = gr.Textbox(label="Subfolder", value="")
848
+ t1_lora = gr.Textbox(label="LoRA", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
849
+ with gr.Row():
850
+ t1_scale = gr.Slider(0, 3, 1, step=0.1, label="Scale")
851
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
852
+ t1_shard = gr.Slider(0.1, 10, 2, label="Shard GB")
853
+ t1_out = gr.Textbox(label="Output Repo")
854
+ t1_struct = gr.Textbox(label="Extras Source", value="name/repo")
855
+ t1_priv = gr.Checkbox(label="Private", value=True)
856
+ t1_btn = gr.Button("Merge")
857
+ t1_res = gr.Textbox(label="Result")
858
+ t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
859
+
860
+ # --- TAB 2 (PRESERVED) ---
861
+ with gr.Tab("Extract Adapter"):
862
+ t2_token = gr.Textbox(label="Token", type="password")
863
+ t2_org = gr.Textbox(label="Original")
864
+ t2_tun = gr.Textbox(label="Tuned")
865
+ t2_rank = gr.Number(label="Rank", value=32)
866
+ t2_out = gr.Textbox(label="Output")
867
+ t2_btn = gr.Button("Extract")
868
+ t2_res = gr.Textbox(label="Result")
869
+ t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
870
+
871
+ # --- TAB 3 (PRESERVED) ---
872
+ with gr.Tab("Merge Adapters"):
873
+ t3_token = gr.Textbox(label="Token", type="password")
874
+ t3_urls = gr.TextArea(label="URLs")
875
+ t3_method = gr.Dropdown(["Iterative EMA", "Concatenation", "SVD Fusion"], value="Iterative EMA")
876
+ t3_weights = gr.Textbox(label="Weights")
877
+ t3_rank = gr.Number(label="Rank", value=128)
878
+ with gr.Row():
879
+ t3_beta = gr.Slider(0.01, 1, 0.95, label="Beta")
880
+ t3_sigma = gr.Slider(0.01, 1, 0.21, label="Sigma")
881
+ t3_out = gr.Textbox(label="Output")
882
+ t3_priv = gr.Checkbox(label="Private", value=True)
883
+ t3_btn = gr.Button("Merge")
884
+ t3_res = gr.Textbox(label="Result")
885
+ t3_btn.click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], t3_res)
886
+
887
+ # --- TAB 4 (PRESERVED) ---
888
+ with gr.Tab("Resize Adapter"):
889
+ t4_token = gr.Textbox(label="Token", type="password")
890
+ t4_in = gr.Textbox(label="LoRA")
891
+ t4_rank = gr.Number(label="To Rank", value=8)
892
+ t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None")
893
+ t4_param = gr.Number(label="Param", value=0.9)
894
+ t4_out = gr.Textbox(label="Output")
895
+ t4_btn = gr.Button("Resize")
896
+ t4_res = gr.Textbox(label="Result")
897
+ t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
898
+
899
+ # --- TAB 5: AMPHINTERPOLATIVE ---
900
+ with gr.Tab("Amphinterpolative"):
901
+ gr.Markdown("### Spherical Interpolation Family")
902
+ t5_token = gr.Textbox(label="HF Token", type="password")
903
+ t5_method = gr.Dropdown(["slerp", "nuslerp", "multislerp", "karcher"], value="slerp", label="Method")
904
+ with gr.Row():
905
+ t5_base = gr.Textbox(label="Base Model (Mandatory for slerp/nuslerp)")
906
+ t5_t = gr.Slider(0, 1, 0.5, label="t (Interpolation)")
907
+ with gr.Row():
908
+ t5_norm = gr.Checkbox(label="Normalize", value=True)
909
+ t5_i8 = gr.Checkbox(label="Int8 Mask", value=False)
910
+ t5_flat = gr.Checkbox(label="NuSlerp Flatten", value=False)
911
+ t5_row = gr.Checkbox(label="NuSlerp Row Wise", value=False)
912
+ with gr.Row():
913
+ t5_eps = gr.Textbox(label="eps (MultiSlerp)", value="1e-8")
914
+ t5_iter = gr.Number(label="max_iter (Karcher)", value=10)
915
+ t5_tol = gr.Textbox(label="tol (Karcher)", value="1e-5")
916
+
917
+ with gr.Row():
918
+ m1, w1 = gr.Textbox(label="Model 1"), gr.Textbox(label="Weight 1", value="1.0")
919
+ m2, w2 = gr.Textbox(label="Model 2"), gr.Textbox(label="Weight 2", value="1.0")
920
+ with gr.Accordion("More Models (MultiSlerp/Karcher)", open=False):
921
+ with gr.Row():
922
+ m3, w3 = gr.Textbox(label="Model 3"), gr.Textbox(label="Weight 3", value="1.0")
923
+ m4, w4 = gr.Textbox(label="Model 4"), gr.Textbox(label="Weight 4", value="1.0")
924
+ m5, w5 = gr.Textbox(label="Model 5"), gr.Textbox(label="Weight 5", value="1.0")
925
+
926
+ t5_out = gr.Textbox(label="Output Repo")
927
+ t5_priv = gr.Checkbox(label="Private", value=True)
928
+ t5_btn = gr.Button("Execute Amphinterpolative Merge")
929
+ t5_res = gr.Textbox(label="Result")
930
+ t5_btn.click(task_amphinterpolative, [t5_token, t5_method, t5_base, t5_t, t5_norm, t5_i8, t5_flat, t5_row, t5_eps, t5_iter, t5_tol, m1, w1, m2, w2, m3, w3, m4, w4, m5, w5, t5_out, t5_priv], t5_res)
931
+
932
+ # --- TAB 6: STIR/TIE BASES ---
933
+ with gr.Tab("Stir/Tie Bases"):
934
+ gr.Markdown("### Task Vector Family")
935
+ t6_token = gr.Textbox(label="Token", type="password")
936
+ t6_method = gr.Dropdown(["task_arithmetic", "ties", "dare_ties", "dare_linear", "della", "della_linear", "breadcrumbs", "breadcrumbs_ties", "sce"], value="ties", label="Method")
937
+ t6_base = gr.Textbox(label="Base Model")
938
+ with gr.Row():
939
+ t6_norm = gr.Checkbox(label="Normalize", value=True)
940
+ t6_i8 = gr.Checkbox(label="Int8 Mask", value=False)
941
+ t6_resc = gr.Checkbox(label="Rescale (Dare Linear)", value=True)
942
+ with gr.Row():
943
+ t6_lamb = gr.Number(label="Lambda", value=1.0)
944
+ t6_topk = gr.Slider(0, 1, 1.0, label="Select TopK (SCE)")
945
+
946
+ with gr.Row():
947
+ m1_6, w1_6 = gr.Textbox(label="M1"), gr.Textbox(label="W1", value="1.0")
948
+ d1_6, g1_6, e1_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
949
+ with gr.Row():
950
+ m2_6, w2_6 = gr.Textbox(label="M2"), gr.Textbox(label="W2", value="1.0")
951
+ d2_6, g2_6, e2_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
952
+ with gr.Accordion("More Models", open=False):
953
+ with gr.Row():
954
+ m3_6, w3_6 = gr.Textbox(label="M3"), gr.Textbox(label="W3", value="1.0")
955
+ d3_6, g3_6, e3_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
956
+ with gr.Row():
957
+ m4_6, w4_6 = gr.Textbox(label="M4"), gr.Textbox(label="W4", value="1.0")
958
+ d4_6, g4_6, e4_6 = gr.Textbox(label="Density", value="1.0"), gr.Number(label="Gamma", value=0.01), gr.Number(label="Epsilon", value=0.15)
959
+
960
+ t6_out = gr.Textbox(label="Output Repo")
961
+ t6_priv = gr.Checkbox(label="Private", value=True)
962
+ t6_btn = gr.Button("Execute Stir/Tie Merge")
963
+ t6_res = gr.Textbox(label="Result")
964
+ t6_btn.click(task_stirtie, [t6_token, t6_method, t6_base, t6_norm, t6_i8, t6_lamb, t6_resc, t6_topk, m1_6, w1_6, d1_6, g1_6, e1_6, m2_6, w2_6, d2_6, g2_6, e2_6, m3_6, w3_6, d3_6, g3_6, e3_6, m4_6, w4_6, d4_6, g4_6, e4_6, t6_out, t6_priv], t6_res)
965
+
966
+ # --- TAB 7: SPECIOUS ---
967
+ with gr.Tab("Specious"):
968
+ gr.Markdown("### Specialized Methods")
969
+ t7_token = gr.Textbox(label="Token", type="password")
970
+ t7_method = gr.Dropdown(["model_stock", "nearswap", "arcee_fusion", "passthrough", "linear"], value="model_stock", label="Method")
971
+ t7_base = gr.Textbox(label="Base Model (Optional depending on method)")
972
+ with gr.Row():
973
+ t7_norm = gr.Checkbox(label="Normalize", value=True)
974
+ t7_i8 = gr.Checkbox(label="Int8 Mask", value=False)
975
+ t7_t = gr.Slider(0, 1, 0.5, label="t (Nearswap)")
976
+ t7_filt_w = gr.Checkbox(label="Filter Wise (Model Stock)", value=False)
977
+
978
+ with gr.Row():
979
+ m1_7, w1_7 = gr.Textbox(label="M1"), gr.Textbox(label="W1", value="1.0")
980
+ f1_7 = gr.Textbox(label="Filter (Passthrough only)", placeholder="e.g. down_proj")
981
+ with gr.Row():
982
+ m2_7, w2_7 = gr.Textbox(label="M2"), gr.Textbox(label="W2", value="1.0")
983
+ with gr.Accordion("More Models", open=False):
984
+ m3_7, w3_7 = gr.Textbox(label="M3"), gr.Textbox(label="W3", value="1.0")
985
+ m4_7, w4_7 = gr.Textbox(label="M4"), gr.Textbox(label="W4", value="1.0")
986
+ m5_7, w5_7 = gr.Textbox(label="M5"), gr.Textbox(label="W5", value="1.0")
987
+
988
+ t7_out = gr.Textbox(label="Output Repo")
989
+ t7_priv = gr.Checkbox(label="Private", value=True)
990
+ t7_btn = gr.Button("Execute Specious Merge")
991
+ t7_res = gr.Textbox(label="Result")
992
+ t7_btn.click(task_specious, [t7_token, t7_method, t7_base, t7_norm, t7_i8, t7_t, t7_filt_w, m1_7, w1_7, f1_7, m2_7, w2_7, m3_7, w3_7, m4_7, w4_7, m5_7, w5_7, t7_out, t7_priv], t7_res)
993
+
994
+ # --- TAB 8: MoEr ---
995
+ with gr.Tab("MoEr"):
996
+ gr.Markdown("### Mixture of Experts (MergeKit)")
997
+ t8_token = gr.Textbox(label="Token", type="password")
998
+ t8_base = gr.Textbox(label="Base Model")
999
+ t8_experts = gr.TextArea(label="Experts List (one per line)")
1000
+ with gr.Row():
1001
+ t8_gate = gr.Dropdown(["cheap_embed", "random", "hidden"], value="cheap_embed", label="Gate Mode")
1002
+ t8_dtype = gr.Dropdown(["float16", "bfloat16"], value="bfloat16", label="Dtype")
1003
+ t8_out = gr.Textbox(label="Output Repo")
1004
+ t8_priv = gr.Checkbox(label="Private", value=True)
1005
+ t8_btn = gr.Button("Build MoE")
1006
+ t8_res = gr.Textbox(label="Result")
1007
+ t8_btn.click(task_moer, [t8_token, t8_base, t8_experts, t8_gate, t8_dtype, t8_out, t8_priv], t8_res)
1008
+
1009
+ # --- TAB 9: Rawer ---
1010
+ with gr.Tab("Rawer"):
1011
+ gr.Markdown("### Raw PyTorch / Non-Transformer")
1012
+ t9_token = gr.Textbox(label="Token", type="password")
1013
+ t9_models = gr.TextArea(label="Models (one per line)")
1014
+ t9_method = gr.Dropdown(["linear", "passthrough"], value="linear", label="Method")
1015
+ t9_dtype = gr.Dropdown(["float32", "float16", "bfloat16"], value="float32", label="Dtype")
1016
+ t9_out = gr.Textbox(label="Output Repo")
1017
+ t9_priv = gr.Checkbox(label="Private", value=True)
1018
+ t9_btn = gr.Button("Merge Raw")
1019
+ t9_res = gr.Textbox(label="Result")
1020
+ t9_btn.click(task_rawer, [t9_token, t9_models, t9_method, t9_dtype, t9_out, t9_priv], t9_res)
1021
+
1022
+ # --- TAB 10: MARIO, DARE! ---
1023
+ with gr.Tab("Mario,DARE!"):
1024
+ gr.Markdown("### Custom DARE Implementation")
1025
+ t10_token = gr.Textbox(label="Token", type="password")
1026
+ with gr.Row():
1027
+ t10_base = gr.Textbox(label="Base Model")
1028
+ t10_ft = gr.Textbox(label="Fine-Tuned Model")
1029
+ with gr.Row():
1030
+ t10_ratio = gr.Slider(0, 5, 1.0, label="Ratio (Lambda)")
1031
+ t10_mask = gr.Slider(0, 0.99, 0.5, label="Mask Rate (Drop)")
1032
+ t10_out = gr.Textbox(label="Output Repo")
1033
+ t10_priv = gr.Checkbox(label="Private", value=True)
1034
+ t10_btn = gr.Button("Run Mario,DARE!")
1035
+ t10_res = gr.Textbox(label="Result")
1036
+ t10_btn.click(task_mario_dare, [t10_token, t10_base, t10_ft, t10_ratio, t10_mask, t10_out, t10_priv], t10_res)
1037
+
1038
+ if __name__ == "__main__":
1039
+ demo.queue().launch(css=css, ssr_mode=False)