AlekseyCalvin commited on
Commit
5931ee2
·
verified ·
1 Parent(s): c1d9d54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -25
app.py CHANGED
@@ -608,9 +608,22 @@ def task_extract(hf_token, org, tun, rank, out):
608
  except Exception as e: return f"Error: {e}"
609
 
610
  # =================================================================================
611
- # TAB 3: MERGE ADAPTERS (EMA)
612
  # =================================================================================
613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  def sigma_rel_to_gamma(sigma_rel):
615
  t = sigma_rel**-2
616
  coeffs = [1, 7, 16 - t, 12 - t]
@@ -618,19 +631,8 @@ def sigma_rel_to_gamma(sigma_rel):
618
  gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
619
  return gamma
620
 
621
- def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo):
622
- cleanup_temp()
623
- if hf_token: login(hf_token.strip())
624
-
625
- urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
626
- paths = []
627
- try:
628
- for i, url in enumerate(urls):
629
- paths.append(download_lora_smart(url, hf_token))
630
- except Exception as e: return f"Download Error: {e}"
631
-
632
- if not paths: return "No models found"
633
-
634
  base_sd = load_file(paths[0], device="cpu")
635
  for k in base_sd:
636
  if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
@@ -651,12 +653,210 @@ def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo):
651
  for k in base_sd:
652
  if k in curr and "alpha" not in k:
653
  base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  out = TempDir / "merged_adapters.safetensors"
656
- save_file(base_sd, out)
657
- api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
658
- api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
659
- return "Done"
 
 
 
660
 
661
  # =================================================================================
662
  # TAB 4: RESIZE (CPU Optimized)
@@ -770,18 +970,34 @@ with gr.Blocks() as demo:
770
  t2_btn = gr.Button("Extract")
771
  t2_res = gr.Textbox(label="Result")
772
  t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
773
-
774
  with gr.Tab("Merge Multiple Adapters"):
 
775
  t3_token = gr.Textbox(label="Token", type="password")
776
- t3_urls = gr.Textbox(label="URLs")
 
 
 
 
 
 
 
 
 
 
 
 
777
  with gr.Row():
778
- t3_beta = gr.Slider(label="Beta", value=0.95, minimum=0.01, maximum=1.00, step=0.01)
779
- t3_sigma = gr.Slider(label="Sigma Rel (Overrides Beta)", value=0.21, minimum=0.01, maximum=1.00, step=0.01)
 
780
  t3_out = gr.Textbox(label="Output Repo")
781
- t3_btn = gr.Button("Merge")
 
782
  t3_res = gr.Textbox(label="Result")
783
- t3_btn.click(task_merge_adapters, [t3_token, t3_urls, t3_beta, t3_sigma, t3_out], t3_res)
784
-
 
785
  with gr.Tab("Resize Adapter"):
786
  t4_token = gr.Textbox(label="Token", type="password")
787
  t4_in = gr.Textbox(label="LoRA")
 
608
  except Exception as e: return f"Error: {e}"
609
 
610
  # =================================================================================
611
+ # TAB 3: MERGE ADAPTERS (Multi-Method)
612
  # =================================================================================
613
 
614
+ def load_full_state_dict(path):
615
+ """Loads a safetensor file and cleans keys for easier processing."""
616
+ raw = load_file(path, device="cpu")
617
+ cleaned = {}
618
+ for k, v in raw.items():
619
+ # Map common keys to standard "lora_up/lora_down"
620
+ if "lora_A" in k: new_k = k.replace("lora_A", "lora_down")
621
+ elif "lora_B" in k: new_k = k.replace("lora_B", "lora_up")
622
+ else: new_k = k
623
+ cleaned[new_k] = v.float()
624
+ return cleaned
625
+
626
+ # --- Original EMA Method ---
627
  def sigma_rel_to_gamma(sigma_rel):
628
  t = sigma_rel**-2
