File size: 1,830 Bytes
0ca1aa6
 
 
 
 
 
cfb6c44
0ca1aa6
cfb6c44
0ca1aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfb6c44
c061d70
71f792c
 
cfb6c44
 
 
71f792c
0ca1aa6
d6dfa49
42d9957
0ca1aa6
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import base64
from io import BytesIO
from typing import Dict, Any

import torch
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline
import requests


# helper decoder
def decode_base64_image(image_string):
    base64_image = base64.b64decode(image_string)
    buffer = BytesIO(base64_image)
    return Image.open(buffer)


class EndpointHandler:
    def __init__(self, path=""):
        self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(path,
            torch_dtype=torch.float16, revision="fp16")
        self.pipe = self.pipe.to("cuda")

    def __call__(self, data: Any) -> Dict[str, str]:
        """
        Return predict value.
        :param data: A dictionary contains `inputs` and optional `image` field.
        :return: A dictionary with `image` field contains image in base64.
        """
        prompts = data.pop("inputs", None)
        url = data.pop("image", None)
        seed = data.pop("seed", 0) 
        width = data.pop("width", 0) 
        height = data.pop("height", 0) 
        response = requests.get(url)
        init_image = Image.open(BytesIO(response.content)).convert("RGB")
        #init_image = decode_base64_image(encoded_image)
        init_image.thumbnail((width, height))

        generator = torch.Generator(device="cuda").manual_seed(seed)
        images = self.pipe(prompts, image=init_image,generator = generator, **data).images
        img_strs = []

        for image in images:
            buffered = BytesIO()
            image.save(buffered, format="png")
            img_str = base64.b64encode(buffered.getvalue())
            img_strs.append(img_str)

        if len(img_strs) > 1 :
            return {"images": [img_str.decode() for img_str in img_strs] }
        else:
            return {"image": img_strs[0].decode() }