How did you create the FP8 mix?

#3
by Nuke1229 - opened

Hi, thanks for sharing. I'm interested in the FP8mix version. Could you please explain how you created FP8mix?

I used that method.

import torch
from safetensors.torch import save_file
from safetensors import safe_open
import os
import gc
import json

def load_sensitivity_scores(path="sensitivity_report.json"):
    if not os.path.exists(path):
        return []
    with open(path, "r") as f:
        return json.load(f)

def is_base_important(key: str) -> bool:
    if any(x in key for x in ['img_in', 'txt_in', 'proj_out']):
        return True
    if 'norm' in key or key.endswith('.bias'):
        return True
    return False

def quantize_comfy_native(tensor: torch.Tensor):
    """
    ComfyUI ops.py ๊ทœ๊ฒฉ์— ๋งž๊ฒŒ weight์™€ weight_scale์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
    """
    if tensor.dtype == torch.bfloat16:
        tensor = tensor.float()
    
    max_val = tensor.abs().max().item()
    if max_val == 0:
        return torch.zeros_like(tensor, dtype=torch.float8_e4m3fn), torch.tensor(1.0, dtype=torch.bfloat16)
    
    # FP8_e4m3fn max is 448.0
    # weight_scale = max_val / 448.0
    # weight = tensor / weight_scale (casted to fp8)
    scale = max_val / 448.0
    normalized = tensor / scale
    quantized_weight = normalized.to(torch.float8_e4m3fn)
    weight_scale = torch.tensor(scale, dtype=torch.bfloat16)
    
    return quantized_weight, weight_scale

def main():
    input_path = "FireRed-Image-Edit-1.0_bf16.safetensors"
    output_path = "FireRed-Image-Edit-1.0_fp8_comfy.safetensors"
    
    scores = load_sensitivity_scores()
    
    print(f"Selecting layers for 22GB budget (ComfyUI Native Format)...")
    
    protected_keys = set()
    with safe_open(input_path, framework="pt", device="cpu") as f:
        all_keys = f.keys()
        for key in all_keys:
            if is_base_important(key):
                protected_keys.add(key)
        
        total_params = sum(f.get_tensor(k).numel() for k in all_keys)
        # weight_scale(BF16, 2bytes)์ด ์ถ”๊ฐ€๋˜๋ฏ€๋กœ ์˜ˆ์‚ฐ์„ ์กฐ๊ธˆ ๋” ํƒ€์ดํŠธํ•˜๊ฒŒ ์žก์Œ
        # FP8(1b) + Scale(2b/tensor) ์ด๋ฏ€๋กœ ํ…์„œ ๊ฐœ์ˆ˜๊ฐ€ ๋งŽ์œผ๋ฉด ์šฉ๋Ÿ‰์ด ์•ฝ๊ฐ„ ๋Š˜์–ด๋‚จ
        budget_params = total_params * 0.125 
        
        current_bf16_params = 0
        for key in protected_keys:
            current_bf16_params += f.get_tensor(key).numel()
            
        for key, score, stats in scores:
            if key in protected_keys: continue
            tensor_size = f.get_tensor(key).numel()
            if current_bf16_params + tensor_size < budget_params:
                protected_keys.add(key)
                current_bf16_params += tensor_size
            else:
                break

    print(f"Protection plan: {len(protected_keys)} layers in BF16 (~{current_bf16_params/1e9:.2f}B params)")
    
    quantized_state_dict = {}
    bf16_count = 0
    fp8_count = 0
    
    # ComfyUI comfy_quant metadata
    quant_meta = {"format": "float8_e4m3fn", "full_precision_matrix_mult": True}
    quant_meta_bytes = json.dumps(quant_meta).encode('utf-8')

    with safe_open(input_path, framework="pt", device="cpu") as f:
        for idx, key in enumerate(all_keys):
            tensor = f.get_tensor(key)
            
            # ๋ ˆ์ด์–ด ์ด๋ฆ„ ์ฒ˜๋ฆฌ (prefix ์ถ”์ถœ์šฉ)
            prefix = key.replace("weight", "")
            
            if key in protected_keys:
                quantized_state_dict[key] = tensor.to(torch.bfloat16)
                bf16_count += 1
            else:
                # ComfyUI Native Quantization: weight + weight_scale + comfy_quant
                q_weight, q_scale = quantize_comfy_native(tensor)
                
                quantized_state_dict[key] = q_weight
                # ops.py expects weight_scale at the same prefix
                quantized_state_dict[f"{prefix}weight_scale"] = q_scale
                # ๋ฉ”๋ชจ๋ฆฌ ๊ณต์œ  ์˜ค๋ฅ˜ ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด ๋งค๋ฒˆ ์ƒˆ๋กœ ์ƒ์„ฑํ•˜๊ฑฐ๋‚˜ clone() ์‚ฌ์šฉ
                quantized_state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(quant_meta_bytes), dtype=torch.uint8)
                fp8_count += 1
            
            if (idx + 1) % 100 == 0:
                print(f"Processed {idx + 1}/{len(all_keys)} tensors...", end='\r')

    metadata = {
        "quantization": "fp8_e4m3fn",
        "format": "comfyui_native_mixed_precision",
        "full_precision_matrix_mult": "true"
    }
    
    print(f"\nSaving to {output_path}...")
    # save_file can be slow for many small tensors, but necessary for safetensors format
    save_file(quantized_state_dict, output_path, metadata=metadata)
    
    final_size_gb = os.path.getsize(output_path) / (1024**3)
    print(f"\nFinal Results:")
    print(f"  - BF16 Tensors: {bf16_count}")
    print(f"  - FP8 (Comfy Native) Tensors: {fp8_count}")
    print(f"  - Final File Size: {final_size_gb:.2f} GB")

if __name__ == "__main__":
    main()
cocorang pinned discussion

Sign up or log in to comment