Yuchan commited on
Commit
b826d8a
·
verified ·
1 Parent(s): f82693c

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +13 -4
AlphaS2S.py CHANGED
@@ -243,8 +243,8 @@ class LoU(layers.Layer):
243
  out = self.glu(out)
244
  return tf.cast(out, x.dtype)
245
 
246
- class Transformer(tf.keras.Model):
247
- def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
248
  super().__init__()
249
  self.max_len = max_len
250
  self.d_model = d_model
@@ -252,9 +252,9 @@ class Transformer(tf.keras.Model):
252
  self.enc_pos_embedding = layers.Embedding(max_len, d_model)
253
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
254
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
255
- self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
256
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
257
- self.final_layer = layers.Dense(target_vocab_size)
258
  def call(self, inputs, training=False):
259
  enc_inputs = inputs["enc_inputs"]
260
  dec_inputs = inputs["dec_inputs"]
@@ -266,3 +266,12 @@ class Transformer(tf.keras.Model):
266
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
267
  for layer in self.dec_layers: y = layer(y, enc_out, training=training)
268
  return self.final_layer(y)
 
 
 
 
 
 
 
 
 
 
243
  out = self.glu(out)
244
  return tf.cast(out, x.dtype)
245
 
246
+ class AlphaS2S(tf.keras.Model):
247
+ def __init__(self, num_layers, d_model, num_heads, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
248
  super().__init__()
249
  self.max_len = max_len
250
  self.d_model = d_model
 
252
  self.enc_pos_embedding = layers.Embedding(max_len, d_model)
253
  self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
254
  self.dec_pos_embedding = layers.Embedding(max_len, d_model)
255
+ self.enc_layers = [EncoderBlock(d_model, num_heads, dropout) for _ in range(num_layers)]
256
  self.dec_layers = [LoU(d_model) for _ in range(num_layers)]
257
+ self.final_layer = layers.Dense(target_vocab_size, use_bias=False)
258
  def call(self, inputs, training=False):
259
  enc_inputs = inputs["enc_inputs"]
260
  dec_inputs = inputs["dec_inputs"]
 
266
  y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
267
  for layer in self.dec_layers: y = layer(y, enc_out, training=training)
268
  return self.final_layer(y)
269
+
270
+
271
+ chat_model = AlphaS2S(num_layers=4, d_model=160, num_heads=8,
272
+ input_vocab_size=chat_vocab_size, target_vocab_size=chat_vocab_size)
273
+ dummy_input = {
274
+ "enc_inputs": tf.zeros((1, max_len), dtype=tf.int32),
275
+ "dec_inputs": tf.zeros((1, max_len), dtype=tf.int32)
276
+ }
277
+ _ = chat_model(dummy_input)