|
|
import torch |
|
|
from pathlib import Path |
|
|
from PIL.Image import Image |
|
|
from diffusers import StableDiffusionXLPipeline, DDIMScheduler |
|
|
from pipelines.models import TextToImageRequest |
|
|
from torch import Generator |
|
|
from cache_diffusion import cachify |
|
|
from trt_pipeline.deploy import load_unet_trt |
|
|
from loss import SchedulerWrapper |
|
|
import numpy as np |
|
|
def pixel_filter(image: Image) -> Image: |
|
|
try: |
|
|
|
|
|
img_array = np.array(image) |
|
|
|
|
|
|
|
|
max_val = img_array.min() |
|
|
|
|
|
|
|
|
img_array[img_array == max_val] +=1 |
|
|
|
|
|
filtered_image = Image.fromarray(img_array) |
|
|
return filtered_image |
|
|
except: |
|
|
return image |
|
|
|
|
|
|
|
|
generator = Generator(torch.device("cuda")).manual_seed(69) |
|
|
|
|
|
SDXL_DEFAULT_CONFIG = [ |
|
|
{ |
|
|
"wildcard_or_filter_func": lambda name: "down_blocks.2" not in name and"down_blocks.3" not in name and "up_blocks.2" not in name, |
|
|
"select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 10), |
|
|
}] |
|
|
def load_pipeline() -> StableDiffusionXLPipeline: |
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
|
"models/newdream-sdxl-20", torch_dtype=torch.float16, use_safetensors=True, local_files_only=True |
|
|
).to("cuda") |
|
|
load_unet_trt( |
|
|
pipe.unet, |
|
|
engine_path=Path("./engine"), |
|
|
batch_size=1, |
|
|
) |
|
|
cachify.prepare(pipe, SDXL_DEFAULT_CONFIG) |
|
|
cachify.enable(pipe) |
|
|
pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config)) |
|
|
with cachify.infer(pipe) as cached_pipe: |
|
|
for _ in range(4): |
|
|
pipe(prompt="a superman", num_inference_steps=15) |
|
|
cachify.disable(pipe) |
|
|
pipe.scheduler.prepare_loss() |
|
|
return pipe |
|
|
|
|
|
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image: |
|
|
|
|
|
if request.seed is None: |
|
|
generator = None |
|
|
else: |
|
|
generator = Generator(pipeline.device).manual_seed(request.seed) |
|
|
cachify.prepare(pipeline, SDXL_DEFAULT_CONFIG) |
|
|
cachify.enable(pipeline) |
|
|
with cachify.infer(pipeline) as cached_pipe: |
|
|
image = cached_pipe( |
|
|
prompt=request.prompt, |
|
|
negative_prompt=request.negative_prompt, |
|
|
width=request.width, |
|
|
height=request.height, |
|
|
generator=generator, |
|
|
num_inference_steps=15, |
|
|
).images[0] |
|
|
filtered_image = pixel_filter(image) |
|
|
return filtered_image |
|
|
|