|
|
import torch |
|
|
import torch.nn as nn |
|
|
import argparse |
|
|
from safetensors.torch import load_file, save_file |
|
|
from model import LocalSongModel |
|
|
from pathlib import Path |
|
|
|
|
|
class LoRALinear(nn.Module): |
|
|
def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0): |
|
|
super().__init__() |
|
|
self.original_linear = original_linear |
|
|
self.rank = rank |
|
|
self.alpha = alpha |
|
|
self.scaling = alpha / rank |
|
|
|
|
|
self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank)) |
|
|
self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features)) |
|
|
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5) |
|
|
nn.init.zeros_(self.lora_B) |
|
|
|
|
|
self.original_linear.weight.requires_grad = False |
|
|
if self.original_linear.bias is not None: |
|
|
self.original_linear.bias.requires_grad = False |
|
|
|
|
|
def forward(self, x): |
|
|
result = self.original_linear(x) |
|
|
lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling |
|
|
return result + lora_out |
|
|
|
|
|
def inject_lora(model, rank=8, alpha=16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None): |
|
|
if device is None: |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
if any(target in name for target in target_modules): |
|
|
*parent_path, attr_name = name.split('.') |
|
|
parent = model |
|
|
for p in parent_path: |
|
|
parent = getattr(parent, p) |
|
|
|
|
|
lora_layer = LoRALinear(module, rank=rank, alpha=alpha) |
|
|
lora_layer.lora_A.data = lora_layer.lora_A.data.to(device) |
|
|
lora_layer.lora_B.data = lora_layer.lora_B.data.to(device) |
|
|
setattr(parent, attr_name, lora_layer) |
|
|
|
|
|
return model |
|
|
|
|
|
def load_lora_weights(model, lora_path, device): |
|
|
print(f"Loading LoRA from {lora_path}") |
|
|
lora_state_dict = load_file(lora_path, device=str(device)) |
|
|
|
|
|
loaded_count = 0 |
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, LoRALinear): |
|
|
lora_a_key = f"{name}.lora_A" |
|
|
lora_b_key = f"{name}.lora_B" |
|
|
if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict: |
|
|
module.lora_A.data = lora_state_dict[lora_a_key].to(device) |
|
|
module.lora_B.data = lora_state_dict[lora_b_key].to(device) |
|
|
loaded_count += 2 |
|
|
|
|
|
print(f"Loaded {loaded_count} LoRA parameters") |
|
|
|
|
|
def merge_lora_into_model(model): |
|
|
""" |
|
|
Merge LoRA weights into the base model weights. |
|
|
For each LoRALinear layer: W_merged = W_original + (lora_A @ lora_B) * scaling |
|
|
""" |
|
|
print("\nMerging LoRA weights into base model...") |
|
|
merged_count = 0 |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, LoRALinear): |
|
|
lora_delta = (module.lora_A @ module.lora_B) * module.scaling |
|
|
|
|
|
with torch.no_grad(): |
|
|
module.original_linear.weight.data += lora_delta.T |
|
|
|
|
|
merged_count += 1 |
|
|
|
|
|
print(f"Merged {merged_count} LoRA layers into base weights") |
|
|
|
|
|
def extract_base_weights(model): |
|
|
""" |
|
|
Extract the merged weights from LoRALinear modules back into a regular state dict. |
|
|
""" |
|
|
print("\nExtracting merged weights...") |
|
|
new_state_dict = {} |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, LoRALinear): |
|
|
original_name_weight = f"{name}.weight" |
|
|
original_name_bias = f"{name}.bias" |
|
|
|
|
|
new_state_dict[original_name_weight] = module.original_linear.weight.data |
|
|
if module.original_linear.bias is not None: |
|
|
new_state_dict[original_name_bias] = module.original_linear.bias.data |
|
|
|
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if 'lora_A' not in name and 'lora_B' not in name and 'original_linear' not in name: |
|
|
new_state_dict[name] = param.data |
|
|
|
|
|
print(f"Extracted {len(new_state_dict)} parameters") |
|
|
return new_state_dict |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Merge LoRA weights into a base model checkpoint") |
|
|
parser.add_argument( |
|
|
"--base-checkpoint", |
|
|
type=str, |
|
|
default="checkpoints/checkpoint_461260.safetensors", |
|
|
help="Path to the base model checkpoint" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lora-checkpoint", |
|
|
type=str, |
|
|
default="lora.safetensors", |
|
|
help="Path to the LoRA checkpoint" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-checkpoint", |
|
|
type=str, |
|
|
default="checkpoints/checkpoint_461260_merged_lora.safetensors", |
|
|
help="Path to save the merged checkpoint" |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
base_checkpoint = args.base_checkpoint |
|
|
lora_checkpoint = args.lora_checkpoint |
|
|
output_checkpoint = args.output_checkpoint |
|
|
|
|
|
lora_rank = 16 |
|
|
lora_alpha = 16.0 |
|
|
|
|
|
print(f"\nBase checkpoint: {base_checkpoint}") |
|
|
print(f"LoRA checkpoint: {lora_checkpoint}") |
|
|
print(f"Output checkpoint: {output_checkpoint}") |
|
|
print(f"LoRA rank: {lora_rank}, alpha: {lora_alpha}") |
|
|
|
|
|
|
|
|
print("\nLoading base model...") |
|
|
model = LocalSongModel( |
|
|
in_channels=8, |
|
|
num_groups=16, |
|
|
hidden_size=1024, |
|
|
decoder_hidden_size=2048, |
|
|
num_blocks=36, |
|
|
patch_size=(16, 1), |
|
|
num_classes=2304, |
|
|
max_tags=8, |
|
|
).to(device) |
|
|
|
|
|
state_dict = load_file(base_checkpoint, device=str(device)) |
|
|
model.load_state_dict(state_dict, strict=True) |
|
|
print("Base model loaded") |
|
|
|
|
|
print("\nInjecting LoRA layers...") |
|
|
model = inject_lora(model, rank=lora_rank, alpha=lora_alpha, device=device) |
|
|
|
|
|
load_lora_weights(model, lora_checkpoint, device) |
|
|
|
|
|
merge_lora_into_model(model) |
|
|
|
|
|
merged_state_dict = extract_base_weights(model) |
|
|
|
|
|
print(f"\nSaving merged checkpoint to {output_checkpoint}...") |
|
|
save_file(merged_state_dict, output_checkpoint) |
|
|
print("✓ Merged checkpoint saved successfully!") |
|
|
|
|
|
print("\nVerifying merged checkpoint...") |
|
|
test_model = LocalSongModel( |
|
|
in_channels=8, |
|
|
num_groups=16, |
|
|
hidden_size=1024, |
|
|
decoder_hidden_size=2048, |
|
|
num_blocks=36, |
|
|
patch_size=(16, 1), |
|
|
num_classes=2304, |
|
|
max_tags=8, |
|
|
).to(device) |
|
|
|
|
|
merged_loaded = load_file(output_checkpoint, device=str(device)) |
|
|
test_model.load_state_dict(merged_loaded, strict=True) |
|
|
print("✓ Merged checkpoint verified successfully!") |
|
|
|
|
|
print(f"\nDone! You can now use '{output_checkpoint}' as a standalone checkpoint without needing LoRA.") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|