File size: 2,856 Bytes
3a2e5f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Greedy caption generation.

Mirrors notebook cell 25's ``generate_caption`` exactly. The notebook closes
over four globals (``caption_model``, ``tokenizer``, ``idx2word``,
``MAX_LENGTH``); we accept them as explicit arguments so the function is
callable from tests, scripts, FastAPI, and the parity audit.

The algorithm:
    1. CNN-encode the image.
    2. Transformer-encode the patch features.
    3. Seed the caption with ``[start]``.
    4. For each position 0 ... ``max_length - 2``:
        a. Tokenise the partial caption (``[:, :-1]`` because TextVectorization
           pads to ``max_length`` and we feed ``max_length - 1`` positions
           into the decoder).
        b. Decode and take the argmax at the current position.
        c. Stop on ``[end]``; otherwise append the predicted word.
    5. Strip the ``[start]`` prefix and return.
"""

from __future__ import annotations

from captioning.preprocessing.caption import END_TOKEN, START_TOKEN
from captioning.preprocessing.tokenizer import CaptionTokenizer


def generate_caption_greedy(
    model,
    tokenizer: CaptionTokenizer,
    image_tensor,
    max_length: int,
    *,
    add_noise: bool = False,
) -> str:
    """Generate a caption for one image using greedy (argmax) decoding.

    Args:
        model: An ``ImageCaptioningModel`` whose weights have been loaded.
        tokenizer: Fitted ``CaptionTokenizer`` (the same one used at training).
        image_tensor: A ``[299, 299, 3]`` float tensor produced by
            ``inference.load_image_from_path`` (or ``preprocess_image_tensor``).
        max_length: Decode budget — equals ``config.model.max_length`` (40
            in the notebook).
        add_noise: Replicates the notebook's ``add_noise`` knob; off by default.

    Returns:
        The generated caption string with the ``[start]`` sentinel removed.
        The ``[end]`` sentinel is naturally absent because the loop breaks on it.
    """
    import numpy as np
    import tensorflow as tf

    img = image_tensor
    if add_noise:
        noise = tf.random.normal(img.shape) * 0.1
        img = img + noise
        img = (img - tf.reduce_min(img)) / (tf.reduce_max(img) - tf.reduce_min(img))

    img = tf.expand_dims(img, axis=0)
    img_embed = model.cnn_model(img)
    img_encoded = model.encoder(img_embed, training=False)

    y_inp = START_TOKEN
    for i in range(max_length - 1):
        tokenized = tokenizer.encode([y_inp])[:, :-1]
        mask = tf.cast(tokenized != 0, tf.int32)
        pred = model.decoder(tokenized, img_encoded, training=False, mask=mask)

        pred_idx = np.argmax(pred[0, i, :])
        pred_idx = tf.convert_to_tensor(pred_idx)
        pred_word = tokenizer.decode_id(pred_idx)
        if pred_word == END_TOKEN:
            break
        y_inp += " " + pred_word

    return y_inp.replace(f"{START_TOKEN} ", "")