anyMODE commited on
Commit
042c5bc
·
verified ·
1 Parent(s): 54c3b42

Upload lora_dims.py

Browse files
Files changed (1) hide show
  1. lora_dims.py +91 -0
lora_dims.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from safetensors import safe_open
4
+
5
+ def get_lora_dimensions_from_directory(directory_path):
6
+ """
7
+ Scans a directory for .safetensors files and extracts the LoRA network dimension
8
+ from their metadata, falling back to inspecting tensor shapes if metadata is absent.
9
+
10
+ Args:
11
+ directory_path (str): The path to the directory to scan.
12
+ """
13
+ print(f"Scanning for LoRA models in: '{directory_path}'...\n")
14
+
15
+ found_models = 0
16
+
17
+ # Walk through the directory and its subdirectories
18
+ for root, _, files in os.walk(directory_path):
19
+ for filename in sorted(files):
20
+ if filename.lower().endswith(".safetensors"):
21
+ file_path = os.path.join(root, filename)
22
+ try:
23
+ # Use safe_open to read metadata without loading the whole file
24
+ with safe_open(file_path, framework="pt", device="cpu") as f:
25
+ metadata = f.metadata()
26
+
27
+ if not metadata:
28
+ print(f"- {filename}: No metadata found. Checking weights...")
29
+ # Fallthrough to weight checking
30
+
31
+ # LoRA training scripts like Kohya's SS store the dimension here
32
+ network_dim = metadata.get("ss_network_dim") if metadata else None
33
+
34
+ if network_dim:
35
+ print(f"- {filename}: Dimension = {network_dim} (from metadata)")
36
+ found_models += 1
37
+ else:
38
+ # Fallback: try to determine dimension from tensor shapes
39
+ dim_from_weights = None
40
+ for key in f.keys():
41
+ # Typically, the rank is the first dimension of the 'lora_down' tensor
42
+ if key.endswith("lora_down.weight"):
43
+ tensor = f.get_tensor(key)
44
+ # The shape of lora_down.weight is (rank, in_features)
45
+ dim_from_weights = tensor.shape[0]
46
+ break # Found it, no need to check other keys
47
+
48
+ # Alternative naming uses lora_B or lora_up for the up-projection
49
+ if key.endswith(("lora_B.weight", "lora_up.weight")):
50
+ tensor = f.get_tensor(key)
51
+ # The shape of lora_up/lora_B is (out_features, rank)
52
+ dim_from_weights = tensor.shape[1]
53
+ break # Found it, no need to check other keys
54
+
55
+ if dim_from_weights is not None:
56
+ print(f"- {filename}: Dimension = {dim_from_weights} (from weights)")
57
+ found_models += 1
58
+ else:
59
+ print(f"- {filename}: (Could not determine dimension from metadata or weights)")
60
+
61
+ except Exception as e:
62
+ print(f"Could not process {filename}. Error: {e}")
63
+
64
+ if found_models == 0:
65
+ print("\nNo LoRA models with dimension information were found in the specified directory.")
66
+ else:
67
+ print(f"\nScan complete. Found {found_models} models with dimension info.")
68
+
69
+ if __name__ == "__main__":
70
+ # Set up command-line argument parsing
71
+ parser = argparse.ArgumentParser(
72
+ description="Get the network dimensions of LoRA models in a directory.",
73
+ formatter_class=argparse.RawTextHelpFormatter
74
+ )
75
+
76
+ parser.add_argument(
77
+ "directory",
78
+ type=str,
79
+ help="The path to the directory containing your LoRA (.safetensors) files."
80
+ )
81
+
82
+ args = parser.parse_args()
83
+
84
+ # Check if the provided path is a valid directory
85
+ if not os.path.isdir(args.directory):
86
+ print(f"Error: The path '{args.directory}' is not a valid directory.")
87
+ else:
88
+ get_lora_dimensions_from_directory(args.directory)
89
+
90
+
91
+