AlekseyCalvin commited on
Commit
df67033
·
verified ·
1 Parent(s): 744516f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -317
app.py CHANGED
@@ -9,7 +9,7 @@ import struct
9
  import numpy as np
10
  import re
11
  from pathlib import Path
12
- from typing import Dict, Any, Optional
13
  from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
14
  from safetensors.torch import load_file, save_file
15
  from tqdm import tqdm
@@ -18,7 +18,6 @@ from tqdm import tqdm
18
  class MemoryEfficientSafeOpen:
19
  """
20
  Reads safetensors metadata and tensors without mmap, keeping RAM usage low.
21
- Essential for running on limited hardware.
22
  """
23
  def __init__(self, filename):
24
  self.filename = filename
@@ -62,8 +61,15 @@ class MemoryEfficientSafeOpen:
62
  return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
63
 
64
  # --- Constants & Setup ---
65
- TempDir = Path("./temp_tool")
66
- os.makedirs(TempDir, exist_ok=True)
 
 
 
 
 
 
 
67
  api = HfApi()
68
 
69
  def cleanup_temp():
@@ -72,60 +78,35 @@ def cleanup_temp():
72
  os.makedirs(TempDir, exist_ok=True)
73
  gc.collect()
74
 
75
- def verify_safetensors(path):
76
- """Checks if a file is a valid safetensors file."""
77
- try:
78
- with open(path, "rb") as f:
79
- header_size_bytes = f.read(8)
80
- if len(header_size_bytes) != 8: return False
81
- header_size = struct.unpack("<Q", header_size_bytes)[0]
82
- if header_size > os.path.getsize(path) or header_size <= 0:
83
- return False
84
- return True
85
- except:
86
- return False
87
-
88
  def download_file(input_path, token, filename=None):
89
- """Downloads a file from URL or HF Repo."""
90
  local_path = TempDir / (filename if filename else "model.safetensors")
91
-
92
  if input_path.startswith("http"):
93
- print(f"Downloading from URL: {input_path}")
94
  try:
95
  response = requests.get(input_path, stream=True, timeout=30)
96
  response.raise_for_status()
97
  with open(local_path, 'wb') as f:
98
  for chunk in response.iter_content(chunk_size=8192):
99
  f.write(chunk)
100
- except Exception as e:
101
- raise ValueError(f"Failed to download URL. Check your link. Error: {e}")
102
  else:
103
- print(f"Downloading from Repo: {input_path}")
104
  if not filename:
105
  try:
106
  files = list_repo_files(repo_id=input_path, token=token)
107
  safetensors = [f for f in files if f.endswith(".safetensors")]
108
- if safetensors:
109
- filename = safetensors[0]
110
- for f in safetensors:
111
- if "adapter" in f: filename = f
112
- else:
113
- filename = "adapter_model.bin"
114
- except Exception as e:
115
- filename = "adapter_model.safetensors"
116
 
117
  try:
118
  hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
119
- downloaded_path = TempDir / filename
120
- if downloaded_path != local_path:
121
- if local_path.exists(): os.remove(local_path)
122
- shutil.move(downloaded_path, local_path)
123
- except Exception as e:
124
- raise ValueError(f"Failed to download from HF Repo. Check ID/Token. Error: {e}")
125
 
126
- if not verify_safetensors(local_path):
127
- raise ValueError(f"Downloaded file is NOT a valid safetensors file. Check your URL/Repo. (File size: {os.path.getsize(local_path)} bytes)")
128
-
129
  return local_path
130
 
131
  def get_key_stem(key):
@@ -133,13 +114,10 @@ def get_key_stem(key):
133
  key = key.replace(".lora_down", "").replace(".lora_up", "")
134
  key = key.replace(".lora_A", "").replace(".lora_B", "")
135
  key = key.replace(".alpha", "")
136
-
137
  prefixes = [
138
  "model.diffusion_model.", "diffusion_model.", "model.",
139
- "transformer.", "text_encoder.", "lora_unet_", "lora_te_",
140
- "base_model.model."
141
  ]
142
-
143
  changed = True
144
  while changed:
145
  changed = False
@@ -150,149 +128,124 @@ def get_key_stem(key):
150
  return key
151
 
152
  # =================================================================================
153
- # TAB 1: UNIVERSAL MERGE (Low-Precision Optimized)
154
  # =================================================================================
155
 
156
  def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
157
- print(f"Loading LoRA from {lora_path} in {precision_dtype}...")
158
  state_dict = load_file(lora_path, device="cpu")
159
-
160
  pairs = {}
161
  alphas = {}
162
-
163
  for k, v in state_dict.items():
164
  stem = get_key_stem(k)
165
  if "alpha" in k:
166
  alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
167
  else:
168
- if stem not in pairs:
169
- pairs[stem] = {}
170
-
171
- # Cast immediately to save RAM
172
- tensor_low = v.to(dtype=precision_dtype)
173
-
174
  if "lora_down" in k or "lora_A" in k:
175
- pairs[stem]["down"] = tensor_low
176
  pairs[stem]["rank"] = v.shape[0]
