Update src/pipeline.py
Browse files- src/pipeline.py +38 -2
src/pipeline.py
CHANGED
|
@@ -14,6 +14,12 @@ import time
|
|
| 14 |
from diffusers import FluxTransformer2DModel, DiffusionPipeline
|
| 15 |
# from torchao.quantization import quantize_,int8_weight_only
|
| 16 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
|
| 18 |
Pipeline = None
|
| 19 |
|
|
@@ -44,8 +50,38 @@ def load_pipeline() -> Pipeline:
|
|
| 44 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 45 |
torch.cuda.set_per_process_memory_fraction(0.99)
|
| 46 |
pipeline.text_encoder.to(memory_format=torch.channels_last)
|
| 47 |
-
pipeline.transformer.to(memory_format=torch.channels_last)
|
| 48 |
-
quantize_dynamic(pipeline.transformer, dtype=torch.float8_e5m2fnuz, inplace=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
pipeline.vae.to(memory_format=torch.channels_last)
|
| 51 |
pipeline.vae = torch.compile(pipeline.vae)
|
|
|
|
| 14 |
from diffusers import FluxTransformer2DModel, DiffusionPipeline
|
| 15 |
# from torchao.quantization import quantize_,int8_weight_only
|
| 16 |
import os
|
| 17 |
+
|
| 18 |
+
from torch.ao.quantization import prepare, convert
|
| 19 |
+
from torch.ao.quantization import QConfig
|
| 20 |
+
from torch.ao.quantization.observer import MinMaxObserver
|
| 21 |
+
from torch.ao.quantization.quantize import quantize_dynamic
|
| 22 |
+
|
| 23 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
|
| 24 |
Pipeline = None
|
| 25 |
|
|
|
|
| 50 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 51 |
torch.cuda.set_per_process_memory_fraction(0.99)
|
| 52 |
pipeline.text_encoder.to(memory_format=torch.channels_last)
|
| 53 |
+
# pipeline.transformer.to(memory_format=torch.channels_last)
|
| 54 |
+
# quantize_dynamic(pipeline.transformer, dtype=torch.float8_e5m2fnuz, inplace=True)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Define a custom qconfig for float8_e5m2fnuz
|
| 58 |
+
float8_observer = MinMaxObserver.with_args(dtype=torch.float8_e5m2fnuz)
|
| 59 |
+
custom_qconfig = QConfig(
|
| 60 |
+
activation=float8_observer,
|
| 61 |
+
weight=float8_observer
|
| 62 |
+
)
|
| 63 |
+
qconfig_spec = {
|
| 64 |
+
"linear": custom_qconfig,
|
| 65 |
+
"linear_1": custom_qconfig,
|
| 66 |
+
"linear_2": custom_qconfig,
|
| 67 |
+
"to_q": custom_qconfig,
|
| 68 |
+
"to_k": custom_qconfig,
|
| 69 |
+
"to_v": custom_qconfig,
|
| 70 |
+
"add_k_proj": custom_qconfig,
|
| 71 |
+
"add_v_proj": custom_qconfig,
|
| 72 |
+
"add_q_proj": custom_qconfig,
|
| 73 |
+
"proj": custom_qconfig,
|
| 74 |
+
"proj_mlp": custom_qconfig,
|
| 75 |
+
"proj_out": custom_qconfig
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# Apply dynamic quantization to Transformer
|
| 79 |
+
pipeline.transformer = quantize_dynamic(
|
| 80 |
+
pipeline.transformer,
|
| 81 |
+
qconfig_spec=qconfig_spec, # Apply qconfig only to transformer layers
|
| 82 |
+
dtype=torch.float8_e5m2fnuz,
|
| 83 |
+
inplace=True,
|
| 84 |
+
)
|
| 85 |
|
| 86 |
pipeline.vae.to(memory_format=torch.channels_last)
|
| 87 |
pipeline.vae = torch.compile(pipeline.vae)
|