import torch from pathlib import Path from PIL.Image import Image from diffusers import StableDiffusionXLPipeline from pipelines.models import TextToImageRequest from torch import Generator from cache_diffusion import cachify from trt_pipeline.deploy import load_unet_trt # from cache_diffusion.utils import SDXL_DEFAULT_CONFIG generator = Generator(torch.device("cuda")).manual_seed(666) prompt = "future punk robot shooting" SDXL_DEFAULT_CONFIG = [ { "wildcard_or_filter_func": lambda name: "down_blocks.3" not in name and "up_blocks.2" not in name, "select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 13), }] def load_pipeline() -> StableDiffusionXLPipeline: pipe = StableDiffusionXLPipeline.from_pretrained( "models/newdream-sdxl-20", torch_dtype=torch.float16, use_safetensors=True, local_files_only=True ).to("cuda") # pipe(prompt, generator=generator, num_inference_steps=21) # pipe.fuse_qkv_projections() # pipe.vae = torch.compile(pipe.vae, backend="cudagraphs", fullgraph=True) # pipe.text_encoder = torch.compile(pipe.text_encoder, backend="cudagraphs", fullgraph=True) load_unet_trt( pipe.unet, engine_path=Path("./engine"), batch_size=1, ) cachify.prepare(pipe, SDXL_DEFAULT_CONFIG) cachify.enable(pipe) with cachify.infer(pipe) as cached_pipe: cached_pipe(prompt=prompt, num_inference_steps=21) cached_pipe(prompt=prompt, num_inference_steps=21) cachify.disable(pipe) return pipe def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image: if request.seed is None: generator = None else: generator = Generator(pipeline.device).manual_seed(request.seed) cachify.enable(pipeline) with cachify.infer(pipeline) as cached_pipe: image = cached_pipe( prompt=request.prompt, negative_prompt=request.negative_prompt, width=request.width, height=request.height, generator=generator, num_inference_steps=21, ).images[0] cachify.disable(pipeline) return image