|
|
--- |
|
|
library_name: diffusers |
|
|
tags: |
|
|
- fp8 |
|
|
- safetensors |
|
|
- quantization |
|
|
- precision-recovery |
|
|
- diffusion |
|
|
- converted-by-gradio |
|
|
--- |
|
|
# FP8 Model with Precision Recovery |
|
|
- **Source**: `https://huggingface.co/LifuWang/DistillT5` |
|
|
- **File**: `model.safetensors` |
|
|
- **FP8 Format**: `E5M2` |
|
|
- **Correction Mode**: per_tensor |
|
|
- **Correction File**: `model-correction.safetensors` |
|
|
- **FP8 File**: `model-fp8-e5m2.safetensors` |
|
|
|
|
|
## Usage (Inference) |
|
|
```python |
|
|
from safetensors.torch import load_file |
|
|
import torch |
|
|
|
|
|
# Load FP8 model and correction factors |
|
|
fp8_state = load_file("model-fp8-e5m2.safetensors") |
|
|
correction_state = load_file("model-correction.safetensors") if os.path.exists("model-correction.safetensors") else {} |
|
|
|
|
|
# Reconstruct high-precision weights |
|
|
reconstructed = {} |
|
|
for key in fp8_state: |
|
|
fp8_weight = fp8_state[key].to(torch.float32) |
|
|
|
|
|
# Apply correction if available |
|
|
correction_key = f"correction.{key}" |
|
|
if correction_key in correction_state: |
|
|
correction = correction_state[correction_key].to(torch.float32) |
|
|
reconstructed[key] = fp8_weight + correction |
|
|
else: |
|
|
reconstructed[key] = fp8_weight |
|
|
|
|
|
# Use reconstructed weights in your model |
|
|
model.load_state_dict(reconstructed) |
|
|
``` |
|
|
|
|
|
## Correction Modes |
|
|
- **Per-Channel**: Computes mean correction per output channel (best for most layers) |
|
|
- **Per-Tensor**: Single correction value per tensor (lightweight) |
|
|
- **None**: No correction (pure FP8) |
|
|
|
|
|
> Requires PyTorch ≥ 2.1 for FP8 support. For best quality, use the correction file during inference. |
|
|
|