Yuchan commited on
Commit
1bf639d
·
verified ·
1 Parent(s): 5a66735

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +2 -2
AlphaS2S.py CHANGED
@@ -172,7 +172,7 @@ class EncoderBlock(layers.Layer):
172
  self.dropout2 = layers.Dropout(dropout)
173
  def call(self, x, mask=None, training=False):
174
  attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
175
- out1 = self.norm1(x + attn_out)
176
  ffn_out = self.dropout2(self.ffn(out1), training=training)
177
  return self.norm2(out1 + ffn_out)
178
 
@@ -190,7 +190,7 @@ class DecoderBlock(layers.Layer):
190
  self.dropout3 = layers.Dropout(dropout)
191
  def call(self, x, enc_out, training=False):
192
  attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
193
- out1 = self.norm1(x + attn1)
194
  attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
195
  out2 = self.norm2(out1 + attn2)
196
  ffn_out = self.dropout3(self.ffn(out2), training=training)
 
172
  self.dropout2 = layers.Dropout(dropout)
173
  def call(self, x, mask=None, training=False):
174
  attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
175
+ out1 = self.norm1(attn_out)
176
  ffn_out = self.dropout2(self.ffn(out1), training=training)
177
  return self.norm2(out1 + ffn_out)
178
 
 
190
  self.dropout3 = layers.Dropout(dropout)
191
  def call(self, x, enc_out, training=False):
192
  attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
193
+ out1 = self.norm1(attn1)
194
  attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
195
  out2 = self.norm2(out1 + attn2)
196
  ffn_out = self.dropout3(self.ffn(out2), training=training)