| import torch | |
| from pathlib import Path | |
| from PIL.Image import Image | |
| from diffusers import StableDiffusionXLPipeline, DDIMScheduler | |
| from pipelines.models import TextToImageRequest | |
| from torch import Generator | |
| from cache_diffusion import cachify | |
| from trt_pipeline.deploy import load_unet_trt | |
| from loss import SchedulerWrapper | |
| generator = Generator(torch.device("cuda")).manual_seed(69) | |
| prompt = "make submissions great again" | |
| SDXL_DEFAULT_CONFIG = [ | |
| { | |
| "wildcard_or_filter_func": lambda name: "down_blocks.2" not in name and"down_blocks.3" not in name and "up_blocks.2" not in name, | |
| "select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 10), | |
| }] | |
| def load_pipeline() -> StableDiffusionXLPipeline: | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "models/newdream-sdxl-20", torch_dtype=torch.float16, use_safetensors=True, local_files_only=True | |
| ).to("cuda") | |
| load_unet_trt( | |
| pipe.unet, | |
| engine_path=Path("./engine"), | |
| batch_size=1, | |
| ) | |
| cachify.prepare(pipe, SDXL_DEFAULT_CONFIG) | |
| cachify.enable(pipe) | |
| pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config)) | |
| with cachify.infer(pipe) as cached_pipe: | |
| for _ in range(5): | |
| pipe(prompt=prompt, num_inference_steps=20) | |
| cachify.disable(pipe) | |
| pipe.scheduler.prepare_loss() | |
| return pipe | |
| def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image: | |
| if request.seed is None: | |
| generator = None | |
| else: | |
| generator = Generator(pipeline.device).manual_seed(request.seed) | |
| cachify.prepare(pipeline, SDXL_DEFAULT_CONFIG) | |
| cachify.enable(pipeline) | |
| with cachify.infer(pipeline) as cached_pipe: | |
| image = cached_pipe( | |
| prompt=request.prompt, | |
| negative_prompt=request.negative_prompt, | |
| width=request.width, | |
| height=request.height, | |
| generator=generator, | |
| num_inference_steps=15, | |
| ).images[0] | |
| return image | |