raayraay commited on
Commit
b0a7741
·
verified ·
1 Parent(s): 998155c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -150
app.py CHANGED
@@ -7,148 +7,153 @@ import numpy as np
7
  import os
8
  import gc
9
  from scipy.sparse.linalg import svds
 
 
 
10
 
11
  # --- CORE SAMM ALGORITHM ---
12
 
13
  def farms_spectral_analysis(tensor, num_patches=10, patch_size=64):
14
- """
15
- Implements the FARMS method (Fixed-Aspect-Ratio Matrix Subsampling).
16
- Instead of analyzing the full rectangular matrix, we sample square submatrices
17
- to get a robust estimate of the spectral density and dominant directions.
18
- """""
19
- # Ensure tensor is 2D
20
- if len(tensor.shape) != 2:
21
- return None, None
22
-
23
- rows, cols = tensor.shape
24
- u_list = []
25
-
26
- # FARMS: Randomly sample square patches to avoid aspect ratio bias
27
- for _ in range(num_patches):
28
- r_start = np.random.randint(0, max(1, rows - patch_size))
29
- c_start = np.random.randint(0, max(1, cols - patch_size))
30
-
31
- # Extract patch
32
- patch = tensor[r_start:r_start+patch_size, c_start:c_start+patch_size]
33
-
34
- # Compute SVD on patch
35
- try:
36
- # We only need top components to find the "Universal Subspace"
37
- u, s, vh = np.linalg.svd(patch.float().numpy(), full_matrices=False)
38
- u_list.append(u[:, :1]) # Keep top principal direction
39
- except:
40
- continue
41
 
42
- # In a full implementation, we would aggregate these patch spectra.
43
- # For this simplified Space, we return the Full SVD guided by the hypothesis
44
- # that the top directions are stable.
45
 
46
- # Fallback to full SVD for the merging step, but using the "Universal" concept
47
- # We posit the top k singular vectors form the shared subspace.
48
  try:
49
- u, s, v = torch.svd_lowrank(tensor.float(), q=32) # Efficient randomized SVD
50
- return u, v # Returns Left (U) and Right (V) singular vectors
51
- except:
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  return None, None
53
 
54
- def spectral_aware_merge(adapters_dict, merge_ratio=0.5):
55
- """
56
- Merges adapters by aligning them in the Universal Weight Subspace.
57
- """""
58
- merged_state_dict = {}
59
-
60
- # Get the keys (layer names) from the first adapter
61
- all_keys = list(next(iter(adapters_dict.values())).keys())
62
-
63
- print(f"Starting SAMM merge on {len(all_keys)} layers...")
64
-
65
- for key in all_keys:
66
- # 1. Collect weights from all adapters for this layer
67
- layer_tensors = []
68
- for name, state in adapters_dict.items():
69
- if key in state:
70
- layer_tensors.append(state[key])
71
-
72
- if not layer_tensors:
73
- continue
74
-
75
- # Stack for analysis
76
- # Shape: (N_adapters, rows, cols)
77
- stack = torch.stack(layer_tensors)
78
- avg_weight = torch.mean(stack, dim=0)
79
-
80
- # 2. IF it's a LoRA weight (usually 'lora_A' or 'lora_B'), we try SAMM
81
- # For simplicity in this demo, we apply it to the computed Delta W or the raw weights
82
- # Here we apply the Universal Subspace Hypothesis:
83
- # "The mean is a good approximation only if we project out the noise orthogonal to the principal subspace."
84
-
85
- # Compute "Universal" basis from the average (center of the cluster)
86
- # Using the FARMS concept: the shared structure is in the dominant spectrum
87
- u_univ, v_univ = farms_spectral_analysis(avg_weight.cpu())
88
-
89
- if u_univ is not None:
90
- # Project all adapters into this subspace and re-construct
91
- # W_clean = U U^T W (Filtering out non-universal spectral noise)
92
-
93
- cleaned_tensors = []
94
- for w in layer_tensors:
95
- w = w.float().cpu()
96
- # Project onto Top-32 universal directions (Filtering)
97
- # W_proj = U @ (U.T @ W)
98
- w_proj = torch.mm(u_univ, torch.mm(u_univ.t(), w))
99
- cleaned_tensors.append(w_proj)
100
-
101
- # Average the "Cleaned" (Spectrally Aligned) weights
102
- merged_weight = torch.mean(torch.stack(cleaned_tensors), dim=0)
103
- else:
104
- # Fallback to simple average if SVD fails or vector is 1D
105
- merged_weight = avg_weight
106
-
107
- merged_state_dict[key] = merged_weight
 
 
 