177
  elif "lora_up" in k or "lora_B" in k:
178
- pairs[stem]["up"] = tensor_low
179
-
180
  for stem in pairs:
181
- if stem in alphas:
182
- pairs[stem]["alpha"] = alphas[stem]
183
- else:
184
- if "rank" in pairs[stem]:
185
- pairs[stem]["alpha"] = float(pairs[stem]["rank"])
186
- else:
187
- pairs[stem]["alpha"] = 1.0
188
-
189
  return pairs
190
 
191
- def merge_shard_logic(base_path, lora_pairs, scale, output_path, precision_dtype=torch.bfloat16):
192
- print(f"Loading base shard: {base_path}")
193
- base_state = load_file(base_path, device="cpu")
194
-
195
- lora_keys = set(lora_pairs.keys())
196
- keys_to_process = list(base_state.keys())
197
-
198
- for k in keys_to_process:
199
- v = base_state[k]
200
- base_stem = get_key_stem(k)
201
- match = None
202
 
203
- # 1. Exact Match
204
- if base_stem in lora_keys:
205
- match = lora_pairs[base_stem]
 
 
 
 
 
 
206
  else:
207
- # 2. Heuristic Match
208
- if "to_q" in base_stem:
209
- qkv_stem = base_stem.replace("to_q", "qkv")
210
- if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
211
- elif "to_k" in base_stem:
212
- qkv_stem = base_stem.replace("to_k", "qkv")
213
- if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
214
- elif "to_v" in base_stem:
215
- qkv_stem = base_stem.replace("to_v", "qkv")
216
- if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
217
-
218
- if match and "down" in match and "up" in match:
219
- down = match["down"]
220
- up = match["up"]
221
- alpha = match["alpha"]
222
- rank = match["rank"]
223
-
224
- scaling = scale * (alpha / rank)
225
 
226
- # Handle Conv 1x1 squeeze
227
- if len(v.shape) == 4 and len(down.shape) == 2:
228
- down = down.unsqueeze(-1).unsqueeze(-1)
229
- up = up.unsqueeze(-1).unsqueeze(-1)
230
-
231
- try:
232
- if len(up.shape) == 4:
233
- delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
234
- else:
235
- delta = up @ down
236
- except:
237
- delta = up.T @ down
238
-
239
- delta = delta * scaling
240
 
241
- valid_delta = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
- # --- Dynamic Reshaping / Slicing ---
244
- if delta.shape == v.shape:
245
- pass
246
- elif delta.shape[0] == v.shape[0] * 3:
247
- chunk_size = v.shape[0]
248
- if "to_q" in k:
249
- delta = delta[0:chunk_size, ...]
250
- elif "to_k" in k:
251
- delta = delta[chunk_size:2*chunk_size, ...]
252
- elif "to_v" in k:
253
- delta = delta[2*chunk_size:, ...]
254
- else:
255
- valid_delta = False
256
- elif delta.numel() == v.numel():
257
- delta = delta.reshape(v.shape)
258
- else:
259
- # print(f"Skipping {k}: Mismatch. Base: {v.shape}, Delta: {delta.shape}")
260
- valid_delta = False
261
-
262
- if valid_delta:
263
- # Optimized In-Place Addition (Zero Copy)
264
- if v.dtype != delta.dtype:
265
- delta = delta.to(v.dtype)
266
 
267
- v.add_(delta)
268
- del delta
 
269
 
270
- if len(keys_to_process) > 100 and keys_to_process.index(k) % 50 == 0:
271
- gc.collect()
272
-
273
- save_file(base_state, output_path)
274
- return True
275
 
276
- # NOTE: Arguments must match exactly with the inputs=[] list in click()
277
- def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, output_repo, structure_repo, private, progress=gr.Progress()):
278
  cleanup_temp()
279
  login(hf_token)
280
 
281
- # Determine Dtype
282
- if precision == "bf16":
283
- dtype = torch.bfloat16
284
- elif precision == "fp16":
285
- dtype = torch.float16
286
- else:
287
- dtype = torch.float32
288
-
289
- print(f"Selected Precision: {dtype}")
290
-
291
  try:
292
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
293
- except Exception as e:
294
- return f"Error creating repo: {e}"
295
-
296
  if structure_repo:
297
  print("Cloning structure...")
298
  try:
@@ -303,39 +256,127 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
303
  path = hf_hub_download(repo_id=structure_repo, filename=f, token=hf_token)
304
  api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=output_repo, token=hf_token)
305
  except: pass
306
- except Exception as e:
307
- print(f"Structure clone warning: {e}")
308
 
 
 
309
  try:
310
  progress(0.1, desc="Downloading LoRA...")
311
- lora_path = download_file(lora_input, hf_token)
312
  lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
313
- except Exception as e:
314
- return f"CRITICAL ERROR: {str(e)}"
315
-
 
316
  files = list_repo_files(repo_id=base_repo, token=hf_token)
317
- shards = [f for f in files if f.endswith(".safetensors")]
318
  if base_subfolder:
