| |
| |
| |
| |
|
|
| from safetensors.torch import load_file, save_file |
| import os |
| import glob |
| import torch |
|
|
| def main(): |
| print("=== Shard Merger + FP8 Quantizer (E4M3FN) ===\n") |
| |
| shard_dir = input("Enter the directory path containing the .safetensors shards: ").strip() |
| if not os.path.isdir(shard_dir): |
| print(f"Error: Directory '{shard_dir}' not found!") |
| return |
| |
| output_name = input("Enter desired output filename (e.g. VBVR-Wan2.2-fp8-e4m3.safetensors): ").strip() |
| if not output_name: |
| output_name = "Wan2.2-fp8-e4m3.safetensors" |
| if not output_name.endswith(".safetensors"): |
| output_name += ".safetensors" |
| |
| output_path = os.path.join(shard_dir, output_name) |
| |
| |
| shard_files = sorted(glob.glob(os.path.join(shard_dir, "*.safetensors"))) |
| |
| if not shard_files: |
| print("No .safetensors files found!") |
| return |
| if len(shard_files) == 1: |
| print("Only one shard found β skipping merge, quantizing directly.") |
| merged_path = shard_files[0] |
| else: |
| print(f"\nFound {len(shard_files)} shards:") |
| for f in shard_files: |
| print(f" - {os.path.basename(f)}") |
| |
| confirm = input("\nProceed with merging? (y/n): ").strip().lower() |
| if confirm != 'y': |
| return |
| |
| print("\nMerging shards...") |
| merged_state_dict = {} |
| for shard_path in shard_files: |
| print(f" Loading {os.path.basename(shard_path)}...") |
| state_dict = load_file(shard_path, device="cpu") |
| merged_state_dict.update(state_dict) |
| |
| temp_merged = os.path.join(shard_dir, "temp_merged.safetensors") |
| print(f"Saving temporary merged file: {temp_merged}") |
| save_file(merged_state_dict, temp_merged) |
| merged_path = temp_merged |
| |
| |
| print("\nQuantizing to FP8 (float8_e4m3fn)... This may take 10β30 minutes depending on RAM/CPU.") |
| quantized_dict = {} |
| state_dict = load_file(merged_path, device="cpu") |
| |
| for key, tensor in state_dict.items(): |
| if tensor.dtype in (torch.float32, torch.float16, torch.bfloat16): |
| |
| quantized_dict[key] = tensor.to(dtype=torch.float8_e4m3fn) |
| else: |
| quantized_dict[key] = tensor |
| |
| print(f"Saving FP8 quantized model to:\n {output_path}") |
| save_file(quantized_dict, output_path) |
| |
| |
| if merged_path != shard_files[0] and os.path.exists(merged_path): |
| os.remove(merged_path) |
| print("Cleaned up temporary merged file.") |
| |
| print("\nDone! FP8 single-file model ready. Load in ComfyUI/Diffusers and test prompts.") |
|
|
| if __name__ == "__main__": |
| main() |