File size: 2,538 Bytes
ee03b1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import tensorflow as tf
from .encoder import Encoder
from .decoder import Decoder
from tensorflow.keras.layers import Dense

@tf.keras.utils.register_keras_serializable()
class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 

                 target_vocab_size, max_tokens, dropout_rate=0.1, **kwargs):
        super(Transformer, self).__init__(**kwargs)
        self.num_layers = num_layers
        self.d_model = d_model
        self.num_heads = num_heads
        self.dff = dff
        self.input_vocab_size = input_vocab_size
        self.target_vocab_size = target_vocab_size
        self.max_tokens = max_tokens
        self.dropout_rate = dropout_rate
        
        self.encoder = Encoder(num_layers, d_model, num_heads, dff, 
                             input_vocab_size, max_tokens, dropout_rate)
        self.decoder = Decoder(num_layers, d_model, num_heads, dff, 
                              target_vocab_size, max_tokens, dropout_rate)
        self.final_layer = Dense(target_vocab_size)

    def call(self, inputs, training=None):
        enc_input, dec_input = inputs
        enc_padding_mask = self.create_padding_mask(enc_input)
        look_ahead_mask = self.create_look_ahead_mask(tf.shape(dec_input)[1])
        dec_padding_mask = self.create_padding_mask(enc_input)
        enc_output = self.encoder(enc_input, training=training, mask=enc_padding_mask)
        dec_output = self.decoder(dec_input, enc_output, training=training, 
                                 look_ahead_mask=look_ahead_mask, 
                                 padding_mask=dec_padding_mask)
        final_output = self.final_layer(dec_output)
        return final_output

    def create_padding_mask(self, seq):
        mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
        return mask[:, tf.newaxis, tf.newaxis, :]

    def create_look_ahead_mask(self, size):
        mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
        return mask

    def get_config(self):
        config = super().get_config()
        config.update({
            'num_layers': self.num_layers,
            'd_model': self.d_model,
            'num_heads': self.num_heads,
            'dff': self.dff,
            'input_vocab_size': self.input_vocab_size,
            'target_vocab_size': self.target_vocab_size,
            'max_tokens': self.max_tokens,
            'dropout_rate': self.dropout_rate
        })
        return config