File size: 6,066 Bytes
faa1b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import argparse
from safetensors.torch import save_file, safe_open
from tqdm import tqdm
import sys


def get_torch_dtype(dtype_str: str):
    """Converts a string to a torch.dtype object."""
    if dtype_str == "fp32":
        return torch.float32
    if dtype_str == "fp16":
        return torch.float16
    if dtype_str == "bf16":
        return torch.bfloat16
    raise ValueError(f"Unsupported dtype: {dtype_str}")


def extract_and_svd_lora(model_a_path: str, model_b_path: str, output_path: str, rank: int, device: str, alpha: float,
                         dtype: torch.dtype):
    """
    Extracts the difference between two models, applies SVD to reduce the rank,
    and saves the result as a LoRA file.
    """
    print(f"Loading base model A: {model_a_path}")
    print(f"Loading finetuned model B: {model_b_path}")

    lora_tensors = {}

    with safe_open(model_a_path, framework="pt", device="cpu") as f_a, \
            safe_open(model_b_path, framework="pt", device="cpu") as f_b:

        keys_a = set(f_a.keys())
        keys_b = set(f_b.keys())
        common_keys = keys_a.intersection(keys_b)

        # Filter for processable layers (typically linear and conv weights)
        # We exclude biases and non-weight tensors.
        weight_keys = {k for k in common_keys if k.endswith('.weight') and 'lora_' not in k}

        if not weight_keys:
            print("No common weight keys found between the two models. Exiting.")
            sys.exit(1)

        print(f"Found {len(weight_keys)} common weight keys to process.")

        # Main processing loop with progress bar
        for key in tqdm(sorted(list(weight_keys)), desc="Processing Layers"):
            try:
                # Load tensors and move to the selected device and dtype
                tensor_a = f_a.get_tensor(key).to(device=device, dtype=dtype)
                tensor_b = f_b.get_tensor(key).to(device=device, dtype=dtype)

                if tensor_a.shape != tensor_b.shape:
                    print(f"Skipping key {key} due to shape mismatch: A={tensor_a.shape}, B={tensor_b.shape}")
                    continue

                # Calculate the difference (delta weight)
                delta_w = tensor_b - tensor_a

                # SVD works on 2D matrices. Reshape conv layers and other ND tensors.
                original_shape = delta_w.shape
                if delta_w.dim() > 2:
                    delta_w = delta_w.view(original_shape[0], -1)

                # --- Core SVD Logic ---
                # ΔW ≈ U * S * Vh
                # U: Left singular vectors
                # S: Singular values (a 1D vector)
                # Vh: Right singular vectors (transposed)
                U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False)

                # Truncate to the desired rank
                current_rank = min(rank, S.size(0))  # Ensure rank is not > possible rank
                U = U[:, :current_rank]
                S = S[:current_rank]
                Vh = Vh[:current_rank, :]

                # --- Decompose into LoRA A and B matrices ---
                # LoRA A (lora_down) is Vh
                # LoRA B (lora_up) is U * S
                # We scale lora_up by the singular values to retain the magnitude
                lora_down = Vh
                lora_up = U @ torch.diag(S)

                # Reshape back to original conv format if necessary
                if len(original_shape) > 2:
                    # For Conv2D, lora_down is (rank, in_channels * k_h * k_w)
                    # and lora_up is (out_channels, rank). No reshape needed for up.
                    pass  # The matrix form is standard for LoRA conv layers

                # Create LoRA tensor names
                base_name = key.replace('.weight', '')
                lora_down_name = f"{base_name}.lora_down.weight"
                lora_up_name = f"{base_name}.lora_up.weight"
                alpha_name = f"{base_name}.alpha"

                # Store tensors, moving them to CPU for saving
                lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32)
                lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32)
                lora_tensors[alpha_name] = torch.tensor(alpha).to(torch.float32)

            except Exception as e:
                print(f"Failed to process key {key}: {e}")

    # Save the final LoRA file
    if not lora_tensors:
        print("No tensors were processed. Output file will not be created.")
        return

    print(f"\nSaving {len(lora_tensors)} tensors to {output_path}...")
    save_file(lora_tensors, output_path)
    print("✅ Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract and SVD a LoRA from two SafeTensors checkpoints.")

    parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint in .safetensors format.")
    parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint in .safetensors format.")
    parser.add_argument("output", type=str, help="Path to save the output LoRA file in .safetensors format.")

    parser.add_argument("--rank", type=int, required=True, help="The target rank for the SVD.")
    parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
                        help="Device to use for computation ('cuda' or 'cpu').")
    parser.add_argument("--alpha", type=float, default=1.0, help="The alpha (scaling) factor for the LoRA.")
    parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"],
                        help="Precision to use for calculations.")

    args = parser.parse_args()

    # Device check
    if args.device == "cuda" and not torch.cuda.is_available():
        print("CUDA is not available. Falling back to CPU.")
        args.device = "cpu"

    dtype = get_torch_dtype(args.precision)

    extract_and_svd_lora(args.model_a, args.model_b, args.output, args.rank, args.device, args.alpha, dtype)