108
 
109
  return merged_state_dict
110
 
111
  # --- GRADIO HANDLERS ---
112
 
113
  def run_samm_merge(base_model_id, lora_ids_text, hf_token):
114
- if not hf_token:
115
- return "Error: Please enter a Hugging Face Write Token."
116
 
117
- lora_ids = [x.strip() for x in lora_ids_text.split(",") if x.strip()]
118
 
119
- if len(lora_ids) < 2:
120
- return "Error: Please provide at least 2 LoRA adapters to merge."
121
 
122
- log = f"Loading {len(lora_ids)} adapters...\n"
123
- yield log
124
 
125
- try:
126
- # 1. Download/Load Adapters (Weights only to save RAM)
127
- adapters_weights = {}
128
-
129
- for lora_id in lora_ids:
130
- log += f"Fetching {lora_id}...\n"
131
- yield log
132
-
133
- # We use PEFT to download, but we manually load state_dict to avoid loading Base Model 10 times
134
- # Note: In a real large-scale deployment, we would stream this.
135
- # Here we assume LoRA weights are small enough to fit in RAM.
136
- try:
137
- # Hack: Use downloading logic from PEFT without loading base model
138
- from huggingface_hub import snapshot_download
139
- path = snapshot_download(repo_id=lora_id, token=hf_token)
140
-
141
- # Load safetensors or bin
142
- if os.path.exists(os.path.join(path, "adapter_model.safetensors")):
143
- from safetensors.torch import load_file
144
- state = load_file(os.path.join(path, "adapter_model.safetensors"))
145
- else:
146
- state = torch.load(os.path.join(path, "adapter_model.bin"), map_location="cpu")
147
 
148
- adapters_weights[lora_id] = state
149
- except Exception as e:
150
  log += f"Failed to load {lora_id}: {str(e)}\n"
151
  yield log
 
152
 
153
  # 2. Perform SAMM Merge
154
  log += "\nInitializing Spectral-Aware Model Merging (SAMM)...\n"
@@ -162,48 +167,47 @@ except Exception as e:
162
  os.makedirs(output_dir, exist_ok=True)
163
 
164
  # Save weights
165
- from safetensors.torch import save_file
166
  save_file(merged_weights, os.path.join(output_dir, "adapter_model.safetensors"))
167
 
168
  # Save config (Copy from first adapter)
169
- import json
170
  config_path = snapshot_download(repo_id=lora_ids[0], token=hf_token)
171
  with open(os.path.join(config_path, "adapter_config.json"), 'r') as f:
172
- config = json.load(f)
173
- with open(os.path.join(output_dir, "adapter_config.json"), 'w') as f:
174
- json.dump(config, f)
175
 
176
- log += f"\nSuccess! Merged LoRA saved locally to ./{output_dir}\n"
177
  log += "Ready for download or push to hub."
178
  yield log
179
 
180
- except Exception as e:
181
  yield f"Critical Error: {str(e)}"
182
 
183
  # --- UI SETUP ---
184
 
185
  with gr.Blocks(title="SAMM: Spectral-Aware Model Merging") as demo:
186
- gr.Markdown("""
187
- # 💡 SAMM: Spectral-Aware Model Merging
188
- Algorithm: Universal Weight Subspace via FARMS (Fixed-Aspect-Ratio Matrix Subsampling)
189
-
190
- This tool merges multiple LoRA adapters by identifying their shared spectral directions (the """Universal Subspace")
191
- and projecting weights into this noise-free manifold before averaging.
192
- """)
193
-
194
- with gr.Row():
195
- base_model_input = gr.Textbox(label="Base Model ID""", value="mistralai/Mistral-7B-v0.1")
196
- hf_token_input = gr.Textbox(label="HF Write Token", type="password")
197
-
198
- loras_input = gr.Textbox(label="LoRA Adapter IDs (comma separated)",
199
- placeholder="user/lora1, user/lora2, user/lora3...", lines=3)
200
-
201
- merge_btn = gr.Button("Perform Spectral Merge", variant="primary")
202
- output_log = gr.Textbox(label="Merge Logs", lines=10)
203
-
204
- merge_btn.click(fn=run_samm_merge,
205
- inputs=[base_model_input, loras_input, hf_token_input],
206
- outputs=output_log)
207
 
208
- if __name__ == "__main__":
209
- demo.queue().launch()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import os
8
  import gc
9
  from scipy.sparse.linalg import svds
