File size: 1,538 Bytes
33468f7
 
 
 
 
ce144ee
 
33468f7
 
 
ce144ee
33468f7
 
 
ce144ee
 
 
33468f7
 
 
 
 
 
ce144ee
33468f7
ce144ee
33468f7
ce144ee
33468f7
 
ce144ee
 
 
 
 
 
 
33468f7
ce144ee
 
 
 
33468f7
 
ce144ee
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
---
library_name: diffusers
tags:
- fp8
- safetensors
- quantization
- precision-recovery
- diffusion
- converted-by-gradio
---
# FP8 Model with Precision Recovery
- **Source**: `https://huggingface.co/LifuWang/DistillT5`
- **File**: `model.safetensors`
- **FP8 Format**: `E5M2`
- **Correction Mode**: per_tensor
- **Correction File**: `model-correction.safetensors`
- **FP8 File**: `model-fp8-e5m2.safetensors`

## Usage (Inference)
```python
from safetensors.torch import load_file
import torch

# Load FP8 model and correction factors
fp8_state = load_file("model-fp8-e5m2.safetensors")
correction_state = load_file("model-correction.safetensors") if os.path.exists("model-correction.safetensors") else {}

# Reconstruct high-precision weights
reconstructed = {}
for key in fp8_state:
    fp8_weight = fp8_state[key].to(torch.float32)
    
    # Apply correction if available
    correction_key = f"correction.{key}"
    if correction_key in correction_state:
        correction = correction_state[correction_key].to(torch.float32)
        reconstructed[key] = fp8_weight + correction
    else:
        reconstructed[key] = fp8_weight

# Use reconstructed weights in your model
model.load_state_dict(reconstructed)
```

## Correction Modes
- **Per-Channel**: Computes mean correction per output channel (best for most layers)
- **Per-Tensor**: Single correction value per tensor (lightweight)
- **None**: No correction (pure FP8)

> Requires PyTorch ≥ 2.1 for FP8 support. For best quality, use the correction file during inference.