File size: 6,053 Bytes
3a2e5f0
 
91a1214
 
3a2e5f0
 
 
91a1214
 
 
 
3a2e5f0
 
 
91a1214
 
 
 
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
 
 
 
 
3a2e5f0
 
 
 
 
91a1214
 
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a1214
 
 
 
 
 
3a2e5f0
 
 
91a1214
3a2e5f0
 
91a1214
 
3a2e5f0
 
 
 
 
 
 
 
 
 
91a1214
3a2e5f0
 
 
 
 
91a1214
 
 
 
 
 
3a2e5f0
 
 
 
 
 
91a1214
3a2e5f0
91a1214
 
 
 
3a2e5f0
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""``ImageCaptioningModel`` — top-level Keras model with custom train/test step.

Mirrors notebook cell 20 verbatim by default. The model owns its own loss &
accuracy trackers (rather than using compile-time metrics) because the masked
arithmetic in ``calculate_loss`` / ``calculate_accuracy`` depends on the
caption padding mask, which Keras's standard metric API can't see.

Two opt-in fixes layered on top of the baseline, both controlled by the
constructor (defaults preserve the IEEE notebook quirk exactly):

* ``honour_training_flag_in_test_step``
    The notebook's ``compute_loss_and_acc`` hardcodes ``training=True`` on
    both the encoder and decoder calls, even when invoked from ``test_step``.
    That means dropout is active during validation in the IEEE results.
    Setting this flag to True restores the conventional behaviour — dropout
    off in test_step — so val_loss reflects deployment behaviour and early
    stopping fires on a clean signal.

* ``correct_masked_accuracy``
    The baseline's accuracy tracker is a ``Mean`` of per-batch ratios, which
    weights small batches the same as large ones. Setting this flag to True
    feeds the per-batch token count as ``sample_weight`` so the reported
    metric is a true global token-level masked accuracy.

Both knobs are off by default to keep numeric parity with the notebook; the
trainer flips them on automatically when the user opts in via
``train.honour_training_flag_in_test_step``.
"""

from __future__ import annotations


def _build_captioning_model_class():
    import tensorflow as tf

    class ImageCaptioningModel(tf.keras.Model):
        """Stitches CNN encoder + Transformer encoder + Transformer decoder."""

        def __init__(
            self,
            cnn_model,
            encoder,
            decoder,
            image_aug=None,
            *,
            honour_training_flag_in_test_step: bool = False,
            correct_masked_accuracy: bool = False,
        ) -> None:
            super().__init__()
            self.cnn_model = cnn_model
            self.encoder = encoder
            self.decoder = decoder
            self.image_aug = image_aug
            self.honour_training_flag_in_test_step = honour_training_flag_in_test_step
            self.correct_masked_accuracy = correct_masked_accuracy
            self.loss_tracker = tf.keras.metrics.Mean(name="loss")
            self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")

        # --- masked metrics (notebook cell 20) -----------------------------

        def calculate_loss(self, y_true, y_pred, mask):
            loss = self.loss(y_true, y_pred)
            mask = tf.cast(mask, dtype=loss.dtype)
            loss *= mask
            return tf.reduce_sum(loss) / tf.reduce_sum(mask)

        def calculate_accuracy(self, y_true, y_pred, mask):
            accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
            accuracy = tf.math.logical_and(mask, accuracy)
            accuracy = tf.cast(accuracy, dtype=tf.float32)
            mask = tf.cast(mask, dtype=tf.float32)
            return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)

        # --- shared loss/acc step (parity quirk: training=True hardcoded) --

        def compute_loss_and_acc(self, img_embed, captions, training=True):
            # The IEEE notebook hardcoded `training=True` on encoder/decoder
            # calls even from `test_step`, which means dropout is on during
            # validation. Honouring the flag (opt-in) restores the standard
            # behaviour and gives a cleaner val_loss signal.
            effective_training = bool(training) if self.honour_training_flag_in_test_step else True
            encoder_output = self.encoder(img_embed, training=effective_training)
            y_input = captions[:, :-1]
            y_true = captions[:, 1:]
            mask = y_true != 0
            y_pred = self.decoder(y_input, encoder_output, training=effective_training, mask=mask)
            loss = self.calculate_loss(y_true, y_pred, mask)
            acc = self.calculate_accuracy(y_true, y_pred, mask)
            mask_count = tf.reduce_sum(tf.cast(mask, tf.float32))
            return loss, acc, mask_count

        # --- Keras hooks ---------------------------------------------------

        def train_step(self, batch):
            imgs, captions = batch
            if self.image_aug:
                imgs = self.image_aug(imgs)
            img_embed = self.cnn_model(imgs)

            with tf.GradientTape() as tape:
                loss, acc, mask_count = self.compute_loss_and_acc(img_embed, captions)

            train_vars = self.encoder.trainable_variables + self.decoder.trainable_variables
            grads = tape.gradient(loss, train_vars)
            self.optimizer.apply_gradients(zip(grads, train_vars, strict=False))
            self.loss_tracker.update_state(loss)
            if self.correct_masked_accuracy:
                # Weight per-batch accuracy by token count so the epoch
                # average is a true global accuracy, not a mean of ratios.
                self.acc_tracker.update_state(acc, sample_weight=mask_count)
            else:
                self.acc_tracker.update_state(acc)

            return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}

        def test_step(self, batch):
            imgs, captions = batch
            img_embed = self.cnn_model(imgs)
            loss, acc, mask_count = self.compute_loss_and_acc(img_embed, captions, training=False)
            self.loss_tracker.update_state(loss)
            if self.correct_masked_accuracy:
                self.acc_tracker.update_state(acc, sample_weight=mask_count)
            else:
                self.acc_tracker.update_state(acc)
            return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}

        @property
        def metrics(self):
            return [self.loss_tracker, self.acc_tracker]

    return ImageCaptioningModel


ImageCaptioningModel = _build_captioning_model_class()