629
  coeffs = [1, 7, 16 - t, 12 - t]
 
631
  gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
632
  return gamma
633
 
634
+ def merge_lora_iterative_ema(paths, beta, sigma_rel):
635
+ print("Executing Iterative EMA Merge (Original Method)...")
 
 
 
 
 
 
 
 
 
 
 
636
  base_sd = load_file(paths[0], device="cpu")
637
  for k in base_sd:
638
  if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
 
653
  for k in base_sd:
654
  if k in curr and "alpha" not in k:
655
  base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
656
+ return base_sd
657
+
658
+ # --- New Concatenation Method (DiffSynth) ---
659
+ def merge_lora_concatenation(adapter_states, weights):
660
+ """
661
+ DiffSynth Method: Concatenates ranks.
662
+ New Rank = sum(ranks). Lossless merging.
663
+ """
664
+ print("Executing Concatenation Merge (Rank Summation)...")
665
+ merged_state = {}
666
+
667
+ # Identify all stems (layers) present across all adapters
668
+ all_stems = set()
669
+ for state in adapter_states:
670
+ for k in state.keys():
671
+ stem = k.split(".lora_")[0]
672
+ if "lora_" in k: all_stems.add(stem)
673
+
674
+ for stem in tqdm(all_stems, desc="Concatenating Layers"):
675
+ down_list = []
676
+ up_list = []
677
+ alpha_sum = 0.0
678
+
679
+ for i, state in enumerate(adapter_states):
680
+ w = weights[i]
681
+ down_key = f"{stem}.lora_down.weight"
682
+ up_key = f"{stem}.lora_up.weight"
683
+ alpha_key = f"{stem}.alpha"
684
+
685
+ if down_key in state and up_key in state:
686
+ d = state[down_key]
687
+ u = state[up_key] * w # weighted contribution applied to UP
688
+
689
+ down_list.append(d)
690
+ up_list.append(u)
691
+
692
+ if alpha_key in state:
693
+ alpha_sum += state[alpha_key].item()
694
+ else:
695
+ alpha_sum += d.shape[0]
696
+
697
+ if down_list and up_list:
698
+ # Concat Down (A) along dim 0 (output of A, input to B) - Wait, lora_A is (rank, in)
699
+ # Concat Up (B) along dim 1 (input of B) - lora_B is (out, rank)
700
+ # Reference: DiffSynth code: lora_A = concat(tensors_A, dim=0), lora_B = concat(tensors_B, dim=1)
701
+
702
+ new_down = torch.cat(down_list, dim=0) # (sum_rank, in)
703
+ new_up = torch.cat(up_list, dim=1) # (out, sum_rank)
704
+
705
+ merged_state[f"{stem}.lora_down.weight"] = new_down.contiguous()
706
+ merged_state[f"{stem}.lora_up.weight"] = new_up.contiguous()
707
+ merged_state[f"{stem}.alpha"] = torch.tensor(alpha_sum)
708
+
709
+ return merged_state
710
+
711
+ # --- New SVD/Task Arithmetic Method ---
712
+ def merge_lora_svd(adapter_states, weights, target_rank):
713
+ """
714
+ SVD / Task Arithmetic Method:
715
+ 1. Calculate Delta W for each adapter: dW = B @ A
716
+ 2. Sum Delta Ws: Total dW = sum(weight_i * dW_i)
717
+ 3. SVD(Total dW) -> New B, New A at target_rank
718
+ """
719
+ print(f"Executing SVD Merge (Target Rank: {target_rank})...")
720
+ merged_state = {}
721
+
722
+ all_stems = set()
723
+ for state in adapter_states:
724
+ for k in state.keys():
725
+ stem = k.split(".lora_")[0]
726
+ if "lora_" in k: all_stems.add(stem)
727
+
728
+ for stem in tqdm(all_stems, desc="SVD Merging Layers"):
729
+ total_delta = None
730
+ valid_layer = False
731
+
732
+ for i, state in enumerate(adapter_states):
733
+ w = weights[i]
734
+ down_key = f"{stem}.lora_down.weight"
735
+ up_key = f"{stem}.lora_up.weight"
736
+ alpha_key = f"{stem}.alpha"
737
+
738
+ if down_key in state and up_key in state:
739
+ down = state[down_key]
740
+ up = state[up_key]
741
+ alpha = state[alpha_key].item() if alpha_key in state else down.shape[0]
742
+ rank = down.shape[0]
743
+
744
+ scale = (alpha / rank) * w
745
+
746
+ # Reconstruct Delta
747
+ if len(down.shape) == 4: # Conv2d
748
+ d_flat = down.flatten(start_dim=1)
749
+ u_flat = up.flatten(start_dim=1)
750
+ delta = (u_flat @ d_flat).reshape(up.shape[0], down.shape[1], down.shape[2], down.shape[3])
751
+ else:
752
+ delta = up @ down
753
+
754
+ delta = delta * scale
755
+
756
+ if total_delta is None:
757
+ total_delta = delta
758
+ valid_layer = True
759
+ else:
760
+ if total_delta.shape == delta.shape:
761
+ total_delta += delta
762
+ else:
763
+ print(f"Shape mismatch in {stem}, skipping.")
764
+
765
+ if valid_layer and total_delta is not None:
766
+ out_dim = total_delta.shape[0]
767
+ in_dim = total_delta.shape[1]
768
+ is_conv = len(total_delta.shape) == 4
769
+
770
+ if is_conv:
771
+ flat_delta = total_delta.flatten(start_dim=1)
772
+ else:
773
+ flat_delta = total_delta
774
+
775
+ try:
776
+ U, S, V = torch.svd_lowrank(flat_delta, q=target_rank + 4, niter=4)
777
+ Vh = V.t()
778
+
779
+ U = U[:, :target_rank]
780
+ S = S[:target_rank]
781
+ Vh = Vh[:target_rank, :]
782
 
