| from typing import Dict, List, Any | |
| import sys | |
| import base64 | |
| import logging | |
| import keras_cv | |
| class EndpointHandler(): | |
| def __init__(self, path="", version="2"): | |
| self.sd = self._instantiate_stable_diffusion(version) | |
| if isinstance(self.sd, str): | |
| sys.exit(self.sd) | |
| else: | |
| self.sd.text_to_image("test prompt", batch_size=1) | |
| logging.warning(f"Stable Diffusion v{version} is fully loaded") | |
| def _instantiate_stable_diffusion(self, version: str): | |
| if version is "1.4": | |
| return keras_cv.models.StableDiffusion(img_width=512, img_height=512) | |
| elif version is "2": | |
| return keras_cv.models.StableDiffusionV2(img_width=512, img_height=512) | |
| else: | |
| return f"v{version} is not supported" | |
| def __call__(self, data: Dict[str, Any]) -> str: | |
| prompt = data.pop("inputs", data) | |
| batch_size = data.pop("batch_size", 1) | |
| images = self.sd.text_to_image(prompt, batch_size=batch_size) | |
| return base64.b64encode(images.tobytes()).decode() | |