File size: 6,842 Bytes
12bbde9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import torch
import torch.nn as nn
import argparse
from safetensors.torch import load_file, save_file
from model import LocalSongModel
from pathlib import Path

class LoRALinear(nn.Module):
    def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0):
        super().__init__()
        self.original_linear = original_linear
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

        self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features))

        nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
        nn.init.zeros_(self.lora_B)

        self.original_linear.weight.requires_grad = False
        if self.original_linear.bias is not None:
            self.original_linear.bias.requires_grad = False

    def forward(self, x):
        result = self.original_linear(x)
        lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
        return result + lora_out

def inject_lora(model, rank=8, alpha=16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None):
    if device is None:
        device = next(model.parameters()).device

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if any(target in name for target in target_modules):
                *parent_path, attr_name = name.split('.')
                parent = model
                for p in parent_path:
                    parent = getattr(parent, p)

                lora_layer = LoRALinear(module, rank=rank, alpha=alpha)
                lora_layer.lora_A.data = lora_layer.lora_A.data.to(device)
                lora_layer.lora_B.data = lora_layer.lora_B.data.to(device)
                setattr(parent, attr_name, lora_layer)

    return model

def load_lora_weights(model, lora_path, device):
    print(f"Loading LoRA from {lora_path}")
    lora_state_dict = load_file(lora_path, device=str(device))

    loaded_count = 0
    for name, module in model.named_modules():
        if isinstance(module, LoRALinear):
            lora_a_key = f"{name}.lora_A"
            lora_b_key = f"{name}.lora_B"
            if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict:
                module.lora_A.data = lora_state_dict[lora_a_key].to(device)
                module.lora_B.data = lora_state_dict[lora_b_key].to(device)
                loaded_count += 2

    print(f"Loaded {loaded_count} LoRA parameters")

def merge_lora_into_model(model):
    """
    Merge LoRA weights into the base model weights.
    For each LoRALinear layer: W_merged = W_original + (lora_A @ lora_B) * scaling
    """
    print("\nMerging LoRA weights into base model...")
    merged_count = 0

    for name, module in model.named_modules():
        if isinstance(module, LoRALinear):
            lora_delta = (module.lora_A @ module.lora_B) * module.scaling

            with torch.no_grad():
                module.original_linear.weight.data += lora_delta.T 

            merged_count += 1

    print(f"Merged {merged_count} LoRA layers into base weights")

def extract_base_weights(model):
    """
    Extract the merged weights from LoRALinear modules back into a regular state dict.
    """
    print("\nExtracting merged weights...")
    new_state_dict = {}

    for name, module in model.named_modules():
        if isinstance(module, LoRALinear):
            original_name_weight = f"{name}.weight"
            original_name_bias = f"{name}.bias"

            new_state_dict[original_name_weight] = module.original_linear.weight.data
            if module.original_linear.bias is not None:
                new_state_dict[original_name_bias] = module.original_linear.bias.data

    # Copy over all non-LoRA parameters
    for name, param in model.named_parameters():
        if 'lora_A' not in name and 'lora_B' not in name and 'original_linear' not in name:
            new_state_dict[name] = param.data

    print(f"Extracted {len(new_state_dict)} parameters")
    return new_state_dict

def main():
    parser = argparse.ArgumentParser(description="Merge LoRA weights into a base model checkpoint")
    parser.add_argument(
        "--base-checkpoint",
        type=str,
        default="checkpoints/checkpoint_461260.safetensors",
        help="Path to the base model checkpoint"
    )
    parser.add_argument(
        "--lora-checkpoint",
        type=str,
        default="lora.safetensors",
        help="Path to the LoRA checkpoint"
    )
    parser.add_argument(
        "--output-checkpoint",
        type=str,
        default="checkpoints/checkpoint_461260_merged_lora.safetensors",
        help="Path to save the merged checkpoint"
    )
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Configuration
    base_checkpoint = args.base_checkpoint
    lora_checkpoint = args.lora_checkpoint
    output_checkpoint = args.output_checkpoint

    lora_rank = 16
    lora_alpha = 16.0

    print(f"\nBase checkpoint: {base_checkpoint}")
    print(f"LoRA checkpoint: {lora_checkpoint}")
    print(f"Output checkpoint: {output_checkpoint}")
    print(f"LoRA rank: {lora_rank}, alpha: {lora_alpha}")

    # Load base model
    print("\nLoading base model...")
    model = LocalSongModel(
        in_channels=8,
        num_groups=16,
        hidden_size=1024,
        decoder_hidden_size=2048,
        num_blocks=36,
        patch_size=(16, 1),
        num_classes=2304,
        max_tags=8,
    ).to(device)

    state_dict = load_file(base_checkpoint, device=str(device))
    model.load_state_dict(state_dict, strict=True)
    print("Base model loaded")

    print("\nInjecting LoRA layers...")
    model = inject_lora(model, rank=lora_rank, alpha=lora_alpha, device=device)

    load_lora_weights(model, lora_checkpoint, device)

    merge_lora_into_model(model)

    merged_state_dict = extract_base_weights(model)

    print(f"\nSaving merged checkpoint to {output_checkpoint}...")
    save_file(merged_state_dict, output_checkpoint)
    print("✓ Merged checkpoint saved successfully!")

    print("\nVerifying merged checkpoint...")
    test_model = LocalSongModel(
        in_channels=8,
        num_groups=16,
        hidden_size=1024,
        decoder_hidden_size=2048,
        num_blocks=36,
        patch_size=(16, 1),
        num_classes=2304,
        max_tags=8,
    ).to(device)

    merged_loaded = load_file(output_checkpoint, device=str(device))
    test_model.load_state_dict(merged_loaded, strict=True)
    print("✓ Merged checkpoint verified successfully!")

    print(f"\nDone! You can now use '{output_checkpoint}' as a standalone checkpoint without needing LoRA.")

if __name__ == '__main__':
    main()