from typing import Dict, List, Any import tensorflow as tf from keras_cv.models.generative.stable_diffusion.text_encoder import TextEncoder from keras_cv.models.generative.stable_diffusion.clip_tokenizer import SimpleTokenizer from keras_cv.models.generative.stable_diffusion.constants import _UNCONDITIONAL_TOKENS class EndpointHandler(): def __init__(self, path=""): self.MAX_PROMPT_LENGTH = 77 self.tokenizer = SimpleTokenizer() self.text_encoder = TextEncoder(self.MAX_PROMPT_LENGTH) self.pos_ids = tf.convert_to_tensor([list(range(self.MAX_PROMPT_LENGTH))], dtype=tf.int32) def _get_unconditional_context(self): unconditional_tokens = tf.convert_to_tensor( [_UNCONDITIONAL_TOKENS], dtype=tf.int32 ) unconditional_context = self.text_encoder.predict_on_batch( [unconditional_tokens, self.pos_ids] ) return unconditional_context def encode_text(self, prompt): # Tokenize prompt (i.e. starting context) inputs = self.tokenizer.encode(prompt) if len(inputs) > self.MAX_PROMPT_LENGTH: raise ValueError( f"Prompt is too long (should be <= {self.MAX_PROMPT_LENGTH} tokens)" ) phrase = inputs + [49407] * (self.MAX_PROMPT_LENGTH - len(inputs)) phrase = tf.convert_to_tensor([phrase], dtype=tf.int32) context = self.text_encoder.predict_on_batch([phrase, self.pos_ids]) return context def get_contexts(self, encoded_text, batch_size): encoded_text = tf.squeeze(encoded_text) if encoded_text.shape.rank == 2: encoded_text = tf.repeat( tf.expand_dims(encoded_text, axis=0), batch_size, axis=0 ) context = encoded_text unconditional_context = tf.repeat( _get_unconditional_context(), batch_size, axis=0 ) return context, unconditional_context def __call__(self, data: Dict[str, Any]) -> str: # get inputs prompt = data.pop("inputs", data) batch_size = data.pop("batch_size", 1) encoded_text = encode_text(prompt) context, unconditional_context = get_contexts(encoded_text, batch_size) return context, unconditional_context