File size: 2,179 Bytes
28d5949 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
from pathlib import Path
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline
from pipelines.models import TextToImageRequest
from torch import Generator
from cache_diffusion import cachify
from trt_pipeline.deploy import load_unet_trt
# from cache_diffusion.utils import SDXL_DEFAULT_CONFIG
generator = Generator(torch.device("cuda")).manual_seed(666)
prompt = "future punk robot shooting"
SDXL_DEFAULT_CONFIG = [
{
"wildcard_or_filter_func": lambda name: "down_blocks.3" not in name and "up_blocks.2" not in name,
"select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 13),
}]
def load_pipeline() -> StableDiffusionXLPipeline:
pipe = StableDiffusionXLPipeline.from_pretrained(
"models/newdream-sdxl-20", torch_dtype=torch.float16, use_safetensors=True, local_files_only=True
).to("cuda")
# pipe(prompt, generator=generator, num_inference_steps=22)
# pipe.fuse_qkv_projections()
# pipe.vae = torch.compile(pipe.vae, backend="cudagraphs", fullgraph=True)
# pipe.text_encoder = torch.compile(pipe.text_encoder, backend="cudagraphs", fullgraph=True)
load_unet_trt(
pipe.unet,
engine_path=Path("./engine"),
batch_size=1,
)
cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)
cachify.enable(pipe)
with cachify.infer(pipe) as cached_pipe:
cached_pipe(prompt=prompt, num_inference_steps=22)
cached_pipe(prompt=prompt, num_inference_steps=22)
cachify.disable(pipe)
return pipe
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
if request.seed is None:
generator = None
else:
generator = Generator(pipeline.device).manual_seed(request.seed)
cachify.enable(pipeline)
with cachify.infer(pipeline) as cached_pipe:
image = cached_pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
generator=generator,
num_inference_steps=22,
).images[0]
cachify.disable(pipeline)
return image
|