import streamlit as st import numpy as np import tensorflow as tf import string import re from tensorflow.keras.utils import register_keras_serializable from tensorflow.keras.layers import Conv2D, Add from tensorflow import keras from tensorflow.keras import layers import os @register_keras_serializable(package="Custom") class SelfAttention(tf.keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) def build(self, input_shape): self.filters = input_shape[-1] self.f = Conv2D(self.filters // 8, kernel_size=1, padding='same') self.g = Conv2D(self.filters // 8, kernel_size=1, padding='same') self.h = Conv2D(self.filters, kernel_size=1, padding='same') super().build(input_shape) def call(self, x): f = self.f(x) # (B, H, W, C//8) g = self.g(x) h = self.h(x) # (B, H, W, C) shape_f = tf.shape(f) B, H, W = shape_f[0], shape_f[1], shape_f[2] f_flat = tf.reshape(f, [B, H * W, self.filters // 8]) g_flat = tf.reshape(g, [B, H * W, self.filters // 8]) h_flat = tf.reshape(h, [B, H * W, self.filters]) beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), axis=-1) # (B, N, N) o = tf.matmul(beta, h_flat) # (B, N, C) o = tf.reshape(o, [B, H, W, self.filters]) return Add()([x, o]) # Residual connection def get_config(self): config = super().get_config() # If you have custom arguments in __init__, add them here return config captions = np.load('src/caption.npy') decoder = tf.keras.models.load_model('src/epoch_78_decoder.keras',custom_objects={'SelfAttention': SelfAttention}) codebook = np.load('src/epoch_78_codebook.npy') codebook = tf.convert_to_tensor(codebook, dtype=tf.float32) strip_chars = string.punctuation + "¿" strip_chars = strip_chars.replace("[", "") strip_chars = strip_chars.replace("]", "") def custom_standardization(input_string): lowercase = tf.strings.lower(input_string) return tf.strings.regex_replace( lowercase, f"[{re.escape(strip_chars)}]", "") vocab_size = 18500 sequence_length = 85 vectorizer = layers.TextVectorization( max_tokens=vocab_size, output_mode="int", output_sequence_length=sequence_length, ) vectorizer.adapt(captions) @register_keras_serializable(package="Custom") class TransformerEncoder(layers.Layer): def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.dense_dim = dense_dim self.num_heads = num_heads self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads)) self.dense_proj = keras.Sequential([ layers.Dense(dense_dim, activation="relu"), layers.Dense(embed_dim) ]) self.layernorm_1 = layers.LayerNormalization(epsilon=1e-5) self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5) self.dropout_1 = layers.Dropout(0.1) def call(self, inputs, mask=None): # Convert mask to boolean with shape (batch, 1, seq_len) if mask is not None: mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.bool) attention_output = self.attention( query=inputs, value=inputs, key=inputs, attention_mask=mask ) attention_output = self.dropout_1(attention_output) proj_input = self.layernorm_1(inputs + attention_output) proj_output = self.dense_proj(proj_input) return self.layernorm_2(proj_input + proj_output) @register_keras_serializable(package="Custom") class TransformerDecoder(layers.Layer): def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.dense_dim = dense_dim self.num_heads = num_heads self.attention_1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads)) self.attention_2 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads)) self.dense_proj = keras.Sequential([ layers.Dense(dense_dim, activation="relu"), layers.Dense(embed_dim) ]) self.dropout_1 = layers.Dropout(0.1) self.dropout_2 = layers.Dropout(0.1) self.layernorm_1 = layers.LayerNormalization(epsilon=1e-5) self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5) self.layernorm_3 = layers.LayerNormalization(epsilon=1e-5) def call(self, inputs, encoder_outputs, mask=None): attention_output_1 = self.attention_1( query=inputs, value=inputs, key=inputs, attention_mask=None, use_causal_mask=True ) attention_output_1 = self.dropout_1(attention_output_1) attention_output_1 = self.layernorm_1(inputs + attention_output_1) # Cross-attention with padding mask only attention_output_2 = self.attention_2( query=attention_output_1, value=encoder_outputs, key=encoder_outputs, attention_mask=mask, use_causal_mask=False ) attention_output_2 = self.dropout_2(attention_output_2) attention_output_2 = self.layernorm_2(attention_output_1 + attention_output_2) proj_output = self.dense_proj(attention_output_2) return self.layernorm_3(attention_output_2 + proj_output) @register_keras_serializable(package="Custom") class PositionalEmbedding(layers.Layer): def __init__(self, sequence_length, input_dim, output_dim, mask_zero = True, **kwargs): super().__init__(**kwargs) self.token_embeddings = layers.Embedding( input_dim=input_dim, output_dim=output_dim,mask_zero=mask_zero) self.position_embeddings = layers.Embedding( input_dim=sequence_length, output_dim=output_dim,mask_zero=False) self.sequence_length = sequence_length self.input_dim = input_dim self.output_dim = output_dim def call(self, inputs): length = tf.shape(inputs)[-1] positions = tf.range(start=0, limit=length, delta=1) embedded_tokens = self.token_embeddings(inputs) embedded_positions = self.position_embeddings(positions) return embedded_tokens + embedded_positions embed_dim = 512 dense_dim = 2048 num_heads = 8 num_blocks = 7 encoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="encoder_inputs") decoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="decoder_inputs") # Masks encoder_mask = tf.keras.layers.Lambda(lambda x: tf.cast(tf.not_equal(x, 0), tf.bool))(encoder_inputs) cross_attention_mask = tf.keras.layers.Lambda(lambda x: tf.cast(x[:, tf.newaxis, tf.newaxis, :], tf.bool))(encoder_mask) # Embeddings encoder_embed = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs) decoder_embed = PositionalEmbedding(256, 257, embed_dim, mask_zero=False)(decoder_inputs) # Pre-instantiate blocks encoder_blocks = [TransformerEncoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)] decoder_blocks = [TransformerDecoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)] # Encoder x = encoder_embed for block in encoder_blocks: x = block(x, mask=encoder_mask) encoder_outputs = x # Decoder x = decoder_embed for block in decoder_blocks: x = block(x, encoder_outputs, mask=cross_attention_mask) # Output layers x = layers.LayerNormalization(epsilon=1e-5)(x) x = layers.Dropout(0.1)(x) decoder_outputs = layers.Dense(256)(x) transformer = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs) start_token = 256 max_output_length = 256 def generate_image_tokens(input_text): # Vectorize input text tokenized_text = vectorizer([input_text]) # Shape: (1, text_seq_len) # Start the decoded sequence with the start token decoded_image_tokens = [start_token] for i in range(max_output_length): # Convert to proper input format decoder_input = tf.convert_to_tensor([decoded_image_tokens]) # Predict next token probabilities predictions = (transformer([tokenized_text, decoder_input])) # Get the token for the current step sampled_token_index = np.argmax(predictions[0, -1, :]) # Append token to sequence decoded_image_tokens.append(sampled_token_index) # Optionally decode tokens into an image here return decoded_image_tokens[1:] def get_embeddings(indices, codebook): flat_indices = tf.reshape(indices, [-1]) flat_embeddings = tf.nn.embedding_lookup(codebook, flat_indices) out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], axis=0) return tf.reshape(flat_embeddings, (-1,16,16,256)) transformer.load_weights('src/VQGAN_Transformer.weights.h5') user_input = st.text_input("Enter some text:", "") if user_input != "": with st.spinner("Generating image..."): st.write(user_input) output_tokens = generate_image_tokens(user_input) embedding = get_embeddings(output_tokens, codebook) image = decoder(embedding)[0].numpy() image = np.clip(image * 255, 0, 255).astype(np.uint8) st.image(image, caption="Generated Image", width=512)