Spaces:
Runtime error
Runtime error
Float8DynamicActivation quantization
Browse files- optimization.py +4 -0
- requirements.txt +1 -0
optimization.py
CHANGED
|
@@ -8,6 +8,8 @@ from typing import ParamSpec
|
|
| 8 |
import spaces
|
| 9 |
import torch
|
| 10 |
from torch.utils._pytree import tree_map_only
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from optimization_utils import capture_component_call
|
| 13 |
from optimization_utils import aoti_compile
|
|
@@ -46,6 +48,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
| 46 |
|
| 47 |
pipeline.transformer.fuse_qkv_projections()
|
| 48 |
|
|
|
|
|
|
|
| 49 |
exported = torch.export.export(
|
| 50 |
mod=pipeline.transformer,
|
| 51 |
args=call.args,
|
|
|
|
| 8 |
import spaces
|
| 9 |
import torch
|
| 10 |
from torch.utils._pytree import tree_map_only
|
| 11 |
+
from torchao.quantization import quantize_
|
| 12 |
+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
| 13 |
|
| 14 |
from optimization_utils import capture_component_call
|
| 15 |
from optimization_utils import aoti_compile
|
|
|
|
| 48 |
|
| 49 |
pipeline.transformer.fuse_qkv_projections()
|
| 50 |
|
| 51 |
+
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
|
| 52 |
+
|
| 53 |
exported = torch.export.export(
|
| 54 |
mod=pipeline.transformer,
|
| 55 |
args=call.args,
|
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
transformers
|
| 2 |
git+https://github.com/huggingface/diffusers.git
|
| 3 |
accelerate
|
|
|
|
| 1 |
+
torchao
|
| 2 |
transformers
|
| 3 |
git+https://github.com/huggingface/diffusers.git
|
| 4 |
accelerate
|