File size: 2,604 Bytes
1004df5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
from pathlib import Path
from PIL.Image import Image
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
from pipelines.models import TextToImageRequest
from torch import Generator
from cache_diffusion import cachify
from trt_pipeline.deploy import load_unet_trt
from loss import SchedulerWrapper
import numpy as np
def pixel_filter(image: Image) -> Image:
try:
# Convert the image to a numpy array
img_array = np.array(image)
# Find the maximum pixel value in the image
# max_val = img_array.max()
max_val = img_array.min()
# Reduce the maximum value to 1
img_array[img_array == max_val] +=1
# Convert the numpy array back to an image
filtered_image = Image.fromarray(img_array)
return filtered_image
except:
return image
generator = Generator(torch.device("cuda")).manual_seed(69)
SDXL_DEFAULT_CONFIG = [
{
"wildcard_or_filter_func": lambda name: "down_blocks.2" not in name and"down_blocks.3" not in name and "up_blocks.2" not in name,
"select_cache_step_func": lambda step: (step % 2 != 0) and (step >= 10),
}]
def load_pipeline() -> StableDiffusionXLPipeline:
pipe = StableDiffusionXLPipeline.from_pretrained(
"models/newdream-sdxl-20", torch_dtype=torch.float16, use_safetensors=True, local_files_only=True
).to("cuda")
load_unet_trt(
pipe.unet,
engine_path=Path("./engine"),
batch_size=1,
)
cachify.prepare(pipe, SDXL_DEFAULT_CONFIG)
cachify.enable(pipe)
pipe.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipe.scheduler.config))
with cachify.infer(pipe) as cached_pipe:
for _ in range(4):
pipe(prompt="a superman", num_inference_steps=15)
cachify.disable(pipe)
pipe.scheduler.prepare_loss()
return pipe
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
if request.seed is None:
generator = None
else:
generator = Generator(pipeline.device).manual_seed(request.seed)
cachify.prepare(pipeline, SDXL_DEFAULT_CONFIG)
cachify.enable(pipeline)
with cachify.infer(pipeline) as cached_pipe:
image = cached_pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
generator=generator,
num_inference_steps=15,
).images[0]
filtered_image = pixel_filter(image)
return filtered_image
|