Stack-2-9-finetuned / scripts /training /merge_lora_adapters.py
walidsobhie-code
chore: Rename MCP server to Stack2.9
c7f1596
#!/usr/bin/env python3
"""
Merge Multiple LoRA Adapters
Combines multiple LoRA adapters using weighted averaging based on success rates.
The merged adapter can be used to combine patterns learned by different users
or from different sources.
Usage:
python merge_lora_adapters.py \
--adapters adapter1.safetensors adapter2.safetensors \
--weights 0.6 0.4 \
--output merged.safetensors
# Or with success rates (auto-computes weights proportional to success)
python merge_lora_adapters.py \
--adapters adapter1.safetensors adapter2.safetensors \
--success-rates 0.85 0.65 \
--output merged.safetensors
"""
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Optional
# Try to import required libraries
try:
import torch
import torch.nn as nn
from safetensors.torch import load_file, save_file
HAS_LIBS = True
except ImportError:
HAS_LIBS = False
def load_adapter(path: str) -> dict:
"""Load a LoRA adapter from a safetensors file."""
if not os.path.exists(path):
raise FileNotFoundError(f"Adapter not found: {path}")
return load_file(path)
def compute_weights_from_success_rates(success_rates: list[float]) -> list[float]:
"""Compute normalized weights proportional to success rates."""
total = sum(success_rates)
if total == 0:
# Equal weights if all success rates are 0
return [1.0 / len(success_rates)] * len(success_rates)
return [rate / total for rate in success_rates]
def merge_adapters_weighted(
adapters: list[dict],
weights: list[float],
output_path: str
) -> dict:
"""
Merge multiple LoRA adapters using weighted averaging.
Algorithm: merged_weight = Σ(adapter_i.weight * adapter_i.success_rate) / Σ(success_rate)
For simplicity, we use the provided weights directly.
"""
if len(adapters) != len(weights):
raise ValueError("Number of adapters must match number of weights")
# Normalize weights
total_weight = sum(weights)
if total_weight == 0:
raise ValueError("Sum of weights cannot be zero")
normalized_weights = [w / total_weight for w in weights]
print(f"Merging {len(adapters)} adapters with weights: {normalized_weights}")
# Get all keys from the first adapter
sample_adapter = adapters[0]
all_keys = set(sample_adapter.keys())
# Verify all adapters have the same keys
for i, adapter in enumerate(adapters[1:], 1):
adapter_keys = set(adapter.keys())
if adapter_keys != all_keys:
print(f"Warning: Adapter {i} has different keys. Taking union.", file=sys.stderr)
all_keys = all_keys.union(adapter_keys)
# Merge each tensor
merged = {}
for key in all_keys:
# Collect tensors from all adapters
tensors = []
valid_weights = []
for i, (adapter, weight) in enumerate(zip(adapters, normalized_weights)):
if key in adapter:
tensors.append(adapter[key])
valid_weights.append(weight)
if not tensors:
continue
# Normalize weights for available tensors
total_valid = sum(valid_weights)
if total_valid == 0:
continue
norm_weights = [w / total_valid for w in valid_weights]
# Weighted average
merged[key] = sum(t * w for t, w in zip(tensors, norm_weights))
# Save merged adapter
save_file(merged, output_path)
print(f"Merged adapter saved to: {output_path}")
return merged
def compute_adapter_stats(adapter: dict) -> dict:
"""Compute statistics about an adapter for debugging."""
stats = {
"num_tensors": len(adapter),
"total_params": 0,
"dtype_counts": {},
"shape_counts": {}
}
for key, tensor in adapter.items():
num_params = tensor.numel()
stats["total_params"] += num_params
dtype = str(tensor.dtype)
stats["dtype_counts"][dtype] = stats["dtype_counts"].get(dtype, 0) + 1
shape = tuple(tensor.shape)
shape_key = str(shape)
stats["shape_counts"][shape_key] = stats["shape_counts"].get(shape_key, 0) + 1
return stats
def main():
parser = argparse.ArgumentParser(
description="Merge multiple LoRA adapters using weighted averaging"
)
parser.add_argument(
"--adapters",
type=str,
nargs="+",
required=True,
help="Paths to LoRA adapter safetensors files"
)
parser.add_argument(
"--weights",
type=float,
nargs="+",
default=None,
help="Manual weights for each adapter (must sum to 1 or will be normalized)"
)
parser.add_argument(
"--success-rates",
type=float,
nargs="+",
default=None,
help="Success rates for each adapter (weights computed proportionally)"
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Output path for merged adapter"
)
parser.add_argument(
"--stats",
action="store_true",
help="Print adapter statistics"
)
args = parser.parse_args()
if not HAS_LIBS:
print("Error: Required libraries not found.", file=sys.stderr)
print("Install with: pip install torch safetensors", file=sys.stderr)
sys.exit(1)
# Validate inputs
if args.weights and args.success_rates:
print("Error: Cannot specify both --weights and --success-rates", file=sys.stderr)
sys.exit(1)
if args.weights:
if len(args.adapters) != len(args.weights):
print("Error: Number of --adapters must match number of --weights", file=sys.stderr)
sys.exit(1)
weights = args.weights
elif args.success_rates:
if len(args.adapters) != len(args.success_rates):
print("Error: Number of --adapters must match number of --success-rates", file=sys.stderr)
sys.exit(1)
weights = compute_weights_from_success_rates(args.success_rates)
print(f"Computed weights from success rates: {weights}")
else:
# Equal weights
weights = [1.0 / len(args.adapters)] * len(args.adapters)
# Load adapters
print(f"Loading {len(args.adapters)} adapters...")
adapters = []
for i, path in enumerate(args.adapters):
print(f" Loading {i+1}: {path}")
adapter = load_adapter(path)
adapters.append(adapter)
if args.stats:
stats = compute_adapter_stats(adapter)
print(f" Stats: {stats['num_tensors']} tensors, {stats['total_params']:,} params")
# Merge
merge_adapters_weighted(adapters, weights, args.output)
# Print merge info
print(f"\nMerge complete!")
print(f" Output: {args.output}")
print(f" Adapters merged: {len(args.adapters)}")
# Save merge metadata
metadata_path = args.output + ".meta.json"
metadata = {
"adapters": args.adapters,
"weights": weights,
"num_adapters": len(args.adapters)
}
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f" Metadata: {metadata_path}")
if __name__ == "__main__":
main()