import os import torch from pathlib import Path from PIL.Image import Image from diffusers import StableDiffusionXLPipeline, DDIMScheduler, AutoencoderTiny from autoencoder_kl import AutoencoderKL from pipelines.models import TextToImageRequest from torch import Generator from cache_diffusion import cachify from trt_pipeline.deploy import load_unet_trt from loss import SchedulerWrapper # from cache_diffusion.utils import SDXL_DEFAULT_CONFIG generator = Generator(torch.device("cuda")).manual_seed(666) prompt = "future punk robot shooting" neg_prompt = "bloody, fire" no_cache_blk = ["down_blocks.2", "up_blocks.0", "mid_block"] SDXL_DEFAULT_CONFIG = [{ "wildcard_or_filter_func": lambda name: any([blk in name for blk in no_cache_blk]), "select_cache_step_func": lambda step: step in [9, 11, 12], }] HOME = os.environ["HOME"] def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs): if step_index == int(pipe.num_timesteps * 0.78): callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1] callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1] callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1] pipe._guidance_scale = 0.1 return callback_kwargs def load_pipeline() -> StableDiffusionXLPipeline: pipe = StableDiffusionXLPipeline.from_pretrained( "stablediffusionapi/newdream-sdxl-20", torch_dtype=torch.float16, use_safetensors=True ).to("cuda") pipe.text_encoder = torch.compile(pipe.text_encoder, fullgraph=True, backend="cudagraphs") pipe.vae = torch.compile(pipe.vae, fullgraph=True, backend="cudagraphs") pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, fullgraph=True, backend="cudagraphs") pipe(prompt, negative_prompt=neg_prompt) pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config)) pipe = pipe.to("cuda") pipe.scheduler.prepare_loss(16) ENGINE_PATH = f"{HOME}/.cache/huggingface/hub/models--jokerbit--newdream-sdxl-20-engine/snapshots/8f4557f59479ee5feb721132867da906ca7c8e44" try: file_path = os.path.join(ENGINE_PATH, ".gitattributes") os.remove(file_path) except Exception as err: print(err) pass load_unet_trt( pipe.unet, engine_path=Path(ENGINE_PATH), batch_size=1, ) pipe(prompt, negative_prompt=neg_prompt) cachify.prepare(pipe, SDXL_DEFAULT_CONFIG) cachify.enable(pipe) for _ in range(5): pipe(prompt, negative_prompt=neg_prompt, callback_on_step_end=callback_dynamic_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'] ) cachify.disable(pipe) return pipe def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image: if request.seed is None: generator = None else: generator = Generator(pipeline.device).manual_seed(request.seed) cachify.enable(pipeline) image = pipeline( prompt=request.prompt, negative_prompt=request.negative_prompt, width=request.width, height=request.height, generator=generator, num_inference_steps=13, eta=1.0, guidance_scale = 5.0, guidance_rescale = 0.0, ).images[0] cachify.disable(pipeline) return image