codemichaeld commited on
Commit
9f1faf9
·
verified ·
1 Parent(s): 2bb0e42

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +35 -15
README.md CHANGED
@@ -3,36 +3,56 @@ library_name: diffusers
3
  tags:
4
  - fp8
5
  - safetensors
6
- - lora
7
- - low-rank
8
  - diffusion
9
  - converted-by-gradio
10
  ---
11
- # FP8 Model with Low-Rank LoRA
12
  - **Source**: `https://huggingface.co/LifuWang/DistillT5`
13
  - **File**: `model.safetensors`
14
  - **FP8 Format**: `E5M2`
15
- - **LoRA Rank**: 128
16
  - **Architecture**: text_encoder
17
- - **LoRA File**: `model-lora-r128.safetensors`
 
18
  - **FP8 File**: `model-fp8-e5m2.safetensors`
 
19
  ## Usage (Inference)
20
  ```python
21
  from safetensors.torch import load_file
22
  import torch
 
23
  # Load FP8 model
24
  fp8_state = load_file("model-fp8-e5m2.safetensors")
25
- lora_state = load_file("model-lora-r128.safetensors")
26
- # Reconstruct approximate original weights
 
 
 
 
 
27
  reconstructed = {}
28
  for key in fp8_state:
29
- if f"lora_A.{key}" in lora_state and f"lora_B.{key}" in lora_state:
30
- A = lora_state[f"lora_A.{key}"].to(torch.float32)
31
- B = lora_state[f"lora_B.{key}"].to(torch.float32)
32
- lora_weight = B @ A # (out_features, rank) @ (rank, in_features) -> (out_features, in_features)
33
- fp8_weight = fp8_state[key].to(torch.float32)
34
- reconstructed[key] = fp8_weight + lora_weight
 
 
 
 
 
 
 
 
 
 
35
  else:
36
- reconstructed[key] = fp8_state[key].to(torch.float32)
 
 
37
  ```
38
- > Requires PyTorch ≥ 2.1 for FP8 support.
 
 
 
3
  tags:
4
  - fp8
5
  - safetensors
6
+ - precision-recovery
 
7
  - diffusion
8
  - converted-by-gradio
9
  ---
10
+ # FP8 Model with Precision Recovery
11
  - **Source**: `https://huggingface.co/LifuWang/DistillT5`
12
  - **File**: `model.safetensors`
13
  - **FP8 Format**: `E5M2`
 
14
  - **Architecture**: text_encoder
15
+ - **Precision Recovery Type**: LoRA
16
+ - **Precision Recovery File**: `model-lora-r64-text_encoder.safetensors` if available
17
  - **FP8 File**: `model-fp8-e5m2.safetensors`
18
+
19
  ## Usage (Inference)
20
  ```python
21
  from safetensors.torch import load_file
22
  import torch
23
+
24
  # Load FP8 model
25
  fp8_state = load_file("model-fp8-e5m2.safetensors")
26
+
27
+ # Load precision recovery file if available
28
+ recovery_state = {}
29
+ if "model-lora-r64-text_encoder.safetensors":
30
+ recovery_state = load_file("model-lora-r64-text_encoder.safetensors")
31
+
32
+ # Reconstruct high-precision weights
33
  reconstructed = {}
34
  for key in fp8_state:
35
+ # Dequantize FP8 to target precision
36
+ fp_weight = fp8_state[key].to(torch.float32)
37
+
38
+ if recovery_state:
39
+ # For LoRA approach
40
+ if f"lora_A.{key}" in recovery_state and f"lora_B.{key}" in recovery_state:
41
+ A = recovery_state[f"lora_A.{key}"].to(torch.float32)
42
+ B = recovery_state[f"lora_B.{key}"].to(torch.float32)
43
+ error_correction = B @ A
44
+ reconstructed[key] = fp_weight + error_correction
45
+ # For correction factor approach
46
+ elif f"correction.{key}" in recovery_state:
47
+ correction = recovery_state[f"correction.{key}"].to(torch.float32)
48
+ reconstructed[key] = fp_weight + correction
49
+ else:
50
+ reconstructed[key] = fp_weight
51
  else:
52
+ reconstructed[key] = fp_weight
53
+
54
+ print("Model reconstructed with FP8 error recovery")
55
  ```
56
+
57
+ > **Note**: This precision recovery targets FP8 quantization errors.
58
+ > Average quantization error: 0.052733