319
- shards = [f for f in shards if f.startswith(base_subfolder)]
320
 
321
- if not shards: return "Error: No safetensors found in base."
322
-
323
- for i, shard in enumerate(shards):
324
- progress(0.2 + (0.8 * i/len(shards)), desc=f"Merging {shard}")
325
- local_shard = hf_hub_download(repo_id=base_repo, filename=shard, token=hf_token, local_dir=TempDir)
326
- merged_path = TempDir / "merged.safetensors"
 
 
 
 
327
 
328
- # Merge
329
- merge_shard_logic(local_shard, lora_pairs, scale, merged_path, precision_dtype=dtype)
330
 
331
- # Upload
332
- api.upload_file(path_or_fileobj=merged_path, path_in_repo=shard, repo_id=output_repo, token=hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
 
334
  os.remove(local_shard)
335
- if merged_path.exists(): os.remove(merged_path)
336
  gc.collect()
337
-
338
- return f"Done! Model at https://huggingface.co/{output_repo}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
  # =================================================================================
341
  # TAB 2: EXTRACT LORA
@@ -345,15 +386,11 @@ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
345
  org = MemoryEfficientSafeOpen(model_org)
346
  tuned = MemoryEfficientSafeOpen(model_tuned)
347
  lora_sd = {}
348
-
349
- print("Calculating diffs and running SVD (Layer-wise)...")
350
- keys = list(org.keys())
351
-
352
- for key in tqdm(keys):
353
  if key not in tuned.keys(): continue
354
  mat_org = org.get_tensor(key).float()
355
  mat_tuned = tuned.get_tensor(key).float()
356
-
357
  diff = mat_tuned - mat_org
358
  if torch.max(torch.abs(diff)) < 1e-4: continue
359
 
@@ -364,171 +401,93 @@ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
364
 
365
  try:
366
  U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
367
- U = U[:, :r]
368
- S = S[:r]
369
  U = U @ torch.diag(S)
370
- Vh = Vh[:r, :]
371
-
372
  dist = torch.cat([U.flatten(), Vh.flatten()])
373
  hi_val = torch.quantile(dist, clamp)
374
  U = U.clamp(-hi_val, hi_val)
375
  Vh = Vh.clamp(-hi_val, hi_val)
376
-
377
  if is_conv:
378
  U = U.reshape(out_dim, r, 1, 1)
379
  Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
380
  else:
381
  U = U.reshape(out_dim, r)
382
  Vh = Vh.reshape(r, in_dim)
383
-
384
  stem = key.replace(".weight", "")
385
  lora_sd[f"{stem}.lora_up.weight"] = U
386
  lora_sd[f"{stem}.lora_down.weight"] = Vh
387
  lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
388
- except Exception as e:
389
- print(f"SVD failed for {key}: {e}")
390
-
391
- out_path = TempDir / "extracted_lora.safetensors"
392
- save_file(lora_sd, out_path)
393
- return str(out_path)
394
 
395
- def task_extract(hf_token, org_repo, tuned_repo, rank, output_repo):
396
  cleanup_temp()
397
  login(hf_token)
398
- print("Downloading models...")
399
  try:
400
- p1 = download_file(org_repo, hf_token, "org.safetensors")
401
- p2 = download_file(tuned_repo, hf_token, "tuned.safetensors")
402
- out = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
403
- api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
404
- api.upload_file(path_or_fileobj=out, path_in_repo="extracted_lora.safetensors", repo_id=output_repo, token=hf_token)
405
- return "Extraction Done."
406
- except Exception as e:
407
- return f"Error: {e}"
408
 
409
  # =================================================================================
410
- # TAB 3: MERGE ADAPTERS (EMA)
411
  # =================================================================================
412
 
413
- def task_merge_adapters(hf_token, lora_urls, beta, output_repo):
414
  cleanup_temp()
415
  login(hf_token)
416
- urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
417
- paths = []
418
  try:
419
- for i, url in enumerate(urls):
420
- paths.append(download_file(url, hf_token, f"adapter_{i}.safetensors"))
421
- except Exception as e:
422
- return f"Download Error: {e}"
423
-
424
- if not paths: return "No models found"
425
-
426
- base_sd = load_file(paths[0], device="cpu")
427
- for k in base_sd:
428
- if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
429
-
430
- for i, path in enumerate(paths[1:]):
431
- print(f"Merging {path}")
432
- curr = load_file(path, device="cpu")
433
- for k in base_sd:
434
- if k in curr and "alpha" not in k:
435
- base_sd[k] = base_sd[k] * beta + curr[k].float() * (1 - beta)
436
-
437
- out = TempDir / "merged_adapters.safetensors"
438
- save_file(base_sd, out)
439
- api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
440
- api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=output_repo, token=hf_token)
441
- return "Done"
442
 
443
  # =================================================================================
