How did you create the FP8 mix?
#3
pinned
by
Nuke1229 - opened
Hi, thanks for sharing. I'm interested in the FP8mix version. Could you please explain how you created FP8mix?
I used that method.
import torch
from safetensors.torch import save_file
from safetensors import safe_open
import os
import gc
import json
def load_sensitivity_scores(path="sensitivity_report.json"):
if not os.path.exists(path):
return []
with open(path, "r") as f:
return json.load(f)
def is_base_important(key: str) -> bool:
if any(x in key for x in ['img_in', 'txt_in', 'proj_out']):
return True
if 'norm' in key or key.endswith('.bias'):
return True
return False
def quantize_comfy_native(tensor: torch.Tensor):
"""
ComfyUI ops.py ๊ท๊ฒฉ์ ๋ง๊ฒ weight์ weight_scale์ ์์ฑํฉ๋๋ค.
"""
if tensor.dtype == torch.bfloat16:
tensor = tensor.float()
max_val = tensor.abs().max().item()
if max_val == 0:
return torch.zeros_like(tensor, dtype=torch.float8_e4m3fn), torch.tensor(1.0, dtype=torch.bfloat16)
# FP8_e4m3fn max is 448.0
# weight_scale = max_val / 448.0
# weight = tensor / weight_scale (casted to fp8)
scale = max_val / 448.0
normalized = tensor / scale
quantized_weight = normalized.to(torch.float8_e4m3fn)
weight_scale = torch.tensor(scale, dtype=torch.bfloat16)
return quantized_weight, weight_scale
def main():
input_path = "FireRed-Image-Edit-1.0_bf16.safetensors"
output_path = "FireRed-Image-Edit-1.0_fp8_comfy.safetensors"
scores = load_sensitivity_scores()
print(f"Selecting layers for 22GB budget (ComfyUI Native Format)...")
protected_keys = set()
with safe_open(input_path, framework="pt", device="cpu") as f:
all_keys = f.keys()
for key in all_keys:
if is_base_important(key):
protected_keys.add(key)
total_params = sum(f.get_tensor(k).numel() for k in all_keys)
# weight_scale(BF16, 2bytes)์ด ์ถ๊ฐ๋๋ฏ๋ก ์์ฐ์ ์กฐ๊ธ ๋ ํ์ดํธํ๊ฒ ์ก์
# FP8(1b) + Scale(2b/tensor) ์ด๋ฏ๋ก ํ
์ ๊ฐ์๊ฐ ๋ง์ผ๋ฉด ์ฉ๋์ด ์ฝ๊ฐ ๋์ด๋จ
budget_params = total_params * 0.125
current_bf16_params = 0
for key in protected_keys:
current_bf16_params += f.get_tensor(key).numel()
for key, score, stats in scores:
if key in protected_keys: continue
tensor_size = f.get_tensor(key).numel()
if current_bf16_params + tensor_size < budget_params:
protected_keys.add(key)
current_bf16_params += tensor_size
else:
break
print(f"Protection plan: {len(protected_keys)} layers in BF16 (~{current_bf16_params/1e9:.2f}B params)")
quantized_state_dict = {}
bf16_count = 0
fp8_count = 0
# ComfyUI comfy_quant metadata
quant_meta = {"format": "float8_e4m3fn", "full_precision_matrix_mult": True}
quant_meta_bytes = json.dumps(quant_meta).encode('utf-8')
with safe_open(input_path, framework="pt", device="cpu") as f:
for idx, key in enumerate(all_keys):
tensor = f.get_tensor(key)
# ๋ ์ด์ด ์ด๋ฆ ์ฒ๋ฆฌ (prefix ์ถ์ถ์ฉ)
prefix = key.replace("weight", "")
if key in protected_keys:
quantized_state_dict[key] = tensor.to(torch.bfloat16)
bf16_count += 1
else:
# ComfyUI Native Quantization: weight + weight_scale + comfy_quant
q_weight, q_scale = quantize_comfy_native(tensor)
quantized_state_dict[key] = q_weight
# ops.py expects weight_scale at the same prefix
quantized_state_dict[f"{prefix}weight_scale"] = q_scale
# ๋ฉ๋ชจ๋ฆฌ ๊ณต์ ์ค๋ฅ ๋ฐฉ์ง๋ฅผ ์ํด ๋งค๋ฒ ์๋ก ์์ฑํ๊ฑฐ๋ clone() ์ฌ์ฉ
quantized_state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(quant_meta_bytes), dtype=torch.uint8)
fp8_count += 1
if (idx + 1) % 100 == 0:
print(f"Processed {idx + 1}/{len(all_keys)} tensors...", end='\r')
metadata = {
"quantization": "fp8_e4m3fn",
"format": "comfyui_native_mixed_precision",
"full_precision_matrix_mult": "true"
}
print(f"\nSaving to {output_path}...")
# save_file can be slow for many small tensors, but necessary for safetensors format
save_file(quantized_state_dict, output_path, metadata=metadata)
final_size_gb = os.path.getsize(output_path) / (1024**3)
print(f"\nFinal Results:")
print(f" - BF16 Tensors: {bf16_count}")
print(f" - FP8 (Comfy Native) Tensors: {fp8_count}")
print(f" - Final File Size: {final_size_gb:.2f} GB")
if __name__ == "__main__":
main()
cocorang pinned discussion