File size: 6,066 Bytes
faa1b64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import argparse
from safetensors.torch import save_file, safe_open
from tqdm import tqdm
import sys
def get_torch_dtype(dtype_str: str):
"""Converts a string to a torch.dtype object."""
if dtype_str == "fp32":
return torch.float32
if dtype_str == "fp16":
return torch.float16
if dtype_str == "bf16":
return torch.bfloat16
raise ValueError(f"Unsupported dtype: {dtype_str}")
def extract_and_svd_lora(model_a_path: str, model_b_path: str, output_path: str, rank: int, device: str, alpha: float,
dtype: torch.dtype):
"""
Extracts the difference between two models, applies SVD to reduce the rank,
and saves the result as a LoRA file.
"""
print(f"Loading base model A: {model_a_path}")
print(f"Loading finetuned model B: {model_b_path}")
lora_tensors = {}
with safe_open(model_a_path, framework="pt", device="cpu") as f_a, \
safe_open(model_b_path, framework="pt", device="cpu") as f_b:
keys_a = set(f_a.keys())
keys_b = set(f_b.keys())
common_keys = keys_a.intersection(keys_b)
# Filter for processable layers (typically linear and conv weights)
# We exclude biases and non-weight tensors.
weight_keys = {k for k in common_keys if k.endswith('.weight') and 'lora_' not in k}
if not weight_keys:
print("No common weight keys found between the two models. Exiting.")
sys.exit(1)
print(f"Found {len(weight_keys)} common weight keys to process.")
# Main processing loop with progress bar
for key in tqdm(sorted(list(weight_keys)), desc="Processing Layers"):
try:
# Load tensors and move to the selected device and dtype
tensor_a = f_a.get_tensor(key).to(device=device, dtype=dtype)
tensor_b = f_b.get_tensor(key).to(device=device, dtype=dtype)
if tensor_a.shape != tensor_b.shape:
print(f"Skipping key {key} due to shape mismatch: A={tensor_a.shape}, B={tensor_b.shape}")
continue
# Calculate the difference (delta weight)
delta_w = tensor_b - tensor_a
# SVD works on 2D matrices. Reshape conv layers and other ND tensors.
original_shape = delta_w.shape
if delta_w.dim() > 2:
delta_w = delta_w.view(original_shape[0], -1)
# --- Core SVD Logic ---
# ΔW ≈ U * S * Vh
# U: Left singular vectors
# S: Singular values (a 1D vector)
# Vh: Right singular vectors (transposed)
U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False)
# Truncate to the desired rank
current_rank = min(rank, S.size(0)) # Ensure rank is not > possible rank
U = U[:, :current_rank]
S = S[:current_rank]
Vh = Vh[:current_rank, :]
# --- Decompose into LoRA A and B matrices ---
# LoRA A (lora_down) is Vh
# LoRA B (lora_up) is U * S
# We scale lora_up by the singular values to retain the magnitude
lora_down = Vh
lora_up = U @ torch.diag(S)
# Reshape back to original conv format if necessary
if len(original_shape) > 2:
# For Conv2D, lora_down is (rank, in_channels * k_h * k_w)
# and lora_up is (out_channels, rank). No reshape needed for up.
pass # The matrix form is standard for LoRA conv layers
# Create LoRA tensor names
base_name = key.replace('.weight', '')
lora_down_name = f"{base_name}.lora_down.weight"
lora_up_name = f"{base_name}.lora_up.weight"
alpha_name = f"{base_name}.alpha"
# Store tensors, moving them to CPU for saving
lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32)
lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32)
lora_tensors[alpha_name] = torch.tensor(alpha).to(torch.float32)
except Exception as e:
print(f"Failed to process key {key}: {e}")
# Save the final LoRA file
if not lora_tensors:
print("No tensors were processed. Output file will not be created.")
return
print(f"\nSaving {len(lora_tensors)} tensors to {output_path}...")
save_file(lora_tensors, output_path)
print("✅ Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract and SVD a LoRA from two SafeTensors checkpoints.")
parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint in .safetensors format.")
parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint in .safetensors format.")
parser.add_argument("output", type=str, help="Path to save the output LoRA file in .safetensors format.")
parser.add_argument("--rank", type=int, required=True, help="The target rank for the SVD.")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
help="Device to use for computation ('cuda' or 'cpu').")
parser.add_argument("--alpha", type=float, default=1.0, help="The alpha (scaling) factor for the LoRA.")
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"],
help="Precision to use for calculations.")
args = parser.parse_args()
# Device check
if args.device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available. Falling back to CPU.")
args.device = "cpu"
dtype = get_torch_dtype(args.precision)
extract_and_svd_lora(args.model_a, args.model_b, args.output, args.rank, args.device, args.alpha, dtype) |