pk2 / src /pipeline.py
slobers's picture
Upload folder using huggingface_hub
61f4807 verified
import os
import torch
from pathlib import Path
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, AutoencoderTiny
from autoencoder_kl import AutoencoderKL
from pipelines.models import TextToImageRequest
from torch import Generator
from cache_diffusion import cachify
from trt_pipeline.deploy import load_unet_trt
from loss import SchedulerWrapper
no_cache_blk = ["down_blocks.2", "up_blocks.0", "mid_block"]
SDXL_DEFAULT_CONFIG = [{
"wildcard_or_filter_func": lambda name: any([blk in name for blk in no_cache_blk]),
"select_cache_step_func": lambda step: (step % 2 == 0) and (step >= 8),
}]
HOME = os.environ["HOME"]
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
if step_index == int(pipe.num_timesteps * 0.75):
callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
pipe._guidance_scale = 1.1
return callback_kwargs
def load_pipeline() -> StableDiffusionXLPipeline:
pipe = StableDiffusionXLPipeline.from_pretrained(
"stablediffusionapi/newdream-sdxl-20",
torch_dtype=torch.float16,
use_safetensors=True
)
pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config))
pipe = pipe.to("cuda")
pipe.scheduler.prepare_loss(16)
ENGINE_PATH = f"{HOME}/.cache/huggingface/hub/models--slobers--cancer/snapshots/209cecbed645ffa913ebaefc115029021a0fa230"
try:
file_path = os.path.join(ENGINE_PATH, ".gitattributes")
os.remove(file_path)
except Exception as err:
print(err)
pass
load_unet_trt(
pipe.unet,
engine_path=Path(ENGINE_PATH),
batch_size=1,
)
pipe(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus")
cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)
cachify.enable(pipe)
for _ in range(5):
pipe(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus",
callback_on_step_end=callback_dynamic_cfg,
callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids']
)
cachify.disable(pipe)
return pipe
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
if request.seed is None:
generator = None
else:
generator = Generator(pipeline.device).manual_seed(request.seed)
cachify.enable(pipeline)
image = pipeline(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
generator=generator,
num_inference_steps=10,
end_cfg=0.5,
eta=1.0,
guidance_scale = 5.0,
guidance_rescale = 0.0,
).images[0]
cachify.disable(pipeline)
return image