more_wan_loras / shard_merger.py
olivv-cs's picture
Create shard_merger.py
aebabbd verified
# merge_and_quantize_to_fp8.py
# Merges sharded .safetensors β†’ single file β†’ quantizes to FP8 E4M3FN
# Usage: python merge_and_quantize_to_fp8.py
# Prompts for input dir and output name
from safetensors.torch import load_file, save_file
import os
import glob
import torch
def main():
print("=== Shard Merger + FP8 Quantizer (E4M3FN) ===\n")
shard_dir = input("Enter the directory path containing the .safetensors shards: ").strip()
if not os.path.isdir(shard_dir):
print(f"Error: Directory '{shard_dir}' not found!")
return
output_name = input("Enter desired output filename (e.g. VBVR-Wan2.2-fp8-e4m3.safetensors): ").strip()
if not output_name:
output_name = "Wan2.2-fp8-e4m3.safetensors"
if not output_name.endswith(".safetensors"):
output_name += ".safetensors"
output_path = os.path.join(shard_dir, output_name)
# Step 1: Find and sort shards
shard_files = sorted(glob.glob(os.path.join(shard_dir, "*.safetensors")))
if not shard_files:
print("No .safetensors files found!")
return
if len(shard_files) == 1:
print("Only one shard found β€” skipping merge, quantizing directly.")
merged_path = shard_files[0]
else:
print(f"\nFound {len(shard_files)} shards:")
for f in shard_files:
print(f" - {os.path.basename(f)}")
confirm = input("\nProceed with merging? (y/n): ").strip().lower()
if confirm != 'y':
return
print("\nMerging shards...")
merged_state_dict = {}
for shard_path in shard_files:
print(f" Loading {os.path.basename(shard_path)}...")
state_dict = load_file(shard_path, device="cpu")
merged_state_dict.update(state_dict)
temp_merged = os.path.join(shard_dir, "temp_merged.safetensors")
print(f"Saving temporary merged file: {temp_merged}")
save_file(merged_state_dict, temp_merged)
merged_path = temp_merged
# Step 2: Quantize to FP8
print("\nQuantizing to FP8 (float8_e4m3fn)... This may take 10–30 minutes depending on RAM/CPU.")
quantized_dict = {}
state_dict = load_file(merged_path, device="cpu")
for key, tensor in state_dict.items():
if tensor.dtype in (torch.float32, torch.float16, torch.bfloat16):
# Simple cast β€” most community fp8 Wan models use this or similar
quantized_dict[key] = tensor.to(dtype=torch.float8_e4m3fn)
else:
quantized_dict[key] = tensor # keep as-is if already quantized/other
print(f"Saving FP8 quantized model to:\n {output_path}")
save_file(quantized_dict, output_path)
# Optional: clean up temp file
if merged_path != shard_files[0] and os.path.exists(merged_path):
os.remove(merged_path)
print("Cleaned up temporary merged file.")
print("\nDone! FP8 single-file model ready. Load in ComfyUI/Diffusers and test prompts.")
if __name__ == "__main__":
main()