codemichaeld commited on
Commit
ce144ee
·
verified ·
1 Parent(s): 93f0319

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +26 -19
README.md CHANGED
@@ -3,41 +3,48 @@ library_name: diffusers
3
  tags:
4
  - fp8
5
  - safetensors
6
- - lora
7
- - low-rank
8
  - diffusion
9
  - converted-by-gradio
10
  ---
11
-
12
- # FP8 Model with Low-Rank LoRA
13
-
14
  - **Source**: `https://huggingface.co/LifuWang/DistillT5`
15
  - **File**: `model.safetensors`
16
  - **FP8 Format**: `E5M2`
17
- - **LoRA Rank**: 64
18
- - **LoRA File**: `model-lora-r64.safetensors`
 
19
 
20
  ## Usage (Inference)
21
-
22
  ```python
23
  from safetensors.torch import load_file
24
  import torch
25
 
26
- # Load FP8 model
27
  fp8_state = load_file("model-fp8-e5m2.safetensors")
28
- lora_state = load_file("model-lora-r64.safetensors")
29
 
30
- # Reconstruct approximate original weights
31
  reconstructed = {}
32
  for key in fp8_state:
33
- if f"lora_A.{key}" in lora_state and f"lora_B.{key}" in lora_state:
34
- A = lora_state[f"lora_A.{key}"].to(torch.float32)
35
- B = lora_state[f"lora_B.{key}"].to(torch.float32)
36
- lora_weight = B @ A # (rank, out) @ (in, rank) -> (out, in)
37
- fp8_weight = fp8_state[key].to(torch.float32)
38
- reconstructed[key] = fp8_weight + lora_weight
 
39
  else:
40
- reconstructed[key] = fp8_state[key].to(torch.float32)
 
 
 
41
  ```
42
 
43
- > Requires PyTorch ≥ 2.1 for FP8 support.
 
 
 
 
 
 
3
  tags:
4
  - fp8
5
  - safetensors
6
+ - quantization
7
+ - precision-recovery
8
  - diffusion
9
  - converted-by-gradio
10
  ---
11
+ # FP8 Model with Precision Recovery
 
 
12
  - **Source**: `https://huggingface.co/LifuWang/DistillT5`
13
  - **File**: `model.safetensors`
14
  - **FP8 Format**: `E5M2`
15
+ - **Correction Mode**: per_tensor
16
+ - **Correction File**: `model-correction.safetensors`
17
+ - **FP8 File**: `model-fp8-e5m2.safetensors`
18
 
19
  ## Usage (Inference)
 
20
  ```python
21
  from safetensors.torch import load_file
22
  import torch
23
 
24
+ # Load FP8 model and correction factors
25
  fp8_state = load_file("model-fp8-e5m2.safetensors")
26
+ correction_state = load_file("model-correction.safetensors") if os.path.exists("model-correction.safetensors") else {}
27
 
28
+ # Reconstruct high-precision weights
29
  reconstructed = {}
30
  for key in fp8_state:
31
+ fp8_weight = fp8_state[key].to(torch.float32)
32
+
33
+ # Apply correction if available
34
+ correction_key = f"correction.{key}"
35
+ if correction_key in correction_state:
36
+ correction = correction_state[correction_key].to(torch.float32)
37
+ reconstructed[key] = fp8_weight + correction
38
  else:
39
+ reconstructed[key] = fp8_weight
40
+
41
+ # Use reconstructed weights in your model
42
+ model.load_state_dict(reconstructed)
43
  ```
44
 
45
+ ## Correction Modes
46
+ - **Per-Channel**: Computes mean correction per output channel (best for most layers)
47
+ - **Per-Tensor**: Single correction value per tensor (lightweight)
48
+ - **None**: No correction (pure FP8)
49
+
50
+ > Requires PyTorch ≥ 2.1 for FP8 support. For best quality, use the correction file during inference.