| import tensorflow as tf | |
| from tensorflow.keras import layers, Model | |
| class SwiGLU(layers.Layer): | |
| def __init__(self, d_model, d_ff): | |
| super().__init__() | |
| self.proj = layers.Dense(d_ff*2) | |
| self.out = layers.Dense(d_model) | |
| def call(self, x): | |
| x_proj = self.proj(x) | |
| x_val, x_gate = tf.split(x_proj, 2, axis=-1) | |
| return self.out(x_val * tf.nn.silu(x_gate)) | |
| class EncoderBlock(layers.Layer): | |
| def __init__(self, d_model, num_heads, dff, dropout=0.1): | |
| super().__init__() | |
| self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model) | |
| self.ffn = SwiGLU(d_model, dff) | |
| self.norm1 = layers.LayerNormalization(epsilon=1e-6) | |
| self.norm2 = layers.LayerNormalization(epsilon=1e-6) | |
| self.dropout1 = layers.Dropout(dropout) | |
| self.dropout2 = layers.Dropout(dropout) | |
| def call(self, x, mask=None, training=False): | |
| attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training) | |
| out1 = self.norm1(x + attn_out) | |
| ffn_out = self.dropout2(self.ffn(out1), training=training) | |
| return self.norm2(out1 + ffn_out) | |
| class DecoderBlock(layers.Layer): | |
| def __init__(self, d_model, num_heads, dff, dropout=0.1): | |
| super().__init__() | |
| self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model) | |
| self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model) | |
| self.ffn = SwiGLU(d_model, dff) | |
| self.norm1 = layers.LayerNormalization(epsilon=1e-6) | |
| self.norm2 = layers.LayerNormalization(epsilon=1e-6) | |
| self.norm3 = layers.LayerNormalization(epsilon=1e-6) | |
| self.dropout1 = layers.Dropout(dropout) | |
| self.dropout2 = layers.Dropout(dropout) | |
| self.dropout3 = layers.Dropout(dropout) | |
| def call(self, x, enc_out, training=False): | |
| attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training) | |
| out1 = self.norm1(x + attn1) | |
| attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training) | |
| out2 = self.norm2(out1 + attn2) | |
| ffn_out = self.dropout3(self.ffn(out2), training=training) | |
| return self.norm3(out2 + ffn_out) | |
| class Transformer(tf.keras.Model): | |
| def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1): | |
| super().__init__() | |
| self.max_len = max_len | |
| self.d_model = d_model | |
| self.enc_embedding = layers.Embedding(input_vocab_size, d_model) | |
| self.enc_pos_embedding = layers.Embedding(max_len, d_model) | |
| self.dec_embedding = layers.Embedding(target_vocab_size, d_model) | |
| self.dec_pos_embedding = layers.Embedding(max_len, d_model) | |
| self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)] | |
| self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)] | |
| self.final_layer = layers.Dense(target_vocab_size) | |
| def call(self, inputs, training=False): | |
| enc_inputs = inputs["enc_inputs"] | |
| dec_inputs = inputs["dec_inputs"] | |
| enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :] | |
| dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :] | |
| x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos) | |
| for layer in self.enc_layers: x = layer(x, training=training) | |
| enc_out = x | |
| y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos) | |
| for layer in self.dec_layers: y = layer(y, enc_out, training=training) | |
| return self.final_layer(y) | |