update custom handler
Browse files- __pycache__/handler.cpython-38.pyc +0 -0
- handler.py +38 -12
- requirements.txt +2 -2
__pycache__/handler.cpython-38.pyc
CHANGED
|
Binary files a/__pycache__/handler.cpython-38.pyc and b/__pycache__/handler.cpython-38.pyc differ
|
|
|
handler.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
from typing import Dict, List, Any
|
| 2 |
-
import base64
|
| 3 |
|
|
|
|
|
|
|
| 4 |
import math
|
| 5 |
import numpy as np
|
| 6 |
import tensorflow as tf
|
| 7 |
from tensorflow import keras
|
| 8 |
-
from keras_cv.models.
|
| 9 |
-
from keras_cv.models.
|
|
|
|
| 10 |
|
| 11 |
class GroupNormalization(tf.keras.layers.Layer):
|
| 12 |
"""GroupNormalization layer.
|
|
@@ -184,7 +186,7 @@ class ImageEncoder(keras.Sequential):
|
|
| 184 |
self.load_weights(image_encoder_weights_fpath)
|
| 185 |
|
| 186 |
class EndpointHandler():
|
| 187 |
-
def __init__(self, path=""):
|
| 188 |
self.seed = None
|
| 189 |
|
| 190 |
img_height = 512
|
|
@@ -193,15 +195,33 @@ class EndpointHandler():
|
|
| 193 |
self.img_width = round(img_width / 128) * 128
|
| 194 |
|
| 195 |
self.MAX_PROMPT_LENGTH = 77
|
| 196 |
-
self.
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
)
|
| 201 |
-
self.diffusion_model.load_weights(diffusion_model_weights_fpath)
|
| 202 |
|
| 203 |
self.image_encoder = ImageEncoder()
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
def _get_initial_diffusion_noise(self, batch_size, seed):
|
| 206 |
if seed is not None:
|
| 207 |
return tf.random.stateless_normal(
|
|
@@ -266,11 +286,17 @@ class EndpointHandler():
|
|
| 266 |
|
| 267 |
context = base64.b64decode(inputs[0])
|
| 268 |
context = np.frombuffer(context, dtype="float32")
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
unconditional_context = base64.b64decode(inputs[1])
|
| 272 |
unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
num_steps = data.pop("num_steps", 25)
|
| 276 |
unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
|
|
|
|
| 1 |
from typing import Dict, List, Any
|
|
|
|
| 2 |
|
| 3 |
+
import sys
|
| 4 |
+
import base64
|
| 5 |
import math
|
| 6 |
import numpy as np
|
| 7 |
import tensorflow as tf
|
| 8 |
from tensorflow import keras
|
| 9 |
+
from keras_cv.models.stable_diffusion.constants import _ALPHAS_CUMPROD
|
| 10 |
+
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
|
| 11 |
+
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModelV2
|
| 12 |
|
| 13 |
class GroupNormalization(tf.keras.layers.Layer):
|
| 14 |
"""GroupNormalization layer.
|
|
|
|
| 186 |
self.load_weights(image_encoder_weights_fpath)
|
| 187 |
|
| 188 |
class EndpointHandler():
|
| 189 |
+
def __init__(self, path="", version="2"):
|
| 190 |
self.seed = None
|
| 191 |
|
| 192 |
img_height = 512
|
|
|
|
| 195 |
self.img_width = round(img_width / 128) * 128
|
| 196 |
|
| 197 |
self.MAX_PROMPT_LENGTH = 77
|
| 198 |
+
self.version = version
|
| 199 |
+
self.diffusion_model = self._instantiate_diffusion_model(version)
|
| 200 |
+
if isinstance(self.diffusion_model, str):
|
| 201 |
+
sys.exit(self.diffusion_model)
|
|
|
|
|
|
|
| 202 |
|
| 203 |
self.image_encoder = ImageEncoder()
|
| 204 |
|
| 205 |
+
def _instantiate_diffusion_model(self, version: str):
|
| 206 |
+
if version == "1.4":
|
| 207 |
+
diffusion_model_weights_fpath = keras.utils.get_file(
|
| 208 |
+
origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
|
| 209 |
+
file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
|
| 210 |
+
)
|
| 211 |
+
diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
|
| 212 |
+
diffusion_model.load_weights(diffusion_model_weights_fpath)
|
| 213 |
+
return diffusion_model
|
| 214 |
+
elif version == "2":
|
| 215 |
+
diffusion_model_weights_fpath = keras.utils.get_file(
|
| 216 |
+
origin="https://huggingface.co/ianstenbit/keras-sd2.1/resolve/main/diffusion_model_v2_1.h5",
|
| 217 |
+
file_hash="c31730e91111f98fe0e2dbde4475d381b5287ebb9672b1821796146a25c5132d",
|
| 218 |
+
)
|
| 219 |
+
diffusion_model = DiffusionModelV2(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH)
|
| 220 |
+
diffusion_model.load_weights(diffusion_model_weights_fpath)
|
| 221 |
+
return diffusion_model
|
| 222 |
+
else:
|
| 223 |
+
return f"v{version} is not supported"
|
| 224 |
+
|
| 225 |
def _get_initial_diffusion_noise(self, batch_size, seed):
|
| 226 |
if seed is not None:
|
| 227 |
return tf.random.stateless_normal(
|
|
|
|
| 286 |
|
| 287 |
context = base64.b64decode(inputs[0])
|
| 288 |
context = np.frombuffer(context, dtype="float32")
|
| 289 |
+
if self.version == "1.4":
|
| 290 |
+
context = np.reshape(context, (batch_size, 77, 768))
|
| 291 |
+
else:
|
| 292 |
+
context = np.reshape(context, (batch_size, 77, 1024))
|
| 293 |
|
| 294 |
unconditional_context = base64.b64decode(inputs[1])
|
| 295 |
unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
|
| 296 |
+
if self.version == "1.4":
|
| 297 |
+
unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
|
| 298 |
+
else:
|
| 299 |
+
unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 1024))
|
| 300 |
|
| 301 |
num_steps = data.pop("num_steps", 25)
|
| 302 |
unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5)
|
requirements.txt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
keras-cv
|
| 2 |
-
tensorflow
|
| 3 |
tensorflow_datasets
|
|
|
|
| 1 |
+
keras-cv==0.4
|
| 2 |
+
tensorflow==2.11
|
| 3 |
tensorflow_datasets
|