Yuchan
commited on
Update AlphaS2S.py
Browse files- AlphaS2S.py +13 -4
AlphaS2S.py
CHANGED
|
@@ -243,8 +243,8 @@ class LoU(layers.Layer):
|
|
| 243 |
out = self.glu(out)
|
| 244 |
return tf.cast(out, x.dtype)
|
| 245 |
|
| 246 |
-
class
|
| 247 |
-
def __init__(self, num_layers, d_model, num_heads,
|
| 248 |
super().__init__()
|
| 249 |
self.max_len = max_len
|
| 250 |
self.d_model = d_model
|
|
@@ -252,9 +252,9 @@ class Transformer(tf.keras.Model):
|
|
| 252 |
self.enc_pos_embedding = layers.Embedding(max_len, d_model)
|
| 253 |
self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
|
| 254 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 255 |
-
self.enc_layers = [EncoderBlock(d_model, num_heads,
|
| 256 |
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
|
| 257 |
-
self.final_layer = layers.Dense(target_vocab_size)
|
| 258 |
def call(self, inputs, training=False):
|
| 259 |
enc_inputs = inputs["enc_inputs"]
|
| 260 |
dec_inputs = inputs["dec_inputs"]
|
|
@@ -266,3 +266,12 @@ class Transformer(tf.keras.Model):
|
|
| 266 |
y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
|
| 267 |
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
| 268 |
return self.final_layer(y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
out = self.glu(out)
|
| 244 |
return tf.cast(out, x.dtype)
|
| 245 |
|
| 246 |
+
class AlphaS2S(tf.keras.Model):
|
| 247 |
+
def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
|
| 248 |
super().__init__()
|
| 249 |
self.max_len = max_len
|
| 250 |
self.d_model = d_model
|
|
|
|
| 252 |
self.enc_pos_embedding = layers.Embedding(max_len, d_model)
|
| 253 |
self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
|
| 254 |
self.dec_pos_embedding = layers.Embedding(max_len, d_model)
|
| 255 |
+
self.enc_layers = [EncoderBlock(d_model, num_heads, dropout) for _ in range(num_layers)]
|
| 256 |
self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
|
| 257 |
+
self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
|
| 258 |
def call(self, inputs, training=False):
|
| 259 |
enc_inputs = inputs["enc_inputs"]
|
| 260 |
dec_inputs = inputs["dec_inputs"]
|
|
|
|
| 266 |
y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
|
| 267 |
for layer in self.dec_layers: y = layer(y, enc_out, training=training)
|
| 268 |
return self.final_layer(y)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
|
| 272 |
+
input_vocab_size=chat_vocab_size, target_vocab_size=chat_vocab_size)
|
| 273 |
+
dummy_input = {
|
| 274 |
+
"enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
|
| 275 |
+
"dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
|
| 276 |
+
}
|
| 277 |
+
_ = chat_model(dummy_input)
|