Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +2 -2
AlphaS2S.py
CHANGED
|
@@ -172,7 +172,7 @@ class EncoderBlock(layers.Layer):
|
|
| 172 |
self.dropout2 = layers.Dropout(dropout)
|
| 173 |
def call(self, x, mask=None, training=False):
|
| 174 |
attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
|
| 175 |
-
out1 = self.norm1(
|
| 176 |
ffn_out = self.dropout2(self.ffn(out1), training=training)
|
| 177 |
return self.norm2(out1 + ffn_out)
|
| 178 |
|
|
@@ -190,7 +190,7 @@ class DecoderBlock(layers.Layer):
|
|
| 190 |
self.dropout3 = layers.Dropout(dropout)
|
| 191 |
def call(self, x, enc_out, training=False):
|
| 192 |
attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
|
| 193 |
-
out1 = self.norm1(
|
| 194 |
attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
|
| 195 |
out2 = self.norm2(out1 + attn2)
|
| 196 |
ffn_out = self.dropout3(self.ffn(out2), training=training)
|
|
|
|
| 172 |
self.dropout2 = layers.Dropout(dropout)
|
| 173 |
def call(self, x, mask=None, training=False):
|
| 174 |
attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
|
| 175 |
+
out1 = self.norm1(attn_out)
|
| 176 |
ffn_out = self.dropout2(self.ffn(out1), training=training)
|
| 177 |
return self.norm2(out1 + ffn_out)
|
| 178 |
|
|
|
|
| 190 |
self.dropout3 = layers.Dropout(dropout)
|
| 191 |
def call(self, x, enc_out, training=False):
|
| 192 |
attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
|
| 193 |
+
out1 = self.norm1(attn1)
|
| 194 |
attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
|
| 195 |
out2 = self.norm2(out1 + attn2)
|
| 196 |
ffn_out = self.dropout3(self.ffn(out2), training=training)
|