Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import tensorflow as tf | |
| import string | |
| import re | |
| from tensorflow.keras.utils import register_keras_serializable | |
| from tensorflow.keras.layers import Conv2D, Add | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| import os | |
| class SelfAttention(tf.keras.layers.Layer): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def build(self, input_shape): | |
| self.filters = input_shape[-1] | |
| self.f = Conv2D(self.filters // 8, kernel_size=1, padding='same') | |
| self.g = Conv2D(self.filters // 8, kernel_size=1, padding='same') | |
| self.h = Conv2D(self.filters, kernel_size=1, padding='same') | |
| super().build(input_shape) | |
| def call(self, x): | |
| f = self.f(x) # (B, H, W, C//8) | |
| g = self.g(x) | |
| h = self.h(x) # (B, H, W, C) | |
| shape_f = tf.shape(f) | |
| B, H, W = shape_f[0], shape_f[1], shape_f[2] | |
| f_flat = tf.reshape(f, [B, H * W, self.filters // 8]) | |
| g_flat = tf.reshape(g, [B, H * W, self.filters // 8]) | |
| h_flat = tf.reshape(h, [B, H * W, self.filters]) | |
| beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), axis=-1) # (B, N, N) | |
| o = tf.matmul(beta, h_flat) # (B, N, C) | |
| o = tf.reshape(o, [B, H, W, self.filters]) | |
| return Add()([x, o]) # Residual connection | |
| def get_config(self): | |
| config = super().get_config() | |
| # If you have custom arguments in __init__, add them here | |
| return config | |
| captions = np.load('src/caption.npy') | |
| decoder = tf.keras.models.load_model('src/epoch_78_decoder.keras',custom_objects={'SelfAttention': SelfAttention}) | |
| codebook = np.load('src/epoch_78_codebook.npy') | |
| codebook = tf.convert_to_tensor(codebook, dtype=tf.float32) | |
| strip_chars = string.punctuation + "¿" | |
| strip_chars = strip_chars.replace("[", "") | |
| strip_chars = strip_chars.replace("]", "") | |
| def custom_standardization(input_string): | |
| lowercase = tf.strings.lower(input_string) | |
| return tf.strings.regex_replace( | |
| lowercase, f"[{re.escape(strip_chars)}]", "") | |
| vocab_size = 18500 | |
| sequence_length = 85 | |
| vectorizer = layers.TextVectorization( | |
| max_tokens=vocab_size, | |
| output_mode="int", | |
| output_sequence_length=sequence_length, | |
| ) | |
| vectorizer.adapt(captions) | |
| class TransformerEncoder(layers.Layer): | |
| def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): | |
| super().__init__(**kwargs) | |
| self.embed_dim = embed_dim | |
| self.dense_dim = dense_dim | |
| self.num_heads = num_heads | |
| self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads)) | |
| self.dense_proj = keras.Sequential([ | |
| layers.Dense(dense_dim, activation="relu"), | |
| layers.Dense(embed_dim) | |
| ]) | |
| self.layernorm_1 = layers.LayerNormalization(epsilon=1e-5) | |
| self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5) | |
| self.dropout_1 = layers.Dropout(0.1) | |
| def call(self, inputs, mask=None): | |
| # Convert mask to boolean with shape (batch, 1, seq_len) | |
| if mask is not None: | |
| mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.bool) | |
| attention_output = self.attention( | |
| query=inputs, | |
| value=inputs, | |
| key=inputs, | |
| attention_mask=mask | |
| ) | |
| attention_output = self.dropout_1(attention_output) | |
| proj_input = self.layernorm_1(inputs + attention_output) | |
| proj_output = self.dense_proj(proj_input) | |
| return self.layernorm_2(proj_input + proj_output) | |
| class TransformerDecoder(layers.Layer): | |
| def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): | |
| super().__init__(**kwargs) | |
| self.embed_dim = embed_dim | |
| self.dense_dim = dense_dim | |
| self.num_heads = num_heads | |
| self.attention_1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads)) | |
| self.attention_2 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads)) | |
| self.dense_proj = keras.Sequential([ | |
| layers.Dense(dense_dim, activation="relu"), | |
| layers.Dense(embed_dim) | |
| ]) | |
| self.dropout_1 = layers.Dropout(0.1) | |
| self.dropout_2 = layers.Dropout(0.1) | |
| self.layernorm_1 = layers.LayerNormalization(epsilon=1e-5) | |
| self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5) | |
| self.layernorm_3 = layers.LayerNormalization(epsilon=1e-5) | |
| def call(self, inputs, encoder_outputs, mask=None): | |
| attention_output_1 = self.attention_1( | |
| query=inputs, | |
| value=inputs, | |
| key=inputs, | |
| attention_mask=None, | |
| use_causal_mask=True | |
| ) | |
| attention_output_1 = self.dropout_1(attention_output_1) | |
| attention_output_1 = self.layernorm_1(inputs + attention_output_1) | |
| # Cross-attention with padding mask only | |
| attention_output_2 = self.attention_2( | |
| query=attention_output_1, | |
| value=encoder_outputs, | |
| key=encoder_outputs, | |
| attention_mask=mask, | |
| use_causal_mask=False | |
| ) | |
| attention_output_2 = self.dropout_2(attention_output_2) | |
| attention_output_2 = self.layernorm_2(attention_output_1 + attention_output_2) | |
| proj_output = self.dense_proj(attention_output_2) | |
| return self.layernorm_3(attention_output_2 + proj_output) | |
| class PositionalEmbedding(layers.Layer): | |
| def __init__(self, sequence_length, input_dim, output_dim, mask_zero = True, **kwargs): | |
| super().__init__(**kwargs) | |
| self.token_embeddings = layers.Embedding( | |
| input_dim=input_dim, output_dim=output_dim,mask_zero=mask_zero) | |
| self.position_embeddings = layers.Embedding( | |
| input_dim=sequence_length, output_dim=output_dim,mask_zero=False) | |
| self.sequence_length = sequence_length | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| def call(self, inputs): | |
| length = tf.shape(inputs)[-1] | |
| positions = tf.range(start=0, limit=length, delta=1) | |
| embedded_tokens = self.token_embeddings(inputs) | |
| embedded_positions = self.position_embeddings(positions) | |
| return embedded_tokens + embedded_positions | |
| embed_dim = 512 | |
| dense_dim = 2048 | |
| num_heads = 8 | |
| num_blocks = 7 | |
| encoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="encoder_inputs") | |
| decoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="decoder_inputs") | |
| # Masks | |
| encoder_mask = tf.keras.layers.Lambda(lambda x: tf.cast(tf.not_equal(x, 0), tf.bool))(encoder_inputs) | |
| cross_attention_mask = tf.keras.layers.Lambda(lambda x: tf.cast(x[:, tf.newaxis, tf.newaxis, :], tf.bool))(encoder_mask) | |
| # Embeddings | |
| encoder_embed = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs) | |
| decoder_embed = PositionalEmbedding(256, 257, embed_dim, mask_zero=False)(decoder_inputs) | |
| # Pre-instantiate blocks | |
| encoder_blocks = [TransformerEncoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)] | |
| decoder_blocks = [TransformerDecoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)] | |
| # Encoder | |
| x = encoder_embed | |
| for block in encoder_blocks: | |
| x = block(x, mask=encoder_mask) | |
| encoder_outputs = x | |
| # Decoder | |
| x = decoder_embed | |
| for block in decoder_blocks: | |
| x = block(x, encoder_outputs, mask=cross_attention_mask) | |
| # Output layers | |
| x = layers.LayerNormalization(epsilon=1e-5)(x) | |
| x = layers.Dropout(0.1)(x) | |
| decoder_outputs = layers.Dense(256)(x) | |
| transformer = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs) | |
| start_token = 256 | |
| max_output_length = 256 | |
| def generate_image_tokens(input_text): | |
| # Vectorize input text | |
| tokenized_text = vectorizer([input_text]) # Shape: (1, text_seq_len) | |
| # Start the decoded sequence with the start token | |
| decoded_image_tokens = [start_token] | |
| for i in range(max_output_length): | |
| # Convert to proper input format | |
| decoder_input = tf.convert_to_tensor([decoded_image_tokens]) | |
| # Predict next token probabilities | |
| predictions = (transformer([tokenized_text, decoder_input])) | |
| # Get the token for the current step | |
| sampled_token_index = np.argmax(predictions[0, -1, :]) | |
| # Append token to sequence | |
| decoded_image_tokens.append(sampled_token_index) | |
| # Optionally decode tokens into an image here | |
| return decoded_image_tokens[1:] | |
| def get_embeddings(indices, codebook): | |
| flat_indices = tf.reshape(indices, [-1]) | |
| flat_embeddings = tf.nn.embedding_lookup(codebook, flat_indices) | |
| out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], axis=0) | |
| return tf.reshape(flat_embeddings, (-1,16,16,256)) | |
| transformer.load_weights('src/VQGAN_Transformer.weights.h5') | |
| user_input = st.text_input("Enter some text:", "") | |
| if user_input != "": | |
| with st.spinner("Generating image..."): | |
| st.write(user_input) | |
| output_tokens = generate_image_tokens(user_input) | |
| embedding = get_embeddings(output_tokens, codebook) | |
| image = decoder(embedding)[0].numpy() | |
| image = np.clip(image * 255, 0, 255).astype(np.uint8) | |
| st.image(image, caption="Generated Image", width=512) |