File size: 3,652 Bytes
842c50e 95ddcb0 842c50e 95ddcb0 c6aeea5 842c50e c6aeea5 95ddcb0 7a0cd96 95ddcb0 c6aeea5 842c50e 7c9208c 2e53069 7c9208c 2e53069 842c50e 7a0cd96 95ddcb0 842c50e 7a0cd96 842c50e 7a0cd96 842c50e 7a0cd96 842c50e 95ddcb0 7c9208c 842c50e 95ddcb0 c6aeea5 7a0cd96 842c50e dfa159a 95ddcb0 dfa159a 95ddcb0 dfa159a 7c9208c dfa159a 7a0cd96 842c50e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from typing import Dict, List, Any
import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, EulerAncestralDiscreteScheduler
from PIL import Image
import base64
from io import BytesIO
import numpy as np
# from RealESRGAN import RealESRGAN
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
class EndpointHandler():
def __init__(self, path=""):
# load StableDiffusionInpaintPipeline pipeline
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16)
# use EulerAncestralDiscreteScheduler
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
# pipe.enable_sequential_cpu_offload()
# move to device
self.pipe.to(device)
self.pipe.enable_xformers_memory_efficient_attention()
# self.upscaler = RealESRGAN(device, scale=4)
# self.upscaler.load_weights('weights/RealESRGAN_x4.pth', download=True)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `inputs` and optional `image` field.
:return: A dictionary with `image` field contains image in base64.
"""
inputs = data.pop("inputs", data)
encoded_image = data.pop("image", None)
encoded_mask_image = data.pop("mask_image", None)
num_images = data.pop("num_images", None)
print(f"num_image {num_images}")
if num_images > 4 or num_images < 1:
return {"Invalid Request": "Number of generated images must be >= 1 and <=4"}
# hyperparamters
num_inference_steps = data.pop("num_inference_steps", 50)
guidance_scale = data.pop("guidance_scale", 7.5)
negative_prompt = data.pop("negative_prompt", None)
height = data.pop("height", None)
width = data.pop("width", None)
# process image
if encoded_image is not None and encoded_mask_image is not None:
image = self.decode_base64_image(encoded_image)
mask_image = self.decode_base64_image(encoded_mask_image)
else:
image = None
mask_image = None
# run inference pipeline
out = self.pipe(inputs,
image=image,
mask_image=mask_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images,
negative_prompt=negative_prompt,
height=height,
width=width
).images
# for i in range(len(out)):
# gen_img = Image.composite(out[i], image.resize(out[i].size), mask_image.resize(out[i].size))
# gen_img = self.upscaler.predict(gen_img)
# gen_img = Image.composite(gen_img, image.resize(gen_img.size), mask_image.resize(gen_img.size))
# out[i] = gen_img
# return first generate PIL image
json_imgs = {}
for i in range(len(out)):
buffered = BytesIO()
out[i].save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
json_imgs[f"{i}"] = img_str.decode()
return json_imgs
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image
|