File size: 5,250 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Multi-head Transformer decoder with causal masking and cross-attention.

Mirrors notebook cell 19. Two changes from the notebook, both behaviour-
preserving when defaults match:

1. **Globals are now constructor arguments.** The notebook closes over
   ``tokenizer.vocabulary_size()`` and ``MAX_LENGTH`` from module scope.
   We pass them in as ``vocab_size`` and ``max_len`` so the decoder can be
   instantiated in tests, factories, and notebooks without setting up a
   global tokenizer first.
2. **Dropout rates and attention head count are configurable** with the
   notebook values as defaults. This costs nothing today and lets Phase 1b
   ablations vary them without code changes.
"""

from __future__ import annotations

from captioning.models.embeddings import Embeddings


def _build_transformer_decoder_class():
    import tensorflow as tf

    class TransformerDecoderLayer(tf.keras.layers.Layer):
        """Causal self-attention + cross-attention + FFN block.

        Args:
            embed_dim: Token/positional embedding dimension. Must equal the
                encoder's ``embed_dim``.
            units: Hidden dimension of the feed-forward sub-block.
            num_heads: Multi-head attention heads. Notebook uses 8.
            vocab_size: Output projection dimension (the model emits softmax
                probabilities over the vocabulary).
            max_len: Maximum decode length, used to size positional embeddings.
            attention_dropout: Dropout applied inside MultiHeadAttention.
                Notebook uses 0.1.
            inner_dropout: Dropout after the first dense layer in the FFN.
                Notebook uses 0.3.
            outer_dropout: Dropout after the residual + final layernorm.
                Notebook uses 0.5.
        """

        def __init__(
            self,
            embed_dim: int,
            units: int,
            num_heads: int,
            vocab_size: int,
            max_len: int,
            attention_dropout: float = 0.1,
            inner_dropout: float = 0.3,
            outer_dropout: float = 0.5,
        ) -> None:
            super().__init__()
            self.embedding = Embeddings(vocab_size, embed_dim, max_len)

            self.attention_1 = tf.keras.layers.MultiHeadAttention(
                num_heads=num_heads, key_dim=embed_dim, dropout=attention_dropout
            )
            self.attention_2 = tf.keras.layers.MultiHeadAttention(
                num_heads=num_heads, key_dim=embed_dim, dropout=attention_dropout
            )

            self.layernorm_1 = tf.keras.layers.LayerNormalization()
            self.layernorm_2 = tf.keras.layers.LayerNormalization()
            self.layernorm_3 = tf.keras.layers.LayerNormalization()

            self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
            self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)

            self.out = tf.keras.layers.Dense(vocab_size, activation="softmax")

            self.dropout_1 = tf.keras.layers.Dropout(inner_dropout)
            self.dropout_2 = tf.keras.layers.Dropout(outer_dropout)

        def call(self, input_ids, encoder_output, training, mask=None):
            embeddings = self.embedding(input_ids)

            combined_mask = None
            padding_mask = None

            if mask is not None:
                causal_mask = self.get_causal_attention_mask(embeddings)
                padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
                combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
                combined_mask = tf.minimum(combined_mask, causal_mask)

            attn_output_1 = self.attention_1(
                query=embeddings,
                value=embeddings,
                key=embeddings,
                attention_mask=combined_mask,
                training=training,
            )
            out_1 = self.layernorm_1(embeddings + attn_output_1)

            attn_output_2 = self.attention_2(
                query=out_1,
                value=encoder_output,
                key=encoder_output,
                attention_mask=padding_mask,
                training=training,
            )
            out_2 = self.layernorm_2(out_1 + attn_output_2)

            ffn_out = self.ffn_layer_1(out_2)
            ffn_out = self.dropout_1(ffn_out, training=training)
            ffn_out = self.ffn_layer_2(ffn_out)

            ffn_out = self.layernorm_3(ffn_out + out_2)
            ffn_out = self.dropout_2(ffn_out, training=training)
            return self.out(ffn_out)

        def get_causal_attention_mask(self, inputs):
            input_shape = tf.shape(inputs)
            batch_size, sequence_length = input_shape[0], input_shape[1]
            i = tf.range(sequence_length)[:, tf.newaxis]
            j = tf.range(sequence_length)
            mask = tf.cast(i >= j, dtype="int32")
            mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
            mult = tf.concat(
                [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
                axis=0,
            )
            return tf.tile(mask, mult)

    return TransformerDecoderLayer


TransformerDecoderLayer = _build_transformer_decoder_class()