AlekseyCalvin commited on
Commit
e134498
·
verified ·
1 Parent(s): 9312d5d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +623 -0
app.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import re
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional, List
13
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
14
+ from safetensors.torch import load_file, save_file
15
+ from tqdm import tqdm
16
+
17
+ # --- Memory Efficient Safetensors ---
18
+ class MemoryEfficientSafeOpen:
19
+ """
20
+ Reads safetensors metadata and tensors without mmap, keeping RAM usage low.
21
+ """
22
+ def __init__(self, filename):
23
+ self.filename = filename
24
+ self.file = open(filename, "rb")
25
+ self.header, self.header_size = self._read_header()
26
+
27
+ def __enter__(self):
28
+ return self
29
+
30
+ def __exit__(self, exc_type, exc_val, exc_tb):
31
+ self.file.close()
32
+
33
+ def keys(self) -> list[str]:
34
+ return [k for k in self.header.keys() if k != "__metadata__"]
35
+
36
+ def metadata(self) -> Dict[str, str]:
37
+ return self.header.get("__metadata__", {})
38
+
39
+ def get_tensor(self, key):
40
+ if key not in self.header:
41
+ raise KeyError(f"Tensor '{key}' not found in the file")
42
+ metadata = self.header[key]
43
+ offset_start, offset_end = metadata["data_offsets"]
44
+ self.file.seek(self.header_size + 8 + offset_start)
45
+ tensor_bytes = self.file.read(offset_end - offset_start)
46
+ return self._deserialize_tensor(tensor_bytes, metadata)
47
+
48
+ def _read_header(self):
49
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
50
+ header_json = self.file.read(header_size).decode("utf-8")
51
+ return json.loads(header_json), header_size
52
+
53
+ def _deserialize_tensor(self, tensor_bytes, metadata):
54
+ dtype_map = {
55
+ "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
56
+ "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
57
+ "U8": torch.uint8, "BOOL": torch.bool
58
+ }
59
+ dtype = dtype_map[metadata["dtype"]]
60
+ shape = metadata["shape"]
61
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
62
+
63
+ # --- Constants & Setup ---
64
+ try:
65
+ TempDir = Path("/tmp/temp_tool")
66
+ os.makedirs(TempDir, exist_ok=True)
67
+ except:
68
+ TempDir = Path("./temp_tool")
69
+ os.makedirs(TempDir, exist_ok=True)
70
+
71
+ api = HfApi()
72
+
73
+ def cleanup_temp():
74
+ if TempDir.exists():
75
+ shutil.rmtree(TempDir)
76
+ os.makedirs(TempDir, exist_ok=True)
77
+ gc.collect()
78
+
79
+ def download_file(input_path, token, filename=None):
80
+ local_path = TempDir / (filename if filename else "model.safetensors")
81
+ if input_path.startswith("http"):
82
+ print(f"Downloading {filename} from URL...")
83
+ try:
84
+ response = requests.get(input_path, stream=True, timeout=30)
85
+ response.raise_for_status()
86
+ with open(local_path, 'wb') as f:
87
+ for chunk in response.iter_content(chunk_size=8192):
88
+ f.write(chunk)
89
+ except Exception as e: raise ValueError(f"Download failed: {e}")
90
+ else:
91
+ print(f"Downloading {filename} from Hub...")
92
+ if not filename:
93
+ try:
94
+ files = list_repo_files(repo_id=input_path, token=token)
95
+ safetensors = [f for f in files if f.endswith(".safetensors")]
96
+ filename = safetensors[0] if safetensors else "adapter_model.safetensors"
97
+ except: filename = "adapter_model.safetensors"
98
+
99
+ try:
100
+ hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
101
+ if not (TempDir / filename).exists():
102
+ found = list(TempDir.rglob(filename))
103
+ if found: shutil.move(found[0], local_path)
104
+ except Exception as e: raise ValueError(f"Hub download failed: {e}")
105
+
106
+ return local_path
107
+
108
+ def get_key_stem(key):
109
+ key = key.replace(".weight", "").replace(".bias", "")
110
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
111
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
112
+ key = key.replace(".alpha", "")
113
+ prefixes = [
114
+ "model.diffusion_model.", "diffusion_model.", "model.",
115
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
116
+ ]
117
+ changed = True
118
+ while changed:
119
+ changed = False
120
+ for p in prefixes:
121
+ if key.startswith(p):
122
+ key = key[len(p):]
123
+ changed = True
124
+ return key
125
+
126
+ # =================================================================================
127
+ # TAB 1: MERGE & RESHARD (Fixes Folder Structure & Aux Files)
128
+ # =================================================================================
129
+
130
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
131
+ print(f"Loading LoRA from {lora_path}...")
132
+ state_dict = load_file(lora_path, device="cpu")
133
+ pairs = {}
134
+ alphas = {}
135
+ for k, v in state_dict.items():
136
+ stem = get_key_stem(k)
137
+ if "alpha" in k:
138
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
139
+ else:
140
+ if stem not in pairs: pairs[stem] = {}
141
+ if "lora_down" in k or "lora_A" in k:
142
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
143
+ pairs[stem]["rank"] = v.shape[0]
144
+ elif "lora_up" in k or "lora_B" in k:
145
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
146
+ for stem in pairs:
147
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
148
+ return pairs
149
+
150
+ class ShardBuffer:
151
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token):
152
+ self.max_bytes = int(max_size_gb * 1024**3)
153
+ self.output_dir = output_dir
154
+ self.output_repo = output_repo
155
+ self.subfolder = subfolder
156
+ self.hf_token = hf_token
157
+ self.buffer = []
158
+ self.current_bytes = 0
159
+ self.shard_count = 0
160
+ self.index_map = {}
161
+
162
+ def add_tensor(self, key, tensor):
163
+ if tensor.dtype == torch.bfloat16:
164
+ raw_bytes = tensor.view(torch.int16).numpy().tobytes()
165
+ dtype_str = "BF16"
166
+ elif tensor.dtype == torch.float16:
167
+ raw_bytes = tensor.numpy().tobytes()
168
+ dtype_str = "F16"
169
+ else:
170
+ raw_bytes = tensor.numpy().tobytes()
171
+ dtype_str = "F32"
172
+
173
+ size = len(raw_bytes)
174
+ self.buffer.append({
175
+ "key": key,
176
+ "data": raw_bytes,
177
+ "dtype": dtype_str,
178
+ "shape": tensor.shape
179
+ })
180
+ self.current_bytes += size
181
+ if self.current_bytes >= self.max_bytes:
182
+ self.flush()
183
+
184
+ def flush(self):
185
+ if not self.buffer: return
186
+ self.shard_count += 1
187
+
188
+ # Naming: model-00001-of-0000X.safetensors is ideal but we don't know total count yet.
189
+ # We use model-00001.safetensors. Diffusers index.json handles the mapping, so name is flexible.
190
+ filename = f"model-{self.shard_count:05d}.safetensors"
191
+
192
+ # Proper Subfolder Handling
193
+ path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
194
+
195
+ print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...")
196
+
197
+ header = {"__metadata__": {"format": "pt"}}
198
+ current_offset = 0
199
+ for item in self.buffer:
200
+ header[item["key"]] = {
201
+ "dtype": item["dtype"],
202
+ "shape": item["shape"],
203
+ "data_offsets": [current_offset, current_offset + len(item["data"])]
204
+ }
205
+ current_offset += len(item["data"])
206
+ self.index_map[item["key"]] = filename # Index map uses relative filename (no subfolder prefix)
207
+
208
+ header_json = json.dumps(header).encode('utf-8')
209
+
210
+ out_path = self.output_dir / filename
211
+ with open(out_path, 'wb') as f:
212
+ f.write(struct.pack('<Q', len(header_json)))
213
+ f.write(header_json)
214
+ for item in self.buffer:
215
+ f.write(item["data"])
216
+
217
+ print(f"Uploading {path_in_repo}...")
218
+ api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
219
+
220
+ os.remove(out_path)
221
+ self.buffer = []
222
+ self.current_bytes = 0
223
+ gc.collect()
224
+
225
+ def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix="transformer"):
226
+ """
227
+ Copies files one-by-one from source to dest, skipping 'ignore_prefix'.
228
+ Does NOT skip .safetensors/.bin if they are outside the ignore folder.
229
+ """
230
+ print(f"Scanning {src_repo} for auxiliary files...")
231
+ try:
232
+ files = api.list_repo_files(repo_id=src_repo, token=token)
233
+
234
+ for f in tqdm(files, desc="Copying Structure"):
235
+ # 1. Skip the folder we are replacing (e.g., transformer/)
236
+ if ignore_prefix and f.startswith(ignore_prefix):
237
+ continue
238
+
239
+ # 2. Skip hidden/system files
240
+ if f.startswith("."):
241
+ continue
242
+
243
+ # 3. Download -> Upload -> Delete loop
244
+ # This ensures we get VAE/TextEnc weights without disk overflow
245
+ try:
246
+ print(f"Copying {f}...")
247
+ local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir)
248
+
249
+ api.upload_file(
250
+ path_or_fileobj=local,
251
+ path_in_repo=f,
252
+ repo_id=dst_repo,
253
+ token=token
254
+ )
255
+
256
+ if os.path.exists(local):
257
+ os.remove(local)
258
+ except Exception as e:
259
+ print(f"Failed to copy {f}: {e}")
260
+
261
+ except Exception as e:
262
+ print(f"Structure cloning error: {e}")
263
+
264
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
265
+ cleanup_temp()
266
+ login(hf_token)
267
+
268
+ # 1. Output Setup
269
+ try:
270
+ api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
271
+ except Exception as e: return f"Error creating repo: {e}"
272
+
273
+ # 2. Structure Clone (Corrected Logic)
274
+ if structure_repo:
275
+ # Ignore the folder we are about to fill with new weights
276
+ ignore = base_subfolder if base_subfolder else None
277
+ streaming_copy_structure(hf_token, structure_repo, output_repo, ignore)
278
+
279
+ # 3. Load LoRA
280
+ dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
281
+ try:
282
+ progress(0.1, desc="Downloading LoRA...")
283
+ lora_path = download_file(lora_input, hf_token, filename="adapter.safetensors")
284
+ lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
285
+ except Exception as e: return f"Error loading LoRA: {e}"
286
+
287
+ # 4. Stream Process
288
+ progress(0.2, desc="Fetching File List...")
289
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
290
+ input_shards = [f for f in files if f.endswith(".safetensors")]
291
+ if base_subfolder:
292
+ input_shards = [f for f in input_shards if f.startswith(base_subfolder)]
293
+
294
+ if not input_shards: return "No base safetensors found."
295
+
296
+ input_shards.sort()
297
+
298
+ # Buffer handles the logic of accumulation and uploading to the correct subfolder
299
+ buffer = ShardBuffer(shard_size, TempDir, output_repo, base_subfolder, hf_token)
300
+
301
+ for i, shard_file in enumerate(input_shards):
302
+ progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {shard_file}")
303
+
304
+ local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
305
+
306
+ with MemoryEfficientSafeOpen(local_shard) as f:
307
+ keys = f.keys()
308
+ for k in keys:
309
+ v = f.get_tensor(k)
310
+ base_stem = get_key_stem(k)
311
+ lora_keys = set(lora_pairs.keys())
312
+ match = None
313
+
314
+ if base_stem in lora_keys:
315
+ match = lora_pairs[base_stem]
316
+ else:
317
+ if "to_q" in base_stem:
318
+ qkv_stem = base_stem.replace("to_q", "qkv")
319
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
320
+ elif "to_k" in base_stem:
321
+ qkv_stem = base_stem.replace("to_k", "qkv")
322
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
323
+ elif "to_v" in base_stem:
324
+ qkv_stem = base_stem.replace("to_v", "qkv")
325
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
326
+
327
+ if match and "down" in match and "up" in match:
328
+ down = match["down"]
329
+ up = match["up"]
330
+ alpha = match["alpha"]
331
+ rank = match["rank"]
332
+ scaling = scale * (alpha / rank)
333
+
334
+ if len(v.shape) == 4 and len(down.shape) == 2:
335
+ down = down.unsqueeze(-1).unsqueeze(-1)
336
+ up = up.unsqueeze(-1).unsqueeze(-1)
337
+
338
+ try:
339
+ if len(up.shape) == 4:
340
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
341
+ else:
342
+ delta = up @ down
343
+ except:
344
+ delta = up.T @ down
345
+
346
+ delta = delta * scaling
347
+ valid_delta = True
348
+ if delta.shape == v.shape:
349
+ pass
350
+ elif delta.shape[0] == v.shape[0] * 3:
351
+ chunk = v.shape[0]
352
+ if "to_q" in k: delta = delta[0:chunk, ...]
353
+ elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
354
+ elif "to_v" in k: delta = delta[2*chunk:, ...]
355
+ else: valid_delta = False
356
+ elif delta.numel() == v.numel():
357
+ delta = delta.reshape(v.shape)
358
+ else:
359
+ valid_delta = False
360
+
361
+ if valid_delta:
362
+ v = v.to(dtype)
363
+ delta = delta.to(dtype)
364
+ v.add_(delta)
365
+ del delta
366
+
367
+ if v.dtype != dtype: v = v.to(dtype)
368
+ buffer.add_tensor(k, v)
369
+ del v
370
+
371
+ os.remove(local_shard)
372
+ gc.collect()
373
+
374
+ buffer.flush()
375
+
376
+ # Upload Index
377
+ print("Uploading Index...")
378
+ index_data = {"metadata": {"total_size": 0}, "weight_map": buffer.index_map}
379
+ index_name = "model.safetensors.index.json"
380
+ with open(TempDir / index_name, "w") as f:
381
+ json.dump(index_data, f, indent=4)
382
+
383
+ path_in_repo = f"{base_subfolder}/{index_name}" if base_subfolder else index_name
384
+ api.upload_file(path_or_fileobj=TempDir / index_name, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
385
+
386
+ cleanup_temp()
387
+ return f"Done! Merged into {buffer.shard_count} shards at {output_repo}"
388
+
389
+ # =================================================================================
390
+ # TAB 2: EXTRACT LORA
391
+ # =================================================================================
392
+
393
+ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
394
+ org = MemoryEfficientSafeOpen(model_org)
395
+ tuned = MemoryEfficientSafeOpen(model_tuned)
396
+ lora_sd = {}
397
+ print("Calculating diffs...")
398
+ for key in tqdm(org.keys()):
399
+ if key not in tuned.keys(): continue
400
+ mat_org = org.get_tensor(key).float()
401
+ mat_tuned = tuned.get_tensor(key).float()
402
+ diff = mat_tuned - mat_org
403
+ if torch.max(torch.abs(diff)) < 1e-4: continue
404
+
405
+ out_dim, in_dim = diff.shape[:2]
406
+ r = min(rank, in_dim, out_dim)
407
+ is_conv = len(diff.shape) == 4
408
+ if is_conv: diff = diff.flatten(start_dim=1)
409
+
410
+ try:
411
+ U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
412
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
413
+ U = U @ torch.diag(S)
414
+ dist = torch.cat([U.flatten(), Vh.flatten()])
415
+ hi_val = torch.quantile(dist, clamp)
416
+ U = U.clamp(-hi_val, hi_val)
417
+ Vh = Vh.clamp(-hi_val, hi_val)
418
+ if is_conv:
419
+ U = U.reshape(out_dim, r, 1, 1)
420
+ Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
421
+ else:
422
+ U = U.reshape(out_dim, r)
423
+ Vh = Vh.reshape(r, in_dim)
424
+ stem = key.replace(".weight", "")
425
+ lora_sd[f"{stem}.lora_up.weight"] = U
426
+ lora_sd[f"{stem}.lora_down.weight"] = Vh
427
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
428
+ except: pass
429
+ out = TempDir / "extracted.safetensors"
430
+ save_file(lora_sd, out)
431
+ return str(out)
432
+
433
+ def task_extract(hf_token, org, tun, rank, out):
434
+ cleanup_temp()
435
+ login(hf_token)
436
+ try:
437
+ p1 = download_file(org, hf_token, filename="org.safetensors")
438
+ p2 = download_file(tun, hf_token, filename="tun.safetensors")
439
+ f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
440
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
441
+ api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token)
442
+ return "Done"
443
+ except Exception as e: return f"Error: {e}"
444
+
445
+ # =================================================================================
446
+ # TAB 3: MERGE ADAPTERS (EMA) with Sigma Rel
447
+ # =================================================================================
448
+
449
+ def sigma_rel_to_gamma(sigma_rel):
450
+ t = sigma_rel**-2
451
+ coeffs = [1, 7, 16 - t, 12 - t]
452
+ roots = np.roots(coeffs)
453
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
454
+ return gamma
455
+
456
+ def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo):
457
+ cleanup_temp()
458
+ login(hf_token)
459
+
460
+ urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
461
+ paths = []
462
+ try:
463
+ for i, url in enumerate(urls):
464
+ paths.append(download_file(url, hf_token, filename=f"a_{i}.safetensors"))
465
+ except Exception as e: return f"Download Error: {e}"
466
+
467
+ if not paths: return "No models found"
468
+
469
+ base_sd = load_file(paths[0], device="cpu")
470
+ for k in base_sd:
471
+ if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
472
+
473
+ gamma = None
474
+ if sigma_rel > 0:
475
+ gamma = sigma_rel_to_gamma(sigma_rel)
476
+
477
+ for i, path in enumerate(paths[1:]):
478
+ print(f"Merging {path}")
479
+ if gamma is not None:
480
+ t = i + 1
481
+ current_beta = (1 - 1 / t) ** (gamma + 1)
482
+ else:
483
+ current_beta = beta
484
+
485
+ curr = load_file(path, device="cpu")
486
+ for k in base_sd:
487
+ if k in curr and "alpha" not in k:
488
+ base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
489
+
490
+ out = TempDir / "merged_adapters.safetensors"
491
+ save_file(base_sd, out)
492
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
493
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
494
+ return "Done"
495
+
496
+ # =================================================================================
497
+ # TAB 4: RESIZE
498
+ # =================================================================================
499
+
500
+ def index_sv_ratio(S, target):
501
+ max_sv = S[0]
502
+ min_sv = max_sv / target
503
+ index = int(torch.sum(S > min_sv).item())
504
+ return max(1, min(index, len(S) - 1))
505
+
506
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
507
+ cleanup_temp()
508
+ login(hf_token)
509
+ try:
510
+ path = download_file(lora_input, hf_token)
511
+ except Exception as e: return f"Error: {e}"
512
+
513
+ state = load_file(path, device="cpu")
514
+ new_state = {}
515
+
516
+ groups = {}
517
+ for k in state:
518
+ stem = get_key_stem(k)
519
+ simple = k.split(".lora_")[0]
520
+ if simple not in groups: groups[simple] = {}
521
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
522
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
523
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
524
+
525
+ for stem, g in tqdm(groups.items()):
526
+ if "down" in g and "up" in g:
527
+ down, up = g["down"].float(), g["up"].float()
528
+
529
+ if len(down.shape) == 4:
530
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
531
+ flat = merged.flatten(1)
532
+ else:
533
+ merged = up @ down
534
+ flat = merged
535
+
536
+ U, S, Vh = torch.linalg.svd(flat, full_matrices=False)
537
+
538
+ target_rank = int(new_rank)
539
+ if dynamic_method == "sv_ratio":
540
+ target_rank = index_sv_ratio(S, dynamic_param)
541
+
542
+ target_rank = min(target_rank, S.shape[0])
543
+
544
+ U = U[:, :target_rank]
545
+ S = S[:target_rank]
546
+ U = U @ torch.diag(S)
547
+ Vh = Vh[:target_rank, :]
548
+
549
+ if len(down.shape) == 4:
550
+ U = U.reshape(up.shape[0], target_rank, 1, 1)
551
+ Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3])
552
+
553
+ new_state[f"{stem}.lora_down.weight"] = Vh
554
+ new_state[f"{stem}.lora_up.weight"] = U
555
+ new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
556
+
557
+ out = TempDir / "resized.safetensors"
558
+ save_file(new_state, out)
559
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
560
+ api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token)
561
+ return "Done"
562
+
563
+ # =================================================================================
564
+ # UI
565
+ # =================================================================================
566
+
567
+ css = ".container { max-width: 900px; margin: auto; }"
568
+
569
+ with gr.Blocks() as demo:
570
+ gr.Markdown("# 🧰 Universal LoRA Toolkit V14 (Complete)")
571
+
572
+ with gr.Tabs():
573
+ with gr.Tab("Merge + Reshard"):
574
+ t1_token = gr.Textbox(label="Token", type="password")
575
+ t1_base = gr.Textbox(label="Base Repo", value="ostris/Z-Image-De-Turbo")
576
+ t1_sub = gr.Textbox(label="Subfolder", value="transformer")
577
+ t1_lora = gr.Textbox(label="LoRA")
578
+ with gr.Row():
579
+ t1_scale = gr.Slider(label="Scale", value=1.0)
580
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
581
+ t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.5, maximum=10.0, step=0.5)
582
+ t1_out = gr.Textbox(label="Output")
583
+ t1_struct = gr.Textbox(label="Structure Repo (Copies VAE/TextEnc)", value="Tongyi-MAI/Z-Image-Turbo")
584
+ t1_priv = gr.Checkbox(label="Private", value=True)
585
+ t1_btn = gr.Button("Merge")
586
+ t1_res = gr.Textbox(label="Result")
587
+ t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
588
+
589
+ with gr.Tab("Extract"):
590
+ t2_token = gr.Textbox(label="Token", type="password")
591
+ t2_org = gr.Textbox(label="Original")
592
+ t2_tun = gr.Textbox(label="Tuned")
593
+ t2_rank = gr.Number(label="Rank", value=32)
594
+ t2_out = gr.Textbox(label="Output")
595
+ t2_btn = gr.Button("Extract")
596
+ t2_res = gr.Textbox(label="Result")
597
+ t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
598
+
599
+ with gr.Tab("Merge Adapters (EMA)"):
600
+ t3_token = gr.Textbox(label="Token", type="password")
601
+ t3_urls = gr.Textbox(label="URLs")
602
+ with gr.Row():
603
+ t3_beta = gr.Slider(label="Beta", value=0.9)
604
+ t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.0)
605
+ t3_out = gr.Textbox(label="Output")
606
+ t3_btn = gr.Button("Merge")
607
+ t3_res = gr.Textbox(label="Result")
608
+ t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res)
609
+
610
+ with gr.Tab("Resize"):
611
+ t4_token = gr.Textbox(label="Token", type="password")
612
+ t4_in = gr.Textbox(label="LoRA")
613
+ with gr.Row():
614
+ t4_rank = gr.Number(label="Rank", value=8)
615
+ t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method")
616
+ t4_param = gr.Number(label="Dynamic Param", value=4.0)
617
+ t4_out = gr.Textbox(label="Output")
618
+ t4_btn = gr.Button("Resize")
619
+ t4_res = gr.Textbox(label="Result")
620
+ t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
621
+
622
+ if __name__ == "__main__":
623
+ demo.queue().launch(css=css, ssr_mode=False)