File size: 6,580 Bytes
59d2585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import argparse
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm


def resize_lora_model(model_path, output_path, new_dim, device):
    """
    Resizes the LoRA dimension of a model using SVD for optimal weight preservation.

    Args:
        model_path (str): Path to the LoRA model to resize.
        output_path (str): Path to save the new resized model.
        new_dim (int): The target new dimension for the LoRA weights.
        device (str): The device to run calculations on ('cuda' or 'cpu').
    """
    print(f"Loading model from: {model_path}")
    model = load_file(model_path)
    new_model = {}

    # --- Metadata & Weight Inspection ---
    original_dim = None
    alpha = None
    try:
        with safe_open(model_path, framework="pt", device="cpu") as f:
            metadata = f.metadata()
            if metadata:
                if 'ss_network_dim' in metadata:
                    original_dim = int(metadata['ss_network_dim'])
                    print(f"Original dimension (from metadata): {original_dim}")
                if 'ss_network_alpha' in metadata:
                    alpha = float(metadata['ss_network_alpha'])
                    print(f"Original alpha (from metadata): {alpha}")
    except Exception as e:
        print(f"Could not read metadata: {e}. Dimension and alpha will be inferred.")

    # Infer original_dim from weights if not in metadata
    if original_dim is None:
        for key in model.keys():
            if key.endswith((".lora_down.weight", ".lora_A.weight")):
                original_dim = model[key].shape[0]
                print(f"Inferred original dimension from weights: {original_dim}")
                break

    # Infer alpha from weights if not in metadata
    if alpha is None:
        for key in model.keys():
            if key.endswith(".alpha"):
                alpha = model[key].item()
                print(f"Inferred alpha from weights: {alpha}")
                break

                # Fallback for alpha if still not found
    if alpha is None and original_dim is not None:
        alpha = float(original_dim)
        print(f"Alpha not found, falling back to using dimension: {alpha}")

    # --- Tensor Processing ---
    lora_keys_to_process = set()
    for key in model.keys():
        if 'lora_' in key and key.endswith('.weight'):
            # Get the base name (e.g., "lora_unet_down_blocks_0_attentions_0_proj_in")
            base_key = key.split('.lora_')[0]
            lora_keys_to_process.add(base_key)

    if not lora_keys_to_process:
        print("Error: No LoRA weights found in the model.")
        return

    print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...")

    for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"):
        try:
            down_key, up_key = None, None

            # Determine naming convention
            if base_key + ".lora_down.weight" in model:
                down_key = base_key + ".lora_down.weight"
                up_key = base_key + ".lora_up.weight"
            elif base_key + ".lora_A.weight" in model:
                down_key = base_key + ".lora_A.weight"
                up_key = base_key + ".lora_B.weight"
            else:
                continue

            # Move weights to the selected device for calculation
            down_weight = model[down_key].to(device)
            up_weight = model[up_key].to(device)

            # --- SVD Resizing ---
            original_dtype = up_weight.dtype

            # Combine the two matrices to get the full weight update
            conv2d = down_weight.ndim == 4
            if conv2d:
                # For conv layers, treat spatial dims as batch dims
                down_weight = down_weight.flatten(1)
                up_weight = up_weight.flatten(1)

            full_weight = up_weight @ down_weight

            # Always cast to float32 for SVD, as some devices (CPU, and some GPUs) don't support bfloat16
            U, S, Vh = torch.linalg.svd(full_weight.to(torch.float32))

            # Truncate or pad the SVD components
            U = U[:, :new_dim]
            S = S[:new_dim]
            Vh = Vh[:new_dim, :]

            # Reconstruct the new low-rank matrices
            new_down = torch.diag(S) @ Vh
            new_up = U

            # Reshape back to original conv format if necessary
            if conv2d:
                new_down = new_down.reshape(new_dim, down_weight.shape[1], 1, 1)
                new_up = new_up.reshape(up_weight.shape[0], new_dim, 1, 1)

            # Move back to CPU and original dtype for saving
            new_model[down_key] = new_down.contiguous().to(original_dtype).cpu()
            new_model[up_key] = new_up.contiguous().to(original_dtype).cpu()

            # Copy alpha tensor if it exists for this key
            alpha_key = base_key + ".alpha"
            if alpha_key in model:
                new_model[alpha_key] = model[alpha_key]

        except KeyError:
            continue

    # Copy non-LoRA tensors
    for key, value in model.items():
        if ".lora_" not in key:
            new_model[key] = value

    # --- Save New Model ---
    new_metadata = {'ss_network_dim': str(new_dim)}
    if alpha is not None and original_dim is not None and original_dim > 0:
        new_alpha = alpha * (new_dim / original_dim)
        new_metadata['ss_network_alpha'] = str(new_alpha)
        print(f"\nNew alpha scaled to: {new_alpha:.2f}")

    print(f"\nSaving resized model to: {output_path}")
    save_file(new_model, output_path, metadata=new_metadata)
    print("Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Resize a LoRA model to a new dimension using SVD.",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).")
    parser.add_argument("output_path", type=str, help="Path to save the resized LoRA model.")
    parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).")
    parser.add_argument("--device", type=str, default=None,
                        help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.")

    args = parser.parse_args()

    if args.device:
        device = args.device
    else:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Using device: {device}")

    resize_lora_model(args.model_path, args.output_path, args.new_dim, device)