checkpoint / ts.py
MoeZilla's picture
Rename test.py to ts.py
3dd7366 verified
import torch
from safetensors.torch import load_file, save_file
# Load both sharded safetensors
shard_1 = load_file("diffusion_pytorch_model-00001-of-00002.safetensors")
shard_2 = load_file("diffusion_pytorch_model-00002-of-00002.safetensors")
# Combine both shards into a single dictionary
combined_model = {**shard_1, **shard_2}
# Save the combined model as output.safetensors
save_file(combined_model, "flowgram.safetensors")
print("Successfully merged the files into output.safetensors")