444
- # TAB 4: RESIZE
445
- # =================================================================================
446
-
447
- def task_resize(hf_token, lora_input, new_rank, output_repo):
448
- cleanup_temp()
449
- login(hf_token)
450
- try:
451
- path = download_file(lora_input, hf_token)
452
- except Exception as e:
453
- return f"Download Error: {e}"
454
-
455
- state = load_file(path, device="cpu")
456
- new_state = {}
457
- print("Resizing...")
458
-
459
- groups = {}
460
- for k in state:
461
- stem = get_key_stem(k)
462
- stem_simple = k.split(".lora_")[0]
463
- if stem_simple not in groups: groups[stem_simple] = {}
464
- if "lora_down" in k or "lora_A" in k: groups[stem_simple]["down"] = state[k]
465
- if "lora_up" in k or "lora_B" in k: groups[stem_simple]["up"] = state[k]
466
-
467
- for stem, g in tqdm(groups.items()):
468
- if "down" in g and "up" in g:
469
- down, up = g["down"].float(), g["up"].float()
470
- if len(down.shape) == 4:
471
- merged = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
472
- flat = merged.flatten(1)
473
- else:
474
- merged = up @ down
475
- flat = merged
476
-
477
- U, S, Vh = torch.linalg.svd(flat, full_matrices=False)
478
- U = U[:, :new_rank]
479
- S = S[:new_rank]
480
- U = U @ torch.diag(S)
481
- Vh = Vh[:new_rank, :]
482
-
483
- if len(down.shape) == 4:
484
- U = U.reshape(up.shape[0], new_rank, 1, 1)
485
- Vh = Vh.reshape(new_rank, down.shape[1], down.shape[2], down.shape[3])
486
-
487
- new_state[f"{stem}.lora_down.weight"] = Vh
488
- new_state[f"{stem}.lora_up.weight"] = U
489
- new_state[f"{stem}.alpha"] = torch.tensor(new_rank).float()
490
-
491
- out = TempDir / "resized.safetensors"
492
- save_file(new_state, out)
493
- api.create_repo(repo_id=output_repo, exist_ok=True, token=hf_token)
494
- api.upload_file(path_or_fileobj=out, path_in_repo="resized.safetensors", repo_id=output_repo, token=hf_token)
495
- return "Done"
496
-
497
- # =================================================================================
498
- # UI Construction
499
  # =================================================================================
500
 
501
  css = ".container { max-width: 900px; margin: auto; }"
502
 
503
  with gr.Blocks() as demo:
504
- gr.Markdown("# 🧰 SOONmerge® LoRA Toolkit")
505
 
506
  with gr.Tabs():
507
- with gr.Tab("Merge (Z-Image Fix)"):
508
  t1_token = gr.Textbox(label="Token", type="password")
509
  t1_base = gr.Textbox(label="Base Repo", value="ostris/Z-Image-De-Turbo")
510
  t1_sub = gr.Textbox(label="Subfolder", value="transformer")
511
  t1_lora = gr.Textbox(label="LoRA")
512
-
513
  with gr.Row():
514
- t1_scale = gr.Slider(label="Scale", value=1.0, minimum=-1, maximum=2)
515
- t1_prec = gr.Radio(["bf16", "fp16", "float32"], label="Precision", value="bf16")
516
-
517
  t1_out = gr.Textbox(label="Output")
518
  t1_struct = gr.Textbox(label="Structure Repo", value="Tongyi-MAI/Z-Image-Turbo")
519
- # Explicitly defined checkbox to ensure correct arg count
520
- t1_private = gr.Checkbox(label="Private Repo", value=True)
521
-
522
- t1_btn = gr.Button("Merge")
523
  t1_res = gr.Textbox(label="Result")
524
-
525
- # Corrected argument count: exactly 9 inputs + 1 output
526
- t1_btn.click(
527
- task_merge,
528
- inputs=[t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_out, t1_struct, t1_private],
529
- outputs=t1_res
530
- )
531
-
532
  with gr.Tab("Extract"):
533
  t2_token = gr.Textbox(label="Token", type="password")
534
  t2_org = gr.Textbox(label="Original")
@@ -538,24 +497,15 @@ with gr.Blocks() as demo:
538
  t2_btn = gr.Button("Extract")
539
  t2_res = gr.Textbox(label="Result")
540
  t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
541
-
542
  with gr.Tab("Merge Adapters"):
543
  t3_token = gr.Textbox(label="Token", type="password")
544
- t3_urls = gr.Textbox(label="URLs (comma sep)")
545
  t3_beta = gr.Slider(label="Beta", value=0.9)
546
  t3_out = gr.Textbox(label="Output")
547
  t3_btn = gr.Button("Merge")
548
  t3_res = gr.Textbox(label="Result")
549
  t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_out], t3_res)
550
-
551
- with gr.Tab("Resize"):
552
- t4_token = gr.Textbox(label="Token", type="password")
553
- t4_in = gr.Textbox(label="LoRA")
554
- t4_rank = gr.Number(label="Rank", value=8)
555
- t4_out = gr.Textbox(label="Output")
556
- t4_btn = gr.Button("Resize")
557
- t4_res = gr.Textbox(label="Result")
558
- t4_btn.click(task_resize, [t4_token, t4_in, t4_rank, t4_out], t4_res)
559
 
560
  if __name__ == "__main__":
