File size: 1,071 Bytes
785c23c 44565ca 9036128 44565ca 785c23c 44565ca 785c23c 9036128 7f1c80b 83b5221 785c23c 83b5221 9036128 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | from typing import Dict, List, Any
import sys
import base64
import logging
import keras_cv
class EndpointHandler():
def __init__(self, path="", version="2"):
self.sd = self._instantiate_stable_diffusion(version)
if isinstance(self.sd, str):
sys.exit(self.sd)
else:
self.sd.text_to_image("test prompt", batch_size=1)
logging.warning(f"Stable Diffusion v{version} is fully loaded")
def _instantiate_stable_diffusion(self, version: str):
if version is "1.4":
return keras_cv.models.StableDiffusion(img_width=512, img_height=512)
elif version is "2":
return keras_cv.models.StableDiffusionV2(img_width=512, img_height=512)
else:
return f"v{version} is not supported"
def __call__(self, data: Dict[str, Any]) -> str:
prompt = data.pop("inputs", data)
batch_size = data.pop("batch_size", 1)
images = self.sd.text_to_image(prompt, batch_size=batch_size)
return base64.b64encode(images.tobytes()).decode()
|