Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +1 -1
AlphaS2S.py
CHANGED
|
@@ -207,7 +207,7 @@ class Transformer(tf.keras.Model):
|
|
| 207 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 208 |
self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|
| 209 |
self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|
| 210 |
-
self.final_layer = layers.Dense(target_vocab_size)
|
| 211 |
def call(self, inputs, training=False):
|
| 212 |
enc_inputs = inputs["enc_inputs"]
|
| 213 |
dec_inputs = inputs["dec_inputs"]
|
|
|
|
| 207 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 208 |
self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|
| 209 |
self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
|
| 210 |
+
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
|
| 211 |
def call(self, inputs, training=False):
|
| 212 |
enc_inputs = inputs["enc_inputs"]
|
| 213 |
dec_inputs = inputs["dec_inputs"]
|