File size: 1,071 Bytes
785c23c
44565ca
9036128
44565ca
785c23c
 
 
44565ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785c23c
9036128
7f1c80b
83b5221
785c23c
83b5221
9036128
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from typing import Dict, List, Any
import sys
import base64
import logging
import keras_cv

class EndpointHandler():
    def __init__(self, path="", version="2"):
        self.sd = self._instantiate_stable_diffusion(version)

        if isinstance(self.sd, str):
            sys.exit(self.sd)
        else:
            self.sd.text_to_image("test prompt", batch_size=1)
            logging.warning(f"Stable Diffusion v{version} is fully loaded")

    def _instantiate_stable_diffusion(self, version: str):
        if version is "1.4":
            return keras_cv.models.StableDiffusion(img_width=512, img_height=512)
        elif version is "2":
            return keras_cv.models.StableDiffusionV2(img_width=512, img_height=512)
        else:
            return f"v{version} is not supported"
    
    def __call__(self, data: Dict[str, Any]) -> str:
        prompt = data.pop("inputs", data)
        batch_size = data.pop("batch_size", 1)

        images = self.sd.text_to_image(prompt, batch_size=batch_size)
        return base64.b64encode(images.tobytes()).decode()