File size: 9,437 Bytes
407d097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cd8731
 
 
407d097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da5fd9f
407d097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da5fd9f
407d097
 
da5fd9f
 
407d097
da5fd9f
 
 
407d097
da5fd9f
 
 
 
 
 
 
407d097
da5fd9f
 
407d097
da5fd9f
 
407d097
da5fd9f
 
407d097
da5fd9f
407d097
 
 
da5fd9f
407d097
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cd8731
407d097
 
85e2b48
407d097
3c66367
407d097
 
 
 
0af326c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import streamlit as st 
import numpy as np
import tensorflow as tf 
import string
import re
from tensorflow.keras.utils import register_keras_serializable
from tensorflow.keras.layers import Conv2D, Add
from tensorflow import keras 
from tensorflow.keras import layers
import os

@register_keras_serializable(package="Custom")
class SelfAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.filters = input_shape[-1]
        self.f = Conv2D(self.filters // 8, kernel_size=1, padding='same')
        self.g = Conv2D(self.filters // 8, kernel_size=1, padding='same')
        self.h = Conv2D(self.filters, kernel_size=1, padding='same')
        super().build(input_shape)

    def call(self, x):
        f = self.f(x)  # (B, H, W, C//8)
        g = self.g(x)
        h = self.h(x)  # (B, H, W, C)

        shape_f = tf.shape(f)
        B, H, W = shape_f[0], shape_f[1], shape_f[2]

        f_flat = tf.reshape(f, [B, H * W, self.filters // 8])
        g_flat = tf.reshape(g, [B, H * W, self.filters // 8])
        h_flat = tf.reshape(h, [B, H * W, self.filters])

        beta = tf.nn.softmax(tf.matmul(f_flat, g_flat, transpose_b=True), axis=-1)  # (B, N, N)

        o = tf.matmul(beta, h_flat)  # (B, N, C)
        o = tf.reshape(o, [B, H, W, self.filters])

        return Add()([x, o])  # Residual connection

    def get_config(self):
        config = super().get_config()
        # If you have custom arguments in __init__, add them here
        return config

captions = np.load('src/caption.npy')
decoder = tf.keras.models.load_model('src/epoch_78_decoder.keras',custom_objects={'SelfAttention': SelfAttention})
codebook = np.load('src/epoch_78_codebook.npy')
codebook = tf.convert_to_tensor(codebook, dtype=tf.float32)

strip_chars = string.punctuation + "¿"
strip_chars = strip_chars.replace("[", "")
strip_chars = strip_chars.replace("]", "")

def custom_standardization(input_string):
    lowercase = tf.strings.lower(input_string)
    return tf.strings.regex_replace(
        lowercase, f"[{re.escape(strip_chars)}]", "")

vocab_size = 18500
sequence_length = 85

vectorizer = layers.TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length,
)


vectorizer.adapt(captions)
@register_keras_serializable(package="Custom")
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads

        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads))
        self.dense_proj = keras.Sequential([
            layers.Dense(dense_dim, activation="relu"),
            layers.Dense(embed_dim)
        ])
        self.layernorm_1 = layers.LayerNormalization(epsilon=1e-5)
        self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5)
        self.dropout_1 = layers.Dropout(0.1)

    def call(self, inputs, mask=None):
        # Convert mask to boolean with shape (batch, 1, seq_len)
        if mask is not None:
            mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.bool)

        attention_output = self.attention(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=mask
        )
        attention_output = self.dropout_1(attention_output)
        proj_input = self.layernorm_1(inputs + attention_output)

        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

@register_keras_serializable(package="Custom")
class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads

        self.attention_1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads))
        self.attention_2 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=int(embed_dim/num_heads))
        self.dense_proj = keras.Sequential([
            layers.Dense(dense_dim, activation="relu"),
            layers.Dense(embed_dim)
        ])
        self.dropout_1 = layers.Dropout(0.1)
        self.dropout_2 = layers.Dropout(0.1)
        self.layernorm_1 = layers.LayerNormalization(epsilon=1e-5)
        self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5)
        self.layernorm_3 = layers.LayerNormalization(epsilon=1e-5)

    def call(self, inputs, encoder_outputs, mask=None):
        
        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=None,
            use_causal_mask=True
        )
        attention_output_1 = self.dropout_1(attention_output_1)
        attention_output_1 = self.layernorm_1(inputs + attention_output_1)

        # Cross-attention with padding mask only
        attention_output_2 = self.attention_2(
            query=attention_output_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=mask,
            use_causal_mask=False
        )
        attention_output_2 = self.dropout_2(attention_output_2)
        attention_output_2 = self.layernorm_2(attention_output_1 + attention_output_2)

        proj_output = self.dense_proj(attention_output_2)
        return self.layernorm_3(attention_output_2 + proj_output)

