Yuchan commited on
Commit
c536d0e
·
verified ·
1 Parent(s): a6af02f

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +7 -3
AlphaS2S.py CHANGED
@@ -166,11 +166,13 @@ class EncoderBlock(layers.Layer):
166
  super().__init__()
167
  self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
168
  self.ffn = SwiGLU(d_model, dff)
 
169
  self.norm1 = layers.LayerNormalization(epsilon=1e-6)
170
  self.norm2 = layers.LayerNormalization(epsilon=1e-6)
171
  self.dropout1 = layers.Dropout(dropout)
172
  self.dropout2 = layers.Dropout(dropout)
173
  def call(self, x, mask=None, training=False):
 
174
  attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
175
  out1 = self.norm1(attn_out + x)
176
  ffn_out = self.dropout2(self.ffn(out1), training=training)
@@ -182,6 +184,7 @@ class DecoderBlock(layers.Layer):
182
  self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
183
  self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
184
  self.ffn = SwiGLU(d_model, dff)
 
185
  self.norm1 = layers.LayerNormalization(epsilon=1e-6)
186
  self.norm2 = layers.LayerNormalization(epsilon=1e-6)
187
  self.norm3 = layers.LayerNormalization(epsilon=1e-6)
@@ -189,6 +192,7 @@ class DecoderBlock(layers.Layer):
189
  self.dropout2 = layers.Dropout(dropout)
190
  self.dropout3 = layers.Dropout(dropout)
191
  def call(self, x, enc_out, training=False):
 
192
  attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
193
  out1 = self.norm1(attn1 + x)
194
  attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
@@ -201,9 +205,9 @@ class Transformer(tf.keras.Model):
201
  super().__init__()
202
  self.max_len = max_len
203
  self.d_model = d_model
204
- self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
205
- self.enc_pos_embedding = layers.Embedding(max_len, d_model)
206
- self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
207
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
208
  self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
209
  self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
 
166
  super().__init__()
167
  self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
168
  self.ffn = SwiGLU(d_model, dff)
169
+ self.proj = layers.Dense(d_model)
170
  self.norm1 = layers.LayerNormalization(epsilon=1e-6)
171
  self.norm2 = layers.LayerNormalization(epsilon=1e-6)
172
  self.dropout1 = layers.Dropout(dropout)
173
  self.dropout2 = layers.Dropout(dropout)
174
  def call(self, x, mask=None, training=False):
175
+ x = self.proj(x)
176
  attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
177
  out1 = self.norm1(attn_out + x)
178
  ffn_out = self.dropout2(self.ffn(out1), training=training)
 
184
  self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
185
  self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
186
  self.ffn = SwiGLU(d_model, dff)
187
+ self.proj = layers.Dense(d_model)
188
  self.norm1 = layers.LayerNormalization(epsilon=1e-6)
189
  self.norm2 = layers.LayerNormalization(epsilon=1e-6)
190
  self.norm3 = layers.LayerNormalization(epsilon=1e-6)
 
192
  self.dropout2 = layers.Dropout(dropout)
193
  self.dropout3 = layers.Dropout(dropout)
194
  def call(self, x, enc_out, training=False):
195
+ x = self.proj(x)
196
  attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
197
  out1 = self.norm1(attn1 + x)
198
  attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
 
205
  super().__init__()
206
  self.max_len = max_len
207
  self.d_model = d_model
208
+ self.enc_embedding = layers.Embedding(input_vocab_size, 256)
209
+ self.enc_pos_embedding = layers.Embedding(max_len, 256)
210
+ self.dec_embedding = layers.Embedding(target_vocab_size, 256)
211
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
212
  self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
213
  self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]