chansung's picture
Create handler.py
72c9a82
from typing import Dict, List, Any
import sys
import base64
import logging
import copy
import numpy as np
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
class ReusablePipePool:
def __init__(
self,
size,
model_base="runwayml/stable-diffusion-v1-5"
):
self._reusablePipes = []
for i in range(size):
pipe = StableDiffusionPipeline.from_pretrained(
model_base, torch_dtype=torch.float16
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
self._reusablePipes.append(pipe)
if not self.empty():
self.original_unet = copy.deepcopy(self._reusablePipes[0].unet)
def acquire(self):
return self._reusablePipes.pop()
def release(self, reusablePipe):
self._reusablePipes.append(reusablePipe)
def empty(self):
return len(self._reusablePipes) == 0
class EndpointHandler():
def __init__(self, path=""):
self.pool = ReusablePipePool(2)
def _generate_images(
self,
model_path,
prompt,
num_inference_steps=25,
guidance_scale=7.5,
num_images_per_prompt=1):
reusablePipe = None
while not self.pool.empty():
reusablePipe = self.pool.acquire()
if model_path == "base":
reusablePipe.unet = copy.deepcopy(self.pool.original_unet)
else:
reusablePipe.unet.load_attn_procs(model_path)
reusablePipe.to("cuda")
pil_images = reusablePipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt).images
self.pool.release(reusablePipe)
np_images = []
for i in range(len(pil_images)):
np_images.append(np.asarray(pil_images[i]))
return np.stack(np_images, axis=0)
def __call__(self, data: Dict[str, Any]) -> str:
prompt = data.pop("inputs", "test image")
model_path = data.pop("model_path", "base")
num_inference_steps = data.pop("num_inference_steps", 25)
guidance_scale = data.pop("guidance_scale", 7.5)
num_images_per_prompt = data.pop("num_images_per_prompt", 1)
images = self._generate_images(
model_path, prompt,
num_inference_steps, guidance_scale, num_images_per_prompt
)
return base64.b64encode(images.tobytes()).decode()