Spaces:
Configuration error
Configuration error
File size: 6,053 Bytes
3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 91a1214 3a2e5f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """``ImageCaptioningModel`` — top-level Keras model with custom train/test step.
Mirrors notebook cell 20 verbatim by default. The model owns its own loss &
accuracy trackers (rather than using compile-time metrics) because the masked
arithmetic in ``calculate_loss`` / ``calculate_accuracy`` depends on the
caption padding mask, which Keras's standard metric API can't see.
Two opt-in fixes layered on top of the baseline, both controlled by the
constructor (defaults preserve the IEEE notebook quirk exactly):
* ``honour_training_flag_in_test_step``
The notebook's ``compute_loss_and_acc`` hardcodes ``training=True`` on
both the encoder and decoder calls, even when invoked from ``test_step``.
That means dropout is active during validation in the IEEE results.
Setting this flag to True restores the conventional behaviour — dropout
off in test_step — so val_loss reflects deployment behaviour and early
stopping fires on a clean signal.
* ``correct_masked_accuracy``
The baseline's accuracy tracker is a ``Mean`` of per-batch ratios, which
weights small batches the same as large ones. Setting this flag to True
feeds the per-batch token count as ``sample_weight`` so the reported
metric is a true global token-level masked accuracy.
Both knobs are off by default to keep numeric parity with the notebook; the
trainer flips them on automatically when the user opts in via
``train.honour_training_flag_in_test_step``.
"""
from __future__ import annotations
def _build_captioning_model_class():
import tensorflow as tf
class ImageCaptioningModel(tf.keras.Model):
"""Stitches CNN encoder + Transformer encoder + Transformer decoder."""
def __init__(
self,
cnn_model,
encoder,
decoder,
image_aug=None,
*,
honour_training_flag_in_test_step: bool = False,
correct_masked_accuracy: bool = False,
) -> None:
super().__init__()
self.cnn_model = cnn_model
self.encoder = encoder
self.decoder = decoder
self.image_aug = image_aug
self.honour_training_flag_in_test_step = honour_training_flag_in_test_step
self.correct_masked_accuracy = correct_masked_accuracy
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")
# --- masked metrics (notebook cell 20) -----------------------------
def calculate_loss(self, y_true, y_pred, mask):
loss = self.loss(y_true, y_pred)
mask = tf.cast(mask, dtype=loss.dtype)
loss *= mask
return tf.reduce_sum(loss) / tf.reduce_sum(mask)
def calculate_accuracy(self, y_true, y_pred, mask):
accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
accuracy = tf.math.logical_and(mask, accuracy)
accuracy = tf.cast(accuracy, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
# --- shared loss/acc step (parity quirk: training=True hardcoded) --
def compute_loss_and_acc(self, img_embed, captions, training=True):
# The IEEE notebook hardcoded `training=True` on encoder/decoder
# calls even from `test_step`, which means dropout is on during
# validation. Honouring the flag (opt-in) restores the standard
# behaviour and gives a cleaner val_loss signal.
effective_training = bool(training) if self.honour_training_flag_in_test_step else True
encoder_output = self.encoder(img_embed, training=effective_training)
y_input = captions[:, :-1]
y_true = captions[:, 1:]
mask = y_true != 0
y_pred = self.decoder(y_input, encoder_output, training=effective_training, mask=mask)
loss = self.calculate_loss(y_true, y_pred, mask)
acc = self.calculate_accuracy(y_true, y_pred, mask)
mask_count = tf.reduce_sum(tf.cast(mask, tf.float32))
return loss, acc, mask_count
# --- Keras hooks ---------------------------------------------------
def train_step(self, batch):
imgs, captions = batch
if self.image_aug:
imgs = self.image_aug(imgs)
img_embed = self.cnn_model(imgs)
with tf.GradientTape() as tape:
loss, acc, mask_count = self.compute_loss_and_acc(img_embed, captions)
train_vars = self.encoder.trainable_variables + self.decoder.trainable_variables
grads = tape.gradient(loss, train_vars)
self.optimizer.apply_gradients(zip(grads, train_vars, strict=False))
self.loss_tracker.update_state(loss)
if self.correct_masked_accuracy:
# Weight per-batch accuracy by token count so the epoch
# average is a true global accuracy, not a mean of ratios.
self.acc_tracker.update_state(acc, sample_weight=mask_count)
else:
self.acc_tracker.update_state(acc)
return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
def test_step(self, batch):
imgs, captions = batch
img_embed = self.cnn_model(imgs)
loss, acc, mask_count = self.compute_loss_and_acc(img_embed, captions, training=False)
self.loss_tracker.update_state(loss)
if self.correct_masked_accuracy:
self.acc_tracker.update_state(acc, sample_weight=mask_count)
else:
self.acc_tracker.update_state(acc)
return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
@property
def metrics(self):
return [self.loss_tracker, self.acc_tracker]
return ImageCaptioningModel
ImageCaptioningModel = _build_captioning_model_class()
|