File size: 4,392 Bytes
09c7007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization, Embedding, Dense


class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim    # Dimension of embedding. 4 in the dummy example
        self.dense_dim = dense_dim    # No. of neurons in dense layer
        self.num_heads = num_heads    # No. of heads for MultiHead Attention layer
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)    # MultiHead Attention layer
        self.dense_proj = keras.Sequential([layers.Dense(dense_dim, activation="relu"),
                                            layers.Dense(embed_dim),]    # encoders are stacked on top of the other.
                                           )                             # So output dimension is also embed_dim
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()

    # Call function based on figure above
    def call(self, inputs, mask=None):
        if mask is not None:
            mask = mask[:, tf.newaxis, :]
        attention_output = self.attention(query=inputs,             # Query: inputs,
                                          value=inputs,             # Value: inputs,
                                          key=inputs,               # Keys: Same as Values by default
                                          attention_mask=mask
                                          )                         # Q: Can you see how this is self attention? A: all args are the same

        proj_input = self.layernorm_1(inputs + attention_output) # LayerNormalization; + Recall cat picture
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)  # LayerNormalization + Residual connection

    def get_config(self):
        config = super().get_config()
        config.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "dense_dim": self.dense_dim,
        })
        return config


# Using positional encoding to re-inject order information

class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):                  # input_dim = (token) vocabulary size,  output_dim = embedding size
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(input_dim=input_dim, output_dim=output_dim)            # Q: what is input_dim and output_dim? A: vocab size, embedding dim
        self.position_embeddings = layers.Embedding(input_dim=sequence_length, output_dim=output_dim)   # Q: Why input_dim = seq_length?  A: there are seq_len; no. of possible positions
                                                                                                        # Q: What is the vocab for this Embedding layer? A: seq_length
        self.sequence_length = sequence_length
        self.input_dim = input_dim
        self.output_dim = output_dim


    def call(self, inputs):   # inputs will be a batch of sequences (batch, seq_len)
        length = tf.shape(inputs)[-1]     # lenght will just be sequence length
        positions = tf.range(start=0, limit=length, delta=1) # indices for input to positional embedding
        embedded_tokens = tf.reshape(self.token_embeddings(inputs), (-1, length, self.output_dim))
        embedded_positions = tf.reshape(self.position_embeddings(positions), (-1, length, self.output_dim))
        return layers.Add()([embedded_tokens, embedded_positions])     # ADD the embeddings

    def compute_mask(self, inputs, mask=None):     # makes this layer a mask-generating layer
        if mask is None:
            return None
        return tf.math.not_equal(inputs, 0)        # mask will get propagated to the next layer.

    # When using custom layers, this enables the layer to be reinstantiated from its config dict,
    # which is useful during model saving and loading.
    def get_config(self):
        config = super().get_config()
        config.update({
            "output_dim": self.output_dim,
            "sequence_length": self.sequence_length,
            "input_dim": self.input_dim,
        })
        return config