Beasto commited on
Commit
da5fd9f
·
verified ·
1 Parent(s): cf28d18

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +22 -37
src/streamlit_app.py CHANGED
@@ -128,28 +128,8 @@ class TransformerDecoder(layers.Layer):
128
  self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5)
129
  self.layernorm_3 = layers.LayerNormalization(epsilon=1e-5)
130
 
131
- # def get_causal_attention_mask(self, inputs):
132
- # seq_len = tf.shape(inputs)[1]
133
- # causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len), dtype=tf.bool), -1, 0)
134
- # return causal_mask[tf.newaxis, :, :] # (1, seq_len, seq_len)
135
-
136
  def call(self, inputs, encoder_outputs, mask=None):
137
- # Padding mask: (batch_size, 1, seq_len)
138
- # if mask is not None:
139
- # padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.bool)
140
- # else:
141
- # padding_mask = None
142
-
143
- # # Causal mask: (1, seq_len, seq_len)
144
- # causal_mask = self.get_causal_attention_mask(inputs)
145
-
146
- # # Combine masks for self-attention
147
- # if padding_mask is not None:
148
- # combined_mask = tf.logical_and(padding_mask, causal_mask)
149
- # else:
150
- # combined_mask = causal_mask
151
-
152
- # Self-attention with combined mask
153
  attention_output_1 = self.attention_1(
154
  query=inputs,
155
  value=inputs,
@@ -197,33 +177,38 @@ embed_dim = 512
197
  dense_dim = 2048
198
  num_heads = 8
199
  num_blocks = 7
200
-
201
  encoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="encoder_inputs")
202
  decoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="decoder_inputs")
203
-
204
- # Paddincfg masks
205
  encoder_mask = tf.keras.layers.Lambda(lambda x: tf.cast(tf.not_equal(x, 0), tf.bool))(encoder_inputs)
206
- cross_attention_mask = tf.keras.layers.Lambda(lambda x: tf.cast(x[:, tf.newaxis, tf.newaxis, :], tf.bool))(encoder_mask)
207
-
208
- # Embeddings
209
  encoder_embed = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
210
- decoder_embed = PositionalEmbedding(256, 257, embed_dim,mask_zero=False)(decoder_inputs)
211
-
212
- # Encoder blocks
 
 
 
 
213
  x = encoder_embed
214
- for _ in range(num_blocks):
215
- x = TransformerEncoder(embed_dim, dense_dim, num_heads)(x, mask=encoder_mask)
216
  encoder_outputs = x
217
-
218
- # Decoder blocks
219
  x = decoder_embed
220
- for _ in range(num_blocks):
221
- x = TransformerDecoder(embed_dim, dense_dim, num_heads)(x, encoder_outputs, mask=cross_attention_mask)
222
 
223
- # Final layers
224
  x = layers.LayerNormalization(epsilon=1e-5)(x)
225
  x = layers.Dropout(0.1)(x)
226
  decoder_outputs = layers.Dense(256)(x)
 
227
  transformer = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
228
 
229
  start_token = 256
 
128
  self.layernorm_2 = layers.LayerNormalization(epsilon=1e-5)
129
  self.layernorm_3 = layers.LayerNormalization(epsilon=1e-5)
130
 
 
 
 
 
 
131
  def call(self, inputs, encoder_outputs, mask=None):
132
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  attention_output_1 = self.attention_1(
134
  query=inputs,
135
  value=inputs,
 
177
  dense_dim = 2048
178
  num_heads = 8
179
  num_blocks = 7
180
+
181
  encoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="encoder_inputs")
182
  decoder_inputs = tf.keras.Input(shape=(None,), dtype="int32", name="decoder_inputs")
183
+
184
+ # Masks
185
  encoder_mask = tf.keras.layers.Lambda(lambda x: tf.cast(tf.not_equal(x, 0), tf.bool))(encoder_inputs)
186
+ cross_attention_mask = tf.keras.layers.Lambda(lambda x: tf.cast(x[:, tf.newaxis, tf.newaxis, :], tf.bool))(encoder_mask)
187
+
188
+ # Embeddings
189
  encoder_embed = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(encoder_inputs)
190
+ decoder_embed = PositionalEmbedding(256, 257, embed_dim, mask_zero=False)(decoder_inputs)
191
+
192
+ # Pre-instantiate blocks
193
+ encoder_blocks = [TransformerEncoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)]
194
+ decoder_blocks = [TransformerDecoder(embed_dim, dense_dim, num_heads) for _ in range(num_blocks)]
195
+
196
+ # Encoder
197
  x = encoder_embed
198
+ for block in encoder_blocks:
199
+ x = block(x, mask=encoder_mask)
200
  encoder_outputs = x
201
+
202
+ # Decoder
203
  x = decoder_embed
204
+ for block in decoder_blocks:
205
+ x = block(x, encoder_outputs, mask=cross_attention_mask)
206
 
207
+ # Output layers
208
  x = layers.LayerNormalization(epsilon=1e-5)(x)
209
  x = layers.Dropout(0.1)(x)
210
  decoder_outputs = layers.Dense(256)(x)
211
+
212
  transformer = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)
213
 
214
  start_token = 256