AlekseyCalvin commited on
Commit
7c19351
·
verified ·
1 Parent(s): 1eb4a88

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +631 -0
app.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import re
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional, List
13
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
14
+ from safetensors.torch import load_file, save_file
15
+ from tqdm import tqdm
16
+
17
+ # --- Memory Efficient Safetensors ---
18
+ class MemoryEfficientSafeOpen:
19
+ """
20
+ Reads safetensors metadata and tensors without mmap, keeping RAM usage low.
21
+ """
22
+ def __init__(self, filename):
23
+ self.filename = filename
24
+ self.file = open(filename, "rb")
25
+ self.header, self.header_size = self._read_header()
26
+
27
+ def __enter__(self):
28
+ return self
29
+
30
+ def __exit__(self, exc_type, exc_val, exc_tb):
31
+ self.file.close()
32
+
33
+ def keys(self) -> list[str]:
34
+ return [k for k in self.header.keys() if k != "__metadata__"]
35
+
36
+ def metadata(self) -> Dict[str, str]:
37
+ return self.header.get("__metadata__", {})
38
+
39
+ def get_tensor(self, key):
40
+ if key not in self.header:
41
+ raise KeyError(f"Tensor '{key}' not found in the file")
42
+ metadata = self.header[key]
43
+ offset_start, offset_end = metadata["data_offsets"]
44
+ self.file.seek(self.header_size + 8 + offset_start)
45
+ tensor_bytes = self.file.read(offset_end - offset_start)
46
+ return self._deserialize_tensor(tensor_bytes, metadata)
47
+
48
+ def _read_header(self):
49
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
50
+ header_json = self.file.read(header_size).decode("utf-8")
51
+ return json.loads(header_json), header_size
52
+
53
+ def _deserialize_tensor(self, tensor_bytes, metadata):
54
+ dtype_map = {
55
+ "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
56
+ "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
57
+ "U8": torch.uint8, "BOOL": torch.bool
58
+ }
59
+ dtype = dtype_map[metadata["dtype"]]
60
+ shape = metadata["shape"]
61
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
62
+
63
+ # --- Constants & Setup ---
64
+ try:
65
+ TempDir = Path("/tmp/temp_tool")
66
+ os.makedirs(TempDir, exist_ok=True)
67
+ except:
68
+ TempDir = Path("./temp_tool")
69
+ os.makedirs(TempDir, exist_ok=True)
70
+
71
+ api = HfApi()
72
+
73
+ def cleanup_temp():
74
+ if TempDir.exists():
75
+ shutil.rmtree(TempDir)
76
+ os.makedirs(TempDir, exist_ok=True)
77
+ gc.collect()
78
+
79
+ def download_file(input_path, token, filename=None):
80
+ local_path = TempDir / (filename if filename else "model.safetensors")
81
+ if input_path.startswith("http"):
82
+ print(f"Downloading {filename} from URL...")
83
+ try:
84
+ response = requests.get(input_path, stream=True, timeout=30)
85
+ response.raise_for_status()
86
+ with open(local_path, 'wb') as f:
87
+ for chunk in response.iter_content(chunk_size=8192):
88
+ f.write(chunk)
89
+ except Exception as e: raise ValueError(f"Download failed: {e}")
90
+ else:
91
+ print(f"Downloading {filename} from Hub...")
92
+ if not filename:
93
+ try:
94
+ files = list_repo_files(repo_id=input_path, token=token)
95
+ safetensors = [f for f in files if f.endswith(".safetensors")]
96
+ filename = safetensors[0] if safetensors else "adapter_model.safetensors"
97
+ except: filename = "adapter_model.safetensors"
98
+
99
+ try:
100
+ hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
101
+ if not (TempDir / filename).exists():
102
+ found = list(TempDir.rglob(filename))
103
+ if found: shutil.move(found[0], local_path)
104
+ except Exception as e: raise ValueError(f"Hub download failed: {e}")
105
+
106
+ return local_path
107
+
108
+ def get_key_stem(key):
109
+ key = key.replace(".weight", "").replace(".bias", "")
110
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
111
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
112
+ key = key.replace(".alpha", "")
113
+ prefixes = [
114
+ "model.diffusion_model.", "diffusion_model.", "model.",
115
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
116
+ ]
117
+ changed = True
118
+ while changed:
119
+ changed = False
120
+ for p in prefixes:
121
+ if key.startswith(p):
122
+ key = key[len(p):]
123
+ changed = True
124
+ return key
125
+
126
+ # =================================================================================
127
+ # TAB 1: GREEDY STREAMING RESHARDER + SERVER-SIDE COPY
128
+ # =================================================================================
129
+
130
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
131
+ print(f"Loading LoRA from {lora_path}...")
132
+ state_dict = load_file(lora_path, device="cpu")
133
+ pairs = {}
134
+ alphas = {}
135
+ for k, v in state_dict.items():
136
+ stem = get_key_stem(k)
137
+ if "alpha" in k:
138
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
139
+ else:
140
+ if stem not in pairs: pairs[stem] = {}
141
+ if "lora_down" in k or "lora_A" in k:
142
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
143
+ pairs[stem]["rank"] = v.shape[0]
144
+ elif "lora_up" in k or "lora_B" in k:
145
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
146
+ for stem in pairs:
147
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
148
+ return pairs
149
+
150
+ class ShardBuffer:
151
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token):
152
+ self.max_bytes = int(max_size_gb * 1024**3)
153
+ self.output_dir = output_dir
154
+ self.output_repo = output_repo
155
+ self.subfolder = subfolder
156
+ self.hf_token = hf_token
157
+ self.buffer = []
158
+ self.current_bytes = 0
159
+ self.shard_count = 0
160
+ self.index_map = {}
161
+
162
+ def add_tensor(self, key, tensor):
163
+ if tensor.dtype == torch.bfloat16:
164
+ raw_bytes = tensor.view(torch.int16).numpy().tobytes()
165
+ dtype_str = "BF16"
166
+ elif tensor.dtype == torch.float16:
167
+ raw_bytes = tensor.numpy().tobytes()
168
+ dtype_str = "F16"
169
+ else:
170
+ raw_bytes = tensor.numpy().tobytes()
171
+ dtype_str = "F32"
172
+
173
+ size = len(raw_bytes)
174
+ self.buffer.append({
175
+ "key": key,
176
+ "data": raw_bytes,
177
+ "dtype": dtype_str,
178
+ "shape": tensor.shape
179
+ })
180
+ self.current_bytes += size
181
+ if self.current_bytes >= self.max_bytes:
182
+ self.flush()
183
+
184
+ def flush(self):
185
+ if not self.buffer: return
186
+ self.shard_count += 1
187
+
188
+ # Proper shard naming including subfolder
189
+ filename = f"model-{self.shard_count:05d}.safetensors"
190
+
191
+ # If subfolder exists, prepend it for the upload path
192
+ path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
193
+
194
+ print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...")
195
+
196
+ header = {"__metadata__": {"format": "pt"}}
197
+ current_offset = 0
198
+ for item in self.buffer:
199
+ header[item["key"]] = {
200
+ "dtype": item["dtype"],
201
+ "shape": item["shape"],
202
+ "data_offsets": [current_offset, current_offset + len(item["data"])]
203
+ }
204
+ current_offset += len(item["data"])
205
+ self.index_map[item["key"]] = filename # Index map uses relative filename
206
+
207
+ header_json = json.dumps(header).encode('utf-8')
208
+
209
+ out_path = self.output_dir / filename
210
+ with open(out_path, 'wb') as f:
211
+ f.write(struct.pack('<Q', len(header_json)))
212
+ f.write(header_json)
213
+ for item in self.buffer:
214
+ f.write(item["data"])
215
+
216
+ print(f"Uploading {path_in_repo}...")
217
+ api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
218
+
219
+ os.remove(out_path)
220
+ self.buffer = []
221
+ self.current_bytes = 0
222
+ gc.collect()
223
+
224
+ def server_side_copy_structure(token, src_repo, dst_repo, ignore_prefix="transformer"):
225
+ """
226
+ Copies all files from src_repo to dst_repo EXCEPT those starting with ignore_prefix.
227
+ Uses server-side copy (zero local disk usage).
228
+ """
229
+ print(f"Scanning {src_repo} for structure cloning...")
230
+ try:
231
+ files = api.list_repo_files(repo_id=src_repo, token=token)
232
+ files_to_copy = []
233
+
234
+ for f in files:
235
+ # Skip the folder we are replacing
236
+ if ignore_prefix and f.startswith(ignore_prefix):
237
+ continue
238
+
239
+ # Skip hidden files
240
+ if f.startswith("."):
241
+ continue
242
+
243
+ files_to_copy.append(f)
244
+
245
+ print(f"Found {len(files_to_copy)} files to copy (skipping {ignore_prefix})...")
246
+
247
+ for f in tqdm(files_to_copy, desc="Server-side Copying"):
248
+ try:
249
+ # API copy_file is server-side
250
+ api.copy_file(
251
+ repo_id=src_repo,
252
+ filename=f,
253
+ target_repo_id=dst_repo,
254
+ target_filename=f,
255
+ token=token
256
+ )
257
+ except Exception as e:
258
+ print(f"Failed to copy {f}: {e}")
259
+
260
+ except Exception as e:
261
+ print(f"Structure cloning failed: {e}")
262
+
263
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
264
+ cleanup_temp()
265
+ login(hf_token)
266
+
267
+ # 1. Output Setup
268
+ try:
269
+ api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
270
+ except Exception as e: return f"Error creating repo: {e}"
271
+
272
+ # 2. Server-Side Structure Clone
273
+ if structure_repo:
274
+ # If we are writing to 'transformer', we ignore existing 'transformer' files in source
275
+ ignore = base_subfolder if base_subfolder else None
276
+ server_side_copy_structure(hf_token, structure_repo, output_repo, ignore)
277
+
278
+ # 3. Load LoRA
279
+ dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
280
+ try:
281
+ progress(0.1, desc="Downloading LoRA...")
282
+ lora_path = download_file(lora_input, hf_token, filename="adapter.safetensors")
283
+ lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
284
+ except Exception as e: return f"Error loading LoRA: {e}"
285
+
286
+ # 4. Stream Process
287
+ progress(0.2, desc="Fetching File List...")
288
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
289
+ input_shards = [f for f in files if f.endswith(".safetensors")]
290
+ if base_subfolder:
291
+ input_shards = [f for f in input_shards if f.startswith(base_subfolder)]
292
+
293
+ if not input_shards: return "No base safetensors found."
294
+
295
+ input_shards.sort()
296
+
297
+ # Pass base_subfolder to buffer so it knows where to put files
298
+ buffer = ShardBuffer(shard_size, TempDir, output_repo, base_subfolder, hf_token)
299
+
300
+ for i, shard_file in enumerate(input_shards):
301
+ progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {shard_file}")
302
+
303
+ local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
304
+
305
+ with MemoryEfficientSafeOpen(local_shard) as f:
306
+ keys = f.keys()
307
+ for k in keys:
308
+ v = f.get_tensor(k)
309
+ base_stem = get_key_stem(k)
310
+ lora_keys = set(lora_pairs.keys())
311
+ match = None
312
+
313
+ if base_stem in lora_keys:
314
+ match = lora_pairs[base_stem]
315
+ else:
316
+ if "to_q" in base_stem:
317
+ qkv_stem = base_stem.replace("to_q", "qkv")
318
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
319
+ elif "to_k" in base_stem:
320
+ qkv_stem = base_stem.replace("to_k", "qkv")
321
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
322
+ elif "to_v" in base_stem:
323
+ qkv_stem = base_stem.replace("to_v", "qkv")
324
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
325
+
326
+ if match and "down" in match and "up" in match:
327
+ down = match["down"]
328
+ up = match["up"]
329
+ alpha = match["alpha"]
330
+ rank = match["rank"]
331
+ scaling = scale * (alpha / rank)
332
+
333
+ if len(v.shape) == 4 and len(down.shape) == 2:
334
+ down = down.unsqueeze(-1).unsqueeze(-1)
335
+ up = up.unsqueeze(-1).unsqueeze(-1)
336
+
337
+ try:
338
+ if len(up.shape) == 4:
339
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
340
+ else:
341
+ delta = up @ down
342
+ except:
343
+ delta = up.T @ down
344
+
345
+ delta = delta * scaling
346
+ valid_delta = True
347
+ if delta.shape == v.shape:
348
+ pass
349
+ elif delta.shape[0] == v.shape[0] * 3:
350
+ chunk = v.shape[0]
351
+ if "to_q" in k: delta = delta[0:chunk, ...]
352
+ elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
353
+ elif "to_v" in k: delta = delta[2*chunk:, ...]
354
+ else: valid_delta = False
355
+ elif delta.numel() == v.numel():
356
+ delta = delta.reshape(v.shape)
357
+ else:
358
+ valid_delta = False
359
+
360
+ if valid_delta:
361
+ v = v.to(dtype)
362
+ delta = delta.to(dtype)
363
+ v.add_(delta)
364
+ del delta
365
+
366
+ if v.dtype != dtype: v = v.to(dtype)
367
+ buffer.add_tensor(k, v)
368
+ del v
369
+
370
+ os.remove(local_shard)
371
+ gc.collect()
372
+
373
+ buffer.flush()
374
+
375
+ # Upload Index
376
+ print("Uploading Index...")
377
+ index_data = {"metadata": {"total_size": 0}, "weight_map": buffer.index_map}
378
+ index_name = "model.safetensors.index.json"
379
+ with open(TempDir / index_name, "w") as f:
380
+ json.dump(index_data, f, indent=4)
381
+
382
+ path_in_repo = f"{base_subfolder}/{index_name}" if base_subfolder else index_name
383
+ api.upload_file(path_or_fileobj=TempDir / index_name, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
384
+
385
+ cleanup_temp()
386
+ return f"Done! Merged into {buffer.shard_count} shards at {output_repo}"
387
+
388
+ # =================================================================================
389
+ # TAB 2: EXTRACT LORA
390
+ # =================================================================================
391
+
392
+ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
393
+ org = MemoryEfficientSafeOpen(model_org)
394
+ tuned = MemoryEfficientSafeOpen(model_tuned)
395
+ lora_sd = {}
396
+ print("Calculating diffs...")
397
+ for key in tqdm(org.keys()):
398
+ if key not in tuned.keys(): continue
399
+ mat_org = org.get_tensor(key).float()
400
+ mat_tuned = tuned.get_tensor(key).float()
401
+ diff = mat_tuned - mat_org
402
+ if torch.max(torch.abs(diff)) < 1e-4: continue
403
+
404
+ out_dim, in_dim = diff.shape[:2]
405
+ r = min(rank, in_dim, out_dim)
406
+ is_conv = len(diff.shape) == 4
407
+ if is_conv: diff = diff.flatten(start_dim=1)
408
+
409
+ try:
410
+ U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
411
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
412
+ U = U @ torch.diag(S)
413
+ dist = torch.cat([U.flatten(), Vh.flatten()])
414
+ hi_val = torch.quantile(dist, clamp)
415
+ U = U.clamp(-hi_val, hi_val)
416
+ Vh = Vh.clamp(-hi_val, hi_val)
417
+ if is_conv:
418
+ U = U.reshape(out_dim, r, 1, 1)
419
+ Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
420
+ else:
421
+ U = U.reshape(out_dim, r)
422
+ Vh = Vh.reshape(r, in_dim)
423
+ stem = key.replace(".weight", "")
424
+ lora_sd[f"{stem}.lora_up.weight"] = U
425
+ lora_sd[f"{stem}.lora_down.weight"] = Vh
426
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
427
+ except: pass
428
+ out = TempDir / "extracted.safetensors"
429
+ save_file(lora_sd, out)
430
+ return str(out)
431
+
432
+ def task_extract(hf_token, org, tun, rank, out):
433
+ cleanup_temp()
434
+ login(hf_token)
435
+ try:
436
+ p1 = download_file(org, hf_token, filename="org.safetensors")
437
+ p2 = download_file(tun, hf_token, filename="tun.safetensors")
438
+ f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
439
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
440
+ api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token)
441
+ return "Done"
442
+ except Exception as e: return f"Error: {e}"
443
+
444
+ # =================================================================================
445
+ # TAB 3: MERGE ADAPTERS (EMA) with Sigma Rel
446
+ # =================================================================================
447
+
448
+ def sigma_rel_to_gamma(sigma_rel):
449
+ t = sigma_rel**-2
450
+ coeffs = [1, 7, 16 - t, 12 - t]
451
+ roots = np.roots(coeffs)
452
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
453
+ return gamma
454
+
455
+ def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo):
456
+ cleanup_temp()
457
+ login(hf_token)
458
+
459
+ urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
460
+ paths = []
461
+ try:
462
+ for i, url in enumerate(urls):
463
+ paths.append(download_file(url, hf_token, filename=f"a_{i}.safetensors"))
464
+ except Exception as e: return f"Download Error: {e}"
465
+
466
+ if not paths: return "No models found"
467
+
468
+ # Sort by mtime (proxy for age) or input order. Kohya uses mtime.
469
+ # We will trust input order as "oldest to newest" for simplicity in UI context.
470
+
471
+ base_sd = load_file(paths[0], device="cpu")
472
+ for k in base_sd:
473
+ if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
474
+
475
+ gamma = None
476
+ if sigma_rel > 0:
477
+ gamma = sigma_rel_to_gamma(sigma_rel)
478
+
479
+ ema_count = len(paths) - 1
480
+
481
+ for i, path in enumerate(paths[1:]):
482
+ print(f"Merging {path}")
483
+
484
+ # Calculate Beta
485
+ if gamma is not None:
486
+ t = i + 1
487
+ current_beta = (1 - 1 / t) ** (gamma + 1)
488
+ else:
489
+ current_beta = beta # Fixed beta or interpolation logic could go here
490
+
491
+ curr = load_file(path, device="cpu")
492
+ for k in base_sd:
493
+ if k in curr and "alpha" not in k:
494
+ base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
495
+
496
+ out = TempDir / "merged_adapters.safetensors"
497
+ save_file(base_sd, out)
498
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
499
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
500
+ return "Done"
501
+
502
+ # =================================================================================
503
+ # TAB 4: RESIZE
504
+ # =================================================================================
505
+
506
+ def index_sv_ratio(S, target):
507
+ max_sv = S[0]
508
+ min_sv = max_sv / target
509
+ index = int(torch.sum(S > min_sv).item())
510
+ return max(1, min(index, len(S) - 1))
511
+
512
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
513
+ cleanup_temp()
514
+ login(hf_token)
515
+ try:
516
+ path = download_file(lora_input, hf_token)
517
+ except Exception as e: return f"Error: {e}"
518
+
519
+ state = load_file(path, device="cpu")
520
+ new_state = {}
521
+
522
+ groups = {}
523
+ for k in state:
524
+ stem = get_key_stem(k)
525
+ simple = k.split(".lora_")[0]
526
+ if simple not in groups: groups[simple] = {}
527
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
528
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
529
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
530
+
531
+ for stem, g in tqdm(groups.items()):
532
+ if "down" in g and "up" in g:
533
+ down, up = g["down"].float(), g["up"].float()
534
+
535
+ # Merge
536
+ if len(down.shape) == 4:
537
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
538
+ flat = merged.flatten(1)
539
+ else:
540
+ merged = up @ down
541
+ flat = merged
542
+
543
+ U, S, Vh = torch.linalg.svd(flat, full_matrices=False)
544
+
545
+ # Rank Selection
546
+ target_rank = int(new_rank)
547
+ if dynamic_method == "sv_ratio":
548
+ target_rank = index_sv_ratio(S, dynamic_param)
549
+
550
+ target_rank = min(target_rank, S.shape[0])
551
+
552
+ U = U[:, :target_rank]
553
+ S = S[:target_rank]
554
+ U = U @ torch.diag(S)
555
+ Vh = Vh[:target_rank, :]
556
+
557
+ if len(down.shape) == 4:
558
+ U = U.reshape(up.shape[0], target_rank, 1, 1)
559
+ Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3])
560
+
561
+ new_state[f"{stem}.lora_down.weight"] = Vh
562
+ new_state[f"{stem}.lora_up.weight"] = U
563
+ new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
564
+
565
+ out = TempDir / "resized.safetensors"
566
+ save_file(new_state, out)
567
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
568
+ api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token)
569
+ return "Done"
570
+
571
+ # =================================================================================
572
+ # UI
573
+ # =================================================================================
574
+
575
+ css = ".container { max-width: 900px; margin: auto; }"
576
+
577
+ with gr.Blocks() as demo:
578
+ gr.Markdown("# 🧰 SOONmerge® LoRA Toolkit")
579
+
580
+ with gr.Tabs():
581
+ with gr.Tab("Merge + Reshard"):
582
+ t1_token = gr.Textbox(label="Token", type="password")
583
+ t1_base = gr.Textbox(label="Base Repo", value="ostris/Z-Image-De-Turbo")
584
+ t1_sub = gr.Textbox(label="Subfolder", value="transformer")
585
+ t1_lora = gr.Textbox(label="LoRA")
586
+ with gr.Row():
587
+ t1_scale = gr.Slider(label="Scale", value=1.0)
588
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
589
+ t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.5, maximum=10.0, step=0.5)
590
+ t1_out = gr.Textbox(label="Output")
591
+ t1_struct = gr.Textbox(label="Structure Repo (Copies VAE/TextEnc)", value="Tongyi-MAI/Z-Image-Turbo")
592
+ t1_priv = gr.Checkbox(label="Private", value=True)
593
+ t1_btn = gr.Button("Merge")
594
+ t1_res = gr.Textbox(label="Result")
595
+ t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
596
+
597
+ with gr.Tab("Extract"):
598
+ t2_token = gr.Textbox(label="Token", type="password")
599
+ t2_org = gr.Textbox(label="Original")
600
+ t2_tun = gr.Textbox(label="Tuned")
601
+ t2_rank = gr.Number(label="Rank", value=32)
602
+ t2_out = gr.Textbox(label="Output")
603
+ t2_btn = gr.Button("Extract")
604
+ t2_res = gr.Textbox(label="Result")
605
+ t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
606
+
607
+ with gr.Tab("Merge Adapters (EMA)"):
608
+ t3_token = gr.Textbox(label="Token", type="password")
609
+ t3_urls = gr.Textbox(label="URLs")
610
+ with gr.Row():
611
+ t3_beta = gr.Slider(label="Beta", value=0.9)
612
+ t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.0)
613
+ t3_out = gr.Textbox(label="Output")
614
+ t3_btn = gr.Button("Merge")
615
+ t3_res = gr.Textbox(label="Result")
616
+ t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res)
617
+
618
+ with gr.Tab("Resize"):
619
+ t4_token = gr.Textbox(label="Token", type="password")
620
+ t4_in = gr.Textbox(label="LoRA")
621
+ with gr.Row():
622
+ t4_rank = gr.Number(label="Rank", value=8)
623
+ t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method")
624
+ t4_param = gr.Number(label="Dynamic Param", value=4.0)
625
+ t4_out = gr.Textbox(label="Output")
626
+ t4_btn = gr.Button("Resize")
627
+ t4_res = gr.Textbox(label="Result")
628
+ t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
629
+
630
+ if __name__ == "__main__":
631
+ demo.queue().launch(css=css, ssr_mode=False)