English
John6666 commited on
Commit
3a1fdfd
·
verified ·
1 Parent(s): bb44a1f

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -4
handler.py CHANGED
@@ -217,13 +217,13 @@ def load_pipeline_fast(repo_id: str, dtype: torch.dtype) -> Any:
217
  #pipe.transformer.fuse_qkv_projections()
218
  pipe.transformer.to(memory_format=torch.channels_last)
219
  #pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
220
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
221
  if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
222
  elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
223
  #pipe.vae.fuse_qkv_projections()
224
  pipe.vae.to(memory_format=torch.channels_last)
225
  #pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
226
- pipe.vae = torch.compile(pipe.vae, mode="max-autotune-no-cudagraphs")
227
  return pipe
228
 
229
  class EndpointHandler:
@@ -238,11 +238,15 @@ class EndpointHandler:
238
  elif IS_COMPILE: self.pipeline = load_pipeline_fast(repo_id, dtype)
239
  elif IS_LVRAM and IS_CC89: self.pipeline = load_pipeline_lowvram(repo_id, dtype)
240
  else: self.pipeline = load_pipeline_stable(repo_id, dtype)
 
 
241
  if not IS_COMPILE:
242
  self.pipeline.enable_vae_slicing()
243
  self.pipeline.enable_vae_tiling()
244
- if IS_PARA: apply_cache_on_pipe(self.pipeline, residual_diff_threshold=0.12)
245
- self.pipeline.to("cuda")
 
 
246
  gc.collect()
247
  torch.cuda.empty_cache()
248
  print_vram()
 
217
  #pipe.transformer.fuse_qkv_projections()
218
  pipe.transformer.to(memory_format=torch.channels_last)
219
  #pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
220
+ #pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
221
  if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
222
  elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
223
  #pipe.vae.fuse_qkv_projections()
224
  pipe.vae.to(memory_format=torch.channels_last)
225
  #pipe.vae = torch.compile(pipe.vae, mode="max-autotune", fullgraph=True)
226
+ #pipe.vae = torch.compile(pipe.vae, mode="max-autotune-no-cudagraphs")
227
  return pipe
228
 
229
  class EndpointHandler:
 
238
  elif IS_COMPILE: self.pipeline = load_pipeline_fast(repo_id, dtype)
239
  elif IS_LVRAM and IS_CC89: self.pipeline = load_pipeline_lowvram(repo_id, dtype)
240
  else: self.pipeline = load_pipeline_stable(repo_id, dtype)
241
+ if IS_PARA: apply_cache_on_pipe(self.pipeline, residual_diff_threshold=0.12)
242
+ self.pipeline.to("cuda")
243
  if not IS_COMPILE:
244
  self.pipeline.enable_vae_slicing()
245
  self.pipeline.enable_vae_tiling()
246
+ else:
247
+ print("Compiling pipeline...")
248
+ self.pipeline.transformer = torch.compile(self.pipeline.transformer, mode="max-autotune-no-cudagraphs")
249
+ self.pipeline.vae = torch.compile(self.pipeline.vae, mode="max-autotune-no-cudagraphs")
250
  gc.collect()
251
  torch.cuda.empty_cache()
252
  print_vram()