Cursed-Text-to-Image / src /streamlit_app.py
Beasto's picture
Update src/streamlit_app.py
85e2b48 verified
import streamlit as st
import numpy as np
import tensorflow as tf
import string
import re
from tensorflow.keras.utils import register_keras_serializable
from tensorflow.keras.layers import Conv2D, Add
from tensorflow import keras
from tensorflow.keras import layers
import os
@register_keras_serializable(package="Custom")
class SelfAttention(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
self.filters = input_shape[-1]
self.f = Conv2D(self.filters // 8, kernel_size=1, padding='same')
self.g = Conv2D(self.filters // 8, kernel_size=1, padding='same')
self.h = Conv2D(self.filters, kernel_size=1, padding='same')
super().build(input_shape)
def call(self, x):
f = self.f(x) # (B, H, W, C//8)
g = self.g(x)
h = self.h(x) # (B, H, W, C)
shape_f = tf.shape(f)
B, H, W = shape_f[0], shape_f[1], shape_f[2]
f_flat = tf.reshape(f, [B, H * W, self.filters // 8])
g_flat = tf.reshape(g, [B, H * W, self.filters // 8])
h_flat = tf.reshape(h, [B, H * W, self.filters])
beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), axis=-1) # (B, N, N)
o = tf.matmul(beta, h_flat) # (B, N, C)
o = tf.reshape(o, [B, H, W, self.filters])
return Add()([x, o]) # Residual connection
def get_config(self):
config = super().get_config()
# If you have custom arguments in __init__, add them here
return config
captions = np.load('src/caption.npy')
decoder = tf.keras.models.load_model('src/epoch_78_decoder.keras',custom_objects={'SelfAttention': SelfAttention})
codebook = np.load('src/epoch_78_codebook.npy')
codebook = tf.convert_to_tensor(codebook, dtype=tf.float32)
strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")
def custom_standardization(input_string):
lowercase = tf.strings.lower(input_string)
return tf.strings.regex_replace(
lowercase, f"[{re.escape(strip_chars)}]", "")
vocab_size = 18500
sequence_length = 85
vectorizer = layers.TextVectorization(
max_tokens=vocab_size,
output_mode="int",
output_sequence_length=sequence_length,
)
vectorizer.adapt(captions)
@register_keras_serializable(package="Custom")
class TransformerEncoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads))
self.dense_proj = keras.Sequential([
layers.Dense(dense_dim, activation="relu"),
layers.Dense(embed_dim)
])
self.layernorm_1 = layers.LayerNormalization(epsilon=1e-5)
self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5)
self.dropout_1 = layers.Dropout(0.1)
def call(self, inputs, mask=None):
# Convert mask to boolean with shape (batch, 1, seq_len)
if mask is not None:
mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.bool)
attention_output = self.attention(
query=inputs,
value=inputs,
key=inputs,
attention_mask=mask
)
attention_output = self.dropout_1(attention_output)
proj_input = self.layernorm_1(inputs + attention_output)
proj_output = self.dense_proj(proj_input)
return self.layernorm_2(proj_input + proj_output)
@register_keras_serializable(package="Custom")
class TransformerDecoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.attention_1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads))
self.attention_2 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads))
self.dense_proj = keras.Sequential([
layers.Dense(dense_dim, activation="relu"),
layers.Dense(embed_dim)
])
self.dropout_1 = layers.Dropout(0.1)
self.dropout_2 = layers.Dropout(0.1)
self.layernorm_1 = layers.LayerNormalization(epsilon=1e-5)
self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5)
self.layernorm_3 = layers.LayerNormalization(epsilon=1e-5)
def call(self, inputs, encoder_outputs, mask=None):
attention_output_1 = self.attention_1(
query=inputs,
value=inputs,
key=inputs,
attention_mask=None,
use_causal_mask=True
)
attention_output_1 = self.dropout_1(attention_output_1)
attention_output_1 = self.layernorm_1(inputs + attention_output_1)
# Cross-attention with padding mask only
attention_output_2 = self.attention_2(
query=attention_output_1,
value=encoder_outputs,
key=encoder_outputs,
attention_mask=mask,
use_causal_mask=False
)
attention_output_2 = self.dropout_2(attention_output_2)
attention_output_2 = self.layernorm_2(attention_output_1 + attention_output_2)
proj_output = self.dense_proj(attention_output_2)
return self.layernorm_3(attention_output_2 + proj_output)
@register_keras_serializable(package="Custom")
class PositionalEmbedding(layers.Layer):
def __init__(self, sequence_length, input_dim, output_dim, mask_zero = True, **kwargs):
super().__init__(**kwargs)
self.token_embeddings = layers.Embedding(
input_dim=input_dim, output_dim=output_dim,mask_zero=mask_zero)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=output_dim,mask_zero=False)
self.sequence_length = sequence_length
self.input_dim = input_dim
self.output_dim = output_dim
def call(self, inputs):
length = tf.shape(inputs)[-1]
positions = tf.range(start=0, limit=length, delta=1)
embedded_tokens = self.token_embeddings(inputs)
embedded_positions = self.position_embeddings(positions)
return embedded_tokens + embedded_positions
embed_dim = 512
dense_dim = 2048
num_heads = 8
num_blocks = 7
encoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="encoder_inputs")
decoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="decoder_inputs")
# Masks
encoder_mask = tf.keras.layers.Lambda(lambda x: tf.cast(tf.not_equal(x, 0), tf.bool))(encoder_inputs)
cross_attention_mask = tf.keras.layers.Lambda(lambda x: tf.cast(x[:, tf.newaxis, tf.newaxis, :], tf.bool))(encoder_mask)
# Embeddings
encoder_embed = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
decoder_embed = PositionalEmbedding(256, 257, embed_dim, mask_zero=False)(decoder_inputs)
# Pre-instantiate blocks
encoder_blocks = [TransformerEncoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)]
decoder_blocks = [TransformerDecoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)]
# Encoder
x = encoder_embed
for block in encoder_blocks:
x = block(x, mask=encoder_mask)
encoder_outputs = x
# Decoder
x = decoder_embed
for block in decoder_blocks:
x = block(x, encoder_outputs, mask=cross_attention_mask)
# Output layers
x = layers.LayerNormalization(epsilon=1e-5)(x)
x = layers.Dropout(0.1)(x)
decoder_outputs = layers.Dense(256)(x)
transformer = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
start_token = 256
max_output_length = 256
def generate_image_tokens(input_text):
# Vectorize input text
tokenized_text = vectorizer([input_text]) # Shape: (1, text_seq_len)
# Start the decoded sequence with the start token
decoded_image_tokens = [start_token]
for i in range(max_output_length):
# Convert to proper input format
decoder_input = tf.convert_to_tensor([decoded_image_tokens])
# Predict next token probabilities
predictions = (transformer([tokenized_text, decoder_input]))
# Get the token for the current step
sampled_token_index = np.argmax(predictions[0, -1, :])
# Append token to sequence
decoded_image_tokens.append(sampled_token_index)
# Optionally decode tokens into an image here
return decoded_image_tokens[1:]
def get_embeddings(indices, codebook):
flat_indices = tf.reshape(indices, [-1])
flat_embeddings = tf.nn.embedding_lookup(codebook, flat_indices)
out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], axis=0)
return tf.reshape(flat_embeddings, (-1,16,16,256))
transformer.load_weights('src/VQGAN_Transformer.weights.h5')
user_input = st.text_input("Enter some text:", "")
if user_input != "":
with st.spinner("Generating image..."):
st.write(user_input)
output_tokens = generate_image_tokens(user_input)
embedding = get_embeddings(output_tokens, codebook)
image = decoder(embedding)[0].numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
st.image(image, caption="Generated Image", width=512)