File size: 3,659 Bytes
cafd528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import tensorflow as tf
from tensorflow.keras import layers, Model

class SwiGLU(layers.Layer):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.proj = layers.Dense(d_ff*2)
        self.out = layers.Dense(d_model)
    def call(self, x):
        x_proj = self.proj(x)
        x_val, x_gate = tf.split(x_proj, 2, axis=-1)
        return self.out(x_val * tf.nn.silu(x_gate))

class EncoderBlock(layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout=0.1):
        super().__init__()
        self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ffn = SwiGLU(d_model, dff)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)
    def call(self, x, mask=None, training=False):
        attn_out = self.dropout1(self.mha(x, x, x, attention_mask=mask), training=training)
        out1 = self.norm1(x + attn_out)
        ffn_out = self.dropout2(self.ffn(out1), training=training)
        return self.norm2(out1 + ffn_out)

class DecoderBlock(layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout=0.1):
        super().__init__()
        self.self_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.cross_mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ffn = SwiGLU(d_model, dff)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.norm3 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)
        self.dropout3 = layers.Dropout(dropout)
    def call(self, x, enc_out, training=False):
        attn1 = self.dropout1(self.self_mha(x, x, x, use_causal_mask=True), training=training)
        out1 = self.norm1(x + attn1)
        attn2 = self.dropout2(self.cross_mha(out1, enc_out, enc_out), training=training)
        out2 = self.norm2(out1 + attn2)
        ffn_out = self.dropout3(self.ffn(out2), training=training)
        return self.norm3(out2 + ffn_out)

class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, max_len=100, dropout=0.1):
        super().__init__()
        self.max_len = max_len
        self.d_model = d_model
        self.enc_embedding = layers.Embedding(input_vocab_size, d_model)
        self.enc_pos_embedding = layers.Embedding(max_len, d_model)
        self.dec_embedding = layers.Embedding(target_vocab_size, d_model)
        self.dec_pos_embedding = layers.Embedding(max_len, d_model)
        self.enc_layers = [EncoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
        self.dec_layers = [DecoderBlock(d_model, num_heads, dff, dropout) for _ in range(num_layers)]
        self.final_layer = layers.Dense(target_vocab_size)
    def call(self, inputs, training=False):
        enc_inputs = inputs["enc_inputs"]
        dec_inputs = inputs["dec_inputs"]
        enc_pos = tf.range(tf.shape(enc_inputs)[1])[tf.newaxis, :]
        dec_pos = tf.range(tf.shape(dec_inputs)[1])[tf.newaxis, :]
        x = self.enc_embedding(enc_inputs) + self.enc_pos_embedding(enc_pos)
        for layer in self.enc_layers: x = layer(x, training=training)
        enc_out = x
        y = self.dec_embedding(dec_inputs) + self.dec_pos_embedding(dec_pos)
        for layer in self.dec_layers: y = layer(y, enc_out, training=training)
        return self.final_layer(y)