SanDiegoDude commited on
Commit
491e92d
·
verified ·
1 Parent(s): ae08bd1

Upload convert_to_fp8.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. convert_to_fp8.py +108 -0
convert_to_fp8.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert the JoyAI-Image DiT transformer from bf16 to FP8 (float8_e4m3fn).
3
+
4
+ Strategy:
5
+ - 2D+ weight tensors (linear layers, conv kernels) → float8_e4m3fn
6
+ - 1D tensors (biases, norms, embeddings) → keep original dtype
7
+ - This matches ComfyUI's fp8 convention for diffusion models
8
+
9
+ The resulting checkpoint is ~16 GB instead of ~32 GB, fitting in a 4090's 24 GB VRAM.
10
+
11
+ Usage:
12
+ python convert_to_fp8.py --input ckpts_infer/transformer/transformer.safetensors \
13
+ --output ckpts_infer/transformer/transformer_fp8.safetensors
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import os
20
+ import sys
21
+ import time
22
+
23
+ import torch
24
+ from safetensors.torch import load_file, save_file
25
+
26
+
27
+ FP8_DTYPE = torch.float8_e4m3fn
28
+
29
+
30
+ def quantize_tensor(t: torch.Tensor) -> torch.Tensor:
31
+ """Quantize a tensor to FP8 E4M3, clamping to the representable range."""
32
+ finfo = torch.finfo(FP8_DTYPE)
33
+ t_clamped = t.float().clamp(finfo.min, finfo.max)
34
+ return t_clamped.to(FP8_DTYPE)
35
+
36
+
37
+ def should_quantize(name: str, tensor: torch.Tensor) -> bool:
38
+ """Decide whether a tensor should be quantized to FP8."""
39
+ if tensor.ndim < 2:
40
+ return False
41
+ if tensor.numel() < 1024:
42
+ return False
43
+ return True
44
+
45
+
46
+ def convert(input_path: str, output_path: str) -> None:
47
+ print(f"Loading {input_path} ...")
48
+ state_dict = load_file(input_path, device="cpu")
49
+
50
+ total_tensors = len(state_dict)
51
+ quantized_count = 0
52
+ kept_count = 0
53
+ original_bytes = 0
54
+ new_bytes = 0
55
+
56
+ print(f"Processing {total_tensors} tensors ...")
57
+ converted = {}
58
+ for name, tensor in state_dict.items():
59
+ original_bytes += tensor.numel() * tensor.element_size()
60
+
61
+ if should_quantize(name, tensor):
62
+ converted[name] = quantize_tensor(tensor)
63
+ quantized_count += 1
64
+ else:
65
+ converted[name] = tensor
66
+ kept_count += 1
67
+
68
+ new_bytes += converted[name].numel() * converted[name].element_size()
69
+
70
+ print(f" Quantized to FP8: {quantized_count}")
71
+ print(f" Kept original: {kept_count}")
72
+ print(f" Size: {original_bytes / 1e9:.2f} GB → {new_bytes / 1e9:.2f} GB "
73
+ f"({new_bytes / original_bytes:.1%})")
74
+
75
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
76
+ print(f"Saving {output_path} ...")
77
+ save_file(converted, output_path, metadata={"format": "fp8_e4m3fn"})
78
+
79
+ output_size = os.path.getsize(output_path)
80
+ print(f" File size: {output_size / 1e9:.2f} GB")
81
+
82
+ print("Verifying reload ...")
83
+ reloaded = load_file(output_path, device="cpu")
84
+ for name in converted:
85
+ assert reloaded[name].dtype == converted[name].dtype, \
86
+ f"Dtype mismatch for {name}: {reloaded[name].dtype} vs {converted[name].dtype}"
87
+ assert reloaded[name].shape == converted[name].shape, \
88
+ f"Shape mismatch for {name}"
89
+ print("Verification passed.")
90
+
91
+
92
+ def main() -> None:
93
+ parser = argparse.ArgumentParser(description="Convert DiT weights to FP8")
94
+ parser.add_argument("--input", required=True, help="Input safetensors file (bf16)")
95
+ parser.add_argument("--output", required=True, help="Output safetensors file (fp8)")
96
+ args = parser.parse_args()
97
+
98
+ if not os.path.isfile(args.input):
99
+ print(f"Error: {args.input} not found", file=sys.stderr)
100
+ sys.exit(1)
101
+
102
+ t0 = time.time()
103
+ convert(args.input, args.output)
104
+ print(f"\nDone in {time.time() - t0:.1f}s")
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()