| | from typing import Dict, List, Any |
| |
|
| | import tensorflow as tf |
| | from keras_cv.models.generative.stable_diffusion.text_encoder import TextEncoder |
| | from keras_cv.models.generative.stable_diffusion.clip_tokenizer import SimpleTokenizer |
| | from keras_cv.models.generative.stable_diffusion.constants import _UNCONDITIONAL_TOKENS |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | self.MAX_PROMPT_LENGTH = 77 |
| |
|
| | self.tokenizer = SimpleTokenizer() |
| | self.text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH) |
| | self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32) |
| |
|
| | def _get_unconditional_context(self): |
| | unconditional_tokens = tf.convert_to_tensor( |
| | [_UNCONDITIONAL_TOKENS], dtype=tf.int32 |
| | ) |
| | unconditional_context = self.text_encoder.predict_on_batch( |
| | [unconditional_tokens, self.pos_ids] |
| | ) |
| |
|
| | return unconditional_context |
| |
|
| | def encode_text(self, prompt): |
| | |
| | inputs = self.tokenizer.encode(prompt) |
| | if len(inputs) > self.MAX_PROMPT_LENGTH: |
| | raise ValueError( |
| | f"Prompt is too long (should be <= {self.MAX_PROMPT_LENGTH} tokens)" |
| | ) |
| | phrase = inputs + [49407] * (self.MAX_PROMPT_LENGTH - len(inputs)) |
| | phrase = tf.convert_to_tensor([phrase], dtype=tf.int32) |
| |
|
| | context = self.text_encoder.predict_on_batch([phrase, self.pos_ids]) |
| |
|
| | return context |
| |
|
| | def get_contexts(self, encoded_text, batch_size): |
| | encoded_text = tf.squeeze(encoded_text) |
| | if encoded_text.shape.rank == 2: |
| | encoded_text = tf.repeat( |
| | tf.expand_dims(encoded_text, axis=0), batch_size, axis=0 |
| | ) |
| |
|
| | context = encoded_text |
| |
|
| | unconditional_context = tf.repeat( |
| | _get_unconditional_context(), batch_size, axis=0 |
| | ) |
| |
|
| | return context, unconditional_context |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> str: |
| | |
| | prompt = data.pop("inputs", data) |
| | batch_size = data.pop("batch_size", 1) |
| |
|
| | encoded_text = encode_text(prompt) |
| | context, unconditional_context = get_contexts(encoded_text, batch_size) |
| |
|
| | return context, unconditional_context |
| |
|