10
+ from huggingface_hub import snapshot_download
11
+ from safetensors.torch import load_file, save_file
12
+ import json
13
 
14
  # --- CORE SAMM ALGORITHM ---
15
 
16
  def farms_spectral_analysis(tensor, num_patches=10, patch_size=64):
17
+ """
18
+ Implements the FARMS method (Fixed-Aspect-Ratio Matrix Subsampling).
19
+ Instead of analyzing the full rectangular matrix, we sample square submatrices
20
+ to get a robust estimate of the spectral density and dominant directions.
21
+ """
22
+ # Ensure tensor is 2D
23
+ if len(tensor.shape) != 2:
24
+ return None, None
25
+
26
+ rows, cols = tensor.shape
27
+ u_list = []
28
+
29
+ # FARMS: Randomly sample square patches to avoid aspect ratio bias
30
+ for _ in range(num_patches):
31
+ r_start = np.random.randint(0, max(1, rows - patch_size))
32
+ c_start = np.random.randint(0, max(1, cols - patch_size))
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Extract patch
35
+ patch = tensor[r_start:r_start+patch_size, c_start:c_start+patch_size]
 
36
 
37
+ # Compute SVD on patch
 
38
  try:
39
+ # We only need top components to find the "Universal Subspace"
40
+ u, s, vh = np.linalg.svd(patch.float().numpy(), full_matrices=False)
41
+ u_list.append(u[:, :1]) # Keep top principal direction
42
+ except:
43
+ continue
44
+
45
+ # In a full implementation, we would aggregate these patch spectra.
46
+ # For this simplified Space, we return the Full SVD guided by the hypothesis
47
+ # that the top directions are stable.
48
+
49
+ # Fallback to full SVD for the merging step, but using the "Universal" concept
50
+ # We posit the top k singular vectors form the shared subspace.
51
+ try:
52
+ u, s, v = torch.svd_lowrank(tensor.float(), q=32) # Efficient randomized SVD
53
+ return u, v # Returns Left (U) and Right (V) singular vectors
54
+ except:
55
  return None, None
56
 
57
+ def spectral_aware_merge(adapters_dict, merge_ratio=0.5):
58
+ """
59
+ Merges adapters by aligning them in the Universal Weight Subspace.
60
+ """
61
+ merged_state_dict = {}
62
+
63
+ # Get the keys (layer names) from the first adapter
64
+ if not adapters_dict:
65
+ return {}
66
+
67
+ all_keys = list(next(iter(adapters_dict.values())).keys())
68
+
69
+ print(f"Starting SAMM merge on {len(all_keys)} layers...")
70
+
71
+ for key in all_keys:
72
+ # 1. Collect weights from all adapters for this layer
73
+ layer_tensors = []
74
+ for name, state in adapters_dict.items():
75
+ if key in state:
76
+ layer_tensors.append(state[key])
77
+
78
+ if not layer_tensors:
79
+ continue
80
+
81
+ # Stack for analysis
82
+ # Shape: (N_adapters, rows, cols)
83
+ stack = torch.stack(layer_tensors)
84
+ avg_weight = torch.mean(stack, dim=0)
85
+
86
+ # 2. IF it's a LoRA weight (usually 'lora_A' or 'lora_B'), we try SAMM
87
+ # For simplicity in this demo, we apply it to the computed Delta W or the raw weights
88
+ # Here we apply the Universal Subspace Hypothesis:
89
+ # "The mean is a good approximation only if we project out the noise orthogonal to the principal subspace."
90
+
91
+ # Compute "Universal" basis from the average (center of the cluster)
92
+ # Using the FARMS concept: the shared structure is in the dominant spectrum
93
+ u_univ, v_univ = farms_spectral_analysis(avg_weight.cpu())
94
+
95
+ if u_univ is not None:
96
+ # Project all adapters into this subspace and re-construct
97
+ # W_clean = U U^T W (Filtering out non-universal spectral noise)
98
+
99
+ cleaned_tensors = []
100
+ for w in layer_tensors:
101
+ w = w.float().cpu()
102
+ # Project onto Top-32 universal directions (Filtering)
103
+ # W_proj = U @ (U.T @ W)
104
+ w_proj = torch.mm(u_univ, torch.mm(u_univ.t(), w))
105
+ cleaned_tensors.append(w_proj)
106
+
107
+ # Average the "Cleaned" (Spectrally Aligned) weights
108
+ merged_weight = torch.mean(torch.stack(cleaned_tensors), dim=0)
109
+ else:
110
+ # Fallback to simple average if SVD fails or vector is 1D
111
+ merged_weight = avg_weight
112
+
113
+ merged_state_dict[key] = merged_weight
114
 
