"""Multi-head Transformer decoder with causal masking and cross-attention. Mirrors notebook cell 19. Two changes from the notebook, both behaviour- preserving when defaults match: 1. **Globals are now constructor arguments.** The notebook closes over ``tokenizer.vocabulary_size()`` and ``MAX_LENGTH`` from module scope. We pass them in as ``vocab_size`` and ``max_len`` so the decoder can be instantiated in tests, factories, and notebooks without setting up a global tokenizer first. 2. **Dropout rates and attention head count are configurable** with the notebook values as defaults. This costs nothing today and lets Phase 1b ablations vary them without code changes. """ from __future__ import annotations from captioning.models.embeddings import Embeddings def _build_transformer_decoder_class(): import tensorflow as tf class TransformerDecoderLayer(tf.keras.layers.Layer): """Causal self-attention + cross-attention + FFN block. Args: embed_dim: Token/positional embedding dimension. Must equal the encoder's ``embed_dim``. units: Hidden dimension of the feed-forward sub-block. num_heads: Multi-head attention heads. Notebook uses 8. vocab_size: Output projection dimension (the model emits softmax probabilities over the vocabulary). max_len: Maximum decode length, used to size positional embeddings. attention_dropout: Dropout applied inside MultiHeadAttention. Notebook uses 0.1. inner_dropout: Dropout after the first dense layer in the FFN. Notebook uses 0.3. outer_dropout: Dropout after the residual + final layernorm. Notebook uses 0.5. """ def __init__( self, embed_dim: int, units: int, num_heads: int, vocab_size: int, max_len: int, attention_dropout: float = 0.1, inner_dropout: float = 0.3, outer_dropout: float = 0.5, ) -> None: super().__init__() self.embedding = Embeddings(vocab_size, embed_dim, max_len) self.attention_1 = tf.keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim, dropout=attention_dropout ) self.attention_2 = tf.keras.layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim, dropout=attention_dropout ) self.layernorm_1 = tf.keras.layers.LayerNormalization() self.layernorm_2 = tf.keras.layers.LayerNormalization() self.layernorm_3 = tf.keras.layers.LayerNormalization() self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu") self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim) self.out = tf.keras.layers.Dense(vocab_size, activation="softmax") self.dropout_1 = tf.keras.layers.Dropout(inner_dropout) self.dropout_2 = tf.keras.layers.Dropout(outer_dropout) def call(self, input_ids, encoder_output, training, mask=None): embeddings = self.embedding(input_ids) combined_mask = None padding_mask = None if mask is not None: causal_mask = self.get_causal_attention_mask(embeddings) padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32) combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32) combined_mask = tf.minimum(combined_mask, causal_mask) attn_output_1 = self.attention_1( query=embeddings, value=embeddings, key=embeddings, attention_mask=combined_mask, training=training, ) out_1 = self.layernorm_1(embeddings + attn_output_1) attn_output_2 = self.attention_2( query=out_1, value=encoder_output, key=encoder_output, attention_mask=padding_mask, training=training, ) out_2 = self.layernorm_2(out_1 + attn_output_2) ffn_out = self.ffn_layer_1(out_2) ffn_out = self.dropout_1(ffn_out, training=training) ffn_out = self.ffn_layer_2(ffn_out) ffn_out = self.layernorm_3(ffn_out + out_2) ffn_out = self.dropout_2(ffn_out, training=training) return self.out(ffn_out) def get_causal_attention_mask(self, inputs): input_shape = tf.shape(inputs) batch_size, sequence_length = input_shape[0], input_shape[1] i = tf.range(sequence_length)[:, tf.newaxis] j = tf.range(sequence_length) mask = tf.cast(i >= j, dtype="int32") mask = tf.reshape(mask, (1, input_shape[1], input_shape[1])) mult = tf.concat( [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], axis=0, ) return tf.tile(mask, mult) return TransformerDecoderLayer TransformerDecoderLayer = _build_transformer_decoder_class()