edge-06 / src /pipeline.py
agentbot's picture
Initial commit with folder contents
28d5949 verified
import torch
from pathlib import Path
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline
from pipelines.models import TextToImageRequest
from torch import Generator
from cache_diffusion import cachify
from trt_pipeline.deploy import load_unet_trt
# from cache_diffusion.utils import SDXL_DEFAULT_CONFIG
generator = Generator(torch.device("cuda")).manual_seed(666)
prompt = "future punk robot shooting"
SDXL_DEFAULT_CONFIG = [
{
"wildcard_or_filter_func": lambda name: "down_blocks.3" not in name and "up_blocks.2" not in name,
"select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 13),
}]
def load_pipeline() -> StableDiffusionXLPipeline:
pipe = StableDiffusionXLPipeline.from_pretrained(
"models/newdream-sdxl-20", torch_dtype=torch.float16, use_safetensors=True, local_files_only=True
).to("cuda")
# pipe(prompt, generator=generator, num_inference_steps=22)
# pipe.fuse_qkv_projections()
# pipe.vae = torch.compile(pipe.vae, backend="cudagraphs", fullgraph=True)
# pipe.text_encoder = torch.compile(pipe.text_encoder, backend="cudagraphs", fullgraph=True)
load_unet_trt(
pipe.unet,
engine_path=Path("./engine"),
batch_size=1,
)
cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)
cachify.enable(pipe)
with cachify.infer(pipe) as cached_pipe:
cached_pipe(prompt=prompt, num_inference_steps=22)
cached_pipe(prompt=prompt, num_inference_steps=22)
cachify.disable(pipe)
return pipe
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
if request.seed is None:
generator = None
else:
generator = Generator(pipeline.device).manual_seed(request.seed)
cachify.enable(pipeline)
with cachify.infer(pipeline) as cached_pipe:
image = cached_pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
generator=generator,
num_inference_steps=22,
).images[0]
cachify.disable(pipeline)
return image