English
John6666 commited on
Commit
5eb864d
·
verified ·
1 Parent(s): da6f0ba

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +6 -4
handler.py CHANGED
@@ -30,7 +30,7 @@ IS_LVRAM = False
30
  IS_COMPILE = True
31
  IS_WARM = True
32
  IS_QUANT = True
33
- IS_AUTOQ = False
34
  IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
35
  IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
36
 
@@ -217,8 +217,7 @@ def load_pipeline_fast(repo_id: str, dtype: torch.dtype) -> Any:
217
  pipe.transformer.to(memory_format=torch.channels_last)
218
  pipe.vae.to(memory_format=torch.channels_last)
219
  apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
220
- if IS_QUANT:
221
- int8_dynamic_activation_int4_weight()
222
  quantize_(pipe.text_encoder, int8_dynamic_activation_int8_weight())
223
  quantize_(pipe.text_encoder_2, int8_dynamic_activation_int8_weight())
224
  if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
@@ -237,7 +236,7 @@ class EndpointHandler:
237
  #dtype = torch.float16 # for older nVidia GPUs
238
  print_vram()
239
  print("Loading pipeline...")
240
- if IS_AUTOQ: self.pipeline = load_pipeline_autoquant(repo_id, dtype)
241
  elif IS_COMPILE: self.pipeline = load_pipeline_fast(repo_id, dtype)
242
  elif IS_LVRAM and IS_CC89: self.pipeline = load_pipeline_lowvram(repo_id, dtype)
243
  else: self.pipeline = load_pipeline_stable(repo_id, dtype)
@@ -250,6 +249,9 @@ class EndpointHandler:
250
  print("Compiling pipeline...")
251
  self.pipeline.transformer = torch.compile(self.pipeline.transformer, mode="max-autotune-no-cudagraphs")
252
  self.pipeline.vae = torch.compile(self.pipeline.vae, mode="max-autotune-no-cudagraphs")
 
 
 
253
  gc.collect()
254
  torch.cuda.empty_cache()
255
  print_vram()
 
30
  IS_COMPILE = True
31
  IS_WARM = True
32
  IS_QUANT = True
33
+ IS_AUTOQ = True
34
  IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
35
  IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
36
 
 
217
  pipe.transformer.to(memory_format=torch.channels_last)
218
  pipe.vae.to(memory_format=torch.channels_last)
219
  apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
220
+ if IS_QUANT and not IS_AUTOQ:
 
221
  quantize_(pipe.text_encoder, int8_dynamic_activation_int8_weight())
222
  quantize_(pipe.text_encoder_2, int8_dynamic_activation_int8_weight())
223
  if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
 
236
  #dtype = torch.float16 # for older nVidia GPUs
237
  print_vram()
238
  print("Loading pipeline...")
239
+ if IS_AUTOQ: self.pipeline = load_pipeline_fast(repo_id, dtype)
240
  elif IS_COMPILE: self.pipeline = load_pipeline_fast(repo_id, dtype)
241
  elif IS_LVRAM and IS_CC89: self.pipeline = load_pipeline_lowvram(repo_id, dtype)
242
  else: self.pipeline = load_pipeline_stable(repo_id, dtype)
 
249
  print("Compiling pipeline...")
250
  self.pipeline.transformer = torch.compile(self.pipeline.transformer, mode="max-autotune-no-cudagraphs")
251
  self.pipeline.vae = torch.compile(self.pipeline.vae, mode="max-autotune-no-cudagraphs")
252
+ if IS_AUTOQ:
253
+ self.pipeline.transformer = autoquant(self.pipeline.transformer, error_on_unseen=False)
254
+ self.pipeline.vae = autoquant(self.pipeline.vae, error_on_unseen=False)
255
  gc.collect()
256
  torch.cuda.empty_cache()
257
  print_vram()