Yuchan commited on
Commit
cafd528
·
verified ·
1 Parent(s): c80c688

Create AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +71 -0
AlphaS2S.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers, Model
3
+
4
+ class SwiGLU(layers.Layer):
5
+ def __init__(self, d_model, d_ff):
6
+ super().__init__()
7
+ self.proj = layers.Dense(d_ff*2)
8
+ self.out = layers.Dense(d_model)
9
+ def call(self, x):
10
+ x_proj = self.proj(x)
11
+ x_val, x_gate = tf.split(x_proj, 2, axis=-1)
12
+ return self.out(x_val * tf.nn.silu(x_gate))
13
+
14
+ class EncoderBlock(layers.Layer):
15
+ def __init__(self, d_model, num_heads, dff, dropout=0.1):
16
+ super().__init__()
17
+ self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
18
+ self.ffn = SwiGLU(d_model, dff)
19
+ self.norm1 = layers.LayerNormalization(epsilon=1e-6)
20
+ self.norm2 = layers.LayerNormalization(epsilon=1e-6)
21
+ self.dropout1 = layers.Dropout(dropout)
22
+ self.dropout2 = layers.Dropout(dropout)
23
+ def call(self, x, mask=None, training=False):
24
+ attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
25
+ out1 = self.norm1(x + attn_out)
26
+ ffn_out = self.dropout2(self.ffn(out1), training=training)
27
+ return self.norm2(out1 + ffn_out)
28
+
29
+ class DecoderBlock(layers.Layer):
30
+ def __init__(self, d_model, num_heads, dff, dropout=0.1):
31
+ super().__init__()
32
+ self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
33
+ self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
34
+ self.ffn = SwiGLU(d_model, dff)
35
+ self.norm1 = layers.LayerNormalization(epsilon=1e-6)
36
+ self.norm2 = layers.LayerNormalization(epsilon=1e-6)
37
+ self.norm3 = layers.LayerNormalization(epsilon=1e-6)
38
+ self.dropout1 = layers.Dropout(dropout)
39
+ self.dropout2 = layers.Dropout(dropout)
40
+ self.dropout3 = layers.Dropout(dropout)
41
+ def call(self, x, enc_out, training=False):
42
+ attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
43
+ out1 = self.norm1(x + attn1)
44
+ attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
45
+ out2 = self.norm2(out1 + attn2)
46
+ ffn_out = self.dropout3(self.ffn(out2), training=training)
47
+ return self.norm3(out2 + ffn_out)
48
+
49
+ class Transformer(tf.keras.Model):
50
+ def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
51
+ super().__init__()
52
+ self.max_len = max_len
53
+ self.d_model = d_model
54
+ self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
55
+ self.enc_pos_embedding = layers.Embedding(max_len, d_model)
56
+ self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
57
+ self.dec_pos_embedding = layers.Embedding(max_len, d_model)
58
+ self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
59
+ self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
60
+ self.final_layer = layers.Dense(target_vocab_size)
61
+ def call(self, inputs, training=False):
62
+ enc_inputs = inputs["enc_inputs"]
63
+ dec_inputs = inputs["dec_inputs"]
64
+ enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
65
+ dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
66
+ x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
67
+ for layer in self.enc_layers: x = layer(x, training=training)
68
+ enc_out = x
69
+ y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
70
+ for layer in self.dec_layers: y = layer(y, enc_out, training=training)
71
+ return self.final_layer(y)