File size: 6,842 Bytes
12bbde9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import torch
import torch.nn as nn
import argparse
from safetensors.torch import load_file, save_file
from model import LocalSongModel
from pathlib import Path
class LoRALinear(nn.Module):
def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0):
super().__init__()
self.original_linear = original_linear
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank
self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank))
self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features))
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
nn.init.zeros_(self.lora_B)
self.original_linear.weight.requires_grad = False
if self.original_linear.bias is not None:
self.original_linear.bias.requires_grad = False
def forward(self, x):
result = self.original_linear(x)
lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
return result + lora_out
def inject_lora(model, rank=8, alpha=16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None):
if device is None:
device = next(model.parameters()).device
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
if any(target in name for target in target_modules):
*parent_path, attr_name = name.split('.')
parent = model
for p in parent_path:
parent = getattr(parent, p)
lora_layer = LoRALinear(module, rank=rank, alpha=alpha)
lora_layer.lora_A.data = lora_layer.lora_A.data.to(device)
lora_layer.lora_B.data = lora_layer.lora_B.data.to(device)
setattr(parent, attr_name, lora_layer)
return model
def load_lora_weights(model, lora_path, device):
print(f"Loading LoRA from {lora_path}")
lora_state_dict = load_file(lora_path, device=str(device))
loaded_count = 0
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
lora_a_key = f"{name}.lora_A"
lora_b_key = f"{name}.lora_B"
if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict:
module.lora_A.data = lora_state_dict[lora_a_key].to(device)
module.lora_B.data = lora_state_dict[lora_b_key].to(device)
loaded_count += 2
print(f"Loaded {loaded_count} LoRA parameters")
def merge_lora_into_model(model):
"""
Merge LoRA weights into the base model weights.
For each LoRALinear layer: W_merged = W_original + (lora_A @ lora_B) * scaling
"""
print("\nMerging LoRA weights into base model...")
merged_count = 0
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
lora_delta = (module.lora_A @ module.lora_B) * module.scaling
with torch.no_grad():
module.original_linear.weight.data += lora_delta.T
merged_count += 1
print(f"Merged {merged_count} LoRA layers into base weights")
def extract_base_weights(model):
"""
Extract the merged weights from LoRALinear modules back into a regular state dict.
"""
print("\nExtracting merged weights...")
new_state_dict = {}
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
original_name_weight = f"{name}.weight"
original_name_bias = f"{name}.bias"
new_state_dict[original_name_weight] = module.original_linear.weight.data
if module.original_linear.bias is not None:
new_state_dict[original_name_bias] = module.original_linear.bias.data
# Copy over all non-LoRA parameters
for name, param in model.named_parameters():
if 'lora_A' not in name and 'lora_B' not in name and 'original_linear' not in name:
new_state_dict[name] = param.data
print(f"Extracted {len(new_state_dict)} parameters")
return new_state_dict
def main():
parser = argparse.ArgumentParser(description="Merge LoRA weights into a base model checkpoint")
parser.add_argument(
"--base-checkpoint",
type=str,
default="checkpoints/checkpoint_461260.safetensors",
help="Path to the base model checkpoint"
)
parser.add_argument(
"--lora-checkpoint",
type=str,
default="lora.safetensors",
help="Path to the LoRA checkpoint"
)
parser.add_argument(
"--output-checkpoint",
type=str,
default="checkpoints/checkpoint_461260_merged_lora.safetensors",
help="Path to save the merged checkpoint"
)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Configuration
base_checkpoint = args.base_checkpoint
lora_checkpoint = args.lora_checkpoint
output_checkpoint = args.output_checkpoint
lora_rank = 16
lora_alpha = 16.0
print(f"\nBase checkpoint: {base_checkpoint}")
print(f"LoRA checkpoint: {lora_checkpoint}")
print(f"Output checkpoint: {output_checkpoint}")
print(f"LoRA rank: {lora_rank}, alpha: {lora_alpha}")
# Load base model
print("\nLoading base model...")
model = LocalSongModel(
in_channels=8,
num_groups=16,
hidden_size=1024,
decoder_hidden_size=2048,
num_blocks=36,
patch_size=(16, 1),
num_classes=2304,
max_tags=8,
).to(device)
state_dict = load_file(base_checkpoint, device=str(device))
model.load_state_dict(state_dict, strict=True)
print("Base model loaded")
print("\nInjecting LoRA layers...")
model = inject_lora(model, rank=lora_rank, alpha=lora_alpha, device=device)
load_lora_weights(model, lora_checkpoint, device)
merge_lora_into_model(model)
merged_state_dict = extract_base_weights(model)
print(f"\nSaving merged checkpoint to {output_checkpoint}...")
save_file(merged_state_dict, output_checkpoint)
print("✓ Merged checkpoint saved successfully!")
print("\nVerifying merged checkpoint...")
test_model = LocalSongModel(
in_channels=8,
num_groups=16,
hidden_size=1024,
decoder_hidden_size=2048,
num_blocks=36,
patch_size=(16, 1),
num_classes=2304,
max_tags=8,
).to(device)
merged_loaded = load_file(output_checkpoint, device=str(device))
test_model.load_state_dict(merged_loaded, strict=True)
print("✓ Merged checkpoint verified successfully!")
print(f"\nDone! You can now use '{output_checkpoint}' as a standalone checkpoint without needing LoRA.")
if __name__ == '__main__':
main()
|