115
  return merged_state_dict
116
 
117
  # --- GRADIO HANDLERS ---
118
 
119
  def run_samm_merge(base_model_id, lora_ids_text, hf_token):
120
+ if not hf_token:
121
+ return "Error: Please enter a Hugging Face Write Token."
122
 
123
+ lora_ids = [x.strip() for x in lora_ids_text.split(",") if x.strip()]
124
 
125
+ if len(lora_ids) < 2:
126
+ return "Error: Please provide at least 2 LoRA adapters to merge."
127
 
128
+ log = f"Loading {len(lora_ids)} adapters...\n"
129
+ yield log
130
 
131
+ try:
132
+ # 1. Download/Load Adapters (Weights only to save RAM)
133
+ adapters_weights = {}
134
+
135
+ for lora_id in lora_ids:
136
+ log += f"Fetching {lora_id}...\n"
137
+ yield log
138
+
139
+ # We use PEFT to download, but we manually load state_dict to avoid loading Base Model 10 times
140
+ # Note: In a real large-scale deployment, we would stream this.
141
+ # Here we assume LoRA weights are small enough to fit in RAM.
142
+ try:
143
+ # Hack: Use downloading logic from PEFT without loading base model
144
+ path = snapshot_download(repo_id=lora_id, token=hf_token)
145
+
146
+ # Load safetensors or bin
147
+ if os.path.exists(os.path.join(path, "adapter_model.safetensors")):
148
+ state = load_file(os.path.join(path, "adapter_model.safetensors"))
149
+ else:
150
+ state = torch.load(os.path.join(path, "adapter_model.bin"), map_location="cpu")
 
 
151
 
152
+ adapters_weights[lora_id] = state
153
+ except Exception as e:
154
  log += f"Failed to load {lora_id}: {str(e)}\n"
155
  yield log
156
+ continue # Skip this adapter if it fails
157
 
158
  # 2. Perform SAMM Merge
159
  log += "\nInitializing Spectral-Aware Model Merging (SAMM)...\n"
 
167
  os.makedirs(output_dir, exist_ok=True)
168
 
169
  # Save weights
 
170
  save_file(merged_weights, os.path.join(output_dir, "adapter_model.safetensors"))
171
 
172
  # Save config (Copy from first adapter)
 
173
  config_path = snapshot_download(repo_id=lora_ids[0], token=hf_token)
174
  with open(os.path.join(config_path, "adapter_config.json"), 'r') as f:
175
+ config = json.load(f)
176
+ with open(os.path.join(output_dir, "adapter_config.json"), 'w') as f:
177
+ json.dump(config, f)
178
 
179
+ log += f"\nSuccess! Merged LoRA saved locally to ./{output_dir}\n"
180
  log += "Ready for download or push to hub."
181
  yield log
182
 
183
+ except Exception as e:
184
  yield f"Critical Error: {str(e)}"
185
 
186
  # --- UI SETUP ---
187
 
188
  with gr.Blocks(title="SAMM: Spectral-Aware Model Merging") as demo:
189
+ gr.Markdown("""
190
+ # 💡 SAMM: Spectral-Aware Model Merging
191
+ Algorithm: Universal Weight Subspace via FARMS (Fixed-Aspect-Ratio Matrix Subsampling)
192
+
193
+ This tool merges multiple LoRA adapters by identifying their shared spectral directions (the "Universal Subspace")
194
+ and projecting weights into this noise-free manifold before averaging.
195
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ with gr.Row():
198
+ base_model_input = gr.Textbox(label="Base Model ID", value="mistralai/Mistral-7B-v0.1")
199
+ hf_token_input = gr.Textbox(label="HF Write Token", type="password")
200
+
201
+ loras_input = gr.Textbox(label="LoRA Adapter IDs (comma separated)",
202
+ placeholder="user/lora1, user/lora2, user/lora3...", lines=3)
203
+
204
+ merge_btn = gr.Button("Perform Spectral Merge", variant="primary")
205
+ output_log = gr.Textbox(label="Merge Logs", lines=10)
206
+
207
+ merge_btn.click(fn=run_samm_merge,
208
+ inputs=[base_model_input, loras_input, hf_token_input],
209
+ outputs=output_log)
210
+
211
+ if __name__ == "__main__":
212
+ demo.queue().launch()
213
+