AlekseyCalvin commited on
Commit
e455bfb
·
verified ·
1 Parent(s): 0479976

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +581 -0
app.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import gc
5
+ import shutil
6
+ import requests
7
+ import json
8
+ import struct
9
+ import numpy as np
10
+ import re
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional, List
13
+ from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
14
+ from safetensors.torch import load_file, save_file
15
+ from tqdm import tqdm
16
+
17
+ # --- Memory Efficient Safetensors ---
18
+ class MemoryEfficientSafeOpen:
19
+ def __init__(self, filename):
20
+ self.filename = filename
21
+ self.file = open(filename, "rb")
22
+ self.header, self.header_size = self._read_header()
23
+
24
+ def __enter__(self):
25
+ return self
26
+
27
+ def __exit__(self, exc_type, exc_val, exc_tb):
28
+ self.file.close()
29
+
30
+ def keys(self) -> list[str]:
31
+ return [k for k in self.header.keys() if k != "__metadata__"]
32
+
33
+ def metadata(self) -> Dict[str, str]:
34
+ return self.header.get("__metadata__", {})
35
+
36
+ def get_tensor(self, key):
37
+ if key not in self.header:
38
+ raise KeyError(f"Tensor '{key}' not found in the file")
39
+ metadata = self.header[key]
40
+ offset_start, offset_end = metadata["data_offsets"]
41
+ self.file.seek(self.header_size + 8 + offset_start)
42
+ tensor_bytes = self.file.read(offset_end - offset_start)
43
+ return self._deserialize_tensor(tensor_bytes, metadata)
44
+
45
+ def _read_header(self):
46
+ header_size = struct.unpack("<Q", self.file.read(8))[0]
47
+ header_json = self.file.read(header_size).decode("utf-8")
48
+ return json.loads(header_json), header_size
49
+
50
+ def _deserialize_tensor(self, tensor_bytes, metadata):
51
+ dtype_map = {
52
+ "F32": torch.float32, "F16": torch.float16, "BF16": torch.bfloat16,
53
+ "I64": torch.int64, "I32": torch.int32, "I16": torch.int16, "I8": torch.int8,
54
+ "U8": torch.uint8, "BOOL": torch.bool
55
+ }
56
+ dtype = dtype_map[metadata["dtype"]]
57
+ shape = metadata["shape"]
58
+ return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
59
+
60
+ # --- Constants & Setup ---
61
+ try:
62
+ TempDir = Path("/tmp/temp_tool")
63
+ os.makedirs(TempDir, exist_ok=True)
64
+ except:
65
+ TempDir = Path("./temp_tool")
66
+ os.makedirs(TempDir, exist_ok=True)
67
+
68
+ api = HfApi()
69
+
70
+ def cleanup_temp():
71
+ if TempDir.exists():
72
+ shutil.rmtree(TempDir)
73
+ os.makedirs(TempDir, exist_ok=True)
74
+ gc.collect()
75
+
76
+ def get_key_stem(key):
77
+ key = key.replace(".weight", "").replace(".bias", "")
78
+ key = key.replace(".lora_down", "").replace(".lora_up", "")
79
+ key = key.replace(".lora_A", "").replace(".lora_B", "")
80
+ key = key.replace(".alpha", "")
81
+ prefixes = [
82
+ "model.diffusion_model.", "diffusion_model.", "model.",
83
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
84
+ ]
85
+ changed = True
86
+ while changed:
87
+ changed = False
88
+ for p in prefixes:
89
+ if key.startswith(p):
90
+ key = key[len(p):]
91
+ changed = True
92
+ return key
93
+
94
+ # =================================================================================
95
+ # TAB 1: MERGE & RESHARD
96
+ # =================================================================================
97
+
98
+ def parse_hf_url(url):
99
+ """Parses a direct HF URL into repo_id and filename."""
100
+ # Pattern: https://huggingface.co/{user}/{repo}/resolve/{branch}/{filename...}
101
+ if "huggingface.co" in url and "resolve" in url:
102
+ try:
103
+ parts = url.split("huggingface.co/")[-1].split("/")
104
+ # parts[0]=user, parts[1]=repo, parts[2]=resolve, parts[3]=branch, parts[4:]=file
105
+ repo_id = f"{parts[0]}/{parts[1]}"
106
+ filename = "/".join(parts[4:]).split("?")[0] # Strip query params
107
+ return repo_id, filename
108
+ except:
109
+ return None, None
110
+ return None, None
111
+
112
+ def download_lora_smart(input_str, token):
113
+ local_path = TempDir / "adapter.safetensors"
114
+ if local_path.exists(): os.remove(local_path)
115
+
116
+ print(f"Resolving LoRA Input: {input_str}")
117
+
118
+ # 1. Try Parse as HF URL (Most Robust Method)
119
+ repo_id, filename = parse_hf_url(input_str)
120
+ if repo_id and filename:
121
+ print(f"Detected HF URL. Repo: {repo_id}, File: {filename}")
122
+ try:
123
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
124
+ # Move to standard name
125
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0] # Handle subfolder downloads
126
+ if found != local_path: shutil.move(found, local_path)
127
+ return local_path
128
+ except Exception as e:
129
+ print(f"HF Download failed: {e}. Falling back...")
130
+
131
+ # 2. Try as Raw Repo ID (User/Repo)
132
+ try:
133
+ # Check if user put "User/Repo/file.safetensors"
134
+ if ".safetensors" in input_str and input_str.count("/") >= 2:
135
+ parts = input_str.split("/")
136
+ repo_id = f"{parts[0]}/{parts[1]}"
137
+ filename = "/".join(parts[2:])
138
+ hf_hub_download(repo_id=repo_id, filename=filename, token=token, local_dir=TempDir)
139
+ found = list(TempDir.rglob(filename.split("/")[-1]))[0]
140
+ if found != local_path: shutil.move(found, local_path)
141
+ return local_path
142
+
143
+ # Standard Auto-Discovery
144
+ candidates = ["adapter_model.safetensors", "model.safetensors"]
145
+ files = list_repo_files(repo_id=input_str, token=token)
146
+ target = next((f for f in files if f in candidates), None)
147
+ if not target:
148
+ safes = [f for f in files if f.endswith(".safetensors")]
149
+ if safes: target = safes[0]
150
+
151
+ if not target: raise ValueError("No safetensors found")
152
+
153
+ hf_hub_download(repo_id=input_str, filename=target, token=token, local_dir=TempDir)
154
+ found = list(TempDir.rglob(target.split("/")[-1]))[0]
155
+ if found != local_path: shutil.move(found, local_path)
156
+ return local_path
157
+
158
+ except Exception as e:
159
+ # 3. Last Resort: Raw Requests (For non-HF links)
160
+ if input_str.startswith("http"):
161
+ try:
162
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
163
+ r = requests.get(input_str, stream=True, headers=headers, timeout=60)
164
+ r.raise_for_status()
165
+ with open(local_path, 'wb') as f:
166
+ for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
167
+ return local_path
168
+ except Exception as req_e:
169
+ raise ValueError(f"All download methods failed.\nRepo Logic Error: {e}\nURL Logic Error: {req_e}")
170
+ raise e
171
+
172
+ def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
173
+ print(f"Loading LoRA from {lora_path}...")
174
+ state_dict = load_file(lora_path, device="cpu")
175
+ pairs = {}
176
+ alphas = {}
177
+ for k, v in state_dict.items():
178
+ stem = get_key_stem(k)
179
+ if "alpha" in k:
180
+ alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
181
+ else:
182
+ if stem not in pairs: pairs[stem] = {}
183
+ if "lora_down" in k or "lora_A" in k:
184
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
185
+ pairs[stem]["rank"] = v.shape[0]
186
+ elif "lora_up" in k or "lora_B" in k:
187
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
188
+ for stem in pairs:
189
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
190
+ return pairs
191
+
192
+ class ShardBuffer:
193
+ def __init__(self, max_size_gb, output_dir, output_repo, subfolder, hf_token, filename_prefix="model"):
194
+ self.max_bytes = int(max_size_gb * 1024**3)
195
+ self.output_dir = output_dir
196
+ self.output_repo = output_repo
197
+ self.subfolder = subfolder
198
+ self.hf_token = hf_token
199
+ self.filename_prefix = filename_prefix
200
+ self.buffer = []
201
+ self.current_bytes = 0
202
+ self.shard_count = 0
203
+ self.index_map = {}
204
+ self.total_size = 0
205
+
206
+ def add_tensor(self, key, tensor):
207
+ if tensor.dtype == torch.bfloat16:
208
+ raw_bytes = tensor.view(torch.int16).numpy().tobytes()
209
+ dtype_str = "BF16"
210
+ elif tensor.dtype == torch.float16:
211
+ raw_bytes = tensor.numpy().tobytes()
212
+ dtype_str = "F16"
213
+ else:
214
+ raw_bytes = tensor.numpy().tobytes()
215
+ dtype_str = "F32"
216
+
217
+ size = len(raw_bytes)
218
+ self.buffer.append({
219
+ "key": key,
220
+ "data": raw_bytes,
221
+ "dtype": dtype_str,
222
+ "shape": tensor.shape
223
+ })
224
+ self.current_bytes += size
225
+ self.total_size += size
226
+
227
+ if self.current_bytes >= self.max_bytes:
228
+ self.flush()
229
+
230
+ def flush(self):
231
+ if not self.buffer: return
232
+ self.shard_count += 1
233
+
234
+ filename = f"{self.filename_prefix}-{self.shard_count:05d}.safetensors"
235
+ path_in_repo = f"{self.subfolder}/{filename}" if self.subfolder else filename
236
+
237
+ print(f"Flushing {path_in_repo} ({self.current_bytes / 1024**3:.2f} GB)...")
238
+
239
+ header = {"__metadata__": {"format": "pt"}}
240
+ current_offset = 0
241
+ for item in self.buffer:
242
+ header[item["key"]] = {
243
+ "dtype": item["dtype"],
244
+ "shape": item["shape"],
245
+ "data_offsets": [current_offset, current_offset + len(item["data"])]
246
+ }
247
+ current_offset += len(item["data"])
248
+ self.index_map[item["key"]] = filename
249
+
250
+ header_json = json.dumps(header).encode('utf-8')
251
+
252
+ out_path = self.output_dir / filename
253
+ with open(out_path, 'wb') as f:
254
+ f.write(struct.pack('<Q', len(header_json)))
255
+ f.write(header_json)
256
+ for item in self.buffer:
257
+ f.write(item["data"])
258
+
259
+ print(f"Uploading {path_in_repo}...")
260
+ api.upload_file(path_or_fileobj=out_path, path_in_repo=path_in_repo, repo_id=self.output_repo, token=self.hf_token)
261
+
262
+ os.remove(out_path)
263
+ self.buffer = []
264
+ self.current_bytes = 0
265
+ gc.collect()
266
+
267
+ def copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder):
268
+ """Aggressively copy all config/misc files, only skipping heavy weights."""
269
+ print(f"Copying config files from {base_repo}...")
270
+ try:
271
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
272
+ blocked_ext = ['.safetensors', '.bin', '.pt', '.pth', '.msgpack', '.h5', '.onnx']
273
+
274
+ for f in files:
275
+ # Filter by subfolder if needed
276
+ if base_subfolder and not f.startswith(base_subfolder): continue
277
+
278
+ # Block heavy weights
279
+ if any(f.endswith(ext) for ext in blocked_ext): continue
280
+
281
+ print(f"Transferring {f}...")
282
+ local = hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=TempDir)
283
+
284
+ # Determine path in new repo
285
+ rel_name = f[len(base_subfolder):].lstrip('/') if base_subfolder else f
286
+ target_path = f"{output_subfolder}/{rel_name}" if output_subfolder else rel_name
287
+
288
+ api.upload_file(path_or_fileobj=local, path_in_repo=target_path, repo_id=output_repo, token=hf_token)
289
+ os.remove(local)
290
+
291
+ except Exception as e:
292
+ print(f"Config copy warning: {e}")
293
+
294
+ def streaming_copy_structure(token, src_repo, dst_repo, ignore_prefix=None, is_root_merge=False):
295
+ print(f"Scanning {src_repo} for structure cloning...")
296
+ try:
297
+ files = api.list_repo_files(repo_id=src_repo, token=token)
298
+ for f in tqdm(files, desc="Copying Structure"):
299
+ if ignore_prefix and f.startswith(ignore_prefix): continue
300
+
301
+ if is_root_merge:
302
+ if any(f.endswith(ext) for ext in ['.safetensors', '.bin', '.pt', '.pth']):
303
+ continue
304
+
305
+ try:
306
+ local = hf_hub_download(repo_id=src_repo, filename=f, token=token, local_dir=TempDir)
307
+ api.upload_file(path_or_fileobj=local, path_in_repo=f, repo_id=dst_repo, token=token)
308
+ if os.path.exists(local): os.remove(local)
309
+ except: pass
310
+ except Exception as e: print(f"Structure clone error: {e}")
311
+
312
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
313
+ cleanup_temp()
314
+ if not hf_token: return "Error: HF Token required."
315
+ login(hf_token.strip())
316
+
317
+ try:
318
+ api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
319
+ except Exception as e: return f"Error creating repo: {e}"
320
+
321
+ # Logic: If using a subfolder like 'transformer', we want standard diffusers naming
322
+ output_subfolder = base_subfolder if base_subfolder else ""
323
+
324
+ # 2. Copy Configs from Base (Aggressive Copy)
325
+ if base_subfolder:
326
+ copy_auxiliary_configs(hf_token, base_repo, base_subfolder, output_repo, output_subfolder)
327
+
328
+ # 3. Clone Structure Repo
329
+ if structure_repo:
330
+ ignore = output_subfolder if output_subfolder else None
331
+ streaming_copy_structure(hf_token, structure_repo, output_repo, ignore_prefix=ignore, is_root_merge=not bool(output_subfolder))
332
+
333
+ # 4. Download Shards
334
+ progress(0.1, desc="Downloading Input Model...")
335
+ files = list_repo_files(repo_id=base_repo, token=hf_token)
336
+ input_shards = []
337
+
338
+ for f in files:
339
+ if f.endswith(".safetensors"):
340
+ if output_subfolder and not f.startswith(output_subfolder): continue
341
+
342
+ local = TempDir / "inputs" / os.path.basename(f)
343
+ os.makedirs(local.parent, exist_ok=True)
344
+ hf_hub_download(repo_id=base_repo, filename=f, token=hf_token, local_dir=local.parent, local_dir_use_symlinks=False)
345
+ found = list(local.parent.rglob(os.path.basename(f)))
346
+ if found: input_shards.append(found[0])
347
+
348
+ if not input_shards: return "No safetensors found."
349
+ input_shards.sort()
350
+
351
+ # --- NAMING CONVENTION ---
352
+ # Force diffusion naming if target is transformer/unet
353
+ if output_subfolder in ["transformer", "unet", "qint4", "qint8"]:
354
+ filename_prefix = "diffusion_pytorch_model"
355
+ index_filename = "diffusion_pytorch_model.safetensors.index.json"
356
+ elif "diffusion_pytorch_model" in os.path.basename(input_shards[0]):
357
+ filename_prefix = "diffusion_pytorch_model"
358
+ index_filename = "diffusion_pytorch_model.safetensors.index.json"
359
+ else:
360
+ filename_prefix = "model"
361
+ index_filename = "model.safetensors.index.json"
362
+
363
+ print(f"Naming scheme: {filename_prefix}")
364
+
365
+ # 5. Load LoRA
366
+ dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
367
+ try:
368
+ progress(0.15, desc="Downloading LoRA...")
369
+ lora_path = download_lora_smart(lora_input, hf_token)
370
+ lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
371
+ except Exception as e: return f"Error loading LoRA: {e}"
372
+
373
+ # 6. Stream
374
+ buffer = ShardBuffer(shard_size, TempDir, output_repo, output_subfolder, hf_token, filename_prefix=filename_prefix)
375
+
376
+ for i, shard_file in enumerate(input_shards):
377
+ progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {os.path.basename(shard_file)}")
378
+
379
+ with MemoryEfficientSafeOpen(shard_file) as f:
380
+ keys = f.keys()
381
+ for k in keys:
382
+ v = f.get_tensor(k)
383
+ base_stem = get_key_stem(k)
384
+ match = lora_pairs.get(base_stem)
385
+
386
+ # QKV Heuristics
387
+ if not match:
388
+ if "to_q" in base_stem:
389
+ qkv = base_stem.replace("to_q", "qkv")
390
+ match = lora_pairs.get(qkv)
391
+ elif "to_k" in base_stem:
392
+ qkv = base_stem.replace("to_k", "qkv")
393
+ match = lora_pairs.get(qkv)
394
+ elif "to_v" in base_stem:
395
+ qkv = base_stem.replace("to_v", "qkv")
396
+ match = lora_pairs.get(qkv)
397
+
398
+ if match:
399
+ down = match["down"]
400
+ up = match["up"]
401
+ scaling = scale * (match["alpha"] / match["rank"])
402
+
403
+ if len(v.shape) == 4 and len(down.shape) == 2:
404
+ down = down.unsqueeze(-1).unsqueeze(-1)
405
+ up = up.unsqueeze(-1).unsqueeze(-1)
406
+
407
+ try:
408
+ if len(up.shape) == 4:
409
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
410
+ else:
411
+ delta = up @ down
412
+ except: delta = up.T @ down
413
+
414
+ delta = delta * scaling
415
+
416
+ valid = True
417
+ if delta.shape == v.shape: pass
418
+ elif delta.shape[0] == v.shape[0] * 3:
419
+ chunk = v.shape[0]
420
+ if "to_q" in k: delta = delta[0:chunk, ...]
421
+ elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
422
+ elif "to_v" in k: delta = delta[2*chunk:, ...]
423
+ else: valid = False
424
+ elif delta.numel() == v.numel(): delta = delta.reshape(v.shape)
425
+ else: valid = False
426
+
427
+ if valid:
428
+ v = v.to(dtype)
429
+ delta = delta.to(dtype)
430
+ v.add_(delta)
431
+ del delta
432
+
433
+ if v.dtype != dtype: v = v.to(dtype)
434
+ buffer.add_tensor(k, v)
435
+ del v
436
+
437
+ os.remove(shard_file)
438
+ gc.collect()
439
+
440
+ buffer.flush()
441
+
442
+ print(f"Uploading Index: {index_filename} (Size: {buffer.total_size})")
443
+ index_data = {"metadata": {"total_size": buffer.total_size}, "weight_map": buffer.index_map}
444
+ with open(TempDir / index_filename, "w") as f:
445
+ json.dump(index_data, f, indent=4)
446
+
447
+ path_in_repo = f"{output_subfolder}/{index_filename}" if output_subfolder else index_filename
448
+ api.upload_file(path_or_fileobj=TempDir / index_filename, path_in_repo=path_in_repo, repo_id=output_repo, token=hf_token)
449
+
450
+ cleanup_temp()
451
+ return f"Done! Merged {buffer.shard_count} shards to {output_repo}"
452
+
453
+ # =================================================================================
454
+ # TAB 4: RESIZE (CPU Optimized)
455
+ # =================================================================================
456
+
457
+ def task_resize(hf_token, lora_input, new_rank, dynamic_method, dynamic_param, out_repo):
458
+ cleanup_temp()
459
+ if not hf_token: return "Error: Token required"
460
+ login(hf_token.strip())
461
+
462
+ try:
463
+ path = download_lora_smart(lora_input, hf_token)
464
+ except Exception as e: return f"Error: {e}"
465
+
466
+ state = load_file(path, device="cpu")
467
+ new_state = {}
468
+
469
+ groups = {}
470
+ for k in state:
471
+ stem = get_key_stem(k)
472
+ simple = k.split(".lora_")[0]
473
+ if simple not in groups: groups[simple] = {}
474
+ if "lora_down" in k or "lora_A" in k: groups[simple]["down"] = state[k]
475
+ if "lora_up" in k or "lora_B" in k: groups[simple]["up"] = state[k]
476
+ if "alpha" in k: groups[simple]["alpha"] = state[k]
477
+
478
+ print(f"Resizing {len(groups)} blocks...")
479
+ for stem, g in tqdm(groups.items()):
480
+ if "down" in g and "up" in g:
481
+ down, up = g["down"].float(), g["up"].float()
482
+
483
+ if len(down.shape) == 4:
484
+ merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
485
+ flat = merged.flatten(1)
486
+ else:
487
+ merged = up @ down
488
+ flat = merged
489
+
490
+ # FAST SVD (svd_lowrank)
491
+ target_rank = int(new_rank)
492
+ # Add buffer to q to ensure convergence, then slice
493
+ q = min(target_rank + 10, min(flat.shape))
494
+
495
+ U, S, V = torch.svd_lowrank(flat, q=q)
496
+ # V from svd_lowrank is (N, q), we need Vh (q, N)
497
+ Vh = V.t()
498
+
499
+ # Exact truncation
500
+ U = U[:, :target_rank]
501
+ S = S[:target_rank]
502
+ Vh = Vh[:target_rank, :]
503
+
504
+ # Reconstruct
505
+ U = U @ torch.diag(S)
506
+
507
+ if len(down.shape) == 4:
508
+ U = U.reshape(up.shape[0], target_rank, 1, 1)
509
+ Vh = Vh.reshape(target_rank, down.shape[1], down.shape[2], down.shape[3])
510
+
511
+ new_state[f"{stem}.lora_down.weight"] = Vh
512
+ new_state[f"{stem}.lora_up.weight"] = U
513
+ new_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
514
+
515
+ out = TempDir / "resized.safetensors"
516
+ save_file(new_state, out)
517
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
518
+ api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=out_repo, token=hf_token)
519
+ return "Done"
520
+
521
+ # =================================================================================
522
+ # UI
523
+ # =================================================================================
524
+
525
+ css = ".container { max-width: 900px; margin: auto; }"
526
+
527
+ with gr.Blocks() as demo:
528
+ gr.Markdown("# 🧰SOONmerge® LoRA Toolkit")
529
+
530
+ with gr.Tabs():
531
+ with gr.Tab("Merge to Base + Reshard Output"):
532
+ t1_token = gr.Textbox(label="Token", type="password")
533
+ t1_base = gr.Textbox(label="Base Repo (Diffusers)", value="ostris/Z-Image-De-Turbo")
534
+ t1_sub = gr.Textbox(label="Subfolder", value="transformer")
535
+ t1_lora = gr.Textbox(label="LoRA Direct Link or Repo", value="https://huggingface.co/GuangyuanSD/Z-Image-Re-Turbo-LoRA/resolve/main/Z-image_re_turbo_lora_8steps_rank_32_v1_fp16.safetensors")
536
+ with gr.Row():
537
+ t1_scale = gr.Slider(label="Scale", value=1.0, minimum=0, maximum=3.0, step=0.1)
538
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
539
+ t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.1, maximum=10.0, step=0.1)
540
+ t1_out = gr.Textbox(label="Output Repo")
541
+ t1_struct = gr.Textbox(label="Diffusers Extras (Copies VAE/TextEnc/etc)", value="Tongyi-MAI/Z-Image-Turbo")
542
+ t1_priv = gr.Checkbox(label="Private", value=True)
543
+ t1_btn = gr.Button("Merge")
544
+ t1_res = gr.Textbox(label="Result")
545
+ t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
546
+
547
+ with gr.Tab("Extract Adapter"):
548
+ t2_token = gr.Textbox(label="Token", type="password")
549
+ t2_org = gr.Textbox(label="Original Model")
550
+ t2_tun = gr.Textbox(label="Tuned Model")
551
+ t2_rank = gr.Number(label="Extract At Rank", value=32, minimum=1, maximum=1024, step=1)
552
+ t2_out = gr.Textbox(label="Output Repo")
553
+ t2_btn = gr.Button("Extract")
554
+ t2_res = gr.Textbox(label="Result")
555
+ t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
556
+
557
+ with gr.Tab("Merge Multiple Adapters"):
558
+ t3_token = gr.Textbox(label="Token", type="password")
559
+ t3_urls = gr.Textbox(label="URLs")
560
+ with gr.Row():
561
+ t3_beta = gr.Slider(label="Beta", value=0.95, minimum=0.01, maximum=1.00, step=0.01)
562
+ t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.21, minimum=0.01, maximum=1.00, step=0.01)
563
+ t3_out = gr.Textbox(label="Output Repo")
564
+ t3_btn = gr.Button("Merge")
565
+ t3_res = gr.Textbox(label="Result")
566
+ t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res)
567
+
568
+ with gr.Tab("Resize Adapter"):
569
+ t4_token = gr.Textbox(label="Token", type="password")
570
+ t4_in = gr.Textbox(label="LoRA")
571
+ with gr.Row():
572
+ t4_rank = gr.Number(label="To Rank (Lower Only!)", value=8, minimum=1, maximum=256, step=1)
573
+ t4_method = gr.Dropdown(["None", "sv_ratio"], value="None", label="Dynamic Method")
574
+ t4_param = gr.Number(label="Dynamic Param", value=4.0)
575
+ t4_out = gr.Textbox(label="Output")
576
+ t4_btn = gr.Button("Resize")
577
+ t4_res = gr.Textbox(label="Result")
578
+ t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_method, t4_param, t4_out], t4_res)
579
+
580
+ if __name__ == "__main__":
581
+ demo.queue().launch(css=css, ssr_mode=False)