File size: 7,430 Bytes
d083607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
#!/usr/bin/env python3
"""
Merge Multiple LoRA Adapters

Combines multiple LoRA adapters using weighted averaging based on success rates.
The merged adapter can be used to combine patterns learned by different users
or from different sources.

Usage:
    python merge_lora_adapters.py \
        --adapters adapter1.safetensors adapter2.safetensors \
        --weights 0.6 0.4 \
        --output merged.safetensors

    # Or with success rates (auto-computes weights proportional to success)
    python merge_lora_adapters.py \
        --adapters adapter1.safetensors adapter2.safetensors \
        --success-rates 0.85 0.65 \
        --output merged.safetensors
"""

import argparse
import json
import os
import sys
from pathlib import Path
from typing import Optional

# Try to import required libraries
try:
    import torch
    import torch.nn as nn
    from safetensors.torch import load_file, save_file
    HAS_LIBS = True
except ImportError:
    HAS_LIBS = False


def load_adapter(path: str) -> dict:
    """Load a LoRA adapter from a safetensors file."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Adapter not found: {path}")
    
    return load_file(path)


def compute_weights_from_success_rates(success_rates: list[float]) -> list[float]:
    """Compute normalized weights proportional to success rates."""
    total = sum(success_rates)
    if total == 0:
        # Equal weights if all success rates are 0
        return [1.0 / len(success_rates)] * len(success_rates)
    return [rate / total for rate in success_rates]


def merge_adapters_weighted(
    adapters: list[dict],
    weights: list[float],
    output_path: str
) -> dict:
    """
    Merge multiple LoRA adapters using weighted averaging.
    
    Algorithm: merged_weight = Σ(adapter_i.weight * adapter_i.success_rate) / Σ(success_rate)
    
    For simplicity, we use the provided weights directly.
    """
    if len(adapters) != len(weights):
        raise ValueError("Number of adapters must match number of weights")
    
    # Normalize weights
    total_weight = sum(weights)
    if total_weight == 0:
        raise ValueError("Sum of weights cannot be zero")
    normalized_weights = [w / total_weight for w in weights]
    
    print(f"Merging {len(adapters)} adapters with weights: {normalized_weights}")
    
    # Get all keys from the first adapter
    sample_adapter = adapters[0]
    all_keys = set(sample_adapter.keys())
    
    # Verify all adapters have the same keys
    for i, adapter in enumerate(adapters[1:], 1):
        adapter_keys = set(adapter.keys())
        if adapter_keys != all_keys:
            print(f"Warning: Adapter {i} has different keys. Taking union.", file=sys.stderr)
            all_keys = all_keys.union(adapter_keys)
    
    # Merge each tensor
    merged = {}
    for key in all_keys:
        # Collect tensors from all adapters
        tensors = []
        valid_weights = []
        
        for i, (adapter, weight) in enumerate(zip(adapters, normalized_weights)):
            if key in adapter:
                tensors.append(adapter[key])
                valid_weights.append(weight)
        
        if not tensors:
            continue
        
        # Normalize weights for available tensors
        total_valid = sum(valid_weights)
        if total_valid == 0:
            continue
        norm_weights = [w / total_valid for w in valid_weights]
        
        # Weighted average
        merged[key] = sum(t * w for t, w in zip(tensors, norm_weights))
    
    # Save merged adapter
    save_file(merged, output_path)
    print(f"Merged adapter saved to: {output_path}")
    
    return merged


def compute_adapter_stats(adapter: dict) -> dict:
    """Compute statistics about an adapter for debugging."""
    stats = {
        "num_tensors": len(adapter),
        "total_params": 0,
        "dtype_counts": {},
        "shape_counts": {}
    }
    
    for key, tensor in adapter.items():
        num_params = tensor.numel()
        stats["total_params"] += num_params
        
        dtype = str(tensor.dtype)
        stats["dtype_counts"][dtype] = stats["dtype_counts"].get(dtype, 0) + 1
        
        shape = tuple(tensor.shape)
        shape_key = str(shape)
        stats["shape_counts"][shape_key] = stats["shape_counts"].get(shape_key, 0) + 1
    
    return stats


def main():
    parser = argparse.ArgumentParser(
        description="Merge multiple LoRA adapters using weighted averaging"
    )
    parser.add_argument(
        "--adapters",
        type=str,
        nargs="+",
        required=True,
        help="Paths to LoRA adapter safetensors files"
    )
    parser.add_argument(
        "--weights",
        type=float,
        nargs="+",
        default=None,
        help="Manual weights for each adapter (must sum to 1 or will be normalized)"
    )
    parser.add_argument(
        "--success-rates",
        type=float,
        nargs="+",
        default=None,
        help="Success rates for each adapter (weights computed proportionally)"
    )
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output path for merged adapter"
    )
    parser.add_argument(
        "--stats",
        action="store_true",
        help="Print adapter statistics"
    )
    
    args = parser.parse_args()
    
    if not HAS_LIBS:
        print("Error: Required libraries not found.", file=sys.stderr)
        print("Install with: pip install torch safetensors", file=sys.stderr)
        sys.exit(1)
    
    # Validate inputs
    if args.weights and args.success_rates:
        print("Error: Cannot specify both --weights and --success-rates", file=sys.stderr)
        sys.exit(1)
    
    if args.weights:
        if len(args.adapters) != len(args.weights):
            print("Error: Number of --adapters must match number of --weights", file=sys.stderr)
            sys.exit(1)
        weights = args.weights
    elif args.success_rates:
        if len(args.adapters) != len(args.success_rates):
            print("Error: Number of --adapters must match number of --success-rates", file=sys.stderr)
            sys.exit(1)
        weights = compute_weights_from_success_rates(args.success_rates)
        print(f"Computed weights from success rates: {weights}")
    else:
        # Equal weights
        weights = [1.0 / len(args.adapters)] * len(args.adapters)
    
    # Load adapters
    print(f"Loading {len(args.adapters)} adapters...")
    adapters = []
    for i, path in enumerate(args.adapters):
        print(f"  Loading {i+1}: {path}")
        adapter = load_adapter(path)
        adapters.append(adapter)
        
        if args.stats:
            stats = compute_adapter_stats(adapter)
            print(f"    Stats: {stats['num_tensors']} tensors, {stats['total_params']:,} params")
    
    # Merge
    merge_adapters_weighted(adapters, weights, args.output)
    
    # Print merge info
    print(f"\nMerge complete!")
    print(f"  Output: {args.output}")
    print(f"  Adapters merged: {len(args.adapters)}")
    
    # Save merge metadata
    metadata_path = args.output + ".meta.json"
    metadata = {
        "adapters": args.adapters,
        "weights": weights,
        "num_adapters": len(args.adapters)
    }
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    print(f"  Metadata: {metadata_path}")


if __name__ == "__main__":
    main()