LocalSong / merge_lora.py
Localsong's picture
Upload 5 files
12bbde9 verified
import torch
import torch.nn as nn
import argparse
from safetensors.torch import load_file, save_file
from model import LocalSongModel
from pathlib import Path
class LoRALinear(nn.Module):
def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0):
super().__init__()
self.original_linear = original_linear
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank
self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank))
self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features))
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
nn.init.zeros_(self.lora_B)
self.original_linear.weight.requires_grad = False
if self.original_linear.bias is not None:
self.original_linear.bias.requires_grad = False
def forward(self, x):
result = self.original_linear(x)
lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
return result + lora_out
def inject_lora(model, rank=8, alpha=16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None):
if device is None:
device = next(model.parameters()).device
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
if any(target in name for target in target_modules):
*parent_path, attr_name = name.split('.')
parent = model
for p in parent_path:
parent = getattr(parent, p)
lora_layer = LoRALinear(module, rank=rank, alpha=alpha)
lora_layer.lora_A.data = lora_layer.lora_A.data.to(device)
lora_layer.lora_B.data = lora_layer.lora_B.data.to(device)
setattr(parent, attr_name, lora_layer)
return model
def load_lora_weights(model, lora_path, device):
print(f"Loading LoRA from {lora_path}")
lora_state_dict = load_file(lora_path, device=str(device))
loaded_count = 0
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
lora_a_key = f"{name}.lora_A"
lora_b_key = f"{name}.lora_B"
if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict:
module.lora_A.data = lora_state_dict[lora_a_key].to(device)
module.lora_B.data = lora_state_dict[lora_b_key].to(device)
loaded_count += 2
print(f"Loaded {loaded_count} LoRA parameters")
def merge_lora_into_model(model):
"""
Merge LoRA weights into the base model weights.
For each LoRALinear layer: W_merged = W_original + (lora_A @ lora_B) * scaling
"""
print("\nMerging LoRA weights into base model...")
merged_count = 0
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
lora_delta = (module.lora_A @ module.lora_B) * module.scaling
with torch.no_grad():
module.original_linear.weight.data += lora_delta.T
merged_count += 1
print(f"Merged {merged_count} LoRA layers into base weights")
def extract_base_weights(model):
"""
Extract the merged weights from LoRALinear modules back into a regular state dict.
"""
print("\nExtracting merged weights...")
new_state_dict = {}
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
original_name_weight = f"{name}.weight"
original_name_bias = f"{name}.bias"
new_state_dict[original_name_weight] = module.original_linear.weight.data
if module.original_linear.bias is not None:
new_state_dict[original_name_bias] = module.original_linear.bias.data
# Copy over all non-LoRA parameters
for name, param in model.named_parameters():
if 'lora_A' not in name and 'lora_B' not in name and 'original_linear' not in name:
new_state_dict[name] = param.data
print(f"Extracted {len(new_state_dict)} parameters")
return new_state_dict
def main():
parser = argparse.ArgumentParser(description="Merge LoRA weights into a base model checkpoint")
parser.add_argument(
"--base-checkpoint",
type=str,
default="checkpoints/checkpoint_461260.safetensors",
help="Path to the base model checkpoint"
)
parser.add_argument(
"--lora-checkpoint",
type=str,
default="lora.safetensors",
help="Path to the LoRA checkpoint"
)
parser.add_argument(
"--output-checkpoint",
type=str,
default="checkpoints/checkpoint_461260_merged_lora.safetensors",
help="Path to save the merged checkpoint"
)
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Configuration
base_checkpoint = args.base_checkpoint
lora_checkpoint = args.lora_checkpoint
output_checkpoint = args.output_checkpoint
lora_rank = 16
lora_alpha = 16.0
print(f"\nBase checkpoint: {base_checkpoint}")
print(f"LoRA checkpoint: {lora_checkpoint}")
print(f"Output checkpoint: {output_checkpoint}")
print(f"LoRA rank: {lora_rank}, alpha: {lora_alpha}")
# Load base model
print("\nLoading base model...")
model = LocalSongModel(
in_channels=8,
num_groups=16,
hidden_size=1024,
decoder_hidden_size=2048,
num_blocks=36,
patch_size=(16, 1),
num_classes=2304,
max_tags=8,
).to(device)
state_dict = load_file(base_checkpoint, device=str(device))
model.load_state_dict(state_dict, strict=True)
print("Base model loaded")
print("\nInjecting LoRA layers...")
model = inject_lora(model, rank=lora_rank, alpha=lora_alpha, device=device)
load_lora_weights(model, lora_checkpoint, device)
merge_lora_into_model(model)
merged_state_dict = extract_base_weights(model)
print(f"\nSaving merged checkpoint to {output_checkpoint}...")
save_file(merged_state_dict, output_checkpoint)
print("✓ Merged checkpoint saved successfully!")
print("\nVerifying merged checkpoint...")
test_model = LocalSongModel(
in_channels=8,
num_groups=16,
hidden_size=1024,
decoder_hidden_size=2048,
num_blocks=36,
patch_size=(16, 1),
num_classes=2304,
max_tags=8,
).to(device)
merged_loaded = load_file(output_checkpoint, device=str(device))
test_model.load_state_dict(merged_loaded, strict=True)
print("✓ Merged checkpoint verified successfully!")
print(f"\nDone! You can now use '{output_checkpoint}' as a standalone checkpoint without needing LoRA.")
if __name__ == '__main__':
main()