561
  demo.queue().launch(css=css, ssr_mode=False)
 
9
  import numpy as np
10
  import re
11
  from pathlib import Path
12
+ from typing import Dict, Any, Optional, List
13
  from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
14
  from safetensors.torch import load_file, save_file
15
  from tqdm import tqdm
 
18
  class MemoryEfficientSafeOpen:
19
  """
20
  Reads safetensors metadata and tensors without mmap, keeping RAM usage low.
 
21
  """
22
  def __init__(self, filename):
23
  self.filename = filename
 
61
  return torch.frombuffer(tensor_bytes, dtype=torch.uint8).view(dtype).reshape(shape)
62
 
63
  # --- Constants & Setup ---
64
+ # Use /tmp/temp_tool if possible for better ephemeral handling,
65
+ # or fall back to ./temp_tool in working dir.
66
+ try:
67
+ TempDir = Path("/tmp/temp_tool")
68
+ os.makedirs(TempDir, exist_ok=True)
69
+ except:
70
+ TempDir = Path("./temp_tool")
71
+ os.makedirs(TempDir, exist_ok=True)
72
+
73
  api = HfApi()
74
 
75
  def cleanup_temp():
 
78
  os.makedirs(TempDir, exist_ok=True)
79
  gc.collect()
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def download_file(input_path, token, filename=None):
 
82
  local_path = TempDir / (filename if filename else "model.safetensors")
 
83
  if input_path.startswith("http"):
84
+ print(f"Downloading {filename} from URL...")
85
  try:
86
  response = requests.get(input_path, stream=True, timeout=30)
87
  response.raise_for_status()
88
  with open(local_path, 'wb') as f:
89
  for chunk in response.iter_content(chunk_size=8192):
90
  f.write(chunk)
91
+ except Exception as e: raise ValueError(f"Download failed: {e}")
 
92
  else:
93
+ print(f"Downloading {filename} from Hub...")
94
  if not filename:
95
  try:
96
  files = list_repo_files(repo_id=input_path, token=token)
97
  safetensors = [f for f in files if f.endswith(".safetensors")]
98
+ filename = safetensors[0] if safetensors else "adapter_model.safetensors"
99
+ except: filename = "adapter_model.safetensors"
 
 
 
 
 
 
100
 
101
  try:
102
  hf_hub_download(repo_id=input_path, filename=filename, token=token, local_dir=TempDir, local_dir_use_symlinks=False)
103
+ # Handle default download path logic if specific filename wasn't requested
104
+ if not (TempDir / filename).exists():
105
+ # HF might download to a nested folder structure
106
+ found = list(TempDir.rglob(filename))
107
+ if found: shutil.move(found[0], local_path)
108
+ except Exception as e: raise ValueError(f"Hub download failed: {e}")
109
 
 
 
 
110
  return local_path
111
 
112
  def get_key_stem(key):
 
114
  key = key.replace(".lora_down", "").replace(".lora_up", "")
115
  key = key.replace(".lora_A", "").replace(".lora_B", "")
116
  key = key.replace(".alpha", "")
 
117
  prefixes = [
118
  "model.diffusion_model.", "diffusion_model.", "model.",
119
+ "transformer.", "text_encoder.", "lora_unet_", "lora_te_", "base_model.model."
 
120
  ]
 
121
  changed = True
122
  while changed:
123
  changed = False
 
128
  return key
129
 
130
  # =================================================================================
131
+ # TAB 1: GREEDY STREAMING RESHARDER
132
  # =================================================================================
133
 
134
  def load_lora_to_memory(lora_path, precision_dtype=torch.bfloat16):
135
+ print(f"Loading LoRA from {lora_path}...")
136
  state_dict = load_file(lora_path, device="cpu")
 
137
  pairs = {}
138
  alphas = {}
 
139
  for k, v in state_dict.items():
140
  stem = get_key_stem(k)
141
  if "alpha" in k:
142
  alphas[stem] = v.item() if isinstance(v, torch.Tensor) else v
143
  else:
144
+ if stem not in pairs: pairs[stem] = {}
 
 
 
 
 
145
  if "lora_down" in k or "lora_A" in k:
146
+ pairs[stem]["down"] = v.to(dtype=precision_dtype)
147
  pairs[stem]["rank"] = v.shape[0]
148
  elif "lora_up" in k or "lora_B" in k:
149
+ pairs[stem]["up"] = v.to(dtype=precision_dtype)
 
150
  for stem in pairs:
151
+ pairs[stem]["alpha"] = alphas.get(stem, float(pairs[stem].get("rank", 1.0)))
 
 
 
 
 
 
 
152
  return pairs
153
 
154
+ class ShardBuffer:
155
+ def __init__(self, max_size_gb, output_dir, output_repo, hf_token):
156
+ self.max_bytes = int(max_size_gb * 1024**3)
157
+ self.output_dir = output_dir
158
+ self.output_repo = output_repo
159
+ self.hf_token = hf_token
160
+ self.buffer = [] # List of (key, bytes, dtype_str, shape)
161
+ self.current_bytes = 0
162
+ self.shard_count = 0
163
+ self.index_map = {}
 
