AlekseyCalvin commited on
Commit
3848f5b
·
verified ·
1 Parent(s): 0a8593d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -16
app.py CHANGED
@@ -454,42 +454,111 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
454
  # TAB 2: EXTRACT LORA
455
  # =================================================================================
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
458
  org = MemoryEfficientSafeOpen(model_org)
459
  tuned = MemoryEfficientSafeOpen(model_tuned)
460
  lora_sd = {}
461
- print("Calculating diffs...")
462
- for key in tqdm(org.keys()):
463
- if key not in tuned.keys(): continue
 
 
 
 
 
 
 
464
  mat_org = org.get_tensor(key).float()
465
  mat_tuned = tuned.get_tensor(key).float()
 
 
 
 
466
  diff = mat_tuned - mat_org
 
 
467
  if torch.max(torch.abs(diff)) < 1e-4: continue
468
 
469
- out_dim, in_dim = diff.shape[:2]
 
 
470
  r = min(rank, in_dim, out_dim)
 
471
  is_conv = len(diff.shape) == 4
472
  if is_conv: diff = diff.flatten(start_dim=1)
 
473
 
474
  try:
475
- U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
476
- U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
 
 
 
 
 
 
 
477
  U = U @ torch.diag(S)
 
 
478
  dist = torch.cat([U.flatten(), Vh.flatten()])
479
- hi_val = torch.quantile(dist, clamp)
480
- U = U.clamp(-hi_val, hi_val)
481
- Vh = Vh.clamp(-hi_val, hi_val)
 
 
482
  if is_conv:
483
  U = U.reshape(out_dim, r, 1, 1)
484
  Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
485
  else:
486
  U = U.reshape(out_dim, r)
487
  Vh = Vh.reshape(r, in_dim)
 
488
  stem = key.replace(".weight", "")
489
- lora_sd[f"{stem}.lora_up.weight"] = U
490
- lora_sd[f"{stem}.lora_down.weight"] = Vh
491
  lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
492
- except: pass
 
 
 
493
  out = TempDir / "extracted.safetensors"
494
  save_file(lora_sd, out)
495
  return str(out)
@@ -498,12 +567,16 @@ def task_extract(hf_token, org, tun, rank, out):
498
  cleanup_temp()
499
  if hf_token: login(hf_token.strip())
500
  try:
501
- p1 = download_file(org, hf_token, filename="org.safetensors")
502
- p2 = download_file(tun, hf_token, filename="tun.safetensors")
 
 
 
503
  f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
 
504
  api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
505
- api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token)
506
- return "Done"
507
  except Exception as e: return f"Error: {e}"
508
 
509
  # =================================================================================
 
454
  # TAB 2: EXTRACT LORA
455
  # =================================================================================
456
 
457
+ def identify_and_download_model(repo_id, token):
458
+ """
459
+ Smart download: checks for diffusers format (unet/transformer) vs standard safetensors.
460
+ """
461
+ print(f"Scanning {repo_id} for model weights...")
462
+ files = list_repo_files(repo_id=repo_id, token=token)
463
+
464
+ # Priority list for diffusers vs single file
465
+ priorities = [
466
+ "transformer/diffusion_pytorch_model.safetensors",
467
+ "unet/diffusion_pytorch_model.safetensors",
468
+ "model.safetensors",
469
+ # Fallback to any safetensors that isn't an adapter or lora
470
+ lambda f: f.endswith(".safetensors") and "lora" not in f and "adapter" not in f and "extracted" not in f
471
+ ]
472
+
473
+ target_file = None
474
+ for p in priorities:
475
+ if callable(p):
476
+ candidates = [f for f in files if p(f)]
477
+ if candidates:
478
+ target_file = candidates[0]
479
+ break
480
+ elif p in files:
481
+ target_file = p
482
+ break
483
+
484
+ if not target_file:
485
+ raise ValueError(f"Could not find a valid model weight file in {repo_id}. Ensure it contains .safetensors weights.")
486
+
487
+ print(f"Downloading main weight file: {target_file}")
488
+ hf_hub_download(repo_id=repo_id, filename=target_file, token=token, local_dir=TempDir)
489
+
490
+ # Locate actual path
491
+ found = list(TempDir.rglob(os.path.basename(target_file)))[0]
492
+ return found
493
+
494
  def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
