Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +7 -3
AlphaS2S.py
CHANGED
|
@@ -166,11 +166,13 @@ class EncoderBlock(layers.Layer):
|
|
| 166 |
super().__init__()
|
| 167 |
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 168 |
self.ffn = SwiGLU(d_model, dff)
|
|
|
|
| 169 |
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 170 |
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 171 |
self.dropout1 = layers.Dropout(dropout)
|
| 172 |
self.dropout2 = layers.Dropout(dropout)
|
| 173 |
def call(self, x, mask=None, training=False):
|
|
|
|
| 174 |
attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
|
| 175 |
out1 = self.norm1(attn_out + x)
|
| 176 |
ffn_out = self.dropout2(self.ffn(out1), training=training)
|
|
@@ -182,6 +184,7 @@ class DecoderBlock(layers.Layer):
|
|
| 182 |
self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 183 |
self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 184 |
self.ffn = SwiGLU(d_model, dff)
|
|
|
|
| 185 |
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 186 |
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 187 |
self.norm3 = layers.LayerNormalization(epsilon=1e-6)
|
|
@@ -189,6 +192,7 @@ class DecoderBlock(layers.Layer):
|
|
| 189 |
self.dropout2 = layers.Dropout(dropout)
|
| 190 |
self.dropout3 = layers.Dropout(dropout)
|
| 191 |
def call(self, x, enc_out, training=False):
|
|
|
|
| 192 |
attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
|
| 193 |
out1 = self.norm1(attn1 + x)
|
| 194 |
attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
|
|
@@ -201,9 +205,9 @@ class Transformer(tf.keras.Model):
|
|
| 201 |
super().__init__()
|
| 202 |
self.max_len = max_len
|
| 203 |
self.d_model = d_model
|
| 204 |
-
self.enc_embedding = layers.Embedding(input_vocab_size,
|
| 205 |
-
self.enc_pos_embedding = layers.Embedding(max_len,
|
| 206 |
-
self.dec_embedding = layers.Embedding(target_vocab_size,
|
| 207 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 208 |
self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|
| 209 |
self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|
|
|
|
| 166 |
super().__init__()
|
| 167 |
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 168 |
self.ffn = SwiGLU(d_model, dff)
|
| 169 |
+
self.proj = layers.Dense(d_model)
|
| 170 |
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 171 |
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 172 |
self.dropout1 = layers.Dropout(dropout)
|
| 173 |
self.dropout2 = layers.Dropout(dropout)
|
| 174 |
def call(self, x, mask=None, training=False):
|
| 175 |
+
x = self.proj(x)
|
| 176 |
attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
|
| 177 |
out1 = self.norm1(attn_out + x)
|
| 178 |
ffn_out = self.dropout2(self.ffn(out1), training=training)
|
|
|
|
| 184 |
self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 185 |
self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
|
| 186 |
self.ffn = SwiGLU(d_model, dff)
|
| 187 |
+
self.proj = layers.Dense(d_model)
|
| 188 |
self.norm1 = layers.LayerNormalization(epsilon=1e-6)
|
| 189 |
self.norm2 = layers.LayerNormalization(epsilon=1e-6)
|
| 190 |
self.norm3 = layers.LayerNormalization(epsilon=1e-6)
|
|
|
|
| 192 |
self.dropout2 = layers.Dropout(dropout)
|
| 193 |
self.dropout3 = layers.Dropout(dropout)
|
| 194 |
def call(self, x, enc_out, training=False):
|
| 195 |
+
x = self.proj(x)
|
| 196 |
attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
|
| 197 |
out1 = self.norm1(attn1 + x)
|
| 198 |
attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
|
|
|
|
| 205 |
super().__init__()
|
| 206 |
self.max_len = max_len
|
| 207 |
self.d_model = d_model
|
| 208 |
+
self.enc_embedding = layers.Embedding(input_vocab_size, 256)
|
| 209 |
+
self.enc_pos_embedding = layers.Embedding(max_len, 256)
|
| 210 |
+
self.dec_embedding = layers.Embedding(target_vocab_size, 256)
|
| 211 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 212 |
self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|
| 213 |
self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|