File size: 3,173 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""``build_caption_model(config, vocab_size)`` — single place to wire layers.

Mirrors notebook cell 21::

    encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
    decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
    cnn_model = CNN_Encoder()
    caption_model = ImageCaptioningModel(
        cnn_model=cnn_model,
        encoder=encoder,
        decoder=decoder,
        image_aug=image_augmentation,
    )

Pulling this into a factory function isolates "how layers are wired" from
"what hyperparameters they use", so Phase 1b ablations and Phase 5 model
swaps only touch this file.
"""

from __future__ import annotations

from captioning.config.schema import AppConfig
from captioning.models.captioning_model import ImageCaptioningModel
from captioning.models.encoder_cnn import build_cnn_encoder
from captioning.models.transformer_decoder import TransformerDecoderLayer
from captioning.models.transformer_encoder import TransformerEncoderLayer
from captioning.preprocessing.augmentation import default_image_augmentation


def build_caption_model(
    config: AppConfig,
    vocab_size: int,
    *,
    use_augmentation: bool = True,
):
    """Construct a ready-to-compile ``ImageCaptioningModel``.

    Args:
        config: Validated app config (the ``model`` section is consumed here).
        vocab_size: Comes from the *fitted* tokenizer
            (``CaptionTokenizer.vocabulary_size``). The factory does not own
            tokenizer state — callers fit the tokenizer first, pass the size in.
        use_augmentation: If True (default), wires
            ``default_image_augmentation()`` for ``train_step``. Inference and
            evaluation paths pass False.

    Returns:
        An uncompiled ``ImageCaptioningModel``. Caller is responsible for
        ``model.compile(optimizer=..., loss=...)``.
    """
    m = config.model

    encoder = TransformerEncoderLayer(m.embedding_dim, m.encoder_num_heads)
    decoder = TransformerDecoderLayer(
        embed_dim=m.embedding_dim,
        units=m.units,
        num_heads=m.decoder_num_heads,
        vocab_size=vocab_size,
        max_len=m.max_length,
        attention_dropout=m.decoder_attention_dropout,
        inner_dropout=m.decoder_dropout_inner,
        outer_dropout=m.decoder_dropout_outer,
    )
    cnn = build_cnn_encoder()
    aug = default_image_augmentation() if use_augmentation else None

    # ``honour_training_flag_in_test_step`` and ``correct_masked_accuracy``
    # default to False so this factory keeps producing notebook-parity models
    # unless the user opts in by flipping the corresponding YAML flag.
    honour_flag = bool(config.train.honour_training_flag_in_test_step)
    # The masked-accuracy correction is harmless under parity (it's a
    # better-weighted average of the same per-batch numbers), so we tie it to
    # the same opt-in flag rather than adding a separate one — keeps the
    # YAML surface minimal.
    return ImageCaptioningModel(
        cnn_model=cnn,
        encoder=encoder,
        decoder=decoder,
        image_aug=aug,
        honour_training_flag_in_test_step=honour_flag,
        correct_masked_accuracy=honour_flag,
    )