495
  org = MemoryEfficientSafeOpen(model_org)
496
  tuned = MemoryEfficientSafeOpen(model_tuned)
497
  lora_sd = {}
498
+ print("Calculating diffs & extracting LoRA...")
499
+
500
+ # Get intersection of keys
501
+ keys = set(org.keys()).intersection(set(tuned.keys()))
502
+
503
+ for key in tqdm(keys, desc="Extracting"):
504
+ # Skip integer buffers/metadata
505
+ if "num_batches_tracked" in key or "running_mean" in key or "running_var" in key:
506
+ continue
507
+
508
  mat_org = org.get_tensor(key).float()
509
  mat_tuned = tuned.get_tensor(key).float()
510
+
511
+ # Skip if shapes mismatch (shouldn't happen if models match)
512
+ if mat_org.shape != mat_tuned.shape: continue
513
+
514
  diff = mat_tuned - mat_org
515
+
516
+ # Skip if no difference
517
  if torch.max(torch.abs(diff)) < 1e-4: continue
518
 
519
+ out_dim = diff.shape[0]
520
+ in_dim = diff.shape[1] if len(diff.shape) > 1 else 1
521
+
522
  r = min(rank, in_dim, out_dim)
523
+
524
  is_conv = len(diff.shape) == 4
525
  if is_conv: diff = diff.flatten(start_dim=1)
526
+ elif len(diff.shape) == 1: diff = diff.unsqueeze(1) # Handle biases if needed, though rarely lora'd
527
 
528
  try:
529
+ # Use svd_lowrank for massive speedup on CPU vs linalg.svd
530
+ U, S, V = torch.svd_lowrank(diff, q=r+4, niter=4)
531
+ Vh = V.t()
532
+
533
+ U = U[:, :r]
534
+ S = S[:r]
535
+ Vh = Vh[:r, :]
536
+
537
+ # Merge S into U for standard LoRA format
538
  U = U @ torch.diag(S)
539
+
540
+ # Clamp outliers
541
  dist = torch.cat([U.flatten(), Vh.flatten()])
542
+ hi_val = torch.quantile(torch.abs(dist), clamp)
543
+ if hi_val > 0:
544
+ U = U.clamp(-hi_val, hi_val)
545
+ Vh = Vh.clamp(-hi_val, hi_val)
546
+
547
  if is_conv:
548
  U = U.reshape(out_dim, r, 1, 1)
549
  Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
550
  else:
551
  U = U.reshape(out_dim, r)
552
  Vh = Vh.reshape(r, in_dim)
553
+
554
  stem = key.replace(".weight", "")
555
+ lora_sd[f"{stem}.lora_up.weight"] = U.contiguous()
556
+ lora_sd[f"{stem}.lora_down.weight"] = Vh.contiguous()
557
  lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
558
+ except Exception as e:
559
+ print(f"Skipping {key} due to error: {e}")
560
+ pass
561
+
562
  out = TempDir / "extracted.safetensors"
563
  save_file(lora_sd, out)
564
  return str(out)
 
567
  cleanup_temp()
568
  if hf_token: login(hf_token.strip())
569
  try:
570
+ print("Downloading Original Model...")
571
+ p1 = identify_and_download_model(org, hf_token)
572
+ print("Downloading Tuned Model...")
573
+ p2 = identify_and_download_model(tun, hf_token)
574
+
575
  f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
576
+
577
  api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
578
+ api.upload_file(path_or_fileobj=f, path_in_repo="extracted_lora.safetensors", repo_id=out, token=hf_token)
579
+ return "Done! Extracted to " + out
580
  except Exception as e: return f"Error: {e}"
581
 
582
  # =================================================================================