Upload handler.py
Browse files- handler.py +6 -4
handler.py
CHANGED
|
@@ -30,7 +30,7 @@ IS_LVRAM = False
|
|
| 30 |
IS_COMPILE = True
|
| 31 |
IS_WARM = True
|
| 32 |
IS_QUANT = True
|
| 33 |
-
IS_AUTOQ =
|
| 34 |
IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
|
| 35 |
IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
|
| 36 |
|
|
@@ -217,8 +217,7 @@ def load_pipeline_fast(repo_id: str, dtype: torch.dtype) -> Any:
|
|
| 217 |
pipe.transformer.to(memory_format=torch.channels_last)
|
| 218 |
pipe.vae.to(memory_format=torch.channels_last)
|
| 219 |
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
|
| 220 |
-
if IS_QUANT:
|
| 221 |
-
int8_dynamic_activation_int4_weight()
|
| 222 |
quantize_(pipe.text_encoder, int8_dynamic_activation_int8_weight())
|
| 223 |
quantize_(pipe.text_encoder_2, int8_dynamic_activation_int8_weight())
|
| 224 |
if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
|
@@ -237,7 +236,7 @@ class EndpointHandler:
|
|
| 237 |
#dtype = torch.float16 # for older nVidia GPUs
|
| 238 |
print_vram()
|
| 239 |
print("Loading pipeline...")
|
| 240 |
-
if IS_AUTOQ: self.pipeline =
|
| 241 |
elif IS_COMPILE: self.pipeline = load_pipeline_fast(repo_id, dtype)
|
| 242 |
elif IS_LVRAM and IS_CC89: self.pipeline = load_pipeline_lowvram(repo_id, dtype)
|
| 243 |
else: self.pipeline = load_pipeline_stable(repo_id, dtype)
|
|
@@ -250,6 +249,9 @@ class EndpointHandler:
|
|
| 250 |
print("Compiling pipeline...")
|
| 251 |
self.pipeline.transformer = torch.compile(self.pipeline.transformer, mode="max-autotune-no-cudagraphs")
|
| 252 |
self.pipeline.vae = torch.compile(self.pipeline.vae, mode="max-autotune-no-cudagraphs")
|
|
|
|
|
|
|
|
|
|
| 253 |
gc.collect()
|
| 254 |
torch.cuda.empty_cache()
|
| 255 |
print_vram()
|
|
|
|
| 30 |
IS_COMPILE = True
|
| 31 |
IS_WARM = True
|
| 32 |
IS_QUANT = True
|
| 33 |
+
IS_AUTOQ = True
|
| 34 |
IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
|
| 35 |
IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
|
| 36 |
|
|
|
|
| 217 |
pipe.transformer.to(memory_format=torch.channels_last)
|
| 218 |
pipe.vae.to(memory_format=torch.channels_last)
|
| 219 |
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
|
| 220 |
+
if IS_QUANT and not IS_AUTOQ:
|
|
|
|
| 221 |
quantize_(pipe.text_encoder, int8_dynamic_activation_int8_weight())
|
| 222 |
quantize_(pipe.text_encoder_2, int8_dynamic_activation_int8_weight())
|
| 223 |
if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
|
|
|
| 236 |
#dtype = torch.float16 # for older nVidia GPUs
|
| 237 |
print_vram()
|
| 238 |
print("Loading pipeline...")
|
| 239 |
+
if IS_AUTOQ: self.pipeline = load_pipeline_fast(repo_id, dtype)
|
| 240 |
elif IS_COMPILE: self.pipeline = load_pipeline_fast(repo_id, dtype)
|
| 241 |
elif IS_LVRAM and IS_CC89: self.pipeline = load_pipeline_lowvram(repo_id, dtype)
|
| 242 |
else: self.pipeline = load_pipeline_stable(repo_id, dtype)
|
|
|
|
| 249 |
print("Compiling pipeline...")
|
| 250 |
self.pipeline.transformer = torch.compile(self.pipeline.transformer, mode="max-autotune-no-cudagraphs")
|
| 251 |
self.pipeline.vae = torch.compile(self.pipeline.vae, mode="max-autotune-no-cudagraphs")
|
| 252 |
+
if IS_AUTOQ:
|
| 253 |
+
self.pipeline.transformer = autoquant(self.pipeline.transformer, error_on_unseen=False)
|
| 254 |
+
self.pipeline.vae = autoquant(self.pipeline.vae, error_on_unseen=False)
|
| 255 |
gc.collect()
|
| 256 |
torch.cuda.empty_cache()
|
| 257 |
print_vram()
|