File size: 3,652 Bytes
842c50e
 
95ddcb0
842c50e
 
 
95ddcb0
c6aeea5
842c50e
 
 
 
 
 
 
 
 
 
 
c6aeea5
 
95ddcb0
7a0cd96
95ddcb0
 
 
c6aeea5
 
842c50e
 
 
 
 
 
 
 
 
 
7c9208c
2e53069
7c9208c
 
2e53069
842c50e
7a0cd96
95ddcb0
842c50e
 
 
7a0cd96
842c50e
 
 
 
 
 
7a0cd96
 
842c50e
7a0cd96
 
 
842c50e
95ddcb0
7c9208c
842c50e
 
 
95ddcb0
 
c6aeea5
 
 
 
 
7a0cd96
842c50e
dfa159a
95ddcb0
dfa159a
95ddcb0
dfa159a
7c9208c
dfa159a
7a0cd96
842c50e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from typing import  Dict, List, Any
import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, EulerAncestralDiscreteScheduler
from PIL import Image
import base64
from io import BytesIO
import numpy as np
# from RealESRGAN import RealESRGAN

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type != 'cuda':
    raise ValueError("need to run on GPU")

class EndpointHandler():
    def __init__(self, path=""):
        # load StableDiffusionInpaintPipeline pipeline
        self.pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16)
        # use EulerAncestralDiscreteScheduler
        self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
        # pipe.enable_sequential_cpu_offload()
        # move to device
        self.pipe.to(device)
        self.pipe.enable_xformers_memory_efficient_attention()

        # self.upscaler = RealESRGAN(device, scale=4)
        # self.upscaler.load_weights('weights/RealESRGAN_x4.pth', download=True)


    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        :param data: A dictionary contains `inputs` and optional `image` field.
        :return: A dictionary with `image` field contains image in base64.
        """
        inputs = data.pop("inputs", data)
        encoded_image = data.pop("image", None)
        encoded_mask_image = data.pop("mask_image", None)
        num_images = data.pop("num_images", None)
        print(f"num_image {num_images}")
        if num_images > 4 or num_images < 1:
          return {"Invalid Request": "Number of generated images must be >= 1 and <=4"}

        # hyperparamters
        num_inference_steps = data.pop("num_inference_steps", 50)
        guidance_scale = data.pop("guidance_scale", 7.5)
        negative_prompt = data.pop("negative_prompt", None)
        height = data.pop("height", None)
        width = data.pop("width", None)

        # process image
        if encoded_image is not None and encoded_mask_image is not None:
            image = self.decode_base64_image(encoded_image)
            mask_image = self.decode_base64_image(encoded_mask_image)
        else:
            image = None
            mask_image = None

        # run inference pipeline
        out = self.pipe(inputs,
                        image=image,
                        mask_image=mask_image,
                        num_inference_steps=num_inference_steps,
                        guidance_scale=guidance_scale,
                        num_images_per_prompt=num_images,
                        negative_prompt=negative_prompt,
                        height=height,
                        width=width
        ).images

        # for i in range(len(out)):
        #   gen_img = Image.composite(out[i], image.resize(out[i].size), mask_image.resize(out[i].size))
        #   gen_img = self.upscaler.predict(gen_img)
        #   gen_img = Image.composite(gen_img, image.resize(gen_img.size), mask_image.resize(gen_img.size))
        #   out[i] = gen_img

        # return first generate PIL image
        json_imgs = {}
        for i in range(len(out)):
          buffered = BytesIO()
          out[i].save(buffered, format="PNG")
          img_str = base64.b64encode(buffered.getvalue())
          json_imgs[f"{i}"] = img_str.decode()
        return json_imgs

    # helper to decode input image
    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image