import torch import argparse from safetensors.torch import save_file, safe_open from tqdm import tqdm import sys def get_torch_dtype(dtype_str: str): """Converts a string to a torch.dtype object.""" if dtype_str == "fp32": return torch.float32 if dtype_str == "fp16": return torch.float16 if dtype_str == "bf16": return torch.bfloat16 raise ValueError(f"Unsupported dtype: {dtype_str}") def extract_and_svd_lora(model_a_path: str, model_b_path: str, output_path: str, rank: int, device: str, alpha: float, dtype: torch.dtype): """ Extracts the difference between two models, applies SVD to reduce the rank, and saves the result as a LoRA file. """ print(f"Loading base model A: {model_a_path}") print(f"Loading finetuned model B: {model_b_path}") lora_tensors = {} with safe_open(model_a_path, framework="pt", device="cpu") as f_a, \ safe_open(model_b_path, framework="pt", device="cpu") as f_b: keys_a = set(f_a.keys()) keys_b = set(f_b.keys()) common_keys = keys_a.intersection(keys_b) # Filter for processable layers (typically linear and conv weights) # We exclude biases and non-weight tensors. weight_keys = {k for k in common_keys if k.endswith('.weight') and 'lora_' not in k} if not weight_keys: print("No common weight keys found between the two models. Exiting.") sys.exit(1) print(f"Found {len(weight_keys)} common weight keys to process.") # Main processing loop with progress bar for key in tqdm(sorted(list(weight_keys)), desc="Processing Layers"): try: # Load tensors and move to the selected device and dtype tensor_a = f_a.get_tensor(key).to(device=device, dtype=dtype) tensor_b = f_b.get_tensor(key).to(device=device, dtype=dtype) if tensor_a.shape != tensor_b.shape: print(f"Skipping key {key} due to shape mismatch: A={tensor_a.shape}, B={tensor_b.shape}") continue # Calculate the difference (delta weight) delta_w = tensor_b - tensor_a # SVD works on 2D matrices. Reshape conv layers and other ND tensors. original_shape = delta_w.shape if delta_w.dim() > 2: delta_w = delta_w.view(original_shape[0], -1) # --- Core SVD Logic --- # ΔW ≈ U * S * Vh # U: Left singular vectors # S: Singular values (a 1D vector) # Vh: Right singular vectors (transposed) U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False) # Truncate to the desired rank current_rank = min(rank, S.size(0)) # Ensure rank is not > possible rank U = U[:, :current_rank] S = S[:current_rank] Vh = Vh[:current_rank, :] # --- Decompose into LoRA A and B matrices --- # LoRA A (lora_down) is Vh # LoRA B (lora_up) is U * S # We scale lora_up by the singular values to retain the magnitude lora_down = Vh lora_up = U @ torch.diag(S) # Reshape back to original conv format if necessary if len(original_shape) > 2: # For Conv2D, lora_down is (rank, in_channels * k_h * k_w) # and lora_up is (out_channels, rank). No reshape needed for up. pass # The matrix form is standard for LoRA conv layers # Create LoRA tensor names base_name = key.replace('.weight', '') lora_down_name = f"{base_name}.lora_down.weight" lora_up_name = f"{base_name}.lora_up.weight" alpha_name = f"{base_name}.alpha" # Store tensors, moving them to CPU for saving lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32) lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32) lora_tensors[alpha_name] = torch.tensor(alpha).to(torch.float32) except Exception as e: print(f"Failed to process key {key}: {e}") # Save the final LoRA file if not lora_tensors: print("No tensors were processed. Output file will not be created.") return print(f"\nSaving {len(lora_tensors)} tensors to {output_path}...") save_file(lora_tensors, output_path) print("✅ Done!") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Extract and SVD a LoRA from two SafeTensors checkpoints.") parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint in .safetensors format.") parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint in .safetensors format.") parser.add_argument("output", type=str, help="Path to save the output LoRA file in .safetensors format.") parser.add_argument("--rank", type=int, required=True, help="The target rank for the SVD.") parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use for computation ('cuda' or 'cpu').") parser.add_argument("--alpha", type=float, default=1.0, help="The alpha (scaling) factor for the LoRA.") parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], help="Precision to use for calculations.") args = parser.parse_args() # Device check if args.device == "cuda" and not torch.cuda.is_available(): print("CUDA is not available. Falling back to CPU.") args.device = "cpu" dtype = get_torch_dtype(args.precision) extract_and_svd_lora(args.model_a, args.model_b, args.output, args.rank, args.device, args.alpha, dtype)