| --- |
| library_name: diffusers |
| tags: |
| - fp8 |
| - safetensors |
| - quantization |
| - precision-recovery |
| - diffusion |
| - converted-by-gradio |
| --- |
| # FP8 Model with Precision Recovery |
| - **Source**: `https://huggingface.co/LifuWang/DistillT5` |
| - **File**: `model.safetensors` |
| - **FP8 Format**: `E5M2` |
| - **Correction Mode**: per_channel |
| - **Correction File**: `model-correction.safetensors` |
| - **FP8 File**: `model-fp8-e5m2.safetensors` |
| |
| ## Usage (Inference) |
| ```python |
| from safetensors.torch import load_file |
| import torch |
|
|
| # Load FP8 model and correction factors |
| fp8_state = load_file("model-fp8-e5m2.safetensors") |
| correction_state = load_file("model-correction.safetensors") if os.path.exists("model-correction.safetensors") else {} |
|
|
| # Reconstruct high-precision weights |
| reconstructed = {} |
| for key in fp8_state: |
| fp8_weight = fp8_state[key].to(torch.float32) |
| |
| # Apply correction if available |
| correction_key = f"correction.{key}" |
| if correction_key in correction_state: |
| correction = correction_state[correction_key].to(torch.float32) |
| reconstructed[key] = fp8_weight + correction |
| else: |
| reconstructed[key] = fp8_weight |
| |
| # Use reconstructed weights in your model |
| model.load_state_dict(reconstructed) |
| ``` |
| |
| ## Correction Modes |
| - **Per-Channel**: Computes mean correction per output channel (best for most layers) |
| - **Per-Tensor**: Single correction value per tensor (lightweight) |
| - **None**: No correction (pure FP8) |
| |
| > Requires PyTorch ≥ 2.1 for FP8 support. For best quality, use the correction file during inference. |
| |