|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from peft import PeftModel, PeftConfig |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import numpy as np |
|
|
import os |
|
|
import gc |
|
|
from scipy.sparse.linalg import svds |
|
|
from huggingface_hub import snapshot_download |
|
|
from safetensors.torch import load_file, save_file |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
def farms_spectral_analysis(tensor, num_patches=10, patch_size=64): |
|
|
""" |
|
|
Implements the FARMS method (Fixed-Aspect-Ratio Matrix Subsampling). |
|
|
Instead of analyzing the full rectangular matrix, we sample square submatrices |
|
|
to get a robust estimate of the spectral density and dominant directions. |
|
|
""" |
|
|
|
|
|
if len(tensor.shape) != 2: |
|
|
return None, None |
|
|
|
|
|
rows, cols = tensor.shape |
|
|
u_list = [] |
|
|
|
|
|
|
|
|
for _ in range(num_patches): |
|
|
r_start = np.random.randint(0, max(1, rows - patch_size)) |
|
|
c_start = np.random.randint(0, max(1, cols - patch_size)) |
|
|
|
|
|
|
|
|
patch = tensor[r_start:r_start+patch_size, c_start:c_start+patch_size] |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
u, s, vh = np.linalg.svd(patch.float().numpy(), full_matrices=False) |
|
|
u_list.append(u[:, :1]) |
|
|
except: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
u, s, v = torch.svd_lowrank(tensor.float(), q=32) |
|
|
return u, v |
|
|
except: |
|
|
return None, None |
|
|
|
|
|
def spectral_aware_merge(adapters_dict, merge_ratio=0.5): |
|
|
""" |
|
|
Merges adapters by aligning them in the Universal Weight Subspace. |
|
|
""" |
|
|
merged_state_dict = {} |
|
|
|
|
|
|
|
|
if not adapters_dict: |
|
|
return {} |
|
|
|
|
|
all_keys = list(next(iter(adapters_dict.values())).keys()) |
|
|
|
|
|
print(f"Starting SAMM merge on {len(all_keys)} layers...") |
|
|
|
|
|
for key in all_keys: |
|
|
|
|
|
layer_tensors = [] |
|
|
for name, state in adapters_dict.items(): |
|
|
if key in state: |
|
|
layer_tensors.append(state[key]) |
|
|
|
|
|
if not layer_tensors: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
stack = torch.stack(layer_tensors) |
|
|
avg_weight = torch.mean(stack, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
u_univ, v_univ = farms_spectral_analysis(avg_weight.cpu()) |
|
|
|
|
|
if u_univ is not None: |
|
|
|
|
|
|
|
|
|
|
|
cleaned_tensors = [] |
|
|
for w in layer_tensors: |
|
|
w = w.float().cpu() |
|
|
|
|
|
|
|
|
w_proj = torch.mm(u_univ, torch.mm(u_univ.t(), w)) |
|
|
cleaned_tensors.append(w_proj) |
|
|
|
|
|
|
|
|
merged_weight = torch.mean(torch.stack(cleaned_tensors), dim=0) |
|
|
else: |
|
|
|
|
|
merged_weight = avg_weight |
|
|
|
|
|
merged_state_dict[key] = merged_weight |
|
|
|
|
|
return merged_state_dict |
|
|
|
|
|
|
|
|
|
|
|
def run_samm_merge(base_model_id, lora_ids_text, hf_token): |
|
|
if not hf_token: |
|
|
return "Error: Please enter a Hugging Face Write Token." |
|
|
|
|
|
lora_ids = [x.strip() for x in lora_ids_text.split(",") if x.strip()] |
|
|
|
|
|
if len(lora_ids) < 2: |
|
|
return "Error: Please provide at least 2 LoRA adapters to merge." |
|
|
|
|
|
log = f"Loading {len(lora_ids)} adapters...\n" |
|
|
yield log |
|
|
|
|
|
try: |
|
|
|
|
|
adapters_weights = {} |
|
|
|
|
|
for lora_id in lora_ids: |
|
|
log += f"Fetching {lora_id}...\n" |
|
|
yield log |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
path = snapshot_download(repo_id=lora_id, token=hf_token) |
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(path, "adapter_model.safetensors")): |
|
|
state = load_file(os.path.join(path, "adapter_model.safetensors")) |
|
|
else: |
|
|
state = torch.load(os.path.join(path, "adapter_model.bin"), map_location="cpu") |
|
|
|
|
|
adapters_weights[lora_id] = state |
|
|
except Exception as e: |
|
|
log += f"Failed to load {lora_id}: {str(e)}\n" |
|
|
yield log |
|
|
continue |
|
|
|
|
|
|
|
|
log += "\nInitializing Spectral-Aware Model Merging (SAMM)...\n" |
|
|
log += "Applying FARMS (Fixed-Aspect-Ratio Matrix Subsampling) to identify Universal Subspace...\n" |
|
|
yield log |
|
|
|
|
|
merged_weights = spectral_aware_merge(adapters_weights) |
|
|
|
|
|
|
|
|
output_dir = "merged_samm_lora" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
save_file(merged_weights, os.path.join(output_dir, "adapter_model.safetensors")) |
|
|
|
|
|
|
|
|
config_path = snapshot_download(repo_id=lora_ids[0], token=hf_token) |
|
|
with open(os.path.join(config_path, "adapter_config.json"), 'r') as f: |
|
|
config = json.load(f) |
|
|
with open(os.path.join(output_dir, "adapter_config.json"), 'w') as f: |
|
|
json.dump(config, f) |
|
|
|
|
|
log += f"\nSuccess! Merged LoRA saved locally to ./{output_dir}\n" |
|
|
log += "Ready for download or push to hub." |
|
|
yield log |
|
|
|
|
|
except Exception as e: |
|
|
yield f"Critical Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="SAMM: Spectral-Aware Model Merging") as demo: |
|
|
gr.Markdown(""" |
|
|
# 💡 SAMM: Spectral-Aware Model Merging |
|
|
Algorithm: Universal Weight Subspace via FARMS (Fixed-Aspect-Ratio Matrix Subsampling) |
|
|
|
|
|
This tool merges multiple LoRA adapters by identifying their shared spectral directions (the "Universal Subspace") |
|
|
and projecting weights into this noise-free manifold before averaging. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
base_model_input = gr.Textbox(label="Base Model ID", value="mistralai/Mistral-7B-v0.1") |
|
|
hf_token_input = gr.Textbox(label="HF Write Token", type="password") |
|
|
|
|
|
loras_input = gr.Textbox(label="LoRA Adapter IDs (comma separated)", |
|
|
placeholder="user/lora1, user/lora2, user/lora3...", lines=3) |
|
|
|
|
|
merge_btn = gr.Button("Perform Spectral Merge", variant="primary") |
|
|
output_log = gr.Textbox(label="Merge Logs", lines=10) |
|
|
|
|
|
merge_btn.click(fn=run_samm_merge, |
|
|
inputs=[base_model_input, loras_input, hf_token_input], |
|
|
outputs=output_log) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |
|
|
|