783
+ U = U @ torch.diag(S)
784
+
785
+ if is_conv:
786
+ U = U.reshape(out_dim, target_rank, 1, 1)
787
+ Vh = Vh.reshape(target_rank, in_dim, total_delta.shape[2], total_delta.shape[3])
788
+ else:
789
+ U = U.reshape(out_dim, target_rank)
790
+ Vh = Vh.reshape(target_rank, in_dim)
791
+
792
+ merged_state[f"{stem}.lora_down.weight"] = Vh.contiguous()
793
+ merged_state[f"{stem}.lora_up.weight"] = U.contiguous()
794
+ merged_state[f"{stem}.alpha"] = torch.tensor(target_rank).float()
795
+ except Exception as e:
796
+ print(f"SVD Failed for {stem}: {e}")
797
+
798
+ return merged_state
799
+
800
+ def task_merge_adapters_advanced(hf_token, inputs_text, method, weight_str, beta, sigma_rel, target_rank, out_repo, private):
801
+ cleanup_temp()
802
+ if hf_token: login(hf_token.strip())
803
+
804
+ if not out_repo or not out_repo.strip():
805
+ return "Error: Output Repo cannot be empty."
806
+
807
+ # 1. Parse Inputs (Multi-line support)
808
+ raw_lines = inputs_text.replace(" ", "\n").split('\n')
809
+ urls = [line.strip() for line in raw_lines if line.strip()]
810
+ if len(urls) < 2: return "Error: Please provide at least 2 adapters."
811
+
812
+ # 2. Parse Weights (for SVD/Concatenation)
813
+ try:
814
+ if not weight_str.strip():
815
+ weights = [1.0] * len(urls)
816
+ else:
817
+ weights = [float(w.strip()) for w in weight_str.split(',')]
818
+ # Broadcast or Truncate
819
+ if len(weights) < len(urls):
820
+ weights += [1.0] * (len(urls) - len(weights))
821
+ else:
822
+ weights = weights[:len(urls)]
823
+ except:
824
+ return "Error parsing weights. Use format: 1.0, 0.5, 0.8"
825
+
826
+ # 3. Download All
827
+ paths = []
828
+ try:
829
+ for url in tqdm(urls, desc="Downloading Adapters"):
830
+ paths.append(download_lora_smart(url, hf_token))
831
+ except Exception as e: return f"Download Error: {e}"
832
+
833
+ merged = None
834
+
835
+ # 4. Execute Selected Method
836
+ if "Iterative EMA" in method:
837
+ # Calls the original method logic exactly
838
+ merged = merge_lora_iterative_ema(paths, beta, sigma_rel)
839
+
840
+ else:
841
+ # For new methods, we load everything upfront
842
+ states = [load_full_state_dict(p) for p in paths]
843
+
844
+ if "Concatenation" in method:
845
+ merged = merge_lora_concatenation(states, weights)
846
+ elif "SVD" in method:
847
+ merged = merge_lora_svd(states, weights, int(target_rank))
848
+
849
+ if not merged: return "Merge failed (Result empty)."
850
+
851
+ # 5. Save & Upload
852
  out = TempDir / "merged_adapters.safetensors"