@register_keras_serializable(package="Custom")
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, input_dim, output_dim, mask_zero = True, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=input_dim, output_dim=output_dim,mask_zero=mask_zero)
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=output_dim,mask_zero=False)
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim = output_dim

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

embed_dim = 512
dense_dim = 2048
num_heads = 8
num_blocks = 7

encoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="encoder_inputs")
decoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="decoder_inputs")

# Masks
encoder_mask = tf.keras.layers.Lambda(lambda x: tf.cast(tf.not_equal(x, 0), tf.bool))(encoder_inputs)
cross_attention_mask = tf.keras.layers.Lambda(lambda x: tf.cast(x[:, tf.newaxis, tf.newaxis, :], tf.bool))(encoder_mask)

    # Embeddings
encoder_embed = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
decoder_embed = PositionalEmbedding(256, 257, embed_dim, mask_zero=False)(decoder_inputs)

    # Pre-instantiate blocks
encoder_blocks = [TransformerEncoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)]
decoder_blocks = [TransformerDecoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)]

    # Encoder
x = encoder_embed
for block in encoder_blocks:
    x = block(x, mask=encoder_mask)
encoder_outputs = x

    # Decoder
x = decoder_embed
for block in decoder_blocks:
    x = block(x, encoder_outputs, mask=cross_attention_mask)

    # Output layers
x = layers.LayerNormalization(epsilon=1e-5)(x)
x = layers.Dropout(0.1)(x)
decoder_outputs = layers.Dense(256)(x)

transformer = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)

start_token = 256
max_output_length = 256

def generate_image_tokens(input_text):
    # Vectorize input text
    tokenized_text = vectorizer([input_text])  # Shape: (1, text_seq_len)

    # Start the decoded sequence with the start token
    decoded_image_tokens = [start_token]

    for i in range(max_output_length):
        # Convert to proper input format
        decoder_input = tf.convert_to_tensor([decoded_image_tokens])

        # Predict next token probabilities
        predictions = (transformer([tokenized_text, decoder_input]))
        

        # Get the token for the current step
        sampled_token_index = np.argmax(predictions[0, -1, :])

        # Append token to sequence
        decoded_image_tokens.append(sampled_token_index)

    # Optionally decode tokens into an image here
    return decoded_image_tokens[1:]

def get_embeddings(indices, codebook):
    flat_indices = tf.reshape(indices, [-1])
    flat_embeddings = tf.nn.embedding_lookup(codebook, flat_indices)

    out_shape = tf.concat([tf.shape(indices), [tf.shape(codebook)[-1]]], axis=0)
    return tf.reshape(flat_embeddings, (-1,16,16,256))

transformer.load_weights('src/VQGAN_Transformer.weights.h5')
user_input = st.text_input("Enter some text:", "")

if user_input != "":
    with st.spinner("Generating image..."):
        st.write(user_input)
        output_tokens = generate_image_tokens(user_input)
        embedding = get_embeddings(output_tokens, codebook)
        image = decoder(embedding)[0].numpy()
        image = np.clip(image * 255, 0, 255).astype(np.uint8)
        st.image(image, caption="Generated Image", width=512)