AlekseyCalvin commited on
Commit
ee40bb0
·
verified ·
1 Parent(s): 49dc183

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1141 -0
app.py ADDED
@@ -0,0 +1,1141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import re
11
+ import yaml
12
+ from merge_utils import execute_mergekit
13
+ from pathlib import Path
14
+ from typing import Dict, Any, Optional, List
15
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
16
+ from safetensors.torch import load_file, save_file
17
+ from tqdm import tqdm
18
+
19
+ # --- Memory Efficient Safetensors ---
20
+ class MemoryEfficientSafeOpen:
21
+ def __init__(self, filename):
22
+ self.filename = filename
23
+ self.file = open(filename, "rb")
24
+ self.header, self.header_size = self._read_header()
25
+
26
+ def __enter__(self):
27
+ return self
28
+
29
+ def __exit__(self, exc_type, exc_val, exc_tb):
30
+ self.file.close()
31
+
32
+ def keys(self) -> list[str]:
33
+ return [k for k in self.header.keys() if k != "__metadata__"]
34
+
35
+ def metadata(self) -> Dict[str, str]:
36
+ return self.header.get("__metadata__", {})
37
+
38
+ def get_tensor(self, key):
39
+ if key not in self.header:
40
+ raise KeyError(f"Tensor '{key}' not found in the file")
41
+ metadata = self.header[key]
42
+ offset_start, offset_end = metadata["data_offsets"]
43
+ self.file.seek(self.header_size + 8 + offset_start)
44
+ tensor_bytes = self.file.read(offset_end - offset_start)
45
+ return self._deserialize_tensor(tensor_bytes, metadata)
46
+
47
+ def _read_header(self):
48
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
49
+ header_json = self.file.read(header_size).decode("utf-8")
50
+ return json.loads(header_json), header_size
51
+
52
+ def _deserialize_tensor(self, tensor_bytes, metadata):
53
+ dtype_map = {
54
+ "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
55
+ "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
56
+ "U8": torch.uint8, "BOOL": torch.bool
57
+ }
58
+ dtype = dtype_map[metadata["dtype"]]
59
+ shape = metadata["shape"]
60
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
61
+
62
+ # --- Constants & Setup ---
63
+ try:
64
+ TempDir = Path("/tmp/temp_tool")
65
+ os.makedirs(TempDir, exist_ok=True)
66
+ except:
67
+ TempDir = Path("./temp_tool")
68
+ os.makedirs(TempDir, exist_ok=True)
69
+
70
+ api = HfApi()
71
+
72
+ def cleanup_temp():
73
+ if TempDir.exists():
74
+ shutil.rmtree(TempDir)
75
+ os.makedirs(TempDir, exist_ok=True)
76
+ gc.collect()
77
+
78
+ def get_key_stem(key):
79
+ key = key.replace(".weight", "").replace(".bias", "")
80
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
81
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
82
+ key = key.replace(".alpha", "")
83
+ prefixes = [
84
+ "model.diffusion_model.", "diffusion_model.", "model.",
85
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
86
+ ]
87
+ changed = True
88
+ while changed:
89
+ changed = False
90
+ for p in prefixes:
91
+ if key.startswith(p):
92
+ key = key[len(p):]
93
+ changed = True
94
+ return key
95
+
96
+ # =================================================================================
97
+ # TAB 1: MERGE & RESHARD
98
+ # =================================================================================
99
+
100
+ def parse_hf_url(url):
101
+ """Parses a direct HF URL into repo_id and filename."""
102
+ # Pattern: https://huggingface.co/{user}/{repo}/resolve/{branch}/{filename...}
103
+ if "huggingface.co" in url and "resolve" in url:
104
+ try:
105
+ parts = url.split("huggingface.co/")[-1].split("/")
106
+ # parts[0]=user, parts[1]=repo, parts[2]=resolve, parts[3]=branch, parts[4:]=file
107
+ repo_id = f"{parts[0]}/{parts[1]}"
108
+ filename = "/".join(parts[4:]).split("?")[0] # Strip query params
109
+ return repo_id, filename
110
+ except:
111
+ return None, None
112
+ return None, None
113
+
114
+ def download_lora_smart(input_str, token):
115
+ local_path = TempDir / "adapter.safetensors"
116
+ if local_path.exists(): os.remove(local_path)
117
+
118
+ print(f"Resolving LoRA Input: {input_str}")
119
+
120
+ # 1. Try Parse as HF URL (Most Robust Method)
121
+ repo_id, filename = parse_hf_url(input_str)
122
+ if repo_id and filename:
123
+ print(f"Detected HF URL. Repo: {repo_id}, File: {filename}")
124
+ try:
125
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
126
+ # Move to standard name
127
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0] # Handle subfolder downloads
128
+ if found != local_path: shutil.move(found, local_path)
129
+ return local_path
130
+ except Exception as e:
131
+ print(f"HF Download failed: {e}. Falling back...")
132
+
133
+ # 2. Try as Raw Repo ID (User/Repo)
134
+ try:
135
+ # Check if user put "User/Repo/file.safetensors"
136
+ if ".safetensors" in input_str and input_str.count("/") >= 2:
137
+ parts = input_str.split("/")
138
+ repo_id = f"{parts[0]}/{parts[1]}"
139
+ filename = "/".join(parts[2:])
140
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
141
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
142
+ if found != local_path: shutil.move(found, local_path)
143
+ return local_path
144
+
145
+ # Standard Auto-Discovery
146
+ candidates = ["adapter_model.safetensors", "model.safetensors"]
147
+ files = list_repo_files(repo_id=input_str, token=token)
148
+ target = next((f for f in files if f in candidates), None)
149
+ if not target:
150
+ safes = [f for f in files if f.endswith(".safetensors")]
151
+ if safes: target = safes[0]
152
+
153
+ if not target: raise ValueError("No safetensors found")
154
+
155
+ hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
156
+ found = list(TempDir.rglob(target.split("/")[-1]))[0]
157
+ if found != local_path: shutil.move(found, local_path)
158
+ return local_path
159
+
160
+ except Exception as e:
161
+ # 3. Last Resort: Raw Requests (For non-HF links)
162
+ if input_str.startswith("http"):
163
+ try:
164
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
165
+ r = requests.get(input_str, stream=True, headers=headers, timeout=60)
166
+ r.raise_for_status()
167
+ with open(local_path, 'wb') as f:
168
+ for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
169
+ return local_path
170
+ except Exception as req_e:
171
+ raise ValueError(f"All download methods failed.\nRepo Logic Error: {e}\nURL Logic Error: {req_e}")
172
+ raise e
173
+
174
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
175
+ print(f"Loading LoRA from {lora_path}...")
176
+ state_dict = load_file(lora_path, device="cpu")
177
+ pairs = {}
178
+ alphas = {}
179
+ for k, v in state_dict.items():
180
+ stem = get_key_stem(k)
181
+ if "alpha" in k:
182
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
183
+ else:
184
+ if stem not in pairs: pairs[stem] = {}
185
+ if "lora_down" in k or "lora_A" in k:
186
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
187
+ pairs[stem]["rank"] = v.shape[0]
188
+ elif "lora_up" in k or "lora_B" in k:
189
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
190
+ for stem in pairs:
191
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
192
+ return pairs
193
+
194
+ class ShardBuffer:
195
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
196
+ self.max_bytes = int(max_size_gb * 1024**3)
197
+ self.output_dir = output_dir
198
+ self.output_repo = output_repo
199
+ self.subfolder = subfolder
200
+ self.hf_token = hf_token
201
+ self.filename_prefix = filename_prefix
202
+ self.buffer = []
203
+ self.current_bytes = 0
204
+ self.shard_count = 0
205
+ self.index_map = {}
206
+ self.total_size = 0
207
+
208
+ def add_tensor(self, key, tensor):
209
+ if tensor.dtype == torch.bfloat16:
210
+ raw_bytes = tensor.view(torch.int16).numpy().tobytes()
211
+ dtype_str = "BF16"
212
+ elif tensor.dtype == torch.float16:
213
+ raw_bytes = tensor.numpy().tobytes()
214
+ dtype_str = "F16"
215
+ else:
216
+ raw_bytes = tensor.numpy().tobytes()
217
+ dtype_str = "F32"
218
+
219
+ size = len(raw_bytes)
220
+ self.buffer.append({
221
+ "key": key,
222
+ "data": raw_bytes,
223
+ "dtype": dtype_str,
224
+ "shape": tensor.shape
225
+ })
226
+ self.current_bytes += size
227
+ self.total_size += size
228
+
229
+ if self.current_bytes >= self.max_bytes:
230
+ self.flush()
231
+
232
+ def flush(self):
233
+ if not self.buffer: return
234
+ self.shard_count += 1
235
+
236
+ filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
237
+ path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
238
+
239
+ print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...")
240
+
241
+ header = {"__metadata__": {"format": "pt"}}
242
+ current_offset = 0
243
+ for item in self.buffer:
244
+ header[item["key"]] = {
245
+ "dtype": item["dtype"],
246
+ "shape": item["shape"],
247
+ "data_offsets": [current_offset, current_offset + len(item["data"])]
248
+ }
249
+ current_offset += len(item["data"])
250
+ self.index_map[item["key"]] = filename
251
+
252
+ header_json = json.dumps(header).encode('utf-8')
253
+
254
+ out_path = self.output_dir / filename
255
+ with open(out_path, 'wb') as f:
256
+ f.write(struct.pack('<Q', len(header_json)))
257
+ f.write(header_json)
258
+ for item in self.buffer:
259
+ f.write(item["data"])
260
+
261
+ print(f"Uploading {path_in_repo}...")
262
+ api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
263
+
264
+ os.remove(out_path)
265
+ self.buffer = []
266
+ self.current_bytes = 0
267
+ gc.collect()
268
+
269
+ def copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder):
270
+ """Aggressively copy all config/misc files, only skipping heavy weights."""
271
+ print(f"Copying config files from {base_repo}...")
272
+ try:
273
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
274
+ blocked_ext = ['.safetensors', '.bin', '.pt', '.pth', '.msgpack', '.h5', '.onnx']
275
+
276
+ for f in files:
277
+ # Filter by subfolder if needed
278
+ if base_subfolder and not f.startswith(base_subfolder): continue
279
+
280
+ # Block heavy weights
281
+ if any(f.endswith(ext) for ext in blocked_ext): continue
282
+
283
+ print(f"Transferring {f}...")
284
+ local = hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=TempDir)
285
+
286
+ # Determine path in new repo
287
+ rel_name = f[len(base_subfolder):].lstrip('/') if base_subfolder else f
288
+ target_path = f"{output_subfolder}/{rel_name}" if output_subfolder else rel_name
289
+
290
+ api.upload_file(path_or_fileobj=local, path_in_repo=target_path, repo_id=output_repo, token=hf_token)
291
+ os.remove(local)
292
+
293
+ except Exception as e:
294
+ print(f"Config copy warning: {e}")
295
+
296
+ def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix=None, is_root_merge=False):
297
+ print(f"Scanning {src_repo} for structure cloning...")
298
+ try:
299
+ files = api.list_repo_files(repo_id=src_repo, token=token)
300
+ for f in tqdm(files, desc="Copying Structure"):
301
+ if ignore_prefix and f.startswith(ignore_prefix): continue
302
+
303
+ if is_root_merge:
304
+ if any(f.endswith(ext) for ext in ['.safetensors', '.bin', '.pt', '.pth']):
305
+ continue
306
+
307
+ try:
308
+ local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir)
309
+ api.upload_file(path_or_fileobj=local, path_in_repo=f, repo_id=dst_repo, token=token)
310
+ if os.path.exists(local): os.remove(local)
311
+ except: pass
312
+ except Exception as e: print(f"Structure clone error: {e}")
313
+
314
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
315
+ cleanup_temp()
316
+ if not hf_token: return "Error: HF Token required."
317
+ login(hf_token.strip())
318
+
319
+ try:
320
+ api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
321
+ except Exception as e: return f"Error creating repo: {e}"
322
+
323
+ # Logic: If using a subfolder like 'transformer', we want standard diffusers naming
324
+ output_subfolder = base_subfolder if base_subfolder else ""
325
+
326
+ # 2. Copy Configs from Base (Aggressive Copy)
327
+ if base_subfolder:
328
+ copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder)
329
+
330
+ # 3. Clone Structure Repo
331
+ if structure_repo:
332
+ ignore = output_subfolder if output_subfolder else None
333
+ streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix=ignore, is_root_merge=not bool(output_subfolder))
334
+
335
+ # 4. Download Shards
336
+ progress(0.1, desc="Downloading Input Model...")
337
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
338
+ input_shards = []
339
+
340
+ for f in files:
341
+ if f.endswith(".safetensors"):
342
+ if output_subfolder and not f.startswith(output_subfolder): continue
343
+
344
+ local = TempDir / "inputs" / os.path.basename(f)
345
+ os.makedirs(local.parent, exist_ok=True)
346
+ hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local.parent, local_dir_use_symlinks=False)
347
+ found = list(local.parent.rglob(os.path.basename(f)))
348
+ if found: input_shards.append(found[0])
349
+
350
+ if not input_shards: return "No safetensors found."
351
+ input_shards.sort()
352
+
353
+ # --- NAMING CONVENTION ---
354
+ # Force diffusion naming if target is transformer/unet
355
+ if output_subfolder in ["transformer", "unet", "qint4", "qint8"]:
356
+ filename_prefix = "diffusion_pytorch_model"
357
+ index_filename = "diffusion_pytorch_model.safetensors.index.json"
358
+ elif "diffusion_pytorch_model" in os.path.basename(input_shards[0]):
359
+ filename_prefix = "diffusion_pytorch_model"
360
+ index_filename = "diffusion_pytorch_model.safetensors.index.json"
361
+ else:
362
+ filename_prefix = "model"
363
+ index_filename = "model.safetensors.index.json"
364
+
365
+ print(f"Naming scheme: {filename_prefix}")
366
+
367
+ # 5. Load LoRA
368
+ dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
369
+ try:
370
+ progress(0.15, desc="Downloading LoRA...")
371
+ lora_path = download_lora_smart(lora_input, hf_token)
372
+ lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
373
+ except Exception as e: return f"Error loading LoRA: {e}"
374
+
375
+ # 6. Stream
376
+ buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix)
377
+
378
+ for i, shard_file in enumerate(input_shards):
379
+ progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}")
380
+
381
+ with MemoryEfficientSafeOpen(shard_file) as f:
382
+ keys = f.keys()
383
+ for k in keys:
384
+ v = f.get_tensor(k)
385
+ base_stem = get_key_stem(k)
386
+ match = lora_pairs.get(base_stem)
387
+
388
+ # QKV Heuristics
389
+ if not match:
390
+ if "to_q" in base_stem:
391
+ qkv = base_stem.replace("to_q", "qkv")
392
+ match = lora_pairs.get(qkv)
393
+ elif "to_k" in base_stem:
394
+ qkv = base_stem.replace("to_k", "qkv")
395
+ match = lora_pairs.get(qkv)
396
+ elif "to_v" in base_stem:
397
+ qkv = base_stem.replace("to_v", "qkv")
398
+ match = lora_pairs.get(qkv)
399
+
400
+ if match:
401
+ down = match["down"]
402
+ up = match["up"]
403
+ scaling = scale * (match["alpha"] / match["rank"])
404
+
405
+ if len(v.shape) == 4 and len(down.shape) == 2:
406
+ down = down.unsqueeze(-1).unsqueeze(-1)
407
+ up = up.unsqueeze(-1).unsqueeze(-1)
408
+
409
+ try:
410
+ if len(up.shape) == 4:
411
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
412
+ else:
413
+ delta = up @ down
414
+ except: delta = up.T @ down
415
+
416
+ delta = delta * scaling
417
+
418
+ valid = True
419
+ if delta.shape == v.shape: pass
420
+ elif delta.shape[0] == v.shape[0] * 3:
421
+ chunk = v.shape[0]
422
+ if "to_q" in k: delta = delta[0:chunk, ...]
423
+ elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
424
+ elif "to_v" in k: delta = delta[2*chunk:, ...]
425
+ else: valid = False
426
+ elif delta.numel() == v.numel(): delta = delta.reshape(v.shape)
427
+ else: valid = False
428
+
429
+ if valid:
430
+ v = v.to(dtype)
431
+ delta = delta.to(dtype)
432
+ v.add_(delta)
433
+ del delta
434
+
435
+ if v.dtype != dtype: v = v.to(dtype)
436
+ buffer.add_tensor(k, v)
437
+ del v
438
+
439
+ os.remove(shard_file)
440
+ gc.collect()
441
+
442
+ buffer.flush()
443
+
444
+ print(f"Uploading Index: {index_filename} (Size: {buffer.total_size})")
445
+ index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map}
446
+ with open(TempDir / index_filename, "w") as f:
447
+ json.dump(index_data, f, indent=4)
448
+
449
+ path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename
450
+ api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
451
+
452
+ cleanup_temp()
453
+ return f"Done! Merged {buffer.shard_count} shards to {output_repo}"
454
+
455
+ # =================================================================================
456
+ # TAB 2: EXTRACT LORA
457
+ # =================================================================================
458
+
459
+ def identify_and_download_model(input_str, token):
460
+ """
461
+ Smart download:
462
+ 1. Checks if input is a direct URL -> downloads specific file.
463
+ 2. If input is a Repo ID -> scans for diffusers format (unet/transformer) or standard safetensors.
464
+ """
465
+ print(f"Resolving model input: {input_str}")
466
+
467
+ # --- STRATEGY A: Direct URL ---
468
+ repo_id_from_url, filename_from_url = parse_hf_url(input_str)
469
+
470
+ if repo_id_from_url and filename_from_url:
471
+ print(f"Detected Direct Link. Repo: {repo_id_from_url}, File: {filename_from_url}")
472
+ local_path = TempDir / os.path.basename(filename_from_url)
473
+ # Clean up previous download if name conflicts
474
+ if local_path.exists(): os.remove(local_path)
475
+
476
+ try:
477
+ hf_hub_download(repo_id=repo_id_from_url, filename=filename_from_url, token=token, local_dir=TempDir)
478
+ # Find where it landed (handling subfolders in local_dir)
479
+ found = list(TempDir.rglob(os.path.basename(filename_from_url)))[0]
480
+ return found
481
+ except Exception as e:
482
+ print(f"URL Download failed: {e}. Trying fallback...")
483
+
484
+ # --- STRATEGY B: Repo Discovery (Auto-Detect) ---
485
+ # If we are here, input_str is treated as a Repo ID (e.g. "ostris/Z-Image-De-Turbo")
486
+ print(f"Scanning Repo {input_str} for model weights...")
487
+
488
+ try:
489
+ files = list_repo_files(repo_id=input_str, token=token)
490
+ except Exception as e:
491
+ raise ValueError(f"Failed to list repo '{input_str}'. If this is a URL, ensure it is formatted correctly. Error: {e}")
492
+
493
+ # Priority list for diffusers vs single file
494
+ priorities = [
495
+ "transformer/diffusion_pytorch_model.safetensors",
496
+ "unet/diffusion_pytorch_model.safetensors",
497
+ "model.safetensors",
498
+ # Fallback to any safetensors that isn't an adapter or lora
499
+ lambda f: f.endswith(".safetensors") and "lora" not in f and "adapter" not in f and "extracted" not in f
500
+ ]
501
+
502
+ target_file = None
503
+ for p in priorities:
504
+ if callable(p):
505
+ candidates = [f for f in files if p(f)]
506
+ if candidates:
507
+ # Pick the largest file if multiple candidates (heuristic for "main" model)
508
+ target_file = candidates[0]
509
+ break
510
+ elif p in files:
511
+ target_file = p
512
+ break
513
+
514
+ if not target_file:
515
+ raise ValueError(f"Could not find a valid model weight file in {input_str}. Ensure it contains .safetensors weights.")
516
+
517
+ print(f"Downloading auto-detected weight file: {target_file}")
518
+ hf_hub_download(repo_id=input_str, filename=target_file, token=token, local_dir=TempDir)
519
+
520
+ # Locate actual path
521
+ found = list(TempDir.rglob(os.path.basename(target_file)))[0]
522
+ return found
523
+
524
+ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
525
+ org = MemoryEfficientSafeOpen(model_org)
526
+ tuned = MemoryEfficientSafeOpen(model_tuned)
527
+ lora_sd = {}
528
+ print("Calculating diffs & extracting LoRA...")
529
+
530
+ # Get intersection of keys
531
+ keys = set(org.keys()).intersection(set(tuned.keys()))
532
+
533
+ for key in tqdm(keys, desc="Extracting"):
534
+ # Skip integer buffers/metadata
535
+ if "num_batches_tracked" in key or "running_mean" in key or "running_var" in key:
536
+ continue
537
+
538
+ mat_org = org.get_tensor(key).float()
539
+ mat_tuned = tuned.get_tensor(key).float()
540
+
541
+ # Skip if shapes mismatch (shouldn't happen if models match)
542
+ if mat_org.shape != mat_tuned.shape: continue
543
+
544
+ diff = mat_tuned - mat_org
545
+
546
+ # Skip if no difference
547
+ if torch.max(torch.abs(diff)) < 1e-4: continue
548
+
549
+ out_dim = diff.shape[0]
550
+ in_dim = diff.shape[1] if len(diff.shape) > 1 else 1
551
+
552
+ r = min(rank, in_dim, out_dim)
553
+
554
+ is_conv = len(diff.shape) == 4
555
+ if is_conv: diff = diff.flatten(start_dim=1)
556
+ elif len(diff.shape) == 1: diff = diff.unsqueeze(1) # Handle biases if needed
557
+
558
+ try:
559
+ # Use svd_lowrank for massive speedup on CPU vs linalg.svd
560
+ U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
561
+ Vh = V.t()
562
+
563
+ U = U[:, :r]
564
+ S = S[:r]
565
+ Vh = Vh[:r, :]
566
+
567
+ # Merge S into U for standard LoRA format
568
+ U = U @ torch.diag(S)
569
+
570
+ # Clamp outliers
571
+ dist = torch.cat([U.flatten(), Vh.flatten()])
572
+ hi_val = torch.quantile(torch.abs(dist), clamp)
573
+ if hi_val > 0:
574
+ U = U.clamp(-hi_val, hi_val)
575
+ Vh = Vh.clamp(-hi_val, hi_val)
576
+
577
+ if is_conv:
578
+ U = U.reshape(out_dim, r, 1, 1)
579
+ Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
580
+ else:
581
+ U = U.reshape(out_dim, r)
582
+ Vh = Vh.reshape(r, in_dim)
583
+
584
+ stem = key.replace(".weight", "")
585
+ lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
586
+ lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
587
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
588
+ except Exception as e:
589
+ print(f"Skipping {key} due to error: {e}")
590
+ pass
591
+
592
+ out = TempDir / "extracted.safetensors"
593
+ save_file(lora_sd, out)
594
+ return str(out)
595
+
596
+ def task_extract(hf_token, org, tun, rank, out):
597
+ cleanup_temp()
598
+ if hf_token: login(hf_token.strip())
599
+ try:
600
+ print("Downloading Original Model...")
601
+ p1 = identify_and_download_model(org, hf_token)
602
+ print("Downloading Tuned Model...")
603
+ p2 = identify_and_download_model(tun, hf_token)
604
+
605
+ f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
606
+
607
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
608
+ api.upload_file(path_or_fileobj=f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
609
+ return "Done! Extracted to " + out
610
+ except Exception as e: return f"Error: {e}"
611
+
612
+ # =================================================================================
613
+ # TAB 3: MERGE ADAPTERS (Multi-Method)
614
+ # =================================================================================
615
+
616
+ def load_full_state_dict(path):
617
+ """Loads a safetensor file and cleans keys for easier processing."""
618
+ raw = load_file(path, device="cpu")
619
+ cleaned = {}
620
+ for k, v in raw.items():
621
+ # Map common keys to standard "lora_up/lora_down"
622
+ if "lora_A" in k: new_k = k.replace("lora_A", "lora_down")
623
+ elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up")
624
+ else: new_k = k
625
+ cleaned[new_k] = v.float()
626
+ return cleaned
627
+
628
+ # --- Original EMA Method ---
629
+ def sigma_rel_to_gamma(sigma_rel):
630
+ t = sigma_rel**-2
631
+ coeffs = [1, 7, 16 - t, 12 - t]
632
+ roots = np.roots(coeffs)
633
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
634
+ return gamma
635
+
636
+ def merge_lora_iterative_ema(paths, beta, sigma_rel):
637
+ print("Executing Iterative EMA Merge (Original Method)...")
638
+ base_sd = load_file(paths[0], device="cpu")
639
+ for k in base_sd:
640
+ if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
641
+
642
+ gamma = None
643
+ if sigma_rel > 0:
644
+ gamma = sigma_rel_to_gamma(sigma_rel)
645
+
646
+ for i, path in enumerate(paths[1:]):
647
+ print(f"Merging {path}")
648
+ if gamma is not None:
649
+ t = i + 1
650
+ current_beta = (1 - 1 / t) ** (gamma + 1)
651
+ else:
652
+ current_beta = beta
653
+
654
+ curr = load_file(path, device="cpu")
655
+ for k in base_sd:
656
+ if k in curr and "alpha" not in k:
657
+ base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
658
+ return base_sd
659
+
660
+ # --- New Concatenation Method (DiffSynth) ---
661
+ def merge_lora_concatenation(adapter_states, weights):
662
+ """
663
+ DiffSynth Method: Concatenates ranks.
664
+ New Rank = sum(ranks). Lossless merging.
665
+ """
666
+ print("Executing Concatenation Merge (Rank Summation)...")
667
+ merged_state = {}
668
+
669
+ # Identify all stems (layers) present across all adapters
670
+ all_stems = set()
671
+ for state in adapter_states:
672
+ for k in state.keys():
673
+ stem = k.split(".lora_")[0]
674
+ if "lora_" in k: all_stems.add(stem)
675
+
676
+ for stem in tqdm(all_stems, desc="Concatenating Layers"):
677
+ down_list = []
678
+ up_list = []
679
+ alpha_sum = 0.0
680
+
681
+ for i, state in enumerate(adapter_states):
682
+ w = weights[i]
683
+ down_key = f"{stem}.lora_down.weight"
684
+ up_key = f"{stem}.lora_up.weight"
685
+ alpha_key = f"{stem}.alpha"
686
+
687
+ if down_key in state and up_key in state:
688
+ d = state[down_key]
689
+ u = state[up_key] * w # weighted contribution applied to UP
690
+
691
+ down_list.append(d)
692
+ up_list.append(u)
693
+
694
+ if alpha_key in state:
695
+ alpha_sum += state[alpha_key].item()
696
+ else:
697
+ alpha_sum += d.shape[0]
698
+
699
+ if down_list and up_list:
700
+ # Concat Down (A) along dim 0 (output of A, input to B) - Wait, lora_A is (rank, in)
701
+ # Concat Up (B) along dim 1 (input of B) - lora_B is (out, rank)
702
+ # Reference: DiffSynth code: lora_A = concat(tensors_A, dim=0), lora_B = concat(tensors_B, dim=1)
703
+
704
+ new_down = torch.cat(down_list, dim=0) # (sum_rank, in)
705
+ new_up = torch.cat(up_list, dim=1) # (out, sum_rank)
706
+
707
+ merged_state[f"{stem}.lora_down.weight"] = new_down.contiguous()
708
+ merged_state[f"{stem}.lora_up.weight"] = new_up.contiguous()
709
+ merged_state[f"{stem}.alpha"] = torch.tensor(alpha_sum)
710
+
711
+ return merged_state
712
+
713
+ # --- New SVD/Task Arithmetic Method ---
714
+ def merge_lora_svd(adapter_states, weights, target_rank):
715
+ """
716
+ SVD / Task Arithmetic Method:
717
+ 1. Calculate Delta W for each adapter: dW = B @ A
718
+ 2. Sum Delta Ws: Total dW = sum(weight_i * dW_i)
719
+ 3. SVD(Total dW) -> New B, New A at target_rank
720
+ """
721
+ print(f"Executing SVD Merge (Target Rank: {target_rank})...")
722
+ merged_state = {}
723
+
724
+ all_stems = set()
725
+ for state in adapter_states:
726
+ for k in state.keys():
727
+ stem = k.split(".lora_")[0]
728
+ if "lora_" in k: all_stems.add(stem)
729
+
730
+ for stem in tqdm(all_stems, desc="SVD Merging Layers"):
731
+ total_delta = None
732
+ valid_layer = False
733
+
734
+ for i, state in enumerate(adapter_states):
735
+ w = weights[i]
736
+ down_key = f"{stem}.lora_down.weight"
737
+ up_key = f"{stem}.lora_up.weight"
738
+ alpha_key = f"{stem}.alpha"
739
+
740
+ if down_key in state and up_key in state:
741
+ down = state[down_key]
742
+ up = state[up_key]
743
+ alpha = state[alpha_key].item() if alpha_key in state else down.shape[0]
744
+ rank = down.shape[0]
745
+
746
+ scale = (alpha / rank) * w
747
+
748
+ # Reconstruct Delta
749
+ if len(down.shape) == 4: # Conv2d
750
+ d_flat = down.flatten(start_dim=1)
751
+ u_flat = up.flatten(start_dim=1)
752
+ delta = (u_flat @ d_flat).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
753
+ else:
754
+ delta = up @ down
755
+
756
+ delta = delta * scale
757
+
758
+ if total_delta is None:
759
+ total_delta = delta
760
+ valid_layer = True
761
+ else:
762
+ if total_delta.shape == delta.shape:
763
+ total_delta += delta
764
+ else:
765
+ print(f"Shape mismatch in {stem}, skipping.")
766
+
767
+ if valid_layer and total_delta is not None:
768
+ out_dim = total_delta.shape[0]
769
+ in_dim = total_delta.shape[1]
770
+ is_conv = len(total_delta.shape) == 4
771
+
772
+ if is_conv:
773
+ flat_delta = total_delta.flatten(start_dim=1)
774
+ else:
775
+ flat_delta = total_delta
776
+
777
+ try:
778
+ U, S, V = torch.svd_lowrank(flat_delta, q=target_rank + 4, niter=4)
779
+ Vh = V.t()
780
+
781
+ U = U[:, :target_rank]
782
+ S = S[:target_rank]
783
+ Vh = Vh[:target_rank, :]
784
+
785
+ U = U @ torch.diag(S)
786
+
787
+ if is_conv:
788
+ U = U.reshape(out_dim, target_rank, 1, 1)
789
+ Vh = Vh.reshape(target_rank, in_dim, total_delta.shape[2], total_delta.shape[3])
790
+ else:
791
+ U = U.reshape(out_dim, target_rank)
792
+ Vh = Vh.reshape(target_rank, in_dim)
793
+
794
+ merged_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
795
+ merged_state[f"{stem}.lora_up.weight"] = U.contiguous()
796
+ merged_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
797
+ except Exception as e:
798
+ print(f"SVD Failed for {stem}: {e}")
799
+
800
+ return merged_state
801
+
802
+ def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private):
803
+ cleanup_temp()
804
+ if hf_token: login(hf_token.strip())
805
+
806
+ if not out_repo or not out_repo.strip():
807
+ return "Error: Output Repo cannot be empty."
808
+
809
+ # 1. Parse Inputs (Multi-line support)
810
+ raw_lines = inputs_text.replace(" ", "\n").split('\n')
811
+ urls = [line.strip() for line in raw_lines if line.strip()]
812
+ if len(urls) < 2: return "Error: Please provide at least 2 adapters."
813
+
814
+ # 2. Parse Weights (for SVD/Concatenation)
815
+ try:
816
+ if not weight_str.strip():
817
+ weights = [1.0] * len(urls)
818
+ else:
819
+ weights = [float(w.strip()) for w in weight_str.split(',')]
820
+ # Broadcast or Truncate
821
+ if len(weights) < len(urls):
822
+ weights += [1.0] * (len(urls) - len(weights))
823
+ else:
824
+ weights = weights[:len(urls)]
825
+ except:
826
+ return "Error parsing weights. Use format: 1.0, 0.5, 0.8"
827
+
828
+ # 3. Download All
829
+ paths = []
830
+ try:
831
+ for url in tqdm(urls, desc="Downloading Adapters"):
832
+ paths.append(download_lora_smart(url, hf_token))
833
+ except Exception as e: return f"Download Error: {e}"
834
+
835
+ merged = None
836
+
837
+ # 4. Execute Selected Method
838
+ if "Iterative EMA" in method:
839
+ # Calls the original method logic exactly
840
+ merged = merge_lora_iterative_ema(paths, beta, sigma_rel)
841
+
842
+ else:
843
+ # For new methods, we load everything upfront
844
+ states = [load_full_state_dict(p) for p in paths]
845
+
846
+ if "Concatenation" in method:
847
+ merged = merge_lora_concatenation(states, weights)
848
+ elif "SVD" in method:
849
+ merged = merge_lora_svd(states, weights, int(target_rank))
850
+
851
+ if not merged: return "Merge failed (Result empty)."
852
+
853
+ # 5. Save & Upload
854
+ out = TempDir / "merged_adapters.safetensors"
855
+ save_file(merged, out)
856
+
857
+ try:
858
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
859
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
860
+ return f"Success! Merged to {out_repo}"
861
+ except Exception as e: return f"Upload Error: {e}"
862
+
863
+ # =================================================================================
864
+ # TAB 4: RESIZE (CPU Optimized)
865
+ # =================================================================================
866
+
867
+ def index_sv_cumulative(S, target):
868
+ """Cumulative sum retention."""
869
+ original_sum = float(torch.sum(S))
870
+ cumulative_sums = torch.cumsum(S, dim=0) / original_sum
871
+ index = int(torch.searchsorted(cumulative_sums, target)) + 1
872
+ index = max(1, min(index, len(S) - 1))
873
+ return index
874
+
875
+ def index_sv_fro(S, target):
876
+ """Frobenius norm retention (squared sum)."""
877
+ S_squared = S.pow(2)
878
+ S_fro_sq = float(torch.sum(S_squared))
879
+ sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
880
+ index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
881
+ index = max(1, min(index, len(S) - 1))
882
+ return index
883
+
884
+ def index_sv_ratio(S, target):
885
+ """Ratio between max and min singular value."""
886
+ max_sv = S[0]
887
+ min_sv = max_sv / target
888
+ index = int(torch.sum(S > min_sv).item())
889
+ index = max(1, min(index, len(S) - 1))
890
+ return index
891
+
892
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
893
+ cleanup_temp()
894
+ if not hf_token: return "Error: Token required"
895
+ login(hf_token.strip())
896
+
897
+ try:
898
+ path = download_lora_smart(lora_input, hf_token)
899
+ except Exception as e: return f"Error: {e}"
900
+
901
+ state = load_file(path, device="cpu")
902
+ new_state = {}
903
+
904
+ groups = {}
905
+ for k in state:
906
+ stem = get_key_stem(k)
907
+ simple = k.split(".lora_")[0]
908
+ if simple not in groups: groups[simple] = {}
909
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
910
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
911
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
912
+
913
+ print(f"Resizing {len(groups)} blocks...")
914
+
915
+ # Pre-parse user settings
916
+ target_rank_limit = int(new_rank)
917
+ if dynamic_method == "None": dynamic_method = None
918
+
919
+ for stem, g in tqdm(groups.items()):
920
+ if "down" in g and "up" in g:
921
+ down, up = g["down"].float(), g["up"].float()
922
+
923
+ # 1. Merge Up/Down to get full weight delta
924
+ if len(down.shape) == 4:
925
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
926
+ flat = merged.flatten(1)
927
+ else:
928
+ merged = up @ down
929
+ flat = merged
930
+
931
+ # 2. FAST SVD (svd_lowrank)
932
+ # Use the "To Rank" input as a computational hard limit + buffer.
933
+ # This ensures we don't compute expensive full SVD for massive layers.
934
+ q_limit = target_rank_limit + 32 # Buffer to allow dynamic methods some wiggle room before truncation
935
+ q = min(q_limit, min(flat.shape))
936
+
937
+ U, S, V = torch.svd_lowrank(flat, q=q)
938
+ Vh = V.t()
939
+
940
+ # 3. Dynamic Rank Selection
941
+ calculated_rank = target_rank_limit
942
+
943
+ if dynamic_method == "sv_ratio":
944
+ calculated_rank = index_sv_ratio(S, dynamic_param)
945
+ elif dynamic_method == "sv_cumulative":
946
+ calculated_rank = index_sv_cumulative(S, dynamic_param)
947
+ elif dynamic_method == "sv_fro":
948
+ calculated_rank = index_sv_fro(S, dynamic_param)
949
+
950
+ # Apply Hard Limit (User's "To Rank")
951
+ final_rank = min(calculated_rank, target_rank_limit, S.shape[0])
952
+
953
+ # 4. Truncate
954
+ U = U[:, :final_rank]
955
+ S = S[:final_rank]
956
+ Vh = Vh[:final_rank, :]
957
+
958
+ # 5. Reconstruct Up Matrix (Absorb S into U)
959
+ U = U @ torch.diag(S)
960
+
961
+ if len(down.shape) == 4:
962
+ U = U.reshape(up.shape[0], final_rank, 1, 1)
963
+ Vh = Vh.reshape(final_rank, down.shape[1], down.shape[2], down.shape[3])
964
+
965
+ # 6. Save (FIX: Enforce contiguous memory layout)
966
+ new_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
967
+ new_state[f"{stem}.lora_up.weight"] = U.contiguous()
968
+ new_state[f"{stem}.alpha"] = torch.tensor(final_rank).float()
969
+
970
+ out = TempDir / "shrunken_.safetensors"
971
+ # safetensors requires contiguous tensors
972
+ save_file(new_state, out)
973
+
974
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
975
+ api.upload_file(path_or_fileobj=out, path_in_repo="shrunken.safetensors", repo_id=out_repo, token=hf_token)
976
+ return "Done"
977
+
978
+ # =================================================================================
979
+ # NEW TAB 5: FULL MODEL MERGER (MergeKit GUI Wrapper)
980
+ # =================================================================================
981
+
982
+ def task_full_model_merge(hf_token, model_a, model_b, method, base_model, weight_a, weight_b, density, out_repo, private):
983
+ cleanup_temp()
984
+ if hf_token: login(hf_token.strip())
985
+
986
+ # Construct a valid MergeKit YAML Config dynamically
987
+ config = {
988
+ "merge_method": method.lower(),
989
+ "base_model": base_model if base_model else model_a,
990
+ "models": [
991
+ {"model": model_a, "parameters": {"weight": weight_a, "density": density}},
992
+ {"model": model_b, "parameters": {"weight": weight_b, "density": density}}
993
+ ],
994
+ "dtype": "float16",
995
+ "tokenizer_source": "base"
996
+ }
997
+
998
+ out_path = TempDir / "merged_model"
999
+ try:
1000
+ execute_mergekit(config, str(out_path), hf_token)
1001
+ # Push to Hub logic (reuse your existing streaming_upload logic if sharding is needed)
1002
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
1003
+ api.upload_folder(folder_path=str(out_path), repo_id=out_repo, token=hf_token)
1004
+ return f"Full model merged successfully to {out_repo}"
1005
+ except Exception as e:
1006
+ return f"MergeKit Error: {e}"
1007
+
1008
+ # =================================================================================
1009
+ # NEW TAB 6: MIXTURE OF EXPERTS (MoE Creator)
1010
+ # =================================================================================
1011
+
1012
+ def task_create_moe(hf_token, base_model, experts_list, out_repo, private):
1013
+ cleanup_temp()
1014
+ experts = [e.strip() for e in experts_list.split(",") if e.strip()]
1015
+ config = {
1016
+ "method": "moe",
1017
+ "base_model": base_model,
1018
+ "experts": [{"source_model": exp} for exp in experts],
1019
+ "gate_mode": "cheap_embed" # Memory efficient for CPU
1020
+ }
1021
+ # [Execution logic similar to Tab 5]
1022
+ return "MoE Model Created (Placeholder for execution logic)"
1023
+
1024
+ # =================================================================================
1025
+ # UI
1026
+ # =================================================================================
1027
+
1028
+ css = ".container { max-width: 900px; margin: auto; }"
1029
+
1030
+ with gr.Blocks() as demo:
1031
+ title = gr.HTML(
1032
+ """<h1><img src="https://huggingface.co/spaces/AlekseyCalvin/Soon_Merger/resolve/main/SMerger3.png" alt="SOONmerge®"> Transform Transformers for FREE!</h1>""",
1033
+ elem_id="title",
1034
+ )
1035
+ gr.Markdown("# 🧰SOONmerge® LoRA Toolkit")
1036
+
1037
+ with gr.Tabs():
1038
+ with gr.Tab("Merge to Base Model + Reshard Output"):
1039
+ t1_token = gr.Textbox(label="Token", type="password")
1040
+ t1_base = gr.Textbox(label="Base Repo", value="name/repo")
1041
+ t1_sub = gr.Textbox(label="Subfolder (Optional)", value="")
1042
+ t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
1043
+ with gr.Row():
1044
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1)
1045
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
1046
+ t1_shard = gr.Slider(label="Max Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1)
1047
+ t1_out = gr.Textbox(label="Output Repo")
1048
+ t1_struct = gr.Textbox(label="Extras Source (copies configs/components/etc)", value="name/repo")
1049
+ t1_priv = gr.Checkbox(label="Private", value=True)
1050
+ t1_btn = gr.Button("Merge")
1051
+ t1_res = gr.Textbox(label="Result")
1052
+ t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
1053
+
1054
+ with gr.Tab("Extract Adapter"):
1055
+ t2_token = gr.Textbox(label="Token", type="password")
1056
+ t2_org = gr.Textbox(label="Original Model")
1057
+ t2_tun = gr.Textbox(label="Tuned or Homologous Model")
1058
+ t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1)
1059
+ t2_out = gr.Textbox(label="Output Repo")
1060
+ t2_btn = gr.Button("Extract")
1061
+ t2_res = gr.Textbox(label="Result")
1062
+ t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
1063
+
1064
+ with gr.Tab("Merge Adapters/Weights"):
1065
+ gr.Markdown("### Batch Adapter Merging")
1066
+ t3_token = gr.Textbox(label="Token", type="password")
1067
+ t3_urls = gr.TextArea(label="Adapter URLs/Repos (one per line, or space-separated)", placeholder="user/lora1\nhttps://hf.co/user/lora2.safetensors\n...")
1068
+
1069
+ with gr.Row():
1070
+ t3_method = gr.Dropdown(
1071
+ ["Iterative EMA (Linear w/ Beta/Sigma coefficient)", "Concatenation (MOE-like weights-stack)", "SVD Fusion (Task Arithmetic/Compressed)"],
1072
+ value="Iterative EMA (Linear w/ Beta/Sigma coefficient)",
1073
+ label="Merge Method"
1074
+ )
1075
+
1076
+ with gr.Row():
1077
+ t3_weights = gr.Textbox(label="Weights (comma-separated) – for Concat/SVD", placeholder="1.0, 0.5, 0.8...")
1078
+ t3_rank = gr.Number(label="Target Rank – For SVD only", value=128, minimum=4, maximum=1024)
1079
+
1080
+ with gr.Row():
1081
+ t3_beta = gr.Slider(label="Beta – for linear/post-hoc EMA", value=0.95, minimum=0.01, maximum=1.00, step=0.01)
1082
+ t3_sigma = gr.Slider(label="Sigma Rel – for linear/post-hoc EMA", value=0.21, minimum=0.01, maximum=1.00, step=0.01)
1083
+
1084
+ t3_out = gr.Textbox(label="Output Repo")
1085
+ t3_priv = gr.Checkbox(label="Private Output", value=True)
1086
+ t3_btn = gr.Button("Merge")
1087
+ t3_res = gr.Textbox(label="Result")
1088
+
1089
+ t3_btn.click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], t3_res)
1090
+
1091
+ with gr.Tab("Resize Adapter"):
1092
+ t4_token = gr.Textbox(label="Token", type="password")
1093
+ t4_in = gr.Textbox(label="LoRA")
1094
+ with gr.Row():
1095
+ t4_rank = gr.Number(label="To Rank (Safety Ceiling)", value=8, minimum=1, maximum=512, step=1)
1096
+ t4_method = gr.Dropdown(["None", "sv_ratio", "sv_fro", "sv_cumulative"], value="None", label="Dynamic Method")
1097
+ t4_param = gr.Number(label="Dynamic Param", value=0.9)
1098
+
1099
+ gr.Markdown(
1100
+ """
1101
+ ### 📉 Dynamic Resizing Guide
1102
+ These methods intelligently determine the best rank per layer.
1103
+ * **sv_ratio (Relative Strength):** Keeps features that are at least `1/Param` as strong as the main feature. **Param must be >= 2**. (e.g. 2 = keep features half as strong as top).
1104
+ * **sv_fro (Visual Information Density):** Preserves `Param%` of the total information content (Frobenius Norm) of the layer. **Param between 0.0 and 1.0** (e.g. 0.9 = 90% info retention).
1105
+ * **sv_cumulative (Cumulative Sum):** Preserves weights that sum up to `Param%` of the total strength. **Param between 0.0 and 1.0**.
1106
+ * **⚠️ Safety Ceiling:** The **"To Rank"** slider acts as a hard limit. Even if a dynamic method wants a higher rank, it will be cut down to this number to keep file sizes small.
1107
+ """
1108
+ )
1109
+ t4_out = gr.Textbox(label="Output")
1110
+ t4_btn = gr.Button("Resize")
1111
+ t4_res = gr.Textbox(label="Result")
1112
+ t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
1113
+
1114
+ with gr.Tab("Full Model Merge (MergeKit)"):
1115
+ gr.Markdown("### 🧩 Advanced Model Fusion (MergeKit Engine)")
1116
+ with gr.Row():
1117
+ t5_token = gr.Textbox(label="HF Token", type="password")
1118
+ t5_method = gr.Dropdown(["Linear", "SLERP", "TIES", "DARE_TIES", "DARE_LINEAR"], value="TIES", label="Merge Method")
1119
+ with gr.Row():
1120
+ t5_model_a = gr.Textbox(label="Model A (Repo ID)")
1121
+ t5_model_b = gr.Textbox(label="Model B (Repo ID)")
1122
+ t5_base = gr.Textbox(label="Base Model (Optional)", placeholder="Required for TIES/DARE")
1123
+ with gr.Row():
1124
+ t5_weight_a = gr.Slider(0, 1, 0.5, label="Weight A")
1125
+ t5_weight_b = gr.Slider(0, 1, 0.5, label="Weight B")
1126
+ t5_density = gr.Slider(0, 1, 0.5, label="Density (TIES/DARE)")
1127
+ t5_out = gr.Textbox(label="Output Repo")
1128
+ t5_priv = gr.Checkbox(label="Private", value=True)
1129
+ t5_btn = gr.Button("Execute Full Merge")
1130
+ t5_res = gr.Textbox(label="Result")
1131
+
1132
+ t5_btn.click(task_full_model_merge, [t5_token, t5_model_a, t5_model_b, t5_method, t5_base, t5_weight_a, t5_weight_b, t5_density, t5_out, t5_priv], t5_res)
1133
+
1134
+ with gr.Tab("Create MoE"):
1135
+ gr.Markdown("### 🤖 Mixture of Experts Upscaling")
1136
+ t6_base = gr.Textbox(label="Base Architecture Model")
1137
+ t6_experts = gr.TextArea(label="Expert Models (Comma separated Repo IDs)")
1138
+ t6_btn = gr.Button("Build MoE")
1139
+
1140
+ if __name__ == "__main__":
1141
+ demo.queue().launch(css=css, ssr_mode=False)