Upload combine.py
Browse files- combine.py +51 -0
combine.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from safetensors import safe_open
|
| 4 |
+
from safetensors.torch import save_file
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def merge_safetensors(input_dir, output_file, config_file):
|
| 8 |
+
# Dictionary to store all tensors
|
| 9 |
+
merged_tensors = {}
|
| 10 |
+
|
| 11 |
+
# Load config
|
| 12 |
+
with open(config_file, 'r') as f:
|
| 13 |
+
config = json.load(f)
|
| 14 |
+
|
| 15 |
+
# Prepare metadata
|
| 16 |
+
metadata = {
|
| 17 |
+
"format": "pt",
|
| 18 |
+
"total_size": "", #str(total_size), # Notice we stringify this!
|
| 19 |
+
"_diffusers_version": config.get("_diffusers_version", ""),
|
| 20 |
+
"_class_name": config.get("_class_name", ""),
|
| 21 |
+
# Add other fields at this level
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
total_size = 0
|
| 25 |
+
|
| 26 |
+
# Iterate through all files in the input directory
|
| 27 |
+
for filename in os.listdir(input_dir):
|
| 28 |
+
if filename.endswith('.safetensors'):
|
| 29 |
+
file_path = os.path.join(input_dir, filename)
|
| 30 |
+
|
| 31 |
+
# Load tensors and metadata from each file
|
| 32 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 33 |
+
file_metadata = f.metadata()
|
| 34 |
+
if file_metadata and "__metadata__" in file_metadata:
|
| 35 |
+
total_size += int(file_metadata["__metadata__"].get("total_size", 0))
|
| 36 |
+
|
| 37 |
+
for key in f.keys():
|
| 38 |
+
tensor = f.get_tensor(key)
|
| 39 |
+
merged_tensors[key] = tensor
|
| 40 |
+
|
| 41 |
+
# Add total size to metadata
|
| 42 |
+
metadata["total_size"] = str(total_size)
|
| 43 |
+
|
| 44 |
+
# Save the merged tensors to a single file with metadata
|
| 45 |
+
save_file(merged_tensors, output_file, metadata)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
input_directory = './10_1'
|
| 49 |
+
output_file = './10_1/flux1-merge-S10_D1.safetensors'
|
| 50 |
+
config_file = './10_1/config.json'
|
| 51 |
+
merge_safetensors(input_directory, output_file, config_file)
|