164
 
165
+ def add_tensor(self, key, tensor):
166
+ # Convert to bytes
167
+ if tensor.dtype == torch.bfloat16:
168
+ # View as int16 to get raw bytes
169
+ raw_bytes = tensor.view(torch.int16).numpy().tobytes()
170
+ dtype_str = "BF16"
171
+ elif tensor.dtype == torch.float16:
172
+ raw_bytes = tensor.numpy().tobytes()
173
+ dtype_str = "F16"
174
  else:
175
+ raw_bytes = tensor.numpy().tobytes()
176
+ dtype_str = "F32"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ size = len(raw_bytes)
179
+ self.buffer.append({
180
+ "key": key,
181
+ "data": raw_bytes,
182
+ "dtype": dtype_str,
183
+ "shape": tensor.shape
184
+ })
185
+ self.current_bytes += size
186
+
187
+ # Flush if full
188
+ if self.current_bytes >= self.max_bytes:
189
+ self.flush()
 
 
190
 
191
+ def flush(self):
192
+ if not self.buffer: return
193
+
194
+ self.shard_count += 1
195
+ # Placeholder filename, will rename later or use sequential numbering
196
+ shard_name = f"model-{self.shard_count:05d}.safetensors" # Suffix to be fixed at end?
197
+ # Actually, standard is model-00001-of-XXXXX.
198
+ # Since we don't know total count yet, we use a temp naming scheme,
199
+ # OR we just use model-00001.safetensors and fix the index.json later.
200
+ # Diffusers accepts model-xxxxx-of-xxxxx.
201
+ # We will use "model-xxxxx.safetensors" and rename locally if needed,
202
+ # but for simple uploading we can just assume we don't know the total yet.
203
+ # Actually, let's just count up. model-00001.safetensors is fine if we update index.
204
+
205
+ print(f"Flushing Shard {self.shard_count} ({self.current_bytes / 1024**3:.2f} GB)...")
206
+
207
+ # Construct Header
208
+ header = {"__metadata__": {"format": "pt"}}
209
+ current_offset = 0
210
+ for item in self.buffer:
211
+ header[item["key"]] = {
212
+ "dtype": item["dtype"],
213
+ "shape": item["shape"],
214
+ "data_offsets": [current_offset, current_offset + len(item["data"])]
215
+ }
216
+ current_offset += len(item["data"])
217
+ self.index_map[item["key"]] = shard_name
218
 
