| from diffusers import AutoencoderTiny, StableDiffusionXLPipeline |
| from .hacked_sdxl_pipeline import HackedSDXLPipeline |
| import torch |
|
|
| def fast_diffusion_pipeline(model_id = "stabilityai/sdxl-turbo", vae_id = "madebyollin/taesdxl", compile = False): |
| """ |
| :param compile: If true, does a bunch of stuff to make calls fast, but the first call will be very slow as a consequence |
| - If you use this, don't vary the batch size (probably) |
| """ |
|
|
| pipe = HackedSDXLPipeline.from_pretrained(model_id, torch_dtype = torch.float16) |
| pipe.set_progress_bar_config(disable=True) |
| pipe.cached_encode = None |
| pipe.vae = AutoencoderTiny.from_pretrained(vae_id, torch_dtype=torch.float16) |
|
|
| pipe.to('cuda') |
|
|
| if compile: |
| pipe.unet = torch.compile(pipe.unet) |
| pipe.vae.decode = torch.compile(pipe.vae.decode) |
| """ |
| from sfast.compilers.stable_diffusion_pipeline_compiler import (compile, CompilationConfig) |
| |
| config = CompilationConfig() |
| config.enable_jit = True |
| config.enable_jit_freeze = True |
| config.trace_scheduler = True |
| config.enable_cnn_optimization = True |
| config.preserve_parameters = False |
| config.prefer_lowp_gemm = True |
| |
| pipe = compile(pipe, config) |
| """ |
| return pipe |
|
|