AlekseyCalvin commited on
Commit
a690cfc
·
verified ·
1 Parent(s): b5a221c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +656 -0
app.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import re
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional, List
13
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
14
+ from safetensors.torch import load_file, save_file
15
+ from tqdm import tqdm
16
+
17
+ # --- Memory Efficient Safetensors ---
18
+ class MemoryEfficientSafeOpen:
19
+ """
20
+ Reads safetensors metadata and tensors without mmap, keeping RAM usage low.
21
+ """
22
+ def __init__(self, filename):
23
+ self.filename = filename
24
+ self.file = open(filename, "rb")
25
+ self.header, self.header_size = self._read_header()
26
+
27
+ def __enter__(self):
28
+ return self
29
+
30
+ def __exit__(self, exc_type, exc_val, exc_tb):
31
+ self.file.close()
32
+
33
+ def keys(self) -> list[str]:
34
+ return [k for k in self.header.keys() if k != "__metadata__"]
35
+
36
+ def metadata(self) -> Dict[str, str]:
37
+ return self.header.get("__metadata__", {})
38
+
39
+ def get_tensor(self, key):
40
+ if key not in self.header:
41
+ raise KeyError(f"Tensor '{key}' not found in the file")
42
+ metadata = self.header[key]
43
+ offset_start, offset_end = metadata["data_offsets"]
44
+ self.file.seek(self.header_size + 8 + offset_start)
45
+ tensor_bytes = self.file.read(offset_end - offset_start)
46
+ return self._deserialize_tensor(tensor_bytes, metadata)
47
+
48
+ def _read_header(self):
49
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
50
+ header_json = self.file.read(header_size).decode("utf-8")
51
+ return json.loads(header_json), header_size
52
+
53
+ def _deserialize_tensor(self, tensor_bytes, metadata):
54
+ dtype_map = {
55
+ "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
56
+ "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
57
+ "U8": torch.uint8, "BOOL": torch.bool
58
+ }
59
+ dtype = dtype_map[metadata["dtype"]]
60
+ shape = metadata["shape"]
61
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
62
+
63
+ # --- Constants & Setup ---
64
+ try:
65
+ TempDir = Path("/tmp/temp_tool")
66
+ os.makedirs(TempDir, exist_ok=True)
67
+ except:
68
+ TempDir = Path("./temp_tool")
69
+ os.makedirs(TempDir, exist_ok=True)
70
+
71
+ api = HfApi()
72
+
73
+ def cleanup_temp():
74
+ if TempDir.exists():
75
+ shutil.rmtree(TempDir)
76
+ os.makedirs(TempDir, exist_ok=True)
77
+ gc.collect()
78
+
79
+ def download_file(input_path, token, filename=None):
80
+ local_path = TempDir / (filename if filename else "model.safetensors")
81
+ if input_path.startswith("http"):
82
+ print(f"Downloading {filename} from URL...")
83
+ try:
84
+ response = requests.get(input_path, stream=True, timeout=30)
85
+ response.raise_for_status()
86
+ with open(local_path, 'wb') as f:
87
+ for chunk in response.iter_content(chunk_size=8192):
88
+ f.write(chunk)
89
+ except Exception as e: raise ValueError(f"Download failed: {e}")
90
+ else:
91
+ print(f"Downloading {filename} from Hub...")
92
+ if not filename:
93
+ try:
94
+ files = list_repo_files(repo_id=input_path, token=token)
95
+ safetensors = [f for f in files if f.endswith(".safetensors")]
96
+ filename = safetensors[0] if safetensors else "adapter_model.safetensors"
97
+ except: filename = "adapter_model.safetensors"
98
+
99
+ try:
100
+ hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
101
+ if not (TempDir / filename).exists():
102
+ found = list(TempDir.rglob(filename))
103
+ if found: shutil.move(found[0], local_path)
104
+ except Exception as e: raise ValueError(f"Hub download failed: {e}")
105
+
106
+ return local_path
107
+
108
+ def get_key_stem(key):
109
+ key = key.replace(".weight", "").replace(".bias", "")
110
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
111
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
112
+ key = key.replace(".alpha", "")
113
+ prefixes = [
114
+ "model.diffusion_model.", "diffusion_model.", "model.",
115
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
116
+ ]
117
+ changed = True
118
+ while changed:
119
+ changed = False
120
+ for p in prefixes:
121
+ if key.startswith(p):
122
+ key = key[len(p):]
123
+ changed = True
124
+ return key
125
+
126
+ # =================================================================================
127
+ # TAB 1: MERGE & RESHARD (Fixes Folder Structure & Aux Files)
128
+ # =================================================================================
129
+
130
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
131
+ print(f"Loading LoRA from {lora_path}...")
132
+ state_dict = load_file(lora_path, device="cpu")
133
+ pairs = {}
134
+ alphas = {}
135
+ for k, v in state_dict.items():
136
+ stem = get_key_stem(k)
137
+ if "alpha" in k:
138
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
139
+ else:
140
+ if stem not in pairs: pairs[stem] = {}
141
+ if "lora_down" in k or "lora_A" in k:
142
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
143
+ pairs[stem]["rank"] = v.shape[0]
144
+ elif "lora_up" in k or "lora_B" in k:
145
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
146
+ for stem in pairs:
147
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
148
+ return pairs
149
+
150
+ class ShardBuffer:
151
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
152
+ self.max_bytes = int(max_size_gb * 1024**3)
153
+ self.output_dir = output_dir
154
+ self.output_repo = output_repo
155
+ self.subfolder = subfolder
156
+ self.hf_token = hf_token
157
+ self.filename_prefix = filename_prefix # Dynamic prefix (e.g. 'diffusion_pytorch_model' or 'model')
158
+ self.buffer = []
159
+ self.current_bytes = 0
160
+ self.shard_count = 0
161
+ self.index_map = {}
162
+ self.total_model_size = 0
163
+
164
+ def add_tensor(self, key, tensor):
165
+ if tensor.dtype == torch.bfloat16:
166
+ raw_bytes = tensor.view(torch.int16).numpy().tobytes()
167
+ dtype_str = "BF16"
168
+ elif tensor.dtype == torch.float16:
169
+ raw_bytes = tensor.numpy().tobytes()
170
+ dtype_str = "F16"
171
+ else:
172
+ raw_bytes = tensor.numpy().tobytes()
173
+ dtype_str = "F32"
174
+
175
+ size = len(raw_bytes)
176
+ self.buffer.append({
177
+ "key": key,
178
+ "data": raw_bytes,
179
+ "dtype": dtype_str,
180
+ "shape": tensor.shape
181
+ })
182
+ self.current_bytes += size
183
+ self.total_model_size += size
184
+
185
+ if self.current_bytes >= self.max_bytes:
186
+ self.flush()
187
+
188
+ def flush(self):
189
+ if not self.buffer: return
190
+ self.shard_count += 1
191
+
192
+ # ADAPTIVE NAMING: Uses the prefix detected from the base model
193
+ filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
194
+
195
+ # Proper Subfolder Handling
196
+ path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
197
+
198
+ print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...")
199
+
200
+ header = {"__metadata__": {"format": "pt"}}
201
+ current_offset = 0
202
+ for item in self.buffer:
203
+ header[item["key"]] = {
204
+ "dtype": item["dtype"],
205
+ "shape": item["shape"],
206
+ "data_offsets": [current_offset, current_offset + len(item["data"])]
207
+ }
208
+ current_offset += len(item["data"])
209
+ self.index_map[item["key"]] = filename
210
+
211
+ header_json = json.dumps(header).encode('utf-8')
212
+
213
+ out_path = self.output_dir / filename
214
+ with open(out_path, 'wb') as f:
215
+ f.write(struct.pack('<Q', len(header_json)))
216
+ f.write(header_json)
217
+ for item in self.buffer:
218
+ f.write(item["data"])
219
+
220
+ print(f"Uploading {path_in_repo}...")
221
+ api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
222
+
223
+ os.remove(out_path)
224
+ self.buffer = []
225
+ self.current_bytes = 0
226
+ gc.collect()
227
+
228
+ def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix="transformer"):
229
+ """
230
+ Copies files one-by-one from source to dest, skipping 'ignore_prefix'.
231
+ Does NOT skip .safetensors/.bin if they are outside the ignore folder.
232
+ """
233
+ print(f"Scanning {src_repo} for auxiliary files...")
234
+ try:
235
+ files = api.list_repo_files(repo_id=src_repo, token=token)
236
+
237
+ for f in tqdm(files, desc="Copying Structure"):
238
+ # 1. Skip the folder we are replacing (e.g., transformer/)
239
+ if ignore_prefix and f.startswith(ignore_prefix):
240
+ continue
241
+
242
+ # 2. Skip hidden/system files
243
+ if f.startswith("."):
244
+ continue
245
+
246
+ # 3. Download -> Upload -> Delete loop
247
+ # This ensures we get VAE/TextEnc weights without disk overflow
248
+ try:
249
+ print(f"Copying {f}...")
250
+ local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir)
251
+
252
+ api.upload_file(
253
+ path_or_fileobj=local,
254
+ path_in_repo=f,
255
+ repo_id=dst_repo,
256
+ token=token
257
+ )
258
+
259
+ if os.path.exists(local):
260
+ os.remove(local)
261
+ except Exception as e:
262
+ print(f"Failed to copy {f}: {e}")
263
+
264
+ except Exception as e:
265
+ print(f"Structure cloning error: {e}")
266
+
267
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
268
+ cleanup_temp()
269
+ login(hf_token)
270
+
271
+ # 1. Output Setup
272
+ try:
273
+ api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
274
+ except Exception as e: return f"Error creating repo: {e}"
275
+
276
+ # 2. Server-Side Structure Clone
277
+ if structure_repo:
278
+ ignore = base_subfolder if base_subfolder else None
279
+ streaming_copy_structure(hf_token, structure_repo, output_repo, ignore)
280
+
281
+ # 3. Load LoRA
282
+ dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
283
+ try:
284
+ progress(0.1, desc="Downloading LoRA...")
285
+ lora_path = download_file(lora_input, hf_token, filename="adapter.safetensors")
286
+ lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
287
+ except Exception as e: return f"Error loading LoRA: {e}"
288
+
289
+ # 4. Stream Process
290
+ progress(0.2, desc="Fetching File List...")
291
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
292
+
293
+ # Identify valid shards in the target folder
294
+ input_shards = []
295
+ for f in files:
296
+ if not f.endswith(".safetensors"): continue
297
+ if base_subfolder and not f.startswith(base_subfolder): continue
298
+ input_shards.append(f)
299
+
300
+ if not input_shards: return "No base safetensors found in specified location."
301
+
302
+ input_shards.sort()
303
+
304
+ # --- AUTO-DETECT NAMING CONVENTION ---
305
+ # We look at the first file to decide the naming scheme.
306
+ # Common schemes:
307
+ # "diffusion_pytorch_model-00001..." -> prefix: "diffusion_pytorch_model"
308
+ # "model-00001..." -> prefix: "model"
309
+ # "model.safetensors" -> prefix: "model"
310
+
311
+ first_file = os.path.basename(input_shards[0])
312
+
313
+ if first_file.startswith("diffusion_pytorch_model"):
314
+ filename_prefix = "diffusion_pytorch_model"
315
+ index_filename = "diffusion_pytorch_model.safetensors.index.json"
316
+ else:
317
+ # Default for LLMs, Text Encoders, etc.
318
+ filename_prefix = "model"
319
+ index_filename = "model.safetensors.index.json"
320
+
321
+ print(f"Detected naming convention: {filename_prefix} (Index: {index_filename})")
322
+
323
+ # Initialize Buffer with detected prefix
324
+ buffer = ShardBuffer(shard_size, TempDir, output_repo, base_subfolder, hf_token, filename_prefix=filename_prefix)
325
+
326
+ for i, shard_file in enumerate(input_shards):
327
+ progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {shard_file}")
328
+
329
+ local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
330
+
331
+ with MemoryEfficientSafeOpen(local_shard) as f:
332
+ keys = f.keys()
333
+ for k in keys:
334
+ v = f.get_tensor(k)
335
+ base_stem = get_key_stem(k)
336
+ lora_keys = set(lora_pairs.keys())
337
+ match = None
338
+
339
+ # Matching Logic (Exact + Heuristic for QKV)
340
+ if base_stem in lora_keys:
341
+ match = lora_pairs[base_stem]
342
+ else:
343
+ if "to_q" in base_stem:
344
+ qkv_stem = base_stem.replace("to_q", "qkv")
345
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
346
+ elif "to_k" in base_stem:
347
+ qkv_stem = base_stem.replace("to_k", "qkv")
348
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
349
+ elif "to_v" in base_stem:
350
+ qkv_stem = base_stem.replace("to_v", "qkv")
351
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
352
+
353
+ if match and "down" in match and "up" in match:
354
+ down = match["down"]
355
+ up = match["up"]
356
+ alpha = match["alpha"]
357
+ rank = match["rank"]
358
+ scaling = scale * (alpha / rank)
359
+
360
+ # Handle Conv 1x1 squeeze
361
+ if len(v.shape) == 4 and len(down.shape) == 2:
362
+ down = down.unsqueeze(-1).unsqueeze(-1)
363
+ up = up.unsqueeze(-1).unsqueeze(-1)
364
+
365
+ try:
366
+ if len(up.shape) == 4:
367
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
368
+ else:
369
+ delta = up @ down
370
+ except:
371
+ delta = up.T @ down
372
+
373
+ delta = delta * scaling
374
+ valid_delta = True
375
+
376
+ # Shape Slicing Logic
377
+ if delta.shape == v.shape:
378
+ pass
379
+ elif delta.shape[0] == v.shape[0] * 3:
380
+ chunk = v.shape[0]
381
+ if "to_q" in k: delta = delta[0:chunk, ...]
382
+ elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
383
+ elif "to_v" in k: delta = delta[2*chunk:, ...]
384
+ else: valid_delta = False
385
+ elif delta.numel() == v.numel():
386
+ delta = delta.reshape(v.shape)
387
+ else:
388
+ valid_delta = False
389
+
390
+ if valid_delta:
391
+ v = v.to(dtype)
392
+ delta = delta.to(dtype)
393
+ v.add_(delta)
394
+ del delta
395
+
396
+ if v.dtype != dtype: v = v.to(dtype)
397
+ buffer.add_tensor(k, v)
398
+ del v
399
+
400
+ os.remove(local_shard)
401
+ gc.collect()
402
+
403
+ buffer.flush()
404
+
405
+ # Upload Index (Using the dynamically determined index filename)
406
+ print(f"Uploading Index: {index_filename}")
407
+
408
+ index_data = {
409
+ "metadata": {"total_size": buffer.total_model_size},
410
+ "weight_map": buffer.index_map
411
+ }
412
+
413
+ with open(TempDir / index_filename, "w") as f:
414
+ json.dump(index_data, f, indent=4)
415
+
416
+ path_in_repo = f"{base_subfolder}/{index_filename}" if base_subfolder else index_filename
417
+ api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
418
+
419
+ cleanup_temp()
420
+ return f"Done! Merged into {buffer.shard_count} shards at {output_repo}"
421
+
422
+ # =================================================================================
423
+ # TAB 2: EXTRACT LORA
424
+ # =================================================================================
425
+
426
+ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
427
+ org = MemoryEfficientSafeOpen(model_org)
428
+ tuned = MemoryEfficientSafeOpen(model_tuned)
429
+ lora_sd = {}
430
+ print("Calculating diffs...")
431
+ for key in tqdm(org.keys()):
432
+ if key not in tuned.keys(): continue
433
+ mat_org = org.get_tensor(key).float()
434
+ mat_tuned = tuned.get_tensor(key).float()
435
+ diff = mat_tuned - mat_org
436
+ if torch.max(torch.abs(diff)) < 1e-4: continue
437
+
438
+ out_dim, in_dim = diff.shape[:2]
439
+ r = min(rank, in_dim, out_dim)
440
+ is_conv = len(diff.shape) == 4
441
+ if is_conv: diff = diff.flatten(start_dim=1)
442
+
443
+ try:
444
+ U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
445
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
446
+ U = U @ torch.diag(S)
447
+ dist = torch.cat([U.flatten(), Vh.flatten()])
448
+ hi_val = torch.quantile(dist, clamp)
449
+ U = U.clamp(-hi_val, hi_val)
450
+ Vh = Vh.clamp(-hi_val, hi_val)
451
+ if is_conv:
452
+ U = U.reshape(out_dim, r, 1, 1)
453
+ Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
454
+ else:
455
+ U = U.reshape(out_dim, r)
456
+ Vh = Vh.reshape(r, in_dim)
457
+ stem = key.replace(".weight", "")
458
+ lora_sd[f"{stem}.lora_up.weight"] = U
459
+ lora_sd[f"{stem}.lora_down.weight"] = Vh
460
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
461
+ except: pass
462
+ out = TempDir / "extracted.safetensors"
463
+ save_file(lora_sd, out)
464
+ return str(out)
465
+
466
+ def task_extract(hf_token, org, tun, rank, out):
467
+ cleanup_temp()
468
+ login(hf_token)
469
+ try:
470
+ p1 = download_file(org, hf_token, filename="org.safetensors")
471
+ p2 = download_file(tun, hf_token, filename="tun.safetensors")
472
+ f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
473
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
474
+ api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token)
475
+ return "Done"
476
+ except Exception as e: return f"Error: {e}"
477
+
478
+ # =================================================================================
479
+ # TAB 3: MERGE ADAPTERS (EMA) with Sigma Rel
480
+ # =================================================================================
481
+
482
+ def sigma_rel_to_gamma(sigma_rel):
483
+ t = sigma_rel**-2
484
+ coeffs = [1, 7, 16 - t, 12 - t]
485
+ roots = np.roots(coeffs)
486
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
487
+ return gamma
488
+
489
+ def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo):
490
+ cleanup_temp()
491
+ login(hf_token)
492
+
493
+ urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
494
+ paths = []
495
+ try:
496
+ for i, url in enumerate(urls):
497
+ paths.append(download_file(url, hf_token, filename=f"a_{i}.safetensors"))
498
+ except Exception as e: return f"Download Error: {e}"
499
+
500
+ if not paths: return "No models found"
501
+
502
+ base_sd = load_file(paths[0], device="cpu")
503
+ for k in base_sd:
504
+ if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
505
+
506
+ gamma = None
507
+ if sigma_rel > 0:
508
+ gamma = sigma_rel_to_gamma(sigma_rel)
509
+
510
+ for i, path in enumerate(paths[1:]):
511
+ print(f"Merging {path}")
512
+ if gamma is not None:
513
+ t = i + 1
514
+ current_beta = (1 - 1 / t) ** (gamma + 1)
515
+ else:
516
+ current_beta = beta
517
+
518
+ curr = load_file(path, device="cpu")
519
+ for k in base_sd:
520
+ if k in curr and "alpha" not in k:
521
+ base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
522
+
523
+ out = TempDir / "merged_adapters.safetensors"
524
+ save_file(base_sd, out)
525
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
526
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
527
+ return "Done"
528
+
529
+ # =================================================================================
530
+ # TAB 4: RESIZE
531
+ # =================================================================================
532
+
533
+ def index_sv_ratio(S, target):
534
+ max_sv = S[0]
535
+ min_sv = max_sv / target
536
+ index = int(torch.sum(S > min_sv).item())
537
+ return max(1, min(index, len(S) - 1))
538
+
539
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
540
+ cleanup_temp()
541
+ login(hf_token)
542
+ try:
543
+ path = download_file(lora_input, hf_token)
544
+ except Exception as e: return f"Error: {e}"
545
+
546
+ state = load_file(path, device="cpu")
547
+ new_state = {}
548
+
549
+ groups = {}
550
+ for k in state:
551
+ stem = get_key_stem(k)
552
+ simple = k.split(".lora_")[0]
553
+ if simple not in groups: groups[simple] = {}
554
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
555
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
556
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
557
+
558
+ for stem, g in tqdm(groups.items()):
559
+ if "down" in g and "up" in g:
560
+ down, up = g["down"].float(), g["up"].float()
561
+
562
+ if len(down.shape) == 4:
563
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
564
+ flat = merged.flatten(1)
565
+ else:
566
+ merged = up @ down
567
+ flat = merged
568
+
569
+ U, S, Vh = torch.linalg.svd(flat, full_matrices=False)
570
+
571
+ target_rank = int(new_rank)
572
+ if dynamic_method == "sv_ratio":
573
+ target_rank = index_sv_ratio(S, dynamic_param)
574
+
575
+ target_rank = min(target_rank, S.shape[0])
576
+
577
+ U = U[:, :target_rank]
578
+ S = S[:target_rank]
579
+ U = U @ torch.diag(S)
580
+ Vh = Vh[:target_rank, :]
581
+
582
+ if len(down.shape) == 4:
583
+ U = U.reshape(up.shape[0], target_rank, 1, 1)
584
+ Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3])
585
+
586
+ new_state[f"{stem}.lora_down.weight"] = Vh
587
+ new_state[f"{stem}.lora_up.weight"] = U
588
+ new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
589
+
590
+ out = TempDir / "resized.safetensors"
591
+ save_file(new_state, out)
592
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
593
+ api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token)
594
+ return "Done"
595
+
596
+ # =================================================================================
597
+ # UI
598
+ # =================================================================================
599
+
600
+ css = ".container { max-width: 900px; margin: auto; }"
601
+
602
+ with gr.Blocks() as demo:
603
+ gr.Markdown("# 🧰SOONmerge® LoRA Toolkit")
604
+
605
+ with gr.Tabs():
606
+ with gr.Tab("Merge to Base + Reshard Output"):
607
+ t1_token = gr.Textbox(label="Token", type="password")
608
+ t1_base = gr.Textbox(label="Base Repo (Diffusers)", value="ostris/Z-Image-De-Turbo")
609
+ t1_sub = gr.Textbox(label="Subfolder", value="transformer")
610
+ t1_lora = gr.Textbox(label="LoRA Direct Link", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
611
+ with gr.Row():
612
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1)
613
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
614
+ t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1)
615
+ t1_out = gr.Textbox(label="Output Repo")
616
+ t1_struct = gr.Textbox(label="Diffusers Extras (Copies VAE/TextEnc/etc)", value="Tongyi-MAI/Z-Image-Turbo")
617
+ t1_priv = gr.Checkbox(label="Private", value=True)
618
+ t1_btn = gr.Button("Merge")
619
+ t1_res = gr.Textbox(label="Result")
620
+ t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
621
+
622
+ with gr.Tab("Extract Adapter"):
623
+ t2_token = gr.Textbox(label="Token", type="password")
624
+ t2_org = gr.Textbox(label="Original Model")
625
+ t2_tun = gr.Textbox(label="Tuned Model")
626
+ t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1)
627
+ t2_out = gr.Textbox(label="Output Repo")
628
+ t2_btn = gr.Button("Extract")
629
+ t2_res = gr.Textbox(label="Result")
630
+ t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
631
+
632
+ with gr.Tab("Merge Multiple Adapters"):
633
+ t3_token = gr.Textbox(label="Token", type="password")
634
+ t3_urls = gr.Textbox(label="URLs")
635
+ with gr.Row():
636
+ t3_beta = gr.Slider(label="Beta", value=0.95, minimum=0.01, maximum=1.00, step=0.01)
637
+ t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.21, minimum=0.01, maximum=1.00, step=0.01)
638
+ t3_out = gr.Textbox(label="Output Repo")
639
+ t3_btn = gr.Button("Merge")
640
+ t3_res = gr.Textbox(label="Result")
641
+ t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res)
642
+
643
+ with gr.Tab("Resize Adapter"):
644
+ t4_token = gr.Textbox(label="Token", type="password")
645
+ t4_in = gr.Textbox(label="LoRA")
646
+ with gr.Row():
647
+ t4_rank = gr.Number(label="To Rank (Lower Only!)", value=8, minimum=1, maximum=256, step=1)
648
+ t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method")
649
+ t4_param = gr.Number(label="Dynamic Param", value=4.0)
650
+ t4_out = gr.Textbox(label="Output")
651
+ t4_btn = gr.Button("Resize")
652
+ t4_res = gr.Textbox(label="Result")
653
+ t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
654
+
655
+ if __name__ == "__main__":
656
+ demo.queue().launch(css=css, ssr_mode=False)