| | from typing import Dict, List, Any |
| |
|
| | import base64 |
| | import math |
| | import numpy as np |
| | import tensorflow as tf |
| | from tensorflow import keras |
| |
|
| | from keras_cv.models.generative.stable_diffusion.constants import _ALPHAS_CUMPROD |
| | from keras_cv.models.generative.stable_diffusion.diffusion_model import DiffusionModel |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | self.seed = None |
| |
|
| | img_height = 512 |
| | img_width = 512 |
| | self.img_height = round(img_height / 128) * 128 |
| | self.img_width = round(img_width / 128) * 128 |
| |
|
| | self.MAX_PROMPT_LENGTH = 77 |
| | self.diffusion_model = DiffusionModel(self.img_height, self.img_width, self.MAX_PROMPT_LENGTH) |
| |
|
| | def _get_initial_diffusion_noise(self, batch_size, seed): |
| | if seed is not None: |
| | return tf.random.stateless_normal( |
| | (batch_size, self.img_height // 8, self.img_width // 8, 4), |
| | seed=[seed, seed], |
| | ) |
| | else: |
| | return tf.random.normal( |
| | (batch_size, self.img_height // 8, self.img_width // 8, 4) |
| | ) |
| |
|
| | def _get_initial_alphas(self, timesteps): |
| | alphas = [_ALPHAS_CUMPROD[t] for t in timesteps] |
| | alphas_prev = [1.0] + alphas[:-1] |
| |
|
| | return alphas, alphas_prev |
| |
|
| | def _get_timestep_embedding(self, timestep, batch_size, dim=320, max_period=10000): |
| | half = dim // 2 |
| | freqs = tf.math.exp( |
| | -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half |
| | ) |
| | args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs |
| | embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0) |
| | embedding = tf.reshape(embedding, [1, -1]) |
| | return tf.repeat(embedding, batch_size, axis=0) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> str: |
| | |
| | tmp_data = data.pop("inputs", data) |
| |
|
| | context = base64.b64decode(tmp_data[0]) |
| | context = np.frombuffer(context, dtype="float32") |
| | context = np.reshape(context, (1, 77, 768)) |
| |
|
| | unconditional_context = base64.b64decode(tmp_data[1]) |
| | unconditional_context = np.frombuffer(unconditional_context, dtype="float32") |
| | unconditional_context = np.reshape(unconditional_context, (1, 77, 768)) |
| |
|
| | batch_size = data.pop("batch_size", 1) |
| |
|
| | num_steps = data.pop("num_steps", 50) |
| | unconditional_guidance_scale = data.pop("unconditional_guidance_scale", 7.5) |
| |
|
| | latent = self._get_initial_diffusion_noise(batch_size, self.seed) |
| |
|
| | |
| | timesteps = tf.range(1, 1000, 1000 // num_steps) |
| | alphas, alphas_prev = self._get_initial_alphas(timesteps) |
| | progbar = keras.utils.Progbar(len(timesteps)) |
| | iteration = 0 |
| | for index, timestep in list(enumerate(timesteps))[::-1]: |
| | latent_prev = latent |
| | t_emb = self._get_timestep_embedding(timestep, batch_size) |
| | unconditional_latent = self.diffusion_model.predict_on_batch( |
| | [latent, t_emb, unconditional_context] |
| | ) |
| | latent = self.diffusion_model.predict_on_batch([latent, t_emb, context]) |
| | latent = unconditional_latent + unconditional_guidance_scale * ( |
| | latent - unconditional_latent |
| | ) |
| | a_t, a_prev = alphas[index], alphas_prev[index] |
| | pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(a_t) |
| | latent = latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0 |
| | iteration += 1 |
| | progbar.update(iteration) |
| |
|
| | latent_b64 = base64.b64encode(latent.numpy().tobytes()) |
| | latent_b64str = latent_b64.decode() |
| |
|
| | return latent_b64str |
| |
|