edge-15 / src /pipeline.py
agentbot's picture
Initial commit with folder contents
1004df5 verified
import torch
from pathlib import Path
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
from pipelines.models import TextToImageRequest
from torch import Generator
from cache_diffusion import cachify
from trt_pipeline.deploy import load_unet_trt
from loss import SchedulerWrapper
import numpy as np
def pixel_filter(image: Image) -> Image:
try:
# Convert the image to a numpy array
img_array = np.array(image)
# Find the maximum pixel value in the image
# max_val = img_array.max()
max_val = img_array.min()
# Reduce the maximum value to 1
img_array[img_array == max_val] +=1
# Convert the numpy array back to an image
filtered_image = Image.fromarray(img_array)
return filtered_image
except:
return image
generator = Generator(torch.device("cuda")).manual_seed(69)
SDXL_DEFAULT_CONFIG = [
{
"wildcard_or_filter_func": lambda name: "down_blocks.2" not in name and"down_blocks.3" not in name and "up_blocks.2" not in name,
"select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 10),
}]
def load_pipeline() -> StableDiffusionXLPipeline:
pipe = StableDiffusionXLPipeline.from_pretrained(
"models/newdream-sdxl-20", torch_dtype=torch.float16, use_safetensors=True, local_files_only=True
).to("cuda")
load_unet_trt(
pipe.unet,
engine_path=Path("./engine"),
batch_size=1,
)
cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)
cachify.enable(pipe)
pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config))
with cachify.infer(pipe) as cached_pipe:
for _ in range(4):
pipe(prompt="a superman", num_inference_steps=15)
cachify.disable(pipe)
pipe.scheduler.prepare_loss()
return pipe
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
if request.seed is None:
generator = None
else:
generator = Generator(pipeline.device).manual_seed(request.seed)
cachify.prepare(pipeline, SDXL_DEFAULT_CONFIG)
cachify.enable(pipeline)
with cachify.infer(pipeline) as cached_pipe:
image = cached_pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
generator=generator,
num_inference_steps=15,
).images[0]
filtered_image = pixel_filter(image)
return filtered_image