Upload handler.py
Browse files- handler.py +19 -16
handler.py
CHANGED
|
@@ -29,6 +29,7 @@ IS_MGPU = False
|
|
| 29 |
IS_LVRAM = False
|
| 30 |
IS_COMPILE = True
|
| 31 |
IS_WARM = True
|
|
|
|
| 32 |
IS_AUTOQ = False
|
| 33 |
IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
|
| 34 |
IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
|
|
@@ -41,11 +42,11 @@ if IS_COMPILE:
|
|
| 41 |
import torch._dynamo
|
| 42 |
torch._dynamo.config.suppress_errors = False
|
| 43 |
#torch._dynamo.config.suppress_errors = True
|
| 44 |
-
torch._inductor.config.disable_progress = False
|
| 45 |
-
torch._inductor.config.conv_1x1_as_mm = True
|
| 46 |
-
torch._inductor.config.coordinate_descent_tuning = True
|
| 47 |
-
torch._inductor.config.coordinate_descent_check_all_directions = True
|
| 48 |
-
torch._inductor.config.epilogue_fusion = False
|
| 49 |
|
| 50 |
if IS_MGPU:
|
| 51 |
import torch.distributed as dist
|
|
@@ -211,19 +212,21 @@ def load_pipeline_fast(repo_id: str, dtype: torch.dtype) -> Any:
|
|
| 211 |
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
|
| 212 |
pipe.enable_vae_slicing()
|
| 213 |
pipe.enable_vae_tiling()
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
|
| 217 |
-
#pipe.transformer.fuse_qkv_projections()
|
| 218 |
pipe.transformer.to(memory_format=torch.channels_last)
|
| 219 |
-
#pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
| 220 |
-
#pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")
|
| 221 |
-
if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
| 222 |
-
elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
|
| 223 |
-
#pipe.vae.fuse_qkv_projections()
|
| 224 |
pipe.vae.to(memory_format=torch.channels_last)
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
return pipe
|
| 228 |
|
| 229 |
class EndpointHandler:
|
|
|
|
| 29 |
IS_LVRAM = False
|
| 30 |
IS_COMPILE = True
|
| 31 |
IS_WARM = True
|
| 32 |
+
IS_QUANT = True
|
| 33 |
IS_AUTOQ = False
|
| 34 |
IS_CC90 = True if torch.cuda.get_device_capability() >= (9, 0) else False
|
| 35 |
IS_CC89 = True if torch.cuda.get_device_capability() >= (8, 9) else False
|
|
|
|
| 42 |
import torch._dynamo
|
| 43 |
torch._dynamo.config.suppress_errors = False
|
| 44 |
#torch._dynamo.config.suppress_errors = True
|
| 45 |
+
#torch._inductor.config.disable_progress = False
|
| 46 |
+
#torch._inductor.config.conv_1x1_as_mm = True
|
| 47 |
+
#torch._inductor.config.coordinate_descent_tuning = True
|
| 48 |
+
#torch._inductor.config.coordinate_descent_check_all_directions = True
|
| 49 |
+
#torch._inductor.config.epilogue_fusion = False
|
| 50 |
|
| 51 |
if IS_MGPU:
|
| 52 |
import torch.distributed as dist
|
|
|
|
| 212 |
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=dtype).to("cuda")
|
| 213 |
pipe.enable_vae_slicing()
|
| 214 |
pipe.enable_vae_tiling()
|
| 215 |
+
pipe.transformer.fuse_qkv_projections()
|
| 216 |
+
pipe.vae.fuse_qkv_projections()
|
|
|
|
|
|
|
| 217 |
pipe.transformer.to(memory_format=torch.channels_last)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
pipe.vae.to(memory_format=torch.channels_last)
|
| 219 |
+
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
|
| 220 |
+
if IS_QUANT:
|
| 221 |
+
int8_dynamic_activation_int4_weight()
|
| 222 |
+
quantize_(pipe.text_encoder, int8_dynamic_activation_int8_weight())
|
| 223 |
+
quantize_(pipe.text_encoder_2, int8_dynamic_activation_int8_weight())
|
| 224 |
+
if IS_CC90: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
| 225 |
+
elif IS_CC89: quantize_(pipe.transformer, float8_dynamic_activation_float8_weight(), device="cuda")
|
| 226 |
+
else: quantize_(pipe.vae, int8_dynamic_activation_int4_weight())
|
| 227 |
+
if IS_CC90: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(granularity=PerRow()), device="cuda")
|
| 228 |
+
elif IS_CC89: quantize_(pipe.vae, float8_dynamic_activation_float8_weight(), device="cuda")
|
| 229 |
+
else: quantize_(pipe.vae, int8_dynamic_activation_int8_weight())
|
| 230 |
return pipe
|
| 231 |
|
| 232 |
class EndpointHandler:
|