| import os |
| import torch |
| from pathlib import Path |
| from PIL.Image import Image |
| from diffusers import StableDiffusionXLPipeline, DDIMScheduler, AutoencoderTiny |
| from autoencoder_kl import AutoencoderKL |
| from pipelines.models import TextToImageRequest |
| from torch import Generator |
|
|
| from cache_diffusion import cachify |
| from trt_pipeline.deploy import load_unet_trt |
| from loss import SchedulerWrapper |
|
|
| |
|
|
| generator = Generator(torch.device("cuda")).manual_seed(666) |
| prompt = "future punk robot shooting" |
| neg_prompt = "bloody, fire" |
| no_cache_blk = ["down_blocks.2", "up_blocks.0", "mid_block"] |
| SDXL_DEFAULT_CONFIG = [{ |
| "wildcard_or_filter_func": lambda name: any([blk in name for blk in no_cache_blk]), |
| "select_cache_step_func": lambda step: step in [9, 11, 12], |
| }] |
| HOME = os.environ["HOME"] |
|
|
| def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs): |
| if step_index == int(pipe.num_timesteps * 0.78): |
| callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1] |
| callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1] |
| callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1] |
| pipe._guidance_scale = 0.1 |
|
|
| return callback_kwargs |
|
|
| def load_pipeline() -> StableDiffusionXLPipeline: |
| pipe = StableDiffusionXLPipeline.from_pretrained( |
| "stablediffusionapi/newdream-sdxl-20", |
| torch_dtype=torch.float16, |
| use_safetensors=True |
| ).to("cuda") |
| pipe.text_encoder = torch.compile(pipe.text_encoder, fullgraph=True, backend="cudagraphs") |
| pipe.vae = torch.compile(pipe.vae, fullgraph=True, backend="cudagraphs") |
| pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, fullgraph=True, backend="cudagraphs") |
| pipe(prompt, negative_prompt=neg_prompt) |
|
|
| pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config)) |
| pipe = pipe.to("cuda") |
| |
| pipe.scheduler.prepare_loss(16) |
| ENGINE_PATH = f"{HOME}/.cache/huggingface/hub/models--jokerbit--newdream-sdxl-20-engine/snapshots/8f4557f59479ee5feb721132867da906ca7c8e44" |
| try: |
| file_path = os.path.join(ENGINE_PATH, ".gitattributes") |
| os.remove(file_path) |
| except Exception as err: |
| print(err) |
| pass |
| load_unet_trt( |
| pipe.unet, |
| engine_path=Path(ENGINE_PATH), |
| batch_size=1, |
| ) |
|
|
| pipe(prompt, negative_prompt=neg_prompt) |
| cachify.prepare(pipe, SDXL_DEFAULT_CONFIG) |
| cachify.enable(pipe) |
| for _ in range(5): |
| pipe(prompt, negative_prompt=neg_prompt, |
| callback_on_step_end=callback_dynamic_cfg, |
| callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'] |
| ) |
| cachify.disable(pipe) |
| return pipe |
|
|
| def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image: |
|
|
| if request.seed is None: |
| generator = None |
| else: |
| generator = Generator(pipeline.device).manual_seed(request.seed) |
| cachify.enable(pipeline) |
| image = pipeline( |
| prompt=request.prompt, |
| negative_prompt=request.negative_prompt, |
| width=request.width, |
| height=request.height, |
| generator=generator, |
| num_inference_steps=13, |
| eta=1.0, |
| guidance_scale = 5.0, |
| guidance_rescale = 0.0, |
| ).images[0] |
| cachify.disable(pipeline) |
| return image |
|
|