|
|
import torch |
|
|
import argparse |
|
|
from safetensors.torch import save_file, safe_open |
|
|
from tqdm import tqdm |
|
|
import sys |
|
|
|
|
|
|
|
|
def get_torch_dtype(dtype_str: str): |
|
|
"""Converts a string to a torch.dtype object.""" |
|
|
if dtype_str == "fp32": |
|
|
return torch.float32 |
|
|
if dtype_str == "fp16": |
|
|
return torch.float16 |
|
|
if dtype_str == "bf16": |
|
|
return torch.bfloat16 |
|
|
raise ValueError(f"Unsupported dtype: {dtype_str}") |
|
|
|
|
|
|
|
|
def extract_and_svd_lora(model_a_path: str, model_b_path: str, output_path: str, rank: int, device: str, alpha: float, |
|
|
dtype: torch.dtype): |
|
|
""" |
|
|
Extracts the difference between two models, applies SVD to reduce the rank, |
|
|
and saves the result as a LoRA file. |
|
|
""" |
|
|
print(f"Loading base model A: {model_a_path}") |
|
|
print(f"Loading finetuned model B: {model_b_path}") |
|
|
|
|
|
lora_tensors = {} |
|
|
|
|
|
with safe_open(model_a_path, framework="pt", device="cpu") as f_a, \ |
|
|
safe_open(model_b_path, framework="pt", device="cpu") as f_b: |
|
|
|
|
|
keys_a = set(f_a.keys()) |
|
|
keys_b = set(f_b.keys()) |
|
|
common_keys = keys_a.intersection(keys_b) |
|
|
|
|
|
|
|
|
|
|
|
weight_keys = {k for k in common_keys if k.endswith('.weight') and 'lora_' not in k} |
|
|
|
|
|
if not weight_keys: |
|
|
print("No common weight keys found between the two models. Exiting.") |
|
|
sys.exit(1) |
|
|
|
|
|
print(f"Found {len(weight_keys)} common weight keys to process.") |
|
|
|
|
|
|
|
|
for key in tqdm(sorted(list(weight_keys)), desc="Processing Layers"): |
|
|
try: |
|
|
|
|
|
tensor_a = f_a.get_tensor(key).to(device=device, dtype=dtype) |
|
|
tensor_b = f_b.get_tensor(key).to(device=device, dtype=dtype) |
|
|
|
|
|
if tensor_a.shape != tensor_b.shape: |
|
|
print(f"Skipping key {key} due to shape mismatch: A={tensor_a.shape}, B={tensor_b.shape}") |
|
|
continue |
|
|
|
|
|
|
|
|
delta_w = tensor_b - tensor_a |
|
|
|
|
|
|
|
|
original_shape = delta_w.shape |
|
|
if delta_w.dim() > 2: |
|
|
delta_w = delta_w.view(original_shape[0], -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False) |
|
|
|
|
|
|
|
|
current_rank = min(rank, S.size(0)) |
|
|
U = U[:, :current_rank] |
|
|
S = S[:current_rank] |
|
|
Vh = Vh[:current_rank, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lora_down = Vh |
|
|
lora_up = U @ torch.diag(S) |
|
|
|
|
|
|
|
|
if len(original_shape) > 2: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
base_name = key.replace('.weight', '') |
|
|
lora_down_name = f"{base_name}.lora_down.weight" |
|
|
lora_up_name = f"{base_name}.lora_up.weight" |
|
|
alpha_name = f"{base_name}.alpha" |
|
|
|
|
|
|
|
|
lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32) |
|
|
lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32) |
|
|
lora_tensors[alpha_name] = torch.tensor(alpha).to(torch.float32) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Failed to process key {key}: {e}") |
|
|
|
|
|
|
|
|
if not lora_tensors: |
|
|
print("No tensors were processed. Output file will not be created.") |
|
|
return |
|
|
|
|
|
print(f"\nSaving {len(lora_tensors)} tensors to {output_path}...") |
|
|
save_file(lora_tensors, output_path) |
|
|
print("✅ Done!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Extract and SVD a LoRA from two SafeTensors checkpoints.") |
|
|
|
|
|
parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint in .safetensors format.") |
|
|
parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint in .safetensors format.") |
|
|
parser.add_argument("output", type=str, help="Path to save the output LoRA file in .safetensors format.") |
|
|
|
|
|
parser.add_argument("--rank", type=int, required=True, help="The target rank for the SVD.") |
|
|
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], |
|
|
help="Device to use for computation ('cuda' or 'cpu').") |
|
|
parser.add_argument("--alpha", type=float, default=1.0, help="The alpha (scaling) factor for the LoRA.") |
|
|
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], |
|
|
help="Precision to use for calculations.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.device == "cuda" and not torch.cuda.is_available(): |
|
|
print("CUDA is not available. Falling back to CPU.") |
|
|
args.device = "cpu" |
|
|
|
|
|
dtype = get_torch_dtype(args.precision) |
|
|
|
|
|
extract_and_svd_lora(args.model_a, args.model_b, args.output, args.rank, args.device, args.alpha, dtype) |