Spaces:
Configuration error
Configuration error
File size: 5,250 Bytes
3a2e5f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """Multi-head Transformer decoder with causal masking and cross-attention.
Mirrors notebook cell 19. Two changes from the notebook, both behaviour-
preserving when defaults match:
1. **Globals are now constructor arguments.** The notebook closes over
``tokenizer.vocabulary_size()`` and ``MAX_LENGTH`` from module scope.
We pass them in as ``vocab_size`` and ``max_len`` so the decoder can be
instantiated in tests, factories, and notebooks without setting up a
global tokenizer first.
2. **Dropout rates and attention head count are configurable** with the
notebook values as defaults. This costs nothing today and lets Phase 1b
ablations vary them without code changes.
"""
from __future__ import annotations
from captioning.models.embeddings import Embeddings
def _build_transformer_decoder_class():
import tensorflow as tf
class TransformerDecoderLayer(tf.keras.layers.Layer):
"""Causal self-attention + cross-attention + FFN block.
Args:
embed_dim: Token/positional embedding dimension. Must equal the
encoder's ``embed_dim``.
units: Hidden dimension of the feed-forward sub-block.
num_heads: Multi-head attention heads. Notebook uses 8.
vocab_size: Output projection dimension (the model emits softmax
probabilities over the vocabulary).
max_len: Maximum decode length, used to size positional embeddings.
attention_dropout: Dropout applied inside MultiHeadAttention.
Notebook uses 0.1.
inner_dropout: Dropout after the first dense layer in the FFN.
Notebook uses 0.3.
outer_dropout: Dropout after the residual + final layernorm.
Notebook uses 0.5.
"""
def __init__(
self,
embed_dim: int,
units: int,
num_heads: int,
vocab_size: int,
max_len: int,
attention_dropout: float = 0.1,
inner_dropout: float = 0.3,
outer_dropout: float = 0.5,
) -> None:
super().__init__()
self.embedding = Embeddings(vocab_size, embed_dim, max_len)
self.attention_1 = tf.keras.layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim, dropout=attention_dropout
)
self.attention_2 = tf.keras.layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim, dropout=attention_dropout
)
self.layernorm_1 = tf.keras.layers.LayerNormalization()
self.layernorm_2 = tf.keras.layers.LayerNormalization()
self.layernorm_3 = tf.keras.layers.LayerNormalization()
self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)
self.out = tf.keras.layers.Dense(vocab_size, activation="softmax")
self.dropout_1 = tf.keras.layers.Dropout(inner_dropout)
self.dropout_2 = tf.keras.layers.Dropout(outer_dropout)
def call(self, input_ids, encoder_output, training, mask=None):
embeddings = self.embedding(input_ids)
combined_mask = None
padding_mask = None
if mask is not None:
causal_mask = self.get_causal_attention_mask(embeddings)
padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
combined_mask = tf.minimum(combined_mask, causal_mask)
attn_output_1 = self.attention_1(
query=embeddings,
value=embeddings,
key=embeddings,
attention_mask=combined_mask,
training=training,
)
out_1 = self.layernorm_1(embeddings + attn_output_1)
attn_output_2 = self.attention_2(
query=out_1,
value=encoder_output,
key=encoder_output,
attention_mask=padding_mask,
training=training,
)
out_2 = self.layernorm_2(out_1 + attn_output_2)
ffn_out = self.ffn_layer_1(out_2)
ffn_out = self.dropout_1(ffn_out, training=training)
ffn_out = self.ffn_layer_2(ffn_out)
ffn_out = self.layernorm_3(ffn_out + out_2)
ffn_out = self.dropout_2(ffn_out, training=training)
return self.out(ffn_out)
def get_causal_attention_mask(self, inputs):
input_shape = tf.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = tf.range(sequence_length)[:, tf.newaxis]
j = tf.range(sequence_length)
mask = tf.cast(i >= j, dtype="int32")
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
mult = tf.concat(
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
axis=0,
)
return tf.tile(mask, mult)
return TransformerDecoderLayer
TransformerDecoderLayer = _build_transformer_decoder_class()
|