| import os | |
| import torch | |
| from pathlib import Path | |
| from PIL.Image import Image | |
| from diffusers import StableDiffusionXLPipeline, DDIMScheduler, AutoencoderTiny | |
| from autoencoder_kl import AutoencoderKL | |
| from pipelines.models import TextToImageRequest | |
| from torch import Generator | |
| from cache_diffusion import cachify | |
| from trt_pipeline.deploy import load_unet_trt | |
| from loss import SchedulerWrapper | |
| no_cache_blk = ["down_blocks.2", "up_blocks.0", "mid_block"] | |
| SDXL_DEFAULT_CONFIG = [{ | |
| "wildcard_or_filter_func": lambda name: any([blk in name for blk in no_cache_blk]), | |
| "select_cache_step_func": lambda step: (step % 2 == 0) and (step >= 8), | |
| }] | |
| HOME = os.environ["HOME"] | |
| def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs): | |
| if step_index == int(pipe.num_timesteps * 0.75): | |
| callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1] | |
| callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1] | |
| callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1] | |
| pipe._guidance_scale = 1.1 | |
| return callback_kwargs | |
| def load_pipeline() -> StableDiffusionXLPipeline: | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "stablediffusionapi/newdream-sdxl-20", | |
| torch_dtype=torch.float16, | |
| use_safetensors=True | |
| ) | |
| pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config)) | |
| pipe = pipe.to("cuda") | |
| pipe.scheduler.prepare_loss() | |
| ENGINE_PATH = f"{HOME}/.cache/huggingface/hub/models--slobers--cancer/snapshots/209cecbed645ffa913ebaefc115029021a0fa230" | |
| try: | |
| file_path = os.path.join(ENGINE_PATH, ".gitattributes") | |
| os.remove(file_path) | |
| except Exception as err: | |
| print(err) | |
| pass | |
| load_unet_trt( | |
| pipe.unet, | |
| engine_path=Path(ENGINE_PATH), | |
| batch_size=1, | |
| ) | |
| pipe(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus") | |
| cachify.prepare(pipe, SDXL_DEFAULT_CONFIG) | |
| cachify.enable(pipe) | |
| for _ in range(5): | |
| pipe(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", | |
| callback_on_step_end=callback_dynamic_cfg, | |
| callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'] | |
| ) | |
| cachify.disable(pipe) | |
| return pipe | |
| def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image: | |
| if request.seed is None: | |
| generator = None | |
| else: | |
| generator = Generator(pipeline.device).manual_seed(request.seed) | |
| cachify.enable(pipeline) | |
| image = pipeline( | |
| prompt=request.prompt, | |
| negative_prompt=request.negative_prompt, | |
| width=request.width, | |
| height=request.height, | |
| generator=generator, | |
| num_inference_steps=14, | |
| end_cfg=0.5, | |
| eta=1.0, | |
| guidance_scale = 5.0, | |
| guidance_rescale = 0.0, | |
| ).images[0] | |
| cachify.disable(pipeline) | |
| return image | |