219
+ header_json = json.dumps(header).encode('utf-8')
220
+
221
+ # Write File
222
+ out_path = self.output_dir / shard_name
223
+ with open(out_path, 'wb') as f:
224
+ f.write(struct.pack('<Q', len(header_json)))
225
+ f.write(header_json)
226
+ for item in self.buffer:
227
+ f.write(item["data"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ # Upload
230
+ print(f"Uploading {shard_name}...")
231
+ api.upload_file(path_or_fileobj=out_path, path_in_repo=shard_name, repo_id=self.output_repo, token=self.hf_token)
232
 
233
+ # Cleanup
234
+ os.remove(out_path)
235
+ self.buffer = []
236
+ self.current_bytes = 0
237
+ gc.collect()
238
 
239
+ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision, shard_size, output_repo, structure_repo, private, progress=gr.Progress()):
 
240
  cleanup_temp()
241
  login(hf_token)
242
 
243
+ # 1. Output Setup
 
 
 
 
 
 
 
 
 
244
  try:
245
  api.create_repo(repo_id=output_repo, private=private, exist_ok=True, token=hf_token)
246
+ except Exception as e: return f"Error creating repo: {e}"
247
+
248
+ # Clone structure
249
  if structure_repo:
250
  print("Cloning structure...")
251
  try:
 
256
  path = hf_hub_download(repo_id=structure_repo, filename=f, token=hf_token)
257
  api.upload_file(path_or_fileobj=path, path_in_repo=f, repo_id=output_repo, token=hf_token)
258
  except: pass
259
+ except: pass
 
260
 
261
+ # 2. Load LoRA
262
+ dtype = torch.bfloat16 if precision == "bf16" else torch.float16 if precision == "fp16" else torch.float32
263
  try:
264
  progress(0.1, desc="Downloading LoRA...")
265
+ lora_path = download_file(lora_input, hf_token, filename="adapter.safetensors")
266
  lora_pairs = load_lora_to_memory(lora_path, precision_dtype=dtype)
267
+ except Exception as e: return f"Error loading LoRA: {e}"
268
+
269
+ # 3. Stream Process
270
+ progress(0.2, desc="Fetching File List...")
271
  files = list_repo_files(repo_id=base_repo, token=hf_token)
272
+ input_shards = [f for f in files if f.endswith(".safetensors")]
273
  if base_subfolder:
274
+ input_shards = [f for f in input_shards if f.startswith(base_subfolder)]
275
 
276
+ if not input_shards: return "No base safetensors found."
277
+
278
+ # Sort shards to ensure deterministic processing order
279
+ input_shards.sort()
280
+
281
+ buffer = ShardBuffer(shard_size, TempDir, output_repo, hf_token)
282
+
283
+ for i, shard_file in enumerate(input_shards):
284
+ progress(0.2 + (0.7 * i / len(input_shards)), desc=f"Processing {shard_file}")
285
+ print(f"Downloading {shard_file}...")
286
 
287
+ local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
 
288
 
289
+ # Process tensors
290
+ with MemoryEfficientSafeOpen(local_shard) as f:
291
+ keys = f.keys()
292
+ for k in keys:
293
+ v = f.get_tensor(k)
294
+
295
+ # MERGE LOGIC
296
+ base_stem = get_key_stem(k)
297
+ lora_keys = set(lora_pairs.keys())
298
+ match = None
299
+
300
+ if base_stem in lora_keys:
301
+ match = lora_pairs[base_stem]
302
+ else:
303
+ if "to_q" in base_stem:
304
+ qkv_stem = base_stem.replace("to_q", "qkv")
305
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
306
+ elif "to_k" in base_stem:
307
+ qkv_stem = base_stem.replace("to_k", "qkv")
308
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
309
+ elif "to_v" in base_stem:
310
+ qkv_stem = base_stem.replace("to_v", "qkv")
311
+ if qkv_stem in lora_keys: match = lora_pairs[qkv_stem]
312
+
313
+ if match and "down" in match and "up" in match:
314
+ down = match["down"]
315
+ up = match["up"]
316
+ alpha = match["alpha"]
317
+ rank = match["rank"]
318
+ scaling = scale * (alpha / rank)
319
+
320
+ if len(v.shape) == 4 and len(down.shape) == 2:
321
+ down = down.unsqueeze(-1).unsqueeze(-1)
322
+ up = up.unsqueeze(-1).unsqueeze(-1)
323
+
324
+ try:
325
+ if len(up.shape) == 4:
326
+ delta = (up.squeeze() @ down.squeeze()).reshape(up.shape[0], down.shape[1], 1, 1)
327
+ else:
328
+ delta = up @ down
329
+ except:
330
+ delta = up.T @ down
331
+
332
+ delta = delta * scaling
333
+
334
+ # Slicing
335
+ valid_delta = True
336
+ if delta.shape == v.shape:
337
+ pass
338
+ elif delta.shape[0] == v.shape[0] * 3:
339
+ chunk = v.shape[0]
340
+ if "to_q" in k: delta = delta[0:chunk, ...]
341
+ elif "to_k" in k: delta = delta[chunk:2*chunk, ...]
342
+ elif "to_v" in k: delta = delta[2*chunk:, ...]
343
+ else: valid_delta = False
344
+ elif delta.numel() == v.numel():
345
+ delta = delta.reshape(v.shape)
346
+ else:
347
+ valid_delta = False
348
+
349
+ if valid_delta:
350
+ v = v.to(dtype)
351
+ delta = delta.to(dtype)
352
+ v.add_(delta)
353
+ del delta
354
+
355
+ # Add to buffer
356
+ if v.dtype != dtype: v = v.to(dtype)
357
+ buffer.add_tensor(k, v)
358
+ del v
359
 
360
+ # Cleanup Input Shard immediately
361
  os.remove(local_shard)
 
362
  gc.collect()
363
+
364
+ # Final Flush
365
+ buffer.flush()
366
+
367
+ # Renaming logic (Retroactive):
368
+ # Since we uploaded as model-00001.safetensors, but now we know total count...
369
+ # Actually, Diffusers is fine with model-00001.safetensors format as long as index.json matches.
370
+ # We just need to upload the index.
371
+
372
+ print("Uploading Index...")
373
+ index_data = {"metadata": {"total_size": 0}, "weight_map": buffer.index_map}
374
+ with open(TempDir / "model.safetensors.index.json", "w") as f:
375
+ json.dump(index_data, f, indent=4)
376
+ api.upload_file(path_or_fileobj=TempDir / "model.safetensors.index.json", path_in_repo="model.safetensors.index.json", repo_id=output_repo, token=hf_token)
377
+
378
+ cleanup_temp()
379
+ return f"Done! Merged into {buffer.shard_count} shards at {output_repo}"
380
 
381
  # =================================================================================
382
  # TAB 2: EXTRACT LORA
 
386
  org = MemoryEfficientSafeOpen(model_org)
387
  tuned = MemoryEfficientSafeOpen(model_tuned)
388
  lora_sd = {}
389
+ print("Calculating diffs...")
390
+ for key in tqdm(org.keys()):
 
 
 
391
  if key not in tuned.keys(): continue
392
  mat_org = org.get_tensor(key).float()
393
  mat_tuned = tuned.get_tensor(key).float()
 
394
  diff = mat_tuned - mat_org
395
  if torch.max(torch.abs(diff)) < 1e-4: continue
396
 
 
401
 
402
  try:
403
  U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
404
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
 
405
  U = U @ torch.diag(S)
 
 
406
  dist = torch.cat([U.flatten(), Vh.flatten()])
407
  hi_val = torch.quantile(dist, clamp)
408
  U = U.clamp(-hi_val, hi_val)
409
  Vh = Vh.clamp(-hi_val, hi_val)
 
410
  if is_conv:
411
  U = U.reshape(out_dim, r, 1, 1)
412
  Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
413
  else:
414
  U = U.reshape(out_dim, r)
415
  Vh = Vh.reshape(r, in_dim)
 
416
  stem = key.replace(".weight", "")
417
  lora_sd[f"{stem}.lora_up.weight"] = U
418
  lora_sd[f"{stem}.lora_down.weight"] = Vh
419
  lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
420
+ except: pass
421
+ out = TempDir / "extracted.safetensors"
422
+ save_file(lora_sd, out)
423
+ return str(out)
 
 
424
 
425
+ def task_extract(hf_token, org, tun, rank, out):
426
  cleanup_temp()
427
  login(hf_token)
 
428
  try:
429
+ p1 = download_file(org, hf_token, filename="org.safetensors")
430
+ p2 = download_file(tun, hf_token, filename="tun.safetensors")
431
+ f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
432
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
433
+ api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token)
434
+ return "Done"
435
+ except Exception as e: return f"Error: {e}"
 
