Upload 2 files
Browse files- handler.py +8 -8
- requirements.txt +1 -1
handler.py
CHANGED
|
@@ -9,7 +9,7 @@ import time
|
|
| 9 |
from PIL import Image
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
import torch
|
| 12 |
-
from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight, int8_dynamic_activation_int8_weight
|
| 13 |
from torchao.quantization.quant_api import PerRow
|
| 14 |
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
|
| 15 |
from transformers import T5EncoderModel
|
|
@@ -30,7 +30,7 @@ IS_LVRAM = False
|
|
| 30 |
IS_COMPILE = True
|
| 31 |
IS_WARM = True
|
| 32 |
IS_QUANT = True
|
| 33 |
-
IS_AUTOQ =
|
| 34 |
IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
|
| 35 |
IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
|
| 36 |
|
|
@@ -42,7 +42,7 @@ if IS_COMPILE:
|
|
| 42 |
import torch._dynamo
|
| 43 |
torch._dynamo.config.suppress_errors = False
|
| 44 |
#torch._dynamo.config.suppress_errors = True
|
| 45 |
-
|
| 46 |
#torch._inductor.config.conv_1x1_as_mm = True
|
| 47 |
#torch._inductor.config.coordinate_descent_tuning = True
|
| 48 |
#torch._inductor.config.coordinate_descent_check_all_directions = True
|
|
@@ -217,14 +217,14 @@ def load_pipeline_fast(repo_id: str, dtype: torch.dtype) -> Any:
|
|
| 217 |
pipe.transformer.to(memory_format=torch.channels_last)
|
| 218 |
pipe.vae.to(memory_format=torch.channels_last)
|
| 219 |
if IS_QUANT and not IS_AUTOQ:
|
| 220 |
-
quantize_(pipe.text_encoder,
|
| 221 |
-
quantize_(pipe.text_encoder_2,
|
| 222 |
if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
| 223 |
elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
|
| 224 |
-
else: quantize_(pipe.vae,
|
| 225 |
if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
| 226 |
elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
|
| 227 |
-
else: quantize_(pipe.vae,
|
| 228 |
return pipe
|
| 229 |
|
| 230 |
class EndpointHandler:
|
|
@@ -261,7 +261,7 @@ class EndpointHandler:
|
|
| 261 |
end = time.time()
|
| 262 |
print(f'Compiled in {end - start:.3f} sec.')
|
| 263 |
|
| 264 |
-
def __call__(self, data: Dict[str, Any]) ->
|
| 265 |
logger.info(f"Received incoming request with {data=}")
|
| 266 |
|
| 267 |
if "inputs" in data and isinstance(data["inputs"], str):
|
|
|
|
| 9 |
from PIL import Image
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
import torch
|
| 12 |
+
from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight, int8_dynamic_activation_int8_weight, int8_weight_only
|
| 13 |
from torchao.quantization.quant_api import PerRow
|
| 14 |
from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
|
| 15 |
from transformers import T5EncoderModel
|
|
|
|
| 30 |
IS_COMPILE = True
|
| 31 |
IS_WARM = True
|
| 32 |
IS_QUANT = True
|
| 33 |
+
IS_AUTOQ = False
|
| 34 |
IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
|
| 35 |
IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
|
| 36 |
|
|
|
|
| 42 |
import torch._dynamo
|
| 43 |
torch._dynamo.config.suppress_errors = False
|
| 44 |
#torch._dynamo.config.suppress_errors = True
|
| 45 |
+
torch._inductor.config.disable_progress = False
|
| 46 |
#torch._inductor.config.conv_1x1_as_mm = True
|
| 47 |
#torch._inductor.config.coordinate_descent_tuning = True
|
| 48 |
#torch._inductor.config.coordinate_descent_check_all_directions = True
|
|
|
|
| 217 |
pipe.transformer.to(memory_format=torch.channels_last)
|
| 218 |
pipe.vae.to(memory_format=torch.channels_last)
|
| 219 |
if IS_QUANT and not IS_AUTOQ:
|
| 220 |
+
quantize_(pipe.text_encoder, int8_weight_only())
|
| 221 |
+
quantize_(pipe.text_encoder_2, int8_weight_only())
|
| 222 |
if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
| 223 |
elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
|
| 224 |
+
else: quantize_(pipe.vae, int8_weight_only())
|
| 225 |
if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
| 226 |
elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
|
| 227 |
+
else: quantize_(pipe.vae, int8_weight_only())
|
| 228 |
return pipe
|
| 229 |
|
| 230 |
class EndpointHandler:
|
|
|
|
| 261 |
end = time.time()
|
| 262 |
print(f'Compiled in {end - start:.3f} sec.')
|
| 263 |
|
| 264 |
+
def __call__(self, data: Dict[str, Any]) -> str:
|
| 265 |
logger.info(f"Received incoming request with {data=}")
|
| 266 |
|
| 267 |
if "inputs" in data and isinstance(data["inputs"], str):
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
--extra-index-url https://download.pytorch.org/whl/
|
| 2 |
torch==2.6.0
|
| 3 |
torchvision
|
| 4 |
torchaudio
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 2 |
torch==2.6.0
|
| 3 |
torchvision
|
| 4 |
torchaudio
|