T5Base_fp8 / README.md
codemichaeld's picture
Upload README.md with huggingface_hub
ce144ee verified
|
raw
history blame
1.54 kB
metadata
library_name: diffusers
tags:
  - fp8
  - safetensors
  - quantization
  - precision-recovery
  - diffusion
  - converted-by-gradio

FP8 Model with Precision Recovery

  • Source: https://huggingface.co/LifuWang/DistillT5
  • File: model.safetensors
  • FP8 Format: E5M2
  • Correction Mode: per_tensor
  • Correction File: model-correction.safetensors
  • FP8 File: model-fp8-e5m2.safetensors

Usage (Inference)

from safetensors.torch import load_file
import torch

# Load FP8 model and correction factors
fp8_state = load_file("model-fp8-e5m2.safetensors")
correction_state = load_file("model-correction.safetensors") if os.path.exists("model-correction.safetensors") else {}

# Reconstruct high-precision weights
reconstructed = {}
for key in fp8_state:
    fp8_weight = fp8_state[key].to(torch.float32)
    
    # Apply correction if available
    correction_key = f"correction.{key}"
    if correction_key in correction_state:
        correction = correction_state[correction_key].to(torch.float32)
        reconstructed[key] = fp8_weight + correction
    else:
        reconstructed[key] = fp8_weight

# Use reconstructed weights in your model
model.load_state_dict(reconstructed)

Correction Modes

  • Per-Channel: Computes mean correction per output channel (best for most layers)
  • Per-Tensor: Single correction value per tensor (lightweight)
  • None: No correction (pure FP8)

Requires PyTorch ≥ 2.1 for FP8 support. For best quality, use the correction file during inference.