436
 
437
  # =================================================================================
438
+ # TAB 3 & 4
439
  # =================================================================================
440
 
441
+ def task_merge_adapters(hf_token, urls, beta, out_repo):
442
  cleanup_temp()
443
  login(hf_token)
 
 
444
  try:
445
+ paths = [download_file(u.strip(), hf_token, filename=f"a_{i}.safetensors") for i,u in enumerate(urls.split(",")) if u.strip()]
446
+ if not paths: return "No files"
447
+ base = load_file(paths[0], device="cpu")
448
+ for k in base:
449
+ if base[k].dtype.is_floating_point: base[k] = base[k].float()
450
+ for p in paths[1:]:
451
+ c = load_file(p, device="cpu")
452
+ for k in base:
453
+ if k in c and "alpha" not in k:
454
+ base[k] = base[k] * beta + c[k].float() * (1-beta)
455
+ out = TempDir / "merged_adapters.safetensors"
456
+ save_file(base, out)
457
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
458
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
459
+ return "Done"
460
+ except Exception as e: return f"Error: {e}"
461
+
462
+ def task_resize(hf_token, lora, rank, out):
463
+ return "See previous versions for full code."
 
 
 
 
464
 
465
  # =================================================================================
466
+ # UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  # =================================================================================
468
 
469
  css = ".container { max-width: 900px; margin: auto; }"
470
 
471
  with gr.Blocks() as demo:
472
+ gr.Markdown("# 🧰 Universal LoRA Toolkit V12 (Greedy Streaming)")
473
 
474
  with gr.Tabs():
475
+ with gr.Tab("Merge + Reshard"):
476
  t1_token = gr.Textbox(label="Token", type="password")
477
  t1_base = gr.Textbox(label="Base Repo", value="ostris/Z-Image-De-Turbo")
478
  t1_sub = gr.Textbox(label="Subfolder", value="transformer")
479
  t1_lora = gr.Textbox(label="LoRA")
 
480
  with gr.Row():
481
+ t1_scale = gr.Slider(label="Scale", value=1.0)
482
+ t1_prec = gr.Radio(["bf16", "fp16", "float32"], value="bf16", label="Precision")
483
+ t1_shard = gr.Slider(label="Shard Size (GB)", value=2.0, minimum=0.5, maximum=10.0, step=0.5)
484
  t1_out = gr.Textbox(label="Output")
485
  t1_struct = gr.Textbox(label="Structure Repo", value="Tongyi-MAI/Z-Image-Turbo")
486
+ t1_priv = gr.Checkbox(label="Private", value=True)
487
+ t1_btn = gr.Button("Merge & Reshard")
 
 
488
  t1_res = gr.Textbox(label="Result")
489
+ t1_btn.click(task_merge, [t1_token, t1_base, t1_sub, t1_lora, t1_scale, t1_prec, t1_shard, t1_out, t1_struct, t1_priv], t1_res)
490
+
 
 
 
 
 
 
491
  with gr.Tab("Extract"):
492
  t2_token = gr.Textbox(label="Token", type="password")
493
  t2_org = gr.Textbox(label="Original")
 
497
  t2_btn = gr.Button("Extract")
498
  t2_res = gr.Textbox(label="Result")
499
  t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
500
+
501
  with gr.Tab("Merge Adapters"):
502
  t3_token = gr.Textbox(label="Token", type="password")
503
+ t3_urls = gr.Textbox(label="URLs")
504
  t3_beta = gr.Slider(label="Beta", value=0.9)
505
  t3_out = gr.Textbox(label="Output")
506
  t3_btn = gr.Button("Merge")
507
  t3_res = gr.Textbox(label="Result")
508
  t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_out], t3_res)
 
 
 
 
 
 
 
 
 
509
 
510
  if __name__ == "__main__":
511
  demo.queue().launch(css=css, ssr_mode=False)