File size: 1,809 Bytes
33468f7
 
 
 
 
9f1faf9
33468f7
 
 
9f1faf9
33468f7
 
 
640b3c2
9f1faf9
640b3c2
ce144ee
9f1faf9
33468f7
 
 
 
9f1faf9
34f3335
33468f7
9f1faf9
 
 
640b3c2
 
9f1faf9
 
33468f7
 
9f1faf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33468f7
9f1faf9
 
 
33468f7
9f1faf9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
---
library_name: diffusers
tags:
- fp8
- safetensors
- precision-recovery
- diffusion
- converted-by-gradio
---
# FP8 Model with Precision Recovery
- **Source**: `https://huggingface.co/LifuWang/DistillT5`
- **File**: `model.safetensors`
- **FP8 Format**: `E5M2`
- **Architecture**: all
- **Precision Recovery Type**: LoRA
- **Precision Recovery File**: `model-lora-r64-all.safetensors` if available
- **FP8 File**: `model-fp8-e5m2.safetensors`

## Usage (Inference)
```python
from safetensors.torch import load_file
import torch

# Load FP8 model
fp8_state = load_file("model-fp8-e5m2.safetensors")

# Load precision recovery file if available
recovery_state = {}
if "model-lora-r64-all.safetensors":
    recovery_state = load_file("model-lora-r64-all.safetensors")

# Reconstruct high-precision weights
reconstructed = {}
for key in fp8_state:
    # Dequantize FP8 to target precision
    fp_weight = fp8_state[key].to(torch.float32)
    
    if recovery_state:
        # For LoRA approach
        if f"lora_A.{key}" in recovery_state and f"lora_B.{key}" in recovery_state:
            A = recovery_state[f"lora_A.{key}"].to(torch.float32)
            B = recovery_state[f"lora_B.{key}"].to(torch.float32)
            error_correction = B @ A
            reconstructed[key] = fp_weight + error_correction
        # For correction factor approach
        elif f"correction.{key}" in recovery_state:
            correction = recovery_state[f"correction.{key}"].to(torch.float32)
            reconstructed[key] = fp_weight + correction
        else:
            reconstructed[key] = fp_weight
    else:
        reconstructed[key] = fp_weight

print("Model reconstructed with FP8 error recovery")
```

> **Note**: This precision recovery targets FP8 quantization errors.
> Average quantization error: 0.052733