apoorvrajdev's picture
feat: finalize Phase 1 modular ML architecture
3a2e5f0
"""Greedy caption generation.
Mirrors notebook cell 25's ``generate_caption`` exactly. The notebook closes
over four globals (``caption_model``, ``tokenizer``, ``idx2word``,
``MAX_LENGTH``); we accept them as explicit arguments so the function is
callable from tests, scripts, FastAPI, and the parity audit.
The algorithm:
1. CNN-encode the image.
2. Transformer-encode the patch features.
3. Seed the caption with ``[start]``.
4. For each position 0 ... ``max_length - 2``:
a. Tokenise the partial caption (``[:, :-1]`` because TextVectorization
pads to ``max_length`` and we feed ``max_length - 1`` positions
into the decoder).
b. Decode and take the argmax at the current position.
c. Stop on ``[end]``; otherwise append the predicted word.
5. Strip the ``[start]`` prefix and return.
"""
from __future__ import annotations
from captioning.preprocessing.caption import END_TOKEN, START_TOKEN
from captioning.preprocessing.tokenizer import CaptionTokenizer
def generate_caption_greedy(
model,
tokenizer: CaptionTokenizer,
image_tensor,
max_length: int,
*,
add_noise: bool = False,
) -> str:
"""Generate a caption for one image using greedy (argmax) decoding.
Args:
model: An ``ImageCaptioningModel`` whose weights have been loaded.
tokenizer: Fitted ``CaptionTokenizer`` (the same one used at training).
image_tensor: A ``[299, 299, 3]`` float tensor produced by
``inference.load_image_from_path`` (or ``preprocess_image_tensor``).
max_length: Decode budget — equals ``config.model.max_length`` (40
in the notebook).
add_noise: Replicates the notebook's ``add_noise`` knob; off by default.
Returns:
The generated caption string with the ``[start]`` sentinel removed.
The ``[end]`` sentinel is naturally absent because the loop breaks on it.
"""
import numpy as np
import tensorflow as tf
img = image_tensor
if add_noise:
noise = tf.random.normal(img.shape) * 0.1
img = img + noise
img = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img))
img = tf.expand_dims(img, axis=0)
img_embed = model.cnn_model(img)
img_encoded = model.encoder(img_embed, training=False)
y_inp = START_TOKEN
for i in range(max_length - 1):
tokenized = tokenizer.encode([y_inp])[:, :-1]
mask = tf.cast(tokenized != 0, tf.int32)
pred = model.decoder(tokenized, img_encoded, training=False, mask=mask)
pred_idx = np.argmax(pred[0, i, :])
pred_idx = tf.convert_to_tensor(pred_idx)
pred_word = tokenizer.decode_id(pred_idx)
if pred_word == END_TOKEN:
break
y_inp += " " + pred_word
return y_inp.replace(f"{START_TOKEN} ", "")