import torch import torch.nn as nn import argparse from safetensors.torch import load_file, save_file from model import LocalSongModel from pathlib import Path class LoRALinear(nn.Module): def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0): super().__init__() self.original_linear = original_linear self.rank = rank self.alpha = alpha self.scaling = alpha / rank self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank)) self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features)) nn.init.kaiming_uniform_(self.lora_A, a=5**0.5) nn.init.zeros_(self.lora_B) self.original_linear.weight.requires_grad = False if self.original_linear.bias is not None: self.original_linear.bias.requires_grad = False def forward(self, x): result = self.original_linear(x) lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling return result + lora_out def inject_lora(model, rank=8, alpha=16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None): if device is None: device = next(model.parameters()).device for name, module in model.named_modules(): if isinstance(module, nn.Linear): if any(target in name for target in target_modules): *parent_path, attr_name = name.split('.') parent = model for p in parent_path: parent = getattr(parent, p) lora_layer = LoRALinear(module, rank=rank, alpha=alpha) lora_layer.lora_A.data = lora_layer.lora_A.data.to(device) lora_layer.lora_B.data = lora_layer.lora_B.data.to(device) setattr(parent, attr_name, lora_layer) return model def load_lora_weights(model, lora_path, device): print(f"Loading LoRA from {lora_path}") lora_state_dict = load_file(lora_path, device=str(device)) loaded_count = 0 for name, module in model.named_modules(): if isinstance(module, LoRALinear): lora_a_key = f"{name}.lora_A" lora_b_key = f"{name}.lora_B" if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict: module.lora_A.data = lora_state_dict[lora_a_key].to(device) module.lora_B.data = lora_state_dict[lora_b_key].to(device) loaded_count += 2 print(f"Loaded {loaded_count} LoRA parameters") def merge_lora_into_model(model): """ Merge LoRA weights into the base model weights. For each LoRALinear layer: W_merged = W_original + (lora_A @ lora_B) * scaling """ print("\nMerging LoRA weights into base model...") merged_count = 0 for name, module in model.named_modules(): if isinstance(module, LoRALinear): lora_delta = (module.lora_A @ module.lora_B) * module.scaling with torch.no_grad(): module.original_linear.weight.data += lora_delta.T merged_count += 1 print(f"Merged {merged_count} LoRA layers into base weights") def extract_base_weights(model): """ Extract the merged weights from LoRALinear modules back into a regular state dict. """ print("\nExtracting merged weights...") new_state_dict = {} for name, module in model.named_modules(): if isinstance(module, LoRALinear): original_name_weight = f"{name}.weight" original_name_bias = f"{name}.bias" new_state_dict[original_name_weight] = module.original_linear.weight.data if module.original_linear.bias is not None: new_state_dict[original_name_bias] = module.original_linear.bias.data # Copy over all non-LoRA parameters for name, param in model.named_parameters(): if 'lora_A' not in name and 'lora_B' not in name and 'original_linear' not in name: new_state_dict[name] = param.data print(f"Extracted {len(new_state_dict)} parameters") return new_state_dict def main(): parser = argparse.ArgumentParser(description="Merge LoRA weights into a base model checkpoint") parser.add_argument( "--base-checkpoint", type=str, default="checkpoints/checkpoint_461260.safetensors", help="Path to the base model checkpoint" ) parser.add_argument( "--lora-checkpoint", type=str, default="lora.safetensors", help="Path to the LoRA checkpoint" ) parser.add_argument( "--output-checkpoint", type=str, default="checkpoints/checkpoint_461260_merged_lora.safetensors", help="Path to save the merged checkpoint" ) args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Configuration base_checkpoint = args.base_checkpoint lora_checkpoint = args.lora_checkpoint output_checkpoint = args.output_checkpoint lora_rank = 16 lora_alpha = 16.0 print(f"\nBase checkpoint: {base_checkpoint}") print(f"LoRA checkpoint: {lora_checkpoint}") print(f"Output checkpoint: {output_checkpoint}") print(f"LoRA rank: {lora_rank}, alpha: {lora_alpha}") # Load base model print("\nLoading base model...") model = LocalSongModel( in_channels=8, num_groups=16, hidden_size=1024, decoder_hidden_size=2048, num_blocks=36, patch_size=(16, 1), num_classes=2304, max_tags=8, ).to(device) state_dict = load_file(base_checkpoint, device=str(device)) model.load_state_dict(state_dict, strict=True) print("Base model loaded") print("\nInjecting LoRA layers...") model = inject_lora(model, rank=lora_rank, alpha=lora_alpha, device=device) load_lora_weights(model, lora_checkpoint, device) merge_lora_into_model(model) merged_state_dict = extract_base_weights(model) print(f"\nSaving merged checkpoint to {output_checkpoint}...") save_file(merged_state_dict, output_checkpoint) print("✓ Merged checkpoint saved successfully!") print("\nVerifying merged checkpoint...") test_model = LocalSongModel( in_channels=8, num_groups=16, hidden_size=1024, decoder_hidden_size=2048, num_blocks=36, patch_size=(16, 1), num_classes=2304, max_tags=8, ).to(device) merged_loaded = load_file(output_checkpoint, device=str(device)) test_model.load_state_dict(merged_loaded, strict=True) print("✓ Merged checkpoint verified successfully!") print(f"\nDone! You can now use '{output_checkpoint}' as a standalone checkpoint without needing LoRA.") if __name__ == '__main__': main()