File size: 6,580 Bytes
59d2585 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import argparse
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
def resize_lora_model(model_path, output_path, new_dim, device):
"""
Resizes the LoRA dimension of a model using SVD for optimal weight preservation.
Args:
model_path (str): Path to the LoRA model to resize.
output_path (str): Path to save the new resized model.
new_dim (int): The target new dimension for the LoRA weights.
device (str): The device to run calculations on ('cuda' or 'cpu').
"""
print(f"Loading model from: {model_path}")
model = load_file(model_path)
new_model = {}
# --- Metadata & Weight Inspection ---
original_dim = None
alpha = None
try:
with safe_open(model_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata:
if 'ss_network_dim' in metadata:
original_dim = int(metadata['ss_network_dim'])
print(f"Original dimension (from metadata): {original_dim}")
if 'ss_network_alpha' in metadata:
alpha = float(metadata['ss_network_alpha'])
print(f"Original alpha (from metadata): {alpha}")
except Exception as e:
print(f"Could not read metadata: {e}. Dimension and alpha will be inferred.")
# Infer original_dim from weights if not in metadata
if original_dim is None:
for key in model.keys():
if key.endswith((".lora_down.weight", ".lora_A.weight")):
original_dim = model[key].shape[0]
print(f"Inferred original dimension from weights: {original_dim}")
break
# Infer alpha from weights if not in metadata
if alpha is None:
for key in model.keys():
if key.endswith(".alpha"):
alpha = model[key].item()
print(f"Inferred alpha from weights: {alpha}")
break
# Fallback for alpha if still not found
if alpha is None and original_dim is not None:
alpha = float(original_dim)
print(f"Alpha not found, falling back to using dimension: {alpha}")
# --- Tensor Processing ---
lora_keys_to_process = set()
for key in model.keys():
if 'lora_' in key and key.endswith('.weight'):
# Get the base name (e.g., "lora_unet_down_blocks_0_attentions_0_proj_in")
base_key = key.split('.lora_')[0]
lora_keys_to_process.add(base_key)
if not lora_keys_to_process:
print("Error: No LoRA weights found in the model.")
return
print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...")
for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"):
try:
down_key, up_key = None, None
# Determine naming convention
if base_key + ".lora_down.weight" in model:
down_key = base_key + ".lora_down.weight"
up_key = base_key + ".lora_up.weight"
elif base_key + ".lora_A.weight" in model:
down_key = base_key + ".lora_A.weight"
up_key = base_key + ".lora_B.weight"
else:
continue
# Move weights to the selected device for calculation
down_weight = model[down_key].to(device)
up_weight = model[up_key].to(device)
# --- SVD Resizing ---
original_dtype = up_weight.dtype
# Combine the two matrices to get the full weight update
conv2d = down_weight.ndim == 4
if conv2d:
# For conv layers, treat spatial dims as batch dims
down_weight = down_weight.flatten(1)
up_weight = up_weight.flatten(1)
full_weight = up_weight @ down_weight
# Always cast to float32 for SVD, as some devices (CPU, and some GPUs) don't support bfloat16
U, S, Vh = torch.linalg.svd(full_weight.to(torch.float32))
# Truncate or pad the SVD components
U = U[:, :new_dim]
S = S[:new_dim]
Vh = Vh[:new_dim, :]
# Reconstruct the new low-rank matrices
new_down = torch.diag(S) @ Vh
new_up = U
# Reshape back to original conv format if necessary
if conv2d:
new_down = new_down.reshape(new_dim, down_weight.shape[1], 1, 1)
new_up = new_up.reshape(up_weight.shape[0], new_dim, 1, 1)
# Move back to CPU and original dtype for saving
new_model[down_key] = new_down.contiguous().to(original_dtype).cpu()
new_model[up_key] = new_up.contiguous().to(original_dtype).cpu()
# Copy alpha tensor if it exists for this key
alpha_key = base_key + ".alpha"
if alpha_key in model:
new_model[alpha_key] = model[alpha_key]
except KeyError:
continue
# Copy non-LoRA tensors
for key, value in model.items():
if ".lora_" not in key:
new_model[key] = value
# --- Save New Model ---
new_metadata = {'ss_network_dim': str(new_dim)}
if alpha is not None and original_dim is not None and original_dim > 0:
new_alpha = alpha * (new_dim / original_dim)
new_metadata['ss_network_alpha'] = str(new_alpha)
print(f"\nNew alpha scaled to: {new_alpha:.2f}")
print(f"\nSaving resized model to: {output_path}")
save_file(new_model, output_path, metadata=new_metadata)
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Resize a LoRA model to a new dimension using SVD.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).")
parser.add_argument("output_path", type=str, help="Path to save the resized LoRA model.")
parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).")
parser.add_argument("--device", type=str, default=None,
help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.")
args = parser.parse_args()
if args.device:
device = args.device
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
resize_lora_model(args.model_path, args.output_path, args.new_dim, device)
|