English
John6666 commited on
Commit
da6f0ba
·
verified ·
1 Parent(s): 3a1fdfd

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -16
handler.py CHANGED
@@ -29,6 +29,7 @@ IS_MGPU = False
29
  IS_LVRAM = False
30
  IS_COMPILE = True
31
  IS_WARM = True
 
32
  IS_AUTOQ = False
33
  IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
34
  IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
@@ -41,11 +42,11 @@ if IS_COMPILE:
41
  import torch._dynamo
42
  torch._dynamo.config.suppress_errors = False
43
  #torch._dynamo.config.suppress_errors = True
44
- torch._inductor.config.disable_progress = False
45
- torch._inductor.config.conv_1x1_as_mm = True
46
- torch._inductor.config.coordinate_descent_tuning = True
47
- torch._inductor.config.coordinate_descent_check_all_directions = True
48
- torch._inductor.config.epilogue_fusion = False
49
 
50
  if IS_MGPU:
51
  import torch.distributed as dist
@@ -211,19 +212,21 @@ def load_pipeline_fast(repo_id: str, dtype: torch.dtype) -> Any:
211
  pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
212
  pipe.enable_vae_slicing()
213
  pipe.enable_vae_tiling()
214
- apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
215
- if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
216
- elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
217
- #pipe.transformer.fuse_qkv_projections()
218
  pipe.transformer.to(memory_format=torch.channels_last)
219
- #pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
220
- #pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
221
- if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
222
- elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
223
- #pipe.vae.fuse_qkv_projections()
224
  pipe.vae.to(memory_format=torch.channels_last)
225
- #pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
226
- #pipe.vae = torch.compile(pipe.vae, mode="max-autotune-no-cudagraphs")
 
 
 
 
 
 
 
 
 
227
  return pipe
228
 
229
  class EndpointHandler:
 
29
  IS_LVRAM = False
30
  IS_COMPILE = True
31
  IS_WARM = True
32
+ IS_QUANT = True
33
  IS_AUTOQ = False
34
  IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
35
  IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
 
42
  import torch._dynamo
43
  torch._dynamo.config.suppress_errors = False
44
  #torch._dynamo.config.suppress_errors = True
45
+ #torch._inductor.config.disable_progress = False
46
+ #torch._inductor.config.conv_1x1_as_mm = True
47
+ #torch._inductor.config.coordinate_descent_tuning = True
48
+ #torch._inductor.config.coordinate_descent_check_all_directions = True
49
+ #torch._inductor.config.epilogue_fusion = False
50
 
51
  if IS_MGPU:
52
  import torch.distributed as dist
 
212
  pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
213
  pipe.enable_vae_slicing()
214
  pipe.enable_vae_tiling()
215
+ pipe.transformer.fuse_qkv_projections()
216
+ pipe.vae.fuse_qkv_projections()
 
 
217
  pipe.transformer.to(memory_format=torch.channels_last)
 
 
 
 
 
218
  pipe.vae.to(memory_format=torch.channels_last)
219
+ apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
220
+ if IS_QUANT:
221
+ int8_dynamic_activation_int4_weight()
222
+ quantize_(pipe.text_encoder, int8_dynamic_activation_int8_weight())
223
+ quantize_(pipe.text_encoder_2, int8_dynamic_activation_int8_weight())
224
+ if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
225
+ elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
226
+ else: quantize_(pipe.vae, int8_dynamic_activation_int4_weight())
227
+ if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
228
+ elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
229
+ else: quantize_(pipe.vae, int8_dynamic_activation_int8_weight())
230
  return pipe
231
 
232
  class EndpointHandler: