English
John6666 commited on
Commit
9378ccb
·
verified ·
1 Parent(s): cc5f48e

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +8 -8
  2. requirements.txt +1 -1
handler.py CHANGED
@@ -9,7 +9,7 @@ import time
9
  from PIL import Image
10
  from huggingface_hub import hf_hub_download
11
  import torch
12
- from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight, int8_dynamic_activation_int8_weight
13
  from torchao.quantization.quant_api import PerRow
14
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
15
  from transformers import T5EncoderModel
@@ -30,7 +30,7 @@ IS_LVRAM = False
30
  IS_COMPILE = True
31
  IS_WARM = True
32
  IS_QUANT = True
33
- IS_AUTOQ = True
34
  IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
35
  IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
36
 
@@ -42,7 +42,7 @@ if IS_COMPILE:
42
  import torch._dynamo
43
  torch._dynamo.config.suppress_errors = False
44
  #torch._dynamo.config.suppress_errors = True
45
- #torch._inductor.config.disable_progress = False
46
  #torch._inductor.config.conv_1x1_as_mm = True
47
  #torch._inductor.config.coordinate_descent_tuning = True
48
  #torch._inductor.config.coordinate_descent_check_all_directions = True
@@ -217,14 +217,14 @@ def load_pipeline_fast(repo_id: str, dtype: torch.dtype) -> Any:
217
  pipe.transformer.to(memory_format=torch.channels_last)
218
  pipe.vae.to(memory_format=torch.channels_last)
219
  if IS_QUANT and not IS_AUTOQ:
220
- quantize_(pipe.text_encoder, int8_dynamic_activation_int8_weight())
221
- quantize_(pipe.text_encoder_2, int8_dynamic_activation_int8_weight())
222
  if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
223
  elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
224
- else: quantize_(pipe.vae, int8_dynamic_activation_int4_weight())
225
  if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
226
  elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
227
- else: quantize_(pipe.vae, int8_dynamic_activation_int8_weight())
228
  return pipe
229
 
230
  class EndpointHandler:
@@ -261,7 +261,7 @@ class EndpointHandler:
261
  end = time.time()
262
  print(f'Compiled in {end - start:.3f} sec.')
263
 
264
- def __call__(self, data: Dict[str, Any]) -> Image.Image:
265
  logger.info(f"Received incoming request with {data=}")
266
 
267
  if "inputs" in data and isinstance(data["inputs"], str):
 
9
  from PIL import Image
10
  from huggingface_hub import hf_hub_download
11
  import torch
12
+ from torchao.quantization import quantize_, autoquant, int8_dynamic_activation_int4_weight, float8_dynamic_activation_float8_weight, int8_dynamic_activation_int8_weight, int8_weight_only
13
  from torchao.quantization.quant_api import PerRow
14
  from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKL, TorchAoConfig
15
  from transformers import T5EncoderModel
 
30
  IS_COMPILE = True
31
  IS_WARM = True
32
  IS_QUANT = True
33
+ IS_AUTOQ = False
34
  IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
35
  IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
36
 
 
42
  import torch._dynamo
43
  torch._dynamo.config.suppress_errors = False
44
  #torch._dynamo.config.suppress_errors = True
45
+ torch._inductor.config.disable_progress = False
46
  #torch._inductor.config.conv_1x1_as_mm = True
47
  #torch._inductor.config.coordinate_descent_tuning = True
48
  #torch._inductor.config.coordinate_descent_check_all_directions = True
 
217
  pipe.transformer.to(memory_format=torch.channels_last)
218
  pipe.vae.to(memory_format=torch.channels_last)
219
  if IS_QUANT and not IS_AUTOQ:
220
+ quantize_(pipe.text_encoder, int8_weight_only())
221
+ quantize_(pipe.text_encoder_2, int8_weight_only())
222
  if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
223
  elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
224
+ else: quantize_(pipe.vae, int8_weight_only())
225
  if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
226
  elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
227
+ else: quantize_(pipe.vae, int8_weight_only())
228
  return pipe
229
 
230
  class EndpointHandler:
 
261
  end = time.time()
262
  print(f'Compiled in {end - start:.3f} sec.')
263
 
264
+ def __call__(self, data: Dict[str, Any]) -> str:
265
  logger.info(f"Received incoming request with {data=}")
266
 
267
  if "inputs" in data and isinstance(data["inputs"], str):
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu126
2
  torch==2.6.0
3
  torchvision
4
  torchaudio
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
  torch==2.6.0
3
  torchvision
4
  torchaudio