853
+ save_file(merged, out)
854
+
855
+ try:
856
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
857
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
858
+ return f"Success! Merged to {out_repo}"
859
+ except Exception as e: return f"Upload Error: {e}"
860
 
861
  # =================================================================================
862
  # TAB 4: RESIZE (CPU Optimized)
 
970
  t2_btn = gr.Button("Extract")
971
  t2_res = gr.Textbox(label="Result")
972
  t2_btn.click(task_extract, [t2_token, t2_org, t2_tun, t2_rank, t2_out], t2_res)
973
+
974
  with gr.Tab("Merge Multiple Adapters"):
975
+ gr.Markdown("### Batch Adapter Merging")
976
  t3_token = gr.Textbox(label="Token", type="password")
977
+ t3_urls = gr.TextArea(label="Adapter URLs/Repos (One per line, or space separated)", placeholder="ostris/lora1\nhttps://hf.co/user/lora2.safetensors\n...")
978
+
979
+ with gr.Row():
980
+ t3_method = gr.Dropdown(
981
+ ["Iterative EMA (Original Beta/Sigma)", "Concatenation (DiffSynth - Lossless)", "SVD Merge (Task Arithmetic/Compressed)"],
982
+ value="Iterative EMA (Original Beta/Sigma)",
983
+ label="Merge Method"
984
+ )
985
+
986
+ with gr.Row():
987
+ t3_weights = gr.Textbox(label="Weights (Comma separated) - For Concat/SVD", placeholder="1.0, 0.5, 0.8...")
988
+ t3_rank = gr.Number(label="Target Rank - For SVD only", value=128, minimum=4, maximum=1024)
989
+
990
  with gr.Row():
991
+ t3_beta = gr.Slider(label="Beta - For EMA only", value=0.95, minimum=0.01, maximum=1.00, step=0.01)
992
+ t3_sigma = gr.Slider(label="Sigma Rel - For EMA only", value=0.21, minimum=0.01, maximum=1.00, step=0.01)
993
+
994
  t3_out = gr.Textbox(label="Output Repo")
995
+ t3_priv = gr.Checkbox(label="Private Output", value=True)
996
+ t3_btn = gr.Button("Merge Adapters")
997
  t3_res = gr.Textbox(label="Result")
998
+
999
+ t3_btn.click(task_merge_adapters_advanced, [t3_token, t3_urls, t3_method, t3_weights, t3_beta, t3_sigma, t3_rank, t3_out, t3_priv], t3_res)
1000
+
1001
  with gr.Tab("Resize Adapter"):
1002
  t4_token = gr.Textbox(label="Token", type="password")
1003
  t4_in = gr.Textbox(label="LoRA")